diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..a9fa0b90a63aa634f8a98da808c146735d1cff71 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,32 @@ +# Build-context / image trimming for the Learn2Splat demo Space. +# The Dockerfile needs: demo.py, optgs/, submodules/, requirements.txt, +# pyproject.toml, LICENSE — keep those; drop everything below. + +# Secrets — never copy into the image. +.env +.env.* +/wandb/ + +# Git + Python build droppings. +.git/ +.gitignore +**/__pycache__/ +**/*.pyc +**/*.egg-info/ +submodules/*/build/ + +# Large runtime artefacts — fetched into the container on first run. +/data/ +/checkpoints/ +/results/ + +# Repo material the demo doesn't use. +/assets/ +/docs/ +/figures/ +/tests/ +/scripts/ +/mlcloud_scripts/ +/visualization/ +/todo/ +huggingface_space/ diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..6f2964f366600daa0074562c00926f641007d2bc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +submodules/fused-ssim/images/albert.jpg filter=lfs diff=lfs merge=lfs -text +submodules/fused-ssim/images/inference_time.png filter=lfs diff=lfs merge=lfs -text +submodules/fused-ssim/images/inference_time_4090.png filter=lfs diff=lfs merge=lfs -text +submodules/fused-ssim/images/predicted.jpg filter=lfs diff=lfs merge=lfs -text +submodules/fused-ssim/images/training_time.png filter=lfs diff=lfs merge=lfs -text +submodules/fused-ssim/images/training_time_4090.png filter=lfs diff=lfs merge=lfs -text diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..063ef7dd0ef2d24410c1759b7cf676a1817b30c7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,76 @@ +# Learn2Splat — interactive demo for a Hugging Face Space (Docker SDK, GPU). +# +# Builds the optgs package + its CUDA extensions and runs demo.py's viser GUI: +# SfM-initialize a COLMAP scene, then refine the Gaussians with the learned +# optimizer live in the browser. Mirrors setup.sh, minus conda — the CUDA +# toolkit ships in the base image. +# +# Build context = the optgs repo root (see huggingface_space/DEPLOY.md). +# Hardware: pick a GPU in the Space settings — A10G (24 GB) recommended; the +# GUI holds the dense and sparse checkpoints in VRAM at once. + +# CUDA 12.8 devel (nvcc + headers); Ubuntu 22.04 — the OS setup.sh is tested on. +# A devel base is required: gsplat / nerfacc JIT-compile CUDA on first use, so +# nvcc must also be present at runtime. +FROM nvidia/cuda:12.8.0-devel-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + # Compile the CUDA extensions for every GPU a Space may run on + # (T4 7.5 · A100 8.0 · A10G 8.6 · L4/L40S 8.9 · H100 9.0). Trim this to + # your chosen GPU to shorten the build. + TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6 8.9 9.0+PTX" + +# Build tools + extension headers (libglm-dev) and the OpenCV runtime libs +# (libgl1, libglib2.0-0 — optgs's COLMAP loader imports cv2). +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 python3-dev python3-venv \ + git build-essential ninja-build libglm-dev \ + libgl1 libglib2.0-0 ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# HF Spaces convention: run as a non-root user (UID 1000). +RUN useradd -m -u 1000 user +USER user +ENV HOME=/home/user \ + HF_HOME=/home/user/.cache/huggingface \ + TORCH_HOME=/home/user/.cache/torch +WORKDIR /home/user/app + +# All Python work happens in a venv on PATH (no system-Python writes). +RUN python3 -m venv /home/user/venv +ENV PATH=/home/user/venv/bin:$PATH +RUN pip install --upgrade pip setuptools wheel + +# PyTorch (CUDA 12.8) — pinned to setup.sh. +RUN pip install torch==2.7.1 torchvision==0.22.1 torchaudio==2.7.1 \ + --index-url https://download.pytorch.org/whl/cu128 + +# Python requirements (copied first so this layer caches across code edits). +COPY --chown=user:user requirements.txt . +RUN pip install -r requirements.txt + +# gsplat + nerfacc — built from git against the torch installed above. +RUN pip install --no-build-isolation \ + git+https://github.com/nerfstudio-project/nerfacc \ + git+https://github.com/nerfstudio-project/gsplat.git + +# The optgs repo. +COPY --chown=user:user . . + +# CUDA-extension submodules, then optgs itself. pycolmap is the pure-Python +# COLMAP reader (no C++ build); the other four compile CUDA kernels. +RUN pip install submodules/pycolmap \ + && pip install --no-build-isolation submodules/fused-ssim \ + && pip install --no-build-isolation submodules/simple-knn \ + && pip install --no-build-isolation submodules/pointops \ + && pip install --no-build-isolation submodules/fused_knn_attn \ + && pip install --no-build-isolation --no-deps -e . + +# viser serves the GUI here — must equal app_port in README.md. +EXPOSE 7860 + +# client mode: viser ships the splats to the browser's WebGL renderer, so the +# GPU is used only for optimization. viser binds 0.0.0.0 by default. +CMD ["python", "demo.py", "--with-gui", "client", "--gui-port", "7860"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bbbeaa0cdeb9e77da53b0011ec4163cadcc0d950 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Naama Pearl and Stefano Esposito + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 85fd0776496d6cb718e7750b058f3322825c2730..33ec5cfd686e01690bf8a78a40a1d3187ec90ada 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,24 @@ --- title: Learn2Splat -emoji: 😻 -colorFrom: indigo -colorTo: pink -sdk: gradio -sdk_version: 6.14.0 -python_version: '3.13' -app_file: app.py +emoji: 🪴 +colorFrom: green +colorTo: indigo +sdk: docker +app_port: 7860 pinned: false +short_description: Interactive demo of the Learn2Splat learned 3DGS optimizer --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# Learn2Splat — interactive demo + +A learned optimizer for 3D Gaussian Splatting. This Space SfM-initializes a +COLMAP scene and refines the Gaussians live in your browser: pick the +Learn2Splat optimizer (dense or sparse checkpoint) or a 3DGS Adam baseline, +press **Start**, and watch the splats converge. + +Runs `demo.py --with-gui client` from the +[Learn2Splat repository](https://github.com/autonomousvision/learn2splat); +the splats are drawn by viser's in-browser WebGL renderer. + +> Requires GPU hardware. The demo holds two checkpoints in VRAM at once — +> an A10G (24 GB) is recommended. diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..13325f206b2417b85d9211e37cd405ef80e908fb --- /dev/null +++ b/demo.py @@ -0,0 +1,766 @@ +"""End-to-end OptGS demo on a COLMAP scene. + +Main-codebase port of ``baselines/gsplat/examples/simple_trainer_optgs.py``: +same flow — SfM-initialize Gaussians, refine them with the learned optimizer +via the ``OptGS`` API, evaluate on held-out views — but using only the +``optgs`` package (no gsplat / gsplat-examples dependency): + + from optgs.experimental.api import OptGS + + optgs = OptGS(checkpoint="hf://org/repo/model.ckpt", device="cuda") + optgs.initialize_from_tensors(gaussians, batched_views) + refined = optgs.optimize() # learned optimization + +COLMAP loading uses ``optgs.dataset.colmap``; the SfM init builds an optgs +``Gaussians`` directly via ``points_to_gaussians``; evaluation renders with +the optimizer's own decoder. + +The scene is refined three ways and compared on held-out views: the learned +optimizer (Learn2Splat) with the *dense* and the *sparse* checkpoint, and a +3DGS Adam baseline (gsplat hyperparameters). All run through the same +``optimize()`` path with identical SfM init, view minibatches and step budget. +Each uses its checkpoint's gsplat renderer; ``--rasterize-mode`` / ``--eps2d`` +pin one renderer across all runs. + +Usage (run from the repo root, with ``optgs`` importable): + + python demo.py # headless: dense + sparse checkpoints + an Adam baseline + python demo.py --with-gui server # interactive viser GUI (frames rendered by the decoder) + python demo.py --with-gui client # interactive viser GUI (viser's WebGL splat renderer) + +The demo scene and the checkpoints are fetched from the Hugging Face Hub on +first run (cached under ./data and ./checkpoints). A CUDA device is required. +""" + +import warnings + +# Demo: silence third-party UserWarnings (xFormers/flash-attn not installed, +# Hydra's _self_ notice, pointops' deprecated tensor constructors) for clean output. +warnings.filterwarnings("ignore") + +import json +import os +import time +from dataclasses import dataclass +from typing import Dict, List, Literal, Optional, Tuple + +import imageio.v2 as imageio +import numpy as np +import torch +import torch.nn.functional as F +import tyro +from rich.console import Console +from rich.table import Table +from torch import Tensor + +console = Console() + +from optgs.dataset.colmap.utils import Dataset, Parser +from optgs.experimental.initializers_utils import knn, points_to_gaussians +from optgs.model.types import Gaussians +from optgs.scene_trainer.common.gaussian_adapter import build_covariance + +# Camera near/far planes — inria's znear/zfar (also the optgs colmap-dataset +# constants). Fixed; not a user knob. +NEAR_PLANE = 0.01 +FAR_PLANE = 100.0 + +# Spherical-harmonics DC -> RGB (3DGS convention: rgb = 0.5 + C0 * dc). Colours +# the splats for viser's client-side renderer. +SH_C0 = 0.28209479177387814 + +# The demo scene is fetched from this Hugging Face repo on first run. The repo +# mirrors the local layout, so e.g. ``data/mip360/garden`` in the repo lands at +# ``./data/mip360/garden``. +DEMO_DATA_REPO = "autonomousvision/learn2splat" + +# Learned-optimizer checkpoints on the Hugging Face Hub. hf:// refs are fetched +# and cached under ./checkpoints on first use (see optgs.misc.hf_ckpt). +CHECKPOINTS = { + "dense": "hf://autonomousvision/learn2splat/dense/checkpoints/epoch_5-step_50000.ckpt", + "sparse": "hf://autonomousvision/learn2splat/sparse/checkpoints/epoch_9-step_90000.ckpt", +} + + +def ensure_data(data_dir: str) -> None: + """Download the demo scene from the Hugging Face Hub if it is not present.""" + if os.path.isdir(data_dir) and os.listdir(data_dir): + return + from huggingface_hub import snapshot_download + + console.print( + f"[yellow]{data_dir}[/] not found — downloading from " + f"[cyan]hf://{DEMO_DATA_REPO}[/] …" + ) + snapshot_download( + repo_id=DEMO_DATA_REPO, + allow_patterns=[f"{data_dir.rstrip('/')}/**"], + local_dir=".", + ) + console.print(f"[green]✓[/] scene ready at [yellow]{data_dir}[/]") + + +@dataclass +class Config: + # Path to the COLMAP dataset (expects images/ + sparse/0/). + data_dir: str = "data/mip360/garden" + # Downsample factor for the dataset. + data_factor: int = 4 + # Global multiplier on scene-size-related parameters. + global_scale: float = 1.0 + # Normalize the world space. + normalize_world_space: bool = True + # Every N images is a test image, held out for evaluation. + test_every: int = 8 + # Directory to save renders / stats / the refined PLY. + result_dir: str = "results/demo" + # Random seed. + seed: int = 42 + + # --- Interactive GUI --- + # Launch a viser GUI instead of the headless comparison. "server" renders + # frames with the optgs decoder; "client" uses viser's built-in WebGL + # Gaussian-splat renderer. Unset = headless run. + with_gui: Optional[Literal["client", "server"]] = None + # Port for the viser GUI web server (--with-gui only). + gui_port: int = 8080 + + # --- OptGS learned optimizer --- + # Compute device (OptGS requires CUDA). + device: str = "cuda" + # Number of learned refinement steps. + max_steps: int = 100 + # Views the optimizer sees per refinement step (the view minibatch). + opt_batch_size: int = 8 + # View-minibatch sampling strategy: "random", "sequential", or "fps" + # (farthest-point sampling over camera positions). + opt_batch_strategy: Literal["random", "sequential", "fps"] = "fps" + + # --- gsplat renderer --- + # rasterize_mode / eps2d: when set, applied to every run (dense, sparse, + # Adam), overriding each checkpoint's decoder config so the comparison uses + # one renderer. Left unset, each run uses its own checkpoint's value. + rasterize_mode: Optional[Literal["classic", "antialiased"]] = None + eps2d: Optional[float] = None + + # --- Initialization --- + # Initialization strategy: "sfm" or "random". + init_type: str = "sfm" + # Initial number of GSs. Ignored when init_type="sfm". + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the scene extent (random init). + init_extent: float = 3.0 + # Initial opacity / scale of each GS. + init_opa: float = 0.1 + init_scale: float = 1.0 + + +def scene_extent(parser: Parser, global_scale: float) -> float: + """Scene-size scalar: parser extent x 1.1 x global_scale.""" + return parser.scene_scale * 1.1 * global_scale + + +def sfm_initialization( + parser: Parser, cfg: Config, sh_degree: int, device: torch.device, dtype: torch.dtype +) -> Gaussians: + """SfM (or random) Gaussian init -> an optgs ``Gaussians`` (batch=1). + + Builds the parameter tensors with the same heuristics as 3DGS / the optgs + COLMAP initializer, then assembles them through ``points_to_gaussians``. + """ + if cfg.init_type == "sfm": + points = torch.from_numpy(parser.points).float() + rgbs = torch.from_numpy(parser.points_rgb / 255.0).float() + elif cfg.init_type == "random": + extent = scene_extent(parser, cfg.global_scale) + points = cfg.init_extent * extent * ( + torch.rand((cfg.init_num_pts, 3)) * 2 - 1 + ) + rgbs = torch.rand((cfg.init_num_pts, 3)) + else: + raise ValueError(f"unknown init_type: {cfg.init_type!r} (sfm | random)") + + # GS size = average distance to the 3 nearest neighbours ([:, 1:] drops self). + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) + scales = (torch.sqrt(dist2_avg) * cfg.init_scale).unsqueeze(-1).repeat(1, 3) + opacities = torch.full((points.shape[0],), cfg.init_opa) + + # points_to_gaussians returns pre-activation params (log scales, logit + # opacity, sh0/shN, random quats). + g = points_to_gaussians( + {"xyz": points, "rgb": rgbs, "scales": scales, "opacities": opacities}, + sh_degree=sh_degree, + device=device, + ) + sh0, shN = g["sh0"], g["shN"] + harmonics = torch.cat([sh0, shN], dim=1) if shN is not None else sh0 # [N, K, 3] + harmonics = harmonics.permute(0, 2, 1) # -> [N, 3, K] + + scales_act = torch.exp(g["scales_raw"]) + opacities_act = torch.sigmoid(g["opacities_raw"]) + rotations = F.normalize(g["rotations_unnorm"], dim=-1) + covariances = build_covariance(scale=scales_act, rotation_xyzw=rotations) + + def _b(t: Tensor) -> Tensor: # add the batch dimension and cast + return t.unsqueeze(0).to(dtype) + + return Gaussians( + means=_b(g["xyz"]), + covariances=_b(covariances), + harmonics=_b(harmonics), + opacities=_b(opacities_act), + scales=_b(scales_act), + rotations=_b(rotations), + rotations_unnorm=_b(g["rotations_unnorm"]), + ) + + +def collect_cameras( + dataset: Dataset, indices: List[int] +) -> Tuple[Tensor, Tensor, Tensor]: + """Stack the selected views into ``(camtoworlds, Ks, images)``. + + ``images`` is returned in [0, 1]. All views must share one (H, W) — the + optgs renderer takes a single image shape. + """ + c2ws, ks, imgs = [], [], [] + hw = None + for i in indices: + data = dataset[i] + img = data["image"] / 255.0 # [H, W, 3], float + if hw is None: + hw = img.shape[:2] + elif img.shape[:2] != hw: + raise ValueError( + f"all views must share one (H, W); got {tuple(img.shape[:2])} " + f"vs {tuple(hw)}. Render the dataset at a single resolution." + ) + c2ws.append(data["camtoworld"]) + ks.append(data["K"]) + imgs.append(img) + return torch.stack(c2ws), torch.stack(ks), torch.stack(imgs) + + +def build_batched_views( + camtoworlds: Tensor, + Ks: Tensor, + images: Tensor, + scene_scale: float, + device: torch.device, + dtype: torch.dtype, +) -> dict: + """COLMAP cameras -> an optgs ``BatchedViews`` dict (batch=1). + + COLMAP ``camtoworld`` is already optgs's extrinsics convention (OpenCV + camera->world). ``K`` is pixel-space; optgs wants it normalized by image + width/height. + """ + v, h, w = images.shape[0], images.shape[1], images.shape[2] + + Ks_norm = Ks.clone() + Ks_norm[:, 0, :] /= w # normalized focal / principal point + Ks_norm[:, 1, :] /= h + + image = images.permute(0, 3, 1, 2) # [V, 3, H, W] + + def _b(t: Tensor) -> Tensor: # add the batch dimension and move to device + return t.unsqueeze(0).to(device=device, dtype=dtype) + + return { + "extrinsics": _b(camtoworlds), + "intrinsics": _b(Ks_norm), + "image": _b(image), + "near": torch.full((1, v), NEAR_PLANE, device=device, dtype=dtype), + "far": torch.full((1, v), FAR_PLANE, device=device, dtype=dtype), + "index": torch.arange(v, device=device).unsqueeze(0), + "scene_scale": torch.tensor([scene_scale], device=device, dtype=dtype), + } + + +@torch.no_grad() +def render_and_score( + optgs, + refined: Gaussians, + val_bv: dict, + val_images: Tensor, + out_dir: str, + device: torch.device, +) -> dict: + """Render one optimizer's result on the held-out views; report mean PSNR. + + Saves a ``gt | pred`` strip per view under ``out_dir/renders``. + """ + render_dir = os.path.join(out_dir, "renders") + os.makedirs(render_dir, exist_ok=True) + h, w = val_images.shape[1], val_images.shape[2] + + out = optgs.decoder.forward( + refined, val_bv["extrinsics"], val_bv["intrinsics"], + val_bv["near"], val_bv["far"], image_shape=(h, w), + ) + colors = out.color[0].clamp(0.0, 1.0) # [V, 3, H, W] + + psnrs = [] + for i in range(colors.shape[0]): + gt = val_images[i].to(device) # [H, W, 3] + pred = colors[i].permute(1, 2, 0) + psnrs.append(-10.0 * torch.log10(torch.mean((pred - gt) ** 2)).item()) + + canvas = torch.cat([gt, pred], dim=1).cpu().numpy() # gt | pred + imageio.imwrite( + os.path.join(render_dir, f"val_{i:04d}.png"), + (canvas * 255).astype(np.uint8), + ) + + return {"psnr": float(np.mean(psnrs)), "num_views": int(colors.shape[0])} + + +@torch.no_grad() +def render_view( + optgs, gaussians: Gaussians, camera, height: int, + device: torch.device, dtype: torch.dtype, +) -> np.ndarray: + """Render ``gaussians`` from a viser camera into an ``[H, W, 3]`` uint8 image. + + viser cameras follow OpenCV conventions, so ``(wxyz, position)`` is directly + the camera-to-world transform the optgs decoder expects — no axis flip. + """ + import viser.transforms as vtf + + from optgs.misc.image_io import prep_image + + h = int(height) + w = max(1, round(h * camera.aspect)) # camera.aspect = width / height + + c2w = torch.eye(4, device=device, dtype=dtype) + c2w[:3, :3] = torch.tensor( + vtf.SO3(camera.wxyz).as_matrix(), device=device, dtype=dtype + ) + c2w[:3, 3] = torch.tensor(camera.position, device=device, dtype=dtype) + + # Normalized intrinsics from the vertical fov; the decoder un-normalizes by + # the image width/height. + fy = (h / 2.0) / float(np.tan(camera.fov / 2.0)) + K = torch.eye(3, device=device, dtype=dtype) + K[0, 0] = fy / w + K[1, 1] = fy / h + K[0, 2] = 0.5 + K[1, 2] = 0.5 + + near = torch.full((1, 1), NEAR_PLANE, device=device, dtype=dtype) + far = torch.full((1, 1), FAR_PLANE, device=device, dtype=dtype) + out = optgs.decoder.forward( + gaussians, c2w[None, None], K[None, None], near, far, image_shape=(h, w), + ) + return prep_image(out.color[0, 0]) # [H, W, 3] uint8 + + +def gaussians_to_splat_data(gaussians: Gaussians) -> dict: + """An optgs ``Gaussians`` (batch=1) -> numpy arrays for viser's splat viewer. + + Covariances are recomputed from scale/rotation (the optimizer updates those + but may leave the optional ``Gaussians.covariances`` field stale); colours + come from the SH DC term (degree 0 — viser's renderer is not view-dependent). + """ + scales = gaussians.scales[0] + opacities = gaussians.opacities[0] + if not gaussians.stores_activated: + scales = torch.exp(scales) + opacities = torch.sigmoid(opacities) + rotations = F.normalize(gaussians.rotations_unnorm[0], dim=-1) + covariances = build_covariance(scale=scales, rotation_xyzw=rotations) + rgbs = (0.5 + SH_C0 * gaussians.harmonics[0, :, :, 0]).clamp(0.0, 1.0) + + def _np(t: Tensor) -> np.ndarray: + return t.detach().cpu().numpy().astype(np.float32) + + return { + "centers": _np(gaussians.means[0]), # (N, 3) + "covariances": _np(covariances), # (N, 3, 3) + "rgbs": _np(rgbs), # (N, 3) + "opacities": _np(opacities.reshape(-1, 1)), # (N, 1) + } + + +def run_gui( + instances: dict, + gaussians: Gaussians, + train_bv: dict, + cfg: Config, + device: torch.device, + dtype: torch.dtype, +) -> None: + """Interactive viser GUI: watch the optimization, pick an optimizer, reset. + + The initialization is shown first; the user picks an optimizer — the + Learn2Splat learned optimizer (dense or sparse checkpoint) or a 3DGS Adam + baseline — and clicks Start; every optimizer step is rendered and displayed; + Reset restores the initialization. ``cfg.with_gui`` chooses the renderer — + "server" (optgs decoder, frames streamed as images) or "client" (viser's + WebGL splats). + + ``instances`` maps "dense"/"sparse" to their initialized ``OptGS``. + """ + import threading + + import viser + import viser.transforms as vtf + + from optgs.experimental.api.integration.config_bridge import build_adam_baseline + + mode = cfg.with_gui # "server" | "client" + server = viser.ViserServer(port=cfg.gui_port) + + # Optimizer dropdown label -> (instances key, whether to swap in Adam). + # "dense"/"sparse" run that checkpoint's own learned optimizer; "Adam" runs + # a 3DGS Adam baseline on the dense checkpoint's pipeline. + OPTIONS: Dict[str, Tuple[str, bool]] = { + "Learn2Splat (dense)": ("dense", False), + "Learn2Splat (sparse)": ("sparse", False), + "Adam (3DGS)": ("dense", True), + } + + optimizer_dd = server.gui.add_dropdown("Optimizer", tuple(OPTIONS)) + + # Optimization controls — applied to the picked OptGS at Start; frozen + # while optimizing, unfrozen by Reset. opt_batch_size is capped at the + # number of training views (the per-step view minibatch can't exceed them). + n_train_views = int(train_bv["image"].shape[1]) + max_steps_input = server.gui.add_number( + "Max steps", min=1, max=1000, step=1, initial_value=cfg.max_steps + ) + batch_size_input = server.gui.add_number( + "Opt batch size", min=1, max=n_train_views, step=1, + initial_value=min(cfg.opt_batch_size, n_train_views), + ) + strategy_dd = server.gui.add_dropdown( + "Opt batch strategy", ("random", "sequential", "fps"), + initial_value=cfg.opt_batch_strategy, + ) + opt_controls = (max_steps_input, batch_size_input, strategy_dd) + + start_btn = server.gui.add_button("Start optimization") + reset_btn = server.gui.add_button("Reset to initialization") + status = server.gui.add_markdown("**initialized** — pick an optimizer, then Start") + res_slider = ( + server.gui.add_slider( + "Render height", min=240, max=1080, step=60, initial_value=540 + ) + if mode == "server" + else None + ) + + init_gaussians = gaussians.clone() # pristine copy, for Reset + current = init_gaussians # Gaussians currently displayed + active = instances["dense"] # OptGS used to render + to optimize next + gen = None # optimize_iter generator while running + last_cam_ts: dict = {} # client id -> last-rendered camera stamp + lock = threading.Lock() + state = { + "mode": "init", # "init" | "optimizing" | "done" + "step": 0, + "start": False, + "reset": False, + "rerender": False, # a GUI control changed -> re-render once + "selected": next(iter(OPTIONS)), + } + + @start_btn.on_click + def _(_) -> None: + with lock: + if state["mode"] in ("init", "done"): + state["selected"] = optimizer_dd.value + state["start"] = True + + @reset_btn.on_click + def _(_) -> None: + with lock: + state["reset"] = True + + # The render-height slider only affects server-rendered frames; re-render + # on change so the new resolution takes effect without a camera move. + if res_slider is not None: + + @res_slider.on_update + def _(_) -> None: + with lock: + state["rerender"] = True + + # Frame newly-connected clients on the first training camera (viser and + # optgs share the OpenCV camera-to-world convention). + cam_extr = train_bv["extrinsics"][0, 0].detach().cpu().numpy() + + @server.on_client_connect + def _(client) -> None: + try: + client.camera.position = cam_extr[:3, 3] + client.camera.wxyz = vtf.SO3.from_matrix(cam_extr[:3, :3]).wxyz + except Exception: + pass + + if mode == "client": # show the initialization immediately + # Black backdrop for the WebGL splat renderer (viser's canvas is not + # black by default); on server.scene so late-joining clients get it. + server.scene.set_background_image(np.zeros((8, 8, 3), dtype=np.uint8)) + server.scene.add_gaussian_splats( + "/optgs/splats", **gaussians_to_splat_data(current) + ) + + console.print( + f"[green]✓[/] viser GUI ([cyan]{mode}[/]) on port [cyan]{cfg.gui_port}[/]" + f" — forward the port over SSH and open the printed URL" + ) + + try: + while True: + changed = False + + with lock: + do_reset, do_start = state["reset"], state["start"] + do_rerender = state["rerender"] + state["reset"] = state["start"] = state["rerender"] = False + selected = state["selected"] + + if do_rerender: + changed = True # server mode re-renders every connected client + + if do_reset: + if gen is not None: + gen.close() # runs optimize_iter's finally -> on_scene_end() + gen = None + current = init_gaussians + with lock: + state["mode"], state["step"] = "init", 0 + optimizer_dd.disabled = start_btn.disabled = False + for c in opt_controls: + c.disabled = False + changed = True + + if do_start and gen is None: + name, use_adam = OPTIONS[selected] + active = instances[name] + # Apply the GUI optimization controls before the run starts. + active.num_refine = int(max_steps_input.value) + active.opt_batch_size = int(batch_size_input.value) + active.opt_batch_strategy = strategy_dd.value + opt = ( + build_adam_baseline(active.num_refine).to(device) + if use_adam + else None + ) + gen = active.optimize_iter(optimizer=opt) + with lock: + state["mode"], state["step"] = "optimizing", 0 + optimizer_dd.disabled = start_btn.disabled = True + for c in opt_controls: + c.disabled = True + + if gen is not None: + try: + step, current = next(gen) + changed = True + with lock: + state["step"] = step + 1 + except StopIteration: + gen = None + with lock: + state["mode"] = "done" + optimizer_dd.disabled = start_btn.disabled = False + + if mode == "server": + for cid, client in server.get_clients().items(): + try: + cam_ts = client.camera.update_timestamp + if last_cam_ts.get(cid) != cam_ts or changed: + last_cam_ts[cid] = cam_ts + image = render_view( + active, current, client.camera, + res_slider.value, device, dtype, + ) + client.scene.set_background_image(image, format="jpeg") + except Exception: + continue # no camera message from this client yet + elif changed: # client mode — re-push splats when the Gaussians change + server.scene.add_gaussian_splats( + "/optgs/splats", **gaussians_to_splat_data(current) + ) + + with lock: + status.content = ( + f"**{state['mode']}** — step " + f"{state['step']}/{active.num_refine} — " + f"{current.means.shape[1]} Gaussians" + ) + + if gen is None: + time.sleep(1 / 30) # idle: poll cameras at ~30 Hz + except KeyboardInterrupt: + if gen is not None: + gen.close() + console.print("\n[yellow]GUI stopped.[/]") + + +def main(cfg: Config) -> None: + # Fetch the demo scene on first run, before anything else touches it. + ensure_data(cfg.data_dir) + + from optgs.experimental.api import OptGS, OptGSError + from optgs.experimental.api.integration.config_bridge import build_adam_baseline + + os.makedirs(cfg.result_dir, exist_ok=True) + device = torch.device(cfg.device) + dtype = torch.float32 + + console.rule("[bold cyan]OptGS demo[/] · Learn2Splat vs Adam") + + # --- COLMAP scene, train/val split --- + parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + verbose=False, + ) + dataset = Dataset(parser) + val_idx = [i for i in range(len(dataset)) if i % cfg.test_every == 0] + train_idx = [i for i in range(len(dataset)) if i % cfg.test_every != 0] + scene_scale = scene_extent(parser, cfg.global_scale) + console.print( + f"scene scale [cyan]{scene_scale:.4f}[/] · " + f"train [cyan]{len(train_idx)}[/] · val [cyan]{len(val_idx)}[/]" + ) + train_bv = build_batched_views( + *collect_cameras(dataset, train_idx), scene_scale, device, dtype + ) + + # --- Interactive GUI: build both learned-optimizer checkpoints (dense and + # sparse), initialize each, and hand off to the viser GUI instead of the + # headless comparison. The GUI's Optimizer dropdown picks between them. --- + if cfg.with_gui is not None: + instances = {} + for name in ("dense", "sparse"): + try: + instances[name] = OptGS( + checkpoint=CHECKPOINTS[name], + device=cfg.device, + num_refine=cfg.max_steps, + opt_batch_size=cfg.opt_batch_size, + opt_batch_strategy=cfg.opt_batch_strategy, + rasterize_mode=cfg.rasterize_mode, + eps2d=cfg.eps2d, + ) + except OptGSError as e: + console.print(f"[bold red]OptGS error ({name}):[/] {e}") + raise SystemExit(1) + + # One SfM init shared by both checkpoints: dense and sparse get an + # identical starting point, and the GUI shows a single initialization + # regardless of which optimizer is picked. + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + gaussians = sfm_initialization( + parser, cfg, instances["dense"].sh_degree, device, dtype + ) + for inst in instances.values(): + inst.initialize_from_tensors(gaussians, train_bv) + + run_gui(instances, gaussians, train_bv, cfg, device, dtype) + return + + val_c2w, val_Ks, val_images = collect_cameras(dataset, val_idx) + val_bv = build_batched_views(val_c2w, val_Ks, val_images, scene_scale, device, dtype) + + results: dict = {} + + def finish(optgs, refined, name: str, elapsed: float) -> None: + """Persist + evaluate one run's result under results/demo//.""" + out_dir = os.path.join(cfg.result_dir, name) + os.makedirs(out_dir, exist_ok=True) + optgs.export_ply(os.path.join(out_dir, "point_cloud.ply")) + ev = render_and_score(optgs, refined, val_bv, val_images, out_dir, device) + results[name] = { + "psnr": ev["psnr"], "time": elapsed, + "num_views": ev["num_views"], "num_GS": int(refined.means.shape[1]), + } + console.print( + f"[green]✓[/] [bold]{name}[/] — PSNR [cyan]{ev['psnr']:.3f}[/] · " + f"[cyan]{elapsed:.1f}s[/] → [yellow]{out_dir}[/]" + ) + + # --- Learned optimizer (Learn2Splat): dense, then sparse --- + optgs = None + for name in ("dense", "sparse"): + optgs = None # free the previous instance before building the next + torch.cuda.empty_cache() + try: + optgs = OptGS( + checkpoint=CHECKPOINTS[name], + device=cfg.device, + num_refine=cfg.max_steps, + opt_batch_size=cfg.opt_batch_size, + opt_batch_strategy=cfg.opt_batch_strategy, + rasterize_mode=cfg.rasterize_mode, + eps2d=cfg.eps2d, + ) + except OptGSError as e: + console.print(f"[bold red]OptGS error ({name}):[/] {e}") + raise SystemExit(1) + # Seed *after* construction so dense and sparse get an identical SfM init. + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + gaussians = sfm_initialization(parser, cfg, optgs.sh_degree, device, dtype) + optgs.initialize_from_tensors(gaussians, train_bv) + + torch.cuda.synchronize() # drain setup GPU work so it isn't timed + tic = time.time() + refined = optgs.optimize() + torch.cuda.synchronize() + finish(optgs, refined, name, time.time() - tic) + + # --- Fair Adam baseline: same SfM init / views / step budget / gsplat + # renderer, run through the same optimize() path on the last OptGS + # instance — only the update rule differs. --- + adam = build_adam_baseline(optgs.num_refine).to(device) + torch.cuda.synchronize() # drain setup GPU work so it isn't timed + tic = time.time() + refined_adam = optgs.optimize(optimizer=adam) + torch.cuda.synchronize() + finish(optgs, refined_adam, "adam", time.time() - tic) + + # --- Comparison table --- + table = Table( + title=( + f"Novel-view PSNR · {results['dense']['num_views']} held-out " + f"views · {cfg.max_steps} steps · " + f"{results['dense']['num_GS']} Gaussians" + ), + title_style="bold", + caption=( + f"gsplat renderer · " + f"rasterize_mode={cfg.rasterize_mode or 'per-checkpoint'} · " + f"eps2d={cfg.eps2d if cfg.eps2d is not None else 'per-checkpoint'}" + ), + ) + table.add_column("Optimizer") + table.add_column("PSNR (dB)", justify="right") + table.add_column("Time (s)", justify="right") + best = max(results, key=lambda k: results[k]["psnr"]) + for key, label in ( + ("dense", "Learn2Splat (dense)"), + ("sparse", "Learn2Splat (sparse)"), + ("adam", "Adam"), + ): + table.add_row( + label, + f"{results[key]['psnr']:.3f}", + f"{results[key]['time']:.1f}", + style="bold green" if key == best else None, + ) + console.print(table) + + with open(os.path.join(cfg.result_dir, "stats.json"), "w") as f: + json.dump(results, f, indent=2) + console.print(f"[green]✓[/] results written to [yellow]{cfg.result_dir}[/]") + + +if __name__ == "__main__": + main(tyro.cli(Config)) diff --git a/optgs/__init__.py b/optgs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1efed167c6a7061a3c01fbee11c8d0dfcc2dc58b --- /dev/null +++ b/optgs/__init__.py @@ -0,0 +1 @@ +"""optgs — learned optimization for 3D Gaussian Splatting.""" diff --git a/optgs/config.py b/optgs/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9ded2f8a1407b5ab718005ae1b94a9bff605284d --- /dev/null +++ b/optgs/config.py @@ -0,0 +1,770 @@ +import importlib +from copy import deepcopy +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Optional, Type, TypeVar, Any, Callable + +import hydra +import torch +from dacite import Config, from_dict, UnionMatchError +from hydra.core.global_hydra import GlobalHydra +from hydra.core.hydra_config import HydraConfig +from hydra.types import RunMode +from omegaconf import DictConfig +from omegaconf import OmegaConf +from pytorch_lightning.strategies import DDPStrategy, FSDPStrategy + +from .config_migrate import migrate, CURRENT_CFG_VERSION +from .dataset.data_module import DataLoaderCfg, DatasetCfg +from .global_cfg import set_cfg +from .loss import LossCfgWrapper +from .misc.io import CustomPath +from .misc.io import cyan, read_omega_cfg +from .misc.checkpointing import find_latest_ckpt +from .misc.hf_ckpt import maybe_resolve_hf_ref +from .paths import CKPT_DIR, RESULTS_DIR +from .scene_trainer.scene_trainer_cfg import SceneTrainerCfg, MetaOptimizerCfg, TestCfg, TrainCfg + + +# In order to extract filename or dirname from a path in the config +def checkpoint_rel_dir(path): + rel_dir = CustomPath(path) - CKPT_DIR # dir_path / checkpoints / epoch_x-step_xxxxx.ckpt + dir_path = rel_dir.parent.parent + return str(dir_path) + + +OmegaConf.register_new_resolver("checkpoint_rel_dir", checkpoint_rel_dir) +OmegaConf.register_new_resolver("parent_dir", lambda path: str(CustomPath(path).parent)) + + +@dataclass +class CheckpointingCfg: + load: Optional[str] # Not a path, since it could be something like wandb://... + every_n_train_steps: int + save_top_k: int + pretrained_model: Optional[str] + pretrained_monodepth: Optional[str] + pretrained_mvdepth: Optional[str] + pretrained_depth: Optional[str] + pretrained_scale_predictor: Optional[str] + pretrained_depth_teacher: Optional[str] + no_strict_load: bool + resume: bool + no_resume_upsampler: bool + partial_load: bool + freeze_mono_vit: bool + pretrained_initializer: Optional[str] + pretrained_optimizer: Optional[str] + resume_update_module: str | None + load_existing_cfg: bool + + def __post_init__(self): + # Resolve any Hugging Face Hub references (hf://org/repo/file[@rev]) to + # local cached paths so all downstream torch.load calls work unchanged. + for attr in ("pretrained_model", "pretrained_optimizer", "pretrained_initializer", + "pretrained_monodepth", "pretrained_mvdepth", "pretrained_depth", + "pretrained_scale_predictor", "pretrained_depth_teacher", + "resume_update_module"): + resolved = maybe_resolve_hf_ref(getattr(self, attr)) + if resolved != getattr(self, attr): + setattr(self, attr, resolved) + + for attr in ("pretrained_model", "pretrained_optimizer", "pretrained_initializer"): + path = getattr(self, attr) + if path is not None and Path(path).name == "last": + try: + resolved = find_latest_ckpt(Path(path).parent) + setattr(self, attr, resolved) + print(f"Replacing {attr} to last checkpoint: {resolved}") + except Exception as e: + print(cyan(f"Warning: {e}. Continuing with 'last' as {attr}.")) + + +@dataclass +class MetaTrainerCfg: + max_steps: int + val_check_interval: int | float | None + gradient_clip_val: int | float | None + num_sanity_val_steps: int + num_nodes: int + eval_index: str | None + limit_test_batches: int | float + limit_train_batches: int | float + test: TestCfg + train: TrainCfg + + def get_dist_strategy(self, scene_trainer_cfg: SceneTrainerCfg): + from .scene_trainer.initializer.initializer_resplat import ResplatInitializerCfg + dist_strategy = "auto" + if torch.cuda.device_count() > 1: + dist_strategy = 'ddp' + if isinstance(scene_trainer_cfg.scene_optimizer, ResplatInitializerCfg): + if scene_trainer_cfg.scene_initializer.use_gt_depth: + dist_strategy = 'ddp_find_unused_parameters_true' + if scene_trainer_cfg.scene_initializer.use_checkpointing or scene_trainer_cfg.scene_initializer.init_use_checkpointing: + dist_strategy = DDPStrategy(static_graph=True) + if scene_trainer_cfg.use_fsdp: + def only_wrap_trainable(module, recurse, nonwrapped_numel): + has_trainable = any(p.requires_grad for p in module.parameters()) + return has_trainable + + dist_strategy = FSDPStrategy(auto_wrap_policy=only_wrap_trainable) + if self.train.use_replay_buffer: + # When resampling from the replay buffer, + # we don't project the condition_features to state, so the update_proj is not used + dist_strategy = "ddp_find_unused_parameters_true" + return dist_strategy + + +@dataclass +class RootCfg: + wandb: dict + mode: Literal["train", "test"] + dataset: DatasetCfg + data_loader: DataLoaderCfg + scene_trainer: SceneTrainerCfg + meta_optimizer: MetaOptimizerCfg ## TODO Naama: should we move under meta trainer config? + checkpointing: CheckpointingCfg + meta_trainer: MetaTrainerCfg + loss: list[LossCfgWrapper] + seed: int + use_plugins: bool + output_dir: str + version: int | None + debug_cfg: bool + + def __post_init__(self): + if self.mode == "test": + self._setup_test_output_dir() + + def _setup_test_output_dir(self): + base_res_dir = RESULTS_DIR + if self.meta_trainer.limit_test_batches != 1.0: + base_res_dir = RESULTS_DIR + f"_{self.meta_trainer.limit_test_batches}_scenes" + if self.output_dir == "placeholder": + if self.meta_trainer.test.postprocessing is not None and self.meta_trainer.test.postprocessing.is_active: + self.output_dir = (base_res_dir / + "nonlearned" / + "vanilla_3dgs" / + self.meta_trainer.test.postprocessing.name / + self.meta_trainer.test.postprocessing.get_dir_name(with_name=False)) + else: + ckpt_path = self.checkpointing.pretrained_model or self.checkpointing.pretrained_optimizer + pretrained_model_rel_dir = checkpoint_rel_dir(ckpt_path) + self.output_dir = (base_res_dir / + "optgs" / + pretrained_model_rel_dir) + elif 'experimental' in str(self.output_dir): # TODO (release): remove + self._setup_experimental_output_dir() + + def _setup_experimental_output_dir(self): + resplat_str = [] + grad_str = [] + normgrad_str = [] + assert self.scene_trainer.scene_optimizer.experimental_run + for p in self.scene_trainer.scene_optimizer.experimental_update.param_names: + update = getattr(self.scene_trainer.scene_optimizer.experimental_update, p) + use_norm_grad = getattr(self.scene_trainer.scene_optimizer.experimental_use_norm_grads, p) + use_grad = self.scene_trainer.scene_optimizer.experimental_use_grads and not use_norm_grad + use_resplat = update and not use_grad and not use_norm_grad + if update: + assert use_grad ^ use_norm_grad ^ use_resplat, f"Invalid combination for {p}: use_resplat={use_resplat}, use_grad={use_grad}, use_norm_grad={use_norm_grad}" + if use_resplat: + resplat_str.append(p) + if use_grad: + grad_str.append(p) + if use_norm_grad: + normgrad_str.append(p) + + if len(resplat_str) == len(self.scene_trainer.scene_optimizer.experimental_update.param_names): + resplat_str = ["all"] + if len(grad_str) == len(self.scene_trainer.scene_optimizer.experimental_update.param_names): + grad_str = ["all"] + if len(normgrad_str) == len(self.scene_trainer.scene_optimizer.experimental_update.param_names): + normgrad_str = ["all"] + + exp_name = "_".join([ + ("resplat_" + "_".join(resplat_str) if len(resplat_str) > 0 else ""), + ("grad_" + "_".join(grad_str) if len(grad_str) > 0 else ""), + ("normgrad_" + "_".join(normgrad_str) if len(normgrad_str) > 0 else ""), + ]) + + output_dir_str = str(self.output_dir) + output_dir_str = output_dir_str.replace("experimental", f"experimental_{exp_name}") + self.output_dir = Path(output_dir_str) + print(cyan(f"Experimental run, setting output_dir to {CustomPath(self.output_dir)}")) + + +TYPE_HOOKS = { + Path: Path, +} + +T = TypeVar("T") + + +def get_class_by_path(path: str): + module_path, class_name = path.rsplit('.', 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def _diagnose_union_error(e: UnionMatchError, data: dict, dacite_config: Config) -> str: + """Try each union member individually and report per-member errors.""" + import dataclasses + import typing + union_type = e.field_type + # Extract the member types from the union + args = typing.get_args(union_type) + if not args: + return str(e) + lines = [str(e), "", "Per-member diagnostics:"] + for member_type in args: + try: + from_dict(member_type, data, config=dacite_config) + lines.append(f" {member_type.__name__}: matched OK (unexpected)") + except Exception as member_err: + lines.append(f" {member_type.__name__}: {member_err}") + # For dataclasses, also check for extra/missing fields + if dataclasses.is_dataclass(member_type): + expected = {f.name for f in dataclasses.fields(member_type)} + provided = set(data.keys()) if isinstance(data, dict) else set() + missing = expected - provided + extra = provided - expected + if missing: + lines.append(f" missing fields: {missing}") + if extra: + lines.append(f" extra fields (ignored with strict=False): {extra}") + return "\n".join(lines) + + +def load_typed_config( + cfg: DictConfig, + data_class: Type[T], + extra_type_hooks: dict = {}, +) -> T: + dacite_config = Config(type_hooks={**TYPE_HOOKS, **extra_type_hooks}) + try: + return from_dict( + data_class, + OmegaConf.to_container(cfg), + config=dacite_config, + ) + except UnionMatchError as e: + diagnostic = _diagnose_union_error(e, e.value, dacite_config) + print(f"\n{'='*60}\n" + f"Current config: {e.value}\n" + "\n" + "\n" + f"UnionMatchError diagnostic:\n{diagnostic}\n{'='*60}" + f"\n", + flush=True) + raise + + +def separate_loss_cfg_wrappers(joined: dict) -> list[LossCfgWrapper]: + # The dummy allows the union to be converted. + @dataclass + class Dummy: + dummy: LossCfgWrapper + + return [ + load_typed_config(DictConfig({"dummy": {k: v}}), Dummy).dummy + for k, v in joined.items() + ] + + +def universal_target_hook(cfg: dict, _: Type) -> Any: + """Generic hook to construct config objects from `__target__`.""" + if not isinstance(cfg, dict): + return None + if "__target__" not in cfg: + return None # Let decite handle it + + cfg_copy = deepcopy(cfg) # avoid mutating original + target = cfg_copy.pop("__target__") + + if isinstance(target, str): + target_type = get_class_by_path(target) + else: + target_type = target + + # Use recursive loading with known additional hooks + return load_typed_config( + DictConfig(cfg_copy), + target_type, + ) + + +def make_target_hook_for_type(t: Type) -> Callable: + return lambda cfg: universal_target_hook(cfg, t) + + +def load_typed_root_config(cfg: DictConfig) -> RootCfg: + # scene_trainer/scene_optimizer=none loads a full dict from none.yaml; + # dacite can't match that dict to the None arm of SceneOptimizerCfg | None. + # Convert it to Python None here so dacite matches correctly. + scene_opt = OmegaConf.select(cfg, "scene_trainer.scene_optimizer") + if isinstance(scene_opt, DictConfig) and OmegaConf.select(scene_opt, "name") == "none": + OmegaConf.set_struct(cfg, False) + OmegaConf.update(cfg, "scene_trainer.scene_optimizer", None, merge=False) + OmegaConf.set_struct(cfg, True) + + return load_typed_config( + cfg, + RootCfg, + {list[LossCfgWrapper]: separate_loss_cfg_wrappers} + ) + + +def should_run(cfg_dict): + if cfg_dict.mode == "test": + if cfg_dict.meta_trainer.test.skip_if_outputs_exist: + output_dir = cfg_dict.output_dir + if not output_dir.exists(): + return True + metrics_path_pattern = output_dir / "metrics" / "target_*_psnr.json" + metric_paths = list(metrics_path_pattern.parent.glob(metrics_path_pattern.name)) + if len(metric_paths) > 0: + print(cyan(f"Test metrics already exist at {metric_paths}.")) + return False + return True + + +def setup_cfg(cfg_dict): + # Get the original config from the output directory, when testing or resuming. + cfg_dict = merge_config_from_file(cfg_dict) + eval_cfg = get_eval_cfg(cfg_dict) + cfg = load_typed_root_config(cfg_dict) + # Set global cfg object. + set_cfg(cfg_dict) + # Set up the output directory. + setup_output_dir(cfg, cfg_dict) + return cfg, cfg_dict, eval_cfg # TODO Naama: why do we need both cfg and cfg_dict? + + +def flatten_wandb(cfg): + """Recursively replace {'desc': ..., 'value': v} with v.""" + if isinstance(cfg, dict): + if "value" in cfg and len(cfg) == 2 and "desc" in cfg: + return flatten_wandb(cfg["value"]) + return {k: flatten_wandb(v) for k, v in cfg.items()} + elif isinstance(cfg, list): + return [flatten_wandb(v) for v in cfg] + else: + return cfg + + +def _apply_cli_overrides(merged_cfg: DictConfig, orig_cli_cfg: DictConfig, raw_overrides: list[str]) -> DictConfig: + """ + Re-apply CLI overrides onto merged_cfg after the checkpoint config has been merged in. + + Takes already-composed values from orig_cli_cfg rather than re-parsing the raw override + strings. This correctly handles: + - Group overrides (e.g. dataset/view_sampler=evaluation) → replace subtree from cli + - Complex values (e.g. loss=[mse,ssim]) → replace subtree from cli + - Interpolated values (e.g. output_dir=${...}) → take resolved value from cli + - Defaults-list overrides (+experiment=re10k) → skip (already baked into orig_cli_cfg) + """ + if not raw_overrides: + return merged_cfg + + from hydra.core.override_parser.overrides_parser import OverridesParser + parser = OverridesParser.create() + parsed = parser.parse_overrides(raw_overrides) + + print(cyan(f"Re-applying {len(raw_overrides)} CLI overrides onto merged config.")) + OmegaConf.set_struct(merged_cfg, False) + + # Architecture subtrees: CLI group default fills in *new* fields only; + # checkpoint values win for fields that already exist. + ARCH_KEYS = {"scene_optimizer", "scene_initializer"} + # Sub-keys within ARCH_KEYS where CLI should always win over checkpoint values. + CLI_WINS_SUBKEYS = {"refiner"} + + for override in parsed: + key = override.key_or_group + dotkey = key.replace("/", ".") + + cli_val = OmegaConf.select(orig_cli_cfg, dotkey, default=None, throw_on_resolution_failure=False) + + if cli_val is None: + # No direct config path — e.g. +experiment=re10k is a defaults-list override + # whose effect is already baked into orig_cli_cfg; nothing to apply. + print(cyan(f" Skipping '{key}' (no direct config path in cli)")) + continue + + # For architecture group overrides: fill in missing fields from CLI defaults + # without overriding checkpoint values for fields that already exist. + is_group_override = "/" in key or isinstance(cli_val, (DictConfig, dict, list)) + if is_group_override and any(arch_key in dotkey for arch_key in ARCH_KEYS): + # If the override targets a CLI-wins sub-key directly, CLI wins entirely. + dotkey_parts = set(dotkey.split(".")) + if dotkey_parts & CLI_WINS_SUBKEYS: + OmegaConf.update(merged_cfg, dotkey, cli_val, merge=False) + print(cyan(f" '{dotkey}': replace from cli (CLI wins)")) + continue + + existing_val = OmegaConf.select(merged_cfg, dotkey, default=None) + if existing_val is not None: + # cli_val provides new defaults; existing_val (checkpoint) wins for shared fields + new_val = OmegaConf.merge(cli_val, existing_val) + # Re-apply CLI-wins sub-keys so they override checkpoint values. + for subkey in CLI_WINS_SUBKEYS: + cli_subval = OmegaConf.select(cli_val, subkey, default=None) + if cli_subval is not None: + OmegaConf.set_struct(new_val, False) + OmegaConf.update(new_val, subkey, cli_subval, merge=False) + print(cyan(f" '{dotkey}.{subkey}': CLI override applied (CLI wins)")) + OmegaConf.update(merged_cfg, dotkey, new_val, merge=False) + print(cyan(f" '{dotkey}': fill-missing from cli (checkpoint values preserved)")) + continue + + # Group overrides and complex values replace the whole subtree; + # scalars are merged so sibling keys are preserved. + replace = is_group_override + print(cyan(f" '{dotkey}': {'replace' if replace else 'update'} from cli")) + OmegaConf.update(merged_cfg, dotkey, cli_val, merge=not replace) + + OmegaConf.set_struct(merged_cfg, True) + return merged_cfg + + +def _print_cfg_diff(before: dict, after: dict, prefix: str = "") -> None: + """Recursively print keys that differ between two plain-dict config snapshots.""" + all_keys = set(before) | set(after) + diffs = [] + for k in sorted(all_keys): + full_key = f"{prefix}.{k}" if prefix else k + b_val = before.get(k, "") + a_val = after.get(k, "") + if isinstance(b_val, dict) and isinstance(a_val, dict): + _print_cfg_diff(b_val, a_val, prefix=full_key) + elif b_val != a_val: + diffs.append((full_key, b_val, a_val)) + for full_key, b_val, a_val in diffs: + print(cyan(f" [cfg diff] {full_key}: {b_val!r} → {a_val!r}")) + + +def _find_config_for_checkpoint(ckpt_path) -> Path | None: + """Return the config.yaml path for a given checkpoint, or None.""" + p = Path(ckpt_path).parent.parent / "config.yaml" + if p.exists(): + return p + # Fall back to wandb latest-run + p = Path(ckpt_path).parent.parent / "wandb" / "latest-run" / "files" / "config.yaml" + if p.exists(): + return p + return None + + +def _load_checkpoint_cfg(config_path: Path) -> DictConfig: + """Load, migrate, and (if from wandb) flatten a checkpoint config file.""" + cfg = read_omega_cfg(config_path) + cfg = migrate(cfg) + if "wandb" in str(config_path): + cfg = OmegaConf.create(flatten_wandb(OmegaConf.to_container(cfg, resolve=True))) + return cfg + + +def _patch_scene_initializer(target_cfg: DictConfig, init_config_path: Path, context: str) -> None: + """ + Load scene_trainer.scene_initializer from init_config_path and patch it into target_cfg in-place. + target_cfg must not be struct-protected when this is called. + """ + init_cfg = _load_checkpoint_cfg(init_config_path) + initializer_subcfg = OmegaConf.select(init_cfg, "scene_trainer.scene_initializer", default=None) + if initializer_subcfg is not None: + print(cyan(f"{context}: patching scene_trainer.scene_initializer from pretrained_initializer config.")) + OmegaConf.update(target_cfg, "scene_trainer.scene_initializer", initializer_subcfg, merge=True) + else: + print(cyan("pretrained_initializer config has no scene_trainer.scene_initializer key; skipping patch.")) + + +def _resolve_config_paths(cli_cfg) -> tuple[Path | None, Path | None]: + """ + Determine which config files to load based on CLI checkpointing settings. + + Returns: + config_path: main checkpoint config (optimizer + initializer architecture), or None + initializer_config_path: separate initializer checkpoint config (overrides main for initializer), or None + + Priority for config_path: + resume > pretrained_model > pretrained_optimizer (> pretrained_initializer sets initializer_config_path only) + """ + pretrained_model = cli_cfg.checkpointing.pretrained_model + pretrained_optimizer = cli_cfg.checkpointing.pretrained_optimizer + pretrained_initializer = cli_cfg.checkpointing.pretrained_initializer + should_load = cli_cfg.mode == "test" or cli_cfg.checkpointing.load_existing_cfg + + config_path = None + initializer_config_path = None + + if pretrained_model is not None: + if should_load: + config_path = _find_config_for_checkpoint(pretrained_model) + print(cyan(f"Loading config from pretrained_model checkpoint {config_path}" + if config_path else f"No config found for pretrained_model {pretrained_model}.")) + + elif pretrained_optimizer is not None: + if should_load: + config_path = _find_config_for_checkpoint(pretrained_optimizer) + print(cyan(f"Loading config from pretrained_optimizer checkpoint {config_path}" + if config_path else f"No config found for pretrained_optimizer {pretrained_optimizer}.")) + if pretrained_initializer is not None: + initializer_config_path = _find_config_for_checkpoint(pretrained_initializer) + print(cyan(f"Loading initializer config from pretrained_initializer checkpoint {initializer_config_path}" + if initializer_config_path else f"No config found for pretrained_initializer {pretrained_initializer}.")) + + elif pretrained_initializer is not None: + if should_load: + initializer_config_path = _find_config_for_checkpoint(pretrained_initializer) + print(cyan(f"Loading initializer-only config from pretrained_initializer checkpoint {initializer_config_path}" + if initializer_config_path else f"No config found for pretrained_initializer {pretrained_initializer}.")) + + else: + print(cyan("No pretrained_model, pretrained_optimizer, or pretrained_initializer specified, using cli config only.")) + + # Resume overrides config_path to point at the output directory's saved config. + if cli_cfg.checkpointing.resume and cli_cfg.checkpointing.load_existing_cfg: + config_path = Path(cli_cfg.output_dir) / "config.yaml" + print(cyan(f"Resuming: loading config from cfg.output_dir {config_path}")) + else: + print(cyan("Not resuming..")) + + if config_path is not None and not config_path.exists(): + print(cyan(f"Config file {config_path} does not exist. Continuing with cli config only.")) + config_path = None + elif config_path is not None: + print(cyan(f"Found config file {config_path}.")) + + return config_path, initializer_config_path + + +def _merge_test_mode( + cli_cfg: DictConfig, + loaded_cfg: DictConfig, + initializer_config_path: Path | None, + pretrained_initializer: str | None, +) -> tuple[DictConfig, DictConfig]: + """ + Test mode: CLI config is the base for all settings (dataset, test flags, etc.). + Only optimizer and initializer *architecture* are patched in from checkpoint configs. + + Initializer source priority: + 1. separate initializer_config_path (pretrained_initializer ckpt with a config file) + 2. main loaded_cfg (optimizer checkpoint's bundled initializer) + 3. CLI config as-is (pretrained_initializer set but has no config file) + + Returns (merged_cfg, orig_cli_cfg); orig_cli_cfg is the snapshot taken before any + checkpoint patches so that _apply_cli_overrides can restore explicit CLI values. + """ + OmegaConf.set_struct(cli_cfg, False) + # Snapshot BEFORE patching: merged_cfg aliases cli_cfg, so patches below also mutate + # cli_cfg. _apply_cli_overrides must see the original CLI values, not the patched ones. + orig_cli_cfg = OmegaConf.create( + OmegaConf.to_container(cli_cfg, resolve=False, throw_on_missing=False) + ) + merged_cfg = cli_cfg # patched in-place + + # Patch optimizer architecture from checkpoint + optimizer_subcfg = OmegaConf.select(loaded_cfg, "scene_trainer.scene_optimizer", default=None) + if optimizer_subcfg is not None: + print(cyan("Test mode: patching scene_trainer.scene_optimizer from checkpoint config.")) + OmegaConf.update(merged_cfg, "scene_trainer.scene_optimizer", optimizer_subcfg, merge=True) + + # Patch initializer architecture (priority order above) + if initializer_config_path is not None and initializer_config_path.exists(): + _patch_scene_initializer(merged_cfg, initializer_config_path, context="Test mode") + elif pretrained_initializer is None: + pass + # TODO Naama + # No explicit initializer checkpoint — fall back to the optimizer checkpoint's initializer + # initializer_subcfg = OmegaConf.select(loaded_cfg, "scene_trainer.scene_initializer", default=None) + # if initializer_subcfg is not None: + # print(cyan("Test mode: patching scene_trainer.scene_initializer from checkpoint config.")) + # OmegaConf.update(merged_cfg, "scene_trainer.scene_initializer", initializer_subcfg, merge=True) + else: + print(cyan("pretrained_initializer set but has no config file; using CLI scene_initializer config.")) + + OmegaConf.set_struct(merged_cfg, True) + return merged_cfg, orig_cli_cfg + + +def _merge_train_mode( + cli_cfg: DictConfig, + loaded_cfg: DictConfig, + initializer_config_path: Path | None, +) -> tuple[DictConfig, DictConfig]: + """ + Train mode: checkpoint config takes priority over CLI for all existing fields + (preserves the trained architecture). CLI fills in any new fields added since training. + + If a separate initializer checkpoint is given, its scene_initializer replaces the one + inside loaded_cfg before the full merge, so the right initializer architecture is used. + + Returns (merged_cfg, orig_cli_cfg); orig_cli_cfg is the pre-merge snapshot used + by _apply_cli_overrides to restore explicit CLI values. + """ + if initializer_config_path is not None and initializer_config_path.exists(): + init_cfg = _load_checkpoint_cfg(initializer_config_path) + initializer_subcfg = OmegaConf.select(init_cfg, "scene_trainer.scene_initializer", default=None) + if initializer_subcfg is not None: + print(cyan("Replacing scene_trainer.scene_initializer in loaded config with initializer config.")) + OmegaConf.update(loaded_cfg, "scene_trainer.scene_initializer", initializer_subcfg, merge=False) + else: + print(cyan("pretrained_initializer config has no scene_trainer.scene_initializer key; skipping patch.")) + + orig_cli_cfg = OmegaConf.create( + OmegaConf.to_container(cli_cfg, resolve=False, throw_on_missing=False) + ) + OmegaConf.set_struct(cli_cfg, False) + merged_cfg = OmegaConf.merge(cli_cfg, loaded_cfg) # loaded_cfg wins for existing fields + OmegaConf.set_struct(merged_cfg, True) + return merged_cfg, orig_cli_cfg + + +def merge_config_from_file(cli_cfg): + # 1. Determine which config files to load. + config_path, initializer_config_path = _resolve_config_paths(cli_cfg) + + # 2. No checkpoint config: use CLI as-is, optionally patching in initializer architecture. + if config_path is None: + print(cyan(f"No config file found, using cli config only. \n" + f"Setting config version to {CURRENT_CFG_VERSION}.")) + cli_cfg["version"] = CURRENT_CFG_VERSION + if initializer_config_path is not None and initializer_config_path.exists(): + OmegaConf.set_struct(cli_cfg, False) + _patch_scene_initializer(cli_cfg, initializer_config_path, context="No-checkpoint") + OmegaConf.set_struct(cli_cfg, True) + return cli_cfg + + # 3. Load and migrate the checkpoint config. + print(cyan(f"Loading config from {config_path}.")) + loaded_cfg = _load_checkpoint_cfg(config_path) + + # 4. Merge checkpoint config with CLI config (strategy differs by mode). + # Test: CLI is the base; only optimizer/initializer architecture patched from checkpoint. + # Train: checkpoint takes priority; CLI fills in new fields added since training. + pretrained_initializer = cli_cfg.checkpointing.pretrained_initializer + if cli_cfg.mode == "test": + merged_cfg, orig_cli_cfg = _merge_test_mode( + cli_cfg, loaded_cfg, initializer_config_path, pretrained_initializer + ) + else: + merged_cfg, orig_cli_cfg = _merge_train_mode(cli_cfg, loaded_cfg, initializer_config_path) + + # 5. Re-apply CLI overrides so user-specified values win over loaded checkpoint config. + merged_cfg = _apply_cli_overrides(merged_cfg, orig_cli_cfg, list(HydraConfig.get().overrides.task)) + + return merged_cfg + + +class SkipRun(Exception): + pass + + +def setup_output_dir(cfg, cfg_dict): + if cfg.output_dir != cfg_dict.output_dir: + if "$" in str(cfg.output_dir): + # interpolated value, not sure how to make it work. + cfg.output_dir = CustomPath(cfg_dict.output_dir) + output_dir = cfg.output_dir + if output_dir is None: + output_dir = CustomPath( + HydraConfig.get()["runtime"]["output_dir"] + ) + else: # for resuming + output_dir = CustomPath(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + if HydraConfig.get().mode == RunMode.MULTIRUN and output_dir == "placeholder": + # Hack to overcome multirun issues + # TODO Naama, need to move to post_init of cfg + output_dir = CustomPath(hydra.core.hydra_config.HydraConfig.get()["run"]["dir"]) + print(cyan(f"Multirun detected, setting output_dir to {CustomPath(output_dir):link}")) + # save checkoint path to a file for debugging + ckpt_path = cfg.checkpointing.pretrained_model or cfg.checkpointing.pretrained_optimizer + (output_dir / "ckpt_dir.txt").write_text(str(ckpt_path)) + cfg_dict.output_dir = output_dir + cfg.output_dir = output_dir + output_dir.mkdir(exist_ok=True, parents=True) + + if cfg.mode == 'test': + if cfg.meta_trainer.test.output_path is None or str(cfg.meta_trainer.test.output_path) in ['placeholder', 'outputs/test']: + cfg.meta_trainer.test.output_path = output_dir + if cfg.meta_trainer.test.compute_scores: + (cfg.meta_trainer.test.output_path / "metrics").mkdir(exist_ok=True, parents=True) + print(cyan(f"Saving outputs to {CustomPath(output_dir):link}.")) + + # Save the config to the output directory. + cfg_dict_path = output_dir / "config.yaml" + + with open(cfg_dict_path, "w") as f: + OmegaConf.save(cfg_dict, f) + + +def get_eval_cfg(cfg_dict): + if "meta_trainer" in cfg_dict: + meta_trainer_dict = cfg_dict["meta_trainer"] + else: + raise ValueError("No trainer or meta_trainer in cfg_dict") + + if cfg_dict["mode"] == "train" and meta_trainer_dict["train"]["eval_model_every_n_val"] > 0: + eval_cfg_dict = deepcopy(cfg_dict) + dataset_dir = str(cfg_dict["dataset"]["roots"]).lower() + if "re10k" in dataset_dir: + if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 2: + eval_path = "assets/evaluation_index_re10k.json" + elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 4: + eval_path = "assets/re10k_start_0_distance_150_ctx_4v_tgt_6v.json" + elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 6: + eval_path = "assets/re10k_start_0_distance_200_ctx_6v_tgt_6v.json" + else: + if meta_trainer_dict["eval_index"] is not None: + eval_path = None # placeholder + else: + raise ValueError("unsupported number of views for re10k") + elif "dl3dv" in dataset_dir: + if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 6: + eval_path = "assets/dl3dv_start_0_distance_50_ctx_6v_tgt_8v.json" + elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 2: + eval_path = "assets/dl3dv_start_0_distance_20_ctx_2v_tgt_4v.json" + elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 8: + eval_path = "assets/dl3dv_evaluation/dl3dv_start_0_distance_40_ctx_8v_tgt_8v.json" + elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 16: + eval_path = "assets/dl3dv_evaluation/dl3dv_start_0_distance_80_ctx_16v_tgt_16v.json" + elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 32: + eval_path = "assets/dl3dv_evaluation/dl3dv_start_0_distance_160_ctx_32v_tgt_24v.json" + elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 64: + eval_path = "assets/dl3dv_benchmark/dl3dv_ctx_64v_tgt_every8th.json" + elif cfg_dict["dataset"]["view_sampler"]["num_context_views"] == -1: + print("Setting manually eval_path, num_context_views remains -1 for dl3dv eval") + eval_path = "assets/dl3dv_evaluation/dl3dv_start_0_distance_40_ctx_8v_tgt_8v.json" + else: + raise ValueError("unsupported number of views for dl3dv") + elif "scannet" in dataset_dir: + if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 2: + eval_path = "assets/evaluation_index_scannet_view2.json" + else: + raise ValueError("unsupported number of views for scannet") + elif "tartanair" in dataset_dir: + if cfg_dict["dataset"]["view_sampler"]["num_context_views"] == 2: + eval_path = 'assets/evaluation_index_tartanair_view2.json' + else: + raise ValueError("unsupported number of views for tartanair") + else: + raise Exception("Fail to load eval index path") + eval_cfg_dict["dataset"]["view_sampler"] = { + "name": "evaluation", + "index_path": eval_path, + "num_context_views": cfg_dict["dataset"]["view_sampler"]["num_context_views"], + } + + # specify eval index + if meta_trainer_dict["eval_index"] is not None: + eval_cfg_dict["dataset"]["view_sampler"]["index_path"] = meta_trainer_dict["eval_index"] + + eval_cfg = load_typed_root_config(eval_cfg_dict) + else: + eval_cfg = None + return eval_cfg diff --git a/optgs/config/dataset/base.yaml b/optgs/config/dataset/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a1fe6f750a001cf482ddee99afc7c05623328402 --- /dev/null +++ b/optgs/config/dataset/base.yaml @@ -0,0 +1,8 @@ +image_shape: [0, 0] +background_color: [0.0, 0.0, 0.0] +cameras_are_circular: false +overfit_to_scene: null +opencv_pose_format: false +pose_align_middle_view: false + +test_start_idx: 0 \ No newline at end of file diff --git a/optgs/config/dataset/colmap.yaml b/optgs/config/dataset/colmap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a013e5eda245c48782642065c2448c3d9e2a9bc2 --- /dev/null +++ b/optgs/config/dataset/colmap.yaml @@ -0,0 +1,12 @@ +defaults: + - base + - view_sampler: dense + +name: colmap +roots: null +scene_name: null +normalize_world_space: false +subsample_factor: 8 +symmetric_principal_point: false + +crop_size: null \ No newline at end of file diff --git a/optgs/config/dataset/dl3dv.yaml b/optgs/config/dataset/dl3dv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f709e9e48a9eda6ff5a4c80ebab1f91c23c62228 --- /dev/null +++ b/optgs/config/dataset/dl3dv.yaml @@ -0,0 +1,61 @@ +defaults: + - base + - view_sampler: boundedv2_360 + +name: dl3dv +roots: [datasets/dl3dv] +make_baseline_1: false +augment: true + + +image_shape: [270, 480] + +baseline_epsilon: 1e-3 +max_fov: 100.0 + +skip_bad_shape: true +near: -1. +far: -1. +baseline_scale_bounds: false +shuffle_val: true +test_len: -1 +test_chunk_interval: 1 +sort_target_index: true +sort_context_index: true + +train_times_per_scene: 1 +test_times_per_scene: 1 +ori_image_shape: [270, 480] +overfit_max_views: 148 +use_index_to_load_chunk: false + +mix_tartanair: false +no_mix_test_set: true +load_depth: false +center_pose: false + +pose_align_first_view: false + +scale_extrinsics: 1. +metric_scale_align_dl3dv: false + +# view filtering +min_views: 0 +max_views: 0 +highres: false + +# mix re10k & dl3dv +mix_re10k: false +re10k_min_view_dist: 40 +re10k_max_view_dist: 300 + +# load remaining context views +load_remain_context: false +num_remain_context: 8 + +# random crop in training +random_crop: false +min_size: null +max_size: null + +index_name: index.json \ No newline at end of file diff --git a/optgs/config/dataset/re10k.yaml b/optgs/config/dataset/re10k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..be4e8f29ba74f7e8fd6f946c575183d441cb865d --- /dev/null +++ b/optgs/config/dataset/re10k.yaml @@ -0,0 +1,27 @@ +defaults: + - base + - view_sampler: bounded + +name: re10k +roots: [datasets/re10k] +make_baseline_1: false +augment: true + +image_shape: [180, 320] +highres: false + +baseline_epsilon: 1e-3 +max_fov: 100.0 + +skip_bad_shape: true +near: -1. +far: -1. +baseline_scale_bounds: true +shuffle_val: true +test_len: -1 +test_chunk_interval: 1 + +use_index_to_load_chunk: false + +average_pose: false +center_pose: false \ No newline at end of file diff --git a/optgs/config/dataset/scannet.yaml b/optgs/config/dataset/scannet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ee12cc6ff771b572c70769fe0958cf719bfb8d9 --- /dev/null +++ b/optgs/config/dataset/scannet.yaml @@ -0,0 +1,13 @@ +defaults: + - base + - view_sampler: ids + +name: scannet +roots: datasets/quicksplat_spp_data_processed +scene_name: null +split: test +subsample_factor: 1 +num_context_views: 100 +filter_bad_frames: true + +crop_size: null diff --git a/optgs/config/dataset/view_sampler/all.yaml b/optgs/config/dataset/view_sampler/all.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c49fb4661ff2e30ecd98c7e233e2835c9071014 --- /dev/null +++ b/optgs/config/dataset/view_sampler/all.yaml @@ -0,0 +1 @@ +name: all diff --git a/optgs/config/dataset/view_sampler/arbitrary.yaml b/optgs/config/dataset/view_sampler/arbitrary.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c947c2ce3ddf89fac2136ee8133c6662536af3d5 --- /dev/null +++ b/optgs/config/dataset/view_sampler/arbitrary.yaml @@ -0,0 +1,7 @@ +name: arbitrary + +num_target_views: 1 +num_context_views: 2 + +# If you want to hard-code context views, do so here. +context_views: null diff --git a/optgs/config/dataset/view_sampler/bounded.yaml b/optgs/config/dataset/view_sampler/bounded.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e7b67d31a27bfe83598a980cf73ab6a2bd581a9 --- /dev/null +++ b/optgs/config/dataset/view_sampler/bounded.yaml @@ -0,0 +1,12 @@ +name: bounded + +num_target_views: 1 +num_context_views: 2 + +min_distance_between_context_views: 2 +max_distance_between_context_views: 6 +min_distance_to_context_views: 0 + +warm_up_steps: 0 +initial_min_distance_between_context_views: 2 +initial_max_distance_between_context_views: 6 \ No newline at end of file diff --git a/optgs/config/dataset/view_sampler/boundedv2.yaml b/optgs/config/dataset/view_sampler/boundedv2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..065e7ea4c7a9eedccb017c1f3bee916da4fc1648 --- /dev/null +++ b/optgs/config/dataset/view_sampler/boundedv2.yaml @@ -0,0 +1,15 @@ +name: boundedv2 + +num_target_views: 1 +num_context_views: 2 + +min_distance_between_context_views: 2 +max_distance_between_context_views: 6 +max_distance_to_context_views: 0 + +context_gap_warm_up_steps: 0 +target_gap_warm_up_steps: 0 + +initial_min_distance_between_context_views: 2 +initial_max_distance_between_context_views: 6 +initial_max_distance_to_context_views: 0 diff --git a/optgs/config/dataset/view_sampler/boundedv2_360.yaml b/optgs/config/dataset/view_sampler/boundedv2_360.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b6c2609aa231ed1b995249c319ab4992cbc3dd93 --- /dev/null +++ b/optgs/config/dataset/view_sampler/boundedv2_360.yaml @@ -0,0 +1,17 @@ +name: boundedv2 + +num_target_views: 4 +num_context_views: 4 + +min_distance_between_context_views: 20 +max_distance_between_context_views: 50 +max_distance_to_context_views: 0 + +context_gap_warm_up_steps: 10000 +target_gap_warm_up_steps: 0 + +initial_min_distance_between_context_views: 15 +initial_max_distance_between_context_views: 30 +initial_max_distance_to_context_views: 0 +extra_views_sampling_strategy: farthest_point +target_views_replace_sample: false diff --git a/optgs/config/dataset/view_sampler/dense.yaml b/optgs/config/dataset/view_sampler/dense.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c8e6dc693bfc35ec76dde13fbf71b81b2cddcecb --- /dev/null +++ b/optgs/config/dataset/view_sampler/dense.yaml @@ -0,0 +1,6 @@ +name: dense + +target_every: 8 +context_every: -1 +num_target_views: -1 +num_context_views: -1 \ No newline at end of file diff --git a/optgs/config/dataset/view_sampler/evaluation.yaml b/optgs/config/dataset/view_sampler/evaluation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3931de80100fc279027bc7285893d11871d546ab --- /dev/null +++ b/optgs/config/dataset/view_sampler/evaluation.yaml @@ -0,0 +1,4 @@ +name: evaluation + +index_path: assets/evaluation_index_re10k_video.json +num_context_views: 2 diff --git a/optgs/config/dataset/view_sampler/ids.yaml b/optgs/config/dataset/view_sampler/ids.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2e0c542e2b3494ebf3e721827af6c960af766284 --- /dev/null +++ b/optgs/config/dataset/view_sampler/ids.yaml @@ -0,0 +1,4 @@ +name: ids + +context_views_ids: [] +target_views_ids: [] \ No newline at end of file diff --git a/optgs/config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml b/optgs/config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c8030734831b35eb63677f8f912cd88369e7edf --- /dev/null +++ b/optgs/config/dataset/view_sampler_dataset_specific_config/bounded_re10k.yaml @@ -0,0 +1,11 @@ +# @package _global_ + +dataset: + view_sampler: + min_distance_between_context_views: 45 + max_distance_between_context_views: 135 + min_distance_to_context_views: 0 + warm_up_steps: 30000 + initial_min_distance_between_context_views: 25 + initial_max_distance_between_context_views: 45 + num_target_views: 4 diff --git a/optgs/config/dataset/view_sampler_dataset_specific_config/boundedv2_dl3dv.yaml b/optgs/config/dataset/view_sampler_dataset_specific_config/boundedv2_dl3dv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d97b060f90c11180b5c6335b6e6f084422a22acb --- /dev/null +++ b/optgs/config/dataset/view_sampler_dataset_specific_config/boundedv2_dl3dv.yaml @@ -0,0 +1,14 @@ +# @package _global_ + +dataset: + view_sampler: + min_distance_between_context_views: 20 + max_distance_between_context_views: 50 + max_distance_to_context_views: 0 + context_gap_warm_up_steps: 10000 + target_gap_warm_up_steps: 0 + initial_min_distance_between_context_views: 15 + initial_max_distance_between_context_views: 30 + initial_max_distance_to_context_views: 0 + extra_views_sampling_strategy: farthest_point + num_target_views: 4 diff --git a/optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_dl3dv.yaml b/optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_dl3dv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2502f8cb73f52e3f43a00463ef6bcf5108cc6ffa --- /dev/null +++ b/optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_dl3dv.yaml @@ -0,0 +1,5 @@ +# @package _global_ + +dataset: + view_sampler: + index_path: assets/dl3dv_360_v5.json diff --git a/optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml b/optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8c9235712813bac9403dc56d99c7d2ca36a10547 --- /dev/null +++ b/optgs/config/dataset/view_sampler_dataset_specific_config/evaluation_re10k.yaml @@ -0,0 +1,5 @@ +# @package _global_ + +dataset: + view_sampler: + index_path: assets/evaluation_index_re10k.json diff --git a/optgs/config/experiment/re10k_unified.yaml b/optgs/config/experiment/re10k_unified.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d9b1a659a9e2b8934566b1aa7338ba23365fdcc --- /dev/null +++ b/optgs/config/experiment/re10k_unified.yaml @@ -0,0 +1,78 @@ +# @package _global_ + +defaults: + - override /dataset: re10k + - override /scene_trainer/scene_initializer: resplat_v1 + - override /scene_trainer/scene_optimizer: learn2splat + - override /loss: [mse, lpips] + +wandb: + name: re10k + tags: [re10k, 256x256] + +data_loader: + train: + batch_size: 14 + +meta_trainer: + max_steps: 300_001 + num_nodes: 1 + test: + eval_time_skip_steps: 5 + compute_scores: true + compute_scores_metrics: [psnr,ssim,lpips] + metrics_batch_size: 32 + +scene_trainer: + initializer: + num_depth_candidates: 128 + costvolume_unet_feat_dim: 128 + costvolume_unet_channel_mult: [1,1,1] + costvolume_unet_attn_res: [4] + gaussians_per_pixel: 1 + depth_unet_feat_dim: 32 + depth_unet_attn_res: [16] + depth_unet_channel_mult: [1,1,1,1,1] + shim_patch_size: 16 + use_fsdp: false + train_scene_init: false + train_scene_opt: false + num_update_steps: 0 + iter_batch_size: -1 + opt_batch_size: -1 + train_min_refine: 0 + train_max_refine: 0 + + +# lpips loss +loss: + lpips: + apply_after_step: 0 + weight: 0.5 + perceptual_loss: true + deltas: + weight: 0.0 + exclude_by_norm_grad: false + exclude_by_norm_grad_opposite: true + eps: 1e-8 + apply_after_step: 10000000 + + +dataset: + image_shape: [256, 256] + roots: [datasets/re10k] + near: 0.01 + far: 100. + baseline_scale_bounds: false + make_baseline_1: false + train_times_per_scene: 1 + highres: false + scannet: false + tartanair: false + load_depth: false + pose_align_first_view: false + scale_extrinsics: 1. + load_remain_context: false + pose_align_middle_view: false + overfit_to_scene: null + opencv_pose_format: false diff --git a/optgs/config/experiment/test_colmap.yaml b/optgs/config/experiment/test_colmap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..faca0e651c688e7d3200e5d59cd6785194d29b3c --- /dev/null +++ b/optgs/config/experiment/test_colmap.yaml @@ -0,0 +1,32 @@ +# @package _global_ + +defaults: + - override /dataset: colmap + - override /scene_trainer/scene_initializer: null # overridden by init_opts.sh + - override /scene_trainer/scene_optimizer: null # overridden by checkpoint (ours) or CLI (baselines) + - override /scene_trainer/decoder: gsplat + - override /loss: [mse] + - override /meta_trainer/test/postprocessing: none + +mode: test + +scene_trainer: + train_scene_init: false + train_scene_opt: false + opt_batch_strategy: fps + +checkpointing: + pretrained_model: null + pretrained_depth: null + +meta_trainer: + test: + compute_scores: true + skip_if_outputs_exist: true + save_cameras_json: false + save_render_image: false + save_gaussian: false + eval_initialization: false + +output_dir: placeholder +log_slurm_id: true diff --git a/optgs/config/experiment/test_dl3dv.yaml b/optgs/config/experiment/test_dl3dv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ae67b044ab3eb7910726f0cc25bf107048a9f9a --- /dev/null +++ b/optgs/config/experiment/test_dl3dv.yaml @@ -0,0 +1,38 @@ +# @package _global_ + +defaults: + - override /dataset: dl3dv + - override /scene_trainer/scene_initializer: null # overridden by checkpoint (ours) or init_opts.sh + - override /scene_trainer/scene_optimizer: null # overridden by checkpoint (ours) or CLI (baselines) + - override /scene_trainer/decoder: gsplat + - override /meta_trainer/test/postprocessing: none + +mode: test + +dataset: + roots: [datasets/dl3dv-480p-chunks] + near: 0.01 + far: 200. + opencv_pose_format: false + image_shape: [256, 448] + +scene_trainer: + train_scene_init: false + train_scene_opt: false + opt_batch_strategy: fps + +checkpointing: + pretrained_model: null + pretrained_depth: null + +meta_trainer: + test: + compute_scores: true + skip_if_outputs_exist: false + save_cameras_json: false + save_render_image: false + save_gaussian: false + eval_initialization: false + +output_dir: placeholder +log_slurm_id: true diff --git a/optgs/config/experiment/test_re10k.yaml b/optgs/config/experiment/test_re10k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e9b0c8d6a9f39fd9a37d8e1032b3ee900ee3880 --- /dev/null +++ b/optgs/config/experiment/test_re10k.yaml @@ -0,0 +1,36 @@ +# @package _global_ + +defaults: + - override /dataset: re10k + - override /scene_trainer/scene_initializer: resplat_v1 + - override /scene_trainer/scene_optimizer: knn_based + - override /scene_trainer/decoder: gsplat + - override /loss: [mse] + - override /meta_trainer/test/postprocessing: none + +mode: test + +dataset: + image_shape: [512, 960] + ori_image_shape: [512, 960] + +scene_trainer: + train_scene_init: false + train_scene_opt: false + opt_batch_strategy: fps + +checkpointing: + pretrained_model: null + pretrained_depth: null + +meta_trainer: + test: + compute_scores: true + skip_if_outputs_exist: true + save_cameras_json: false + save_render_image: false + save_gaussian: false + eval_initialization: false + +output_dir: placeholder +log_slurm_id: true diff --git a/optgs/config/experiment/train_dl3dv.yaml b/optgs/config/experiment/train_dl3dv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..160f746688a3b752b777d155a54fb4666da47dbb --- /dev/null +++ b/optgs/config/experiment/train_dl3dv.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +# A shared config for training on dl3dv, used by both resplat initializer, resplat optimizer, and learn2splat optimizer. + +defaults: + - override /dataset: dl3dv + - override /scene_trainer/scene_initializer: resplat_v1 + - override /scene_trainer/scene_optimizer: learn2splat + - override /loss: [ mse, lpips ] + - override /dataset/view_sampler: boundedv2_360 + +wandb: + name: dl3dv + tags: [ dl3dv, 270x480 ] + +data_loader: + train: + batch_size: 1 + +meta_trainer: + max_steps: 50_000 + val_check_interval: 0.25 + train: + l1_loss: true + depth_smooth_loss_weight: 0.0 + test: + eval_time_skip_steps: 0 + dec_chunk_size: 30 + save_every_freq: [ 1, 10, 100, 500 ] + save_every_steps: [ 0, 10, 100, 1000 ] + +# lpips loss +loss: + lpips: + apply_after_step: 0 + weight: 0.5 + perceptual_loss: true + +dataset: + roots: [ datasets/dl3dv-480p-chunks ] + near: 0.01 + far: 200. + min_size: [ 384,512 ] + max_size: [ 512,960 ] + image_shape: [ 256, 448 ] + view_sampler: + num_context_views: 8 + num_target_views: 6 + min_distance_between_context_views: 24 + max_distance_between_context_views: 45 + initial_min_distance_between_context_views: 20 + initial_max_distance_between_context_views: 30 + +output_dir: placeholder +log_slurm_id: true \ No newline at end of file diff --git a/optgs/config/experiment/train_l2s_sparse_dl3dv.yaml b/optgs/config/experiment/train_l2s_sparse_dl3dv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..232e95237564afac18aae815bf1da5e44a518500 --- /dev/null +++ b/optgs/config/experiment/train_l2s_sparse_dl3dv.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +defaults: + - train_dl3dv + - override /meta_trainer/train/replay_buffer_cfg: default + - override /loss: [ mse, lpips, deltas ] + +loss: + mse: + weight: 1.0 + lpips: + apply_after_step: 0 + weight: 0.5 + perceptual_loss: true + deltas: + weight: 1 + exclude_by_norm_grad: true + exclude_by_norm_grad_opposite: true + eps: 1e-8 + apply_after_step: 100 + +meta_trainer: + train: + loss_on_input_views: true + loss_on_input_views_num: 4 + use_replay_buffer: true + +scene_trainer: + train_scene_opt: true + num_update_steps: 4 + train_max_refine: 6 + train_min_refine: 1 + +meta_optimizer: + lr: 1e-4 + lr_monodepth: 0.0 + + +checkpointing: + pretrained_initializer: checkpoints/optgs/unified-dl3dv-8views/init/checkpoints/epoch_20-step_100000.ckpt # resplat inititalizer + no_strict_load: false \ No newline at end of file diff --git a/optgs/config/experiment/train_l2s_sparse_dl3dv_no_delta.yaml b/optgs/config/experiment/train_l2s_sparse_dl3dv_no_delta.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc160feac7dd19ee8f524fe16e5e2a1dd440ebc4 --- /dev/null +++ b/optgs/config/experiment/train_l2s_sparse_dl3dv_no_delta.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +defaults: + - train_dl3dv + - override /meta_trainer/train/replay_buffer_cfg: default + - override /loss: [ mse, lpips ] + +loss: + mse: + weight: 1.0 + lpips: + apply_after_step: 0 + weight: 0.5 + perceptual_loss: true + +meta_trainer: + train: + loss_on_input_views: true + loss_on_input_views_num: 4 + use_replay_buffer: true + +scene_trainer: + train_scene_opt: true + num_update_steps: 4 + train_max_refine: 6 + train_min_refine: 1 + +meta_optimizer: + lr: 1e-4 + lr_monodepth: 0.0 + + +checkpointing: + pretrained_initializer: checkpoints/optgs/unified-dl3dv-8views/init/checkpoints/epoch_20-step_100000.ckpt # resplat inititalizer + no_strict_load: false \ No newline at end of file diff --git a/optgs/config/experiment/train_l2s_sparse_dl3dv_no_loss.yaml b/optgs/config/experiment/train_l2s_sparse_dl3dv_no_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc160feac7dd19ee8f524fe16e5e2a1dd440ebc4 --- /dev/null +++ b/optgs/config/experiment/train_l2s_sparse_dl3dv_no_loss.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +defaults: + - train_dl3dv + - override /meta_trainer/train/replay_buffer_cfg: default + - override /loss: [ mse, lpips ] + +loss: + mse: + weight: 1.0 + lpips: + apply_after_step: 0 + weight: 0.5 + perceptual_loss: true + +meta_trainer: + train: + loss_on_input_views: true + loss_on_input_views_num: 4 + use_replay_buffer: true + +scene_trainer: + train_scene_opt: true + num_update_steps: 4 + train_max_refine: 6 + train_min_refine: 1 + +meta_optimizer: + lr: 1e-4 + lr_monodepth: 0.0 + + +checkpointing: + pretrained_initializer: checkpoints/optgs/unified-dl3dv-8views/init/checkpoints/epoch_20-step_100000.ckpt # resplat inititalizer + no_strict_load: false \ No newline at end of file diff --git a/optgs/config/loss/deltas.yaml b/optgs/config/loss/deltas.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8f1792e364044ed28e7370a03b643aab16f4529 --- /dev/null +++ b/optgs/config/loss/deltas.yaml @@ -0,0 +1,6 @@ +deltas: + weight: 1.0 + exclude_by_norm_grad: false + exclude_by_norm_grad_opposite: true + eps: 0.1 + apply_after_step: 100 diff --git a/optgs/config/loss/gaussians.yaml b/optgs/config/loss/gaussians.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d9816234657459490fd157de00b1593e620a91a1 --- /dev/null +++ b/optgs/config/loss/gaussians.yaml @@ -0,0 +1,6 @@ +gaussians: + weight: 1.0 + weight_scales: 0.01 + weight_opacities: 0.0 + weight_sh: 0.005 + sh_alpha: 1.0 # 1.0 = uniform; >1.0 = penalize higher SH degrees more diff --git a/optgs/config/loss/iso_scales.yaml b/optgs/config/loss/iso_scales.yaml new file mode 100644 index 0000000000000000000000000000000000000000..481f86e547ba2a85c8ef161b90d6962fc53e4d70 --- /dev/null +++ b/optgs/config/loss/iso_scales.yaml @@ -0,0 +1,2 @@ +iso_scales: + weight: 1.0 diff --git a/optgs/config/loss/lpips.yaml b/optgs/config/loss/lpips.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cb3aafcc76f7a2219e6c2ab178bd37a2af9f6472 --- /dev/null +++ b/optgs/config/loss/lpips.yaml @@ -0,0 +1,4 @@ +lpips: + weight: 0.05 + apply_after_step: 150_000 + perceptual_loss: false diff --git a/optgs/config/loss/mse.yaml b/optgs/config/loss/mse.yaml new file mode 100644 index 0000000000000000000000000000000000000000..80cc0be6dc7950661998336bd2bb5cb4ff06ba07 --- /dev/null +++ b/optgs/config/loss/mse.yaml @@ -0,0 +1,2 @@ +mse: + weight: 1.0 diff --git a/optgs/config/loss/sgd.yaml b/optgs/config/loss/sgd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cb59020610feeeaabd07e7c6ac05b11ad8094712 --- /dev/null +++ b/optgs/config/loss/sgd.yaml @@ -0,0 +1,2 @@ +sgd: + weight: 1.0 \ No newline at end of file diff --git a/optgs/config/loss/sh0.yaml b/optgs/config/loss/sh0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..80cc0be6dc7950661998336bd2bb5cb4ff06ba07 --- /dev/null +++ b/optgs/config/loss/sh0.yaml @@ -0,0 +1,2 @@ +mse: + weight: 1.0 diff --git a/optgs/config/loss/ssim.yaml b/optgs/config/loss/ssim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..92534c7128150e54ea745e651c808d436a4f8cb7 --- /dev/null +++ b/optgs/config/loss/ssim.yaml @@ -0,0 +1,2 @@ +ssim: + weight: 0.2 # default in 3dgs \ No newline at end of file diff --git a/optgs/config/loss/stability.yaml b/optgs/config/loss/stability.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4a9abc4af85dec6c037aeb1b029e297ca401bbe --- /dev/null +++ b/optgs/config/loss/stability.yaml @@ -0,0 +1,2 @@ +stability: + weight: 1.0 diff --git a/optgs/config/main.yaml b/optgs/config/main.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bbd128507fc8b4c60e520875c503df4445e17ce9 --- /dev/null +++ b/optgs/config/main.yaml @@ -0,0 +1,195 @@ +defaults: + - loss: [ mse ] + - dataset: null + - scene_trainer/scene_initializer: null + - scene_trainer/scene_optimizer: null + - scene_trainer/decoder: gsplat + - meta_trainer/test/postprocessing: none + - meta_trainer/train/replay_buffer_cfg: none + +wandb: + project: placeholder + entity: placeholder + name: placeholder + mode: online + id: null + notes: null + +mode: train + +data_loader: + train: + num_workers: 10 + persistent_workers: true + batch_size: 4 + seed: 1234 + test: + num_workers: 4 + persistent_workers: false + batch_size: 1 + seed: 2345 + val: + num_workers: 1 + persistent_workers: true + batch_size: 1 + seed: 3456 + +meta_optimizer: + lr: 2.e-4 + lr_monodepth: 2.e-6 + lr_depth: 0. + warm_up_steps: 2000 + weight_decay: 0.01 + warm_up_ratio: 0.01 + adamw_8bit: false + +checkpointing: + load: null + every_n_train_steps: 1000 + save_top_k: 5 + pretrained_model: null + pretrained_model_rel_dir: ${checkpoint_rel_dir:${checkpointing.pretrained_model}} + pretrained_monodepth: null + pretrained_mvdepth: null + pretrained_depth: null + pretrained_scale_predictor: null + pretrained_depth_teacher: null + no_strict_load: false + resume: false + no_resume_upsampler: false + partial_load: false + freeze_mono_vit: false + resume_update_module: null + pretrained_initializer: null + pretrained_optimizer: null + load_existing_cfg: false + +seed: 111123 + +meta_trainer: + max_steps: -1 + val_check_interval: 0.5 + gradient_clip_val: 0.5 + num_sanity_val_steps: 2 + eval_index: null + limit_test_batches: 1.0 + limit_train_batches: 1.0 + num_nodes: 1 + train: + depth_mode: null + extended_visualization: false + print_log_every_n_steps: 100 + eval_model_every_n_val: 2 # quantitative evaluation every n val + eval_data_length: 999999 + eval_deterministic: false + eval_time_skip_steps: 3 + eval_save_model: true + l1_loss: false + intermediate_loss_weight: 0.9 + no_viz_video: false + eval_depth: false + train_ignore_large_loss: 0. + no_log_projections: true + no_log_video: true + depth_loss_weight: 0. + log_depth_loss: true + depth_smooth_loss_weight: 0.01 + depth_smooth_loss_nonorm: false + depth_smooth_loss_weight_nvs: 0. # for novel views + monodepth_loss_weight: 0. # for monocular depth loss + depth_teacher_loss_weight: 0. + viz_depth_teacher: false + eval_render_depth: false + render_depth_loss_weight: 0. + viz_render_depth: false + use_gt_depth_range: false + depth_range_from_disparity: false + max_disparity: 128. + min_disparity: 4. + loss_on_input_views: false + loss_on_target_views: true + loss_on_input_views_num: 1 + loss_on_target_views_num: -1 + train_window_size: null + half_res_lpips_loss: false + viz_depth_separate: false + # L2 weight decay on Gaussian properties (meta-loss) + scale_l2_loss_weight: 0. + sh_l2_loss_weight: 0. + opacity_l2_loss_weight: 0. + use_replay_buffer: false + test: + output_path: null + compute_scores: true + compute_scores_metrics: [psnr,ssim,lpips] + metrics_batch_size: 32 + eval_time_skip_steps: 0 + eval_initialization: true + save_render_image: false + save_render_image_last_only: false + save_gt_image: false + save_render_depth: false + save_gt_depth: false + save_error_image: false + save_video: false + save_video_fixed_view: false + save_video_fixed_view_index: 0 + save_video_fixed_view_duplicate: 0 + save_video_fixed_iteration: false + save_video_fixed_iteration_indices: null + save_video_fixed_iteration_render_fixed_view: false + save_video_combined: false + save_video_combined_iterations: null + save_video_combined_fixed_iteration_length: 50 + save_gaussian: false + save_poses: false + save_cameras_json: true + save_cameras_npz: true + save_point_cloud: false + render_chunk_size: null + dec_chunk_size: null + stablize_camera: false + stab_camera_kernel: 50 + eval_context_views: false + inference_window_size: null + profile_model: false + save_colmap_train_test_views: false + ori_colmap_data_path: null + adam_optimizer_step: 0 + save_at_iters: null + save_every_freq: null + save_every_steps: null + skip_if_outputs_exist: false + scenes_filter: null + + experimental_add_noise_to_images: false + experimental_add_noise_to_images_std: null + +scene_trainer: + use_fsdp: false + train_scene_init: false + train_scene_opt: false + train_min_refine: 0 + train_max_refine: 0 + num_update_steps: 0 + iter_batch_size: -1 + opt_batch_size: -1 + opt_batch_size_min: 0 + opt_batch_size_max: 0 + opt_batch_strategy: random + sh_degree_interval: 0 + +output_dir: null + +use_plugins: false + +log_slurm_id: false + +version: null + +profiling: + # one of: none, basic, advanced, pytorch + # advanced profiling requires pytorch-lightning-2.5.3 (default: 2.4.0) + mode: none + +debug_cfg: false diff --git a/optgs/config/meta_trainer/test/postprocessing/adam.yaml b/optgs/config/meta_trainer/test/postprocessing/adam.yaml new file mode 100644 index 0000000000000000000000000000000000000000..56ea1d0dddecba1c64f0fbaa59123ad8adcd8603 --- /dev/null +++ b/optgs/config/meta_trainer/test/postprocessing/adam.yaml @@ -0,0 +1,10 @@ +defaults: + - base + +name: adam +lr_data: + _base: 0.001 +betas: [0.9, 0.999] +weight_decay: 0.0 +amsgrad: false +eps: 1e-08 diff --git a/optgs/config/meta_trainer/test/postprocessing/base.yaml b/optgs/config/meta_trainer/test/postprocessing/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ff3162fa7786b9f40fc243af2c1cdc2f115b2959 --- /dev/null +++ b/optgs/config/meta_trainer/test/postprocessing/base.yaml @@ -0,0 +1,24 @@ +steps: 2000 +compute_metrics_every: 100 +lr_data: + _base: 1 + _means: 1 + _scales: 1 + _opacities: 1 + _quats: 1 + _sh0: 1 + _shN: 1 +scheduler: null +scheduler_warm_up_ratio: 0.01 +prior_steps: 0 + +# Means LR scheduling (defaults match vanilla optimizer) +means_lr_final_ratio: 0.0625 # ratio of final/initial means LR (vanilla: 1e-5 / 1.6e-4) +means_lr_delay_mult: 0.01 # ramp-up delay multiplier (vanilla default) +means_lr_scale_by_scene_extent: true + +# View chunking for gradient accumulation +chunk_size: -1 # -1 = all views at once + +# ADC (Adaptive Density Control) - null = disabled +adc: null \ No newline at end of file diff --git a/optgs/config/meta_trainer/test/postprocessing/none.yaml b/optgs/config/meta_trainer/test/postprocessing/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72ad4973fd31722f879602949c6cc4363570f493 --- /dev/null +++ b/optgs/config/meta_trainer/test/postprocessing/none.yaml @@ -0,0 +1,5 @@ +defaults: + - base + +name: none +steps: 0 diff --git a/optgs/config/meta_trainer/test/postprocessing/sgd.yaml b/optgs/config/meta_trainer/test/postprocessing/sgd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..56ce522a0138a50723a718435a12530673c7dd86 --- /dev/null +++ b/optgs/config/meta_trainer/test/postprocessing/sgd.yaml @@ -0,0 +1,7 @@ +defaults: + - base + +name: sgd +momentum: 0.0 +weight_decay: 0.0 +nesterov: false diff --git a/optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs.yaml b/optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5d2b45aeb5ef632d2019b61bcdfb8673e93ce4df --- /dev/null +++ b/optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs.yaml @@ -0,0 +1,12 @@ +defaults: + - base + - adam + +lr_data: + _base: 1 + _means: 1.6e-4 + _scales: 5e-3 + _opacities: 5e-2 + _quats: 1e-3 + _sh0: 2.5e-3 + _shN: 1.25e-4 # 2.5e-3 / 20 \ No newline at end of file diff --git a/optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs_sgd.yaml b/optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs_sgd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd4c71627d9aa436cda54a3c421a23e1d33f02a9 --- /dev/null +++ b/optgs/config/meta_trainer/test/postprocessing/vanilla_3dgs_sgd.yaml @@ -0,0 +1,12 @@ +defaults: + - base + - sgd + +lr_data: + _base: 1 + _means: 1.6e-4 + _scales: 5e-3 + _opacities: 5e-2 + _quats: 1e-3 + _sh0: 2.5e-3 + _shN: 1.25e-4 # 2.5e-3 / 20 \ No newline at end of file diff --git a/optgs/config/meta_trainer/train/replay_buffer_cfg/default.yaml b/optgs/config/meta_trainer/train/replay_buffer_cfg/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df3d239d2ea4207f77ce8208ce39c58a9d7fa87a --- /dev/null +++ b/optgs/config/meta_trainer/train/replay_buffer_cfg/default.yaml @@ -0,0 +1,12 @@ +capacity: 20 +sample_batch_size: 1 +sample_prob: 0.7 +insert_prob: 0.7 +return_prob: 0.99 +simulate_ahead: true +simulate_ahead_min_steps: 1 +simulate_ahead_max_steps: 50 +simulate_ahead_grow: 10000 +max_t: null +push_only_if_not_full: false +remove_strategy_when_full: oldest \ No newline at end of file diff --git a/optgs/config/meta_trainer/train/replay_buffer_cfg/none.yaml b/optgs/config/meta_trainer/train/replay_buffer_cfg/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..06e1f12d212f6187e107b6175a85a0dd0ae4f724 --- /dev/null +++ b/optgs/config/meta_trainer/train/replay_buffer_cfg/none.yaml @@ -0,0 +1,12 @@ +capacity: 0 +sample_batch_size: 1 +sample_prob: 0.0 +insert_prob: 0.0 +return_prob: 0.0 +simulate_ahead: false +simulate_ahead_min_steps: 0 +simulate_ahead_max_steps: 0 +simulate_ahead_grow: 0 +max_t: null +push_only_if_not_full: false +remove_strategy_when_full: oldest \ No newline at end of file diff --git a/optgs/config/scene_trainer/decoder/gsplat.yaml b/optgs/config/scene_trainer/decoder/gsplat.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e8cfaa75b43416e67b455c371a87f4aabd67c884 --- /dev/null +++ b/optgs/config/scene_trainer/decoder/gsplat.yaml @@ -0,0 +1,4 @@ +name: gsplat +use_covariances: false +rasterize_mode: antialiased +eps2d: 0.3 \ No newline at end of file diff --git a/optgs/config/scene_trainer/decoder/inria.yaml b/optgs/config/scene_trainer/decoder/inria.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b04fae650230cf5cf5f48801b0bc34eff9567806 --- /dev/null +++ b/optgs/config/scene_trainer/decoder/inria.yaml @@ -0,0 +1,3 @@ +name: inria +scale_invariant: false +use_covariances: false diff --git a/optgs/config/scene_trainer/decoder/splatting_cuda.yaml b/optgs/config/scene_trainer/decoder/splatting_cuda.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c7a90a1e571354d6720f80419c4c9a5359afee9 --- /dev/null +++ b/optgs/config/scene_trainer/decoder/splatting_cuda.yaml @@ -0,0 +1,2 @@ +name: gsplat +scale_invariant: false \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_initializer/base.yaml b/optgs/config/scene_trainer/scene_initializer/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6197fc94d11a622ddd35f4ff6aaf88744af6af4 --- /dev/null +++ b/optgs/config/scene_trainer/scene_initializer/base.yaml @@ -0,0 +1,9 @@ +per_pixel: false +per_view: false + +train_min_gaussians_subsample: null +train_max_gaussians_subsample: null +eval_min_gaussians_subsample: null +eval_max_gaussians_subsample: null +train_fixed_gaussians_num: null +eval_fixed_gaussians_num: null \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_initializer/colmap.yaml b/optgs/config/scene_trainer/scene_initializer/colmap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9a630c6f96dbeb231836daf6967f736d2e863efe --- /dev/null +++ b/optgs/config/scene_trainer/scene_initializer/colmap.yaml @@ -0,0 +1,24 @@ +defaults: + - base + +name: colmap + +path: null +normalize_world_space: false +scaling_factor: 1.0 +init_opacity: 0.1 +sh_degree: 3 +dl3dv_settings: false + +train_fixed_gaussians_num: 70_000 # For DDP training, number of Gaussians should be the same across all processes. +# By default, testing should use all gaussians, but in validation during training, we should still use a fixed number. +# Should be set in the training script +eval_fixed_gaussians_num: null +filter_zero_rgb: false +points3d_subdir: null # if set, loads points3D from this subdir instead of the default colmap dir (cameras/images unaffected) +points3d_ply_filename: null # if set, loads points from this PLY file (relative to scene dir, e.g. "input.ply") instead of COLMAP binary +randomize_opacity: false # When true, randomizes opacity values +randomize_opacity_distribution: "uniform" # Options: "uniform" (min to init_opacity) or "gaussian" (around mean) +randomize_opacity_min: 0.0 # Minimum value for uniform distribution (only used when distribution is "uniform") +randomize_opacity_std: 0.05 # Standard deviation for gaussian distribution (only used when distribution is "gaussian") +override_dataset_poses: true # When true, overrides dataset poses with COLMAP poses diff --git a/optgs/config/scene_trainer/scene_initializer/edgs.yaml b/optgs/config/scene_trainer/scene_initializer/edgs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..265bae8d718fae849f69ec40ae72f0a087f77d7e --- /dev/null +++ b/optgs/config/scene_trainer/scene_initializer/edgs.yaml @@ -0,0 +1,10 @@ +defaults: + - base + +name: edgs + +sh_degree: 3 +init_opacity: 0.5 +scaling_factor: 0.5 +roma_model_type: outdoors +sample_init_gaussians: -1 \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_initializer/ply.yaml b/optgs/config/scene_trainer/scene_initializer/ply.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86fa26d1a2d6683f1605b253645a04715e4a0780 --- /dev/null +++ b/optgs/config/scene_trainer/scene_initializer/ply.yaml @@ -0,0 +1,8 @@ +defaults: + - base + +name: ply + +path: null +sh_degree: 3 +ply_filename: "gaussians.ply" \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_initializer/pointcloud.yaml b/optgs/config/scene_trainer/scene_initializer/pointcloud.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f977e6b8662b3fcd1b8ce6f930891344945e08f5 --- /dev/null +++ b/optgs/config/scene_trainer/scene_initializer/pointcloud.yaml @@ -0,0 +1,17 @@ +defaults: + - base + +name: pointcloud + +path: null +scaling_factor: 1.0 +init_opacity: 0.1 +sh_degree: 3 +filter_zero_rgb: true +# Transform from PLY coordinate system to camera coordinate system. +# For ScanNet++/NeRFstudio: (x,y,z) -> (y,x,-z) +world_transform: + - [0, 1, 0, 0] + - [1, 0, 0, 0] + - [0, 0, -1, 0] + - [0, 0, 0, 1] diff --git a/optgs/config/scene_trainer/scene_initializer/random.yaml b/optgs/config/scene_trainer/scene_initializer/random.yaml new file mode 100644 index 0000000000000000000000000000000000000000..97c6b1faab9bbbca0bbb5ba561a97e8f043730c2 --- /dev/null +++ b/optgs/config/scene_trainer/scene_initializer/random.yaml @@ -0,0 +1,10 @@ +defaults: + - base + +name: random + +init_num_pts: 100000 +init_extent: 3.0 +scaling_factor: 1.0 +init_opacity: 0.1 +sh_degree: 3 \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_initializer/resplat_v1.yaml b/optgs/config/scene_trainer/scene_initializer/resplat_v1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce18ce1a2f8a9c0e340c799a4aa4e8f698eac0af --- /dev/null +++ b/optgs/config/scene_trainer/scene_initializer/resplat_v1.yaml @@ -0,0 +1,151 @@ +defaults: + - base + +per_pixel: true +per_view: true + +name: resplat_v1 + +num_depth_candidates: 128 +num_surfaces: 1 + +gaussians_per_pixel: 1 + +gaussian_adapter: + gaussian_scale_min: 0.5 + gaussian_scale_max: 0.3 + sh_degree: 3 + exp_scale: false + softplus_scale: true + clamp_min_scale: 1e-6 + scale_detach_depth: false + exp_scale_bias: 4. + no_rotate_sh: true + no_sh_mask: true + init_rotation_identity: false + +d_feature: 128 + +visualizer: + num_samples: 8 + min_resolution: 256 + export_ply: false + +unimatch_weights_path: "pretrained/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth" +multiview_trans_attn_split: 2 +downscale_factor: 4 +shim_patch_size: 16 + +local_mv_match: 2 + +# monodepth +monodepth_vit_type: vits + +# return depth +return_depth: true + +# mv_unimatch +num_scales: 1 +upsample_factor: 8 +lowest_feature_resolution: 8 +depth_unet_channels: 128 +grid_sample_disable_cudnn: false + +# depthsplat color branch +large_gaussian_head: false +color_large_unet: false +init_sh_input_img: true +feature_upsampler_channels: 64 +gaussian_regressor_channels: 256 +unet_gaussian_regressor: false +resnet_gaussian_regressor: false + +# only depth +train_depth_only: false + +# point transformer +pt_head: true +pt_heads: 1 +init_pt_with_mv_attn: false +init_pt_with_mv_attn_lowres: false +pt_head_channels: null +pt_head_concat_img: false +pt_head_conv: false +multi_scale_pt: false +attn_proj_channels: 64 +fps_num_samples: null +knn_samples: 16 +post_norm: false +no_rpe: true +no_knn_attn: false +num_blocks: 4 +pt_downsample: 0 +fps_agg_func: attn +subsample_method: fps +add_pt_residual: true +pt_pred_residual_position: false + +# freeze depth +freeze_depth: false +use_gt_depth: false + +# separate depth & color +separate_depth_color: false +separate_depth_type: small +separate_depth_gaussian_scale: false + +sample_log_depth: true +bilinear_upsample_depth: false +no_upsample_depth: false +return_lowres_depth: false + +# lvsm gaussian regressor +lvsm_gaussian_regressor: false +lvsm_layers: 6 + +# latent gaussian instead of pixel aligned gaussians +latent_gs: true +latent_downsample: 4 +fixed_latent_size: true +latent_gs_img_interp: area +dpt_head_depth: false +latent_dpt_upsampler: false +latent_dpt_upsampler_no_concat: false +light_dpt_feature: false +avgpool_depth: false +nearest_down_depth: false + +# predict scene scale with a pretrained depth model +predict_scale: false +norm_by_points: false +no_pred_depth_range: false + +point_dist_init_gaussian_scale: false + +resizeconv_upsampler: false + +# handle high resolution images +depth_pred_half_res: false + +use_amp: true +pt_head_amp: true + +use_fsdp: false +use_checkpointing: false +init_use_checkpointing: false + +# refactor: new gaussian parameter order +rotate_quat_to_world: false +refine_rotate_quat_to_world: false +refine_no_use_covariance: false +latent_new_reshape: false + + + +no_pixel_offset: false + +init_gaussian_multiple: 1 +deform_sample_depth: false +deform_sample_depth_debug: false + + diff --git a/optgs/config/scene_trainer/scene_initializer/resplat_v2.yaml b/optgs/config/scene_trainer/scene_initializer/resplat_v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7fe1337fc7f108d18dcbef59cd09977c307cf263 --- /dev/null +++ b/optgs/config/scene_trainer/scene_initializer/resplat_v2.yaml @@ -0,0 +1,157 @@ +defaults: + - base + +per_pixel: true +per_view: true + +name: resplat_v2 + +num_depth_candidates: 128 +num_surfaces: 1 + +gaussians_per_pixel: 1 + +gaussian_adapter: + gaussian_scale_min: 0.5 + gaussian_scale_max: 1.0 + sh_degree: 3 + exp_scale: false + softplus_scale: true + clamp_min_scale: 1e-6 + scale_detach_depth: false + exp_scale_bias: 4. + no_rotate_sh: true + no_sh_mask: true + init_rotation_identity: false + +d_feature: 128 + +visualizer: + num_samples: 8 + min_resolution: 256 + export_ply: false + +unimatch_weights_path: "pretrained/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth" +multiview_trans_attn_split: 2 +costvolume_unet_feat_dim: 128 +costvolume_unet_channel_mult: [1,1,1] +costvolume_unet_attn_res: [] +depth_unet_feat_dim: 64 +depth_unet_attn_res: [] +depth_unet_channel_mult: [1, 1, 1] +downscale_factor: 4 +shim_patch_size: 4 + +local_mv_match: 2 + +# monodepth +monodepth_vit_type: vitb + +# return depth +return_depth: true + +# mv_unimatch +num_scales: 1 +upsample_factor: 8 +lowest_feature_resolution: 8 +depth_unet_channels: 128 +grid_sample_disable_cudnn: false + +# depthsplat color branch +large_gaussian_head: false +color_large_unet: false +init_sh_input_img: true +feature_upsampler_channels: 64 +gaussian_regressor_channels: 512 +unet_gaussian_regressor: false +resnet_gaussian_regressor: false + +# only depth +train_depth_only: false + +# point transformer +pt_head: true +pt_heads: 1 +init_pt_with_mv_attn: true +init_pt_with_mv_attn_lowres: true +pt_head_channels: null +pt_head_concat_img: false +pt_head_conv: false +multi_scale_pt: false +attn_proj_channels: 64 +fps_num_samples: null +knn_samples: 16 +post_norm: false +no_rpe: true +no_knn_attn: false +num_blocks: 6 +pt_downsample: 0 +fps_agg_func: attn +subsample_method: fps +add_pt_residual: true +pt_pred_residual_position: false + +# freeze depth +freeze_depth: false +use_gt_depth: false + +# separate depth & color +separate_depth_color: false +separate_depth_type: small +separate_depth_gaussian_scale: false + +sample_log_depth: true +bilinear_upsample_depth: false +no_upsample_depth: false +return_lowres_depth: false + +# lvsm gaussian regressor +lvsm_gaussian_regressor: false +lvsm_layers: 6 + +# latent gaussian instead of pixel aligned gaussians +latent_gs: true +latent_downsample: 4 +fixed_latent_size: true +latent_gs_img_interp: area +dpt_head_depth: false +latent_dpt_upsampler: false +latent_dpt_upsampler_no_concat: false +light_dpt_feature: false +avgpool_depth: false +nearest_down_depth: false + +# predict scene scale with a pretrained depth model +predict_scale: false +norm_by_points: false +no_pred_depth_range: false + +point_dist_init_gaussian_scale: false + +resizeconv_upsampler: false + +# handle high resolution images +depth_pred_half_res: false + +use_amp: true +pt_head_amp: true + +use_fsdp: false +use_checkpointing: false +init_use_checkpointing: false + +# refactor: new gaussian parameter order +rotate_quat_to_world: false +refine_rotate_quat_to_world: false +refine_no_use_covariance: false +latent_new_reshape: false + + + +no_pixel_offset: false + +init_gaussian_multiple: 1 +deform_sample_depth: false +deform_sample_depth_debug: false + + diff --git a/optgs/config/scene_trainer/scene_optimizer/3dgs.yaml b/optgs/config/scene_trainer/scene_optimizer/3dgs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46f55dac1b966cb9a3491ec4ef8f3357fcccefbd --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/3dgs.yaml @@ -0,0 +1,22 @@ +defaults: + - base + - override refiner: default + +name: adam + +# Adam optimizer +betas: [0.9, 0.999] +eps: 1e-15 +weight_decay: 0.0 + +# learning rates (gsplat) +base_lr: 1 +means_lr_init: 1.6e-4 +means_lr_final: 1.6e-6 +means_lr_delay_mult: 1.0 +means_lr_max_steps: 30000 # should be equal to total optimization steps +scales_lr: 5e-3 +rotations_lr: 1e-3 +opacities_lr: 5e-2 +sh0s_lr: 2.5e-3 +shNs_lr: 1.25e-4 diff --git a/optgs/config/scene_trainer/scene_optimizer/3dgs_star.yaml b/optgs/config/scene_trainer/scene_optimizer/3dgs_star.yaml new file mode 100644 index 0000000000000000000000000000000000000000..44d4c33fd69fadee8ea7af12f94b0867606b19f0 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/3dgs_star.yaml @@ -0,0 +1,22 @@ +defaults: + - base + - override refiner: default + +name: adam + +# Adam optimizer +betas: [0.99, 0.999] +eps: 1e-15 +weight_decay: 0.0 + +base_lr: 5 +# 3dgs defaults +means_lr_init: 1.6e-4 # Setting same as final to have constant LR +means_lr_final: 1.6e-4 +means_lr_delay_mult: 0.0 +means_lr_max_steps: 30000 # should be equal to total optimization steps +scales_lr: 5e-3 +rotations_lr: 1e-3 +opacities_lr: 5e-2 +sh0s_lr: 2.5e-3 +shNs_lr: 1.25e-4 \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/base.yaml b/optgs/config/scene_trainer/scene_optimizer/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..813184bc9dfd4a981b75d8f2767e20850c5b25fc --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/base.yaml @@ -0,0 +1,19 @@ +defaults: + - refiner: none + - lr_scheduler: none + +# gradinents +input_gradients_chunk_size: -1 + +# iterative refine +no_refine_mean: false +no_refine_scale: false +no_refine_rotation: false +no_refine_opacity: false +no_refine_sh0: false +no_refine_shN: false + +zero_state_on_densify: false + +# L1 opacity regularization from 3DGS-MCMC (arXiv:2404.09591); set > 0 (e.g. 1e-3) to enable +opacity_reg_lambda: 0.0 \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/knn_based.yaml b/optgs/config/scene_trainer/scene_optimizer/knn_based.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4ff20fd09c73e600eb5f0f3e32e89d46c39c7a16 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/knn_based.yaml @@ -0,0 +1,210 @@ +defaults: + - base + +name: knn_based + +# iterative refine +no_render_error: false +refine_sh_only: false +num_basic_refine_blocks: 4 +num_refine_blocks: 1 +concat_init_state: false +replace_init_state: false +state_channels: 256 +refine_block_rmsnorm: false +refine_block_layernorm: false +pt_qk_norm: false +norm_pt_block: false +refine_gaussian_multiple: 1 +refine_residual_init_state: false +clamp_refine_max_scale: 3.0 +clamp_min_scale: 1e-6 +clamp_min_raw_scales: -1e10 +clamp_max_raw_scales: 1e10 +clamp_min_raw_opacities: -7 +clamp_max_raw_opacities: 7 +gaussian_head_multiple: 1 +clamp_min_sh0: -1e10 +clamp_max_sh0: 1e10 +clamp_min_shs: -1e10 +clamp_max_shs: 1e10 +clamp_shs_soft: false + +update_attn_proj_channels: 64 +update_no_knn_attn: false +update_no_tran_block_norm: false +update_tran_block_act: gelu +multi_gaussian_scale_smaller: false +init_gaussian_multiple: 1 +refine_condition_pt_feature: true +reinit_gaussian_when_refine_multiple: false +refine_same_num_points: false +input_error_rgb_no_shuffle: false +input_error_cache_resnet_feature: false + +init_state_wo_features: false +init_state_type: constant +init_state_scale: 0 + +# point transformer +pt_heads: 1 + +# refine with mv attention +refine_with_mv_attn: false +refine_with_mv_attn_lowres: false +refine_no_mv_attn: false +mv_attn_conv_with_norm: false +refine_mv_shuffle_attn: false +refine_mv_attn_with_pos_enc: false +refine_shuffle_attn_no_norm: false +refine_mv_unimatch_attn: false +refine_knn_samples: 16 +refine_multi_scale_pt: false + +# KNN +use_fused_attn: true # fused KNN gather + attention CUDA kernel (faster, less memory) +prune_invisible_gaussians: false +knn_idx_update_every: 1 + +# inputs +input_alpha: false +input_depth: false +input_depth_smooth_error: false +input_error: false + +input_error_add_rgb_feature: false +input_error_resnet_feature: true +input_error_no_freeze_resnet_feature: false +input_error_shallow_resnet_feature: false +input_error_resnet_feature_layers: 18 + +# cross attention render error +input_error_additional_cross_attn: false +input_error_num_intermediate_views: 8 + +# add global attention to the rendered error to exchange info across views +input_error_mv_attn: false +input_error_mv_attn_blocks: 2 + +# number of views to render error +input_error_num_views: 0 + +# render error based on remaining context views +input_error_remain_context: false +input_error_merge_remain_context: false +input_error_warp_remain_context: false +input_error_random_num_remain_context: false +input_error_num_remain_context_test: 0 + + + +# explicit gradient +input_gradient: false +input_gradient_log: false +input_gradient_log_clip_deltas: 0.001 +input_gradient_scale: 1. +gradient_update_scale: 1. +input_gradient_with_ssim_loss: false +input_gradient_same_loss: false +input_gradient_loss_reduction: mean +scale_residual_grads: false + +window_local_refine: false +window_global_refine: false +window_local_global_refine: false + +# sliding window update to save training memory +update_window_size: 0 +local_gaussian_render: false + +train_global_update_only: false + +# random size refine +# update more for low resolution, less for high +random_update_with_size: false + + +use_amp: true +pt_head_amp: true +pt_update_amp: true + +use_checkpointing: false +recurrent_use_checkpointing: false + +# Debugging +debug_refine_update_module: true + +# Normalizing input +input_gradient_normalize: false +input_gradient_normalize_type: layer +input_normalize_state: false +input_normalize_gaussians: false + +# Scaling +residual_state: false +predict_state_scale: false +predict_state_scale_norm: false + +# Update head +update_head_concat_img: false +update_head_layer_num: 2 +update_head_act: gelu +update_head_final_act: identity +update_head_hidden_dim_matches: "input" # rebuttal version. switch to "output" for submission version + +update_head_scale_mag: false +update_head_scalar_scale: false +update_head_scalar_scale_act: relu + +# Per-parameter-group heads (Feature A): separate heads per param group, each with own normalize+scale +update_head_per_param_heads: false +update_head_per_param_hidden_dim: 48 # tuned so total params ≈ baseline head (~81K) +# Per-parameter scalar scales (Feature B): per-group scalar scales (requires update_head_scalar_scale=true) +update_head_per_param_scales: false + +opt_scales_before_act: false + +# Preprocessing the init gaussians +scale_initial_opacities: 1.0 +sh_d: null + +# Deactivate gaussians +local_prune_zero_radii: false +local_prune_low_weights: false +local_prune_low_weights_thresh: -1 +update_only_nonzero_grad: false + +# Experiments +experimental_run: false +experimental_update: + _base: true + _means: true + _scales: true + _quats: true + _opacities: true + _sh0: true + _shN: true + +experimental_use_grads: false + +experimental_use_norm_grads: + _base: false + _means: false + _scales: false + _quats: false + _opacities: false + _sh0: false + _shN: false + +experimental_lr: + _base: 1 + _means: 1.6e-4 + _scales: 5e-3 + _opacities: 5e-2 + _quats: 1e-3 + _sh0: 2.5e-3 + _shN: 1.25e-4 # 2.5e-3 / 20 + +# time encoding +use_time_encoding: false +time_encoding_max_steps: 2000 diff --git a/optgs/config/scene_trainer/scene_optimizer/learn2splat.yaml b/optgs/config/scene_trainer/scene_optimizer/learn2splat.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e730c789a0af3b7ae89747a91d54eca1273c68f5 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/learn2splat.yaml @@ -0,0 +1,30 @@ +defaults: + - knn_based + +name: l2s + +# General optimization settings +opt_scales_before_act: true +sh_d: 16 # should be the default, but just in case + +# Input gradient settings +input_gradient: true +input_gradient_normalize: true +input_gradient_normalize_type: adam +input_gradient_with_ssim_loss: true + +# Freeze zero-grad gaussians ❄️($G_{∇=0}) +update_only_nonzero_grad: true + +# state scale +predict_state_scale: true + +# Delta scale +update_head_scalar_scale: true +update_head_scalar_scale_act: relu # should be the default, but just in case + + + + + + diff --git a/optgs/config/scene_trainer/scene_optimizer/lr_scheduler/base.yaml b/optgs/config/scene_trainer/scene_optimizer/lr_scheduler/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..41d0251d80485d153928294b98aec3ed18d225e0 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/lr_scheduler/base.yaml @@ -0,0 +1,17 @@ +lr_data: + _base: 1 + _means: 1 + _scales: 1 + _quats: 1 + _opacities: 1 + _sh0: 1 + _shN: 1 + +apply_scheduler: + _base: true + _means: true + _scales: true + _quats: true + _opacities: true + _sh0: true + _shN: true \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/lr_scheduler/ddim.yaml b/optgs/config/scene_trainer/scene_optimizer/lr_scheduler/ddim.yaml new file mode 100644 index 0000000000000000000000000000000000000000..893a5172097a9dc94dfd039722b30cc0d5faba87 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/lr_scheduler/ddim.yaml @@ -0,0 +1,7 @@ +defaults: + - base + +name: ddim +T: 1000 +min_lr: 0.0 +s: 0.008 \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/lr_scheduler/none.yaml b/optgs/config/scene_trainer/scene_optimizer/lr_scheduler/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d834255d0d2451c9b5a0939e235b79e0ebddd56 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/lr_scheduler/none.yaml @@ -0,0 +1,12 @@ +defaults: + - base + +name: none +apply_scheduler: + _base: false + _means: false + _scales: false + _quats: false + _opacities: false + _sh0: false + _shN: false \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/none.yaml b/optgs/config/scene_trainer/scene_optimizer/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aaf371c4137c3d1bbbf7fc625d6862641453137f --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/none.yaml @@ -0,0 +1,4 @@ +defaults: + - base + +name: none diff --git a/optgs/config/scene_trainer/scene_optimizer/refiner/default.yaml b/optgs/config/scene_trainer/scene_optimizer/refiner/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bed54bcb01fe2dbbdd685b5542ea04611ff84996 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/refiner/default.yaml @@ -0,0 +1,35 @@ +name: default + +do_densify: true +do_prune: true +do_opacity_reset: true + +# +cap_max: -1 # no cap +noise_lr: 0.0 + +# scheduling +pause_refine_after_reset: 0 +refine_every: 100 +reset_every: 3000 +refine_start_iter: 500 +refine_stop_iter: 15000 +refine_scale2d_stop_iter: 0 + +# thresholds +grow_grad2d: 0.0002 +grow_scale3d: 0.01 # aka. percent_dense +prune_scale3d: 0.1 +prune_scale2d: 0.15 +grow_scale2d: 0.05 +min_opacity: 0.005 + +# Pruning parameters +prune_zero_radii: false + +# Slightly reduce opacity every few steps +reduce_opacity: false +reduce_factor: 0.0 # not used +reduce_every: 0 # not used + +fallback_means_lr: 0.0 # unused (noise_lr=0.0) \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/refiner/edgs.yaml b/optgs/config/scene_trainer/scene_optimizer/refiner/edgs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..20f868d283a9ac335e211bc96dabebd2e6171e30 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/refiner/edgs.yaml @@ -0,0 +1,35 @@ +name: edgs + +do_densify: false +do_prune: true +do_opacity_reset: false + +# +cap_max: -1 # no cap +noise_lr: 0.0 + +# scheduling +pause_refine_after_reset: 0 +refine_every: 100 +reset_every: 1000000 # effectively disable +refine_start_iter: 500 +refine_stop_iter: 15000 +refine_scale2d_stop_iter: 0 + +# thresholds +grow_grad2d: 0.0 # disabled +grow_scale3d: 0.0 # disabled +prune_scale3d: 0.0 # disabled +prune_scale2d: 0.0 # disabled +grow_scale2d: 0.0 # disabled +min_opacity: 0.005 + +# Pruning parameters +prune_zero_radii: false + +# Slightly reduce opacity every few steps +reduce_opacity: true +reduce_factor: 0.99 +reduce_every: 10 + +fallback_means_lr: 0.0 # unused (noise_lr=0.0) \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/refiner/mcmc.yaml b/optgs/config/scene_trainer/scene_optimizer/refiner/mcmc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..87fe09d5836bd2fc46f4b8b9f34f2b2fa8b58c62 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/refiner/mcmc.yaml @@ -0,0 +1,45 @@ +name: mcmc + +do_densify: true +do_prune: true +do_opacity_reset: true + +# Population growth cap. -1 disables growth (add_new is a no-op). +# Set to e.g. 1000000 to allow growing up to 1M Gaussians. +cap_max: -1 +noise_lr: 5e5 + +# scheduling +pause_refine_after_reset: 0 +refine_every: 100 +reset_every: 999999999 # effectively disable resets +refine_start_iter: 500 +refine_stop_iter: 25000 +refine_scale2d_stop_iter: 0 # disable scale2d + +# thresholds +grow_grad2d: 0.0 # not used +grow_scale3d: 0.0 # aka. percent_dense, not used +prune_scale3d: 0.0 # not used +prune_scale2d: 0.0 # not used +grow_scale2d: 0.0 # not used +min_opacity: 0.005 + +# Pruning parameters +prune_zero_radii: false + +# Slightly reduce opacity every few steps +reduce_opacity: false +reduce_factor: 0.0 # not used +reduce_every: 0 # not used + +# Fallback means lr for noise injection when optimizer has no means_lr_scheduler (e.g. Learn2SplatOptimizer). +# Original paper: means_lr (~1.6e-4) * noise_lr (5e5) ≈ 80 world units of noise. +fallback_means_lr: 1.6e-4 +relocate_copy_state: true # inherit alive Gaussian's optimizer state (better than zeroing) + +# Cap (in scale units) applied to the *noise-only* covariance computation. Does NOT change the +# rendered Gaussian scales. Needed for knn_based whose network saturates clamp_refine_max_scale, +# which produces covariances 10²-10⁴× larger than vanilla's and makes the noise overflow the +# renderer's tile-binning math (silent CUDA OOB downstream). Tune to ~scene_scale / 5. +noise_scale_cap: 1.0 \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/refiner/none.yaml b/optgs/config/scene_trainer/scene_optimizer/refiner/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a7f7e8cd500d4115fa8e74023e9437da23720eae --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/refiner/none.yaml @@ -0,0 +1,35 @@ +name: none + +do_densify: false +do_prune: false +do_opacity_reset: false + +# +cap_max: -1 # no cap +noise_lr: 0.0 + +# scheduling +pause_refine_after_reset: 0 +refine_every: 999999999 +reset_every: 999999999 +refine_start_iter: 999999999 +refine_stop_iter: 999999999 +refine_scale2d_stop_iter: 0 + +# thresholds +grow_grad2d: 0.0 +grow_scale3d: 0.0 +prune_scale3d: 0.0 +prune_scale2d: 0.0 +grow_scale2d: 0.0 +min_opacity: 0.0 + +# Pruning parameters +prune_zero_radii: false + +# Slightly reduce opacity every few steps +reduce_opacity: false +reduce_factor: 0.0 # not used +reduce_every: 0 # not used + +fallback_means_lr: 0.0 # unused (noise_lr=0.0) \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/resplat_v1.yaml b/optgs/config/scene_trainer/scene_optimizer/resplat_v1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1082080378935d93a3a5628da0a7e46ad1286107 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/resplat_v1.yaml @@ -0,0 +1,6 @@ +defaults: + - knn_based + +name: resplat_v1 +input_error: true +input_error_mv_attn: true \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/resplat_v2.yaml b/optgs/config/scene_trainer/scene_optimizer/resplat_v2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2203843e3161607169721fc6aac39e66ad98daed --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/resplat_v2.yaml @@ -0,0 +1,11 @@ +defaults: + - knn_based + +name: resplat_v2 +input_error: true +input_error_mv_attn: true +input_error_add_rgb_feature: true +refine_knn_samples: 8 +state_channels: 512 +residual_state: true +update_head_layer_num: 4 \ No newline at end of file diff --git a/optgs/config/scene_trainer/scene_optimizer/sgd.yaml b/optgs/config/scene_trainer/scene_optimizer/sgd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42e0116ff8319a9ff0835b15af629d8cb659b5d7 --- /dev/null +++ b/optgs/config/scene_trainer/scene_optimizer/sgd.yaml @@ -0,0 +1,22 @@ +defaults: + - base + - override refiner: default + +name: sgd + +# Adam optimizer +betas: [0.9, 0.999] +eps: 1e-15 +weight_decay: 0.0 + +# learning rates (gsplat) +base_lr: 1 +means_lr_init: 1.6e-4 +means_lr_final: 1e-5 +means_lr_delay_mult: 0.01 +means_lr_max_steps: 30000 # should be equal to total optimization steps +scales_lr: 5e-3 +rotations_lr: 1e-3 +opacities_lr: 5e-2 +sh0s_lr: 2.5e-3 +shNs_lr: 1.25e-4 \ No newline at end of file diff --git a/optgs/config_migrate.py b/optgs/config_migrate.py new file mode 100644 index 0000000000000000000000000000000000000000..e8dff60d8f46234ff71f97353da649caebaf57cf --- /dev/null +++ b/optgs/config_migrate.py @@ -0,0 +1,188 @@ +from omegaconf import OmegaConf + + +CURRENT_CFG_VERSION = 2 + +def migrate(cfg_dict): + was_omega = not isinstance(cfg_dict, dict) + version = cfg_dict.get("version", 0) + + # null means a fresh run from main.yaml — treat as current version. + if version is None: + version = CURRENT_CFG_VERSION + + if version == 0: + # Heuristic: configs that were partially migrated may have version=0 but a + # non-depthsplat optimizer name (already renamed during v0→v1), so skip v0→v1. + so = cfg_dict.get("scene_trainer", {}).get("scene_optimizer", {}) + if so.get("name", "") not in ["depthsplat"]: + version = 1 + else: + print("Migrating config from version 0 (cvpr submission) to version 1 (cvpr rebuttal)...") + cfg_dict = migrate_v0_to_v1(cfg_dict) + version = 1 + + if version == 1: + print("Migrating config from version 1 to version 2 (train/test moved under meta_trainer)...") + cfg_dict = migrate_v1_to_v2(cfg_dict) + version = 2 + + if version != CURRENT_CFG_VERSION: + raise ValueError(f"Unsupported config version: {version}") + + # Apply code-level renames and strip stale fields. + # Work on a plain dict so mutations propagate; convert back to OmegaConf if needed. + cfg_container = OmegaConf.to_container(cfg_dict, resolve=False) if not isinstance(cfg_dict, dict) else cfg_dict + + # Handle code-level renames that don't require a version bump (e.g. resplat → resplat_v1). + so = cfg_container.get("scene_trainer", {}).get("scene_optimizer", {}) + si = cfg_container.get("scene_trainer", {}).get("scene_initializer", {}) + if so.get("name") == "resplat": + so["name"] = "resplat_v1" + if si.get("name") == "resplat": + si["name"] = "resplat_v1" + + # Strip stale postprocessing fields from old checkpoint configs + pp = cfg_container.get("meta_trainer", {}).get("test", {}).get("postprocessing", None) + if isinstance(pp, dict): + pp.pop("__target__", None) + pp.pop("enabled", None) + pp.pop("lr", None) + + # Strip stale foundationstereo fields (encoder removed) + si.pop("foundationstereo", None) + si.pop("fstereo_num_refine", None) + + if was_omega: + return OmegaConf.create(cfg_container) + return cfg_container + + +def migrate_v1_to_v2(cfg_dict): + """ + Migration from v1 to v2: move top-level 'train' and 'test' under 'meta_trainer'. + """ + cfg = OmegaConf.to_container(cfg_dict, resolve=False) if not isinstance(cfg_dict, dict) else dict(cfg_dict) + + meta_trainer = cfg.setdefault("meta_trainer", {}) + + for key in ("train", "test"): + if key in cfg and key not in meta_trainer: + meta_trainer[key] = cfg.pop(key) + + cfg["version"] = 2 + return cfg + + +def migrate_v0_to_v1(cfg): + """ + Migration from submission v0 (refine_*) to rebuttal v1 (input_error_*). + """ + + cfg = OmegaConf.to_container(cfg, resolve=False) + + so = cfg["scene_trainer"]["scene_optimizer"] + si = cfg["scene_trainer"]["scene_initializer"] + + # ------------------------------------------------------------------ + # Module renames + # ------------------------------------------------------------------ + if si["name"] == "depthsplat": + si["name"] = "resplat_v1" + if so["name"] == "depthsplat": + if so["refine_input_gradient"]: + so["name"] = "learn2splat" + else: + so["name"] = "resplat_v1" + + # ------------------------------------------------------------------ + # Key renames (declarative) + # ------------------------------------------------------------------ + RENAME_MAP = { + # feature extraction + "refine_lpips_error": "input_error_lpips_features", + "refine_pool_vgg_features": "input_error_pool_vgg_features", + "refine_use_all_vgg_features": "input_error_use_all_vgg_features", + "refine_vit_feature": "input_error_vit_feature", + "refine_resnet_feature": "input_error_resnet_feature", + "no_freeze_resnet_feature": "input_error_no_freeze_resnet_feature", + "shallow_resnet_feature": "input_error_shallow_resnet_feature", + "resnet_feature_layers": "input_error_resnet_feature_layers", + "refine_convnext_feature": "input_error_convnext_feature", + "convnext_feature_size": "input_error_convnext_feature_size", + "refine_concat_feature": "input_error_concat_feature", + "refine_concat_feature_cosine": "input_error_concat_feature_cosine", + "refine_cosine_feature": "input_error_cosine_feature", + "refine_add_feature": "input_error_add_feature", + "refine_concat_rgb_feature_error": "input_error_concat_rgb_feature_error", + + # render error → input error + "render_error_no_abs": "input_error_no_abs", + "render_error_no_shuffle": "input_error_no_shuffle", + "render_cache_resnet_feature": "input_error_cache_resnet_feature", + "render_view_pool_resnet_feature": "input_error_view_pool_resnet_feature", + "render_global_pool_resnet_feature": "input_error_global_pool_resnet_feature", + + # input toggles + "refine_input_alpha": "input_alpha", + "refine_input_depth": "input_depth", + "refine_input_depth_smooth_error": "input_depth_smooth_error", + "refine_input_error": "input_error", + + # attention (input error) + "radii_averaged_render_error": "input_error_radii_averaged", + "cross_attn_additional_render_error": "input_error_additional_cross_attn", + "num_intermediate_views": "input_error_num_intermediate_views", + "render_error_mv_attn_blocks": "input_error_mv_attn_blocks", + + # context handling + "render_error_num_views": "input_error_num_views", + "render_error_remain_context": "input_error_remain_context", + "render_error_merge_remain_context": "input_error_merge_remain_context", + "render_error_warp_remain_context": "input_error_warp_remain_context", + "render_error_random_num_remain_context": "input_error_random_num_remain_context", + "render_error_num_remain_context_test": "input_error_num_remain_context_test", + "render_error_warp_input_view": "input_error_warp_input_view", + + # input gradient + "refine_input_gradient": "input_gradient", + "refine_input_gradient_log": "input_gradient_log", + "refine_input_gradient_log_clip_deltas": "input_gradient_log_clip_deltas", + "refine_input_gradient_scale": "input_gradient_scale", + + # normalize input + "normalize_update_input": "input_gradient_normalize", + "normalize_update_input_type": "input_gradient_normalize_type", + "normalize_state": "input_normalize_state", + "normalize_gaussians": "input_normalize_gaussians", + + + # update head + "final_head_act": "update_head_final_act", + "refine_output_scale_mag": "update_head_scale_mag", + "scalar_scale_out": "update_head_scalar_scale", + "scalar_scale_out_act": "update_head_scalar_scale_act", + + } + + for old, new in RENAME_MAP.items(): + if old in so: + so[new] = so.pop(old) + + # ------------------------------------------------------------------ + # New / fixed defaults + # ------------------------------------------------------------------ + if so["name"] in ["clogs", "learn2splat", "resplat_v1"]: + so["update_head_hidden_dim_matches"] = "output" + else: + raise NotImplementedError + + if so["state_channels"] == 0: + so["state_channels"] = 256 + + # ------------------------------------------------------------------ + # Version bump + # ------------------------------------------------------------------ + cfg["version"] = 1 + + return OmegaConf.create(cfg) diff --git a/optgs/dataset/__init__.py b/optgs/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..274af83221657b5b82687b7ecc89e41287a4d05a --- /dev/null +++ b/optgs/dataset/__init__.py @@ -0,0 +1,39 @@ +import warnings + +from torch.utils.data import Dataset +from typing import Type + +from ..misc.step_tracker import StepTracker +from .dataset_re10k import DatasetRE10k, DatasetRE10kCfg +from .dataset_dl3dv import DatasetDL3DV, DatasetDL3DVCfg +from .dataset_colmap import DatasetColmap, DatasetColmapCfg +from .dataset_scannet import DatasetScannet, DatasetScannetCfg +from .data_types import Stage +from .view_sampler import get_view_sampler + +DATASETS: dict[str, Type[Dataset]] = { + "re10k": DatasetRE10k, + "dl3dv": DatasetDL3DV, + "colmap": DatasetColmap, + "scannet": DatasetScannet, +} + + +DatasetCfg = DatasetRE10kCfg | DatasetDL3DVCfg | DatasetColmapCfg | DatasetScannetCfg + + +def get_dataset( + cfg: DatasetCfg, + stage: Stage, + step_tracker: StepTracker | None, +) -> Dataset: + print(f"Using dataset: {cfg.name}") + view_sampler = get_view_sampler( + cfg.view_sampler, + stage, + cfg.overfit_to_scene is not None, + cfg.cameras_are_circular, + step_tracker, + ) + + return DATASETS[cfg.name](cfg, stage, view_sampler) diff --git a/optgs/dataset/camera_datasets/__init__.py b/optgs/dataset/camera_datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/dataset/camera_datasets/camera.py b/optgs/dataset/camera_datasets/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..d925c5cf667f6b02b2b9d7a9f29860842ccb20f9 --- /dev/null +++ b/optgs/dataset/camera_datasets/camera.py @@ -0,0 +1,397 @@ + +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from jaxtyping import Float +from torch import Tensor +from pathlib import Path +import os +import json + +from optgs.geometry.projection import get_fov, get_projection_matrix +from optgs.visualization.camera_trajectory.wobble import generate_wobble_transformation +from optgs.visualization.camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics + + +def get_scene_scale(camtoworlds: Float[np.ndarray, "N 4 4"]) -> float: + # camtoworlds: [N, 4, 4] + # size of the scene measured by cameras as in gsplat + camera_locations = camtoworlds[:, :3, 3] + scene_center = np.mean(camera_locations, axis=0) + dists = np.linalg.norm(camera_locations - scene_center, axis=1) + scene_scale = np.max(dists) + return float(scene_scale) * 1.1 + + +class Camera(nn.Module): + """ + A camera class that stores the camera parameters and the image for Re10k dataset. + + Attributes: + image_name: + extrinsics: C2W matrix (4x4 torch.Tensor) + intrinsics: K matrix (3x3 torch.Tensor) + near: Near clipping plane distance + far: Far clipping plane distance + image: RGB image (3xHxW torch.Tensor) + fov_x: Field of view in x direction + fov_y: Field of view in y direction + image_heigth: Height of the image + image_width: Width of the image + view_matrix: View matrix (4x4 torch.Tensor) + full_projection_matrix: Full projection matrix (4x4 torch.Tensor) + camera_center: Camera center (3 torch.Tensor) + """ + def __init__( + self, + colmap_id: str, + extrinsics: Float[Tensor, "4 4"], + intrinsics: Float[Tensor, "3 3"], + extrinsics_render_view: Float[Tensor, "4 4"], + intrinsics_render_view: Float[Tensor, "3 3"], + scale_matrix: Float[Tensor, "4 4"], + trans_matrix: Float[Tensor, "4 4"], + image: Float[Tensor, "3 h w"], + raw_image_shape: tuple[int, int], + image_name: str, + uid: int, + near: Float[Tensor, "1"], + far: Float[Tensor, "1"], + data_device: torch.device, + gt_alpha_mask: Float[Tensor, "1 h w"] | None = None, + trans=np.array([0.0, 0.0, 0.0]), + scale=1.0 + ): + super(Camera, self).__init__() + + self.idx = -1 + self.uid = uid + self.colmap_id = colmap_id + self.image_name = image_name + + try: + self.data_device = data_device + except Exception as e: + print(e) + print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) + self.data_device = torch.device("cuda") + + self.extrinsics = extrinsics.to(self.data_device) # C2W matrix! (not really extrinsics) + self.intrinsics = intrinsics.to(self.data_device) + self.extrinsics_render_view = extrinsics_render_view.to(self.data_device) + self.intrinsics_render_view = intrinsics_render_view.to(self.data_device) + self.scale_matrix = scale_matrix.to(self.data_device) + self.trans_matrix = trans_matrix.to(self.data_device) + + self.raw_image_shape = raw_image_shape + + self.original_image = image.clamp(0.0, 1.0) + self.image_width = self.original_image.shape[2] + self.image_height = self.original_image.shape[1] + + if gt_alpha_mask is not None: + # self.original_image *= gt_alpha_mask.to(self.data_device) + self.gt_alpha_mask = gt_alpha_mask.to(self.data_device) + else: + # self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) + self.gt_alpha_mask = None + + self.zfar = far.to(self.data_device) + self.znear = near.to(self.data_device) + + self.trans = trans + self.scale = scale + + fov_x, fov_y = get_fov(self.intrinsics.unsqueeze(0)).unbind(dim=-1) + + self.FoVx = fov_x.item() + self.FoVy = fov_y.item() + + projection_matrix = get_projection_matrix(self.znear, self.zfar, fov_x, fov_y) + projection_matrix = rearrange(projection_matrix, "b i j -> b j i") + view_matrix = rearrange(self.extrinsics.inverse(), "i j -> j i") + full_projection = (view_matrix.unsqueeze(0) @ projection_matrix)[0] + + self.camera_center = self.extrinsics[:3, 3] + self.projection_matrix = projection_matrix[0].transpose(0, 1) + self.world_view_transform = view_matrix + self.full_proj_transform = full_projection + + def save(self, save_dir: Path): + cam_dir = save_dir / self.image_name + os.makedirs(cam_dir, exist_ok=True) + + torch.save(self.extrinsics, cam_dir / "extrinsics.pt") + torch.save(self.intrinsics, cam_dir / "intrinsics.pt") + torch.save(self.original_image, cam_dir / "image.pt") + + if self.gt_alpha_mask is not None: + torch.save(self.gt_alpha_mask, cam_dir / "gt_alpha_mask.pt") + + with open(cam_dir / "cam_info.json", "w") as f: + json.dump( + { + "colmap_id": self.colmap_id, + "image_name": self.image_name, + "uid": self.uid, + "raw_image_shape": self.raw_image_shape, + "near": self.znear.item(), + "far": self.zfar.item() + }, + f, + indent=4, + ) + + @classmethod + def load_camera(cls, cam_dir: Path, data_device: torch.device): + extrinsics = torch.load(cam_dir / "extrinsics.pt") + intrinsics = torch.load(cam_dir / "intrinsics.pt") + image = torch.load(cam_dir / "image.pt") + + if (cam_dir / "gt_alpha_mask.pt").exists(): + gt_alpha_mask = torch.load(cam_dir / "gt_alpha_mask.pt") + else: + gt_alpha_mask = None + + with open(cam_dir / "cam_info.json", "r") as f: + cam_info = json.load(f) + + return cls( + colmap_id=cam_info["colmap_id"], + extrinsics=extrinsics.to(data_device), + intrinsics=intrinsics.to(data_device), + image=image.to(data_device), + gt_alpha_mask=gt_alpha_mask.to(data_device) if gt_alpha_mask is not None else None, + raw_image_shape=tuple(cam_info["raw_image_shape"]), + image_name=cam_info["image_name"], + uid=cam_info["uid"], + near=torch.Tensor([cam_info["near"]]).to(data_device), + far=torch.Tensor([cam_info["far"]]).to(data_device), + data_device=data_device, + ).to(data_device) + + +def generate_cam_params_for_wobble(t: Tensor, cam_a: Camera, cam_b: Camera): + origin_a = cam_a.extrinsics[:3, 3] + origin_b = cam_b.extrinsics[:3, 3] + cam_a_extrinsics = cam_a.extrinsics + cam_b_extrinsics = cam_b.extrinsics + cam_a_intrinsics = cam_a.intrinsics + cam_b_intrinsics = cam_b.intrinsics + + delta = (origin_a - origin_b).norm(dim=-1) + + tf = generate_wobble_transformation( + radius=delta * 0.5, + t=t, + num_rotations=1, + scale_radius_with_t=False, + ) + + extrinsics = interpolate_extrinsics( + initial=cam_a_extrinsics, + final=cam_b_extrinsics, + t=(t - 2), + ) + intrinsics = interpolate_intrinsics( + initial=cam_a_intrinsics, + final=cam_b_intrinsics, + t=(t - 2), + ) + return extrinsics @ tf, intrinsics + + +def generate_cam_params_for_interpolation(t: Tensor, cam_a: Camera, cam_b: Camera): + cam_a_extrinsics = cam_a.extrinsics + cam_a_extrinsics_render_view = cam_a.extrinsics_render_view + cam_b_extrinsics = cam_b.extrinsics + cam_b_extrinsics_render_view = cam_b.extrinsics_render_view + cam_a_intrinsics = cam_a.intrinsics + cam_a_intrinsics_render_view = cam_a.intrinsics_render_view + cam_b_intrinsics = cam_b.intrinsics + cam_b_intrinsics_render_view = cam_b.intrinsics_render_view + + extrinsics = interpolate_extrinsics( + initial=cam_a_extrinsics, + final=cam_b_extrinsics, + t=(t - 2), + ) + intrinsics = interpolate_intrinsics( + initial=cam_a_intrinsics, + final=cam_b_intrinsics, + t=(t - 2), + ) + extrinsics_render_view = interpolate_extrinsics( + initial=cam_a_extrinsics_render_view, + final=cam_b_extrinsics_render_view, + t=(t - 2), + ) + intrinsics_render_view = interpolate_intrinsics( + initial=cam_a_intrinsics_render_view, + final=cam_b_intrinsics_render_view, + t=(t - 2), + ) + return extrinsics, intrinsics, extrinsics_render_view, intrinsics_render_view + + +def get_intermediate_cameras(cam_a: Camera, cam_b: Camera, num_frames: int = 150, smooth: bool = False): + t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=cam_a.data_device) + if smooth: t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 + + extrinsics, intrinsics, extrinsics_render_view, intrinsics_render_view = ( + generate_cam_params_for_interpolation(t, cam_a, cam_b) + ) + extrinsics = extrinsics.squeeze(0) + intrinsics = intrinsics.squeeze(0) + extrinsics_render_view = extrinsics_render_view.squeeze(0) + intrinsics_render_view = intrinsics_render_view.squeeze(0) + + cameras = [ + Camera( + colmap_id=cam_a.colmap_id, + image_name=f"{cam_a.image_name}_{index:04d}", + uid=index, + near=cam_a.znear, + far=cam_a.zfar, + data_device=cam_a.data_device, + image=cam_a.original_image, # These views have no ground truth image but we should never require images for mesh views + raw_image_shape=cam_a.raw_image_shape, + extrinsics=extrinsics[index], + intrinsics=intrinsics[index], + extrinsics_render_view=extrinsics_render_view[index], + intrinsics_render_view=intrinsics_render_view[index], + scale_matrix=cam_a.scale_matrix, + trans_matrix=cam_a.trans_matrix, + gt_alpha_mask=None + ) + for index in range(num_frames) + ] + return cameras + + +def patch_shim(cams: list[Camera], patch_size: int) -> list[Camera]: + new_cams = [] + + for cam in cams: + _, h, w = cam.original_image.shape + + assert h % 2 == 0 and w % 2 == 0 + + h_new = (h // patch_size) * patch_size + row = (h - h_new) // 2 + w_new = (w // patch_size) * patch_size + col = (w - w_new) // 2 + + # Center-crop the image. + new_original_image = cam.original_image[:, row : row + h_new, col : col + w_new] + + # Adjust the intrinsics to account for the cropping. + new_intrinsics = cam.intrinsics.clone() + new_intrinsics[0, 2] -= col + new_intrinsics[1, 2] -= row + + # Adjust the intrinsics to account for the cropping. + new_render_view_intrinsics = cam.intrinsics_render_view.clone() + new_render_view_intrinsics[0] -= col + new_render_view_intrinsics[1] -= row + + new_cams.append( + Camera( + colmap_id=cam.colmap_id, + image_name=cam.image_name, + uid=cam.uid, + near=cam.znear, + far=cam.zfar, + data_device=cam.data_device, + raw_image_shape=cam.raw_image_shape, + image=new_original_image, + extrinsics=cam.extrinsics, + intrinsics=new_intrinsics, + extrinsics_render_view=cam.extrinsics_render_view, + intrinsics_render_view=new_render_view_intrinsics, + scale_matrix=cam.scale_matrix, + trans_matrix=cam.trans_matrix, + gt_alpha_mask=cam.gt_alpha_mask + ) + ) + + return new_cams + + +def calculate_cameras_extent(cam_centers: Tensor): + avg_cam_center = cam_centers.mean(dim=0, keepdim=True) + dist = torch.norm(cam_centers - avg_cam_center, dim=-1, keepdim=True) + diagonal = dist.max() + + center = avg_cam_center.flatten() + radius = diagonal * 1.1 + + translate = -center + return translate, radius.item() + + +def save_cameras(cameras: list[Camera], save_dir: Path): + os.makedirs(save_dir, exist_ok=True) + + extrinsics = torch.stack([cam.extrinsics for cam in cameras]) + intrinsics = torch.stack([cam.intrinsics for cam in cameras]) + images = torch.stack([cam.original_image for cam in cameras]) + + torch.save(extrinsics, save_dir / "extrinsics.pt") + torch.save(intrinsics, save_dir / "intrinsics.pt") + torch.save(images, save_dir / "images.pt") + + if cameras[0].gt_alpha_mask is not None: + gt_alpha_masks = torch.stack([cam.gt_alpha_mask for cam in cameras]) + torch.save(gt_alpha_masks, save_dir / "gt_alpha_masks.pt") + + with open(save_dir / "cam_info.json", "w") as f: + json.dump( + { + "num_cameras": len(cameras), + "image_shape": [(cam.image_height, cam.image_width) for cam in cameras], + "znear": [cam.znear.item() for cam in cameras], + "zfar": [cam.zfar.item() for cam in cameras], + "uids": [cam.uid for cam in cameras], + "colmap_ids": [cam.colmap_id for cam in cameras], + "raw_image_shapes": [cam.raw_image_shape for cam in cameras], + }, + f, + indent=4, + ) + +def load_cameras(cam_dir: Path, device: torch.device) -> list[Camera]: + cameras = [] + + extrinsics = torch.load(cam_dir / "extrinsics.pt") + intrinsics = torch.load(cam_dir / "intrinsics.pt") + images = torch.load(cam_dir / "images.pt") + + if (cam_dir / "gt_alpha_masks.pt").exists(): + gt_alpha_masks = torch.load(cam_dir / "gt_alpha_masks.pt") + else: + gt_alpha_masks = [None] * len(images) + + with open(cam_dir / "cam_info.json", "r") as f: + cam_info = json.load(f) + + for idx in range(cam_info["num_cameras"]): + cameras.append( + Camera( + colmap_id=cam_info["colmap_ids"][idx], + image_name=f"image_{idx:04d}", + uid=cam_info["uids"][idx], + near=torch.Tensor([cam_info["znear"][idx]]).to(device), + far=torch.Tensor([cam_info["zfar"][idx]]).to(device), + data_device=device, + image=images[idx].to(device), + extrinsics=extrinsics[idx].to(device), + intrinsics=intrinsics[idx].to(device), + raw_image_shape=tuple(cam_info["raw_image_shapes"][idx]), + gt_alpha_mask=gt_alpha_masks[idx].to(device) if gt_alpha_masks[idx] is not None else None + ) + ) + + return cameras + \ No newline at end of file diff --git a/optgs/dataset/colmap/__init__.py b/optgs/dataset/colmap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/dataset/colmap/normalize.py b/optgs/dataset/colmap/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..681623b311625065744813f565666e9a8154a33c --- /dev/null +++ b/optgs/dataset/colmap/normalize.py @@ -0,0 +1,143 @@ +import numpy as np + + +def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"): + """ + reference: nerf-factory + Get a similarity transform to normalize dataset + from c2w (OpenCV convention) cameras + :param c2w: (N, 4) + :return T (4,4) , scale (float) + """ + t = c2w[:, :3, 3] + R = c2w[:, :3, :3] + + # (1) Rotate the world so that z+ is the up axis + # we estimate the up axis by averaging the camera up axes + ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) + world_up = np.mean(ups, axis=0) + world_up /= np.linalg.norm(world_up) + + up_camspace = np.array([0.0, -1.0, 0.0]) + c = (up_camspace * world_up).sum() + cross = np.cross(world_up, up_camspace) + skew = np.array( + [ + [0.0, -cross[2], cross[1]], + [cross[2], 0.0, -cross[0]], + [-cross[1], cross[0], 0.0], + ] + ) + if c > -1: + R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) + else: + # In the unlikely case the original data has y+ up axis, + # rotate 180-deg about x axis + R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + + # R_align = np.eye(3) # DEBUG + R = R_align @ R + fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) + t = (R_align @ t[..., None])[..., 0] + + # (2) Recenter the scene. + if center_method == "focus": + # find the closest point to the origin for each camera's center ray + nearest = t + (fwds * -t).sum(-1)[:, None] * fwds + translate = -np.median(nearest, axis=0) + elif center_method == "poses": + # use center of the camera positions + translate = -np.median(t, axis=0) + else: + raise ValueError(f"Unknown center_method {center_method}") + + transform = np.eye(4) + transform[:3, 3] = translate + transform[:3, :3] = R_align + + # (3) Rescale the scene using camera distances + scale_fn = np.max if strict_scaling else np.median + scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1)) + transform[:3, :] *= scale + + return transform + + +def align_principal_axes(point_cloud): + # Compute centroid + centroid = np.median(point_cloud, axis=0) + + # Translate point cloud to centroid + translated_point_cloud = point_cloud - centroid + + # Compute covariance matrix + covariance_matrix = np.cov(translated_point_cloud, rowvar=False) + + # Compute eigenvectors and eigenvalues + eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix) + + # Sort eigenvectors by eigenvalues (descending order) so that the z-axis + # is the principal axis with the smallest eigenvalue. + sort_indices = eigenvalues.argsort()[::-1] + eigenvectors = eigenvectors[:, sort_indices] + + # Check orientation of eigenvectors. If the determinant of the eigenvectors is + # negative, then we need to flip the sign of one of the eigenvectors. + if np.linalg.det(eigenvectors) < 0: + eigenvectors[:, 0] *= -1 + + # Create rotation matrix + rotation_matrix = eigenvectors.T + + # Create SE(3) matrix (4x4 transformation matrix) + transform = np.eye(4) + transform[:3, :3] = rotation_matrix + transform[:3, 3] = -rotation_matrix @ centroid + + return transform + + +def transform_points(matrix, points): + """Transform points using an SE(3) matrix. + + Args: + matrix: 4x4 SE(3) matrix + points: Nx3 array of points + + Returns: + Nx3 array of transformed points + """ + assert matrix.shape == (4, 4) + assert len(points.shape) == 2 and points.shape[1] == 3 + return points @ matrix[:3, :3].T + matrix[:3, 3] + + +def transform_cameras(matrix, camtoworlds): + """Transform cameras using an SE(3) matrix. + + Args: + matrix: 4x4 SE(3) matrix + camtoworlds: Nx4x4 array of camera-to-world matrices + + Returns: + Nx4x4 array of transformed camera-to-world matrices + """ + assert matrix.shape == (4, 4) + assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4) + camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix) + scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1) + camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None] + return camtoworlds + + +def normalize(camtoworlds, points=None): + T1 = similarity_from_cameras(camtoworlds) + camtoworlds = transform_cameras(T1, camtoworlds) + if points is not None: + points = transform_points(T1, points) + T2 = align_principal_axes(points) + camtoworlds = transform_cameras(T2, camtoworlds) + points = transform_points(T2, points) + return camtoworlds, points, T2 @ T1 + else: + return camtoworlds, T1 diff --git a/optgs/dataset/colmap/utils.py b/optgs/dataset/colmap/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3332ebc23bea11383ec715c71bd549ffbb802e72 --- /dev/null +++ b/optgs/dataset/colmap/utils.py @@ -0,0 +1,633 @@ +import json +import os +from typing import Any, Dict, List, Optional, OrderedDict + +import cv2 +import imageio.v2 as imageio +import numpy as np +import torch +from PIL import Image +from pycolmap import Image as ColmapImage +from pycolmap import SceneManager, Quaternion +from tqdm import tqdm +from typing_extensions import assert_never + +from .normalize import ( + align_principal_axes, + similarity_from_cameras, + transform_cameras, + transform_points, +) + + +def new_load_images_txt(self, input_file): + self.images = OrderedDict() + with open(input_file, "r") as f: + lines = [line.rstrip("\n") for line in f] + + idx = 0 + num_lines = len(lines) + + while idx < num_lines: + line = lines[idx].strip() + + # Skip comments + if not line or line.startswith("#"): + idx += 1 + continue + + # ------------------------- + # Line 1: image metadata + # ------------------------- + data = line.split() + + image_id = int(data[0]) + qvec = np.array(data[1:5], dtype=float) + tvec = np.array(data[5:8], dtype=float) + camera_id = int(data[8]) + image_name = data[9] + + image = ColmapImage( + image_name, + camera_id, + Quaternion(qvec), + tvec + ) + + # ------------------------- + # Line 2: POINTS2D (may be empty) + # ------------------------- + idx += 1 + if idx >= num_lines: + raise ValueError("Unexpected EOF while reading POINTS2D") + + line = lines[idx].strip() + + if not line: + image.points2D = np.empty((0, 2), dtype=float) + image.point3D_ids = np.empty((0,), dtype=np.uint64) + else: + data = line.split() + + x = np.array(data[0::3], dtype=float) + y = np.array(data[1::3], dtype=float) + image.points2D = np.stack([x, y], axis=1) + + image.point3D_ids = np.array(data[2::3], dtype=np.uint64) + + # ------------------------- + # Store image + # ------------------------- + self.images[image_id] = image + self.name_to_image_id[image.name] = image_id + self.last_image_id = max(self.last_image_id, image_id) + + idx += 1 + + +SceneManager._load_images_txt = new_load_images_txt + + +def _get_rel_paths(path_dir: str) -> List[str]: + """Recursively get relative paths of files in a directory.""" + paths = [] + for dp, dn, fn in os.walk(path_dir): + for f in fn: + paths.append(os.path.relpath(os.path.join(dp, f), path_dir)) + return paths + + +def _resize_image_folder(image_dir: str, resized_dir: str, factor: int) -> str: + """Resize image folder.""" + print(f"Downscaling images by {factor}x from {image_dir} to {resized_dir}.") + os.makedirs(resized_dir, exist_ok=True) + + image_files = _get_rel_paths(image_dir) + for image_file in tqdm(image_files): + image_path = os.path.join(image_dir, image_file) + resized_path = os.path.join( + resized_dir, os.path.splitext(image_file)[0] + ".png" + ) + if os.path.isfile(resized_path): + continue + image = imageio.imread(image_path)[..., :3] + resized_size = ( + int(round(image.shape[1] / factor)), + int(round(image.shape[0] / factor)), + ) + resized_image = np.array( + Image.fromarray(image).resize(resized_size, Image.BICUBIC) + ) + imageio.imwrite(resized_path, resized_image) + return resized_dir + + +class SilentSceneManager(SceneManager): + """A silent version of SceneManager that suppresses print statements.""" + + def load_colmap_project_file(self, project_file=None, image_path=None): + if project_file is None: + project_file = self.folder + 'project.ini' + + self.image_path = image_path + + if self.image_path is None: + try: + with open(project_file, 'r') as f: + for line in iter(f.readline, ''): + if line.startswith('image_path'): + self.image_path = line[11:].strip() + break + except: + pass + + if self.image_path is None: + # Difference from parent class: no print statement + pass + elif not self.image_path.endswith('/'): + self.image_path += '/' + + +class Parser: + """COLMAP parser.""" + + def __init__( + self, + data_dir: str, + factor: int = 1, + normalize: bool = False, + load_images: bool = True, + dl3dv_settings: bool = False, + points3d_subdir: Optional[str] = None, + verbose: bool = True, + ): + self.data_dir = data_dir + self.factor = factor + self.normalize = normalize + + if dl3dv_settings: + colmap_dir = os.path.join(data_dir, "sparse_train_points/0/") + else: + colmap_dir = os.path.join(data_dir, "sparse/0/") + if not os.path.exists(colmap_dir): + colmap_dir = os.path.join(data_dir, "sparse") + + assert os.path.exists(colmap_dir), f"COLMAP directory {colmap_dir} does not exist." + + if verbose: + manager = SceneManager(colmap_dir) + else: + manager = SilentSceneManager(colmap_dir) + manager.load_cameras() + manager.load_images() + + # Load points3D — optionally from a different subfolder + if points3d_subdir is not None: + points3d_dir = os.path.join(data_dir, points3d_subdir) + points3d_bin = os.path.join(points3d_dir, "points3D.bin") + points3d_txt = os.path.join(points3d_dir, "points3D.txt") + if os.path.exists(points3d_bin): + manager.load_points3D(points3d_bin) + elif os.path.exists(points3d_txt): + manager.load_points3D(points3d_txt) + else: + raise IOError(f"No points3D file found in {points3d_dir}") + else: + manager.load_points3D() + + # Extract extrinsic matrices in world-to-camera format. + imdata = manager.images + w2c_mats = [] + camera_ids = [] + Ks_dict = dict() + params_dict = dict() + imsize_dict = dict() # width, height + mask_dict = dict() + bottom = np.array([0, 0, 0, 1]).reshape(1, 4) + for k in tqdm(imdata, disable=not verbose): + im = imdata[k] + rot = im.R() + trans = im.tvec.reshape(3, 1) + w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) + w2c_mats.append(w2c) + + # support different camera intrinsics + camera_id = im.camera_id + camera_ids.append(camera_id) + + # camera intrinsics + cam = manager.cameras[camera_id] + fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + K[:2, :] /= factor + Ks_dict[camera_id] = K + + # Get distortion parameters. + type_ = cam.camera_type + if type_ == 0 or type_ == "SIMPLE_PINHOLE": + params = np.empty(0, dtype=np.float32) + camtype = "perspective" + elif type_ == 1 or type_ == "PINHOLE": + params = np.empty(0, dtype=np.float32) + camtype = "perspective" + if type_ == 2 or type_ == "SIMPLE_RADIAL": + params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32) + camtype = "perspective" + elif type_ == 3 or type_ == "RADIAL": + params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32) + camtype = "perspective" + elif type_ == 4 or type_ == "OPENCV": + params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32) + camtype = "perspective" + elif type_ == 5 or type_ == "OPENCV_FISHEYE": + params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32) + camtype = "fisheye" + assert ( + camtype == "perspective" or camtype == "fisheye" + ), f"Only perspective and fisheye cameras are supported, got {type_}" + + params_dict[camera_id] = params + imsize_dict[camera_id] = (cam.width // factor, cam.height // factor) + mask_dict[camera_id] = None + if verbose: + print( + f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras." + ) + + if len(imdata) == 0: + raise ValueError("No images found in COLMAP.") + if not (type_ == 0 or type_ == 1): + if verbose: + print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.") + + w2c_mats = np.stack(w2c_mats, axis=0) + + # Convert extrinsics to camera-to-world. + camtoworlds = np.linalg.inv(w2c_mats) + + # Image names from COLMAP. No need for permuting the poses according to + # image names anymore. + image_names = [imdata[k].name for k in imdata] + + # Previous Nerf results were generated with images sorted by filename, + # ensure metrics are reported on the same test set. + inds = np.argsort(image_names) + image_names = [image_names[i] for i in inds] + camtoworlds = camtoworlds[inds] + camera_ids = [camera_ids[i] for i in inds] + + # Load extended metadata. Used by Bilarf dataset. + self.extconf = { + "spiral_radius_scale": 1.0, + "no_factor_suffix": False, + } + extconf_file = os.path.join(data_dir, "ext_metadata.json") + if os.path.exists(extconf_file): + with open(extconf_file) as f: + self.extconf.update(json.load(f)) + + # Load bounds if possible (only used in forward facing scenes). + self.bounds = np.array([0.01, 1.0]) + posefile = os.path.join(data_dir, "poses_bounds.npy") + if os.path.exists(posefile): + self.bounds = np.load(posefile)[:, -2:] + + # Load images. + if dl3dv_settings: + # DL3DV settings + image_dir_suffix = "_train" + colmap_image_suffix = "_train" + else: + colmap_image_suffix = "" + if factor > 1 and not self.extconf["no_factor_suffix"]: + image_dir_suffix = f"_{factor}" + else: + image_dir_suffix = "" + + if load_images: + colmap_image_dir = os.path.join(data_dir, "images" + colmap_image_suffix) + print("COLMAP image dir:", colmap_image_dir) + + image_dir = os.path.join(data_dir, "images" + image_dir_suffix) + + # Prefer an existing (non-empty) images_{factor}/ directory. Only + # fall back to images_{factor}_png/ — resizing from the full-res + # colmap image dir when even that is missing — if it is absent. + if factor > 1 and not (os.path.isdir(image_dir) and os.listdir(image_dir)): + image_dir = image_dir + "_png" + if not (os.path.isdir(image_dir) and os.listdir(image_dir)): + image_dir = _resize_image_folder( + colmap_image_dir, image_dir, factor=factor + ) + + print("Image dir:", image_dir) + if not os.path.exists(image_dir): + raise ValueError(f"Image folder {image_dir} does not exist.") + + # Build stem -> relative path mapping for files in image_dir + image_files_by_stem = {} + for f in _get_rel_paths(image_dir): + stem = os.path.splitext(f)[0] + image_files_by_stem[stem] = f + + # Match colmap image entries to image_dir files by filename stem, so + # images load regardless of their on-disk extension (.JPG/.jpg/.png/…) + # and whether or not the original colmap image dir is present. + colmap_to_image = { + cf: image_files_by_stem[os.path.splitext(cf)[0]] + for cf in image_names + if os.path.splitext(cf)[0] in image_files_by_stem + } + + image_files = sorted(_get_rel_paths(image_dir)) + image_paths = [ + os.path.join(image_dir, colmap_to_image[f]) + if f in colmap_to_image + else os.path.join(image_dir, image_files_by_stem.get(os.path.splitext(f)[0], f)) + for f in image_names + ] + + # Filter out views that don't have corresponding images in the image folder + existing_mask = [os.path.exists(p) for p in image_paths] + if not all(existing_mask): + num_missing = sum(1 for m in existing_mask if not m) + if verbose: + print(f"[Parser] Filtering out {num_missing} views without corresponding images.") + existing_indices = [i for i, m in enumerate(existing_mask) if m] + image_names = [image_names[i] for i in existing_indices] + image_paths = [image_paths[i] for i in existing_indices] + camtoworlds = camtoworlds[existing_indices] + camera_ids = [camera_ids[i] for i in existing_indices] + if verbose: + print(f"[Parser] Remaining {len(image_names)} images after filtering.") + if len(image_names) == 0: + raise ValueError( + f"[Parser] Remaining 0 images after filtering: all {num_missing} " + f"views were dropped because their images are missing from {image_dir}." + ) + + else: + + image_paths = None + + # 3D points and {image_name -> [point_idx]} + points = manager.points3D.astype(np.float32) + points_err = manager.point3D_errors.astype(np.float32) + points_rgb = manager.point3D_colors.astype(np.uint8) + point_indices = dict() + + image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()} + for point_id, data in manager.point3D_id_to_images.items(): + for image_id, _ in data: + image_name = image_id_to_name[image_id] + point_idx = manager.point3D_id_to_point3D_idx[point_id] + point_indices.setdefault(image_name, []).append(point_idx) + point_indices = { + k: np.array(v).astype(np.int32) for k, v in point_indices.items() + } + + # Normalize the world space. + if normalize: + T1 = similarity_from_cameras(camtoworlds) + camtoworlds = transform_cameras(T1, camtoworlds) + points = transform_points(T1, points) + + T2 = align_principal_axes(points) + camtoworlds = transform_cameras(T2, camtoworlds) + points = transform_points(T2, points) + + transform = T2 @ T1 + + # Fix for up side down. We assume more points towards + # the bottom of the scene which is true when ground floor is + # present in the images. + if np.median(points[:, 2]) > np.mean(points[:, 2]): + # rotate 180 degrees around x axis such that z is flipped + T3 = np.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + camtoworlds = transform_cameras(T3, camtoworlds) + points = transform_points(T3, points) + transform = T3 @ transform + else: + transform = np.eye(4) + + self.image_names = image_names # List[str], (num_images,) + self.image_paths = image_paths # List[str], (num_images,) + self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) + self.camera_ids = camera_ids # List[int], (num_images,) + self.Ks_dict = Ks_dict # Dict of camera_id -> K + self.params_dict = params_dict # Dict of camera_id -> params + self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) + self.mask_dict = mask_dict # Dict of camera_id -> mask + self.points = points # np.ndarray, (num_points, 3) + self.points_err = points_err # np.ndarray, (num_points,) + self.points_rgb = points_rgb # np.ndarray, (num_points, 3) + self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] + self.transform = transform # np.ndarray, (4, 4) + + # load one image to check the size. In the case of tanksandtemples dataset, the + # intrinsics stored in COLMAP corresponds to 2x upsampled images. + if load_images: + actual_image = imageio.imread(self.image_paths[0])[..., :3] + actual_height, actual_width = actual_image.shape[:2] + else: + actual_width, actual_height = self.imsize_dict[self.camera_ids[0]] + colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]] + s_height, s_width = actual_height / colmap_height, actual_width / colmap_width + for camera_id, K in self.Ks_dict.items(): + K[0, :] *= s_width + K[1, :] *= s_height + self.Ks_dict[camera_id] = K + width, height = self.imsize_dict[camera_id] + self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height)) + + # undistortion + self.mapx_dict = dict() + self.mapy_dict = dict() + self.roi_undist_dict = dict() + for camera_id in self.params_dict.keys(): + params = self.params_dict[camera_id] + if len(params) == 0: + continue # no distortion + assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}" + assert ( + camera_id in self.params_dict + ), f"Missing params for camera {camera_id}" + K = self.Ks_dict[camera_id] + width, height = self.imsize_dict[camera_id] + + if camtype == "perspective": + K_undist, roi_undist = cv2.getOptimalNewCameraMatrix( + K, params, (width, height), 0 + ) + mapx, mapy = cv2.initUndistortRectifyMap( + K, params, None, K_undist, (width, height), cv2.CV_32FC1 + ) + mask = None + elif camtype == "fisheye": + fx = K[0, 0] + fy = K[1, 1] + cx = K[0, 2] + cy = K[1, 2] + grid_x, grid_y = np.meshgrid( + np.arange(width, dtype=np.float32), + np.arange(height, dtype=np.float32), + indexing="xy", + ) + x1 = (grid_x - cx) / fx + y1 = (grid_y - cy) / fy + theta = np.sqrt(x1 ** 2 + y1 ** 2) + r = ( + 1.0 + + params[0] * theta ** 2 + + params[1] * theta ** 4 + + params[2] * theta ** 6 + + params[3] * theta ** 8 + ) + mapx = (fx * x1 * r + width // 2).astype(np.float32) + mapy = (fy * y1 * r + height // 2).astype(np.float32) + + # Use mask to define ROI + mask = np.logical_and( + np.logical_and(mapx > 0, mapy > 0), + np.logical_and(mapx < width - 1, mapy < height - 1), + ) + y_indices, x_indices = np.nonzero(mask) + y_min, y_max = y_indices.min(), y_indices.max() + 1 + x_min, x_max = x_indices.min(), x_indices.max() + 1 + mask = mask[y_min:y_max, x_min:x_max] + K_undist = K.copy() + K_undist[0, 2] -= x_min + K_undist[1, 2] -= y_min + roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min] + else: + assert_never(camtype) + + self.mapx_dict[camera_id] = mapx + self.mapy_dict[camera_id] = mapy + self.Ks_dict[camera_id] = K_undist + self.roi_undist_dict[camera_id] = roi_undist + self.imsize_dict[camera_id] = (roi_undist[2], roi_undist[3]) + self.mask_dict[camera_id] = mask + + # size of the scene measured by cameras + camera_locations = camtoworlds[:, :3, 3] + scene_center = np.mean(camera_locations, axis=0) + dists = np.linalg.norm(camera_locations - scene_center, axis=1) + self.scene_scale = np.max(dists) + + # set height and width from the first image + first_camera_id = self.camera_ids[0] + self.height, self.width = self.imsize_dict[first_camera_id] + + +class Dataset: + """A simple dataset class.""" + + def __init__( + self, + parser: Parser, + split: str = "train", + patch_size: Optional[int] = None, + load_depths: bool = False, + ): + self.parser = parser + self.split = split + self.patch_size = patch_size + self.load_depths = load_depths + self.indices = np.arange(len(self.parser.image_names)) + + def __len__(self): + return len(self.indices) + + def __getitem__(self, item: int) -> Dict[str, Any]: + index = self.indices[item] + image = imageio.imread(self.parser.image_paths[index])[..., :3] + camera_id = self.parser.camera_ids[index] + K = self.parser.Ks_dict[camera_id].copy() # undistorted K + params = self.parser.params_dict[camera_id] + camtoworlds = self.parser.camtoworlds[index] + mask = self.parser.mask_dict[camera_id] + + if len(params) > 0: + # Images are distorted. Undistort them. + mapx, mapy = ( + self.parser.mapx_dict[camera_id], + self.parser.mapy_dict[camera_id], + ) + image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR) + x, y, w, h = self.parser.roi_undist_dict[camera_id] + image = image[y: y + h, x: x + w] + + if self.patch_size is not None: + # Random crop. + h, w = image.shape[:2] + x = np.random.randint(0, max(w - self.patch_size, 1)) + y = np.random.randint(0, max(h - self.patch_size, 1)) + image = image[y: y + self.patch_size, x: x + self.patch_size] + K[0, 2] -= x + K[1, 2] -= y + + data = { + "K": torch.from_numpy(K).float(), + "camtoworld": torch.from_numpy(camtoworlds).float(), + "image": torch.from_numpy(image).float(), + "image_id": item, # the index of the image in the dataset + } + if mask is not None: + data["mask"] = torch.from_numpy(mask).bool() + + if self.load_depths: + # projected points to image plane to get depths + worldtocams = np.linalg.inv(camtoworlds) + image_name = self.parser.image_names[index] + point_indices = self.parser.point_indices[image_name] + points_world = self.parser.points[point_indices] + points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T + points_proj = (K @ points_cam.T).T + points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2) + depths = points_cam[:, 2] # (M,) + # filter out points outside the image + selector = ( + (points[:, 0] >= 0) + & (points[:, 0] < image.shape[1]) + & (points[:, 1] >= 0) + & (points[:, 1] < image.shape[0]) + & (depths > 0) + ) + points = points[selector] + depths = depths[selector] + data["points"] = torch.from_numpy(points).float() + data["depths"] = torch.from_numpy(depths).float() + + return data + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, default="data/360_v2/garden") + parser.add_argument("--factor", type=int, default=4) + args = parser.parse_args() + + # Parse COLMAP data. + parser = Parser(data_dir=args.data_dir, factor=args.factor, normalize=True) + dataset = Dataset(parser, split="train", load_depths=True) + print(f"Dataset: {len(dataset)} images.") + + writer = imageio.get_writer("results/points.mp4", fps=30) + for data in tqdm(dataset, desc="Plotting points"): + image = data["image"].numpy().astype(np.uint8) + points = data["points"].numpy() + depths = data["depths"].numpy() + for x, y in points: + cv2.circle(image, (int(x), int(y)), 2, (255, 0, 0), -1) + writer.append_data(image) + writer.close() diff --git a/optgs/dataset/data_module.py b/optgs/dataset/data_module.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1e16362f031ec58f5990520949b9d85eeecf34 --- /dev/null +++ b/optgs/dataset/data_module.py @@ -0,0 +1,147 @@ +import random +from dataclasses import dataclass +from typing import Callable + +import numpy as np +import torch +from pytorch_lightning import LightningDataModule +from torch import Generator, nn +from torch.utils.data import DataLoader, Dataset, IterableDataset + +from . import DatasetCfg, get_dataset +from .data_types import DataShim, Stage +from .validation_wrapper import ValidationWrapper +from ..misc.step_tracker import StepTracker + + +def get_data_shim(encoder: nn.Module) -> DataShim: + """Get functions that modify the batch. It's sometimes necessary to modify batches + outside the data loader because GPU computations are required to modify the batch or + because the modification depends on something outside the data loader. + """ + + shims: list[DataShim] = [] + if hasattr(encoder, "get_data_shim"): + shims.append(encoder.get_data_shim()) + + def combined_shim(batch): + for shim in shims: + batch = shim(batch) + return batch + + return combined_shim + + +@dataclass +class DataLoaderStageCfg: + batch_size: int + num_workers: int + persistent_workers: bool + seed: int | None + + +@dataclass +class DataLoaderCfg: + train: DataLoaderStageCfg + test: DataLoaderStageCfg + val: DataLoaderStageCfg + + +DatasetShim = Callable[[Dataset, Stage], Dataset] + + +def worker_init_fn(worker_id: int) -> None: + random.seed(int(torch.utils.data.get_worker_info().seed) % (2 ** 32 - 1)) + np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2 ** 32 - 1)) + + +class DataModule(LightningDataModule): + dataset_cfg: DatasetCfg + data_loader_cfg: DataLoaderCfg + step_tracker: StepTracker | None + dataset_shim: DatasetShim + global_rank: int + + def __init__( + self, + dataset_cfg: DatasetCfg, + data_loader_cfg: DataLoaderCfg, + step_tracker: StepTracker | None = None, + dataset_shim: DatasetShim = lambda dataset, _: dataset, + global_rank: int = 0, + ) -> None: + super().__init__() + self.dataset_cfg = dataset_cfg + self.data_loader_cfg = data_loader_cfg + self.step_tracker = step_tracker + self.dataset_shim = dataset_shim + self.global_rank = global_rank + + def get_persistent(self, loader_cfg: DataLoaderStageCfg) -> bool | None: + return None if loader_cfg.num_workers == 0 else loader_cfg.persistent_workers + + def get_generator(self, loader_cfg: DataLoaderStageCfg) -> torch.Generator | None: + if loader_cfg.seed is None: + return None + generator = Generator() + generator.manual_seed(loader_cfg.seed + self.global_rank) + return generator + + def train_dataloader(self): + loader_cfg = self.data_loader_cfg.train + + dataset = get_dataset( + self.dataset_cfg, + "train", + self.step_tracker, + ) + dataset = self.dataset_shim(dataset, "train") + + return DataLoader( + dataset, + loader_cfg.batch_size, + shuffle=not isinstance(dataset, IterableDataset), + num_workers=loader_cfg.num_workers, + generator=self.get_generator(loader_cfg), + worker_init_fn=worker_init_fn, + persistent_workers=self.get_persistent(loader_cfg), + ) + + def val_dataloader(self): + loader_cfg = self.data_loader_cfg.val + + dataset = get_dataset( + self.dataset_cfg, + "val", + self.step_tracker, + ) + dataset = self.dataset_shim(dataset, "val") + + return DataLoader( + ValidationWrapper(dataset, 1), + loader_cfg.batch_size, + num_workers=loader_cfg.num_workers, + generator=self.get_generator(loader_cfg), + worker_init_fn=worker_init_fn, + persistent_workers=self.get_persistent(loader_cfg), + ) + + def test_dataloader(self, dataset_cfg=None): + loader_cfg = self.data_loader_cfg.test + + dataset = get_dataset( + self.dataset_cfg if dataset_cfg is None else dataset_cfg, + "test", + self.step_tracker, + ) + dataset = self.dataset_shim(dataset, "test") + + return DataLoader( + dataset, + loader_cfg.batch_size, + num_workers=loader_cfg.num_workers, + generator=self.get_generator(loader_cfg), + worker_init_fn=worker_init_fn, + persistent_workers=self.get_persistent(loader_cfg), + shuffle=False, + ) diff --git a/optgs/dataset/data_types.py b/optgs/dataset/data_types.py new file mode 100644 index 0000000000000000000000000000000000000000..b0aae8d6679d59976a01c566ce873c0d69beb0d6 --- /dev/null +++ b/optgs/dataset/data_types.py @@ -0,0 +1,143 @@ +from dataclasses import dataclass +from typing import Callable, Literal, TypedDict + +import torch +from jaxtyping import Float, Int64, Bool +from torch import Tensor + +Stage = Literal["train", "val", "test"] + + +# The following types mainly exist to make type-hinted keys show up in VS Code. Some +# dimensions are annotated as "_" because either: +# 1. They're expected to change as part of a function call (e.g., resizing the dataset). +# 2. They're expected to vary within the same function call (e.g., the number of views, +# which differs between context and target BatchedViews). + +class BatchedViewsDict(TypedDict, total=False): + extrinsics: Float[Tensor, "batch view 4 4"] # batch view 4 4 + intrinsics: Float[Tensor, "batch view 3 3"] # batch view 3 3 + image: Float[Tensor, "batch view channel height width"] # batch view channel height width + near: Float[Tensor, "batch view"] # batch view + far: Float[Tensor, "batch view"] # batch view + index: Int64[Tensor, "batch view"] # batch view + scene_scale: Float[Tensor, "batch"] # batch + + +@dataclass +class BatchedViews: + """ + BatchedViews represents a batch of views. The batch dimension allows for efficient processing of multiple + scenes simultaneously. + + This class is a wrapper around the BatchedViewsDict TypedDict. + Some dict like behavior is still missing, can be added as needed. + """ + extrinsics: Float[Tensor, "batch _ 4 4"] # batch view 4 4 + intrinsics: Float[Tensor, "batch _ 3 3"] # batch view 3 3 + image: Float[Tensor, "batch _ _ _ _"] # batch view channel height width + near: Float[Tensor, "batch _"] # batch view + far: Float[Tensor, "batch _"] # batch view + index: Int64[Tensor, "batch _"] # batch view + scene_scale: Float[Tensor, "batch "] | None + + x_flipped: Bool[Tensor, "batch"] + + viewpoint_stack: Int64[Tensor, "batch _"] | None = None # batch subset + used_indices_list: list[Int64[Tensor, "_"]] | None = None # list of tensors of shape (subset,) + + def __contains__(self, key): + return hasattr(self, key) + + def get(self, key, default=None): + return getattr(self, key, default) + + def __getitem__(self, item: str): + return getattr(self, item) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def batchify_views(self, indices: Int64[Tensor, "batch _"]) -> "BatchedViews": + """Select a subset of views for each example in the batch.""" + + scene_batch = indices.size(0) + scene_batch_idx = torch.arange(scene_batch, device=indices.device)[:, None] + + return BatchedViews( + extrinsics=self.extrinsics[scene_batch_idx, indices], + intrinsics=self.intrinsics[scene_batch_idx, indices], + image=self.image[scene_batch_idx, indices], + near=self.near[scene_batch_idx, indices], + far=self.far[scene_batch_idx, indices], + index=self.index[scene_batch_idx, indices], + x_flipped=self.x_flipped, + scene_scale=self.scene_scale + ) + + @classmethod + def from_dict(cls, data: BatchedViewsDict) -> "BatchedViews": + b = data["extrinsics"].size(0) + device = data["extrinsics"].device + return cls( + extrinsics=data["extrinsics"], + intrinsics=data["intrinsics"], + image=data["image"], + near=data["near"], + far=data["far"], + index=data["index"], + x_flipped=data.get("x_flipped", torch.zeros(b, dtype=torch.bool, device=device)), + viewpoint_stack=data.get("viewpoint_stack", None), + used_indices_list=data.get("used_indices_list", None), + scene_scale=data.get("scene_scale", None) + ) + + def reset_viewpoint_stack_if_needed(self, strategy, batch_size) -> None: + + if self.viewpoint_stack is None or self.viewpoint_stack.size(1) < batch_size: + # Create a new viewpoint stack + batch = self.extrinsics.size(0) + num_views = self.extrinsics.size(1) + + if strategy == "random": + # Create a random permutation of viewpoints for each example in the batch, by sorting random values + rand_matrix = torch.rand(batch, num_views, device=self.extrinsics.device) + new_stack = torch.argsort(rand_matrix, dim=1) + else: + new_stack = torch.arange(num_views, device=self.extrinsics.device).unsqueeze(0).expand(batch, + -1).clone() + + # Assign the new viewpoint stack + if self.viewpoint_stack is None or self.viewpoint_stack.size(1) == 0: + self.viewpoint_stack = new_stack + else: + # Concatenate the existing stack with the new one + self.viewpoint_stack = torch.cat([self.viewpoint_stack, new_stack], dim=1) + + +class BatchedExample(TypedDict, total=False): + target: BatchedViews | BatchedViewsDict + context: BatchedViews | BatchedViewsDict + scene: list[str] + + +class UnbatchedViews(TypedDict, total=False): + extrinsics: Float[Tensor, "_ 4 4"] + intrinsics: Float[Tensor, "_ 3 3"] + image: Float[Tensor, "_ 3 height width"] + near: Float[Tensor, " _"] + far: Float[Tensor, " _"] + index: Int64[Tensor, " _"] + + +class UnbatchedExample(TypedDict, total=False): + target: UnbatchedViews + context: UnbatchedViews + scene: str + + +# A data shim modifies the example after it's been returned from the data loader. +DataShim = Callable[[BatchedExample], BatchedExample] + +AnyExample = BatchedExample | UnbatchedExample +AnyViews = BatchedViews | UnbatchedViews diff --git a/optgs/dataset/dataset.py b/optgs/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..834d941f53ffb42fa55c0db56aa08f3d003151d6 --- /dev/null +++ b/optgs/dataset/dataset.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +from .view_sampler import ViewSamplerCfg + + +@dataclass +class DatasetCfgCommon: + image_shape: list[int] + background_color: list[float] + cameras_are_circular: bool + pose_align_middle_view: bool + overfit_to_scene: str | None + view_sampler: ViewSamplerCfg + opencv_pose_format: bool | None + + test_start_idx: int # skip the first N scenes during test (for scene-chunked SLURM jobs) \ No newline at end of file diff --git a/optgs/dataset/dataset_colmap.py b/optgs/dataset/dataset_colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..18423056bf3eea1b1fda0e2cf6c57bc2fc2afc33 --- /dev/null +++ b/optgs/dataset/dataset_colmap.py @@ -0,0 +1,271 @@ +# Adapted from https://github.com/nerfstudio-project/gsplat/blob/b5392febf6097655c18db17693636cd21bbe58c0/examples/datasets/colmap.py + +from dataclasses import dataclass +from pathlib import Path +from typing import List, Literal, Optional + +import imageio +import numpy as np +import torch +import torchvision.transforms as tf +from einops import repeat +from jaxtyping import Float +from torch import Tensor +from torch.utils.data import IterableDataset + +from .colmap.utils import Parser +from .data_types import Stage +from .dataset import DatasetCfgCommon +from .shims.patch_shim import apply_patch_shim +from .view_sampler import ViewSampler +from .view_sampler.view_sampler_all import ViewSamplerAll +from .view_sampler.view_sampler_dense import ViewSamplerDense +from .view_sampler.view_sampler_evaluation import ViewSamplerEvaluation +from .view_sampler.view_sampler_ids import ViewSamplerIDs + + +@dataclass +class DatasetColmapCfg(DatasetCfgCommon): + name: Literal["colmap"] + roots: Path + scene_name: Optional[str] # If None, iterate over all scenes in roots + normalize_world_space: bool + subsample_factor: int + crop_size: None | int | list[int] + symmetric_principal_point: bool = False # override cx, cy to image center (matches 3DGS getProjectionMatrix) + + +class DatasetColmap(IterableDataset): + cfg: DatasetColmapCfg + stage: Stage + view_sampler: ViewSampler + + to_tensor: tf.ToTensor + near: float = 0.01 + far: float = 100.0 + + def __init__( + self, + cfg: DatasetColmapCfg, + stage: Stage, + view_sampler: ViewSampler, + ) -> None: + super().__init__() + + # COLMAP datasets should only be used for testing/validation, not training + if stage == "train": + raise ValueError( + "COLMAP dataset does not support training stage. " + "Use 'test' or 'val' stage instead. " + "COLMAP scenes are typically small and meant for evaluation." + ) + + self.cfg = cfg + self.stage = stage + self.view_sampler = view_sampler + + # check if view_sampler is supported + assert isinstance(self.view_sampler, (ViewSamplerDense, ViewSamplerIDs, ViewSamplerAll, ViewSamplerEvaluation)), \ + "COLMAP dataset requires ViewSamplerDense, ViewSamplerIDs, ViewSamplerAll, or ViewSamplerEvaluation." + self.to_tensor = tf.ToTensor() + + # Discover available scenes + if cfg.scene_name is not None: + # Single scene mode (backward compatible) + self.scene_names = [cfg.scene_name] + else: + # Multi-scene mode: list all subdirectories that contain COLMAP data + self.scene_names = self._discover_scenes(cfg.roots) + + print(f"Found {len(self.scene_names)} scene(s) in {cfg.roots}: {self.scene_names}") + + # Image shape will be set when loading the first scene + self.image_shape = None + + @staticmethod + def _discover_scenes(roots: Path) -> List[str]: + """Discover all valid COLMAP scenes in the roots directory.""" + scenes = [] + for subdir in sorted(roots.iterdir()): + if subdir.is_dir(): + # Check if this looks like a COLMAP scene (has sparse folder or images folder) + if (subdir / "sparse").exists() or (subdir / "images").exists(): + scenes.append(subdir.name) + return scenes + + def _load_scene(self, scene_name: str) -> dict: + """Load a single scene and return it in chunk format.""" + colmap_root = self.cfg.roots / scene_name + assert colmap_root.exists(), f"COLMAP root {colmap_root} does not exist." + + print( + f"Loading COLMAP scene '{scene_name}' from {colmap_root} with subsample factor {self.cfg.subsample_factor}") + + # Create parser for this scene + print(f"in dataset NORMALIZE {self.cfg.normalize_world_space}") + parser = Parser( + data_dir=str(colmap_root), + factor=self.cfg.subsample_factor, + normalize=self.cfg.normalize_world_space, + ) + print(f"parser scene scale {parser.scene_scale * 1.1}") + + # Update image shape from first loaded scene + if self.image_shape is None: + self.image_shape = [parser.height, parser.width] + + # Convert to chunk format + return self._create_chunk_from_parser(parser, scene_name) + + def _create_chunk_from_parser(self, parser: Parser, scene_name: str) -> dict: + """Convert COLMAP parser data to DL3DV-style chunk format.""" + + # Collect all camera data (both context and target) + all_indices = list(range(len(parser.image_names))) + + # Build cameras tensor (fx, fy, cx, cy, 4x4 w2c matrix) + extrinsics_list = [] + intrinsics_list = [] + images_list = [] + + for idx in all_indices: + camera_id = parser.camera_ids[idx] + + # Get image dimensions + w, h = parser.imsize_dict[camera_id] + + # Get camera intrinsics + K = parser.Ks_dict[camera_id].copy() + + if self.cfg.symmetric_principal_point: + K[0, 2] = w / 2.0 + K[1, 2] = h / 2.0 + + # Normalize camera intrinsics + K[0, :] /= w + K[1, :] /= h + + # check if K is invertible + if np.linalg.matrix_rank(K) < 3: + print(K) + raise ValueError(f"Camera intrinsic matrix for image {parser.image_names[idx]} is not invertible.") + + # Get camera-to-world matrix + c2w = parser.camtoworlds[idx] + + # Pack + extrinsics = torch.from_numpy(c2w).float() + intrinsics = torch.from_numpy(K).float() + + extrinsics_list.append(extrinsics) + intrinsics_list.append(intrinsics) + + # Load image + image = imageio.imread(parser.image_paths[idx])[..., :3] + image = torch.from_numpy(image).permute(2, 0, 1) # C, H, W + + images_list.append(image) # list of C, H, W tensors + + extrinsics = torch.stack(extrinsics_list, dim=0) + intrinsics = torch.stack(intrinsics_list, dim=0) + + chunk = { + "key": scene_name, + "cameras": (extrinsics, intrinsics), + "images": images_list, + "scene_scale": parser.scene_scale * 1.1 + } + + return chunk + + def _process_scene(self, chunk: dict): + """Process a single scene chunk and yield examples.""" + extrinsics, intrinsics = chunk["cameras"] + scene = chunk["key"] + + out_data = self.view_sampler.sample( + scene, + extrinsics, + intrinsics, + ) + + context_indices, target_indices = out_data[:2] + + c_list = [context_indices] + t_list = [target_indices] + for context_indices, target_indices in zip(c_list, t_list): + # Load the images + context_images = [ + chunk["images"][index.item()] for index in context_indices + ] + context_images = torch.stack(context_images).float() / 255.0 + + target_images = [ + chunk["images"][index.item()] for index in target_indices + ] + target_images = torch.stack(target_images).float() / 255.0 + + example_out = { + "context": { + "extrinsics": extrinsics[context_indices], + "intrinsics": intrinsics[context_indices], + "image": context_images, + "near": self.get_bound("near", len(context_indices)), + "far": self.get_bound("far", len(context_indices)), + "index": context_indices, + "scene_scale": chunk["scene_scale"], + }, + "target": { + "extrinsics": extrinsics[target_indices], + "intrinsics": intrinsics[target_indices], + "image": target_images, + "near": self.get_bound("near", len(target_indices)), + "far": self.get_bound("far", len(target_indices)), + "index": target_indices, + "scene_scale": chunk["scene_scale"], + }, + "scene": scene, + } + + if self.cfg.crop_size is not None: + example_out = apply_patch_shim(example_out, self.cfg.crop_size) + + yield example_out + + def __iter__(self): + # Handle multiple workers - each worker should only process a subset of scenes + worker_info = torch.utils.data.get_worker_info() + if self.stage == "test" and worker_info is not None: + # Split scenes among workers + scene_names = [ + scene_name + for scene_index, scene_name in enumerate(self.scene_names) + if scene_index % worker_info.num_workers == worker_info.id + ] + else: + scene_names = self.scene_names + + # Iterate over assigned scenes + test_scene_counter = 0 + for i, scene_name in enumerate(scene_names): + # Skip scenes before test_start_idx (for scene-chunked SLURM jobs) + if self.stage == "test" and test_scene_counter < self.cfg.test_start_idx: + test_scene_counter += 1 + continue + test_scene_counter += 1 + + # Load the scene data + chunk = self._load_scene(scene_name) + # Process and yield examples from this scene + yield from self._process_scene(chunk) + + def get_bound( + self, + bound: Literal["near", "far"], + num_views: int, + ) -> Float[Tensor, " view"]: + value = torch.tensor(getattr(self, bound), dtype=torch.float32) + return repeat(value, "-> v", v=num_views) + + def __len__(self) -> int: + return len(self.scene_names) diff --git a/optgs/dataset/dataset_dl3dv.py b/optgs/dataset/dataset_dl3dv.py new file mode 100644 index 0000000000000000000000000000000000000000..87f4f78ead697ed0baba2bba86b8ab29b6d32ac1 --- /dev/null +++ b/optgs/dataset/dataset_dl3dv.py @@ -0,0 +1,640 @@ +import json +from dataclasses import dataclass +from functools import cached_property +from io import BytesIO +from pathlib import Path +from typing import Literal, Optional + +import torch +import torchvision.transforms as tf +from einops import rearrange, repeat +from jaxtyping import Float, UInt8 +from PIL import Image +from torch import Tensor +from torch.utils.data import IterableDataset +import numpy as np +import os +import random + +from ..geometry.projection import get_fov +from .dataset import DatasetCfgCommon +from .shims.augmentation_shim import apply_augmentation_shim +from .shims.crop_shim import apply_crop_shim +from .data_types import Stage +from .view_sampler import ViewSampler + + +@dataclass +class DatasetDL3DVCfg(DatasetCfgCommon): + name: Literal["dl3dv"] + roots: list[Path] + baseline_epsilon: float + max_fov: float + make_baseline_1: bool + augment: bool + test_len: int + test_chunk_interval: int + train_times_per_scene: int + test_times_per_scene: int + ori_image_shape: list[int] + # random crop training + random_crop: bool + max_size: list[int] | None + min_size: list[int] | None + + skip_bad_shape: bool = True + near: float = -1.0 + far: float = -1.0 + baseline_scale_bounds: bool = True + shuffle_val: bool = True + no_mix_test_set: bool = True + load_depth: bool = False + min_views: int = 0 + max_views: int = 0 + highres: bool = False + sort_target_index: Optional[bool] = False + overfit_max_views: Optional[int] = None + sort_context_index: Optional[bool] = False + use_index_to_load_chunk: Optional[bool] = False + pose_align_first_view: bool = False # align the camera pose to the first view + scale_extrinsics: float = 1. + metric_scale_align_dl3dv: bool = False + center_pose: bool = False # center and normalize the pose by the distance to the center + + # mix re10k & dl3dv + mix_re10k: bool = False + re10k_min_view_dist: int = 40 + re10k_max_view_dist: int = 300 + + # load remaining context views + load_remain_context: bool = False + num_remain_context: int = 8 + + index_name: str = "index.json" + + +class DatasetDL3DV(IterableDataset): + cfg: DatasetDL3DVCfg + stage: Stage + view_sampler: ViewSampler + + to_tensor: tf.ToTensor + chunks: list[Path] + near: float = 0.1 + far: float = 1000.0 + + def __init__( + self, + cfg: DatasetDL3DVCfg, + stage: Stage, + view_sampler: ViewSampler, + ) -> None: + super().__init__() + + self.cfg = cfg + self.stage = stage + self.view_sampler = view_sampler + self.to_tensor = tf.ToTensor() + if cfg.near != -1: + self.near = cfg.near + if cfg.far != -1: + self.far = cfg.far + + # Collect chunks. + self.chunks = [] + for i, root in enumerate(cfg.roots): + root = root / self.data_stage + if self.cfg.use_index_to_load_chunk: + with open(root / self.cfg.index_name, "r") as f: + json_dict = json.load(f) + root_chunks = sorted(list(set(json_dict.values()))) + else: + root_chunks = sorted( + [path for path in root.iterdir() if path.suffix == ".torch"] + ) + + # mixed data training only evaluate on a single test set + if cfg.no_mix_test_set and self.data_stage in ['val', 'test'] and i > 0: + continue + + # balance the datasets for mixed dataset training + # for gs: mix re10k, dl3dv + if len(cfg.roots) > 1 and self.data_stage == 'train': + if 'dl3dv' in str(root): + root_chunks = root_chunks * 8 + + self.chunks.extend(root_chunks) + if self.cfg.overfit_to_scene is not None: + chunk_path = self.index[self.cfg.overfit_to_scene] + self.chunks = [chunk_path] * len(self.chunks) + if self.stage == "test": + # fast testing + self.chunks = self.chunks[:: cfg.test_chunk_interval] + if self.stage == "val": + self.chunks = self.chunks * int(1e6 // len(self.chunks)) + + if self.cfg.metric_scale_align_dl3dv: + # read invalid scales + scale_dir = '/cluster/project/cvg/haofei/datasets/depthsplat/dl3dv_metric_scale_factor' + filename = os.path.join(scale_dir, 'dl3dv_invalid.txt') + with open(filename, "r") as f: + self.invalid_scale_scenes = [line.strip() for line in f] + + # Calculate actual number of scenes (keys of index file that their chunks exist). + num_available_scenes = 0 + for scene, chunk_path in self.index.items(): + if chunk_path in self.chunks: + num_available_scenes += 1 + self.num_available_scenes = num_available_scenes + + def shuffle(self, lst: list) -> list: + indices = torch.randperm(len(lst)) + return [lst[x] for x in indices] + + def _get_scale_factor(self, scene: str) -> float: + """Get the scale factor for a scene.""" + if self.cfg.metric_scale_align_dl3dv: + scale_dir = '/cluster/project/cvg/haofei/datasets/depthsplat/dl3dv_metric_scale_factor' + if self.stage == 'train': + folder = scene.split('_')[1] + else: + folder = scene + filename = os.path.join(scale_dir, folder, 'scale_factor.txt') + + if not os.path.exists(filename) or folder in self.invalid_scale_scenes: + return self.cfg.scale_extrinsics + else: + with open(filename, "r") as f: + return float(f.read().strip()) + else: + return self.cfg.scale_extrinsics + + def _process_example_to_batch( + self, + example: dict, + extrinsics: Tensor, + intrinsics: Tensor, + context_indices: Tensor, + target_indices: Tensor, + scale_factor: float, + ) -> Optional[dict]: + """ + Process an example into a batch dict (original behavior). + Returns None if the example should be skipped. + """ + scene = example["key"] + + # Load remaining context views if configured + if self.cfg.load_remain_context: + remaining_indices = get_remaining_indices( + context_indices, target_indices, self.cfg.num_remain_context + ) + remain_context_images = [ + example["images"][index.item()] for index in remaining_indices + ] + try: + remain_context_images = self.convert_images(remain_context_images) + except OSError: + return None + + # Load context images + context_images = [ + example["images"][index.item()] for index in context_indices + ] + try: + context_images = self.convert_images(context_images) + except OSError: + return None + + # Load target images + target_images = [ + example["images"][index.item()] for index in target_indices + ] + try: + target_images = self.convert_images(target_images) + except OSError: + return None + + # Validate image shapes + if self.cfg.mix_re10k and 'dl3dv' not in scene: + if self.cfg.highres: + expected_shape = (3, 720, 1280) + else: + expected_shape = (3, 360, 640) + else: + expected_shape = tuple([3, *self.cfg.ori_image_shape]) + + if self.stage in ['test', 'val'] or 'dl3dv' in scene: + expected_shape = tuple([3, *self.cfg.ori_image_shape]) + + if self.cfg.skip_bad_shape: + if context_images.shape[1:] != expected_shape or target_images.shape[1:] != expected_shape: + print( + f"Skipped bad example {scene}. Context shape was " + f"{context_images.shape}, target shape was " + f"{target_images.shape}, and expected shape was {expected_shape}" + ) + return None + + if self.cfg.load_remain_context and remain_context_images.shape[1:] != expected_shape: + return None + + # Apply pose transformations + if self.cfg.pose_align_middle_view: + mid_index = context_indices.shape[0] // 2 + extrinsics = camera_normalization( + extrinsics[context_indices][mid_index:mid_index + 1], extrinsics + ) + + if self.cfg.pose_align_first_view: + extrinsics = camera_normalization(extrinsics[context_indices][0:1], extrinsics) + + if self.cfg.center_pose: + extrinsics = center_norm_pose(extrinsics) + + # Validate extrinsics + if any(torch.isnan(torch.det(extrinsics[context_indices][:, :3, :3]))): + return None + if any(torch.isnan(torch.det(extrinsics[target_indices][:, :3, :3]))): + return None + if (extrinsics[context_indices][:, :3, 3].abs() > 1e3).any(): + return None + if (extrinsics[target_indices][:, :3, 3].abs() > 1e3).any(): + return None + if not torch.allclose( + torch.det(extrinsics[context_indices][:, :3, :3]), + extrinsics[context_indices][:, :3, :3].new_tensor(1) + ): + return None + if not torch.allclose( + torch.det(extrinsics[target_indices][:, :3, :3]), + extrinsics[target_indices][:, :3, :3].new_tensor(1) + ): + return None + + if self.cfg.load_remain_context: + if any(torch.isnan(torch.det(extrinsics[remaining_indices][:, :3, :3]))): + return None + if (extrinsics[remaining_indices][:, :3, 3] > 1e3).any(): + return None + if not torch.allclose( + torch.det(extrinsics[remaining_indices][:, :3, :3]), + extrinsics[remaining_indices][:, :3, :3].new_tensor(1) + ): + return None + + # Apply scale factor + extrinsics[:, :3, 3] *= scale_factor + + # Build output + example_out = { + "context": { + "extrinsics": extrinsics[context_indices], + "intrinsics": intrinsics[context_indices], + "image": context_images, + "near": self.get_bound("near", len(context_indices)), + "far": self.get_bound("far", len(context_indices)), + "index": context_indices, + }, + "target": { + "extrinsics": extrinsics[target_indices], + "intrinsics": intrinsics[target_indices], + "image": target_images, + "near": self.get_bound("near", len(target_indices)), + "far": self.get_bound("far", len(target_indices)), + "index": target_indices, + }, + "scene": scene, + } + + if self.cfg.load_remain_context: + example_out["context_remain"] = { + "extrinsics": extrinsics[remaining_indices], + "intrinsics": intrinsics[remaining_indices], + "image": remain_context_images, + "near": self.get_bound("near", len(remaining_indices)), + "far": self.get_bound("far", len(remaining_indices)), + "index": remaining_indices, + } + + return example_out + + def __iter__(self): + # Chunks must be shuffled here (not inside __init__) for validation to show + # random chunks. + if self.stage in (("train", "val") if self.cfg.shuffle_val else ("train")): + self.chunks = self.shuffle(self.chunks) + + # When testing, the data loaders alternate chunks. + worker_info = torch.utils.data.get_worker_info() + if self.stage == "test" and worker_info is not None: + self.chunks = [ + chunk + for chunk_index, chunk in enumerate(self.chunks) + if chunk_index % worker_info.num_workers == worker_info.id + ] + + # Counter for skipping the first test_start_idx scenes (used for scene-chunked SLURM jobs). + test_scene_counter = 0 + + # Iterate over chunks. + for chunk_path in self.chunks: + # Load the chunk. + chunk = torch.load(chunk_path) + + if self.cfg.overfit_to_scene is not None: + item = [x for x in chunk if x["key"] == self.cfg.overfit_to_scene] + assert len(item) == 1 + if self.stage == "test": + chunk = item + else: + chunk = item * len(chunk) + + if self.stage in (("train", "val") if self.cfg.shuffle_val else ("train")): + chunk = self.shuffle(chunk) + + times_per_scene = ( + self.cfg.test_times_per_scene + if self.stage == "test" + else self.cfg.train_times_per_scene + ) + + # Iterate over examples in the chunk. + for run_idx in range(int(times_per_scene * len(chunk))): + example = chunk[run_idx // times_per_scene] + + if example["key"] not in self.index: + continue + + extrinsics, intrinsics = self.convert_poses(example["cameras"]) + scene = example["key"] + + # Skip if field of view is too wide + if (get_fov(intrinsics).rad2deg() > self.cfg.max_fov).any(): + continue + + scale_factor = self._get_scale_factor(scene) + + try: + extra_kwargs = {} + if self.cfg.overfit_to_scene is not None and self.stage != "test": + extra_kwargs["max_num_views"] = ( + 148 if self.cfg.overfit_max_views is None + else self.cfg.overfit_max_views + ) + + is_re10k = self.cfg.mix_re10k and 'dl3dv' not in scene and self.stage == 'train' + + out_data = self.view_sampler.sample( + scene, + extrinsics, + intrinsics, + min_context_views=self.cfg.min_views, + max_context_views=self.cfg.max_views, + min_view_dist=self.cfg.re10k_min_view_dist if is_re10k else None, + max_view_dist=self.cfg.re10k_max_view_dist if is_re10k else None, + **extra_kwargs, + ) + + if isinstance(out_data, tuple): + context_indices, target_indices = out_data[:2] + c_list = [ + context_indices.sort()[0] if self.cfg.sort_context_index else context_indices + ] + t_list = [ + target_indices.sort()[0] if self.cfg.sort_target_index else target_indices + ] + elif isinstance(out_data, list): + c_list = [ + a.context.sort()[0] if self.cfg.sort_context_index else a.context + for a in out_data + ] + t_list = [ + a.target.sort()[0] if self.cfg.sort_target_index else a.target + for a in out_data + ] + + except ValueError: + # Skip because the example doesn't have enough frames. + continue + + for context_indices, target_indices in zip(c_list, t_list): + example_out = self._process_example_to_batch( + example, extrinsics.clone(), intrinsics, + context_indices, target_indices, scale_factor + ) + + if example_out is None: + continue + + # Apply augmentation and cropping + if self.stage == "train" and self.cfg.augment: + example_out = apply_augmentation_shim(example_out) + + # Skip scenes before test_start_idx (for scene-chunked SLURM jobs) + if self.stage == "test" and test_scene_counter < self.cfg.test_start_idx: + test_scene_counter += 1 + continue + + context_images = example_out["context"]["image"] + if self.cfg.image_shape == list(context_images.shape[2:]): + yield example_out + else: + if self.stage == "train" and self.cfg.random_crop: + crop_h = random.randint(self.cfg.min_size[0], self.cfg.max_size[0] + 1) // 64 * 64 + crop_w = random.randint(self.cfg.min_size[1], self.cfg.max_size[1] + 1) // 64 * 64 + crop_size = (crop_h, crop_w) + yield apply_crop_shim(example_out, crop_size) + else: + yield apply_crop_shim(example_out, tuple(self.cfg.image_shape)) + + def convert_poses( + self, + poses: Float[Tensor, "batch 18"], + ) -> tuple[ + Float[Tensor, "batch 4 4"], # extrinsics + Float[Tensor, "batch 3 3"], # intrinsics + ]: + b, _ = poses.shape + + # Convert the intrinsics to a 3x3 normalized K matrix. + intrinsics = torch.eye(3, dtype=torch.float32) + intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone() + fx, fy, cx, cy = poses[:, :4].T + intrinsics[:, 0, 0] = fx + intrinsics[:, 1, 1] = fy + intrinsics[:, 0, 2] = cx + intrinsics[:, 1, 2] = cy + + # Convert the extrinsics to a 4x4 OpenCV-style C2W matrix. + w2c = repeat(torch.eye(4, dtype=torch.float32), + "h w -> b h w", b=b).clone() + w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4) + + if self.cfg.opencv_pose_format: + return self.opengl_to_opencv(w2c.inverse()), intrinsics + else: + return w2c.inverse(), intrinsics + + def opengl_to_opencv(self, c2w): + # https://github.com/DL3DV-10K/Dataset/issues/4#issuecomment-2019441741 + blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + blender2opencv = torch.tensor(blender2opencv, dtype=c2w.dtype, device=c2w.device).unsqueeze(0) + c2w = torch.matmul(c2w, blender2opencv) + c2w[:, 2, :] *= -1 + c2w = c2w[:, torch.tensor(np.array([1, 0, 2, 3])), :] + c2w[:, 0:3, 1:3] *= -1 + + return c2w + + def convert_images( + self, + images: list[UInt8[Tensor, "..."]], + ) -> Float[Tensor, "batch 3 height width"]: + torch_images = [] + for image in images: + image = Image.open(BytesIO(image.numpy().tobytes())) + torch_images.append(self.to_tensor(image)) + return torch.stack(torch_images) + + def get_bound( + self, + bound: Literal["near", "far"], + num_views: int, + ) -> Float[Tensor, " view"]: + value = torch.tensor(getattr(self, bound), dtype=torch.float32) + return repeat(value, "-> v", v=num_views) + + @property + def data_stage(self) -> Stage: + if self.cfg.overfit_to_scene is not None: + return "test" + if self.stage == "val": + return "test" + return self.stage + + @cached_property + def index(self) -> dict[str, Path]: + merged_index = {} + data_stages = [self.data_stage] + if self.cfg.overfit_to_scene is not None: + data_stages = ("test", "train") + for data_stage in data_stages: + for i, root in enumerate(self.cfg.roots): + if not (root / data_stage).is_dir(): + continue + + # Load the root's index. + with (root / data_stage / self.cfg.index_name).open("r") as f: + index = json.load(f) + index = {k: Path(root / data_stage / v) + for k, v in index.items()} + + # The constituent datasets should have unique keys. + assert not (set(merged_index.keys()) & set(index.keys())) + + # mixed data training only evaluate on a single test set + if self.cfg.no_mix_test_set and data_stage == 'test' and i > 0: + continue + + # Merge the root's index into the main index. + merged_index = {**merged_index, **index} + return merged_index + + def __len__(self) -> int: + """Calculate dataset length based on stage and configuration.""" + if self.stage == "test": + test_length = self.num_available_scenes * self.cfg.test_times_per_scene + if self.cfg.test_len > 0: + return min(test_length, self.cfg.test_len) + else: + return test_length + elif self.stage == "train": + return self.num_available_scenes * self.cfg.train_times_per_scene + else: + # Validation stage: return large value to prevent exhaustion + # The actual length will be wrapped to 1 by the dataloader + return int(1e10) + + +def camera_normalization(pivotal_pose: torch.Tensor, poses: torch.Tensor): + # [1, 4, 4], [N, 4, 4] + + # Manually calculate the inverse of SE(3) to avoid numerical issues + R = pivotal_pose[:, :3, :3] # [1, 3, 3] + t = pivotal_pose[:, :3, 3:] # [1, 3, 1] + R_inv = R.transpose(-1, -2) # [1, 3, 3] + t_inv = -R_inv @ t # [1, 3, 1] + camera_norm_matrix_manuall = torch.eye(4, dtype=poses.dtype, device=poses.device).unsqueeze(0) # [1, 4, 4] + camera_norm_matrix_manuall[:, :3, :3] = R_inv + camera_norm_matrix_manuall[:, :3, 3:] = t_inv + + # normalize all views + normalized_poses = camera_norm_matrix_manuall @ poses # [N, 4, 4] + + return normalized_poses + + +def center_norm_pose(extrinsics): + # extrinsics: [V, 4, 4] + cam_centers = extrinsics[:, :3, 3] # [V, 3] + avg_center = cam_centers.mean(dim=0, keepdim=True) # [1, 3] + dist = (cam_centers - avg_center).norm(dim=1, keepdim=True) # [V, 1] + scale = dist.max() + + # translate + extrinsics = extrinsics.clone() + extrinsics[:, :3, 3] -= avg_center + extrinsics[:, :3, 3] /= scale + + return extrinsics + + +def get_remaining_indices(context_indices: torch.Tensor, + target_indices: torch.Tensor, + num_remain_context: int) -> torch.Tensor: + """ + Randomly selects a fixed number of remaining indices in the range [min(context), max(context)], + excluding those in context or target. Pads by repeating if not enough remain. + + Args: + context_indices (torch.Tensor): 1D tensor of context indices. + target_indices (torch.Tensor): 1D tensor of target indices. + num_remain_context (int): Number of remaining indices to return. + + Returns: + torch.Tensor: 1D tensor of length `num_remain_context`. + """ + if context_indices.numel() == 0: + raise ValueError("context_indices must not be empty.") + + min_idx = torch.min(context_indices).item() + max_idx = torch.max(context_indices).item() + + full_range = torch.arange(min_idx, max_idx + 1, dtype=torch.long) + exclude_indices = torch.cat([context_indices, target_indices]) + mask = ~torch.isin(full_range, exclude_indices) + + remaining = full_range[mask] + + if remaining.numel() == 0: + # Nothing to sample from; repeat the first context index (or any fallback) + return context_indices[0].repeat(num_remain_context) + + # return all + selected = remaining + + # Randomly sample with or without replacement + # if remaining.numel() >= num_remain_context: + # selected = remaining[torch.randperm(remaining.numel())[:num_remain_context]] + # else: + # # return all + # selected = remaining + # # # Repeat with wrap-around to pad + # # num_repeat = (num_remain_context + remaining.numel() - 1) // remaining.numel() + # # padded = remaining.repeat(num_repeat)[:num_remain_context] + # # selected = padded[torch.randperm(num_remain_context)] # Shuffle for randomness + + return selected.sort().values diff --git a/optgs/dataset/dataset_re10k.py b/optgs/dataset/dataset_re10k.py new file mode 100644 index 0000000000000000000000000000000000000000..87ba95471b704b1ee0fb70e957c72c5803a527c6 --- /dev/null +++ b/optgs/dataset/dataset_re10k.py @@ -0,0 +1,464 @@ +import json +from dataclasses import dataclass +from functools import cached_property +from io import BytesIO +from pathlib import Path +from typing import Literal, Optional + +import numpy as np +import torch +import torchvision.transforms as tf +import torch.nn.functional as F +from einops import rearrange, repeat +from jaxtyping import Float, UInt8 +from PIL import Image +from torch import Tensor +from torch.utils.data import IterableDataset +import cv2 + +from ..geometry.projection import get_fov +from .dataset import DatasetCfgCommon +from .shims.augmentation_shim import apply_augmentation_shim +from .shims.crop_shim import apply_crop_shim +from .data_types import Stage +from .view_sampler import ViewSampler +from .dataset_dl3dv import get_remaining_indices + + +@dataclass +class DatasetRE10kCfg(DatasetCfgCommon): + name: Literal["re10k"] + roots: list[Path] + baseline_epsilon: float + max_fov: float + make_baseline_1: bool + augment: bool + test_len: int + test_chunk_interval: int + average_pose: bool + skip_bad_shape: bool = True + near: float = -1.0 + far: float = -1.0 + baseline_scale_bounds: bool = True + shuffle_val: bool = True + train_times_per_scene: int = 1 + highres: bool = False + scannet: bool = False + tartanair: bool = False + use_index_to_load_chunk: Optional[bool] = False + load_depth: bool = False + pose_align_first_view: bool = False # align the camera pose to the first view + center_pose: bool = False # center and normalize the pose by the distance to the center + + scale_extrinsics: float = 1. + + # load remaining context views + load_remain_context: bool = False + + +class DatasetRE10k(IterableDataset): + cfg: DatasetRE10kCfg + stage: Stage + view_sampler: ViewSampler + + to_tensor: tf.ToTensor + chunks: list[Path] + near: float = 0.1 + far: float = 1000.0 + + def __init__( + self, + cfg: DatasetRE10kCfg, + stage: Stage, + view_sampler: ViewSampler, + ) -> None: + super().__init__() + self.cfg = cfg + self.stage = stage + self.view_sampler = view_sampler + self.to_tensor = tf.ToTensor() + if cfg.near != -1: + self.near = cfg.near + if cfg.far != -1: + self.far = cfg.far + + # Collect chunks. + self.chunks = [] + for i, root in enumerate(cfg.roots): + root = root / self.data_stage + if self.cfg.use_index_to_load_chunk: + with open(root / "index.json", "r") as f: + json_dict = json.load(f) + root_chunks = sorted(list(set(json_dict.values()))) + else: + root_chunks = sorted( + [path for path in root.iterdir() if path.suffix == ".torch"] + ) + + self.chunks.extend(root_chunks) + if self.cfg.overfit_to_scene is not None: + chunk_path = self.index[self.cfg.overfit_to_scene] + self.chunks = [chunk_path] * len(self.chunks) + if self.stage == "test": + # testing on a subset for fast speed + self.chunks = self.chunks[::cfg.test_chunk_interval] + + def shuffle(self, lst: list) -> list: + indices = torch.randperm(len(lst)) + return [lst[x] for x in indices] + + def __iter__(self): + # Chunks must be shuffled here (not inside __init__) for validation to show + # random chunks. + if self.stage in (("train", "val") if self.cfg.shuffle_val else ("train")): + self.chunks = self.shuffle(self.chunks) + + # When testing, the data loaders alternate chunks. + worker_info = torch.utils.data.get_worker_info() + if self.stage == "test" and worker_info is not None: + self.chunks = [ + chunk + for chunk_index, chunk in enumerate(self.chunks) + if chunk_index % worker_info.num_workers == worker_info.id + ] + + # Counter for skipping the first test_start_idx scenes (used for scene-chunked SLURM jobs). + test_scene_counter = 0 + + for chunk_path in self.chunks: + # Load the chunk. + chunk = torch.load(chunk_path) + + if self.cfg.overfit_to_scene is not None: + item = [x for x in chunk if x["key"] == self.cfg.overfit_to_scene] + assert len(item) == 1 + chunk = item * len(chunk) + + if self.stage in (("train", "val") if self.cfg.shuffle_val else ("train")): + chunk = self.shuffle(chunk) + + times_per_scene = ( + 1 + if self.stage == "test" + else self.cfg.train_times_per_scene + ) + + for run_idx in range(int(times_per_scene * len(chunk))): + example = chunk[run_idx // times_per_scene] + extrinsics, intrinsics = self.convert_poses(example["cameras"]) + scene = example["key"] + + try: + context_indices, target_indices = self.view_sampler.sample( + scene, + extrinsics, + intrinsics, + ) + except ValueError: + # Skip because the example doesn't have enough frames. + print(f"Skipped example {example['key']} due to not enough frames.") + continue + + # Skip the example if the field of view is too wide. + if (get_fov(intrinsics).rad2deg() > self.cfg.max_fov).any(): + continue + + # load remaining context views + if self.cfg.load_remain_context: + # randomly select fixed number of remaining views such that they can be batched + remaining_indices = get_remaining_indices(context_indices, target_indices, + 0) + + # Load the images. + remain_context_images = [ + example["images"][index.item()] for index in remaining_indices + ] + + try: + remain_context_images = self.convert_images(remain_context_images) + except OSError: + # some data might be corrupted + continue + + # Load the images. + context_images = [ + example["images"][index.item()] for index in context_indices + ] + context_images = self.convert_images(context_images) + target_images = [ + example["images"][index.item()] for index in target_indices + ] + target_images = self.convert_images(target_images) + + # Skip the example if the images don't have the right shape. + if self.cfg.highres: + expected_shape = (3, 720, 1280) + elif self.cfg.scannet or self.cfg.tartanair: + expected_shape = (3, 480, 640) + else: + expected_shape = (3, 360, 640) + context_image_invalid = context_images.shape[1:] != expected_shape + target_image_invalid = target_images.shape[1:] != expected_shape + if self.cfg.skip_bad_shape and (context_image_invalid or target_image_invalid): + print( + f"Skipped bad example {example['key']}. Context shape was " + f"{context_images.shape} and target shape was " + f"{target_images.shape}." + ) + continue + + if self.cfg.load_remain_context: + remain_context_invalid = remain_context_images.shape[1:] != expected_shape + + if self.cfg.skip_bad_shape and remain_context_invalid: + continue + + # check the extrinsics + if any(torch.isnan(torch.det(extrinsics[context_indices][:, :3, :3]))): + continue + + if any(torch.isnan(torch.det(extrinsics[target_indices][:, :3, :3]))): + continue + + if self.cfg.average_pose: + extrinsics = self.preprocess_poses(extrinsics) + + # load depth + if self.cfg.load_depth: + context_depths = [ + example["depths"][index.item()] for index in context_indices + ] + if self.cfg.scannet: + context_depths = self.convert_scannet_depths(context_depths) + elif self.cfg.tartanair: + context_depths = self.convert_tartanair_depths(context_depths) + else: + raise NotImplementedError + + target_depths = [ + example["depths"][index.item()] for index in target_indices + ] + if self.cfg.scannet: + target_depths = self.convert_scannet_depths(target_depths) + elif self.cfg.tartanair: + target_depths = self.convert_tartanair_depths(target_depths) + else: + raise NotImplementedError + + # align pose to the first view + if self.cfg.pose_align_first_view: + extrinsics = camera_normalization(extrinsics[context_indices][0:1], extrinsics) + + if self.cfg.center_pose: + extrinsics = center_norm_pose(extrinsics) + + # scale the scene when necessary: only scale the extrinsics + extrinsics[:, :3, 3] *= self.cfg.scale_extrinsics + + example = { + "context": { + "extrinsics": extrinsics[context_indices], + "intrinsics": intrinsics[context_indices], + "image": context_images, + "near": self.get_bound("near", len(context_indices)), + "far": self.get_bound("far", len(context_indices)), + "index": context_indices, + }, + "target": { + "extrinsics": extrinsics[target_indices], + "intrinsics": intrinsics[target_indices], + "image": target_images, + "near": self.get_bound("near", len(target_indices)), + "far": self.get_bound("far", len(target_indices)), + "index": target_indices, + }, + "scene": scene, + } + + if self.cfg.load_remain_context: + example.update({ + "context_remain": { + "extrinsics": extrinsics[remaining_indices], + "intrinsics": intrinsics[remaining_indices], + "image": remain_context_images, + "near": self.get_bound("near", len(remaining_indices)), + "far": self.get_bound("far", len(remaining_indices)), + "index": remaining_indices, + } + } + ) + + if self.cfg.load_depth: + example['context']['depth'] = context_depths + example['target']['depth'] = target_depths + + if self.stage == "train" and self.cfg.augment: + example = apply_augmentation_shim(example) + + # Skip scenes before test_start_idx (for scene-chunked SLURM jobs) + if self.stage == "test" and test_scene_counter < self.cfg.test_start_idx: + test_scene_counter += 1 + print(f"Skipping test example {example['scene']} because test_start_idx is {self.cfg.test_start_idx}") + continue + + yield apply_crop_shim(example, tuple(self.cfg.image_shape)) + + def convert_poses( + self, + poses: Float[Tensor, "batch 18"], + ) -> tuple[ + Float[Tensor, "batch 4 4"], # extrinsics + Float[Tensor, "batch 3 3"], # intrinsics + ]: + b, _ = poses.shape + + # Convert the intrinsics to a 3x3 normalized K matrix. + intrinsics = torch.eye(3, dtype=torch.float32) + intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone() + fx, fy, cx, cy = poses[:, :4].T + intrinsics[:, 0, 0] = fx + intrinsics[:, 1, 1] = fy + intrinsics[:, 0, 2] = cx + intrinsics[:, 1, 2] = cy + + # Convert the extrinsics to a 4x4 OpenCV-style C2W matrix. + w2c = repeat(torch.eye(4, dtype=torch.float32), "h w -> b h w", b=b).clone() + w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4) + return w2c.inverse(), intrinsics + + def convert_images( + self, + images: list[UInt8[Tensor, "..."]], + ) -> Float[Tensor, "batch 3 height width"]: + torch_images = [] + for image in images: + image = Image.open(BytesIO(image.numpy().tobytes())) + torch_images.append(self.to_tensor(image)) + return torch.stack(torch_images) + + def convert_scannet_depths( + self, + depths: list[UInt8[Tensor, "..."]] | list[Tensor], + ) -> Float[Tensor, "batch height width"]: + torch_depths = [] + for depth in depths: + depth = Image.open(BytesIO(depth.numpy().tobytes())) + # mm to meter depth + torch_depths.append(self.to_tensor(depth) / 1000.) + return torch.stack(torch_depths).squeeze(1) + + def convert_tartanair_depths( + self, + depths: list[UInt8[Tensor, "..."]] | list[Tensor], + ) -> Float[Tensor, "batch height width"]: + torch_depths = [] + for depth in depths: + depth = np.load(BytesIO(depth.numpy().tobytes())) + torch_depths.append(self.to_tensor(depth)) + return torch.stack(torch_depths).squeeze(1) + + def get_bound( + self, + bound: Literal["near", "far"], + num_views: int, + ) -> Float[Tensor, " view"]: + value = torch.tensor(getattr(self, bound), dtype=torch.float32) + return repeat(value, "-> v", v=num_views) + + @property + def data_stage(self) -> Stage: + if self.cfg.overfit_to_scene is not None: + return "test" + if self.stage == "val": + return "test" + return self.stage + + @cached_property + def index(self) -> dict[str, Path]: + merged_index = {} + data_stages = [self.data_stage] + if self.cfg.overfit_to_scene is not None: + data_stages = ("test", "train") + for data_stage in data_stages: + for i, root in enumerate(self.cfg.roots): + # Load the root's index. + with (root / data_stage / "index.json").open("r") as f: + index = json.load(f) + index = {k: Path(root / data_stage / v) for k, v in index.items()} + + # The constituent datasets should have unique keys. + assert not (set(merged_index.keys()) & set(index.keys())) + + # Merge the root's index into the main index. + merged_index = {**merged_index, **index} + return merged_index + + def __len__(self) -> int: + return ( + min(len(self.index.keys()), self.cfg.test_len) + if self.stage == "test" and self.cfg.test_len > 0 + else len(self.index.keys()) * self.cfg.train_times_per_scene + ) + + def preprocess_poses( + self, + in_c2ws: torch.Tensor, + scene_scale_factor=1.35, + ): + """ + Ref: https://github.com/Haian-Jin/LVSM/blob/main/data/dataset_scene.py + Preprocess the poses to: + 1. translate and rotate the scene to align the average camera direction and position + 2. rescale the whole scene to a fixed scale + """ + + # Translation and Rotation + # align coordinate system (OpenCV coordinate) to the mean camera + # center is the average of all camera centers + # average direction vectors are computed from all camera direction vectors (average down and forward) + center = in_c2ws[:, :3, 3].mean(0) + avg_forward = F.normalize(in_c2ws[:, :3, 2].mean(0), dim=-1) # average forward direction (z of opencv camera) + avg_down = in_c2ws[:, :3, 1].mean(0) # average down direction (y of opencv camera) + avg_right = F.normalize(torch.cross(avg_down, avg_forward, dim=-1), dim=-1) # (x of opencv camera) + avg_down = F.normalize(torch.cross(avg_forward, avg_right, dim=-1), dim=-1) # (y of opencv camera) + + avg_pose = torch.eye(4, device=in_c2ws.device) # average c2w matrix + avg_pose[:3, :3] = torch.stack([avg_right, avg_down, avg_forward], dim=-1) + avg_pose[:3, 3] = center + avg_pose = torch.linalg.inv(avg_pose) # average w2c matrix + in_c2ws = avg_pose @ in_c2ws + + + # Rescale the whole scene to a fixed scale + scene_scale = torch.max(torch.abs(in_c2ws[:, :3, 3])) + scene_scale = scene_scale_factor * scene_scale + + in_c2ws[:, :3, 3] /= scene_scale + + return in_c2ws + + +def camera_normalization(pivotal_pose: torch.Tensor, poses: torch.Tensor): + # [1, 4, 4], [N, 4, 4] + + camera_norm_matrix = torch.inverse(pivotal_pose) + + # normalize all views + poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses) + + return poses + + +def center_norm_pose(extrinsics): + # extrinsics: [V, 4, 4] + cam_centers = extrinsics[:, :3, 3] # [V, 3] + avg_center = cam_centers.mean(dim=0, keepdim=True) # [1, 3] + dist = (cam_centers - avg_center).norm(dim=1, keepdim=True) # [V, 1] + scale = dist.max() + + # translate + extrinsics = extrinsics.clone() + extrinsics[:, :3, 3] -= avg_center + extrinsics[:, :3, 3] /= scale + + return extrinsics diff --git a/optgs/dataset/dataset_scannet.py b/optgs/dataset/dataset_scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..b1fa5dae3f85715fceb5603d96dfb9f3eae9fd4e --- /dev/null +++ b/optgs/dataset/dataset_scannet.py @@ -0,0 +1,327 @@ +import json +from dataclasses import dataclass +from pathlib import Path +from typing import List, Literal, Optional + +import imageio +import numpy as np +import torch +import torchvision.transforms as tf +from einops import repeat +from jaxtyping import Float +from torch import Tensor +from torch.utils.data import IterableDataset + +from .data_types import Stage +from .dataset import DatasetCfgCommon +from .shims.patch_shim import apply_patch_shim +from .view_sampler import ViewSampler +from .view_sampler.view_sampler_all import ViewSamplerAll +from .view_sampler.view_sampler_dense import ViewSamplerDense +from .view_sampler.view_sampler_ids import ViewSamplerIDs + + +# OpenGL to OpenCV conversion: flip Y and Z axes +_GL_TO_CV = np.diag([1.0, -1.0, -1.0, 1.0]).astype(np.float32) + + +@dataclass +class DatasetScannetCfg(DatasetCfgCommon): + name: Literal["scannet"] + roots: Path + scene_name: Optional[str] # If None, iterate over all scenes from split + split: str # "test", "val", "train", "test_debug" -> splits/{split}_scene_ids.txt + subsample_factor: int + crop_size: None | int | list[int] + num_context_views: int # Max context views to select via FPS + filter_bad_frames: bool + + +class DatasetScannet(IterableDataset): + cfg: DatasetScannetCfg + stage: Stage + view_sampler: ViewSampler + + to_tensor: tf.ToTensor + near: float = 0.01 + far: float = 100.0 + + def __init__( + self, + cfg: DatasetScannetCfg, + stage: Stage, + view_sampler: ViewSampler, + ) -> None: + super().__init__() + + if stage == "train": + raise ValueError( + "ScanNet dataset does not support training stage. " + "Use 'test' or 'val' stage instead." + ) + + self.cfg = cfg + self.stage = stage + self.view_sampler = view_sampler + + assert isinstance(self.view_sampler, (ViewSamplerDense, ViewSamplerIDs, ViewSamplerAll)), \ + "ScanNet dataset requires ViewSamplerDense, ViewSamplerIDs, or ViewSamplerAll." + self.to_tensor = tf.ToTensor() + + # Discover available scenes + if cfg.scene_name is not None: + self.scene_names = [cfg.scene_name] + else: + self.scene_names = self._discover_scenes() + + print(f"Found {len(self.scene_names)} scene(s) for split '{cfg.split}': {self.scene_names}") + + self.image_shape = None + + @staticmethod + def _read_split_file(roots: Path, split: str) -> List[str]: + """Read scene IDs from a split file.""" + split_path = roots / "splits" / f"{split}_scene_ids.txt" + with open(split_path) as f: + return [line.strip() for line in f if line.strip()] + + def _discover_scenes(self) -> List[str]: + """Discover valid scenes: read split file, filter to scenes that exist in data/.""" + scene_ids = self._read_split_file(self.cfg.roots, self.cfg.split) + data_dir = self.cfg.roots / "data" + valid = [s for s in scene_ids if (data_dir / s).exists()] + if len(valid) < len(scene_ids): + print(f"Warning: {len(scene_ids) - len(valid)} scenes from split not found in data/") + return valid + + @staticmethod + def _fps_select(positions: np.ndarray, num_select: int) -> np.ndarray: + """Furthest point sampling on 3D camera positions. + + Greedily selects points that maximize the minimum distance to + the already-selected set, starting from the first point. + + Args: + positions: [N, 3] array of camera positions. + num_select: Number of points to select. + + Returns: + [num_select] array of selected indices. + """ + n = len(positions) + if num_select >= n: + return np.arange(n) + + selected = [0] + min_dists = np.full(n, np.inf) + + for _ in range(num_select - 1): + last = positions[selected[-1]] + dists = np.linalg.norm(positions - last, axis=1) + min_dists = np.minimum(min_dists, dists) + min_dists[selected] = -1 # exclude already selected + selected.append(int(np.argmax(min_dists))) + + return np.array(selected) + + def _parse_frames( + self, + frames: list[dict], + scene_dir: Path, + w: int, + h: int, + fl_x: float, + fl_y: float, + cx: float, + cy: float, + ) -> tuple[list[Tensor], list[Tensor], list[Tensor]]: + """Parse a list of frames into extrinsics, intrinsics, and images. + + Returns: + extrinsics_list: list of [4, 4] tensors (c2w in OpenCV convention) + intrinsics_list: list of [3, 3] tensors (normalized) + images_list: list of [C, H, W] tensors (uint8) + """ + # Build normalized intrinsic matrix (same for all frames in a scene) + K = np.array([ + [fl_x / w, 0.0, cx / w], + [0.0, fl_y / h, cy / h], + [0.0, 0.0, 1.0], + ], dtype=np.float32) + intrinsics_tensor = torch.from_numpy(K) + + extrinsics_list = [] + intrinsics_list = [] + images_list = [] + + for frame in frames: + if self.cfg.filter_bad_frames and frame.get("is_bad", False): + continue + + # Parse c2w and convert OpenGL -> OpenCV + c2w_gl = np.array(frame["transform_matrix"], dtype=np.float32) + c2w_cv = c2w_gl @ _GL_TO_CV + + extrinsics_list.append(torch.from_numpy(c2w_cv)) + intrinsics_list.append(intrinsics_tensor.clone()) + + # Load image + img_path = scene_dir / "images" / frame["file_path"] + image = imageio.imread(str(img_path))[..., :3] + + # Subsample if needed + if self.cfg.subsample_factor > 1: + factor = self.cfg.subsample_factor + image = image[::factor, ::factor] + + image = torch.from_numpy(image).permute(2, 0, 1) # [C, H, W] + images_list.append(image) + + return extrinsics_list, intrinsics_list, images_list + + def _load_scene(self, scene_name: str) -> dict: + """Load a single scene and return it in chunk format.""" + scene_dir = self.cfg.roots / "data" / scene_name + assert scene_dir.exists(), f"Scene directory {scene_dir} does not exist." + + print(f"Loading ScanNet scene '{scene_name}' from {scene_dir}") + + # Load transforms.json + transforms_path = scene_dir / "transforms.json" + with open(transforms_path) as f: + transforms = json.load(f) + + w, h = transforms["w"], transforms["h"] + fl_x, fl_y = transforms["fl_x"], transforms["fl_y"] + cx, cy = transforms["cx"], transforms["cy"] + + train_frames = transforms["frames"] + test_frames = transforms.get("test_frames", []) + + # Filter bad frames before FPS (to get correct positions) + if self.cfg.filter_bad_frames: + train_frames_valid = [f for f in train_frames if not f.get("is_bad", False)] + else: + train_frames_valid = train_frames + + # FPS on camera positions to select context views + if len(train_frames_valid) > self.cfg.num_context_views: + positions = np.array([ + np.array(f["transform_matrix"], dtype=np.float32)[:3, 3] + for f in train_frames_valid + ]) + fps_indices = self._fps_select(positions, self.cfg.num_context_views) + selected_train_frames = [train_frames_valid[i] for i in fps_indices] + print(f" FPS selected {len(selected_train_frames)}/{len(train_frames_valid)} context views") + else: + selected_train_frames = train_frames_valid + print(f" Using all {len(selected_train_frames)} context views (< {self.cfg.num_context_views})") + + # Parse context frames (selected training frames) + ctx_ext, ctx_int, ctx_imgs = self._parse_frames( + selected_train_frames, scene_dir, w, h, fl_x, fl_y, cx, cy + ) + + # Parse target frames (test frames) + tgt_ext, tgt_int, tgt_imgs = self._parse_frames( + test_frames, scene_dir, w, h, fl_x, fl_y, cx, cy + ) + + context_end_idx = len(ctx_ext) + all_ext = ctx_ext + tgt_ext + all_int = ctx_int + tgt_int + all_imgs = ctx_imgs + tgt_imgs + + extrinsics = torch.stack(all_ext, dim=0) + intrinsics = torch.stack(all_int, dim=0) + + if self.image_shape is None and len(all_imgs) > 0: + self.image_shape = [all_imgs[0].shape[1], all_imgs[0].shape[2]] + + print(f" Loaded {context_end_idx} context + {len(tgt_ext)} target views") + + return { + "key": scene_name, + "cameras": (extrinsics, intrinsics), + "images": all_imgs, + "context_end_idx": context_end_idx, + } + + def _process_scene(self, chunk: dict): + """Process a single scene chunk and yield examples.""" + extrinsics, intrinsics = chunk["cameras"] + scene = chunk["key"] + + # Delegate to view sampler to determine context/target split + context_indices, target_indices = self.view_sampler.sample( + scene, extrinsics, intrinsics, + ) + + # Assert no overlap between context and target views + context_set = set(context_indices.tolist()) + target_set = set(target_indices.tolist()) + overlap = context_set & target_set + assert len(overlap) == 0, ( + f"Scene '{scene}': {len(overlap)} target views leaked into context: {overlap}" + ) + + # Load and normalize images + context_images = torch.stack( + [chunk["images"][i.item()] for i in context_indices] + ).float() / 255.0 + + target_images = torch.stack( + [chunk["images"][i.item()] for i in target_indices] + ).float() / 255.0 + + example_out = { + "context": { + "extrinsics": extrinsics[context_indices], + "intrinsics": intrinsics[context_indices], + "image": context_images, + "near": self.get_bound("near", len(context_indices)), + "far": self.get_bound("far", len(context_indices)), + "index": context_indices, + }, + "target": { + "extrinsics": extrinsics[target_indices], + "intrinsics": intrinsics[target_indices], + "image": target_images, + "near": self.get_bound("near", len(target_indices)), + "far": self.get_bound("far", len(target_indices)), + "index": target_indices, + }, + "scene": scene, + } + + if self.cfg.crop_size is not None: + example_out = apply_patch_shim(example_out, self.cfg.crop_size) + + yield example_out + + def __iter__(self): + # Handle multiple workers - each worker should only process a subset of scenes + worker_info = torch.utils.data.get_worker_info() + if self.stage == "test" and worker_info is not None: + scene_names = [ + scene_name + for scene_index, scene_name in enumerate(self.scene_names) + if scene_index % worker_info.num_workers == worker_info.id + ] + else: + scene_names = self.scene_names + + for scene_name in scene_names: + chunk = self._load_scene(scene_name) + yield from self._process_scene(chunk) + + def get_bound( + self, + bound: Literal["near", "far"], + num_views: int, + ) -> Float[Tensor, " view"]: + value = torch.tensor(getattr(self, bound), dtype=torch.float32) + return repeat(value, "-> v", v=num_views) + + def __len__(self) -> int: + return len(self.scene_names) diff --git a/optgs/dataset/shims/__init__.py b/optgs/dataset/shims/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/dataset/shims/augmentation_shim.py b/optgs/dataset/shims/augmentation_shim.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b2d66a08e40b8257b4f6cb72e97b9741fd0795 --- /dev/null +++ b/optgs/dataset/shims/augmentation_shim.py @@ -0,0 +1,46 @@ +import torch +from jaxtyping import Float +from torch import Tensor + +from ..data_types import AnyExample, AnyViews + + +def reflect_extrinsics( + extrinsics: Float[Tensor, "*batch 4 4"], +) -> Float[Tensor, "*batch 4 4"]: + reflect = torch.eye(4, dtype=torch.float32, device=extrinsics.device) + reflect[0, 0] = -1 + return reflect @ extrinsics @ reflect + + +def reflect_views(views: AnyViews) -> AnyViews: + return { + **views, + "image": views["image"].flip(-1), + "extrinsics": reflect_extrinsics(views["extrinsics"]), + "x_flipped": True, + } + + +def apply_augmentation_shim( + example: AnyExample, + generator: torch.Generator | None = None, +) -> AnyExample: + """Randomly augment the training images.""" + # Do not augment with 50% chance. + if torch.rand(tuple(), generator=generator) < 0.5: + return example + + if "context_remain" in example: + return { + **example, + "context": reflect_views(example["context"]), + "target": reflect_views(example["target"]), + "context_remain": reflect_views(example["context_remain"]), + } + + return { + **example, + "context": reflect_views(example["context"]), + "target": reflect_views(example["target"]), + } diff --git a/optgs/dataset/shims/bounds_shim.py b/optgs/dataset/shims/bounds_shim.py new file mode 100644 index 0000000000000000000000000000000000000000..b3eb601436396b4d1bcc41642203b60693d1a15b --- /dev/null +++ b/optgs/dataset/shims/bounds_shim.py @@ -0,0 +1,80 @@ +import torch +from einops import einsum, reduce, repeat +from jaxtyping import Float +from torch import Tensor + +from ..data_types import BatchedExample + + +def compute_depth_for_disparity( + extrinsics: Float[Tensor, "batch view 4 4"], + intrinsics: Float[Tensor, "batch view 3 3"], + image_shape: tuple[int, int], + disparity: float, + delta_min: float = 1e-6, # This prevents motionless scenes from lacking depth. +) -> Float[Tensor, " batch"]: + """Compute the depth at which moving the maximum distance between cameras + corresponds to the specified disparity (in pixels). + """ + + # Use the furthest distance between cameras as the baseline. + origins = extrinsics[:, :, :3, 3] + deltas = (origins[:, None, :, :] - origins[:, :, None, :]).norm(dim=-1) + deltas = deltas.clip(min=delta_min) + baselines = reduce(deltas, "b v ov -> b", "max") + + # Compute a single pixel's size at depth 1. + h, w = image_shape + pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) + pixel_size = einsum( + intrinsics[..., :2, :2].inverse(), pixel_size, "... i j, j -> ... i" + ) + + # This wouldn't make sense with non-square pixels, but then again, non-square pixels + # don't make much sense anyway. + mean_pixel_size = reduce(pixel_size, "b v xy -> b", "mean") + + return baselines / (disparity * mean_pixel_size) + + +def apply_bounds_shim( + batch: BatchedExample, + near_disparity: float, + far_disparity: float, +) -> BatchedExample: + """Compute reasonable near and far planes (lower and upper bounds on depth). This + assumes that all of an example's views are of roughly the same thing. + """ + + context = batch["context"] + _, cv, _, h, w = context["image"].shape + + # Compute near and far planes using the context views. + near = compute_depth_for_disparity( + context["extrinsics"], + context["intrinsics"], + (h, w), + near_disparity, + ) + far = compute_depth_for_disparity( + context["extrinsics"], + context["intrinsics"], + (h, w), + far_disparity, + ) + + target = batch["target"] + _, tv, _, _, _ = target["image"].shape + return { + **batch, + "context": { + **context, + "near": repeat(near, "b -> b v", v=cv), + "far": repeat(far, "b -> b v", v=cv), + }, + "target": { + **target, + "near": repeat(near, "b -> b v", v=tv), + "far": repeat(far, "b -> b v", v=tv), + }, + } diff --git a/optgs/dataset/shims/crop_shim.py b/optgs/dataset/shims/crop_shim.py new file mode 100644 index 0000000000000000000000000000000000000000..1f6571b837f784138be58fae8b44d23441ce48ce --- /dev/null +++ b/optgs/dataset/shims/crop_shim.py @@ -0,0 +1,143 @@ +import numpy as np +import torch +from einops import rearrange +from jaxtyping import Float +from PIL import Image +from torch import Tensor +import torch.nn.functional as F + +from ..data_types import AnyExample, AnyViews + + +def rescale( + image: Float[Tensor, "3 h_in w_in"], + shape: tuple[int, int], +) -> Float[Tensor, "3 h_out w_out"]: + h, w = shape + image_new = (image * 255).clip(min=0, max=255).type(torch.uint8) + image_new = rearrange(image_new, "c h w -> h w c").detach().cpu().numpy() + image_new = Image.fromarray(image_new) + image_new = image_new.resize((w, h), Image.LANCZOS) + image_new = np.array(image_new) / 255 + image_new = torch.tensor(image_new, dtype=image.dtype, device=image.device) + return rearrange(image_new, "h w c -> c h w") + + +def center_crop( + images: Float[Tensor, "*#batch c h w"], + intrinsics: Float[Tensor, "*#batch 3 3"], + shape: tuple[int, int], + depths: None | Float[Tensor, "*#batch h w"], +) -> ( + tuple[ + Float[Tensor, "*#batch c h_out w_out"], # updated images + Float[Tensor, "*#batch 3 3"], # updated intrinsics + ] + | tuple[ + Float[Tensor, "*#batch c h_out w_out"], # updated images + Float[Tensor, "*#batch 3 3"], # updated intrinsics + Float[Tensor, "*#batch h_out w_out"], # updated depths + ] +): + *_, h_in, w_in = images.shape + h_out, w_out = shape + + # Note that odd input dimensions induce half-pixel misalignments. + row = (h_in - h_out) // 2 + col = (w_in - w_out) // 2 + + # Center-crop the image. + images = images[..., :, row : row + h_out, col : col + w_out] + + # Adjust the intrinsics to account for the cropping. + intrinsics = intrinsics.clone() + intrinsics[..., 0, 0] *= w_in / w_out # fx + intrinsics[..., 1, 1] *= h_in / h_out # fy + + if depths is not None: + depths = depths[..., :, row : row + h_out, col : col + w_out] + return images, intrinsics, depths + + return images, intrinsics + + +def rescale_and_crop( + images: Float[Tensor, "*#batch c h w"], + intrinsics: Float[Tensor, "*#batch 3 3"], + shape: tuple[int, int], + depths: None | Float[Tensor, "*#batch h w"], +) -> ( + tuple[ + Float[Tensor, "*#batch c h_out w_out"], # updated images + Float[Tensor, "*#batch 3 3"], # updated intrinsics + ] + | tuple[ + Float[Tensor, "*#batch c h_out w_out"], # updated images + Float[Tensor, "*#batch 3 3"], # updated intrinsics + Float[Tensor, "*#batch h_out w_out"], # updated depths + ] +): + *_, h_in, w_in = images.shape + h_out, w_out = shape + assert h_out <= h_in and w_out <= w_in + + scale_factor = max(h_out / h_in, w_out / w_in) + h_scaled = round(h_in * scale_factor) + w_scaled = round(w_in * scale_factor) + assert h_scaled == h_out or w_scaled == w_out + + # Reshape the images to the correct size. Assume we don't have to worry about + # changing the intrinsics based on how the images are rounded. + *batch, c, h, w = images.shape + images = images.reshape(-1, c, h, w) + images = torch.stack([rescale(image, (h_scaled, w_scaled)) for image in images]) + images = images.reshape(*batch, c, h_scaled, w_scaled) + + # reshape and crop depth as well when available + if depths is not None: + depths = F.interpolate( + depths.unsqueeze(1), + size=(h_scaled, w_scaled), + mode="bilinear", + align_corners=True, + ).squeeze(1) + + return center_crop(images, intrinsics, shape, depths=depths) + + +def apply_crop_shim_to_views(views: AnyViews, shape: tuple[int, int]) -> AnyViews: + depths = views["depth"] if "depth" in views else None + if depths is not None: + images, intrinsics, depths = rescale_and_crop(views["image"], views["intrinsics"], shape, + depths=depths) + return { + **views, + "image": images, + "depth": depths, + "intrinsics": intrinsics, + } + else: + images, intrinsics = rescale_and_crop(views["image"], views["intrinsics"], shape, + depths=None) + return { + **views, + "image": images, + "intrinsics": intrinsics, + } + + +def apply_crop_shim(example: AnyExample, shape: tuple[int, int]) -> AnyExample: + """Crop images in the example.""" + if "context_remain" in example: + return { + **example, + "context": apply_crop_shim_to_views(example["context"], shape), + "target": apply_crop_shim_to_views(example["target"], shape), + "context_remain": apply_crop_shim_to_views(example["context_remain"], shape), + } + + return { + **example, + "context": apply_crop_shim_to_views(example["context"], shape), + "target": apply_crop_shim_to_views(example["target"], shape), + } diff --git a/optgs/dataset/shims/patch_shim.py b/optgs/dataset/shims/patch_shim.py new file mode 100644 index 0000000000000000000000000000000000000000..c57c6f977b943d250eddcadf3c16f72095a30246 --- /dev/null +++ b/optgs/dataset/shims/patch_shim.py @@ -0,0 +1,52 @@ +from dataclasses import asdict + +from ..data_types import BatchedExample, BatchedViews, UnbatchedViews, BatchedViewsDict, UnbatchedExample + + +def apply_patch_shim_to_views(views: BatchedViews | UnbatchedViews | BatchedViewsDict, + patch_size: int | list[int]) -> BatchedViews | UnbatchedViews | BatchedViewsDict: + *_, h, w = views["image"].shape + + if isinstance(patch_size, int): + patch_size_x = patch_size + patch_size_y = patch_size + else: + patch_size_x, patch_size_y = patch_size + + h_new = (h // patch_size_x) * patch_size_x + row = (h - h_new) // 2 + w_new = (w // patch_size_y) * patch_size_y + col = (w - w_new) // 2 + + # Center-crop the image. + image = views["image"][..., row: row + h_new, col: col + w_new] + + # Adjust the intrinsics to account for the cropping. + intrinsics = views["intrinsics"].clone() + intrinsics[..., 0, 0] *= w / w_new # fx + intrinsics[..., 1, 1] *= h / h_new # fy + + if isinstance(views, BatchedViews): + return BatchedViews.from_dict({ + **asdict(views), + "image": image, + "intrinsics": intrinsics, + }) + else: + return { + **views, + "image": image, + "intrinsics": intrinsics, + } + + +def apply_patch_shim(batch: BatchedExample | UnbatchedExample, + patch_size: int | list[int]) -> BatchedExample | UnbatchedExample: + """Crop images in the batch so that their dimensions are cleanly divisible by the + specified patch size. + """ + return { + **batch, + "context": apply_patch_shim_to_views(batch["context"], patch_size), + "target": apply_patch_shim_to_views(batch["target"], patch_size), + } diff --git a/optgs/dataset/validation_wrapper.py b/optgs/dataset/validation_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..9771125f0f38059eb43469174c84122542d8c863 --- /dev/null +++ b/optgs/dataset/validation_wrapper.py @@ -0,0 +1,32 @@ +from typing import Iterator, Optional + +import torch +from torch.utils.data import Dataset, IterableDataset + + +class ValidationWrapper(Dataset): + """Wraps a dataset so that PyTorch Lightning's validation step can be turned into a + visualization step. + """ + + dataset: Dataset + dataset_iterator: Optional[Iterator] + length: int + + def __init__(self, dataset: Dataset, length: int) -> None: + super().__init__() + self.dataset = dataset + self.length = length + self.dataset_iterator = None + + def __len__(self): + return self.length + + def __getitem__(self, index: int): + if isinstance(self.dataset, IterableDataset): + if self.dataset_iterator is None: + self.dataset_iterator = iter(self.dataset) + return next(self.dataset_iterator) + + random_index = torch.randint(0, len(self.dataset), tuple()) + return self.dataset[random_index.item()] diff --git a/optgs/dataset/view_sampler/__init__.py b/optgs/dataset/view_sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a71d4214e13be168864e780d2705732fd1d25edf --- /dev/null +++ b/optgs/dataset/view_sampler/__init__.py @@ -0,0 +1,49 @@ +from typing import Any + +from ...misc.step_tracker import StepTracker +from ..data_types import Stage +from .view_sampler import ViewSampler +from .view_sampler_all import ViewSamplerAll, ViewSamplerAllCfg +from .view_sampler_ids import ViewSamplerIDs, ViewSamplerIDsCfg +from .view_sampler_arbitrary import ViewSamplerArbitrary, ViewSamplerArbitraryCfg +from .view_sampler_bounded import ViewSamplerBounded, ViewSamplerBoundedCfg +from .view_sampler_evaluation import ViewSamplerEvaluation, ViewSamplerEvaluationCfg +from .view_sampler_bounded_v2 import ViewSamplerBoundedV2, ViewSamplerBoundedV2Cfg +from optgs.dataset.view_sampler.view_sampler_dense import ViewSamplerDense, ViewSamplerDenseCfg + +VIEW_SAMPLERS: dict[str, ViewSampler[Any]] = { + "all": ViewSamplerAll, + "ids": ViewSamplerIDs, + "dense": ViewSamplerDense, # colmap datasets + "arbitrary": ViewSamplerArbitrary, + "bounded": ViewSamplerBounded, + "evaluation": ViewSamplerEvaluation, # during evaluation + "boundedv2": ViewSamplerBoundedV2, # during training +} + +ViewSamplerCfg = ( + ViewSamplerArbitraryCfg + | ViewSamplerBoundedCfg + | ViewSamplerEvaluationCfg + | ViewSamplerAllCfg + | ViewSamplerBoundedV2Cfg + | ViewSamplerDenseCfg + | ViewSamplerIDsCfg +) + + +def get_view_sampler( + cfg: ViewSamplerCfg, + stage: Stage, + overfit: bool, + cameras_are_circular: bool, + step_tracker: StepTracker | None, +) -> ViewSampler[Any]: + print("Using view sampler:", cfg.name) + return VIEW_SAMPLERS[cfg.name]( + cfg, + stage, + overfit, + cameras_are_circular, + step_tracker, + ) diff --git a/optgs/dataset/view_sampler/view_sampler.py b/optgs/dataset/view_sampler/view_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4677b9eb34de4b923172341a0db23273c68112c6 --- /dev/null +++ b/optgs/dataset/view_sampler/view_sampler.py @@ -0,0 +1,135 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Generic, TypeVar, Literal + +import torch +from jaxtyping import Float, Int64 +from torch import Tensor +from typeguard import value + +from ...misc.step_tracker import StepTracker +from ..data_types import Stage + +T = TypeVar("T") + + +@dataclass +class ViewSamplerCfg: + name: Literal["base"] + num_context_views: int + num_target_views: int + + +class ViewSampler(ABC, Generic[T]): + cfg: T + stage: Stage + is_overfitting: bool + cameras_are_circular: bool + step_tracker: StepTracker | None + + def __init__( + self, + cfg: T, + stage: Stage, + is_overfitting: bool, + cameras_are_circular: bool, + step_tracker: StepTracker | None, + ) -> None: + self.cfg = cfg + self.stage = stage + self.is_overfitting = is_overfitting + self.cameras_are_circular = cameras_are_circular + self.step_tracker = step_tracker + + self._all_context_indices = None + self._all_target_indices = None + + @property + def all_context_indices(self) -> Int64[Tensor, " context_view"]: + return self._all_context_indices + + @property + def context_indices(self) -> Int64[Tensor, " target_view"]: + return self._all_context_indices + + @context_indices.setter + def context_indices(self, indices: Int64[Tensor, " context_view"]): + if self._all_context_indices is None: + self._all_context_indices = indices + else: + raise RuntimeError("Context indices have already been set.") + + @property + def target_indices(self) -> Int64[Tensor, " target_view"]: + return self._all_target_indices + + @target_indices.setter + def target_indices(self, indices: Int64[Tensor, " target_view"]): + if self._all_target_indices is None: + self._all_target_indices = indices + else: + raise RuntimeError("Target indices have already been set.") + + def sample_subset(self, extrinsics, intrinsics, device): + pass + + @abstractmethod + def _sample_impl( + self, + scene: str, + extrinsics: Float[Tensor, "view 4 4"], + intrinsics: Float[Tensor, "view 3 3"], + device: torch.device = torch.device("cpu"), + **kwargs, + ) -> tuple[ + Int64[Tensor, " context_view"], # indices for context views + Int64[Tensor, " target_view"], # indices for target views + ]: + pass + + def sample( + self, + scene: str, + extrinsics: Float[Tensor, "view 4 4"], + intrinsics: Float[Tensor, "view 3 3"], + device: torch.device = torch.device("cpu"), + **kwargs, + ) -> tuple[ + Int64[Tensor, " context_view"], # indices for context views + Int64[Tensor, " target_view"], # indices for target views + ]: + context_indices, target_indices = self._sample_impl( + scene=scene, + extrinsics=extrinsics, + intrinsics=intrinsics, + device=device, + **kwargs, + ) + # self.context_indices = context_indices + # self.target_indices = target_indices + + return context_indices, target_indices + + @property + @abstractmethod + def num_target_views(self) -> int: + pass + + @property + @abstractmethod + def num_context_views(self) -> int: + pass + + @property + def global_step(self) -> int: + return 0 if self.step_tracker is None else self.step_tracker.get_step() + + def new_instance(self) -> "ViewSampler": + """Create a new instance of the same ViewSampler class with the same configuration.""" + return value(self.__class__)( + cfg=self.cfg, + stage=self.stage, + is_overfitting=self.is_overfitting, + cameras_are_circular=self.cameras_are_circular, + step_tracker=self.step_tracker, + ) diff --git a/optgs/dataset/view_sampler/view_sampler_all.py b/optgs/dataset/view_sampler/view_sampler_all.py new file mode 100644 index 0000000000000000000000000000000000000000..e80656626a69b805134939668bc5401d4435dac8 --- /dev/null +++ b/optgs/dataset/view_sampler/view_sampler_all.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import Literal + +import torch +from jaxtyping import Float, Int64 +from torch import Tensor + +from .view_sampler import ViewSampler + + +@dataclass +class ViewSamplerAllCfg: + name: Literal["all"] + + +class ViewSamplerAll(ViewSampler[ViewSamplerAllCfg]): + def _sample_impl( + self, + scene: str, + extrinsics: Float[Tensor, "view 4 4"], + intrinsics: Float[Tensor, "view 3 3"], + device: torch.device = torch.device("cpu"), + **kwargs, + ) -> tuple[ + Int64[Tensor, " context_view"], # indices for context views + Int64[Tensor, " target_view"], # indices for target views + ]: + v, _, _ = extrinsics.shape + all_frames = torch.arange(v, device=device) + return all_frames, all_frames + + @property + def num_context_views(self) -> int: + return 0 + + @property + def num_target_views(self) -> int: + return 0 diff --git a/optgs/dataset/view_sampler/view_sampler_arbitrary.py b/optgs/dataset/view_sampler/view_sampler_arbitrary.py new file mode 100644 index 0000000000000000000000000000000000000000..38d7a5769410719fc68584a8ae6abda6b730a935 --- /dev/null +++ b/optgs/dataset/view_sampler/view_sampler_arbitrary.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass +from typing import Literal + +import torch +from jaxtyping import Float, Int64 +from torch import Tensor + +from .view_sampler import ViewSampler + + +@dataclass +class ViewSamplerArbitraryCfg: + name: Literal["arbitrary"] + num_context_views: int + num_target_views: int + context_views: list[int] | None + target_views: list[int] | None + + +class ViewSamplerArbitrary(ViewSampler[ViewSamplerArbitraryCfg]): + def _sample_impl( + self, + scene: str, + extrinsics: Float[Tensor, "view 4 4"], + intrinsics: Float[Tensor, "view 3 3"], + device: torch.device = torch.device("cpu"), + **kwargs, + ) -> tuple[ + Int64[Tensor, " context_view"], # indices for context views + Int64[Tensor, " target_view"], # indices for target views + ]: + """Arbitrarily sample context and target views.""" + num_views, _, _ = extrinsics.shape + + index_context = torch.randint( + 0, + num_views, + size=(self.cfg.num_context_views,), + device=device, + ) + + # Allow the context views to be fixed. + if self.cfg.context_views is not None: + assert len(self.cfg.context_views) == self.cfg.num_context_views + index_context = torch.tensor( + self.cfg.context_views, dtype=torch.int64, device=device + ) + + index_target = torch.randint( + 0, + num_views, + size=(self.cfg.num_target_views,), + device=device, + ) + + # Allow the target views to be fixed. + if self.cfg.target_views is not None: + assert len(self.cfg.target_views) == self.cfg.num_target_views + index_target = torch.tensor( + self.cfg.target_views, dtype=torch.int64, device=device + ) + + return index_context, index_target + + @property + def num_context_views(self) -> int: + return self.cfg.num_context_views + + @property + def num_target_views(self) -> int: + return self.cfg.num_target_views diff --git a/optgs/dataset/view_sampler/view_sampler_bounded.py b/optgs/dataset/view_sampler/view_sampler_bounded.py new file mode 100644 index 0000000000000000000000000000000000000000..47359882e310c7ae74f6bf714d7f11103747f097 --- /dev/null +++ b/optgs/dataset/view_sampler/view_sampler_bounded.py @@ -0,0 +1,132 @@ +from dataclasses import dataclass +from typing import Literal + +import torch +from jaxtyping import Float, Int64 +from torch import Tensor + +from .view_sampler import ViewSampler + + +@dataclass +class ViewSamplerBoundedCfg: + name: Literal["bounded"] + num_context_views: int + num_target_views: int + min_distance_between_context_views: int + max_distance_between_context_views: int + min_distance_to_context_views: int + warm_up_steps: int + initial_min_distance_between_context_views: int + initial_max_distance_between_context_views: int + + +class ViewSamplerBounded(ViewSampler[ViewSamplerBoundedCfg]): + def schedule(self, initial: int, final: int) -> int: + fraction = self.global_step / self.cfg.warm_up_steps + return min(initial + int((final - initial) * fraction), final) + + def _sample_impl( + self, + scene: str, + extrinsics: Float[Tensor, "view 4 4"], + intrinsics: Float[Tensor, "view 3 3"], + device: torch.device = torch.device("cpu"), + min_view_dist: int | None = None, + max_view_dist: int | None = None, + **kwargs, + ) -> tuple[ + Int64[Tensor, " context_view"], # indices for context views + Int64[Tensor, " target_view"], # indices for target views + ]: + num_views, _, _ = extrinsics.shape + + # Compute the context view spacing based on the current global step. + if self.stage == "test": + # When testing, always use the full gap. + max_gap = self.cfg.max_distance_between_context_views + min_gap = self.cfg.max_distance_between_context_views + elif self.cfg.warm_up_steps > 0: + max_gap = self.schedule( + self.cfg.initial_max_distance_between_context_views, + self.cfg.max_distance_between_context_views, + ) + min_gap = self.schedule( + self.cfg.initial_min_distance_between_context_views, + self.cfg.min_distance_between_context_views, + ) + else: + max_gap = self.cfg.max_distance_between_context_views + min_gap = self.cfg.min_distance_between_context_views + + # Pick the gap between the context views. + if not self.cameras_are_circular: + max_gap = min(num_views - 1, max_gap) + min_gap = max(2 * self.cfg.min_distance_to_context_views, min_gap) + + # overwrite min_gap and max_gap, useful for mixed dataset training + # use different view distance for different dataset + if min_view_dist is not None: + min_gap = min_view_dist + + if max_view_dist is not None: + max_gap = max_view_dist + + if max_gap < min_gap: + raise ValueError("Example does not have enough frames!") + context_gap = torch.randint( + min_gap, + max_gap + 1, + size=tuple(), + device=device, + ).item() + + # Pick the left and right context indices. + index_context_left = torch.randint( + num_views if self.cameras_are_circular else num_views - context_gap, + size=tuple(), + device=device, + ).item() + if self.stage == "test": + index_context_left = index_context_left * 0 + index_context_right = index_context_left + context_gap + + if self.is_overfitting: + index_context_left *= 0 + index_context_right *= 0 + index_context_right += max_gap + + # Pick the target view indices. + if self.stage == "test": + # When testing, pick all. + index_target = torch.arange( + index_context_left, + index_context_right + 1, + device=device, + ) + else: + # When training or validating (visualizing), pick at random. + index_target = torch.randint( + index_context_left + self.cfg.min_distance_to_context_views, + index_context_right + 1 - self.cfg.min_distance_to_context_views, + size=(self.cfg.num_target_views,), + device=device, + ) + + # Apply modulo for circular datasets. + if self.cameras_are_circular: + index_target %= num_views + index_context_right %= num_views + + return ( + torch.tensor((index_context_left, index_context_right)), + index_target, + ) + + @property + def num_context_views(self) -> int: + return 2 + + @property + def num_target_views(self) -> int: + return self.cfg.num_target_views diff --git a/optgs/dataset/view_sampler/view_sampler_bounded_v2.py b/optgs/dataset/view_sampler/view_sampler_bounded_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c48992d40ce226a4963f1a71007eeb108373fc3e --- /dev/null +++ b/optgs/dataset/view_sampler/view_sampler_bounded_v2.py @@ -0,0 +1,302 @@ +''' +Modifiedy from latentSplat and pixelSplat to handle extrapolate and more context views +''' +import copy +from dataclasses import dataclass +from typing import Literal, Optional + +import torch +from jaxtyping import Float, Int64 +from torch import Tensor +import random + +from .view_sampler import ViewSampler + + +def farthest_point_sample(xyz, npoint, first_idx_strategy="max_dist"): + """ + Input: + xyz: pointcloud data, [B, N, 3] + npoint: number of samples + Return: + centroids: sampled pointcloud index, [B, npoint] + """ + + device = xyz.device + B, N, C = xyz.shape + + centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) + distance = torch.ones(B, N).to(device) * 1e10 + + batch_indices = torch.arange(B, dtype=torch.long).to(device) + + if first_idx_strategy == 'max_dist': + barycenter = torch.sum((xyz), 1) + barycenter = barycenter / xyz.shape[1] + barycenter = barycenter.view(B, 1, 3) + + dist = torch.sum((xyz - barycenter) ** 2, -1) + curr_idx = torch.max(dist, 1)[1] + elif first_idx_strategy == 'random': + curr_idx = torch.randint(0, N, (B,), dtype=torch.long).to(device) + else: + raise ValueError(f"Unknown first_idx_strategy: {first_idx_strategy}") + + for i in range(npoint): + centroids[:, i] = curr_idx + centroid = xyz[batch_indices, curr_idx, :].view(B, 1, 3) + dist = torch.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + curr_idx = torch.max(distance, -1)[1] + + return centroids + + +@dataclass +class ViewSamplerBoundedV2Cfg: + name: Literal["boundedv2"] + num_context_views: int + num_target_views: int + min_distance_between_context_views: int + max_distance_between_context_views: int + max_distance_to_context_views: int + context_gap_warm_up_steps: int + target_gap_warm_up_steps: int + initial_min_distance_between_context_views: int + initial_max_distance_between_context_views: int + initial_max_distance_to_context_views: int + extra_views_sampling_strategy: Optional[Literal["random", "farthest_point", "equal"]] = "random" + target_views_replace_sample: Optional[bool] = True + + +class ViewSamplerBoundedV2(ViewSampler[ViewSamplerBoundedV2Cfg]): + + def __init__(self, cfg, stage, is_overfitting: bool, cameras_are_circular: bool, + step_tracker) -> None: + super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) + self._cfg_backup = copy.deepcopy(cfg) + + def schedule(self, initial: int, final: int, steps: int) -> int: + fraction = self.global_step / steps + return min(initial + int((final - initial) * fraction), final) + + def _sample_impl( + self, + scene: str, + extrinsics: Float[Tensor, "view 4 4"], + intrinsics: Float[Tensor, "view 3 3"], + device: torch.device = torch.device("cpu"), + max_num_views: Optional[int] = None, + min_context_views: int = 0, + max_context_views: int = 0, + min_view_dist: int | None = None, + max_view_dist: int | None = None, + ) -> tuple[ + Int64[Tensor, " context_view"], # indices for context views + Int64[Tensor, " target_view"], # indices for target views + ]: + num_views, _, _ = extrinsics.shape + + if max_num_views is not None: + num_views = min(num_views, max_num_views) + + def determine_per_scene_values(name, value): + if getattr(self._cfg_backup, name) < 0: + setattr(self.cfg, name, value) + + determine_per_scene_values('max_distance_between_context_views', num_views) + determine_per_scene_values('initial_max_distance_between_context_views', num_views) + determine_per_scene_values('min_distance_between_context_views', num_views-1) + determine_per_scene_values('initial_min_distance_between_context_views', num_views-1) + + if min_context_views > 0 and max_context_views > 0 and self.stage != "test": + random_num_views = random.randint(min_context_views, max_context_views) + else: + random_num_views = None + + context_gap = self.get_context_gap(device, max_context_views, max_view_dist, min_view_dist, num_views, + random_num_views) + if context_gap < 0: + context_gap = num_views + + # Compute the margin from context window to target window based on the current global step + max_target_gap = self.get_max_target_gap() + if max_target_gap < 0: + max_target_gap = num_views + 1 + + # Pick the left and right context indices. + index_context_left, index_context_right, index_target_left, index_target_right = self.get_bound_indices( + context_gap, device, max_target_gap, num_views) + + # Note: targets are sampled before extra context views — order matters for reproducibility. + index_target = self.get_target_indices(device, index_target_left, index_target_right, + [index_context_left, index_context_right]) + + # Apply modulo for circular datasets. + if self.cameras_are_circular: + index_target %= num_views + index_context_right %= num_views + + # If more than two context views are desired, pick extra context views between + # the left and right ones. + if random_num_views is not None: + total_num_views = random_num_views + else: + total_num_views = self.cfg.num_context_views + + extra_views, index_context_left, index_context_right = self.get_extra_views(extrinsics, index_context_left, + index_context_right, + total_num_views, + index_target) + index_context = torch.tensor((index_context_left, *extra_views, index_context_right)) + assert set(index_context.tolist()).isdisjoint(set(index_target.tolist())), \ + f"Context and target views overlap! Context: {index_context}, target: {index_target}" + + return index_context, index_target + + def get_extra_views(self, extrinsics, index_context_left, index_context_right, total_num_views, index_target): + if total_num_views > 2: + num_extra_views = total_num_views - 2 + extra_views = [] + if self.cfg.extra_views_sampling_strategy == 'random': + extra_views = self.sample_unique_excluding( + index_context_left + 1, + index_context_right - 1, + num_extra_views, + index_target, + ) + elif self.cfg.extra_views_sampling_strategy == 'farthest_point': + context_bounded_index = torch.arange(index_context_left, index_context_right + 1) + # remove target views from candidates + context_bounded_index = torch.tensor([i for i in context_bounded_index if i not in index_target]) + candidate_views_position = extrinsics[context_bounded_index, :3, -1].unsqueeze(0) + index_context_local = farthest_point_sample(candidate_views_position, total_num_views).squeeze(0) + # remap context index back to global scene based index + index_context = context_bounded_index[index_context_local] + index_context = index_context.sort().values + index_context_left = index_context[0].item() + index_context_right = index_context[-1].item() + extra_views = index_context[1:-1].tolist() + elif self.cfg.extra_views_sampling_strategy == 'equal': + pass + + # sort the index + extra_views = sorted(extra_views) + else: + extra_views = [] + return extra_views, index_context_left, index_context_right + + def get_max_target_gap(self): + if self.stage != "test" and self.cfg.target_gap_warm_up_steps > 0: + max_target_gap = self.schedule( + self.cfg.initial_max_distance_to_context_views, + self.cfg.max_distance_to_context_views, + self.cfg.target_gap_warm_up_steps, + ) + else: + max_target_gap = self.cfg.max_distance_to_context_views + return max_target_gap + + def get_context_gap(self, device, max_context_views, max_view_dist, min_view_dist, num_views, random_num_views): + # Compute the context view spacing based on the current global step. + if self.stage == "test": + # When testing, always use the full gap. + max_context_gap = self.cfg.max_distance_between_context_views + min_context_gap = self.cfg.max_distance_between_context_views + elif self.cfg.context_gap_warm_up_steps > 0: + max_context_gap = self.schedule( + self.cfg.initial_max_distance_between_context_views, + self.cfg.max_distance_between_context_views, + self.cfg.context_gap_warm_up_steps, + ) + min_context_gap = self.schedule( + self.cfg.initial_min_distance_between_context_views, + self.cfg.min_distance_between_context_views, + self.cfg.context_gap_warm_up_steps, + ) + else: + max_context_gap = self.cfg.max_distance_between_context_views + min_context_gap = self.cfg.min_distance_between_context_views + if min_view_dist is not None and max_view_dist is not None: + # for mixed dataset training, with different sampling distance + min_context_gap = min_view_dist + max_context_gap = max_view_dist + if random_num_views is not None: + # smaller context gap accordingly + scale_factor = max(max_context_views // random_num_views, 1) + max_context_gap = max_context_gap // scale_factor + min_context_gap = min_context_gap // scale_factor + if not self.cameras_are_circular: + max_context_gap = min( + num_views - 1, max_context_gap + ) + # Pick the gap between the context views. + if max_context_gap < min_context_gap: + raise ValueError("Example does not have enough frames!") + context_gap = torch.randint( + min_context_gap, + max_context_gap + 1, + size=tuple(), + device=device, + ).item() + return context_gap + + @staticmethod + def sample_unique_excluding(left, right, num_samples, exclude_list): + candidates = [i for i in range(left, right + 1) if i not in exclude_list] + if len(candidates) < num_samples: + raise ValueError("Not enough candidates to sample from!") + + # Sample without replacement + indices = torch.randperm(len(candidates))[:num_samples] + samples = [candidates[i] for i in indices] + assert len(set(samples)) == num_samples, f"Expected {num_samples} unique samples, got {set(samples)}" + return samples + + def get_target_indices(self, device, index_target_left, index_target_right, excluded_indices): + if self.stage == "test": + candidates = [i for i in range(index_target_left, index_target_right + 1) + if i not in excluded_indices] + index_target = torch.tensor(candidates[:self.cfg.num_target_views], device=device) + else: + if self.cfg.target_views_replace_sample: + # Sample with replacement from candidates excluding context views. + candidates = [i for i in range(index_target_left, index_target_right + 1) + if i not in excluded_indices] + rand_indices = torch.randint(0, len(candidates), size=(self.cfg.num_target_views,), device=device) + index_target = torch.tensor([candidates[i] for i in rand_indices], device=device) + else: + index_target = self.sample_unique_excluding( + index_target_left, + index_target_right, + self.cfg.num_target_views, + excluded_indices, + ) + index_target = torch.tensor(index_target, device=device) + return index_target + + def get_bound_indices(self, context_gap, device, max_target_gap, num_views): + index_context_left = torch.randint( + low=0, + high=num_views if self.cameras_are_circular else num_views - context_gap, + size=tuple(), + device=device, + ).item() + if self.stage == "test": + index_context_left = index_context_left * 0 + index_context_right = index_context_left + context_gap + index_target_left = index_context_left - max_target_gap + index_target_right = index_context_right + max_target_gap + if not self.cameras_are_circular: + index_target_left = max(0, index_target_left) + index_target_right = min(num_views - 1, index_target_right) + return index_context_left, index_context_right, index_target_left, index_target_right + + @property + def num_context_views(self) -> int: + return self.cfg.num_context_views + + @property + def num_target_views(self) -> int: + return self.cfg.num_target_views diff --git a/optgs/dataset/view_sampler/view_sampler_dense.py b/optgs/dataset/view_sampler/view_sampler_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..00003e2666f78d37095737116ce07cd86dcb5e77 --- /dev/null +++ b/optgs/dataset/view_sampler/view_sampler_dense.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch +from jaxtyping import Float, Int64 +from torch import Tensor + +from .view_sampler import ViewSampler, ViewSamplerCfg + + +@dataclass +class ViewSamplerDenseCfg(ViewSamplerCfg): + name: Literal["dense"] + target_every: int + context_every: int + + sample_views_strategy: Literal["random", "neighbors"] = "random" + + def __post_init__(self): + assert (self.target_every > 0) != (self.context_every > 0), \ + "Either target_every or context_every must be set, but not both." + + +class ViewSamplerDense(ViewSampler[ViewSamplerDenseCfg]): + + def _sample_impl( + self, + scene: str, + extrinsics: Float[Tensor, "view 4 4"], + intrinsics: Float[Tensor, "view 3 3"], + device: torch.device = torch.device("cpu"), + **kwargs, + ) -> tuple[ + Int64[Tensor, " context_view"], # indices for context views + Int64[Tensor, " target_view"], # indices for target views + ]: + """Sample context and target views.""" + num_views, _, _ = extrinsics.shape + + all_views = torch.arange(num_views, device=device) + + if self.cfg.target_every > 0: + target_views = all_views[::self.cfg.target_every] + context_views = set(all_views.tolist()) - set(target_views.tolist()) + context_views = torch.tensor(list(context_views), device=device) + elif self.cfg.context_every > 0: + context_views = all_views[::self.cfg.context_every] + target_views = set(all_views.tolist()) - set(context_views.tolist()) + target_views = torch.tensor(list(target_views), device=device) + else: + raise ValueError("Either target_every or context_every must be set to a positive integer.") + + def sample_views(extrinsics, index_views, num_views_to_sample: int, strategy: str, + center_idx: int | None = None) -> Tensor: + if num_views_to_sample == -1 or num_views_to_sample >= len(index_views): + return index_views + if strategy == "random": + return index_views[torch.randperm(len(index_views))[:num_views_to_sample]] + elif strategy == "neighbors": + raise NotImplementedError + # Choose a random center view and choose views around it, based on cameras extrinsics + if center_idx is None: + center_idx = np.random.choice( + len(index_views), + size=1, + replace=False + )[0] + # Calculate distances to the center view + rotations = extrinsics[:, :3, :3] # [V, 3, 3] + # Calculate camera center as -R^T * t + translation = extrinsics[:, :3, [3]] # [V, 3, 1] + # poses = -rotations.transpose(1, 2) @ translation # [V, 3, 1] + poses = translation # [V, 3, 1] + center_pose = poses[center_idx] # [3, 1] + # Calculate Euclidean distances to the center view + dists = torch.norm(poses - center_pose.unsqueeze(0), dim=1)[0] # [V] + # Calculate angular differences to the center view + center_rot = extrinsics[center_idx, :3, :3] # [3, 3] + # Compute rotation difference + rot_diffs = torch.matmul(rotations, center_rot.transpose(0, 1)) # [V, 3, 3] + # Compute angles from rotation matrices + cos_angles = (rot_diffs[:, 0, 0] + rot_diffs[:, 1, 1] + rot_diffs[:, 2, 2] - 1) / 2 # [V] + cos_angles = torch.clamp(cos_angles, -1.0, 1.0) # Numerical stability + angles = torch.acos(cos_angles) # [V] + # Combine distance and angle into a single metric + combined_metric = dists + angles # [V] + + # Get the indices of the nearest neighbors + combined_metric = combined_metric[index_views] + sorted_indices = torch.argsort(combined_metric) + + return index_views[sorted_indices[:num_views_to_sample]] + else: + raise ValueError(f"Unknown sampling strategy: {strategy}") + + index_context = sample_views(extrinsics, context_views, self.cfg.num_context_views, + self.cfg.sample_views_strategy) + index_target = sample_views(extrinsics, target_views, self.cfg.num_target_views, self.cfg.sample_views_strategy, + center_idx=index_context[0].item()) + + return index_context, index_target + + @property + def num_context_views(self) -> int: + return self.cfg.num_context_views + + @property + def num_target_views(self) -> int: + return self.cfg.num_target_views diff --git a/optgs/dataset/view_sampler/view_sampler_evaluation.py b/optgs/dataset/view_sampler/view_sampler_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..7de25ca9d9e390b4043158384609c77fa15821ad --- /dev/null +++ b/optgs/dataset/view_sampler/view_sampler_evaluation.py @@ -0,0 +1,69 @@ +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import torch +from dacite import Config, from_dict +from jaxtyping import Float, Int64 +from torch import Tensor + +from ...evaluation.evaluation_index_generator import IndexEntry +from ...misc.step_tracker import StepTracker +from ...paths import asset_path +from ..data_types import Stage +from .view_sampler import ViewSampler + + +@dataclass +class ViewSamplerEvaluationCfg: + name: Literal["evaluation"] + index_path: Path + num_context_views: int + + +class ViewSamplerEvaluation(ViewSampler[ViewSamplerEvaluationCfg]): + index: dict[str, IndexEntry | None] + + def __init__( + self, + cfg: ViewSamplerEvaluationCfg, + stage: Stage, + is_overfitting: bool, + cameras_are_circular: bool, + step_tracker: StepTracker | None, + ) -> None: + super().__init__(cfg, stage, is_overfitting, cameras_are_circular, step_tracker) + + dacite_config = Config(cast=[tuple]) + with asset_path(cfg.index_path).open("r") as f: + self.index = { + k: None if v is None else from_dict(IndexEntry, v, dacite_config) + for k, v in json.load(f).items() + } + + def _sample_impl( + self, + scene: str, + extrinsics: Float[Tensor, "view 4 4"], + intrinsics: Float[Tensor, "view 3 3"], + device: torch.device = torch.device("cpu"), + **kwargs, + ) -> tuple[ + Int64[Tensor, " context_view"], # indices for context views + Int64[Tensor, " target_view"], # indices for target views + ]: + entry = self.index.get(scene) + if entry is None: + raise ValueError(f"No indices available for scene {scene}.") + context_indices = torch.tensor(entry.context, dtype=torch.int64, device=device) + target_indices = torch.tensor(entry.target, dtype=torch.int64, device=device) + return context_indices, target_indices + + @property + def num_context_views(self) -> int: + return 0 + + @property + def num_target_views(self) -> int: + return 0 diff --git a/optgs/dataset/view_sampler/view_sampler_ids.py b/optgs/dataset/view_sampler/view_sampler_ids.py new file mode 100644 index 0000000000000000000000000000000000000000..a5635918f178c95a9cf164d37453114a8926734c --- /dev/null +++ b/optgs/dataset/view_sampler/view_sampler_ids.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass +from typing import Literal + +import torch +from jaxtyping import Float, Int64 +from torch import Tensor + +from .view_sampler import ViewSampler + + +@dataclass +class ViewSamplerIDsCfg: + name: Literal["ids"] + context_views_ids: list[int] + target_views_ids: list[int] + + +class ViewSamplerIDs(ViewSampler[ViewSamplerIDsCfg]): + def _sample_impl( + self, + scene: str, + extrinsics: Float[Tensor, "view 4 4"], + intrinsics: Float[Tensor, "view 3 3"], + device: torch.device = torch.device("cpu"), + **kwargs, + ) -> tuple[ + Int64[Tensor, " context_view"], # indices for context views + Int64[Tensor, " target_view"], # indices for target views + ]: + v, _, _ = extrinsics.shape + context_indices = torch.tensor(self.cfg.context_views_ids, device=device, dtype=torch.int64) + target_indices = torch.tensor(self.cfg.target_views_ids, device=device, dtype=torch.int64) + return context_indices, target_indices + + @property + def num_context_views(self) -> int: + return len(self.cfg.context_views_ids) + + @property + def num_target_views(self) -> int: + return len(self.cfg.target_views_ids) \ No newline at end of file diff --git a/optgs/evaluation/__init__.py b/optgs/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/evaluation/depth_metrics.py b/optgs/evaluation/depth_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f3b673f4445e43d45f04e0d398bdc05793d4f2 --- /dev/null +++ b/optgs/evaluation/depth_metrics.py @@ -0,0 +1,23 @@ +import numpy as np + + +def compute_depth_errors(gt, pred): + """Computation of error metrics between predicted and ground truth depths + """ + thresh = np.maximum((gt / pred), (pred / gt)) + a1 = (thresh < 1.25).mean() + a2 = (thresh < 1.25 ** 2).mean() + a3 = (thresh < 1.25 ** 3).mean() + + rmse = (gt - pred) ** 2 + rmse = np.sqrt(rmse.mean()) + + rmse_log = (np.log(gt) - np.log(pred)) ** 2 + rmse_log = np.sqrt(rmse_log.mean()) + + abs_rel = np.mean(np.abs(gt - pred) / gt) + + sq_rel = np.mean(((gt - pred) ** 2) / gt) + + return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 + diff --git a/optgs/evaluation/evaluation_cfg.py b/optgs/evaluation/evaluation_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..0b5d310826ecf64da15f868ad5f16218ad2b5c7c --- /dev/null +++ b/optgs/evaluation/evaluation_cfg.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class MethodCfg: + name: str + key: str + path: Path + + +@dataclass +class SceneCfg: + scene: str + target_index: int + + +@dataclass +class EvaluationCfg: + methods: list[MethodCfg] + side_by_side_path: Path | None + animate_side_by_side: bool + highlighted: list[SceneCfg] diff --git a/optgs/evaluation/evaluation_index_generator.py b/optgs/evaluation/evaluation_index_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..dc2762996e1c991c8a05f7135810e42ad1cbd14f --- /dev/null +++ b/optgs/evaluation/evaluation_index_generator.py @@ -0,0 +1,158 @@ +import json +from dataclasses import asdict, dataclass +from pathlib import Path + +import torch +from einops import rearrange +from pytorch_lightning import LightningModule +from tqdm import tqdm + +from ..geometry.epipolar_lines import project_rays +from ..geometry.projection import get_world_rays, sample_image_grid +from ..misc.image_io import save_image +from ..visualization.annotation import add_label +from ..visualization.layout import add_border, hcat + + +@dataclass +class EvaluationIndexGeneratorCfg: + num_target_views: int + min_distance: int + max_distance: int + min_overlap: float + max_overlap: float + output_path: Path + save_previews: bool + seed: int + + +@dataclass +class IndexEntry: + context: tuple[int, ...] + target: tuple[int, ...] + + +class EvaluationIndexGenerator(LightningModule): + generator: torch.Generator + cfg: EvaluationIndexGeneratorCfg + index: dict[str, IndexEntry | None] + + def __init__(self, cfg: EvaluationIndexGeneratorCfg) -> None: + super().__init__() + self.cfg = cfg + self.generator = torch.Generator() + self.generator.manual_seed(cfg.seed) + self.index = {} + + def test_step(self, batch, batch_idx): + b, v, _, h, w = batch["target"]["image"].shape + assert b == 1 + extrinsics = batch["target"]["extrinsics"][0] + intrinsics = batch["target"]["intrinsics"][0] + scene = batch["scene"][0] + + context_indices = torch.randperm(v, generator=self.generator) + for context_index in tqdm(context_indices, "Finding context pair"): + xy, _ = sample_image_grid((h, w), self.device) + context_origins, context_directions = get_world_rays( + rearrange(xy, "h w xy -> (h w) xy"), + extrinsics[context_index], + intrinsics[context_index], + ) + + # Step away from context view until the minimum overlap threshold is met. + valid_indices = [] + for step in (1, -1): + min_distance = self.cfg.min_distance + max_distance = self.cfg.max_distance + current_index = context_index + step * min_distance + + while 0 <= current_index.item() < v: + # Compute overlap. + current_origins, current_directions = get_world_rays( + rearrange(xy, "h w xy -> (h w) xy"), + extrinsics[current_index], + intrinsics[current_index], + ) + projection_onto_current = project_rays( + context_origins, + context_directions, + extrinsics[current_index], + intrinsics[current_index], + ) + projection_onto_context = project_rays( + current_origins, + current_directions, + extrinsics[context_index], + intrinsics[context_index], + ) + overlap_a = projection_onto_context["overlaps_image"].float().mean() + overlap_b = projection_onto_current["overlaps_image"].float().mean() + + overlap = min(overlap_a, overlap_b) + delta = (current_index - context_index).abs() + + min_overlap = self.cfg.min_overlap + max_overlap = self.cfg.max_overlap + if min_overlap <= overlap <= max_overlap: + valid_indices.append( + (current_index.item(), overlap_a, overlap_b) + ) + + # Stop once the camera has panned away too much. + if overlap < min_overlap or delta > max_distance: + break + + current_index += step + + if valid_indices: + # Pick a random valid view. Index the resulting views. + num_options = len(valid_indices) + chosen = torch.randint( + 0, num_options, size=tuple(), generator=self.generator + ) + chosen, overlap_a, overlap_b = valid_indices[chosen] + + context_left = min(chosen, context_index.item()) + context_right = max(chosen, context_index.item()) + delta = context_right - context_left + + # Pick non-repeated random target views. + while True: + target_views = torch.randint( + context_left, + context_right + 1, + (self.cfg.num_target_views,), + generator=self.generator, + ) + if (target_views.unique(return_counts=True)[1] == 1).all(): + break + + target = tuple(sorted(target_views.tolist())) + self.index[scene] = IndexEntry( + context=(context_left, context_right), + target=target, + ) + + # Optionally, save a preview. + if self.cfg.save_previews: + preview_path = self.cfg.output_path / "previews" + preview_path.mkdir(exist_ok=True, parents=True) + a = batch["target"]["image"][0, chosen] + a = add_label(a, f"Overlap: {overlap_a * 100:.1f}%") + b = batch["target"]["image"][0, context_index] + b = add_label(b, f"Overlap: {overlap_b * 100:.1f}%") + vis = add_border(add_border(hcat(a, b)), 1, 0) + vis = add_label(vis, f"Distance: {delta} frames") + save_image(add_border(vis), preview_path / f"{scene}.png") + break + else: + # This happens if no starting frame produces a valid evaluation example. + self.index[scene] = None + + def save_index(self) -> None: + self.cfg.output_path.mkdir(exist_ok=True, parents=True) + with (self.cfg.output_path / "evaluation_index.json").open("w") as f: + json.dump( + {k: None if v is None else asdict(v) for k, v in self.index.items()}, f + ) diff --git a/optgs/evaluation/metric_computer.py b/optgs/evaluation/metric_computer.py new file mode 100644 index 0000000000000000000000000000000000000000..6b966b9737b292a65b9a088c9665765514766bdb --- /dev/null +++ b/optgs/evaluation/metric_computer.py @@ -0,0 +1,115 @@ +import os +from pathlib import Path + +import torch +from pytorch_lightning import LightningModule +from ..misc.console import metrics_table +from ..misc.image_io import load_image, save_image +from ..visualization.annotation import add_label +from ..visualization.layout import add_border, hcat +from .evaluation_cfg import EvaluationCfg +from .metrics import compute_lpips, compute_psnr, compute_ssim + + +class MetricComputer(LightningModule): + cfg: EvaluationCfg + + def __init__(self, cfg: EvaluationCfg) -> None: + super().__init__() + self.cfg = cfg + + def test_step(self, batch, batch_idx): + scene = batch["scene"][0] + b, cv, _, _, _ = batch["context"]["image"].shape + assert b == 1 and cv == 2 + _, v, _, _, _ = batch["target"]["image"].shape + + # Skip scenes. + for method in self.cfg.methods: + if not (method.path / scene).exists(): + print(f'Skipping "{scene}".') + return + + # Load the images. + all_images = {} + try: + for method in self.cfg.methods: + images = [ + load_image(method.path / scene / f"color/{index.item():0>6}.png") + for index in batch["target"]["index"][0] + ] + all_images[method.key] = torch.stack(images).to(self.device) + except FileNotFoundError: + print(f'Skipping "{scene}".') + return + + # Compute metrics. + all_metrics = {} + rgb_gt = batch["target"]["image"][0] + for key, images in all_images.items(): + alex_lpips, vgg_lpips = compute_lpips(rgb_gt, images) + all_metrics = { + **all_metrics, + f"alex_lpips_{key}": alex_lpips, + f"vgg_lpips_{key}": vgg_lpips, + f"ssim_{key}": compute_ssim(rgb_gt, images), + f"psnr_{key}": compute_psnr(rgb_gt, images), + } + self.log_dict(all_metrics) + self.print_preview_metrics(all_metrics) + + # Skip the rest if no side-by-side is needed. + if self.cfg.side_by_side_path is None: + return + + # Create side-by-side. + scene_key = f"{batch_idx:0>6}_{scene}" + for i in range(v): + true_index = batch["target"]["index"][0, i] + row = [add_label(batch["target"]["image"][0, i], "Ground Truth")] + for method in self.cfg.methods: + image = all_images[method.key][i] + image = add_label(image, method.name) + row.append(image) + start_frame = batch["target"]["index"][0, 0] + end_frame = batch["target"]["index"][0, -1] + label = f"Scene {batch['scene'][0]} (frames {start_frame} to {end_frame})" + row = add_border(add_label(hcat(*row), label, font_size=16)) + save_image( + row, + self.cfg.side_by_side_path / scene_key / f"{true_index:0>6}.png", + ) + + # Create an animation. + if self.cfg.animate_side_by_side: + (self.cfg.side_by_side_path / "videos").mkdir(exist_ok=True, parents=True) + command = ( + 'ffmpeg -y -framerate 30 -pattern_type glob -i "*.png" -c:v libx264 ' + '-pix_fmt yuv420p -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2"' + ) + os.system( + f"cd {self.cfg.side_by_side_path / scene_key} && {command} " + f"{Path.cwd()}/{self.cfg.side_by_side_path}/videos/{scene_key}.mp4" + ) + + def print_preview_metrics(self, metrics: dict[str, float]) -> None: + if getattr(self, "running_metrics", None) is None: + self.running_metrics = metrics + self.running_metric_steps = 1 + else: + s = self.running_metric_steps + self.running_metrics = { + k: ((s * v) + metrics[k]) / (s + 1) + for k, v in self.running_metrics.items() + } + self.running_metric_steps += 1 + + rows = [] + for method in self.cfg.methods: + row = [ + f"{self.running_metrics[f'{metric}_{method.key}']:.3f}" + for metric in ("psnr", "lpips", "ssim") + ] + rows.append((method.key, *row)) + + metrics_table(rows, ["Method", "PSNR (dB)", "LPIPS", "SSIM"]) diff --git a/optgs/evaluation/metrics.py b/optgs/evaluation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..841dc11c1cb30874131b81550e7bb3ec3547e042 --- /dev/null +++ b/optgs/evaluation/metrics.py @@ -0,0 +1,133 @@ +from functools import cache + +import torch +from einops import reduce +from jaxtyping import Float +# from lpips import LPIPS +# from skimage.metrics import structural_similarity +from torch import Tensor +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from tqdm import tqdm + + +@torch.no_grad() +def compute_psnr( + ground_truth: Float[Tensor, "batch channel height width"], + predicted: Float[Tensor, "batch channel height width"], +) -> Float[Tensor, ""]: + ground_truth = ground_truth.clip(min=0, max=1) + predicted = predicted.clip(min=0, max=1) + # Use native torch ops instead of einops reduce for speed + mse = ((ground_truth - predicted) ** 2).mean(dim=(1, 2, 3)) # [b] + return -10 * mse.log10().mean() + + +@cache +def get_alex_lpips(device: torch.device) -> LearnedPerceptualImagePatchSimilarity: + return LearnedPerceptualImagePatchSimilarity(net_type="alex", normalize=True, reduction='none').to(device) + + +@cache +def get_vgg_lpips(device: torch.device) -> LearnedPerceptualImagePatchSimilarity: + return LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True, reduction='none').to(device) + + +@torch.no_grad() +def compute_lpips( + ground_truth: Float[Tensor, "batch channel height width"], + predicted: Float[Tensor, "batch channel height width"], +) -> tuple[Float[Tensor, ""], Float[Tensor, ""]]: + predicted = torch.clamp(predicted, 0.0, 1.0) + ground_truth = torch.clamp(ground_truth, 0.0, 1.0) + vgg_value = get_vgg_lpips(predicted.device)(ground_truth, predicted).mean() + # Note: skipping alex lpips for efficiency, always return 0. + # alex_value = get_alex_lpips(predicted.device)(ground_truth, predicted) + alex_value = torch.zeros_like(vgg_value).mean() + return alex_value, vgg_value + + +@cache +def get_ssim(device: torch.device) -> StructuralSimilarityIndexMeasure: + return StructuralSimilarityIndexMeasure(data_range=1.0, reduction='none').to(device) + + +@torch.no_grad() +def compute_ssim( + ground_truth: Float[Tensor, "batch channel height width"], + predicted: Float[Tensor, "batch channel height width"], +) -> Float[Tensor, ""]: + predicted = torch.clamp(predicted, 0.0, 1.0) + ground_truth = torch.clamp(ground_truth, 0.0, 1.0) + ssim_value = get_ssim(predicted.device)(predicted, ground_truth).mean() + return ssim_value + + +metric_fn_dict = { + "psnr": compute_psnr, + "ssim": compute_ssim, + "lpips": compute_lpips, +} + + +def compute_rgb_metrics(rgb, rgb_gt, metrics: list[str], iter_batch_size: int = -1) -> dict: + metric_scores = {} + for m in metrics: + + # check if metric is recognized + if m not in metric_fn_dict: + raise ValueError(f"Metric {m} not recognized. Available metrics: {list(metric_fn_dict.keys())}") + + # compute metric score + if iter_batch_size == -1: + # compute all at once + # move back to device + rgb = rgb.to("cuda") + rgb_gt = rgb_gt.to("cuda") + score = metric_fn_dict[m](rgb_gt, rgb) + # can be tuple (for lpips) or single tensor + + else: + # batchify to save memory + all_batches_scores = [] + batch_sizes = [] + + batch_num = rgb.shape[0] // iter_batch_size + int(rgb.shape[0] % iter_batch_size != 0) + for i in tqdm(range(0, rgb.shape[0], iter_batch_size), disable=batch_num < 20, + desc=f"Computing {m} in batches"): + bs = min(iter_batch_size, rgb.shape[0] - i) + rgb_batch = rgb[i:i + bs].to("cuda") + rgb_gt_batch = rgb_gt[i:i + bs].to("cuda") + batch_scores = metric_fn_dict[m](rgb_gt_batch, rgb_batch) + # can be tuple (for lpips) or single tensor + + all_batches_scores.append(batch_scores) + batch_sizes.append(bs) + + assert len(all_batches_scores) > 0, "No batch scores computed." + + # Use weighted mean to avoid bias when the last batch is smaller than iter_batch_size. + # Each batch score is the mean over `bs` images, so we weight by bs to recover the + # true per-image mean across all N images. + weights = torch.tensor(batch_sizes, dtype=torch.float32) + total = weights.sum() + first = all_batches_scores[0] + + # Case 1: scalar tensors + if isinstance(first, torch.Tensor): + vals = torch.stack(all_batches_scores).cpu() + score = (vals * weights).sum() / total + + # Case 2: tuples of tensors + elif isinstance(first, tuple): + n = len(first) + cols = [torch.stack([batch[i] for batch in all_batches_scores]).cpu() for i in range(n)] + score = tuple((col * weights).sum() / total for col in cols) + + else: + raise TypeError("Unexpected element type: must be torch.Tensor or tuple of torch.Tensors.") + + # append to scores list + metric_scores[m] = score + + return metric_scores diff --git a/optgs/experimental/__init__.py b/optgs/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/experimental/api/__init__.py b/optgs/experimental/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0deec26ca5784ef1dbfe62ada65a32607fdea1d8 --- /dev/null +++ b/optgs/experimental/api/__init__.py @@ -0,0 +1,24 @@ +"""optgs experimental API — use the learned optimizer in external 3DGS codebases. + +The public entry point is :class:`OptGS`. It is exposed lazily via PEP 562 +``__getattr__`` so that ``import optgs.experimental.api`` stays cheap (no +torch/hydra import) until ``OptGS`` is actually accessed. +""" + +__all__ = ["OptGS", "OptGSError"] + + +def __getattr__(name: str): + if name == "OptGS": + from optgs.experimental.api.api import OptGS + + return OptGS + if name == "OptGSError": + from optgs.experimental.api.integration.scene_protocol import OptGSError + + return OptGSError + raise AttributeError(f"module 'optgs.experimental.api' has no attribute {name!r}") + + +def __dir__(): + return sorted(__all__) diff --git a/optgs/experimental/api/api.py b/optgs/experimental/api/api.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e49b51e623025b07ace591c0387fd7c44f9bf1 --- /dev/null +++ b/optgs/experimental/api/api.py @@ -0,0 +1,465 @@ +"""Public API: use optgs's learned optimizer in external 3DGS codebases. + +Typical inria (graphdeco-inria/gaussian-splatting) integration — replace the +hand-written training loop with three lines:: + + from optgs.experimental.api import OptGS + + gaussians = GaussianModel(sh_degree) # set up as usual (SfM init) + scene = Scene(dataset, gaussians) + optgs = OptGS(checkpoint="hf://org/repo/model.ckpt", device="cuda") + optgs.initialize(scene) # ingest scene + build optimizer + optgs.optimize(scene) # learned optimization, written back in place + scene.save(iteration) # proceed as normal + +Full-replacement semantics: ``optimize`` overwrites ``scene.gaussians`` in +place and nulls the inria Adam optimizer + densification accumulators. If you +later want to resume inria Adam, call ``gaussians.training_setup(...)`` again. + +For non-inria codebases use :meth:`OptGS.initialize_from_ply` / +:meth:`OptGS.initialize_from_tensors` + :meth:`OptGS.export_ply`. + +External SfM scenes carry no optgs encoder features, so checkpoints trained +with ``init_state_wo_features=False`` are coerced at construction (with a +warning): the feature-conditioned ``update_proj`` weights are dropped and the +optimizer state is initialized standard-normal. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Sequence + +import torch + +from optgs.experimental.api.integration.scene_protocol import OptGSError + +if TYPE_CHECKING: # pragma: no cover - typing only + from optgs.model.types import Gaussians + +__all__ = ["OptGS", "OptGSError"] + + +class OptGS: + """Facade around the learned per-scene optimizer.""" + + def __init__( + self, + checkpoint: str, + *, + device: str | torch.device = "cuda", + num_refine: int | None = None, + iter_batch_size: int | None = None, + opt_batch_size: int | None = None, + opt_batch_strategy: str | None = None, + background_color: Sequence[float] | None = None, + rasterize_mode: str | None = None, + eps2d: float | None = None, + strict_load: bool = True, + ) -> None: + if not checkpoint: + raise OptGSError( + "OptGS(checkpoint=...) is required (an 'hf://org/repo/file' " + "reference or a local checkpoint path)." + ) + self.device = torch.device(device) + if self.device.type != "cuda": + raise OptGSError( + "OptGS requires a CUDA device (the learned optimizer uses " + "CUDA/KNN kernels). Pass device='cuda'." + ) + # float32 only — the learned optimizer's CUDA/KNN kernels and the + # gsplat rasterizer require it (and the checkpoint trained with it). + self.dtype = torch.float32 + self.iter_batch_size = iter_batch_size + self.opt_batch_size = opt_batch_size + self.opt_batch_strategy = opt_batch_strategy + + from optgs.config import _find_config_for_checkpoint + from optgs.experimental.api.integration.config_bridge import ( + build_decoder, + build_optimizer, + build_optimizer_cfg, + get_scene_trainer_scalar, + load_optimizer_state, + ) + from optgs.misc.hf_ckpt import hf_sibling_config, maybe_resolve_hf_ref + + local_ckpt = maybe_resolve_hf_ref(checkpoint) + # For hf:// refs, hf_hub_download fetches only the ckpt; pull the + # sibling config.yaml so the architecture can be rebuilt. + cfg_path = hf_sibling_config(checkpoint) or _find_config_for_checkpoint(local_ckpt) + if cfg_path is None: + raise OptGSError( + f"no config.yaml found next to checkpoint {local_ckpt!r} " + f"(looked for /../../config.yaml and the wandb " + f"latest-run fallback). OptGS needs the training config to " + f"rebuild the optimizer architecture." + ) + + opt_cfg, num_update_steps = build_optimizer_cfg(cfg_path) + + if not getattr(opt_cfg, "init_state_wo_features", False): + warnings.warn( + "this checkpoint was trained WITH encoder features " + "(scene_trainer.scene_optimizer.init_state_wo_features=False). " + "External SfM/inria scenes carry no optgs encoder features; " + "proceeding with init_state_wo_features=True — the " + "feature-conditioned update_proj weights are dropped and the " + "initial optimizer state is set to a standard-normal random " + "vector (init_state_type='random', init_state_scale=1.0)." + ) + opt_cfg.init_state_wo_features = True + opt_cfg.init_state_type = "random" + opt_cfg.init_state_scale = 1.0 + + optimizer = build_optimizer(opt_cfg) # asserts cfg.name; nn.Module + load_optimizer_state( + optimizer, local_ckpt, init_state_wo_features=True, strict=strict_load + ) + self.optimizer = optimizer.to(device=self.device, dtype=self.dtype).eval() + + from types import SimpleNamespace + + bg = list(background_color) if background_color is not None else [0.0, 0.0, 0.0] + # Build the renderer the checkpoint trained with (gsplat by default; + # NOT a hardcoded backend — see build_decoder). rasterize_mode / eps2d, + # when given, override the checkpoint's decoder config. + decoder_overrides = { + k: v + for k, v in (("rasterize_mode", rasterize_mode), ("eps2d", eps2d)) + if v is not None + } + self.decoder = build_decoder( + cfg_path, SimpleNamespace(background_color=bg), decoder_overrides + ).to(self.device) + + resolved = num_refine if num_refine is not None else num_update_steps + if resolved is None: + raise OptGSError( + "num_refine could not be determined: pass OptGS(num_refine=...) " + "or use a checkpoint whose config has " + "scene_trainer.num_update_steps." + ) + self.num_refine = int(resolved) + + # Render-batching size: user override, else the checkpoint's + # scene_trainer.iter_batch_size (-1 = render all views per step). + if self.iter_batch_size is None: + self.iter_batch_size = int( + get_scene_trainer_scalar(cfg_path, "iter_batch_size", -1) + ) + + # Per-step view minibatch — opt_batch_size views are fed to the + # optimizer each step (the checkpoint's scene_trainer.opt_batch_size / + # opt_batch_strategy, i.e. the regime it was trained with). -1 = all. + if self.opt_batch_size is None: + self.opt_batch_size = int( + get_scene_trainer_scalar(cfg_path, "opt_batch_size", -1) + ) + if self.opt_batch_strategy is None: + self.opt_batch_strategy = str( + get_scene_trainer_scalar(cfg_path, "opt_batch_strategy", "random") + ) + if self.opt_batch_strategy not in ("random", "sequential", "fps"): + raise OptGSError( + f"opt_batch_strategy={self.opt_batch_strategy!r} is not supported " + f"by the API (supported: 'random', 'sequential', 'fps'). Pass " + f"OptGS(opt_batch_strategy='random')." + ) + + self._opt_cfg = opt_cfg + # SH degree the checkpoint's Gaussians use — derived from the optimizer + # cfg's init_sh_d (= (sh_degree + 1) ** 2, set by opt_cfg.update from the + # initializer cfg). API consumers build/render Gaussians with this; it is + # dictated by the checkpoint, not a free choice. + self.sh_degree = int(round(opt_cfg.init_sh_d ** 0.5)) - 1 + self._initialized = False + self._scene_ref = None + self._context = None + self._init_output = None + self._refined: "Gaussians | None" = None + + # ------------------------------------------------------------------ + # Ingest + # ------------------------------------------------------------------ + + def initialize(self, scene: object) -> "OptGS": + """Ingest an already-initialized inria-style scene. + + This does NOT run optgs's learned Initializer — the scene already has + Gaussians (e.g. from SfM / inria ``create_from_pcd``). + """ + from optgs.experimental.api.integration.inria_bridge import ( + batched_views_from_cameras, + optgs_gaussians_from_inria_model, + ) + from optgs.experimental.api.integration.scene_protocol import ( + assert_scene_protocol, + ) + from optgs.scene_trainer.initializer.initializer import InitializerOutput + + assert_scene_protocol(scene) + g = optgs_gaussians_from_inria_model( + scene.gaussians, device=self.device, dtype=self.dtype + ) + self._init_output = InitializerOutput(gaussians=g, features=None, depths=None) + self._context = batched_views_from_cameras( + list(scene.getTrainCameras()), + scene_scale=float(scene.cameras_extent), + device=self.device, + dtype=self.dtype, + ) + self._scene_ref = scene + self._initialized = True + return self + + def initialize_from_ply( + self, + ply_path: str, + cameras: Sequence[object], + *, + sh_degree: int, + scene_scale: float, + ) -> "OptGS": + """Low-level ingest for non-inria codebases (no inria ``Scene``). + + ``cameras`` is a sequence of inria-``Camera``-like objects (``R``, + ``T``, ``FoVx``, ``FoVy``, ``image_width``, ``image_height``, + ``original_image``). + """ + from optgs.experimental.api.integration.inria_bridge import ( + batched_views_from_cameras, + optgs_gaussians_from_ply, + ) + from optgs.scene_trainer.initializer.initializer import InitializerOutput + + g = optgs_gaussians_from_ply( + ply_path, sh_degree=sh_degree, device=self.device, dtype=self.dtype + ) + self._init_output = InitializerOutput(gaussians=g, features=None, depths=None) + self._context = batched_views_from_cameras( + list(cameras), scene_scale=scene_scale, device=self.device, dtype=self.dtype + ) + self._scene_ref = None + self._initialized = True + return self + + def initialize_from_tensors(self, gaussians: object, batched_views: object) -> "OptGS": + """Low-level ingest from optgs-native objects (power users). + + ``gaussians``: an optgs ``Gaussians`` (batch=1, post-activation). + ``batched_views``: an optgs ``BatchedViews`` or a dict accepted by + ``BatchedViews.from_dict``. + """ + from optgs.dataset.data_types import BatchedViews + from optgs.model.types import Gaussians + from optgs.scene_trainer.initializer.initializer import InitializerOutput + + if not isinstance(gaussians, Gaussians): + raise OptGSError( + "initialize_from_tensors expects an optgs Gaussians instance " + "(use initialize_from_ply for raw 3DGS PLY input)." + ) + bv = ( + batched_views + if isinstance(batched_views, BatchedViews) + else BatchedViews.from_dict(batched_views) + ) + self._init_output = InitializerOutput( + gaussians=gaussians.to(device=self.device, dtype=self.dtype), + features=None, + depths=None, + ) + self._context = bv + self._scene_ref = None + self._initialized = True + return self + + # ------------------------------------------------------------------ + # Optimize + # ------------------------------------------------------------------ + + def _view_minibatch(self, views): + """Sample the next per-step view minibatch from ``views``. + + Mirrors SceneTrainer's viewpoint-stack cycling: views are drawn + ``opt_batch_size`` at a time and the stack is refilled once exhausted, + so every view is seen before any repeats. ``random``/``sequential`` take + the front of the (shuffled/ordered) stack; ``fps`` picks a + farthest-point spread over the remaining views' camera positions. + Returns ``views`` unchanged when ``opt_batch_size`` is <= 0 or already + covers the whole scene. + """ + v = views.image.shape[1] + bs = self.opt_batch_size + if bs <= 0 or bs >= v: + return views + + views.reset_viewpoint_stack_if_needed(self.opt_batch_strategy, bs) + stack = views.viewpoint_stack # [B, V_stack] + + if self.opt_batch_strategy == "fps": + from optgs.dataset.view_sampler.view_sampler_bounded_v2 import ( + farthest_point_sample, + ) + + b = stack.shape[0] + arange = torch.arange(b, device=stack.device)[:, None] + # FPS over the camera positions of the views still in the stack. + positions = views.extrinsics[arange, stack][:, :, :3, 3] # [B, V_stack, 3] + local = farthest_point_sample(positions, bs, first_idx_strategy="random") + idx = stack[arange, local] # [B, bs] + keep = ~(stack.unsqueeze(-1) == idx.unsqueeze(1)).any(-1) # [B, V_stack] + views.viewpoint_stack = stack[keep].view(b, -1) + else: # random / sequential — take the front of the stack + idx = stack[:, :bs] + views.viewpoint_stack = stack[:, bs:] + return views.batchify_views(idx) + + @torch.no_grad() + def optimize(self, scene: object | None = None, *, optimizer=None): + """Run the learned optimization. + + inria path: refined Gaussians are written back into ``scene.gaussians`` + in place and ``scene.gaussians`` is returned. Low-level path: the + refined optgs ``Gaussians`` is returned (use :meth:`export_ply` to + persist). + + ``optimizer`` swaps in a different optgs ``Optimizer`` (e.g. an Adam + baseline) — running the *same* per-scene pipeline (init, view minibatch, + step budget, renderer) with another update rule, i.e. a fair + comparison. Defaults to the checkpoint's learned optimizer. + """ + if scene is not None and scene is not self._scene_ref: + self.initialize(scene) + if not self._initialized: + raise OptGSError("call initialize(scene) before optimize().") + + opt = optimizer if optimizer is not None else self.optimizer + + from optgs.scene_trainer.optimizer.optimizer import ( + OptimizerInput, + OptimizerOutput, + OptimizerPreviousOutput, + ) + + inp = OptimizerInput( + context=self._context, + renderer=self.decoder, + prev_output=self._init_output, + num_refine=self.num_refine, + iter_batch_size=self.iter_batch_size, + target=self._context, + ) + opt.validate_input(inp) + opt.on_scene_start(inp) # InitializerOutput -> OptimizerPreviousOutput (+ADC) + if not isinstance(inp.prev_output, OptimizerPreviousOutput): + raise OptGSError( + "optimizer.on_scene_start did not produce an " + f"OptimizerPreviousOutput (got {type(inp.prev_output)})." + ) + + out = OptimizerOutput.empty(t=0) + out.T = self.num_refine + steps = range(self.num_refine) + try: + from tqdm import tqdm + + steps = tqdm(steps, desc=f"optimize[{type(opt).__name__}]") + except Exception: + pass + for step in steps: + # Feed the optimizer a fresh view minibatch each step (the regime it + # was trained with); full_context/full_target stay the whole scene. + batch = self._view_minibatch(self._context) + inp.context = batch + inp.target = batch + out = opt( + step, inp, out, full_context=self._context, full_target=self._context + ) + out.t = (out.t or 0) + 1 + + if torch.cuda.is_available(): + torch.cuda.synchronize() + opt.on_scene_end() + + final = inp.prev_output.gaussians + self._refined = final + + if self._scene_ref is not None: + from optgs.experimental.api.integration.inria_bridge import ( + write_back_to_inria_model, + ) + + write_back_to_inria_model(self._scene_ref.gaussians, final) + return self._scene_ref.gaussians + return final + + def optimize_iter(self, *, optimizer=None): + """Generator form of :meth:`optimize`: yields ``(step, gaussians)`` after + each optimization step. + + Lets a caller drive the learned optimization one step at a time and + render the Gaussians in between — used by ``demo.py``'s ``--with-gui``. + ``on_scene_end()`` runs even if the caller closes the generator early + (e.g. a GUI Reset), via the ``finally`` block. + """ + if not self._initialized: + raise OptGSError("call initialize(...) before optimize_iter().") + + opt = optimizer if optimizer is not None else self.optimizer + + from optgs.scene_trainer.optimizer.optimizer import ( + OptimizerInput, + OptimizerOutput, + OptimizerPreviousOutput, + ) + + with torch.no_grad(): + inp = OptimizerInput( + context=self._context, + renderer=self.decoder, + prev_output=self._init_output, + num_refine=self.num_refine, + iter_batch_size=self.iter_batch_size, + target=self._context, + ) + opt.validate_input(inp) + opt.on_scene_start(inp) # InitializerOutput -> OptimizerPreviousOutput + if not isinstance(inp.prev_output, OptimizerPreviousOutput): + raise OptGSError( + "optimizer.on_scene_start did not produce an " + f"OptimizerPreviousOutput (got {type(inp.prev_output)})." + ) + + out = OptimizerOutput.empty(t=0) + out.T = self.num_refine + try: + for step in range(self.num_refine): + # Fresh view minibatch each step (the regime the optimizer + # was trained with); full_context/target stay the whole scene. + batch = self._view_minibatch(self._context) + inp.context = batch + inp.target = batch + out = opt( + step, inp, out, + full_context=self._context, full_target=self._context, + ) + out.t = (out.t or 0) + 1 + yield step, inp.prev_output.gaussians + finally: + if torch.cuda.is_available(): + torch.cuda.synchronize() + opt.on_scene_end() + self._refined = inp.prev_output.gaussians + + def export_ply(self, path: str) -> None: + """Write the most recently refined Gaussians to a 3DGS PLY.""" + if self._refined is None: + raise OptGSError("nothing to export — call optimize() first.") + from pathlib import Path + + from optgs.model.ply_export import save_gaussian_ply + + save_gaussian_ply(self._refined, save_path=Path(path)) diff --git a/optgs/experimental/api/integration/__init__.py b/optgs/experimental/api/integration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..378da3222aedca2e785fd914f6fa95982ae64ef7 --- /dev/null +++ b/optgs/experimental/api/integration/__init__.py @@ -0,0 +1,22 @@ +"""Integration helpers for using optgs's learned optimizer in external +(inria-style) 3D Gaussian Splatting codebases. + +Public surface is :class:`optgs.OptGS`; the symbols here are the building +blocks (kept importable for advanced/low-level use and testing). +""" + +from optgs.experimental.api.integration.scene_protocol import ( + OptGSError, + CameraLike, + GaussiansLike, + SceneLike, + assert_scene_protocol, +) + +__all__ = [ + "OptGSError", + "CameraLike", + "GaussiansLike", + "SceneLike", + "assert_scene_protocol", +] diff --git a/optgs/experimental/api/integration/config_bridge.py b/optgs/experimental/api/integration/config_bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..245ad91c87af2cf6b3dbfb7a249acb7b14405e61 --- /dev/null +++ b/optgs/experimental/api/integration/config_bridge.py @@ -0,0 +1,420 @@ +"""Hydra-free checkpoint -> optimizer construction. + +Rebuilds the learned optimizer (architecture + weights) from a checkpoint +*without* going through Hydra. Only ``_load_checkpoint_cfg`` + +``load_typed_config`` are used (both Hydra-free); the Hydra coupling lives in +``setup_cfg`` / ``merge_config_from_file`` / ``setup_output_dir`` which we never +call. All heavy imports are deferred into the functions so ``import optgs`` +stays cheap. +""" + +from __future__ import annotations + +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING + +from optgs.experimental.api.integration.scene_protocol import OptGSError + + +@lru_cache(maxsize=8) +def _load_ckpt_cfg_cached(cfg_path_str: str): + """Load + migrate a checkpoint config once per path (read-only callers). + + ``build_optimizer_cfg`` / ``build_decoder`` / ``get_scene_trainer_scalar`` + all need the same DictConfig; caching avoids re-parsing the file. + """ + from optgs.config import _load_checkpoint_cfg # Hydra-free + + return _load_checkpoint_cfg(Path(cfg_path_str)) + + +def get_scene_trainer_scalar(cfg_path: Path, key: str, default): + """Read ``scene_trainer.`` from a checkpoint config (or ``default``). + + Used for scalars that live on the (Hydra-free unavailable) scene-trainer + config rather than the optimizer cfg: ``num_update_steps``, + ``iter_batch_size``, ``sh_degree_interval``. + """ + from omegaconf import OmegaConf + + cfg = _load_ckpt_cfg_cached(str(cfg_path)) + return OmegaConf.select(cfg, f"scene_trainer.{key}", default=default) + +if TYPE_CHECKING: # pragma: no cover - typing only + from torch import nn + + from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizerCfg + + +def _optimizer_class_by_cfg_name(): + """Map a checkpoint config's ``scene_optimizer.name`` -> optimizer class. + + The registry (``SCENE_OPTIMIZERS``) keys on registry names (e.g. + ``"depthsplat"``), but a checkpoint config's ``scene_optimizer.name`` is the + cfg literal (``"knn_based"`` / ``"l2s"`` / ``"resplat_v1"`` / + ``"resplat_v2"``). Each class asserts ``cfg.name`` is its ``OPTIMIZER_NAME`` + or one of its ``OPTIMIZER_NAME_ALIASES`` (e.g. legacy ``"clogs"`` for + ``Learn2SplatOptimizer``), so we dispatch on both. + """ + from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizer + from optgs.scene_trainer.optimizer.optimizer_learn2splat import ( + Learn2SplatOptimizer, + ) + from optgs.scene_trainer.optimizer.optimizer_resplat import ( + ResplatOptimizerV1, + ResplatOptimizerV2, + ) + + classes = ( + KnnBasedOptimizer, + Learn2SplatOptimizer, + ResplatOptimizerV1, + ResplatOptimizerV2, + ) + mapping = {} + for cls in classes: + for name in (cls.OPTIMIZER_NAME, *getattr(cls, "OPTIMIZER_NAME_ALIASES", ())): + mapping[name] = cls + return mapping + + +def _initializer_cfg_class(name: str): + """Map ``scene_initializer.name`` -> its concrete typed Cfg dataclass. + + ``InitializerCfg`` is a PEP-604 union; dacite needs a concrete dataclass + as the top-level target (a union is only resolvable as a *field* type). + Keyed to match both ``SCENE_INITIALIZERS`` and each Cfg's ``name`` + Literal. + """ + from optgs.scene_trainer.initializer import ( + InitializerColmapCfg, + InitializerEdgsCfg, + InitializerPlyCfg, + InitializerPointcloudCfg, + InitializerRandomCfg, + ResplatInitializerCfg, + ) + + return { + "resplat_v1": ResplatInitializerCfg, + "resplat_v2": ResplatInitializerCfg, + "colmap": InitializerColmapCfg, + "ply": InitializerPlyCfg, + "edgs": InitializerEdgsCfg, + "random": InitializerRandomCfg, + "pointcloud": InitializerPointcloudCfg, + }.get(name) + + +def _compose_default_group(group: str, value: str): + """Hydra-compose the bundled default for ``scene_trainer.=``. + + Released checkpoints predate fields later added to the typed configs + (e.g. ``scene_optimizer.refiner.fallback_means_lr``). The training/eval + pipeline reconciles this by merging the checkpoint config over the + *current* default config (config.py:merge_config_from_file). We mirror + that: compose the bundled default for the group (e.g. + ``scene_optimizer=knn_based`` -> base -> refiner:none, or + ``scene_initializer=colmap``) so missing fields can be backfilled with + current defaults while checkpoint values win for shared keys. + + Scoped use of ``hydra.compose`` (no ``@hydra.main`` / ``HydraConfig.get``, + no app context); lazily imported so ``import optgs`` stays light. Returns + ``None`` if composition fails (caller falls back to a strict parse). + """ + try: + import optgs + from hydra import compose, initialize_config_dir + from hydra.core.global_hydra import GlobalHydra + from omegaconf import OmegaConf + + config_dir = str(Path(optgs.__file__).resolve().parent / "config") + GlobalHydra.instance().clear() + try: + with initialize_config_dir(version_base=None, config_dir=config_dir): + composed = compose( + config_name="main", + overrides=[f"scene_trainer/{group}={value}"], + ) + finally: + GlobalHydra.instance().clear() + return OmegaConf.select(composed, f"scene_trainer.{group}") + except Exception as e: # noqa: BLE001 - best-effort backfill + print( + f"[optgs] warning: could not compose default scene_trainer.{group}" + f"={value} for back-compat merge ({type(e).__name__}: {e}); " + f"parsing checkpoint config as-is." + ) + return None + + +def build_optimizer_cfg(cfg_path: Path) -> tuple["KnnBasedOptimizerCfg", int | None]: + """Load a checkpoint's saved config and return its typed optimizer cfg. + + Returns ``(KnnBasedOptimizerCfg, num_update_steps)`` where + ``num_update_steps`` (the per-scene optimization step count) is read from + ``scene_trainer.num_update_steps`` if present (it is NOT part of the + optimizer cfg), else ``None``. + """ + from omegaconf import OmegaConf + + from optgs.config import load_typed_config + from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizerCfg + + cfg = _load_ckpt_cfg_cached(str(cfg_path)) # read_omega_cfg + migrate; NO Hydra + so = OmegaConf.select(cfg, "scene_trainer.scene_optimizer") + name = OmegaConf.select(cfg, "scene_trainer.scene_optimizer.name") + if so is None or name in (None, "none"): + raise OptGSError( + f"checkpoint config at {cfg_path} has no learned scene_optimizer " + f"(scene_trainer.scene_optimizer={name!r}). OptGS needs a learned " + f"optimizer checkpoint (knn_based / clogs / resplat_v1 / resplat_v2)." + ) + # Backfill fields a released (older) checkpoint config lacks with the + # current defaults, then let checkpoint values win for shared keys + # (mirrors config.py:merge_config_from_file's OmegaConf.merge). + default_so = _compose_default_group("scene_optimizer", "knn_based") + if default_so is not None: + OmegaConf.set_struct(default_so, False) + merged_so = OmegaConf.merge(default_so, so) + else: + merged_so = so + try: + opt_cfg = load_typed_config(merged_so, KnnBasedOptimizerCfg) + except Exception as e: # dacite/omegaconf errors -> actionable message + raise OptGSError( + f"failed to parse scene_optimizer from {cfg_path} into " + f"KnnBasedOptimizerCfg ({type(e).__name__}: {e})." + ) from e + + # Mirror SceneTrainerCfg (scene_trainer_cfg.py: scene_optimizer.update( + # scene_initializer)): wire the checkpoint's initializer cfg into the + # optimizer cfg so the runtime-only fields init_gaussian_param_num / + # init_sh_d / sh_d — absent from every config file — are populated before + # the optimizer nn.Module is built. + si = OmegaConf.select(cfg, "scene_trainer.scene_initializer") + si_name = OmegaConf.select(cfg, "scene_trainer.scene_initializer.name") + if si is None or si_name in (None, "none"): + raise OptGSError( + f"checkpoint config at {cfg_path} has no scene_initializer " + f"(name={si_name!r}); cannot derive init_gaussian_param_num " + f"required to build the optimizer." + ) + init_cls = _initializer_cfg_class(str(si_name)) + if init_cls is None: + raise OptGSError( + f"unsupported scene_initializer.name={si_name!r} in {cfg_path}; " + f"cannot derive init_gaussian_param_num for the optimizer." + ) + default_si = _compose_default_group("scene_initializer", str(si_name)) + if default_si is not None: + OmegaConf.set_struct(default_si, False) + merged_si = OmegaConf.merge(default_si, si) + else: + merged_si = si + try: + init_cfg = load_typed_config(merged_si, init_cls) + opt_cfg.update(init_cfg) # sets init_gaussian_param_num/init_sh_d/sh_d + except Exception as e: + raise OptGSError( + f"failed to wire scene_initializer ({si_name!r}) into the " + f"optimizer cfg from {cfg_path} ({type(e).__name__}: {e})." + ) from e + + num_update_steps = OmegaConf.select( + cfg, "scene_trainer.num_update_steps", default=None + ) + return opt_cfg, num_update_steps + + +def build_decoder( + cfg_path: Path, dataset_cfg: object, decoder_overrides: dict | None = None +) -> "nn.Module": + """Build the renderer the checkpoint was trained with. + + Uses ``scene_trainer.decoder`` from the checkpoint config (NOT a hardcoded + backend): the learned optimizer's in-loop render gradients must match the + backend it trained with, and only the registered/available backends are + usable (e.g. ``gsplat`` — the optgs default; the ``inria`` backend needs + ``diff_gaussian_rasterization``, which is optional). ``dataset_cfg`` only + needs a ``background_color`` attribute. ``decoder_overrides`` (e.g. + ``rasterize_mode`` / ``eps2d``) take precedence over the checkpoint config. + """ + from omegaconf import OmegaConf + + from optgs.config import load_typed_config + from optgs.model.decoder import DecoderCfg, get_decoder + + cfg = _load_ckpt_cfg_cached(str(cfg_path)) + node = OmegaConf.select(cfg, "scene_trainer.decoder") + if node is None: + raise OptGSError( + f"checkpoint config at {cfg_path} has no scene_trainer.decoder; " + f"cannot rebuild the renderer the optimizer trained with." + ) + # gsplat decoder rasterize_mode / eps2d, by precedence: + # caller override > checkpoint config > gsplat rasterization() default + # (so an older checkpoint that omits a field behaves as plain gsplat would). + if OmegaConf.select(node, "name") == "gsplat": + import inspect + + from gsplat.rendering import rasterization + + sig = inspect.signature(rasterization).parameters + node = OmegaConf.merge( + OmegaConf.create( + {f: sig[f].default for f in ("rasterize_mode", "eps2d") if f in sig} + ), + node, + OmegaConf.create(dict(decoder_overrides or {})), + ) + try: + decoder_cfg = load_typed_config(node, DecoderCfg) + except Exception as e: + raise OptGSError( + f"failed to parse scene_trainer.decoder from {cfg_path} " + f"({type(e).__name__}: {e})." + ) from e + try: + return get_decoder(decoder_cfg, dataset_cfg) + except (KeyError, ImportError) as e: + raise OptGSError( + f"decoder backend {decoder_cfg.name!r} is not available in this " + f"environment ({type(e).__name__}: {e}). Install its backend " + f"(e.g. diff_gaussian_rasterization for 'inria') or use a " + f"checkpoint trained with the 'gsplat' decoder." + ) from e + + +def build_optimizer(opt_cfg: "KnnBasedOptimizerCfg") -> "nn.Module": + """Construct the concrete learned optimizer for ``opt_cfg`` (no weights).""" + from optgs.misc.io import FrequencyScheduler + + mapping = _optimizer_class_by_cfg_name() + cls = mapping.get(opt_cfg.name) + if cls is None: + raise OptGSError( + f"unsupported scene_optimizer.name={opt_cfg.name!r}; OptGS supports " + f"{sorted(mapping)}." + ) + optimizer = cls(opt_cfg) + # The optimizer's save_every (info/context/target/debug artifact dumps) is + # wired by SceneTrainer during training; the optimizer calls it + # unconditionally, so the API inference path — which has nothing to dump — + # installs a disabled scheduler instead of leaving it None. + save_every = FrequencyScheduler(last_step=0) + save_every.disable(True) + optimizer.save_every = save_every + return optimizer + + +def build_adam_baseline(num_refine: int) -> "nn.Module": + """Build the codebase's 3DGS Adam optimizer for a fair baseline comparison. + + Uses the bundled ``scene_optimizer=3dgs`` config — gsplat's example + hyperparameters (LRs, betas). Densification is disabled so the baseline + refines the same fixed Gaussian set as the learned optimizer (a + like-for-like update-rule comparison), and the means-LR decay horizon is set + to ``num_refine``. Returns a ready-to-run ``AdamOptimizer``. + """ + from omegaconf import OmegaConf + + from optgs.config import load_typed_config + from optgs.misc.io import FrequencyScheduler + from optgs.scene_trainer.optimizer.optimizer_adam import ( + AdamOptimizer, + AdamOptimizerCfg, + ) + + composed = _compose_default_group("scene_optimizer", "3dgs") + if composed is None: + raise OptGSError( + "could not Hydra-compose the bundled 'scene_optimizer=3dgs' config " + "for the Adam baseline." + ) + OmegaConf.set_struct(composed, False) + # gsplat decays the means LR over the full step budget. + composed.means_lr_max_steps = int(num_refine) + # Disable densification — the baseline refines the same fixed Gaussian set + # as the learned optimizer (a like-for-like comparison of the update rule). + for flag in ("do_densify", "do_prune", "do_opacity_reset"): + if flag in composed.refiner: + composed.refiner[flag] = False + try: + adam_cfg = load_typed_config(composed, AdamOptimizerCfg) + except Exception as e: + raise OptGSError( + f"failed to parse the bundled '3dgs' config into AdamOptimizerCfg " + f"({type(e).__name__}: {e})." + ) from e + + optimizer = AdamOptimizer(adam_cfg) + save_every = FrequencyScheduler(last_step=0) # nothing to dump (see build_optimizer) + save_every.disable(True) + optimizer.save_every = save_every + # AdamOptimizer is a NonlearnedOptimizer — already pinned to eval mode. + return optimizer + + +# Module-attribute renames applied when the legacy Resplat encoder was split +# into separate initializer/optimizer modules (transcribed from +# optgs/main.py:load_optimizer). +_ORIG_OPTIMIZER_ATTR_RENAMES = { + "render_error_mv_attn": "update_error_attn", +} + + +def load_optimizer_state( + optimizer: "nn.Module", + ckpt_path: str, + init_state_wo_features: bool, + strict: bool, +) -> None: + """Load optimizer weights from ``ckpt_path`` into ``optimizer``. + + Transcribes the prefix-stripping / legacy-rename / feature-drop logic from + ``optgs/main.py:load_optimizer`` (we cannot call that function: it needs a + full Hydra ``cfg`` and a ``scene_trainer``). + """ + import torch + + state = torch.load(ckpt_path, map_location="cpu") + if isinstance(state, dict) and "state_dict" in state: + state = state["state_dict"] + # Strip the Lightning "scene_trainer." prefix if present. + state = {k.replace("scene_trainer.", ""): v for k, v in state.items()} + + if any(k.startswith("optimizer.") for k in state): + # Unified repo format: keys are optimizer.* + osd = { + k[len("optimizer."):]: v + for k, v in state.items() + if k.startswith("optimizer.") + } + else: + # Legacy Resplat format: keys are encoder.* (before init/opt split). + osd = { + k[len("encoder."):]: v + for k, v in state.items() + if k.startswith("encoder.") + } + renamed = {} + for k, v in osd.items(): + for old, new in _ORIG_OPTIMIZER_ATTR_RENAMES.items(): + if k == old or k.startswith(old + "."): + k = new + k[len(old):] + break + renamed[k] = v + osd = renamed + + if not osd: + raise OptGSError( + f"no optimizer weights found in {ckpt_path} (looked for " + f"'optimizer.*' or legacy 'encoder.*' keys)." + ) + + if init_state_wo_features: + osd = {k: v for k, v in osd.items() if "update_proj" not in k} + + optimizer.load_state_dict(osd, strict=strict) diff --git a/optgs/experimental/api/integration/inria_bridge.py b/optgs/experimental/api/integration/inria_bridge.py new file mode 100644 index 0000000000000000000000000000000000000000..d3828903c9aa8192cea9da3d7f49f191f865de6f --- /dev/null +++ b/optgs/experimental/api/integration/inria_bridge.py @@ -0,0 +1,198 @@ +"""Convention bridges between inria 3DGS objects and optgs types. + +* Gaussians: via the original-3DGS PLY schema. ``optgs/model/ply_export.py`` + (``load_gaussians_ply`` / ``save_gaussian_ply``) and inria + ``GaussianModel.save_ply`` / ``load_ply`` write/read the *same* schema + (scales log<->exp, opacity logit<->sigmoid, quat wxyz<->xyzw, SH dc/rest), + so a PLY round-trip is convention-correct by construction. + +* Cameras: inria stores world->camera ``R,T`` + FoV (COLMAP convention); + optgs wants camera->world extrinsics and image-size-normalized intrinsics. + We reuse inria's own ``getWorld2View2`` / ``fov2focal`` for the forward + direction so the round trip through ``optgs.geometry.projection.get_fov`` is + exact. + +All inria imports are deferred into functions (the inria repo is only on +``sys.path`` when the caller runs from it). +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path +from typing import TYPE_CHECKING, Sequence + +if TYPE_CHECKING: # pragma: no cover - typing only + import torch + + from optgs.model.types import Gaussians + + +def _import_inria_graphics(): + try: + from utils.graphics_utils import fov2focal, getWorld2View2 # type: ignore + except Exception as e: # ImportError or deeper + from optgs.experimental.api.integration.scene_protocol import OptGSError + + raise OptGSError( + "could not import inria graphics utils (utils.graphics_utils). " + "Run from your inria gaussian-splatting checkout (so it is on " + "sys.path), or use OptGS.initialize_from_ply / " + "initialize_from_tensors. " + f"Original error: {type(e).__name__}: {e}" + ) from e + return getWorld2View2, fov2focal + + +# --------------------------------------------------------------------------- +# Gaussians +# --------------------------------------------------------------------------- + +def optgs_gaussians_from_ply( + ply_path: str | Path, + *, + sh_degree: int, + device: "torch.device", + dtype: "torch.dtype", +) -> "Gaussians": + """Load a 3DGS PLY into an optgs ``Gaussians`` (batch=1, post-activation).""" + from optgs.model.ply_export import load_gaussians_ply + from optgs.scene_trainer.common.gaussians import build_covariance + + g = load_gaussians_ply(str(ply_path), max_sh_degree=sh_degree) + g = g.to(device=device, dtype=dtype) + # Populate covariances so any depth / use_covariances path is safe (the + # default color path recomputes from scales+rotations anyway). + try: + g.covariances = build_covariance(g.scales[0], g.rotations[0]).unsqueeze(0) + except Exception: + g.covariances = None + return g + + +def optgs_gaussians_from_inria_model( + gm: object, + *, + device: "torch.device", + dtype: "torch.dtype", +) -> "Gaussians": + """Ingest an inria ``GaussianModel`` via a temp PLY round-trip.""" + sh_degree = int(getattr(gm, "max_sh_degree", 3)) + with tempfile.TemporaryDirectory(prefix="optgs_ingest_") as d: + tmp = Path(d) / "gaussians.ply" + gm.save_ply(str(tmp)) # inria writer (original-3DGS schema) + return optgs_gaussians_from_ply( + tmp, sh_degree=sh_degree, device=device, dtype=dtype + ) + + +def write_back_to_inria_model(gm: object, final: "Gaussians") -> None: + """Replace an inria ``GaussianModel``'s params with refined Gaussians. + + Full-replacement semantics: the learned ADC may change the point count, so + we reallocate every parameter (via inria ``load_ply``) and reset the inria + Adam optimizer + densification accumulators. The caller must call + ``gaussians.training_setup(...)`` again before resuming inria Adam. + """ + import torch + + from optgs.model.ply_export import save_gaussian_ply + + with tempfile.TemporaryDirectory(prefix="optgs_writeback_") as d: + tmp = Path(d) / "refined.ply" + # save_gaussian_ply: B==1, xyzw->wxyz, re-inverts activations. + save_gaussian_ply(final, save_path=tmp) + gm.load_ply(str(tmp)) # reallocs _xyz/_features_*/_opacity/_scaling/_rotation + + n = gm._xyz.shape[0] + dev = gm._xyz.device + gm.optimizer = None + gm.xyz_gradient_accum = torch.zeros((n, 1), device=dev) + gm.denom = torch.zeros((n, 1), device=dev) + gm.max_radii2D = torch.zeros((n,), device=dev) + + +# --------------------------------------------------------------------------- +# Cameras +# --------------------------------------------------------------------------- + +def batched_views_from_cameras( + cameras: Sequence[object], + *, + scene_scale: float, + device: "torch.device", + dtype: "torch.dtype", + near: float = 0.01, + far: float = 100.0, +): + """Build an optgs ``BatchedViews`` (B=1) from inria-style cameras. + + ``near``/``far`` default to inria's hardcoded ``znear=0.01``/``zfar=100.0`` + (also the optgs colmap-dataset constants). All cameras must share one + (H, W) — the decoder takes a single image shape. + """ + import torch + + from optgs.dataset.data_types import BatchedViews + + getWorld2View2, fov2focal = _import_inria_graphics() + + if len(cameras) == 0: + from optgs.experimental.api.integration.scene_protocol import OptGSError + + raise OptGSError("no cameras provided.") + + exts, intrs, imgs = [], [], [] + H0 = W0 = None + for cam in cameras: + W = int(cam.image_width) + H = int(cam.image_height) + if H0 is None: + H0, W0 = H, W + elif (H, W) != (H0, W0): + from optgs.experimental.api.integration.scene_protocol import OptGSError + + raise OptGSError( + f"all train cameras must share one (H, W); got {(H, W)} vs " + f"{(H0, W0)}. Render to a single resolution before optimizing." + ) + + w2c = torch.tensor( + getWorld2View2(cam.R, cam.T), dtype=torch.float32 + ) # [4,4] world->camera + c2w = torch.inverse(w2c) # optgs extrinsics convention + + fx = fov2focal(cam.FoVx, W) + fy = fov2focal(cam.FoVy, H) + cx = float(getattr(cam, "cx", W / 2.0)) + cy = float(getattr(cam, "cy", H / 2.0)) + K = torch.eye(3, dtype=torch.float32) + K[0, 0] = fx / W # normalized focal + K[1, 1] = fy / H + K[0, 2] = cx / W # normalized principal point + K[1, 2] = cy / H + + img = cam.original_image + if not torch.is_tensor(img): + img = torch.as_tensor(img) + img = img.float().clamp(0.0, 1.0) # [3, H, W] + + exts.append(c2w) + intrs.append(K) + imgs.append(img) + + V = len(cameras) + extrinsics = torch.stack(exts).unsqueeze(0).to(device=device, dtype=dtype) + intrinsics = torch.stack(intrs).unsqueeze(0).to(device=device, dtype=dtype) + image = torch.stack(imgs).unsqueeze(0).to(device=device, dtype=dtype) + return BatchedViews.from_dict( + { + "extrinsics": extrinsics, + "intrinsics": intrinsics, + "image": image, + "near": torch.full((1, V), near, device=device, dtype=dtype), + "far": torch.full((1, V), far, device=device, dtype=dtype), + "index": torch.arange(V, device=device).unsqueeze(0), + "scene_scale": torch.tensor([float(scene_scale)], device=device, dtype=dtype), + } + ) diff --git a/optgs/experimental/api/integration/scene_protocol.py b/optgs/experimental/api/integration/scene_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..d976e427e1fc92a78f38d0420d20c27279fa85cb --- /dev/null +++ b/optgs/experimental/api/integration/scene_protocol.py @@ -0,0 +1,110 @@ +"""Structural contracts the external scene must satisfy. + +These ``runtime_checkable`` Protocols are written so that inria +gaussian-splatting's own ``Scene`` / ``GaussianModel`` / ``Camera`` classes +satisfy them **with no changes** (graphdeco-inria/gaussian-splatting and the +3DGS-LM fork in ``baselines/3DGS-LM``). Non-inria codebases can either expose +the same attribute names or use the low-level ``OptGS.initialize_from_ply`` / +``initialize_from_tensors`` entrypoints. +""" + +from __future__ import annotations + +from typing import Any, Protocol, Sequence, runtime_checkable + + +class OptGSError(RuntimeError): + """Raised for all OptGS API misuse / unsupported-checkpoint conditions.""" + + +@runtime_checkable +class GaussiansLike(Protocol): + """An inria ``GaussianModel`` (raw, pre-activation parameter storage).""" + + active_sh_degree: int + max_sh_degree: int + _xyz: Any + _features_dc: Any + _features_rest: Any + _scaling: Any + _rotation: Any + _opacity: Any + + def save_ply(self, path: str) -> None: ... + def load_ply(self, path: str) -> None: ... + + +@runtime_checkable +class CameraLike(Protocol): + """An inria ``Camera`` (world->camera ``R,T`` + FoV, COLMAP convention).""" + + R: Any + T: Any + FoVx: float + FoVy: float + image_width: int + image_height: int + original_image: Any # [3, H, W] in [0, 1] + + +@runtime_checkable +class SceneLike(Protocol): + """An inria ``Scene`` holding a ``GaussianModel`` and posed cameras.""" + + cameras_extent: float + gaussians: GaussiansLike + + def getTrainCameras(self, scale: float = 1.0) -> Sequence[CameraLike]: ... + + +# Explicit attribute lists drive precise error messages (clearer than a bare +# isinstance failure, which does not say *which* member is missing). +_SCENE_ATTRS = ("cameras_extent", "gaussians") +_SCENE_METHODS = ("getTrainCameras",) +_GAUSSIAN_ATTRS = ( + "active_sh_degree", "max_sh_degree", + "_xyz", "_features_dc", "_features_rest", "_scaling", "_rotation", "_opacity", +) +_GAUSSIAN_METHODS = ("save_ply", "load_ply") +_CAMERA_ATTRS = ( + "R", "T", "FoVx", "FoVy", "image_width", "image_height", "original_image", +) + + +def _missing(obj: object, attrs: tuple[str, ...], methods: tuple[str, ...] = ()) -> list[str]: + miss = [a for a in attrs if not hasattr(obj, a)] + miss += [f"{m}()" for m in methods if not callable(getattr(obj, m, None))] + return miss + + +def assert_scene_protocol(scene: object) -> None: + """Validate ``scene`` against :class:`SceneLike`; raise a precise error. + + Checked structurally (duck-typed) so inria's classes pass unchanged. + """ + miss = _missing(scene, _SCENE_ATTRS, _SCENE_METHODS) + if miss: + raise OptGSError( + f"scene is missing required attribute(s)/method(s): {', '.join(miss)}. " + f"Expected an inria-style Scene (cameras_extent, gaussians, " + f"getTrainCameras()). Use OptGS.initialize_from_ply/" + f"initialize_from_tensors for non-inria codebases." + ) + gm = scene.gaussians + gmiss = _missing(gm, _GAUSSIAN_ATTRS, _GAUSSIAN_METHODS) + if gmiss: + raise OptGSError( + f"scene.gaussians is missing: {', '.join(gmiss)}. Expected an inria " + f"GaussianModel (raw _xyz/_features_dc/_features_rest/_scaling/" + f"_rotation/_opacity + save_ply/load_ply)." + ) + cams = scene.getTrainCameras() + if cams is None or len(cams) == 0: + raise OptGSError("scene.getTrainCameras() returned no cameras.") + cmiss = _missing(cams[0], _CAMERA_ATTRS) + if cmiss: + raise OptGSError( + f"train camera is missing: {', '.join(cmiss)}. Expected an inria " + f"Camera (R, T, FoVx, FoVy, image_width, image_height, " + f"original_image)." + ) diff --git a/optgs/experimental/edgs/__init__.py b/optgs/experimental/edgs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/experimental/edgs/init.py b/optgs/experimental/edgs/init.py new file mode 100644 index 0000000000000000000000000000000000000000..8b588edd138984b6bfd39ce8d0371e6d7ceaf7e5 --- /dev/null +++ b/optgs/experimental/edgs/init.py @@ -0,0 +1,204 @@ +import torch +import torch.nn.functional as F +from torch import Tensor +import numpy as np +from tqdm import tqdm +from PIL import Image +from optgs.experimental.edgs.utils import ( + select_cameras_kmeans, + k_closest_vectors, + aggregate_confidences_and_warps, + extract_keypoints_and_colors, + triangulate_points, + select_best_keypoints +) + + +@torch.no_grad() +def init_gaussians_with_corr( + viewpoints_img: Tensor, # (N, 3, H, W) original images, column-major + viewpoints_w2c: Tensor, # (N, 4, 4) camera-to-world matrices, column-major + viewpoints_proj: Tensor, # (N, 4, 4) projection matrices, column-major + camera_centers: Tensor, # (N, 3) camera centers in world coordinates + init_opacity: float, + roma_model_type: str, + verbose: bool = False +): + """ + For a given input gaussians and a scene we instantiate a RoMa model(change to indoors if necessary) and process scene + training frames to extract correspondences. Those are used to initialize gaussians + Args: + scene: object of the Scene class. + cfg: configuration. Use init_wC + Returns: + gaussians: inplace transforms object gaussians of the class GaussianModel. + + """ + # default values used in original EDGS code + num_refs: int = 180 + nns_per_ref: int = 3 + matches_per_ref: int = 20000 + proj_err_tolerance: float = 0.01 + device = viewpoints_w2c.device + + try: + from romatch import roma_outdoor, roma_indoor + except ImportError as e: + raise ImportError( + "The edgs initializer requires RoMa (romatch), which is not " + "installed. Install it with: " + "pip install git+https://github.com/Parskatt/RoMa.git" + ) from e + if roma_model_type == "indoors": + roma_model = roma_indoor(device=device) + else: + roma_model = roma_outdoor(device=device) + roma_model.upsample_preds = False + roma_model.symmetric = False + M = matches_per_ref + upper_thresh = roma_model.sample_thresh + expansion_factor = 1 + keypoint_fit_error_tolerance = proj_err_tolerance + visualizations = {} + + N_VIEWS = viewpoints_img.shape[0] + NUM_REFERENCE_FRAMES = min(num_refs, N_VIEWS) + NUM_NNS_PER_REFERENCE = min(nns_per_ref , N_VIEWS) + + # Select cameras using K-means + # viewpoint_cam_all = torch.stack([x.world_view_transform.flatten() for x in viewpoint_stack], axis=0) + viewpoint_cam_all = viewpoints_w2c.reshape(N_VIEWS, -1) # (N_VIEWS, 16) + + selected_indices = select_cameras_kmeans(cameras=viewpoint_cam_all.detach().cpu().numpy(), K=NUM_REFERENCE_FRAMES) + selected_indices = sorted(selected_indices) + + # Find the k-closest vectors for each vector + closest_indices = k_closest_vectors(viewpoint_cam_all, NUM_NNS_PER_REFERENCE) + + if verbose: + print("Indices of k-closest vectors for each vector:\n", closest_indices) + + closest_indices_selected = closest_indices[:, :].detach().cpu().numpy() + + all_new_xyz = [] + all_new_rgb = [] + all_new_scaling = [] + all_new_opacities_raw = [] + + # Run roma_model.match once to kinda initialize the model + viewpoint_img1 = viewpoints_img[0].cpu().numpy().transpose(1, 2, 0) # [H, W, 3] + viewpoint_img2 = viewpoints_img[1].cpu().numpy().transpose(1, 2, 0) # [H, W, 3] + imA = Image.fromarray(np.clip(viewpoint_img1 * 255, 0, 255).astype(np.uint8)) + imB = Image.fromarray(np.clip(viewpoint_img2 * 255, 0, 255).astype(np.uint8)) + + warp, certainty_warp = roma_model.match(imA, imB, device=device) + if verbose: + print("Once run full roma_model.match warp.shape:", warp.shape) + print("Once run full roma_model.match certainty_warp.shape:", certainty_warp.shape) + del warp, certainty_warp + torch.cuda.empty_cache() + + for source_idx in tqdm(sorted(selected_indices)): + + # 1. Compute keypoints and warping for all the neigboring views + # Call the aggregation function to get imA and imB_compound + certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all = aggregate_confidences_and_warps( + # viewpoint_stack=viewpoint_stack, + viewpoints_img=viewpoints_img, + closest_indices=closest_indices_selected, + roma_model=roma_model, + source_idx=source_idx, + verbose=verbose, output_dict=visualizations + ) + + # Triangulate keypoints + matches = warps_max + certainty = certainties_max + certainty = certainty.clone() + certainty[certainty > upper_thresh] = 1 + matches, certainty = ( + matches.reshape(-1, 4), + certainty.reshape(-1), + ) + + # Select based on certainty elements with high confidence. These are basically all of + # kptsA_np. + good_samples = torch.multinomial(certainty, + num_samples=min(expansion_factor * M, len(certainty)), + replacement=False) + + certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all, warps_all + reference_image_dict = { + "ref_image": imA, + "NNs_images": imB_compound, + "certainties_all": certainties_all, + "warps_all": warps_all, + "triangulated_points": [], + "triangulated_points_errors_proj1": [], + "triangulated_points_errors_proj2": [] + + } + for NN_idx in tqdm(range(len(warps_all))): + matches_NN = warps_all[NN_idx].reshape(-1, 4)[good_samples] + + # Extract keypoints and colors + kptsA_np, kptsB_np, kptsB_proj_matrices_idcs, kptsA_color, kptsB_color = extract_keypoints_and_colors( + imA, imB_compound, certainties_max, certainties_max_idcs, matches_NN, roma_model + ) + + # proj_matrices_A = viewpoint_stack[source_idx].full_proj_transform + # proj_matrices_B = viewpoint_stack[closest_indices_selected[source_idx, NN_idx]].full_proj_transform + + proj_matrices_A = viewpoints_proj[source_idx] + proj_matrices_B = viewpoints_proj[closest_indices_selected[source_idx, NN_idx]] + # exit(0) + triangulated_points, triangulated_points_errors_proj1, triangulated_points_errors_proj2 = triangulate_points( + P1=torch.stack([proj_matrices_A] * M, axis=0), + P2=torch.stack([proj_matrices_B] * M, axis=0), + k1_x=kptsA_np[:M, 0], k1_y=kptsA_np[:M, 1], + k2_x=kptsB_np[:M, 0], k2_y=kptsB_np[:M, 1]) + + reference_image_dict["triangulated_points"].append(triangulated_points) + reference_image_dict["triangulated_points_errors_proj1"].append(triangulated_points_errors_proj1) + reference_image_dict["triangulated_points_errors_proj2"].append(triangulated_points_errors_proj2) + + NNs_triangulated_points_selected, NNs_triangulated_points_selected_proj_errors = select_best_keypoints( + NNs_triangulated_points=torch.stack(reference_image_dict["triangulated_points"], dim=0), + NNs_errors_proj1=np.stack(reference_image_dict["triangulated_points_errors_proj1"], axis=0), + NNs_errors_proj2=np.stack(reference_image_dict["triangulated_points_errors_proj2"], axis=0)) + + # 4. Save as gaussians + # N = len(NNs_triangulated_points_selected) + + new_xyz = NNs_triangulated_points_selected[:, :-1] + all_new_xyz.append(new_xyz) # seeked_splats + all_new_rgb.append(torch.tensor(kptsA_color.astype(np.float32) / 255.).to(device)) + + mask_bad_points = torch.tensor( + NNs_triangulated_points_selected_proj_errors > keypoint_fit_error_tolerance, + dtype=torch.float32) + mask_bad_points = mask_bad_points.to(device) + print("Number of bad points for source_idx", source_idx, ":", mask_bad_points.sum().item()) + # exit(0) + new_opacities = torch.ones((new_xyz.shape[0]), device=device) * init_opacity + new_opacities_raw = torch.logit(new_opacities) + new_opacities_raw = new_opacities_raw - mask_bad_points * (1e1) + all_new_opacities_raw.append(new_opacities_raw) + camera_center = camera_centers[source_idx].unsqueeze(0) # (1, 3) + dist_points_to_cam1 = torch.linalg.norm(camera_center - new_xyz, dim=1, ord=2) + all_new_scaling.append((dist_points_to_cam1 * 0.001).unsqueeze(1).repeat(1, 3)) + + all_new_xyz = torch.cat(all_new_xyz, dim=0) + all_new_rgb = torch.cat(all_new_rgb, dim=0) + all_new_scaling = torch.cat(all_new_scaling, dim=0) + all_new_opacities_raw = torch.cat(all_new_opacities_raw, dim=0) + all_new_opacities = torch.sigmoid(all_new_opacities_raw) + + points_dict = { + "xyz": all_new_xyz, + "rgb": all_new_rgb, + "scales": all_new_scaling, + "opacities": all_new_opacities, + } + + return closest_indices_selected, visualizations, points_dict diff --git a/optgs/experimental/edgs/utils.py b/optgs/experimental/edgs/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c477e6db88c8559b3d007220f79ac5a2a9331b50 --- /dev/null +++ b/optgs/experimental/edgs/utils.py @@ -0,0 +1,536 @@ +import torch +from torch import Tensor +import matplotlib.pyplot as plt +import numpy as np +from tqdm import tqdm +from scipy.cluster.vq import kmeans, vq +from scipy.spatial.distance import cdist +from PIL import Image +import torch.nn.functional as F + + +def pairwise_distances(matrix): + """ + Computes the pairwise Euclidean distances between all vectors in the input matrix. + + Args: + matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality. + + Returns: + torch.Tensor: Pairwise distance matrix of shape [N, N]. + """ + # Compute squared pairwise distances + squared_diff = torch.cdist(matrix, matrix, p=2) + return squared_diff + + +def k_closest_vectors(matrix, k): + """ + Finds the k-closest vectors for each vector in the input matrix based on Euclidean distance. + + Args: + matrix (torch.Tensor): Input matrix of shape [N, D], where N is the number of vectors and D is the dimensionality. + k (int): Number of closest vectors to return for each vector. + + Returns: + torch.Tensor: Indices of the k-closest vectors for each vector, excluding the vector itself. + """ + # Compute pairwise distances + distances = pairwise_distances(matrix) + + # For each vector, sort distances and get the indices of the k-closest vectors (excluding itself) + # Set diagonal distances to infinity to exclude the vector itself from the nearest neighbors + distances.fill_diagonal_(float('inf')) + + # Get the indices of the k smallest distances (k-closest vectors) + _, indices = torch.topk(distances, k, largest=False, dim=1) + + return indices + + +def select_cameras_kmeans(cameras, K): + """ + Selects K cameras from a set using K-means clustering. + + Args: + cameras: NumPy array of shape (N, 16), representing N cameras with their 4x4 homogeneous matrices flattened. + K: Number of clusters (cameras to select). + + Returns: + selected_indices: List of indices of the cameras closest to the cluster centers. + """ + # Ensure input is a NumPy array + if not isinstance(cameras, np.ndarray): + cameras = np.asarray(cameras) + + if cameras.shape[1] != 16: + raise ValueError("Each camera must have 16 values corresponding to a flattened 4x4 matrix.") + + # Perform K-means clustering + cluster_centers, _ = kmeans(cameras, K) + + # Assign each camera to a cluster and find distances to cluster centers + cluster_assignments, _ = vq(cameras, cluster_centers) + + # Find the camera nearest to each cluster center + selected_indices = [] + for k in range(K): + cluster_members = cameras[cluster_assignments == k] + distances = cdist([cluster_centers[k]], cluster_members)[0] + nearest_camera_idx = np.where(cluster_assignments == k)[0][np.argmin(distances)] + selected_indices.append(nearest_camera_idx) + + return selected_indices + + +def compute_warp_and_confidence( + # viewpoint_cam1, + # viewpoint_cam2, + viewpoint_img1: Tensor, # (3, H, W) + viewpoint_img2: Tensor, # (3, H, W) + roma_model, + device="cuda", + verbose=False, + output_dict={} +): + """ + Computes the warp and confidence between two viewpoint cameras using the roma_model. + + Args: + viewpoint_cam1: Source viewpoint camera. + viewpoint_cam2: Target viewpoint camera. + roma_model: Pre-trained Roma model for correspondence matching. + device: Device to run the computation on. + verbose: If True, displays the images. + + Returns: + certainty: Confidence tensor. + warp: Warp tensor. + imB: Processed image B as numpy array. + """ + # Prepare images + # imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0) + # imB = viewpoint_cam2.original_image.detach().cpu().numpy().transpose(1, 2, 0) + imA = viewpoint_img1.detach().cpu().numpy().transpose(1, 2, 0) # [H, W, 3] + imB = viewpoint_img2.detach().cpu().numpy().transpose(1, 2, 0) # [H, W, 3] + imA = Image.fromarray(np.clip(imA * 255, 0, 255).astype(np.uint8)) + imB = Image.fromarray(np.clip(imB * 255, 0, 255).astype(np.uint8)) + + if verbose: + fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8)) + cax1 = ax[0].imshow(imA) + ax[0].set_title("Image 1") + cax2 = ax[1].imshow(imB) + ax[1].set_title("Image 2") + fig.colorbar(cax1, ax=ax[0]) + fig.colorbar(cax2, ax=ax[1]) + + for axis in ax: + axis.axis('off') + # Save the figure into the dictionary + output_dict[f'image_pair'] = fig + + # Transform images + ws, hs = roma_model.w_resized, roma_model.h_resized + + from romatch.utils import get_tuple_transform_ops + test_transform = get_tuple_transform_ops(resize=(hs, ws), normalize=True) + im_A, im_B = test_transform((imA, imB)) + batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)} + + # Forward pass through Roma model + corresps = roma_model.forward(batch) if not roma_model.symmetric else roma_model.forward_symmetric(batch) + finest_scale = 1 + hs, ws = roma_model.upsample_res if roma_model.upsample_preds else (hs, ws) + + # Process certainty and warp + certainty = corresps[finest_scale]["certainty"] + im_A_to_im_B = corresps[finest_scale]["flow"] + if roma_model.attenuate_cert: + low_res_certainty = F.interpolate( + corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear" + ) + certainty -= 0.5 * low_res_certainty * (low_res_certainty < 0) + + # Upsample predictions if needed + if roma_model.upsample_preds: + im_A_to_im_B = F.interpolate( + im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear" + ) + certainty = F.interpolate( + certainty, size=(hs, ws), align_corners=False, mode="bilinear" + ) + + # Convert predictions to final format + im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1) + im_A_coords = torch.stack(torch.meshgrid( + torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device), + torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device), + indexing='ij' + ), dim=0).permute(1, 2, 0).unsqueeze(0).expand(im_A_to_im_B.size(0), -1, -1, -1) + + warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1) + certainty = certainty.sigmoid() + + return certainty[0, 0], warp[0], np.array(imB) + + +def resize_batch(tensors_3d, tensors_4d, target_shape): + """ + Resizes a batch of tensors with shapes [B, H, W] and [B, H, W, 4] to the target spatial dimensions. + + Args: + tensors_3d: Tensor of shape [B, H, W]. + tensors_4d: Tensor of shape [B, H, W, 4]. + target_shape: Tuple (target_H, target_W) specifying the target spatial dimensions. + + Returns: + resized_tensors_3d: Tensor of shape [B, target_H, target_W]. + resized_tensors_4d: Tensor of shape [B, target_H, target_W, 4]. + """ + target_H, target_W = target_shape + + # Resize [B, H, W] tensor + resized_tensors_3d = F.interpolate( + tensors_3d.unsqueeze(1), size=(target_H, target_W), mode="bilinear", align_corners=False + ).squeeze(1) + + # Resize [B, H, W, 4] tensor + B, _, _, C = tensors_4d.shape + resized_tensors_4d = F.interpolate( + tensors_4d.permute(0, 3, 1, 2), size=(target_H, target_W), mode="bilinear", align_corners=False + ).permute(0, 2, 3, 1) + + return resized_tensors_3d, resized_tensors_4d + + +def aggregate_confidences_and_warps( + # viewpoint_stack, + viewpoints_img, + # viewpoints_c2w, + closest_indices, + roma_model, + source_idx, + verbose=False, + output_dict={} +): + """ + Aggregates confidences and warps by iterating over the nearest neighbors of the source viewpoint. + + Args: + viewpoint_stack: Stack of viewpoint cameras. + closest_indices: Indices of the nearest neighbors for each viewpoint. + roma_model: Pre-trained Roma model. + source_idx: Index of the source viewpoint. + verbose: If True, displays intermediate results. + + Returns: + certainties_max: Aggregated maximum confidences. + warps_max: Aggregated warps corresponding to maximum confidences. + certainties_max_idcs: Pixel-wise index of the image from which we taken the best matching. + imB_compound: List of the neighboring images. + """ + certainties_all, warps_all, imB_compound = [], [], [] + + for nn in tqdm(closest_indices[source_idx]): + + # viewpoint_cam1 = viewpoint_stack[source_idx] + # viewpoint_cam2 = viewpoint_stack[nn] + viewpoint_img1 = viewpoints_img[source_idx] # (3, H, W) + viewpoint_img2 = viewpoints_img[nn] # (3, H, W) + + certainty, warp, imB = compute_warp_and_confidence( + # viewpoint_cam1, + # viewpoint_cam2, + viewpoint_img1, + viewpoint_img2, + roma_model, + verbose=verbose, + output_dict=output_dict + ) + certainties_all.append(certainty) + warps_all.append(warp) + imB_compound.append(imB) + + certainties_all = torch.stack(certainties_all, dim=0) + target_shape = imB_compound[0].shape[:2] + if verbose: + print("certainties_all.shape:", certainties_all.shape) + print("torch.stack(warps_all, dim=0).shape:", torch.stack(warps_all, dim=0).shape) + print("target_shape:", target_shape) + + certainties_all_resized, warps_all_resized = resize_batch(certainties_all, + torch.stack(warps_all, dim=0), + target_shape + ) + + if verbose: + print("warps_all_resized.shape:", warps_all_resized.shape) + for n, cert in enumerate(certainties_all): + fig, ax = plt.subplots() + cax = ax.imshow(cert.cpu().numpy(), cmap='viridis') + fig.colorbar(cax, ax=ax) + ax.set_title("Pixel-wise Confidence") + output_dict[f'certainty_{n}'] = fig + + for n, warp in enumerate(warps_all): + fig, ax = plt.subplots() + cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis') + fig.colorbar(cax, ax=ax) + ax.set_title("Pixel-wise warp") + output_dict[f'warp_resized_{n}'] = fig + + for n, cert in enumerate(certainties_all_resized): + fig, ax = plt.subplots() + cax = ax.imshow(cert.cpu().numpy(), cmap='viridis') + fig.colorbar(cax, ax=ax) + ax.set_title("Pixel-wise Confidence resized") + output_dict[f'certainty_resized_{n}'] = fig + + for n, warp in enumerate(warps_all_resized): + fig, ax = plt.subplots() + cax = ax.imshow(warp.cpu().numpy()[:, :, :3], cmap='viridis') + fig.colorbar(cax, ax=ax) + ax.set_title("Pixel-wise warp resized") + output_dict[f'warp_resized_{n}'] = fig + + certainties_max, certainties_max_idcs = torch.max(certainties_all_resized, dim=0) + H, W = certainties_max.shape + + warps_max = warps_all_resized[certainties_max_idcs, torch.arange(H).unsqueeze(1), torch.arange(W)] + + # imA = viewpoint_cam1.original_image.detach().cpu().numpy().transpose(1, 2, 0) + imA = viewpoint_img1.detach().cpu().numpy().transpose(1, 2, 0) # [H, W, 3] + imA = np.clip(imA * 255, 0, 255).astype(np.uint8) + + return certainties_max, warps_max, certainties_max_idcs, imA, imB_compound, certainties_all_resized, warps_all_resized + + + +def extract_keypoints_and_colors(imA, imB_compound, certainties_max, certainties_max_idcs, matches, roma_model, + verbose=False, output_dict={}): + """ + Extracts keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound). + + Args: + imA: Source image as a NumPy array (H_A, W_A, C). + imB_compound: List of target images as NumPy arrays [(H_B, W_B, C), ...]. + certainties_max: Tensor of pixel-wise maximum confidences. + certainties_max_idcs: Tensor of pixel-wise indices for the best matches. + matches: Matches in normalized coordinates. + roma_model: Roma model instance for keypoint operations. + verbose: if to show intermediate outputs and visualize results + + Returns: + kptsA_np: Keypoints in imA in normalized coordinates. + kptsB_np: Keypoints in imB in normalized coordinates. + kptsA_color: Colors of keypoints in imA. + kptsB_color: Colors of keypoints in imB based on certainties_max_idcs. + """ + H_A, W_A, _ = imA.shape + H, W = certainties_max.shape + + # Convert matches to pixel coordinates + kptsA, kptsB = roma_model.to_pixel_coordinates( + matches, W_A, H_A, H, W # W, H + ) + + kptsA_np = kptsA.detach().cpu().numpy() + kptsB_np = kptsB.detach().cpu().numpy() + kptsA_np = kptsA_np[:, [1, 0]] + + if verbose: + fig, ax = plt.subplots(figsize=(12, 6)) + cax = ax.imshow(imA) + ax.set_title("Reference image, imA") + output_dict[f'reference_image'] = fig + + fig, ax = plt.subplots(figsize=(12, 6)) + cax = ax.imshow(imB_compound[0]) + ax.set_title("Image to compare to image, imB_compound") + output_dict[f'imB_compound'] = fig + + fig, ax = plt.subplots(figsize=(12, 6)) + cax = ax.imshow(np.flipud(imA)) + cax = ax.scatter(kptsA_np[:, 0], H_A - kptsA_np[:, 1], s=.03) + ax.set_title("Keypoints in imA") + ax.set_xlim(0, W_A) + ax.set_ylim(0, H_A) + output_dict[f'kptsA'] = fig + + fig, ax = plt.subplots(figsize=(12, 6)) + cax = ax.imshow(np.flipud(imB_compound[0])) + cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03) + ax.set_title("Keypoints in imB") + ax.set_xlim(0, W_A) + ax.set_ylim(0, H_A) + output_dict[f'kptsB'] = fig + + # Keypoints are in format (row, column) so the first value is alwain in range [0;height] and second is in range[0;width] + + kptsA_np = kptsA.detach().cpu().numpy() + kptsB_np = kptsB.detach().cpu().numpy() + + # Extract colors for keypoints in imA (vectorized) + # New experimental version + kptsA_x = np.round(kptsA_np[:, 0] / 1.).astype(int) + kptsA_y = np.round(kptsA_np[:, 1] / 1.).astype(int) + kptsA_color = imA[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)] + + # Create a composite image from imB_compound + imB_compound_np = np.stack(imB_compound, axis=0) + H_B, W_B, _ = imB_compound[0].shape + + # Extract colors for keypoints in imB using certainties_max_idcs + imB_np = imB_compound_np[ + certainties_max_idcs.detach().cpu().numpy(), + np.arange(H).reshape(-1, 1), + np.arange(W) + ] + + if verbose: + print("imB_np.shape:", imB_np.shape) + print("imB_np:", imB_np) + fig, ax = plt.subplots(figsize=(12, 6)) + cax = ax.imshow(np.flipud(imB_np)) + cax = ax.scatter(kptsB_np[:, 0], H_A - kptsB_np[:, 1], s=.03) + ax.set_title("np.flipud(imB_np[0]") + ax.set_xlim(0, W_A) + ax.set_ylim(0, H_A) + output_dict[f'np.flipud(imB_np[0]'] = fig + + + kptsB_x = np.round(kptsB_np[:, 0]).astype(int) + kptsB_y = np.round(kptsB_np[:, 1]).astype(int) + + certainties_max_idcs_np = certainties_max_idcs.detach().cpu().numpy() + kptsB_proj_matrices_idx = certainties_max_idcs_np[np.clip(kptsA_x, 0, H - 1), np.clip(kptsA_y, 0, W - 1)] + kptsB_color = imB_compound_np[kptsB_proj_matrices_idx, np.clip(kptsB_y, 0, H - 1), np.clip(kptsB_x, 0, W - 1)] + + # Normalize keypoints in both images + kptsA_np[:, 0] = kptsA_np[:, 0] / H * 2.0 - 1.0 + kptsA_np[:, 1] = kptsA_np[:, 1] / W * 2.0 - 1.0 + kptsB_np[:, 0] = kptsB_np[:, 0] / W_B * 2.0 - 1.0 + kptsB_np[:, 1] = kptsB_np[:, 1] / H_B * 2.0 - 1.0 + + return kptsA_np[:, [1, 0]], kptsB_np, kptsB_proj_matrices_idx, kptsA_color, kptsB_color + +def prepare_tensor(input_array, device): + """ + Converts an input array to a torch tensor, clones it, and detaches it for safe computation. + Args: + input_array (array-like): The input array to convert. + device (str or torch.device): The device to move the tensor to. + Returns: + torch.Tensor: A detached tensor clone of the input array on the specified device. + """ + if not isinstance(input_array, torch.Tensor): + return torch.tensor(input_array, dtype=torch.float32).to(device).clone().detach() + return input_array.clone().detach().to(device).to(torch.float32) + +def triangulate_points(P1, P2, k1_x, k1_y, k2_x, k2_y, device="cuda"): + """ + Solves for a batch of 3D points given batches of projection matrices and corresponding image points. + + Parameters: + - P1, P2: Tensors of projection matrices of size (batch_size, 4, 4) or (4, 4) + - k1_x, k1_y: Tensors of shape (batch_size,) + - k2_x, k2_y: Tensors of shape (batch_size,) + + Returns: + - X: A tensor containing the 3D homogeneous coordinates, shape (batch_size, 4) + """ + EPS = 1e-4 + # Ensure inputs are tensors + + P1 = prepare_tensor(P1, device) + P2 = prepare_tensor(P2, device) + k1_x = prepare_tensor(k1_x, device) + k1_y = prepare_tensor(k1_y, device) + k2_x = prepare_tensor(k2_x, device) + k2_y = prepare_tensor(k2_y, device) + batch_size = k1_x.shape[0] + + # Expand P1 and P2 if they are not batched + if P1.ndim == 2: + P1 = P1.unsqueeze(0).expand(batch_size, -1, -1) + if P2.ndim == 2: + P2 = P2.unsqueeze(0).expand(batch_size, -1, -1) + + # Extract columns from P1 and P2 + P1_0 = P1[:, :, 0] # Shape: (batch_size, 4) + P1_1 = P1[:, :, 1] + P1_2 = P1[:, :, 2] + + P2_0 = P2[:, :, 0] + P2_1 = P2[:, :, 1] + P2_2 = P2[:, :, 2] + + # Reshape kx and ky to (batch_size, 1) + k1_x = k1_x.view(-1, 1) + k1_y = k1_y.view(-1, 1) + k2_x = k2_x.view(-1, 1) + k2_y = k2_y.view(-1, 1) + + # Construct the equations for each batch + # For camera 1 + A1 = P1_0 - k1_x * P1_2 # Shape: (batch_size, 4) + A2 = P1_1 - k1_y * P1_2 + # For camera 2 + A3 = P2_0 - k2_x * P2_2 + A4 = P2_1 - k2_y * P2_2 + + # Stack the equations + A = torch.stack([A1, A2, A3, A4], dim=1) # Shape: (batch_size, 4, 4) + + # Right-hand side (constants) + b = -A[:, :, 3] # Shape: (batch_size, 4) + A_reduced = A[:, :, :3] # Coefficients of x, y, z + + # Solve using torch.linalg.lstsq (supports batching) + X_xyz = torch.linalg.lstsq(A_reduced, b.unsqueeze(2)).solution.squeeze(2) # Shape: (batch_size, 3) + + # Append 1 to get homogeneous coordinates + ones = torch.ones((batch_size, 1), dtype=torch.float32, device=X_xyz.device) + X = torch.cat([X_xyz, ones], dim=1) # Shape: (batch_size, 4) + + # Now compute the errors of projections. + seeked_splats_proj1 = (X.unsqueeze(1) @ P1).squeeze(1) + seeked_splats_proj1 = seeked_splats_proj1 / (EPS + seeked_splats_proj1[:, [3]]) + seeked_splats_proj2 = (X.unsqueeze(1) @ P2).squeeze(1) + seeked_splats_proj2 = seeked_splats_proj2 / (EPS + seeked_splats_proj2[:, [3]]) + proj1_target = torch.concat([k1_x, k1_y], dim=1) + proj2_target = torch.concat([k2_x, k2_y], dim=1) + errors_proj1 = torch.abs(seeked_splats_proj1[:, :2] - proj1_target).sum(1).detach().cpu().numpy() + errors_proj2 = torch.abs(seeked_splats_proj2[:, :2] - proj2_target).sum(1).detach().cpu().numpy() + + return X, errors_proj1, errors_proj2 + + + +def select_best_keypoints( + NNs_triangulated_points, NNs_errors_proj1, NNs_errors_proj2, device="cuda"): + """ + From all the points fitted to keypoints and corresponding colors from the source image (imA) and multiple target images (imB_compound). + + Args: + NNs_triangulated_points: torch tensor with keypoints coordinates (num_nns, num_points, dim). dim can be arbitrary, + usually 3 or 4(for homogeneous representation). + NNs_errors_proj1: numpy array with projection error of the estimated keypoint on the reference frame (num_nns, num_points). + NNs_errors_proj2: numpy array with projection error of the estimated keypoint on the neighbor frame (num_nns, num_points). + Returns: + selected_keypoints: keypoints with the best score. + """ + + NNs_errors_proj = np.maximum(NNs_errors_proj1, NNs_errors_proj2) + + # Convert indices to PyTorch tensor + indices = torch.from_numpy(np.argmin(NNs_errors_proj, axis=0)).long().to(device) + + # Create index tensor for the second dimension + n_indices = torch.arange(NNs_triangulated_points.shape[1]).long().to(device) + + # Use advanced indexing to select elements + NNs_triangulated_points_selected = NNs_triangulated_points[indices, n_indices, :] # Shape: [N, k] + + return NNs_triangulated_points_selected, np.min(NNs_errors_proj, axis=0) \ No newline at end of file diff --git a/optgs/experimental/initializers_utils.py b/optgs/experimental/initializers_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d992c93215d5d51e6abd28875c2697622094081 --- /dev/null +++ b/optgs/experimental/initializers_utils.py @@ -0,0 +1,51 @@ +from sklearn.neighbors import NearestNeighbors +import torch +from torch import Tensor + +from optgs.scene_trainer.common.gaussian_adapter import RGB2SH + + +def knn(x: Tensor, K: int = 4) -> Tensor: + x_np = x.cpu().numpy() + model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) + distances, _ = model.kneighbors(x_np) + return torch.from_numpy(distances).to(x) + + +def points_to_gaussians( + points_dict: dict[str, Tensor], + sh_degree: int = 3, + device: torch.device = torch.device("cpu"), +) -> dict[str, Tensor]: + + xyz = points_dict["xyz"].clone().to(device) + N = xyz.shape[0] + + # color is SH coefficients + rgbs = points_dict["rgb"].clone().to(device) # [N, 3], in [0, 1] + + # if sh_degree > 0: + shs = torch.zeros((N, (sh_degree + 1) ** 2, 3), device=device) # [N, K, 3] + shs[:, 0, :] = RGB2SH(rgbs) + sh0 = shs[:, :1, :] # [N, 1, 3] + if sh_degree > 0: + shN = shs[:, 1:, :] # [N, K-1, 3] + else: + shN = None + + quats_unnorm = torch.rand((N, 4), device=device) # [N, 4] + + scales = points_dict["scales"].clone().to(device) # [N, 3] + scales_raw = torch.log(scales) + + opacities = points_dict["opacities"].clone().to(device) # [N,] + opacities_raw = torch.logit(opacities) + + return { + "xyz": xyz, + "sh0": sh0, # [N, 1, 3] + "shN": shN, # [N, sh_d-1, 3] or None + "scales_raw": scales_raw, + "rotations_unnorm": quats_unnorm, + "opacities_raw": opacities_raw, + } diff --git a/optgs/geometry/__init__.py b/optgs/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/geometry/epipolar_lines.py b/optgs/geometry/epipolar_lines.py new file mode 100644 index 0000000000000000000000000000000000000000..d4ea46bfbdfefb77851e747dd16d1eade1970d2c --- /dev/null +++ b/optgs/geometry/epipolar_lines.py @@ -0,0 +1,292 @@ +import itertools +from typing import Iterable, Literal, Optional, TypedDict + +import torch +from einops import einsum, repeat +from jaxtyping import Bool, Float +from torch import Tensor +from torch.utils.data.dataloader import default_collate + +from .projection import ( + get_world_rays, + homogenize_points, + homogenize_vectors, + intersect_rays, + project_camera_space, +) + + +def _is_in_bounds( + xy: Float[Tensor, "*batch 2"], + epsilon: float = 1e-6, +) -> Bool[Tensor, " *batch"]: + """Check whether the specified XY coordinates are within the normalized image plane, + which has a range from 0 to 1 in each direction. + """ + return (xy >= -epsilon).all(dim=-1) & (xy <= 1 + epsilon).all(dim=-1) + + +def _is_in_front_of_camera( + xyz: Float[Tensor, "*batch 3"], + epsilon: float = 1e-6, +) -> Bool[Tensor, " *batch"]: + """Check whether the specified points in camera space are in front of the camera.""" + return xyz[..., -1] > -epsilon + + +def _is_positive_t( + t: Float[Tensor, " *batch"], + epsilon: float = 1e-6, +) -> Bool[Tensor, " *batch"]: + """Check whether the specified t value is positive.""" + return t > -epsilon + + +class PointProjection(TypedDict): + t: Float[Tensor, " *batch"] # ray parameter, as in xyz = origin + t * direction + xy: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1) + + # A "valid" projection satisfies two conditions: + # 1. It is in front of the camera (i.e., its 3D Z coordinate is positive). + # 2. It is within the image frame (i.e., its 2D coordinates are between 0 and 1). + valid: Bool[Tensor, " *batch"] + + +def _intersect_image_coordinate( + intrinsics: Float[Tensor, "*#batch 3 3"], + origins: Float[Tensor, "*#batch 3"], + directions: Float[Tensor, "*#batch 3"], + dimension: Literal["x", "y"], + coordinate_value: float, +) -> PointProjection: + """Compute the intersection of the projection of a camera-space ray with a line + that's parallel to the image frame, either horizontally or vertically. + """ + + # Define shorthands. + dim = "xy".index(dimension) + other_dim = 1 - dim + fs = intrinsics[..., dim, dim] # focal length, same coordinate + fo = intrinsics[..., other_dim, other_dim] # focal length, other coordinate + cs = intrinsics[..., dim, 2] # principal point, same coordinate + co = intrinsics[..., other_dim, 2] # principal point, other coordinate + os = origins[..., dim] # ray origin, same coordinate + oo = origins[..., other_dim] # ray origin, other coordinate + ds = directions[..., dim] # ray direction, same coordinate + do = directions[..., other_dim] # ray direction, other coordinate + oz = origins[..., 2] # ray origin, z coordinate + dz = directions[..., 2] # ray direction, z coordinate + c = (coordinate_value - cs) / fs # coefficient (computed once and factored out) + + # Compute the value of t at the intersection. + # Note: Infinite values of t are fine. No need to handle division by zero. + t_numerator = c * oz - os + t_denominator = ds - c * dz + t = t_numerator / t_denominator + + # Compute the value of the other coordinate at the intersection. + # Note: Infinite coordinate values are fine. No need to handle division by zero. + coordinate_numerator = fo * (oo * (c * dz - ds) + do * (os - c * oz)) + coordinate_denominator = dz * os - ds * oz + coordinate_other = co + coordinate_numerator / coordinate_denominator + coordinate_same = torch.ones_like(coordinate_other) * coordinate_value + xy = [coordinate_same] + xy.insert(other_dim, coordinate_other) + xy = torch.stack(xy, dim=-1) + xyz = origins + t[..., None] * directions + + # These will all have exactly the same batch shape (no broadcasting necessary). In + # terms of jaxtyping annotations, they all match *batch, not just *#batch. + return { + "t": t, + "xy": xy, + "valid": _is_in_bounds(xy) & _is_in_front_of_camera(xyz) & _is_positive_t(t), + } + + +def _compare_projections( + intersections: Iterable[PointProjection], + reduction: Literal["min", "max"], +) -> PointProjection: + intersections = {k: v.clone() for k, v in default_collate(intersections).items()} + t = intersections["t"] + xy = intersections["xy"] + valid = intersections["valid"] + + # Make sure out-of-bounds values are not chosen. + lowest_priority = { + "min": torch.inf, + "max": -torch.inf, + }[reduction] + t[~valid] = lowest_priority + + # Run the reduction (either t.min() or t.max()). + reduced, selector = getattr(t, reduction)(dim=0) + + # Index the results. + return { + "t": reduced, + "xy": xy.gather(0, repeat(selector, "... -> () ... xy", xy=2))[0], + "valid": valid.gather(0, selector[None])[0], + } + + +def _compute_point_projection( + xyz: Float[Tensor, "*#batch 3"], + t: Float[Tensor, "*#batch"], + intrinsics: Float[Tensor, "*#batch 3 3"], +) -> PointProjection: + xy = project_camera_space(xyz, intrinsics) + return { + "t": t, + "xy": xy, + "valid": _is_in_bounds(xy) & _is_in_front_of_camera(xyz) & _is_positive_t(t), + } + + +class RaySegmentProjection(TypedDict): + t_min: Float[Tensor, " *batch"] # ray parameter + t_max: Float[Tensor, " *batch"] # ray parameter + xy_min: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1) + xy_max: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1) + + # Whether the segment overlaps the image. If not, the above values are meaningless. + overlaps_image: Bool[Tensor, " *batch"] + + +def project_rays( + origins: Float[Tensor, "*#batch 3"], + directions: Float[Tensor, "*#batch 3"], + extrinsics: Float[Tensor, "*#batch 4 4"], + intrinsics: Float[Tensor, "*#batch 3 3"], + near: Optional[Float[Tensor, "*#batch"]] = None, + far: Optional[Float[Tensor, "*#batch"]] = None, + epsilon: float = 1e-6, +) -> RaySegmentProjection: + # Transform the rays into camera space. + world_to_cam = torch.linalg.inv(extrinsics) + origins = homogenize_points(origins) + origins = einsum(world_to_cam, origins, "... i j, ... j -> ... i") + directions = homogenize_vectors(directions) + directions = einsum(world_to_cam, directions, "... i j, ... j -> ... i") + origins = origins[..., :3] + directions = directions[..., :3] + + # Compute intersections with the image's frame. + frame_intersections = ( + _intersect_image_coordinate(intrinsics, origins, directions, "x", 0.0), + _intersect_image_coordinate(intrinsics, origins, directions, "x", 1.0), + _intersect_image_coordinate(intrinsics, origins, directions, "y", 0.0), + _intersect_image_coordinate(intrinsics, origins, directions, "y", 1.0), + ) + frame_intersection_min = _compare_projections(frame_intersections, "min") + frame_intersection_max = _compare_projections(frame_intersections, "max") + + if near is None: + # Compute the ray's projection at zero depth. If an origin's depth (z value) is + # within epsilon of zero, this can mean one of two things: + # 1. The origin is at the camera's position. In this case, use the direction + # instead (the ray is probably coming from the camera). + # 2. The origin isn't at the camera's position, and randomly happens to be on + # the plane at zero depth. In this case, its projection is outside the image + # plane, and is thus marked as invalid. + origins_for_projection = origins.clone() + mask_depth_zero = origins_for_projection[..., -1] < epsilon + mask_at_camera = origins_for_projection.norm(dim=-1) < epsilon + origins_for_projection[mask_at_camera] = directions[mask_at_camera] + projection_at_zero = _compute_point_projection( + origins_for_projection, + torch.zeros_like(frame_intersection_min["t"]), + intrinsics, + ) + projection_at_zero["valid"][mask_depth_zero & ~mask_at_camera] = False + else: + # If a near plane is specified, use it instead. + t_near = near.broadcast_to(frame_intersection_min["t"].shape) + projection_at_zero = _compute_point_projection( + origins + near[..., None] * directions, + t_near, + intrinsics, + ) + + if far is None: + # Compute the ray's projection at infinite depth. Using the projection function + # with directions (vectors) instead of points may seem wonky, but is equivalent + # to projecting the point at (origins + infinity * directions). + projection_at_infinity = _compute_point_projection( + directions, + torch.ones_like(frame_intersection_min["t"]) * torch.inf, + intrinsics, + ) + else: + # If a far plane is specified, use it instead. + t_far = far.broadcast_to(frame_intersection_min["t"].shape) + projection_at_infinity = _compute_point_projection( + origins + far[..., None] * directions, + t_far, + intrinsics, + ) + + # Build the result by handling cases for ray intersection. + result = { + "t_min": torch.empty_like(projection_at_zero["t"]), + "t_max": torch.empty_like(projection_at_infinity["t"]), + "xy_min": torch.empty_like(projection_at_zero["xy"]), + "xy_max": torch.empty_like(projection_at_infinity["xy"]), + "overlaps_image": torch.empty_like(projection_at_zero["valid"]), + } + + for min_valid, max_valid in itertools.product([True, False], [True, False]): + min_mask = projection_at_zero["valid"] ^ (not min_valid) + max_mask = projection_at_infinity["valid"] ^ (not max_valid) + mask = min_mask & max_mask + min_value = projection_at_zero if min_valid else frame_intersection_min + max_value = projection_at_infinity if max_valid else frame_intersection_max + result["t_min"][mask] = min_value["t"][mask] + result["t_max"][mask] = max_value["t"][mask] + result["xy_min"][mask] = min_value["xy"][mask] + result["xy_max"][mask] = max_value["xy"][mask] + result["overlaps_image"][mask] = (min_value["valid"] & max_value["valid"])[mask] + + return result + + +class RaySegmentProjection(TypedDict): + t_min: Float[Tensor, " *batch"] # ray parameter + t_max: Float[Tensor, " *batch"] # ray parameter + xy_min: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1) + xy_max: Float[Tensor, "*batch 2"] # image-space xy (normalized to 0 to 1) + + # Whether the segment overlaps the image. If not, the above values are meaningless. + overlaps_image: Bool[Tensor, " *batch"] + + +def lift_to_3d( + origins: Float[Tensor, "*#batch 3"], + directions: Float[Tensor, "*#batch 3"], + xy: Float[Tensor, "*#batch 2"], + extrinsics: Float[Tensor, "*#batch 4 4"], + intrinsics: Float[Tensor, "*#batch 3 3"], +) -> Float[Tensor, "*batch 3"]: + """Calculate the 3D positions that correspond to the specified 2D points on the + epipolar lines defined by the origins and directions. The extrinsics and intrinsics + are for the images the 2D points lie on. + """ + + xy_origins, xy_directions = get_world_rays(xy, extrinsics, intrinsics) + return intersect_rays(origins, directions, xy_origins, xy_directions) + + +def get_depth( + origins: Float[Tensor, "*#batch 3"], + directions: Float[Tensor, "*#batch 3"], + xy: Float[Tensor, "*#batch 2"], + extrinsics: Float[Tensor, "*#batch 4 4"], + intrinsics: Float[Tensor, "*#batch 3 3"], +) -> Float[Tensor, " *batch"]: + """Calculate the depths that correspond to the specified 2D points on the epipolar + lines defined by the origins and directions. The extrinsics and intrinsics are for + the images the 2D points lie on. + """ + xyz = lift_to_3d(origins, directions, xy, extrinsics, intrinsics) + return (xyz - origins).norm(dim=-1) diff --git a/optgs/geometry/projection.py b/optgs/geometry/projection.py new file mode 100644 index 0000000000000000000000000000000000000000..3228991cda0c4f9d8ca6fe69370c1ff2109d95b1 --- /dev/null +++ b/optgs/geometry/projection.py @@ -0,0 +1,278 @@ +from math import prod + +import torch +from einops import einsum, rearrange, reduce, repeat +from jaxtyping import Bool, Float, Int64 +from torch import Tensor + + +def homogenize_points( + points: Float[Tensor, "*batch dim"], +) -> Float[Tensor, "*batch dim+1"]: + """Convert batched points (xyz) to (xyz1).""" + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +def homogenize_vectors( + vectors: Float[Tensor, "*batch dim"], +) -> Float[Tensor, "*batch dim+1"]: + """Convert batched vectors (xyz) to (xyz0).""" + return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1) + + +def transform_rigid( + homogeneous_coordinates: Float[Tensor, "*#batch dim"], + transformation: Float[Tensor, "*#batch dim dim"], +) -> Float[Tensor, "*batch dim"]: + """Apply a rigid-body transformation to points or vectors.""" + return einsum(transformation, homogeneous_coordinates, "... i j, ... j -> ... i") + + +def transform_cam2world( + homogeneous_coordinates: Float[Tensor, "*#batch dim"], + extrinsics: Float[Tensor, "*#batch dim dim"], +) -> Float[Tensor, "*batch dim"]: + """Transform points from 3D camera coordinates to 3D world coordinates.""" + return transform_rigid(homogeneous_coordinates, extrinsics) + + +def transform_world2cam( + homogeneous_coordinates: Float[Tensor, "*#batch dim"], + extrinsics: Float[Tensor, "*#batch dim dim"], +) -> Float[Tensor, "*batch dim"]: + """Transform points from 3D world coordinates to 3D camera coordinates.""" + return transform_rigid(homogeneous_coordinates, extrinsics.inverse()) + + +def project_camera_space( + points: Float[Tensor, "*#batch dim"], + intrinsics: Float[Tensor, "*#batch dim dim"], + epsilon: float = torch.finfo(torch.float32).eps, + infinity: float = 1e8, +) -> Float[Tensor, "*batch dim-1"]: + points = points / (points[..., -1:] + epsilon) + points = points.nan_to_num(posinf=infinity, neginf=-infinity) + points = einsum(intrinsics, points, "... i j, ... j -> ... i") + return points[..., :-1] + + +def project( + points: Float[Tensor, "*#batch dim"], + extrinsics: Float[Tensor, "*#batch dim+1 dim+1"], + intrinsics: Float[Tensor, "*#batch dim dim"], + epsilon: float = torch.finfo(torch.float32).eps, +) -> tuple[ + Float[Tensor, "*batch dim-1"], # xy coordinates + Bool[Tensor, " *batch"], # whether points are in front of the camera +]: + points = homogenize_points(points) + points = transform_world2cam(points, extrinsics)[..., :-1] + in_front_of_camera = points[..., -1] >= 0 + return project_camera_space(points, intrinsics, epsilon=epsilon), in_front_of_camera + + +def unproject( + coordinates: Float[Tensor, "*#batch dim"], + z: Float[Tensor, "*#batch"], + intrinsics: Float[Tensor, "*#batch dim+1 dim+1"], +) -> Float[Tensor, "*batch dim+1"]: + """Unproject 2D camera coordinates with the given Z values.""" + + # Apply the inverse intrinsics to the coordinates. + coordinates = homogenize_points(coordinates) + ray_directions = einsum( + intrinsics.inverse(), coordinates, "... i j, ... j -> ... i" + ) + + # Apply the supplied depth values. + return ray_directions * z[..., None] + + +def get_world_rays( + coordinates: Float[Tensor, "*#batch dim"], + extrinsics: Float[Tensor, "*#batch dim+2 dim+2"], + intrinsics: Float[Tensor, "*#batch dim+1 dim+1"], +) -> tuple[ + Float[Tensor, "*batch dim+1"], # origins + Float[Tensor, "*batch dim+1"], # directions +]: + # Get camera-space ray directions. + directions = unproject( + coordinates, + torch.ones_like(coordinates[..., 0]), + intrinsics, + ) + directions = directions / directions[..., -1:] + + # Transform ray directions to world coordinates. + directions = homogenize_vectors(directions) + directions = transform_cam2world(directions, extrinsics)[..., :-1] + + # Tile the ray origins to have the same shape as the ray directions. + origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape) + + return origins, directions + + +def sample_image_grid( + shape: tuple[int, ...], + device: torch.device = torch.device("cpu"), +) -> tuple[ + Float[Tensor, "*shape dim"], # float coordinates (xy indexing) + Int64[Tensor, "*shape dim"], # integer indices (ij indexing) +]: + """Get normalized (range 0 to 1) coordinates and integer indices for an image.""" + + # Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a + # (row, col) coordinate. + indices = [torch.arange(length, device=device) for length in shape] + stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1) + + # Each entry is a floating-point coordinate in the range (0, 1). In the 2D case, + # each entry is an (x, y) coordinate. + coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)] + coordinates = reversed(coordinates) + coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1) + + return coordinates, stacked_indices + + +def sample_training_rays( + image: Float[Tensor, "batch view channel ..."], + intrinsics: Float[Tensor, "batch view dim dim"], + extrinsics: Float[Tensor, "batch view dim+1 dim+1"], + num_rays: int, +) -> tuple[ + Float[Tensor, "batch ray dim"], # origins + Float[Tensor, "batch ray dim"], # directions + Float[Tensor, "batch ray 3"], # sampled color +]: + device = extrinsics.device + b, v, _, *grid_shape = image.shape + + # Generate all possible target rays. + xy, _ = sample_image_grid(tuple(grid_shape), device) + origins, directions = get_world_rays( + rearrange(xy, "... d -> ... () () d"), + extrinsics, + intrinsics, + ) + origins = rearrange(origins, "... b v xy -> b (v ...) xy", b=b, v=v) + directions = rearrange(directions, "... b v xy -> b (v ...) xy", b=b, v=v) + pixels = rearrange(image, "b v c ... -> b (v ...) c") + + # Sample random rays. + num_possible_rays = v * prod(grid_shape) + ray_indices = torch.randint(num_possible_rays, (b, num_rays), device=device) + batch_indices = repeat(torch.arange(b, device=device), "b -> b n", n=num_rays) + + return ( + origins[batch_indices, ray_indices], + directions[batch_indices, ray_indices], + pixels[batch_indices, ray_indices], + ) + + +def intersect_rays( + origins_x: Float[Tensor, "*#batch 3"], + directions_x: Float[Tensor, "*#batch 3"], + origins_y: Float[Tensor, "*#batch 3"], + directions_y: Float[Tensor, "*#batch 3"], + eps: float = 1e-5, + inf: float = 1e10, +) -> Float[Tensor, "*batch 3"]: + """Compute the least-squares intersection of rays. Uses the math from here: + https://math.stackexchange.com/a/1762491/286022 + """ + + # Broadcast the rays so their shapes match. + shape = torch.broadcast_shapes( + origins_x.shape, + directions_x.shape, + origins_y.shape, + directions_y.shape, + ) + origins_x = origins_x.broadcast_to(shape) + directions_x = directions_x.broadcast_to(shape) + origins_y = origins_y.broadcast_to(shape) + directions_y = directions_y.broadcast_to(shape) + + # Detect and remove batch elements where the directions are parallel. + parallel = einsum(directions_x, directions_y, "... xyz, ... xyz -> ...") > 1 - eps + origins_x = origins_x[~parallel] + directions_x = directions_x[~parallel] + origins_y = origins_y[~parallel] + directions_y = directions_y[~parallel] + + # Stack the rays into (2, *shape). + origins = torch.stack([origins_x, origins_y], dim=0) + directions = torch.stack([directions_x, directions_y], dim=0) + dtype = origins.dtype + device = origins.device + + # Compute n_i * n_i^T - eye(3) from the equation. + n = einsum(directions, directions, "r b i, r b j -> r b i j") + n = n - torch.eye(3, dtype=dtype, device=device).broadcast_to((2, 1, 3, 3)) + + # Compute the left-hand side of the equation. + lhs = reduce(n, "r b i j -> b i j", "sum") + + # Compute the right-hand side of the equation. + rhs = einsum(n, origins, "r b i j, r b j -> r b i") + rhs = reduce(rhs, "r b i -> b i", "sum") + + # Left-matrix-multiply both sides by the pseudo-inverse of lhs to find p. + result = torch.linalg.lstsq(lhs, rhs).solution + + # Handle the case of parallel lines by setting depth to infinity. + result_all = torch.ones(shape, dtype=dtype, device=device) * inf + result_all[~parallel] = result + return result_all + + +def get_fov(intrinsics: Float[Tensor, "batch 3 3"]) -> Float[Tensor, "batch 2"]: + intrinsics_inv = intrinsics.inverse() + + def process_vector(vector): + vector = torch.tensor(vector, dtype=torch.float32, device=intrinsics.device) + vector = einsum(intrinsics_inv, vector, "b i j, j -> b i") + return vector / vector.norm(dim=-1, keepdim=True) + + left = process_vector([0, 0.5, 1]) + right = process_vector([1, 0.5, 1]) + top = process_vector([0.5, 0, 1]) + bottom = process_vector([0.5, 1, 1]) + fov_x = (left * right).sum(dim=-1).acos() + fov_y = (top * bottom).sum(dim=-1).acos() + return torch.stack((fov_x, fov_y), dim=-1) + + +def get_projection_matrix( + near: Float[Tensor, " batch"], + far: Float[Tensor, " batch"], + fov_x: Float[Tensor, " batch"], + fov_y: Float[Tensor, " batch"], +) -> Float[Tensor, "batch 4 4"]: + """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z + axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after + transformation and that Z is flipped. + """ + tan_fov_x = (0.5 * fov_x).tan() + tan_fov_y = (0.5 * fov_y).tan() + + top = tan_fov_y * near + bottom = -top + right = tan_fov_x * near + left = -right + + (b,) = near.shape + result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device) + result[:, 0, 0] = 2 * near / (right - left) + result[:, 1, 1] = 2 * near / (top - bottom) + result[:, 0, 2] = (right + left) / (right - left) + result[:, 1, 2] = (top + bottom) / (top - bottom) + result[:, 3, 2] = 1 + result[:, 2, 2] = far / (far - near) + result[:, 2, 3] = -(far * near) / (far - near) + return result + \ No newline at end of file diff --git a/optgs/global_cfg.py b/optgs/global_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8571b0adc3dbc3c41ddb0b364698782ff94ced --- /dev/null +++ b/optgs/global_cfg.py @@ -0,0 +1,19 @@ +from typing import Optional + +from omegaconf import DictConfig + +cfg: Optional[DictConfig] = None + + +def get_cfg() -> DictConfig: + global cfg + return cfg + + +def set_cfg(new_cfg: DictConfig) -> None: + global cfg + cfg = new_cfg + + +def get_seed() -> int: + return cfg.seed diff --git a/optgs/loss/__init__.py b/optgs/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa32d016c9757698dff268488a597ecbc9d823f8 --- /dev/null +++ b/optgs/loss/__init__.py @@ -0,0 +1,38 @@ +from .loss import Loss +from .loss_deltas import LossDeltas, LossDeltasCfgWrapper +from .loss_iso_scales import LossIsoScalesCfgWrapper, LossIsoScales +from .loss_lpips import LossLpips, LossLpipsCfgWrapper +from .loss_mse import LossMse, LossMseCfgWrapper +from .loss_sh0 import LossSh0CfgWrapper, LossSh0 +from .loss_ssim import LossSsimCfgWrapper, LossSsim +from .loss_sgd import LossSGDCfgWrapper, LossSGD +from .loss_gaussians import LossGaussiansCfgWrapper, LossGaussians +from .loss_stability import LossStabilityCfgWrapper, LossStability + +LOSSES = { + LossLpipsCfgWrapper: LossLpips, + LossMseCfgWrapper: LossMse, + LossDeltasCfgWrapper: LossDeltas, + LossSsimCfgWrapper: LossSsim, + LossSh0CfgWrapper: LossSh0, + LossIsoScalesCfgWrapper: LossIsoScales, + LossSGDCfgWrapper: LossSGD, + LossGaussiansCfgWrapper: LossGaussians, + LossStabilityCfgWrapper: LossStability, +} + +LossCfgWrapper = ( + LossLpipsCfgWrapper | + LossMseCfgWrapper | + LossDeltasCfgWrapper | + LossSsimCfgWrapper | + LossSh0CfgWrapper | + LossIsoScalesCfgWrapper | + LossSGDCfgWrapper | + LossGaussiansCfgWrapper | + LossStabilityCfgWrapper +) + + +def get_losses(cfgs: list[LossCfgWrapper]) -> list[Loss]: + return [LOSSES[type(cfg)](cfg) for cfg in cfgs] diff --git a/optgs/loss/loss.py b/optgs/loss/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1a73890d0faac3b891456a60e45c8a4d4b655d --- /dev/null +++ b/optgs/loss/loss.py @@ -0,0 +1,54 @@ +from abc import ABC, abstractmethod +from dataclasses import fields +from typing import Generic, TypeVar + +from jaxtyping import Float +from torch import Tensor, nn + +from ..misc.batchify import batched_select +from ..model.decoder.decoder import DecoderOutput +from ..model.types import Gaussians +from optgs.scene_trainer.gaussian_module import GaussiansModule + +T_cfg = TypeVar("T_cfg") +T_wrapper = TypeVar("T_wrapper") + + +class Loss(nn.Module, ABC, Generic[T_cfg, T_wrapper]): + cfg: T_cfg + name: str + + def __init__(self, cfg: T_wrapper) -> None: + super().__init__() + + # Extract the configuration from the wrapper. + (field,) = fields(type(cfg)) + self.cfg = getattr(cfg, field.name) + self.name = field.name + + @abstractmethod + def forward( + self, + prediction: DecoderOutput, + gaussians: Gaussians | GaussiansModule | None, + global_step: int, + gt_rgb: Tensor, + pred_rgb: Tensor, + valid_depth_mask: Tensor | None, + **kwargs, + ) -> Float[Tensor, ""]: + pass + + @staticmethod + def extract_pred_gt(curr_gt_rgb, prediction, error_idx, valid_depth_mask): + # curr_gt_rgb is already subsampled to the rendered views (opt_batch_size subset); + # error_idx further subsamples both gt and pred to the views used for the loss. + pred_rgb = prediction.color # [B, V_rendered, C, H, W] + gt_rgb = curr_gt_rgb + if error_idx is not None: + gt_rgb = batched_select(gt_rgb, error_idx) + pred_rgb = batched_select(pred_rgb, error_idx) + if valid_depth_mask is not None: + valid_depth_mask = batched_select(valid_depth_mask, error_idx) + + return gt_rgb, pred_rgb, valid_depth_mask diff --git a/optgs/loss/loss_deltas.py b/optgs/loss/loss_deltas.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f06b0266e6bfeebb603f0b80d201df9f9d49c5 --- /dev/null +++ b/optgs/loss/loss_deltas.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass + +import torch +from jaxtyping import Float +from torch import Tensor + +from optgs.loss import Loss +from optgs.model.decoder.decoder import DecoderOutput +from optgs.model.types import Gaussians +from optgs.scene_trainer.gaussian_module import GaussiansModule + + +@dataclass +class LossDeltasCfg: + weight: float | int + exclude_by_norm_grad: bool + exclude_by_norm_grad_opposite: bool + eps: float + apply_after_step: int + +@dataclass +class LossDeltasCfgWrapper: + deltas: LossDeltasCfg + + +class LossDeltas(Loss[LossDeltasCfg, LossDeltasCfgWrapper]): + def forward( + self, + prediction: DecoderOutput, + gaussians: Gaussians | GaussiansModule | None, + global_step: int, + gt_rgb: Tensor, + pred_rgb: Tensor, + valid_depth_mask: Tensor | None, + l1_loss: bool, + clamp_large_error: float, + **kwargs, + ) -> Float[Tensor, ""]: + + cfg = self.cfg + # Before the specified step, don't apply the loss. + if global_step < cfg.apply_after_step: + return torch.tensor(0, dtype=torch.float32, device=prediction.color.device) + + if gaussians is None: + raise ValueError("Gaussians must be provided for LossDeltas.") + + predicted_deltas = gaussians.deltas + + if not cfg.exclude_by_norm_grad: + return predicted_deltas.abs().mean() * cfg.weight + + norm_g = gaussians.norm_gradients + if norm_g is None: + return predicted_deltas.abs().mean() * cfg.weight + + g = gaussians.gradients + eps = cfg.eps + g_abs = g.abs() + + # Condition 1: small gradients + cond_small = g_abs < eps + mask = cond_small + + # Condition 2: large gradients but opposite sign + # deltas are added (sgd substract), so in practice we want to exclude when they have the same sign + if cfg.exclude_by_norm_grad_opposite: + cond_opposite = (g_abs > self.cfg.eps) & (norm_g.sign() == predicted_deltas.sign()) + # Combine both + mask = cond_small | cond_opposite + + if not mask.any(): + return prediction.color.new_zeros((), dtype=torch.float32) + + # predicted_deltas[mask] creates a new tensor + # return predicted_deltas[mask].abs().mean() * cfg.weight + # alternative without indexing + mask_f = mask.to(predicted_deltas.dtype) + + loss = (predicted_deltas.abs() * mask_f).sum() / mask_f.sum() + return loss * cfg.weight diff --git a/optgs/loss/loss_depth_smooth.py b/optgs/loss/loss_depth_smooth.py new file mode 100644 index 0000000000000000000000000000000000000000..b81949c145ebe73337ca7231f2f2cb9b97ba6ed7 --- /dev/null +++ b/optgs/loss/loss_depth_smooth.py @@ -0,0 +1,29 @@ +import torch + + +def get_smooth_loss(disp, img, no_mean=False): + """Computes the smoothness loss for a disparity image + The color image is used for edge-aware smoothness + ref: https://github.com/nianticlabs/monodepth2/blob/master/layers.py#L202 + """ + if no_mean: + out = torch.zeros_like(disp) + + grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) + grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) + + grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) + grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) + + grad_disp_x *= torch.exp(-grad_img_x) + grad_disp_y *= torch.exp(-grad_img_y) + + if no_mean: + out[:, :, :, :-1] = out[:, :, :, :-1] + grad_disp_x + out[:, :, :-1, :] = out[:, :, :-1, :] + grad_disp_y + + return out + + return grad_disp_x.mean() + grad_disp_y.mean() + + \ No newline at end of file diff --git a/optgs/loss/loss_gaussians.py b/optgs/loss/loss_gaussians.py new file mode 100644 index 0000000000000000000000000000000000000000..4290c560e1ca2c659d0fa0066a54227dc120dff1 --- /dev/null +++ b/optgs/loss/loss_gaussians.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from math import isqrt + +import torch +from jaxtyping import Float +from torch import Tensor + +from optgs.loss import Loss +from optgs.model.decoder.decoder import DecoderOutput +from optgs.model.types import Gaussians +from optgs.scene_trainer.gaussian_module import GaussiansModule + + +@dataclass +class LossGaussiansCfg: + weight: float | int + weight_scales: float + weight_opacities: float + weight_sh: float + sh_alpha: float + + +@dataclass +class LossGaussiansCfgWrapper: + gaussians: LossGaussiansCfg + + +class LossGaussians(Loss[LossGaussiansCfg, LossGaussiansCfgWrapper]): + """L2 regularization on Gaussian scales, opacities, and SH coefficients. + + Each component has an independent weight so they can be tuned separately. + SH degree 0 (DC / base color) is always excluded. + + sh_alpha controls per-degree weighting for the SH term: + alpha=1.0 (default): uniform across all degrees >= 1 + alpha>1.0: exponentially increasing penalty on higher degrees + (degree d gets alpha^d weighting) + """ + + def forward( + self, + prediction: DecoderOutput, + gaussians: Gaussians | GaussiansModule | None, + global_step: int, + gt_rgb: Tensor, + pred_rgb: Tensor, + valid_depth_mask: Tensor | None, + **kwargs, + ) -> Float[Tensor, ""]: + if gaussians is None: + raise ValueError("Gaussians must be provided for LossGaussians.") + + loss = 0 + nr_valid = gaussians.nr_valid + + if self.cfg.weight_scales > 0: + loss = loss + self.cfg.weight_scales * (gaussians.scales[:, :nr_valid] ** 2).mean() # [B, G, 3] + + if self.cfg.weight_opacities > 0: + loss = loss + self.cfg.weight_opacities * (gaussians.opacities[:, :nr_valid] ** 2).mean() # [B, G] + + if self.cfg.weight_sh > 0: + shs = gaussians.harmonics # [B, G, 3, d_sh] + d_sh = shs.shape[-1] + if d_sh > 1: + alpha = self.cfg.sh_alpha + if alpha == 1.0: + shN = shs[:, :nr_valid, :, 1:] + loss = loss + self.cfg.weight_sh * (shN ** 2).mean() + else: + max_degree = isqrt(d_sh) - 1 + sh_loss = torch.tensor(0.0, device=shs.device, dtype=shs.dtype) + for degree in range(1, max_degree + 1): + start = degree ** 2 + end = (degree + 1) ** 2 + sh_band = shs[:, :nr_valid, :, start:end] + sh_loss = sh_loss + (alpha ** degree) * (sh_band ** 2).mean() + loss = loss + self.cfg.weight_sh * sh_loss + + return loss * self.cfg.weight diff --git a/optgs/loss/loss_iso_scales.py b/optgs/loss/loss_iso_scales.py new file mode 100644 index 0000000000000000000000000000000000000000..502e8cd28bb205497cb06a12e53b378c97aacfad --- /dev/null +++ b/optgs/loss/loss_iso_scales.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass + +from jaxtyping import Float +from torch import Tensor + +from optgs.loss import Loss +from optgs.model.decoder.decoder import DecoderOutput +from optgs.model.types import Gaussians +from optgs.scene_trainer.gaussian_module import GaussiansModule + + +@dataclass +class LossIsoScalesCfg: + weight: float | int + +@dataclass +class LossIsoScalesCfgWrapper: + iso_scales: LossIsoScalesCfg + +class LossIsoScales(Loss[LossIsoScalesCfg, LossIsoScalesCfgWrapper]): + """ Enforce isotropic scales of the gaussians. """ + def forward( + self, + prediction: DecoderOutput, + gaussians: Gaussians | GaussiansModule | None, + global_step: int, + gt_rgb: Tensor, + pred_rgb: Tensor, + valid_depth_mask: Tensor | None, + **kwargs, + ) -> Float[Tensor, ""]: + + scales = gaussians.scales # [B, G, 3] + min_scales = scales.min(-1).values # [B, G] + max_scales = scales.max(-1).values # [B, G] + aspect_ratio = min_scales / max_scales + iso_loss = ((aspect_ratio - 1) ** 2).mean() + return iso_loss * self.cfg.weight + + diff --git a/optgs/loss/loss_lpips.py b/optgs/loss/loss_lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..e649a53fee1e33a7533bde2a09822f7cbef694bd --- /dev/null +++ b/optgs/loss/loss_lpips.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from einops import rearrange +from jaxtyping import Float +from lpips import LPIPS +from torch import Tensor + +from ..misc.nn_module_tools import convert_to_buffer +from ..model.decoder.decoder import DecoderOutput +from ..model.types import Gaussians +from optgs.scene_trainer.gaussian_module import GaussiansModule +from .loss import Loss +from .perceptual_loss import PerceptualLoss + + +@dataclass +class LossLpipsCfg: + weight: float + apply_after_step: int + perceptual_loss: bool + + +@dataclass +class LossLpipsCfgWrapper: + lpips: LossLpipsCfg + + +class LossLpips(Loss[LossLpipsCfg, LossLpipsCfgWrapper]): + lpips: LPIPS + + def __init__(self, cfg: LossLpipsCfgWrapper) -> None: + super().__init__(cfg) + + if self.cfg.perceptual_loss: + self.lpips = PerceptualLoss() + else: + self.lpips = LPIPS(net="vgg") + + convert_to_buffer(self.lpips, persistent=False) + + def forward( + self, + prediction: DecoderOutput, + gaussians: Gaussians | GaussiansModule | None, + global_step: int, + gt_rgb: Tensor, + pred_rgb: Tensor, + valid_depth_mask: Tensor | None, + half_res_lpips: bool = False, + **kwargs, + ) -> Float[Tensor, ""]: + + if global_step < self.cfg.apply_after_step: + return torch.tensor(0, dtype=torch.float32, device=pred_rgb.device) + + if valid_depth_mask is not None and valid_depth_mask.max() > 0.5: + pred_rgb = pred_rgb.clone() + gt_rgb = gt_rgb.clone() + pred_rgb[valid_depth_mask] = 0 + gt_rgb[valid_depth_mask] = 0 + + pred = rearrange(pred_rgb, "b v c h w -> (b v) c h w") + gt = rearrange(gt_rgb, "b v c h w -> (b v) c h w") + + if half_res_lpips: + pred = F.interpolate(pred, scale_factor=0.5, mode="bilinear", align_corners=True) + gt = F.interpolate(gt, scale_factor=0.5, mode="bilinear", align_corners=True) + + if self.cfg.perceptual_loss: + loss = self.lpips(pred, gt) + else: + loss = self.lpips(pred, gt, normalize=True) + + return self.cfg.weight * loss.mean() diff --git a/optgs/loss/loss_monodepth.py b/optgs/loss/loss_monodepth.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd9a65c02549e42d21944f2b79adefca13b86ec --- /dev/null +++ b/optgs/loss/loss_monodepth.py @@ -0,0 +1,42 @@ +import torch + + +import cv2 +import torch + +from optgs.model.encoder.depth_anything_v2.dpt import DepthAnythingV2 + + + +def get_monodepth_model(): + model_configs = { + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} + } + + encoder = 'vitl' # or 'vits', 'vitb', 'vitg' + + model = DepthAnythingV2(**model_configs[encoder]) + model.load_state_dict(torch.load(f'pretrained/depth_anything_v2_{encoder}.pth', map_location='cpu')) + model = model.eval() + + for param in model.parameters(): + param.requires_grad = False + + return model + + + +def get_monodepth_pred(img, model): + + with torch.no_grad(): + pass + + +def get_monodepth_loss(pred_depth, img): + pass + + + diff --git a/optgs/loss/loss_mse.py b/optgs/loss/loss_mse.py new file mode 100644 index 0000000000000000000000000000000000000000..788c19f1cc31b6bf74defaa3c2e8f9f166954875 --- /dev/null +++ b/optgs/loss/loss_mse.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass + +from jaxtyping import Float +from torch import Tensor + +from ..model.decoder.decoder import DecoderOutput +from ..model.types import Gaussians +from optgs.scene_trainer.gaussian_module import GaussiansModule +from .loss import Loss + + +@dataclass +class LossMseCfg: + weight: float + + +@dataclass +class LossMseCfgWrapper: + mse: LossMseCfg + + +class LossMse(Loss[LossMseCfg, LossMseCfgWrapper]): + def forward( + self, + prediction: DecoderOutput, + gaussians: Gaussians | GaussiansModule | None, + global_step: int, + gt_rgb: Tensor, + pred_rgb: Tensor, + valid_depth_mask: Tensor | None, + l1_loss: bool, + clamp_large_error: float, + **kwargs, + ) -> Float[Tensor, ""]: + + error = pred_rgb - gt_rgb # [B, V, C, H, W] + + if valid_depth_mask is not None and valid_depth_mask.max() > 0.5 and valid_depth_mask.min() < 0.5: + error = error[~valid_depth_mask] + + if l1_loss: + # l1 loss + error = error.abs() + else: + # l2 loss + error = error ** 2 + + if clamp_large_error > 0: + valid_mask = error < clamp_large_error + error = error[valid_mask] + + error = error.mean() + + return self.cfg.weight * error diff --git a/optgs/loss/loss_sgd.py b/optgs/loss/loss_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..0f565b45e1c00c4a5a2d729886575ed370b6ae0f --- /dev/null +++ b/optgs/loss/loss_sgd.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass + +from jaxtyping import Float +from torch import Tensor + +from optgs.dataset.data_types import BatchedExample +from optgs.loss import Loss +from optgs.model.decoder.decoder import DecoderOutput +from optgs.model.types import Gaussians +from optgs.scene_trainer.gaussian_module import GaussiansModule + + +@dataclass +class LossSGDCfg: + pass + +@dataclass +class LossSGDCfgWrapper: + sgd: LossSGDCfg + + +class LossSGD(Loss[LossSGDCfg, LossSGDCfgWrapper]): + def forward( + self, + prediction: DecoderOutput, + batch: BatchedExample, + gaussians: Gaussians | GaussiansModule | None, + global_step: int, + l1_loss: bool, + clamp_large_error: float, + valid_depth_mask: Tensor | None, + **kwargs, + ) -> Float[Tensor, ""]: + + if gaussians is None: + raise ValueError("Gaussians must be provided for LossDeltas.") + + predicted_deltas = gaussians.deltas + gt_gradients = gaussians.gradients + + # cast to float16 if necessary + if predicted_deltas.dtype != gt_gradients.dtype: + gt_gradients = gt_gradients.to(predicted_deltas.dtype) + + if l1_loss: + loss = (predicted_deltas - gt_gradients).abs().mean() + else: + loss = ((predicted_deltas - gt_gradients) ** 2).mean() + if clamp_large_error > 0: + loss = loss.clamp(max=clamp_large_error) + return loss diff --git a/optgs/loss/loss_sh0.py b/optgs/loss/loss_sh0.py new file mode 100644 index 0000000000000000000000000000000000000000..1d559f4ab0ac8a09e5c373b53f7c29f2b7e56a31 --- /dev/null +++ b/optgs/loss/loss_sh0.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass + +import torch +import torch.nn.functional +from einops import rearrange +from gsplat.exporter import sh2rgb + +from optgs.loss import Loss +from optgs.model.decoder.decoder import DecoderOutput + + +@dataclass +class LossSh0Cfg: + weight: float + +@dataclass +class LossSh0CfgWrapper: + sh0: LossSh0Cfg + +class LossSh0(Loss[LossSh0Cfg, LossSh0CfgWrapper]): + def forward( + self, + prediction: DecoderOutput, + gaussians, + global_step: int, + gt_rgb: torch.Tensor, + pred_rgb: torch.Tensor, + valid_depth_mask, + gt_image: torch.Tensor, # full-res image [B, V, C, H, W], all views un-subsampled + **kwargs, + ): + sh0_pred = gaussians.harmonics[..., 0] # [B, G, 3] + # Convert SH0 to RGB + rgb_pred = sh2rgb(sh0_pred) # [B, G, 3] + + rgb = gt_image # [B, V, C, H, W] + h, w = rgb.shape[-2:] + means2d = prediction.means2d.detach().clone() # [B, V, G, 2] + means2d[..., 0] = (means2d[..., 0] / (w - 1)) * 2 - 1 + means2d[..., 1] = (means2d[..., 1] / (h - 1)) * 2 - 1 + rgb_gt = torch.nn.functional.grid_sample(rearrange(rgb, "b v c h w -> (b v) c h w"), + rearrange(means2d, "b v g c -> (b v) 1 g c"), + align_corners=False, + padding_mode="border") # [(B V), 3, 1, G] + rgb_gt = rearrange(rgb_gt, "(b v) c 1 g -> b v g c", b=rgb.shape[0], v=rgb.shape[1]) # [B, V, G, 3] + # Calculate mean over views, exclude invalid pixels + # Calculate only for valid intersection of the gaussians and the views + radii = prediction.radii.detach() # [B, V, G, 2] + # Gaussian didn't contribute to this view + # For these gaussians, means2d is (0,0), so we want to exclude them from the computation + valid = (radii > 0).all(-1, keepdim=True) # [B, V, G, 1] + valid_counts = valid.sum(1) # [B, G, 1] + denom = valid_counts + (valid_counts == 0).float() # avoid division by zero + rgb_gt_avg = rgb_gt * valid # [B, V, G, 3] + rgb_gt_avg = rgb_gt_avg.sum(1) / denom # [B, G, 3] + + error = rgb_pred - rgb_gt_avg + error = error[(valid_counts > 0)[..., 0]] + loss = (error ** 2).abs().mean() + return loss * self.cfg.weight + diff --git a/optgs/loss/loss_ssim.py b/optgs/loss/loss_ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..ad99cb195f190e8d2c1a584b5138c6b3ed73aa25 --- /dev/null +++ b/optgs/loss/loss_ssim.py @@ -0,0 +1,53 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# +from dataclasses import dataclass + +from jaxtyping import Float +from torch import Tensor + +from .loss import Loss +from ..model.decoder.decoder import DecoderOutput +from optgs.scene_trainer.gaussian_module import GaussiansModule +from ..model.types import Gaussians + +from fused_ssim import fused_ssim +from einops import rearrange + +@dataclass +class LossSsimCfg: + weight: float + + +@dataclass +class LossSsimCfgWrapper: + ssim: LossSsimCfg + + +class LossSsim(Loss[LossSsimCfg, LossSsimCfgWrapper]): + def forward( + self, + prediction: DecoderOutput, + gaussians: Gaussians | GaussiansModule | None, + global_step: int, + gt_rgb: Tensor, + pred_rgb: Tensor, + valid_depth_mask: Tensor | None, + **kwargs, + ) -> Float[Tensor, ""]: + # same calculation as gsplat + # https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py#L684 + # predicted_image, gt_image: [BS, CH, H, W] + # predicted_image is differentiable + pred = rearrange(pred_rgb, "b v c h w -> (b v) c h w") + gt = rearrange(gt_rgb, "b v c h w -> (b v) c h w") + ssim_value = 1 - fused_ssim(pred, gt, padding="valid") + + return self.cfg.weight * ssim_value diff --git a/optgs/loss/loss_stability.py b/optgs/loss/loss_stability.py new file mode 100644 index 0000000000000000000000000000000000000000..fa299b2139eb589d33852f30b5bf6d600145867d --- /dev/null +++ b/optgs/loss/loss_stability.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass + +import torch +from jaxtyping import Float +from torch import Tensor + +from optgs.dataset.data_types import BatchedExample +from optgs.loss import Loss +from optgs.scene_trainer.optimizer.optimizer import OptimizerOutput + + +@dataclass +class LossStabilityCfg: + weight: float | int + + +@dataclass +class LossStabilityCfgWrapper: + stability: LossStabilityCfg + + +class LossStability(Loss[LossStabilityCfg, LossStabilityCfgWrapper]): + def forward( + self, + optimizer_output: OptimizerOutput, + batch: BatchedExample, + **kwargs, + ) -> Float[Tensor, ""]: + total_loss = torch.tensor(0.0, device=optimizer_output.get_render_list("context")[0].color.device) + # Stability loss: encourage the model to produce similar outputs for the same input across iterations. + for input_str in ["context", "target"]: + render_list = optimizer_output.get_render_list(input_str) + index_list = optimizer_output.get_index_list(input_str) # list of I-1 tensors of shape [B, V] + + if len(index_list) == 0: + predictions = [render.color for render in render_list] + predictions = torch.stack(predictions, dim=0) # [I, B, V, C, H, W] + gt = batch[input_str]["image"] # [B, V_all, C, H, W] + + # V == V_all + # Compute l1 loss between predictions and gt for each iteration + loss = torch.abs(predictions - gt).mean(dim=[3, 4, 5]) # [I, B, V] + change_in_loss = loss[1:] - loss[:-1].detach() # [I-1, B, V] + change_in_loss = torch.relu(change_in_loss) # Only consider increases in loss as contributing to the stability loss + else: + continue + + # Duplicate the first index for the initialization + index_list = [index_list[0]] + index_list # Now we have I tensors of shape [B, V] + index_list = torch.stack(index_list, dim=0) # [I-1, B, V] + + b = gt.shape[0] + device = gt.device + batch_idx = torch.arange(b, device=device)[None, :, None] # [1, B, 1] + gt_indexed = gt[batch_idx, index_list] # [I, B, V, C, H, W] + + # Compute l1 loss between predictions and gt for each iteration + # Consider the the indexing of the views within the full batch + loss = torch.abs(predictions - gt_indexed).mean(dim=[3, 4, 5]) # [I, B, V] + + # We want to make sure that the loss decreases across iterations for specific views + I, B, V_all = predictions.shape[0], gt.shape[0], gt.shape[1] + + # Scatter losses into full view space + # Don't use scatter_ in-place to enable backpropagation through the loss values + loss_full = torch.zeros(I, B, V_all, device=loss.device).scatter(2, index_list, loss) # [I, B, V_all] + + iter_idx = torch.arange(I, device=device).view(-1, 1, 1) # [I,1,1] + + # mark unvisited as -1 + visited = loss_full > 0 # [I, B, V_all] + visit_ids = torch.where(visited, iter_idx, torch.full_like(iter_idx, -1)) # [I, B, V] + + # running max gives last visit index + last_visit = torch.cummax(visit_ids, dim=0).values # [I,B,V] + + # shift to get strictly previous visit + prev_visit = torch.roll(last_visit, shifts=1, dims=0) + prev_visit[0] = -1 # first iter has no previous + + safe_prev = prev_visit.clamp(min=0) + + prev_loss = loss_full.gather(0, safe_prev).detach() + + has_prev = prev_visit >= 0 + + change_in_loss = torch.relu(loss_full - prev_loss) + change_in_loss = change_in_loss * has_prev.detach() + + + # # Create a mask to identify views that have been visited in previous iterations (cumulative OR) + # # Calculate the + # visited = loss_full > 0 # [I, B, V_all] + # + # # Calcaulate the last visited index for each view + # # Indices along I dimension: shape [I, 1, 1], broadcast over B and v_all + # indices = torch.arange(I, device=visited.device).view(-1, 1, 1).expand_as(visited) # [I, 1, 1] -> [I, B, V_all] + # indices = indices.clone() + # indices[visited == 0] = 0 + # prev_visit_idx = torch.cummax(indices, dim=0).values - 1 # [I, B, V_all] + # # valid previous visit exists + # has_prev = prev_visit_idx >= 0 + # prev_visit_idx = torch.clamp(prev_visit_idx, min=0) # Ensure indices are non-negative + # + # # Loss from the previous visit for each view at each iteration (starting from the second iteration) + # prev_loss = loss_full.detach().gather(0, prev_visit_idx)[1:] # [I-1, B, V_all] + # + # curr_loss = loss_full[1:] # [I-1, B, V_all], current loss for each view at each iteration + # + # change_in_loss = curr_loss - prev_loss # [I-1, B, V_all] + # change_in_loss = torch.relu(change_in_loss) # Only consider increases in loss as contributing to the stability loss + # + # # Valid comparison mask: + # # - current iter visited + # # - previous visit index is strictly smaller (i.e. a real previous visit exists) + # mask = visited[1:] & has_prev[:-1] # [I-1, B, V_all] + # change_in_loss = change_in_loss * mask.float().detach() # Zero out change_in_loss for views that haven't been visited in both iterations + # + + + # # Fill in the loss values for the previous visits + # loss_full_filled = loss_full.gather(0, prev_visit_idx) # [I, B, V_all], now loss_full[i] contains the loss from the previous visit for each view + # + # # Update visited + # visited_filled = loss_full > 0 # [I, B, V_all], now visited[i] is True for all views visited up to iteration i + # + # # Now compute change_in_loss across consecutive iterations + # change_in_loss = loss_full_filled[1:] - loss_full_filled[:-1].detach() # [I-1, B, V_all] + # + # # Mask change_in_loss to only consider views that have been visited in previous iterations (i.e., views that have a valid loss comparison) + # # Detach the mask to prevent gradients from flowing through it + # mask = visited_filled[1:] & visited_filled[:-1] # [I-1, B, V_all], True for views that have been visited in both iterations being compared + # mask = mask.detach() + # change_in_loss = change_in_loss * mask.float() # [I-1, B, V_all], zero out change_in_loss for views that haven't been visited in both iterations + # + # # Apply ReLU to only penalize increases in loss + # change_in_loss = torch.relu(change_in_loss) # [I-1, B, V_all], only positive change_in_loss contribute to the loss + + # loss + total_loss += change_in_loss.sum() + return total_loss * self.cfg.weight diff --git a/optgs/loss/perceptual_loss.py b/optgs/loss/perceptual_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d356e8ca3809a2f0ddc01dc1f4996b089d21e6c4 --- /dev/null +++ b/optgs/loss/perceptual_loss.py @@ -0,0 +1,143 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F +from torchvision.models import vgg19 +import scipy.io +import os +from pathlib import Path + + +# the perception loss code is modified from https://github.com/zhengqili/Crowdsampling-the-Plenoptic-Function/blob/f5216f312cf82d77f8d20454b5eeb3930324630a/models/networks.py#L1478 +# and some parts are based on https://github.com/arthurhero/Long-LRM/blob/main/model/loss.py +class PerceptualLoss(nn.Module): + def __init__(self): + super().__init__() + self.vgg = self._build_vgg() + self._load_weights() + self._setup_feature_blocks() + + def _build_vgg(self): + """Create VGG model with average pooling instead of max pooling.""" + model = vgg19() + # Replace max pooling with average pooling + for i, layer in enumerate(model.features): + if isinstance(layer, nn.MaxPool2d): + model.features[i] = nn.AvgPool2d(kernel_size=2, stride=2) + + model = model.eval() + + for param in model.parameters(): + param.requires_grad = False + + return model + + def _load_weights(self): + """Load pre-trained VGG weights. """ + weight_file = Path("./metric_checkpoint/imagenet-vgg-verydeep-19.mat") + weight_file.parent.mkdir(exist_ok=True, parents=True) + + # if torch.distributed.get_rank() == 0: + # Download weights if needed + if not weight_file.exists(): + os.system(f'wget https://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat -O {weight_file}') + # torch.distributed.barrier() + + # Load MatConvNet weights + vgg_data = scipy.io.loadmat(weight_file) + vgg_layers = vgg_data["layers"][0] + + # Layer indices and filter sizes + layer_indices = [0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34] + filter_sizes = [64, 64, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 512] + + # Transfer weights to PyTorch model + with torch.no_grad(): + for i, layer_idx in enumerate(layer_indices): + # Set weights + weights = torch.from_numpy(vgg_layers[layer_idx][0][0][2][0][0]).permute(3, 2, 0, 1) + self.vgg.features[layer_idx].weight = nn.Parameter(weights, requires_grad=False) + + # Set biases + biases = torch.from_numpy(vgg_layers[layer_idx][0][0][2][0][1]).view(filter_sizes[i]) + self.vgg.features[layer_idx].bias = nn.Parameter(biases, requires_grad=False) + + def _setup_feature_blocks(self): + """Create feature extraction blocks at different network depths.""" + output_indices = [0, 4, 9, 14, 23, 32] + self.blocks = nn.ModuleList() + + # Create sequential blocks + for i in range(len(output_indices) - 1): + block = nn.Sequential(*list(self.vgg.features[output_indices[i]:output_indices[i+1]])) + self.blocks.append(block.eval()) + + # Freeze all parameters + for param in self.vgg.parameters(): + param.requires_grad = False + + def _extract_features(self, x): + """Extract features from each block.""" + features = [] + for block in self.blocks: + x = block(x) + features.append(x) + return features + + def _preprocess_images(self, images): + """Convert images to VGG input format.""" + # VGG mean values for ImageNet + mean = torch.tensor([123.6800, 116.7790, 103.9390]).reshape(1, 3, 1, 1).to(images.device) + return images * 255.0 - mean + + @staticmethod + def _compute_error(real, fake): + return torch.mean(torch.abs(real - fake)) + + def forward(self, pred_img, target_img, return_feature=False, **kwargs): + """Compute perceptual loss between prediction and target.""" + # Preprocess images + target_img_p = self._preprocess_images(target_img) + pred_img_p = self._preprocess_images(pred_img) + + # Extract features + target_features = self._extract_features(target_img_p) + pred_features = self._extract_features(pred_img_p) + # for x in target_features: + # print(x.shape) + + if return_feature: + return pred_features, target_features + + # Pixel-level error + e0 = self._compute_error(target_img_p, pred_img_p) + + # Feature-level errors with scaling factors + e1 = self._compute_error(target_features[0], pred_features[0]) / 2.6 + e2 = self._compute_error(target_features[1], pred_features[1]) / 4.8 + e3 = self._compute_error(target_features[2], pred_features[2]) / 3.7 + e4 = self._compute_error(target_features[3], pred_features[3]) / 5.6 + e5 = self._compute_error(target_features[4], pred_features[4]) * 10 / 1.5 + + # Combine all errors and normalize + total_loss = (e0 + e1 + e2 + e3 + e4 + e5) / 255.0 + + return total_loss + + + +def test(): + b, h, w = 2, 128, 256 + device = torch.device('cuda') + x = torch.randn(b, 3, h, w).to(device) + y = torch.randn(b, 3, h, w).to(device) + + model = PerceptualLoss().to(device) + + loss = model(x, y) + + print(loss) + + +if __name__ == '__main__': + test() + diff --git a/optgs/main.py b/optgs/main.py new file mode 100644 index 0000000000000000000000000000000000000000..a1316a83e36c6832f54e2169baad072753bc442e --- /dev/null +++ b/optgs/main.py @@ -0,0 +1,279 @@ +import os +import random +import sys +import warnings +from pathlib import Path + +import hydra +import numpy as np +import torch +from jaxtyping import install_import_hook +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ( + LearningRateMonitor, + ModelCheckpoint, +) +from pytorch_lightning.loggers.wandb import WandbLogger +from pytorch_lightning.plugins.environments import LightningEnvironment +from pytorch_lightning.profilers import PyTorchProfiler + +from optgs.misc.io import cyan +from optgs.misc.console import banner, config_table, warn + +# Configure beartype and jaxtyping. +with install_import_hook( + ("optgs",), + ("beartype", "beartype"), +): + from optgs.config import setup_cfg, SkipRun + from optgs.dataset.data_module import DataModule + from optgs.loss import get_losses + from optgs.misc.step_tracker import StepTracker + from optgs.misc.wandb_tools import update_checkpoint_path, setup_wandb_logger + from optgs.misc.checkpointing import find_latest_ckpt, load_model_weights + from optgs.meta_trainer.meta_trainer import MetaTrainer + +# print torch device info +print(cyan(f"Torch version: {torch.__version__}")) +if torch.cuda.is_available(): + print(cyan(f"CUDA is available. Number of devices: {torch.cuda.device_count()}")) + for i in range(torch.cuda.device_count()): + print(cyan(f"Device {i}: {torch.cuda.get_device_name(i)}")) +else: + print(cyan("CUDA is not available.")) + # raise ValueError("CUDA is required to run this code.") + + +@hydra.main( + version_base=None, + config_path="config", + config_name="main", +) +def train(cfg_dict: DictConfig): + print(cyan(f"Starting main script. cli cfg was parsed ")) + # Set up configuration. + try: + cfg, cfg_dict, eval_cfg = setup_cfg(cfg_dict) + except SkipRun as e: + print(cyan(f"Skipping run: {e}")) + sys.exit(0) + + print_important_cfg_flags(cfg) + + if cfg.debug_cfg: + print(cyan("=" * 60)) + print(cfg) + print(cyan("=" * 60)) + print(cyan(f"Config debug mode, exiting..")) + exit(0) + + # Set up logging with wandb. + callbacks = [] + logger = setup_wandb_logger(cfg, cfg_dict) + if isinstance(logger, WandbLogger): + callbacks.append(LearningRateMonitor("step", True)) + + # Set up checkpointing. + callbacks.append( + ModelCheckpoint( + cfg_dict.output_dir / "checkpoints", + every_n_train_steps=cfg.checkpointing.every_n_train_steps, + save_top_k=cfg.checkpointing.save_top_k, + monitor="info/global_step", + mode="max", + ) + ) + for cb in callbacks: + cb.CHECKPOINT_EQUALS_CHAR = '_' + + # Prepare the checkpoint for loading. + if cfg.checkpointing.resume: + if not os.path.exists(cfg_dict.output_dir / 'checkpoints'): + checkpoint_path = None + else: + checkpoint_path = find_latest_ckpt(cfg_dict.output_dir / 'checkpoints') + # Pass to Lightning via ckpt_path — it restores weights, optimizer, scheduler, and step. + # Do not also set pretrained_model; that would double-load the weights. + print(f'resume from {checkpoint_path}') + else: + checkpoint_path = update_checkpoint_path(cfg.checkpointing.load, cfg.wandb) + + # This allows the current step to be shared with the data loader processes. + step_tracker = StepTracker() + + strategy = cfg.meta_trainer.get_dist_strategy(cfg.scene_trainer) + + if cfg_dict.profiling.mode == "basic": + profiler = "simple" + elif cfg_dict.profiling.mode == "advanced": + profiler = "advanced" + elif cfg_dict.profiling.mode == "pytorch": + # wall clock time not representative of true wall clock time + profiler = PyTorchProfiler(filename="profile-logs") # saves separate reports per rank when distributed training + else: + profiler = None + + trainer = Trainer( + max_epochs=-1, + accelerator="gpu" if torch.cuda.is_available() else "auto", + logger=logger, + devices=torch.cuda.device_count() if torch.cuda.is_available() else "auto", + strategy=strategy, + callbacks=callbacks, + val_check_interval=cfg.meta_trainer.val_check_interval, + enable_progress_bar=cfg.mode == "test", + gradient_clip_val=cfg.meta_trainer.gradient_clip_val if not cfg.scene_trainer.use_fsdp else 0., + # clip by norm is not supported by fsdp + max_steps=cfg.meta_trainer.max_steps, + num_sanity_val_steps=cfg.meta_trainer.num_sanity_val_steps, + num_nodes=cfg.meta_trainer.num_nodes, + plugins=LightningEnvironment() if cfg.use_plugins else None, + limit_test_batches=cfg.meta_trainer.limit_test_batches, + limit_train_batches=cfg.meta_trainer.limit_train_batches, + inference_mode=False, # never use inference mode to allow autograd graph construction + profiler=profiler, + ) + + seed = cfg_dict.seed + trainer.global_rank + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + # Note: Only helpful w/ ReSplat initializer for ours + init_name = getattr(cfg.scene_trainer.scene_initializer, "name", None) + opt_name = getattr(cfg.scene_trainer.scene_optimizer, "name", None) + if init_name == "resplat" and opt_name in ["clogs", "learn2splat"]: + if not cfg.scene_trainer.scene_optimizer.update_only_nonzero_grad: + # Means that the number of gaussians is fixed along itertaion + torch.backends.cudnn.benchmark = True + + # Create the model (MetaTrainer wraps SceneTrainer) + meta_trainer = MetaTrainer( + cfg=cfg, + meta_optimizer_cfg=cfg.meta_optimizer, + test_cfg=cfg.meta_trainer.test, + train_cfg=cfg.meta_trainer.train, + scene_trainer_cfg=cfg.scene_trainer, + losses=get_losses(cfg.loss), + step_tracker=step_tracker, + eval_data_cfg=(None if eval_cfg is None else eval_cfg.dataset), + ) + + data_module = DataModule( + cfg.dataset, + cfg.data_loader, + step_tracker, + global_rank=trainer.global_rank, + ) + + if cfg.mode == "train": + print("train:", len(data_module.train_dataloader())) + print("val:", len(data_module.val_dataloader())) + print("test:", len(data_module.test_dataloader())) + else: + print("test:", len(data_module.test_dataloader())) + + strict_load = not cfg.checkpointing.no_strict_load + + if cfg.mode == "train": + assert cfg.scene_trainer.train_scene_opt or cfg.scene_trainer.train_scene_init, \ + "Both scene optimizer and initializer are frozen. Nothing to train." + load_model_weights(cfg, meta_trainer.scene_trainer, strict_load, mode="train") + trainer.fit(meta_trainer, datamodule=data_module, ckpt_path=checkpoint_path) + else: + load_model_weights(cfg, meta_trainer.scene_trainer, strict_load, mode="test") + trainer.test( + meta_trainer, + datamodule=data_module, + ckpt_path=checkpoint_path, + ) + + +def print_important_cfg_flags(cfg): + def kv(param_name): + """Return (param_name, value) for a param known to exist.""" + return param_name, eval(param_name, {"cfg": cfg}) + + def maybe(param_name): + """Return (param_name, value), or None if the attribute is absent.""" + try: + return kv(param_name) + except AttributeError: + return None + + def present(*rows): + """Drop rows that `maybe` resolved to None.""" + return [r for r in rows if r is not None] + + if cfg.scene_trainer.scene_optimizer is None: + optimizer_rows = [("cfg.scene_trainer.scene_optimizer", "None")] + else: + optimizer_rows = present( + maybe("cfg.scene_trainer.scene_optimizer.name"), + maybe("cfg.scene_trainer.scene_optimizer.init_state_wo_features"), + maybe("cfg.scene_trainer.scene_optimizer.init_state_scale"), + maybe("cfg.scene_trainer.scene_optimizer.init_state_type"), + maybe("cfg.scene_trainer.scene_optimizer.use_fused_attn"), + maybe("cfg.scene_trainer.scene_optimizer.knn_idx_update_every"), + maybe("cfg.scene_trainer.scene_optimizer.update_only_nonzero_grad"), + ) + + sections = { + "Output dir": [kv("cfg.output_dir"), kv("cfg.mode")], + "Scene trainer": [ + kv("cfg.scene_trainer.opt_batch_size"), + kv("cfg.scene_trainer.opt_batch_strategy"), + ], + "Checkpoints": [ + kv("cfg.checkpointing.pretrained_model"), + kv("cfg.checkpointing.pretrained_optimizer"), + kv("cfg.checkpointing.pretrained_initializer"), + kv("cfg.checkpointing.no_strict_load"), + ], + "Optimizer": optimizer_rows, + "Initialization": present( + kv("cfg.scene_trainer.scene_initializer.name"), + maybe("cfg.scene_trainer.scene_initializer.path"), + maybe("cfg.scene_trainer.scene_initializer.dl3dv_settings"), + maybe("cfg.scene_trainer.scene_initializer.eval_fixed_gaussians_num"), + maybe("cfg.scene_trainer.scene_initializer.filter_zero_rgb"), + ), + "Dataset": present( + kv("cfg.dataset.name"), + maybe("cfg.dataset.test_start_idx"), + maybe("cfg.dataset.num_scenes"), + kv("cfg.dataset.view_sampler.name"), + maybe("cfg.dataset.view_sampler.num_context_views"), + maybe("cfg.dataset.view_sampler.index_path"), + maybe("cfg.dataset.image_shape"), + maybe("cfg.dataset.ori_image_shape"), + ), + "Training": present(maybe("cfg.loss")), + } + config_table(sections, title="Important config params") + + +def main(): + """Console entry point. Equivalent to `python -m optgs.main`.""" + warnings.filterwarnings("ignore") + torch.set_float32_matmul_precision('high') + + if not torch.cuda.is_available(): + warn("CUDA is not available, running on CPU.") + + banner( + "optgs", + [ + f"host {os.uname().nodename}", + f"slurm job id {os.environ.get('SLURM_JOB_ID', 'N/A')}", + f"slurm gpus {os.environ.get('SLURM_STEP_GPUS', 'N/A')}", + f"working dir {Path.cwd()}", + ], + ) + + train() + + +if __name__ == "__main__": + main() diff --git a/optgs/meta_trainer/__init__.py b/optgs/meta_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/meta_trainer/meta_trainer.py b/optgs/meta_trainer/meta_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..b09b32c4d8846f2e2813d80229e542da21d944ea --- /dev/null +++ b/optgs/meta_trainer/meta_trainer.py @@ -0,0 +1,2546 @@ +import json +import math +import os +import time +import warnings +from collections import defaultdict +from pathlib import Path +from typing import Optional, runtime_checkable, Protocol, Literal + +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torchvision +import wandb +from einops import rearrange, repeat, pack +from jaxtyping import Float +from lightning_fabric.utilities import rank_zero_only +from pytorch_lightning import LightningModule +from pytorch_lightning.loggers import WandbLogger +from torch import optim, Tensor, nn +from tqdm import tqdm + +from optgs.config import RootCfg, MetaTrainerCfg +from optgs.dataset import DatasetCfg +from optgs.dataset.data_module import get_data_shim +from optgs.dataset.data_types import BatchedExample, BatchedViews +from optgs.evaluation.depth_metrics import compute_depth_errors +from optgs.evaluation.metrics import compute_psnr, compute_ssim, compute_rgb_metrics +from optgs.loss import Loss +from optgs.loss.loss_depth_smooth import get_smooth_loss +from optgs.loss.loss_stability import LossStability +from optgs.meta_trainer.replay_buffer import GaussianEpisodeEntry +from optgs.misc.LocalLogger import LocalLogger, LOG_PATH +from optgs.misc.batchify import batched_select +from optgs.misc.benchmarker import Benchmarker +from optgs.misc.console import rule, warn +from optgs.misc.general_utils import SkipBatchException +from optgs.misc.image_io import prep_image, save_video, save_image +from optgs.misc.io import CustomPath +from optgs.misc.stablize_camera import render_stabilization_path +from optgs.misc.step_tracker import StepTracker +from optgs.model.colmap_utils.convert_to_colmap import save_opencv_camera +from optgs.model.colmap_utils.extract_sparse_view_extrinsics import extract_sparse_images_bin +from optgs.model.decoder import get_decoder +from optgs.model.ply_export import save_gaussian_ply +from optgs.paths import DEBUG +from optgs.scene_trainer.initializer.initializer import InitializerOutput, Initializer +from optgs.scene_trainer.optimizer.optimizer import OptimizerPreviousOutput, OptimizerOutput, Optimizer +from optgs.scene_trainer.postprocessing import PostProcessing3DGS +from optgs.scene_trainer.scene_trainer import SceneTrainer # Use existing SceneTrainer +from optgs.scene_trainer.scene_trainer_cfg import SceneTrainerCfg, MetaOptimizerCfg, TestCfg, TrainCfg +from optgs.visualization.annotation import add_label +from optgs.visualization.camera_trajectory.interpolation import interpolate_extrinsics, interpolate_intrinsics +from optgs.visualization.camera_trajectory.wobble import generate_wobble, generate_wobble_transformation +from optgs.visualization.color_map import apply_color_map_to_image +from optgs.visualization.layout import hcat, vcat, add_border +from optgs.visualization.validation_in_3d import render_projections +from optgs.visualization.vis_depth import viz_depth_tensor + +try: + from bitsandbytes.optim import AdamW8bit +except: + pass + +try: + import moviepy.editor as mpy +except: + import moviepy as mpy + + +@runtime_checkable +class TrajectoryFn(Protocol): + def __call__( + self, + t: Float[Tensor, " t"], + ) -> tuple[ + Float[Tensor, "batch view 4 4"], # extrinsics + Float[Tensor, "batch view 3 3"], # intrinsics + ]: + pass + + +slurm_id_logged = False +debug_count = 0 + + +class _SkipStepException(Exception): + """Raised inside meta_training_step to signal that this step should be + skipped. Caught in training_step, which then does a single all_reduce so + every rank skips together — preventing NCCL hangs.""" + pass + + +class MetaTrainer(LightningModule): + """ + Meta-level trainer that handles the outer loop of meta-learning. + + This class focuses on: + - Meta-level training loop and replay buffer management + - Delegating scene-level optimization to the existing SceneTrainer + - Meta-optimization of the SceneTrainer's parameters + """ + + meta_optimizer_cfg: MetaOptimizerCfg + test_cfg: TestCfg + train_cfg: TrainCfg + logger: Optional[WandbLogger] + scene_trainer_cfg: SceneTrainerCfg + losses: nn.ModuleList + step_tracker: StepTracker | None + eval_data_cfg: Optional[DatasetCfg | None] + meta_trainer_cfg: MetaTrainerCfg + + def __init__( + self, + cfg: RootCfg, + meta_optimizer_cfg: MetaOptimizerCfg, + test_cfg: TestCfg, + train_cfg: TrainCfg, + scene_trainer_cfg: SceneTrainerCfg, + losses: list[Loss], + step_tracker: StepTracker | None, + eval_data_cfg: Optional[DatasetCfg] = None, + ) -> None: + super().__init__() + self.meta_optimizer_cfg = cfg.meta_optimizer + self.test_cfg = cfg.meta_trainer.test + self.train_cfg = cfg.meta_trainer.train + self.step_tracker = step_tracker + self.eval_data_cfg = eval_data_cfg + self.scene_trainer_cfg = cfg.scene_trainer + self.meta_trainer_cfg = cfg.meta_trainer + + # Create the existing SceneTrainer that contains all the scene-level logic + # This includes the initializer, optimizer, decoder, and get_optimized_gaussians method + self.scene_trainer = SceneTrainer( + test_cfg=test_cfg, + train_cfg=train_cfg, + scene_trainer_cfg=scene_trainer_cfg, + decoder=get_decoder(cfg.scene_trainer.decoder, cfg.dataset), + step_tracker=step_tracker, + eval_data_cfg=eval_data_cfg, + ) + + self.initializer_data_shim = get_data_shim(self.scene_initializer) + self.losses = nn.ModuleList(losses) + + # Testing utilities + self.benchmarker = Benchmarker() + self.eval_cnt = 0 + + if self.test_cfg.compute_scores: + self.test_step_outputs_target = defaultdict(list) + self.test_step_outputs_context = defaultdict(list) + + if cfg.mode == "train" and self.train_cfg.use_replay_buffer and self.scene_trainer_cfg.num_update_steps > 0: + assert self.scene_optimizer is not None + assert self.scene_optimizer.strategy == "learned" + + if getattr(self.scene_optimizer.cfg, 'concat_init_state', False): + raise NotImplementedError("Replay buffer with concat_init_state is not supported") + if getattr(self.scene_optimizer.cfg, 'replace_init_state', False): + raise NotImplementedError("Replay buffer with replace_init_state is not supported") + from optgs.meta_trainer.replay_buffer import EpisodeReplayBuffer + self.buffer = EpisodeReplayBuffer(self.train_cfg.replay_buffer_cfg) + else: + self.buffer = None + + self._use_dataloader_batch = True # default + self._new_scenes_cnt = -1 + self.gaussian_timestep_list = [] + self.gaussian_timestep_table = wandb.Table(columns=["epoch", "gaussian_timestep", "count"]) + + self.promoting_buffer_sample = False + + if self.training: + self._inner_iteration_data = [] # Store data for logging inner iterations psnr across meta iterations + + # ==================== Lightning Hooks ==================== + + def on_before_batch_transfer(self, batch: BatchedExample, dataloader_idx: int) -> BatchedExample: + """Decide before device transfer whether this step should draw from the replay buffer or the dataloader.""" + # Decide whether we'll use the buffer + if self.training and self.buffer is not None and self.buffer.should_sample(): + self._use_dataloader_batch = False + else: + self._use_dataloader_batch = True + return batch + + def on_after_batch_transfer(self, batch: BatchedExample, dataloader_idx: int) -> BatchedExample: + """Convert raw context/target dicts into typed BatchedViews after the batch lands on the device.""" + batch["context"] = BatchedViews.from_dict(batch["context"]) + batch["target"] = BatchedViews.from_dict(batch["target"]) + return batch + + def will_move_minibatch_to_device(self, batch): + """True when only a sub-batch of views needs to move to device (non-learned init + opt_batch_size < V).""" + # TODO Naama: check if used + # When we sabsample a minibatch, we can move only it to device + return (self.scene_initializer.strategy == "nonlearned" and + self.scene_optimizer is not None and + self.scene_trainer_cfg.opt_batch_size != batch["context"]["image"].shape[1]) + + def transfer_batch_to_device(self, batch, device, dataloader_idx): + # Only transfer if we're going to use this batch + if self.training: + if self._use_dataloader_batch: + should_move = True # move if using dataloader batch + # Also, if the initializer is not learned and the optimizer uses inner batch size, then we also don't want to + # move the batch + # if self.will_move_minibatch_to_device(batch): + # should_move = False + else: + should_move = False # don't move if using buffer sample (we'll move it in the buffer sampling code) + else: + should_move = True # always move during validation and testing + + if should_move: + return super().transfer_batch_to_device(batch, device, dataloader_idx) + else: + return batch # Don't move — we're going to ignore this batch anyway + + def on_save_checkpoint(self, checkpoint): + # Remove the monodepth_model weights from the checkpoint + if 'state_dict' in checkpoint: + keys_to_remove = [k for k in checkpoint['state_dict'] if k.startswith('pretrained_monodepth')] + for k in keys_to_remove: + del checkpoint['state_dict'][k] + + def on_load_checkpoint(self, checkpoint): + # Override scheduler total_steps to match current max_steps so LR doesn't + # hit 0 early when resuming for extended training. + for scheduler in checkpoint.get("lr_schedulers", []): + saved_steps = scheduler.get("total_steps") + if saved_steps is not None and saved_steps != self.trainer.max_steps: + print( + f"Resuming with extended training: scheduler total_steps " + f"{saved_steps} → {self.trainer.max_steps}. " + f"LR schedule will be stretched, not restarted from scratch." + ) + scheduler["total_steps"] = self.trainer.max_steps + + def on_train_epoch_start(self): + """Handle epoch start for scene-based training.""" + if hasattr(self.scene_trainer, 'on_train_epoch_start'): + return self.scene_trainer.on_train_epoch_start() + + def on_train_epoch_end(self) -> None: + if self.global_rank == 0: + if self.buffer is not None: + print(f"Buffer size: {len(self.buffer)}") + + if self.logger is not None and isinstance(self.logger, WandbLogger): + # counts = Counter(self.gaussian_timestep_list) + # for eid, c in counts.items(): + # self.gaussian_timestep_table.add_data(self.current_epoch, eid, c) + + # wandb.log({"replay_buffer/event_counts_table": self.gaussian_timestep_table}) + # log also histogram + wandb.log({"replay_buffer/gaussian_timestep_histogram": wandb.Histogram(self.gaussian_timestep_list)}) + + if self.buffer is not None: + self.buffer.clear() + self.gaussian_timestep_list = [] + + def on_validation_epoch_end(self) -> None: + """hack to run the full validation""" + if self.trainer.sanity_checking and self.global_rank == 0: + print(self) # log the model to wandb log files + + if (not self.trainer.sanity_checking) and (self.eval_data_cfg is not None): + self.eval_cnt = self.eval_cnt + 1 + if self.eval_cnt % self.train_cfg.eval_model_every_n_val == 0: + # backup current ckpt before running full test sets eval + if self.train_cfg.eval_save_model: + ckpt_saved_path = ( + self.trainer.checkpoint_callback.format_checkpoint_name( + dict( + epoch=self.trainer.current_epoch, + step=self.trainer.global_step, + ) + ) + ) + backup_dir = str( + Path(ckpt_saved_path).parent.parent / "checkpoints_backups" + ) + if self.global_rank == 0: + os.makedirs(backup_dir, exist_ok=True) + ckpt_saved_path = os.path.join( + backup_dir, os.path.basename(ckpt_saved_path) + ) + if self.global_rank == 0: + print(f"backup model to {ckpt_saved_path}.") + # call save_checkpoint on ALL process as suggested by pytorch_lightning + self.trainer.save_checkpoint( + ckpt_saved_path, + weights_only=True, + ) + + # run full test sets eval on rank=0 device + self.run_full_test_sets_eval() + + def on_test_epoch_start(self): + """Handle test epoch start.""" + if hasattr(self.scene_trainer, 'on_test_epoch_start'): + return self.scene_trainer.on_test_epoch_start() + + def on_test_epoch_end(self): + """Handle test epoch end.""" + if hasattr(self.scene_trainer, 'on_test_epoch_end'): + return self.scene_trainer.on_test_epoch_end() + + def on_test_end(self) -> None: + out_dir = self.test_cfg.output_path + + # Merge sub-module benchmarkers so all tags land in one file. + # scene_trainer.benchmarker holds "initializer" (wall-clock, from init_gaussians_and_render). + # optimizer.benchmarker is unused here — decoder/optimizer split is recorded per-scene + # in meta_test_step via benchmarker.record() directly on self.benchmarker. + self.benchmarker.merge(self.scene_trainer.benchmarker) + + # saved_scores = {} + if self.test_cfg.compute_scores: + self.benchmarker.dump_memory(out_dir / "peak_memory.json") + self.benchmarker.dump(out_dir / "benchmark.json") + + for output_dict, input_str in zip([self.test_step_outputs_context, self.test_step_outputs_target], + ["context", "target"]): + for metric_name, metric_scores in output_dict.items(): + metric_scores = torch.tensor(metric_scores) # [scenes, update_steps] + if metric_scores.numel() == 0: + continue + metric_scores = metric_scores.float() # [scenes, update_steps] + update_step_scores = metric_scores.mean(dim=0).tolist() # [update_steps] + # saved_scores[f"{input_str}_{metric_name}"] = update_step_scores[-1] + print(input_str, metric_name, update_step_scores) + with (out_dir / "metrics" / f"{input_str}_{metric_name}.json").open("w") as f: + json.dump(metric_scores.tolist(), f) + + self.benchmarker.clear_history() + else: + self.benchmarker.dump(out_dir / "metrics" / "benchmark.json") + self.benchmarker.dump_memory(out_dir / "metrics" / "peak_memory.json") + self.benchmarker.summarize() + + # ==================== Training ==================== + + def _move_batch_to_device(self, batch: dict) -> dict: + """Move a batch dict to the current device.""" + + def move_tensor(x): + if isinstance(x, Tensor): + return x.to(self.device) + elif isinstance(x, dict): + return {k: move_tensor(v) for k, v in x.items()} + elif isinstance(x, list): + return [move_tensor(v) for v in x] + return x + + return move_tensor(batch) + + def training_step(self, batch, batch_idx): + """ + This is a meta trainer class. Each training step and test step corresponds to training on one scene. + We delegate the actual training to meta_training_step and meta_test_step. + The loop over inner training steps (training within a specific scene) is performed in + self.get_optimized_gaussians. + Each inner iteration is done by calling the forward call of the optimizer. + """ + # DDP-safe skip: all ranks must call all_reduce together, so we do it + # here in training_step (which is always called on every rank) rather + # than inside meta_training_step (which may return early on only one rank). + # + # Each rank sets skip_flag=1 if it wants to skip, 0 otherwise. + # After MAX all_reduce, every rank sees 1 if *any* rank wants to skip, + # and all return zero loss together — keeping NCCL collectives in sync. + is_dist = dist.is_available() and dist.is_initialized() + wants_skip = torch.zeros(1, device=self.device) + + try: + loss = self.meta_training_step(batch, batch_idx) + except _SkipStepException: + wants_skip.fill_(1) + loss = torch.tensor(0.0, device=self.device, requires_grad=True) + + if is_dist: + dist.all_reduce(wants_skip, op=dist.ReduceOp.MAX) + if wants_skip.item() > 0: + return torch.tensor(0.0, device=self.device, requires_grad=True) + + return loss + + def meta_training_step(self, scene_batch, batch_idx): + """One meta-training step: initialize Gaussians, run optimizer refinement, compute loss, optionally push to replay buffer.""" + batch_size, init_target_render_output = None, None + optimizer_output: OptimizerOutput | None = None + + # Prepare input (from dataloader or replay buffer) + if self._use_dataloader_batch: + # Use new batch from dataloader + scene_batch: BatchedExample = self.initializer_data_shim(scene_batch) + + # Get initialization Gaussians + try: + init_output = self.get_init_gaussians(scene_batch, is_training=self.scene_trainer_cfg.train_scene_init) + except SkipBatchException as e: + self.log("skip_zero_gaussians_batch", 1, prog_bar=True) + if self.global_rank == 0: + warn(f"Skipping batch {batch_idx} due to {e}. t meta {self.global_step}") + raise _SkipStepException(f"SkipBatch(init): {e}") + + prev_output = init_output + + # Render the init gaussians for loss calculation (only when training the initializer) + if self.scene_trainer_cfg.train_scene_init: + batch_size, init_target_render_output = ( + self.train_render_output_for_init_gaussians(scene_batch, init_output.gaussians)) + + curr_inner_iter = 0 + self._new_scenes_cnt += 1 + else: + # Resample from replay buffer intermediate optimized Gaussians (only when training the optimizer) + assert self.scene_trainer_cfg.train_scene_opt + assert not self.scene_trainer_cfg.train_scene_init + + # Sample from buffer + gaussian_episode_entry: GaussianEpisodeEntry = self.buffer.sample(device=self.device, + leave_batch_fn=self.will_move_minibatch_to_device) + + # Adjust sample + scene_batch = gaussian_episode_entry.batch + prev_output = OptimizerPreviousOutput(gaussians=gaussian_episode_entry.gaussians, + state=gaussian_episode_entry.state) + curr_inner_iter = gaussian_episode_entry.t + + # Simulate init_output for logging (no training of the init_model in this case) + init_output = InitializerOutput(gaussians=gaussian_episode_entry.gaussians) + + # Log the current timestep for analysis + self.gaussian_timestep_list.append(curr_inner_iter) + + # Optimize the gaussians + if self.scene_trainer.optimizer is not None and self.scene_trainer_cfg.train_scene_opt: + # During optimization, we render the context and target images for: + # 1. error/gradients calculation + # 2. loss calculation + # Although it is not necessary, we also render the init target image again for loss calculation + # In the case or training both initializer and optimizer, this is redundant. + + try: + optimizer_output: OptimizerOutput = self.get_optimized_gaussians(scene_batch, prev_output, + curr_iter=curr_inner_iter) + except (torch.cuda.OutOfMemoryError, torch.OutOfMemoryError) as e: + self.log("skip_oom_batch", 1, prog_bar=True) + print( + f"[rank {self.global_rank}] skipping batch {batch_idx} t meta {self.global_step} t inner {curr_inner_iter}: {e}") + torch.cuda.empty_cache() + raise _SkipStepException("OOM") + except SkipBatchException as e: + self.log("skip_nan_batch", 1, prog_bar=True) + if self.global_rank == 0: + warn(f"Skipping batch {batch_idx} due to {e}. " + f"t meta {self.global_step} t inner {curr_inner_iter}") + raise _SkipStepException(f"SkipBatch(opt): {e}") + curr_inner_iter = optimizer_output.t + + if optimizer_output.last_prev_output.state.state is not None: + state_norm = optimizer_output.last_prev_output.state.state.norm(dim=1).mean() + self.log("info/state_norm", state_norm) + + # Compute and log loss. + init_gaussians = init_output.gaussians + + try: + total_loss = self.train_calc_total_loss(scene_batch, optimizer_output, init_gaussians, + init_target_render_output, init_output.depths) + except (torch.cuda.OutOfMemoryError, torch.OutOfMemoryError) as e: + self.log("skip_oom_batch", 1, prog_bar=True) + print( + f"[rank {self.global_rank}] OOM: {e}, skipping batch {batch_idx} t meta {self.global_step} t inner {curr_inner_iter} num of inner {len(optimizer_output.gaussian_list)}") + torch.cuda.empty_cache() + raise _SkipStepException("OOM") + + # More logging + if optimizer_output is not None: + last_gaussians = optimizer_output.gaussian_list[-1] + else: + last_gaussians = init_gaussians + self.train_logging(scene_batch, optimizer_output, last_gaussians, total_loss) + + # Check for NaN loss + # Skipping pushing to the replat buffer + if torch.isnan(total_loss) or torch.isinf(total_loss): + self.log("skip_nan_batch", 1, prog_bar=True) + if self.global_rank == 0: + warn(f"Skipping batch {batch_idx} due to NaN loss. " + f"t meta {self.global_step} t inner {optimizer_output.t}") + raise _SkipStepException("NaN/Inf loss") + + # Push back to buffer + if self.buffer is not None and self.buffer.should_push(new_sample=self._use_dataloader_batch, + t=curr_inner_iter): + push = True + if self.train_cfg.replay_buffer_cfg.simulate_ahead: + min_steps = self.train_cfg.replay_buffer_cfg.simulate_ahead_min_steps + cfg_max_steps = self.train_cfg.replay_buffer_cfg.simulate_ahead_max_steps + + if self.train_cfg.replay_buffer_cfg.simulate_ahead_grow > 0: + t_meta = self.global_step + T_grow = self.train_cfg.replay_buffer_cfg.simulate_ahead_grow + max_steps = min_steps + (cfg_max_steps - min_steps) * min(1.0, t_meta / T_grow) + max_steps = int(max_steps) + else: + max_steps = cfg_max_steps + + if min_steps == max_steps: + steps = min_steps + else: + steps = np.random.randint(low=min_steps, high=max_steps + 1) + with torch.no_grad(): + # Set eval mode + self.eval() + self.scene_optimizer.save_every.set_all_tags(False) + self.promoting_buffer_sample = True + + try: + optimizer_output = self.get_optimized_gaussians(scene_batch, optimizer_output.last_prev_output, + curr_iter=optimizer_output.t, + num_update_steps=steps, + disable_tqdm=True) + last_gaussians = optimizer_output.last_prev_output.gaussians + # catching multiple errors + except (ValueError, SkipBatchException) as e: + warn(f"Skipping pushing batch {batch_idx} to buffer due to {e}.") + push = False + self.train() + self.scene_optimizer.save_every.set_all_tags(True) + self.promoting_buffer_sample = False + + # assert len(optimizer_output.target_render_list) == 0 # no rendering needed + + if optimizer_output.last_prev_output.state.state is not None: + with torch.no_grad(): + state_norm = optimizer_output.last_prev_output.state.state.norm(dim=1).mean() + if state_norm > 500: + warnings.warn(f"Pushing sample norm state {state_norm} {optimizer_output.t} {self.global_step}") + if push: + self.buffer.push(GaussianEpisodeEntry(t=optimizer_output.t, + batch=scene_batch, + gaussians=last_gaussians, + state=optimizer_output.last_prev_output.state, + id=self._new_scenes_cnt), to_cpu=True) + + self.log("replay_buffer/size", len(self.buffer.buffer)) + if self.train_cfg.replay_buffer_cfg.simulate_ahead: + self.log("replay_buffer/simulate_ahead", steps) + self.log("replay_buffer/stored_step", optimizer_output.t) + + return total_loss + + def train_logging(self, batch, optimizer_output, gaussians, total_loss): + self.log("loss/total", total_loss) + if ( + self.global_rank == 0 + and (self.global_step % self.train_cfg.print_log_every_n_steps == 0 or total_loss > 5) + ): + print( + f"train step {self.global_step}; " + f"scene_name = {[x[:20] for x in batch['scene']]}; " + f"context = {batch['context']['index'].tolist()}; " + f"target = {batch['target']['index'].tolist()}; " + f"bound = [{batch['context']['near'].detach().cpu().numpy().mean()} " + f"{batch['context']['far'].detach().cpu().numpy().mean()}]; " + f"loss = {total_loss:.6f}; " + ) + self.log("info/near", batch["context"]["near"].detach().cpu().numpy().mean()) + self.log("info/far", batch["context"]["far"].detach().cpu().numpy().mean()) + self.log("info/global_step", self.global_step) # hack for ckpt monitor + + # log gaussians scales + if self.scene_trainer_cfg.num_update_steps > 0 and "deltas" in optimizer_output.info: + delta_means = [deltas["means"] for deltas in optimizer_output.info["deltas"]] + delta_scales = [deltas["scales"] for deltas in optimizer_output.info["deltas"]] + + for i in range(len(delta_means)): + self.log(f"update{i}/delta_means_min", delta_means[i].abs().min().item()) + self.log(f"update{i}/delta_means_mean", delta_means[i].abs().mean().item()) + self.log(f"update{i}/delta_means_max", delta_means[i].abs().max().item()) + + for i in range(len(delta_scales)): + self.log(f"update{i}/delta_scales_min", delta_scales[i].abs().min().item()) + self.log(f"update{i}/delta_scales_mean", delta_scales[i].abs().mean().item()) + self.log(f"update{i}/delta_scales_max", delta_scales[i].abs().max().item()) + + self.log("info/gaussian_scale_min", gaussians.scales.min().item()) + self.log("info/gaussian_scale_max", gaussians.scales.max().item()) + self.log("info/gaussian_scale_mean", gaussians.scales.mean().item()) + + # log gaussians opacities + self.log("info/gaussian_opacity_min", gaussians.opacities.min().item()) + self.log("info/gaussian_opacity_max", gaussians.opacities.max().item()) + self.log("info/gaussian_opacity_mean", gaussians.opacities.mean().item()) + + # Tell the data loader processes about the current step. + if self.step_tracker is not None: + self.step_tracker.set_step(self.global_step) + if self.global_step == 5 and self.global_rank == 0: + os.system("nvidia-smi") + global slurm_id_logged + if self.global_rank == 0 and not slurm_id_logged: + print('slurm id:', os.environ.get('SLURM_JOB_ID')) + slurm_id_logged = True + + def compute_losses(self, gaussians, i, num_output, render_output, curr_gt_rgb, valid_depth_mask, + error_idx=None, all_gt_rgb=None, tag="target"): + """Compute weighted sum of all configured losses at one optimizer step. + + curr_gt_rgb [B, V_rendered] is pre-indexed to the views the optimizer rendered; + all_gt_rgb [B, V_all] is the full GT passed only to losses that need every view (e.g. LossSh0). + intermediate_loss_weight discounts earlier refinement steps. + """ + # curr_gt_rgb: [B, V_rendered, C, H, W] — already narrowed to the views the optimizer rendered + # all_gt_rgb: [B, V_all, C, H, W] — full GT; passed to losses that need all views (e.g. LossSh0) + if all_gt_rgb is None: + all_gt_rgb = curr_gt_rgb + total_loss = 0 + curr_loss_weight = self.train_cfg.intermediate_loss_weight ** (num_output - 1 - i) + + gt_rgb, pred_rgb, valid_depth_mask = Loss.extract_pred_gt( + curr_gt_rgb, render_output, error_idx, valid_depth_mask + ) + + for loss_fn in self.losses: + if isinstance(loss_fn, LossStability): + # Stability loss is applied on all intermediate outputs + # Will be calculated outside of the inner steps loop + continue + # TODO Naama review + loss = loss_fn( + render_output, + gaussians, + self.global_step, + gt_rgb=gt_rgb, + pred_rgb=pred_rgb, + gt_image=all_gt_rgb, + valid_depth_mask=valid_depth_mask, + l1_loss=self.train_cfg.l1_loss, + clamp_large_error=self.train_cfg.train_ignore_large_loss, + half_res_lpips=self.train_cfg.half_res_lpips_loss, + ) + + loss_tag = f"{tag}_" + loss_fn.name + loss_tag += f"_{i + 1}" if i > 0 else "" + self.log(f"loss/{loss_tag}", loss) + + total_loss += curr_loss_weight * loss + + return total_loss + + def train_calc_total_loss(self, batch, optimizer_output: OptimizerOutput | None, init_gaussians, + init_target_render_output, pred_depths): + """Accumulate total training loss: init + optimizer steps + depth + monodepth losses.""" + total_loss = 0 + valid_depth_mask = None + + target_gt_rgb = batch["target"]["image"] + t = optimizer_output.t if optimizer_output is not None else 0 + + # Log and calculate loss of init + if self.scene_trainer_cfg.train_scene_init: + total_loss += self._calc_init_loss(init_gaussians, init_target_render_output, target_gt_rgb, + valid_depth_mask) + else: + # Still log init psnr, but init_target_render_output is None + self._log_init_metrics_from_optimizer(batch, optimizer_output) + + # Log and calculate loss of intermediate outputs during refinement + if self.scene_trainer_cfg.train_scene_opt: + total_loss += self._calc_opt_loss(batch, optimizer_output, t, valid_depth_mask) + + # More loss on the last prediction + assert self.scene_trainer_cfg.train_scene_init ^ self.scene_trainer_cfg.train_scene_opt + last_target_decoder_output = optimizer_output.target_render_list[ + -1] if optimizer_output is not None else init_target_render_output + + # render depth loss + if self.train_cfg.render_depth_loss_weight > 0: + # [B, V, H, W] + near = batch["target"]["near"][..., None, None] # [B, V, 1, 1] + far = batch["target"]["far"][..., None, None] + + target_gt_depth = batch["target"]["depth"] + render_depth = last_target_decoder_output.depth + + valid = (target_gt_depth >= near) & (target_gt_depth <= far) & (render_depth >= near) & ( + render_depth <= far) + + render_depth_loss = self.train_cfg.render_depth_loss_weight * ( + torch.log(target_gt_depth[valid]) - torch.log(render_depth[valid])).abs().mean() + + self.log(f"loss/render_depth", render_depth_loss) + total_loss = total_loss + render_depth_loss + + # depth loss + if self.train_cfg.depth_loss_weight > 0: + near = batch["context"]["near"][..., None, None] # [B, V, 1, 1] + far = batch["context"]["far"][..., None, None] + + depth_gt = batch['context']["depth"] # [B, V, H, W] + + valid = (depth_gt >= near) & (depth_gt <= far) + + # in case there is no valid gt depth (loss will be nan) + if valid.max() > 0.5: + # log or inverse depth loss + if self.train_cfg.log_depth_loss: + depth_loss = ( + torch.log(pred_depths[valid]) - torch.log(depth_gt[valid])).abs().mean() + else: + depth_loss = ( + 1. / pred_depths[valid] - 1. / depth_gt[valid]).abs().mean() + + depth_loss = self.train_cfg.depth_loss_weight * depth_loss + + self.log(f"loss/depth", depth_loss) + total_loss = total_loss + depth_loss + + # depth smooth loss + if self.train_cfg.depth_smooth_loss_weight > 0: + imgs = batch["context"]["image"].flatten(0, 1) # [BV, 3, H, W] + + depth = pred_depths.flatten(0, 1).unsqueeze(1) + + disp = 1. / depth + if self.train_cfg.depth_smooth_loss_nonorm: + norm_disp = disp + else: + mean_disp = disp.mean(2, True).mean(3, True) + norm_disp = disp / (mean_disp + 1e-7) + + # resize to depth's resolution + if imgs.shape[-2:] != norm_disp.shape[-2:]: + imgs = F.interpolate(imgs, size=norm_disp.shape[-2:], mode='bilinear', align_corners=True) + + depth_smooth_loss = get_smooth_loss(norm_disp, imgs) + + depth_smooth_loss = self.train_cfg.depth_smooth_loss_weight * depth_smooth_loss + + self.log(f"loss/depth_smooth", depth_smooth_loss) + total_loss = total_loss + depth_smooth_loss + # depth smooth loss for novel views + if self.train_cfg.depth_smooth_loss_weight_nvs > 0: + imgs = batch["target"]["image"].flatten(0, 1) # [BV, 3, H, W] + + depth = last_target_decoder_output.depth.flatten(0, 1).unsqueeze(1) + + disp = 1. / depth.clamp(min=1e-3, max=1000.) + if self.train_cfg.depth_smooth_loss_nonorm: + norm_disp = disp + else: + mean_disp = disp.mean(2, True).mean(3, True) + norm_disp = disp / (mean_disp + 1e-7) + + depth_smooth_loss_nvs = get_smooth_loss(norm_disp, imgs) + + depth_smooth_loss_nvs = self.train_cfg.depth_smooth_loss_weight_nvs * depth_smooth_loss_nvs + + self.log(f"loss/depth_smooth_nvs", depth_smooth_loss_nvs) + total_loss = total_loss + depth_smooth_loss_nvs + # monodepth loss + if self.train_cfg.monodepth_loss_weight > 0: + imgs = batch["context"]["image"].flatten(0, 1) # [BV, 3, H, W] + + pred_disp = 1. / pred_depths.flatten(0, 1).clamp(min=1e-2) # [BV, H, W] + + # resize to max size 518 + max_width = 518 + + ori_h, ori_w = imgs.shape[-2:] + + # resize the max size to 518 + assert ori_h <= ori_w + if ori_w != max_width: + new_h = int(ori_h * max_width / ori_w) // 14 * 14 # make sure divisible by 14 + new_w = max_width + imgs = F.interpolate(imgs, size=(new_h, new_w), mode='bilinear', align_corners=True) + + # normalize images + imgs = torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + )(imgs) + + # monodepth prediction: disparity + with torch.no_grad(): + monodepth_pred = self.pretrained_monodepth(imgs) + + monodepth_pred = F.interpolate(monodepth_pred.unsqueeze(1), size=(ori_h, ori_w), mode='nearest').squeeze( + 1) # [BV, H, W] + + def normalize_disp(disp): + median = disp.median(dim=-1, keepdim=True)[0] # [BV] + var = (disp - median).abs().mean(dim=-1, keepdim=True) + + return (disp - median) / (var + 1e-6) + + norm_pred_disp = normalize_disp(pred_disp.flatten(1, 2)) + norm_mono_disp = normalize_disp(monodepth_pred.flatten(1, 2)) + + monodepth_loss = (norm_pred_disp - norm_mono_disp).abs().mean() + + monodepth_loss = self.train_cfg.monodepth_loss_weight * monodepth_loss + + self.log(f"loss/monodepth", monodepth_loss) + total_loss = total_loss + monodepth_loss + return total_loss + + def _calc_opt_loss(self, batch, optimizer_output, t, valid_depth_mask): + """Compute loss over all optimizer refinement steps for both target and context views.""" + opt_loss = 0 + assert optimizer_output is not None + refine_step_num = len(optimizer_output.context_render_list) - 1 # first render is initialization + + # (tag, loss_enabled, loss_num) — render/index lists accessed via optimizer_output methods + view_loss_cfg = [ + ("target", self.train_cfg.loss_on_target_views, self.train_cfg.loss_on_target_views_num), + ("context", self.train_cfg.loss_on_input_views, self.train_cfg.loss_on_input_views_num), + ] + + for i in range(refine_step_num): + for tag, loss_enabled, loss_num in view_loss_cfg: + render_list = optimizer_output.get_render_list(tag) + index_list = optimizer_output.get_index_list(tag) + # all_gt_rgb: full GT for all views in the batch [B, V_all, C, H, W] + all_gt_rgb = batch[tag]["image"] + + if index_list: + # opt_batch_size < V_all: optimizer rendered a subset of views this step + train_idx = index_list[i] # [B, V_rendered] — from scene_trainer.opt_batch_size + curr_gt_rgb = batched_select(all_gt_rgb, train_idx) # [B, V_rendered, C, H, W] + else: + curr_gt_rgb = all_gt_rgb + self._log_train_metrics(i + 1, render_list[i + 1].color, curr_gt_rgb, tag=tag, t=t) + + if loss_enabled: + b, actual_v = curr_gt_rgb.shape[:2] + num_loss = actual_v if loss_num < 0 else loss_num + # error_idx: subsample rendered views down to loss_num for the loss computation + error_idx = torch.randperm(actual_v, device=curr_gt_rgb.device)[:num_loss] + error_idx = error_idx.unsqueeze(0).expand(b, -1) + opt_loss += self.compute_losses(optimizer_output.gaussian_list[i], i, refine_step_num, + render_list[i + 1], curr_gt_rgb, valid_depth_mask, + error_idx=error_idx, all_gt_rgb=all_gt_rgb, tag=tag) + if any(isinstance(loss, LossStability) for loss in self.losses): + stability_loss_fn = next(loss for loss in self.losses if isinstance(loss, LossStability)) + stability_loss = stability_loss_fn(optimizer_output, batch) + opt_loss += stability_loss + self.log(f"loss/stability", stability_loss) + + return opt_loss + + def _log_init_metrics_from_optimizer(self, batch, optimizer_output): + assert optimizer_output is not None + for tag, is_target in [("context", False), ("target", True)]: + render_list = optimizer_output.get_render_list(tag) + index_list = optimizer_output.get_index_list(tag) + all_gt_rgb = batch[tag]["image"] + # Using the first optimization step indices (which was used for rendering during optimization) + curr_gt_rgb = batched_select(all_gt_rgb, index_list[0]) if index_list else all_gt_rgb + self._log_train_metrics(0, render_list[0].color, curr_gt_rgb, tag=tag) + + def _calc_init_loss(self, init_gaussians, init_target_render_output, target_gt_rgb, valid_depth_mask): + assert not self.train_cfg.loss_on_input_views + self._log_train_metrics(0, init_target_render_output.color, target_gt_rgb, tag="target") + # TODO Naama: train init model on context+target? + return self.compute_losses(init_gaussians, 0, 1, init_target_render_output, target_gt_rgb, valid_depth_mask) + + def _log_train_metrics(self, i, pred, gt, tag, t=-1): + psnr = compute_psnr( + rearrange(gt, "b v c h w -> (b v) c h w"), + rearrange(pred, "b v c h w -> (b v) c h w"), + ) + self.log(f"train/{tag}_psnr_{i}", psnr.mean().item()) + + if self.global_step < (100000 if DEBUG else 10) and self.global_rank == 0: + print( + f"Training step {self.global_step}, inner step {t} i {i} train psnr {psnr.mean().item()}") + + def train_render_output_for_init_gaussians(self, batch, gaussians): + b, v, _, h, w = batch["context"]["image"].shape + assert gaussians.means.size(0) == batch["target"]["extrinsics"].size(0), \ + "num_scales must be 1; multi-scale depth supervision is not supported" + batch_size = batch["target"]["extrinsics"].size(0) + output = self.scene_decoder.forward( + gaussians, + batch["target"]["extrinsics"], + batch["target"]["intrinsics"], + batch["target"]["near"], + batch["target"]["far"], + (h, w), + depth_mode='depth' if self.train_cfg.render_depth_loss_weight > 0 else None, + ) + return batch_size, output + + # ==================== Meta Optimizer Configuration ==================== + + def _split_params(self, filter_key: str) -> tuple[list, list]: + """Split parameters into (matched, rest) based on whether name contains filter_key.""" + matched, rest = [], [] + for name, param in self.named_parameters(): + (matched if filter_key in name else rest).append(param) + return matched, rest + + def _build_adamw(self, params_or_groups, weight_decay: float, **kwargs): + """Instantiate AdamW or AdamW8bit depending on config.""" + cls = AdamW8bit if self.meta_optimizer_cfg.adamw_8bit else optim.AdamW + return cls(params_or_groups, weight_decay=weight_decay, **kwargs) + + def configure_optimizers(self): + # This is the *meta* optimizer — it optimizes the parameters of the learned optimizer itself + # (i.e. the KnnBasedOptimizer weights), not individual scene Gaussians. + # Only called by Lightning during fit(); skipped entirely in test mode. + cfg = self.meta_optimizer_cfg + + if cfg.lr_depth > 0: + pretrained, rest = self._split_params("depth_predictor") + param_groups = [{"params": pretrained, "lr": cfg.lr_depth}, {"params": rest, "lr": cfg.lr}] + scheduler_lrs = [cfg.lr_monodepth, cfg.lr] + elif cfg.lr_monodepth > 0: + pretrained, rest = self._split_params("pretrained") + param_groups = [{"params": pretrained, "lr": cfg.lr_monodepth}, {"params": rest, "lr": cfg.lr}] + scheduler_lrs = [cfg.lr_monodepth, cfg.lr] + else: + param_groups = self.parameters() + scheduler_lrs = cfg.lr + + meta_optimizer = self._build_adamw(param_groups, cfg.weight_decay) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + meta_optimizer, + scheduler_lrs, + self.trainer.max_steps + 10, + pct_start=cfg.warm_up_ratio, + cycle_momentum=False, + anneal_strategy="cos", + ) + + return { + "optimizer": meta_optimizer, + "lr_scheduler": {"scheduler": scheduler, "interval": "step", "frequency": 1}, + } + + # ==================== Testing ==================== + + @torch.no_grad() + def test_step(self, scene_batch: BatchedExample, batch_idx: int): + """ + This is a meta trainer class. Each training step and test step corresponds to training on one scene. + We delegate the actual training/testing to meta_training_step and meta_test_step. + The loop over inner training steps (training within a specific scene) is performed in + self.get_optimized_gaussians. + Each inner iteration is done by calling the forward call of the optimizer. + """ + return self.meta_test_step(scene_batch, batch_idx) + + @torch.no_grad() + def meta_test_step(self, scene_batch: BatchedExample, batch_idx: int): + """Run the full test pipeline for one scene: initialize, optimize, then evaluate and save.""" + if self.test_cfg.scenes_filter is not None and scene_batch['scene'][0] not in self.test_cfg.scenes_filter: + print(f"Scenes filter: {self.test_cfg.scenes_filter}") + print(f"Skipping scene {scene_batch['scene'][0]} (not in scenes_filter)") + return + + output_path = self.test_cfg.output_path + + if output_path is not None and self.test_cfg.skip_if_outputs_exist: + optimizer_name = self.scene_trainer.optimizer.__class__.__name__.lower() if self.scene_trainer.optimizer is not None else "no_optimizer" + target_metric_path = output_path / optimizer_name / "metrics" / f"{scene_batch['scene'][0]}" / f"target_{optimizer_name}.json" + context_metric_path = output_path / optimizer_name / "metrics" / f"{scene_batch['scene'][0]}" / f"context_{optimizer_name}.json" + should_eval_context = self.test_cfg.eval_context_views + should_eval_target = True # always evaluate target views + + skip_target = (should_eval_target and target_metric_path.exists()) or not should_eval_target + skip_context = (should_eval_context and context_metric_path.exists()) or not should_eval_context + + if skip_target and skip_context: + print( + f"Metrics for scene {scene_batch['scene'][0]} already exist at {target_metric_path} and {context_metric_path}. Skipping...") + return + + rule(f"Testing scene {batch_idx}: {scene_batch['scene'][0]}") + + # input (context and target) + batch: BatchedExample = self.initializer_data_shim(scene_batch) + + # Process batch for experiments, e.g., add noise (skip if not needed) + if self.test_cfg.experimental_add_noise_to_images: + batch = self.experimental_process_batch(batch) + + # Save cameras as JSON (before optimization, cameras are fixed) + if self.test_cfg.save_cameras_json: + scene_name = batch["scene"][0] + relevant_keys = ["extrinsics", "intrinsics"] + context_info = {key: batch["context"][key][0].cpu().tolist() for key in relevant_keys} + target_info = {key: batch["target"][key][0].cpu().tolist() for key in relevant_keys} + resolution = list(batch["context"]["image"].shape[-2:]) + cameras_data = { + "scene": scene_name, + "context": context_info, + "target": target_info, + "resolution": resolution, + } + cameras_dir = output_path / "cameras" + cameras_dir.mkdir(parents=True, exist_ok=True) + cameras_path = cameras_dir / f"{scene_name}_cameras.json" + with open(cameras_path, "w") as f: + json.dump(cameras_data, f, indent=4) + print(f"Saved cameras JSON to {cameras_path}") + + # Save cameras as NPZ in the exact form fed to the rasterizer: + # viewmats = inverse(extrinsics) (world-to-camera, [V,4,4]) + # Ks = intrinsics * diag(W, H, 1) (pixel-space, [V,3,3]) + # Mirrors GSplatDecoderSplattingCUDA.forward (gsplat_decoder_splatting_cuda.py:137-140). + if self.test_cfg.save_cameras_npz: + scene_name = batch["scene"][0] + cameras_dir = output_path / "cameras" + cameras_dir.mkdir(parents=True, exist_ok=True) + npz_data = {"scene": scene_name} + for input_str in ("context", "target"): + view = batch[input_str] + extrinsics = view["extrinsics"][0] # [V, 4, 4] cam-to-world + intrinsics = view["intrinsics"][0] # [V, 3, 3] normalized + h, w = view["image"].shape[-2:] + viewmats = extrinsics.inverse() # [V, 4, 4] world-to-cam + scale = intrinsics.new_tensor([[w], [h], [1]]) + Ks = intrinsics * scale # [V, 3, 3] pixel-space + npz_data[f"{input_str}_viewmats"] = viewmats.cpu().numpy() + npz_data[f"{input_str}_Ks"] = Ks.cpu().numpy() + npz_data[f"{input_str}_image_shape"] = np.array([h, w], dtype=np.int64) + cameras_npz_path = cameras_dir / f"{scene_name}_cameras.npz" + np.savez(cameras_npz_path, **npz_data) + print(f"Saved renderer-ready cameras NPZ to {cameras_npz_path}") + + self.scene_initializer.preprocessing(batch, self.train_cfg) + + # Infer Gaussians. + + # init + scene_name = batch["scene"][0] + init_output: InitializerOutput = self.init_gaussians_and_render( + batch, + visualization_dump={}, + render_context=True, + render_target=True, + grad_enabled=False, + cached_data_path=Path(os.path.join("cache", "edgs", scene_name)), # for EDGS only for now # TODO Naame: review + ) + + if self.test_cfg.eval_initialization: + print("\nEvaluating initialization...") + + # Evaluate and save initialization + self._eval_and_save( + self.scene_initializer, + batch, + batch_idx, + init_output, + output_path + ) + + # Optimization + if self.scene_trainer.optimizer is None: + optimizer_output = None + else: + # run optimizer + torch.cuda.reset_peak_memory_stats() + try: + optimizer_output = self.get_optimized_gaussians( + batch, + init_output, + output_path=output_path / self.scene_trainer.optimizer.__class__.__name__.lower(), + scene_name=scene_name, + debug_dict=defaultdict(list), + ) + except (torch.OutOfMemoryError, RuntimeError) as e: + warn('ran out of memory during optimization. Skipping scene.') + torch.cuda.empty_cache() + return None + except SkipBatchException as e: + warn(f'skipping scene due to SkipBatch during optimization: {e}') + return None + + peak_vram_mb = torch.cuda.max_memory_allocated() / (1024 ** 2) + self.benchmarker.record("peak_vram_mb", peak_vram_mb) + + # Record per-scene timing from CUDA event logs (all in ms). + # optimizer_net = on_scene_start + all iteration steps. Excludes save-every renders + # (which happen after iter_end.record() and are therefore not in iter_time_log). + opt = self.scene_trainer.optimizer + decoder_ms = sum(opt.decoder_time_log) + optimizer_ms = sum(opt.optimizer_time_log) + optimizer_net_ms = opt.scene_start_ms + decoder_ms + optimizer_ms + self.benchmarker.record("decoder", decoder_ms) + self.benchmarker.record("optimizer", optimizer_ms) + self.benchmarker.record("optimizer_net", optimizer_net_ms) + print( + f"[timing] scene={scene_name} " + f"scene_start={opt.scene_start_ms:.0f}ms " + f"decoder={decoder_ms:.0f}ms " + f"optimizer={optimizer_ms:.0f}ms " + f"optimizer_net={optimizer_net_ms:.0f}ms " + f"peak_vram={peak_vram_mb:.0f}MB" + ) + opt.decoder_time_log.clear() + opt.optimizer_time_log.clear() + + # Collected here; written into target_*.json / context_*.json by _eval_and_save below. + _scene_timing_metrics = { + "peak_vram_mb": peak_vram_mb, + "decoder_ms": decoder_ms, + "optimizer_ms": optimizer_ms, + "optimizer_net_ms": optimizer_net_ms, + "scene_start_ms": opt.scene_start_ms, + } + + # + plot_phases = [] # (label, metrics_dict) for combined plotting + + if optimizer_output is not None: + # Init is already spliced into position 0 of optimizer_output lists by + # SceneTrainer.get_optimized_gaussians (see _insert_init_into_output). + + # Run evaluation and saving + opt_metrics = self._eval_and_save( + self.scene_trainer.optimizer, + batch, + batch_idx, + optimizer_output, + output_path, + extra_scene_metrics=_scene_timing_metrics, + ) + opt_label = self.scene_trainer.optimizer.__class__.__name__.lower() + plot_phases.append((opt_label, opt_metrics)) + + # updates, parameters and gradients visualizations + # self.debugging(optimizer_output, output_path, batch["scene"][0]) + + # Post-processing + postprocessed_output = self.test_postprocess_gaussians( + batch, + gaussians=optimizer_output.gaussian_list[-1] if optimizer_output is not None else init_output.gaussians, + visualization_dump={} + ) + + # Evaluate and save post-processing + if postprocessed_output is not None: + pp_metrics = self._eval_and_save( + self.scene_trainer.postprocess, + batch, + batch_idx, + postprocessed_output, + output_path + ) + pp_label = self.scene_trainer.postprocess.__class__.__name__.lower() + plot_phases.append((pp_label, pp_metrics)) + + # Combined metrics plot (optimizer + postprocessing) + if plot_phases: + pass + # self._plot_combined_metrics( + # output_path=output_path, + # scene_name=scene_name, + # phases=plot_phases, + # ) + + def experimental_process_batch(self, batch: BatchedExample) -> BatchedExample: + noise_std = self.test_cfg.experimental_add_noise_to_images_std + for key in ["context", "target"]: + images = batch[key]["image"] # [B, V, 3, H, W] + noise = torch.randn_like(images) * noise_std + noisy_images = images + noise + noisy_images = torch.clamp(noisy_images, 0.0, 1.0) + batch[key]["image"] = noisy_images + batch[key]["clean_image"] = images # keep clean images for evaluation + return batch + + @torch.no_grad() + @rank_zero_only + def validation_step(self, scene_batch: BatchedExample, batch_idx: int): + scene_batch: BatchedExample = self.initializer_data_shim(scene_batch) + + self.scene_initializer.preprocessing(scene_batch, self.train_cfg) + + if self.global_rank == 0: + print( + f"validation step {self.global_step}; " + f"scene_name = {[a[:20] for a in scene_batch['scene']]}; " + f"context = {scene_batch['context']['index'].tolist()}; " + f"target = {scene_batch['target']['index'].tolist()}" + ) + + # Render Gaussians. + b, v, _, h, w = scene_batch["context"]["image"].shape + assert b == 1 + + try: + initializer_output = self.get_init_gaussians(scene_batch, is_training=False) + except SkipBatchException as e: + warn(f"Skipping validation for scene {scene_batch['scene'][0]} due to error in initialization: {e}") + return + + output_softmax = self.scene_decoder.forward_target( + initializer_output.gaussians, scene_batch, (h, w), + depth_mode='depth' if self.train_cfg.eval_render_depth or self.train_cfg.viz_render_depth else None, + ) + + # refine + debug_dict = {} + if self.scene_optimizer is not None: + try: + optimizer_output = self.get_optimized_gaussians( + scene_batch, + initializer_output, + debug_dict=debug_dict + ) + except SkipBatchException as e: + warn(f"Skipping validation for scene {scene_batch['scene'][0]} due to error: {e}") + return + render_output = optimizer_output.target_render_list + output_softmax = render_output[-1] + + rgb_softmax = output_softmax.color[0] + + # Move prediction back to device + rgb_softmax = rgb_softmax.to(scene_batch["target"]["image"].device) + + # Compute validation metrics. + rgb_gt = scene_batch["target"]["image"][0] + for tag, rgb in zip(("val",), (rgb_softmax,)): + psnr = compute_psnr(rgb_gt, rgb) + self.log(f"val/psnr_{tag}", psnr) + ssim = compute_ssim(rgb_gt, rgb) + self.log(f"val/ssim_{tag}", ssim) + + # viz depth + if initializer_output.depths is not None and self.train_cfg.viz_depth_separate: + # only visualize predicted depth + pred_depths = initializer_output.depths[0] # [V, H, W] + + # gaussian downsample + # downsample image to depth resolution + if pred_depths.shape[1:] != scene_batch["context"]["image"].shape[-2:]: + input_images = F.interpolate( + scene_batch["context"]["image"][0], + size=pred_depths.shape[-2:], + mode="bilinear", + align_corners=True, + ).squeeze(1) + else: + input_images = scene_batch["context"]["image"][0] # [N, 3, H, W] + + concat = self._make_depth_viz(1.0 / pred_depths, input_images) + + # reshape when the number of input images is too large + # otherwise the image will be too wide + num_inputs = input_images.shape[0] + width = input_images.shape[-1] + if num_inputs > 8: + rows = 4 + assert num_inputs % rows == 0 + stride = num_inputs // rows + out = [] + for i in range(rows): + out.append(concat[:, :, width * stride * i: width * stride * (i + 1)]) + + concat = torch.cat(out, dim=1) # [3, H*2*R, W*N/R] + + # resize to half resolution to save space + concat = F.interpolate(concat.unsqueeze(0), scale_factor=0.5, mode='bilinear', + align_corners=True).squeeze(0) + + self.logger.log_image( + "depth", + [concat], + step=self.global_step, + caption=scene_batch["scene"], + ) + + # viz rendered depth + if self.train_cfg.eval_render_depth or self.train_cfg.viz_render_depth: + render_depth = output_softmax.depth[0] # [V, H, W] + input_images = scene_batch["target"]["image"][0] # [N, 3, H, W] + concat = self._make_depth_viz(1.0 / render_depth.clamp(min=0.01, max=1000.), input_images) + + self.logger.log_image( + "render_depth", + [concat], + step=self.global_step, + caption=scene_batch["scene"], + ) + + # Subsample context images when there are too many to fit comfortably side-by-side + n_ctx = scene_batch["context"]["image"][0].shape[0] + stride = 4 if n_ctx > 16 else (2 if n_ctx > 8 else 1) + viz_input = scene_batch["context"]["image"][0][::stride] + tag = "Context" if stride == 1 else f"Context (1/{stride})" + + comparison = self._build_comparison_image( + initializer_output, viz_input, tag, rgb_gt, rgb_softmax, stride + ) + + self.logger.log_image( + "comparison", + [prep_image(add_border(comparison))], + step=self.global_step, + caption=scene_batch["scene"], + ) + + if not self.train_cfg.no_log_projections: + # Render projections and construct projection image. + projections = hcat( + *render_projections( + initializer_output.gaussians, + 256, + extra_label="(Prediction)", + )[0] + ) + self.logger.log_image( + "projection", + [prep_image(add_border(projections))], + step=self.global_step, + ) + + # Draw cameras. + # cameras = hcat(*render_cameras(batch, 256)) + # self.logger.log_image( + # "cameras", [prep_image(add_border(cameras))], step=self.global_step + # ) + + # Run video validation step. + if not self.train_cfg.no_viz_video: + self.render_video_interpolation(scene_batch) + # self.render_video_wobble(batch) + if self.train_cfg.extended_visualization: + self.render_video_interpolation_exaggerated(scene_batch) + + def _build_comparison_image( + self, + initializer_output: InitializerOutput, + viz_input: Tensor, + tag: str, + rgb_gt: Tensor, + rgb_softmax: Tensor, + stride: int, + ) -> Tensor: + """Build the side-by-side comparison image for validation logging.""" + cols = [ + add_label(vcat(*viz_input), tag), + add_label(vcat(*rgb_gt), "Target (Ground Truth)"), + add_label(vcat(*rgb_softmax), "Target (Prediction)"), + ] + + if not self.train_cfg.viz_depth_separate and initializer_output.depths is not None: + pred_depths = initializer_output.depths[0] # [V, H, W] + inverse_depth_pred = 1.0 / pred_depths + concat = rearrange(inverse_depth_pred, "v h w -> (v h) w") + depth_viz = viz_depth_tensor(concat.cpu().detach()).to(pred_depths.device).float() / 255. + depth_viz = rearrange(depth_viz, "c (v h) w -> v c h w", v=pred_depths.shape[0]) + + if depth_viz.shape[-2:] != viz_input.shape[-2:]: + depth_viz = F.interpolate(depth_viz, size=viz_input.shape[-2:], mode='bilinear', align_corners=True) + + depth_viz = depth_viz[::stride] + cols.insert(1, add_label(vcat(*depth_viz), "Depth (Prediction)")) + + return hcat(*cols) + + @staticmethod + def _make_depth_viz(inverse_depth: Tensor, images: Tensor) -> Tensor: + """Combine inverse-depth colormap with RGB images for logging. Returns [3, H*2, W*V].""" + depth_viz = viz_depth_tensor(torch.cat(list(inverse_depth), dim=1).cpu().detach()) # [3, H, W*V] + concat_img = torch.cat(list(images), dim=-1).cpu().detach() * 255 # [3, H, W*V] + return torch.cat((concat_img, depth_viz), dim=1) # [3, H*2, W*V] + + def on_fit_start(self): + run = self.logger.experiment + if run is not None: + run.define_metric("inner_iteration") + run.define_metric("test/psnr/*", step_metric="inner_iteration") + + @rank_zero_only + def run_full_test_sets_eval(self) -> None: + """Run evaluation on the full test set during training (rank-zero only). Logs PSNR/SSIM to wandb table.""" + print( + f"Validation step at global step {self.global_step}. Running evaluation on {self.train_cfg.eval_data_length} test sets...") + start_t = time.time() + + pred_depths = None + depth_gt = None + + full_testsets = self.trainer.datamodule.test_dataloader( + dataset_cfg=self.eval_data_cfg + ) + scores_dict = defaultdict(lambda: defaultdict(list)) + + self.benchmarker.clear_history() + time_skip_first_n_steps = min( + self.train_cfg.eval_time_skip_steps, len(full_testsets) + ) + time_skip_steps_dict = {"encoder": 0, "decoder": 0} + for batch_idx, batch in tqdm( + enumerate(full_testsets), + total=min(len(full_testsets), self.train_cfg.eval_data_length), + ): + if batch_idx >= self.train_cfg.eval_data_length: + break + + batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0) + batch = self.on_after_batch_transfer(batch, dataloader_idx=batch_idx) + batch = self.initializer_data_shim(batch) + + # use gt depth range instead of a fixed one + self.scene_initializer.preprocessing(batch, self.train_cfg) + + # Render Gaussians. + b, v, _, h, w = batch["target"]["image"].shape + assert b == 1 + if batch_idx < time_skip_first_n_steps: + time_skip_steps_dict["encoder"] += 1 + time_skip_steps_dict["decoder"] += v + + with self.benchmarker.time("encoder"): + init_output = self.get_init_gaussians(batch, is_training=False) + + with self.benchmarker.time("decoder", num_calls=v): + output_probabilistic = self.scene_decoder.forward_target( + init_output.gaussians, batch, (h, w), + depth_mode='depth' if self.train_cfg.eval_render_depth or self.train_cfg.viz_render_depth else None, + ) + + init_rgb = output_probabilistic.color[0] + + # refine + if self.scene_optimizer is not None: + try: + optimizer_output = self.get_optimized_gaussians(batch, init_output) + except SkipBatchException as e: + warn(f'Skipping batch due to SkipBatch during optimization: {e}') + continue + render_output = optimizer_output.target_render_list + output_probabilistic = render_output[-1] + + rgbs = [init_rgb] + if self.scene_trainer_cfg.num_update_steps > 0: + rgbs += [render.color[0] for render in render_output] + tags = ["probabilistic"] * len(rgbs) + + if self.train_cfg.eval_deterministic: + gaussians_deterministic = self.encoder( + batch["context"], + self.global_step, + deterministic=True, + ) + output_deterministic = self.scene_decoder.forward( + gaussians_deterministic, + batch["target"]["extrinsics"], + batch["target"]["intrinsics"], + batch["target"]["near"], + batch["target"]["far"], + (h, w), + ) + rgbs.append(output_deterministic.color[0]) + tags.append("deterministic") + + # Compute validation metrics. + rgb_gt = batch["target"]["image"][0] + if self.scene_optimizer is not None: + steps = self.scene_optimizer.save_every.get_iterations(len(rgbs)) + else: + steps = [0] + for i, (tag, rgb) in enumerate(zip(tags, rgbs)): + # Move prediction back to device + rgb = rgb.to(batch["target"]["image"].device) + metric_scores: dict = compute_rgb_metrics( + rgb, rgb_gt, + metrics=["psnr", "ssim", "lpips"], + iter_batch_size=-1, + ) + for name, score in metric_scores.items(): + if name == "lpips": + # tuple of (alex, vgg) + scores_dict[f"alex_lpips_{steps[i]}"][tag].append(score[0].item()) + scores_dict[f"vgg_lpips_{steps[i]}"][tag].append(score[1].item()) + else: + scores_dict[f"{name}_{steps[i]}"][tag].append(score.item()) + # log the last step metrics to compare between runs + if i == len(rgbs) - 1: + if name == "lpips": + # tuple of (alex, vgg) + scores_dict[f"alex_lpips"][tag].append(score[0].item()) + scores_dict[f"vgg_lpips"][tag].append(score[1].item()) + else: + scores_dict[f"{name}"][tag].append(score.item()) + + # compute depth metrics + if pred_depths is not None and depth_gt is not None and depth_gt.max() > 0: + assert pred_depths is not None and depth_gt is not None + + pred_depths = pred_depths[0] # [V, H, W] + + # gaussian downsample + if pred_depths.shape[1:] != batch["context"]["image"].shape[-2:]: + pred_depths = F.interpolate( + pred_depths.unsqueeze(1), + size=batch["context"]["image"].shape[-2:], + mode="bilinear", + align_corners=True, + ).squeeze(1) + + depth_gt = depth_gt[0] # [V, H, W] + + near = batch["context"]["near"][..., + None, None][0] # [V, 1, 1] + far = batch["context"]["far"][..., None, None][0] # [V, 1, 1] + + valid = (depth_gt >= near) & (depth_gt <= far) + + all_metrics = compute_depth_errors(depth_gt[valid].detach().cpu().numpy(), + pred_depths[valid].detach().cpu().numpy()) + scores_dict["abs_rel"]["probabilistic"].append(all_metrics[0]) + scores_dict["rmse"]["probabilistic"].append(all_metrics[2]) + scores_dict["a1"]["probabilistic"].append(all_metrics[4]) + + # compute rendered depth metrics + if self.train_cfg.eval_render_depth: + render_depth = output_probabilistic.depth + target_depth_gt = batch["target"]["depth"] + + pred_depths = render_depth[0] # [V, H, W] + depth_gt = target_depth_gt[0] # [V, H, W] + + near = batch["target"]["near"][..., None, None][0] # [V, 1, 1] + far = batch["target"]["far"][..., None, None][0] # [V, 1, 1] + + valid = (depth_gt >= near) & (depth_gt <= far) + + all_metrics = compute_depth_errors(depth_gt[valid].detach().cpu().numpy(), + pred_depths[valid].detach().cpu().numpy()) + + scores_dict["render_abs_rel"]["probabilistic"].append(all_metrics[0]) + scores_dict["render_rmse"]["probabilistic"].append(all_metrics[2]) + scores_dict["render_a1"]["probabilistic"].append(all_metrics[4]) + + # summarise scores and log to logger + # Create wandb table for inner iteration visualization + # For now, log only psnr + if hasattr(self.logger, 'experiment') and self.logger.experiment is not None: + # Extract metrics that have step numbers (e.g., "psnr_0", "psnr_1", etc.) + inner_iteration_data = [] + for score_tag, methods in scores_dict.items(): + # Check if this is a step-specific metric (e.g., "psnr_0", "psnr_1", etc.) + if '_' in score_tag and score_tag.split('_')[-1].isdigit(): + metric_name, step_str = score_tag.rsplit('_', 1) + inner_step = int(step_str) + + if metric_name not in ["psnr"]: + continue + + for method_tag, cur_scores in methods.items(): + if len(cur_scores) > 0: + cur_mean = sum(cur_scores) / len(cur_scores) + inner_iteration_data.append({ + 'meta_iteration': self.global_step, + 'inner_iteration': inner_step, + 'metric_name': metric_name, + 'method': method_tag, + 'value': cur_mean + }) + + # Log the table if we have inner iteration data + if inner_iteration_data: + try: + # Rewrite the chart (wandb cannot append to the current figure (?)) + df = pd.DataFrame(inner_iteration_data) + df["meta_iteration_str"] = df["meta_iteration"].astype(str) + metric_to_plot = "psnr" + df_metric = df[df["metric_name"] == metric_to_plot] + + table = wandb.Table(dataframe=df_metric) + + # self.logger.experiment.log({f"{metric_to_plot}_line": wandb.plot.line( + # table, + # x="inner_iteration", + # y="value", + # title=f"{metric_to_plot} per inner iteration", + # stroke="meta_iteration_str", + # )}) + + # Plot psnr for current meta iteration in a separate chart + # run = self.logger.experiment + current_meta = self.global_step + + df_current = df_metric[df_metric["meta_iteration"] == current_meta] + + if len(df_current) > 1: + run = self.logger.experiment + run.define_metric("inner_iteration") + run.define_metric(f"test/psnr/meta_{current_meta}", step_metric="inner_iteration") + for _, row in df_current.iterrows(): + run.log({ + "inner_iteration": row["inner_iteration"], + f"test/psnr/meta_{current_meta}": row["value"], + }) + + except Exception as e: + warn(f"Could not create automatic charts: {e}") + # Fallback: just log the table + pass + + # Keep the original logging + for score_tag, methods in scores_dict.items(): + for method_tag, cur_scores in methods.items(): + if len(cur_scores) > 0: + cur_mean = sum(cur_scores) / len(cur_scores) + self.log(f"test/{score_tag}", cur_mean) + # summarise run time + for tag, times in self.benchmarker.execution_times.items(): + times = times[int(time_skip_steps_dict[tag]):] + print(f"{tag}: {len(times)} calls, avg. {np.mean(times)} seconds per call") + self.log(f"test/runtime_avg_{tag}", np.mean(times)) + self.benchmarker.clear_history() + + overall_eval_time = time.time() - start_t + psnr_list = [scores_dict[f"psnr_{i}"]["probabilistic"] for i in + range(self.scene_trainer_cfg.num_update_steps + 1)] + psnr_list = [sum(pnsr) / len(pnsr) for pnsr in psnr_list if len(pnsr) > 0] + psnr_str = ", ".join(f"psnr_{i}: {np.mean(pnsr):.3f}" for i, pnsr in enumerate(psnr_list)) + example_num = len(scores_dict['psnr_0']['probabilistic']) + print(f"Eval total time cost: {overall_eval_time:.3f}s, {psnr_str}, example_num: {example_num} ") + self.log("test/runtime_all", overall_eval_time) + + @staticmethod + def _get_renders_list( + output: OptimizerOutput | InitializerOutput, + input_str: str, + module: Initializer | Optimizer | PostProcessing3DGS, + ) -> tuple[list, list[int]]: + """Get render list and corresponding iteration indices for a given view tag ('context' or 'target').""" + if isinstance(output, OptimizerOutput): + renders_list = output.get_render_list(input_str) + elif isinstance(output, InitializerOutput): + if input_str == "context": + assert output.context_render is not None, "InitializerOutput must contain context_render" + renders_list = [output.context_render] + elif input_str == "target": + assert output.target_render is not None, "InitializerOutput must contain target_render" + renders_list = [output.target_render] + else: + raise ValueError(f"Unknown input_str: {input_str}") + else: + raise ValueError(f"Unknown output type: {type(output)}") + iterations = [0] if isinstance(module, Initializer) else module.save_every.get_iterations(len(renders_list)) + return renders_list, iterations + + @staticmethod + def _compute_depth_range( + output: OptimizerOutput | InitializerOutput, + input_strs: list[str], + module: Initializer | Optimizer | PostProcessing3DGS, + ) -> tuple[float, float]: + """Scan all renders to find the global depth min/max for consistent visualization.""" + depth_vmin, depth_vmax = np.inf, -np.inf + have_depths = True + for input_str in input_strs: + renders_list, _ = MetaTrainer._get_renders_list(output, input_str, module) + assert renders_list is not None, f"No renders found for {input_str}" + for iter_renders in renders_list: + iter_depths = iter_renders.depth # (1, V, H, W) + if iter_depths is None: + have_depths = False + continue + iter_depths = iter_depths[0] # (V, H, W) + depth_vmin = min(depth_vmin, iter_depths.min().item()) + depth_vmax = max(depth_vmax, iter_depths.max().item()) + if not have_depths: + depth_vmin, depth_vmax = 0.0, 1.0 + return depth_vmin, depth_vmax + + def _compute_error_vmax( + self, + output: OptimizerOutput | InitializerOutput, + input_strs: list[str], + module: Initializer | Optimizer | PostProcessing3DGS, + batch: BatchedExample, + ) -> float: + # Per-scene error range for error-map visualization. The 99th percentile + # (rather than the raw max) keeps a single outlier pixel from flattening + # the magma color scale across the whole scene. + if not self.test_cfg.save_error_image: + return 1.0 + error_values = [] + for input_str in input_strs: + renders_list, _ = self._get_renders_list(output, input_str, module) + if renders_list is None: + continue + if "clean_image" in batch[input_str]: + rgb_gt = batch[input_str]["clean_image"][0] # (V, 3, H, W) + else: + rgb_gt = batch[input_str]["image"][0] # (V, 3, H, W) + for iter_renders in renders_list: + iter_rgbs = iter_renders.color[0] # (V, 3, H, W) + err = (iter_rgbs - rgb_gt.to(iter_rgbs)).abs().mean(1) # (V, H, W) + error_values.append(err.flatten().cpu().numpy()) + del iter_rgbs, err + if not error_values: + return 1.0 + return float(np.percentile(np.concatenate(error_values), 99)) + + def _compute_and_save_scores( + self, + module: Initializer | Optimizer | PostProcessing3DGS, + output: OptimizerOutput | InitializerOutput, + renders_list: list, + rgb_gt: Tensor, + iterations: list[int], + input_str: str, + module_name: str, + out_dir: Path, + extra_scene_metrics: dict | None, + ) -> None: + """Compute RGB metrics per iteration, accumulate into output_dict, and save per-scene JSON.""" + # Collect per-step stats logs from the module + if isinstance(module, Initializer): + nr_gaussians_log = [output.gaussians.means.shape[1]] + iter_time_log = [0.0] + nr_nonzero_grads_log = [0.0] + elif isinstance(module, (Optimizer, PostProcessing3DGS)): + nr_gaussians_log = module.nr_gaussians_log + nr_nonzero_grads_log = module.nr_nonzero_grad_log + iter_time_log = module.iter_time_log + iter_time_log[0] = 0.0 + else: + raise ValueError(f"Unknown module type: {type(module)}") + + self.init_output_dict_for_new_scene(input_str=input_str, tag=module_name) + output_dict = self.test_step_outputs_context if input_str == "context" else self.test_step_outputs_target + out_dir.mkdir(parents=True, exist_ok=True) + + for i, step in tqdm(enumerate(iterations), desc=f"Evaluating {input_str}", total=len(iterations)): + is_last = (i == len(iterations) - 1) + nr_iter = iterations[i] + # j: index into per-step logs; last step uses nr_iter-1 because logs are 0-indexed up to nr_iter + j = nr_iter - 1 if is_last else nr_iter + iter_rgb = renders_list[i].color[0] # (V, 3, H, W) + + scores: dict = compute_rgb_metrics( + iter_rgb, + rgb_gt, + metrics=self.test_cfg.compute_scores_metrics, + iter_batch_size=self.test_cfg.metrics_batch_size, + ) + + if nr_gaussians_log is not None: + assert j <= len(nr_gaussians_log), f"{j}, {len(nr_gaussians_log)}" + scores["gaussians"] = torch.tensor(nr_gaussians_log[j]) + + if nr_nonzero_grads_log is not None and nr_nonzero_grads_log: + assert j <= len(nr_nonzero_grads_log), f"{j}, {len(nr_nonzero_grads_log)}" + scores["nonzero_grads"] = torch.tensor(nr_nonzero_grads_log[j]) + + if iter_time_log is not None: + assert j <= len(iter_time_log), f"{j}, {len(iter_time_log)}" + scores["time"] = torch.tensor(sum(iter_time_log[:j + 1])) + + for name, score in scores.items(): + if name == "lpips": + output_dict[f"{module_name}_alex_lpips"][-1].append(score[0].item()) + output_dict[f"{module_name}_vgg_lpips"][-1].append(score[1].item()) + else: + output_dict[f"{module_name}_{name}"][-1].append(score.item()) + output_dict[f"{module_name}_iterations"][-1].append(nr_iter) + + del iter_rgb + + # Save per-scene metrics to JSON + last_scene_metrics = {key: vals[-1] for key, vals in output_dict.items()} + if extra_scene_metrics: + last_scene_metrics.update(extra_scene_metrics) + metrics_save_path = out_dir / f"{input_str}_{module_name}.json" + with metrics_save_path.open("w") as f: + print(f"Saving metrics to {metrics_save_path}") + json.dump(last_scene_metrics, f, indent=4) + + @torch.no_grad() + def _eval_and_save( + self, + module: Initializer | Optimizer | PostProcessing3DGS, + batch: BatchedExample, + batch_idx: int, + output: OptimizerOutput | InitializerOutput, + output_path: Path, + extra_scene_metrics: dict | None = None, + ) -> dict: + """Evaluate and save results. Returns collected metrics dict (keyed by module_name_metric).""" + module_name = module.__class__.__name__.lower() + + output_path = CustomPath(output_path / module_name) + output_path.mkdir(parents=True, exist_ok=True) + + target_shape = batch["target"]["image"].shape # [B, V, 3, H, W] + context_shape = batch["context"]["image"].shape # [B, V, 3, H, W] + assert target_shape[-3:] == context_shape[-3:], f"{target_shape}, {context_shape}" + b, v, _, h, w = target_shape + assert b == 1, "Evaluation only supports scene batch size 1." + scene_name = batch["scene"][0] + + # Save poses + if self.test_cfg.save_poses: + poses_data = { + "context": {"shape": context_shape[-2:]}, + "target": {"shape": target_shape[-2:]}, + } + for key in ["extrinsics", "intrinsics", "near", "far"]: + poses_data["context"][key] = batch["context"][key][0].cpu().numpy().tolist() + poses_data["target"][key] = batch["target"][key][0].cpu().numpy().tolist() + save_path = output_path / 'poses' / f"{scene_name}_poses.json" + save_path.parent.mkdir(parents=True, exist_ok=True) + print(f"Saving poses to {save_path.parent}") + with open(save_path, 'w') as f: + json.dump(poses_data, f, indent=4) + + # Save gaussians + if self.test_cfg.save_gaussian: + if isinstance(output, InitializerOutput): + save_path = output_path / 'gaussians' / scene_name / 'init.ply' + save_gaussian_ply(output.gaussians, save_path) + elif isinstance(output, OptimizerOutput): + iterations = module.save_every.get_iterations(len(output.gaussian_list)) + for step, iter_gaussians in zip(iterations, output.gaussian_list): + save_path = output_path / 'gaussians' / scene_name / f'step{step}.ply' + save_gaussian_ply(iter_gaussians, save_path) + else: + raise ValueError(f"Unknown output type: {type(output)}") + + input_strs = ["target"] + if self.test_cfg.eval_context_views: + input_strs.insert(0, "context") + + depth_vmin, depth_vmax = self._compute_depth_range(output, input_strs, module) + error_vmax = self._compute_error_vmax(output, input_strs, module, batch) + + for input_str in input_strs: + indices = batch[input_str]["index"][0] # (V,) # TODO Naama: bug when using opt bs > 0 + renders_list, iterations = self._get_renders_list(output, input_str, module) + if renders_list is None: + continue + + if "clean_image" in batch[input_str]: + rgb_gt = batch[input_str]["clean_image"][0].cpu() # (V, 3, H, W) + else: + rgb_gt = batch[input_str]["image"][0].cpu() # (V, 3, H, W) + depth_gt = batch[input_str].get("depth", None) + + # save pred rgbs + if self.test_cfg.save_render_image: + if self.test_cfg.save_render_image_last_only: + self.test_save_last_rendered_images(renders_list, indices, output_path, scene_name, input_str) + else: + self.test_save_rendered_images(renders_list, indices, output_path, scene_name, input_str) + + # save gt rgbs + if self.test_cfg.save_gt_image and rgb_gt is not None: + self.test_save_gt_images(rgb_gt, indices, output_path, scene_name, input_str) + + # save rgb error maps + if self.test_cfg.save_error_image and rgb_gt is not None: + self.test_save_rendered_errors( + renders_list, rgb_gt, indices, output_path, scene_name, input_str, + vmin=0.0, + vmax=error_vmax, + ) + + # save depths + if self.test_cfg.save_render_depth: + self.test_save_rendered_depth(renders_list, indices, output_path, scene_name, input_str, + vmin=depth_vmin, vmax=depth_vmax) + + # save gt depths + if self.test_cfg.save_gt_depth and depth_gt is not None: + self.test_save_gt_depth(depth_gt, indices, output_path, scene_name, input_str, + vmin=depth_vmin, vmax=depth_vmax) + + # save video + # TODO Naama: reorganize video rendering + # Note: when video mode is enabled this returns early, skipping score computation. + if module is not None and self.test_cfg.save_video and isinstance(output, OptimizerOutput): + # Generate only for the first view in the batch + # Generate a video with optimization trajectory for the first view (using ffmpeg) + if input_str == "target": + if self.test_cfg.save_video_fixed_view: + self.render_supp_videos(batch, h, input_str, iterations, output.gaussian_list, output_path, + scene_name, v, w, fixed_view_video=True, video_type="fixed_view") + if self.test_cfg.save_video_fixed_iteration: + for t in self.test_cfg.save_video_fixed_iteration_indices: + self.render_supp_videos(batch, h, input_str, iterations, output.gaussian_list, output_path, + scene_name, v, w, + fixed_view_video=self.test_cfg.save_video_fixed_iteration_render_fixed_view, + # render a fixed view until the required iteration + fixed_iteration_video=True, + fixed_iteration_indices=[t], + video_type="fixed_iteration") + if self.test_cfg.save_video_combined: + self.render_supp_videos(batch, h, input_str, iterations, output.gaussian_list, output_path, + scene_name, v, w, + fixed_view_video=True, + fixed_iteration_video=True, + fixed_iteration_indices=self.test_cfg.save_video_combined_iterations, + fixed_iteration_length=self.test_cfg.save_video_combined_fixed_iteration_length, + video_type="combined") + return + + # Compute scores + if self.test_cfg.compute_scores: + print("\nComputing scores...") + self._compute_and_save_scores( + module, output, renders_list, rgb_gt, iterations, input_str, module_name, + out_dir=output_path / "metrics" / scene_name, + extra_scene_metrics=extra_scene_metrics, + ) + + # Merge metrics from target (and context if evaluated) for combined plotting + all_metrics = {} + if self.test_cfg.compute_scores: + for key, vals in self.test_step_outputs_target.items(): + if vals: + all_metrics[key] = vals[-1] + for key, vals in self.test_step_outputs_context.items(): + if vals: + all_metrics[key] = vals[-1] + return all_metrics + + @staticmethod + def _plot_combined_metrics( + output_path: Path, + scene_name: str, + phases: list[tuple[str, dict]], + ): + """Create a combined metrics plot for optimizer + postprocessing per scene. + + Args: + output_path: Root output directory. + scene_name: Name of the current scene. + phases: List of (label, metrics_dict) tuples in order. Each metrics_dict + has keys like "{label}_psnr", "{label}_iterations", etc. + """ + try: + MetaTrainer._plot_combined_metrics_impl(output_path, scene_name, phases) + except Exception as e: + warn(f"[plot] failed to create combined metrics plot: {e}") + import traceback + traceback.print_exc() + + @staticmethod + def _plot_combined_metrics_impl( + output_path: Path, + scene_name: str, + phases: list[tuple[str, dict]], + ): + from matplotlib import pyplot as plt + + # Filter out phases with no data + phases = [(label, data) for label, data in phases if data] + if not phases: + print("[plot] No metrics data available, skipping combined plot.") + return + + plot_metrics = ["psnr", "ssim", "alex_lpips", "vgg_lpips", "gaussians"] + + # Build combined series for each metric + plots = [] # (title, x_values_list, y_values_list, labels_list, divider_x) + for metric in plot_metrics: + combined_x = [] + combined_y = [] + combined_labels = [] + x_offset = 0 + divider_x = None + + for phase_idx, (label, data) in enumerate(phases): + iter_key = f"{label}_iterations" + metric_key = f"{label}_{metric}" + + iterations = data.get(iter_key, []) + values = data.get(metric_key, []) + + if not iterations or not values: + continue + + n = min(len(iterations), len(values)) + xs = [x_offset + iterations[j] for j in range(n)] + ys = values[:n] + + combined_x.append(xs) + combined_y.append(ys) + combined_labels.append(label) + + if phase_idx < len(phases) - 1 and xs: + divider_x = xs[-1] + x_offset = divider_x + + if combined_x: + plots.append((metric, combined_x, combined_y, combined_labels, divider_x)) + + if not plots: + print("[plot] No plottable metrics found, skipping combined plot.") + return + + fig, axes = plt.subplots(len(plots), 1, figsize=(10, 3.5 * len(plots)), squeeze=False) + axes = axes[:, 0] + + # Metrics where lower is better + lower_is_better = {"alex_lpips", "vgg_lpips"} + + for ax, (metric_name, x_lists, y_lists, labels, divider_x) in zip(axes, plots): + for xs, ys, label in zip(x_lists, y_lists, labels): + ax.plot(xs, ys, marker=".", markersize=3, label=label) + if divider_x is not None: + ax.axvline(x=divider_x, color="gray", linestyle="--", linewidth=1, alpha=0.7) + + # Find and annotate the best value across all phases + all_ys = [v for ys in y_lists for v in ys] + if all_ys: + if metric_name in lower_is_better: + best_val = min(all_ys) + else: + best_val = max(all_ys) + ax.axhline(y=best_val, color="red", linestyle=":", linewidth=1, alpha=0.6) + ax.text( + 1.0, best_val, f" best={best_val:.4f}", + transform=ax.get_yaxis_transform(), + va="bottom", ha="right", fontsize=7, color="red", + ) + + ax.set_title(metric_name) + ax.set_xlabel("iteration") + ax.set_ylabel(metric_name) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + + fig.suptitle(f"Scene: {scene_name}", fontsize=12, y=1.0) + fig.tight_layout() + + plot_dir = output_path / "plots" / scene_name + plot_dir.mkdir(parents=True, exist_ok=True) + save_path = plot_dir / "combined_metrics.png" + plt.savefig(save_path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved combined metrics plot to {save_path}") + + # region ==================== Save Results Methods ======================= + @staticmethod + def test_save_rendered_images(renders_list: list, indices, output_path, scene_name, input_str): + out_dir = output_path / "images" / scene_name / f"color_{input_str}" + for i, index in tqdm(enumerate(indices), desc=f"Saving {input_str} images"): + color = [] + for iter_renders in renders_list: + iter_rgbs = iter_renders.color[0] # (V, 3, H, W) + color.append(iter_rgbs[i]) + color = torch.cat(color, dim=-1) # concat along width + save_image(color, out_dir / f"{index:06d}.png") + del iter_rgbs + del color + # save last image separately too + MetaTrainer.test_save_last_rendered_images(renders_list, indices, output_path, scene_name, input_str) + + @staticmethod + def test_save_last_rendered_images(renders_list: list, indices, output_path, scene_name, input_str): + out_dir = output_path / "images" / scene_name / "last" / f"color_{input_str}" + out_dir.mkdir(parents=True, exist_ok=True) + for i, index in tqdm(enumerate(indices), desc=f"Saving {input_str} last images"): + iter_renders = renders_list[-1] + iter_rgbs = iter_renders.color[0] # (V, 3, H, W) + color = iter_rgbs[i] + save_image(color, out_dir / f"{index:06d}.png") + del iter_rgbs + del color + + @staticmethod + def test_save_gt_images(rgb_gt, indices, output_path, scene_name, input_str): + out_dir = output_path / "images" / scene_name / f"color_{input_str}" + for index, gt in tqdm(zip(indices, rgb_gt), desc=f"Saving {input_str} GT images"): + save_image(gt, out_dir / f"{index:06d}_gt.png") + + @staticmethod + def test_save_rendered_depth(renders_list: list, indices, output_path, scene_name, input_str, + vmin: float = 0.0, + vmax: float = 1.0): + out_dir = output_path / "images" / scene_name / f"depth_{input_str}" + for i, index in tqdm(enumerate(indices), desc=f"Saving {input_str} depths"): + depth = [] + for iter_renders in renders_list: + iter_depths = iter_renders.depth # (1, V, 3, H, W) + assert iter_depths is not None, "Depths not found in renders." + iter_depths = iter_depths[0] # (V, 3, H, W) + depth.append(iter_depths[i]) + depth = torch.cat(depth, dim=-1) # concat along width + color = viz_depth_tensor(depth, return_numpy=False, as_uint8=False, vmin=vmin, vmax=vmax) + save_image(color, out_dir / f"{index:06d}.png") + del iter_depths + del color + + @staticmethod + def test_save_gt_depth(depth_gt, indices, output_path, scene_name, input_str, vmin: float = 0.0, + vmax: float = 1.0): + out_dir = output_path / "images" / scene_name / f"depth_{input_str}" + for index, gt in tqdm(zip(indices, depth_gt), desc=f"Saving {input_str} GT depths"): + color = viz_depth_tensor(gt, return_numpy=False, as_uint8=False, vmin=vmin, vmax=vmax) + save_image(color, out_dir / f"{index:06d}_gt.png") + + @staticmethod + def test_save_rendered_errors(renders_list: list, rgb_gt, indices, output_path, scene_name, input_str, + vmin: float = 0.0, + vmax: float = 1.0): + # rgb_gt is (V, 3, H, W). Looping per view (outer) and per iteration (inner) + # keeps only one view's iterations in memory at a time. + out_dir = output_path / "images" / scene_name / f"error_{input_str}" + for i, index in tqdm(enumerate(indices), desc=f"Saving {input_str} errors"): + error_maps = [] + for iter_renders in renders_list: + iter_rgbs = iter_renders.color[0] # (V, 3, H, W) + error_maps.append((iter_rgbs[i] - rgb_gt[i].to(iter_rgbs)).abs().mean(0)) # (H, W) + error_map = torch.cat(error_maps, dim=-1) # concat along width + color = viz_depth_tensor(error_map, return_numpy=False, as_uint8=False, colormap='magma', + vmin=vmin, vmax=vmax) + save_image(color, out_dir / f"{index:06d}.png") + del iter_rgbs, error_maps, error_map + + def save_colmap_test_train_views(self, batch, h, w): + # load the distortion parameters from the original colmap data + assert self.test_cfg.ori_colmap_data_path is not None + (scene_name,) = batch["scene"] + output_path = self.test_cfg.output_path + # training views + input_images = batch["context"]["image"][0] # [V, 3, H, W] + index = batch["context"]["index"][0] + for idx, color in zip(index, input_images): + # NOTE: the original image id starts from 1 + save_image(color, output_path / scene_name / "images_train" / f"frame_{idx + 1:05d}.png") + # testing views + target_images = batch["target"]["image"][0] # [V, 3, H, W] + index = batch["target"]["index"][0] + for idx, color in zip(index, target_images): + # NOTE: the original image id starts from 1 + save_image(color, output_path / scene_name / "images_test" / f"frame_{idx + 1:05d}.png") + # save the camera intrinsics + intrinsics = batch["context"]["intrinsics"][0][0].clone() # [3, 3] + # need to rescale to the image size + intrinsics[0, :] *= w + intrinsics[1, :] *= h + # distortion parameters + json_path = os.path.join(self.test_cfg.ori_colmap_data_path, scene_name, "nerfstudio", "transforms.json") + assert os.path.exists(json_path), f"Cannot find {json_path}" + sparse_save_dir = output_path / scene_name / "sparse" / "0" + sparse_save_dir_train = output_path / scene_name / "sparse_train" / "0" + # save to cameras.bin + save_opencv_camera(intrinsics.cpu().numpy(), json_path, sparse_save_dir, image_size=(w, h)) + save_opencv_camera(intrinsics.cpu().numpy(), json_path, sparse_save_dir_train, image_size=(w, h)) + # extract extrinsics from the dense view images.bin + dense_view_extrinsics = os.path.join(self.test_cfg.ori_colmap_data_path, scene_name, + "nerfstudio/colmap/sparse/0") + selected_train_ids = [idx + 1 for idx in batch["context"]["index"][0].tolist()] + selected_test_ids = [idx + 1 for idx in batch["target"]["index"][0].tolist()] + selected_ids = selected_train_ids + selected_test_ids + # also save the sparse features and points3D + extract_sparse_images_bin(dense_view_extrinsics, sparse_save_dir, selected_ids, keep_features=False) + # only for training views: reconstruct the sparse point cloud only from training views + extract_sparse_images_bin(dense_view_extrinsics, sparse_save_dir_train, selected_train_ids) + return + + def compute_depth_scores(self, batch, depth_gt, init_pred_depths): + pred_depths = init_pred_depths[0] # [V, H, W] + depth_gt = depth_gt[0] # [V, H, W] + near = batch["context"]["near"][..., + None, None][0] # [V, 1, 1] + far = batch["context"]["far"][..., None, None][0] # [V, 1, 1] + valid = (depth_gt >= near) & (depth_gt <= far) + all_metrics = compute_depth_errors(depth_gt[valid].detach().cpu().numpy(), + pred_depths[valid].detach().cpu().numpy()) + print(all_metrics) + self.test_step_outputs_target[f"abs_rel"].append( + float(all_metrics[0])) + self.test_step_outputs_target[f"rmse"].append(float(all_metrics[2])) + self.test_step_outputs_target[f"a1"].append(float(all_metrics[4])) + + def init_output_dict_for_new_scene(self, input_str, tag=None): + tag = "" if tag is None else f"{tag}_" + + if input_str == "target": + output_dict = self.test_step_outputs_target + elif input_str == "context": + output_dict = self.test_step_outputs_context + else: + raise ValueError(f"Unknown input_str={input_str}") + + for metric in ["psnr", "ssim", "lpips", "iterations", "time", "gaussians", "nonzero_grads"]: + if metric == "lpips": + # alex + key = f"{tag}alex_lpips" + output_dict[key].append([]) + # vgg + key = f"{tag}vgg_lpips" + output_dict[key].append([]) + else: + key = f"{tag}{metric}" + output_dict[key].append([]) + + # endregion + + # region ==================== Video Rendering Methods ==================== + def render_supp_videos(self, batch, h, input_str, all_iterations, gaussian_list, output_path, scene_name, + v, w, fixed_view_video=False, fixed_iteration_video=False, + fixed_iteration_indices=None, + fixed_iteration_length=-1, + video_type=None): + out_dir = output_path / "supp_videos" / scene_name + out_dir.mkdir(parents=True, exist_ok=True) + combined_iterations = [] + + view = self.test_cfg.save_video_fixed_view_index + + all_frames = [] + + duplicate = self.test_cfg.save_video_fixed_view_duplicate # to focus on the optimization steps + + if fixed_iteration_length > 0: + start = view + end = view + fixed_iteration_length + else: + start = None + end = None + + for i, t in enumerate(all_iterations): + # Render only the view + if fixed_view_video: + decoder_output = self.test_render_videos_views(batch, gaussian_list[i], h, v, w, input_str, + start=view, end=view + 1) + frames_t = decoder_output.color[0].detach().cpu() # (1, 3, H, W) + assert frames_t.shape[0] == 1, f"{frames_t.shape}" + all_frames += [frames_t[0]] * duplicate # (3, H, W) + combined_iterations.extend([t] * duplicate) + + if fixed_iteration_video: + if t in fixed_iteration_indices: + # Render a trajectory around the scene + decoder_output = self.test_render_videos_views(batch, gaussian_list[i], h, v, w, input_str, + start=start, end=end) + frames_t = decoder_output.color[0].detach().cpu() # (num_frames, 3, H, W) + for i in range(3): # forward and backward + for frame in frames_t: + all_frames += [frame] * 3 + for frame in frames_t.flip(0): + all_frames += [frame] * 3 + combined_iterations.extend(['orbit'] * frames_t.shape[0] * 2) + if fixed_iteration_video and not fixed_view_video: + break # no need to continue + + if video_type == "combined": + save_str = f"_combined_{view}" + elif video_type == "fixed_view": + save_str = f"_fixed_view_{view}" + elif video_type == "fixed_iteration": + assert len(fixed_iteration_indices) == 1, f"{fixed_iteration_indices}" + save_str = f"_fixed_iteration_{fixed_iteration_indices[0]}" + else: + raise ValueError + save_video(all_frames, out_dir / f"{input_str}_{save_str}.mp4") + with (open(out_dir / f"{input_str}_{save_str}_iterations.json", 'w')) as f: + json.dump(combined_iterations, f, indent=4) + + def test_render_videos_views(self, batch, gaussians, h, v, w, input_str="target", poses=None, start=None, end=None): + gaussians = gaussians.to(batch["target"]["image"].device) + with self.benchmarker.time("decoder", num_calls=v): + if poses is None: + camera_poses = batch[input_str]["extrinsics"] + + if self.test_cfg.stablize_camera: + stable_poses = render_stabilization_path( + camera_poses[0].detach().cpu().numpy(), + k_size=self.test_cfg.stab_camera_kernel, + ) + + stable_poses = list( + map( + lambda x: np.concatenate( + (x, np.array([[0.0, 0.0, 0.0, 1.0]])), axis=0 + ), + stable_poses, + ) + ) + stable_poses = torch.from_numpy(np.stack(stable_poses, axis=0)).to( + camera_poses + ) + camera_poses = stable_poses.unsqueeze(0) + else: + camera_poses = poses.unsqueeze(0) + + if self.test_cfg.render_chunk_size is not None: + assert start is None + assert end is None + chunk_size = self.test_cfg.render_chunk_size + num_chunks = math.ceil(camera_poses.shape[1] / chunk_size) + + output = None + for i in range(num_chunks): + start = chunk_size * i + end = chunk_size * (i + 1) + curr_output = self.scene_decoder.forward_batch(gaussians, batch, (h, w), + input_str=input_str, + start=start, end=end, camera_poses=camera_poses) + + if i == 0: + output = curr_output + else: + # ignore depth + output.color = torch.cat((output.color, curr_output.color), dim=1) + + else: + output = self.scene_decoder.forward_batch(gaussians.to(batch["target"]["image"].device), + batch, (h, w), + input_str=input_str, + camera_poses=camera_poses, + start=start, + end=end) + return output + + @rank_zero_only + def render_video_wobble(self, batch: BatchedExample) -> None: + # Two views are needed to get the wobble radius. + _, v, _, _ = batch["context"]["extrinsics"].shape + if v != 2: + return + + def trajectory_fn(t): + origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] + origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] + delta = (origin_a - origin_b).norm(dim=-1) + extrinsics = generate_wobble( + batch["context"]["extrinsics"][:, 0], + delta * 0.25, + t, + ) + intrinsics = repeat( + batch["context"]["intrinsics"][:, 0], + "b i j -> b v i j", + v=t.shape[0], + ) + return extrinsics, intrinsics + + return self.render_video_generic(batch, trajectory_fn, "wobble", num_frames=60) + + @rank_zero_only + def render_video_interpolation(self, batch: BatchedExample) -> None: + _, v, _, _ = batch["context"]["extrinsics"].shape + + def trajectory_fn(t): + extrinsics = interpolate_extrinsics( + batch["context"]["extrinsics"][0, 0], + ( + batch["context"]["extrinsics"][0, 1] + if v == 2 + else batch["target"]["extrinsics"][0, 0] + ), + t, + ) + intrinsics = interpolate_intrinsics( + batch["context"]["intrinsics"][0, 0], + ( + batch["context"]["intrinsics"][0, 1] + if v == 2 + else batch["target"]["intrinsics"][0, 0] + ), + t, + ) + return extrinsics[None], intrinsics[None] + + return self.render_video_generic(batch, trajectory_fn, "rgb") + + @rank_zero_only + def render_video_interpolation_exaggerated(self, batch: BatchedExample) -> None: + # Two views are needed to get the wobble radius. + _, v, _, _ = batch["context"]["extrinsics"].shape + if v != 2: + return + + def trajectory_fn(t): + origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] + origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] + delta = (origin_a - origin_b).norm(dim=-1) + tf = generate_wobble_transformation( + delta * 0.5, + t, + 5, + scale_radius_with_t=False, + ) + extrinsics = interpolate_extrinsics( + batch["context"]["extrinsics"][0, 0], + ( + batch["context"]["extrinsics"][0, 1] + if v == 2 + else batch["target"]["extrinsics"][0, 0] + ), + t * 5 - 2, + ) + intrinsics = interpolate_intrinsics( + batch["context"]["intrinsics"][0, 0], + ( + batch["context"]["intrinsics"][0, 1] + if v == 2 + else batch["target"]["intrinsics"][0, 0] + ), + t * 5 - 2, + ) + return extrinsics @ tf, intrinsics[None] + + return self.render_video_generic( + batch, + trajectory_fn, + "interpolation_exagerrated", + num_frames=300, + smooth=False, + loop_reverse=False, + ) + + @rank_zero_only + def render_video_generic( + self, + batch: BatchedExample, + trajectory_fn: TrajectoryFn, + name: str, + num_frames: int = 30, + smooth: bool = True, + loop_reverse: bool = True, + ) -> None: + if self.train_cfg.no_log_video: + return + # Render probabilistic estimate of scene. + gaussians_prob = self.encoder(batch["context"], self.global_step, False) + # gaussians_det = self.encoder(batch["context"], self.global_step, True) + + if isinstance(gaussians_prob, dict): + gaussians_prob = gaussians_prob["gaussians"] + + t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=self.device) + if smooth: + t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 + + extrinsics, intrinsics = trajectory_fn(t) + + _, _, _, h, w = batch["context"]["image"].shape + + # Color-map the result. + def depth_map(result): + near = result[result > 0][:16_000_000].quantile(0.01).log() + far = result.reshape(-1)[:16_000_000].quantile(0.99).log() + result = result.log() + result = 1 - (result - near) / (far - near) + return apply_color_map_to_image(result, "turbo") + + near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames) + far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames) + output_prob = self.scene_decoder.forward( + gaussians_prob, extrinsics, intrinsics, near, far, (h, w), "depth" + ) + rgb_pred = [ + vcat(rgb, depth) + for rgb, depth in zip(output_prob.color[0], depth_map(output_prob.depth[0])) + ] + + images = [ + add_border( + hcat( + add_label(image_prob, "Prediction"), + ) + ) + for image_prob, _ in zip(rgb_pred, rgb_pred) + ] + + video = torch.stack(images) + video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy() + if loop_reverse: + video = pack([video, video[::-1][1:-1]], "* c h w")[0] + + visualizations = { + f"video/{name}": wandb.Video(video[None], fps=30, format="mp4") + } + + # Since the PyTorch Lightning doesn't support video logging, log to wandb directly. + try: + wandb.log(visualizations) + except Exception: + assert isinstance(self.logger, LocalLogger) + for key, value in visualizations.items(): + tensor = value._prepare_video(value.data) + clip = mpy.ImageSequenceClip(list(tensor), fps=value._fps) + dir = LOG_PATH / key + dir.mkdir(exist_ok=True, parents=True) + clip.write_videofile( + str(dir / f"{self.global_step:0>6}.mp4"), logger=None + ) + + # endregion + + # region ==================== Delegation Methods ========================= + + def get_optimized_gaussians(self, *args, **kwargs): + """Delegate to SceneTrainer's get_optimized_gaussians.""" + return self.scene_trainer.get_optimized_gaussians(*args, **kwargs) + + def get_init_gaussians(self, *args, **kwargs): + """Delegate to SceneTrainer's get_init_gaussians.""" + return self.scene_trainer.get_init_gaussians(*args, **kwargs) + + def init_gaussians_and_render(self, *args, **kwargs): + """Delegate to SceneTrainer's init_gaussians_and_render.""" + return self.scene_trainer.init_gaussians_and_render(*args, **kwargs) + + def test_postprocess_gaussians(self, *args, **kwargs): + """Delegate to SceneTrainer's test_postprocess_gaussians.""" + return self.scene_trainer.test_postprocess_gaussians(*args, **kwargs) + + @property + def scene_initializer(self): + """Delegate to SceneTrainer's initializer.""" + return self.scene_trainer.initializer + + @property + def scene_optimizer(self): + """Delegate to SceneTrainer's optimizer.""" + return self.scene_trainer.optimizer + + @property + def scene_decoder(self): + """Delegate to SceneTrainer's decoder.""" + return self.scene_trainer.decoder + + # endregion diff --git a/optgs/meta_trainer/replay_buffer.py b/optgs/meta_trainer/replay_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..22577561c95e2087d86cf36f03dbae78ae48abb5 --- /dev/null +++ b/optgs/meta_trainer/replay_buffer.py @@ -0,0 +1,154 @@ +import random +from dataclasses import dataclass, is_dataclass, fields, replace +from typing import Any + +import torch + +from optgs.dataset.data_types import BatchedExample +from optgs.model.types import Gaussians +from optgs.scene_trainer.optimizer.optimizer import OptimizerState + + +def to_device(obj: Any, device: torch.device | str, detach=True) -> Any: + """ + Recursively moves all tensors (and nested dataclasses) to the given device. + - Skips None fields + - Works with nested dataclasses + - Works with lists/tuples of tensors or dataclasses + """ + if torch.is_tensor(obj): + if detach: + obj = obj.detach() + return obj.to(device) + + elif is_dataclass(obj): + kwargs = {} + for f in fields(obj): + val = getattr(obj, f.name) + if val is not None: + kwargs[f.name] = to_device(val, device, detach=detach) + return replace(obj, **kwargs) + + elif isinstance(obj, (list, tuple)): + return type(obj)(to_device(v, device, detach=detach) for v in obj) + + elif isinstance(obj, dict): + return {k: to_device(v, device, detach=detach) for k, v in obj.items()} + + else: + return obj # Leave unchanged (e.g., int, float, str) + + +@dataclass +class GaussianEpisodeEntry: + id: int + t: int + batch: BatchedExample + gaussians: Gaussians + state: OptimizerState | None = None + info: dict[str, Any] | None = None + + +@dataclass +class ReplayBufferCfg: + capacity: int # number of snapshots to store + sample_batch_size: int # number of snapshots to sample when resuming training + sample_prob: float | int # probability of sampling from the buffer vs starting fresh + insert_prob: float | int # probability of pushing to the buffer a new sample + return_prob: float | int # probability of returning the sampled snapshot (vs discarding it) + simulate_ahead: bool # whether to simulate ahead before returning the updated snapshot + simulate_ahead_min_steps: int # min steps to simulate ahead + simulate_ahead_max_steps: int # max steps to simulate ahead + simulate_ahead_grow: int # number of steps to scale up the max steps over meta iterations + max_t: int | None # maximum number of inner steps per episode + push_only_if_not_full: bool # only push if buffer is not full + remove_strategy_when_full: str # strategy to remove entries when buffer is full: "oldest" or "random" + + +class EpisodeReplayBuffer: + def __init__(self, cfg: ReplayBufferCfg): + self.cfg = cfg + self.buffer = [] + assert self.cfg.sample_batch_size == 1, "Only batch size of 1 is supported for now." + + def push(self, entry, to_cpu=True): + """Store one snapshot (intermediate state of training). + + If the buffer is full, the oldest snapshot will be removed. + """ + if to_cpu: + entry = to_device(entry, 'cpu', detach=True) + + self.buffer.append(entry) + if len(self.buffer) > self.cfg.capacity: + if self.cfg.remove_strategy_when_full == "oldest": + self.buffer.pop(0) # remove oldest if full + elif self.cfg.remove_strategy_when_full == "random": + idx = random.randint(0, len(self.buffer) - 2) # remove random except the newly added one + del self.buffer[idx] + else: + raise ValueError("Invalid remove strategy when full") + + def sample(self, device, leave_batch_fn=None): + """Return and remove a random element from the buffer.""" + if len(self.buffer) < self.cfg.sample_batch_size: + raise ValueError("Not enough elements in the buffer to sample") + + # Sample random entries + indices = random.sample(range(len(self.buffer)), self.cfg.sample_batch_size) + sampled_entries = [self.buffer[i] for i in indices] + + # Remove from buffer by index (must go in reverse to avoid shifting) + for idx in sorted(indices, reverse=True): + del self.buffer[idx] + + assert self.cfg.sample_batch_size == 1, "Only batch size of 1 is supported for now." + sampled_entries = sampled_entries[0] + + # Move to device + if leave_batch_fn is not None: + batch = sampled_entries.batch + # should_move_batch = leave_batch_fn(batch) + sampled_entries = to_device(sampled_entries, device) + + return sampled_entries + + def flipcoin(self, action: str): + """Flip a coin to decide whether to sample or push.""" + if action == "sample": + return random.random() < self.cfg.sample_prob + elif action == "insert": + return random.random() < self.cfg.insert_prob + elif action == "return": + return random.random() < self.cfg.return_prob + else: + raise ValueError("sample_or_push must be 'sample' or 'push'") + + def should_sample(self): + buffer_is_not_full = len(self.buffer) < self.cfg.capacity + if buffer_is_not_full: + return False + return len(self.buffer) >= self.cfg.sample_batch_size and self.flipcoin("sample") + + def should_push(self, new_sample: bool, t: int): + if self.cfg.push_only_if_not_full and len(self.buffer) >= self.cfg.capacity: + return False # do not push if buffer is full + + if self.cfg.max_t is not None: + if t >= self.cfg.max_t: + return # do not store entries beyond max_t + + if len(self.buffer) < self.cfg.capacity: + # Always fill the buffer if possible + return True + + if new_sample: + return self.flipcoin("insert") + else: + return self.flipcoin("return") + + def __len__(self): + return len(self.buffer) + + def clear(self): + self.buffer.clear() \ No newline at end of file diff --git a/optgs/misc/LocalLogger.py b/optgs/misc/LocalLogger.py new file mode 100644 index 0000000000000000000000000000000000000000..963423f00d48e17c744cec713016be520ee431af --- /dev/null +++ b/optgs/misc/LocalLogger.py @@ -0,0 +1,54 @@ +import os +import torch +import numpy as np +from pathlib import Path +from typing import Any, Optional + +from PIL import Image +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.utilities import rank_zero_only + +LOG_PATH = Path("outputs/local") + + +class LocalLogger(Logger): + def __init__(self) -> None: + super().__init__() + self.experiment = None + if LOG_PATH.exists(): + os.system(f"rm -r {LOG_PATH}") + + @property + def name(self): + return "LocalLogger" + + @property + def version(self): + return 0 + + @rank_zero_only + def log_hyperparams(self, params): + pass + + @rank_zero_only + def log_metrics(self, metrics, step): + pass + + @rank_zero_only + def log_image( + self, + key: str, + images: list[Any], + step: Optional[int] = None, + **kwargs, + ): + # The function signature is the same as the wandb logger's, but the step is + # actually required. + assert step is not None + for index, image in enumerate(images): + path = LOG_PATH / f"{key}/{index:0>2}_{step:0>6}.png" + path.parent.mkdir(exist_ok=True, parents=True) + if isinstance(image, torch.Tensor): + Image.fromarray(image.permute(1, 2, 0).numpy().astype(np.uint8)).save(path) + else: + Image.fromarray(image).save(path) diff --git a/optgs/misc/__init__.py b/optgs/misc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/misc/batchify.py b/optgs/misc/batchify.py new file mode 100644 index 0000000000000000000000000000000000000000..f1c26dc23cddb4b79283244273afb6900485bfa5 --- /dev/null +++ b/optgs/misc/batchify.py @@ -0,0 +1,36 @@ +import torch + + +def split_to_minibatch(batch_split, iter_context_idxs): + minibatch = { + "image": batch_split["image"][0][iter_context_idxs].unsqueeze( + 0 + ), # [1, Vc', 3, Hc, Wc] + "extrinsics": batch_split["extrinsics"][0][iter_context_idxs].unsqueeze( + 0 + ), # [1, Vc', 4, 4] + "intrinsics": batch_split["intrinsics"][0][iter_context_idxs].unsqueeze( + 0 + ), # [1, Vc', 4, 4] + "near": batch_split["near"][0][iter_context_idxs].unsqueeze(0), # [1, Vc'] + "far": batch_split["far"][0][iter_context_idxs].unsqueeze(0), # [1, Vc'] + } + return minibatch + + +def batched_select(data, indices): + """ + Select data[i, indices[i]] for each batch element i. + + Args: + data: [B, N, ...] input tensor + indices: [B, K] indices for each batch element + + """ + + assert data.shape[0] == indices.shape[0], f"Batch size mismatch {data.shape[0]} vs {indices.shape[0]}" + assert indices.dim() == 2, f"indices should be 2D, got {indices.shape}" + + B = data.shape[0] + batch_idx = torch.arange(B, device=data.device)[:, None] + return data[batch_idx, indices] diff --git a/optgs/misc/benchmarker.py b/optgs/misc/benchmarker.py new file mode 100644 index 0000000000000000000000000000000000000000..644347a28b809c770b6677523191767ca0c24405 --- /dev/null +++ b/optgs/misc/benchmarker.py @@ -0,0 +1,49 @@ +import json +from collections import defaultdict +from contextlib import contextmanager +from pathlib import Path +from time import time + +import numpy as np +import torch + + +class Benchmarker: + def __init__(self): + self.execution_times = defaultdict(list) + + @contextmanager + def time(self, tag: str, num_calls: int = 1): + try: + start_time = time() + yield + finally: + end_time = time() + for _ in range(num_calls): + self.execution_times[tag].append((end_time - start_time) / num_calls) + + def record(self, tag: str, elapsed_ms: float) -> None: + """Record a pre-measured elapsed time (in milliseconds) under the given tag.""" + self.execution_times[tag].append(elapsed_ms) + + def merge(self, other: "Benchmarker") -> None: + """Merge another benchmarker's recorded times into this one.""" + for tag, times in other.execution_times.items(): + self.execution_times[tag].extend(times) + + def dump(self, path: Path) -> None: + path.parent.mkdir(exist_ok=True, parents=True) + with path.open("w") as f: + json.dump(dict(self.execution_times), f) + + def dump_memory(self, path: Path) -> None: + path.parent.mkdir(exist_ok=True, parents=True) + with path.open("w") as f: + json.dump(torch.cuda.memory_stats()["allocated_bytes.all.peak"], f) + + def summarize(self) -> None: + for tag, times in self.execution_times.items(): + print(f"{tag}: {len(times)} calls, avg {np.mean(times):.1f} ms/call, total {sum(times)/1000:.1f} s") + + def clear_history(self) -> None: + self.execution_times = defaultdict(list) diff --git a/optgs/misc/checkpointing.py b/optgs/misc/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2dc9ea7ff80b16be57acce5c28ea2f2855d17b --- /dev/null +++ b/optgs/misc/checkpointing.py @@ -0,0 +1,238 @@ +import os +from collections import OrderedDict +from typing import Any + +import torch + +from optgs.misc.io import cyan + + +# Function to extract the step number from the filename +def extract_step(file_name): + step_str = file_name.split("-")[1].split("_")[1].replace(".ckpt", "") + return int(step_str) + + +def find_latest_ckpt(ckpt_dir): + # List all files in the directory that end with .ckpt + ckpt_files = [f for f in os.listdir(ckpt_dir) if f.endswith(".ckpt")] + + # Check if there are any .ckpt files in the directory + if not ckpt_files: + raise ValueError(f"No .ckpt files found in {ckpt_dir}.") + else: + # Find the file with the maximum step + latest_ckpt_file = max(ckpt_files, key=extract_step) + return ckpt_dir / latest_ckpt_file + + +def no_resume_upsampler(pretrained_state_dict): + new_state_dict = OrderedDict() + for key, value in pretrained_state_dict.items(): + if 'upsampler' not in key: + new_state_dict[key] = value + return new_state_dict + + +def load_partial_state_dict(model, pretrained_state_dict): + # Load only matching parameters + model_state_dict = model.state_dict() + filtered_state_dict = { + k: v for k, v in pretrained_state_dict.items() + if k in model_state_dict and v.shape == model_state_dict[k].shape + } + # for key in model_state_dict: + # if key not in filtered_state_dict: + # print(key) + model_state_dict.update(filtered_state_dict) + model.load_state_dict(model_state_dict) + + +def _load_state_dict(path): + ckpt = torch.load(path, map_location='cpu') + if 'state_dict' in ckpt: + return ckpt['state_dict'] + if 'model' in ckpt: + return ckpt['model'] + return ckpt + + +def load_optimizer(cfg, scene_trainer, strict_load): + pretrained_model = torch.load(cfg.checkpointing.pretrained_optimizer, map_location='cpu') + if 'state_dict' in pretrained_model: + pretrained_model = pretrained_model['state_dict'] + # Strip scene_trainer. prefix if present (Lightning checkpoint format) + pretrained_model = {k.replace("scene_trainer.", ""): v for k, v in pretrained_model.items()} + if any(k.startswith("optimizer.") for k in pretrained_model): + # Unified repo format: keys are optimizer.* + optimizer_state_dict = {k[len("optimizer."):]: v for k, v in pretrained_model.items() if + k.startswith("optimizer.")} + else: + # Resplat repo format: keys are encoder.* (before init/opt split). + # Strip encoder. prefix; init-related keys will be ignored via strict=False. + optimizer_state_dict = {k[len("encoder."):]: v for k, v in pretrained_model.items() if k.startswith("encoder.")} + # Rename module attributes that changed when the encoder was split. + _ORIG_OPTIMIZER_ATTR_RENAMES = { + "render_error_mv_attn": "update_error_attn", + } + renamed = {} + for k, v in optimizer_state_dict.items(): + for old, new in _ORIG_OPTIMIZER_ATTR_RENAMES.items(): + if k == old or k.startswith(old + "."): + k = new + k[len(old):] + break + renamed[k] = v + optimizer_state_dict = renamed + + # If init_state_wo_features is True, remove all feature-related parameters from the optimizer state dict + print(cfg.scene_trainer.scene_optimizer.init_state_wo_features) + + if getattr(cfg.scene_trainer.scene_optimizer, "init_state_wo_features", False): + optimizer_state_dict = {k: v for k, v in optimizer_state_dict.items() if "update_proj" not in k} + scene_trainer.optimizer.load_state_dict(optimizer_state_dict, strict=strict_load) + print(cyan(f"Loaded pretrained optimizer: {cfg.checkpointing.pretrained_optimizer}")) + + +def load_initializer(cfg, scene_trainer, strict_load): + pretrained_model = torch.load(cfg.checkpointing.pretrained_initializer, map_location='cpu') + if 'state_dict' in pretrained_model: + pretrained_model = pretrained_model['state_dict'] + # Strip scene_trainer. prefix if present (Lightning checkpoint format) + pretrained_model = {k.replace("scene_trainer.", ""): v for k, v in pretrained_model.items()} + if any(k.startswith("initializer.") for k in pretrained_model): + assert all(k.startswith("initializer.") for k in pretrained_model) + # Current repo format: keys are initializer.* + initializer_state_dict = {k[len("initializer."):]: v for k, v in pretrained_model.items() if + k.startswith("initializer.")} + else: + # Resplat repo format: keys are encoder.* (before init/opt split) + initializer_state_dict = {k[len("encoder."):]: v for k, v in pretrained_model.items() if + k.startswith("encoder.")} + scene_trainer.initializer.load_state_dict(initializer_state_dict, strict=strict_load) + print(cyan(f"Loaded pretrained initializer: {cfg.checkpointing.pretrained_initializer}")) + + +def load_full_model(cfg, scene_trainer, strict_load): + pretrained_model = torch.load(cfg.checkpointing.pretrained_model, map_location='cpu') + if 'state_dict' in pretrained_model: + pretrained_model = pretrained_model['state_dict'] + if cfg.checkpointing.partial_load: + print('partial load') + load_partial_state_dict(scene_trainer, pretrained_model) + else: + scene_trainer.load_state_dict(pretrained_model, strict=strict_load) + print(cyan(f"Loaded pretrained weights: {cfg.checkpointing.pretrained_model}")) + + +def load_base_model(cfg, scene_trainer, strict_load: bool | Any): + if cfg.checkpointing.pretrained_model is not None: + load_full_model(cfg, scene_trainer, strict_load) + else: + # Load pretrained initializer if available + if cfg.checkpointing.pretrained_initializer is not None: + load_initializer(cfg, scene_trainer, strict_load) + + if cfg.checkpointing.pretrained_optimizer is not None and scene_trainer.optimizer is not None: + load_optimizer(cfg, scene_trainer, strict_load) + + +def load_model_weights(cfg, scene_trainer, strict_load, mode: str): + assert mode in ("train", "test") + + if mode == "train": + # only load monodepth + if cfg.checkpointing.pretrained_monodepth is not None: + strict_load = False + pretrained_model = torch.load(cfg.checkpointing.pretrained_monodepth, map_location='cpu') + if 'state_dict' in pretrained_model: + pretrained_model = pretrained_model['state_dict'] + if cfg.model.encoder.separate_depth_color or cfg.model.encoder.separate_depth_gaussian_scale: + scene_trainer.encoder.feature_extractor.load_state_dict(pretrained_model, strict=strict_load) + else: + scene_trainer.encoder.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) + print(cyan(f"Loaded pretrained monodepth: {cfg.checkpointing.pretrained_monodepth}")) + + # freeze mono vit + if cfg.checkpointing.freeze_mono_vit: + print('freeze mono vit') + for params in scene_trainer.encoder.depth_predictor.pretrained.parameters(): + params.requires_grad = False + + # load pretrained mvdepth + if cfg.checkpointing.pretrained_mvdepth is not None: + pretrained_model = torch.load(cfg.checkpointing.pretrained_mvdepth, map_location='cpu')['model'] + if cfg.model.encoder.separate_depth_color or cfg.model.encoder.separate_depth_gaussian_scale: + scene_trainer.encoder.feature_extractor.load_state_dict(pretrained_model, strict=False) + else: + scene_trainer.encoder.depth_predictor.load_state_dict(pretrained_model, strict=False) + print(cyan(f"Loaded pretrained mvdepth: {cfg.checkpointing.pretrained_mvdepth}")) + + # load full model (or separate initializer/optimizer checkpoints) + load_base_model(cfg, scene_trainer, strict_load) + + # load pretrained depth + if cfg.checkpointing.pretrained_depth is not None: + pretrained_model = _load_state_dict(cfg.checkpointing.pretrained_depth) + if mode == "train": + if cfg.checkpointing.partial_load: + print('partial load depth') + load_partial_state_dict(scene_trainer.initializer.depth_predictor, pretrained_model) + else: + if cfg.checkpointing.no_resume_upsampler: + pretrained_model = no_resume_upsampler(pretrained_model) + strict_load = False + scene_trainer.initializer.depth_predictor.load_state_dict(pretrained_model, strict=strict_load) + else: + scene_trainer.initializer.depth_predictor.load_state_dict(pretrained_model, strict=True) + print(cyan(f"Loaded pretrained depth: {cfg.checkpointing.pretrained_depth}")) + + # load pretrained scale predictor + if mode == "train" and cfg.checkpointing.pretrained_scale_predictor is not None: + pretrained_model = _load_state_dict(cfg.checkpointing.pretrained_scale_predictor) + scene_trainer.encoder.scale_predictor.load_state_dict(pretrained_model, strict=strict_load) + print(cyan(f"Loaded pretrained scale predictor: {cfg.checkpointing.pretrained_scale_predictor}")) + + print('freeze scale predictor') + for params in scene_trainer.encoder.scale_predictor.parameters(): + params.requires_grad = False + + # load pretrained update module + if cfg.checkpointing.resume_update_module is not None: + pretrained_model = _load_state_dict(cfg.checkpointing.resume_update_module) + + # Filter and load only matching "update_" parameters + filtered_dict = { + k: v for k, v in pretrained_model.items() + if "encoder.update" in k and k in scene_trainer.state_dict() + and v.shape == scene_trainer.state_dict()[k].shape + } + + # Load them using strict=False so it skips missing/unmatched keys + scene_trainer.load_state_dict(filtered_dict, strict=False) + print(cyan(f"Loaded pretrained update module: {cfg.checkpointing.resume_update_module}")) + + if mode == "train": + apply_freezes(cfg, scene_trainer) + + +def apply_freezes(cfg, scene_trainer): + if getattr(cfg.scene_trainer.scene_initializer, 'freeze_depth', False): + print('freeze depth') + for params in scene_trainer.initializer.depth_predictor.parameters(): + params.requires_grad = False + + if not cfg.scene_trainer.train_scene_init: + print('train refine only, freezing scene initializer') + for name, params in scene_trainer.initializer.named_parameters(): + params.requires_grad = False + + if cfg.scene_trainer.num_update_steps > 0: + if not cfg.scene_trainer.train_scene_opt: + print('train refine only, freezing scene optimizer') + for name, params in scene_trainer.optimizer.named_parameters(): + params.requires_grad = False + if cfg.scene_trainer.scene_optimizer.train_global_update_only: + print('train global update only') + for name, params in scene_trainer.optimizer.named_parameters(): + if 'global_update' not in name: + params.requires_grad = False \ No newline at end of file diff --git a/optgs/misc/console.py b/optgs/misc/console.py new file mode 100644 index 0000000000000000000000000000000000000000..6d14d69cb804796923dda0e2269c173624d34aeb --- /dev/null +++ b/optgs/misc/console.py @@ -0,0 +1,107 @@ +"""Shared rich console for consistent, professional terminal output. + +A single `CONSOLE` instance keeps styling uniform across the codebase. rich +degrades gracefully when stdout is not a tty (e.g. SLURM log files): no +animations, plain box-drawing characters — still fully readable. + +Usage: + from optgs.misc.console import CONSOLE, banner, rule, warn + + banner("optgs", ["host galvani", "mode test"]) + rule("Testing scene 3: room_0") + warn("Skipping batch 7 due to OOM") +""" +from __future__ import annotations + +import sys + +from rich import box +from rich.console import Console +from rich.panel import Panel +from rich.table import Table +from rich.theme import Theme + +OPTGS_THEME = Theme( + { + "info": "cyan", + "warning": "yellow", + "error": "bold red", + "success": "bold green", + "metric": "bold magenta", + "path": "blue underline", + "muted": "dim", + } +) + +# Off-tty (SLURM logs, pipes) rich defaults to width 80, which wraps config +# rows awkwardly. Pin a wider, fixed width there; let interactive terminals +# auto-size so output adapts to the real window. +_console_kwargs: dict = {"theme": OPTGS_THEME} +if not sys.stdout.isatty(): + _console_kwargs["width"] = 120 + +CONSOLE = Console(**_console_kwargs) + + +def banner(title: str, lines: list[str] | None = None, style: str = "info") -> None: + """Print a titled panel; `lines` form the panel body.""" + body = "\n".join(lines) if lines else "" + CONSOLE.print( + Panel(body, title=title, title_align="left", border_style=style, expand=False) + ) + + +def rule(title: str, style: str = "info") -> None: + """Print a horizontal section divider with a centered title.""" + CONSOLE.rule(f"[{style}]{title}[/{style}]", style=style) + + +def warn(msg: str) -> None: + """Print a styled warning line.""" + CONSOLE.print(f"[warning]⚠ {msg}[/warning]") + + +def error(msg: str) -> None: + """Print a styled error line.""" + CONSOLE.print(f"[error]✖ {msg}[/error]") + + +def success(msg: str) -> None: + """Print a styled success line.""" + CONSOLE.print(f"[success]✔ {msg}[/success]") + + +def metrics_table( + rows: list[tuple], headers: list[str], title: str | None = None +) -> None: + """Print a metrics table. `rows` is a list of tuples, one per row.""" + table = Table(title=title, header_style="bold", box=box.SIMPLE_HEAD) + for h in headers: + table.add_column(str(h)) + for row in rows: + table.add_row(*[str(c) for c in row]) + CONSOLE.print(table) + + +def config_table( + sections: dict[str, list[tuple[str, str]]], title: str = "Config" +) -> None: + """Print a grouped key/value table. + + `sections` maps a group name to a list of (key, value) pairs. Empty groups + are skipped; group names appear as styled separator rows. + """ + table = Table(title=title, box=box.SIMPLE, show_header=False, pad_edge=False) + table.add_column("key", style="muted", no_wrap=True) + table.add_column("value") + first = True + for group, pairs in sections.items(): + if not pairs: + continue + if not first: + table.add_row("", "") + first = False + table.add_row(f"[info]{group}[/info]", "") + for key, value in pairs: + table.add_row(f" {key}", str(value)) + CONSOLE.print(table) diff --git a/optgs/misc/detaching_cpu_list.py b/optgs/misc/detaching_cpu_list.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6e721e9e85d2c916d5fd96980d921b346fc2cc --- /dev/null +++ b/optgs/misc/detaching_cpu_list.py @@ -0,0 +1,321 @@ +# import os +# import torch +# from dataclasses import dataclass, is_dataclass, replace, field +# from typing import TypeVar, Generic, Optional, Any, Callable, Iterable +# from pathlib import Path +# import pickle +# import tempfile +# import shutil +# import atexit +# import signal +# from typing import TypeVar, Generic +# +# T = TypeVar('T') +# +# +# @dataclass +# class DetachingCPUList(list[T]): +# cache_dir: Optional[Path] = field(default=None) +# save_serializer: Optional[Callable[[Any, Path], None]] = field(default=None) +# load_serializer: Optional[Callable[[Path], Any]] = field(default=None) +# detach_func: Optional[Callable[[Any], Any]] = field(default=None) +# remove_files_on_delete: bool = field(default=True) +# verbose: bool = field(default=True) +# +# _cache: dict = field(init=False, repr=False, default_factory=dict) +# _fd_map: dict = field(init=False, repr=False, default_factory=dict) +# _no_cache_set: set = field(init=False, repr=False, default_factory=set) # Track paths marked as no_cache +# _cache_dir: Path = field(init=False, repr=False) +# _tmp_dir_created: bool = field(init=False, repr=False, default=False) +# +# def __post_init__(self): +# +# # Set default serializers if not provided +# if self.save_serializer is None: +# def _save_pickle(obj, path: Path): +# path.parent.mkdir(parents=True, exist_ok=True) +# with open(path, "wb") as f: +# pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) +# self.save_serializer = _save_pickle +# +# # Set default deserializers if not provided +# if self.load_serializer is None: +# def _load_pickle(source): +# if isinstance(source, (str, Path)): +# with open(str(source), "rb") as f: +# return pickle.load(f) +# else: +# return pickle.load(source) +# self.load_serializer = _load_pickle +# +# # Set default detach_func if not provided +# if self.detach_func is None: +# self.detach_func = self._detach_recursive +# +# if self.cache_dir is None: +# tmp = tempfile.mkdtemp(prefix="detaching_cpu_list_") +# self._cache_dir = Path(tmp) +# self._tmp_dir_created = True +# else: +# self._cache_dir = Path(self.cache_dir) +# self._cache_dir.mkdir(parents=True, exist_ok=True) +# +# atexit.register(self._cleanup) +# for sig in (signal.SIGINT, signal.SIGTERM): +# try: +# old = signal.getsignal(sig) +# def _handler(signum, frame, _old=old): +# self._cleanup() +# if callable(_old) and _old not in (signal.SIG_DFL, signal.SIG_IGN): +# _old(signum, frame) +# signal.signal(sig, _handler) +# except Exception: +# pass +# +# # -------------------------- +# # Core cleanup +# # -------------------------- +# def _cleanup(self): +# # Close all file descriptors +# for fd in list(self._fd_map.values()): +# try: +# os.close(fd) +# except Exception: +# pass +# self._fd_map.clear() +# # Remove on-disk files (non-guaranteed mode) +# if self.remove_files_on_delete and self._tmp_dir_created and self._cache_dir.exists(): +# shutil.rmtree(self._cache_dir, ignore_errors=True) +# +# # -------------------------- +# # Save / load helpers +# # -------------------------- +# def _save_unlinked(self, item: Any): +# """Save item to disk with guaranteed deletion (fd-based approach).""" +# tmp = tempfile.NamedTemporaryFile(delete=False, dir=str(self._cache_dir)) +# temp_path = Path(tmp.name) +# tmp.close() +# +# assert self.save_serializer is not None, "save_serializer must be defined" +# self.save_serializer(item, temp_path) +# fd = os.open(str(temp_path), os.O_RDONLY) +# os.unlink(str(temp_path)) # unlink immediately - kernel guarantees cleanup on fd close +# pseudo = Path(f"/proc/self/fd/{fd}") +# token = pseudo.as_posix() +# self._fd_map[token] = fd +# if self.verbose: +# print(f"Saved item to fd {fd} with path {pseudo}") +# return pseudo +# +# def _load_from_fd(self, token: str): +# """Load from file descriptor path.""" +# fd = self._fd_map.get(token) +# if fd is None: +# raise RuntimeError(f"FD {token} not available.") +# dupfd = os.dup(fd) +# assert self.load_serializer is not None, "load_serializer must be defined" +# with os.fdopen(dupfd, "rb") as f: +# f.seek(0) +# obj = self.load_serializer(f) +# if self.verbose: +# print(f"Loaded item from fd {fd} (token {token})") +# return obj +# +# def _load_from_disk(self, path: Path): +# """Load from regular file path.""" +# assert self.load_serializer is not None, "load_serializer must be defined" +# if self.verbose: +# print(f"Loading from disk: {path}") +# return self.load_serializer(path) +# +# # -------------------------- +# # Public interface +# # -------------------------- +# def append(self, item, detach_and_cpu: bool = False, save_to_disk: bool = False, no_cache: bool = False): +# """ +# Append an item to the list. +# +# Args: +# item: The item to append +# detach_and_cpu: If True, apply detach_func to move tensors to CPU +# save_to_disk: If True, save to disk using fd-based guaranteed deletion +# no_cache: If True, never cache this item in memory when accessed (always reload from disk) +# """ +# # Validate save_to_disk requires detach capability +# if save_to_disk and not detach_and_cpu: +# raise ValueError("Cannot save to disk without detach_and_cpu=True") +# +# if not save_to_disk and no_cache: +# print("Warning: no_cache=True has no effect when save_to_disk=False") +# +# if detach_and_cpu and self.detach_func: +# item = self.detach_func(item) +# +# if save_to_disk: +# # Always use fd-based guaranteed deletion +# p = self._save_unlinked(item) +# super().append(p) +# +# # Mark as no_cache if requested +# if no_cache: +# self._no_cache_set.add(p.as_posix()) +# else: +# super().append(item) +# +# def insert(self, index: int, item, detach_and_cpu: bool = False, save_to_disk: bool = False, no_cache: bool = False): +# """ +# Insert an item at a specific index in the list. +# +# Args: +# index: The index to insert the item at +# item: The item to insert +# detach_and_cpu: If True, apply detach_func to move tensors to CPU +# save_to_disk: If True, save to disk using fd-based guaranteed deletion +# no_cache: If True, never cache this item in memory when accessed (always reload from disk) +# """ +# # Validate save_to_disk requires detach capability +# if save_to_disk and not detach_and_cpu: +# raise ValueError("Cannot save to disk without detach_and_cpu=True") +# +# if detach_and_cpu and self.detach_func: +# item = self.detach_func(item) +# +# if save_to_disk: +# # Always use fd-based guaranteed deletion +# p = self._save_unlinked(item) +# super().insert(index, p) +# +# # Mark as no_cache if requested +# if no_cache: +# self._no_cache_set.add(p.as_posix()) +# else: +# super().insert(index, item) +# +# def extend(self, items: Iterable[Any], **kwargs): +# for it in items: +# self.append(it, **kwargs) +# +# def __getitem__(self, index): +# """ +# Return the item at `index`. +# +# - If the underlying stored value is a Path, load from disk/fd. +# - If marked as no_cache, always reload from disk (never cache). +# - Otherwise, cache after first load and return cached object on subsequent accesses. +# - If not a Path, return the in-memory value directly. +# """ +# raw = super().__getitem__(index) +# +# # If it's not a Path, it's an in-memory object: return directly +# if not isinstance(raw, Path): +# return raw +# +# # it's a Path -> use its string as cache key (works for both normal paths and /proc/self/fd/) +# key = raw.as_posix() +# +# # If in cache, return cached object +# if key in self._cache: +# assert key not in self._no_cache_set, "Inconsistent state: item both cached and marked no_cache" +# return self._cache[key] +# +# # Always reload from disk/fd, never cache +# if str(raw).startswith("/proc/self/fd/"): +# obj = self._load_from_fd(key) +# else: +# raise NotImplementedError("Loading from disk is not implemented") +# # return self._load_from_disk(raw) +# +# if key in self._no_cache_set: +# pass # never cache +# else: +# # Cache it permanently (unless marked as no_cache) +# self._cache[key] = obj +# +# return obj +# +# def pop(self, index: int = -1): +# raw = super().pop(index) +# # If it was a Path, return the loaded object (and keep the fd open / mapping intact) +# if isinstance(raw, Path): +# key = raw.as_posix() +# # return cached value if present; else load now and cache it +# if key in self._cache: +# return self._cache[key] +# if str(raw).startswith("/proc/self/fd/"): +# obj = self._load_from_fd(key) +# else: +# obj = self._load_from_disk(raw) +# self._cache[key] = obj +# return obj +# return raw +# +# def clear(self): +# # do not close fds; keep them open until process exit as requested +# # keep cache intact if you want (or clear it if you prefer) +# # here we remove list entries but keep any cached objects and open fds +# super().clear() +# +# def __iter__(self): +# """Iterate over items, loading from disk as needed.""" +# for i in range(len(self)): +# yield self[i] +# +# def __del__(self): +# self._cleanup() +# +# def _detach_recursive(self, obj): +# if isinstance(obj, torch.Tensor): +# return obj.detach().cpu() +# elif isinstance(obj, dict): +# return {k: self._detach_recursive(v) for k, v in obj.items()} +# elif isinstance(obj, (list, tuple)): +# t = type(obj) +# return t(self._detach_recursive(x) for x in obj) +# elif is_dataclass(obj): +# # Replace fields recursively (returns a new instance) +# return replace(obj, **{ +# field.name: self._detach_recursive(getattr(obj, field.name)) +# for field in obj.__dataclass_fields__.values() +# }) +# else: +# return obj + + +from dataclasses import is_dataclass, replace, dataclass +import torch + + +@dataclass +class DetachingCPUList(list): + # TODO Naama: Add back disk saving + def append(self, item, detach_and_cpu=False, save_to_disk=False, no_cache=False): + if detach_and_cpu: + item = self._detach_recursive(item) + super().append(item) + + def extend(self, iterable, detach_and_cpu=False): + if detach_and_cpu: + iterable = (self._detach_recursive(x) for x in iterable) + super().extend(iterable) + + def insert(self, index, item, detach_and_cpu=False, save_to_disk=False, no_cache=False): + if detach_and_cpu: + item = self._detach_recursive(item) + super().insert(index, item) + + def _detach_recursive(self, obj): + if isinstance(obj, torch.Tensor): + return obj.detach().cpu() + elif isinstance(obj, dict): + return {k: self._detach_recursive(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + t = type(obj) + return t(self._detach_recursive(x) for x in obj) + elif is_dataclass(obj): + # Replace fields recursively (returns a new instance) + return replace(obj, **{ + field.name: self._detach_recursive(getattr(obj, field.name)) + for field in obj.__dataclass_fields__.values() + }) + else: + return obj diff --git a/optgs/misc/discrete_probability_distribution.py b/optgs/misc/discrete_probability_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..1b9fcb228098da9a12ce7f7b59fe43c3562083f5 --- /dev/null +++ b/optgs/misc/discrete_probability_distribution.py @@ -0,0 +1,33 @@ +import torch +from einops import reduce +from jaxtyping import Float, Int64 +from torch import Tensor + + +def sample_discrete_distribution( + pdf: Float[Tensor, "*batch bucket"], + num_samples: int, + eps: float = torch.finfo(torch.float32).eps, +) -> tuple[ + Int64[Tensor, "*batch sample"], # index + Float[Tensor, "*batch sample"], # probability density +]: + *batch, bucket = pdf.shape + normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) + cdf = normalized_pdf.cumsum(dim=-1) + samples = torch.rand((*batch, num_samples), device=pdf.device) + index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) + return index, normalized_pdf.gather(dim=-1, index=index) + + +def gather_discrete_topk( + pdf: Float[Tensor, "*batch bucket"], + num_samples: int, + eps: float = torch.finfo(torch.float32).eps, +) -> tuple[ + Int64[Tensor, "*batch sample"], # index + Float[Tensor, "*batch sample"], # probability density +]: + normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) + index = pdf.topk(k=num_samples, dim=-1).indices + return index, normalized_pdf.gather(dim=-1, index=index) diff --git a/optgs/misc/general_utils.py b/optgs/misc/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fa5944a05cea3fd1f4fda9c99591bb1bce3c310d --- /dev/null +++ b/optgs/misc/general_utils.py @@ -0,0 +1,221 @@ +import torch +import numpy as np +import torch.nn.functional as F +from torch.optim.lr_scheduler import LambdaLR + + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + + +def rotate_quats(rots, quats): + # rots: [B, V, 1, 3, 3] + # quats: [B, N, 4] in xyzw format (scalar last) + + from optgs.scene_trainer.common.gaussians import quaternion_to_matrix + from optgs.scene_trainer.common.gaussians import rotation_matrix_to_quaternion_xyzw + + # rotate gaussians to world space + tmp_rotation = F.normalize(quats, dim=-1) # [1, V, HW, 4] + tmp_rotation = quaternion_to_matrix(tmp_rotation) # [1, V, HW, 3, 3] + + # apply rotations + # tmp_rotation = c2w_rotations @ tmp_rotation @ c2w_rotations.transpose(-1, -2) # [B, V, HW, 3, 3] + tmp_rotation = rots @ tmp_rotation # [B, V, HW, 3, 3] + + rotated_quats = rotation_matrix_to_quaternion_xyzw(tmp_rotation) # [B, V, HW, 4] in xyzw (scalar last) + + return rotated_quats + + + +class SkipBatchException(Exception): + """Exception to signal that the current batch should be skipped.""" + pass + + +def test_lr_schedulers(): + """ + Compare PyTorch ExponentialLR with get_expon_lr_func + """ + # Settings + lr_init = 1.6e-4 + max_steps = 30000 + + print("=" * 100) + print("COMPARISON: PyTorch ExponentialLR vs get_expon_lr_func") + print("=" * 100) + + # ======================================================================== + # TEST 1: gsplat-style (no warm-up, lr_final = 0.01 * lr_init) + # ======================================================================== + print("\n" + "=" * 100) + print("TEST 1: gsplat-style (lr_final = 0.01 * lr_init, no warm-up)") + print("=" * 100) + + lr_final_gsplat = 0.01 * lr_init # 1.6e-6 + + # PyTorch ExponentialLR (gsplat style) + dummy_param = torch.nn.Parameter(torch.zeros(1)) + optimizer_torch = torch.optim.Adam([dummy_param], lr=lr_init) + gamma = 0.01 ** (1.0 / max_steps) + scheduler_torch = torch.optim.lr_scheduler.ExponentialLR(optimizer_torch, gamma=gamma) + + # Custom scheduler (configured to match gsplat) + scheduler_custom = get_expon_lr_func( + lr_init=lr_init, + lr_final=lr_final_gsplat, + lr_delay_steps=0, + lr_delay_mult=1.0, + max_steps=max_steps + ) + + print(f"\nSettings:") + print(f" lr_init = {lr_init:.2e}") + print(f" lr_final = {lr_final_gsplat:.2e}") + print(f" lr_delay_steps = 0") + print(f" lr_delay_mult = 1.0") + print(f" max_steps = {max_steps}") + print(f" gamma = {gamma:.10f}") + + test_steps_1 = [0, 1, 10, 100, 500, 1000, 2000, 5000, 10000, 15000, 20000, 25000, 29000, 29900, 29990, 30000] + + print(f"\n{'Step':<10} {'PyTorch LR':<20} {'Custom LR':<20} {'Ratio':<15} {'Abs Diff':<15} {'Rel Diff %':<15}") + print("-" * 100) + + prev_step = 0 + for step in test_steps_1: + # Get PyTorch LR by stepping + if step > 0: + for _ in range(step - prev_step): + scheduler_torch.step() + lr_torch = optimizer_torch.param_groups[0]['lr'] + + # Get custom LR + lr_custom = scheduler_custom(step) + + # Calculate ratio + ratio = lr_custom / lr_torch if lr_torch != 0 else 0 + + # Calculate difference + abs_diff = abs(lr_torch - lr_custom) + rel_diff = abs_diff / lr_torch * 100 if lr_torch != 0 else 0 + + print(f"{step:<10} {lr_torch:<20.10e} {lr_custom:<20.10e} {ratio:<15.6f} {abs_diff:<15.4e} {rel_diff:<15.8f}") + + prev_step = step + + # ======================================================================== + # TEST 2: Original config (with warm-up, higher lr_final) + # ======================================================================== + print("\n" + "=" * 100) + print("TEST 2: Original config (lr_final = 1.0e-5, warm-up with delay_mult=0.01)") + print("=" * 100) + + lr_final_original = 1.0e-5 + lr_delay_steps = 0 + lr_delay_mult = 0.01 + + # PyTorch ExponentialLR + dummy_param2 = torch.nn.Parameter(torch.zeros(1)) + optimizer_torch2 = torch.optim.Adam([dummy_param2], lr=lr_init) + scheduler_torch2 = torch.optim.lr_scheduler.ExponentialLR(optimizer_torch2, gamma=gamma) + + # Custom scheduler (original config) + scheduler_custom_yours = get_expon_lr_func( + lr_init=lr_init, + lr_final=lr_final_original, + lr_delay_steps=lr_delay_steps, + lr_delay_mult=lr_delay_mult, + max_steps=max_steps + ) + + print(f"\nSettings:") + print(f" lr_init = {lr_init:.2e}") + print(f" lr_final = {lr_final_original:.2e}") + print(f" lr_delay_steps = {lr_delay_steps}") + print(f" lr_delay_mult = {lr_delay_mult}") + print(f" max_steps = {max_steps}") + + test_steps_2 = [0, 1, 10, 50, 100, 250, 500, 750, 1000, 1500, 2000, 5000, 10000, 15000, 20000, 25000, 29000, 30000] + + print(f"\n{'Step':<10} {'PyTorch LR':<20} {'Original Custom LR':<20} {'Ratio':<15} {'Abs Diff':<15} {'Rel Diff %':<15}") + print("-" * 100) + + prev_step = 0 + for step in test_steps_2: + # Get PyTorch LR + if step > 0: + for _ in range(step - prev_step): + scheduler_torch2.step() + lr_torch = optimizer_torch2.param_groups[0]['lr'] + + # Get custom LR + lr_custom = scheduler_custom_yours(step) + + # Calculate ratio + ratio = lr_custom / lr_torch if lr_torch != 0 else 0 + + # Calculate difference + abs_diff = abs(lr_torch - lr_custom) + rel_diff = abs_diff / lr_torch * 100 if lr_torch != 0 else 0 + + print(f"{step:<10} {lr_torch:<20.10e} {lr_custom:<20.10e} {ratio:<15.6f} {abs_diff:<15.4e} {rel_diff:<15.8f}") + + prev_step = step + + + # ======================================================================== + # SUMMARY + # ======================================================================== + print("\n" + "=" * 100) + print("SUMMARY") + print("=" * 100) + + print("\nTEST 1 (gsplat-style matching):") + print(" ✓ Custom scheduler matches PyTorch ExponentialLR when configured identically") + print(" ✓ Both decay from 1.6e-4 to 1.6e-6 (1% of initial)") + print(" ✓ Relative difference is < 0.000001% at all steps") + + print("\nTEST 2 (Original config vs gsplat):") + print(" ⚠ Original config has HIGHER final LR:") + print(f" - gsplat final LR: {lr_final_gsplat:.2e}") + print(f" - Original final LR: {lr_final_original:.2e} (~{lr_final_original/lr_final_gsplat:.1f}x higher)") + print(f" - At step 30000: Original LR is {lr_final_original/lr_final_gsplat:.2f}x higher than gsplat") + + print("\n" + "=" * 100) + +if __name__ == "__main__": + test_lr_schedulers() diff --git a/optgs/misc/heterogeneous_pairings.py b/optgs/misc/heterogeneous_pairings.py new file mode 100644 index 0000000000000000000000000000000000000000..76d9b7fd1ef8676ad47135110ecd0795bf86a363 --- /dev/null +++ b/optgs/misc/heterogeneous_pairings.py @@ -0,0 +1,43 @@ +import torch +from einops import repeat +from jaxtyping import Int +from torch import Tensor + +Index = Int[Tensor, "n n-1"] + + +def generate_heterogeneous_index( + n: int, + device: torch.device = torch.device("cpu"), +) -> tuple[Index, Index]: + """Generate indices for all pairs except self-pairs.""" + arange = torch.arange(n, device=device) + + # Generate an index that represents the item itself. + index_self = repeat(arange, "h -> h w", w=n - 1) + + # Generate an index that represents the other items. + index_other = repeat(arange, "w -> h w", h=n).clone() + index_other += torch.ones((n, n), device=device, dtype=torch.int64).triu() + index_other = index_other[:, :-1] + + return index_self, index_other + + +def generate_heterogeneous_index_transpose( + n: int, + device: torch.device = torch.device("cpu"), +) -> tuple[Index, Index]: + """Generate an index that can be used to "transpose" the heterogeneous index. + Applying the index a second time inverts the "transpose." + """ + arange = torch.arange(n, device=device) + ones = torch.ones((n, n), device=device, dtype=torch.int64) + + index_self = repeat(arange, "w -> h w", h=n).clone() + index_self = index_self + ones.triu() + + index_other = repeat(arange, "h -> h w", w=n) + index_other = index_other - (1 - ones.triu()) + + return index_self[:, :-1], index_other[:, :-1] diff --git a/optgs/misc/hf_ckpt.py b/optgs/misc/hf_ckpt.py new file mode 100644 index 0000000000000000000000000000000000000000..b341de12dfc6c19a908946ad9dba92caa336c9d0 --- /dev/null +++ b/optgs/misc/hf_ckpt.py @@ -0,0 +1,119 @@ +"""Resolve Hugging Face Hub checkpoint references to local cached paths. + +Any `checkpointing.pretrained_*` config value may be given as a Hugging Face +reference instead of a local path: + + hf:/// # latest revision on main + hf:///@ # pinned branch/tag/commit + +Example: + + checkpointing.pretrained_model=hf://autonomousvision/learn2splat/model.ckpt + +The file is downloaded once into ``./checkpoints`` (``HF_CACHE_DIR`` below, +relative to the working directory), laid out by its in-repo path, and the +local path is returned, so all downstream ``torch.load`` calls keep working +unchanged. + +Gated/private repos (e.g. ``autonomousvision/learn2splat``) require +authentication: run ``huggingface-cli login`` or set the ``HF_TOKEN`` +environment variable. +""" + +from __future__ import annotations + +from .io import cyan + +HF_PREFIX = "hf://" + +# hf:// checkpoints (and their sibling config.yaml) are downloaded here on +# first access — relative to the working directory — as plain files laid out +# by their in-repo path (e.g. ./checkpoints/dense/checkpoints/model.ckpt), +# instead of the global HF cache's models--*/snapshots// structure. +# huggingface_hub still skips the download when the local copy is current. +HF_CACHE_DIR = "checkpoints" + + +def is_hf_ref(path: str | None) -> bool: + return isinstance(path, str) and path.startswith(HF_PREFIX) + + +def resolve_hf_ref(ref: str) -> str: + """Download an ``hf://`` reference and return the local cached file path.""" + try: + from huggingface_hub import hf_hub_download + except ImportError as e: # pragma: no cover - depends on env + raise ImportError( + "huggingface_hub is required to load 'hf://' checkpoints. " + "Install it with `pip install huggingface_hub`." + ) from e + + body = ref[len(HF_PREFIX):] + revision = None + if "@" in body: + body, revision = body.rsplit("@", 1) + + parts = body.split("/") + if len(parts) < 3: + raise ValueError( + f"Invalid HF checkpoint reference {ref!r}. Expected " + f"'hf:////[@]'." + ) + repo_id = "/".join(parts[:2]) + filename = "/".join(parts[2:]) + + print(cyan(f"Resolving HF checkpoint {ref} (repo={repo_id}, " + f"file={filename}, revision={revision or 'main'})")) + local_path = hf_hub_download( + repo_id=repo_id, filename=filename, revision=revision, + local_dir=HF_CACHE_DIR, + ) + print(cyan(f"Downloaded to {local_path}")) + return local_path + + +def maybe_resolve_hf_ref(path: str | None) -> str | None: + """Resolve `path` if it is an `hf://` reference, otherwise return it as-is.""" + if is_hf_ref(path): + return resolve_hf_ref(path) + return path + + +def hf_sibling_config(ref: str) -> str | None: + """Download the ``config.yaml`` that sits next to an ``hf://`` checkpoint. + + Released checkpoints are laid out as ``/checkpoints/.ckpt`` with + the training config at ``/config.yaml`` (the same `/../../` + relation `_find_config_for_checkpoint` expects). ``hf_hub_download`` only + fetches the requested file, so the sibling config must be fetched + explicitly; pulling it into the same repo/revision snapshot makes it + discoverable. Returns the local path, or ``None`` if ``ref`` is not an + ``hf://`` reference / the sibling does not exist. + """ + if not is_hf_ref(ref): + return None + from pathlib import PurePosixPath + + from huggingface_hub import hf_hub_download + + body = ref[len(HF_PREFIX):] + revision = None + if "@" in body: + body, revision = body.rsplit("@", 1) + parts = body.split("/") + if len(parts) < 3: + return None + repo_id = "/".join(parts[:2]) + file_in_repo = "/".join(parts[2:]) + cfg_in_repo = str(PurePosixPath(file_in_repo).parent.parent / "config.yaml") + try: + local = hf_hub_download( + repo_id=repo_id, filename=cfg_in_repo, revision=revision, + local_dir=HF_CACHE_DIR, + ) + print(cyan(f"Fetched sibling config {cfg_in_repo} -> {local}")) + return local + except Exception as e: # sibling may not exist for non-standard layouts + print(cyan(f"No sibling config.yaml for {ref} ({type(e).__name__}); " + f"will fall back to local config discovery.")) + return None diff --git a/optgs/misc/image_io.py b/optgs/misc/image_io.py new file mode 100644 index 0000000000000000000000000000000000000000..7fcdb2a133f7191fb62e435482f1807898af4164 --- /dev/null +++ b/optgs/misc/image_io.py @@ -0,0 +1,130 @@ +from pathlib import Path +from typing import Union +# import skvideo.io +import imageio +import cv2 +import numpy as np +import torch +import torchvision.transforms as tf +from einops import rearrange, repeat +from jaxtyping import Float, UInt8 +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor + +from optgs.misc.io import CustomPath + +FloatImage = Union[ + Float[Tensor, "height width"], + Float[Tensor, "channel height width"], + Float[Tensor, "batch channel height width"], +] + + +def fig_to_image( + fig: Figure, + dpi: int = 100, + device: torch.device = torch.device("cpu"), +) -> Float[Tensor, "3 height width"]: + buffer = io.BytesIO() + fig.savefig(buffer, format="raw", dpi=dpi) + buffer.seek(0) + data = np.frombuffer(buffer.getvalue(), dtype=np.uint8) + h = int(fig.bbox.bounds[3]) + w = int(fig.bbox.bounds[2]) + data = rearrange(data, "(h w c) -> c h w", h=h, w=w, c=4) + buffer.close() + return (torch.tensor(data, device=device, dtype=torch.float32) / 255)[:3] + + +def prep_image(image: FloatImage) -> UInt8[np.ndarray, "height width channel"]: + # Handle batched images. + if image.ndim == 4: + image = rearrange(image, "b c h w -> c h (b w)") + + # Handle single-channel images. + if image.ndim == 2: + image = rearrange(image, "h w -> () h w") + + # Ensure that there are 3 or 4 channels. + channel, _, _ = image.shape + if channel == 1: + image = repeat(image, "() h w -> c h w", c=3) + assert image.shape[0] in (3, 4) + + # Round-half-up to match torchvision.utils.save_image (3DGS-LM's path). + image = (image.detach().clip(min=0, max=1) * 255 + 0.5).clip(0, 255).type(torch.uint8) + return rearrange(image, "c h w -> h w c").cpu().numpy() + + +def save_image( + image: FloatImage, + path: Union[Path, str], +) -> None: + """Save an image. Assumed to be in range 0-1.""" + + # Create the parent directory if it doesn't already exist. + path = Path(path) + path.parent.mkdir(exist_ok=True, parents=True) + + # Save the image. + Image.fromarray(prep_image(image)).save(path) + + +def load_image( + path: Union[Path, str], +) -> Float[Tensor, "3 height width"]: + return tf.ToTensor()(Image.open(path))[:3] + + +# def save_video( +# images: list[FloatImage], +# path: Union[Path, str], +# fps: None | int = None +# ) -> None: +# """Save an image. Assumed to be in range 0-1.""" + +# # Create the parent directory if it doesn't already exist. +# path = Path(path) +# path.parent.mkdir(exist_ok=True, parents=True) + +# # prepare frames as uint8 HxWx3 numpy arrays in range 0-255 +# frames = [prep_image(img) for img in images] + +# outputdict = {'-pix_fmt': 'yuv420p', '-crf': '23', '-vf': 'setpts=1.*PTS'} +# if fps is not None: +# outputdict['-r'] = str(fps) + +# # pass a string filename +# writer = skvideo.io.FFmpegWriter(str(path), outputdict=outputdict) +# for frame in frames: +# writer.writeFrame(frame) +# writer.close() + +def save_video(images, path, fps=None, iterations=None): + if len(images) < 3: + return + if iterations is not None: + assert len(images) == len(iterations) + + path = CustomPath(path) + path.parent.mkdir(parents=True, exist_ok=True) + if fps is None: + fps = 30 + # ensure frames are uint8 + frames = [ + np.ascontiguousarray(prep_image(img).clip(0, 255).astype("uint8")) + for img in images + ] + # write iteration number on each frame if given + if iterations is not None: + for i in range(len(frames)): + frame = frames[i] + iter_num = iterations[i] + cv2.putText(frame, f"Iter {iter_num}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) + frames[i] = frame + + # TODO Naama: videos cannot be saved with odd dimensions + with imageio.get_writer(str(path), fps=fps) as writer: + for frame in frames: + writer.append_data(frame) \ No newline at end of file diff --git a/optgs/misc/io.py b/optgs/misc/io.py new file mode 100644 index 0000000000000000000000000000000000000000..15361bae4ae86039956f66d60fd6641a83a8dae5 --- /dev/null +++ b/optgs/misc/io.py @@ -0,0 +1,274 @@ +import pathlib +from copy import copy +import numpy as np +import torch +import yaml +from colorama import Fore +from omegaconf import OmegaConf +from yaml.constructor import ConstructorError + +KNOWN_TAGS = ["target", "context", "info", "debug"] + + +class CustomPath(pathlib.Path): + """A custom path class that can be formatted to display as a hyperlink in terminal.""" + + # This is a hack to inherit pathlib.Path and initialize the _flavour property. + # https://stackoverflow.com/questions/61689391/error-with-simple-subclassing-of-pathlib-path-no-flavour-attribute + # noinspection PyProtectedMember + # noinspection PyUnresolvedReferences + _flavour = type(pathlib.Path())._flavour + + def __format__(self, format_spec): + if format_spec == '': + return str(self) + elif format_spec == 'link': + if self.exists(): + return _create_hyperlink(self.resolve()) + else: + # Missing path: find first existing parent + missing_path = self.resolve() + existing_parent = self.parent + while existing_parent and not existing_parent.exists(): + existing_parent = existing_parent.parent + + # Build base error message + base_msg = f"\033[1;31m{missing_path} does not exist.\033[0m" + + if existing_parent and existing_parent.exists(): + parent_link = _create_hyperlink(existing_parent.resolve()) + + # Gather existing parent’s contents + content_msg = "" + if existing_parent.is_dir(): + content = list(existing_parent.iterdir()) + if content: + content_msg = ( + "\n" + cyan("Nearest existing directory contents:") + "\n" + + "\n".join([' ' + _create_hyperlink(p.resolve()) for p in content]) + ) + + return f"{base_msg}\nNearest existing directory: {parent_link}{content_msg}" + else: + return f"{base_msg}\n(No existing parent found.)" + elif format_spec.startswith('last'): + i = int(format_spec[4:]) + return "/".join(self.parts[-i:]) + elif format_spec == 'exists': + if self.exists(): + # Normal case: just print the link + return _create_hyperlink(self.resolve()) + else: + return _create_hyperlink(self.resolve()) + ' does not exist. \nParent directory: ' + _create_hyperlink( + self.parent.resolve()) + else: + return str(self).__format__(format_spec) + + def __iadd__(self, other: str): + return CustomPath(str(self) + other) + + def __add__(self, other: str): + return CustomPath(str(self) + other) + + def is_json(self): + return self.suffix == '.json' + + def is_yaml(self): + return self.suffix == '.yaml' + + def json_encoder(self): + return str(self) + + def __sub__(self, other): + return CustomPath(self.resolve().relative_to(other.resolve())) + + +def _create_hyperlink(text: str | pathlib.Path): + if isinstance(text, pathlib.Path): + text = str(text) + return f'file:///' + text.replace('\\', '/') + + +def cyan(text: str) -> str: + return f"{Fore.CYAN}{text}{Fore.RESET}" + + +class FrequencyScheduler: + def __init__( + self, + last_step: int, + frequencies: list[int] | None = None, + steps: list[int] | None = None, + iters: list[int] | None = None, + enable_target: bool = True, + enable_context: bool = True, + enable_info: bool = True, + enable_debug: bool = True, + ): + if iters is not None: + print("FrequencyScheduler: using iters argument, ignoring frequencies and steps.") + # assert frequencies is None and steps is None, "When iters is provided, frequencies and steps must be None" + elif frequencies is None and steps is None: + # Make sure frequencies and steps are both either None or lists of the same length + frequencies = [99999999] # effectively never + steps = [0] + elif frequencies is None or steps is None: + raise ValueError("frequencies and steps must both be None or both be lists") + else: + assert len(frequencies) == len( + steps), f"frequencies and steps must be same length. Got {len(frequencies)} and {len(steps)}" + assert steps[0] == 0, f"first step must be 0. Got {steps}" + + if iters is not None: + self.iterations = copy(iters) + # check if last step in iters, else add it to iters and sort, remove higher than last_step + self.iterations = sorted([i for i in self.iterations if i <= last_step]) + if last_step not in self.iterations: + self.iterations.append(last_step) + if 0 not in self.iterations: + self.iterations.insert(0, 0) + else: + frequencies = copy(frequencies) + steps = copy(steps) + steps.pop(0) # remove the first step which is always 0 + if last_step not in steps: + steps.append(last_step) # ensure last step is included + pairs = list(zip(frequencies, steps)) + self.iterations: list[int] = self.get_all_iterations(pairs, last_step) + + self.verbose = False + self.last_step = last_step + + self.enabled_tags = { + "target": enable_target, + "context": enable_context, + "info": enable_info, + "debug": enable_debug + } + + self.is_disabled = False + + def set_verbose(self, verbose: bool): + self.verbose = verbose + + def set_all_tags(self, enabled: bool): + for key in self.enabled_tags: + self.enabled_tags[key] = enabled + + def check_iteration(self, iteration: int, tag: str) -> bool: + """Returns True if any frequency event occurs at this iteration.""" + assert tag in KNOWN_TAGS, f"Invalid tag: {tag}, must be in {KNOWN_TAGS}" + if self.enabled_tags[tag]: + return iteration in self.iterations + else: + return False + + def _occurs_at(self, iteration: int, pairs, last_step) -> bool: + """Returns True if any frequency event occurs at this iteration.""" + + if iteration == last_step: + return True + + for freq, end in pairs: + if iteration <= end: + if iteration % freq == 0: + return True + else: + break + + return False + + def get_all_iterations(self, pairs, last_step) -> list[int]: + """Returns a list of all iterations where an event occurs up to the last step.""" + t = 0 + iterations = [] + while t <= last_step: + if self._occurs_at(t, pairs, last_step): + iterations.append(t) + t += 1 + return iterations + + def get_iterations(self, length_of_event: int) -> list[int]: + """Returns a list of all iterations where an event occurs up to the given length.""" + if self.iterations is not None and len(self.iterations) >= length_of_event: + if length_of_event == 1: + return [self.iterations[-1]] + return self.iterations[:length_of_event] + else: + raise ValueError( + f"Not enough iterations up to last_step {self.last_step} to get {length_of_event} events. " + f"Only got {len(self.iterations)} events.") + + def disable(self, flag): + self.is_disabled = flag + + def __call__(self, iteration: int, tag: str = "") -> bool: + if self.is_disabled: + return False + return self.check_iteration(iteration, tag) + + def __repr__(self): + return f"FrequencyScheduler({self.iterations})" + + +def log_mem(tag=""): + torch.cuda.synchronize() + print(f"{tag}: allocated={torch.cuda.memory_allocated() / 1e6:.1f}MB, " + f"reserved={torch.cuda.memory_reserved() / 1e6:.1f}MB, " + f"max_allocated={torch.cuda.max_memory_allocated() / 1e6:.1f}MB") + + +def read_omega_cfg(path: pathlib.Path) -> OmegaConf: + """Reads an OmegaConf YAML file, handling custom tags safely.""" + try: + loaded_cfg = OmegaConf.load(path) + except ConstructorError as e: + # --- 1. Define a safe fallback constructor for the tag --- + def custompath_constructor(loader, node): + # Detect if it's a scalar or sequence + if isinstance(node, yaml.ScalarNode): + value = loader.construct_scalar(node) + return CustomPath(value) + elif isinstance(node, yaml.SequenceNode): + seq = loader.construct_sequence(node) + # joint the seq parts into a path + path = CustomPath() + for part in seq: + path = path / str(part) + print(path) + return path + else: + raise TypeError(f"Unsupported YAML node type for CustomPath: {type(node)}") + + # Register for both the current tag and the legacy `src.` tag: + # checkpoints released/trained before the src->optgs package rename + # embed `...apply:src.misc.io.CustomPath` in their saved config.yaml. + for _tag in ( + 'tag:yaml.org,2002:python/object/apply:optgs.misc.io.CustomPath', + 'tag:yaml.org,2002:python/object/apply:src.misc.io.CustomPath', + ): + yaml.add_constructor(_tag, custompath_constructor) + + # --- 2. Load with PyYAML safely --- + with open(path, "r") as f: + raw_cfg = yaml.load(f, Loader=yaml.FullLoader) + + # --- 3. Convert to OmegaConf --- + loaded_cfg = OmegaConf.create(raw_cfg) + + return loaded_cfg + + +if __name__ == '__main__': + + print_every = FrequencyScheduler( + frequencies=[1, 2, 5], + steps=[0, 5, 10], + last_step=37, + # iters=[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 56, 67, 100] + ) + for i in range(37 + 1): + if print_every(i, "target"): + pass + + print(print_every.get_iterations(15)) diff --git a/optgs/misc/memory_profiler.py b/optgs/misc/memory_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..4f23ffa8a99dc6eb523e16a6ed10ed2f54654514 --- /dev/null +++ b/optgs/misc/memory_profiler.py @@ -0,0 +1,88 @@ +import gc +import torch +from torch.autograd import profiler + + +def report_gpu_tensors(device='cuda:0'): + tensors_info = [] + total = 0 + + # Scan all objects tracked by Python + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) and obj.is_cuda and str(obj.device) == device: + # Calculate memory + mem = obj.numel() * obj.element_size() / 1024**2 + total += mem + + # Try to find variable names pointing to this tensor + names = [name for name, val in globals().items() if val is obj] + name_str = ", ".join(names) if names else "" + + tensors_info.append((name_str, tuple(obj.shape), str(obj.dtype), mem)) + except: + pass + + # Sort by memory usage descending + tensors_info.sort(key=lambda x: -x[3]) + + # Print nicely + print(f"{'Name(s)':>30} | {'Shape':>20} | {'Dtype':>10} | {'Memory (MB)':>12}") + print("-" * 80) + for name, shape, dtype, mem in tensors_info: + print(f"{name:>30} | {str(shape):>20} | {dtype:>10} | {mem:12.2f}") + print("-" * 80) + print(f"Total tracked GPU tensor memory: {total:.2f} MB") + + +def profile_gpu_memory(fn, *args, top_n=10, **kwargs): + """ + Profile GPU memory usage including custom CUDA kernels. + Prints peak memory and top PyTorch operations. + + Args: + fn: function to call (e.g., model.forward) + *args, **kwargs: arguments to pass to fn + top_n: number of top memory-consuming operations to print + """ + # Reset memory stats + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + # Record memory before call + mem_before = torch.cuda.memory_allocated() + + # Use PyTorch profiler for visible PyTorch ops + with profiler.profile(use_cuda=True, record_shapes=True) as prof: + result = fn(*args, **kwargs) + torch.cuda.synchronize() + + # Record memory after call + mem_after = torch.cuda.memory_allocated() + mem_diff = (mem_after - mem_before) / 1024**3 + peak = torch.cuda.max_memory_allocated() / 1024**3 + + print("\n=== GPU Memory Profiling ===") + print("Function:", fn.__name__) + print(f"\nMemory before call: {mem_before / 1024**3:.2f} GiB") + print(f"Memory after call : {mem_after / 1024**3:.2f} GiB") + print(f"Memory diff : {mem_diff:.2f} GiB") + print(f"Peak allocated : {peak:.2f} GiB\n") + + # Get key averages from profiler + key_avg = prof.key_averages() + key_avg_sorted = sorted( + key_avg, + key=lambda k: getattr(k, "self_cuda_memory_usage", 0), + reverse=True + ) + + # Print top N operations + print(f"{'Operation':<40} | {'CUDA Memory (MB)':>15} | {'Shape info':>20} | #Calls") + print("-" * 100) + for evt in key_avg_sorted[:top_n]: + mem_mb = getattr(evt, "self_cuda_memory_usage", 0) / 1024**2 + shapes = str(evt.input_shapes) if hasattr(evt, 'input_shapes') else "-" + print(f"{evt.key:<40} | {mem_mb:15.2f} | {shapes:>20} | {evt.count}") + + return result \ No newline at end of file diff --git a/optgs/misc/nn_module_tools.py b/optgs/misc/nn_module_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..21570e2880349a8c6083118d76183e90a501e617 --- /dev/null +++ b/optgs/misc/nn_module_tools.py @@ -0,0 +1,16 @@ +from torch import nn + + +def convert_to_buffer(module: nn.Module, persistent: bool = True): + # Recurse over child modules. + for name, child in list(module.named_children()): + convert_to_buffer(child, persistent) + + # Also re-save buffers to change persistence. + for name, parameter_or_buffer in ( + *module.named_parameters(recurse=False), + *module.named_buffers(recurse=False), + ): + value = parameter_or_buffer.detach().clone() + delattr(module, name) + module.register_buffer(name, value, persistent=persistent) diff --git a/optgs/misc/render_utils.py b/optgs/misc/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8025c10db948ff8eaad3dd76a4651bc0fed881 --- /dev/null +++ b/optgs/misc/render_utils.py @@ -0,0 +1,330 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/render_utils.py + +import numpy as np +import os +import enum +import types +from typing import List, Mapping, Optional, Text, Tuple, Union +import copy +from PIL import Image +# import mediapy as media +from matplotlib import cm +from tqdm import tqdm + +import torch + + +def normalize(x: np.ndarray) -> np.ndarray: + """Normalization helper function.""" + return x / np.linalg.norm(x) + + +def pad_poses(p: np.ndarray) -> np.ndarray: + """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" + bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) + return np.concatenate([p[..., :3, :4], bottom], axis=-2) + + +def unpad_poses(p: np.ndarray) -> np.ndarray: + """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" + return p[..., :3, :4] + + +def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Recenter poses around the origin.""" + cam2world = average_pose(poses) + transform = np.linalg.inv(pad_poses(cam2world)) + poses = transform @ pad_poses(poses) + return unpad_poses(poses), transform + + +def average_pose(poses: np.ndarray) -> np.ndarray: + """New pose using average position, z-axis, and up vector of input poses.""" + position = poses[:, :3, 3].mean(0) + z_axis = poses[:, :3, 2].mean(0) + up = poses[:, :3, 1].mean(0) + cam2world = viewmatrix(z_axis, up, position) + return cam2world + + +def viewmatrix(lookdir: np.ndarray, up: np.ndarray, + position: np.ndarray) -> np.ndarray: + """Construct lookat view matrix.""" + vec2 = normalize(lookdir) + vec0 = normalize(np.cross(up, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, position], axis=1) + return m + + +def focus_point_fn(poses: np.ndarray) -> np.ndarray: + """Calculate nearest point to all focal axes in poses.""" + directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] + m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) + mt_m = np.transpose(m, [0, 2, 1]) @ m + focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] + return focus_pt + + +def transform_poses_pca(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Transforms poses so principal components lie on XYZ axes. + + Args: + poses: a (N, 3, 4) array containing the cameras' camera to world transforms. + + Returns: + A tuple (poses, transform), with the transformed poses and the applied + camera_to_world transforms. + """ + t = poses[:, :3, 3] + t_mean = t.mean(axis=0) + t = t - t_mean + + eigval, eigvec = np.linalg.eig(t.T @ t) + # Sort eigenvectors in order of largest to smallest eigenvalue. + inds = np.argsort(eigval)[::-1] + eigvec = eigvec[:, inds] + rot = eigvec.T + if np.linalg.det(rot) < 0: + rot = np.diag(np.array([1, 1, -1])) @ rot + + transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) + poses_recentered = unpad_poses(transform @ pad_poses(poses)) + transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) + + # Flip coordinate system if z component of y-axis is negative + if poses_recentered.mean(axis=0)[2, 1] < 0: + poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered + transform = np.diag(np.array([1, -1, -1, 1])) @ transform + + return poses_recentered, transform + # points = np.random.rand(3,100) + # points_h = np.concatenate((points,np.ones_like(points[:1])), axis=0) + # (poses_recentered @ points_h)[0] + # (transform @ pad_poses(poses) @ points_h)[0,:3] + # import pdb; pdb.set_trace() + + # # Just make sure it's it in the [-1, 1]^3 cube + # scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) + # poses_recentered[:, :3, 3] *= scale_factor + # transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform + + # return poses_recentered, transform + + +def generate_ellipse_path(poses: np.ndarray, + n_frames: int = 120, + const_speed: bool = True, + z_variation: float = 0., + z_phase: float = 0.) -> np.ndarray: + """Generate an elliptical render path based on the given poses.""" + # Calculate the focal point for the path (cameras point toward this). + center = focus_point_fn(poses) + # Path height sits at z=0 (in middle of zero-mean capture pattern). + offset = np.array([center[0], center[1], 0]) + + # Calculate scaling for ellipse axes based on input camera positions. + sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) + # Use ellipse that is symmetric about the focal point in xy. + low = -sc + offset + high = sc + offset + # Optional height variation need not be symmetric + z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) + z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) + + def get_positions(theta): + # Interpolate between bounds with trig functions to get ellipse in x-y. + # Optionally also interpolate in z to change camera height along path. + return np.stack([ + low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5), + low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5), + z_variation * (z_low[2] + (z_high - z_low)[2] * + (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)), + ], -1) + + theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True) + positions = get_positions(theta) + + # if const_speed: + + # # Resample theta angles so that the velocity is closer to constant. + # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) + # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) + # positions = get_positions(theta) + + # Throw away duplicated last position. + positions = positions[:-1] + + # Set path's up vector to axis closest to average of input pose up vectors. + avg_up = poses[:, :3, 1].mean(0) + avg_up = avg_up / np.linalg.norm(avg_up) + ind_up = np.argmax(np.abs(avg_up)) + up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) + + return np.stack([viewmatrix(p - center, up, p) for p in positions]) + + +def generate_path(viewpoint_cameras, n_frames=480): + c2ws = np.array([np.linalg.inv(np.asarray( + (cam.world_view_transform.T).cpu().numpy())) for cam in viewpoint_cameras]) + pose = c2ws[:, :3, :] @ np.diag([1, -1, -1, 1]) + pose_recenter, colmap_to_world_transform = transform_poses_pca(pose) + + # generate new poses + new_poses = generate_ellipse_path(poses=pose_recenter, n_frames=n_frames) + # warp back to orignal scale + new_poses = np.linalg.inv(colmap_to_world_transform) @ pad_poses(new_poses) + + traj = [] + for c2w in new_poses: + c2w = c2w @ np.diag([1, -1, -1, 1]) + cam = copy.deepcopy(viewpoint_cameras[0]) + cam.image_height = int(cam.image_height / 2) * 2 + cam.image_width = int(cam.image_width / 2) * 2 + cam.world_view_transform = torch.from_numpy( + np.linalg.inv(c2w).T).float().cuda() + cam.full_proj_transform = (cam.world_view_transform.unsqueeze( + 0).bmm(cam.projection_matrix.unsqueeze(0))).squeeze(0) + cam.camera_center = cam.world_view_transform.inverse()[3, :3] + traj.append(cam) + + return traj + + +def generate_video_render_path(c2ws, n_frames=480): + # c2ws = np.array([np.linalg.inv(np.asarray( + # (cam.world_view_transform.T).cpu().numpy())) for cam in viewpoint_cameras]) + # c2ws: [V, 4, 4] + pose = c2ws[:, :3, :] @ np.diag([1, -1, -1, 1]) + pose_recenter, colmap_to_world_transform = transform_poses_pca(pose) + + # generate new poses + new_poses = generate_ellipse_path(poses=pose_recenter, n_frames=n_frames) + # warp back to orignal scale + new_poses = np.linalg.inv(colmap_to_world_transform) @ pad_poses(new_poses) + + traj = [] + for c2w in new_poses: + c2w = c2w @ np.diag([1, -1, -1, 1]) + # cam = copy.deepcopy(viewpoint_cameras[0]) + # cam.image_height = int(cam.image_height / 2) * 2 + # cam.image_width = int(cam.image_width / 2) * 2 + # cam.world_view_transform = torch.from_numpy( + # np.linalg.inv(c2w).T).float().cuda() + # cam.full_proj_transform = (cam.world_view_transform.unsqueeze( + # 0).bmm(cam.projection_matrix.unsqueeze(0))).squeeze(0) + # cam.camera_center = cam.world_view_transform.inverse()[3, :3] + # traj.append(cam) + traj.append(c2w) + + return traj + + +def load_img(pth: str) -> np.ndarray: + """Load an image and cast to float32.""" + with open(pth, 'rb') as f: + image = np.array(Image.open(f), dtype=np.float32) + return image + + +def create_videos(base_dir, input_dir, out_name, num_frames=480): + """Creates videos out of the images saved to disk.""" + # Last two parts of checkpoint path are experiment name and scene name. + video_prefix = f'{out_name}' + + zpad = max(5, len(str(num_frames - 1))) + def idx_to_str(idx): return str(idx).zfill(zpad) + + os.makedirs(base_dir, exist_ok=True) + render_dist_curve_fn = np.log + + # Load one example frame to get image shape and depth range. + depth_file = os.path.join(input_dir, 'vis', f'depth_{idx_to_str(0)}.tiff') + depth_frame = load_img(depth_file) + shape = depth_frame.shape + p = 3 + distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p]) + lo, hi = [render_dist_curve_fn(x) for x in distance_limits] + print(f'Video shape is {shape[:2]}') + + video_kwargs = { + 'shape': shape[:2], + 'codec': 'h264', + 'fps': 60, + 'crf': 18, + } + + for k in ['depth', 'normal', 'color']: + video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') + input_format = 'gray' if k == 'alpha' else 'rgb' + + file_ext = 'png' if k in ['color', 'normal'] else 'tiff' + idx = 0 + + if k == 'color': + file0 = os.path.join(input_dir, 'renders', + f'{idx_to_str(0)}.{file_ext}') + else: + file0 = os.path.join( + input_dir, 'vis', f'{k}_{idx_to_str(0)}.{file_ext}') + + if not os.path.exists(file0): + print(f'Images missing for tag {k}') + continue + print(f'Making video {video_file}...') + with media.VideoWriter( + video_file, **video_kwargs, input_format=input_format) as writer: + for idx in tqdm(range(num_frames)): + # img_file = os.path.join(input_dir, f'{k}_{idx_to_str(idx)}.{file_ext}') + if k == 'color': + img_file = os.path.join( + input_dir, 'renders', f'{idx_to_str(idx)}.{file_ext}') + else: + img_file = os.path.join( + input_dir, 'vis', f'{k}_{idx_to_str(idx)}.{file_ext}') + + if not os.path.exists(img_file): + ValueError(f'Image file {img_file} does not exist.') + img = load_img(img_file) + if k in ['color', 'normal']: + img = img / 255. + elif k.startswith('depth'): + img = render_dist_curve_fn(img) + img = np.clip((img - np.minimum(lo, hi)) / + np.abs(hi - lo), 0, 1) + img = cm.get_cmap('turbo')(img)[..., :3] + + frame = (np.clip(np.nan_to_num(img), 0., 1.) + * 255.).astype(np.uint8) + writer.add_image(frame) + idx += 1 + + +def save_img_u8(img, pth): + """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" + with open(pth, 'wb') as f: + Image.fromarray( + (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save( + f, 'PNG') + + +def save_img_f32(depthmap, pth): + """Save an image (probably a depthmap) to disk as a float32 TIFF.""" + with open(pth, 'wb') as f: + Image.fromarray(np.nan_to_num(depthmap).astype( + np.float32)).save(f, 'TIFF') diff --git a/optgs/misc/sh_rotation.py b/optgs/misc/sh_rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1a26d8821232ed4f8aab257a7c67c46d905ff0 --- /dev/null +++ b/optgs/misc/sh_rotation.py @@ -0,0 +1,81 @@ +from math import isqrt + +import torch +from e3nn.o3 import matrix_to_angles, wigner_D +from einops import einsum +from jaxtyping import Float +from torch import Tensor + + +def rotate_sh( + sh_coefficients: Float[Tensor, "*#batch n"], + rotations: Float[Tensor, "*#batch 3 3"], +) -> Float[Tensor, "*batch n"]: + device = sh_coefficients.device + dtype = sh_coefficients.dtype + + *_, n = sh_coefficients.shape + alpha, beta, gamma = matrix_to_angles(rotations) + result = [] + for degree in range(isqrt(n)): + with torch.device(device): + sh_rotations = wigner_D(degree, alpha, beta, gamma).type(dtype) + sh_rotated = einsum( + sh_rotations, + sh_coefficients[..., degree**2 : (degree + 1) ** 2], + "... i j, ... j -> ... i", + ) + result.append(sh_rotated) + + return torch.cat(result, dim=-1) + + +if __name__ == "__main__": + from pathlib import Path + + import matplotlib.pyplot as plt + from e3nn.o3 import spherical_harmonics + from matplotlib import cm + from scipy.spatial.transform.rotation import Rotation as R + + device = torch.device("cuda") + + # Generate random spherical harmonics coefficients. + degree = 4 + coefficients = torch.rand((degree + 1) ** 2, dtype=torch.float32, device=device) + + def plot_sh(sh_coefficients, path: Path) -> None: + phi = torch.linspace(0, torch.pi, 100, device=device) + theta = torch.linspace(0, 2 * torch.pi, 100, device=device) + phi, theta = torch.meshgrid(phi, theta, indexing="xy") + x = torch.sin(phi) * torch.cos(theta) + y = torch.sin(phi) * torch.sin(theta) + z = torch.cos(phi) + xyz = torch.stack([x, y, z], dim=-1) + sh = spherical_harmonics(list(range(degree + 1)), xyz, True) + result = einsum(sh, sh_coefficients, "... n, n -> ...") + result = (result - result.min()) / (result.max() - result.min()) + + # Set the aspect ratio to 1 so our sphere looks spherical + fig = plt.figure(figsize=plt.figaspect(1.0)) + ax = fig.add_subplot(111, projection="3d") + ax.plot_surface( + x.cpu().numpy(), + y.cpu().numpy(), + z.cpu().numpy(), + rstride=1, + cstride=1, + facecolors=cm.seismic(result.cpu().numpy()), + ) + # Turn off the axis planes + ax.set_axis_off() + path.parent.mkdir(exist_ok=True, parents=True) + plt.savefig(path) + + for i, angle in enumerate(torch.linspace(0, 2 * torch.pi, 30)): + rotation = torch.tensor( + R.from_euler("x", angle.item()).as_matrix(), device=device + ) + plot_sh(rotate_sh(coefficients, rotation), Path(f"sh_rotation/{i:0>3}.png")) + + print("Done!") diff --git a/optgs/misc/stablize_camera.py b/optgs/misc/stablize_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2c68c27c0a2bec3e549e81014d2949e989ed5f --- /dev/null +++ b/optgs/misc/stablize_camera.py @@ -0,0 +1,137 @@ +""" +https://github.com/google/dynibar/blob/main/ibrnet/data_loaders/llff_data_utils.py +""" + +import numpy as np +import cv2 + + +def render_stabilization_path(poses, k_size=45, start_idx=0, end_idx=None, loop=False): + """Rendering stablizaed camera path.""" + + # hwf = poses[0, :, 4:5] + + poses = poses[start_idx:end_idx] + if loop: + # Go back and forth + poses = np.concatenate([poses, poses[::-1]], axis=0) + + num_frames = poses.shape[0] + output_poses = [] + + input_poses = [] + + for i in range(num_frames): + input_poses.append( + np.concatenate( + [poses[i, :3, 0:1], poses[i, :3, 1:2], poses[i, :3, 3:4]], axis=-1 + ) + ) + + input_poses = np.array(input_poses) + + gaussian_kernel = cv2.getGaussianKernel(ksize=k_size, sigma=-1) + output_r1 = cv2.filter2D(input_poses[:, :, 0], -1, gaussian_kernel) + output_r2 = cv2.filter2D(input_poses[:, :, 1], -1, gaussian_kernel) + + output_r1 = output_r1 / np.linalg.norm(output_r1, axis=-1, keepdims=True) + output_r2 = output_r2 / np.linalg.norm(output_r2, axis=-1, keepdims=True) + + output_t = cv2.filter2D(input_poses[:, :, 2], -1, gaussian_kernel) + + for i in range(num_frames): + output_r3 = np.cross(output_r1[i], output_r2[i]) + + render_pose = np.concatenate( + [ + output_r1[i, :, None], + output_r2[i, :, None], + output_r3[:, None], + output_t[i, :, None], + ], + axis=-1, + ) + + output_poses.append(render_pose[:3, :]) + + return output_poses + + + +def render_looped_stabilization_path( + poses, + start_idx=0, + num_poses=None, + k_size=45, + loop_frames=None +): + """ + Smooth a subset of camera poses and optionally loop back to the starting pose. + + Args: + poses (np.ndarray): Original poses of shape (N, 3, 4) + start_idx (int): Index of the first pose to use + num_poses (int): Number of poses to consider from start_idx + k_size (int): Gaussian kernel size for smoothing + loop_frames (int): If set, number of extra frames to interpolate back to start + + Returns: + output_poses (np.ndarray): Smoothed (and looped) poses + """ + # Slice poses + if num_poses is None: + selected_poses = poses[start_idx:] + else: + selected_poses = poses[start_idx:start_idx + num_poses] + + num_frames = selected_poses.shape[0] + output_poses = [] + + # Convert to columns for smoothing + input_poses = [] + for i in range(num_frames): + input_poses.append( + np.concatenate( + [selected_poses[i, :3, 0:1], selected_poses[i, :3, 1:2], selected_poses[i, :3, 3:4]], axis=-1 + ) + ) + input_poses = np.array(input_poses) + + # Gaussian smoothing + gaussian_kernel = cv2.getGaussianKernel(ksize=k_size, sigma=-1) + output_r1 = cv2.filter2D(input_poses[:, :, 0], -1, gaussian_kernel) + output_r2 = cv2.filter2D(input_poses[:, :, 1], -1, gaussian_kernel) + output_r1 /= np.linalg.norm(output_r1, axis=-1, keepdims=True) + output_r2 /= np.linalg.norm(output_r2, axis=-1, keepdims=True) + output_t = cv2.filter2D(input_poses[:, :, 2], -1, gaussian_kernel) + + # Build smoothed poses + for i in range(num_frames): + r3 = np.cross(output_r1[i], output_r2[i]) + pose = np.concatenate( + [output_r1[i, :, None], output_r2[i, :, None], r3[:, None], output_t[i, :, None]], axis=-1 + ) + output_poses.append(pose[:3, :]) + + output_poses = np.array(output_poses) + + # Optionally loop back to start + if loop_frames is not None and loop_frames > 0: + start_pose = output_poses[0] + end_pose = output_poses[-1] + looped_poses = [] + for i in range(1, loop_frames + 1): + alpha = i / loop_frames + # Linear interpolation for translation + t_interp = (1 - alpha) * end_pose[:, 3] + alpha * start_pose[:, 3] + # Linear interpolation + re-orthonormalize rotations + r1_interp = (1 - alpha) * end_pose[:, 0] + alpha * start_pose[:, 0] + r2_interp = (1 - alpha) * end_pose[:, 1] + alpha * start_pose[:, 1] + r1_interp /= np.linalg.norm(r1_interp) + r2_interp /= np.linalg.norm(r2_interp) + r3_interp = np.cross(r1_interp, r2_interp) + looped_pose = np.stack([r1_interp, r2_interp, r3_interp, t_interp], axis=-1) + looped_poses.append(looped_pose) + output_poses = np.concatenate([output_poses, np.array(looped_poses)], axis=0) + + return output_poses \ No newline at end of file diff --git a/optgs/misc/step_tracker.py b/optgs/misc/step_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..7298ffcf5cb028799d67c4dedf7ef1a4bf6fe802 --- /dev/null +++ b/optgs/misc/step_tracker.py @@ -0,0 +1,23 @@ +from multiprocessing import RLock + +import torch +from jaxtyping import Int64 +from torch import Tensor +from torch.multiprocessing import Manager + + +class StepTracker: + lock: RLock + step: Int64[Tensor, ""] + + def __init__(self): + self.lock = Manager().RLock() + self.step = torch.tensor(0, dtype=torch.int64).share_memory_() + + def set_step(self, step: int) -> None: + with self.lock: + self.step.fill_(step) + + def get_step(self) -> int: + with self.lock: + return self.step.item() diff --git a/optgs/misc/wandb_tools.py b/optgs/misc/wandb_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..fec2cd7eaf176f6038681897400117e3f080aa69 --- /dev/null +++ b/optgs/misc/wandb_tools.py @@ -0,0 +1,134 @@ +import os +from pathlib import Path + +import wandb +from pytorch_lightning.loggers.wandb import WandbLogger +from omegaconf import OmegaConf + +from optgs.misc.LocalLogger import LocalLogger +from optgs.paths import DEBUG + + +def version_to_int(artifact) -> int: + """Convert versions of the form vX to X. For example, v12 to 12.""" + return int(artifact.version[1:]) + + +def download_checkpoint( + run_id: str, + download_dir: Path, + version: str | None, +) -> Path: + api = wandb.Api() + run = api.run(run_id) + + # Find the latest saved model checkpoint. + chosen = None + for artifact in run.logged_artifacts(): + if artifact.type != "model" or artifact.state != "COMMITTED": + continue + + # If no version is specified, use the latest. + if version is None: + if chosen is None or version_to_int(artifact) > version_to_int(chosen): + chosen = artifact + + # If a specific verison is specified, look for it. + elif version == artifact.version: + chosen = artifact + break + + # Download the checkpoint. + download_dir.mkdir(exist_ok=True, parents=True) + root = download_dir / run_id + chosen.download(root=root) + return root / "model.ckpt" + + +def setup_wandb_logger(cfg, cfg_dict) -> WandbLogger | LocalLogger: + if cfg_dict.wandb.mode == "disabled" or cfg.mode != "train": + return LocalLogger() + + wandb_extra_kwargs = {} + + # Detect the wandb id job run if resuming + if cfg_dict.checkpointing.resume: + if cfg_dict.wandb.id is None: + print(f"Resuming wandb run without id, using latest run in output directory.") + # Find the latest wandb run id in the output directory + wandb_dir = cfg_dict.output_dir / "wandb" / "latest-run" + # look for a file name in the format "run-######.wandb" file and extract the id + wandb_files = list(wandb_dir.glob("run-*.wandb")) + assert len(wandb_files) <= 1, "Multiple wandb files found in the latest run directory." + if len(wandb_files) == 1: + wandb_file = wandb_files[0] + wandb_id = wandb_file.stem.split('-')[1] + wandb_extra_kwargs.update({'id': wandb_id, 'resume': "must"}) + + if cfg_dict.wandb.id is not None: + print(f"Setting wandb run with id from cfg {cfg_dict.wandb.id}.") + wandb_extra_kwargs.update({'id': cfg_dict.wandb.id, 'resume': "must"}) + + run_name = os.path.basename(cfg_dict.output_dir) + + if cfg_dict.log_slurm_id: + hostname = os.uname().nodename + job_id = os.environ.get('SLURM_JOB_ID', "local run: " + hostname) + run_name += f" ({job_id})" + + # if debugging, add a tag to the run name + if DEBUG: + run_name += " DEBUG" + cfg_dict.wandb.update({'tags': ['debug']}) + if os.environ.get('WANDB_ENTITY') is not None: + cfg_dict.wandb.update({'entity': os.environ.get('WANDB_ENTITY')}) + + logger = WandbLogger( + entity=cfg_dict.wandb.entity, + project=cfg_dict.wandb.project, + mode=cfg_dict.wandb.mode, + name=run_name, + tags=cfg_dict.wandb.get("tags", None), + log_model=False, + save_dir=cfg_dict.output_dir, + config=OmegaConf.to_container(cfg_dict), + **wandb_extra_kwargs, + ) + + if logger.experiment is not None: + # Log code + logger.experiment.log_code("optgs") + # Log notes + if cfg_dict.wandb.notes is not None: + logger.experiment.notes = cfg_dict.wandb.notes + # Write wandb run ID to file for SLURM requeue resume + wandb_id_file = os.environ.get("WANDB_ID_FILE") + if wandb_id_file: + with open(wandb_id_file, "w") as f: + f.write(logger.experiment.id) + print(f"Wrote wandb run ID {logger.experiment.id} to {wandb_id_file}") + + return logger + + +def update_checkpoint_path(path: str | None, wandb_cfg: dict) -> Path | None: + if path is None: + return None + + if not str(path).startswith("wandb://"): + return Path(path) + + run_id, *version = path[len("wandb://") :].split(":") + if len(version) == 0: + version = None + elif len(version) == 1: + version = version[0] + else: + raise ValueError("Invalid version specifier!") + + project = wandb_cfg["project"] + return download_checkpoint( + f"{project}/{run_id}", + Path("checkpoints"), + version, + ) diff --git a/optgs/model/__init__.py b/optgs/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/colmap_utils/__init__.py b/optgs/model/colmap_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/colmap_utils/convert_to_colmap.py b/optgs/model/colmap_utils/convert_to_colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6ba58fef9070bb96da63f680f92421c434279c --- /dev/null +++ b/optgs/model/colmap_utils/convert_to_colmap.py @@ -0,0 +1,57 @@ +import json +import numpy as np +from pathlib import Path +from .read_write_model import Camera, write_cameras_binary + + +def save_opencv_camera(K: np.ndarray, json_path: str, output_path, image_size: tuple, camera_id=1): + assert K.shape == (3, 3), "Intrinsic matrix K must be 3x3" + + # Extract intrinsics from K + fx = K[0, 0] + fy = K[1, 1] + cx = K[0, 2] + cy = K[1, 2] + + # Load distortion coefficients from JSON + with open(json_path, 'r') as f: + data = json.load(f) + + k1 = data.get("k1", 0.0) + k2 = data.get("k2", 0.0) + p1 = data.get("p1", 0.0) + p2 = data.get("p2", 0.0) + + # Image resolution + w, h = image_size + + # COLMAP OPENCV model requires 8 parameters + params = (fx, fy, cx, cy, k1, k2, p1, p2) + + # Create and write Camera object + camera = Camera( + id=camera_id, + model="OPENCV", + width=w, + height=h, + params=params + ) + + output_path = Path(output_path) + output_path.mkdir(parents=True, exist_ok=True) + write_cameras_binary({camera_id: camera}, output_path / "cameras.bin") + +# Example usage +if __name__ == "__main__": + K = np.array([ + [629.23, 0, 480.0], + [0, 625.73, 270.0], + [0, 0, 1.0] + ]) + save_opencv_camera( + K=K, + json_path="transforms.json", + output_path="output_model", + image_size=(960, 540), # e.g., 4x downsampled from 3840x2160 + camera_id=1 + ) diff --git a/optgs/model/colmap_utils/extract_sparse_view_extrinsics.py b/optgs/model/colmap_utils/extract_sparse_view_extrinsics.py new file mode 100644 index 0000000000000000000000000000000000000000..9db9c22273c2403adc1ce8afd1e890336a5007f5 --- /dev/null +++ b/optgs/model/colmap_utils/extract_sparse_view_extrinsics.py @@ -0,0 +1,54 @@ +import numpy as np +from pathlib import Path +from .read_write_model import ( + read_images_binary, + write_images_binary, + write_points3D_binary, + Image as ColmapImage +) + +def extract_sparse_images_bin(input_model_dir, output_model_dir, selected_image_ids, keep_features=False): + input_model_dir = Path(input_model_dir) + output_model_dir = Path(output_model_dir) + output_model_dir.mkdir(parents=True, exist_ok=True) + + # Load full images.bin + images = read_images_binary(input_model_dir / "images.bin") + + # Select and blank the images + sparse_images = {} + for image_id in selected_image_ids: + image = images[image_id] + + if keep_features: + xys = image.xys + point3D_ids = image.point3D_ids + else: + xys = np.empty((0, 2)) + point3D_ids = np.empty((0,), dtype=int) + + blank_image = ColmapImage( + id=image.id, + qvec=image.qvec, + tvec=image.tvec, + camera_id=image.camera_id, + name=image.name, + xys=xys, + point3D_ids=point3D_ids + ) + sparse_images[image_id] = blank_image + + # Save sparse images.bin + write_images_binary(sparse_images, output_model_dir / "images.bin") + + # Save empty points3D.bin + write_points3D_binary({}, output_model_dir / "points3D.bin") + +# Example usage +if __name__ == "__main__": + selected_ids = [1, 4, 10, 20] # Replace with your sparse frame IDs + extract_sparse_images_bin( + input_model_dir="dense_model", + output_model_dir="sparse_model", + selected_image_ids=selected_ids + ) diff --git a/optgs/model/colmap_utils/read_write_model.py b/optgs/model/colmap_utils/read_write_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1510f12ce5d623a703f9cf812b34b697d3514d5c --- /dev/null +++ b/optgs/model/colmap_utils/read_write_model.py @@ -0,0 +1,605 @@ +# Copyright (c), ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +import argparse +import collections +import os +import struct + +import numpy as np + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"] +) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"] +) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), +} +CAMERA_MODEL_IDS = dict( + [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] +) +CAMERA_MODEL_NAMES = dict( + [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] +) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera( + id=camera_id, + model=model, + width=width, + height=height, + params=params, + ) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ" + ) + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes( + fid, + num_bytes=8 * num_params, + format_char_sequence="d" * num_params, + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = ( + "# Camera list with one line of data per camera:\n" + + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + + "# Number of cameras: {}\n".format(len(cameras)) + ) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, model_id, cam.width, cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack( + [ + tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3])), + ] + ) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi" + ) + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + binary_image_name = b"" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + binary_image_name += current_char + current_char = read_next_bytes(fid, 1, "c")[0] + image_name = binary_image_name.decode("utf-8") + num_points2D = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q" + )[0] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [ + tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3])), + ] + ) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def write_images_text(images, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum( + (len(img.point3D_ids) for _, img in images.items()) + ) / len(images) + HEADER = ( + "# Image list with two lines of data per image:\n" + + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + + "# Number of images: {}, mean observations per image: {}\n".format( + len(images), mean_observations + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [ + img.id, + *img.qvec, + *img.tvec, + img.camera_id, + img.name, + ] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def read_points3D_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q" + )[0] + track_elems = read_next_bytes( + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum( + (len(pt.image_ids) for _, pt in points3D.items()) + ) / len(points3D) + HEADER = ( + "# 3D point list with one line of data per point:\n" + + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + + "# Number of points: {}, mean track length: {}\n".format( + len(points3D), mean_track_length + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3D_binary(points3D, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def detect_model_format(path, ext): + if ( + os.path.isfile(os.path.join(path, "cameras" + ext)) + and os.path.isfile(os.path.join(path, "images" + ext)) + and os.path.isfile(os.path.join(path, "points3D" + ext)) + ): + print("Detected model format: '" + ext + "'") + return True + + return False + + +def read_model(path, ext=""): + # try to detect the extension automatically + if ext == "": + if detect_model_format(path, ".bin"): + ext = ".bin" + elif detect_model_format(path, ".txt"): + ext = ".txt" + else: + print("Provide model format: '.bin' or '.txt'") + return + + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext=".bin"): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def main(): + parser = argparse.ArgumentParser( + description="Read and write COLMAP binary and text models" + ) + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument( + "--input_format", + choices=[".bin", ".txt"], + help="input model format", + default="", + ) + parser.add_argument("--output_model", help="path to output model folder") + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="output model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model( + path=args.input_model, ext=args.input_format + ) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, + images, + points3D, + path=args.output_model, + ext=args.output_format, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/optgs/model/decoder/__init__.py b/optgs/model/decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0cbca2a3ae533857b36ff06e384bff2a54b8c24d --- /dev/null +++ b/optgs/model/decoder/__init__.py @@ -0,0 +1,37 @@ +from ...dataset import DatasetCfg +from .decoder import Decoder +from .gsplat_decoder_splatting_cuda import GSplatDecoderSplattingCUDACfg, GSplatDecoderSplattingCUDA + +DECODERS = { + "gsplat": GSplatDecoderSplattingCUDA, +} + +DecoderCfg = GSplatDecoderSplattingCUDACfg + +# The inria decoder is optional (it needs diff_gaussian_rasterization). +# Importing this package must NOT require that backend — gsplat is the +# default. If the inria decoder is actually requested while the backend is +# missing, raise a clear, chained ImportError (mirrors the RoMa handling in +# optgs/experimental/edgs/init.py) instead of silently degrading. +try: + from .decoder_splatting_cuda import DecoderSplattingCUDACfg, DecoderSplattingCUDA + DECODERS["inria"] = DecoderSplattingCUDA + DecoderCfg = GSplatDecoderSplattingCUDACfg | DecoderSplattingCUDACfg +except ImportError as _e: + # `except ... as _e` is auto-deleted at block end; keep a stable ref so the + # closure below can chain from the original error. + _INRIA_IMPORT_ERROR = _e + + def _inria_decoder_unavailable(*_args, **_kwargs): + raise ImportError( + "The inria decoder requires diff_gaussian_rasterization, which is " + "not installed. Install it with: " + "pip install git+https://github.com/graphdeco-inria/diff-gaussian-rasterization.git" + ) from _INRIA_IMPORT_ERROR + + DECODERS["inria"] = _inria_decoder_unavailable + + +def get_decoder(decoder_cfg: DecoderCfg, dataset_cfg: DatasetCfg) -> Decoder: + print(f"Using decoder: {decoder_cfg.name}") + return DECODERS[decoder_cfg.name](decoder_cfg, dataset_cfg) diff --git a/optgs/model/decoder/cuda_splatting.py b/optgs/model/decoder/cuda_splatting.py new file mode 100644 index 0000000000000000000000000000000000000000..3beb348b05c7609e5e72f84e313f14cfd014fb84 --- /dev/null +++ b/optgs/model/decoder/cuda_splatting.py @@ -0,0 +1,308 @@ +from math import isqrt +from typing import Literal + +import torch + +try: + from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, + ) +except ImportError as e: + raise ImportError( + "The inria decoder requires diff_gaussian_rasterization, which is " + "not installed. Install it with: " + "pip install git+https://github.com/graphdeco-inria/diff-gaussian-rasterization.git" + ) from e + +from einops import einsum, rearrange, repeat +from jaxtyping import Float, Int +from torch import Tensor + +from ...geometry.projection import get_fov, homogenize_points + + +def get_projection_matrix( + near: Float[Tensor, " batch"], + far: Float[Tensor, " batch"], + fov_x: Float[Tensor, " batch"], + fov_y: Float[Tensor, " batch"], +) -> Float[Tensor, "batch 4 4"]: + """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z + axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after + transformation and that Z is flipped. + """ + tan_fov_x = (0.5 * fov_x).tan() + tan_fov_y = (0.5 * fov_y).tan() + + top = tan_fov_y * near + bottom = -top + right = tan_fov_x * near + left = -right + + (b,) = near.shape + result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device) + result[:, 0, 0] = 2 * near / (right - left) + result[:, 1, 1] = 2 * near / (top - bottom) + result[:, 0, 2] = (right + left) / (right - left) + result[:, 1, 2] = (top + bottom) / (top - bottom) + result[:, 3, 2] = 1 + result[:, 2, 2] = far / (far - near) + result[:, 2, 3] = -(far * near) / (far - near) + return result + + +def render_cuda( + extrinsics: Float[Tensor, "batch 4 4"], + intrinsics: Float[Tensor, "batch 3 3"], + near: Float[Tensor, " batch"], + far: Float[Tensor, " batch"], + image_shape: tuple[int, int], + background_color: Float[Tensor, "batch 3"], + gaussian_means: Float[Tensor, "batch gaussian 3"], + gaussian_covariances: Float[Tensor, "batch gaussian 3 3"] | None, + gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"], + gaussian_opacities: Float[Tensor, "batch gaussian"], + scale_invariant: bool = True, + use_sh: bool = True, + gaussian_scales: Float[Tensor, "batch gaussian 3"] | None = None, + gaussian_rotations: Float[Tensor, "batch gaussian 4"] | None = None, +) -> tuple[ + Float[Tensor, "batch 3 height width"], + Int[Tensor, "batch gaussian"], + Float[Tensor, "batch gaussian 2"], +]: + assert use_sh or gaussian_sh_coefficients.shape[-1] == 1 + # Exactly one of (covariances) or (scales+rotations) must be supplied. + using_cov = gaussian_covariances is not None + using_sr = gaussian_scales is not None and gaussian_rotations is not None + assert using_cov ^ using_sr, "Provide either gaussian_covariances or (gaussian_scales+gaussian_rotations)." + + # Make sure everything is in a range where numerical issues don't appear. + if scale_invariant: + scale = 1 / near + extrinsics = extrinsics.clone() + extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * scale[:, None] + if using_cov: + gaussian_covariances = gaussian_covariances * (scale[:, None, None, None] ** 2) + else: + gaussian_scales = gaussian_scales * scale[:, None, None] + gaussian_means = gaussian_means * scale[:, None, None] + near = near * scale + far = far * scale + + _, _, _, n = gaussian_sh_coefficients.shape + degree = isqrt(n) - 1 + shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() + + b, _, _ = extrinsics.shape + h, w = image_shape + + fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1) + tan_fov_x = (0.5 * fov_x).tan() + tan_fov_y = (0.5 * fov_y).tan() + + cxs = intrinsics[:, 0, 2] * w + cys = intrinsics[:, 1, 2] * h + + projection_matrix = get_projection_matrix(near, far, fov_x, fov_y) + projection_matrix = rearrange(projection_matrix, "b i j -> b j i") + view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i") + full_projection = view_matrix @ projection_matrix + + # The 3DGS-LM fork's settings carry cx/cy/prepare_for_gsgn_backward; stock Inria does not. + _settings_fields = set(GaussianRasterizationSettings._fields) + _fork_has_cxcy = "cx" in _settings_fields and "cy" in _settings_fields + _fork_has_gsgn = "prepare_for_gsgn_backward" in _settings_fields + + all_images = [] + all_radii = [] + all_means2d = [] + for i in range(b): + # Set up a tensor for the gradients of the screen-space means. + mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) + try: + mean_gradients.retain_grad() + except Exception: + pass + + settings_kwargs = dict( + image_height=h, + image_width=w, + tanfovx=tan_fov_x[i].item(), + tanfovy=tan_fov_y[i].item(), + bg=background_color[i], + scale_modifier=1.0, + viewmatrix=view_matrix[i], + projmatrix=full_projection[i], + sh_degree=degree, + campos=extrinsics[i, :3, 3], + prefiltered=False, + debug=False, + ) + if _fork_has_cxcy: + settings_kwargs["cx"] = float(cxs[i].item()) + settings_kwargs["cy"] = float(cys[i].item()) + if _fork_has_gsgn: + settings_kwargs["prepare_for_gsgn_backward"] = False + settings = GaussianRasterizationSettings(**settings_kwargs) + rasterizer = GaussianRasterizer(settings) + + raster_kwargs = dict( + means3D=gaussian_means[i], + means2D=mean_gradients, + shs=shs[i] if use_sh else None, + colors_precomp=None if use_sh else shs[i, :, 0, :], + opacities=gaussian_opacities[i, ..., None], + ) + if using_cov: + row, col = torch.triu_indices(3, 3) + raster_kwargs["cov3D_precomp"] = gaussian_covariances[i, :, row, col] + else: + raster_kwargs["scales"] = gaussian_scales[i] + raster_kwargs["rotations"] = gaussian_rotations[i] + out = rasterizer(**raster_kwargs) + # Stock returns (image, radii); 3DGS-LM fork returns (image, radii, n_contrib, is_hit). + image, radii = out[0], out[1] + all_images.append(image) + all_radii.append(radii) + all_means2d.append(mean_gradients[:, :2]) + return torch.stack(all_images), torch.stack(all_radii), torch.stack(all_means2d) + + +def render_cuda_orthographic( + extrinsics: Float[Tensor, "batch 4 4"], + width: Float[Tensor, " batch"], + height: Float[Tensor, " batch"], + near: Float[Tensor, " batch"], + far: Float[Tensor, " batch"], + image_shape: tuple[int, int], + background_color: Float[Tensor, "batch 3"], + gaussian_means: Float[Tensor, "batch gaussian 3"], + gaussian_covariances: Float[Tensor, "batch gaussian 3 3"], + gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"], + gaussian_opacities: Float[Tensor, "batch gaussian"], + fov_degrees: float = 0.1, + use_sh: bool = True, + dump: dict | None = None, +) -> Float[Tensor, "batch 3 height width"]: + b, _, _ = extrinsics.shape + h, w = image_shape + assert use_sh or gaussian_sh_coefficients.shape[-1] == 1 + + _, _, _, n = gaussian_sh_coefficients.shape + degree = isqrt(n) - 1 + shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() + + # Create fake "orthographic" projection by moving the camera back and picking a + # small field of view. + fov_x = torch.tensor(fov_degrees, device=extrinsics.device).deg2rad() + tan_fov_x = (0.5 * fov_x).tan() + distance_to_near = (0.5 * width) / tan_fov_x + tan_fov_y = 0.5 * height / distance_to_near + fov_y = (2 * tan_fov_y).atan() + near = near + distance_to_near + far = far + distance_to_near + move_back = torch.eye(4, dtype=torch.float32, device=extrinsics.device) + move_back[2, 3] = -distance_to_near + extrinsics = extrinsics @ move_back + + # Escape hatch for visualization/figures. + if dump is not None: + dump["extrinsics"] = extrinsics + dump["fov_x"] = fov_x + dump["fov_y"] = fov_y + dump["near"] = near + dump["far"] = far + + projection_matrix = get_projection_matrix( + near, far, repeat(fov_x, "-> b", b=b), fov_y + ) + projection_matrix = rearrange(projection_matrix, "b i j -> b j i") + view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i") + full_projection = view_matrix @ projection_matrix + + all_images = [] + all_radii = [] + for i in range(b): + # Set up a tensor for the gradients of the screen-space means. + mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) + try: + mean_gradients.retain_grad() + except Exception: + pass + + settings = GaussianRasterizationSettings( + image_height=h, + image_width=w, + tanfovx=tan_fov_x, + tanfovy=tan_fov_y, + bg=background_color[i], + scale_modifier=1.0, + viewmatrix=view_matrix[i], + projmatrix=full_projection[i], + sh_degree=degree, + campos=extrinsics[i, :3, 3], + prefiltered=False, # This matches the original usage. + debug=False, + ) + rasterizer = GaussianRasterizer(settings) + + row, col = torch.triu_indices(3, 3) + + image, radii = rasterizer( + means3D=gaussian_means[i], + means2D=mean_gradients, + shs=shs[i] if use_sh else None, + colors_precomp=None if use_sh else shs[i, :, 0, :], + opacities=gaussian_opacities[i, ..., None], + cov3D_precomp=gaussian_covariances[i, :, row, col], + ) + all_images.append(image) + all_radii.append(radii) + return torch.stack(all_images) + + +DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"] + + +def render_depth_cuda( + extrinsics: Float[Tensor, "batch 4 4"], + intrinsics: Float[Tensor, "batch 3 3"], + near: Float[Tensor, " batch"], + far: Float[Tensor, " batch"], + image_shape: tuple[int, int], + gaussian_means: Float[Tensor, "batch gaussian 3"], + gaussian_covariances: Float[Tensor, "batch gaussian 3 3"], + gaussian_opacities: Float[Tensor, "batch gaussian"], + scale_invariant: bool = True, + mode: DepthRenderingMode = "depth", +) -> Float[Tensor, "batch height width"]: + # Specify colors according to Gaussian depths. + camera_space_gaussians = einsum( + extrinsics.inverse(), homogenize_points(gaussian_means), "b i j, b g j -> b g i" + ) + fake_color = camera_space_gaussians[..., 2] + + if mode == "disparity": + fake_color = 1 / fake_color + elif mode == "log": + fake_color = fake_color.minimum(near[:, None]).maximum(far[:, None]).log() + + # Render using depth as color. + b, _ = fake_color.shape + images, _, _ = render_cuda( + extrinsics, + intrinsics, + near, + far, + image_shape, + torch.zeros((b, 3), dtype=fake_color.dtype, device=fake_color.device), + gaussian_means, + gaussian_covariances, + repeat(fake_color, "b g -> b g c ()", c=3), + gaussian_opacities, + scale_invariant=scale_invariant, + use_sh=False, + ) + return images.mean(dim=1) diff --git a/optgs/model/decoder/decoder.py b/optgs/model/decoder/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..48690f73ba699e3bda4cc7ecc90f32906eb7cc2d --- /dev/null +++ b/optgs/model/decoder/decoder.py @@ -0,0 +1,164 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Generic, Literal, TypeVar + +import torch +from jaxtyping import Float, Int32, Bool, UInt8 +from torch import Tensor, nn + +from ..types import Gaussians +from ...dataset import DatasetCfg +from ...dataset.data_types import BatchedViews, BatchedViewsDict, BatchedExample +from ...scene_trainer.gaussian_module import GaussiansModule + + +DepthRenderingMode = Literal[ + "depth", + "log", + "disparity", + "relative_disparity", +] + + +@dataclass +class DecoderOutput: + color: Float[Tensor, "batch view 3 height width"] | UInt8[Tensor, "batch view 3 height width"] + depth: Float[Tensor, "batch view height width"] | None + normal: Float[Tensor, "batch view 3 height width"] | None = None + distortion_map: Float[Tensor, "batch view height width"] | None = None + accumulated_alpha: Float[Tensor, "batch view height width"] | None = None + radii: Int32[Tensor, "batch view n 2"] | None = None + means2d: Float[Tensor, "batch view n 2"] | None = None + visibility_filter: Bool[Tensor, "batch view n"] | None = None + + +T = TypeVar("T") + + +class Decoder(nn.Module, ABC, Generic[T]): + cfg: T + dataset_cfg: DatasetCfg + + def __init__(self, cfg: T, dataset_cfg: DatasetCfg) -> None: + super().__init__() + self.cfg = cfg + self.dataset_cfg = dataset_cfg + + @abstractmethod + def forward( + self, + gaussians: Gaussians | GaussiansModule, + extrinsics: Float[Tensor, "batch view 4 4"], + intrinsics: Float[Tensor, "batch view 3 3"], + near: Float[Tensor, "batch view"], + far: Float[Tensor, "batch view"], + image_shape: tuple[int, int], + depth_mode: DepthRenderingMode | None = None, + to_cpu: bool = False, + ) -> DecoderOutput: + pass + + def forward_batch( + self, + gaussians: Gaussians | GaussiansModule, + batch: BatchedExample, + image_shape: tuple[int, int] | None = None, + input_str: Literal["context", "target"] | None = None, + eval_context_views: bool | None = None, + depth_mode: DepthRenderingMode | None = None, + start=None, end=None, + camera_poses=None, # In case of manipulating camera poses (e.g. for stabilization) + to_cpu: bool = False, # move outputs to cpu as they are rendered + iter_batch_size: int = -1, # -1 to render all views at once + ) -> DecoderOutput: + + assert input_str is not None or eval_context_views is not None + if input_str is None: + input_str = "context" if eval_context_views else "target" + + input = batch[input_str] + + if image_shape is None: + image_shape = input["image_shape"].shape[-2:] + if camera_poses is None: + camera_poses = input["extrinsics"] + return self.forward( + gaussians, + camera_poses[:, start:end], + input["intrinsics"][:, start:end], + input["near"][:, start:end], + input["far"][:, start:end], + image_shape, + depth_mode=depth_mode, + to_cpu=to_cpu, + iter_batch_size=iter_batch_size, + ) + + def forward_batch_subset(self, gaussians: Gaussians | GaussiansModule, + batch_subset: BatchedViewsDict | BatchedViews, + image_shape: tuple[int, int] | None = None, + start: int | None = None, + end: int | None = None, + indices: torch.Tensor | list | None = None, + **kwargs) -> DecoderOutput: + + assert not ((start is not None and end is not None) and ( + indices is not None)), "Either start and end or indices must be provided." + if start is not None: + indices = list(range(start, end)) + + if indices is None: + indices = list(range(batch_subset["extrinsics"].shape[1])) + + if isinstance(indices, list): + # Convert list to tensor for one flow handling + indices = torch.tensor(indices, device=batch_subset["extrinsics"].device) + indices = indices.unsqueeze(0).expand(batch_subset["extrinsics"].shape[0], -1) # (batch, num_indices) + + if image_shape is None: + image_shape = batch_subset["image"].shape[-2:] + + assert indices.dim() == 2, "Indices tensor must be 2D (scene_batch, num_indices)." + scene_batch = indices.size(0) + scene_batch_idx = torch.arange(scene_batch, device=indices.device)[:, None] # (batch, 1) + return self.forward(gaussians, + batch_subset["extrinsics"][scene_batch_idx, indices], + batch_subset["intrinsics"][scene_batch_idx, indices], + batch_subset["near"][scene_batch_idx, indices], + batch_subset["far"][scene_batch_idx, indices], + image_shape, + **kwargs) + + def forward_context( + self, + gaussians: Gaussians | GaussiansModule, + batch: BatchedExample, + image_shape: tuple[int, int] | None = None, + depth_mode: DepthRenderingMode | None = None, + **kwargs, + ) -> DecoderOutput: + return self.forward_batch( + gaussians, + batch, + image_shape, + "context", + depth_mode=depth_mode, + **kwargs, + ) + + def forward_target( + self, + gaussians: Gaussians | GaussiansModule, + batch: BatchedExample, + image_shape: tuple[int, int] | None = None, + depth_mode: DepthRenderingMode | None = None, + **kwargs, + ) -> DecoderOutput: + return self.forward_batch( + gaussians, + batch, + image_shape, + "target", + depth_mode=depth_mode, + **kwargs, + ) diff --git a/optgs/model/decoder/decoder_splatting_cuda.py b/optgs/model/decoder/decoder_splatting_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..b274cf2d56f9b4425e4968872a03876f49fab76b --- /dev/null +++ b/optgs/model/decoder/decoder_splatting_cuda.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from typing import Literal + +import torch +from einops import rearrange, repeat +from jaxtyping import Float +from torch import Tensor +from tqdm import tqdm + +from ...dataset import DatasetCfg +from ...scene_trainer.gaussian_module import GaussiansModule +from ..types import Gaussians +from .cuda_splatting import DepthRenderingMode, render_cuda, render_depth_cuda +from .decoder import Decoder, DecoderOutput + + +@dataclass +class DecoderSplattingCUDACfg: + name: Literal["inria"] + scale_invariant: bool + # False: pass scales+rotations and let the CUDA kernel compute the covariance + # (matches 3DGS-LM byte-for-byte). True: precompute Python-side and pass + # cov3D_precomp (~42 dB pixel drift from LM, slightly faster on repeat calls). + use_covariances: bool = False + + +class DecoderSplattingCUDA(Decoder[DecoderSplattingCUDACfg]): + background_color: Float[Tensor, "3"] + + def __init__( + self, + cfg: DecoderSplattingCUDACfg, + dataset_cfg: DatasetCfg, + ) -> None: + super().__init__(cfg, dataset_cfg) + self.register_buffer( + "background_color", + torch.tensor(dataset_cfg.background_color, dtype=torch.float32), + persistent=False, + ) + + def forward( + self, + gaussians: Gaussians | GaussiansModule, + extrinsics: Float[Tensor, "batch view 4 4"], + intrinsics: Float[Tensor, "batch view 3 3"], + near: Float[Tensor, "batch view"], + far: Float[Tensor, "batch view"], + image_shape: tuple[int, int], + depth_mode: DepthRenderingMode | None = None, + return_radii: bool = False, + iter_batch_size: int = -1, + to_cpu: bool = False, + ) -> DecoderOutput: + b, v, _, _ = extrinsics.shape + bv = b * v + + # Flatten camera params to (B*V) + flat_ext = rearrange(extrinsics, "b v i j -> (b v) i j") + flat_int = rearrange(intrinsics, "b v i j -> (b v) i j") + flat_near = rearrange(near, "b v -> (b v)") + flat_far = rearrange(far, "b v -> (b v)") + flat_bg = repeat(self.background_color, "c -> (b v) c", b=b, v=v) + + # Prepare Gaussian tensors in flat (B*V) format + scales = rotations_wxyz = covars = None + if isinstance(gaussians, GaussiansModule): + means = repeat(gaussians.means, "g xyz -> bv g xyz", bv=bv) + shs = repeat(gaussians.harmonics, "g c d -> bv g c d", bv=bv) + opacities = repeat(gaussians.opacities, "g -> bv g", bv=bv) + if self.cfg.use_covariances: + covars = repeat(gaussians.covariances, "g i j -> bv g i j", bv=bv) + else: + scales = repeat(gaussians.scales, "g d -> bv g d", bv=bv) + # gaussians.rotations is xyzw post-normalization; the rasterizer wants wxyz. + rotations_wxyz = repeat(gaussians.rotations[:, [3, 0, 1, 2]], "g d -> bv g d", bv=bv) + + elif isinstance(gaussians, Gaussians): + means = repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v) + shs = repeat(gaussians.harmonics, "b g c d -> (b v) g c d", v=v) + opacities = repeat(gaussians.opacities, "b g -> (b v) g", v=v) + if self.cfg.use_covariances: + if gaussians.covariances is None: + raise ValueError("use_covariances=true but gaussians.covariances is None.") + covars = repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v) + else: + _scales = gaussians.scales if gaussians.stores_activated else torch.exp(gaussians.scales) + scales = repeat(_scales, "b g d -> (b v) g d", v=v) + rotations_wxyz = repeat(gaussians.rotations[..., [3, 0, 1, 2]], "b g d -> (b v) g d", v=v) + if not gaussians.stores_activated: + opacities = torch.sigmoid(opacities) + else: + raise ValueError(f"Unknown gaussians type: {type(gaussians)}") + + def _render_flat(s: slice): + imgs, radii, means2d = render_cuda( + flat_ext[s], + flat_int[s], + flat_near[s], + flat_far[s], + image_shape, + flat_bg[s], + means[s], + covars[s] if covars is not None else None, + shs[s], + opacities[s], + scale_invariant=self.cfg.scale_invariant, + gaussian_scales=scales[s] if scales is not None else None, + gaussian_rotations=rotations_wxyz[s] if rotations_wxyz is not None else None, + ) + return imgs, radii, means2d + + if iter_batch_size < 0: + imgs, radii_flat, means2d_flat = _render_flat(slice(None)) + if to_cpu: + imgs = imgs.detach().cpu() + radii_flat = radii_flat.detach().cpu() + means2d_flat = means2d_flat.detach().cpu() + else: + all_imgs, all_radii, all_means2d = [], [], [] + for i in tqdm(range(0, bv, iter_batch_size), desc="Rendering in batches"): + s = slice(i, min(i + iter_batch_size, bv)) + imgs_c, rad_c, m2d_c = _render_flat(s) + if to_cpu: + imgs_c = imgs_c.detach().cpu() + rad_c = rad_c.detach().cpu() + m2d_c = m2d_c.detach().cpu() + all_imgs.append(imgs_c) + all_radii.append(rad_c) + all_means2d.append(m2d_c) + imgs = torch.cat(all_imgs, dim=0) + radii_flat = torch.cat(all_radii, dim=0) + means2d_flat = torch.cat(all_means2d, dim=0) + + # Reshape (B*V) → (B, V) + color = rearrange(imgs, "(b v) c h w -> b v c h w", b=b, v=v) + radii_bv = rearrange(radii_flat, "(b v) n -> b v n", b=b, v=v) + means2d_bv = rearrange(means2d_flat, "(b v) n d -> b v n d", b=b, v=v) + + # Expand scalar radii [B, V, N] → [B, V, N, 2] to match gsplat interface + radii_out = radii_bv.unsqueeze(-1).expand(-1, -1, -1, 2).contiguous() + visibility_filter = radii_bv > 0 # [B, V, N] + + depth = ( + self._render_depth(gaussians, extrinsics, intrinsics, near, far, image_shape, depth_mode) + if depth_mode is not None + else None + ) + + return DecoderOutput( + color=color, + depth=depth, + accumulated_alpha=None, + means2d=means2d_bv, + radii=radii_out, + visibility_filter=visibility_filter, + ) + + def _render_depth( + self, + gaussians: Gaussians | GaussiansModule, + extrinsics: Float[Tensor, "batch view 4 4"], + intrinsics: Float[Tensor, "batch view 3 3"], + near: Float[Tensor, "batch view"], + far: Float[Tensor, "batch view"], + image_shape: tuple[int, int], + mode: DepthRenderingMode = "depth", + ) -> Float[Tensor, "batch view height width"]: + b, v, _, _ = extrinsics.shape + + if isinstance(gaussians, GaussiansModule): + means = repeat(gaussians.means, "g xyz -> (b v) g xyz", b=b, v=v) + covars = repeat(gaussians.covariances, "g i j -> (b v) g i j", b=b, v=v) + opacities = repeat(gaussians.opacities, "g -> (b v) g", b=b, v=v) + else: + means = repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v) + covars = repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v) + opacities = repeat(gaussians.opacities, "b g -> (b v) g", v=v) + if not gaussians.stores_activated: + opacities = torch.sigmoid(opacities) + + result = render_depth_cuda( + rearrange(extrinsics, "b v i j -> (b v) i j"), + rearrange(intrinsics, "b v i j -> (b v) i j"), + rearrange(near, "b v -> (b v)"), + rearrange(far, "b v -> (b v)"), + image_shape, + means, + covars, + opacities, + mode=mode, + scale_invariant=self.cfg.scale_invariant, + ) + return rearrange(result, "(b v) h w -> b v h w", b=b, v=v) diff --git a/optgs/model/decoder/diffgs_cuda_splatting.py b/optgs/model/decoder/diffgs_cuda_splatting.py new file mode 100644 index 0000000000000000000000000000000000000000..8663d799d4d5db66e8da248a6d7763d7a6c04633 --- /dev/null +++ b/optgs/model/decoder/diffgs_cuda_splatting.py @@ -0,0 +1,136 @@ +from math import isqrt +from typing import Literal + +import torch +from diff_gs import ( + GaussianRasterizationSettings, + GaussianRasterizer, +) +from einops import einsum, rearrange, repeat +from jaxtyping import Float +from torch import Tensor + +from ...geometry.projection import get_fov, homogenize_points + + +def get_projection_matrix( + near: Float[Tensor, " batch"], + far: Float[Tensor, " batch"], + fov_x: Float[Tensor, " batch"], + fov_y: Float[Tensor, " batch"], +) -> Float[Tensor, "batch 4 4"]: + """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z + axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after + transformation and that Z is flipped. + """ + tan_fov_x = (0.5 * fov_x).tan() + tan_fov_y = (0.5 * fov_y).tan() + + top = tan_fov_y * near + bottom = -top + right = tan_fov_x * near + left = -right + + (b,) = near.shape + result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device) + result[:, 0, 0] = 2 * near / (right - left) + result[:, 1, 1] = 2 * near / (top - bottom) + result[:, 0, 2] = (right + left) / (right - left) + result[:, 1, 2] = (top + bottom) / (top - bottom) + result[:, 3, 2] = 1 + result[:, 2, 2] = far / (far - near) + result[:, 2, 3] = -(far * near) / (far - near) + return result + + +def render_cuda( + extrinsics: Float[Tensor, "batch 4 4"], + intrinsics: Float[Tensor, "batch 3 3"], + near: Float[Tensor, " batch"], + far: Float[Tensor, " batch"], + image_shape: tuple[int, int], + background_color: Float[Tensor, "batch 3"], + gaussian_means: Float[Tensor, "batch gaussian 3"], + gaussian_covariances: Float[Tensor, "batch gaussian 3 3"], + gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"], + gaussian_opacities: Float[Tensor, "batch gaussian"], + scale_invariant: bool = False, + use_sh: bool = True, +): + assert use_sh or gaussian_sh_coefficients.shape[-1] == 1 + + assert scale_invariant is False + + # Make sure everything is in a range where numerical issues don't appear. + if scale_invariant: + scale = 1 / near + extrinsics = extrinsics.clone() + extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * scale[:, None] + gaussian_covariances = gaussian_covariances * (scale[:, None, None, None] ** 2) + gaussian_means = gaussian_means * scale[:, None, None] + near = near * scale + far = far * scale + + _, _, _, n = gaussian_sh_coefficients.shape + degree = isqrt(n) - 1 + shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() + + b, _, _ = extrinsics.shape + h, w = image_shape + + fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1) + tan_fov_x = (0.5 * fov_x).tan() + tan_fov_y = (0.5 * fov_y).tan() + + projection_matrix = get_projection_matrix(near, far, fov_x, fov_y) + projection_matrix = rearrange(projection_matrix, "b i j -> b j i") + view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i") + full_projection = view_matrix @ projection_matrix + + all_images = [] + all_radii = [] + all_depths = [] + for i in range(b): + # Set up a tensor for the gradients of the screen-space means. + mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) + try: + mean_gradients.retain_grad() + except Exception: + pass + + settings = GaussianRasterizationSettings( + image_height=h, + image_width=w, + tanfovx=tan_fov_x[i].item(), + tanfovy=tan_fov_y[i].item(), + bg=background_color[i], + scale_modifier=1.0, + viewmatrix=view_matrix[i], + projmatrix=full_projection[i], + sh_degree=degree, + campos=extrinsics[i, :3, 3], + prefiltered=False, + debug=False, + antialiasing=False, + ) + rasterizer = GaussianRasterizer(settings) + + row, col = torch.triu_indices(3, 3) + + image, radii, depth = rasterizer( + means3D=gaussian_means[i], + means2D=mean_gradients, + shs=shs[i] if use_sh else None, + colors_precomp=None if use_sh else shs[i, :, 0, :], + opacities=gaussian_opacities[i, ..., None], + cov3D_precomp=gaussian_covariances[i, :, row, col], + ) + all_images.append(image) + all_radii.append(radii) + all_depths.append(depth) + return { + 'image': torch.stack(all_images), + 'depth': torch.stack(all_depths), + 'radii': torch.stack(all_radii), + } + diff --git a/optgs/model/decoder/diffgs_decoder_splatting_cuda.py b/optgs/model/decoder/diffgs_decoder_splatting_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..f03f2588c0d76aa02a3554e0484c7053f0866d37 --- /dev/null +++ b/optgs/model/decoder/diffgs_decoder_splatting_cuda.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass +from typing import Literal + +import torch +from einops import rearrange, repeat +from jaxtyping import Float +from torch import Tensor + +from ...dataset import DatasetCfg +from ..types import Gaussians +from .decoder import DepthRenderingMode +from .diffgs_cuda_splatting import render_cuda +from .decoder import Decoder, DecoderOutput + + +@dataclass +class DiffgsDecoderSplattingCUDACfg: + name: Literal["diffgs"] + scale_invariant: bool + + +class DiffgsDecoderSplattingCUDA(Decoder[DiffgsDecoderSplattingCUDACfg]): + background_color: Float[Tensor, "3"] + + def __init__( + self, + cfg: DiffgsDecoderSplattingCUDACfg, + dataset_cfg: DatasetCfg, + ) -> None: + super().__init__(cfg, dataset_cfg) + self.register_buffer( + "background_color", + torch.tensor(dataset_cfg.background_color, dtype=torch.float32), + persistent=False, + ) + + def forward( + self, + gaussians: Gaussians, + extrinsics: Float[Tensor, "batch view 4 4"], + intrinsics: Float[Tensor, "batch view 3 3"], + near: Float[Tensor, "batch view"], + far: Float[Tensor, "batch view"], + image_shape: tuple[int, int], + depth_mode: DepthRenderingMode | None = None, + ) -> DecoderOutput: + b, v, _, _ = extrinsics.shape + out = render_cuda( + rearrange(extrinsics, "b v i j -> (b v) i j"), + rearrange(intrinsics, "b v i j -> (b v) i j"), + rearrange(near, "b v -> (b v)"), + rearrange(far, "b v -> (b v)"), + image_shape, + repeat(self.background_color, "c -> (b v) c", b=b, v=v), + repeat(gaussians.means, "b g xyz -> (b v) g xyz", v=v), + repeat(gaussians.covariances, "b g i j -> (b v) g i j", v=v), + repeat(gaussians.harmonics, "b g c d_sh -> (b v) g c d_sh", v=v), + repeat(gaussians.opacities, "b g -> (b v) g", v=v), + scale_invariant=self.cfg.scale_invariant, + ) + color = rearrange(out['image'], "(b v) c h w -> b v c h w", b=b, v=v) + # the output is inverse depth (c == 1) + depth = 1. / rearrange(out['depth'], "(b v) c h w -> b (v c) h w", b=b, v=v).clamp(min=1e-6) + + return DecoderOutput( + color, + depth + ) diff --git a/optgs/model/decoder/gsplat_decoder_splatting_cuda.py b/optgs/model/decoder/gsplat_decoder_splatting_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..123c72c6cc626b1107f437f4d63dceeb7516630f --- /dev/null +++ b/optgs/model/decoder/gsplat_decoder_splatting_cuda.py @@ -0,0 +1,204 @@ +from dataclasses import dataclass +from typing import Literal + +import math +import torch +# import torch.nn.functional as F +from einops import repeat +from gsplat.rendering import rasterization +from jaxtyping import Float +from torch import Tensor +from tqdm import tqdm +from optgs.scene_trainer.gaussian_module import GaussiansModule +from .decoder import Decoder, DecoderOutput +from .decoder import DepthRenderingMode +from ..types import Gaussians +from ...dataset import DatasetCfg + + +@dataclass +class GSplatDecoderSplattingCUDACfg: + name: Literal["gsplat"] + use_covariances: bool + rasterize_mode: Literal["antialiased", "classic"] + eps2d: float + + +class GSplatDecoderSplattingCUDA(Decoder[GSplatDecoderSplattingCUDACfg]): + background_color: Float[Tensor, "3"] + + def __init__( + self, + cfg: GSplatDecoderSplattingCUDACfg, + dataset_cfg: DatasetCfg, + ) -> None: + super().__init__(cfg, dataset_cfg) + self.register_buffer( + "background_color", + torch.tensor(dataset_cfg.background_color, dtype=torch.float32), + persistent=False, + ) + + def forward( + self, + gaussians: Gaussians | GaussiansModule, + extrinsics: Float[Tensor, "batch view 4 4"], + intrinsics: Float[Tensor, "batch view 3 3"], + near: Float[Tensor, "batch view"], + far: Float[Tensor, "batch view"], + image_shape: tuple[int, int], + depth_mode: DepthRenderingMode | None = None, # always render depth + return_radii: bool = False, # always return radii + iter_batch_size: int = -1, # -1 to render all views at once + use_covariances: bool = False, # override cfg + to_cpu: bool = False, # move outputs to cpu + ) -> DecoderOutput: + + _use_covariances = self.cfg.use_covariances if use_covariances is None else use_covariances + + height, width = image_shape + + if isinstance(gaussians, GaussiansModule): + + # nb: no batch dimension + + means = gaussians.means + quats = gaussians.rotations # [N, 4] in xyzw (scalar last) + quats = quats[:, [3, 0, 1, 2]] # [N, 4] in wxyz (scalar first) + quats = quats # [1, N, 4] + scales = gaussians.scales # post-activation + opacities = gaussians.opacities # post-activation + colors = gaussians.harmonics.permute(0, 2, 1) # [1, N, d_sh, 3] + if _use_covariances: + covars = gaussians.covariances + else: + covars = None + + # add batch dimension + means = means.unsqueeze(0) # [1, N, 3] + quats = quats.unsqueeze(0) # [1, N, 4] + scales = scales.unsqueeze(0) # [1, N, 3] + opacities = opacities.unsqueeze(0) # [1, N, 1] + colors = colors.unsqueeze(0) # [1, N, d_sh, 3] + if covars is not None: + covars = covars.unsqueeze(0) # [1, N, 3, 3] + + elif isinstance(gaussians, Gaussians): + + means = gaussians.means # [B, N, 3] + quats = gaussians.rotations_unnorm # [B, N, 4] in wxyz (scalar first), rasterization normalizes internally + quats = quats[:, :, [3, 0, 1, 2]] # [B, N, 4] in wxyz (scalar first) + scales = gaussians.scales # [B, N, 3] + opacities = gaussians.opacities # [B, G] + colors = gaussians.harmonics.permute(0, 1, 3, 2) # [B, N, d_sh, 3] + if _use_covariances: + covars = gaussians.covariances # [B, N, 3, 3] + if covars is None: + raise ValueError("Covariances are set to be used, but gaussians.covariances is None.") + else: + covars = None + + if gaussians.stores_activated: + # already activated + pass + else: + # activate + scales = torch.exp(scales) # [B, N, 3] + opacities = torch.sigmoid(opacities) # [B, N] + + else: + raise ValueError(f"Unknown type of gaussians: {type(gaussians)}") + + # prepare inputs for rasterization + sh_degree = int(math.sqrt(colors.shape[-2])) - 1 # d_sh = (degree + 1) ** 2 + viewmats = extrinsics.inverse() # [B, V, 4, 4] + + # scale intrinsics to image shape (avoid clone by creating scaled version directly) + intrinsics_scaled = intrinsics * intrinsics.new_tensor([[[width], [height], [1]]]) # [B, V, 3, 3] + + def _render(viewmats, Ks): + + # rasterize + render_colors, render_alphas, meta = rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + sh_degree=sh_degree, + viewmats=viewmats, + Ks=Ks, + width=width, + height=height, + # near_plane=near[0, 0].item(), # use default + # far_plane=far[0, 0].item(), # use default + eps2d=self.cfg.eps2d, + rasterize_mode=self.cfg.rasterize_mode, + packed=False, + # absgrad=False, # use default + # sparse_grad=False, # use default + render_mode="RGB+ED", + # with_ut=False, # use default + # with_eval3d=False, # use default + # covars=covars, # use default + ) + + # unpack outputs + color = render_colors[..., :3].permute(0, 1, 4, 2, 3) # [B, V, 3, H, W] + depth = render_colors[..., -1] # [B, V, H, W] + means2d = meta["means2d"] # [B, V, N, 2] + radii = meta["radii"] # [B, V, N, 2] + visibility_filter = torch.all(radii > 0, dim=-1) # [B, V, N] + + return color, depth, render_alphas, means2d, visibility_filter, radii + + # split into chunks to save memory + nr_views = extrinsics.shape[1] + if iter_batch_size < 0: + # render all views at once + color, depth, render_alphas, means2d, visibility_filter, radii = _render(viewmats, intrinsics_scaled) + if to_cpu: + color = color.detach().cpu() + depth = depth.detach().cpu() + render_alphas = render_alphas.detach().cpu() + means2d = means2d.detach().cpu() + visibility_filter = visibility_filter.detach().cpu() + radii = radii.detach().cpu() + else: + # split into chunks + chunk_outputs = [] + for i in tqdm(range(0, nr_views, iter_batch_size), desc="Rendering in batches"): + if i + iter_batch_size > nr_views: + bs = nr_views - i + else: + bs = iter_batch_size + iter_viewmats = viewmats[:, i : i + bs] # [B, v, 4, 4] + iter_intrinsics = intrinsics_scaled[:, i : i + bs] # [B, v, 3, 3] + color, depth, render_alphas, means2d, visibility_filter, radii = _render(iter_viewmats, iter_intrinsics) + if to_cpu: + color = color.detach().cpu() + depth = depth.detach().cpu() + render_alphas = render_alphas.detach().cpu() + means2d = means2d.detach().cpu() + visibility_filter = visibility_filter.detach().cpu() + radii = radii.detach().cpu() + chunk_outputs.append((color, depth, render_alphas, means2d, visibility_filter, radii)) + + # concatenate all chunks + color = torch.cat([o[0] for o in chunk_outputs], dim=1) # [B, V, 3, H, W] + depth = torch.cat([o[1] for o in chunk_outputs], dim=1) # [B, V, H, W] + render_alphas = torch.cat([o[2] for o in chunk_outputs], dim=1) # [B, V, H, W, 1] + means2d = torch.cat([o[3] for o in chunk_outputs], dim=1) # [B, V, N, 2] + visibility_filter = torch.cat([o[4] for o in chunk_outputs], dim=1) # [B, V, N] + radii = torch.cat([o[5] for o in chunk_outputs], dim=1) # [B, V, N, 2] + + return DecoderOutput( + color, + depth=depth, + accumulated_alpha=render_alphas.squeeze(-1), # [B, V, H, W] + means2d=means2d, + visibility_filter=visibility_filter, + radii=radii, + ) + + diff --git a/optgs/model/encoder/.deprecated/common/__init__.py b/optgs/model/encoder/.deprecated/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/.deprecated/common/sampler.py b/optgs/model/encoder/.deprecated/common/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf5b61674093d00ef1d1705835f43648e4bc679 --- /dev/null +++ b/optgs/model/encoder/.deprecated/common/sampler.py @@ -0,0 +1,42 @@ +from jaxtyping import Float, Int64, Shaped +from torch import Tensor, nn + +from ....misc.discrete_probability_distribution import ( + gather_discrete_topk, + sample_discrete_distribution, +) + + +class Sampler(nn.Module): + def forward( + self, + probabilities: Float[Tensor, "*batch bucket"], + num_samples: int, + deterministic: bool, + ) -> tuple[ + Int64[Tensor, "*batch 1"], # index + Float[Tensor, "*batch 1"], # probability density + ]: + return ( + gather_discrete_topk(probabilities, num_samples) + if deterministic + else sample_discrete_distribution(probabilities, num_samples) + ) + + def gather( + self, + index: Int64[Tensor, "*batch sample"], + target: Shaped[Tensor, "..."], # *batch bucket *shape + ) -> Shaped[Tensor, "..."]: # *batch sample *shape + """Gather from the target according to the specified index. Handle the + broadcasting needed for the gather to work. See the comments for the actual + expected input/output shapes since jaxtyping doesn't support multiple variadic + lengths in annotations. + """ + bucket_dim = index.ndim - 1 + while len(index.shape) < len(target.shape): + index = index[..., None] + broadcasted_index_shape = list(target.shape) + broadcasted_index_shape[bucket_dim] = index.shape[bucket_dim] + index = index.broadcast_to(broadcasted_index_shape) + return target.gather(dim=bucket_dim, index=index) diff --git a/optgs/model/encoder/.deprecated/foundationstereo/Utils.py b/optgs/model/encoder/.deprecated/foundationstereo/Utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a76573aa69b26a0a785f969ecbeabe67975221eb --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/Utils.py @@ -0,0 +1,134 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +import os, sys, time,torch,torchvision,pickle,itertools,datetime,imageio,logging,importlib,argparse +import torch.nn.functional as F +import torch.nn as nn +from functools import partial +# import open3d as o3d +import cv2 +import numpy as np +# from transformations import * +code_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(code_dir) + + + +def set_logging_format(level=logging.INFO): + importlib.reload(logging) + FORMAT = '%(message)s' + logging.basicConfig(level=level, format=FORMAT, datefmt='%m-%d|%H:%M:%S') + +set_logging_format() + + + +def set_seed(random_seed): + import torch,random + np.random.seed(random_seed) + random.seed(random_seed) + torch.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def toOpen3dCloud(points,colors=None,normals=None): + cloud = o3d.geometry.PointCloud() + cloud.points = o3d.utility.Vector3dVector(points.astype(np.float64)) + if colors is not None: + if colors.max()>1: + colors = colors/255.0 + cloud.colors = o3d.utility.Vector3dVector(colors.astype(np.float64)) + if normals is not None: + cloud.normals = o3d.utility.Vector3dVector(normals.astype(np.float64)) + return cloud + + + +def depth2xyzmap(depth:np.ndarray, K, uvs:np.ndarray=None, zmin=0.1): + invalid_mask = (depthmax_H or W_resize>max_W: + if H_resize>W_resize: + W_resize = round_by_divider(W_resize*max_H/H_resize) + H_resize = max_H + else: + H_resize = round_by_divider(H_resize*max_W/W_resize) + W_resize = max_W + return int(H_resize), int(W_resize) + + +def vis_disparity(disp, min_val=None, max_val=None, invalid_thres=np.inf, color_map=cv2.COLORMAP_TURBO, cmap=None, other_output={}): + """ + @disp: np array (H,W) + @invalid_thres: > thres is invalid + """ + disp = disp.copy() + H,W = disp.shape[:2] + invalid_mask = disp>=invalid_thres + if (invalid_mask==0).sum()==0: + other_output['min_val'] = None + other_output['max_val'] = None + return np.zeros((H,W,3)) + if min_val is None: + min_val = disp[invalid_mask==0].min() + if max_val is None: + max_val = disp[invalid_mask==0].max() + other_output['min_val'] = min_val + other_output['max_val'] = max_val + vis = ((disp-min_val)/(max_val-min_val)).clip(0,1) * 255 + if cmap is None: + vis = cv2.applyColorMap(vis.clip(0, 255).astype(np.uint8), color_map)[...,::-1] + else: + vis = cmap(vis.astype(np.uint8))[...,:3]*255 + if invalid_mask.any(): + vis[invalid_mask] = 0 + return vis.astype(np.uint8) + + diff --git a/optgs/model/encoder/.deprecated/foundationstereo/__init__.py b/optgs/model/encoder/.deprecated/foundationstereo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/.deprecated/foundationstereo/config/large.yaml b/optgs/model/encoder/.deprecated/foundationstereo/config/large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..39b74b93de8f590517fb709ee6a814037feba4c9 --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/config/large.yaml @@ -0,0 +1,33 @@ +corr_implementation: reg +corr_levels: 2 +corr_radius: 4 +finetune_ckpt_name: model_best_bp2.pth +finetune_from: null +hidden_dims: +- 128 +- 128 +- 128 +img_gamma: null +inference_tile: 0 +low_memory: 0 +max_disp: 416 +max_val_sample: null +mixed_precision: true +n_downsample: 2 +n_gru_layers: 3 +disp_head_dim: 1 +notes: '' +num_steps: 200000 +num_worker: 8 +slow_fast_gru: false +tags_more: [] +tile_min_overlap: +- 16 +- 16 +tile_wtype: gaussian +time_limit: 14400 +train_iters: 22 +val_interval: 1 +valid_iters: 32 +wdecay: 0 +world_size: 32 diff --git a/optgs/model/encoder/.deprecated/foundationstereo/config/small.yaml b/optgs/model/encoder/.deprecated/foundationstereo/config/small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4358c390672e3432f5a34476eaa74900460bc312 --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/config/small.yaml @@ -0,0 +1,34 @@ +corr_implementation: reg +corr_levels: 2 +corr_radius: 4 +finetune_ckpt_name: model_best_bp2.pth +finetune_from: null +hidden_dims: +- 128 +- 128 +- 128 +img_gamma: null +inference_tile: 0 +low_memory: 0 +max_disp: 416 +max_val_sample: null +mixed_precision: true +n_downsample: 2 +n_gru_layers: 3 +disp_head_dim: 1 +notes: '' +num_steps: 200000 +num_worker: 8 +slow_fast_gru: false +tags_more: [] +tile_min_overlap: +- 16 +- 16 +tile_wtype: gaussian +time_limit: 14400 +train_iters: 22 +val_interval: 1 +valid_iters: 32 +wdecay: 0 +world_size: 32 +vit_type: vits \ No newline at end of file diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/__init__.py b/optgs/model/encoder/.deprecated/foundationstereo/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/extractor.py b/optgs/model/encoder/.deprecated/foundationstereo/core/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..bdee3e6b444a2bc797f41df76e6e46f91d340745 --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/core/extractor.py @@ -0,0 +1,382 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +import torch,logging,os,sys,urllib,warnings +import torch.nn as nn +import torch.nn.functional as F +code_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(f'{code_dir}/../') +from core.submodule import * +from Utils import * +import timm + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn=='layer': + self.norm1 = LayerNorm2d(planes) + self.norm2 = LayerNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = LayerNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.Sequential() + + if stride == 1 and in_planes == planes: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.conv1(y) + y = self.norm1(y) + y = self.relu(y) + y = self.conv2(y) + y = self.norm2(y) + y = self.relu(y) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class MultiBasicEncoder(nn.Module): + def __init__(self, output_dim=[128], norm_fn='batch', dropout=0.0, downsample=3): + super(MultiBasicEncoder, self).__init__() + self.norm_fn = norm_fn + self.downsample = downsample + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn=='layer': + self.norm1 = LayerNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=1 + (downsample > 1)) + self.layer3 = self._make_layer(128, stride=1 + (downsample > 0)) + self.layer4 = self._make_layer(128, stride=2) + self.layer5 = self._make_layer(128, stride=2) + + output_list = [] + + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(128, 128, self.norm_fn, stride=1), + nn.Conv2d(128, dim[2], 3, padding=1)) + output_list.append(conv_out) + + self.outputs04 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(128, 128, self.norm_fn, stride=1), + nn.Conv2d(128, dim[1], 3, padding=1)) + output_list.append(conv_out) + + self.outputs08 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Conv2d(128, dim[0], 3, padding=1) + output_list.append(conv_out) + + self.outputs16 = nn.ModuleList(output_list) + + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + else: + self.dropout = None + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x, dual_inp=False, num_layers=3): + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + if dual_inp: + v = x + x = x[:(x.shape[0]//2)] + + outputs04 = [f(x) for f in self.outputs04] + if num_layers == 1: + return (outputs04, v) if dual_inp else (outputs04,) + + y = self.layer4(x) + outputs08 = [f(y) for f in self.outputs08] + + if num_layers == 2: + return (outputs04, outputs08, v) if dual_inp else (outputs04, outputs08) + + z = self.layer5(y) + outputs16 = [f(z) for f in self.outputs16] + + return (outputs04, outputs08, outputs16, v) if dual_inp else (outputs04, outputs08, outputs16) + + + +class ContextNetDino(MultiBasicEncoder): + def __init__(self, output_dim=[128], norm_fn='batch', downsample=3): + nn.Module.__init__(self) + self.patch_size = 14 + self.image_size = 518 + self.vit_feat_dim = 384 + code_dir = os.path.dirname(os.path.realpath(__file__)) + + self.out_dims = output_dim + + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn=='layer': + self.norm1 = LayerNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1 + (downsample > 2), padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=1 + (downsample > 1)) + self.layer3 = self._make_layer(128, stride=1 + (downsample > 0)) + self.layer4 = self._make_layer(128, stride=2) + self.layer5 = self._make_layer(128, stride=2) + # unused parameters + # self.down = nn.Sequential( + # nn.Conv2d(128, 128, kernel_size=4, stride=4, padding=0), + # nn.BatchNorm2d(128), + # ) + self.conv2 = BasicConv(128+128, 128, kernel_size=3, padding=1) + self.norm = nn.BatchNorm2d(256) + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(128, 128, self.norm_fn, stride=1), + nn.Conv2d(128, dim[2], 3, padding=1)) + output_list.append(conv_out) + + self.outputs04 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(128, 128, self.norm_fn, stride=1), + nn.Conv2d(128, dim[1], 3, padding=1)) + output_list.append(conv_out) + + self.outputs08 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Conv2d(128, dim[0], 3, padding=1) + output_list.append(conv_out) + + self.outputs16 = nn.ModuleList(output_list) + + def forward(self, x_in, vit_feat, dual_inp=False, num_layers=3): + B,C,H,W = x_in.shape + x = self.conv1(x_in) + x = self.norm1(x) + x = self.relu1(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + divider = np.lcm(self.patch_size, 16) + H_resize, W_resize = get_resize_keep_aspect_ratio(H,W, divider=divider, max_H=1344, max_W=1344) + x = torch.cat([x, vit_feat], dim=1) + x = self.conv2(x) + outputs04 = [f(x) for f in self.outputs04] + + y = self.layer4(x) + outputs08 = [f(y) for f in self.outputs08] + + z = self.layer5(y) + outputs16 = [f(z) for f in self.outputs16] + + return (outputs04, outputs08, outputs16) + + +class DepthAnythingFeature(nn.Module): + def __init__(self, encoder='vits'): + super().__init__() + model_configs = { + 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]} + } + from depth_anything.dpt import DepthAnything + self.encoder = encoder + depth_anything = DepthAnything(model_configs[encoder]) + self.depth_anything = depth_anything + + self.intermediate_layer_idx = { #!NOTE For V2 + 'vits': [2, 5, 8, 11], + 'vitb': [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23], + 'vitg': [9, 19, 29, 39] + } + + + def forward(self, x): + """ + @x: (B,C,H,W) + """ + h, w = x.shape[-2:] + features = self.depth_anything.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) + + + patch_size = self.depth_anything.pretrained.patch_size + patch_h, patch_w = h // patch_size, w // patch_size + out, path_1, path_2, path_3, path_4, disp = self.depth_anything.depth_head.forward(features, patch_h, patch_w, return_intermediate=True) + + return {'out':out, 'path_1':path_1, 'path_2':path_2, 'path_3':path_3, 'path_4':path_4, 'features':features, 'disp':disp} # path_1 is 1/2; path_2 is 1/4 + + +class Feature(nn.Module): + def __init__(self, vit_type='vitl', no_freeze_mono=False): + super(Feature, self).__init__() + model = timm.create_model('edgenext_small', pretrained=True, features_only=False) + self.stem = model.stem + self.stages = model.stages + chans = [48, 96, 160, 304] + self.chans = chans + self.no_freeze_mono = no_freeze_mono + + self.deconv32_16 = Conv2x_IN(chans[3], chans[2], deconv=True, concat=True) + self.deconv16_8 = Conv2x_IN(chans[2]*2, chans[1], deconv=True, concat=True) + self.deconv8_4 = Conv2x_IN(chans[1]*2, chans[0], deconv=True, concat=True) + self.conv4 = nn.Sequential( + BasicConv(chans[0]*2+128, chans[0]*2+128, kernel_size=3, stride=1, padding=1, norm='instance'), + ResidualBlock(chans[0]*2+128, chans[0]*2+128, norm_fn='instance'), + ResidualBlock(chans[0]*2+128, chans[0]*2+128, norm_fn='instance'), + ) + + self.dino = DepthAnythingFeature(encoder=vit_type) + if not no_freeze_mono: + self.dino = freeze_model(self.dino) + self.patch_size = 14 + self.d_out = [chans[0]*2+128, chans[1]*2, chans[2]*2, chans[3]] + + self.vit_type = vit_type + if vit_type == 'vits': + self.vit_proj = nn.Conv2d(32, 128, 1) + elif vit_type == 'vitb': + self.vit_proj = nn.Conv2d(64, 128, 1) + + def forward(self, x): + B,C,H,W = x.shape + divider = np.lcm(self.patch_size, 16) + H_resize, W_resize = get_resize_keep_aspect_ratio(H,W, divider=divider, max_H=1344, max_W=1344) + x_in_ = F.interpolate(x, size=(H_resize, W_resize), mode='bicubic', align_corners=False) + if self.no_freeze_mono: + output = self.dino(x_in_) + else: + self.dino = self.dino.eval() + with torch.no_grad(): + output = self.dino(x_in_) + vit_feat = output['out'] + if self.vit_type in ['vits', 'vitb']: + # TODO: or just copy the smaller features to 128 + vit_feat = self.vit_proj(vit_feat) + + vit_feat = F.interpolate(vit_feat, size=(H//4,W//4), mode='bilinear', align_corners=True) + x = self.stem(x) + x4 = self.stages[0](x) + x8 = self.stages[1](x4) + x16 = self.stages[2](x8) + x32 = self.stages[3](x16) + + x16 = self.deconv32_16(x32, x16) + x8 = self.deconv16_8(x16, x8) + x4 = self.deconv8_4(x8, x4) + x4 = torch.cat([x4, vit_feat], dim=1) + x4 = self.conv4(x4) + return [x4, x8, x16, x32], vit_feat + + diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/foundation_stereo.py b/optgs/model/encoder/.deprecated/foundationstereo/core/foundation_stereo.py new file mode 100644 index 0000000000000000000000000000000000000000..ea45d81ad96da8041036ef652e27997f3a3980a6 --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/core/foundation_stereo.py @@ -0,0 +1,621 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +import torch,pdb,logging +import torch.nn as nn +import torch.nn.functional as F +import sys,os +code_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(f'{code_dir}/../') +from core.update import * +from core.extractor import * +from core.geometry import Combined_Geo_Encoding_Volume +from core.submodule import * +from core.utils.utils import * +from Utils import * +import time +import cv2 + +from .unimatch_matching import group_correlation_softmax_depth, CorrBlock, coords_grid, warp_with_pose_depth_candidates +from .unimatch_geometry import compute_flow_with_depth_pose + + +count = 0 + + +try: + # autocast = torch.cuda.amp.autocast + autocast = torch.amp.autocast +except: + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + + +def normalize_image(img): + ''' + @img: (B,C,H,W) in range 0-255, RGB order + ''' + tf = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False) + return tf(img/255.0).contiguous() + + +class hourglass(nn.Module): + def __init__(self, cfg, in_channels, feat_dims=None): + super().__init__() + self.cfg = cfg + self.conv1 = nn.Sequential(BasicConv(in_channels, in_channels*2, is_3d=True, bn=True, relu=True, kernel_size=3, + padding=1, stride=2, dilation=1), + Conv3dNormActReduced(in_channels*2, in_channels*2, kernel_size=3, kernel_disp=17)) + + self.conv2 = nn.Sequential(BasicConv(in_channels*2, in_channels*4, is_3d=True, bn=True, relu=True, kernel_size=3, + padding=1, stride=2, dilation=1), + Conv3dNormActReduced(in_channels*4, in_channels*4, kernel_size=3, kernel_disp=17)) + + self.conv3 = nn.Sequential(BasicConv(in_channels*4, in_channels*6, is_3d=True, bn=True, relu=True, kernel_size=3, + padding=1, stride=2, dilation=1), + Conv3dNormActReduced(in_channels*6, in_channels*6, kernel_size=3, kernel_disp=17)) + + + self.conv3_up = BasicConv(in_channels*6, in_channels*4, deconv=True, is_3d=True, bn=True, + relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2)) + + self.conv2_up = BasicConv(in_channels*4, in_channels*2, deconv=True, is_3d=True, bn=True, + relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2)) + + self.conv1_up = BasicConv(in_channels*2, in_channels, deconv=True, is_3d=True, bn=True, + relu=True, kernel_size=(4, 4, 4), padding=(1, 1, 1), stride=(2, 2, 2)) + self.conv_out = nn.Sequential( + Conv3dNormActReduced(in_channels, in_channels, kernel_size=3, kernel_disp=17), + Conv3dNormActReduced(in_channels, in_channels, kernel_size=3, kernel_disp=17), + ) + + self.agg_0 = nn.Sequential(BasicConv(in_channels*8, in_channels*4, is_3d=True, kernel_size=1, padding=0, stride=1), + Conv3dNormActReduced(in_channels*4, in_channels*4, kernel_size=3, kernel_disp=17), + Conv3dNormActReduced(in_channels*4, in_channels*4, kernel_size=3, kernel_disp=17),) + + self.agg_1 = nn.Sequential(BasicConv(in_channels*4, in_channels*2, is_3d=True, kernel_size=1, padding=0, stride=1), + Conv3dNormActReduced(in_channels*2, in_channels*2, kernel_size=3, kernel_disp=17), + Conv3dNormActReduced(in_channels*2, in_channels*2, kernel_size=3, kernel_disp=17)) + self.atts = nn.ModuleDict({ + "4": CostVolumeDisparityAttention(d_model=in_channels, nhead=4, dim_feedforward=in_channels, norm_first=False, num_transformer=4, max_len=self.cfg['max_disp']//16), + }) + self.conv_patch = nn.Sequential( + nn.Conv3d(in_channels, in_channels, kernel_size=4, stride=4, padding=0, groups=in_channels), + nn.BatchNorm3d(in_channels), + ) + + self.feature_att_8 = FeatureAtt(in_channels*2, feat_dims[1]) + self.feature_att_16 = FeatureAtt(in_channels*4, feat_dims[2]) + self.feature_att_32 = FeatureAtt(in_channels*6, feat_dims[3]) + self.feature_att_up_16 = FeatureAtt(in_channels*4, feat_dims[2]) + self.feature_att_up_8 = FeatureAtt(in_channels*2, feat_dims[1]) + + def forward(self, x, features): + conv1 = self.conv1(x) + conv1 = self.feature_att_8(conv1, features[1]) + + conv2 = self.conv2(conv1) + conv2 = self.feature_att_16(conv2, features[2]) + + conv3 = self.conv3(conv2) + conv3 = self.feature_att_32(conv3, features[3]) + + conv3_up = self.conv3_up(conv3) + conv2 = torch.cat((conv3_up, conv2), dim=1) + conv2 = self.agg_0(conv2) + conv2 = self.feature_att_up_16(conv2, features[2]) + + conv2_up = self.conv2_up(conv2) + conv1 = torch.cat((conv2_up, conv1), dim=1) + conv1 = self.agg_1(conv1) + conv1 = self.feature_att_up_8(conv1, features[1]) + + conv = self.conv1_up(conv1) + x = self.conv_patch(x) + x = self.atts["4"](x) + x = F.interpolate(x, scale_factor=4, mode='trilinear', align_corners=False) + conv = conv + x + conv = self.conv_out(conv) + + return conv + + + +class FoundationStereo(nn.Module): + def __init__(self, args, + bilinear_init_depth=False, + flow_corr=False, + flow_corr_levels=4, + amp_bf16=True, + vit_type='vitl', + no_geo_volume=False, + concat_geo_volume=False, + depth_sample_geo_volume=False, + no_freeze_mono=False, + bilinear_up_depth=False, + local_match_radius=0, + supervise_init_depth=False, + sample_log_depth=False, + ): + super().__init__() + self.args = args + + self.bilinear_init_depth = bilinear_init_depth + self.flow_corr = flow_corr + self.flow_corr_levels = flow_corr_levels + self.amp_bf16 = amp_bf16 + self.no_geo_volume = no_geo_volume + self.concat_geo_volume = concat_geo_volume + self.depth_sample_geo_volume = depth_sample_geo_volume + self.bilinear_up_depth = bilinear_up_depth + self.local_match_radius = local_match_radius + self.supervise_init_depth = supervise_init_depth + self.sample_log_depth = sample_log_depth + + if local_match_radius > 0: + self.flow_corr = flow_corr = False + + if concat_geo_volume: + assert flow_corr or local_match_radius > 0 + + if depth_sample_geo_volume: + assert concat_geo_volume + + context_dims = args.hidden_dims + self.cv_group = 8 + volume_dim = 28 + + self.cnet = ContextNetDino(output_dim=[args.hidden_dims, context_dims], downsample=args.n_downsample) + self.update_block = BasicSelectiveMultiUpdateBlock(self.args, self.args.hidden_dims[0], volume_dim=volume_dim, + depth_head_dim=2 * local_match_radius + 1, + ) + self.sam = SpatialAttentionExtractor() + self.cam = ChannelAttentionEnhancement(self.args.hidden_dims[0]) + + # unused parameters + # self.context_zqr_convs = nn.ModuleList([nn.Conv2d(context_dims[i], args.hidden_dims[i]*3, kernel_size=3, padding=3//2) for i in range(self.args.n_gru_layers)]) + + self.feature = Feature(vit_type=vit_type, no_freeze_mono=no_freeze_mono) + self.proj_cmb = nn.Conv2d(self.feature.d_out[0], 12, kernel_size=1, padding=0) + + self.stem_2 = nn.Sequential( + BasicConv_IN(3, 32, kernel_size=3, stride=2, padding=1), + nn.Conv2d(32, 32, 3, 1, 1, bias=False), + nn.InstanceNorm2d(32), nn.ReLU() + ) + # self.stem_4 = nn.Sequential( + # BasicConv_IN(32, 48, kernel_size=3, stride=2, padding=1), + # nn.Conv2d(48, 48, 3, 1, 1, bias=False), + # nn.InstanceNorm2d(48), nn.ReLU() + # ) + + self.spx_2_gru = Conv2x(32, 32, True, bn=False) + self.spx_gru = nn.Sequential( + nn.ConvTranspose2d(2*32, 9, kernel_size=4, stride=2, padding=1), + ) + + + self.corr_stem = nn.Sequential( + nn.Conv3d(32, volume_dim, kernel_size=1), + BasicConv(volume_dim, volume_dim, kernel_size=3, padding=1, is_3d=True), + ResnetBasicBlock3D(volume_dim, volume_dim, kernel_size=3, stride=1, padding=1), + ResnetBasicBlock3D(volume_dim, volume_dim, kernel_size=3, stride=1, padding=1), + ) + self.corr_feature_att = FeatureAtt(volume_dim, self.feature.d_out[0]) + self.cost_agg = hourglass(cfg=self.args, in_channels=volume_dim, feat_dims=self.feature.d_out) + self.classifier = nn.Sequential( + BasicConv(volume_dim, volume_dim//2, kernel_size=3, padding=1, is_3d=True), + ResnetBasicBlock3D(volume_dim//2, volume_dim//2, kernel_size=3, stride=1, padding=1), + nn.Conv3d(volume_dim//2, 1, kernel_size=7, padding=3), + ) + + r = self.args.corr_radius + dx = torch.linspace(-r, r, 2*r+1, requires_grad=False).reshape(1, 1, 2*r+1, 1) + self.dx = dx + + if self.flow_corr: + flow_corr_channels = self.flow_corr_levels * (2 * 4 + 1) ** 2 + if self.concat_geo_volume: + self.corr_proj = nn.Conv2d(flow_corr_channels + 504, 522, 1) + else: + self.corr_proj = nn.Conv2d(flow_corr_channels, 522, 1) + + if self.local_match_radius > 0: + # TODO: multi-scale matching, maybe also combine geometry volume sampled from geo_volume + corr_channels = 2 * self.local_match_radius + 1 + self.correlation_proj = nn.Conv2d(corr_channels, 522, 1) + + if self.no_geo_volume: + self.corr_proj = nn.Conv2d(18, 522, 1) + + # unused parameters + # del self.cnet.down[0].weight + # del self.cnet.down[0].bias + # del self.stem_4[0].conv.weight + # del self.stem_4[1].weight + + + def upsample_disp(self, disp, mask_feat_4, stem_2x, task='stereo'): + assert task in ['stereo', 'depth'] + + dtype = torch.bfloat16 if self.amp_bf16 else torch.float16 + + with autocast('cuda', enabled=self.args.mixed_precision, dtype=dtype): + xspx = self.spx_2_gru(mask_feat_4, stem_2x) # 1/2 resolution + spx_pred = self.spx_gru(xspx) + spx_pred = F.softmax(spx_pred, 1) + if task == 'depth': + # no 4x disp since we predict inverse depth + up_disp = context_upsample(disp, spx_pred, task='depth').unsqueeze(1) + else: + up_disp = context_upsample(disp*4., spx_pred).unsqueeze(1) + + return up_disp.float() + + + def forward(self, image1, image2, iters=12, flow_init=None, test_mode=False, low_memory=False, init_disp=None, + no_norm_img=False, + task='stereo', + intrinsics=None, + pose=None, # relative pose transform + min_depth=1. / 0.5, # inverse depth range + max_depth=1. / 10, + num_depth_candidates=64, + pred_bidir_depth=False, + return_features=False, + rectified_stereo=False, + ): + """ Estimate disparity between pair of frames """ + assert task in ['stereo', 'depth'] + + if self.sample_log_depth: + min_depth, max_depth = np.log(1. / max_depth), np.log(1. / min_depth) + # print(min_depth, max_depth) + + if rectified_stereo: + from PIL import Image + ori_img1 = image1[0].permute(1, 2, 0).cpu().numpy() + ori_img2 = image2[0].permute(1, 2, 0).cpu().numpy() + ori_concat = np.concatenate((ori_img1, ori_img2), axis=1) + + save_dir = 'tmp_rectified_stereo' + os.makedirs(save_dir, exist_ok=True) + + # Image.fromarray(ori_concat.astype(np.uint8)).save(save_dir + '/ori.png') + + img_left = ori_img1 + img_right = ori_img2 + + ori_h, ori_w = image1.shape[-2:] + tmp_intrinsics = intrinsics.clone() + tmp_intrinsics[:, 0] *= ori_w + tmp_intrinsics[:, 1] *= ori_h + + K1 = K2 = tmp_intrinsics[0].cpu().numpy() + D1 = np.zeros(5) # Assuming no distortion + D2 = np.zeros(5) + + R = pose[0].cpu().numpy()[:3, :3] + T = pose[0].cpu().numpy()[:3, 3:] + + # TODO: choose left or right + # if T[0] < 0: + # img_left, img_right = ori_img1, ori_img2 + # else: + # # swap + # img_left, img_right = ori_img2, ori_img1 + # # inverse + # R, T = R.T, -R.T @ T + + # [w, h] + img_size = [ori_img1.shape[1], ori_img1.shape[0]] + + # === 3. Stereo rectification === + R1, R2, P1, P2, Q, _, _ = cv2.stereoRectify(K1, D1, K2, D2, img_size, R, T) + + # === 4. Create rectification maps === + map1x, map1y = cv2.initUndistortRectifyMap(K1, D1, R1, P1, img_size, cv2.CV_32FC1) + map2x, map2y = cv2.initUndistortRectifyMap(K2, D2, R2, P2, img_size, cv2.CV_32FC1) + + # === 5. Remap (rectify) images === + rect_left = cv2.remap(img_left, map1x, map1y, cv2.INTER_LINEAR) + rect_right = cv2.remap(img_right, map2x, map2y, cv2.INTER_LINEAR) + + # print(rect_left.shape, rect_right.shape) + + rect_concat = np.concatenate((rect_left, rect_right), axis=1) + + concat = np.concatenate((ori_concat, rect_concat), axis=0) + global count + Image.fromarray(concat.astype(np.uint8)).save(save_dir + f'/rect_{count}.png') + count += 1 + + if count > 20: + assert False + + B = len(image1) + low_memory = low_memory or (self.args.get('low_memory', False)) + if not no_norm_img: + image1 = normalize_image(image1) + image2 = normalize_image(image2) + + dtype = torch.bfloat16 if self.amp_bf16 else torch.float16 + with autocast('cuda', enabled=self.args.mixed_precision, dtype=dtype): + out, vit_feat = self.feature(torch.cat([image1, image2], dim=0)) + if not pred_bidir_depth: + vit_feat = vit_feat[:B] + features_left = [o[:B] for o in out] + features_right = [o[B:] for o in out] + if pred_bidir_depth: + stem_2x = self.stem_2(torch.cat([image1, image2], dim=0)) + for i in range(len(features_left)): + features_left[i], features_right[i] = torch.cat([features_left[i], features_right[i]], dim=0), torch.cat([features_right[i], features_left[i]], dim=0) + else: + stem_2x = self.stem_2(image1) + + if task == 'depth': + assert intrinsics is not None and pose is not None + # NOTE: in this codebase, intrinsics are normalized by image width and height + # in unimatch's codebase, no normalization + ori_h, ori_w = image1.shape[-2:] + intrinsics = intrinsics.clone() + intrinsics[:, 0] *= ori_w + intrinsics[:, 1] *= ori_h + + # scale intrinsics + intrinsics_curr = intrinsics.clone() + intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / 4 + + if pred_bidir_depth: + intrinsics_curr = intrinsics_curr.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + + b, _, h, w = features_left[0].shape + + depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(image1) + depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h, + w) # [B, D, H, W] + + # gwc_volume: [B, G, D, H, W] + gwc_volume, warped_feature1 = group_correlation_softmax_depth(features_left[0], features_right[0], + intrinsics_curr, + pose, + depth_candidates=depth_candidates, + num_groups=self.cv_group, + sample_log_depth=self.sample_log_depth, + ) + + left_tmp = self.proj_cmb(features_left[0]).unsqueeze(2).repeat(1, 1, num_depth_candidates, 1, 1) # [B, C, D, H, W] + right_tmp = self.proj_cmb(warped_feature1.reshape(b, features_left[0].size(1), -1, w)).reshape(b, -1, num_depth_candidates, h, w) + concat_volume = torch.cat((left_tmp, right_tmp), dim=1) # [B, 2C, D, H, W] + del left_tmp, right_tmp + + else: + gwc_volume = build_gwc_volume(features_left[0], features_right[0], self.args.max_disp//4, self.cv_group) # Group-wise correlation volume (B, N_group, max_disp, H, W) + left_tmp = self.proj_cmb(features_left[0]) + right_tmp = self.proj_cmb(features_right[0]) + concat_volume = build_concat_volume(left_tmp, right_tmp, maxdisp=self.args.max_disp//4) + del left_tmp, right_tmp + + comb_volume = torch.cat([gwc_volume, concat_volume], dim=1) + comb_volume = self.corr_stem(comb_volume) + comb_volume = self.corr_feature_att(comb_volume, features_left[0]) + comb_volume = self.cost_agg(comb_volume, features_left) + + # Init disp from geometry encoding volume + prob = F.softmax(self.classifier(comb_volume).squeeze(1), dim=1) #(B, max_disp, H, W) + if init_disp is None: + if task == 'depth': + init_disp = (prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W] + else: + init_disp = disparity_regression(prob, self.args.max_disp//4) # Weighted sum of disparity + + if pred_bidir_depth: + cnet_list = self.cnet(torch.cat((image1, image2), dim=0), vit_feat=vit_feat, num_layers=self.args.n_gru_layers) #(1/4, 1/8, 1/16) + else: + cnet_list = self.cnet(image1, vit_feat=vit_feat, num_layers=self.args.n_gru_layers) #(1/4, 1/8, 1/16) + cnet_list = list(cnet_list) + net_list = [torch.tanh(x[0]) for x in cnet_list] # Hidden information + inp_list = [torch.relu(x[1]) for x in cnet_list] # Context information list of pyramid levels + inp_list = [self.cam(x) * x for x in inp_list] + att = [self.sam(x) for x in inp_list] + + if self.flow_corr: + geo_fn = CorrBlock(features_left[0].float(), features_right[0].float(), num_levels=self.flow_corr_levels) + if self.concat_geo_volume: + geo_volume_fn = Combined_Geo_Encoding_Volume(features_left[0].float(), features_right[0].float(), comb_volume.float(), + num_levels=self.args.corr_levels, dx=self.dx, no_corr=True) + + else: + geo_fn = Combined_Geo_Encoding_Volume(features_left[0].float(), features_right[0].float(), comb_volume.float(), num_levels=self.args.corr_levels, dx=self.dx) + + b, c, h, w = features_left[0].shape + coords = torch.arange(w, dtype=torch.float, device=init_disp.device).reshape(1,1,w,1).repeat(b, h, 1, 1) # (B,H,W,1) Horizontal only + disp = init_disp.float() + disp_preds = [] + + # GRUs iterations to update disparity (1/4 resolution) + for itr in range(iters): + disp = disp.detach() + if self.flow_corr: + proj_coords_from_depth = compute_flow_with_depth_pose( + torch.exp(disp.squeeze(1)) if self.sample_log_depth else 1. / disp.squeeze(1), + intrinsics_curr, + extrinsics_rel=pose, + return_coords=True, + ) + geo_feat = geo_fn(proj_coords_from_depth) + + if self.concat_geo_volume: + if self.depth_sample_geo_volume: + indices = torch.sum(disp >= depth_candidates, dim=1, keepdim=True) - 1 + # Clamp to ensure indices are within [0, D-1] + indices = indices.clamp(min=0, max=num_depth_candidates-1).float() + tmp = geo_volume_fn(indices, coords) + else: + tmp = geo_volume_fn(disp, coords) + geo_feat = torch.cat((geo_feat, tmp), dim=1) + + # use the pre-trained weights + geo_feat = self.corr_proj(geo_feat) + + elif self.local_match_radius > 0: + # 2x smaller interval for each iteration + disp_interval = (max_depth - min_depth) / num_depth_candidates / (2 ** itr) + disp_range_min = (disp - disp_interval * self.local_match_radius).clamp(min=min_depth) # [B, 1, H, W] + disp_range_max = (disp + disp_interval * self.local_match_radius).clamp(max=max_depth) + linear_space = torch.linspace(0, 1, 2 * self.local_match_radius + 1 + ).type_as(disp).view(1, -1, 1, 1) # [1, K, 1, 1] + disp_candidates = disp_range_min + linear_space * (disp_range_max - disp_range_min) # [B, K, H, W] + + warped_feature1 = warp_with_pose_depth_candidates(features_right[0].float(), + intrinsics_curr, + pose, + torch.exp(disp_candidates) if self.sample_log_depth else (1. / disp_candidates), + ) # [B, C, K, H, W] + corr = (F.normalize(features_left[0].float().unsqueeze(2), dim=1) * F.normalize(warped_feature1, dim=1)).sum(1) # [B, K, H, W] + geo_feat = self.correlation_proj(corr) + + else: + geo_feat = geo_fn(disp, coords, low_memory=low_memory, no_geo_volume=self.no_geo_volume) + + if self.no_geo_volume: + geo_feat = self.corr_proj(geo_feat) + + dtype = torch.bfloat16 if self.amp_bf16 else torch.float16 + with autocast('cuda', enabled=self.args.mixed_precision, dtype=dtype): + net_list, mask_feat_4, delta_disp = self.update_block(net_list, inp_list, geo_feat, disp, att) + + if self.local_match_radius > 0: + match_prob = F.softmax(delta_disp.float(), dim=1) + disp = (match_prob * disp_candidates).sum(1, keepdim=True) + else: + disp = disp + delta_disp.float() + + if task == 'depth': + disp = disp.clamp(min=min_depth, max=max_depth) + + if test_mode and itr < iters-1: + continue + + # upsample predictions + if self.bilinear_up_depth: + disp_up = F.interpolate(disp, scale_factor=4, mode='bilinear', align_corners=True) + else: + disp_up = self.upsample_disp(disp.float(), mask_feat_4.float(), stem_2x.float(), task=task) + + disp_preds.append(disp_up) + + if iters == 0 and task == 'depth': + disp = F.interpolate(disp, scale_factor=4, mode='bilinear', align_corners=True).squeeze(1) + disp_up = torch.exp(disp) if self.sample_log_depth else (1. / disp) + # else: + # # no refine, check the base model + # disp = disp.detach() + # geo_feat = geo_fn(disp, coords, low_memory=low_memory) + + # dtype = torch.bfloat16 if self.amp_bf16 else torch.float16 + # with autocast('cuda', enabled=self.args.mixed_precision, dtype=dtype): + # net_list, mask_feat_4, delta_disp = self.update_block(net_list, inp_list, geo_feat, disp, att) + + # # upsample predictions + # disp_up = 1. / self.upsample_disp(disp.float(), mask_feat_4.float(), stem_2x.float(), task=task).squeeze(1) + + if test_mode: + if pred_bidir_depth: + half = disp_up.size(0) // 2 + + if return_features: + return disp_up[:half], disp_up[half:], torch.cat((vit_feat, features_left[0]), dim=1), prob + + return disp_up + else: + depth_preds = [disp_up] + + if pred_bidir_depth: + half = depth_preds[0].size(0) // 2 + fwd_depth_preds = [pred[:half] for pred in depth_preds] + bwd_depth_preds = [pred[half:] for pred in depth_preds] + + if return_features: + return fwd_depth_preds, bwd_depth_preds, torch.cat((vit_feat, features_left[0]), dim=1), prob + + return fwd_depth_preds, bwd_depth_preds + + return depth_preds + + if task == 'depth': + # convert inverse depth to depth + disp_up = torch.exp(disp_up.squeeze(1)) if self.sample_log_depth else (1. / disp_up.squeeze(1)) + init_disp = torch.exp(init_disp) if self.sample_log_depth else (1. / init_disp) + for i in range(len(disp_preds)): + disp_preds[i] = torch.exp(disp_preds[i].squeeze(1)) if self.sample_log_depth else (1. / disp_preds[i].squeeze(1)) # [B, H, W] + + if test_mode or not self.training: + if task == 'depth': + # disp_up = disp_up.clamp(min=min_depth, max=max_depth) + + if pred_bidir_depth: + half = disp_up.size(0) // 2 + + if return_features: + return disp_up[:half], disp_up[half:], torch.cat((vit_feat, features_left[0]), dim=1), prob + + return disp_up[:half], disp_up[half:] + + return disp_up + + if task == 'depth': + # upsample to the full resolution to add supervison + init_disp = F.interpolate(init_disp, scale_factor=4, mode='bilinear', align_corners=True).squeeze(1) + + if self.supervise_init_depth: + depth_preds = [init_disp] + disp_preds + else: + depth_preds = disp_preds + + if pred_bidir_depth: + half = depth_preds[0].size(0) // 2 + fwd_depth_preds = [pred[:half] for pred in depth_preds] + bwd_depth_preds = [pred[half:] for pred in depth_preds] + + if return_features: + return fwd_depth_preds, bwd_depth_preds, torch.cat((vit_feat, features_left[0]), dim=1), prob + + return fwd_depth_preds, bwd_depth_preds + + return depth_preds + + return init_disp, disp_preds + + + def run_hierachical(self, image1, image2, iters=12, test_mode=False, low_memory=False, small_ratio=0.5): + B,_,H,W = image1.shape + img1_small = F.interpolate(image1, scale_factor=small_ratio, align_corners=False, mode='bilinear') + img2_small = F.interpolate(image2, scale_factor=small_ratio, align_corners=False, mode='bilinear') + padder = InputPadder(img1_small.shape[-2:], divis_by=32, force_square=False) + img1_small, img2_small = padder.pad(img1_small, img2_small) + disp_small = self.forward(img1_small, img2_small, test_mode=True, iters=iters, low_memory=low_memory) + disp_small = padder.unpad(disp_small.float()) + disp_small_up = F.interpolate(disp_small, size=(H,W), mode='bilinear', align_corners=True) * 1/small_ratio + disp_small_up = disp_small_up.clip(0, None) + + padder = InputPadder(image1.shape[-2:], divis_by=32, force_square=False) + image1, image2, disp_small_up = padder.pad(image1, image2, disp_small_up) + disp_small_up += padder._pad[0] + init_disp = F.interpolate(disp_small_up, scale_factor=0.25, mode='bilinear', align_corners=True) * 0.25 # Init disp will be 1/4 + disp = self.forward(image1, image2, iters=iters, test_mode=test_mode, low_memory=low_memory, init_disp=init_disp) + disp = padder.unpad(disp.float()) + return disp + diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/geometry.py b/optgs/model/encoder/.deprecated/foundationstereo/core/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..eb19acb506021710d7a404a135526489c678d235 --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/core/geometry.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +import torch,pdb,os,sys +import torch.nn.functional as F +from core.utils.utils import bilinear_sampler +code_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(f'{code_dir}/../') +from Utils import * + +class Combined_Geo_Encoding_Volume: + def __init__(self, init_fmap1, init_fmap2, geo_volume, num_levels=2, dx=None, no_corr=False): + self.num_levels = num_levels + self.geo_volume_pyramid = [] + self.init_corr_pyramid = [] + self.dx = dx + self.no_corr = no_corr + + b, c, d, h, w = geo_volume.shape + geo_volume = geo_volume.permute(0, 3, 4, 1, 2).reshape(b*h*w, c, 1, d).contiguous() + + self.geo_volume_pyramid.append(geo_volume) + for i in range(self.num_levels-1): + geo_volume = F.avg_pool2d(geo_volume, [1,2], stride=[1,2]) + self.geo_volume_pyramid.append(geo_volume) + + if not no_corr: + # all pairs correlation + init_corr = Combined_Geo_Encoding_Volume.corr(init_fmap1, init_fmap2) + b, h, w, _, w2 = init_corr.shape + init_corr = init_corr.reshape(b*h*w, 1, 1, w2) + self.init_corr_pyramid.append(init_corr) + + for i in range(self.num_levels-1): + init_corr = F.avg_pool2d(init_corr, [1,2], stride=[1,2]) + self.init_corr_pyramid.append(init_corr) + + + def __call__(self, disp, coords, low_memory=False, no_geo_volume=False): + b, _, h, w = disp.shape + self.dx = self.dx.to(disp.device) + out_pyramid = [] + for i in range(self.num_levels): + x0 = self.dx + disp.reshape(b*h*w, 1, 1, 1) / 2**i + y0 = torch.zeros_like(x0) + + if not no_geo_volume: + geo_volume = self.geo_volume_pyramid[i] + disp_lvl = torch.cat([x0,y0], dim=-1) + geo_volume = bilinear_sampler(geo_volume, disp_lvl, low_memory=low_memory) + geo_volume = geo_volume.reshape(b, h, w, -1) + out_pyramid.append(geo_volume) + + if not self.no_corr: + init_corr = self.init_corr_pyramid[i] + init_x0 = coords.reshape(b*h*w, 1, 1, 1)/2**i - disp.reshape(b*h*w, 1, 1, 1) / 2**i + self.dx # X on right image + init_coords_lvl = torch.cat([init_x0,y0], dim=-1) + init_corr = bilinear_sampler(init_corr, init_coords_lvl, low_memory=low_memory) + init_corr = init_corr.reshape(b, h, w, -1) + out_pyramid.append(init_corr) + + out_pyramid = torch.cat(out_pyramid, dim=-1) + return out_pyramid.permute(0, 3, 1, 2).contiguous() #(B,C,H,W) + + + @staticmethod + def corr(fmap1, fmap2): + B, D, H, W1 = fmap1.shape + _, _, _, W2 = fmap2.shape + fmap1 = fmap1.reshape(B, D, H, W1) + fmap2 = fmap2.reshape(B, D, H, W2) + with torch.amp.autocast('cuda', enabled=False): + corr = torch.einsum('aijk,aijh->ajkh', F.normalize(fmap1.float(), dim=1), F.normalize(fmap2.float(), dim=1)) + corr = corr.reshape(B, H, W1, 1, W2) + return corr \ No newline at end of file diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/submodule.py b/optgs/model/encoder/.deprecated/foundationstereo/core/submodule.py new file mode 100644 index 0000000000000000000000000000000000000000..ec78f411951ca2d21e0e85616fbde64b2819cc3c --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/core/submodule.py @@ -0,0 +1,596 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +import torch,pdb,os,sys +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange +from torch import einsum +code_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(f'{code_dir}/../') +from Utils import * +from flash_attn import flash_attn_qkvpacked_func, flash_attn_func + + +def _is_contiguous(tensor: torch.Tensor) -> bool: + if torch.jit.is_scripting(): + return tensor.is_contiguous() + else: + return tensor.is_contiguous(memory_format=torch.contiguous_format) + + +class LayerNorm2d(nn.LayerNorm): + r""" https://huggingface.co/spaces/Roll20/pet_score/blob/b258ef28152ab0d5b377d9142a23346f863c1526/lib/timm/models/convnext.py#L85 + LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__(normalized_shape, eps=eps) + + def forward(self, x) -> torch.Tensor: + """ + @x: (B,C,H,W) + """ + if _is_contiguous(x): + return F.layer_norm(x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2).contiguous() + else: + s, u = torch.var_mean(x, dim=1, keepdim=True) + x = (x - u) * torch.rsqrt(s + self.eps) + x = x * self.weight[:, None, None] + self.bias[:, None, None] + return x + + + +class BasicConv(nn.Module): + + def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, bn=True, relu=True, norm='batch', **kwargs): + super(BasicConv, self).__init__() + + self.relu = relu + self.use_bn = bn + self.bn = nn.Identity() + if is_3d: + if deconv: + self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) + else: + self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) + if self.use_bn: + if norm=='batch': + self.bn = nn.BatchNorm3d(out_channels) + elif norm=='instance': + self.bn = nn.InstanceNorm3d(out_channels) + else: + if deconv: + self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) + else: + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + if self.use_bn: + if norm=='batch': + self.bn = nn.BatchNorm2d(out_channels) + elif norm=='instance': + self.bn = nn.InstanceNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + if self.use_bn: + x = self.bn(x) + if self.relu: + x = nn.LeakyReLU()(x)#, inplace=True) + return x + + +class Conv3dNormActReduced(nn.Module): + def __init__(self, C_in, C_out, hidden=None, kernel_size=3, kernel_disp=None, stride=1, norm=nn.BatchNorm3d): + super().__init__() + if kernel_disp is None: + kernel_disp = kernel_size + if hidden is None: + hidden = C_out + self.conv1 = nn.Sequential( + nn.Conv3d(C_in, hidden, kernel_size=(1,kernel_size,kernel_size), padding=(0, kernel_size//2, kernel_size//2), stride=(1, stride, stride)), + norm(hidden), + nn.ReLU(), + ) + self.conv2 = nn.Sequential( + nn.Conv3d(hidden, C_out, kernel_size=(kernel_disp, 1, 1), padding=(kernel_disp//2, 0, 0), stride=(stride, 1, 1)), + norm(C_out), + nn.ReLU(), + ) + + + def forward(self, x): + """ + @x: (B,C,D,H,W) + """ + x = self.conv1(x) + x = self.conv2(x) + return x + + + + +class ResnetBasicBlock(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=nn.BatchNorm2d, bias=False): + super().__init__() + self.norm_layer = norm_layer + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding) + if self.norm_layer is not None: + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding) + if self.norm_layer is not None: + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + + def forward(self, x): + identity = x + + out = self.conv1(x) + if self.norm_layer is not None: + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + if self.norm_layer is not None: + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + out += identity + out = self.relu(out) + + return out + + +class ResnetBasicBlock3D(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=nn.BatchNorm3d, bias=False): + super().__init__() + self.norm_layer = norm_layer + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding) + if self.norm_layer is not None: + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(planes, planes, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding) + if self.norm_layer is not None: + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + + def forward(self, x): + identity = x + + out = self.conv1(x) + if self.norm_layer is not None: + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + if self.norm_layer is not None: + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + out += identity + out = self.relu(out) + + return out + + +class FlashMultiheadAttention(nn.Module): + def __init__(self, embed_dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.embed_dim = embed_dim + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + + def forward(self, query, key, value, attn_mask=None, window_size=(-1,-1)): + """ + @query: (B,L,C) + """ + B,L,C = query.shape + Q = self.q_proj(query) + K = self.k_proj(key) + V = self.v_proj(value) + + Q = Q.view(Q.size(0), Q.size(1), self.num_heads, self.head_dim) + K = K.view(K.size(0), K.size(1), self.num_heads, self.head_dim) + V = V.view(V.size(0), V.size(1), self.num_heads, self.head_dim) + + attn_output = flash_attn_func(Q, K, V, window_size=window_size) # Replace with actual FlashAttention function + + attn_output = attn_output.reshape(B,L,-1) + output = self.out_proj(attn_output) + + return output + + + +class FlashAttentionTransformerEncoderLayer(nn.Module): + def __init__(self, embed_dim, num_heads, dim_feedforward, dropout=0.1, act=nn.GELU, norm=nn.LayerNorm): + super().__init__() + self.self_attn = FlashMultiheadAttention(embed_dim, num_heads) + self.act = act() + + self.linear1 = nn.Linear(embed_dim, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, embed_dim) + + self.norm1 = norm(embed_dim) + self.norm2 = norm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, src, src_mask=None, window_size=(-1, -1)): + src2 = self.self_attn(src, src, src, src_mask, window_size=window_size) + src = src + self.dropout1(src2) + src = self.norm1(src) + + src2 = self.linear2(self.dropout(self.act(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + + return src + + + +class UpsampleConv(nn.Module): + def __init__(self, C_in, C_out, is_3d=False, kernel_size=3, bias=True, stride=1, padding=1): + super().__init__() + self.is_3d = is_3d + if is_3d: + self.conv = nn.Conv3d(C_in, C_out, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=bias) + else: + self.conv = nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=bias) + + def forward(self, x): + if self.is_3d: + mode = 'trilinear' + else: + mode = 'bilinear' + x = F.interpolate(x, size=None, scale_factor=2, align_corners=False, mode=mode) + x = self.conv(x) + return x + + + +class Conv2x(nn.Module): + + def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, bn=True, relu=True, keep_dispc=False): + super(Conv2x, self).__init__() + self.concat = concat + self.is_3d = is_3d + if deconv and is_3d: + kernel = (4, 4, 4) + elif deconv: + kernel = 4 + else: + kernel = 3 + + if deconv and is_3d and keep_dispc: + kernel = (1, 4, 4) + stride = (1, 2, 2) + padding = (0, 1, 1) + self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=bn, relu=True, kernel_size=kernel, stride=stride, padding=padding) + else: + self.conv1 = BasicConv(in_channels, out_channels, deconv, is_3d, bn=bn, relu=True, kernel_size=kernel, stride=2, padding=1) + + if self.concat: + mul = 2 if keep_concat else 1 + self.conv2 = BasicConv(out_channels*2, out_channels*mul, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1) + else: + self.conv2 = BasicConv(out_channels, out_channels, False, is_3d, bn, relu, kernel_size=3, stride=1, padding=1) + + def forward(self, x, rem): + x = self.conv1(x) + if x.shape != rem.shape: + x = F.interpolate(x, size=(rem.shape[-2], rem.shape[-1]), mode='bilinear') + if self.concat: + x = torch.cat((x, rem), 1) + else: + x = x + rem + x = self.conv2(x) + return x + + +class BasicConv_IN(nn.Module): + + def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, IN=True, relu=True, **kwargs): + super(BasicConv_IN, self).__init__() + + self.relu = relu + self.use_in = IN + if is_3d: + if deconv: + self.conv = nn.ConvTranspose3d(in_channels, out_channels, bias=False, **kwargs) + else: + self.conv = nn.Conv3d(in_channels, out_channels, bias=False, **kwargs) + self.IN = nn.InstanceNorm3d(out_channels) + else: + if deconv: + self.conv = nn.ConvTranspose2d(in_channels, out_channels, bias=False, **kwargs) + else: + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.IN = nn.InstanceNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + if self.use_in: + x = self.IN(x) + if self.relu: + x = nn.LeakyReLU()(x)#, inplace=True) + return x + + +class Conv2x_IN(nn.Module): + + def __init__(self, in_channels, out_channels, deconv=False, is_3d=False, concat=True, keep_concat=True, IN=True, relu=True, keep_dispc=False): + super(Conv2x_IN, self).__init__() + self.concat = concat + self.is_3d = is_3d + if deconv and is_3d: + kernel = (4, 4, 4) + elif deconv: + kernel = 4 + else: + kernel = 3 + + if deconv and is_3d and keep_dispc: + kernel = (1, 4, 4) + stride = (1, 2, 2) + padding = (0, 1, 1) + self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=stride, padding=padding) + else: + self.conv1 = BasicConv_IN(in_channels, out_channels, deconv, is_3d, IN=True, relu=True, kernel_size=kernel, stride=2, padding=1) + + if self.concat: + mul = 2 if keep_concat else 1 + self.conv2 = ResnetBasicBlock(out_channels*2, out_channels*mul, kernel_size=3, stride=1, padding=1, norm_layer=nn.InstanceNorm2d) + else: + self.conv2 = BasicConv_IN(out_channels, out_channels, False, is_3d, IN, relu, kernel_size=3, stride=1, padding=1) + + def forward(self, x, rem): + x = self.conv1(x) + if x.shape != rem.shape: + x = F.interpolate(x, size=(rem.shape[-2], rem.shape[-1]), mode='bilinear') + if self.concat: + x = torch.cat((x, rem), 1) + else: + x = x + rem + x = self.conv2(x) + return x + + +def groupwise_correlation(fea1, fea2, num_groups): + B, C, H, W = fea1.shape + assert C % num_groups == 0, f"C:{C}, num_groups:{num_groups}" + channels_per_group = C // num_groups + fea1 = fea1.reshape(B, num_groups, channels_per_group, H, W) + fea2 = fea2.reshape(B, num_groups, channels_per_group, H, W) + with torch.cuda.amp.autocast(enabled=False): + cost = (F.normalize(fea1.float(), dim=2) * F.normalize(fea2.float(), dim=2)).sum(dim=2) #!NOTE Divide first for numerical stability + assert cost.shape == (B, num_groups, H, W) + return cost + +def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups, stride=1): + """ + @refimg_fea: left image feature + @targetimg_fea: right image feature + """ + B, C, H, W = refimg_fea.shape + volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W]) + for i in range(maxdisp): + if i > 0: + volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i], num_groups) + else: + volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups) + volume = volume.contiguous() + return volume + + + +def build_concat_volume(refimg_fea, targetimg_fea, maxdisp): + B, C, H, W = refimg_fea.shape + volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W]) + for i in range(maxdisp): + if i > 0: + volume[:, :C, i, :, :] = refimg_fea[:, :, :, :] + volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i] + else: + volume[:, :C, i, :, :] = refimg_fea + volume[:, C:, i, :, :] = targetimg_fea + volume = volume.contiguous() + return volume + + + +def disparity_regression(x, maxdisp): + assert len(x.shape) == 4 + disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device) + disp_values = disp_values.reshape(1, maxdisp, 1, 1) + return torch.sum(x * disp_values, 1, keepdim=True) + + +class FeatureAtt(nn.Module): + def __init__(self, cv_chan, feat_chan): + super(FeatureAtt, self).__init__() + + self.feat_att = nn.Sequential( + BasicConv(feat_chan, feat_chan//2, kernel_size=1, stride=1, padding=0), + nn.Conv2d(feat_chan//2, cv_chan, 1) + ) + + def forward(self, cv, feat): + ''' + @cv: cost volume (B,C,D,H,W) + @feat: (B,C,H,W) + ''' + feat_att = self.feat_att(feat).unsqueeze(2) #(B,C,1,H,W) + cv = torch.sigmoid(feat_att)*cv + return cv + +def context_upsample(disp_low, up_weights, task='stereo'): + """ + @disp_low: (b,1,h,w) 1/4 resolution + @up_weights: (b,9,4*h,4*w) Image resolution + """ + b, c, h, w = disp_low.shape + + assert task in ['stereo', 'depth'] + if task == 'depth': + # since we predict inverse depth, we don't want to do zero padding + disp_unfold = F.unfold(F.pad(disp_low.reshape(b,c,h,w), pad=(1,1,1,1), mode='replicate'),3,1,0).reshape(b,-1,h,w) + else: + disp_unfold = F.unfold(disp_low.reshape(b,c,h,w),3,1,1).reshape(b,-1,h,w) + disp_unfold = F.interpolate(disp_unfold,(h*4,w*4),mode='nearest').reshape(b,9,h*4,w*4) + + disp = (disp_unfold*up_weights).sum(1) + + return disp + + + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=512): + super().__init__() + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) #(N,1) + div_term = (torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model)).exp()[None] + + pe[:, 0::2] = torch.sin(position * div_term) #(N, d_model/2) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.pe = pe + # self.register_buffer('pe', pe) #(1, max_len, D) + + + def forward(self, x, resize_embed=False): + ''' + @x: (B,N,D) + ''' + self.pe = self.pe.to(x.device).to(x.dtype) + pe = self.pe + if pe.shape[1] 0 else None + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + x + return x \ No newline at end of file diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/unimatch_geometry.py b/optgs/model/encoder/.deprecated/foundationstereo/core/unimatch_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..edf2c26d7c22ec88e6d3f83505d60b610e494537 --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/core/unimatch_geometry.py @@ -0,0 +1,302 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +def back_project(depth, intrinsics): + # Back project 2D pixel coords to 3D points + # depth: [B, H, W] + # intrinsics: [B, 3, 3] + b, h, w = depth.shape + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + + intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3] + + points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W] + + return points + + +def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): + # Transform 3D points from reference camera to target camera + # points_ref: [B, 3, H, W] + # extrinsics_ref: [B, 4, 4] + # extrinsics_tgt: [B, 4, 4] + # extrinsics_rel: [B, 4, 4], relative pose transform + b, _, h, w = points_ref.shape + + if extrinsics_rel is None: + extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] + + points_tgt = torch.bmm(extrinsics_rel[:, :3, :3], + points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W] + + points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W] + + return points_tgt + + +def reproject(points_tgt, intrinsics, return_mask=False): + # reproject to target view + # points_tgt: [B, 3, H, W] + # intrinsics: [B, 3, 3] + + b, _, h, w = points_tgt.shape + + proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W] + + X = proj_points[:, 0] + Y = proj_points[:, 1] + Z = proj_points[:, 2].clamp(min=1e-3) + + pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale + + if return_mask: + # valid mask in pixel space + mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & ( + pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W] + + return pixel_coords, mask + + return pixel_coords + + +def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_mask=False): + # Compute reprojection sample coords + points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] + points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) + + if return_mask: + reproj_coords, mask = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords, mask + + reproj_coords = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords + + +def compute_flow_with_depth_pose(depth_ref, intrinsics, + extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_coords=False, + return_mask=False): + b, h, w = depth_ref.shape + coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] + + if return_mask: + reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + rigid_flow = reproj_coords - coords_init + + return rigid_flow, mask + + reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + + if return_coords: + return reproj_coords + + rigid_flow = reproj_coords - coords_init + + return rigid_flow + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5, + return_flow_diff=False, + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + if return_flow_diff: + return diff_fwd, diff_bwd + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +def warp_with_depth_pose(feature1, intrinsics, pose, depth, + padding_mode='zeros', + return_rigid_flow=False, + ): + assert depth.dim() == 3 # [B, H, W] + sample_coords = reproject_coords(depth, + intrinsics, + extrinsics_rel=pose, + ) # [B, 2, H, W] + + sample_coords = sample_coords.permute(0, 2, 3, 1) # [B, H, W, 2] + + warped_feature1 = bilinear_sample(feature1, sample_coords, + padding_mode=padding_mode) # [B, C, H, W] + + if return_rigid_flow: + b, h, w = depth.size() + coords_init = coords_grid(b, h, w, device=depth.device) # [B, 2, H, W] + rigid_flow = sample_coords.permute(0, 3, 1, 2) - coords_init + + return warped_feature1, rigid_flow + + return warped_feature1 + + +def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth, + padding_mode='zeros', + rigid_flow_to_subtract=None, + ): + # pixel-specific depth candidates, useful for refinement + # feature1: [B, C, H, W] + # intrinsics: [B, 3, 3] + # pose: [B, 4, 4] + # depth: [B, D, H, W] + assert intrinsics.size(1) == intrinsics.size(2) == 3 + assert pose.size(1) == pose.size(2) == 4 + assert depth.dim() == 4 + + b, d, h, w = depth.size() + c = feature1.size(1) + + # stop gradient + with torch.no_grad(): + # pixel coordinates + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + # back project to 3D and transform viewpoint + points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] + points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( + 1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W] + points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] + # reproject to 2D image plane + points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W] + pixel_coords = points[:, :2] / points[:, -1:].clamp(min=1e-3) # [B, 2, D, H*W] + + if rigid_flow_to_subtract is not None: + assert rigid_flow_to_subtract.dim() == 4 # [B, 2, H, W] + assert rigid_flow_to_subtract.size(1) == 2 + + pixel_coords = pixel_coords - rigid_flow_to_subtract.view(b, 2, h * w).unsqueeze(2) + + # normalize to [-1, 1] + x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] + + # sample features + warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear', + padding_mode=padding_mode, + align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W] + + return warped_feature + \ No newline at end of file diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/unimatch_matching.py b/optgs/model/encoder/.deprecated/foundationstereo/core/unimatch_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..2da593b59ca4e246160e5349b00611f938855bdc --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/core/unimatch_matching.py @@ -0,0 +1,394 @@ +import torch +import torch.nn.functional as F + +from .unimatch_geometry import coords_grid, generate_window_grid, normalize_coords + + +def global_correlation_softmax(feature0, feature1, + pred_bidir_flow=False, + ): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax(feature0, feature1, local_radius, + padding_mode='zeros', + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob + + +def local_correlation_with_flow(feature0, feature1, + flow, + local_radius, + padding_mode='zeros', + dilation=1, + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2] + + # flow can be zero when using features after transformer + if not isinstance(flow, float): + sample_coords = sample_coords + flow.view( + b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2] + else: + assert flow == 0. + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W] + + return corr + + +def global_correlation_softmax_stereo(feature0, feature1, + ): + # global correlation on horizontal direction + b, c, h, w = feature0.shape + + x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W] + + feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C] + feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W] + + correlation = torch.matmul(feature0, feature1) / (c ** 0.5) # [B, H, W, W] + + # mask subsequent positions to make disparity positive + mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W] + valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W] + + correlation[~valid_mask] = -1e9 + + prob = F.softmax(correlation, dim=-1) # [B, H, W, W] + + correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W] + + # NOTE: unlike flow, disparity is typically positive + disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W] + + return disparity.unsqueeze(1), prob # feature resolution + + +def local_correlation_softmax_stereo(feature0, feature1, local_radius, + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2] + + local_h = 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(0, 0, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode='zeros', align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)] + feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)] + + correspondence = torch.matmul(prob.unsqueeze(-2), + sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + flow = correspondence - coords_init # flow at feature resolution + match_prob = prob + + flow_x = -flow[:, :1] # [B, 1, H, W] + + return flow_x, match_prob + + +def correlation_softmax_depth(feature0, feature1, + intrinsics, + pose, + depth_candidates, + depth_from_argmax=False, + pred_bidir_depth=False, + ): + b, c, h, w = feature0.size() + assert depth_candidates.dim() == 4 # [B, D, H, W] + scale_factor = c ** 0.5 + + if pred_bidir_depth: + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + intrinsics = intrinsics.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + depth_candidates = depth_candidates.repeat(2, 1, 1, 1) + + # depth candidates are actually inverse depth + warped_feature1 = warp_with_pose_depth_candidates(feature1, intrinsics, pose, + 1. / depth_candidates, + ) # [B, C, D, H, W] + + correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W] + + match_prob = F.softmax(correlation, dim=1) # [B, D, H, W] + + # for cross-task transfer (flow -> depth), extract depth with argmax at test time + if depth_from_argmax: + index = torch.argmax(match_prob, dim=1, keepdim=True) + depth = torch.gather(depth_candidates, dim=1, index=index) + else: + depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W] + + return depth, match_prob + + +def group_correlation_softmax_depth(feature0, feature1, + intrinsics, + pose, + depth_candidates, + depth_from_argmax=False, + pred_bidir_depth=False, + num_groups=8, + sample_log_depth=False, + ): + b, c, h, w = feature0.size() + d = depth_candidates.size(1) + assert depth_candidates.dim() == 4 # [B, D, H, W] + assert c % num_groups == 0, f"c: {c}, num_groups: {num_groups}" + + if pred_bidir_depth: + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + intrinsics = intrinsics.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + depth_candidates = depth_candidates.repeat(2, 1, 1, 1) + + # depth candidates are actually inverse depth + with torch.amp.autocast('cuda', enabled=False): + warped_feature1 = warp_with_pose_depth_candidates(feature1.float(), intrinsics, pose, + torch.exp(depth_candidates) if sample_log_depth else (1. / depth_candidates), + grid_sample_disable_cudnn=False, + ) # [B, C, D, H, W] + + channels_per_group = c // num_groups + warped_feature1_reshape = warped_feature1.reshape(b, num_groups, channels_per_group, d, h, w) # [B, G, C, D, H, W] + feature0_reshape = feature0.reshape(b, num_groups, channels_per_group, 1, h, w) # [B, G, C, 1, H, W] + + with torch.amp.autocast('cuda', enabled=False): + correlation = (F.normalize(feature0_reshape.float(), dim=2) * F.normalize(warped_feature1_reshape.float(), dim=2)).sum(dim=2) # [B, G, D, H, W] + + return correlation, warped_feature1 + + +def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth, + clamp_min_depth=1e-3, + grid_sample_disable_cudnn=False, + ): + """ + feature1: [B, C, H, W] + intrinsics: [B, 3, 3] + pose: [B, 4, 4] + depth: [B, D, H, W] + """ + + assert intrinsics.size(1) == intrinsics.size(2) == 3 + assert pose.size(1) == pose.size(2) == 4 + assert depth.dim() == 4 + + b, d, h, w = depth.size() + c = feature1.size(1) + + with torch.no_grad(): + # pixel coordinates + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + # back project to 3D and transform viewpoint + points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] + points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( + 1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W] + points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] + # reproject to 2D image plane + points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W] + pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W] + + # normalize to [-1, 1] + x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] + + # sample features + # ref: https://github.com/pytorch/pytorch/issues/88380 + with torch.backends.cudnn.flags(enabled=not grid_sample_disable_cudnn): + warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear', + padding_mode='zeros', + align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W] + + return warped_feature + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1, device=coords.device) + dy = torch.linspace(-r, r, 2*r+1, device=coords.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + with torch.amp.autocast('cuda', enabled=False): + corr = torch.matmul(F.normalize(fmap1.float(), dim=1).transpose(1, 2), F.normalize(fmap2.float(), dim=1)) + + corr = corr.view(batch, ht, wd, 1, ht, wd) + + # NOTE: normalize first for numerical stability + # corr = torch.matmul(fmap1.transpose(1,2), fmap2) + # return corr / torch.sqrt(torch.tensor(dim).float()) + return corr + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + \ No newline at end of file diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/update.py b/optgs/model/encoder/.deprecated/foundationstereo/core/update.py new file mode 100644 index 0000000000000000000000000000000000000000..ea4d909927b3963a4d5032e67d18f61f8168d3dd --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/core/update.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + +import torch,pdb,os,sys +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import einsum +code_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(f'{code_dir}/../') +from core.submodule import * +from core.extractor import * + +class DispHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256, output_dim=1): + super(DispHead, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1), + nn.ReLU(), + EdgeNextConvEncoder(input_dim, expan_ratio=4, kernel_size=7, norm=None), + EdgeNextConvEncoder(input_dim, expan_ratio=4, kernel_size=7, norm=None), + nn.Conv2d(input_dim, output_dim, 3, padding=1), + ) + + def forward(self, x): + return self.conv(x) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim, input_dim, kernel_size=3): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2) + + def forward(self, h, cz, cr, cq, *x_list): + x = torch.cat(x_list, dim=1) + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz(hx) + cz) + r = torch.sigmoid(self.convr(hx) + cr) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)) + cq) + h = (1-z) * h + z * q + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, args, ngroup=8): + super(BasicMotionEncoder, self).__init__() + self.args = args + cor_planes = args.corr_levels * (2*args.corr_radius + 1) * (ngroup+1) + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 256, 3, padding=1) + self.convd1 = nn.Conv2d(1, 64, 7, padding=3) + self.convd2 = nn.Conv2d(64, 64, 3, padding=1) + self.conv = nn.Conv2d(64+256, 128-1, 3, padding=1) + + def forward(self, disp, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + disp_ = F.relu(self.convd1(disp)) + disp_ = F.relu(self.convd2(disp_)) + + cor_disp = torch.cat([cor, disp_], dim=1) + out = F.relu(self.conv(cor_disp)) + return torch.cat([out, disp], dim=1) + +def pool2x(x): + return F.avg_pool2d(x, 3, stride=2, padding=1) + +def pool4x(x): + return F.avg_pool2d(x, 5, stride=4, padding=1) + +def interp(x, dest): + interp_args = {'mode': 'bilinear', 'align_corners': True} + return F.interpolate(x, dest.shape[2:], **interp_args) + + +class RaftConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=256, kernel_size=3): + super().__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size // 2) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size // 2) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size // 2) + + def forward(self, h, x, hx): + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + return h + + +class SelectiveConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=256, small_kernel_size=1, large_kernel_size=3, patch_size=None): + super(SelectiveConvGRU, self).__init__() + self.conv0 = nn.Sequential( + nn.Conv2d(input_dim, input_dim, kernel_size=3, padding=1), + nn.ReLU(), + ) + self.conv1 = nn.Sequential( + nn.Conv2d(input_dim+hidden_dim, input_dim+hidden_dim, kernel_size=3, padding=1), + nn.ReLU(), + ) + self.small_gru = RaftConvGRU(hidden_dim, input_dim, small_kernel_size) + self.large_gru = RaftConvGRU(hidden_dim, input_dim, large_kernel_size) + + def forward(self, att, h, *x): + x = torch.cat(x, dim=1) + x = self.conv0(x) + hx = torch.cat([x, h], dim=1) + hx = self.conv1(hx) + h = self.small_gru(h, x, hx) * att + self.large_gru(h, x, hx) * (1 - att) + + return h + + +class BasicSelectiveMultiUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, volume_dim=8, depth_head_dim=1): + super().__init__() + self.args = args + self.encoder = BasicMotionEncoder(args, volume_dim) + + self.depth_head_dim = depth_head_dim + + if args.n_gru_layers == 3: + self.gru16 = SelectiveConvGRU(hidden_dim, hidden_dim * 2) + if args.n_gru_layers >= 2: + self.gru08 = SelectiveConvGRU(hidden_dim, hidden_dim * (args.n_gru_layers == 3) + hidden_dim * 2) + self.gru04 = SelectiveConvGRU(hidden_dim, hidden_dim * (args.n_gru_layers > 1) + hidden_dim * 2) + if depth_head_dim > 1: + # change the name to skip the pretrained weights + self.depth_head = DispHead(hidden_dim, 256, output_dim=depth_head_dim) + else: + self.disp_head = DispHead(hidden_dim, 256) + self.mask = nn.Sequential( + nn.Conv2d(128, 64, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 32, 3, padding=1), + nn.ReLU(inplace=True), + ) + + def forward(self, net, inp, corr, disp, att): + if self.args.n_gru_layers == 3: + net[2] = self.gru16(att[2], net[2], inp[2], pool2x(net[1])) + if self.args.n_gru_layers >= 2: + if self.args.n_gru_layers > 2: + net[1] = self.gru08(att[1], net[1], inp[1], pool2x(net[0]), interp(net[2], net[1])) + else: + net[1] = self.gru08(att[1], net[1], inp[1], pool2x(net[0])) + + motion_features = self.encoder(disp, corr) + motion_features = torch.cat([inp[0], motion_features], dim=1) + if self.args.n_gru_layers > 1: + net[0] = self.gru04(att[0], net[0], motion_features, interp(net[1], net[0])) + + if self.depth_head_dim > 1: + delta_disp = self.depth_head(net[0]) + else: + delta_disp = self.disp_head(net[0]) + + # scale mask to balence gradients + mask = .25 * self.mask(net[0]) + return net, mask, delta_disp diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/utils/__init__.py b/optgs/model/encoder/.deprecated/foundationstereo/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/.deprecated/foundationstereo/core/utils/utils.py b/optgs/model/encoder/.deprecated/foundationstereo/core/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6852481ceeecb126e42999a268e145a189f2faac --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/core/utils/utils.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + + + +import torch,pdb,logging +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel', divis_by=8, force_square=False): + self.ht, self.wd = dims[-2:] + if force_square: + max_side = max(self.ht, self.wd) + pad_ht = ((max_side // divis_by) + 1) * divis_by - self.ht + pad_wd = ((max_side // divis_by) + 1) * divis_by - self.wd + else: + pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by + pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + assert all((x.ndim == 4) for x in inputs) + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, x): + assert x.ndim == 4 + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False, low_memory=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 # Normalize to [-1,1] + assert torch.unique(ygrid).numel() == 1 and H == 1 # This is a stereo problem + grid = torch.cat([xgrid, ygrid], dim=-1).to(img.dtype) + img = F.grid_sample(img, grid, align_corners=True) + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + diff --git a/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/LICENSE.txt b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/__init__.py b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/blocks.py b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf594729fbcc09189ac0ba91b8c95581b80f443 --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/blocks.py @@ -0,0 +1,153 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output diff --git a/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/dpt.py b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..442a04a83f890105f5b149f31b72c4a031319b0c --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/dpt.py @@ -0,0 +1,203 @@ +import argparse +import torch,os,sys,pdb +import torch.nn as nn +import torch.nn.functional as F +code_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(f'{code_dir}/../') +from depth_anything.blocks import FeatureFusionBlock, _make_scratch + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPTHead(nn.Module): + def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False): + super(DPTHead, self).__init__() + + self.nclass = nclass + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + if nclass > 1: + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0), + ) + else: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w, return_intermediate=False, patch_size=14): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * patch_size), int(patch_w * patch_size)), mode="bilinear", align_corners=True) + if return_intermediate: + depth = self.scratch.output_conv2(out) + depth = F.relu(depth) + disp = 1/depth + disp[depth==0] = 0 + disp = disp/disp.max() + return out, path_1, path_2, path_3, path_4, disp + + else: + out = self.scratch.output_conv2(out) + return out + + +class DPT_DINOv2(nn.Module): + def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, pretrained_dino=False): + super(DPT_DINOv2, self).__init__() + + assert encoder in ['vits', 'vitb', 'vitl'] + + # in case the Internet connection is not stable, please load the DINOv2 locally + # if localhub: + # self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False) + # else: + self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder), pretrained=pretrained_dino) + + + dim = self.pretrained.blocks[0].attn.qkv.in_features + + self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + h, w = x.shape[-2:] + + features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True) + patch_size = self.pretrained.patch_size + patch_h, patch_w = h // patch_size, w // patch_size + output = self.depth_head(features, patch_h, patch_w, patch_size=patch_size, return_intermediate=True) + return output + + +class DepthAnything(DPT_DINOv2): + def __init__(self, config): + super().__init__(**config) + + def forward(self, x): + h, w = x.shape[-2:] + + features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True) + patch_size = self.pretrained.patch_size + patch_h, patch_w = h // patch_size, w // patch_size + depth = self.depth_head(features, patch_h, patch_w, patch_size=patch_size) + depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) + depth = F.relu(depth) + + return depth.squeeze(1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--encoder", + default="vits", + type=str, + choices=["vits", "vitb", "vitl"], + ) + args = parser.parse_args() + + model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder)) + + print(model) diff --git a/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/util/__init__.py b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/util/transform.py b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/util/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..d542fefee7ef4c23d7e4425161902e8a8f1e8b6e --- /dev/null +++ b/optgs/model/encoder/.deprecated/foundationstereo/depth_anything/util/transform.py @@ -0,0 +1,248 @@ +import random +from PIL import Image, ImageOps, ImageFilter +import torch +from torchvision import transforms +import torch.nn.functional as F + +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + if "semseg_mask" in sample: + # sample["semseg_mask"] = cv2.resize( + # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST + # ) + sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0] + + if "mask" in sample: + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + # sample["mask"] = sample["mask"].astype(bool) + + # print(sample['image'].shape, sample['depth'].shape) + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "semseg_mask" in sample: + sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32) + sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"]) + + return sample diff --git a/optgs/model/encoder/.deprecated/raft_backbone.py b/optgs/model/encoder/.deprecated/raft_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b3cb73c5ff6041f1bb0211cb3591d35579c7f1 --- /dev/null +++ b/optgs/model/encoder/.deprecated/raft_backbone.py @@ -0,0 +1,211 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class RAFTBasicEncoder(nn.Module): + def __init__(self, output_dim=256, norm_fn='instance', dropout=0.0): + super(RAFTBasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + + +if __name__ == '__main__': + model = RAFTBasicEncoder().cuda() + x = torch.randn(2, 3, 64, 128).cuda() + y = model(x) + print(y.shape) + + weights = '/users/hxu/projects/optgs/pretrained/raft_models/raft-sintel.pth' + weights = torch.load(weights) + # remove 'module.fnet.' in the keys + new_state_dict = {} + for k, v in weights.items(): + if 'module.fnet.' in k: + new_key = k.replace('module.fnet.', '') if k.startswith('module.fnet.') else k + new_state_dict[new_key] = v + model.load_state_dict(new_state_dict) + diff --git a/optgs/model/encoder/__init__.py b/optgs/model/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/depth_anything_v2/__init__.py b/optgs/model/encoder/depth_anything_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/depth_anything_v2/dinov2.py b/optgs/model/encoder/depth_anything_v2/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..596bb745ada3b0130c7269b0bbae830dc59807f9 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dinov2.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ): + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4., + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4., + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4., + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) diff --git a/optgs/model/encoder/depth_anything_v2/dinov2_layers/__init__.py b/optgs/model/encoder/depth_anything_v2/dinov2_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/optgs/model/encoder/depth_anything_v2/dinov2_layers/attention.py b/optgs/model/encoder/depth_anything_v2/dinov2_layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dinov2_layers/attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + \ No newline at end of file diff --git a/optgs/model/encoder/depth_anything_v2/dinov2_layers/block.py b/optgs/model/encoder/depth_anything_v2/dinov2_layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/optgs/model/encoder/depth_anything_v2/dinov2_layers/drop_path.py b/optgs/model/encoder/depth_anything_v2/dinov2_layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/optgs/model/encoder/depth_anything_v2/dinov2_layers/layer_scale.py b/optgs/model/encoder/depth_anything_v2/dinov2_layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/optgs/model/encoder/depth_anything_v2/dinov2_layers/mlp.py b/optgs/model/encoder/depth_anything_v2/dinov2_layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/optgs/model/encoder/depth_anything_v2/dinov2_layers/patch_embed.py b/optgs/model/encoder/depth_anything_v2/dinov2_layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dinov2_layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/optgs/model/encoder/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/optgs/model/encoder/depth_anything_v2/dinov2_layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/optgs/model/encoder/depth_anything_v2/dpt.py b/optgs/model/encoder/depth_anything_v2/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..18d3e6f8b5c973aea5d4927ef783d43053ba6429 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/dpt.py @@ -0,0 +1,221 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + +from .dinov2 import DINOv2 +from .util.blocks import FeatureFusionBlock, _make_scratch +from .util.transform import Resize, NormalizeImage, PrepareForNet + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DepthAnythingV2(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False + ): + super(DepthAnythingV2, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + 'vitb': [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23], + 'vitg': [9, 19, 29, 39] + } + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + + self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 + + features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) + + depth = self.depth_head(features, patch_h, patch_w) + depth = F.relu(depth) + + return depth.squeeze(1) + + @torch.no_grad() + def infer_image(self, raw_image, input_size=518): + image, (h, w) = self.image2tensor(raw_image, input_size) + + depth = self.forward(image) + + depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] + + return depth.cpu().numpy() + + def image2tensor(self, raw_image, input_size=518): + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + h, w = raw_image.shape[:2] + + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0) + + DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + image = image.to(DEVICE) + + return image, (h, w) diff --git a/optgs/model/encoder/depth_anything_v2/util/__init__.py b/optgs/model/encoder/depth_anything_v2/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/depth_anything_v2/util/blocks.py b/optgs/model/encoder/depth_anything_v2/util/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..382ea183a40264056142afffc201c992a2b01d37 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/util/blocks.py @@ -0,0 +1,148 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output diff --git a/optgs/model/encoder/depth_anything_v2/util/transform.py b/optgs/model/encoder/depth_anything_v2/util/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..ed129d720e1bcc40f888a293c8248d56d23cb874 --- /dev/null +++ b/optgs/model/encoder/depth_anything_v2/util/transform.py @@ -0,0 +1,160 @@ +import numpy as np +import cv2 +import torch +import torch.nn.functional as F + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) + + if self.__resize_target: + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "mask" in sample: + sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + return sample \ No newline at end of file diff --git a/optgs/model/encoder/layer.py b/optgs/model/encoder/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..b013aa2dfaa1890edf2446c4da4893ed1be2d178 --- /dev/null +++ b/optgs/model/encoder/layer.py @@ -0,0 +1,71 @@ +from torch import nn as nn + +from torchvision.models import resnet18 + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_channels, out_channels, stride=1, downsample=None, num_groups=8): + super(BasicBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=stride, padding=1, bias=False) + self.gn1 = nn.GroupNorm(num_groups, out_channels) + + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, + stride=1, padding=1, bias=False) + self.gn2 = nn.GroupNorm(num_groups, out_channels) + + self.downsample = downsample + self.gelu = nn.GELU() + + def forward(self, x): + identity = x + + out = self.gelu(self.gn1(self.conv1(x))) + out = self.gn2(self.conv2(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.gelu(out) + + return out + + +class ResNetFeatureWarpper(nn.Module): + def __init__(self, shallow_resnet_feature=False): + super(ResNetFeatureWarpper, self).__init__() + + self.shallow_resnet_feature = shallow_resnet_feature + + resnet = resnet18(pretrained=True) + + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu + self.maxpool = resnet.maxpool + self.layer1 = resnet.layer1 + if not shallow_resnet_feature: + self.layer2 = resnet.layer2 + + def forward(self, x): + out = [] + x = self.conv1(x) + out.append(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + out.append(x) + + if not self.shallow_resnet_feature: + x = self.layer2(x) + out.append(x) + + return out + + diff --git a/optgs/model/encoder/lvsm/__init__.py b/optgs/model/encoder/lvsm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/lvsm/layer.py b/optgs/model/encoder/lvsm/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8170a2dd2f52ffc51a9f81c5dce7fbc3ca50d7 --- /dev/null +++ b/optgs/model/encoder/lvsm/layer.py @@ -0,0 +1,340 @@ +import torch +import torch.nn as nn +from einops import rearrange + +import warnings +import torch.nn.functional as F + + +USE_FLASH_ATTENTION3 = True +try: + from flash_attn_interface import flash_attn_func + FA3_AVAILABLE = True + warnings.warn('flash attention 3 is available (LVSM)') +except ImportError: + FA3_AVAILABLE = False + warnings.warn('flash attention 3 is not available (LVSM)') + + +try: + import xformers.ops as xops + XFORMERS_AVAILABLE = True +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn('xformers is not available (LVSM)') + # raise ImportError("Please install xformers to use flashatt v2") + + +def init_weights(module, std=0.02): + """Initialize weights for linear and embedding layers. + + Args: + module: Module to initialize + std: Standard deviation for normal initialization + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + torch.nn.init.normal_(module.weight, mean=0.0, std=std) + if isinstance(module, nn.Linear) and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + + +# src: https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/llama/model.py#L28 +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + + return output * self.weight.type_as(x) + + + +class MLP(nn.Module): + """ + Multi-Layer Perceptron block. + Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L49-L65 + """ + + def __init__( + self, + dim, + mlp_ratio=4, + bias=False, + dropout=0.0, + activation=nn.GELU, + mlp_dim=None, + ): + """ + Args: + dim: Input dimension + mlp_ratio: Multiplier for hidden dimension + bias: Whether to use bias in linear layers + dropout: Dropout probability + activation: Activation function + mlp_dim: Optional explicit hidden dimension (overrides mlp_ratio) + """ + super().__init__() + hidden_dim = mlp_dim if mlp_dim is not None else int(dim * mlp_ratio) + + self.mlp = nn.Sequential( + nn.Linear(dim, hidden_dim, bias=bias), + activation(), + nn.Linear(hidden_dim, dim, bias=bias), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.mlp(x) + + + +class QK_Norm_SelfAttention(nn.Module): + """ + Self-attention with optional Q-K normalization. + Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L68-L92 + """ + + def __init__( + self, + dim, + head_dim, + qkv_bias=False, + fc_bias=True, + attn_dropout=0.0, + fc_dropout=0.0, + use_qk_norm=True, + ): + """ + Args: + dim: Input dimension + head_dim: Dimension of each attention head + qkv_bias: Whether to use bias in QKV projection + fc_bias: Whether to use bias in output projection + attn_dropout: Dropout probability for attention weights + fc_dropout: Dropout probability for output projection + use_qk_norm: Whether to use Q-K normalization + We use flash attention V2 for efficiency. + """ + super().__init__() + assert dim % head_dim == 0, f"Token dimension {dim} should be divisible by head dimension {head_dim}" + + self.dim = dim + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.attn_dropout = attn_dropout + self.use_qk_norm = use_qk_norm + + self.to_qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) + self.fc = nn.Linear(dim, dim, bias=fc_bias) + self.attn_fc_dropout = nn.Dropout(fc_dropout) + + # Optional Q-K normalization + if self.use_qk_norm: + self.q_norm = RMSNorm(head_dim) + self.k_norm = RMSNorm(head_dim) + + def forward(self, x, attn_bias=None): + """ + Args: + x: Input tensor of shape (batch, seq_len, dim) + attn_bias: Optional attention bias mask + + Returns: + Output tensor of shape (batch, seq_len, dim) + """ + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + q, k, v = (rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.head_dim) for t in (q, k, v)) + + # Apply qk normalization if enabled + if self.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + if USE_FLASH_ATTENTION3 and FA3_AVAILABLE: + x = flash_attn_func(q, k, v)[0] + elif XFORMERS_AVAILABLE: + x = xops.memory_efficient_attention( + q, k, v, + attn_bias=attn_bias, + p=self.attn_dropout if self.training else 0.0, + op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), + ) + else: + # use pytorch's built-in attention + q = q.permute(0, 2, 1, 3).contiguous() # [B, H, L, C] + k = k.permute(0, 2, 1, 3).contiguous() + v = v.permute(0, 2, 1, 3).contiguous() + x = F.scaled_dot_product_attention(q, k, v) + x = x.permute(0, 2, 1, 3).contiguous() # [B, L, H, C] + + x = rearrange(x, "b l nh dh -> b l (nh dh)") + x = self.attn_fc_dropout(self.fc(x)) + + return x + + + + +class SubsetAttention(nn.Module): + """Attention that can attend to subsets of queries or keys/values.""" + + def __init__( + self, + dim, + head_dim, + qkv_bias=False, + attn_dropout=0.0, + fc_bias=False, + fc_dropout=0.0, + use_qk_norm=False + ): + """ + Args: + dim: Input dimension + head_dim: Dimension of each attention head + qkv_bias: Whether to use bias in QKV projection + attn_dropout: Dropout probability for attention weights + fc_bias: Whether to use bias in output projection + fc_dropout: Dropout probability for output projection + use_qk_norm: Whether to use Q-K normalization + We use flash attention V2 for efficiency. + """ + super().__init__() + assert dim % head_dim == 0, f"Token dimension {dim} should be divisible by head dimension {head_dim}" + + self.dim = dim + self.head_dim = head_dim + self.num_heads = dim // head_dim + self.attn_dropout = attn_dropout + self.use_qk_norm = use_qk_norm + + # Projections + self.to_qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) + self.fc = nn.Linear(dim, dim, bias=fc_bias) + self.attn_fc_dropout = nn.Dropout(fc_dropout) + + # Optional Q-K normalization + if self.use_qk_norm: + self.q_norm = RMSNorm(head_dim) + self.k_norm = RMSNorm(head_dim) + + def forward(self, x, subset_kv_size=None, subset_q_size=None): + """ + Args: + x: Input tensor of shape (batch, seq_len, dim) + subset_kv_size: If provided, only attend to tokens after this index in KV + subset_q_size: If provided, only compute attention for queries up to this index + + Returns: + Output tensor of shape (batch, seq_len, dim) + """ + # Only one subset parameter can be provided + assert not (subset_kv_size is not None and subset_q_size is not None), \ + "Only one of subset_kv_size or subset_q_size can be provided" + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + q, k, v = (rearrange(t, "b l (nh dh) -> b l nh dh", dh=self.head_dim) for t in (q, k, v)) + + if self.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + # Handle subset attention cases + if subset_kv_size is not None and subset_kv_size < k.shape[1]: + # Attend to subset of key/value tokens + k_subset = k[:, subset_kv_size:, :, :].contiguous() + v_subset = v[:, subset_kv_size:, :, :].contiguous() + + x = xops.memory_efficient_attention( + q, k_subset, v_subset, + attn_bias=None, + p=self.attn_dropout if self.training else 0.0, + op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), + ) + elif subset_q_size is not None and subset_q_size < q.shape[1]: + # Only compute attention for subset of query tokens + q_subset = q[:, :subset_q_size, :, :].contiguous() + + x = xops.memory_efficient_attention( + q_subset, k, v, + attn_bias=None, + p=self.attn_dropout if self.training else 0.0, + op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), + ) + else: + # Regular attention for all tokens + x = xops.memory_efficient_attention( + q, k, v, + attn_bias=None, + p=self.attn_dropout if self.training else 0.0, + op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp), + ) + + x = rearrange(x, "b l nh dh -> b l (nh dh)") + + # Final projection + x = self.attn_fc_dropout(self.fc(x)) + + return x + + + + +class QK_Norm_TransformerBlock(nn.Module): + """ + Standard transformer block with pre-normalization architecture. + Reference: https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L95-L113 + """ + + def __init__( + self, + dim, + head_dim, + ln_bias=False, + attn_qkv_bias=False, + attn_dropout=0.0, + attn_fc_bias=False, + attn_fc_dropout=0.0, + mlp_ratio=4, + mlp_bias=False, + mlp_dropout=0.0, + use_qk_norm=True, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, bias=ln_bias) + self.attn = QK_Norm_SelfAttention( + dim=dim, + head_dim=head_dim, + qkv_bias=attn_qkv_bias, + fc_bias=attn_fc_bias, + attn_dropout=attn_dropout, + fc_dropout=attn_fc_dropout, + use_qk_norm=use_qk_norm, + ) + + self.norm2 = nn.LayerNorm(dim, bias=ln_bias) + self.mlp = MLP( + dim=dim, + mlp_ratio=mlp_ratio, + bias=mlp_bias, + dropout=mlp_dropout, + ) + + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + + + \ No newline at end of file diff --git a/optgs/model/encoder/lvsm/transformer.py b/optgs/model/encoder/lvsm/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2ce07e1e46be9b7e068dc561336098e93c68fd2b --- /dev/null +++ b/optgs/model/encoder/lvsm/transformer.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + +from .layer import QK_Norm_TransformerBlock, init_weights + + +class LVSMTransformer(nn.Module): + def __init__(self, + dim=768, + d_head=64, + n_layer=24, + special_init=True, + depth_init=True, + use_qk_norm=True, + ): + super().__init__() + + # Create transformer blocks + self.transformer_blocks = [ + QK_Norm_TransformerBlock( + dim, d_head, use_qk_norm=use_qk_norm + ) for _ in range(n_layer) + ] + + # Apply special initialization if configured + if special_init: + for idx, block in enumerate(self.transformer_blocks): + if depth_init: + weight_init_std = 0.02 / (2 * (idx + 1)) ** 0.5 + else: + weight_init_std = 0.02 / (2 * n_layer) ** 0.5 + block.apply(lambda module: init_weights(module, weight_init_std)) + else: + for block in self.transformer_blocks: + block.apply(init_weights) + + self.transformer_blocks = nn.ModuleList(self.transformer_blocks) + + + def forward(self, x): + + for blk in self.transformer_blocks: + x = blk(x) + + return x + + + +if __name__ == '__main__': + device = torch.device('cuda') + model = LVSMTransformer().to(device) + + x = torch.randn(2, 64, 768).to(device) + + with torch.autocast('cuda', dtype=torch.bfloat16): + y = model(x) + + print(y.shape) + + + + diff --git a/optgs/model/encoder/point_transformer/__init__.py b/optgs/model/encoder/point_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/point_transformer/debug.py b/optgs/model/encoder/point_transformer/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2314d4dd172c5b2a67c5e567fce28088702e40 --- /dev/null +++ b/optgs/model/encoder/point_transformer/debug.py @@ -0,0 +1,59 @@ +import torch +import time + + +from layer import KNNAttention, TransformerBlock, PlainPointTransformer, SubsampleBlock + + +device = torch.device('cuda') + +def test(): + bs = 4 + npts = 1024 + len_xyz = 3 + feat_dims = 64 + num_classes = 23 + coord = torch.rand(bs * npts, len_xyz).cuda() + feat = torch.rand(bs * npts, feat_dims).cuda() + offset = [npts * i for i in range(1, bs + 1)] + offset = torch.tensor(offset).cuda() + + # data_dict = dict( + # coord = coord, + # feat = feat, + # offset = offset + # ) + + # model = PointTransformerSeg26().cuda() + + # model = KNNAttention(feat_dims, num_samples=16).cuda() + + # model = TransformerBlock(feat_dims).cuda() + + # model = PlainPointTransformer(feat_dims, num_blocks=2).cuda() + model = SubsampleBlock(feat_dims, feat_dims).cuda() + + print(model) + + + # count time + # count = 100 + + # torch.cuda.synchronize() + # start = time.time() + + # for _ in range(count): + # out = model((coord, feat, offset)) + + # torch.cuda.synchronize() + # print(time.time() - start) + + + out = model((coord, feat, offset)) + print(out[0].shape) + + # print(out.shape) + + + +test() diff --git a/optgs/model/encoder/point_transformer/layer.py b/optgs/model/encoder/point_transformer/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..148a57edbbbcf725037d3b57ee423dc138c25ca9 --- /dev/null +++ b/optgs/model/encoder/point_transformer/layer.py @@ -0,0 +1,1491 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +import pointops +from pointops import grouping, grouping2 +from einops import rearrange +import time + +from ..unimatch.dinov2.layers.block import Block as MultiViewBlock +from ..unimatch.utils import mv_feature_add_position +from ..unimatch.mv_transformer import MultiViewFeatureTransformer + + +USE_PYTORCH_ATTN = False +USE_FLASH_ATTN3 = False + +# try: +# from flash_attn_interface import flash_attn_func +# FA3_AVAILABLE = True +# warnings.warn('flash attention 3 is available (point attn)') + +# except ImportError: +# FA3_AVAILABLE = False +# warnings.warn('flash attention 3 is not available (point attn)') + + +class KNNAttention(nn.Module): + # TODO: multi-head + def __init__(self, channels, knn_samples=16, no_rpe=True, + qk_norm=False, + num_heads=1, + proj_channels=None, + use_fused=False, + ): + super().__init__() + self.proj_channels = proj_channels + + self.knn_samples = knn_samples + self.no_rpe = no_rpe + self.num_heads = num_heads + assert self.num_heads == 1 + + self.use_fused = use_fused + if use_fused: + try: + import sys + from optgs.paths import PROJECT_DIR + sys.path.append(str(PROJECT_DIR / "submodules")) + from fused_knn_attn import fused_knn_attention, FUSED_KNN_ATTN_CUDA_AVAILABLE + self._fused_knn_attention = fused_knn_attention + if not FUSED_KNN_ATTN_CUDA_AVAILABLE: + import warnings + warnings.warn( + "Fused KNN attention CUDA extension not available, " + "using PyTorch fallback (still avoids [N,K,C] intermediates)" + ) + except ImportError: + import warnings + warnings.warn( + "fused_knn_attn package not found, falling back to unfused attention" + ) + self.use_fused = False + + self.qk_norm = qk_norm + if qk_norm: + self.q_norm = nn.RMSNorm(channels) + self.k_norm = nn.RMSNorm(channels) + + if self.proj_channels is not None: + self.qkv = nn.Linear(channels, self.proj_channels * 3, bias=False) + self.proj = nn.Linear(self.proj_channels, channels) + else: + self.qkv = nn.Linear(channels, channels * 3, bias=False) + self.proj = nn.Linear(channels, channels) + + if not self.no_rpe: + self.rpe = nn.Sequential( + nn.Linear(3, 32), + nn.GELU(), + nn.Linear(32, 1) + ) + + + def forward(self, pxo, knn_idx=None): + # [N, 3], [N, C], [B] + p, x, o = pxo + c = x.size(1) + + if self.proj_channels is not None: + c = self.proj_channels + + assert c % self.num_heads == 0 + head_dim = c // self.num_heads + scale_factor = head_dim ** -0.5 + + qkv = self.qkv(x) # [N, 3*C] + x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) # each [N, C] + + # ---- Fused path: gather + attention in one kernel ---- + if self.use_fused and self.no_rpe: + # Ensure we have KNN indices + if knn_idx is None: + knn_idx, _ = pointops.knn_query( + self.knn_samples, p, o, p, o + ) + + # qk_norm: RMSNorm normalizes each C-dim vector independently, + # so applying before gather is equivalent to applying after gather. + if self.qk_norm: + x_q = self.q_norm(x_q) + x_k = self.k_norm(x_k) + + out = self._fused_knn_attention( + x_q.contiguous(), x_k.contiguous(), x_v.contiguous(), + knn_idx.contiguous(), scale_factor + ) + out = self.proj(out) + return out + + # ---- Original unfused path ---- + # # [N, K, C], [N, K] + # x_k, idx = pointops.knn_query_and_group( + # x_k.contiguous(), p, o, new_xyz=p, new_offset=o, + # idx=knn_idx, + # nsample=self.knn_samples, with_xyz=False + # ) # [N, K, C] + # + # # [N, K, C] + # x_v, _ = pointops.knn_query_and_group( + # x_v.contiguous(), + # p, + # o, + # new_xyz=p, + # new_offset=o, + # idx=idx, + # nsample=self.knn_samples, + # with_xyz=False, + # ) + + # ---- Initial improved version ---- + x_kv = torch.cat([x_k, x_v], dim=-1) # [N, 2C/3] + x_kv_query, _ = pointops.knn_query_and_group( + x_kv.contiguous(), p, o, new_xyz=p, new_offset=o, + idx=knn_idx, nsample=self.knn_samples, with_xyz=False + ) # [N, K, 2C/3] + x_k, x_v = torch.chunk(x_kv_query, chunks=2, dim=-1) + + # [N, K, 3], [N, K, C] + # NOTE: without xyz in knn + # p_r, x_k = x_k[:, :, :3], x_k[:, :, 3:] + + # [N, 1, K] + assert self.no_rpe + if not self.no_rpe: + rpe = self.rpe(p_r).permute(0, 2, 1) + else: + rpe = 0 + + if self.qk_norm: + x_q = self.q_norm(x_q) + x_k = self.k_norm(x_k) + + n, k, c = x_k.shape + + # attention + if USE_PYTORCH_ATTN: + out = F.scaled_dot_product_attention( + x_q.view(n, 1, c), + x_k.view(n, k, c), + x_v.view(n, k, c), + ).reshape(n, c) # [N, C] + + elif (USE_FLASH_ATTN3 and FA3_AVAILABLE and self.no_rpe): + # no relative pos enc + out = flash_attn_func( + x_q.view(n, 1, self.num_heads, head_dim).to(torch.bfloat16), + x_k.view(n, k, self.num_heads, head_dim).to(torch.bfloat16), + x_v.view(n, k, self.num_heads, head_dim).to(torch.bfloat16), + )[0].reshape(n, c).float() # [N, C] + else: + # [N, 1, K] + scores = torch.matmul(x_q.unsqueeze(1), x_k.permute(0, 2, 1)) * scale_factor + rpe + # [N, C] + out = torch.matmul(torch.softmax(scores, dim=2), x_v).squeeze(1) + + out = self.proj(out) + + return out + + +class MLP(nn.Module): + def __init__( + self, + channels, + act="gelu", + ): + super().__init__() + + expansion = 4 + + self.fc1 = nn.Linear(channels, channels * expansion) + if act is None or act in ['none', 'identity']: + self.act = nn.Identity() + elif act == 'gelu': + self.act = nn.GELU() + elif act == 'tanh': + self.act = nn.Tanh() + else: + raise ValueError(f"unsupported activation {act}") + self.fc2 = nn.Linear(channels * expansion, channels) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +class TransformerBlock(nn.Module): + def __init__(self, channels, knn_samples=16, post_norm=False, + no_rpe=False, + no_attn=False, + no_norm=False, + act="gelu", + qk_norm=False, + norm_pt_block=False, + num_heads=1, + attn_proj_channels=None, + use_fused_attn=False, + ): + super().__init__() + self.post_norm = post_norm + self.no_attn = no_attn + self.norm_pt_block = norm_pt_block + + if no_norm: + self.norm1 = nn.Identity() + self.norm2 = nn.Identity() + else: + self.norm1 = nn.LayerNorm(channels) + self.norm2 = nn.LayerNorm(channels) + + if self.no_attn: + self.linear = nn.Linear(channels, channels) + else: + self.attn = KNNAttention(channels, knn_samples=knn_samples, no_rpe=no_rpe, + qk_norm=qk_norm, + num_heads=num_heads, + proj_channels=attn_proj_channels, + use_fused=use_fused_attn, + ) + self.mlp = MLP(channels, act=act) + + if self.norm_pt_block: + self.norm3 = nn.LayerNorm(channels) + + def forward(self, pxo, knn_idx=None): + p, x, o = pxo + + if self.post_norm: + if self.no_attn: + x = x + self.norm1(self.linear(x)) + else: + x = x + self.norm1(self.attn((p, x, o), knn_idx=knn_idx)) + x = x + self.norm2(self.mlp(x)) + else: + if self.no_attn: + x = x + self.linear(self.norm1(x)) + else: + x = x + self.attn((p, self.norm1(x), o), knn_idx=knn_idx) + x = x + self.mlp(self.norm2(x)) + + if self.norm_pt_block: + x = self.norm3(x) + + return x + + +class FPSSubsample(nn.Module): + def __init__(self, in_planes, out_planes, stride=4, nsample=16, + agg_func='attn', + subsample_method='fps', + return_idx=False, + fps_num_samples=None, + attn_channels=64, + ): + super().__init__() + + assert stride > 0 + + self.agg_func = agg_func + self.subsample_method = subsample_method + self.knn_samples = nsample + self.return_idx = return_idx + + self.stride, self.nsample = stride, nsample + + if fps_num_samples is not None: + self.nsample = fps_num_samples + + # if stride != 1: + # # xyz + feature + # # self.linear = nn.Linear(3 + in_planes, out_planes, bias=not post_norm) + # # only feature + # # TODO: attention aggregation + # if agg_func == 'maxpool': + # self.agg = nn.MaxPool1d(nsample) + # elif agg_func == 'avgpool': + # self.agg = nn.AvgPool1d(nsample) + # else: + # raise ValueError(f"unsupported agg_func {agg_func}") + + # fewer channels to save memory + assert agg_func in ['attn', 'avgpool'] + if self.agg_func == 'attn': + self.q = nn.Linear(in_planes, attn_channels, bias=False) + self.k = nn.Linear(in_planes, attn_channels, bias=False) + self.v = nn.Linear(in_planes, attn_channels, bias=False) + + self.proj = nn.Linear(attn_channels, out_planes, bias=True) + self.residual = nn.Linear(in_planes, out_planes, bias=True) + else: + self.proj = nn.Linear(in_planes, out_planes, bias=True) + + def forward(self, pxo): + p, x, o = pxo # (n, 3), (n, c), (b) + if self.stride != 1: + if self.subsample_method == 'density': + assert False # not well tested + n_o, count = [o[0].item() // self.stride], o[0].item() // self.stride + for i in range(1, o.shape[0]): + count += (o[i].item() - o[i - 1].item()) // self.stride + n_o.append(count) + n_o = torch.tensor(n_o, dtype=torch.int32, device=x.device) + + # [N, K, C+3] + x_k, _ = pointops.knn_query_and_group( + x.contiguous(), p, o, new_xyz=p, new_offset=o, nsample=self.knn_samples, with_xyz=True + ) + + p_r = x_k[:, :, 0:3] + density = torch.mean(torch.norm(p_r, dim=-1), dim=-1) # [N] + + # TODO: normalize the distance + weights = (density - density.min()) / (density.max() - density.min() + 1e-6) + # weights = density + + # weights = 1.0 / (density + 1e-6) # Inverse density weighting + + # to batch + lists = [weights[:o[0]]] + for i in range(o.shape[0] - 1): + lists.append(weights[o[i]:o[i+1]]) + + weights = torch.stack(lists, dim=0) # [B, N] + + weights = weights / weights.sum(dim=1, keepdim=True) # Normalize weights + + # Sample points based on weights + batch = n_o.shape[0] + num_samples = o[0].item() // self.stride + sampled_indices = torch.stack([ + torch.multinomial(weights[b], num_samples, replacement=False) + for b in range(batch) + ], dim=0) # (B, num_samples) + + idx = rearrange(sampled_indices, "b n -> (b n)") + + point_list = [p[:o[0]]] + for i in range(o.shape[0] - 1): + point_list.append(p[o[i]:o[i+1], :]) + + points = torch.stack(point_list, dim=0) # [B, N, 3] + + # Gather sampled points + sampled_points = torch.gather(points, 1, sampled_indices.unsqueeze(-1).expand(-1, -1, 3)) + + # print(sampled_points.shape) # [B, M, 3] + + sampled_points = rearrange(sampled_points, "b m c -> (b m) c") + + # average pooling + # TODO: try others + x = x_k.mean(dim=1) # [N, C] + x_list = [x[:o[0]]] + for i in range(o.shape[0] - 1): + x_list.append(x[o[i]:o[i+1], :]) + x = torch.stack(x_list, dim=0) # [B, N, C] + + # Gather sampled points + x = torch.gather(x, 1, sampled_indices.unsqueeze(-1).expand(-1, -1, x.size(-1))) + x = rearrange(x, "b n c -> (b n) c") + + # TODO: do we need to add residual to x here? + # use the index to subsample the initial features + x = self.proj(x) + + p, o = sampled_points, n_o + elif self.subsample_method in ['fps', 'grid']: + n_o, count = [o[0].item() // self.stride], o[0].item() // self.stride + for i in range(1, o.shape[0]): + count += (o[i].item() - o[i - 1].item()) // self.stride + n_o.append(count) + n_o = torch.tensor(n_o, dtype=torch.int32, device=x.device) + + if self.subsample_method == 'fps': + idx = pointops.farthest_point_sampling(p, o, n_o) # (m) + else: + # uniform sampling: sanity check + # first reshape to V, H, W, then do grid sampling + # Generate grid indices + # TODO: grid sample in the image space + idx = torch.arange(0, p.size(0), self.stride).to(x.device) + + n_p = p[idx.long(), :] # (m, 3) + x_subsample = x[idx.long(), :] # [M, C] + if self.agg_func == 'attn': + x_q = self.q(x_subsample) # [M, C] + # [M, K, C] + x_k = self.k(x) # [N, C] + else: + x_k = x + + x_k, knn_idx = pointops.knn_query_and_group( + x_k, + p, + offset=o, + new_xyz=n_p, + new_offset=n_o, + nsample=self.nsample, + with_xyz=False, # remove xyz + ) + + if self.agg_func == 'attn': + x_v = self.v(x) + x_v, _ = pointops.knn_query_and_group( + x_v, + p, + offset=o, + new_xyz=n_p, + new_offset=n_o, + idx=knn_idx, + nsample=self.nsample, + with_xyz=False, # remove xyz + ) + + # attention + # x_q: [M, C], x_k: [M, K, C], x_v: [M, K, C] + scale_factor = x_q.shape[-1] ** -0.5 + + # [M, 1, K] + # no relative posenc + scores = torch.matmul(x_q.unsqueeze(1), x_k.permute(0, 2, 1)) * scale_factor + # [M, C] + x = torch.matmul(torch.softmax(scores, dim=2), x_v).squeeze(1) + + # if self.agg_func in ['avgpool', 'maxpool']: + # x = self.agg(x.transpose(1, 2).contiguous()).squeeze(-1) # (m, c) + # else: + # raise NotImplementedError + + # add residual to x here? + # use the index to subsample the initial features + x = self.residual(x_subsample) + self.proj(x) + else: + x = x_k.mean(dim=1) + x = self.proj(x) + + p, o = n_p, n_o + + else: + raise ValueError(f"unsupported subsampling method {self.subsample_method}") + else: + # add residual to x here? + x = x + self.proj(x) + + idx = torch.arange(0, p.size(0)).to(x.device) + + if self.return_idx: + return [p, x, o], idx + return [p, x, o] + + +class SubsampleBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=4, knn_samples=16, post_norm=False, + agg_func='attn', + subsample_method='fps', + return_idx=False, + fps_num_samples=None, + attn_proj_channels=None, + ): + super().__init__() + + assert not post_norm + + self.return_idx = return_idx + + self.post_norm = post_norm + self.norm1 = nn.LayerNorm(in_channels) + self.fps = FPSSubsample(in_channels, out_channels, stride=stride, nsample=knn_samples, + agg_func=agg_func, + subsample_method=subsample_method, + return_idx=return_idx, + fps_num_samples=fps_num_samples, + attn_channels=attn_proj_channels, + ) + + self.norm2 = nn.LayerNorm(out_channels) + self.mlp = MLP(out_channels) + + def forward(self, pxo): + + # pre norm + p, x, o = pxo + x = self.norm1(x) + + if self.return_idx: + pxo, idx = self.fps([p, x, o]) + else: + pxo = self.fps([p, x, o]) + + p, x, o = pxo + + x = x + self.mlp(self.norm2(x)) + + if self.return_idx: + return [p, x, o], idx + + return [p, x, o] + + +class SkipConnect(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.proj1 = nn.Linear(out_channels, out_channels) + self.proj2 = nn.Linear(in_channels, out_channels) + self.proj3 = nn.Linear(out_channels, out_channels) + + def forward(self, pxo1, pxo2): + p1, x1, o1 = pxo1 + p2, x2, o2 = pxo2 + + # TODO: support half precision + with torch.amp.autocast(device_type='cuda', enabled=True, dtype=torch.float32): + x = self.proj1(x1) + pointops.interpolation2( + p2, p1, self.proj2(x2), o2, o1 + ) + + x = self.proj3(x) + + return x + + + +class PlainPointTransformer(nn.Module): + def __init__(self, channels, knn_samples=16, num_blocks=4, post_norm=False, + no_rpe=False, + no_attn=False, + no_norm=False, + act="gelu", + qk_norm=False, + norm_pt_block=False, + num_heads=1, + attn_proj_channels=None, + cache_knn_idx=None, + knn_idx_update_every=1, + with_mv_attn=False, + with_mv_attn_lowres=False, + mv_attn_first=False, + no_mv_attn=False, + conv_with_norm=False, + mv_shuffle_attn=False, + with_pos_enc=False, + shuffle_attn_no_norm=False, + mv_unimatch_attn=False, + use_checkpointing=False, + init_use_checkpointing=False, + use_fused_attn=False, + ): + super().__init__() + + self.cache_knn_idx = cache_knn_idx + self.knn_idx_update_every = knn_idx_update_every + self.knn_samples = knn_samples + self.use_checkpointing = use_checkpointing + self.init_use_checkpointing = init_use_checkpointing + + self.with_mv_attn = with_mv_attn + self.with_mv_attn_lowres = with_mv_attn_lowres + if with_pos_enc: + assert mv_shuffle_attn + + self.blocks = nn.ModuleList() + for _ in range(num_blocks): + self.blocks.append(TransformerBlock(channels, knn_samples=knn_samples, + post_norm=post_norm, + no_rpe=no_rpe, + no_attn=no_attn, + no_norm=no_norm, + act=act, + qk_norm=qk_norm, + norm_pt_block=norm_pt_block, + num_heads=num_heads, + attn_proj_channels=attn_proj_channels, + use_fused_attn=use_fused_attn, + )) + + # multi-view attention + if self.with_mv_attn: + self.mv_blocks = nn.ModuleList() + for _ in range(num_blocks): + # if mv_shuffle_attn: + if self.with_mv_attn_lowres: + self.mv_blocks.append( + MultViewLowresAttn( + channels, + ) + ) + else: + self.mv_blocks.append( + MultiViewBlock( + channels, + num_heads=4, + ) + ) + # elif mv_unimatch_attn: + # self.mv_blocks.append( + # MultViewUniMatchAttn( + # channels, + # ) + # ) + # else: + # self.mv_blocks.append( + # MultViewUnetAttn(channels, + # no_mv_attn=no_mv_attn, + # conv_with_norm=conv_with_norm, + # ) + # ) + + def forward(self, pxo, iter=0, b=None, v=None, h=None, w=None): + p, x, o = pxo + # compute knn idx here only once and pass it to the model + # the positions are not changed inside the blocks + if self.cache_knn_idx is None or (iter % self.knn_idx_update_every) == 0: + knn_idx, _ = pointops.knn_query(self.knn_samples, p, o, p, o) + self.cache_knn_idx = knn_idx + # print(knn_idx.float().mean().item()) + else: + knn_idx = self.cache_knn_idx + + if self.with_mv_attn: + assert b is not None and v is not None and h is not None and w is not None + if self.use_checkpointing: + raise NotImplementedError + + for i in range(len(self.blocks)): + # knn attention + x = self.blocks[i]([p, x, o], knn_idx=knn_idx) + # global multi-view attention + x = rearrange(x, "(b v h w) c -> b (v h w) c", b=b, v=v, h=h, w=w) + if self.with_mv_attn_lowres: + x = self.mv_blocks[i](x, v=v, h=h, w=w) + # # TODO: hard-coded for now + # if x.size(1) == 8 * 512 // 4 * 960 // 4: + # x = self.mv_blocks[i](x, v=8, h=512 // 4, w=960 // 4) + # elif x.size(1) == 8 * 256 // 4 * 448 // 4: + # x = self.mv_blocks[i](x, v=8, h=256 // 4, w=448 // 4) + # else: + # raise ValueError(f"unsupported input size {x.size(1)} for multi-view attention") + # # print(x.shape) + else: + x = self.mv_blocks[i](x) + # x = x.squeeze(0) + x = rearrange(x, "b (v h w) c -> (b v h w) c", + b=b, v=v, h=h, w=w) + else: + for blk in self.blocks: + if self.init_use_checkpointing: + # checkpointing the inital reconstruction model + # NOTE: cannot cache knn_idx here, otherwise index out error + def custom_forward(p, x, o): + return blk((p, x, o), knn_idx=None) # knn_idx is closed over + x = torch.utils.checkpoint.checkpoint(custom_forward, p, x, o) + else: + x = blk((p, x, o), knn_idx=knn_idx) + + return x + + +class MultViewUnetAttn(nn.Module): + def __init__(self, channels, no_mv_attn=False, conv_with_norm=False): + super().__init__() + + self.conv_with_norm = conv_with_norm + + self.down1 = nn.Conv2d(channels, channels, 3, 2, 1) + self.down2 = nn.Conv2d(channels, channels, 3, 2, 1) + + self.up2 = nn.Conv2d(channels, channels, 3, 1, 1) + self.up1 = nn.Conv2d(channels, channels, 3, 1, 1) + + self.attn = MultiViewBlock(channels, 4, no_attn=no_mv_attn) + + if self.conv_with_norm: + self.norm1 = nn.LayerNorm(channels) + self.norm2 = nn.LayerNorm(channels) + self.norm3 = nn.LayerNorm(channels) + self.norm4 = nn.LayerNorm(channels) + + def forward(self, x): + v = 8 + h = 256 // 4 + w = 448 // 4 + b = 1 + assert x.size(0) == b * v * h * w + residual = x + x = rearrange(x, "(b v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) + x1 = self.down1(x) # 1/2 + if self.conv_with_norm: + x1 = self.norm1(x1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x2 = self.down2(x1) # 1/4 + if self.conv_with_norm: + x2 = self.norm2(x2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x2 = rearrange(x2, "(b v) c h w -> b (v h w) c", b=b, v=v) + x2 = self.attn(x2) # 1/4 + x2 = rearrange(x2, "b (v h w) c -> (b v) c h w", b=b, v=v, h=h//4, w=w//4) + x2 = self.up2(x1 + F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=True)) # 1/2 + if self.conv_with_norm: + x2 = self.norm3(x2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.up1(x + F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=True)) # 1 + if self.conv_with_norm: + x = self.norm4(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + x = rearrange(x, "(b v) c h w -> (b v h w) c", b=b, v=v) + + x = residual + x + + return x + + +class MultViewShuffleAttn(nn.Module): + def __init__(self, channels, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False): + super().__init__() + + self.down_factor = 4 + self.with_pos_enc = with_pos_enc + + self.proj1 = nn.Linear(channels * self.down_factor ** 2, channels) + if shuffle_attn_no_norm: + self.norm1 = nn.Identity() + else: + self.norm1 = nn.LayerNorm(channels) + + self.proj2 = nn.Linear(channels, channels * self.down_factor ** 2) + + if shuffle_attn_no_norm: + self.norm2 = nn.Identity() + else: + self.norm2 = nn.LayerNorm(channels * self.down_factor ** 2) + + self.conv = nn.Conv2d(channels, channels, 3, 1, 1) + + if no_mv_attn: + self.attn = nn.Identity() + else: + self.attn = MultiViewBlock(channels, 4, no_attn=no_mv_attn) + + def forward(self, x): + v = 8 + h = 256 // 4 + w = 448 // 4 + b = 1 + assert x.size(0) == b * v * h * w + residual = x + x = rearrange(x, "(b v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) + + # TODO: add positional encoding to x + if self.with_pos_enc: + x = mv_feature_add_position(x, attn_splits=1, feature_channels=x.size(1)) + # print(x.shape) + + x = F.pixel_unshuffle(x, self.down_factor) + + x = rearrange(x, "(b v) c h w -> b (v h w) c", b=b) + x = self.proj1(x) + x = self.norm1(x) + + x = self.attn(x) + + x = self.proj2(x) + x = self.norm2(x) + + x = rearrange(x, "b (v h w) c -> (b v) c h w", b=b, v=v, h=h // self.down_factor, w=w // self.down_factor) + x = F.pixel_shuffle(x, self.down_factor) + x = self.conv(x) + x = rearrange(x, "(b v) c h w -> (b v h w) c", b=b, v=v) + x = x + residual + + return x + + +class MultViewLowresAttn(nn.Module): + def __init__(self, channels, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False, + down_factor=4, + attn_proj_channels=None, + ): + super().__init__() + + self.down_factor = down_factor + self.with_pos_enc = with_pos_enc + + self.attn_proj_channels = attn_proj_channels + + if attn_proj_channels: + ori_channels = channels + self.proj0 = nn.Linear(channels, attn_proj_channels) + channels = attn_proj_channels + + if self.down_factor == 8: + down_factor = 4 + else: + down_factor = self.down_factor + + self.proj1 = nn.Linear(channels * down_factor ** 2, channels) + if shuffle_attn_no_norm: + self.norm1 = nn.Identity() + else: + self.norm1 = nn.LayerNorm(channels) + + self.proj2 = nn.Linear(channels, channels * down_factor ** 2) + + if shuffle_attn_no_norm: + self.norm2 = nn.Identity() + else: + self.norm2 = nn.LayerNorm(channels * down_factor ** 2) + + self.conv = nn.Conv2d(channels, channels, 3, 1, 1) + + if attn_proj_channels: + self.proj3 = nn.Linear(channels, ori_channels) + + if no_mv_attn: + self.attn = nn.Identity() + else: + num_heads = 1 if self.attn_proj_channels else 4 + self.attn = MultiViewBlock(channels, num_heads, no_attn=no_mv_attn) + + def forward(self, x, v=None, h=None, w=None, y=None): + if y is not None: + return self.forward_cross_attn(x, y, v, h, w) + residual = x + if self.attn_proj_channels: + x = self.proj0(x) + + x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) + + # TODO: add positional encoding to x + if self.with_pos_enc: + x = mv_feature_add_position(x, attn_splits=1, feature_channels=x.size(1)) + # print(x.shape) + + if self.down_factor == 8: + # bilinear to half first to save channels + x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) + down_factor = 4 + else: + down_factor = self.down_factor + + x = F.pixel_unshuffle(x, down_factor) + + x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) + x = self.proj1(x) + x = self.norm1(x) + + x = self.attn(x) + + x = self.proj2(x) + x = self.norm2(x) + + x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h // self.down_factor, w=w // self.down_factor) + x = F.pixel_shuffle(x, down_factor) + x = self.conv(x) + if self.down_factor == 8: + # bilinear to full + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) + if self.attn_proj_channels: + x = self.proj3(x) + x = x + residual + + return x + + def forward_cross_attn(self, x, y, v=None, h=None, w=None): + residual = x + if self.attn_proj_channels: + x = self.proj0(x) + + assert y is not None + y = rearrange(y, "b (v h w) c -> (b v) c h w", h=h, w=w) # different v with x + num_cross_view = y.shape[0] // x.shape[0] + + x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) + + # TODO: add positional encoding to x + if self.with_pos_enc: + x = mv_feature_add_position(x, attn_splits=1, feature_channels=x.size(1)) + # print(x.shape) + + if self.down_factor == 8: + # bilinear to half first to save channels + x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) + y = F.interpolate(y, scale_factor=0.5, mode='bilinear', align_corners=True) + down_factor = 4 + else: + down_factor = self.down_factor + + x = F.pixel_unshuffle(x, down_factor) + y = F.pixel_unshuffle(y, down_factor) + + x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) + y = rearrange(y, "(b v) c h w -> b (v h w) c", v=num_cross_view) + x = self.proj1(x) + x = self.norm1(x) + + y = self.proj1(y) + y = self.norm1(y) + + # x_tmp = self.attn(x) + + # print((x - y).abs().max().item()) + + x = self.attn(x, y) + + # there will be slight diff for self and cross attn caused by flash3 + # print((x_tmp - x).abs().max().item()) + + x = self.proj2(x) + x = self.norm2(x) + + x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h // self.down_factor, w=w // self.down_factor) + x = F.pixel_shuffle(x, down_factor) + x = self.conv(x) + if self.down_factor == 8: + # bilinear to full + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) + x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) + if self.attn_proj_channels: + x = self.proj3(x) + x = x + residual + + return x + + + +class GaussianErrorCrossAttn(nn.Module): + def __init__(self, gaussian_channels, + error_channels, + model_channels=256, + no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False, + down_factor=4, + attn_proj_channels=None, + num_heads=4, + with_mlp=False, + ): + super().__init__() + + self.num_heads = num_heads + self.model_channels = model_channels + self.down_factor = down_factor + self.with_mlp = with_mlp + + # self.q_norm = nn.LayerNorm(gaussian_channels) + self.q_proj = nn.Linear(gaussian_channels, model_channels) + + kv_channels = error_channels * (down_factor ** 2) + # self.kv_norm = nn.LayerNorm(kv_channels) + self.kv_proj = nn.Linear(kv_channels, 2 * model_channels) + + # self.out_proj = nn.Linear(model_channels, gaussian_channels) + # concat + self.out_proj = nn.Linear(model_channels + gaussian_channels, gaussian_channels) + + if with_mlp: + self.mlp_norm = nn.LayerNorm(gaussian_channels) + self.mlp = MLP(gaussian_channels) + + + def forward(self, gaussian, error, v=None, h=None, w=None, mask=None): + # [B, VHW, C] + residual = gaussian + b = gaussian.size(0) + + # x = self.q_norm(gaussian) + x = gaussian + q = self.q_proj(x) # [B, VHW, C] + + # spatial reshape to save computation + error = rearrange(error, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) + error = F.pixel_unshuffle(error, self.down_factor) + error = rearrange(error, "(b v) c h w -> b (v h w) c", v=v) + # error = self.kv_norm(error) + + kv = self.kv_proj(error) + k, v = kv.chunk(2, dim=-1) # [B, VHW, C] + + # attention + c = self.model_channels + head_dim = c // self.num_heads + + # [B, N, C] → [B, num_heads, N, head_dim] + def reshape(x): + return x.view(b, -1, self.num_heads, head_dim).transpose(1, 2) # [B, H, N, D] + + q = reshape(q) + k = reshape(k) + v = reshape(v) + + # Fast fused attention + out = F.scaled_dot_product_attention(q, k, v) + + # [B, H, N, D] → [B, N, C] + out = out.transpose(1, 2).contiguous().view(b, -1, c) + + # return self.out_proj(out) + + # out = residual + self.out_proj(out) + # concat + out = self.out_proj(torch.cat([out, gaussian], dim=-1)) + + # if self.with_mlp: + # out = out + self.mlp(self.mlp_norm(out)) + + return out + + + + +class MultViewUniMatchAttn(nn.Module): + def __init__(self, channels, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False): + super().__init__() + + self.attn = MultiViewFeatureTransformer(num_layers=1, + d_model=channels, + ) + + def forward(self, x, v=None, h=None, w=None): + residual = x + x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) + + attn_splits = 4 + + # add pos enc + x = mv_feature_add_position(x, attn_splits, feature_channels=x.size(1)) + x = rearrange(x, "(b v) c h w -> b v c h w", v=v) + + x_list = list(torch.unbind(x, dim=1)) + + x_list = self.attn(x_list, attn_splits) + + x = torch.stack(x_list, dim=1) + + x = rearrange(x, "b v c h w -> b (v h w) c") + + return x + + + +class MultiScalePointTransformer(nn.Module): + def __init__(self, channels, knn_samples=16, post_norm=False, + no_rpe=True, + no_attn=False, + qk_norm=False, + norm_pt_block=False, + num_heads=1, + num_scales=3, + stride=4, + downsample_agg_func='attn', + subsample_method='fps', + fps_num_samples=None, + attn_proj_channels=None, + ): + super().__init__() + + self.blocks = nn.ModuleList() + # knn 4 at 1 + self.blocks.append(TransformerBlock(channels, knn_samples=4, + post_norm=post_norm, + no_rpe=no_rpe, + no_attn=no_attn, + qk_norm=qk_norm, + norm_pt_block=norm_pt_block, + num_heads=num_heads, + attn_proj_channels=attn_proj_channels, + )) + + for i in range(num_scales - 2, -1, -1): + # knn 8 at 1/4 + # knn 16 at 1/16 + self.blocks.append(TransformerBlock(channels * (2 ** i), knn_samples= knn_samples // (2 ** i), + post_norm=post_norm, + no_rpe=no_rpe, + no_attn=no_attn, + qk_norm=qk_norm, + norm_pt_block=norm_pt_block, + num_heads=num_heads, + attn_proj_channels=attn_proj_channels, + )) + + self.down_blocks = nn.ModuleList() + for i in range(num_scales - 1): + self.down_blocks.append( + SubsampleBlock( + channels * (2 ** i), channels * (2 ** (i + 1)), + stride=stride, + knn_samples=knn_samples // (2 ** (num_scales - 1 - i)), + subsample_method=subsample_method, + agg_func=downsample_agg_func, + fps_num_samples=fps_num_samples, + attn_proj_channels=attn_proj_channels, + ) + ) + + self.down_agg = nn.ModuleList() + for i in range(num_scales - 1): + self.down_agg.append( + TransformerBlock(channels * (2 ** (i + 1)), knn_samples=knn_samples // (2 ** (num_scales - 1 - i)), + post_norm=post_norm, + no_rpe=no_rpe, + no_attn=no_attn, + qk_norm=qk_norm, + norm_pt_block=norm_pt_block, + num_heads=num_heads, + attn_proj_channels=attn_proj_channels, + ) + ) + + self.skip_blocks = nn.ModuleList() + for i in range(num_scales - 1, 0, -1): + self.skip_blocks.append( + SkipConnect( + channels * (2 ** i), + channels * (2 ** (i - 1)) + ) + ) + + def forward(self, pxo): + x1 = self.blocks[0](pxo) # 1 + p1, o1 = pxo[0], pxo[2] + p2, x2, o2 = self.down_blocks[0]([p1, x1, o1]) # 1/4 + x2 = self.down_agg[0]([p2, x2, o2]) # 1/4 + p3, x3, o3 = self.down_blocks[1]([p2, x2, o2]) # 1/16 + x3 = self.down_agg[1]([p3, x3, o3]) # 1/16 + + x4 = self.skip_blocks[0]([p2, x2, o2], [p3, x3, o3]) # 1/4 + p4, o4 = p2, o2 + x4 = self.blocks[1]([p4, x4, o4]) + x5 = self.skip_blocks[1]([p1, x1, o1], [p4, x4, o4]) # 1 + p5, o5 = p1, o1 + x5 = self.blocks[2]([p5, x5, o5]) + + return x5 + + +class PointLinearWrapper(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.linear = nn.Linear(in_channels, out_channels) + + def forward(self, pxo, b=None, v=None, h=None, w=None): + p, x, o = pxo + x = self.linear(x) + + return [p, x, o] + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x): + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +def test_fps(): + model = FPSSubsample(256, 256, + fps_num_samples=16, + subsample_method='fps', + ).cuda() + print(model) + + # FPS is significantly slower than grid with many points + + c = 256 + b, n = 2, 40480 + + x = torch.randn(b, n, c).cuda() + offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) + p = torch.randn(b, n, 3).cuda() + pxo = [p.view(-1, 3), x.view(-1, c), offset] + y = model(pxo) + print(y[1].shape) + + count = 100 + + for _ in range(5): + model(pxo) + + torch.cuda.synchronize() + start = time.time() + + for i in range(count): + model(pxo) + + torch.cuda.synchronize() + print(time.time() - start) + +def test_knn_query_and_group(): + c = 256 + # b, n = 2, 80480 + b, n = 8, 57344 + knn_samples = 16 + + x = torch.randn(b, n, c).cuda() + offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) + o = offset + p = torch.randn(b, n, 3).cuda() + p = p.view(-1, 3) + + knn_idx, _ = pointops.knn_query(knn_samples, p, o, p, o) + + print(knn_idx.shape) + + c_qkv = 192 + qkv = torch.randn(b*n, c_qkv).cuda() + T = 1000 + + # chunk first, then query twice + torch.cuda.synchronize() + start_time = time.time() + for _ in range(T): + x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) + x_k_query, idx = pointops.knn_query_and_group( + x_k.contiguous(), p, o, new_xyz=p, new_offset=o, + idx=knn_idx, + nsample=knn_samples, with_xyz=False + ) # [N, K, C/3] + x_v_query, _ = pointops.knn_query_and_group( + x_v.contiguous(), + p, + o, + new_xyz=p, + new_offset=o, + idx=idx, + nsample=knn_samples, + with_xyz=False, + ) + torch.cuda.synchronize() + end_time = time.time() + print(f"KNN query and group time: {(end_time - start_time) / T * 1000:.2f} ms") + + # query first, then chunk + torch.cuda.synchronize() + start_time = time.time() + for _ in range(T): + x_qkv_query = pointops.knn_query_and_group( + qkv.contiguous(), p, o, new_xyz=p, new_offset=o, + idx=knn_idx, + nsample=knn_samples, with_xyz=False + )[0] # [N, K, C*3] + x_q, x_k, x_v = torch.chunk(x_qkv_query, chunks=3, dim=-1) + torch.cuda.synchronize() + end_time = time.time() + print(f"KNN query and group time: {(end_time - start_time) / T * 1000:.2f} ms") + + # chunk first, then query once + torch.cuda.synchronize() + start_time = time.time() + for _ in range(T): + x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) + x_kv = torch.cat([x_k, x_v], dim=-1) # [N, 2C/3] + x_kv_query = pointops.knn_query_and_group( + x_kv.contiguous(), p, o, new_xyz=p, new_offset=o, + idx=knn_idx, nsample=knn_samples, with_xyz=False + )[0] # [N, K, 2C/3] + x_k_query, x_v_query = torch.chunk(x_kv_query, 2, dim=-1) + torch.cuda.synchronize() + end_time = time.time() + print(f"KNN query and group time: {(end_time - start_time) / T * 1000:.2f} ms") + +def test_knn(): + c = 256 + b, n = 2, 80480 + model = KNNAttention(channels=c, + # proj_feature=64, + ).cuda() + print(model) + + x = torch.randn(b, n, c).cuda() + offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) + p = torch.randn(b, n, 3).cuda() + pxo = [p.view(-1, 3), x.view(-1, c), offset] + y = model(pxo) + print(y.shape) + + count = 100 + + for _ in range(5): + model(pxo) + + torch.cuda.synchronize() + start = time.time() + + for i in range(count): + model(pxo) + + torch.cuda.synchronize() + print(time.time() - start) + + +def test_faiss_knn(): + # cannot install faiss unfortunately + # TODO: maybe implement a sliding window knn search later + c = 256 + b, n = 2, 80480 + knn_samples = 16 + + x = torch.randn(b, n, c).cuda() + offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) + o = offset + p = torch.randn(b, n, 3).cuda() + p = p.view(-1, 3) + # pxo = [p.view(-1, 3), x.view(-1, c), offset] + + # print(p.shape, o.shape) + # print(o) + + knn_idx, _ = pointops.knn_query(knn_samples, p, o, p, o) + + print(knn_idx.shape) + + count = 100 + + for _ in range(5): + pointops.knn_query(knn_samples, p, o, p, o) + + torch.cuda.synchronize() + start = time.time() + + for i in range(count): + pointops.knn_query(knn_samples, p, o, p, o) + + torch.cuda.synchronize() + print(time.time() - start) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def test_mlp(): + b, n, c = 2, 40240, 256 + model = MLP(c).cuda() + x = torch.randn(b, n, c).cuda() + + # model = SwiGLUFFN(c, c * 3).cuda() + + print('parameters:', count_parameters(model)) + + x = x.to(torch.bfloat16) + model.to(dtype=torch.bfloat16) + + with torch.autocast('cuda', enabled=True, dtype=torch.bfloat16): + y = model(x) + print(y.shape) + + count = 100 + + for _ in range(5): + model(x) + + torch.cuda.synchronize() + start = time.time() + + for i in range(count): + model(x) + + torch.cuda.synchronize() + print(time.time() - start) + + +def test_mv_block(): + c = 256 + num_heads = 4 + model = MultiViewBlock(c, num_heads).cuda() + x = torch.rand(2, 256, c).cuda() + + print(model) + + y = model(x) + + print(y.shape) + + +def test_cross_attn(): + c = 256 + v, h, w = 8, 64, 128 + num_heads = 4 + model = GaussianErrorCrossAttn(512, c, c).cuda() + x = torch.rand(2, v * h * w, 512).cuda() + y = torch.rand(2, v * h * w, c).cuda() + + print(model) + + y = model(x, y, v=v, h=h, w=w) + + print(x.shape, y.shape) + + +def test_grouping(): + c = 256 + # b, n = 2, 80480 + b, n = 1, 57344 + knn_samples = 16 + + x = torch.randn(b, n, c).cuda() + offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) + o = offset + p = torch.randn(b, n, 3).cuda() + p = p.view(-1, 3) + + knn_idx, _ = pointops.knn_query(knn_samples, p, o, p, o) + + print(knn_idx.shape) + + c_qkv = 192 + qkv = torch.randn(b*n, c_qkv).cuda() + x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) + x_kv = torch.cat([x_k, x_v], dim=-1) # [N, 2C/3] + + m, nsample, c = knn_idx.shape[0], knn_idx.shape[1], x_kv.shape[1] + feat = torch.cat([x_kv, torch.zeros([1, c]).to(x_kv.device)], dim=0) + T = 1000 + + torch.cuda.synchronize() + start_time = time.time() + for _ in range(T): + grouping(idx=knn_idx, feat=x_kv, xyz=p, new_xyz=p, with_xyz=False) + # grouping_idx = feat[knn_idx.view(-1).long(), :].view( + # m, nsample, c + # ) # (m, num_sample, c) + torch.cuda.synchronize() + end_time = time.time() + # print(f"Grouping via indexing: {(end_time - start_time) / T * 1000:.2f} ms") + print(f"grouping pytorch: {(end_time - start_time) / T * 1000:.2f} ms") + + torch.cuda.synchronize() + start_time = time.time() + for _ in range(T): + grouping2(x_kv, knn_idx) + # grouping_embed = torch.nn.functional.embedding(knn_idx, feat) # [m,num_sample,c] + torch.cuda.synchronize() + end_time = time.time() + # print(f"Grouping via embedding: {(end_time - start_time) / T * 1000:.2f} ms") + print(f"grouping cuda: {(end_time - start_time) / T * 1000:.2f} ms") + +if __name__ == '__main__': + # test_fps() + # test_knn() + # test_mlp() + # test_mv_block() + # test_cross_attn() + # test_faiss_knn() + # test_knn_query_and_group() + test_grouping() + diff --git a/optgs/model/encoder/unimatch/__init__.py b/optgs/model/encoder/unimatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/unimatch/backbone.py b/optgs/model/encoder/unimatch/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..4766987d22d4bb309b263adfe8fe7ca6a5f1e0f9 --- /dev/null +++ b/optgs/model/encoder/unimatch/backbone.py @@ -0,0 +1,170 @@ +import torch.nn as nn + + +class ResidualBlock(nn.Module): + def __init__( + self, + in_planes, + planes, + norm_layer=nn.InstanceNorm2d, + stride=1, + dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, + planes, + kernel_size=3, + dilation=dilation, + padding=dilation, + stride=stride, + bias=False, + ) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + dilation=dilation, + padding=dilation, + bias=False, + ) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__( + self, + output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + return_quarter=False, # return 1/4 resolution feature + lowest_scale=8, # lowest resolution, 1/8 or 1/4 + return_all_scales=False, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_scales = num_output_scales + self.return_quarter = return_quarter + self.lowest_scale = lowest_scale + self.return_all_scales = return_all_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d( + 3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False + ) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer( + feature_dims[0], stride=1, norm_layer=norm_layer + ) # 1/2 + + if self.lowest_scale == 4: + stride = 1 + else: + stride = 2 + self.layer2 = self._make_layer( + feature_dims[1], stride=stride, norm_layer=norm_layer + ) # 1/2 or 1/4 + + # lowest resolution 1/4 or 1/8 + self.layer3 = self._make_layer( + feature_dims[2], + stride=2, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock( + self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation + ) + layer2 = ResidualBlock( + dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation + ) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + output_all_scales = [] + output = [] + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + + if self.return_all_scales: + output_all_scales.append(x) + + if self.num_scales >= 3: + output.append(x) + + x = self.layer2(x) # 1/2 or 1/4 + if self.return_quarter: + output.append(x) + + if self.return_all_scales: + output_all_scales.append(x) + + if self.num_scales >= 2: + output.append(x) + + x = self.layer3(x) # 1/4 or 1/8 + x = self.conv2(x) + + if self.return_all_scales: + output_all_scales.append(x) + + if self.return_all_scales: + return output_all_scales + + if self.return_quarter: + output.append(x) + return output + + if self.num_scales >= 1: + output.append(x) + return output + + out = [x] + + return out diff --git a/optgs/model/encoder/unimatch/dinov2/__init__.py b/optgs/model/encoder/unimatch/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/unimatch/dinov2/dinov2.py b/optgs/model/encoder/unimatch/dinov2/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..fd2b668158f1c4a30ee8c69bb21fe2749e32b431 --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/dinov2.py @@ -0,0 +1,435 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + use_checkpointing=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.use_checkpointing = use_checkpointing + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.use_checkpointing: + x = torch.utils.checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + if self.use_checkpointing: + x = torch.utils.checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + if self.use_checkpointing: + x = torch.utils.checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ): + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, use_checkpointing=False, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4., + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + use_checkpointing=use_checkpointing, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, use_checkpointing=False, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + use_checkpointing=use_checkpointing, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, use_checkpointing=False, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4., + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + use_checkpointing=use_checkpointing, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, use_checkpointing=False, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4., + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + use_checkpointing=use_checkpointing, + **kwargs, + ) + return model + + +def DINOv2(model_name, + use_checkpointing=False, + ): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + use_checkpointing=use_checkpointing, + ) + \ No newline at end of file diff --git a/optgs/model/encoder/unimatch/dinov2/layers/__init__.py b/optgs/model/encoder/unimatch/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/optgs/model/encoder/unimatch/dinov2/layers/attention.py b/optgs/model/encoder/unimatch/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fee19a5accb9784fcf7651238bcb90ce20e192 --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/layers/attention.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings +import torch + +from torch import Tensor +from torch import nn + +import torch.nn.functional as F + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +USE_FLASH_ATTENTION3 = True +try: + from flash_attn_interface import flash_attn_func + FA3_AVAILABLE = True + warnings.warn('flash attention 3 is available (ViT)') +except ImportError: + FA3_AVAILABLE = False + warnings.warn('flash attention 3 is not available (ViT)') + + +USE_PYTORCH_ATTN = True # flash attention 2 + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + if USE_PYTORCH_ATTN: + q, k, v = qkv[0], qkv[1], qkv[2] + out = F.scaled_dot_product_attention(q, k, v) + x = out.permute(0, 2, 1, 3).reshape(B, N, C) + else: + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None, y=None) -> Tensor: + if not XFORMERS_AVAILABLE and not (USE_FLASH_ATTENTION3 and FA3_AVAILABLE): + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + if y is not None: + # print('cross attn') + + B, Nq, C = x.shape + context = x if y is None else y + Nk = context.shape[1] + + # Lazy conversion from qkv to q_proj and kv_proj + if not hasattr(self, 'q_proj') or self.q_proj is None: + self.q_proj, self.kv_proj = convert_qkv_to_q_and_kv_proj(self.qkv) + # del self.qkv # Optional: free memory + + # Project q, k, v + q = self.q_proj(x).reshape(B, Nq, self.num_heads, C // self.num_heads) # [B, N, H, d] + kv = self.kv_proj(context).reshape(B, Nk, 2, self.num_heads, C // self.num_heads) # [B, Nk, 2, H, d] + k, v = kv.unbind(dim=2) # [B, h, Nk, d] + + # FlashAttention3 or memory-efficient fallback + if USE_FLASH_ATTENTION3 and FA3_AVAILABLE: + if attn_bias is not None: + raise AssertionError("attn_bias is not supported in FA3") + out = flash_attn_func(q, k, v)[0] # [B, h, Nq, d] + else: + out = memory_efficient_attention(q, k, v, attn_bias=attn_bias) # [B, h, Nq, d] + + # Merge heads + out = out.reshape(B, Nq, C) # [B, Nq, C] + out = self.proj(out) + out = self.proj_drop(out) + return out + + else: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = torch.unbind(qkv, 2) + + if USE_FLASH_ATTENTION3 and FA3_AVAILABLE: + if attn_bias is not None: + raise AssertionError("attn_bias is not supported in FA3") + x = flash_attn_func(q, k, v)[0] + else: + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +def convert_qkv_to_q_and_kv_proj(qkv_layer: nn.Linear): + """ + Convert a self-attention qkv projection layer (dim -> 3*dim) into + separate q_proj (dim -> dim) and kv_proj (dim -> 2*dim) layers. + + Returns: + q_proj (nn.Linear): projection for query + kv_proj (nn.Linear): projection for key and value + """ + assert isinstance(qkv_layer, nn.Linear), "Expected nn.Linear for qkv_layer" + in_features = qkv_layer.in_features + out_features = qkv_layer.out_features + assert out_features % 3 == 0, "Output features must be divisible by 3" + + dim = out_features // 3 + device = qkv_layer.weight.device + dtype = qkv_layer.weight.dtype + + q_proj = nn.Linear(in_features, dim, bias=qkv_layer.bias is not None).to(device=device, dtype=dtype) + kv_proj = nn.Linear(in_features, dim * 2, bias=qkv_layer.bias is not None).to(device=device, dtype=dtype) + + # Split weights and biases + q_weight, k_weight, v_weight = qkv_layer.weight.chunk(3, dim=0) + q_proj.weight.data.copy_(q_weight) + kv_proj.weight.data.copy_(torch.cat([k_weight, v_weight], dim=0)) + + if qkv_layer.bias is not None: + q_bias, k_bias, v_bias = qkv_layer.bias.chunk(3, dim=0) + q_proj.bias.data.copy_(q_bias) + kv_proj.bias.data.copy_(torch.cat([k_bias, v_bias], dim=0)) + + return q_proj, kv_proj diff --git a/optgs/model/encoder/unimatch/dinov2/layers/block.py b/optgs/model/encoder/unimatch/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..2246ef13ac68b176a864df4638a913e8a6596451 --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/layers/block.py @@ -0,0 +1,264 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = MemEffAttention, + ffn_layer: Callable[..., nn.Module] = Mlp, + no_attn: bool = False + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + if no_attn: + self.attn = nn.Identity() + else: + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, y: Tensor | None = None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), y=y)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/optgs/model/encoder/unimatch/dinov2/layers/dino_head.py b/optgs/model/encoder/unimatch/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/optgs/model/encoder/unimatch/dinov2/layers/drop_path.py b/optgs/model/encoder/unimatch/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/optgs/model/encoder/unimatch/dinov2/layers/layer_scale.py b/optgs/model/encoder/unimatch/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/optgs/model/encoder/unimatch/dinov2/layers/mlp.py b/optgs/model/encoder/unimatch/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/optgs/model/encoder/unimatch/dinov2/layers/patch_embed.py b/optgs/model/encoder/unimatch/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/optgs/model/encoder/unimatch/dinov2/layers/swiglu_ffn.py b/optgs/model/encoder/unimatch/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9dafa4592a408f6874d54853e8f60db5c41f74 --- /dev/null +++ b/optgs/model/encoder/unimatch/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/optgs/model/encoder/unimatch/dpt_head.py b/optgs/model/encoder/unimatch/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9f9be87836a4624e0e78bd44ee7f54ef9369e2 --- /dev/null +++ b/optgs/model/encoder/unimatch/dpt_head.py @@ -0,0 +1,623 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False, + concat_cnn_features=True, + concat_mv_features=True, + cnn_feature_channels=[64, 96, 128], + concat_features=True, + downsample_factor=8, + return_feature=False, + num_scales=1, + latent_downsample=None, + latent_feature_no_concat=False, + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.concat_cnn_features = concat_cnn_features + self.concat_mv_features = concat_mv_features + self.concat_features = concat_features + self.downsample_factor = downsample_factor + self.return_feature = return_feature + self.num_scales = num_scales + self.latent_downsample = latent_downsample + self.latent_feature_no_concat = latent_feature_no_concat + + if self.concat_features: + if self.downsample_factor == 4 and num_scales == 2: + depth_channel = 0 if self.return_feature else 1 + self.concat_projects = nn.ModuleList( + [ + nn.Conv2d( + cnn_feature_channels[0] + out_channels[0], + out_channels[0], + 1, + ), + nn.Conv2d( + cnn_feature_channels[1] + + out_channels[1] + + 64 + + depth_channel, + out_channels[1], + 1, + ), # 1/4 concat(cnn, mono, mv, depth) + nn.Conv2d( + cnn_feature_channels[2] + out_channels[2] + 128, + out_channels[2], + 1, + ), # 1/8 concat(cnn, mono, mv) + ] + ) + elif self.downsample_factor == 2 and num_scales == 2: + depth_channel = 0 if self.return_feature else 1 + self.concat_projects = nn.ModuleList( + [ + nn.Conv2d( + cnn_feature_channels[0] + + cnn_feature_channels[1] + + out_channels[0] + + 64 + + depth_channel, + out_channels[0], + 1, + ), # 1/2 + nn.Conv2d( + cnn_feature_channels[2] + out_channels[1] + 128, + out_channels[1], + 1, + ), # 1/4 concat(cnn, mono, mv, depth) + nn.Conv2d(out_channels[2], out_channels[2], 1), # 1/8 mono + ] + ) + elif self.downsample_factor == 4 and num_scales == 1: + depth_channel = 0 if self.return_feature else 1 + self.concat_projects = nn.ModuleList( + [ + nn.Conv2d( + cnn_feature_channels[0] + + cnn_feature_channels[1] + + out_channels[0], + out_channels[0], + 1, + ), + nn.Conv2d( + cnn_feature_channels[2] + + out_channels[1] + + 128 + + depth_channel, + out_channels[1], + 1, + ), + nn.Conv2d(out_channels[2], out_channels[2], 1), # 1/8 mono + ] + ) + else: + depth_channel = 0 if self.return_feature else 1 + self.concat_projects = nn.ModuleList( + [ + nn.Conv2d( + cnn_feature_channels[0] + out_channels[0], + out_channels[0], + 1, + ), + nn.Conv2d( + cnn_feature_channels[1] + out_channels[1], + out_channels[1], + 1, + ), + nn.Conv2d( + cnn_feature_channels[2] + + out_channels[2] + + 128 + + depth_channel, + out_channels[2], + 1, + ), # 1/8 concat(cnn, mono, mv, depth) + ] + ) + else: + if self.concat_cnn_features: + self.cnn_projects = nn.ModuleList( + [ + nn.Conv2d(cnn_feature_channels[i], out_channels[i], 1) + for i in range(len(cnn_feature_channels)) + ] + ) + + if self.concat_mv_features: + self.mv_projects = nn.Conv2d(128, out_channels[2], 1) + + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) + for out_channel in out_channels + ] + ) + + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0, + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1, + ), + ] + ) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()) + ) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + if not self.latent_feature_no_concat: + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + + if self.latent_downsample != 8: + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + # not used + del self.scratch.refinenet4.resConfUnit1 + + head_features_1 = features + head_features_2 = 16 + + if not self.return_feature: + self.scratch.output_conv = nn.Sequential( + nn.Conv2d( + head_features_1, + head_features_1 // 2, + 3, + 1, + 1, + padding_mode="replicate", + ), + nn.GELU(), + nn.Conv2d( + head_features_1 // 2, + head_features_2, + kernel_size=3, + stride=1, + padding=1, + padding_mode="replicate", + ), + nn.GELU(), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + ) + + # init delta depth as zero + nn.init.zeros_(self.scratch.output_conv[-1].weight) + nn.init.zeros_(self.scratch.output_conv[-1].bias) + + def forward( + self, + out_features, + downsample_factor=8, + cnn_features=None, + mv_features=None, + depth=None, + ): + out = [] + for i, x in enumerate(out_features): + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + # 1/2, 1/4, 1/8, 1/16 + layer_1, layer_2, layer_3, layer_4 = out + + if self.concat_features: + if not self.return_feature: + assert depth is not None + + if self.downsample_factor == 4 and self.num_scales == 1: + concat1 = torch.cat((cnn_features[0], cnn_features[1], layer_1), dim=1) + elif self.downsample_factor == 2 and self.num_scales == 2: + if self.return_feature: + concat1 = torch.cat( + (cnn_features[0], cnn_features[1], mv_features[0], layer_1), + dim=1, + ) + else: + concat1 = torch.cat( + ( + cnn_features[0], + cnn_features[1], + mv_features[0], + depth, + layer_1, + ), + dim=1, + ) + else: + concat1 = torch.cat((cnn_features[0], layer_1), dim=1) + layer_1 = self.concat_projects[0](concat1) # 1/2 + + if self.downsample_factor == 4 and self.num_scales == 2: + assert isinstance(mv_features, list) + if self.return_feature: + concat2 = torch.cat( + (cnn_features[1], layer_2, mv_features[0]), dim=1 + ) + else: + concat2 = torch.cat( + (cnn_features[1], layer_2, mv_features[0], depth), dim=1 + ) + layer_2 = self.concat_projects[1](concat2) # 1/4 + + concat3 = torch.cat((cnn_features[2], layer_3, mv_features[1]), dim=1) + layer_3 = self.concat_projects[2](concat3) # 1/8 + elif self.downsample_factor == 2 and self.num_scales == 2: + assert isinstance(mv_features, list) + concat2 = torch.cat((cnn_features[2], layer_2, mv_features[1]), dim=1) + layer_2 = self.concat_projects[1](concat2) # 1/4 + + concat3 = layer_3 + layer_3 = self.concat_projects[2](concat3) # 1/8 + elif self.downsample_factor == 4 and self.num_scales == 1: + if self.return_feature: + concat2 = torch.cat((cnn_features[2], layer_2, mv_features), dim=1) + else: + concat2 = torch.cat( + (cnn_features[2], layer_2, mv_features, depth), dim=1 + ) + layer_2 = self.concat_projects[1](concat2) # 1/4 + + concat3 = layer_3 + layer_3 = self.concat_projects[2](concat3) # 1/8 + else: + concat2 = torch.cat((cnn_features[1], layer_2), dim=1) + layer_2 = self.concat_projects[1](concat2) # 1/4 + + if self.return_feature: + concat3 = torch.cat((cnn_features[2], layer_3, mv_features), dim=1) + else: + concat3 = torch.cat( + (cnn_features[2], layer_3, mv_features, depth), dim=1 + ) + layer_3 = self.concat_projects[2](concat3) # 1/8 + else: + if self.concat_cnn_features: + assert cnn_features is not None + assert len(cnn_features) == 3 # 1/2, 1/4, 1/8 + cnn_features = [ + self.cnn_projects[i](f) for i, f in enumerate(cnn_features) + ] + + layer_1 = layer_1 + cnn_features[0] # 1/2 + layer_2 = layer_2 + cnn_features[1] # 1/4 + layer_3 = layer_3 + cnn_features[2] # 1/8 + + if self.concat_mv_features: + # 1/8 + mv_features = self.mv_projects(mv_features) + + layer_3 = layer_3 + mv_features # 1/8 + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) # 1/8 + if self.latent_feature_no_concat and self.latent_downsample == 8 and self.return_feature: + return path_4 + + path_3 = self.scratch.refinenet3( + path_4, layer_3_rn, size=layer_2_rn.shape[2:] + ) # 1/4 + + if self.latent_feature_no_concat and self.latent_downsample == 4 and self.return_feature: + return path_3 + + path_2 = self.scratch.refinenet2( + path_3, layer_2_rn, size=layer_1_rn.shape[2:] + ) # 1/2 + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) # 1 + + if self.latent_downsample == 4: + # all resize to 1/4 resolution + path_4 = F.interpolate(path_4, scale_factor=2, mode='bilinear', align_corners=True) + path_2 = F.interpolate(path_2, scale_factor=0.5, mode='bilinear', align_corners=True) + path_1 = F.interpolate(path_1, scale_factor=0.25, mode='bilinear', align_corners=True) + # concat all + path_1 = torch.cat((path_4, path_3, path_2, path_1), dim=1) + + if self.return_feature: + return path_1 + + out = self.scratch.output_conv(path_1) + + return out + + +if __name__ == "__main__": + device = torch.device("cuda") + c = 384 + model = DPTHead( + in_channels=c, + concat_cnn_features=True, + concat_mv_features=True, + ).to(device) + print(model) + + h, w = 16, 32 + + x = torch.randn(2, c, h, w).to(device) + + out_features = [x] * 4 + + cnn_features = [ + torch.randn(2, 64, h * 4, w * 4).to(device), + torch.randn(2, 96, h * 2, w * 2).to(device), + torch.randn(2, 128, h, w).to(device), + ] + + mv_features = torch.randn(2, 128, h, w).to(device) + + out = model(out_features, h, w, cnn_features=cnn_features, mv_features=mv_features) + + print(out.shape) diff --git a/optgs/model/encoder/unimatch/feature_upsampler.py b/optgs/model/encoder/unimatch/feature_upsampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4ca714b321f9b691a5815dea89157172519afe8c --- /dev/null +++ b/optgs/model/encoder/unimatch/feature_upsampler.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import math + + +class ResizeConvFeatureUpsampler(nn.Module): + """ + https://distill.pub/2016/deconv-checkerboard/ + """ + + def __init__(self, num_scales=1, + lowest_feature_resolution=8, + out_channels=128, + vit_type='vits', + no_mono_feature=False, + gaussian_downsample=None, + monodepth_backbone=False, + ): + super(ResizeConvFeatureUpsampler, self).__init__() + + self.num_scales = num_scales + self.monodepth_backbone = monodepth_backbone + + self.upsampler = nn.ModuleList() + + vit_feature_channel_dict = { + 'vits': 384, + 'vitb': 768, + 'vitl': 1024 + } + + vit_feature_channel = vit_feature_channel_dict[vit_type] + + if monodepth_backbone: + vit_feature_channel = 384 + + out_channels = out_channels // num_scales + + for i in range(num_scales): + cnn_feature_channels = 128 - (32 * i) + mv_transformer_feature_channels = 128 // (2 ** i) + if no_mono_feature: + mono_feature_channels = 0 + else: + mono_feature_channels = vit_feature_channel // (2 ** i) + + in_channels = cnn_feature_channels + \ + mv_transformer_feature_channels + mono_feature_channels + + if monodepth_backbone: + in_channels = 384 + + curr_upsample_factor = lowest_feature_resolution // (2 ** i) + + num_upsample = int(math.log(curr_upsample_factor, 2)) + + modules = [] + if num_upsample == 1: + curr_in_channels = out_channels * 2 + else: + curr_in_channels = out_channels * 2 * (num_upsample - 1) + modules.append(nn.Conv2d(in_channels, curr_in_channels, 1)) + for i in range(num_upsample): + modules.append(nn.Upsample(scale_factor=2, mode='nearest')) + + if i == num_upsample - 1: + modules.append(nn.Conv2d(curr_in_channels, + out_channels, 3, 1, 1, padding_mode='replicate')) + else: + modules.append(nn.Conv2d(curr_in_channels, + curr_in_channels // 2, 3, 1, 1, padding_mode='replicate')) + curr_in_channels = curr_in_channels // 2 + modules.append(nn.GELU()) + + if gaussian_downsample is not None: + if gaussian_downsample == 2: + del modules[-3:] + elif gaussian_downsample == 4: + del modules[-6:] + else: + raise NotImplementedError + + self.upsampler.append(nn.Sequential(*modules)) + + def forward(self, features_list_cnn, features_list_mv, features_list_mono=None): + out = [] + + for i in range(self.num_scales): + if self.monodepth_backbone: + concat = features_list_cnn[i] + elif features_list_mono is None: + concat = torch.cat( + (features_list_cnn[i], features_list_mv[i]), dim=1) + else: + concat = torch.cat( + (features_list_cnn[i], features_list_mv[i], features_list_mono[i]), dim=1) + concat = self.upsampler[i](concat) + + out.append(concat) + + out = torch.cat(out, dim=1) + + return out + + +def _test(): + device = torch.device('cuda:0') + + model = ResizeConvFeatureUpsampler(num_scales=2, + lowest_feature_resolution=4, + ).to(device) + print(model) + + b, h, w = 2, 32, 64 + features_list_cnn = [torch.randn(b, 128, h, w).to(device)] + features_list_mv = [torch.randn(b, 128, h, w).to(device)] + features_list_mono = [torch.randn(b, 384, h, w).to(device)] + + # scale 2 + features_list_cnn.append(torch.randn(b, 96, h * 2, w * 2).to(device)) + features_list_mv.append(torch.randn(b, 64, h * 2, w * 2).to(device)) + features_list_mono.append(torch.randn(b, 192, h * 2, w * 2).to(device)) + + out = model(features_list_cnn, + features_list_mv, features_list_mono) + + print(out.shape) + + +if __name__ == '__main__': + _test() diff --git a/optgs/model/encoder/unimatch/geometry.py b/optgs/model/encoder/unimatch/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..d28984173fd998f444a5f9e832674a5b1c505eef --- /dev/null +++ b/optgs/model/encoder/unimatch/geometry.py @@ -0,0 +1,306 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +def back_project(depth, intrinsics): + # Back project 2D pixel coords to 3D points + # depth: [B, H, W] + # intrinsics: [B, 3, 3] + b, h, w = depth.shape + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + + intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3] + + points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W] + + return points + + +def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): + # Transform 3D points from reference camera to target camera + # points_ref: [B, 3, H, W] + # extrinsics_ref: [B, 4, 4] + # extrinsics_tgt: [B, 4, 4] + # extrinsics_rel: [B, 4, 4], relative pose transform + b, _, h, w = points_ref.shape + + if extrinsics_rel is None: + extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] + + points_tgt = torch.bmm(extrinsics_rel[:, :3, :3], + points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W] + + points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W] + + return points_tgt + + +def reproject(points_tgt, intrinsics, return_mask=False): + # reproject to target view + # points_tgt: [B, 3, H, W] + # intrinsics: [B, 3, 3] + + b, _, h, w = points_tgt.shape + + proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W] + + X = proj_points[:, 0] + Y = proj_points[:, 1] + Z = proj_points[:, 2].clamp(min=1e-3) + + pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale + + if return_mask: + # valid mask in pixel space + mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & ( + pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W] + + return pixel_coords, mask + + return pixel_coords + + +def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_mask=False): + # Compute reprojection sample coords + points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] + points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) + + if return_mask: + reproj_coords, mask = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords, mask + + reproj_coords = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords + + +def compute_flow_with_depth_pose(depth_ref, intrinsics, + extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_mask=False): + b, h, w = depth_ref.shape + coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] + + if return_mask: + reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + rigid_flow = reproj_coords - coords_init + + return rigid_flow, mask + + reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + + rigid_flow = reproj_coords - coords_init + + return rigid_flow + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5, + return_flow_diff=False, + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + if return_flow_diff: + return diff_fwd, diff_bwd + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +def warp_with_depth_pose(feature1, intrinsics, pose, depth, + padding_mode='zeros', + return_rigid_flow=False, + return_mask=False, + ): + assert depth.dim() == 3 # [B, H, W] + sample_coords = reproject_coords(depth, + intrinsics, + extrinsics_rel=pose, + return_mask=return_mask, + ) # [B, 2, H, W] + + if return_mask: + sample_coords, mask = sample_coords + + sample_coords = sample_coords.permute(0, 2, 3, 1) # [B, H, W, 2] + + warped_feature1 = bilinear_sample(feature1, sample_coords, + padding_mode=padding_mode) # [B, C, H, W] + + if return_mask: + return warped_feature1, mask + + if return_rigid_flow: + b, h, w = depth.size() + coords_init = coords_grid(b, h, w, device=depth.device) # [B, 2, H, W] + rigid_flow = sample_coords.permute(0, 3, 1, 2) - coords_init + + return warped_feature1, rigid_flow + + return warped_feature1 + + +def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth, + padding_mode='zeros', + rigid_flow_to_subtract=None, + ): + # pixel-specific depth candidates, useful for refinement + # feature1: [B, C, H, W] + # intrinsics: [B, 3, 3] + # pose: [B, 4, 4] + # depth: [B, D, H, W] + assert intrinsics.size(1) == intrinsics.size(2) == 3 + assert pose.size(1) == pose.size(2) == 4 + assert depth.dim() == 4 + + b, d, h, w = depth.size() + c = feature1.size(1) + + # stop gradient + with torch.no_grad(): + # pixel coordinates + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + # back project to 3D and transform viewpoint + points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] + points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( + 1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W] + points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] + # reproject to 2D image plane + points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W] + pixel_coords = points[:, :2] / points[:, -1:].clamp(min=MIN_DEPTH) # [B, 2, D, H*W] + + if rigid_flow_to_subtract is not None: + assert rigid_flow_to_subtract.dim() == 4 # [B, 2, H, W] + assert rigid_flow_to_subtract.size(1) == 2 + + pixel_coords = pixel_coords - rigid_flow_to_subtract.view(b, 2, h * w).unsqueeze(2) + + # normalize to [-1, 1] + x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] + + # sample features + warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear', + padding_mode=padding_mode, + align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W] + + return warped_feature + \ No newline at end of file diff --git a/optgs/model/encoder/unimatch/ldm_unet/__init__.py b/optgs/model/encoder/unimatch/ldm_unet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/unimatch/ldm_unet/attention.py b/optgs/model/encoder/unimatch/ldm_unet/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c45b8aee5ece5f99e37ebbcea16c9d1598cf823d --- /dev/null +++ b/optgs/model/encoder/unimatch/ldm_unet/attention.py @@ -0,0 +1,261 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=False): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + # self.checkpoint = checkpoint + + def forward(self, x, context=None): + # return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + return _forward(x, context) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/optgs/model/encoder/unimatch/ldm_unet/cross_attention.py b/optgs/model/encoder/unimatch/ldm_unet/cross_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d119a56076c99ec3c11c248042e5c4011c1161c4 --- /dev/null +++ b/optgs/model/encoder/unimatch/ldm_unet/cross_attention.py @@ -0,0 +1,161 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import warnings + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class CrossAttention(nn.Module): + def __init__( + self, + in_dim1, + in_dim2, + dim=128, + out_dim=None, + num_heads=4, + qkv_bias=False, + proj_bias=False, + ): + super().__init__() + + assert XFORMERS_AVAILABLE + + if out_dim is None: + out_dim = in_dim1 + + self.num_heads = num_heads + self.dim = dim + self.q = nn.Linear(in_dim1, dim, bias=qkv_bias) + self.kv = nn.Linear(in_dim2, dim * 2, bias=qkv_bias) + self.proj = nn.Linear(dim, out_dim, bias=proj_bias) + + def forward(self, x, y): + c = self.dim + b, n1, c1 = x.shape + n2, c2 = y.shape[1:] + + q = self.q(x).reshape(b, n1, self.num_heads, c // self.num_heads) + kv = self.kv(y).reshape(b, n2, 2, self.num_heads, c // self.num_heads) + k, v = unbind(kv, 2) + + x = memory_efficient_attention(q, k, v) + x = x.reshape(b, n1, c) + + x = self.proj(x) + + return x + + +class UNetCrossAttentionBlock(nn.Module): + def __init__(self, + in_dim1, + in_dim2, + dim=128, + out_dim=None, + num_heads=4, + qkv_bias=False, + proj_bias=False, + with_ffn=False, + concat_cross_attn=False, + concat_output=False, + no_cross_attn=False, + with_norm=False, + concat_conv3x3=False, + ): + super().__init__() + + out_dim = out_dim or in_dim1 + + self.no_cross_attn = no_cross_attn + self.with_norm = with_norm + + if no_cross_attn: + if concat_conv3x3: + self.proj = nn.Conv2d(in_dim1 + in_dim2, out_dim, 3, 1, 1) + else: + self.proj = nn.Conv2d(in_dim1 + in_dim2, out_dim, 1) + else: + self.with_ffn = with_ffn + self.concat_cross_attn = concat_cross_attn + self.concat_output = concat_output + + self.cross_attn = CrossAttention( + in_dim1=in_dim1, + in_dim2=in_dim2, + dim=dim, + out_dim=out_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ) + + if with_norm: + self.norm1 = nn.LayerNorm(out_dim) + else: + self.norm1 = nn.Identity() + + if with_ffn: + in_channels = out_dim + in_dim1 if concat_cross_attn else in_dim1 + ffn_dim_expansion = 4 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, in_dim1, bias=False), + ) + + if with_norm: + self.norm2 = nn.LayerNorm(in_dim1) + else: + self.norm2 = nn.Identity() + + if self.concat_output: + self.out = nn.Linear(out_dim + in_dim1, in_dim1) + + def forward(self, x, y): + # x: [B, C, H, W] + # y: [B, N, C] or [B, C, H, W] + + if self.no_cross_attn: + assert x.dim() == 4 and y.dim() == 4 + if y.shape[2:] != x.shape[2:]: + y = F.interpolate(y, x.shape[2:], mode='bilinear', align_corners=True) + return self.proj(torch.cat((x, y), dim=1)) + + identity = x + + b, c, h, w = x.size() + x = x.view(b, c, -1).permute(0, 2, 1) + + cross_attn = self.norm1(self.cross_attn(x, y)) + + if self.with_ffn: + if self.concat_cross_attn: + concat = torch.cat((x, cross_attn), dim=-1) + else: + concat = x + cross_attn + + cross_attn = self.norm2(self.mlp(concat)) + + if self.concat_output: + return self.out(torch.cat((x, cross_attn), dim=-1)) + + # reshape back + cross_attn = cross_attn.view(b, h, w, c).permute(0, 3, 1, 2) # [B, C, H, W] + + return identity + cross_attn + diff --git a/optgs/model/encoder/unimatch/ldm_unet/unet.py b/optgs/model/encoder/unimatch/ldm_unet/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..7fe6483313e07de89e2ea8cd7e840f5027606d1e --- /dev/null +++ b/optgs/model/encoder/unimatch/ldm_unet/unet.py @@ -0,0 +1,1243 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable +from einops import rearrange + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch + +from .util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from .attention import SpatialTransformer + +from .cross_attention import UNetCrossAttentionBlock + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, + downsample_3ddim=False, # downsample all 3d dims instead of only spatial dims + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + self.downsample_3ddim = downsample_3ddim + + def forward(self, x, y=None): + assert x.shape[1] == self.channels + if self.dims == 3 and not self.downsample_3ddim: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1, + downsample_3ddim=False, # downsample all 3d dims instead of only spatial dims + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + + if downsample_3ddim: + assert dims == 3 + stride = 2 + + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x, y=None): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + postnorm=False, + channels_per_group=None, + kernel_size=3, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + if postnorm: + self.in_layers = nn.Sequential( + conv_nd(dims, channels, self.out_channels, kernel_size, padding=(kernel_size - 1) // 2), + normalization(self.out_channels, channels_per_group=channels_per_group), + nn.SiLU(), + ) + else: + self.in_layers = nn.Sequential( + normalization(channels, channels_per_group=channels_per_group), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=(kernel_size - 1) // 2), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + # self.emb_layers = nn.Sequential( + # nn.SiLU(), + # linear( + # emb_channels, + # 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + # ), + # ) + + if postnorm: + self.out_layers = nn.Sequential( + conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=(kernel_size - 1) // 2), + zero_module( + normalization(self.out_channels, channels_per_group=channels_per_group), + ), + nn.SiLU(), + ) + else: + self.out_layers = nn.Sequential( + normalization(self.out_channels, channels_per_group=channels_per_group), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=(kernel_size - 1) // 2) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, kernel_size, padding=(kernel_size - 1) // 2 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb=None): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb=None): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + # emb_out = self.emb_layers(emb).type(h.dtype) + # while len(emb_out.shape) < len(h.shape): + # emb_out = emb_out[..., None] + # if self.use_scale_shift_norm: + # out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + # scale, shift = th.chunk(emb_out, 2, dim=1) + # h = out_norm(h) * (1 + scale) + shift + # h = out_rest(h) + # else: + # h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + postnorm=False, + channels_per_group=None, + num_frames=2, + use_cross_view_self_attn=False, + ): + super().__init__() + + # NOTE: current attention layer doesn't have positional encoding (TODO) + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy( + self.num_heads, + n_frames=num_frames, + use_cross_view_self_attn=use_cross_view_self_attn, + ) + + if postnorm: + self.proj_out = conv_nd(1, channels, channels, 1) + self.norm = zero_module(normalization(channels, channels_per_group=channels_per_group)) + else: + self.norm = normalization(channels, channels_per_group=channels_per_group) + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + self.postnorm = postnorm + + def forward(self, x): + # return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + return self._forward(x) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + + if self.postnorm: + qkv = self.qkv(x) + h = self.attention(qkv) + h = self.proj_out(h) + h = self.norm(h) + else: + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + + return (x + h).reshape(b, c, *spatial) + + +class CrossAttentionBlock(nn.Module): + """ + Corss attention conditioning + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + condition_channels, + num_heads=8, + proj_channels=512, + num_views=3, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + channels_per_group=None, + with_norm=False, + tanh_gating=False, # following Flamingo + ffn_after_cross_attn=False, # following Flamingo + ): + super().__init__() + + self.channels = channels + self.num_head = num_heads + self.num_views = num_views + self.proj_channels = proj_channels + self.with_norm = with_norm + self.tanh_gating = tanh_gating + self.ffn_after_cross_attn = ffn_after_cross_attn + + self.q_proj = nn.Linear(channels, proj_channels) + self.k_proj = nn.Linear(condition_channels, proj_channels) + self.v_proj = nn.Linear(condition_channels, proj_channels) + + # TODO: whether need norm layer + # self.norm = normalization(proj_channels, channels_per_group=channels_per_group) + + if self.tanh_gating: + self.out_proj = conv_nd(3, proj_channels, channels, 1) + self.attn_gate = nn.Parameter(torch.tensor([0.])) + else: + if self.with_norm: + self.out_proj = conv_nd(3, proj_channels, channels, 1) + self.norm = zero_module(normalization(channels, channels_per_group=channels_per_group)) + else: + self.out_proj = zero_module(conv_nd(3, proj_channels, channels, 1)) + + if self.ffn_after_cross_attn: + self.ffn_gate = nn.Parameter(torch.tensor([0.])) + self.ffn = nn.Sequential(nn.Conv3d(channels, channels * 4, 1, 1, 0), + normalization(channels * 4, channels_per_group=channels_per_group), + nn.GELU(), + nn.Conv3d(channels * 4, channels, 1, 1, 0) + ) + + def forward(self, x, y=None): + # return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + return self._forward(x, y) + + def _forward(self, x, y=None): + # x: [B, C, D, H, W], feature + # y: [B, H*W, D, C] + print(x.shape, y.shape) + assert x.dim() == 5 and y.dim() == 4 + # NOTE: the resolutions of feature x and color y are different + + b, c1, d, h, w = x.size() + lx = h * w + ly = y.size(1) + c2 = y.size(-1) + + identity = x + + x = x.permute(0, 2, 3, 4, 1).reshape(b * d, h * w, c1) # [B*D, H*W, C1] + y = y.permute(0, 2, 1, 3).reshape(b * d, ly, c2) # [B*D, H*W, C2] + + c = self.proj_channels + + q = self.q_proj(x) # [B*D, H*W, C] + k = self.k_proj(y) # [B*D, H*W, C] + v = self.v_proj(y) + + if self.num_head > 1: + assert c % self.num_head == 0 + q = q.view(b * d, lx, self.num_head, c // self.num_head) # [B*D, H*W, N, C] + k = k.view(b * d, ly, self.num_head, c // self.num_head) # [B*D, H*W, N, C] + v = v.view(b * d, ly, self.num_head, c // self.num_head) # [B*D, H*W, N, C] + + scores = torch.matmul(q.permute(0, 2, 1, 3), k.permute(0, 2, 3, 1)) / ((c // self.num_head) ** 0.5) # [B*D, N, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + out = torch.matmul(prob, v.permute(0, 2, 1, 3)) # [B*D, H*W, N, C] + out = out.view(b * d, lx, -1) # [B*D, H*W, C] + + else: + scores = torch.matmul(q, k.permute(0, 2, 1)) / (c ** 0.5) # [B*D, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, v) # [B*D, H*W, C] + + out = out.view(b, d, h, w, c).permute(0, 4, 1, 2, 3) # [B, C, D, H, W] + + # out = self.norm(out) + + if self.tanh_gating: + # print('tanh', self.attn_gate.tanh()) + out = self.attn_gate.tanh() * self.out_proj(out) + else: + if self.with_norm: + out = self.out_proj(out) + out = self.norm(out) + else: + out = self.out_proj(out) + + out = identity + out + + if self.ffn_after_cross_attn: + out = out + self.ffn_gate.tanh() * self.ffn(out) + + return out + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads, n_frames=2, use_cross_view_self_attn=False): + super().__init__() + self.n_heads = n_heads + self.n_frames = n_frames + self.use_cross_view_self_attn = use_cross_view_self_attn + + def forward(self, qkv, num_views=None): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + + # move the view dim into T for cross views attention + # (b v) ... + if self.use_cross_view_self_attn: + n_views = self.n_frames if num_views is None else num_views + qkv = rearrange(qkv, "(b v) n t -> b n (v t)", v=n_views) + + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v).reshape(bs, -1, length) + + # move view dim back to batch dim in original '(b v)' order + if self.use_cross_view_self_attn: + a = rearrange(a, "b n (v t) -> (b v) n t", v=n_views) + + return a + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + middle_block_attn=False, # use attn in middle block + middle_block_no_identity=False, # some previous models are trained without the identity layer + postnorm=False, # default prenorm doesn't converge + attn_prenorm=False, # try postnorm for resblock and prenorm for attn + downsample_3ddim=False, # downsample all 3d dims instead of only spatial dims + zero_final_layer=False, # init zero final output layer + channels_per_group=None, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + cross_attn_condition=False, + tanh_gating=False, + ffn_after_cross_attn=False, + cross_attn_with_norm=False, + condition_channels=384, + condition_num_views=3, + no_self_attn=False, + conv_kernel_size=3, + concat_condition=False, + concat_conv3x3=False, + num_frames=2, + use_cross_view_self_attn=False, + downsample_factor=None, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + self.middle_block_attn = middle_block_attn + + self.middle_block_no_identity = middle_block_no_identity + + # output lower resolution gaussians + self.downsample_factor = downsample_factor + + time_embed_dim = model_channels * 4 + # self.time_embed = nn.Sequential( + # linear(model_channels, time_embed_dim), + # nn.SiLU(), + # linear(time_embed_dim, time_embed_dim), + # ) + + # if self.num_classes is not None: + # self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.cross_attn_condition = cross_attn_condition + + self.input_blocks = nn.ModuleList( + [ + nn.Sequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + if not no_self_attn: # only cross attn, without self attn + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + postnorm=False if attn_prenorm else postnorm, + channels_per_group=channels_per_group, + num_frames=num_frames, + use_cross_view_self_attn=use_cross_view_self_attn, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + + if cross_attn_condition: + layers.append( + UNetCrossAttentionBlock(ch, + condition_channels, + dim=256, + no_cross_attn=concat_condition, + with_norm=cross_attn_with_norm, + concat_conv3x3=concat_conv3x3, + ) + ) + + self.input_blocks.append(nn.Sequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + nn.Sequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch, + downsample_3ddim=downsample_3ddim, + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + if self.middle_block_attn: + self.middle_block = nn.Sequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ), + # original has attention block in the middle + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + postnorm=False if attn_prenorm else postnorm, + channels_per_group=channels_per_group, + num_frames=num_frames, + use_cross_view_self_attn=use_cross_view_self_attn, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + # cross attention condition + UNetCrossAttentionBlock(ch, + condition_channels, + dim=256, + no_cross_attn=concat_condition, + with_norm=cross_attn_with_norm, + concat_conv3x3=concat_conv3x3, + ) if cross_attn_condition else nn.Identity(), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ), + ) + else: + if self.middle_block_no_identity: + self.middle_block = nn.Sequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ), + ) + else: + self.middle_block = nn.Sequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ), + UNetCrossAttentionBlock(ch, + condition_channels, + dim=256, + no_cross_attn=concat_condition, + with_norm=cross_attn_with_norm, + concat_conv3x3=concat_conv3x3, + ) if cross_attn_condition else nn.Identity(), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + if not no_self_attn: # only cross attn, without self attn + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + postnorm=False if attn_prenorm else postnorm, + channels_per_group=channels_per_group, + num_frames=num_frames, + use_cross_view_self_attn=use_cross_view_self_attn, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + + if cross_attn_condition: + layers.append( + UNetCrossAttentionBlock(ch, + condition_channels, + dim=256, + no_cross_attn=concat_condition, + with_norm=cross_attn_with_norm, + concat_conv3x3=concat_conv3x3, + ) + ) + + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + postnorm=postnorm, + channels_per_group=channels_per_group, + kernel_size=conv_kernel_size, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, + downsample_3ddim=downsample_3ddim, + ) + ) + ds //= 2 + self.output_blocks.append(nn.Sequential(*layers)) + self._feature_size += ch + + if self.downsample_factor is not None: + if self.downsample_factor == 2: + del self.output_blocks[-3:] + elif self.downsample_factor == 4: + del self.output_blocks[-5:] + else: + raise NotImplementedError + + if postnorm: + self.out = nn.Sequential( + conv_nd(dims, model_channels, out_channels, 3, padding=1), + normalization(out_channels, channels_per_group=channels_per_group) if not zero_final_layer else zero_module(normalization(out_channels, channels_per_group=channels_per_group)), + nn.SiLU(), + ) + else: + if self.downsample_factor is not None: + in_channels = self.model_channels * self.channel_mult[self.downsample_factor // 2] + model_channels = model_channels * self.channel_mult[self.downsample_factor // 2] + out_channels = model_channels + else: + in_channels = ch + self.out = nn.Sequential( + normalization(in_channels, channels_per_group=channels_per_group), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + # print(self.out) + + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch, channels_per_group=channels_per_group), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, num_views=None, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + # t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + # emb = self.time_embed(t_emb) + emb = None + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + # h = module(h, emb, context) + # if i == 0: # conv layer + + if self.cross_attn_condition: + for submodule in module: + if 'UNetCrossAttentionBlock' == submodule.__class__.__name__: + h = submodule(h, context) + else: + h = submodule(h) + else: + h = module(h) + # else: + # print(module) + # h = module(h, context) + hs.append(h) + # h = self.middle_block(h, emb, context) + # h = self.middle_block(h) + + for module in self.middle_block: + if 'UNetCrossAttentionBlock' == module.__class__.__name__: + h = module(h, context) + else: + h = module(h) + + # print(len(self.output_blocks)) + + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + # h = module(h, emb, context) + if self.cross_attn_condition: + for submodule in module: + if 'UNetCrossAttentionBlock' == submodule.__class__.__name__: + h = submodule(h, context) + else: + h = submodule(h) + else: + h = module(h) + # h = module(h, context) + # print(h.shape) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class StackUNet(nn.Module): + def __init__(self, + in_channels, + model_channels, + out_channels, + num_res_blocks=1, + attention_resolutions=[], + channel_mult=[1, 1, 1, 1], + num_head_channels=32, + dims=3, + postnorm=True, + attn_prenorm=False, + middle_block_attn=False, + num_stacks=1, + zero_final_layer=False, + resblock_updown=False, + channels_per_group=None, + cross_attn_condition=False, + cross_attn_with_norm=False, + condition_channels=128, + tanh_gating=False, + ffn_after_cross_attn=False, + condition_num_views=3, + no_self_attn=False, + middle_block_no_identity=False, + conv_kernel_size=3, + ): + + super().__init__() + + self.num_stacks = num_stacks + + self.stacks = nn.ModuleList() + + in_channels = in_channels + + for i in range(num_stacks): + self.stacks.append(UNetModel(image_size=None, + in_channels=in_channels, + model_channels=model_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + attention_resolutions=attention_resolutions, + channel_mult=channel_mult, + num_head_channels=num_head_channels, + dims=dims, + middle_block_attn=middle_block_attn, + middle_block_no_identity=middle_block_no_identity, + postnorm=postnorm, + attn_prenorm=attn_prenorm, + zero_final_layer=zero_final_layer and i == 0, + resblock_updown=resblock_updown, + channels_per_group=channels_per_group, + cross_attn_condition=cross_attn_condition, + tanh_gating=tanh_gating, + ffn_after_cross_attn=ffn_after_cross_attn, + cross_attn_with_norm=cross_attn_with_norm, + condition_channels=condition_channels, + condition_num_views=condition_num_views, + no_self_attn=no_self_attn, + conv_kernel_size=conv_kernel_size, + ) + ) + + in_channels = out_channels + + self.convs = nn.ModuleList() + + for i in range(num_stacks - 1): + self.convs.append(zero_module(conv_nd( + dims, out_channels, in_channels, 3, padding=1 + ))) + + + def forward(self, x, context=None): + x = self.stacks[0](x, context=context) + for i in range(self.num_stacks - 1): + residual = self.convs[i](self.stacks[i + 1](x, context=context)) + x = x + residual + + return x + + + diff --git a/optgs/model/encoder/unimatch/ldm_unet/util.py b/optgs/model/encoder/unimatch/ldm_unet/util.py new file mode 100644 index 0000000000000000000000000000000000000000..633e1f4ecec71366188ff389f925df8db7800ff3 --- /dev/null +++ b/optgs/model/encoder/unimatch/ldm_unet/util.py @@ -0,0 +1,294 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +# from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels, channels_per_group=None): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + # return GroupNorm32(32, channels) # original + + if channels_per_group is not None: + return GroupNorm(channels // channels_per_group, channels) + # return GroupNorm4(4, channels) + # if channels % channels_per_group == 0: + # # adjust group number according to the channels + # return GroupNorm8(channels // channels_per_group, channels) + # else: + # return GroupNorm4(4, channels) + + if channels % 8 != 0: + return GroupNorm4(4, channels) + + return GroupNorm8(8, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +class GroupNorm8(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +class GroupNorm4(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# class HybridConditioner(nn.Module): + +# def __init__(self, c_concat_config, c_crossattn_config): +# super().__init__() +# self.concat_conditioner = instantiate_from_config(c_concat_config) +# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + +# def forward(self, c_concat, c_crossattn): +# c_concat = self.concat_conditioner(c_concat) +# c_crossattn = self.crossattn_conditioner(c_crossattn) +# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/optgs/model/encoder/unimatch/matching.py b/optgs/model/encoder/unimatch/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..333213ac633b107e96f86584516ca7aca572794b --- /dev/null +++ b/optgs/model/encoder/unimatch/matching.py @@ -0,0 +1,90 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def warp_with_pose_depth_candidates( + feature1, + intrinsics, + pose, + depth, + clamp_min_depth=1e-3, + grid_sample_disable_cudnn=False, +): + """ + feature1: [B, C, H, W] + intrinsics: [B, 3, 3] + pose: [B, 4, 4] + depth: [B, D, H, W] + """ + + assert intrinsics.size(1) == intrinsics.size(2) == 3 + assert pose.size(1) == pose.size(2) == 4 + assert depth.dim() == 4 + + b, d, h, w = depth.size() + c = feature1.size(1) + + with torch.no_grad(): + # pixel coordinates + grid = coords_grid( + b, h, w, homogeneous=True, device=depth.device + ) # [B, 3, H, W] + # back project to 3D and transform viewpoint + points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] + points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( + 1, 1, d, 1 + ) * depth.view( + b, 1, d, h * w + ) # [B, 3, D, H*W] + points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] + # reproject to 2D image plane + points = torch.bmm(intrinsics, points.view(b, 3, -1)).view( + b, 3, d, h * w + ) # [B, 3, D, H*W] + pixel_coords = points[:, :2] / points[:, -1:].clamp( + min=clamp_min_depth + ) # [B, 2, D, H*W] + + # normalize to [-1, 1] + x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] + + # sample features + # ref: https://github.com/pytorch/pytorch/issues/88380 + # print(feature1.shape, grid.shape) + # hardcoded workaround + if feature1.numel() > 1000000: + grid_sample_disable_cudnn = True + with torch.backends.cudnn.flags(enabled=not grid_sample_disable_cudnn): + warped_feature = F.grid_sample( + feature1, + grid.view(b, d * h, w, 2), + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ).view( + b, c, d, h, w + ) # [B, C, D, H, W] + + return warped_feature diff --git a/optgs/model/encoder/unimatch/mv_transformer.py b/optgs/model/encoder/unimatch/mv_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a0e577020c394875836469493d1b6068d35a8b --- /dev/null +++ b/optgs/model/encoder/unimatch/mv_transformer.py @@ -0,0 +1,802 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat + +import torch.utils.checkpoint + +from .utils import split_feature, merge_splits + + +USE_PYTORCH_ATTN = True # flash attention 2 + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def generate_shift_window_attn_mask( + input_resolution, + window_size_h, + window_size_w, + shift_size_h, + shift_size_w, + device=torch.device("cuda"), +): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = ( + slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None), + ) + w_slices = ( + slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature( + img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True + ) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + return attn_mask + + +def single_head_split_window_attention( + q, + k, + v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, +): + # Ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] for 2-view + # for multi-view cross-attention, q: [B, L, C], k, v: [B, N-1, L, C] + + # multi(>2)-view corss-attention + if not (q.dim() == k.dim() == v.dim() == 3): + assert k.dim() == v.dim() == 4 + assert h is not None and w is not None + assert q.size(1) == h * w + + m = k.size(1) # m + 1 is num of views + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, m, h, w, c) # [B, N-1, H, W, C] + v = v.view(b, m, h, w, c) + + scale_factor = c**0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(2, 3)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(2, 3)) + + q = split_feature( + q, num_splits=num_splits, channel_last=True + ) # [B*K*K, H/K, W/K, C] + k = split_feature( + k.permute(0, 2, 3, 4, 1).reshape(b, h, w, -1), + num_splits=num_splits, + channel_last=True, + ) # [B*K*K, H/K, W/K, C*(N-1)] + v = split_feature( + v.permute(0, 2, 3, 4, 1).reshape(b, h, w, -1), + num_splits=num_splits, + channel_last=True, + ) # [B*K*K, H/K, W/K, C*(N-1)] + + k = ( + k.view(b_new, h // num_splits, w // num_splits, c, m) + .permute(0, 3, 1, 2, 4) + .reshape(b_new, c, -1) + ) # [B*K*K, C, H/K*W/K*(N-1)] + v = ( + v.view(b_new, h // num_splits, w // num_splits, c, m) + .permute(0, 1, 2, 4, 3) + .reshape(b_new, -1, c) + ) # [B*K*K, H/K*W/K*(N-1), C] + + if USE_PYTORCH_ATTN: + # single head + # [B, H, N, C] + # q = q.view(b_new, 1, -1, c) + # k = k.permute(0, 2, 1).contiguous().unsqueeze(1) + # v = v.unsqueeze(1) + # # print(q.shape) + # attn_mask = attn_mask.repeat(b, 1, m).unsqueeze(1) if with_shift else None + # out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask).squeeze(1) + + out = F.scaled_dot_product_attention( + q.view(b_new, 1, -1, c), + k.permute(0, 2, 1).contiguous().unsqueeze(1), + v.unsqueeze(1), + attn_mask=attn_mask.repeat(b, 1, m).unsqueeze(1) if with_shift else None + ).squeeze(1) + else: + scores = ( + torch.matmul(q.view(b_new, -1, c), k) / scale_factor + ) # [B*K*K, H/K*W/K, H/K*W/K*(N-1)] + + if with_shift: + scores += attn_mask.repeat(b, 1, m) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v) # [B*K*K, H/K*W/K, C] + + out = merge_splits( + out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, + channel_last=True, + ) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + else: + # 2-view self-attention or cross-attention + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c**0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature( + q, num_splits=num_splits, channel_last=True + ) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + if USE_PYTORCH_ATTN: + # single head + # [B, H, N, C] + # q = q.view(b_new, 1, -1, c) + # k = k.view(b_new, 1, -1, c) + # v = v.view(b_new, 1, -1, c) + # # print(q.shape) + # attn_mask = attn_mask.repeat(b, 1, 1).unsqueeze(1) if with_shift else None + # out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask).squeeze(1) + + out = F.scaled_dot_product_attention( + q.view(b_new, 1, -1, c), + k.view(b_new, 1, -1, c), + v.view(b_new, 1, -1, c), + attn_mask=attn_mask.repeat(b, 1, 1).unsqueeze(1) if with_shift else None + ).squeeze(1) + + else: + scores = ( + torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) + / scale_factor + ) # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits( + out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, + channel_last=True, + ) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +def multi_head_split_window_attention( + q, + k, + v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + num_head=1, +): + """Multi-head scaled dot-product attention + Args: + q: [N, L, D] + k: [N, S, D] + v: [N, S, D] + Returns: + out: (N, L, D) + """ + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + assert c % num_head == 0 + + scale_factor = (c // num_head) ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits) + v = split_feature(v, num_splits=num_splits) + + # multi-head attn + q = q.view(b_new, -1, num_head, c // num_head).permute(0, 2, 1, 3) # [B, N, H*W, C] + k = k.view(b_new, -1, num_head, c // num_head).permute(0, 2, 3, 1) # [B, N, C, H*W] + scores = torch.matmul(q, k) / scale_factor # [B*K*K, N, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.unsqueeze(1).repeat(b, num_head, 1, 1) + + attn = torch.softmax(scores, dim=-1) # [B*K*K, N, H/K*W/K, H/K*W/K] + + out = torch.matmul( + attn, v.view(b_new, -1, num_head, c // num_head).permute(0, 2, 1, 3) + ) # [B*K*K, N, H/K*W/K, C] + + out = merge_splits( + out.permute(0, 2, 1, 3).reshape(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, + ) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +class TransformerLayer(nn.Module): + def __init__( + self, + d_model=256, + nhead=1, + attention_type="swin", + no_ffn=False, + ffn_dim_expansion=4, + with_shift=False, + add_per_view_attn=False, + **kwargs, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.attention_type = attention_type + self.no_ffn = no_ffn + self.add_per_view_attn = add_per_view_attn + + self.with_shift = with_shift + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward( + self, + source, + target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + if "attn_type" in kwargs: + attn_type = kwargs["attn_type"] + else: + attn_type = self.attention_type + + # source, target: [B, L, C] for 2-view + # for multi-view cross-attention, source: [B, L, C], target: [B, N-1, L, C] + query, key, value = source, target, target + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] or [B, N-1, L, C] + value = self.v_proj(value) # [B, L, C] or [B, N-1, L, C] + + if attn_type == "swin" and attn_num_splits > 1: + if self.nhead > 1: + message = multi_head_split_window_attention( + query, + key, + value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + num_head=self.nhead, + ) + else: + if self.add_per_view_attn: + assert query.dim() == 3 and key.dim() == 4 and value.dim() == 4 + b, l, c = query.size() + query = query.unsqueeze(1).repeat( + 1, key.size(1), 1, 1 + ) # [B, N-1, L, C] + query = query.view(-1, l, c) # [B*(N-1), L, C] + key = key.view(-1, l, c) + value = value.view(-1, l, c) + message = single_head_split_window_attention( + query, + key, + value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + # [B, L, C] # add per view attn + message = message.view(b, -1, l, c).sum(1) + else: + message = single_head_split_window_attention( + query, + key, + value, + num_splits=attn_num_splits, + with_shift=self.with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__( + self, + d_model=256, + nhead=1, + attention_type="swin", + ffn_dim_expansion=4, + with_shift=False, + add_per_view_attn=False, + no_cross_attn=False, + **kwargs, + ): + super(TransformerBlock, self).__init__() + + self.no_cross_attn = no_cross_attn + + if no_cross_attn: + self.self_attn = TransformerLayer( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + add_per_view_attn=add_per_view_attn, + ) + else: + self.self_attn = TransformerLayer( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + ) + + self.cross_attn_ffn = TransformerLayer( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=with_shift, + add_per_view_attn=add_per_view_attn, + ) + + def forward( + self, + source, + target, + height=None, + width=None, + shifted_window_attn_mask=None, + attn_num_splits=None, + **kwargs, + ): + # source, target: [B, L, C] + # self attention + source = self.self_attn( + source, + source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + **kwargs, + ) + + if self.no_cross_attn: + return source + + # cross attention and ffn + source = self.cross_attn_ffn( + source, + target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + **kwargs, + ) + + return source + + +def batch_features(features, nn_matrix=None): + # construct inputs to multi-view transformer in batch + # features: list of [B, C, H, W] or [B, H*W, C] + + # query, key and value for transformer + q = [] + kv = [] + + num_views = len(features) + if nn_matrix is not None: + # (b v c h w) or (b v hw c) + features_tensor = torch.stack(features, dim=1) + + for i in range(num_views): + x = features.copy() + q.append(x.pop(i)) # [B, C, H, W] or [B, H*W, C] + + # [B, N-1, C, H, W] or [B, N-1, H*W, C] + if nn_matrix is not None: + # select views based on the provided nn matrix + if features_tensor.dim() == 5: + c, h, w = features_tensor.shape[-3:] + index = repeat(nn_matrix[:, i, 1:], "b v -> b v c h w", c=c, h=h, w=w) + elif features_tensor.dim() == 4: + hw, c = features_tensor.shape[-2:] + index = repeat(nn_matrix[:, i, 1:], "b v -> b v hw c", hw=hw, c=c) + + kv_x = torch.gather(features_tensor, dim=1, index=index) + else: + kv_x = torch.stack(x, dim=1) + kv.append(kv_x) + + q = torch.cat(q, dim=0) # [N*B, C, H, W] or [N*B, H*W, C] + kv = torch.cat(kv, dim=0) # [N*B, N-1, C, H, W] or [N*B, N-1, H*W, C] + + return q, kv + + +class MultiViewFeatureTransformer(nn.Module): + def __init__( + self, + num_layers=6, + d_model=128, + nhead=1, + attention_type="swin", + ffn_dim_expansion=4, + add_per_view_attn=False, + no_cross_attn=False, + use_checkpointing=False, + **kwargs, + ): + super(MultiViewFeatureTransformer, self).__init__() + + self.use_checkpointing = use_checkpointing + + self.attention_type = attention_type + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList( + [ + TransformerBlock( + d_model=d_model, + nhead=nhead, + attention_type=attention_type, + ffn_dim_expansion=ffn_dim_expansion, + with_shift=( + True if attention_type == "swin" and i % 2 == 1 else False + ), + add_per_view_attn=add_per_view_attn, + no_cross_attn=no_cross_attn, + ) + for i in range(num_layers) + ] + ) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + # zero init layers beyond 6 + if num_layers > 6: + for i in range(6, num_layers): + self.layers[i].self_attn.norm1.weight.data.zero_() + self.layers[i].self_attn.norm1.bias.data.zero_() + self.layers[i].cross_attn_ffn.norm2.weight.data.zero_() + self.layers[i].cross_attn_ffn.norm2.bias.data.zero_() + + def forward( + self, + multi_view_features, + attn_num_splits=None, + **kwargs, + ): + nn_matrix = kwargs.pop("nn_matrix", None) + + # multi_view_features: list of [B, C, H, W] + b, c, h, w = multi_view_features[0].shape + assert self.d_model == c + + num_views = len(multi_view_features) + + if self.attention_type == "swin" and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=multi_view_features[0].device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # [N*B, C, H, W], [N*B, N-1, C, H, W] + concat0, concat1 = batch_features(multi_view_features, nn_matrix=nn_matrix) + concat0 = concat0.reshape(num_views * b, c, -1).permute( + 0, 2, 1 + ) # [N*B, H*W, C] + c1_v = num_views - 1 if nn_matrix is None else nn_matrix.shape[-1] - 1 + concat1 = concat1.reshape(num_views * b, c1_v, c, -1).permute( + 0, 1, 3, 2 + ) # [N*B, N-1, H*W, C] + + for i, layer in enumerate(self.layers): + if self.use_checkpointing: + def custom_forward(x, y): + return layer(x, y, + height=h, + width=w, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + concat0 = torch.utils.checkpoint.checkpoint(custom_forward, concat0, concat1) + else: + concat0 = layer( + concat0, + concat1, + height=h, + width=w, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_num_splits=attn_num_splits, + ) + + if i < len(self.layers) - 1: + # list of features + features = list(concat0.chunk(chunks=num_views, dim=0)) + # [N*B, H*W, C], [N*B, N-1, H*W, C] + concat0, concat1 = batch_features(features, nn_matrix=nn_matrix) + + features = concat0.chunk(chunks=num_views, dim=0) + features = [ + f.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() for f in features + ] + + return features + + +def batch_features_camera_parameters( + features, + intrinsics, + extrinsics, + nn_matrix=None, + no_batch=False, +): + # construct inputs for warping with plane-sweep stereo + # features: list of [B, C, H, W] + # intrinsics: list of [B, 3, 3] + # extrinsics: list of [B, 4, 4] + + assert ( + features[0].dim() == 4 and intrinsics[0].dim() == 3 and extrinsics[0].dim() == 3 + ) + assert intrinsics[0].size(-1) == intrinsics[0].size(-2) == 3 + assert extrinsics[0].size(-1) == extrinsics[0].size(-2) == 4 + + # query, key and value for transformer + q = [] + q_intrinsics = [] + q_extrinsics = [] + kv = [] + kv_intrinsics = [] + kv_extrinsics = [] + + num_views = len(features) + if nn_matrix is not None: + features_tensor = torch.stack(features, dim=1) # [B, V, C, H, W] + intrinsics_tensor = torch.stack(intrinsics, dim=1) # [B, V, 3, 3] + extrinsics_tensor = torch.stack(extrinsics, dim=1) # [B, V, 4, 4] + + num_selected_views = nn_matrix.size(-1) - 1 + else: + num_selected_views = num_views - 1 + + for i in range(num_views): + # features + x = features.copy() + q.append(x.pop(i)) # [B, C, H, W] + + # camera + y = intrinsics.copy() + q_intrinsics.append(y.pop(i)) + z = extrinsics.copy() + q_extrinsics.append(z.pop(i)) + + # [B, V-1, C, H, W] + if nn_matrix is not None: + # select views based on the provided nn matrix + if features_tensor.dim() == 5: + c, h, w = features_tensor.shape[-3:] + index = repeat(nn_matrix[:, i, 1:], "b v -> b v c h w", c=c, h=h, w=w) + elif features_tensor.dim() == 4: + hw, c = features_tensor.shape[-2:] + index = repeat(nn_matrix[:, i, 1:], "b v -> b v hw c", hw=hw, c=c) + + kv_x = torch.gather(features_tensor, dim=1, index=index) + + # select intrinsics and extrinsics + index = repeat(nn_matrix[:, i, 1:], "b v -> b v 3 3") + kv_y_intrinsics = torch.gather(intrinsics_tensor, dim=1, index=index) + + index = repeat(nn_matrix[:, i, 1:], "b v -> b v 4 4") + kv_z_extrinsics = torch.gather(extrinsics_tensor, dim=1, index=index) + + else: + kv_x = torch.stack(x, dim=1) + kv_y_intrinsics = torch.stack(y, dim=1) + kv_z_extrinsics = torch.stack(z, dim=1) + + kv.append(kv_x) + kv_intrinsics.append(kv_y_intrinsics) + kv_extrinsics.append(kv_z_extrinsics) + + if no_batch: + # list of [B, C, H, W] + return q, q_intrinsics, q_extrinsics, kv, kv_intrinsics, kv_extrinsics + + c, h, w = q[0].shape[1:] + + q = torch.stack(q, dim=1).view(-1, c, h, w) # [BV, C, H, W] + q_intrinsics = torch.stack(q_intrinsics, dim=1).view(-1, 3, 3) # [BV, 3, 3] + q_extrinsics = torch.stack(q_extrinsics, dim=1).view(-1, 4, 4) # [BV, 4, 4] + kv = torch.stack(kv, dim=1).view( + -1, num_selected_views, c, h, w + ) # [BV, V-1, C, H, W] + kv_intrinsics = torch.stack(kv_intrinsics, dim=1).view( + -1, num_selected_views, 3, 3 + ) # [BV, V-1, 3, 3] + kv_extrinsics = torch.stack(kv_extrinsics, dim=1).view( + -1, num_selected_views, 4, 4 + ) # [BV, V-1, 4, 4] + + return q, q_intrinsics, q_extrinsics, kv, kv_intrinsics, kv_extrinsics diff --git a/optgs/model/encoder/unimatch/mv_unimatch.py b/optgs/model/encoder/unimatch/mv_unimatch.py new file mode 100644 index 0000000000000000000000000000000000000000..faf4a551df696b0af5395fae1440a30ed15b8d34 --- /dev/null +++ b/optgs/model/encoder/unimatch/mv_unimatch.py @@ -0,0 +1,673 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import CNNEncoder +from .vit_fpn import ViTFeaturePyramid +from .mv_transformer import ( + MultiViewFeatureTransformer, + batch_features_camera_parameters, +) +from .matching import warp_with_pose_depth_candidates +from .utils import mv_feature_add_position +from .dpt_head import DPTHead +from .ldm_unet.unet import UNetModel, AttentionBlock +from einops import rearrange +from .dinov2.dinov2 import DINOv2 + + +class MultiViewUniMatch(nn.Module): + def __init__( + self, + num_scales=1, + feature_channels=128, + upsample_factor=8, + lowest_feature_resolution=8, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + num_depth_candidates=128, + vit_type="vits", + unet_channels=128, + unet_channel_mult=[1, 1, 1], + unet_num_res_blocks=1, + unet_attn_resolutions=[4], + grid_sample_disable_cudnn=False, + only_features=False, + sample_log_depth=False, + bilinear_upsample_depth=False, + no_upsample_depth=False, + use_amp=False, + return_raw_mono_features=False, + max_mono_vit_input_size=560, # constrain the input resolution to vit + use_checkpointing=False, + **kwargs, + ): + super(MultiViewUniMatch, self).__init__() + + # CNN + self.feature_channels = feature_channels + self.num_scales = num_scales + self.lowest_feature_resolution = lowest_feature_resolution + self.upsample_factor = upsample_factor + self.only_features = only_features + self.bilinear_upsample_depth = bilinear_upsample_depth + self.no_upsample_depth = no_upsample_depth + self.return_raw_mono_features = return_raw_mono_features + self.max_mono_vit_input_size = max_mono_vit_input_size + + self.use_amp = use_amp + + # sample depth in the log scale instead of the inverse depth + self.sample_log_depth = sample_log_depth + + # monocular backbones: final + self.vit_type = vit_type + + # cost volume + self.num_depth_candidates = num_depth_candidates + + # upsampler + vit_feature_channel_dict = {"vits": 384, "vitb": 768, "vitl": 1024} + + vit_feature_channel = vit_feature_channel_dict[vit_type] + + # CNN + self.backbone = CNNEncoder( + output_dim=feature_channels, + num_output_scales=num_scales, + downsample_factor=upsample_factor, + lowest_scale=lowest_feature_resolution, + return_all_scales=True, + ) + + # Transformer + self.transformer = MultiViewFeatureTransformer( + num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + ffn_dim_expansion=ffn_dim_expansion, + use_checkpointing=use_checkpointing, + ) + + if self.num_scales > 1: + # generate multi-scale features + self.mv_pyramid = ViTFeaturePyramid( + in_channels=128, scale_factors=[2**i for i in range(self.num_scales)] + ) + + # monodepth + encoder = vit_type + # local load dinov2 + self.pretrained = DINOv2(encoder, + use_checkpointing=use_checkpointing, + ) + # self.pretrained = torch.hub.load( + # "facebookresearch/dinov2", "dinov2_{:}14".format(encoder) + # ) + + del self.pretrained.mask_token # unused + + if self.num_scales > 1: + # generate multi-scale features + self.mono_pyramid = ViTFeaturePyramid( + in_channels=vit_feature_channel, + scale_factors=[2**i for i in range(self.num_scales)], + ) + + if self.only_features: + return + + # UNet regressor + self.regressor = nn.ModuleList() + self.regressor_residual = nn.ModuleList() + self.depth_head = nn.ModuleList() + + for i in range(self.num_scales): + curr_depth_candidates = num_depth_candidates // (4**i) + cnn_feature_channels = 128 - (32 * i) + mv_transformer_feature_channels = 128 // (2**i) + + mono_feature_channels = vit_feature_channel // (2**i) + + # concat(cost volume, cnn feature, mv feature, mono feature) + in_channels = ( + curr_depth_candidates + + cnn_feature_channels + + mv_transformer_feature_channels + + mono_feature_channels + ) + + # unet channels + channels = unet_channels // (2**i) + + # unet channel mult & unet_attn_resolutions + if i > 0: + unet_channel_mult = unet_channel_mult + [1] + unet_attn_resolutions = [x * 2 for x in unet_attn_resolutions] + + # unet + modules = [ + nn.Conv2d(in_channels, channels, 3, 1, 1), + nn.GroupNorm(8, channels), + nn.GELU(), + ] + + modules.append( + UNetModel( + image_size=None, + in_channels=channels, + model_channels=channels, + out_channels=channels, + num_res_blocks=unet_num_res_blocks, + attention_resolutions=unet_attn_resolutions, + channel_mult=unet_channel_mult, + num_head_channels=32, + dims=2, + postnorm=False, + num_frames=2, + use_cross_view_self_attn=True, + ) + ) + + modules.append(nn.Conv2d(channels, channels, 3, 1, 1)) + + self.regressor.append(nn.Sequential(*modules)) + + # regressor residual + self.regressor_residual.append(nn.Conv2d(in_channels, channels, 1)) + + # depth head + self.depth_head.append( + nn.Sequential( + nn.Conv2d( + channels, channels * 2, 3, 1, 1, padding_mode="replicate" + ), + nn.GELU(), + nn.Conv2d( + channels * 2, + curr_depth_candidates, + 3, + 1, + 1, + padding_mode="replicate", + ), + ) + ) + + # upsampler + # concat(lowres_depth, cnn feature, mv feature, mono feature) + in_channels = ( + 1 + + cnn_feature_channels + + mv_transformer_feature_channels + + mono_feature_channels + ) + + model_configs = { + "vits": { + "in_channels": 384, + "features": 32, + "out_channels": [48, 96, 192, 384], + }, + "vitb": { + "in_channels": 768, + "features": 48, + "out_channels": [96, 192, 384, 768], + }, + "vitl": { + "in_channels": 1024, + "features": 64, + "out_channels": [128, 256, 512, 1024], + }, + } + + if not self.bilinear_upsample_depth and not self.no_upsample_depth: + self.upsampler = DPTHead( + **model_configs[vit_type], + downsample_factor=upsample_factor, + num_scales=num_scales, + ) + + self.grid_sample_disable_cudnn = grid_sample_disable_cudnn + + def normalize_images(self, images): + """Normalize image to match the pretrained UniMatch model. + images: (B, V, C, H, W) + """ + shape = [*[1] * (images.dim() - 3), 3, 1, 1] + mean = torch.tensor([0.485, 0.456, 0.406]).reshape(*shape).to(images.device) + std = torch.tensor([0.229, 0.224, 0.225]).reshape(*shape).to(images.device) + + return (images - mean) / std + + def extract_feature(self, images): + # images: [B, V, C, H, W] + b, v = images.shape[:2] + concat = rearrange(images, "b v c h w -> (b v) c h w") + # list of [BV, C, H, W], resolution from high to low + features = self.backbone(concat) + # reverse: resolution from low to high + features = features[::-1] + + return features + + def forward( + self, + images, + attn_splits_list=None, + intrinsics=None, + min_depth=1.0 / 0.5, # inverse depth range + max_depth=1.0 / 100, + num_depth_candidates=128, + extrinsics=None, + nn_matrix=None, + **kwargs, + ): + + results_dict = {} + depth_preds = [] + match_probs = [] + + # first normalize images + images = self.normalize_images(images) + b, v, _, ori_h, ori_w = images.shape + + # update the num_views in unet attention, useful for random input views + if not self.only_features: + set_num_views(self.regressor, num_views=v) + + # NOTE: in this codebase, intrinsics are normalized by image width and height + # in unimatch's codebase: https://github.com/autonomousvision/unimatch, no normalization + intrinsics = intrinsics.clone() + intrinsics[:, :, 0] *= ori_w + intrinsics[:, :, 1] *= ori_h + + # max_depth, min_depth: [B, V] -> [BV] + max_depth = max_depth.view(-1) + min_depth = min_depth.view(-1) + + if self.sample_log_depth: + # inverse depth to depth + min_depth, max_depth = 1. / max_depth, 1. / min_depth + min_depth, max_depth = torch.log(min_depth), torch.log(max_depth) + + # list of features, resolution low to high + # list of [BV, C, H, W] + with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): + features_list_cnn = self.extract_feature(images) + features_list_cnn_all_scales = features_list_cnn + features_list_cnn = features_list_cnn[: self.num_scales] + results_dict.update({"features_cnn_all_scales": features_list_cnn_all_scales}) + results_dict.update({"features_cnn": features_list_cnn}) + + # mv transformer features + # add position to features + attn_splits = attn_splits_list[0] + + # [BV, C, H, W] + features_cnn_pos = mv_feature_add_position( + features_list_cnn[0], attn_splits, self.feature_channels + ) + + # list of [B, C, H, W] + features_list = list( + torch.unbind( + rearrange(features_cnn_pos, "(b v) c h w -> b v c h w", b=b, v=v), dim=1 + ) + ) + with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): + if features_list[0].shape[-1] > 96: + attn_splits = 4 + + if features_list[0].shape[-1] > 192: + attn_splits = 8 + + features_list_mv = self.transformer( + features_list, + attn_num_splits=attn_splits, + nn_matrix=nn_matrix, + ) + + features_mv = rearrange( + torch.stack(features_list_mv, dim=1), "b v c h w -> (b v) c h w" + ) # [BV, C, H, W] + + if self.num_scales > 1: + # multi-scale mv features: resolution from low to high + # list of [BV, C, H, W] + with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): + features_list_mv = self.mv_pyramid(features_mv) + else: + features_list_mv = [features_mv] + + results_dict.update({"features_mv": features_list_mv}) + + # mono feature + ori_h, ori_w = images.shape[-2:] + + # TODO: support portrait images later + assert ori_h <= ori_w + if ori_w > self.max_mono_vit_input_size: + resize_w = self.max_mono_vit_input_size // 14 * 14 + resize_h = int((ori_h / ori_w) * self.max_mono_vit_input_size) // 14 * 14 + else: + resize_h, resize_w = ori_h // 14 * 14, ori_w // 14 * 14 + # print(resize_h, resize_w) + + concat = rearrange(images, "b v c h w -> (b v) c h w") + concat = F.interpolate( + concat, (resize_h, resize_w), mode="bilinear", align_corners=True + ) + + # get intermediate features + intermediate_layer_idx = { + "vits": [2, 5, 8, 11], + "vitb": [2, 5, 8, 11], + "vitl": [4, 11, 17, 23], + } + + with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): + mono_intermediate_features = list( + self.pretrained.get_intermediate_layers( + concat, intermediate_layer_idx[self.vit_type], return_class_token=False + ) + ) + + if self.return_raw_mono_features: + raw_mono_features = [] + + for i in range(len(mono_intermediate_features)): + curr_features = ( + mono_intermediate_features[i] + .reshape(concat.shape[0], resize_h // 14, resize_w // 14, -1) + .permute(0, 3, 1, 2) + .contiguous() + ) + if self.return_raw_mono_features: + raw_mono_features.append(curr_features) + # resize to 1/8 resolution + curr_features = F.interpolate( + curr_features, + (ori_h // 8, ori_w // 8), + mode="bilinear", + align_corners=True, + ) + mono_intermediate_features[i] = curr_features + + results_dict.update({"features_mono_intermediate": mono_intermediate_features}) + + if self.return_raw_mono_features: + results_dict.update({"raw_mono_features": raw_mono_features}) + + # last mono feature + # TODO: use all the intermediate features for depth estimation? + mono_features = mono_intermediate_features[-1] + + if self.lowest_feature_resolution == 4: + mono_features = F.interpolate( + mono_features, scale_factor=2, mode="bilinear", align_corners=True + ) + + if self.num_scales > 1: + # multi-scale mono features, resolution from low to high + # list of [BV, C, H, W] + with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): + features_list_mono = self.mono_pyramid(mono_features) + else: + features_list_mono = [mono_features] + + results_dict.update({"features_mono": features_list_mono}) + + if self.only_features: + return results_dict + + depth = None + + for scale_idx in range(self.num_scales): + downsample_factor = self.upsample_factor * ( + 2 ** (self.num_scales - 1 - scale_idx) + ) + + # scale intrinsics + intrinsics_curr = intrinsics.clone() # [B, V, 3, 3] + intrinsics_curr[:, :, :2] = intrinsics_curr[:, :, :2] / downsample_factor + + # build cost volume + features_mv = features_list_mv[scale_idx] # [BV, C, H, W] + + # list of [B, C, H, W] + features_mv_curr = list( + torch.unbind( + rearrange(features_mv, "(b v) c h w -> b v c h w", b=b, v=v), dim=1 + ) + ) + + intrinsics_curr = list( + torch.unbind(intrinsics_curr, dim=1) + ) # list of [B, 3, 3] + extrinsics_curr = list(torch.unbind(extrinsics, dim=1)) # list of [B, 4, 4] + + # ref: [BV, C, H, W], [BV, 3, 3], [BV, 4, 4] + # tgt: [BV, V-1, C, H, W], [BV, V-1, 3, 3], [BV, V-1, 4, 4] + ( + ref_features, + ref_intrinsics, + ref_extrinsics, + tgt_features, + tgt_intrinsics, + tgt_extrinsics, + ) = batch_features_camera_parameters( + features_mv_curr, + intrinsics_curr, + extrinsics_curr, + nn_matrix=nn_matrix, + ) + + b_new, _, c, h, w = tgt_features.size() + + # relative pose + # extrinsics: c2w + pose_curr = torch.matmul( + tgt_extrinsics.inverse(), ref_extrinsics.unsqueeze(1) + ) # [BV, V-1, 4, 4] + + if scale_idx > 0: + # 2x upsample depth + assert depth is not None + depth = F.interpolate( + depth, scale_factor=2, mode="bilinear", align_corners=True + ).detach() + + num_depth_candidates = self.num_depth_candidates // (4**scale_idx) + + # generate depth candidates + if scale_idx == 0: + # min_depth, max_depth: [BV] + depth_interval = (max_depth - min_depth) / ( + self.num_depth_candidates - 1 + ) # [BV] + + linear_space = ( + torch.linspace(0, 1, num_depth_candidates) + .type_as(features_list_cnn[0]) + .view(1, num_depth_candidates, 1, 1) + ) # [1, D, 1, 1] + + depth_candidates = min_depth.view(-1, 1, 1, 1) + linear_space * ( + max_depth - min_depth + ).view( + -1, 1, 1, 1 + ) # [BV, D, 1, 1] + else: + # half interval each scale + depth_interval = ( + (max_depth - min_depth) + / (self.num_depth_candidates - 1) + / (2**scale_idx) + ) # [BV] + # [BV, 1, 1, 1] + depth_interval = depth_interval.view(-1, 1, 1, 1) + + # [BV, 1, H, W] + depth_range_min = ( + depth - depth_interval * (num_depth_candidates // 2) + ).clamp(min=min_depth.view(-1, 1, 1, 1)) + depth_range_max = ( + depth + depth_interval * (num_depth_candidates // 2 - 1) + ).clamp(max=max_depth.view(-1, 1, 1, 1)) + + linear_space = ( + torch.linspace(0, 1, num_depth_candidates) + .type_as(features_list_cnn[0]) + .view(1, num_depth_candidates, 1, 1) + ) # [1, D, 1, 1] + depth_candidates = depth_range_min + linear_space * ( + depth_range_max - depth_range_min + ) # [BV, D, H, W] + + if scale_idx == 0: + # [BV*(V-1), D, H, W] + depth_candidates_curr = ( + depth_candidates.unsqueeze(1) + .repeat(1, tgt_features.size(1), 1, h, w) + .view(-1, num_depth_candidates, h, w) + ) + else: + depth_candidates_curr = ( + depth_candidates.unsqueeze(1) + .repeat(1, tgt_features.size(1), 1, 1, 1) + .view(-1, num_depth_candidates, h, w) + ) + + intrinsics_input = torch.stack(intrinsics_curr, dim=1).view( + -1, 3, 3 + ) # [BV, 3, 3] + intrinsics_input = intrinsics_input.unsqueeze(1).repeat( + 1, tgt_features.size(1), 1, 1 + ) # [BV, V-1, 3, 3] + + ref_features = ref_features.float() + tgt_features = tgt_features.float() + depth_candidates_curr = depth_candidates_curr.float() + + warped_tgt_features = warp_with_pose_depth_candidates( + rearrange(tgt_features, "b v ... -> (b v) ..."), + rearrange(intrinsics_input, "b v ... -> (b v) ..."), + rearrange(pose_curr, "b v ... -> (b v) ..."), + torch.exp(depth_candidates_curr) if self.sample_log_depth else 1.0 / depth_candidates_curr, # convert inverse/log depth to depth + grid_sample_disable_cudnn=self.grid_sample_disable_cudnn, + ) # [BV*(V-1), C, D, H, W] + + # ref: [BV, C, H, W] + # warped: [BV*(V-1), C, D, H, W] -> [BV, V-1, C, D, H, W] + warped_tgt_features = rearrange( + warped_tgt_features, + "(b v) ... -> b v ...", + b=b_new, + v=tgt_features.size(1), + ) + # [BV, V-1, D, H, W] -> [BV, D, H, W] + # average cross other views + cost_volume = ( + (ref_features.unsqueeze(-3).unsqueeze(1) * warped_tgt_features).sum(2) + / (c**0.5) + ).mean(1) + + # regressor + features_cnn = features_list_cnn[scale_idx] # [BV, C, H, W] + + features_mono = features_list_mono[scale_idx] # [BV, C, H, W] + + concat = torch.cat( + (cost_volume, features_cnn, features_mv, features_mono), dim=1 + ) + + with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): + out = self.regressor[scale_idx](concat) + self.regressor_residual[ + scale_idx + ](concat) + + out = out.float() + + # depth pred + match_prob = F.softmax( + self.depth_head[scale_idx](out), dim=1 + ) # [BV, D, H, W] + match_probs.append(match_prob) + + if scale_idx == 0: + # [BV, D, H, W] + depth_candidates = depth_candidates.repeat(1, 1, h, w) + depth = (match_prob * depth_candidates).sum( + dim=1, keepdim=True + ) # [BV, 1, H, W] + + # upsample to the original resolution for supervison at training time only + if self.training and scale_idx < self.num_scales - 1: + depth_bilinear = F.interpolate( + depth, + scale_factor=downsample_factor, + mode="bilinear", + align_corners=True, + ) + depth_preds.append(depth_bilinear) + + # final output, learned upsampler + if scale_idx == self.num_scales - 1: + if self.bilinear_upsample_depth or self.no_upsample_depth: + residual_depth = 0 + else: + with torch.amp.autocast(device_type='cuda', enabled=self.use_amp, dtype=torch.bfloat16): + residual_depth = self.upsampler( + mono_intermediate_features, + # resolution high to low + cnn_features=features_list_cnn_all_scales[::-1], + mv_features=( + features_mv if self.num_scales == 1 else features_list_mv[::-1] + ), + depth=depth, + ) + + if self.no_upsample_depth: + depth_preds.append(depth) + else: + depth_bilinear = F.interpolate( + depth, + scale_factor=self.upsample_factor, + mode="bilinear", + align_corners=True, + ) + depth = (depth_bilinear + residual_depth).clamp( + min=min_depth.view(-1, 1, 1, 1), max=max_depth.view(-1, 1, 1, 1) + ) + + depth_preds.append(depth) + + for i in range(len(depth_preds)): + if self.sample_log_depth: + # log depth to depth + depth_pred = torch.exp(depth_preds[i].squeeze(1)) + else: + # convert inverse depth to depth + depth_pred = 1.0 / depth_preds[i].squeeze(1) # [BV, H, W] + depth_preds[i] = rearrange( + depth_pred, "(b v) ... -> b v ...", b=b, v=v + ) # [B, V, H, W] + + results_dict.update({"depth_preds": depth_preds}) + results_dict.update({"match_probs": match_probs}) + + return results_dict + + +def set_num_views(module, num_views): + if isinstance(module, AttentionBlock): + module.attention.n_frames = num_views + elif ( + isinstance(module, nn.ModuleList) + or isinstance(module, nn.Sequential) + or isinstance(module, nn.Module) + ): + for submodule in module.children(): + set_num_views(submodule, num_views) diff --git a/optgs/model/encoder/unimatch/position.py b/optgs/model/encoder/unimatch/position.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3a1f2e4d2e063a736f205a06d004e277036026 --- /dev/null +++ b/optgs/model/encoder/unimatch/position.py @@ -0,0 +1,50 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import torch +import torch.nn as nn +import math + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/optgs/model/encoder/unimatch/utils.py b/optgs/model/encoder/unimatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c90ab7357bac8623e757ac6a5d70f95dcb989e2 --- /dev/null +++ b/optgs/model/encoder/unimatch/utils.py @@ -0,0 +1,179 @@ +import torch +import torch.nn.functional as F +from .position import PositionEmbeddingSine + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid( + [ + torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device), + ], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2.0, (h - 1) / 2.0]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 / 255.0 - mean) / std + img1 = (img1 / 255.0 - mean) / std + + return img0, img1 + + +def split_feature( + feature, + num_splits=2, + channel_last=False, +): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0, f"Feature size ({h}, {w}) must be divisible by num_splits ({num_splits})." + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c) + .permute(0, 1, 3, 2, 4, 5) + .reshape(b_new, h_new, w_new, c) + ) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = ( + feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits) + .permute(0, 2, 4, 1, 3, 5) + .reshape(b_new, c, h_new, w_new) + ) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits( + splits, + num_splits=2, + channel_last=False, +): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = ( + splits.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(new_b, num_splits * h, num_splits * w, c) + ) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = ( + splits.permute(0, 3, 1, 4, 2, 5) + .contiguous() + .view(new_b, c, num_splits * h, num_splits * w) + ) # [B, C, H, W] + + return merge + + +def generate_shift_window_attn_mask( + input_resolution, + window_size_h, + window_size_w, + shift_size_h, + shift_size_w, + device=torch.device("cuda"), +): + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = ( + slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None), + ) + w_slices = ( + slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature( + img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True + ) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + return attn_mask + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 + + +def mv_feature_add_position(features, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + assert features.dim() == 4 # [B*V, C, H, W] + + if attn_splits > 1: # add position in splited window + features_splits = split_feature(features, num_splits=attn_splits) + position = pos_enc(features_splits) + features_splits = features_splits + position + features = merge_splits(features_splits, num_splits=attn_splits) + else: + position = pos_enc(features) + features = features + position + + return features diff --git a/optgs/model/encoder/unimatch/vit_fpn.py b/optgs/model/encoder/unimatch/vit_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..511a151fa2b404670a1b1ae0846d4e0327c03e5c --- /dev/null +++ b/optgs/model/encoder/unimatch/vit_fpn.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# Ref: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py#L363 + + +class ViTFeaturePyramid(nn.Module): + """ + This module implements SimpleFeaturePyramid in :paper:`vitdet`. + It creates pyramid features built on top of the input feature map. + """ + + def __init__( + self, + in_channels, + scale_factors, + ): + """ + Args: + scale_factors (list[float]): list of scaling factors to upsample or downsample + the input features for creating pyramid features. + """ + super(ViTFeaturePyramid, self).__init__() + + self.scale_factors = scale_factors + + out_dim = dim = in_channels + self.stages = nn.ModuleList() + for idx, scale in enumerate(scale_factors): + if scale == 4.0: + layers = [ + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + nn.GELU(), + nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), + ] + out_dim = dim // 4 + elif scale == 2.0: + layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] + out_dim = dim // 2 + elif scale == 1.0: + layers = [] + elif scale == 0.5: + layers = [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + raise NotImplementedError(f"scale_factor={scale} is not supported yet.") + + if scale != 1.0: + layers.extend( + [ + nn.GELU(), + nn.Conv2d(out_dim, out_dim, 3, 1, 1), + ] + ) + layers = nn.Sequential(*layers) + + self.stages.append(layers) + + def forward(self, x): + results = [] + + for stage in self.stages: + results.append(stage(x)) + + return results + + +def _test(): + model = ViTFeaturePyramid( + 384, + scale_factors=[1, 2, 4], + ).cuda() + print(model) + + x = torch.randn(2, 384, 64, 96).cuda() + + out = model(x) + + for x in out: + print(x.shape) + + +if __name__ == "__main__": + _test() diff --git a/optgs/model/encoder/visualization/__init__.py b/optgs/model/encoder/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/model/encoder/visualization/encoder_visualizer.py b/optgs/model/encoder/visualization/encoder_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5af485b01faed25fbe3557f650d0c3809b216b6d --- /dev/null +++ b/optgs/model/encoder/visualization/encoder_visualizer.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +from jaxtyping import Float +from torch import Tensor + +T_cfg = TypeVar("T_cfg") +T_encoder = TypeVar("T_encoder") + + +class EncoderVisualizer(ABC, Generic[T_cfg, T_encoder]): + cfg: T_cfg + encoder: T_encoder + + def __init__(self, cfg: T_cfg, encoder: T_encoder) -> None: + self.cfg = cfg + self.encoder = encoder + + @abstractmethod + def visualize( + self, + context: dict, + global_step: int, + ) -> dict[str, Float[Tensor, "3 _ _"]]: + pass diff --git a/optgs/model/encoder/visualization/encoder_visualizer_depthsplat.py b/optgs/model/encoder/visualization/encoder_visualizer_depthsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..96b043e34b75337fdc2c06bda8d848fd3d74db91 --- /dev/null +++ b/optgs/model/encoder/visualization/encoder_visualizer_depthsplat.py @@ -0,0 +1,527 @@ +from pathlib import Path +from random import randrange +from typing import Optional + +import numpy as np +import torch +import wandb +from einops import rearrange, reduce, repeat +from jaxtyping import Bool, Float +from torch import Tensor + +from ....dataset.data_types import BatchedViews +from ....misc.heterogeneous_pairings import generate_heterogeneous_index +from ....visualization.annotation import add_label +from ....visualization.color_map import apply_color_map, apply_color_map_to_image +from ....visualization.colors import get_distinct_color +from ....visualization.drawing.lines import draw_lines +from ....visualization.drawing.points import draw_points +from ....visualization.layout import add_border, hcat, vcat +# from ...ply_export import export_ply +from ..encoder_depthsplat import EncoderDepthSplat +# from ..epipolar.epipolar_sampler import EpipolarSampling +from .encoder_visualizer import EncoderVisualizer +from .encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg + + +def box( + image: Float[Tensor, "3 height width"], +) -> Float[Tensor, "3 new_height new_width"]: + return add_border(add_border(image), 1, 0) + + +class EncoderVisualizerDepthSplat( + EncoderVisualizer[EncoderVisualizerDepthSplatCfg, EncoderDepthSplat] +): + def visualize( + self, + context: BatchedViews, + global_step: int, + ) -> dict[str, Float[Tensor, "3 _ _"]]: + # Short-circuit execution when using mvsplat. + return {} + + visualization_dump = {} + + softmax_weights = [] + + def hook(module, input, output): + softmax_weights.append(output) + + # Register hooks to grab attention. + handles = [ + layer[0].fn.attend.register_forward_hook(hook) + for layer in self.encoder.epipolar_transformer.transformer.layers + ] + + result = self.encoder.forward( + context, + global_step, + visualization_dump=visualization_dump, + deterministic=True, + ) + + # De-register hooks. + for handle in handles: + handle.remove() + + softmax_weights = torch.stack(softmax_weights) + + # Generate high-resolution context images that can be drawn on. + context_images = context["image"] + _, _, _, h, w = context_images.shape + length = min(h, w) + min_resolution = self.cfg.min_resolution + scale_multiplier = (min_resolution + length - 1) // length + if scale_multiplier > 1: + context_images = repeat( + context_images, + "b v c h w -> b v c (h rh) (w rw)", + rh=scale_multiplier, + rw=scale_multiplier, + ) + + # This is kind of hacky for now, since we're using it for short experiments. + if self.cfg.export_ply and wandb.run is not None: + name = wandb.run._name.split(" ")[0] + ply_path = Path(f"outputs/gaussians/{name}/{global_step:0>6}.ply") + export_ply( + context["extrinsics"][0, 0], + result.means[0], + visualization_dump["scales"][0], + visualization_dump["rotations"][0], + result.harmonics[0], + result.opacities[0], + ply_path, + ) + + return { + "attention": self.visualize_attention( + context_images, + visualization_dump["sampling"], + softmax_weights, + ), + "epipolar_samples": self.visualize_epipolar_samples( + context_images, + visualization_dump["sampling"], + ), + "epipolar_color_samples": self.visualize_epipolar_color_samples( + context_images, + context, + ), + "gaussians": self.visualize_gaussians( + context["image"], + result.opacities, + result.covariances, + result.harmonics[..., 0], # Just visualize DC component. + ), + "overlaps": self.visualize_overlaps( + context["image"], + visualization_dump["sampling"], + visualization_dump.get("is_monocular", None), + ), + "depth": self.visualize_depth( + context, + visualization_dump["depth"], + ), + } + + def visualize_attention( + self, + context_images: Float[Tensor, "batch view 3 height width"], + sampling: None, + attention: Float[Tensor, "layer bvr head 1 sample"], + ) -> Float[Tensor, "3 vis_height vis_width"]: + device = context_images.device + + # Pick a random batch element, view, and other view. + b, v, ov, r, s, _ = sampling.xy_sample.shape + rb = randrange(b) + rv = randrange(v) + rov = randrange(ov) + num_samples = self.cfg.num_samples + rr = np.random.choice(r, num_samples, replace=False) + rr = torch.tensor(rr, dtype=torch.int64, device=device) + + # Visualize the rays in the ray view. + ray_view = draw_points( + context_images[rb, rv], + sampling.xy_ray[rb, rv, rr], + 0, + radius=4, + x_range=(0, 1), + y_range=(0, 1), + ) + ray_view = draw_points( + ray_view, + sampling.xy_ray[rb, rv, rr], + [get_distinct_color(i) for i, _ in enumerate(rr)], + radius=3, + x_range=(0, 1), + y_range=(0, 1), + ) + + # Visualize attention in the sample view. + attention = rearrange( + attention, "l (b v r) hd () s -> l b v r hd s", b=b, v=v, r=r + ) + attention = attention[:, rb, rv, rr, :, :] + num_layers, _, hd, _ = attention.shape + + vis = [] + for il in range(num_layers): + vis_layer = [] + for ihd in range(hd): + # Create colors according to attention. + color = [get_distinct_color(i) for i, _ in enumerate(rr)] + color = torch.tensor(color, device=attention.device) + color = rearrange(color, "r c -> r () c") + attn = rearrange(attention[il, :, ihd], "r s -> r s ()") + color = rearrange(attn * color, "r s c -> (r s ) c") + + # Draw the alternating bucket lines. + vis_layer_head = draw_lines( + context_images[rb, self.encoder.sampler.index_v[rv, rov]], + rearrange( + sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy" + ), + rearrange( + sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy" + ), + color, + 3, + cap="butt", + x_range=(0, 1), + y_range=(0, 1), + ) + vis_layer.append(vis_layer_head) + vis.append(add_label(vcat(*vis_layer), f"Layer {il}")) + vis = add_label(add_border(add_border(hcat(*vis)), 1, 0), "Keys & Values") + vis = add_border(hcat(add_label(ray_view, "ray_view"), vis, align="top")) + return vis + + def visualize_depth( + self, + context: BatchedViews, + multi_depth: Float[Tensor, "batch view height width surface spp"], + ) -> Float[Tensor, "3 vis_width vis_height"]: + multi_vis = [] + *_, srf, _ = multi_depth.shape + for i in range(srf): + depth = multi_depth[..., i, :] + depth = depth.mean(dim=-1) + + # Compute relative depth and disparity. + near = rearrange(context["near"], "b v -> b v () ()") + far = rearrange(context["far"], "b v -> b v () ()") + relative_depth = (depth - near) / (far - near) + relative_disparity = 1 - (1 / depth - 1 / far) / (1 / near - 1 / far) + + relative_depth = apply_color_map_to_image(relative_depth, "turbo") + relative_depth = vcat(*[hcat(*x) for x in relative_depth]) + relative_depth = add_label(relative_depth, "Depth") + relative_disparity = apply_color_map_to_image(relative_disparity, "turbo") + relative_disparity = vcat(*[hcat(*x) for x in relative_disparity]) + relative_disparity = add_label(relative_disparity, "Disparity") + multi_vis.append(add_border(hcat(relative_depth, relative_disparity))) + + return add_border(vcat(*multi_vis)) + + def visualize_overlaps( + self, + context_images: Float[Tensor, "batch view 3 height width"], + sampling: None, + is_monocular: Optional[Bool[Tensor, "batch view height width"]] = None, + ) -> Float[Tensor, "3 vis_width vis_height"]: + device = context_images.device + b, v, _, h, w = context_images.shape + green = torch.tensor([0.235, 0.706, 0.294], device=device)[..., None, None] + rb = randrange(b) + valid = sampling.valid[rb].float() + ds = self.encoder.cfg.epipolar_transformer.downscale + valid = repeat( + valid, + "v ov (h w) -> v ov c (h rh) (w rw)", + c=3, + h=h // ds, + w=w // ds, + rh=ds, + rw=ds, + ) + + if is_monocular is not None: + is_monocular = is_monocular[rb].float() + is_monocular = repeat(is_monocular, "v h w -> v c h w", c=3, h=h, w=w) + + # Select context images in grid. + context_images = context_images[rb] + index, _ = generate_heterogeneous_index(v) + valid = valid * (green + context_images[index]) / 2 + + vis = vcat(*(hcat(im, hcat(*v)) for im, v in zip(context_images, valid))) + vis = add_label(vis, "Context Overlaps") + + if is_monocular is not None: + vis = hcat(vis, add_label(vcat(*is_monocular), "Monocular?")) + + return add_border(vis) + + def visualize_gaussians( + self, + context_images: Float[Tensor, "batch view 3 height width"], + opacities: Float[Tensor, "batch vrspp"], + covariances: Float[Tensor, "batch vrspp 3 3"], + colors: Float[Tensor, "batch vrspp 3"], + ) -> Float[Tensor, "3 vis_height vis_width"]: + b, v, _, h, w = context_images.shape + rb = randrange(b) + context_images = context_images[rb] + opacities = repeat( + opacities[rb], "(v h w spp) -> spp v c h w", v=v, c=3, h=h, w=w + ) + colors = rearrange(colors[rb], "(v h w spp) c -> spp v c h w", v=v, h=h, w=w) + + # Color-map Gaussian covariawnces. + det = covariances[rb].det() + det = apply_color_map(det / det.max(), "inferno") + det = rearrange(det, "(v h w spp) c -> spp v c h w", v=v, h=h, w=w) + + return add_border( + hcat( + add_label(box(hcat(*context_images)), "Context"), + add_label(box(vcat(*[hcat(*x) for x in opacities])), "Opacities"), + add_label( + box(vcat(*[hcat(*x) for x in (colors * opacities)])), "Colors" + ), + add_label(box(vcat(*[hcat(*x) for x in colors])), "Colors (Raw)"), + add_label(box(vcat(*[hcat(*x) for x in det])), "Determinant"), + ) + ) + + def visualize_probabilities( + self, + context_images: Float[Tensor, "batch view 3 height width"], + sampling: None, + pdf: Float[Tensor, "batch view ray sample"], + ) -> Float[Tensor, "3 vis_height vis_width"]: + device = context_images.device + + # Pick a random batch element, view, and other view. + b, v, ov, r, _, _ = sampling.xy_sample.shape + rb = randrange(b) + rv = randrange(v) + rov = randrange(ov) + num_samples = self.cfg.num_samples + rr = np.random.choice(r, num_samples, replace=False) + rr = torch.tensor(rr, dtype=torch.int64, device=device) + colors = [get_distinct_color(i) for i, _ in enumerate(rr)] + colors = torch.tensor(colors, dtype=torch.float32, device=device) + + # Visualize the rays in the ray view. + ray_view = draw_points( + context_images[rb, rv], + sampling.xy_ray[rb, rv, rr], + 0, + radius=4, + x_range=(0, 1), + y_range=(0, 1), + ) + ray_view = draw_points( + ray_view, + sampling.xy_ray[rb, rv, rr], + colors, + radius=3, + x_range=(0, 1), + y_range=(0, 1), + ) + + # Visualize probabilities in the sample view. + pdf = pdf[rb, rv, rr] + pdf = rearrange(pdf, "r s -> r s ()") + colors = rearrange(colors, "r c -> r () c") + sample_view = draw_lines( + context_images[rb, self.encoder.sampler.index_v[rv, rov]], + rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"), + rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"), + rearrange(pdf * colors, "r s c -> (r s) c"), + 6, + cap="butt", + x_range=(0, 1), + y_range=(0, 1), + ) + + # Visualize rescaled probabilities in the sample view. + pdf_magnified = pdf / reduce(pdf, "r s () -> r () ()", "max") + sample_view_magnified = draw_lines( + context_images[rb, self.encoder.sampler.index_v[rv, rov]], + rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"), + rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"), + rearrange(pdf_magnified * colors, "r s c -> (r s) c"), + 6, + cap="butt", + x_range=(0, 1), + y_range=(0, 1), + ) + + return add_border( + hcat( + add_label(ray_view, "Rays"), + add_label(sample_view, "Samples"), + add_label(sample_view_magnified, "Samples (Magnified PDF)"), + ) + ) + + def visualize_epipolar_samples( + self, + context_images: Float[Tensor, "batch view 3 height width"], + sampling: None, + ) -> Float[Tensor, "3 vis_height vis_width"]: + device = context_images.device + + # Pick a random batch element, view, and other view. + b, v, ov, r, s, _ = sampling.xy_sample.shape + rb = randrange(b) + rv = randrange(v) + rov = randrange(ov) + num_samples = self.cfg.num_samples + rr = np.random.choice(r, num_samples, replace=False) + rr = torch.tensor(rr, dtype=torch.int64, device=device) + + # Visualize the rays in the ray view. + ray_view = draw_points( + context_images[rb, rv], + sampling.xy_ray[rb, rv, rr], + 0, + radius=4, + x_range=(0, 1), + y_range=(0, 1), + ) + ray_view = draw_points( + ray_view, + sampling.xy_ray[rb, rv, rr], + [get_distinct_color(i) for i, _ in enumerate(rr)], + radius=3, + x_range=(0, 1), + y_range=(0, 1), + ) + + # Visualize the samples and epipolar lines in the sample view. + # First, draw the epipolar line in black. + sample_view = draw_lines( + context_images[rb, self.encoder.sampler.index_v[rv, rov]], + sampling.xy_sample_near[rb, rv, rov, rr, 0], + sampling.xy_sample_far[rb, rv, rov, rr, -1], + 0, + 5, + cap="butt", + x_range=(0, 1), + y_range=(0, 1), + ) + + # Create an alternating line color for the buckets. + color = repeat( + torch.tensor([0, 1], device=device), + "ab -> r (s ab) c", + r=len(rr), + s=(s + 1) // 2, + c=3, + ) + color = rearrange(color[:, :s], "r s c -> (r s) c") + + # Draw the alternating bucket lines. + sample_view = draw_lines( + sample_view, + rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"), + rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"), + color, + 3, + cap="butt", + x_range=(0, 1), + y_range=(0, 1), + ) + + # Draw the sample points. + sample_view = draw_points( + sample_view, + rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), + 0, + radius=4, + x_range=(0, 1), + y_range=(0, 1), + ) + sample_view = draw_points( + sample_view, + rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), + [get_distinct_color(i // s) for i in range(s * len(rr))], + radius=3, + x_range=(0, 1), + y_range=(0, 1), + ) + + return add_border( + hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View")) + ) + + def visualize_epipolar_color_samples( + self, + context_images: Float[Tensor, "batch view 3 height width"], + context: BatchedViews, + ) -> Float[Tensor, "3 vis_height vis_width"]: + device = context_images.device + + sampling = self.encoder.sampler( + context["image"], + context["extrinsics"], + context["intrinsics"], + context["near"], + context["far"], + ) + + # Pick a random batch element, view, and other view. + b, v, ov, r, s, _ = sampling.xy_sample.shape + rb = randrange(b) + rv = randrange(v) + rov = randrange(ov) + num_samples = self.cfg.num_samples + rr = np.random.choice(r, num_samples, replace=False) + rr = torch.tensor(rr, dtype=torch.int64, device=device) + + # Visualize the rays in the ray view. + ray_view = draw_points( + context_images[rb, rv], + sampling.xy_ray[rb, rv, rr], + 0, + radius=4, + x_range=(0, 1), + y_range=(0, 1), + ) + ray_view = draw_points( + ray_view, + sampling.xy_ray[rb, rv, rr], + [get_distinct_color(i) for i, _ in enumerate(rr)], + radius=3, + x_range=(0, 1), + y_range=(0, 1), + ) + + # Visualize the samples and in the sample view. + sample_view = draw_points( + context_images[rb, self.encoder.sampler.index_v[rv, rov]], + rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), + [get_distinct_color(i // s) for i in range(s * len(rr))], + radius=4, + x_range=(0, 1), + y_range=(0, 1), + ) + sample_view = draw_points( + sample_view, + rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), + rearrange(sampling.features[rb, rv, rov, rr], "r s c -> (r s) c"), + radius=3, + x_range=(0, 1), + y_range=(0, 1), + ) + + return add_border( + hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View")) + ) diff --git a/optgs/model/encoder/visualization/encoder_visualizer_depthsplat_cfg.py b/optgs/model/encoder/visualization/encoder_visualizer_depthsplat_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..2f0831572770802d97e23f7cc1fa3daad0cfef41 --- /dev/null +++ b/optgs/model/encoder/visualization/encoder_visualizer_depthsplat_cfg.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + +# This is in a separate file to avoid circular imports. + + +@dataclass +class EncoderVisualizerDepthSplatCfg: + num_samples: int + min_resolution: int + export_ply: bool diff --git a/optgs/model/ply_export.py b/optgs/model/ply_export.py new file mode 100644 index 0000000000000000000000000000000000000000..36ea93af81d80fc39fabdbfb1f6bd2a73a6a9940 --- /dev/null +++ b/optgs/model/ply_export.py @@ -0,0 +1,195 @@ +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from jaxtyping import Float +from plyfile import PlyData, PlyElement +from torch import Tensor +from optgs.model.types import Gaussians +from optgs.scene_trainer.gaussian_module import GaussiansModule + + +def construct_list_of_attributes(num_rest: int) -> list[str]: + attributes = ["x", "y", "z", "nx", "ny", "nz"] + for i in range(3): + attributes.append(f"f_dc_{i}") + for i in range(num_rest): + attributes.append(f"f_rest_{i}") + attributes.append("opacity") + for i in range(3): + attributes.append(f"scale_{i}") + for i in range(4): + attributes.append(f"rot_{i}") + return attributes + + +def export_ply( + # extrinsics: Float[Tensor, "4 4"], + means: Float[Tensor, "gaussian 3"], + scales: Float[Tensor, "gaussian 3"], + rotations: Float[Tensor, "gaussian 4"], + harmonics: Float[Tensor, "gaussian 3 d_sh"], + opacities: Float[Tensor, "gaussian"], + path: Path, + # align_to_view: bool = False, # whether to align world space to the view space (camera space) of the extrinsics +): + means = means.detach().cpu().numpy() + scales = scales.log().detach().cpu().numpy() + rotations = rotations.detach().cpu().numpy() + harmonics = harmonics.detach() # .cpu().numpy() + opacities = torch.logit(opacities[..., None]).detach().cpu().numpy() + + num_rest = 3 * (harmonics.shape[-1] - 1) + + dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(num_rest)] + elements = np.empty(means.shape[0], dtype=dtype_full) + attributes = ( + means, + np.zeros_like(means), + harmonics[..., 0].cpu().numpy(), + harmonics[..., 1:].flatten(start_dim=1).cpu().numpy(), + opacities, + scales, + rotations, + ) + attributes = np.concatenate(attributes, axis=1) + elements[:] = list(map(tuple, attributes)) + path.parent.mkdir(exist_ok=True, parents=True) + PlyData([PlyElement.describe(elements, "vertex")]).write(path) + + +def save_gaussian_ply( + gaussians: Gaussians | GaussiansModule, + save_path, + save_all_gaussians=True, # no trim +): + """ + Save Gaussians to a .ply file for visualization. + + The saved object will have opacities and scales in the pre-activation space, + i.e., before applying the activation functions (sigmoid for opacity, exp for scales). + + """ + + if not save_all_gaussians: + raise NotImplementedError("Not implemented yet.") + + if isinstance(gaussians, GaussiansModule): + + # no batch dimension + means = gaussians.means # [H*W, 3] + rotations = gaussians.rotations # [H*W, 4] in xyzw + scales = gaussians.scales # [H*W, 3] + opacities = gaussians.opacities # [H*W] + harmonics = gaussians.harmonics # [H*W, 3, d_sh] + + elif isinstance(gaussians, Gaussians): + assert gaussians.means.shape[0] == 1, "Batch size > 1 not supported for saving ply." + means = gaussians.means[0] # [H*W, 3] + rotations = F.normalize(gaussians.rotations_unnorm[0], dim=-1) # [H*W, 4] in xyzw + scales = gaussians.scales[0] # [H*W, 3] + opacities = gaussians.opacities[0] # [H*W] + harmonics = gaussians.harmonics[0] # [H*W, 3, d_sh] + + # export_ply expects activated values (post-exp scales, post-sigmoid opacities) + # and applies inverse activation internally. If values are already deactivated, + # we must activate them first to avoid double inverse activation. + if not gaussians.stores_activated: + scales = torch.exp(scales) + opacities = torch.sigmoid(opacities) + else: + raise ValueError(f"Unknown type of gaussians: {type(gaussians)}") + + # convert to wxyz for saving + rotations = rotations[:, [3, 0, 1, 2]] # [H*W, 4] in wxyz + + # This fn invert activation of opacity and scales (for standard gaussian object, loaded by viewer) + export_ply( + means=means, + scales=scales, + rotations=rotations, + harmonics=harmonics, # [H*W, 3, d_sh] + opacities=opacities, + path=save_path, + ) + + +def load_gaussians_ply(path, max_sh_degree=3) -> Gaussians: + """ Load Gaussians from a .ply file saved by export_ply(). + The loaded object will have opacities and scales in the pre-activation space, + i.e., before applying the activation functions (sigmoid for opacity, expfor scales). + + """ + + plydata = PlyData.read(path) + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + + # + if len(extra_f_names) == 0: + # loaded ply has no SH coefficients + # TODO: does this mean that features_dc probably encodes RGB which needs to be converted to SH0? + # all other features are zero + print("Loaded PLY has no SH coefficients, only DC features.") + features_extra = np.zeros((xyz.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)) + + elif len(extra_f_names) == (3 * (max_sh_degree + 1) ** 2 - 3): + # loaded ply has full SH coefficients + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)) + else: + # not know how to handle + raise ValueError("Mismatch in number of SH coefficients.") + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + # Create Gaussian object + means = torch.tensor(xyz, dtype=torch.float32) # [P, 3] + + opacities = torch.tensor(opacities, dtype=torch.float32).squeeze(-1) # [P] + opacities = torch.sigmoid(opacities) # convert to post-activation space + + harmonics = torch.zeros((xyz.shape[0], 3, (max_sh_degree + 1) ** 2), dtype=torch.float32) # [P, 3, d_sh] + harmonics[:, :, 0] = torch.tensor(features_dc[:, :, 0], dtype=torch.float32) + harmonics[:, :, 1:] = torch.tensor(features_extra, dtype=torch.float32) + + scales = torch.tensor(scales, dtype=torch.float32) + scales = torch.exp(scales) # convert to post-activation space + + quats = torch.tensor(rots, dtype=torch.float32) # in wxyz + quats = quats[:, [1, 2, 3, 0]] # convert to xyzw + quats = F.normalize(quats, dim=-1) # match 3DGS-LM get_rotation which normalizes before rendering + + return Gaussians( + means=means.unsqueeze(0), + harmonics=harmonics.unsqueeze(0), + opacities=opacities.unsqueeze(0), + scales=scales.unsqueeze(0), + rotations=quats.unsqueeze(0), + rotations_unnorm=quats.unsqueeze(0), + ) diff --git a/optgs/model/types.py b/optgs/model/types.py new file mode 100644 index 0000000000000000000000000000000000000000..f26de4c957cad3e5262a6c68c72ab0a44e7467f4 --- /dev/null +++ b/optgs/model/types.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass, fields + +import torch +from jaxtyping import Float, Bool, Int64, BFloat16 +from torch import Tensor + + +@dataclass +class Gaussians: + means: Float[Tensor, "batch gaussian dim"] + harmonics: Float[Tensor, "batch gaussian 3 d_sh"] + opacities: Float[Tensor, "batch gaussian"] + scales: Float[Tensor, "batch gaussian 3"] + rotations_unnorm: Float[Tensor, "batch gaussian 4"] + rotations: Float[Tensor, "batch gaussian 4"] | None = None + covariances: Float[Tensor, "batch gaussian dim dim"] | None = None + probabilities: Float[Tensor, "batch gaussian distr"] | None = None + # mask: Bool[Tensor, "batch gaussian"] | None = None + sel: Int64[Tensor, "valid_gaussian_1"] | None = None + filter_3D: Float[Tensor, "batch gaussian"] | None = None + gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None + norm_gradients: Float[Tensor, "batch valid_gaussian_1 total_dim"] | BFloat16[Tensor, "batch valid_gaussian_1 total_dim"] | None = None + deltas: Float[Tensor, "batch valid_gaussian_2 d_delta"] | BFloat16[Tensor, "batch valid_gaussian_2 d_delta"] | None = None # In case of predicting scale and mag, the raw deltas are 2*total_dim, cannot use SGD loss directly + visibility: Float[Tensor, "batch gaussian"] | None = None # visibility information at the end of the current batch for pruning + visibility_aggregator: Float[Tensor, "batch gaussian"] | None = None # aggregates visibility over epoch + stores_activated: bool = True # whether scales and opacities are stored in activated form + nr_valid: int = -1 # the number of valid gaussians (without padding) + + EXCLUDED_FROM_MASKING = {"sel", "stores_activated", "deltas", "gradients", "norm_gradients", "valid_gaussians"} # deltas are predicted to non masked values + + def to(self, device=None, dtype=None) -> "Gaussians": + """ Move all tensors to the specified device or dtype. """ + def to_with_none(tensor): + if isinstance(tensor, bool): + return tensor + elif isinstance(tensor, int): + return tensor + return tensor.to(device=device, dtype=dtype) if tensor is not None else None + + new_tensors = {field.name: to_with_none(getattr(self, field.name)) for field in fields(self)} + + return Gaussians(**new_tensors) + + def clone(self) -> "Gaussians": + """ Clone all tensors. """ + # handle None and bool fields + new_tensors = {} + for field in fields(self): + tensor = getattr(self, field.name) + if isinstance(tensor, bool): + new_tensors[field.name] = tensor + elif isinstance(tensor, int): + new_tensors[field.name] = tensor + elif tensor is not None: + new_tensors[field.name] = tensor.clone() + else: + new_tensors[field.name] = None + + return Gaussians(**new_tensors) + + # Override __getitem__ to support indexing + def __getitem__(self, idx) -> "Gaussians": + new_tensors = {} + for field in fields(self): + tensor = getattr(self, field.name) + if isinstance(tensor, bool): + new_tensors[field.name] = tensor + elif isinstance(tensor, int): + new_tensors[field.name] = tensor + elif tensor is not None and field.name not in self.EXCLUDED_FROM_MASKING: + new_tensors[field.name] = tensor[idx] + else: + new_tensors[field.name] = None + return Gaussians(**new_tensors) + + def sample_subset(self, sampled_indices) -> "Gaussians": + """ Randomly sample a subset of gaussians. """ + total_gaussians = self.means.shape[1] + sample_num = len(sampled_indices) + + new_tensors = {} + for field in fields(self): + tensor = getattr(self, field.name) + if tensor is not None: + if isinstance(tensor, bool): + new_tensors[field.name] = tensor + elif isinstance(tensor, int): + new_tensors[field.name] = tensor + else: + new_tensors[field.name] = tensor[:, sampled_indices] + else: + new_tensors[field.name] = None + print(f"Sampled {sample_num} / {total_gaussians} gaussians.") + return Gaussians(**new_tensors) + + def __len__(self): + return self.means.shape[1] + + def update_object_by_curr_mask(self, **new_values) -> "Gaussians": + """ Update certain element using the current mask. """ + sel = self.sel + new_tensors = {} + for field in fields(self): + tensor = getattr(self, field.name) # [B, G, ...] + if tensor is not None: + if field.name in new_values: + new_value = new_values[field.name] # [B, G_valid, ...] + if sel is None or new_value is None or field.name in self.EXCLUDED_FROM_MASKING: + tensor = new_value + else: + tensor = tensor.clone() + tensor[:, sel, ...] = new_value + new_tensors[field.name] = tensor + else: + if field.name in new_values: + if field.name in ["deltas", "gradients", "norm_gradients"]: + # Special case: allow updating deltas even if it is None + new_tensors[field.name] = new_values[field.name] + continue + assert new_values[field.name] is None, f"Cannot update a None field! {field.name}, got {new_values[field.name]}" + new_tensors[field.name] = None + return Gaussians(**new_tensors) diff --git a/optgs/paths.py b/optgs/paths.py new file mode 100644 index 0000000000000000000000000000000000000000..89f1d77b6d5807d056a6be4182f7cc1f22331aea --- /dev/null +++ b/optgs/paths.py @@ -0,0 +1,70 @@ +import os +import sys +from pathlib import Path + +from optgs.misc.io import CustomPath + +_PKG_DIR = CustomPath(__file__).resolve().parent # .../optgs (the package) +_REPO_ROOT = _PKG_DIR.parent # repo root (source) OR site-packages (installed) + +# Source checkout iff a project marker sits next to the package dir. +_IS_SOURCE_CHECKOUT = (_REPO_ROOT / "pyproject.toml").exists() or (_REPO_ROOT / ".git").exists() + +# Working root for run outputs (checkpoints/results/figures) and datasets. +# source checkout : the repo root (preserves existing behaviour) +# installed : $OPTGS_HOME if set, else the current working directory +if _IS_SOURCE_CHECKOUT: + PROJECT_DIR = CustomPath(str(_REPO_ROOT)) +else: + PROJECT_DIR = CustomPath(str(os.environ.get("OPTGS_HOME", Path.cwd()))) + +SRC_DIR = CustomPath(str(_PKG_DIR)) # importable package dir (always correct) +CKPT_DIR = PROJECT_DIR / "checkpoints" +RESULTS_DIR = PROJECT_DIR / "results" +FIGURES_DIR = PROJECT_DIR / "figures" +DATA_DIR = PROJECT_DIR / "datasets" + +DL3DV_480P_DIR = DATA_DIR / "dl3dv-480p-chunks" +DL3DV_COLMAP_SfM_DIR = DATA_DIR / "dl3dv-colmap-sfm" + +# Eval-index / font assets are NOT bundled in the wheel (they are data, like +# datasets). Resolution order: $OPTGS_ASSETS, else /assets in a source +# checkout. See README / DATASETS.md. +ASSETS_DIR = ( + CustomPath(str(os.environ["OPTGS_ASSETS"])) + if os.environ.get("OPTGS_ASSETS") + else (PROJECT_DIR / "assets") +) + + +def asset_path(rel) -> CustomPath: + """Resolve an asset file (eval-index JSON, font, ...). + + Absolute or already-existing paths pass through unchanged. A leading + ``assets/`` is stripped so callers may pass either ``assets/x.json`` or + ``x.json``. The base dir is ``$OPTGS_ASSETS`` (recommended for installed + use), else ``/assets`` in a source checkout. + """ + p = Path(str(rel)) + if p.is_absolute() or p.exists(): + return CustomPath(str(p)) + s = str(rel) + if s.startswith("assets/"): + s = s[len("assets/"):] + return CustomPath(str(ASSETS_DIR)) / s + + +# Auto-create the figures dir only in a source checkout (preserves dev +# behaviour). When installed, importing the library must not create +# directories in the user's CWD — callers create output dirs lazily. +if _IS_SOURCE_CHECKOUT: + try: + FIGURES_DIR.mkdir(parents=True, exist_ok=True) + except OSError: + pass + +if sys.gettrace() is not None: + print("Running in debug mode") + DEBUG = True +else: + DEBUG = False diff --git a/optgs/scene_trainer/__init__.py b/optgs/scene_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/scene_trainer/adc/__init__.py b/optgs/scene_trainer/adc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..200bf2d4d37ac0c0d3eb1b7c17d8003bf3ba9635 --- /dev/null +++ b/optgs/scene_trainer/adc/__init__.py @@ -0,0 +1,125 @@ +from typing import Any +from optgs.model.types import Gaussians +import torch +from torch import Tensor +from jaxtyping import Bool +from optgs.scene_trainer.gaussian_module import GaussiansModule +from optgs.scene_trainer.adc.base import BaseStrategyCfg +from optgs.scene_trainer.adc.mcmc import McmcStrategyState, update_mcmc_strategy_state +from optgs.scene_trainer.adc.vanilla import VanillaStrategyState, update_vanilla_strategy_state + +def init_strategy_state( + cfg: BaseStrategyCfg, + **kwargs +) -> VanillaStrategyState | McmcStrategyState: + + if cfg.name == "mcmc": + return McmcStrategyState.initialize( + device=kwargs["device"] + ) + elif cfg.name in ["default", "edgs", "none"]: + return VanillaStrategyState.initialize( + nr_points=kwargs["nr_points"], + device=kwargs["device"], + scene_extent=kwargs["scene_extent"], + ) + else: + raise NotImplementedError(f"ADC strategy state initialization not implemented for {cfg.name}") + +def update_strategy_state( + adc_state: VanillaStrategyState | McmcStrategyState, + **kwargs +) -> None: + """Updates adc_state in place.""" + + if isinstance(adc_state, VanillaStrategyState): + return update_vanilla_strategy_state( + adc_state, + radii_2d=kwargs["radii_2d"], + means2d_grads=kwargs["means2d_grads"], + visibility_mask=kwargs["visibility_mask"], + v=kwargs["v"], + w=kwargs["w"], + h=kwargs["h"], + ) + elif isinstance(adc_state, McmcStrategyState): + return update_mcmc_strategy_state(adc_state) + else: + raise NotImplementedError(f"ADC strategy state update not implemented for {type(adc_state)}") + + +def apply_adc_strategy( + cfg: BaseStrategyCfg, + step: int, + gaussians: Gaussians | GaussiansModule, + adc_state: VanillaStrategyState | McmcStrategyState, + smoothers: dict[str, Any], + zero_t: bool = False, + **kwargs +) -> tuple[int, int, int, float | None, float | None]: + """Applies ADC strategy and returns number of cloned, splitted, pruned GSs.""" + + if cfg.name in ["default", "edgs", "none"]: + from optgs.scene_trainer.adc.vanilla import apply_vanilla_strategy + assert isinstance(adc_state, VanillaStrategyState), "adc_state type mismatch." + return apply_vanilla_strategy( + cfg, + step=step, + gaussians=gaussians, + adc_state=adc_state, + smoothers=smoothers, + zero_t=zero_t + ) + elif cfg.name == "mcmc": + from optgs.scene_trainer.adc.mcmc import apply_mcmc_strategy + assert isinstance(adc_state, McmcStrategyState), "adc_state type mismatch." + return apply_mcmc_strategy( + cfg, + step=step, + gaussians=gaussians, + adc_state=adc_state, + smoothers=smoothers, + lr=kwargs["lr"], + # zero_t=zero_t + ) + else: + raise NotImplementedError(f"ADC strategy not implemented for {cfg.name}") + + +@torch.no_grad() +def post_backward( + cfg: BaseStrategyCfg, + step: int, + gaussians: Gaussians | GaussiansModule, + adc_state: VanillaStrategyState | McmcStrategyState, + smoothers: dict[str, Any], + radii_2d: Tensor, # [B, V, G, 2] + means2d_grads: Tensor | None, # [B, V, G, 2] + visibility_mask: Bool[Tensor, "b v gaussian"], # [B, V, G] + iter_batch_size: int, + w: int, + h: int, + zero_t: bool = False, + **kwargs +) -> tuple[int, int, int, float | None, float | None]: + + # update adc state + update_strategy_state( + adc_state, + radii_2d=radii_2d, + means2d_grads=means2d_grads, + visibility_mask=visibility_mask, + v=iter_batch_size, + w=w, + h=h, + ) + + return apply_adc_strategy( + cfg, + step=step, + gaussians=gaussians, + adc_state=adc_state, + smoothers=smoothers, + zero_t=zero_t, + **kwargs + ) diff --git a/optgs/scene_trainer/adc/base.py b/optgs/scene_trainer/adc/base.py new file mode 100644 index 0000000000000000000000000000000000000000..69f075a4c3252b23ea1cf0ad3de2e36c3ae83ac1 --- /dev/null +++ b/optgs/scene_trainer/adc/base.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass +from typing import Any, Literal, Callable +import torch +from torch import Tensor +from optgs.scene_trainer.gaussian_module import GaussiansModule +from optgs.model.types import Gaussians + +@dataclass +class GenericStrategyState: + pass + +@dataclass +class BaseStrategyCfg: + name: Literal["default", "edgs", "mcmc", "none"] + + do_densify: bool + do_prune: bool + do_opacity_reset: bool + + cap_max: int # Maximum number of GSs, -1 for no cap + noise_lr: float # MCMC samping noise learning rate, 0.0 for no MCMC sampling + + pause_refine_after_reset: int + refine_every: int + reset_every: int + refine_start_iter: int + refine_stop_iter: int + refine_scale2d_stop_iter: int # Until which iteration 2D scale based refinement / pruning is applied + + grow_grad2d: float # GSs with image plane gradient above this value will be split/duplicated + grow_scale3d: float # GSs with scale below this value will be duplicated. Above will be split + prune_scale3d: float # GSs with scale above this value will be pruned + prune_scale2d: float # GSs with 2d scale (normalized by image resolution) above this value will be pruned + grow_scale2d: float # GSs with 2d scale (normalized by image resolution) above this value will be split + min_opacity: float # GSs with opacity below this value will be pruned + prune_zero_radii: bool # GSs with zero radii in screen space will be pruned + + reduce_opacity: bool # Slightly reduce opacity every few steps + reduce_factor: float # Factor to reduce opacity by + reduce_every: int # Reduce opacity every N iterations + + # Fallback means lr used for MCMC noise injection when the optimizer has no means_lr_scheduler. + # Matches the original paper's intended scale: means_lr (~1.6e-4) * noise_lr (5e5) ≈ 80 world units. + fallback_means_lr: float + + # If True, relocated Gaussians inherit the optimizer state of the alive Gaussian they were + # sampled from (better initialization). If False, state is zeroed (original paper behaviour). + relocate_copy_state: bool = False + + # MCMC noise: cap on the scales used for the noise covariance (does NOT affect rendered scales). + # Needed because knn_based saturates clamp_refine_max_scale, producing covariances orders of + # magnitude larger than vanilla's Adam-evolved scales. The resulting MCMC noise overflows the + # renderer's tile-binning math and causes silent CUDA OOB. Rule of thumb: ~scene_scale / 5. + noise_scale_cap: float = 1.0 + + +def _prune_objects(prune_mask, objects): + for key in objects: + if objects[key] is not None: + objects[key].prune(prune_mask) + +def _clone_objects(clone_mask, objects, zero_t): + for key in objects: + if objects[key] is not None: + objects[key].clone(clone_mask, zero_t) + +def _split_objects(split_mask, objects, N, zero_t): + for key in objects: + if objects[key] is not None: + objects[key].split(split_mask, N, zero_t) + +def _add_to_objects(nr_new, objects): + for key in objects: + if objects[key] is not None: + objects[key].add(nr_new) + +def _replace_objects(dest_indices, from_indices, objects, zero_t): + for key in objects: + if objects[key] is not None: + objects[key].replace(from_indices, dest_indices, zero_t) + +def _1d_indices_from_mask(mask: Tensor | None) -> Tensor | None: + if mask is None: + return None + return mask.nonzero(as_tuple=True)[0] + + +def _densification_postfix( + gaussians: Gaussians | GaussiansModule, + adc_state: GenericStrategyState, + new_means: Tensor, + new_scales: Tensor, + new_opacities: Tensor, + new_rotations: Tensor, + new_rotations_unnorm: Tensor, + new_harmonics: Tensor, + new_covariances: Tensor | None, + params_fn: Callable[[Tensor], Tensor], + state_fn: Callable[[Tensor], Tensor], +) -> None: + """Updates gaussians and adc_state in place.""" + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("_densification_postfix not implemented for GaussiansModule") + + # update gaussians + gaussians.means = params_fn(gaussians.means, new_means, dim=1) + gaussians.scales = params_fn(gaussians.scales, new_scales, dim=1) + gaussians.opacities = params_fn(gaussians.opacities, new_opacities, dim=1) + gaussians.rotations = params_fn(gaussians.rotations, new_rotations, dim=1) + gaussians.rotations_unnorm = params_fn(gaussians.rotations_unnorm, new_rotations_unnorm, dim=1) + gaussians.harmonics = params_fn(gaussians.harmonics, new_harmonics, dim=1) + if gaussians.covariances is not None and new_covariances is not None: + gaussians.covariances = params_fn(gaussians.covariances, new_covariances, dim=1) + + # update adc state + adc_state.grad2d_norm_accum = state_fn(adc_state.grad2d_norm_accum) + adc_state.denom = state_fn(adc_state.denom) + adc_state.radii2d = state_fn(adc_state.radii2d) diff --git a/optgs/scene_trainer/adc/mcmc.py b/optgs/scene_trainer/adc/mcmc.py new file mode 100644 index 0000000000000000000000000000000000000000..03c9fc12b4a2b59d5a26f2b34d0627e3270781f6 --- /dev/null +++ b/optgs/scene_trainer/adc/mcmc.py @@ -0,0 +1,488 @@ +from dataclasses import dataclass +from typing import Any, Callable, Literal, Optional, Tuple +import math +import torch +import numpy as np +from jaxtyping import Float, Bool +from torch import Tensor +import torch.nn.functional as F +from optgs.model.types import Gaussians +import torch +from typeguard import value +from optgs.scene_trainer.gaussian_module import GaussiansModule +from optgs.scene_trainer.optimizer.layer import AdamInputSmoothing +import matplotlib.pyplot as plt +from optgs.scene_trainer.adc.base import BaseStrategyCfg, GenericStrategyState +from optgs.scene_trainer.adc.base import _replace_objects, _add_to_objects + + +def _make_lazy_cuda_func(name: str) -> Callable: + def call_cuda(*args, **kwargs): + # pylint: disable=import-outside-toplevel + from gsplat.cuda._backend import _C + + return getattr(_C, name)(*args, **kwargs) + + return call_cuda + +class _QuatScaleToCovarPreci(torch.autograd.Function): + """Converts quaternions and scales to covariance and precision matrices.""" + + @staticmethod + def forward( + ctx, + quats: Tensor, # [..., 4], + scales: Tensor, # [..., 3], + compute_covar: bool = True, + compute_preci: bool = True, + triu: bool = False, + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + covars, precis = _make_lazy_cuda_func("quat_scale_to_covar_preci_fwd")( + quats, scales, compute_covar, compute_preci, triu + ) + ctx.save_for_backward(quats, scales) + ctx.compute_covar = compute_covar + ctx.compute_preci = compute_preci + ctx.triu = triu + return covars, precis + + @staticmethod + def backward(ctx, v_covars: Tensor, v_precis: Tensor): + quats, scales = ctx.saved_tensors + compute_covar = ctx.compute_covar + compute_preci = ctx.compute_preci + triu = ctx.triu + if compute_covar and v_covars.is_sparse: + v_covars = v_covars.to_dense() + if compute_preci and v_precis.is_sparse: + v_precis = v_precis.to_dense() + v_quats, v_scales = _make_lazy_cuda_func("quat_scale_to_covar_preci_bwd")( + quats, + scales, + triu, + v_covars.contiguous() if compute_covar else None, + v_precis.contiguous() if compute_preci else None, + ) + return v_quats, v_scales, None, None, None + +@dataclass +class McmcStrategyState(GenericStrategyState): + # Add MCMC specific state variables here + binoms: Tensor # [n_max, n_max] + + def external_pruning(self, valid_points_mask: Tensor) -> None: + # MCMC has no 2D-gradient accumulators, nothing to prune + pass + + @classmethod + def initialize(cls, device: torch.device) -> "McmcStrategyState": + + # from gsplat + n_max = 51 + binoms = torch.zeros((n_max, n_max)) + for n in range(n_max): + for k in range(n + 1): + binoms[n, k] = math.comb(n, k) + + return cls( + binoms=binoms.to(device) + ) + +def update_mcmc_strategy_state( + adc_state: McmcStrategyState +) -> None: + """Updates adc_state in place.""" + pass + +@torch.no_grad() +def inject_noise_to_position( + gaussians: Gaussians | GaussiansModule, + scaler: float, + scale_cap: float = 1.0, +): + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("noise injection not implemented for GaussiansModule") + + # get all params and remove batch dim + means = gaussians.means.squeeze(0) # [G, 3] + scales = gaussians.scales.squeeze(0) # [G, 3] + opacities = gaussians.opacities.squeeze(0) # [G] + # rotations = gaussians.rotations.squeeze(0) # [G, 4] + rotations_unnorm = gaussians.rotations_unnorm.squeeze(0) # [G, 4] + rotations = F.normalize(rotations_unnorm, dim=-1) + + if gaussians.stores_activated: + # already activated + pass + else: + # activate + opacities = torch.sigmoid(opacities) # [G] + scales = torch.exp(scales) # [G, 3] + + def _quat_scale_to_covar_preci( + quats: Tensor, # [..., 4], + scales: Tensor, # [..., 3], + compute_covar: bool = True, + compute_preci: bool = True, + triu: bool = False, + ) -> Tuple[Optional[Tensor], Optional[Tensor]]: + """Converts quaternions and scales to covariance and precision matrices. + + Args: + quats: Quaternions (No need to be normalized). [..., 4] + scales: Scales. [..., 3] + compute_covar: Whether to compute covariance matrices. Default: True. If False, + the returned covariance matrices will be None. + compute_preci: Whether to compute precision matrices. Default: True. If False, + the returned precision matrices will be None. + triu: If True, the return matrices will be upper triangular. Default: False. + + Returns: + A tuple: + + - **Covariance matrices**. If `triu` is True the returned shape is [..., 6], otherwise [..., 3, 3]. + - **Precision matrices**. If `triu` is True the returned shape is [..., 6], otherwise [..., 3, 3]. + """ + batch_dims = quats.shape[:-1] + assert quats.shape == batch_dims + (4,), quats.shape + assert scales.shape == batch_dims + (3,), scales.shape + quats = quats.contiguous() + scales = scales.contiguous() + covars, precis = _QuatScaleToCovarPreci.apply( + quats, scales, compute_covar, compute_preci, triu + ) + return covars if compute_covar else None, precis if compute_preci else None + + # Cap the scales used for the noise covariance only — does NOT change the rendered Gaussian + # scales. knn_based's network saturates clamp_refine_max_scale, producing covariances orders + # of magnitude larger than vanilla's Adam-evolved scales; the resulting noise overflows the + # renderer's tile-binning math and causes a silent CUDA OOB downstream. See BaseStrategyCfg. + scales_for_noise = scales.clamp(max=scale_cap) + covars, _ = _quat_scale_to_covar_preci( + rotations, + scales_for_noise, + compute_covar=True, + compute_preci=False, + triu=False, + ) + + def op_sigmoid(x, k=100, x0=0.995): + return 1 / (1 + torch.exp(-k * (x - x0))) + + noise = ( + torch.randn_like(means) + * (op_sigmoid(1 - opacities)).unsqueeze(-1) + * scaler + ) + noise = torch.einsum("bij,bj->bi", covars, noise) + + means.add_(noise) + # means is a view of gaussians.means[0], so the add_ above already updated + # the underlying storage. Do NOT reassign gaussians.means here — that would + # replace the original leaf tensor with a requires_grad=False view (created + # inside @torch.no_grad), breaking gradient flow on the next iteration. + +@torch.no_grad() +def _multinomial_sample(weights: Tensor, n: int, replacement: bool = True) -> Tensor: + """Sample from a distribution using torch.multinomial or numpy.random.choice. + + This function adaptively chooses between `torch.multinomial` and `numpy.random.choice` + based on the number of elements in `weights`. If the number of elements exceeds + the torch.multinomial limit (2^24), it falls back to using `numpy.random.choice`. + + Args: + weights (Tensor): A 1D tensor of weights for each element. + n (int): The number of samples to draw. + replacement (bool): Whether to sample with replacement. Default is True. + + Returns: + Tensor: A 1D tensor of sampled indices. + """ + num_elements = weights.size(0) + + if num_elements <= 2**24: + # Use torch.multinomial for elements within the limit + return torch.multinomial(weights, n, replacement=replacement) + else: + # Fallback to numpy.random.choice for larger element spaces + weights = weights / weights.sum() + weights_np = weights.detach().cpu().numpy() + sampled_idxs_np = np.random.choice( + num_elements, size=n, p=weights_np, replace=replacement + ) + sampled_idxs = torch.from_numpy(sampled_idxs_np) + + # Return the sampled indices on the original device + return sampled_idxs.to(weights.device) + +@torch.no_grad() +def _compute_relocation( + opacities: Tensor, # [N] + scales: Tensor, # [N, 3] + ratios: Tensor, # [N] + binoms: Tensor, # [n_max, n_max] +) -> Tuple[Tensor, Tensor]: + """Compute new Gaussians from a set of old Gaussians. + + This function interprets the Gaussians as samples from a likelihood distribution. + It uses the old opacities and scales to compute the new opacities and scales. + This is an implementation of the paper + `3D Gaussian Splatting as Markov Chain Monte Carlo `_, + + Args: + opacities: The opacities of the Gaussians. [N] + scales: The scales of the Gaussians. [N, 3] + ratios: The relative frequencies for each of the Gaussians. [N] + binoms: Precomputed lookup table for binomial coefficients used in + Equation 9 in the paper. [n_max, n_max] + + Returns: + A tuple: + + **new_opacities**: The opacities of the new Gaussians. [N] + **new_scales**: The scales of the Gaussians. [N, 3] + """ + + N = opacities.shape[0] + n_max, _ = binoms.shape + assert scales.shape == (N, 3), scales.shape + assert ratios.shape == (N,), ratios.shape + opacities = opacities.contiguous() + scales = scales.contiguous() + ratios.clamp_(min=1, max=n_max - 1) + ratios = ratios.int().contiguous() + + new_opacities, new_scales = _make_lazy_cuda_func("relocation")( + opacities, scales, ratios, binoms, n_max + ) + return new_opacities, new_scales + + +def relocate( + gaussians: Gaussians | GaussiansModule, + smoothers: dict[str, Any], + adc_state: McmcStrategyState, + min_opacity: float, + copy_state: bool = False, +) -> int: + """Relocates Gaussians based on MCMC strategy. + + Args: + gaussians (Gaussians | GaussiansModule): Gaussian distributions to relocate. + smoothers (dict[str, Any]): Optimizer smoothers. + adc_state (McmcStrategyState): State of the ADC. + + Returns: + int: Number of relocated Gaussians. + """ + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("noise injection not implemented for GaussiansModule") + + n_relocated = 0 + + # get all params and remove batch dim + means = gaussians.means.squeeze(0) # [G, 3] + scales = gaussians.scales.squeeze(0) # [G, 3] + opacities = gaussians.opacities.squeeze(0) # [G] + rotations = gaussians.rotations.squeeze(0) # [G, 4] + rotations_unnorm = gaussians.rotations_unnorm.squeeze(0) # [G, 4] + rotations = F.normalize(rotations_unnorm, dim=-1) + harmonics = gaussians.harmonics.squeeze(0) # [G, H] + covariances = gaussians.covariances.squeeze(0) if gaussians.covariances is not None else None # [G, 6] or None + + if gaussians.stores_activated: + # already activated + pass + else: + # activate + opacities = torch.sigmoid(opacities) # [G] + scales = torch.exp(scales) # [G, 3] + + dead_mask = opacities <= min_opacity + n_gs = dead_mask.sum().item() + if n_gs > 0: + # Inplace relocate some dead Gaussians to the lives ones. + n_relocated = int(n_gs) + + dead_indices = dead_mask.nonzero(as_tuple=True)[0] + alive_indices = (~dead_mask).nonzero(as_tuple=True)[0] + n = len(dead_indices) + + # Sample for new GSs + eps = torch.finfo(torch.float32).eps + probs = opacities[alive_indices].flatten() # ensure its shape is [N,] + sampled_idxs = _multinomial_sample(probs, n, replacement=True) + sampled_idxs = alive_indices[sampled_idxs] + new_opacities, new_scales = _compute_relocation( + opacities=opacities[sampled_idxs], + scales=scales[sampled_idxs], + ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, + binoms=adc_state.binoms, + ) + new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity) + + if gaussians.stores_activated: + # already activated + pass + else: + # deactivate + new_opacities = torch.logit(new_opacities) # [n] + new_scales = torch.log(new_scales) # [n, 3] + + # replace values (batch dim = 0, Gaussian dim = 1) + gaussians.means[0, dead_indices] = means[sampled_idxs] + gaussians.scales[0, dead_indices] = new_scales # relocated scale (deactivated if needed) + gaussians.opacities[0, dead_indices] = new_opacities # relocated opacity (deactivated if needed) + gaussians.rotations_unnorm[0, dead_indices] = rotations_unnorm[sampled_idxs] + gaussians.harmonics[0, dead_indices] = harmonics[sampled_idxs] + if covariances is not None: + gaussians.covariances[0, dead_indices] = covariances[sampled_idxs] + + # replace smoothers + _replace_objects(dead_indices, sampled_idxs, smoothers, zero_t=not copy_state) + + return n_relocated + +def add_new( + gaussians: Gaussians | GaussiansModule, + smoothers: dict[str, Any], + adc_state: McmcStrategyState, + cap_max: int, + min_opacity: float +) -> int: + """Adds new Gaussians based on MCMC strategy. + + Args: + gaussians (Gaussians | GaussiansModule): Gaussian distributions to add new ones to. + smoothers (dict[str, Any]): Optimizer smoothers. + adc_state (McmcStrategyState): State of the ADC. + + Returns: + int: Number of new Gaussians added. + """ + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("noise injection not implemented for GaussiansModule") + + n_new = 0 + + # get all params and remove batch dim + means = gaussians.means.squeeze(0) # [G, 3] + scales = gaussians.scales.squeeze(0) # [G, 3] + opacities = gaussians.opacities.squeeze(0) # [G] + # rotations = gaussians.rotations.squeeze(0) # [G, 4] + rotations_unnorm = gaussians.rotations_unnorm.squeeze(0) # [G, 4] + rotations = F.normalize(rotations_unnorm, dim=-1) + harmonics = gaussians.harmonics.squeeze(0) # [G, H] + covariances = gaussians.covariances.squeeze(0) if gaussians.covariances is not None else None # [G, 6] or None + + if gaussians.stores_activated: + # already activated + pass + else: + # activate + opacities = torch.sigmoid(opacities) # [G] + scales = torch.exp(scales) # [G, 3] + + current_n_points = means.shape[0] + n_target = min(cap_max, int(1.05 * current_n_points)) + n_gs = max(0, n_target - current_n_points) + if n_gs > 0: + # add new + n_new = int(n_gs) + + eps = torch.finfo(torch.float32).eps + probs = opacities.flatten() + sampled_idxs = _multinomial_sample(probs, n_gs, replacement=True) + new_opacities, new_scales = _compute_relocation( + opacities=opacities[sampled_idxs], + scales=scales[sampled_idxs], + ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, + binoms=adc_state.binoms, + ) + new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity) + + # deactivate new opacities/scales for storage if needed + if not gaussians.stores_activated: + new_opacities = torch.logit(new_opacities.clamp(min=eps, max=1.0 - eps)) + new_scales = torch.log(new_scales) + + new_means = means[sampled_idxs] + new_rotations_unnorm = rotations_unnorm[sampled_idxs] + new_harmonics = harmonics[sampled_idxs] + + # append to existing Gaussians (batch dim = 0, Gaussian dim = 1) + gaussians.means = torch.cat([gaussians.means, new_means.unsqueeze(0)], dim=1) + gaussians.scales = torch.cat([gaussians.scales, new_scales.unsqueeze(0)], dim=1) + gaussians.opacities = torch.cat([gaussians.opacities, new_opacities.unsqueeze(0)], dim=1) + gaussians.rotations_unnorm = torch.cat([gaussians.rotations_unnorm, new_rotations_unnorm.unsqueeze(0)], dim=1) + gaussians.harmonics = torch.cat([gaussians.harmonics, new_harmonics.unsqueeze(0)], dim=1) + if gaussians.rotations is not None: + gaussians.rotations = torch.cat( + [gaussians.rotations, F.normalize(new_rotations_unnorm, dim=-1).unsqueeze(0)], dim=1 + ) + if covariances is not None: + gaussians.covariances = torch.cat( + [gaussians.covariances, covariances[sampled_idxs].unsqueeze(0)], dim=1 + ) + + # add new entries to smoothers state + _add_to_objects(n_new, smoothers) + + return n_new + +@torch.no_grad() +def apply_mcmc_strategy( + cfg: BaseStrategyCfg, + step: int, + gaussians: Gaussians | GaussiansModule, + adc_state: McmcStrategyState, + smoothers: dict[str, Any], + lr: float, + # zero_t: bool = False +) -> tuple[int, int, int, float | None, float | None]: + """Applies MCMC strategy to the given Gaussian distributions. + + Args: + cfg (BaseStrategyCfg): Configuration for the strategy. + step (int): Current training step. + gaussians (Gaussians | GaussiansModule): Gaussian distributions to apply the strategy to. + adc_state (McmcStrategyState): State of the ADC. + smoothers (dict[str, Any]): Optimizer smoothers. + lr (float): Learning rate for "means" attribute of the GS. + """ + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("cloning not implemented for GaussiansModule") + + # Densification and Pruning + nr_cloned, nr_splitted, nr_pruned = 0, 0, 0 + + # check if should densify/prune + if ( + step > cfg.refine_start_iter + and step % cfg.refine_every == 0 + and step % cfg.reset_every >= cfg.pause_refine_after_reset + ): + # teleport dead GSs to positions of alive ones + n_relocated_gs = relocate(gaussians, smoothers, adc_state, cfg.min_opacity, copy_state=cfg.relocate_copy_state) + + # grow population up to cap_max; stop before refine_stop_iter so new Gaussians can converge + n_new_gs = 0 + if step < cfg.refine_stop_iter: + n_new_gs = add_new(gaussians, smoothers, adc_state, cfg.cap_max, cfg.min_opacity) + + torch.cuda.empty_cache() + + print( + f"MCMC @ iter {step}: n_relocated {n_relocated_gs}, n_new_gs {n_new_gs}, total now {gaussians.means.shape[1]}" + ) + + # add noise to GSs + inject_noise_to_position( + gaussians, scaler=lr * cfg.noise_lr, scale_cap=cfg.noise_scale_cap + ) + # no need to update smoothers + + return nr_cloned, nr_splitted, nr_pruned, None, None diff --git a/optgs/scene_trainer/adc/vanilla.py b/optgs/scene_trainer/adc/vanilla.py new file mode 100644 index 0000000000000000000000000000000000000000..366a877bd3c75ee595f42478406842626cc2c118 --- /dev/null +++ b/optgs/scene_trainer/adc/vanilla.py @@ -0,0 +1,537 @@ +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn.functional as F +from jaxtyping import Float, Bool +from torch import Tensor + +from optgs.model.types import Gaussians +from optgs.scene_trainer.adc.base import BaseStrategyCfg, GenericStrategyState +from optgs.scene_trainer.adc.base import ( + _1d_indices_from_mask, + _densification_postfix, + _prune_objects, + _split_objects, + _clone_objects, +) +from optgs.scene_trainer.gaussian_module import GaussiansModule + + +@dataclass +class VanillaStrategyState(GenericStrategyState): + # for densification and pruning + grad2d_norm_accum: Float[Tensor, "gaussian"] # running accum of the norm of the image plane gradients for each GS + denom: Float[Tensor, "gaussian"] + radii2d: Float[Tensor, "gaussian"] # max radius in 2D screen space observed for each GS + scene_extent: Float[float, ""] + + def external_pruning(self, valid_points_mask: Tensor) -> None: + if self.grad2d_norm_accum is not None: + self.grad2d_norm_accum = self.grad2d_norm_accum[valid_points_mask] + if self.radii2d is not None: + self.radii2d = self.radii2d[valid_points_mask] + if self.denom is not None: + self.denom = self.denom[valid_points_mask] + + @classmethod + def initialize(cls, nr_points: int, device: torch.device, scene_extent: int | float) -> "VanillaStrategyState": + return cls( + grad2d_norm_accum=torch.zeros(nr_points, device=device), + denom=torch.zeros(nr_points, device=device), + radii2d=torch.zeros(nr_points, device=device), + scene_extent=scene_extent, + ) + + +def update_vanilla_strategy_state( + adc_state: VanillaStrategyState, + radii_2d: Tensor, # [B, V, G, 2] + means2d_grads: Tensor | None, # [B, V, G, 2] + visibility_mask: Bool[Tensor, "b v gaussian"], # [B, V, G] + v: int, # number of views rendered + w: int, # image width + h: int, # image height +) -> None: + """Updates adc_state in place.""" + # get gs ids from visibility mask + + visibility_mask = visibility_mask.squeeze(0) # [V, G], assume batch size 1 + gs_ids = torch.where(visibility_mask)[1] # [G_valid] + assert visibility_mask.ndim == 2, "visibility_mask should be of shape [V, G]" + + if means2d_grads is not None: + assert means2d_grads.ndim == 4 and means2d_grads.shape[-1] == 2, "means2d_grads should be of shape [B, V, G, 2]" + means2d_grads = means2d_grads.squeeze(0) # [V, G, 2], assume batch size 1 + grads = means2d_grads[visibility_mask] # [G_valid, 2] + + # normalize grads to [-1, 1] screen space + grads[..., 0] *= w / 2.0 * v + grads[..., 1] *= h / 2.0 * v + + # accumulate 2D grads norm + adc_state.grad2d_norm_accum.index_add_(0, gs_ids, grads.norm(dim=-1)) + + # accumulate denominator + adc_state.denom.index_add_( + 0, gs_ids, torch.ones_like(gs_ids, dtype=torch.float32) + ) + + radii_2d = radii_2d.squeeze(0) # [V, N, 2], assume batch size 1 + assert radii_2d.ndim == 3 and radii_2d.shape[2] == 2, "radii_2d should be of shape [V, G, 2]" + + radii_2d = radii_2d[visibility_mask] + radii_max = radii_2d.max(dim=-1).values # [V, N] + # normalize radii to [0, 1] screen space + radii_max /= float(max(w, h)) + + # update radii2d + adc_state.radii2d[gs_ids] = torch.maximum(adc_state.radii2d[gs_ids], radii_max) + + +def reset_adc_state( + adc_state: VanillaStrategyState, +) -> None: + """Resets adc_state in place.""" + adc_state.grad2d_norm_accum.zero_() + adc_state.denom.zero_() + adc_state.radii2d.zero_() + torch.cuda.empty_cache() + +def prune( + gaussians: Gaussians | GaussiansModule, + adc_state: VanillaStrategyState, + prune_mask: Tensor, +) -> None: + """Gaussians are updated in place.""" + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("pruning not implemented for GaussiansModule") + + # TODO Naama: check if we can avoid squeezing and unsqueezing + # get all params and remove batch dim + means = gaussians.means.squeeze(0) # [G, 3] + scales = gaussians.scales.squeeze(0) # [G, 3] + opacities = gaussians.opacities.squeeze(0) # [G] + rotations = gaussians.rotations.squeeze(0) # [G, 4] + rotations_unnorm = gaussians.rotations_unnorm.squeeze(0) # [G, 4] + harmonics = gaussians.harmonics.squeeze(0) # [G, 3, d_sh] + covariances = gaussians.covariances.squeeze(0) if gaussians.covariances is not None else None # [G, + + keep_idx = _1d_indices_from_mask(~prune_mask) + + # prune gaussians + gaussians.means = means[keep_idx].unsqueeze(0) + gaussians.scales = scales[keep_idx].unsqueeze(0) + gaussians.opacities = opacities[keep_idx].unsqueeze(0) + gaussians.rotations = rotations[keep_idx].unsqueeze(0) + gaussians.rotations_unnorm = rotations_unnorm[keep_idx].unsqueeze(0) + gaussians.harmonics = harmonics[keep_idx].unsqueeze(0) + if covariances is not None: + gaussians.covariances = covariances[keep_idx].unsqueeze(0) + + # prune adc state + adc_state.grad2d_norm_accum = adc_state.grad2d_norm_accum[keep_idx] + adc_state.radii2d = adc_state.radii2d[keep_idx] + adc_state.denom = adc_state.denom[keep_idx] + +def splitting( + gaussians: Gaussians | GaussiansModule, + adc_state: VanillaStrategyState, + split_mask: Tensor, + N=2, + revised_opacity: bool = False, +) -> None: + """Gaussians are updated in place. + + revised_opacity: Whether to use revised opacity formulation + from arXiv:2404.06109. Default: False. + """ + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("splitting not implemented for GaussiansModule") + + # get all params and remove batch dim + means = gaussians.means.squeeze(0) # [G, 3] + scales = gaussians.scales.squeeze(0) # [G, 3] + opacities = gaussians.opacities.squeeze(0) # [G] + # rotations = gaussians.rotations.squeeze(0) # [G, 4] + rotations_unnorm = gaussians.rotations_unnorm.squeeze(0) # [G, 4] + harmonics = gaussians.harmonics.squeeze(0) # [G, 3, d_sh] + covariances = gaussians.covariances.squeeze(0) if gaussians.covariances is not None else None # [G, 3, 3] + rotations = F.normalize(rotations_unnorm, dim=-1) + + if gaussians.stores_activated: + # already activated + pass + else: + # activate + opacities = torch.sigmoid(opacities) # [G] + scales = torch.exp(scales) # [G, 3] + + sel = _1d_indices_from_mask(split_mask) + rest = _1d_indices_from_mask(~split_mask) + + # get params to split + scales = scales[sel] # [S, 3] + rotations = rotations[sel] # [S, 4] + rotations_unnorm = rotations_unnorm[sel] # [S, 4] + opacities = opacities[sel] # [S] + means = means[sel] # [S, 3] + harmonics = harmonics[sel] # [S, 3, d_sh] + if covariances is not None: + covariances = covariances[sel] # [S, 3, 3] + + def _normalized_quat_to_rotmat(quat: Tensor) -> Tensor: + """Convert normalized quaternion to rotation matrix. + + Args: + quat: Normalized quaternion in wxyz convension. (..., 4) + + Returns: + Rotation matrix (..., 3, 3) + """ + assert quat.shape[-1] == 4, quat.shape + w, x, y, z = torch.unbind(quat, dim=-1) + mat = torch.stack( + [ + 1 - 2 * (y ** 2 + z ** 2), + 2 * (x * y - w * z), + 2 * (x * z + w * y), + 2 * (x * y + w * z), + 1 - 2 * (x ** 2 + z ** 2), + 2 * (y * z - w * x), + 2 * (x * z - w * y), + 2 * (y * z + w * x), + 1 - 2 * (x ** 2 + y ** 2), + ], + dim=-1, + ) + return mat.reshape(quat.shape[:-1] + (3, 3)) + + # new means + rotmats = _normalized_quat_to_rotmat(rotations[:, [3, 0, 1, 2]]) # [N, 3, 3] xyzw to wxyz + device = means.device + samples = torch.einsum( + "nij,nj,bnj->bni", + rotmats, + scales, + torch.randn(N, len(scales), 3, device=device), + ) # [split, N, 3] + new_means = (means + samples).reshape(-1, 3) # [2N, 3] + + # new scales + new_scales = (scales / 1.6).repeat(N, 1) # [2N, 3] + + # new opacities + if revised_opacity: + new_opacities = (1.0 - torch.sqrt(1.0 - opacities)).repeat(N) # [2N] + else: + new_opacities = opacities.repeat(N) # [2N] + + if gaussians.stores_activated: + # already activated + pass + else: + # activate + new_opacities = torch.logit(new_opacities) + new_scales = torch.log(new_scales) + + # new rotations + new_rotations = rotations.repeat(N, 1) # [2N, 4] + + # new rotations unnorm + new_rotations_unnorm = rotations_unnorm.repeat(N, 1) # [2N, 4] + + # new harmonics + new_harmonics = harmonics.repeat(N, 1, 1) # [2N, 3, d_sh] + + # new covariances + if covariances is not None: + new_covariances = covariances.repeat(N, 1, 1) # [2N, 3, 3] + else: + new_covariances = None + + def params_fn(v: Tensor, v_new: Tensor, dim: int) -> Tensor: + v = v.squeeze(0)[rest].unsqueeze(0) + return torch.cat([v, v_new.unsqueeze(0)], dim=dim) + + def state_fn(v: Tensor) -> Tensor: + repeats = [2] + [1] * (v.dim() - 1) + v_new = v[sel].repeat(repeats) + return torch.cat([v[rest], v_new], dim=0) + + _densification_postfix( + gaussians, + adc_state, + new_means, + new_scales, + new_opacities, + new_rotations, + new_rotations_unnorm, + new_harmonics, + new_covariances, + params_fn, + state_fn, + ) + +def cloning( + gaussians: Gaussians | GaussiansModule, + adc_state: VanillaStrategyState, + clone_mask: Tensor, +) -> None: + """Gaussians are updated in place.""" + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("cloning not implemented for GaussiansModule") + + # get all params and remove batch dim + means = gaussians.means.squeeze(0) # [G, 3] + scales = gaussians.scales.squeeze(0) # [G, 3] + opacities = gaussians.opacities.squeeze(0) # [G] + rotations = gaussians.rotations.squeeze(0) # [G, 4] + rotations_unnorm = gaussians.rotations_unnorm.squeeze(0) # [G, 4] + harmonics = gaussians.harmonics.squeeze(0) # [G, 3, d_sh] + covariances = gaussians.covariances.squeeze(0) if gaussians.covariances is not None else None # [G, 3, 3] + + sel = _1d_indices_from_mask(clone_mask) + + # Clone + new_means = means[sel] + new_opacities = opacities[sel] + new_scales = scales[sel] + new_rotations_unnorm = rotations_unnorm[sel] + new_rotations = rotations[sel] + new_harmonics = harmonics[sel] + new_covariances = covariances[sel] if covariances is not None else None + + def params_fn(v: Tensor, v_new: Tensor, dim: int) -> Tensor: + return torch.cat([v, v_new.unsqueeze(0)], dim=dim) + + def state_fn(v: Tensor) -> Tensor: + return torch.cat([v, v[sel]], dim=0) + + _densification_postfix( + gaussians, + adc_state, + new_means, + new_scales, + new_opacities, + new_rotations, + new_rotations_unnorm, + new_harmonics, + new_covariances, + params_fn, + state_fn, + ) + +@torch.no_grad() +def apply_vanilla_strategy( + cfg: BaseStrategyCfg, + step: int, + gaussians: Gaussians | GaussiansModule, + adc_state: VanillaStrategyState, + smoothers: dict[str, Any], + zero_t: bool = False +) -> tuple[int, int, int, float | None, float | None]: + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("cloning not implemented for GaussiansModule") + + # Densification and Pruning + nr_cloned, nr_splitted, nr_pruned = 0, 0, 0 + + if step >= cfg.refine_stop_iter: + return nr_cloned, nr_splitted, nr_pruned, None, None + + # Calculate average 2D grads magnitude and scales + grads: Tensor = adc_state.grad2d_norm_accum / adc_state.denom.clamp_min(1.0) # [G] + max_grad2d = grads.max().item() + max_radii = adc_state.radii2d.max().item() + + # check if should densify/prune + if ( + step > cfg.refine_start_iter + and step % cfg.refine_every == 0 + and step % cfg.reset_every >= cfg.pause_refine_after_reset + ): + device = gaussians.means.device + grow_grad2d: float = cfg.grow_grad2d + grow_scale3d: float = cfg.grow_scale3d + prune_scale3d: float = cfg.prune_scale3d + prune_scale2d: float = cfg.prune_scale2d + grow_scale2d: float = cfg.grow_scale2d + min_opacity: float = cfg.min_opacity + prune_zero_radii: bool = cfg.prune_zero_radii + + if cfg.do_densify: + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("Densification not implemented for GaussiansModule") + scales: Tensor = gaussians.scales # [G, 3] + elif isinstance(gaussians, Gaussians): + scales: Tensor = gaussians.scales.squeeze(0) # [G, 3] + if gaussians.stores_activated: + # already activated + pass + else: + # activate + scales = torch.exp(scales) # [G, 3] + else: + raise ValueError(f"Unknown type of gaussians: {type(gaussians)}") + # Extract points that satisfy the gradient condition + is_grad_high: Tensor = grads > grow_grad2d # [G] + + is_small: Tensor = scales.max(dim=-1).values <= grow_scale3d * adc_state.scene_extent + + is_large: Tensor = ~is_small + + clone_mask: Tensor = is_grad_high & is_small + + split_mask: Tensor = is_grad_high & is_large + + if step < cfg.refine_scale2d_stop_iter: + split_mask |= adc_state.radii2d > grow_scale2d + + # clone --------------------------------------------------------------------- + + # clone points + cloning( + gaussians=gaussians, + adc_state=adc_state, + clone_mask=clone_mask, + ) + _clone_objects(clone_mask, smoothers, zero_t=zero_t) + + # update states + nr_cloned = int(clone_mask.sum().item()) + + # new GSs added by cloning will not be split + split_mask = torch.cat( + [ + split_mask, + torch.zeros(nr_cloned, dtype=torch.bool, device=device), + ] + ) + + # split --------------------------------------------------------------------- + + # split points + # No need to prune after splitting since we already removed the original points in _densification_postfix + N = 2 + splitting( + gaussians=gaussians, + adc_state=adc_state, + split_mask=split_mask, + N=N, # split each point into N points + ) + _split_objects(split_mask, smoothers, N=N, zero_t=zero_t) + nr_splitted = int(split_mask.sum().item()) + + if cfg.do_prune: + + # prune --------------------------------------------------------------------- + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("Densification not implemented for GaussiansModule") + scales: Tensor = gaussians.scales # [G, 3] + opacities = gaussians.opacities # [G] + elif isinstance(gaussians, Gaussians): + opacities = gaussians.opacities.squeeze(0) # [G] + scales = gaussians.scales.squeeze(0) # [G, 3] + if gaussians.stores_activated: + # already activated + pass + else: + # activate + scales = torch.exp(scales) # [G, 3] + opacities = torch.sigmoid(opacities) # [G] + else: + raise ValueError(f"Unknown type of gaussians: {type(gaussians)}") + + # find points to prune and prune gaussians + prune_mask = opacities < min_opacity + + if step > cfg.reset_every: + is_too_big = scales.max(dim=-1).values > prune_scale3d * adc_state.scene_extent + if step < cfg.refine_scale2d_stop_iter: + is_too_big |= adc_state.radii2d > prune_scale2d + prune_mask = prune_mask | is_too_big + + # invisible from training views + if prune_zero_radii: + raise NotImplementedError("prune_zero_radii not implemented yet") + + prune(gaussians, adc_state, prune_mask) + _prune_objects(prune_mask, objects=smoothers) + + # update states + nr_pruned = int(prune_mask.sum().item()) + + # -------------------------------------------------------------------------- + + # reset adc state + reset_adc_state(adc_state) + + print( + f"Densification/Pruning @ iter {step}: cloned {nr_cloned}, splitted {nr_splitted}, pruned {nr_pruned}, total now {gaussians.means.shape[1]}" + ) + + if cfg.do_opacity_reset: + + # Opacity reset + if step % cfg.reset_every == 0 and step > 0: + + if isinstance(gaussians, GaussiansModule): + raise NotImplementedError("Opacity reset not implemented for GaussiansModule") + elif isinstance(gaussians, Gaussians): + opacities = gaussians.opacities + if gaussians.stores_activated: + # already activated + pass + else: + # activate + opacities = torch.sigmoid(opacities) # [G] + else: + raise ValueError(f"Unknown type of gaussians: {type(gaussians)}") + + value = cfg.min_opacity * 2.0 + new_opacities = torch.min(opacities, torch.ones_like(opacities) * value) + + if gaussians.stores_activated: + # already activated + pass + else: + # deactivated + new_opacities = torch.logit(new_opacities) + + gaussians.opacities = new_opacities + + # reset momentums of opacities + smoothers["opacities"].zero_out(zero_t=zero_t) + print("Opacity reset @ iter", step) + + if cfg.reduce_opacity: + # Slightly reduce opacity every few steps (from EDGS) + if step % cfg.reduce_every == 0: + + opacities = gaussians.opacities + if gaussians.stores_activated: + # already activated + pass + else: + # activate + opacities = torch.sigmoid(opacities) # [G] + + opacities_new = opacities * cfg.reduce_factor + + if gaussians.stores_activated: + # already activated + pass + else: + # deactivate + opacities_new = torch.logit(opacities_new) + + gaussians.opacities = opacities_new + + return nr_cloned, nr_splitted, nr_pruned, max_radii, max_grad2d diff --git a/optgs/scene_trainer/common/__init__.py b/optgs/scene_trainer/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/scene_trainer/common/gaussian_adapter.py b/optgs/scene_trainer/common/gaussian_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..27757f76e015e31ef756c607176a8fae369b21a8 --- /dev/null +++ b/optgs/scene_trainer/common/gaussian_adapter.py @@ -0,0 +1,188 @@ +from dataclasses import dataclass + +import torch +from einops import einsum, rearrange +from jaxtyping import Float +from torch import Tensor, nn +import torch.nn.functional as F + +from optgs.geometry.projection import get_world_rays +from optgs.misc.sh_rotation import rotate_sh +from optgs.scene_trainer.common.gaussians import build_covariance + + +@dataclass +class Gaussians: + means: Float[Tensor, "*batch 3"] + covariances: Float[Tensor, "*batch 3 3"] + scales: Float[Tensor, "*batch 3"] + rotations: Float[Tensor, "*batch 4"] + harmonics: Float[Tensor, "*batch 3 _"] + opacities: Float[Tensor, " *batch"] + rotations_unnorm: Float[Tensor, "*batch 4"] + + +@dataclass +class GaussianAdapterCfg: + gaussian_scale_min: float + gaussian_scale_max: float + sh_degree: int + exp_scale: bool + softplus_scale: bool + clamp_min_scale: float + scale_detach_depth: bool + exp_scale_bias: float + no_rotate_sh: bool + no_sh_mask: bool + init_rotation_identity: bool + + +class GaussianAdapter(nn.Module): + cfg: GaussianAdapterCfg + + def __init__(self, cfg: GaussianAdapterCfg): + super().__init__() + self.cfg = cfg + + # Create a mask for the spherical harmonics coefficients. This ensures that at + # initialization, the coefficients are biased towards having a large DC + # component and small view-dependent components. + self.register_buffer( + "sh_mask", + torch.ones((self.d_sh,), dtype=torch.float32), + persistent=False, + ) + for degree in range(1, self.cfg.sh_degree + 1): + self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree + + def forward( + self, + extrinsics: Float[Tensor, "*#batch 4 4"], + intrinsics: Float[Tensor, "*#batch 3 3"] | None, + coordinates: Float[Tensor, "*#batch 2"] | None, + depths: Float[Tensor, "*#batch"] | None, + opacities: Float[Tensor, "*#batch"], + raw_gaussians: Float[Tensor, "*#batch _"], + image_shape: tuple[int, int] | None, + eps: float = 1e-8, + point_cloud: Float[Tensor, "*#batch 3"] | None = None, + input_images: Tensor | None = None, + gaussian_scale_depth: Tensor | None = None, + init_scales: Tensor | None = None, + ) -> Gaussians: + scales, rotations_unnorm, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1) + + if self.cfg.scale_detach_depth: + depths = depths.detach() + + if init_scales is not None: + # learn residual scales + scales = init_scales + scales + + elif gaussian_scale_depth is not None: + scale_min = self.cfg.gaussian_scale_min + scale_max = self.cfg.gaussian_scale_max + scales = scale_min + (scale_max - scale_min) * scales.sigmoid() + h, w = image_shape + pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) + multiplier = self.get_scale_multiplier(intrinsics, pixel_size) + scales = scales * gaussian_scale_depth[..., None] * multiplier[..., None] + + elif point_cloud is not None: + # TODO: try other activations + if self.cfg.softplus_scale: + scales = torch.clamp(F.softplus(scales - self.cfg.exp_scale_bias), max=self.cfg.gaussian_scale_max) + else: + scales = torch.clamp(torch.exp(scales - self.cfg.exp_scale_bias), max=self.cfg.gaussian_scale_max) + elif self.cfg.exp_scale: + scales = torch.clamp(torch.exp(scales - self.cfg.exp_scale_bias), max=self.cfg.gaussian_scale_max) + elif self.cfg.softplus_scale: + scales = torch.clamp(F.softplus(scales - self.cfg.exp_scale_bias), max=self.cfg.gaussian_scale_max) + else: + scale_min = self.cfg.gaussian_scale_min + scale_max = self.cfg.gaussian_scale_max + scales = scale_min + (scale_max - scale_min) * scales.sigmoid() + h, w = image_shape + pixel_size = 1 / torch.tensor((w, h), dtype=torch.float32, device=extrinsics.device) + multiplier = self.get_scale_multiplier(intrinsics, pixel_size) + scales = scales * depths[..., None] * multiplier[..., None] + + # TODO: avoid nan when using exp scale? + scales = torch.clamp(scales, min=self.cfg.clamp_min_scale) + + assert input_images is not None + + if self.cfg.init_rotation_identity: + identity_quat = torch.tensor([0., 0., 0., 1.], device=rotations_unnorm.device, dtype=rotations_unnorm.dtype) + identity_quat = identity_quat.repeat(*rotations_unnorm.shape[:-1], 1) + rotations_unnorm = identity_quat + rotations_unnorm + + # Normalize the quaternion features to yield a valid quaternion. + rotations = rotations_unnorm / (rotations_unnorm.norm(dim=-1, keepdim=True) + eps) + + # [2, 2, 65536, 1, 1, 3, 25] + sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) + # remove the sh_mask + if self.cfg.no_sh_mask: + sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)).clone() + else: + sh = sh.broadcast_to((*opacities.shape, 3, self.d_sh)) * self.sh_mask + + if input_images is not None: + if point_cloud is not None: + if input_images.dim() == 5: + input_images = rearrange(input_images, "b v c h w -> b v (h w) () () c") + sh[..., 0] = sh[..., 0] + RGB2SH(input_images) + else: + # [B, V, H*W, 1, 1, 3] + imgs = rearrange(input_images, "b v c h w -> b v (h w) () () c") + # init sh with input images + sh[..., 0] = sh[..., 0] + RGB2SH(imgs) + + # Create world-space covariance matrices. + covariances = build_covariance(scales, rotations) + c2w_rotations = extrinsics[..., :3, :3] + covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2) + + # Compute Gaussian means. + if point_cloud is not None: + means = point_cloud + else: + origins, directions = get_world_rays(coordinates, extrinsics, intrinsics) + means = origins + directions * depths[..., None] + + return Gaussians( + means=means, + covariances=covariances, + harmonics=sh if self.cfg.no_rotate_sh else rotate_sh(sh, c2w_rotations[..., None, :, :]), + opacities=opacities, + scales=scales, + rotations=rotations.broadcast_to((*scales.shape[:-1], 4)), + rotations_unnorm=rotations_unnorm.broadcast_to((*scales.shape[:-1], 4)), + ) + + def get_scale_multiplier( + self, + intrinsics: Float[Tensor, "*#batch 3 3"], + pixel_size: Float[Tensor, "*#batch 2"], + multiplier: float = 0.1, + ) -> Float[Tensor, " *batch"]: + xy_multipliers = multiplier * einsum( + intrinsics[..., :2, :2].inverse(), + pixel_size, + "... i j, j -> ... i", + ) + return xy_multipliers.sum(dim=-1) + + @property + def d_sh(self) -> int: + return (self.cfg.sh_degree + 1) ** 2 + + @property + def d_in(self) -> int: + return 7 + 3 * self.d_sh + + +def RGB2SH(rgb): + C0 = 0.28209479177387814 + return (rgb - 0.5) / C0 diff --git a/optgs/scene_trainer/common/gaussians.py b/optgs/scene_trainer/common/gaussians.py new file mode 100644 index 0000000000000000000000000000000000000000..a55bc7fbb8e6c481706219a1cb8e1d52b8b63a82 --- /dev/null +++ b/optgs/scene_trainer/common/gaussians.py @@ -0,0 +1,159 @@ +import torch +from einops import rearrange +from jaxtyping import Float +from torch import Tensor + + +# def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: +# """ +# Convert rotations given as quaternions to rotation matrices. + +# Args: +# quaternions: quaternions with real part first, +# as tensor of shape (..., 4). + +# Returns: +# Rotation matrices as tensor of shape (..., 3, 3). +# """ +# r, i, j, k = torch.unbind(quaternions, -1) # wxyz +# # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. +# two_s = 2.0 / (quaternions * quaternions).sum(-1) + +# o = torch.stack( +# ( +# 1 - two_s * (j * j + k * k), +# two_s * (i * j - k * r), +# two_s * (i * k + j * r), +# two_s * (i * j + k * r), +# 1 - two_s * (i * i + k * k), +# two_s * (j * k - i * r), +# two_s * (i * k - j * r), +# two_s * (j * k + i * r), +# 1 - two_s * (i * i + j * j), +# ), +# -1, +# ) +# return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def rotation_matrix_to_quaternion_xyzw(R: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """ + Convert rotation matrices to quaternions in xyzw format. + + Args: + R: Tensor of shape [..., 3, 3] (or any batch shape ending with (3,3)). + eps: small number for numerical stability. + + Returns: + q: Tensor of shape [..., 4], quaternion order [x, y, z, w]. + """ + if R.shape[-2:] != (3, 3): + raise ValueError(f"Expected last two dimensions (3,3), got {R.shape[-2:]}") + + # Force float (preserve device) + dtype = R.dtype if torch.is_floating_point(R) else torch.float32 + R = R.to(dtype=dtype) + + m00 = R[..., 0, 0]; m01 = R[..., 0, 1]; m02 = R[..., 0, 2] + m10 = R[..., 1, 0]; m11 = R[..., 1, 1]; m12 = R[..., 1, 2] + m20 = R[..., 2, 0]; m21 = R[..., 2, 1]; m22 = R[..., 2, 2] + + trace = m00 + m11 + m22 + + # conditions (shape: [...]) + cond1 = trace > 0.0 + cond2 = (m00 >= m11) & (m00 >= m22) & (~cond1) + cond3 = (m11 > m22) & (~cond1) & (~cond2) + # cond4 is the remaining case + cond4 = ~(cond1 | cond2 | cond3) + + # Candidate 1 (trace > 0): s = 4*w + s1 = (trace + 1.0).clamp_min(eps).sqrt() * 2.0 + inv_s1 = 1.0 / s1 + q1_x = (m21 - m12) * inv_s1 + q1_y = (m02 - m20) * inv_s1 + q1_z = (m10 - m01) * inv_s1 + q1_w = 0.25 * s1 + q1 = torch.stack([q1_x, q1_y, q1_z, q1_w], dim=-1) + + # Candidate 2 (m00 largest): s = 4*x + s2 = (1.0 + m00 - m11 - m22).clamp_min(eps).sqrt() * 2.0 + inv_s2 = 1.0 / s2 + q2_x = 0.25 * s2 + q2_y = (m01 + m10) * inv_s2 + q2_z = (m02 + m20) * inv_s2 + q2_w = (m21 - m12) * inv_s2 + q2 = torch.stack([q2_x, q2_y, q2_z, q2_w], dim=-1) + + # Candidate 3 (m11 largest): s = 4*y + s3 = (1.0 + m11 - m00 - m22).clamp_min(eps).sqrt() * 2.0 + inv_s3 = 1.0 / s3 + q3_x = (m01 + m10) * inv_s3 + q3_y = 0.25 * s3 + q3_z = (m12 + m21) * inv_s3 + q3_w = (m02 - m20) * inv_s3 + q3 = torch.stack([q3_x, q3_y, q3_z, q3_w], dim=-1) + + # Candidate 4 (m22 largest): s = 4*z + s4 = (1.0 + m22 - m00 - m11).clamp_min(eps).sqrt() * 2.0 + inv_s4 = 1.0 / s4 + q4_x = (m02 + m20) * inv_s4 + q4_y = (m12 + m21) * inv_s4 + q4_z = 0.25 * s4 + q4_w = (m10 - m01) * inv_s4 + q4 = torch.stack([q4_x, q4_y, q4_z, q4_w], dim=-1) + + # Broadcast masks over last dimension and select candidates + cond1_u = cond1.unsqueeze(-1) + cond2_u = cond2.unsqueeze(-1) + cond3_u = cond3.unsqueeze(-1) + # nested where ensures every position picks exactly one candidate + q = torch.where(cond1_u, q1, + torch.where(cond2_u, q2, + torch.where(cond3_u, q3, q4))) + + # Normalize to unit quaternion + q = q / q.norm(dim=-1, keepdim=True).clamp_min(eps) + + return q + + + +# # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py +def quaternion_to_matrix( + quaternions: Float[Tensor, "*batch 4"], + eps: float = 1e-8, +) -> Float[Tensor, "*batch 3 3"]: + # Order changed to match scipy format! + i, j, k, r = torch.unbind(quaternions, dim=-1) + two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return rearrange(o, "... (i j) -> ... i j", i=3, j=3) + + +def build_covariance( + scale: Float[Tensor, "*#batch 3"], + rotation_xyzw: Float[Tensor, "*#batch 4"], +) -> Float[Tensor, "*batch 3 3"]: + scale = scale.diag_embed() + rotation = quaternion_to_matrix(rotation_xyzw) + return ( + rotation + @ scale + @ rearrange(scale, "... i j -> ... j i") + @ rearrange(rotation, "... i j -> ... j i") + ) diff --git a/optgs/scene_trainer/gaussian_module.py b/optgs/scene_trainer/gaussian_module.py new file mode 100644 index 0000000000000000000000000000000000000000..6465a1fb5b2757a49d8e2b52248a672c5488d8bf --- /dev/null +++ b/optgs/scene_trainer/gaussian_module.py @@ -0,0 +1,142 @@ +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor, nn +from optgs.model.types import Gaussians +from optgs.scene_trainer.common.gaussians import build_covariance + + +class GaussiansModule(nn.Module): + def __init__( + self, + means: Float[Tensor, "gaussian 3"], + harmonics: Float[Tensor, "gaussian 3 d_sh"], + opacities: Float[Tensor, "gaussian"], + scales: Float[Tensor, "gaussian 3"], + rotations_unnorm: Float[Tensor, "gaussian 4"] + ): + # all gaussians parameters are post-activation + + super().__init__() + + def _register_param(name, value): + if value is None: + setattr(self, name, None) + else: + param = nn.Parameter(value) + setattr(self, name, param) + + self.scaling_activation = torch.exp + self.scaling_inverse_activation = torch.log + self.covariance_activation = build_covariance + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = torch.logit + self.rotation_activation = F.normalize + + # Register parameters + means = means.detach().clone() + means.requires_grad_(True) + + harmonics = harmonics.detach().clone() # [G, sh_d, C] + d_sh = harmonics.shape[-1] + sh0 = harmonics[..., 0:1] # [G, 3, 1] + if d_sh == 1: + # sh_degree = 0 + shN = None + else: + # sh_degree > 0 + shN = harmonics[..., 1:] # [G, 3, d_sh-1] + + sh0.requires_grad_(True) + if shN is not None: + shN.requires_grad_(True) + + # Invert the opacity to optimize in the unconstrained space + opacities_raw = self.inverse_opacity_activation(opacities.detach().clone(), eps=1e-6) + opacities_raw.requires_grad_(True) + + # Invert the scales + scales_raw = self.scaling_inverse_activation(scales.detach().clone()) + scales_raw.requires_grad_(True) + + # Rotations in xyzw (scalar last) + # remember to convert to wxyz (scalar first) before rendering and saving to ply + rotations_unnorm = rotations_unnorm.detach().clone() + rotations_unnorm.requires_grad_(True) + + _register_param("opacities_raw", opacities_raw) + _register_param("scales_raw", scales_raw) + _register_param("means", means) + _register_param("rotations_unnorm", rotations_unnorm) + _register_param("sh0", sh0) + if shN is not None: + _register_param("shN", shN) + + for name, param in self.named_parameters(): + print(f"Registered parameter: {name}, shape: {param.shape}, dtype: {param.dtype}, min: {param.min()}, max: {param.max()}, requires_grad: {param.requires_grad}") + + @property + def scales(self): + scales = self.scaling_activation(self.scales_raw) + return scales + + @property + def opacities(self): + opacities = self.opacity_activation(self.opacities_raw) + return opacities + + @property + def rotations(self): + rotations = self.rotation_activation(self.rotations_unnorm, dim=-1) + return rotations + + @property + def harmonics(self): + # returns [G, 3, d_sh] + shN = getattr(self, "shN", None) + if shN is not None: + harmonics_ = torch.cat([self.sh0, shN], dim=-1) + else: + harmonics_ = self.sh0 + return harmonics_ + + @property + def covariances(self): + rotation_xyzw = self.rotations + covariances = self.covariance_activation(self.scales, rotation_xyzw) # [G, 3, 3] + return covariances + + def reset_opacity(self, optimizer): + opacities_old = self.opacity_activation(self.opacities_raw) + opacities_raw_new = self.inverse_opacity_activation(torch.min(opacities_old, torch.ones_like(opacities_old)*0.01), eps=1e-6) + # optimizable_tensors = self.replace_tensor_to_optimizer(optimizer, opacities_raw_new, "opacity") + # self.opacities_raw = optimizable_tensors["opacity"] + + + +def gaussians2module(gaussians: Gaussians, device: torch.device) -> GaussiansModule: + bs = gaussians.means.shape[0] + assert bs == 1, "Batch size > 1 not supported for post-processing" + # bs = 1 + # convert Gaussians to GaussiansModule + gaussian_module = GaussiansModule( + means=gaussians.means[0].to(device), + harmonics=gaussians.harmonics[0].to(device), + opacities=gaussians.opacities[0].to(device), + scales=gaussians.scales[0].to(device), + rotations_unnorm=gaussians.rotations_unnorm[0].to(device), + ) + return gaussian_module + + +def module2gaussians(gaussian_module: GaussiansModule) -> Gaussians: + gaussians = Gaussians( + means=gaussian_module.means.unsqueeze(0), # [1, G, 3] + covariances=gaussian_module.covariances.unsqueeze(0), # [1, G, 3, 3] + harmonics=gaussian_module.harmonics.unsqueeze(0), # [1, G, sh_d, C] + opacities=gaussian_module.opacities.unsqueeze(0), # [1, G] + scales=gaussian_module.scales.unsqueeze(0), # [1, G, 3] + rotations=gaussian_module.rotations.unsqueeze(0), # [1, G, 4] + rotations_unnorm=gaussian_module.rotations.unsqueeze(0), + ) + return gaussians diff --git a/optgs/scene_trainer/initializer/__init__.py b/optgs/scene_trainer/initializer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8fba7f4c37b3ca113e478fd5b3464ed7106ced --- /dev/null +++ b/optgs/scene_trainer/initializer/__init__.py @@ -0,0 +1,26 @@ +from .initializer import Initializer +from .initializer_resplat import ResplatInitializer, ResplatInitializerCfg +from .initializer_colmap import InitializerColmap, InitializerColmapCfg +from .initializer_ply import InitializerPly, InitializerPlyCfg +from .initializer_edgs import InitializerEdgs, InitializerEdgsCfg +from .initializer_random import InitializerRandom, InitializerRandomCfg +from .initializer_pointcloud import InitializerPointcloud, InitializerPointcloudCfg + +SCENE_INITIALIZERS = { + "resplat_v1": ResplatInitializer, + "resplat_v2": ResplatInitializer, + "colmap": InitializerColmap, + "ply": InitializerPly, + "edgs": InitializerEdgs, + "random": InitializerRandom, + "pointcloud": InitializerPointcloud, +} + +InitializerCfg = ResplatInitializerCfg | InitializerColmapCfg | InitializerPlyCfg | InitializerEdgsCfg | InitializerRandomCfg | InitializerPointcloudCfg + + +def get_scene_initializer(cfg: InitializerCfg) -> Initializer: + print(f"Using scene initializer: {cfg.name}") + scene_initializer = SCENE_INITIALIZERS[cfg.name] + scene_initializer = scene_initializer(cfg) + return scene_initializer diff --git a/optgs/scene_trainer/initializer/initializer.py b/optgs/scene_trainer/initializer/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..e867432c43fb7fa5fd99d842cd0b16a44b56e4ac --- /dev/null +++ b/optgs/scene_trainer/initializer/initializer.py @@ -0,0 +1,79 @@ +from abc import ABC +from dataclasses import dataclass +from typing import TypeVar, Generic + +import torch +from torch import nn + +from optgs.model.types import Gaussians +from optgs.model.decoder.decoder import DecoderOutput + +T = TypeVar("T") + + +@dataclass +class InitializerOutput: + gaussians: Gaussians + features: torch.Tensor | None = None + depths: list[torch.Tensor] | torch.Tensor | None = None + target_render: DecoderOutput | None = None + context_render: DecoderOutput | None = None + + +@dataclass +class InitializerCfg: + per_pixel: bool + per_view: bool + + # Gaussian subsampling augmentation (applied before fixed_gaussians_num) + # Set min=max for a fixed subsample count, or use floats for ratio-based sampling + train_min_gaussians_subsample: int | float | None + train_max_gaussians_subsample: int | float | None + eval_min_gaussians_subsample: int | float | None + eval_max_gaussians_subsample: int | float | None + + # Final fixed Gaussian count for DDP consistency (subsample or pad to reach this) + # Applied after subsampling augmentation + train_fixed_gaussians_num: int | None + eval_fixed_gaussians_num: int | None + +@dataclass +class NonlearnedInitializerCfg(InitializerCfg): + pass + +@dataclass +class LearnedInitializerCfg(InitializerCfg): + pass + + +@dataclass +class PerPixelInitializerCfg(InitializerCfg): + latent_gs: bool + latent_downsample: int + + +class Initializer(nn.Module, ABC, Generic[T]): + cfg: T + + def __init__(self, cfg: T) -> None: + super().__init__() + self.cfg = cfg + + def preprocessing(self, batch, train_cfg) -> None: + pass + + @property + def strategy(self) -> str: + raise NotImplementedError() + + +class LearnedInitializer(Initializer[T], ABC): + @property + def strategy(self) -> str: + return "learned" + + +class NonlearnedInitializer(Initializer[T], ABC): + @property + def strategy(self) -> str: + return "nonlearned" diff --git a/optgs/scene_trainer/initializer/initializer_colmap.py b/optgs/scene_trainer/initializer/initializer_colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..30e6ed9745b05ee1c108de1e6e89b811df515a6d --- /dev/null +++ b/optgs/scene_trainer/initializer/initializer_colmap.py @@ -0,0 +1,347 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Optional, Tuple +import os +import tempfile + +import numpy as np +import torch +import torch.nn.functional as F +from plyfile import PlyData + +from optgs.dataset.colmap.utils import Parser +from optgs.dataset.data_types import BatchedViews +from optgs.experimental.initializers_utils import knn, points_to_gaussians +from optgs.misc.general_utils import SkipBatchException +from optgs.model.types import Gaussians +from optgs.scene_trainer.common.gaussian_adapter import build_covariance +from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, NonlearnedInitializerCfg + + +@dataclass +class InitializerColmapCfg(NonlearnedInitializerCfg): + name: Literal["colmap"] + path: Path + normalize_world_space: bool + scaling_factor: float + init_opacity: float + sh_degree: int + dl3dv_settings: bool + filter_zero_rgb: bool + randomize_opacity: bool + randomize_opacity_distribution: Literal["uniform", "gaussian"] + randomize_opacity_std: float # Standard deviation for gaussian distribution + randomize_opacity_min: float # Minimum value for uniform distribution + points3d_subdir: Optional[str] # if set, overrides dl3dv_settings/default subdir logic + points3d_ply_filename: Optional[str] # if set, loads points from this PLY file (relative to scene dir) instead of COLMAP binary + override_dataset_poses: bool # if true, overrides the dataset poses with the COLMAP poses (after applying T_world transform) + + def get_gaussian_param_num(self): + # calculate the number of parameters per Gaussian + sh_d = self.get_sh_d() + init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 + return init_gaussian_param_num + + def get_sh_d(self): + sh_d = (self.sh_degree + 1) ** 2 + return sh_d + + +class InitializerColmap(NonlearnedInitializer[InitializerColmapCfg]): + def __init__(self, cfg: InitializerColmapCfg) -> None: + super().__init__(cfg) + + def _npz_path(self, datadir: Path) -> Path: + suffix = "_norm" if self.cfg.normalize_world_space else "" + if self.cfg.points3d_ply_filename is not None: + ply_stem = Path(self.cfg.points3d_ply_filename).stem + return datadir / f"colmap_points_cache_ply_{ply_stem}{suffix}.npz" + return datadir / f"colmap_points_cache{suffix}.npz" + + def _load_colmap(self, datadir: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Load COLMAP points/colors/poses. + + On first access, parses the raw COLMAP binary files (or a PLY file when + ``points3d_ply_filename`` is set) and saves a compact .npz next to the + scene folder. On subsequent calls only the tiny .npz is loaded. + """ + npz_path = self._npz_path(datadir) + if npz_path.exists(): + try: + data = np.load(npz_path) + return data["points"], data["points_rgb"], data["camtoworlds"] + except PermissionError: + print(f"Warning: No read permission for cache {npz_path}. Attempting to delete and regenerate.") + try: + os.unlink(npz_path) + except Exception as del_e: + print(f"Warning: Could not delete {npz_path} ({del_e}). Will re-parse but cannot cache.") + except Exception as e: + print(f"Warning: Failed to load cache {npz_path} ({e}). Re-parsing COLMAP data.") + + # Always parse COLMAP cameras/images for the poses. + parser = Parser( + data_dir=str(datadir), + factor=1, + normalize=self.cfg.normalize_world_space, + load_images=False, + dl3dv_settings=False, + points3d_subdir=self.cfg.points3d_subdir, + verbose=False, + ) + camtoworlds = parser.camtoworlds # (M, 4, 4) float64 + + if self.cfg.points3d_ply_filename is not None: + # Load 3-D points from a PLY file located directly in the scene dir. + ply_path = datadir / self.cfg.points3d_ply_filename + if not ply_path.exists(): + raise IOError(f"PLY file not found: {ply_path}") + plydata = PlyData.read(str(ply_path)) + vertex = plydata["vertex"] + points = np.stack([ + np.asarray(vertex["x"]), + np.asarray(vertex["y"]), + np.asarray(vertex["z"]), + ], axis=1).astype(np.float32) + points_rgb = np.stack([ + np.asarray(vertex["red"]), + np.asarray(vertex["green"]), + np.asarray(vertex["blue"]), + ], axis=1).astype(np.uint8) + else: + points = parser.points # (N, 3) float32 + points_rgb = parser.points_rgb # (N, 3) uint8 + + # TODO Patricia: Fix permission denied + # Write atomically with a temp file that already ends in .npz. + try: + tmp_path = '' + tmp_fd, tmp_path = tempfile.mkstemp(dir=datadir, suffix=".npz") + os.close(tmp_fd) + np.savez_compressed(tmp_path, points=points, points_rgb=points_rgb, camtoworlds=camtoworlds) + os.chmod(tmp_path, 0o664) # group-readable so other users can use this cache + os.replace(tmp_path, npz_path) # atomic on POSIX + except Exception: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + print(f"Warning: Failed to save COLMAP cache to {npz_path}. This may cause slow loading in the future.") + return points, points_rgb, camtoworlds + + def forward( + self, + context: BatchedViews, + visualization_dump: Optional[dict] = None, + device: Optional[torch.device] = None, + **kwargs + ) -> InitializerOutput: + verbose = False + + # context not used + + # assert COLMAP dir exists + if not self.cfg.path.exists(): + raise ValueError(f"COLMAP dir {self.cfg.path} does not exist.") + + if "scene" in kwargs: + scene_name = kwargs["scene"] + assert len(scene_name) == 1, f"Only single scene initialization supported. {scene_name}" + scene_name = scene_name[0] + if self.cfg.dl3dv_settings: + scene_name = scene_name.replace("dl3dv_", "") + if verbose: + print(f"Initializing scene '{scene_name}' from COLMAP at {self.cfg.path}.") + datadir = self.cfg.path / scene_name + if not datadir.exists(): + raise ValueError(f"COLMAP scene dir {datadir} does not exist.") + else: + datadir = self.cfg.path + + # run COLMAP parser (cached after first load) + points_xyz, points_rgb, camtoworlds = self._load_colmap(datadir) + + if verbose: + print(f"Loaded {points_xyz.shape[0]} points from COLMAP.") + + xyz = torch.from_numpy(points_xyz).float().to(device) + rgbs = torch.from_numpy(points_rgb / 255.0).float().to(device) + + if self.cfg.filter_zero_rgb: + # Filter out points with 0,0,0 RGB values (these are often outliers in COLMAP reconstructions) + valid_mask = (rgbs.sum(dim=-1) > 0) + xyz = xyz[valid_mask] + rgbs = rgbs[valid_mask] + + if self.cfg.dl3dv_settings: + assert "target" in kwargs, "Target key is required in kwargs for COLMAP initializer with dl3dv format." + target = kwargs["target"] + + # In some configration we might move the batch to device later, so we want to keep the device consistent + batch_device = target['extrinsics'].device + + context_c2w_dataset = context['extrinsics'] # (b, V, 4, 4) + c2w_colmap = torch.from_numpy(camtoworlds).to(device=batch_device, + dtype=context_c2w_dataset.dtype) # (N, 4, 4) + # T_world = c2w_dataset[0] @ c2w_colmap[0].inverse() + # eps = 1e-3 + # T_world[T_world.abs() < eps] = 0 + # T_world[(T_world - 1.0).abs() < eps] = 1.0 + # T_world[(T_world + 1.0).abs() < eps] = -1.0 + T_world = torch.tensor([[0., 1., 0., 0.], + [1., 0., 0., 0.], + [0., 0., -1., 0.], + [0., 0., 0., 1.]], device=batch_device, + dtype=context_c2w_dataset.dtype) # hard coded for dl3dv colmap reconstructions + c2w_dataset_predicted = T_world @ c2w_colmap + + # Assume only one scene in the batch + context_x_flipped = context['x_flipped'][0] + target_x_flipped = target['x_flipped'][0] + assert context_x_flipped == target_x_flipped, "Context and target x_flipped values must match." + x_flipped = context_x_flipped + flip_transform = torch.eye(4, device=batch_device, dtype=context_c2w_dataset.dtype) + flip_transform[0, 0] = -1.0 + + if x_flipped: + c2w_dataset_predicted = flip_transform @ c2w_dataset_predicted @ flip_transform + + # Overriding the dataset poses with the COLMAP to ensure consistency + if self.cfg.override_dataset_poses: + context_indices = context['index'][0] + new_context_c2w = c2w_dataset_predicted[context_indices] + new_context_c2w = new_context_c2w[None, ...] # (1, V, 4, 4) + context['extrinsics'] = new_context_c2w + + target_indices = target['index'][0] + new_target_c2w = c2w_dataset_predicted[target_indices] + new_target_c2w = new_target_c2w[None, ...] + target['extrinsics'] = new_target_c2w + + xyz = xyz.to(device) + xyz = T_world.to(device) @ torch.cat([xyz, torch.ones_like(xyz[:, :1])], dim=-1).T + if x_flipped: + xyz[0] *= -1.0 + xyz = xyz[:3, :].T + + # ── Step 1: subsampling augmentation ───────────────────────────────────── + min_sub = self.cfg.train_min_gaussians_subsample if self.training else self.cfg.eval_min_gaussians_subsample + max_sub = self.cfg.train_max_gaussians_subsample if self.training else self.cfg.eval_max_gaussians_subsample + + if min_sub is not None or max_sub is not None: + target_count = self._sample_num_gaussians(xyz.shape[0], min_sub, max_sub) + if xyz.shape[0] > target_count: + indices = torch.randperm(xyz.shape[0], device=xyz.device)[:target_count] + xyz = xyz[indices] + rgbs = rgbs[indices] + + # ── Step 2: subsample to fixed count before knn (so distances are correct) + # If current number of points exceeds the fixed count, we subsample to the fixed count (for DDP consistency). + fixed_num = self.cfg.train_fixed_gaussians_num if self.training else self.cfg.eval_fixed_gaussians_num + if fixed_num is not None and xyz.shape[0] > fixed_num: + indices = torch.randperm(xyz.shape[0], device=xyz.device)[:fixed_num] + xyz = xyz[indices] + rgbs = rgbs[indices] + + if xyz.shape[0] == 0: + black_gaussians_num = (points_rgb == 0).all(axis=-1).sum() + raise SkipBatchException(f"No valid points found in COLMAP data for scene {datadir}. Skipping batch. " + f"Originally {points_xyz.shape[0]} points. Black gaussian num {black_gaussians_num}.") + + # ── Step 3: knn-based scale initialisation ─────────────────────────────── + dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = dist_avg.unsqueeze(-1).repeat(1, 3) # [N, 3] + + # Initialize opacities with optional randomization + if self.cfg.randomize_opacity: + if self.cfg.randomize_opacity_distribution == "uniform": + # Randomize opacities uniformly between min and max + opacities = (torch.rand(xyz.shape[0], device=xyz.device) * (self.cfg.init_opacity - self.cfg.randomize_opacity_min)) + self.cfg.randomize_opacity_min + elif self.cfg.randomize_opacity_distribution == "gaussian": + # Randomize opacities with a Gaussian distribution + mean = self.cfg.init_opacity + stddev = self.cfg.randomize_opacity_std + opacities = torch.normal(mean, stddev, size=(xyz.shape[0],), device=xyz.device) + opacities = opacities.clamp(0, 1) # Clamp to ensure valid values + else: + raise ValueError(f"Unknown randomize_opacity_distribution: {self.cfg.randomize_opacity_distribution}") + else: + opacities = torch.full((xyz.shape[0],), self.cfg.init_opacity) + + nr_valid = xyz.shape[0] + # ── Step 4: pad to fixed count for DDP consistency ─────────────────────── + if fixed_num is not None and xyz.shape[0] < fixed_num: + pad = fixed_num - xyz.shape[0] + xyz = F.pad(xyz, (0, 0, 0, pad), value=0.0) + rgbs = F.pad(rgbs, (0, 0, 0, pad), value=0.0) + scales = F.pad(scales, (0, 0, 0, pad), value=1e-10) + opacities = F.pad(opacities, (0, pad), value=1e-10) + # TODO Naama: might be a problem if we don't freeze zero-grad gaussians + + points_dict = { + "xyz": xyz, + "rgb": rgbs, + "scales": scales, + "opacities": opacities, + } + + points_dict["scales"] *= self.cfg.scaling_factor + + # pre-activation values on device + gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) + + means = gaussians_dict["xyz"] + sh0 = gaussians_dict["sh0"] + shN = gaussians_dict["shN"] + if shN is not None: + harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] + else: + harmonics = sh0 + harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] + rotations_unnorm = gaussians_dict["rotations_unnorm"] + + # post-activation values + opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) + scales = torch.exp(gaussians_dict["scales_raw"]) + rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) + covariances = build_covariance(scale=scales, rotation_xyzw=rotations) + + gaussians = Gaussians( + means=means.unsqueeze(0), + covariances=covariances.unsqueeze(0), + harmonics=harmonics.unsqueeze(0), # [1, N, C, sh_d] + opacities=opacities.unsqueeze(0), + scales=scales.unsqueeze(0), + rotations=rotations.unsqueeze(0), + rotations_unnorm=rotations_unnorm.unsqueeze(0), + nr_valid=nr_valid + ) + + return InitializerOutput( + gaussians=gaussians, + features=None, + depths=None + ) + + @staticmethod + def _sample_num_gaussians( + available: int, + min_val: int | float | None, + max_val: int | float | None, + ) -> int: + if min_val is None and max_val is None: + return available + + assert min_val is not None and max_val is not None, \ + "Both min and max must be set together for Gaussian subsampling." + assert type(min_val) == type(max_val), \ + "min and max must be the same type (both int or both float)." + + if isinstance(min_val, int): + count = torch.randint(min_val, max_val + 1, (1,)).item() + else: + assert 0.0 < min_val <= 1.0 and 0.0 < max_val <= 1.0, \ + "Float subsampling ratios must be in (0, 1]." + ratio = torch.empty(1).uniform_(min_val, max_val).item() + count = int(available * ratio) + + return min(count, available) diff --git a/optgs/scene_trainer/initializer/initializer_edgs.py b/optgs/scene_trainer/initializer/initializer_edgs.py new file mode 100644 index 0000000000000000000000000000000000000000..057f37305a14e21422f78e704c67607b79732bcb --- /dev/null +++ b/optgs/scene_trainer/initializer/initializer_edgs.py @@ -0,0 +1,189 @@ +from dataclasses import dataclass +from typing import Literal, Optional + +from optgs.dataset.data_types import BatchedViews +import numpy as np +import torch +import math +import torch.nn.functional as F +from pathlib import Path +from optgs.experimental.edgs.init import init_gaussians_with_corr +from optgs.experimental.initializers_utils import knn, points_to_gaussians +from optgs.model.types import Gaussians +from optgs.scene_trainer.common.gaussian_adapter import build_covariance +from optgs.scene_trainer.initializer.initializer import InitializerOutput, NonlearnedInitializer, NonlearnedInitializerCfg + + +@dataclass +class InitializerEdgsCfg(NonlearnedInitializerCfg): + name: Literal["edgs"] + sh_degree: int + init_opacity: float + scaling_factor: float + roma_model_type: str + + sample_init_gaussians: int # if >0, randomly sample this many gaussians from the initialized set + + def get_gaussian_param_num(self): + # calculate the number of parameters per Gaussian + sh_d = self.get_sh_d() + # TODO Naama: check where this is used, and if it is needed + init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 + return init_gaussian_param_num + + def get_sh_d(self): + sh_d = (self.sh_degree + 1) ** 2 + return sh_d + + +class InitializerEdgs(NonlearnedInitializer[InitializerEdgsCfg]): + def __init__(self, cfg: InitializerEdgsCfg) -> None: + super().__init__(cfg) + + def forward( + self, + context: BatchedViews, + visualization_dump: Optional[dict] = None, + cached_data_path: Optional[Path] = None, + **kwargs + ) -> InitializerOutput: + + device = context["extrinsics"].device + + # unpack context (batch_dim = 1) + viewpoints_img = context["image"].squeeze(0) # [N, 3, H, W] + h, w = viewpoints_img.shape[2], viewpoints_img.shape[3] + + # poses + viewpoints_c2w = context["extrinsics"].squeeze(0).clone() # [N, 4, 4] + camera_centers = viewpoints_c2w[..., :3, 3] + viewpoints_w2c = torch.inverse(viewpoints_c2w) # [N, 4, 4] + + # convert to column-major + viewpoints_w2c = viewpoints_w2c.permute(0, 2, 1) + + # intrinsics + viewpoints_intrinsics = context["intrinsics"].squeeze(0).clone() # [N, 3, 3] + # un-normalize intrinsics by multiplying by image size + viewpoints_intrinsics[:, 0, :] *= w + viewpoints_intrinsics[:, 1, :] *= h + + def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + + def focal2fov(focal, pixels): + return 2 * math.atan(pixels / (2 * focal)) + + viewpoints_proj = [] + for idx, intrinsic in enumerate(viewpoints_intrinsics): + fx = intrinsic[0, 0] + fy = intrinsic[1, 1] + znear = 0.01 + zfar = 100.0 + fovY = focal2fov(fy, h) + fovX = focal2fov(fx, w) + proj = getProjectionMatrix( + znear=znear, zfar=zfar, fovX=fovX, fovY=fovY + ).transpose(0, 1).cuda() + viewpoints_proj.append(proj) + viewpoints_proj = torch.stack(viewpoints_proj, dim=0) # [N, 4, 4] + + # compute full projection matrices + viewpoints_full_proj = (viewpoints_w2c.bmm(viewpoints_proj)) # [N, 4, 4] + + # check if points_dict is stored on disk already (cached) + found_cached = False + if cached_data_path is not None: + print("Checking for cached points_dict at:", str(cached_data_path)) + cache_path = cached_data_path / "points_dict.pt" + if cache_path.exists(): + points_dict = torch.load(cache_path) + print("Loaded cached points_dict from:", str(cache_path)) + found_cached = True + else: + print("No cached points_dict found at:", str(cache_path)) + + if not found_cached: + # recompute points_dict + _, _, points_dict = init_gaussians_with_corr( + viewpoints_img=viewpoints_img, # [N, 3, H, W] + viewpoints_w2c=viewpoints_w2c, # [N, 4, 4] + viewpoints_proj=viewpoints_full_proj, # [N, 4, 4] + camera_centers=camera_centers, # [N, 3] + init_opacity=self.cfg.init_opacity, + roma_model_type=self.cfg.roma_model_type, + verbose=False + ) + if cached_data_path is not None: + print("Saving points_dict to cache at:", str(cache_path)) + cached_data_path.mkdir(parents=True, exist_ok=True) + torch.save(points_dict, cache_path) + + points_dict["scales"] *= self.cfg.scaling_factor + + # printing some stats + for k, v in points_dict.items(): + print(f"points_dict[{k}]: shape={v.shape}, dtype={v.dtype}, min={v.min().item()}, max={v.max().item()}") + + # downsample if needed + if self.cfg.sample_init_gaussians > 0: + # randomly sample a subset of gaussians + total_points = points_dict["xyz"].shape[0] + sample_num = min(self.cfg.sample_init_gaussians, total_points) + sampled_indices = torch.randperm(total_points)[:sample_num] + points_dict = {k: v[sampled_indices] for k, v in points_dict.items()} + print("Nr points after sampling:", points_dict["xyz"].shape[0]) + + + # pre-activation values on device + gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) + + means = gaussians_dict["xyz"] + sh0 = gaussians_dict["sh0"] + shN = gaussians_dict["shN"] + harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] + harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] + rotations_unnorm = gaussians_dict["rotations_unnorm"] + + # post-activation values + opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) + scales = torch.exp(gaussians_dict["scales_raw"]) + rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) + covariances = build_covariance(scale=scales, rotation_xyzw=rotations) + + print("Nr gaussians initialized:", means.shape[0]) + + gaussians = Gaussians( + means=means.unsqueeze(0), + covariances=covariances.unsqueeze(0), + harmonics=harmonics.unsqueeze(0), # [1, N, 3, sh_d] + opacities=opacities.unsqueeze(0), + scales=scales.unsqueeze(0), + rotations=rotations.unsqueeze(0), + rotations_unnorm=rotations_unnorm.unsqueeze(0), + ) + + return InitializerOutput( + gaussians=gaussians, + features=None, + depths=None + ) diff --git a/optgs/scene_trainer/initializer/initializer_ply.py b/optgs/scene_trainer/initializer/initializer_ply.py new file mode 100644 index 0000000000000000000000000000000000000000..7827531515839e7cd05de5c8d5da2350c9fbcf1f --- /dev/null +++ b/optgs/scene_trainer/initializer/initializer_ply.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Optional + +from optgs.dataset.data_types import BatchedViews + +import torch +import torch.nn.functional as F + +# from optgs.dataset.colmap.utils import Parser +# from optgs.experimental.initializers_utils import knn, points_to_gaussians +from optgs.scene_trainer.common.gaussian_adapter import build_covariance +from optgs.model.types import Gaussians +from optgs.model.ply_export import load_gaussians_ply +from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, NonlearnedInitializerCfg + + +@dataclass +class InitializerPlyCfg(NonlearnedInitializerCfg): + name: Literal["ply"] + path: Path + # normalize_world_space: bool + # scaling_factor: float + # init_opacity: float + sh_degree: int + # dl3dv_settings: bool + ply_filename: str = "gaussians.ply" # relative path under the scene dir; can include subdirs, e.g. "iteration_20000/point_cloud.ply" + + def get_gaussian_param_num(self): + # calculate the number of parameters per Gaussian + sh_d = self.get_sh_d() + init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 + return init_gaussian_param_num + + def get_sh_d(self): + sh_d = (self.sh_degree + 1) ** 2 + return sh_d + + +class InitializerPly(NonlearnedInitializer[InitializerPlyCfg]): + def __init__(self, cfg: InitializerPlyCfg) -> None: + super().__init__(cfg) + + def forward( + self, + context: BatchedViews, + visualization_dump: Optional[dict] = None, + **kwargs + ) -> InitializerOutput: + device = context["extrinsics"].device + verbose = False + + # assert COLMAP dir exists + if not self.cfg.path.exists(): + raise ValueError(f"COLMAP dir {self.cfg.path} does not exist.") + + if "scene" in kwargs: + scene_name = kwargs["scene"] + assert len(scene_name) == 1, f"Only single scene initialization supported. {scene_name}" + scene_name = scene_name[0] + if verbose: + print(f"Initializing scene '{scene_name}' from COLMAP at {self.cfg.path}.") + datadir = self.cfg.path / scene_name + if not datadir.exists(): + raise ValueError(f"COLMAP scene dir {datadir} does not exist.") + else: + scene_name = None + datadir = self.cfg.path + + # ply_filename supports {scene} substitution and glob patterns. The glob + # is matched relative to datadir; exactly one match is required. + rel = self.cfg.ply_filename + if scene_name is not None and "{scene}" in rel: + rel = rel.replace("{scene}", scene_name) + if any(c in rel for c in "*?["): + matches = sorted(datadir.glob(rel)) + if not matches: + raise FileNotFoundError(f"No PLY matched pattern {rel!r} under {datadir}.") + if len(matches) > 1: + raise ValueError(f"PLY pattern {rel!r} matched {len(matches)} files under {datadir}; expected one. Matches: {matches}") + ply_path = matches[0] + else: + ply_path = datadir / rel + + # pre-activation values on device + gaussians = load_gaussians_ply(ply_path, max_sh_degree=self.cfg.sh_degree) + + # move to device + gaussians = gaussians.to(device) + + # build covariances + covariances = build_covariance(scale=gaussians.scales, rotation_xyzw=gaussians.rotations) + gaussians.covariances = covariances + + return InitializerOutput( + gaussians=gaussians, + features=None, + depths=None + ) diff --git a/optgs/scene_trainer/initializer/initializer_pointcloud.py b/optgs/scene_trainer/initializer/initializer_pointcloud.py new file mode 100644 index 0000000000000000000000000000000000000000..37983dc50d60cf04ee1ea610353a0f0f90f0d6dd --- /dev/null +++ b/optgs/scene_trainer/initializer/initializer_pointcloud.py @@ -0,0 +1,197 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Literal, Optional + +import numpy as np +import torch +import torch.nn.functional as F + +from optgs.experimental.initializers_utils import knn, points_to_gaussians +from optgs.model.types import Gaussians +from optgs.scene_trainer.common.gaussian_adapter import build_covariance +from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, NonlearnedInitializerCfg + + +@dataclass +class InitializerPointcloudCfg(NonlearnedInitializerCfg): + name: Literal["pointcloud"] + path: Path # Directory containing .ply files + scaling_factor: float + init_opacity: float + sh_degree: int + filter_zero_rgb: bool + # 4x4 world transform applied to point cloud positions. + # Needed when the PLY is in a different coordinate system than the camera poses. + # For ScanNet++/NeRFstudio: the PLY is in COLMAP space while cameras are in + # NeRFstudio space. The transform is (x,y,z) -> (y,x,-z), i.e.: + # [[0,1,0,0],[1,0,0,0],[0,0,-1,0],[0,0,0,1]] + # Set to null to skip. + world_transform: Optional[list] + + def get_gaussian_param_num(self): + sh_d = (self.sh_degree + 1) ** 2 + return 3 + 4 + 3 * sh_d + 2 + 1 + + def get_sh_d(self): + return (self.sh_degree + 1) ** 2 + + +class InitializerPointcloud(NonlearnedInitializer[InitializerPointcloudCfg]): + def __init__(self, cfg: InitializerPointcloudCfg) -> None: + super().__init__(cfg) + + @staticmethod + def _load_ply(ply_path: Path) -> tuple[np.ndarray, np.ndarray]: + """Load Open3D binary PLY file. + + Returns: + xyz: [N, 3] float32 array of 3D positions + rgb: [N, 3] uint8 array of colors + """ + with open(ply_path, "rb") as f: + num_vertices = 0 + while True: + line = f.readline().decode("ascii").strip() + if line.startswith("element vertex"): + num_vertices = int(line.split()[-1]) + if line == "end_header": + break + + dtype = np.dtype([ + ("x", " InitializerOutput: + # Resolve PLY path + if "scene" in kwargs: + scene_name = kwargs["scene"] + assert len(scene_name) == 1, f"Only single scene initialization supported. {scene_name}" + scene_name = scene_name[0] + ply_path = self.cfg.path / f"{scene_name}.ply" + else: + raise ValueError("Scene name is required for pointcloud initializer.") + + if not ply_path.exists(): + raise ValueError(f"PLY file {ply_path} does not exist.") + + print(f"Loading point cloud from {ply_path}") + + # Load PLY + points_xyz, points_rgb = self._load_ply(ply_path) + print(f"Loaded {points_xyz.shape[0]} points.") + + xyz = torch.from_numpy(points_xyz).float().to(device) + rgbs = torch.from_numpy(points_rgb / 255.0).float().to(device) + + # Apply world transform to align point cloud with camera coordinate system + if self.cfg.world_transform is not None: + T = torch.tensor(self.cfg.world_transform, dtype=torch.float32, device=device) + # Transform: new_xyz = (T @ [xyz, 1])[:3] + xyz_h = torch.cat([xyz, torch.ones(xyz.shape[0], 1, device=device)], dim=-1) # [N, 4] + xyz = (T @ xyz_h.T)[:3].T # [N, 3] + + # Filter zero-RGB points + if self.cfg.filter_zero_rgb: + valid_mask = rgbs.sum(dim=-1) > 0 + xyz = xyz[valid_mask] + rgbs = rgbs[valid_mask] + + # ── Step 1: subsampling augmentation ───────────────────────────────────── + min_sub = self.cfg.train_min_gaussians_subsample if self.training else self.cfg.eval_min_gaussians_subsample + max_sub = self.cfg.train_max_gaussians_subsample if self.training else self.cfg.eval_max_gaussians_subsample + + if min_sub is not None or max_sub is not None: + target_count = self._sample_num_gaussians(xyz.shape[0], min_sub, max_sub) + if xyz.shape[0] > target_count: + indices = torch.randperm(xyz.shape[0], device=xyz.device)[:target_count] + xyz = xyz[indices] + rgbs = rgbs[indices] + + # ── Step 2: subsample to fixed count (for DDP consistency) ──────────── + fixed_num = self.cfg.train_fixed_gaussians_num if self.training else self.cfg.eval_fixed_gaussians_num + if fixed_num is not None and xyz.shape[0] > fixed_num: + indices = torch.randperm(xyz.shape[0], device=xyz.device)[:fixed_num] + xyz = xyz[indices] + rgbs = rgbs[indices] + + # KNN → scales + dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = dist_avg.unsqueeze(-1).repeat(1, 3) # [N, 3] + opacities = torch.full((xyz.shape[0],), self.cfg.init_opacity) + + # Pad to fixed size for distributed training + if self.training and fixed_num is not None: + current_num = xyz.shape[0] + if current_num < fixed_num: + pad_size = fixed_num - current_num + xyz = F.pad(xyz, (0, 0, 0, pad_size), mode='constant', value=0.0) + rgbs = F.pad(rgbs, (0, 0, 0, pad_size), mode='constant', value=0.0) + scales = F.pad(scales, (0, 0, 0, pad_size), mode='constant', value=1e-10) + opacities = F.pad(opacities, (0, pad_size), mode='constant', value=1e-10) + + points_dict = { + "xyz": xyz, + "rgb": rgbs, + "scales": scales * self.cfg.scaling_factor, + "opacities": opacities, + } + + # Convert to Gaussian representation + gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) + + means = gaussians_dict["xyz"] + sh0 = gaussians_dict["sh0"] + shN = gaussians_dict["shN"] + harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] + harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] + rotations_unnorm = gaussians_dict["rotations_unnorm"] + + opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) + scales = torch.exp(gaussians_dict["scales_raw"]) + rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) + covariances = build_covariance(scale=scales, rotation_xyzw=rotations) + + gaussians = Gaussians( + means=means.unsqueeze(0), + covariances=covariances.unsqueeze(0), + harmonics=harmonics.unsqueeze(0), + opacities=opacities.unsqueeze(0), + scales=scales.unsqueeze(0), + rotations=rotations.unsqueeze(0), + rotations_unnorm=rotations_unnorm.unsqueeze(0), + ) + + return InitializerOutput( + gaussians=gaussians, + features=None, + depths=None, + ) + + @staticmethod + def _sample_num_gaussians(available: int, min_sub: int | float | None, max_sub: int | float | None) -> int: + """Sample a target Gaussian count from the [min_sub, max_sub] range.""" + if min_sub is None: + min_sub = max_sub + if max_sub is None: + max_sub = min_sub + + if isinstance(min_sub, int): + target = torch.randint(min_sub, max_sub + 1, (1,)).item() + else: # float → ratio of available + ratio = torch.empty(1).uniform_(min_sub, max_sub).item() + target = int(available * ratio) + + return min(target, available) diff --git a/optgs/scene_trainer/initializer/initializer_random.py b/optgs/scene_trainer/initializer/initializer_random.py new file mode 100644 index 0000000000000000000000000000000000000000..281c85e6ef6a00acf824a113965473286d7075bd --- /dev/null +++ b/optgs/scene_trainer/initializer/initializer_random.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass +from typing import Literal, Optional +from pathlib import Path +import torch +import torch.nn.functional as F + +from optgs.dataset.data_types import BatchedViews +from optgs.scene_trainer.common.gaussian_adapter import build_covariance +from optgs.model.types import Gaussians +from optgs.experimental.initializers_utils import knn, points_to_gaussians +from optgs.scene_trainer.initializer.initializer import NonlearnedInitializer, InitializerOutput, InitializerCfg, NonlearnedInitializerCfg +from optgs.dataset.camera_datasets.camera import get_scene_scale + + +@dataclass +class InitializerRandomCfg(NonlearnedInitializerCfg): + name: Literal["random"] + init_num_pts: int + init_extent: float + scaling_factor: float + init_opacity: float + sh_degree: int + + def get_gaussian_param_num(self): + # calculate the number of parameters per Gaussian + sh_d = self.get_sh_d() + init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 + return init_gaussian_param_num + + def get_sh_d(self): + sh_d = (self.sh_degree + 1) ** 2 + return sh_d + + +class InitializerRandom(NonlearnedInitializer[InitializerRandomCfg]): + def __init__(self, cfg: InitializerRandomCfg) -> None: + super().__init__(cfg) + + def forward( + self, + context: BatchedViews, + **kwargs + ) -> InitializerOutput: + + device = context["extrinsics"].device + init_num_pts = self.cfg.init_num_pts + init_extent = self.cfg.init_extent + + # calculate scene scale from context + camtoworlds = context["extrinsics"].cpu().numpy() # [B, 4, 4] + assert camtoworlds.shape[0] == 1, "Batch size > 1 not supported in random initializer" + camtoworlds = camtoworlds.squeeze(0) + scene_scale = get_scene_scale(camtoworlds) + + xyz = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + rgbs = torch.rand((init_num_pts, 3)) + + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(xyz, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = dist_avg.unsqueeze(-1).repeat(1, 3) # [N, 3] + + points_dict = { + "xyz": xyz, + "rgb": rgbs, + "scales": scales, + "opacities": torch.full((xyz.shape[0],), self.cfg.init_opacity), + } + + points_dict["scales"] *= self.cfg.scaling_factor + + # pre-activation values on device + gaussians_dict = points_to_gaussians(points_dict, sh_degree=self.cfg.sh_degree, device=device) + + means = gaussians_dict["xyz"] + sh0 = gaussians_dict["sh0"] + shN = gaussians_dict["shN"] + harmonics = torch.cat([sh0, shN], dim=1) # [N, sh_d, 3] + harmonics = harmonics.permute(0, 2, 1) # [N, 3, sh_d] + rotations_unnorm = gaussians_dict["rotations_unnorm"] + + # post-activation values + opacities = torch.sigmoid(gaussians_dict["opacities_raw"]) + scales = torch.exp(gaussians_dict["scales_raw"]) + rotations = F.normalize(gaussians_dict["rotations_unnorm"], dim=-1) + covariances = build_covariance(scale=scales, rotation_xyzw=rotations) + + gaussians = Gaussians( + means=means.unsqueeze(0), + covariances=covariances.unsqueeze(0), + harmonics=harmonics.unsqueeze(0), # [1, N, C, sh_d] + opacities=opacities.unsqueeze(0), + scales=scales.unsqueeze(0), + rotations=rotations.unsqueeze(0), + rotations_unnorm=rotations_unnorm.unsqueeze(0), + ) + + return InitializerOutput( + gaussians=gaussians, + features=None, + depths=None + ) \ No newline at end of file diff --git a/optgs/scene_trainer/initializer/initializer_resplat.py b/optgs/scene_trainer/initializer/initializer_resplat.py new file mode 100644 index 0000000000000000000000000000000000000000..5527a572956aeea558733ff4d01d3d74e1f63359 --- /dev/null +++ b/optgs/scene_trainer/initializer/initializer_resplat.py @@ -0,0 +1,1367 @@ +from dataclasses import dataclass +from typing import Literal, Optional, List + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from torch import nn + +from optgs.dataset.data_types import BatchedExample, DataShim, BatchedViews +from optgs.dataset.shims.patch_shim import apply_patch_shim +from optgs.geometry.projection import sample_image_grid, get_world_rays +from optgs.misc.general_utils import rotate_quats +from optgs.misc.io import FrequencyScheduler +from optgs.model.encoder.layer import BasicBlock +from optgs.model.encoder.unimatch.dpt_head import DPTHead +from optgs.model.encoder.unimatch.feature_upsampler import ResizeConvFeatureUpsampler +from optgs.model.encoder.unimatch.ldm_unet.unet import UNetModel +from optgs.model.encoder.unimatch.mv_unimatch import MultiViewUniMatch +from optgs.model.encoder.visualization.encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg +from optgs.model.types import Gaussians +from optgs.scene_trainer.common.gaussian_adapter import GaussianAdapter, GaussianAdapterCfg, build_covariance, RGB2SH +from optgs.scene_trainer.initializer.initializer import InitializerOutput, LearnedInitializer, PerPixelInitializerCfg + +try: + from optgs.model.encoder.point_transformer.layer import (PlainPointTransformer, SubsampleBlock, PointLinearWrapper, + MultiScalePointTransformer, + MultViewLowresAttn, MultViewUniMatchAttn, + GaussianErrorCrossAttn) +except: + pass + +from optgs.model.encoder.lvsm.transformer import LVSMTransformer + +try: + from simple_knn._C import distCUDA2 +except: + pass + + +@dataclass +class ResplatInitializerCfg(PerPixelInitializerCfg): + name: Literal["resplat_v1", "resplat_v2"] + d_feature: int + num_depth_candidates: int + num_surfaces: int + visualizer: EncoderVisualizerDepthSplatCfg + gaussian_adapter: GaussianAdapterCfg + gaussians_per_pixel: int + unimatch_weights_path: str | None + downscale_factor: int + shim_patch_size: int | List[int] + multiview_trans_attn_split: int + + deform_sample_depth: bool # non-pixel aligned Gaussians with learned offsets + deform_sample_depth_debug: bool # check depth sampling + + # mv_unimatch + num_scales: int + upsample_factor: int + lowest_feature_resolution: int + depth_unet_channels: int + grid_sample_disable_cudnn: bool + + # depthsplat color branch + large_gaussian_head: bool + color_large_unet: bool + init_sh_input_img: bool + feature_upsampler_channels: int + gaussian_regressor_channels: int + + # loss config + return_depth: bool + + # only depth + train_depth_only: bool + + # monodepth config + monodepth_vit_type: str + + # multi-view matching + local_mv_match: int + + # point transformer + pt_head: bool + init_pt_with_mv_attn: bool + init_pt_with_mv_attn_lowres: bool + pt_head_conv: bool + pt_head_concat_img: bool + pt_head_channels: int | None + multi_scale_pt: bool + attn_proj_channels: int | None + fps_num_samples: int | None + knn_samples: int + post_norm: bool + no_rpe: bool + no_knn_attn: bool + num_blocks: int + pt_downsample: int + fps_agg_func: str + subsample_method: str + add_pt_residual: bool + pt_pred_residual_position: bool # based on the inital point cloud from depth, predict additional residual + latent_dpt_upsampler: bool + latent_dpt_upsampler_no_concat: bool + light_dpt_feature: bool + + # freeze depth + freeze_depth: bool + use_gt_depth: bool + + # separate depth and color branches + separate_depth_color: bool + separate_depth_type: str + separate_depth_gaussian_scale: bool + + # unet gaussian regressor + unet_gaussian_regressor: bool + resnet_gaussian_regressor: bool + + # lvsm gaussian regressor + lvsm_gaussian_regressor: bool + lvsm_layers: int + + sample_log_depth: bool + bilinear_upsample_depth: bool + no_upsample_depth: bool + return_lowres_depth: bool + + # latent gaussian instead of pixel-aligned gaussian + fixed_latent_size: bool # same channels for both downsample 4 and 8 + latent_gs_img_interp: str + dpt_head_depth: bool # downsample the full resolution depth to low resolution + avgpool_depth: bool + nearest_down_depth: bool + + # predict scene scale and use point distance to normalize the scene + predict_scale: bool + norm_by_points: bool + no_pred_depth_range: bool + + # init gaussian scale with point cloud distance + point_dist_init_gaussian_scale: bool + + # feature upsampler + resizeconv_upsampler: bool + + # rotate_quat_to_world: bool # rotate the quaternion to world space + latent_new_reshape: bool # debug + + # amp + use_amp: bool + pt_head_amp: bool + + use_checkpointing: bool + init_use_checkpointing: bool # init model uses checkpointing + no_pixel_offset: bool + + pt_heads: int + init_gaussian_multiple: int + depth_pred_half_res: bool + + def get_feature_upsampler_channels(self): + # upsample features to the original resolution + model_configs = { + 'vits': {'in_channels': 384, 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'vitb': {'in_channels': 768, 'features': 96, 'out_channels': [96, 192, 384, 768]}, + 'vitl': {'in_channels': 1024, 'features': 128, 'out_channels': [128, 256, 512, 1024]}, + } + + vit_type = self.monodepth_vit_type + in_channels = model_configs[vit_type]['in_channels'] + + if self.latent_gs and not self.latent_dpt_upsampler: + if self.latent_downsample == 2: + feature_num = in_channels // 64 * 4 + 128 // 4 + 64 + 96 + 128 // 4 + elif self.latent_downsample == 4: + feature_num = in_channels // 4 + 128 + 64 + 96 + 128 + elif self.latent_downsample == 8: + if self.fixed_latent_size: + feature_num = in_channels // 4 + 128 + 64 + 96 + 128 + else: + feature_num = in_channels + 128 + 64 + 96 + 128 + else: + raise NotImplementedError(f"Unsupported latent_downsample value: {self.cfg.latent_downsample}") + elif self.resizeconv_upsampler: + feature_num = self.feature_upsampler_channels + else: + if self.light_dpt_feature: + for config in model_configs.values(): + config['out_channels'] = [c // 2 for c in config['out_channels']] + features = model_configs[vit_type]["features"] + if self.latent_gs and not self.latent_dpt_upsampler_no_concat: + features *= 4 + feature_num = features + + return feature_num, model_configs + + def get_pt_in_channels(self): + feature_upsampler_channels, _ = self.get_feature_upsampler_channels() + in_channels = 3 + feature_upsampler_channels + self.gaussian_regressor_channels + 1 + if self.latent_gs: + # image unshuffle + if self.fixed_latent_size: + in_channels = in_channels - 3 + 3 * (4 ** 2) + else: + in_channels = in_channels - 3 + 3 * (self.latent_downsample ** 2) + return in_channels + + def get_gaussian_param_num(self): + # predict gaussian parameters: scale, q, sh, offset, opacity + # d_in: (scale, q, sh) + sh_d = self.get_sh_d() + init_gaussian_param_num = 3 + 4 + 3 * sh_d + 2 + 1 + if self.no_pixel_offset: + init_gaussian_param_num -= 2 + if self.pt_downsample > 0: + # no pixel offset + init_gaussian_param_num -= 2 + if self.pt_pred_residual_position: + # based on the inital point cloud from depth, predict additional residual + # without pixel offset on 2d + init_gaussian_param_num = init_gaussian_param_num + 3 - 2 + # multiple gaussians per latent + if self.init_gaussian_multiple > 1: + # we use the point cloud unprojected from higher resolution depth map as center + # assert self.cfg.gaussian_adapter.init_rotation_identity + assert self.latent_gs + init_gaussian_param_num *= self.init_gaussian_multiple + return init_gaussian_param_num + + def get_sh_d(self): + sh_d = (self.gaussian_adapter.sh_degree + 1) ** 2 + return sh_d + + +class ResplatInitializer(LearnedInitializer[ResplatInitializerCfg]): + def __init__(self, cfg: ResplatInitializerCfg) -> None: + super().__init__(cfg) + + self.depth_predictor = self._get_depth_predictor(cfg) + + if self.cfg.train_depth_only: + return + + feature_upsampler_channels, model_configs = self.cfg.get_feature_upsampler_channels() + + if self.cfg.latent_gs and not self.cfg.latent_dpt_upsampler: + # No need to create a module — this config only computes channels + pass + elif self.cfg.resizeconv_upsampler: + self.feature_upsampler = ResizeConvFeatureUpsampler( + num_scales=cfg.num_scales, + lowest_feature_resolution=cfg.lowest_feature_resolution, + out_channels=self.cfg.feature_upsampler_channels, + vit_type=self.cfg.monodepth_vit_type, + ) + + else: + self.feature_upsampler = DPTHead( + **model_configs[cfg.monodepth_vit_type], + downsample_factor=cfg.upsample_factor, + return_feature=True, + num_scales=cfg.num_scales, + latent_downsample=self.cfg.latent_downsample if self.cfg.latent_gs else None, + latent_feature_no_concat=self.cfg.latent_dpt_upsampler_no_concat, + ) + + # gaussians adapter (can be removed) + self.gaussian_adapter = GaussianAdapter(cfg.gaussian_adapter) + + # concat(img, depth, match_prob, features) + in_channels = 3 + 1 + 1 + feature_upsampler_channels + channels = self.cfg.gaussian_regressor_channels + + if self.cfg.latent_gs: + # image unshuffle + if self.cfg.fixed_latent_size: + # fixed patch size 4 + in_channels = in_channels - 3 + 3 * (4 ** 2) + else: + in_channels = in_channels - 3 + 3 * (self.cfg.latent_downsample ** 2) + + # unet gaussian regressor + if self.cfg.unet_gaussian_regressor: + modules = [ + nn.Conv2d(in_channels, channels, 3, 1, 1), + nn.GroupNorm(8, channels), + nn.GELU(), + ] + + if self.cfg.color_large_unet: + unet_channel_mult = [1, 2, 4, 4, 4] + else: + unet_channel_mult = [1, 1, 1, 1, 1] + unet_attn_resolutions = [16] + + modules.append( + UNetModel( + image_size=None, + in_channels=channels, + model_channels=channels, + out_channels=channels, + num_res_blocks=1, # self.unet_per_scale_blocks, + # attention_resolutions=[8, 4, 2], + attention_resolutions=unet_attn_resolutions, + # channel_mult=[1, 1, 1, 1], + channel_mult=unet_channel_mult, + num_head_channels=32 if self.cfg.gaussian_regressor_channels >= 32 else 16, + dims=2, + postnorm=False, + num_frames=2, + use_cross_view_self_attn=True, + ) + ) + + modules.append(nn.Conv2d(channels, channels, 3, 1, 1)) + + elif self.cfg.resnet_gaussian_regressor: + modules = [ + nn.Conv2d(in_channels, channels, 3, 1, 1), + nn.GroupNorm(8, channels), + nn.GELU(), + BasicBlock(channels, channels), + BasicBlock(channels, channels), + ] + + elif self.cfg.lvsm_gaussian_regressor: + modules = [ + nn.Linear(in_channels, channels), + nn.LayerNorm(channels), + nn.GELU(), + LVSMTransformer(channels, + n_layer=self.cfg.lvsm_layers) + ] + + else: + # conv regressor + modules = [ + nn.Conv2d(in_channels, channels, 3, 1, 1), + nn.GELU(), + nn.Conv2d(channels, channels, 3, 1, 1), + ] + + self.gaussian_regressor = nn.Sequential(*modules) + + init_gaussian_param_num = self.cfg.get_gaussian_param_num() + + # gaussian head input channels + # concat(img, features, regressor_out, match_prob) + in_channels = self.cfg.get_pt_in_channels() + + if self.cfg.pt_head: + channels = self.cfg.gaussian_regressor_channels + if self.cfg.pt_head_channels is not None: + channels = self.cfg.pt_head_channels + self.proj = nn.Linear(in_channels, channels) + + if self.cfg.multi_scale_pt: + self.pt = MultiScalePointTransformer(channels, + self.cfg.knn_samples, + downsample_agg_func=self.cfg.fps_agg_func, + subsample_method=self.cfg.subsample_method, + fps_num_samples=self.cfg.fps_num_samples, + attn_proj_channels=self.cfg.attn_proj_channels, + ) + else: + self.pt = PlainPointTransformer(channels, self.cfg.knn_samples, + post_norm=self.cfg.post_norm, + no_rpe=self.cfg.no_rpe, + no_attn=self.cfg.no_knn_attn, + num_blocks=self.cfg.num_blocks, + num_heads=self.cfg.pt_heads, + attn_proj_channels=self.cfg.attn_proj_channels, + use_checkpointing=self.cfg.use_checkpointing, + init_use_checkpointing=self.cfg.init_use_checkpointing, + with_mv_attn=self.cfg.init_pt_with_mv_attn, + with_mv_attn_lowres=self.cfg.init_pt_with_mv_attn_lowres, + ) + + out_channels = channels + + # point downsample + if self.cfg.pt_downsample > 0: + num_downsample = int(np.log2(self.cfg.pt_downsample)) + + if num_downsample == 0: + stride = 1 + else: + stride = 2 + + assert num_downsample == 1, f"unsupported num_downsample: {num_downsample}" + + self.pt_down = SubsampleBlock(channels, out_channels=channels * 2, + stride=stride, + knn_samples=self.cfg.knn_samples, + post_norm=self.cfg.post_norm, + agg_func=self.cfg.fps_agg_func, + subsample_method=self.cfg.subsample_method, + ) + + out_channels = channels * 2 + + # TODO: add more pt blocks after downsampling + + if self.cfg.pt_head_concat_img: + # concat to the initial image and features + out_channels = out_channels + 3 + + if self.cfg.latent_gs: + # pixel unshuffle the full image to the latent resolution + out_channels = out_channels - 3 + 3 * (self.cfg.latent_downsample ** 2) + + self.gaussian_head = nn.Sequential( + nn.Linear(out_channels, init_gaussian_param_num), + nn.GELU(), + nn.Linear(init_gaussian_param_num, init_gaussian_param_num) + ) + + # random initialize rotations: first part + # 4 + num_rotation_params = 4 * self.cfg.init_gaussian_multiple + + # zero init other remaining params + # scale, opacity, offset, sh + # 4 + 1 + 1 + 3 * 16 = 54 + nn.init.zeros_(self.gaussian_head[-1].weight[num_rotation_params:]) + nn.init.zeros_(self.gaussian_head[-1].bias[num_rotation_params:]) + + else: + self.gaussian_head = nn.Sequential( + nn.Conv2d(in_channels, init_gaussian_param_num, + 3, 1, 1, padding_mode='replicate'), + nn.GELU(), + nn.Conv2d(init_gaussian_param_num, + init_gaussian_param_num, 3, 1, 1, padding_mode='replicate') + ) + + # random initialize rotations: first part + # 4 + num_rotation_params = 4 * self.cfg.init_gaussian_multiple + + # zero init other remaining params + # scale, opacity, offset, sh + # 3 + 1 + 2 + 3 * 16 = 54 + nn.init.zeros_(self.gaussian_head[-1].weight[num_rotation_params:]) + nn.init.zeros_(self.gaussian_head[-1].bias[num_rotation_params:]) + + self.test_save_every: FrequencyScheduler | None = None # a class to save intermediate results during testing, will be set by the ModelWrraper + + def _get_depth_predictor(self, cfg): + return MultiViewUniMatch( + num_scales=cfg.num_scales, + upsample_factor=cfg.upsample_factor, + lowest_feature_resolution=cfg.lowest_feature_resolution, + num_depth_candidates=cfg.num_depth_candidates, + vit_type=cfg.monodepth_vit_type, + unet_channels=cfg.depth_unet_channels, + grid_sample_disable_cudnn=cfg.grid_sample_disable_cudnn, + sample_log_depth=self.cfg.sample_log_depth, + bilinear_upsample_depth=self.cfg.bilinear_upsample_depth, + no_upsample_depth=self.cfg.no_upsample_depth, + use_amp=self.cfg.use_amp, + return_raw_mono_features=not self.cfg.latent_dpt_upsampler, + use_checkpointing=self.cfg.use_checkpointing, + ) + + def forward( + self, + context: BatchedViews, + visualization_dump: Optional[dict] = None, + **kwargs + ) -> InitializerOutput: + device = context["image"].device + b, v, _, h, w = context["image"].shape + + if v > 3: + with torch.no_grad(): + xyzs = context["extrinsics"][:, :, :3, -1].detach() + cameras_dist_matrix = torch.cdist(xyzs, xyzs, p=2) + cameras_dist_index = torch.argsort(cameras_dist_matrix) + + cameras_dist_index = cameras_dist_index[:, :, :(self.cfg.local_mv_match + 1)] + else: + cameras_dist_index = None + + # depth prediction + if self.cfg.depth_pred_half_res: + half_img = rearrange(context["image"], "b v c h w -> (b v) c h w") + half_img = F.interpolate(half_img, scale_factor=0.5, mode='bilinear', align_corners=True) + half_img = rearrange(half_img, "(b v) c h w -> b v c h w", b=b, v=v) + + results_dict = self.depth_predictor( + half_img, + attn_splits_list=[2], + min_depth=1. / context["far"], + max_depth=1. / context["near"], + intrinsics=context["intrinsics"], + extrinsics=context["extrinsics"], + nn_matrix=cameras_dist_index, + ) + + # upsample depth to the original resolution + for key in results_dict.keys(): + # NOTE: no need to upsample depth since depth later is in the low resolution + if key != 'depth_preds': + for i in range(len(results_dict[key])): + results_dict[key][i] = F.interpolate(results_dict[key][i], scale_factor=2, mode='bilinear', + align_corners=True) + + # depthsplat: upsample depth to the original resolution + if not self.cfg.latent_gs: + for i in range(len(results_dict['depth_preds'])): + results_dict['depth_preds'][i] = F.interpolate(results_dict['depth_preds'][i], scale_factor=2, + mode='bilinear', align_corners=True) + + else: + results_dict = self.depth_predictor( + context["image"], + attn_splits_list=[2], + min_depth=1. / context["far"], + max_depth=1. / context["near"], + intrinsics=context["intrinsics"], + extrinsics=context["extrinsics"], + nn_matrix=cameras_dist_index, + ) + + if self.cfg.use_gt_depth: + # directly use gt depth as gaussian centers instead of learning them + # to understand the bottleneck of the model + assert 'depth' in context + depth_preds = [context['depth']] + else: + # list of [B, V, H, W], with all the intermediate depths + depth_preds = results_dict['depth_preds'] + + # [B, V, H, W] + depth = depth_preds[-1] + + gaussian_scale_depth = None + + # features [BV, C, H, W] + if self.cfg.latent_gs and not self.cfg.latent_dpt_upsampler: + # concat all features + assert self.cfg.num_scales == 1 + + # use pixelshuffle and pixelunshuffle to align all feature resolutions + # first resize the mono features to 1/16 + mono_features = [F.interpolate(x, size=(h // 16, w // 16), mode='bilinear', align_corners=True) for x in + results_dict['raw_mono_features']] + if self.cfg.fixed_latent_size: + scale_factor = 4 + mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features] + mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 16 * 4 + + if self.cfg.latent_downsample == 8: + mono_features = F.interpolate(mono_features, scale_factor=0.5, mode='bilinear', align_corners=True) + else: + if self.cfg.latent_downsample == 4: + scale_factor = 4 + mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features] + mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 16 * 4 + elif self.cfg.latent_downsample == 2: + scale_factor = 8 + mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features] + mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 64 * 4 + elif self.cfg.latent_downsample == 8: + scale_factor = 2 + mono_features = [F.pixel_shuffle(x, upscale_factor=scale_factor) for x in mono_features] + mono_features = torch.cat(mono_features, dim=1) # channel: 384 / 4 * 4 + else: + raise NotImplementedError + + cnn_features = results_dict["features_cnn_all_scales"][::-1] + + if self.cfg.latent_downsample == 2: + # use pixel shuffle to save channels + # 1/2, 1/2, 1/4 + cnn_features[2] = F.pixel_shuffle(cnn_features[2], upscale_factor=2) + # 64 + 96 + 128 // 4 + cnn_features = torch.cat(cnn_features, dim=1) + + # 128 // 4 + mv_features = results_dict["features_mv"][0] + mv_features = F.pixel_shuffle(mv_features, upscale_factor=2) + else: + # resize all cnn features to the latent resolution + target_h, target_w = h // self.cfg.latent_downsample, w // self.cfg.latent_downsample + for i in range(len(cnn_features)): + cnn_features[i] = F.interpolate(cnn_features[i], size=(target_h, target_w), mode='bilinear', + align_corners=True) + cnn_features = torch.cat(cnn_features, dim=1) + + mv_features = results_dict["features_mv"][0] + + if mv_features.shape[-2] != target_h or mv_features.shape[-1] != target_w: + mv_features = F.interpolate(mv_features, size=(target_h, target_w), mode='bilinear', + align_corners=True) + + features = torch.cat((mono_features, cnn_features, mv_features), dim=1) + elif self.cfg.resizeconv_upsampler: + features = self.feature_upsampler(results_dict["features_cnn"], + results_dict["features_mv"], + results_dict["features_mono"], + ) + + else: + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16): + features = self.feature_upsampler(results_dict["features_mono_intermediate"], + cnn_features=results_dict["features_cnn_all_scales"][::-1], + mv_features=results_dict["features_mv"][ + 0] if self.cfg.num_scales == 1 else results_dict["features_mv"][ + ::-1] + ) + + # match prob from softmax + # [BV, D, H, W] in feature resolution + match_prob = results_dict['match_probs'][-1] + match_prob = torch.max(match_prob, dim=1, keepdim=True)[ + 0] # [BV, 1, H, W] + + if not self.cfg.latent_gs: + match_prob = F.interpolate( + match_prob, size=depth.shape[-2:], mode='nearest') + + # unet input + if self.cfg.latent_gs: + img_unshuffle = rearrange(context["image"], "b v c h w -> (b v) c h w") + if self.cfg.fixed_latent_size: + if self.cfg.latent_downsample == 8: + img_unshuffle = F.interpolate(img_unshuffle, scale_factor=0.5, mode='area') + + img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=4) + else: + img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=self.cfg.latent_downsample) + # depth is in the full resolution, downsample to latent depth + if self.cfg.depth_pred_half_res: + latent_depth = F.interpolate(depth, scale_factor=1. / (self.cfg.latent_downsample // 2), + mode='bilinear', align_corners=True) + else: + if self.cfg.no_upsample_depth: + assert self.cfg.latent_downsample == 8 or self.cfg.latent_downsample == 4 + if self.cfg.latent_downsample == 8: + latent_depth = depth + else: + # 1/8 depth to 1/4 + latent_depth = F.interpolate(depth, scale_factor=2, mode='bilinear', align_corners=True) + else: + if self.cfg.avgpool_depth: + latent_depth = F.avg_pool2d(depth, kernel_size=self.cfg.latent_downsample, + stride=self.cfg.latent_downsample) + elif self.cfg.nearest_down_depth: + latent_depth = F.interpolate(depth, scale_factor=1. / self.cfg.latent_downsample, + mode='nearest') + else: + latent_depth = F.interpolate(depth, scale_factor=1. / self.cfg.latent_downsample, + mode='bilinear', align_corners=True) + + if match_prob.shape[-2:] != latent_depth.shape[-2]: + match_prob = F.interpolate( + match_prob, size=latent_depth.shape[-2:], mode='nearest') + + concat = torch.cat(( + img_unshuffle, + rearrange(latent_depth, "b v h w -> (b v) () h w"), + match_prob, + features, + ), dim=1) + else: + concat = torch.cat(( + rearrange(context["image"], "b v c h w -> (b v) c h w"), + rearrange(depth, "b v h w -> (b v) () h w"), + match_prob, + features, + ), dim=1) + + if self.cfg.lvsm_gaussian_regressor: + h, w = concat.shape[-2:] + tmp = rearrange(concat, "(b v) c h w -> b (v h w) c", b=b, v=v) + with torch.autocast('cuda', dtype=torch.bfloat16): + out = self.gaussian_regressor(tmp) + + out = rearrange(out, "b (v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) + else: + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16): + out = self.gaussian_regressor(concat) + + if self.cfg.latent_gs: + concat = [out, img_unshuffle, features, match_prob] + else: + concat = [out, + rearrange(context["image"], + "b v c h w -> (b v) c h w"), + features, + match_prob] + + out = torch.cat(concat, dim=1) + + # [BV, C, H, W] + condition_features = out + + init_scales = None + + if self.cfg.pt_head: + if self.cfg.latent_gs: + h, w = latent_depth.shape[-2:] + else: + h, w = depth.shape[-2:] + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_head_amp, dtype=torch.bfloat16): + tmp_feature = self.proj(rearrange(out, "bv c h w -> (bv h w) c")) + # get point cloud + xy_ray, _ = sample_image_grid((h, w), out.device) + xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") + + # [B, V, H*W, 1, 2] + tmp_coords = xy_ray.unsqueeze(0).unsqueeze(0).repeat(b, v, 1, 1, 1) + + # [B, V, H*W, 1, 1] + if self.cfg.latent_gs: + tmp_depth = rearrange(latent_depth, "b v h w -> b v (h w) () ()") + else: + tmp_depth = rearrange(depth, "b v h w -> b v (h w) () ()") + + # [B, V, 1, 1, 4, 4] + tmp_extrinsics = context["extrinsics"].unsqueeze(2).unsqueeze(2) + # [B, V, 1, 1, 3, 3] + tmp_intrinsics = context["intrinsics"].unsqueeze(2).unsqueeze(2) + + # [B, V, H*W, 1, 3] + origins, directions = get_world_rays(tmp_coords, tmp_extrinsics, tmp_intrinsics) + point_cloud = origins + directions * tmp_depth + + # Create offset directly on device to avoid CPU-GPU transfer + offset = torch.arange(1, b + 1, device=depth.device, dtype=torch.long) * (v * h * w) + + point_cloud = rearrange(point_cloud, "b v h w c -> (b v h w) c") + + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_head_amp, dtype=torch.bfloat16): + if self.cfg.add_pt_residual: + out = tmp_feature + self.pt((point_cloud, tmp_feature, offset), b=b, v=v, h=h, w=w) + else: + out = self.pt((point_cloud, tmp_feature, offset), b=b, v=v, h=h, w=w) + + condition_features = rearrange(out, "(bv h w) c -> bv c h w", h=h, w=w) + + if self.cfg.pt_downsample > 0: + out, fps_idx = self.pt_down((point_cloud, out, offset)) + # [N, 3] + point_cloud, out, offset = out + + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_head_amp, dtype=torch.bfloat16): + if self.cfg.pt_head_concat_img: + if self.cfg.latent_gs: + # pixel unshuffle image + img_unshuffle = rearrange(context["image"], "b v c h w -> (b v) c h w") + img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=self.cfg.latent_downsample) + img_unshuffle = rearrange(img_unshuffle, "(b v) c h w -> (b v h w) c", b=b, v=v) + + out = torch.cat((out, img_unshuffle), dim=-1) + + if self.cfg.pt_head_conv: + out = rearrange(out, "(b v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) + + out = self.gaussian_head(out) + + if self.cfg.pt_head_conv: + out = rearrange(out, "(b v) c h w -> (b v h w) c", b=b, v=v) + + if self.cfg.pt_downsample > 0: + # [N, C] + gaussians = out + else: + if self.cfg.pt_pred_residual_position: + # TODO: add intermediate supervision to the initial point cloud + # TODO: multiple scale factor to the delta position to make it more stable + # residual position + point_cloud = point_cloud + out[..., -3:] # [BVHW, 3] + + # remaining gaussians + out = out[..., :-3] + + point_cloud = rearrange(point_cloud, "(b v h w) c -> b v (h w) () () c", b=b, v=v, h=h, w=w) + + gaussians = rearrange(out, "(b v h w) c -> (b v) c h w", b=b, h=h, w=w) + + else: + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16): + gaussians = self.gaussian_head(out) # [BV, C, H, W] + + # [BV, C, H, W] + gaussians = gaussians.float() + + if self.cfg.latent_gs: + if self.cfg.init_gaussian_multiple > 1: + # hard coded for now + if self.cfg.init_gaussian_multiple == 4: + # TODO: try avgpooling downsampling depth + if self.cfg.latent_downsample == 4: + # resize full resolution depth + depths = F.interpolate(depth, scale_factor=0.5, mode='bilinear', align_corners=True) + elif self.cfg.latent_downsample == 8: + depths = F.interpolate(depth, scale_factor=0.25, mode='bilinear', align_corners=True) + elif self.cfg.latent_downsample == 2: + depths = depth + else: + raise NotImplementedError + elif self.cfg.init_gaussian_multiple == 16: + # TODO: try avgpooling downsampling depth + if self.cfg.latent_downsample == 4: + depths = depth + elif self.cfg.latent_downsample == 8: + depths = F.interpolate(depth, scale_factor=0.5, mode='bilinear', align_corners=True) + else: + raise NotImplementedError + else: + raise NotImplementedError + + depths = rearrange(depths, "b v h w -> b v (h w) () ()") + else: + depths = rearrange(latent_depth, "b v h w -> b v (h w) () ()") + else: + depths = rearrange(depth, "b v h w -> b v (h w) () ()") + + if self.cfg.pt_downsample > 0: + + # split batch + assert offset.shape[0] == b + + if self.cfg.latent_gs: + sh_input_images = rearrange(context["image"], "b v c h w -> (b v) c h w") + if self.cfg.latent_gs_img_interp == 'bicubic': + sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample, + mode='bicubic', align_corners=True) + elif self.cfg.latent_gs_img_interp == 'area': + sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample, + mode='area') + elif self.cfg.latent_gs_img_interp == 'softmax': + sh_input_images = self.softmax_downsample(sh_input_images) + else: + raise NotImplementedError + + h, w = sh_input_images.shape[-2:] + + sh_input_images = rearrange(sh_input_images, "(b v) c h w -> b v c h w", b=b, v=v) + else: + sh_input_images = context["image"] + + sh_input_images = rearrange(sh_input_images, "b v c h w -> (b v h w) c") + + # subsample with fps index + sh_input_images = sh_input_images[fps_idx.long(), :] # [N, 3] + + # extrinsics + extrinsics_all = rearrange(repeat(context["extrinsics"], "b v i j -> b v h w i j", h=h, w=w), + "b v h w i j -> (b v h w) i j" + ) + extrinsics_all = extrinsics_all[fps_idx.long(), :, :] # [N, 4, 4] + + point_list = [point_cloud[:offset[0]]] + gaussian_list = [gaussians[:offset[0]]] + sh_img_list = [sh_input_images[:offset[0]]] + extrinsics_list = [extrinsics_all[:offset[0]]] + + for i in range(b - 1): + point_list.append(point_cloud[offset[i]:offset[i + 1]]) + gaussian_list.append(gaussians[offset[i]:offset[i + 1]]) + sh_img_list.append(sh_input_images[offset[i]:offset[i + 1]]) + extrinsics_list.append(extrinsics_all[offset[i]:offset[i + 1]]) + + point_cloud = torch.stack(point_list, dim=0) # [B, N, 3] + gaussians = torch.stack(gaussian_list, dim=0) # [B, N, C] + sh_imgs = torch.stack(sh_img_list, dim=0) # [B, N, 3] + extrinsics_all = torch.stack(extrinsics_list, dim=0) # [B, N, 4, 4] + + # point_cloud = [point_cloud[offset[i]:offset[i+1]] for i in range(b)] + # point_cloud = torch.stack(point_cloud, dim=0) # [B, N, 3] + # gaussians = [gaussians[offset[i]:offset[i+1]] for i in range(b)] + # gaussians = torch.stack(gaussians, dim=0) # [B, N, 3] + + opacities = gaussians[..., 0].sigmoid() # [B, N] + + gaussians = self.gaussian_adapter.forward( + extrinsics=extrinsics_all, + intrinsics=None, + coordinates=None, + depths=None, + opacities=opacities, + raw_gaussians=gaussians[..., 1:], + image_shape=None, + point_cloud=point_cloud, + input_images=sh_imgs, + ) + + gaussians = rearrange(gaussians, "(b v) c h w -> b v c h w", b=b, v=v) + + # [B, V, H*W, 84] + raw_gaussians = rearrange( + gaussians, "b v c h w -> b v (h w) c") + + assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported" + + # [B, V, H*W, C] + repeat = self.cfg.init_gaussian_multiple + num_sh = self.gaussian_adapter.d_sh + + if self.cfg.no_pixel_offset: + rotations_unnorm, scales, opacities_raw, sh = raw_gaussians.split( + [4 * repeat, 3 * repeat, 1 * repeat, 3 * num_sh * repeat], + dim=-1, + ) + else: + rotations_unnorm, scales, opacities_raw, offset, sh = raw_gaussians.split( + [4 * repeat, 3 * repeat, 1 * repeat, 2 * repeat, 3 * num_sh * repeat], + dim=-1, + ) + + latent_h, latent_w = gaussians.shape[-2:] + + if repeat > 1: + # reshape all the gaussian parameters + if True or self.cfg.latent_new_reshape: + # this works + r = int(np.sqrt(repeat)) + rotations_unnorm = rearrange(rotations_unnorm, "b v (h w) (c x y) -> b v (h x w y) c", + h=latent_h, w=latent_w, x=r, y=r) + scales = rearrange(scales, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, + y=r) + opacities_raw = rearrange(opacities_raw, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, + w=latent_w, x=r, y=r) + offset = rearrange(offset, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, + y=r) + sh = rearrange(sh, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, y=r) + else: + # doesn't work + rotations_unnorm = rearrange(rotations_unnorm, "b v hw (k c) -> b v (hw k) c", k=repeat) + scales = rearrange(scales, "b v hw (k c) -> b v (hw k) c", k=repeat) + opacities_raw = rearrange(opacities_raw, "b v hw (k c) -> b v (hw k) c", k=repeat) + offset = rearrange(offset, "b v hw (k c) -> b v (hw k) c", k=repeat) + sh = rearrange(sh, "b v hw (k c) -> b v (hw k) c", k=repeat) + + opacities = opacities_raw.sigmoid() # [B, V, H*W*K, 1] + + if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4: + scale_factor = 2 + elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4: + scale_factor = 2 + elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16: + scale_factor = 4 + elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4: + scale_factor = 2 + elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16: + scale_factor = 4 + else: + scale_factor = 1 + + h, w = latent_h * scale_factor, latent_w * scale_factor + + # unproject depth + xy_ray, _ = sample_image_grid((h, w), device) # [H, W, 2] in [0, 1] + xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") # [H*W, 1, 2] + + if self.cfg.no_pixel_offset: + offset_xy = torch.ones_like(raw_gaussians[..., :2]).unsqueeze(-2).to( + raw_gaussians.device) * 0.5 # [B, V, H*W, 1, 2] + else: + offset_xy = offset.sigmoid().unsqueeze(-2) # [B, V, H*W, 1, 2] + + pixel_size = 1 / \ + torch.tensor((w, h), dtype=torch.float32, device=device) + # [H*W, 1, 2] + if self.cfg.deform_sample_depth and not self.cfg.deform_sample_depth_debug: + # (offset_xy - 0.5) in -0.5 to 0.5, without multiplying by pixel size such that the points can move in the image space + xy_ray = (xy_ray + (offset_xy - 0.5)).clamp(min=0., max=1.) + else: + xy_ray = xy_ray + (offset_xy - 0.5) * pixel_size + + if self.cfg.deform_sample_depth: + # use low-res xy_ray to sample full-res depth + + sample_grid = rearrange(xy_ray, "b v (h w) c xy -> (b v) h w (c xy)", h=h, w=w) # in [0, 1] + # to [-1, 1] + sample_grid = 2 * (sample_grid - 0.5) # [BV, h, w, 2] + + fullres_depth = rearrange(depth, "b v h w -> (b v) () h w") # [BV, 1, H, W] + sampled_depth = F.grid_sample(fullres_depth, sample_grid, mode='bilinear', align_corners=True, + padding_mode="border") # [BV, 1, h, w] + # reshape + depths = rearrange(sampled_depth, "(b v) () h w -> b v (h w) () ()", b=b, v=v, h=h, w=w) + + if self.cfg.latent_gs: + sh_input_images = rearrange(context["image"], "b v c h w -> (b v) c h w") + if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4: + sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area') + elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16: + pass + elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4: + pass + elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4: + sh_input_images = F.interpolate(sh_input_images, scale_factor=0.25, mode='area') + elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16: + sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area') + else: + sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample, + mode='area') + + sh_input_images = rearrange(sh_input_images, "(b v) c h w -> b v c h w", b=b, v=v) + + else: + sh_input_images = context["image"] + + assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported" + + # build gaussians + # scale + scales = torch.clamp(F.softplus(scales - self.cfg.gaussian_adapter.exp_scale_bias), + min=self.cfg.gaussian_adapter.clamp_min_scale, + max=self.cfg.gaussian_adapter.gaussian_scale_max + ) + + # Normalize the quaternion features to yield a valid quaternion. + # rotations = rotations_unnorm / (rotations_unnorm.norm(dim=-1, keepdim=True) + 1e-8) + + # Convert rotations to world-space + c2w_rotations = context["extrinsics"][..., :3, :3].unsqueeze(2) # [B, V, 1, 3, 3] + rotations = rotate_quats(c2w_rotations, rotations_unnorm) + rotations_unnorm = rotations.clone() + # Create world-space covariance matrices. + covariances = build_covariance(scale=scales, rotation_xyzw=rotations) # [B, V, H*W, 3, 3] + + # means + # [B, V, H*W, 1, 2] + # xy_ray = xy_ray.unsqueeze(0).unsqueeze(0).repeat(b, v, 1, 1, 1) + origins, directions = get_world_rays(xy_ray, + context["extrinsics"].unsqueeze(2).unsqueeze(2), + context["intrinsics"].unsqueeze(2).unsqueeze(2)) + means = origins + directions * depths + + # sh: [B, V, HW, 3, SH] + sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3).clone() + # sh = sh.broadcast_to((*opacities.shape, 3, self.gaussian_adapter.d_sh)).clone() + + # [B, V, H*W, 3] + sh_input_images = rearrange(sh_input_images, "b v c h w -> b v (h w) c") + # init sh with input images + sh[..., 0] = sh[..., 0] + RGB2SH(sh_input_images) + + gaussians = Gaussians( + means=rearrange(means, "b v r spp xyz -> b (v r spp) xyz"), + covariances=rearrange(covariances, "b v r i j -> b (v r) i j"), + harmonics=rearrange(sh, "b v r c d_sh -> b (v r) c d_sh"), + opacities=rearrange(opacities, "b v r spp -> b (v r spp)"), + scales=rearrange(scales, "b v r xyz -> b (v r) xyz"), + rotations=rearrange(rotations, "b v r wxyz -> b (v r) wxyz"), # in wxyz format + rotations_unnorm=rearrange(rotations_unnorm, "b v r wxyz -> b (v r) wxyz") # in wxyz format + ) + + else: + gaussians = rearrange(gaussians, "(b v) c h w -> b v c h w", b=b, v=v) + + # [B, V, H*W, 84] + raw_gaussians = rearrange( + gaussians, "b v c h w -> b v (h w) c") + + assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported" + + # [B, V, H*W, C] + repeat = self.cfg.init_gaussian_multiple + num_sh = self.gaussian_adapter.d_sh + + if self.cfg.no_pixel_offset: + rotations_unnorm, scales, opacities_raw, sh = raw_gaussians.split( + [4 * repeat, 3 * repeat, 1 * repeat, 3 * num_sh * repeat], + dim=-1, + ) + else: + rotations_unnorm, scales, opacities_raw, offset, sh = raw_gaussians.split( + [4 * repeat, 3 * repeat, 1 * repeat, 2 * repeat, 3 * num_sh * repeat], + dim=-1, + ) + + latent_h, latent_w = gaussians.shape[-2:] + + if repeat > 1: + # reshape all the gaussian parameters + if True or self.cfg.latent_new_reshape: + # this works + r = int(np.sqrt(repeat)) + rotations_unnorm = rearrange(rotations_unnorm, "b v (h w) (c x y) -> b v (h x w y) c", + h=latent_h, w=latent_w, x=r, y=r) + scales = rearrange(scales, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, + y=r) + opacities_raw = rearrange(opacities_raw, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, + w=latent_w, x=r, y=r) + offset = rearrange(offset, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, + y=r) + sh = rearrange(sh, "b v (h w) (c x y) -> b v (h x w y) c", h=latent_h, w=latent_w, x=r, y=r) + else: + # doesn't work + rotations_unnorm = rearrange(rotations_unnorm, "b v hw (k c) -> b v (hw k) c", k=repeat) + scales = rearrange(scales, "b v hw (k c) -> b v (hw k) c", k=repeat) + opacities_raw = rearrange(opacities_raw, "b v hw (k c) -> b v (hw k) c", k=repeat) + offset = rearrange(offset, "b v hw (k c) -> b v (hw k) c", k=repeat) + sh = rearrange(sh, "b v hw (k c) -> b v (hw k) c", k=repeat) + + opacities = opacities_raw.sigmoid() # [B, V, H*W*K, 1] + + if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4: + scale_factor = 2 + elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4: + scale_factor = 2 + elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16: + scale_factor = 4 + elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4: + scale_factor = 2 + elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16: + scale_factor = 4 + else: + scale_factor = 1 + + h, w = latent_h * scale_factor, latent_w * scale_factor + + # unproject depth + xy_ray, _ = sample_image_grid((h, w), device) + xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") + + if self.cfg.no_pixel_offset: + offset_xy = torch.ones_like(raw_gaussians[..., :2]).unsqueeze(-2).to( + raw_gaussians.device) * 0.5 # [B, V, H*W, 1, 2] + else: + offset_xy = offset.sigmoid().unsqueeze(-2) # [B, V, H*W, 1, 2] + + pixel_size = 1 / \ + torch.tensor((w, h), dtype=torch.float32, device=device) + # [H*W, 1, 2] + if self.cfg.deform_sample_depth and not self.cfg.deform_sample_depth_debug: + # (offset_xy - 0.5) in -0.5 to 0.5, without multiplying by pixel size such that the points can move in the image space + xy_ray = (xy_ray + (offset_xy - 0.5)).clamp(min=0., max=1.) + else: + xy_ray = xy_ray + (offset_xy - 0.5) * pixel_size + + if self.cfg.deform_sample_depth: + # use low-res xy_ray to sample full-res depth + + sample_grid = rearrange(xy_ray, "b v (h w) c xy -> (b v) h w (c xy)", h=h, w=w) # in [0, 1] + # to [-1, 1] + sample_grid = 2 * (sample_grid - 0.5) # [BV, h, w, 2] + + fullres_depth = rearrange(depth, "b v h w -> (b v) () h w") # [BV, 1, H, W] + sampled_depth = F.grid_sample(fullres_depth, sample_grid, mode='bilinear', align_corners=True, + padding_mode="border") # [BV, 1, h, w] + # reshape + depths = rearrange(sampled_depth, "(b v) () h w -> b v (h w) () ()", b=b, v=v, h=h, w=w) + + if self.cfg.latent_gs: + sh_input_images = rearrange(context["image"], "b v c h w -> (b v) c h w") + if self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 4: + sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area') + elif self.cfg.latent_downsample == 4 and self.cfg.init_gaussian_multiple == 16: + pass + elif self.cfg.latent_downsample == 2 and self.cfg.init_gaussian_multiple == 4: + pass + elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 4: + sh_input_images = F.interpolate(sh_input_images, scale_factor=0.25, mode='area') + elif self.cfg.latent_downsample == 8 and self.cfg.init_gaussian_multiple == 16: + sh_input_images = F.interpolate(sh_input_images, scale_factor=0.5, mode='area') + else: + sh_input_images = F.interpolate(sh_input_images, scale_factor=1. / self.cfg.latent_downsample, + mode='area') + + sh_input_images = rearrange(sh_input_images, "(b v) c h w -> b v c h w", b=b, v=v) + + else: + sh_input_images = context["image"] + + assert len(depth_preds) == 1, "num_scales must be 1; multi-scale depth supervision is not supported" + + # build gaussians + # scale + scales = torch.clamp(F.softplus(scales - self.cfg.gaussian_adapter.exp_scale_bias), + min=self.cfg.gaussian_adapter.clamp_min_scale, + max=self.cfg.gaussian_adapter.gaussian_scale_max + ) + + # Convert rotations to world-space + c2w_rotations = context["extrinsics"][..., :3, :3].unsqueeze(2) # [B, V, 1, 3, 3] + # Here quaternions follow the xyzw format (scalar last) + rotations = rotate_quats(c2w_rotations, rotations_unnorm) + rotations_unnorm = rotations.clone() + # Create world-space covariance matrices. + covariances = build_covariance(scale=scales, rotation_xyzw=rotations) # [B, V, H*W, 3, 3] + + # means + # [B, V, H*W, 1, 2] + origins, directions = get_world_rays(xy_ray, + context["extrinsics"].unsqueeze(2).unsqueeze(2), + context["intrinsics"].unsqueeze(2).unsqueeze(2)) + means = origins + directions * depths + + # sh: [B, V, HW, 3, SH] + sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3).clone() + + # [B, V, H*W, 3] + sh_input_images = rearrange(sh_input_images, "b v c h w -> b v (h w) c") + # init sh with input images + sh[..., 0] = sh[..., 0] + RGB2SH(sh_input_images) + + gaussians = Gaussians( + means=rearrange(means, "b v r spp xyz -> b (v r spp) xyz"), + covariances=rearrange(covariances, "b v r i j -> b (v r) i j"), + harmonics=rearrange(sh, "b v r c d_sh -> b (v r) c d_sh"), + opacities=rearrange(opacities, "b v r spp -> b (v r spp)"), + scales=rearrange(scales, "b v r xyz -> b (v r) xyz"), + rotations=rearrange(rotations, "b v r wxyz -> b (v r) wxyz"), + rotations_unnorm=rearrange(rotations_unnorm, "b v r wxyz -> b (v r) wxyz") + ) + + # Dump visualizations if needed. + if visualization_dump is not None: + visualization_dump["depth"] = rearrange( + depths, "b v (h w) srf s -> b v h w srf s", h=h, w=w + ) + # if self.cfg.pt_downsample > 0: + # visualization_dump["scales"] = gaussians.scales + # visualization_dump["rotations"] = gaussians.rotations + # else: + # visualization_dump["scales"] = rearrange( + # gaussians.scales, "b v r srf spp xyz -> b (v r srf spp) xyz" + # ) + # visualization_dump["rotations"] = rearrange( + # gaussians.rotations, "b v r srf spp xyzw -> b (v r srf spp) xyzw" + # ) + + if self.cfg.return_depth: + # return depth prediction for supervision + depths = depth_preds[-1] + + if self.cfg.return_lowres_depth: + assert latent_depth is not None + depths = latent_depth + else: + if depths.shape[-2:] != context["image"].shape[-2:]: + # depths can be at lower resolution since we predict latent + depths = F.interpolate( + depths, size=context["image"].shape[-2:], mode='bilinear', align_corners=True) + + return InitializerOutput( + gaussians=gaussians, + depths=depths, + features=condition_features + ) + else: + + return InitializerOutput( + gaussians=gaussians, + features=condition_features + ) + + def get_data_shim(self) -> DataShim: + def data_shim(batch: BatchedExample) -> BatchedExample: + patch_size = self.cfg.shim_patch_size + if isinstance(self.cfg.shim_patch_size, int): + patch_size = patch_size * self.cfg.downscale_factor + else: + patch_size = [p * self.cfg.downscale_factor for p in patch_size] + batch = apply_patch_shim( + batch, + patch_size=patch_size, + ) + + return batch + + return data_shim + + @staticmethod + def update_gt_depth_range(batch): + assert "depth" in batch["context"] + batch["context"]["near"] = batch["context"]["depth"].min(dim=3)[0].min(dim=2)[0].clamp(min=0.01) + batch["context"]["far"] = batch["context"]["depth"].max(dim=3)[0].max(dim=2)[0].clamp(max=1000.) + batch["target"]["near"] = batch["target"]["depth"].min(dim=3)[0].min(dim=2)[0].clamp(min=0.01) + batch["target"]["far"] = batch["target"]["depth"].max(dim=3)[0].max(dim=2)[0].clamp(max=1000.) + + def update_depth_range_from_disparity(self, batch): + b, v, _, h, w = batch["context"]["image"].shape + # TODO: support multi-view later + assert v == 2 + assert self.decoder.cfg.scale_invariant is False + w = batch["context"]["image"].shape[-1] + # compute the depth range based on disparity range + dist = (batch["context"]["extrinsics"][:, 0, :3, 3] - batch["context"]["extrinsics"][:, 1, :3, 3]).norm( + dim=1, keepdim=True) + focal = batch["context"]["intrinsics"][:, :, 0, 0] * w + min_depth = dist * focal / self.train_cfg.max_disparity + max_depth = dist * focal / self.train_cfg.min_disparity + batch["context"]["near"] = min_depth + batch["context"]["far"] = max_depth + # TODO: also update target near and far + + def predict_scale(self, batch): + context = batch["context"] + # [B, V, H, W] + init_depth = self.encoder.scale_predictor( + context["image"], + attn_splits_list=[2], + min_depth=1. / context["far"], + max_depth=1. / context["near"], + intrinsics=context["intrinsics"], + extrinsics=context["extrinsics"], + )['depth_preds'][-1] + if not self.encoder.cfg.no_pred_depth_range: + new_near = init_depth.min(dim=3)[0].min(dim=2)[0].clamp(min=0.1) # [B, V] + new_far = init_depth.max(dim=3)[0].max(dim=2)[0].clamp(max=100.) + + batch["context"]["near"] = new_near + batch["context"]["far"] = new_far + + batch["target"]["near"] = new_near.min(dim=1, keepdim=True)[0].repeat(1, + batch["target"]["near"].shape[1]) + batch["target"]["far"] = new_far.max(dim=1, keepdim=True)[0].repeat(1, batch["target"]["near"].shape[1]) + if self.encoder.cfg.norm_by_points: + b, v, h, w = init_depth.shape + # get point cloud + xy_ray, _ = sample_image_grid((h, w), batch["context"]["image"].device) + xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") + + # [B, V, H*W, 1, 2] + tmp_coords = xy_ray.unsqueeze(0).unsqueeze(0).repeat(b, v, 1, 1, 1) + + # [B, V, H*W, 1, 1] + tmp_depth = rearrange(init_depth, "b v h w -> b v (h w) () ()") + + # [B, V, 1, 1, 4, 4] + tmp_extrinsics = context["extrinsics"].unsqueeze(2).unsqueeze(2) + # [B, V, 1, 1, 3, 3] + tmp_intrinsics = context["intrinsics"].unsqueeze(2).unsqueeze(2) + + # [B, V, H*W, 1, 3] + origins, directions = get_world_rays(tmp_coords, tmp_extrinsics, tmp_intrinsics) + point_cloud = origins + directions * tmp_depth + + point_cloud = rearrange(point_cloud, "b v h w c -> b (v h w) c") + + point_dist = point_cloud.norm(dim=-1).mean(dim=-1) # [B] + + norm_factor = point_dist.clamp(min=1e-6) + + # normalize near, far and extrinsics + batch["context"]["near"] = batch["context"]["near"] / norm_factor.view(b, 1) + batch["context"]["far"] = batch["context"]["far"] / norm_factor.view(b, 1) + + batch["target"]["near"] = batch["target"]["near"] / norm_factor.view(b, 1) + batch["target"]["far"] = batch["target"]["far"] / norm_factor.view(b, 1) + + batch["context"]["extrinsics"][:, :, :3, -1] /= norm_factor.view(b, 1, 1) + batch["target"]["extrinsics"][:, :, :3, -1] /= norm_factor.view(b, 1, 1) + + def preprocessing(self, batch, train_cfg): + # use gt depth range instead of a fixed one + if train_cfg.use_gt_depth_range: + self.update_gt_depth_range(batch) + # compute depth range from camera distance and disparity range + if train_cfg.depth_range_from_disparity: + self.update_depth_range_from_disparity(batch) + + # use a pretrained depth model to predict scale + if self.cfg.predict_scale: + self.predict_scale(batch) diff --git a/optgs/scene_trainer/optimizer/__init__.py b/optgs/scene_trainer/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05769793054a3f8e6fc2482279e1ca9ceee0ef24 --- /dev/null +++ b/optgs/scene_trainer/optimizer/__init__.py @@ -0,0 +1,41 @@ +from .optimizer import Optimizer, OptimizerCfg +from .optimizer_learn2splat import Learn2SplatOptimizer +from .optimizer_knn_based import KnnBasedOptimizer, KnnBasedOptimizerCfg +from .optimizer_resplat import ResplatOptimizerV1, ResplatOptimizerV2 +from .optimizer_adam import AdamOptimizerCfg, AdamOptimizer + +SceneOptimizerCfg = KnnBasedOptimizerCfg | AdamOptimizerCfg + + +def extract_opt_params(cfg): + opt_params = cfg.__dict__.copy() + opt_params.pop("enabled", None) + opt_params.pop("name", None) + opt_params.pop("steps", None) + opt_params.pop("scheduler", None) + opt_params.pop("scheduler_warm_up_ratio", None) + opt_params.pop("compute_metrics_every", None) + return opt_params + + +SCENE_OPTIMIZERS = { + "none": None, + "depthsplat": KnnBasedOptimizer, + "resplat_v1": ResplatOptimizerV1, + "resplat_v2": ResplatOptimizerV2, + "clogs": Learn2SplatOptimizer, # TODO (release): remove + "l2s": Learn2SplatOptimizer, + "adam": AdamOptimizer, +} + + +def get_scene_optimizer(cfg: SceneOptimizerCfg | None) -> Optimizer | None: + if cfg is None: + print("Using scene optimizer: None") + return None + print(f"Using scene optimizer: {cfg.name}") + scene_optimizer = SCENE_OPTIMIZERS[cfg.name] + if scene_optimizer is None: + return None + scene_optimizer = scene_optimizer(cfg) + return scene_optimizer diff --git a/optgs/scene_trainer/optimizer/layer.py b/optgs/scene_trainer/optimizer/layer.py new file mode 100644 index 0000000000000000000000000000000000000000..641a68f3c6f64075121f344e41e13e8656e74a29 --- /dev/null +++ b/optgs/scene_trainer/optimizer/layer.py @@ -0,0 +1,436 @@ +import torch +from jaxtyping import Float +from torch import nn as nn, Tensor + + +class SlicedG3RNorm(nn.Module): + def __init__(self, num_features, input_slice, eps=1e-8): + """ + Apply G3R normalization to a slice of the input features. + + Devide the input by the maximum absolute value in each channel. + + Args: + num_features (int): Total number of features (channels). + input_slice (slice): Size of each slice to normalize independently. + eps (float): Small constant to prevent division by zero. + """ + super().__init__() + self.input_slice = input_slice + dummy = torch.zeros(1, num_features) + chunk = dummy[:, input_slice] + + self.slice_size = chunk.shape[-1] + self.eps = eps + + def forward(self, x: Float[Tensor, "B C"]): + """ + Args: + x (Tensor): Shape (B, C) where C = num_features + Returns: + Tensor: Same shape, only subset of channels normalized + """ + # Split input into the slice to normalize and the rest + chunk = x[:, self.input_slice] + + # Compute max absolute value per channel + # Detach to avoid backpropagating through max operation + max_val_per_channel = chunk.abs().max(0, keepdim=True)[0].detach() + self.eps + + # Apply G3R normalization to the selected slice + # Replace the normalized slice back into the original input + x = x.clone() + x[:, self.input_slice] = chunk / max_val_per_channel + return x + + +class SlicedBatchNorm1d(nn.Module): + def __init__(self, num_features, input_slice, eps=1e-8, affine=False, track_running_stats=True): + """ + Apply normalization independently to a slice of the input features. + + Args: + num_features (int): Total number of features (channels). + input_slice (slice): Size of each slice to normalize independently. + eps (float): Small constant to prevent division by zero. + affine (bool): Whether to include learnable scale and bias per slice. + """ + super().__init__() + self.input_slice = input_slice + dummy = torch.zeros(1, num_features) + chunk = dummy[:, input_slice] + + self.slice_size = chunk.shape[-1] + self.eps = eps + + # Create a BatchNorm1d module for each slice + self.slice_norm = nn.BatchNorm1d(self.slice_size, eps=eps, affine=affine, + track_running_stats=track_running_stats) + + def forward(self, x): + """ + Args: + x (Tensor): Shape (B, C) where C = num_features + Returns: + Tensor: Same shape, only subset of channels normalized + """ + B, C = x.shape + + # Split input into the slice to normalize and the rest + chunk = x[:, self.input_slice] + + # Apply normalization to the selected slice + chunk = self.slice_norm(chunk) + + # Replace the normalized slice back into the original input + x = x.clone() + x[:, self.input_slice] = chunk + return x + + +class CustomGroupNorm(nn.Module): + def __init__(self, group_sizes, eps=1e-8, affine=True): + """ + Args: + group_sizes (list[int]): List of channel counts for each group. Must sum to total input channels. + eps (float): Small constant to prevent division by zero. + affine (bool): Whether to include learnable scale and bias per group. + """ + super().__init__() + self.group_sizes = group_sizes + self.total_channels = sum(group_sizes) + self.eps = eps + + # Create a LayerNorm module for each group + self.group_norms = nn.ModuleList([ + nn.LayerNorm([size], eps=eps, elementwise_affine=affine) + for size in group_sizes + ]) + + def forward(self, x): + """ + Args: + x (Tensor): Shape (B, C, H, W) + Returns: + Tensor: Same shape, group-wise normalized + """ + B, C = x.shape + assert C == self.total_channels, ( + f"Input has {C} channels, expected {self.total_channels} from group sizes {self.group_sizes}" + ) + + # Split input into channel groups + splits = torch.split(x, self.group_sizes, dim=1) + normed = [] + for i, g in enumerate(splits): + normed_group = self.group_norms[i](g) + normed.append(normed_group) + + return torch.cat(normed, dim=1) + + +class AdamState: + def __init__(self, m, v, t): + self.m = m # First moment vector + self.v = v # Second moment vector + self.t = t # Time step + + +def slice_length(s, dim): + step = s.step or 1 + start = s.start if s.start is not None else (0 if step > 0 else dim - 1) + stop = s.stop if s.stop is not None else (dim if step > 0 else -1) + if start < 0: start += dim + if stop < 0: stop += dim + start = max(0, min(dim, start)) + stop = max(0, min(dim, stop)) + return max(0, (stop - start + (step - 1)) // step) if step > 0 else \ + max(0, (start - stop + (-step - 1)) // -step) + + +@torch.compile(dynamic=True) +def _adam_smooth_unmasked(m, v, t, chunk, beta1, beta2, eps) -> Tensor: + """Fused moment update + bias-corrected output for the unmasked path.""" + m.lerp_(chunk, 1 - beta1) + v.mul_(beta2).addcmul_(chunk, chunk, value=1 - beta2) + t_bc = t.reshape(t.shape[0], *([1] * (m.ndim - 1))) + bias1 = 1 - beta1 ** t_bc + bias2_sqrt = (1 - beta2 ** t_bc).sqrt_() + denom = v.sqrt().div_(bias2_sqrt).add_(eps) + return m.div(bias1).div_(denom) + + +@torch.compile(dynamic=True) +def _adam_smooth_masked(m, v, t, sel, chunk, beta1, beta2, eps) -> Tensor: + """Fused moment update + bias-corrected output for the masked path.""" + m_sel = m[sel].lerp_(chunk, 1 - beta1) + m[sel] = m_sel + + v_sel = v[sel].mul_(beta2).addcmul_(chunk, chunk, value=1 - beta2) + v[sel] = v_sel + + t[sel] += 1 + + t_sel = t[sel].reshape(-1, *([1] * (m.ndim - 1))) + m_hat = m_sel / (1 - beta1 ** t_sel) + v_hat = v_sel / (1 - beta2 ** t_sel) + return m_hat / (torch.sqrt(v_hat) + eps) + + +class AdamInputSmoothing(nn.Module): + def __init__(self, beta1=0.9, beta2=0.999, eps=1e-15, input_slice: slice | None = None, + shape: tuple | None = None, + device=None): + """ + Implements Adam-like smoothing for input vectors. + + Args: + beta1 (float): Exponential decay rate for the first moment estimates. + beta2 (float): Exponential decay rate for the second moment estimates. + eps (float): Small constant to prevent division by zero. + input_slice (slice, optional): If provided, only apply smoothing to this slice of the input. + """ + super().__init__() + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = eps + self.input_slice: slice | None = input_slice + if self.input_slice is not None: + assert isinstance(self.input_slice, slice), "input_slice must be a slice or None" + + # Initialize first and second moment vectors + if shape is None: + self.reset() + else: + self.initialize(shape, + device=device) + + def forward(self, x: Tensor) -> Tensor: + """ + Apply Adam-like smoothing to the input. + + Args: + x (Tensor): Input tensor of shape (..., input_dim) + + Returns: + Tensor: Smoothed tensor of same shape as input + """ + # Select the relevant slice of the input + chunk = x[..., self.input_slice] if self.input_slice is not None else x + + # Initialize internal state if needed + if self.is_reset(): + self.initialize(chunk.shape, device=chunk.device) + + chunk_detached = chunk.detach() + + if self.sel is None: + # Increment step first (matches PyTorch Adam convention) + self.t += 1 + # Fused moment update + bias-corrected output (compiled kernel) + chunk_smoothed = _adam_smooth_unmasked(self.m, self.v, self.t, chunk_detached, + self.beta1, self.beta2, self.eps) + else: + # Fused masked update (compiled kernel) + chunk_smoothed = _adam_smooth_masked(self.m, self.v, self.t, self.sel, chunk_detached, + self.beta1, self.beta2, self.eps) + + # Replace in original tensor + if self.input_slice is not None: + output_shape = slice_length(self.input_slice, x.shape[-1]) + if output_shape == x.shape[-1]: + x_out = chunk_smoothed + else: + # only replace a slice, so we need to clone to avoid modifying input + x_out = x.clone() + x_out[..., self.input_slice] = chunk_smoothed + else: + # we overwrite the whole tensor, no need to clone + x_out = chunk_smoothed + + return x_out + + def reset(self): + """Reset the internal state.""" + self.m = torch.tensor(0, dtype=torch.float32) + self.v = torch.tensor(0, dtype=torch.float32) + self.t = torch.tensor(0, dtype=torch.int64) + + self.sel = None + + def initialize(self, shape, device) -> None: + """Initialize the internal state with zeros for the given number of elements and input dimension.""" + self.m = torch.zeros(shape, dtype=torch.float32, device=device) + self.v = torch.zeros(shape, dtype=torch.float32, device=device) + self.t = torch.zeros(shape[0], dtype=torch.int64, device=device) + + self.sel = None + + def update_state(self, adam_state: AdamState) -> None: + """Update the internal state with provided values.""" + m, v, t = adam_state.m, adam_state.v, adam_state.t + self.m = m + self.v = v + self.t = t + + self.sel = None + + def prune(self, prune_mask: Tensor) -> None: + """Prune the internal state to only keep entries at the specified indices.""" + assert not self.is_reset(), ( + "Cannot prune state that has not been initialized. Call forward() at least once first." + ) + sel = torch.where(~prune_mask)[0] + self.m = self.m[sel] + self.v = self.v[sel] + self.t = self.t[sel] + + if self.sel is not None: + self.sel = self.sel[sel] + + def zero_out(self, zero_t=False) -> None: + """Zero out the moments. Called when resetting gaussians opacities.""" + assert not self.is_reset(), ( + "Cannot extend state that has not been initialized. Call forward() at least once first." + ) + self.m = torch.zeros_like(self.m) + self.v = torch.zeros_like(self.v) + if zero_t: + self.t = torch.zeros_like(self.t) + + def replace(self, from_indices: Tensor, dest_indices: Tensor, zero_t=False) -> None: + """Replace the internal state to duplicate entries at the specified indices.""" + assert not self.is_reset(), ( + "Cannot extend state that has not been initialized. Call forward() at least once first." + ) + + self.m[dest_indices] = self.m[from_indices] + self.v[dest_indices] = self.v[from_indices] + if zero_t: + self.t[dest_indices] = 0 + else: + self.t[dest_indices] = self.t[from_indices] + + def clone(self, clone_mask: Tensor, zero_t=False) -> None: + """Clone the internal state to duplicate entries at the specified indices.""" + assert not self.is_reset(), ( + "Cannot extend state that has not been initialized. Call forward() at least once first." + ) + + num_new_rows = clone_mask.sum() + new_zeros = torch.zeros((num_new_rows, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype) + if zero_t: + new_t = torch.zeros((num_new_rows, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype) + else: + sel = torch.where(clone_mask)[0] + new_t = self.t[sel] + + self.m = torch.cat([self.m, new_zeros], dim=0) + self.v = torch.cat([self.v, new_zeros], dim=0) + self.t = torch.cat([self.t, new_t], dim=0) + + def add(self, nr_new: int) -> None: + """Add new entries to the internal state.""" + assert not self.is_reset(), ( + "Cannot extend state that has not been initialized. Call forward() at least once first." + ) + + new_zeros = torch.zeros((nr_new, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype) + new_t = torch.zeros((nr_new, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype) + + self.m = torch.cat([self.m, new_zeros], dim=0) + self.v = torch.cat([self.v, new_zeros], dim=0) + self.t = torch.cat([self.t, new_t], dim=0) + + def split(self, split_mask: Tensor, N: int, zero_t=False) -> None: + """Split the internal state to duplicate entries at the specified indices.""" + assert not self.is_reset(), ( + "Cannot extend state that has not been initialized. Call forward() at least once first." + ) + + # Count how many new rows we need + num_new_rows = split_mask.sum() * N + + # Handle t depending on zero_t flag + if zero_t: + new_t = torch.zeros((num_new_rows, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype) + else: + # Only t needs to copy repeated original values + sel = torch.where(split_mask)[0] + new_t = self.t[sel].repeat_interleave(N, dim=0) + + rest_sel = torch.where(~split_mask)[0] + + # Preallocate zeros directly for m and v + new_zeros = torch.zeros((num_new_rows, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype) + self.m = torch.cat([self.m[rest_sel], new_zeros], dim=0) + self.v = torch.cat([self.v[rest_sel], new_zeros], dim=0) + self.t = torch.cat([self.t[rest_sel], new_t], dim=0) + + def get_state(self) -> AdamState: + """Get the current internal state.""" + return AdamState(self.m, self.v, self.t) + + def subgroups_view(self, slices: dict[str, slice]) -> dict[str, "AdamInputSmoothing"]: + """ + Create lightweight subgroups that share memory with the main tensor states. + + Args: + slices (dict[str, slice]): Mapping from subgroup name to slice, e.g.: + {"means": slice(0, 3) ,"scale": slice(3, 6), "rotation": slice(6, 10), "opacity": slice(10, 11), "sh": slice(11, 59)} + + Returns: + dict[str, AdamInputSmoothing]: Submodules that share self.m and self.v tensors. + """ + if not hasattr(self, "m") or self.m.ndim == 0: + raise RuntimeError("Cannot create subgroups before the first forward() call.") + + subgroups = {} + for name, slc in slices.items(): + sub = AdamInputSmoothing( + beta1=self.beta1, + beta2=self.beta2, + eps=self.eps, + input_slice=None + ) + + # share the same memory (not copy) + sub.m = self.m[..., slc] + sub.v = self.v[..., slc] + sub.t = self.t # shared time step + + subgroups[name] = sub + + return subgroups + + def aggregate_from_subgroups(self, subgroups: dict[str, "AdamInputSmoothing"], slices: dict[str, slice]) -> None: + """ + Aggregate states from subgroups back into the main module. + + Args: + subgroups (dict[str, AdamInputSmoothing]): Submodules created via subgroups_view. + slices (dict[str, slice]): Mapping from subgroup name to slice, e.g.: + {"means": slice(0, 3) ,"scale": slice(3, 6), "rotation": slice(6, 10), "opacity": slice(10, 11), "sh": slice(11, 59)} + """ + if not hasattr(self, "m") or self.is_reset(): + raise RuntimeError("Cannot aggregate states before the first forward() call.") + + # Adjust stats shape + first_m_val = next(iter(subgroups.values())).m + if self.m.shape[:-1] != first_m_val.shape[:-1]: + self.m = torch.zeros((*first_m_val.shape[:-1], self.m.shape[-1]), dtype=first_m_val.dtype, + device=first_m_val.device) + self.v = torch.zeros((*first_m_val.shape[:-1], self.v.shape[-1]), dtype=first_m_val.dtype, + device=first_m_val.device) + + for name, slc in slices.items(): + sub = subgroups[name] + self.m[..., slc] = sub.m + self.v[..., slc] = sub.v + # Assume time step is the same across all subgroups + self.t = next(iter(subgroups.values())).t + + def is_reset(self) -> bool: + """Check if the internal state is reset.""" + assert self.m.shape == self.v.shape, "First and second moment vectors must have the same shape." + return bool(self.m.ndim == 0 and self.v.ndim == 0 and self.t == 0) diff --git a/optgs/scene_trainer/optimizer/lr_scheduler.py b/optgs/scene_trainer/optimizer/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a6064dbf7f2d63a00cef3e75d5e595db85e0409e --- /dev/null +++ b/optgs/scene_trainer/optimizer/lr_scheduler.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass + +from optgs.scene_trainer.optimizer.optimizer_utils import Number3DGSCfg, Bool3DGSCfg + + +@dataclass +class SchedulerCfg: + name: str + lr_data: Number3DGSCfg + apply_scheduler: Bool3DGSCfg + + +@dataclass +class DDIMSchedulerCfg(SchedulerCfg): + T: int + min_lr: float | int + s: float | int + + +LrSchedulerCfgType = DDIMSchedulerCfg | SchedulerCfg + + +class Scheduler[SchedulerCfg]: + def __init__(self, cfg: SchedulerCfg): + self.cfg = cfg + + def get_lr(self, t: int, param: str) -> float | int: + lr = getattr(self.cfg.lr_data, param) + apply = getattr(self.cfg.apply_scheduler, param) + if apply: + return self.scheduler_fn(t, lr) + else: + return lr + + def scheduler_fn(self, t: int, base_lr: float | int) -> float | int: + raise NotImplementedError + + +class DummyScheduler(Scheduler[SchedulerCfg]): + def get_lr(self, t: int, param: str) -> float | int: + return 1 + + +class DDIMCosineScheduler(Scheduler[DDIMSchedulerCfg]): + def scheduler_fn(self, t: int, base_lr: float | int) -> float | int: + # Implement DDIM Cosine scheduling logic here + import math + t = min(t, self.cfg.T) + rel_t = t / self.cfg.T + alpha_bar = math.cos((rel_t + self.cfg.s) / (1 + self.cfg.s) * math.pi / 2) ** 2 + lr = self.cfg.min_lr + alpha_bar * (base_lr - self.cfg.min_lr) + return lr + + +SCHEDULERS = { + "none": Scheduler, + "ddim": DDIMCosineScheduler, +} + +def get_scheduler(cfg: SchedulerCfg) -> Scheduler: + print(f"Using scheduler: {cfg.name}") + scheduler_class = SCHEDULERS[cfg.name] + scheduler = scheduler_class(cfg) + return scheduler + +if __name__ == "__main__": + cfg = DDIMSchedulerCfg( + name="DDIMCosineScheduler", + lr_data=Number3DGSCfg(_base=1.0, _means=1.0, _scales=1.0, _opacities=1.0, _quats=1.0, _sh0=1.0, _shN=1.0), + apply_scheduler=Bool3DGSCfg(_base=True, _means=True, _scales=True, _opacities=True, _quats=True, _sh0=True, + _shN=True), + T=24, + min_lr=0.0, + s=0.008 + ) + scheduler = DDIMCosineScheduler(cfg) + + iterations = list(range(0, 24)) + lr = [scheduler.get_lr(t, "means") for t in iterations] + + import matplotlib.pyplot as plt + + plt.plot(iterations, lr) + plt.xlabel("Timestep") + plt.ylabel("Learning Rate") + plt.title("DDIM Cosine Learning Rate Schedule") + plt.grid() + plt.show() diff --git a/optgs/scene_trainer/optimizer/optimizer.py b/optgs/scene_trainer/optimizer/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1f2030a85c1799732ab343ccba19e550e41642 --- /dev/null +++ b/optgs/scene_trainer/optimizer/optimizer.py @@ -0,0 +1,452 @@ +from abc import ABC +from dataclasses import dataclass, field +from pathlib import Path +from typing import TypeVar, Generic, Optional, TYPE_CHECKING, Any +import torch +from matplotlib import pyplot as plt +from torch import nn +from torch import Tensor +import numpy as np +import os +from optgs.dataset.camera_datasets.camera import get_scene_scale +from optgs.misc.io import FrequencyScheduler +from optgs.dataset.data_types import BatchedViews +from optgs.model.decoder import Decoder +from optgs.model.decoder.decoder import DecoderOutput +from optgs.model.types import Gaussians +from optgs.scene_trainer.adc.base import BaseStrategyCfg +from optgs.scene_trainer.initializer.initializer import InitializerOutput +from optgs.scene_trainer.optimizer.layer import AdamState +from optgs.scene_trainer.initializer import InitializerCfg +from optgs.misc.detaching_cpu_list import DetachingCPUList +from optgs.scene_trainer.optimizer.lr_scheduler import LrSchedulerCfgType, get_scheduler + +if TYPE_CHECKING: + from optgs.scene_trainer.adc.vanilla import VanillaStrategyState + from optgs.scene_trainer.adc.mcmc import McmcStrategyState + + +@dataclass +class OptimizerState: + state: torch.Tensor | None = None + init_state: torch.Tensor | None = None # state at the beginning of the optimization + adam_state: AdamState | None = None + adc_state: Any = None # VanillaStrategyState | McmcStrategyState | None + + +@dataclass +class OptimizerPreviousOutput: + gaussians: Gaussians + state: OptimizerState | None = None + + +@dataclass +class OptimizerInput: + context: BatchedViews + renderer: Decoder + prev_output: InitializerOutput | OptimizerPreviousOutput + num_refine: int + iter_batch_size: int | None + target: BatchedViews | None = None + context_remain: dict | None = None + debug_dict: dict | None = None + additional_info: tuple | None = None + + @property + def device(self) -> torch.device: + return self.context["image"].device + + +@dataclass +class OptimizerOutput: + # TODO Naama: should we add here iterations? + gaussian_list: DetachingCPUList[Gaussians] + t: int | None = None + T: int | None = None + last_prev_output: OptimizerPreviousOutput | None = None + target_render_list: DetachingCPUList[DecoderOutput] | None = None + context_render_list: DetachingCPUList[DecoderOutput] | None = None + info: dict | None = None + context_index_list: list[int] = field(default_factory=list) + target_index_list: list[int] = field(default_factory=list) + + def get_render_list(self, which: str) -> DetachingCPUList[DecoderOutput] | None: + if which == "target": + return self.target_render_list + elif which == "context": + return self.context_render_list + else: + raise ValueError(f"Unknown which: {which}, should be 'target' or 'context'") + + def get_index_list(self, which: str): + if which == "target": + return self.target_index_list + elif which == "context": + return self.context_index_list + else: + raise ValueError(f"Unknown which: {which}, should be 'target' or 'context'") + + @classmethod + def empty(cls, t=None) -> "OptimizerOutput": + new = cls(gaussian_list=DetachingCPUList(), t=t) + new.target_render_list = DetachingCPUList() + new.context_render_list = DetachingCPUList() + # info is a dict of lists of dicts, should all be stored in cpu + new.info: dict[str, list[dict[str, Tensor]]] = {} + return new + + +@dataclass +class OptimizerCfg: + + # subset optimization flags + no_refine_mean: bool + no_refine_scale: bool + no_refine_rotation: bool + no_refine_opacity: bool + no_refine_sh0: bool + no_refine_shN: bool + + # lr scheduler + lr_scheduler: LrSchedulerCfgType + + refiner: BaseStrategyCfg + + # gradients + input_gradients_chunk_size: int | None # if None, use full image + + # L1 opacity regularization from 3DGS-MCMC (arXiv:2404.09591); 0.0 to disable + opacity_reg_lambda: float + + def update(self, initializer_cfg: InitializerCfg): + pass + + @property + def any_adc(self) -> bool: + return self.refiner.do_densify or self.refiner.do_prune or self.refiner.do_opacity_reset + + @property + def need_2d_grads(self) -> bool: + return self.refiner.do_densify + + @property + def optimize_all(self): + # All the no_refine_* are False + return not any([ + self.no_refine_mean, + self.no_refine_scale, + self.no_refine_rotation, + self.no_refine_opacity, + self.no_refine_sh0, + self.no_refine_shN, + ]) + + +T = TypeVar("T") + + +class Optimizer(nn.Module, ABC, Generic[T]): + cfg: T + + def __init__(self, cfg: T, save_every: Optional[FrequencyScheduler] = None) -> None: + super().__init__() + self.cfg = cfg + self.save_every = save_every + + # for timing + self.iter_start = torch.cuda.Event(enable_timing=True) + self.iter_end = torch.cuda.Event(enable_timing=True) + # decoder_event_start/end bracket only the rendering-for-gradients call inside + # apply_one_update_step, letting us split iter_time into decoder vs optimizer. + self.decoder_event_start = torch.cuda.Event(enable_timing=True) + self.decoder_event_end = torch.cuda.Event(enable_timing=True) + # scene_start_event_start/end bracket optimizer.on_scene_start() (KNN, Adam init). + # Read after the post-loop cuda.synchronize() in scene_trainer.get_optimized_gaussians. + self.scene_start_event_start = torch.cuda.Event(enable_timing=True) + self.scene_start_event_end = torch.cuda.Event(enable_timing=True) + + # Init logs for densification/pruning + self.radii_max_log = [] + self.grads_max_log = [] + self.nr_cloned_log = [] + self.nr_splitted_log = [] + self.nr_pruned_log = [] + self.nr_gaussians_log = [] + self.iter_time_log = [] # total ms per iteration + self.decoder_time_log = [] # ms spent in rendering-for-gradients per iteration + self.optimizer_time_log = [] # ms spent in update step (iter_time - decoder_time) + self.scene_start_ms = 0.0 # ms for on_scene_start (KNN lookup, Adam state init) + self.nr_nonzero_grad_log = [] + + # LR scheduler + self.scheduler = get_scheduler(self.cfg.lr_scheduler) + + def forward(self, i, optimizer_input: OptimizerInput, optimizer_output: OptimizerOutput, **kwargs) -> OptimizerOutput: + return self._forward_impl(i, optimizer_input, optimizer_output, **kwargs) + + def _record_iter_timing(self) -> None: + """Record per-iteration timing into iter/decoder/optimizer_time_log. + Call right after the timed region; iter_start must already be recorded.""" + self.iter_end.record() + torch.cuda.synchronize() + elapsed_time = self.iter_start.elapsed_time(self.iter_end) + self.iter_time_log.append(elapsed_time) + decoder_ms = self.decoder_event_start.elapsed_time(self.decoder_event_end) + self.decoder_time_log.append(decoder_ms) + self.optimizer_time_log.append(elapsed_time - decoder_ms) + + def on_scene_start(self, optimizer_input: OptimizerInput) -> None: + self._on_scene_start_impl(optimizer_input) + + def _on_scene_start_impl(self, optimizer_input: OptimizerInput) -> None: + init_output = optimizer_input.prev_output + assert isinstance(init_output, InitializerOutput), \ + (f"base Optimizer class on_scene_start just convert the InitializerOutput to OptimizerPreviousOutput, " + f"without handling the state. " + f"It also initialize a new state for density control." + f"Got type {type(init_output)}") + + # Converting the initializer output to optimizer previous output + optimizer_prev_output = OptimizerPreviousOutput( + gaussians=init_output.gaussians.clone(), + state=None, + ) + optimizer_input.prev_output = optimizer_prev_output + + if self.cfg.any_adc: + self.reset_logs() + optimizer_prev_output.state = OptimizerState() # init to empty state + self.initialize_adc_state(self.cfg, optimizer_input) + + def on_scene_end(self) -> None: + pass + + def reset_logs(self): + self.radii_max_log = [] + self.grads_max_log = [] + self.nr_cloned_log = [] + self.nr_splitted_log = [] + self.nr_pruned_log = [] + self.nr_gaussians_log = [] + self.iter_time_log = [] + self.decoder_time_log = [] + self.optimizer_time_log = [] + self.scene_start_ms = 0.0 + self.nr_nonzero_grad_log = [] + + @staticmethod + def initialize_adc_state(cfg: OptimizerCfg, optimizer_input: OptimizerInput) -> None: + # Lazy import to avoid circular dependency + from optgs.scene_trainer.adc import init_strategy_state + + # get number of points + init_gaussians = optimizer_input.prev_output.gaussians + nr_points = init_gaussians.means.shape[1] + # get scene extent + context = optimizer_input.context + target = optimizer_input.target + assert ( + context["extrinsics"].shape[0] == context["intrinsics"].shape[0] == 1 + ), "scene batch size > 1 not supported yet..." + + scene_scale = context["scene_scale"][0].item() + # Initialize ADC state + optimizer_input.prev_output.state.adc_state = init_strategy_state( + cfg=cfg.refiner, + nr_points=nr_points, + device=init_gaussians.means.device, + scene_extent=scene_scale + ) + print("Initialized ADC state with", nr_points, "points and scene extent", scene_scale) + + def _forward_impl(self, i, optimizer_input: OptimizerInput, optimizer_output: OptimizerOutput, **kwargs) -> OptimizerOutput: + raise NotImplementedError() + + def validate_input(self, optimizer_input: OptimizerInput) -> None: + pass + + def _save_post_update_renders( + self, + i: int, + optimizer_input: OptimizerInput, + optimizer_output: OptimizerOutput, + updated_gaussians: Gaussians, + full_context: BatchedViews, + full_target: BatchedViews, + ) -> None: + """Render and append post-update context+target views. + + Renders every iteration during training (so per-step renders can feed the meta-loss); + otherwise renders only when save_every fires for the given tag. The per-iter subset + (optimizer_input.context/target) is used in training when sampling indices exist, + otherwise the full views. + """ + for tag, full, iter_views in ( + ("context", full_context, optimizer_input.context), + ("target", full_target, optimizer_input.target), + ): + if not (self.training or self.save_every(i + 1, tag=tag)): + continue + index_list = optimizer_output.get_index_list(tag) + subset = iter_views if (index_list and self.training) else full + render_output = optimizer_input.renderer.forward_batch_subset( + updated_gaussians, + subset, + iter_batch_size=optimizer_input.iter_batch_size, + ) + optimizer_output.get_render_list(tag).append( + render_output, + detach_and_cpu=not self.training, + ) + + @torch.no_grad() + def apply_adc(self, i, v, h, w, adc_state, gaussians, meta, object_dict_to_adjust=None): + """ + Apply adaptive density control (ADC) based on 2D gradient norms. + Implements densification and pruning of Gaussians during optimization, as in vanilla 3DGS. + + Args: + gaussians: Gaussians to be densified/pruned in place. + h: Height of the rendered images. + i: Current optimization iteration. + v: Number of views. + meta: Metadata dict from the rendering, including visibility masks and radii. + w: Width of the rendered images. + object_dict_to_adjust: Dict of object to adjust after pruning and densification, if needed. + """ + # Lazy import to avoid circular dependency + from optgs.scene_trainer.adc import post_backward + + visibility_mask = meta["visibility_filter"] # [B, V, N] + radii_2d = meta["radii"].float() # [B, V, N, 2] + means2d_grads = meta["means_2d_grads"] # [B, V, N, 2] or None + + # means lr for MCMC noise injection + # check if optimizer has means_lr_scheduler + if hasattr(self, "means_lr_scheduler"): + assert self.means_lr_scheduler is not None, "means_lr_scheduler is None." + lr = self.means_lr_scheduler(i) + else: + # Use fallback_means_lr from the refiner config so noise magnitude matches the + # original paper (means_lr * noise_lr ≈ 1.6e-4 * 5e5 = 80 covariance-units). + lr = self.cfg.refiner.fallback_means_lr + + # Post-backward (ADC) + nr_cloned, nr_splitted, nr_pruned, max_radii, max_grad2d = post_backward( + cfg=self.cfg.refiner, + step=i, + gaussians=gaussians, + adc_state=adc_state, + smoothers=object_dict_to_adjust, + radii_2d=radii_2d, # [V, N] + means2d_grads=means2d_grads, # [V, N, 2] + visibility_mask=visibility_mask, # [V, N] + iter_batch_size=v, + w=w, + h=h, + lr=lr + ) + + self.nr_cloned_log.append(nr_cloned) + self.nr_splitted_log.append(nr_splitted) + self.nr_pruned_log.append(nr_pruned) + if max_radii is not None: + self.radii_max_log.append(max_radii) + else: + self.radii_max_log.append(0.0) + if max_grad2d is not None: + self.grads_max_log.append(max_grad2d) + else: + self.grads_max_log.append(0.0) + + def plot_info(self, step, output_path: Path | None = None, scene_name: str | None = None) -> None: + + if output_path is None: + return + + if scene_name is None: + return + + save_path = output_path / "plots" / scene_name + os.makedirs(save_path, exist_ok=True) + + # Define datasets and labels in a compact structure + data = [] + + if len(self.radii_max_log) == len(self.iter_time_log): + data.append((range(len(self.iter_time_log)), self.radii_max_log, "Max Radius")) + if len(self.grads_max_log) == len(self.iter_time_log): + data.append((range(len(self.iter_time_log)), self.grads_max_log, "Max Grad magnitude")) + if len(self.nr_cloned_log) == len(self.iter_time_log): + data.append((range(len(self.iter_time_log)), self.nr_cloned_log, "Cloned")) + if len(self.nr_splitted_log) == len(self.iter_time_log): + data.append((range(len(self.iter_time_log)), self.nr_splitted_log, "Splitted")) + if len(self.nr_pruned_log) == len(self.iter_time_log): + data.append((range(len(self.iter_time_log)), self.nr_pruned_log, "Pruned")) + + data.append((range(len(self.iter_time_log)), self.nr_gaussians_log, "Total")) + data.append((range(len(self.iter_time_log)), self.iter_time_log, "Iteration Time (ms)")) + + # Create a larger figure with shared x-axis + nr_rows = len(data) + fig, axes = plt.subplots(nr_rows, 1, figsize=(10, 15), sharex=True) + + # Define some styles for visual variety + styles = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink'] + assert nr_rows <= len(styles), "Not enough styles defined for the number of subplots." + + # Loop through subplots + for ax, (x, y, label), color in zip(axes, data, styles): + ax.plot(x, y, label=label, color=color, linewidth=2) + ax.set_ylabel("Value", fontsize=11) + ax.grid(True, linestyle="--", alpha=0.6) + ax.legend(loc="upper right", fontsize=10) + ax.set_title(f"{label} Gaussians", fontsize=13, pad=5) + # show x-axis ticks on all plots + ax.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=True) + # set y-axis vmin to 0 + # ax.set_ylim(bottom=0) + + # Shared x-axis label + axes[-1].set_xlabel("Iteration", fontsize=12) + # Improve layout + plt.tight_layout() + plt.subplots_adjust(hspace=0.3) + # + # module_name = self.__class__.__name__.lower() + + # Save and close + save_path = save_path / f"stats_{step}.png" + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print("Saved optimizer stats plot to:", save_path) + + +class LearnedOptimizer(Optimizer[T], ABC): + @property + def strategy(self) -> str: + return "learned" + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + +class NonlearnedOptimizer(Optimizer[T], ABC): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # nn.Module.__init__ sets training=True (a plain attribute, not via + # train()); a non-learned optimizer has no trainable parameters, so pin + # it to eval at construction. + self.eval() + + @property + def strategy(self) -> str: + return "nonlearned" + + def train(self, mode: bool = True): + # train mode is meaningless here, and `self.training` gates + # meta-training-only code paths (e.g. _save_post_update_renders + # retaining full-scene renders on GPU). Pin to eval, even under a + # generic `module.train()` recursion. + return super().train(False) diff --git a/optgs/scene_trainer/optimizer/optimizer_adam.py b/optgs/scene_trainer/optimizer/optimizer_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..2632db990e489de54d4f7f521949a2ff8f7f5b69 --- /dev/null +++ b/optgs/scene_trainer/optimizer/optimizer_adam.py @@ -0,0 +1,387 @@ +from dataclasses import dataclass +from functools import partial +from typing import Literal, List, Optional + +import torch +from torch import Tensor + +from optgs.dataset.data_types import BatchedViews +from optgs.misc.general_utils import get_expon_lr_func +from optgs.misc.io import FrequencyScheduler +from optgs.model.decoder.decoder import Decoder +from optgs.model.types import Gaussians +from optgs.scene_trainer.initializer import InitializerCfg +from optgs.scene_trainer.optimizer.layer import AdamInputSmoothing +from optgs.scene_trainer.optimizer.optimizer import ( + OptimizerInput, + OptimizerOutput, + OptimizerCfg, NonlearnedOptimizer, +) +from optgs.scene_trainer.optimizer.optimizer_utils import ( + calc_input_gradients, + squeeze_grad_dict, + smooth_grads, +) + + +@dataclass +class AdamOptimizerCfg(OptimizerCfg): + name: Literal["adam"] + + # adam params + betas: List[float | int] # Typically a list of two floats, e.g., [0.9, 0.999] + eps: float + weight_decay: float + + # learning rates + base_lr: int | float + means_lr_init: float + means_lr_final: float + means_lr_delay_mult: float + means_lr_max_steps: int # should be equal to total optimization steps + scales_lr: float + rotations_lr: float + opacities_lr: float + sh0s_lr: float + shNs_lr: float # 20 times less as sh0s_lr in original paper + + def update(self, initializer_cfg: InitializerCfg): + pass + + +class AdamOptimizer(NonlearnedOptimizer[AdamOptimizerCfg]): + def __init__( + self, cfg: AdamOptimizerCfg, save_every: Optional[FrequencyScheduler] = None + ) -> None: + super().__init__(cfg, save_every) + + self.smoothers = None + self.means_lr_scheduler = None + self._meta_bufs: dict = {} # reused across steps: radii, visibility buffers + + # NOTE: AdamOptimizer is evaluation-only (3DGS baseline); not used during meta-training. + + def _on_scene_start_impl(self, optimizer_input: OptimizerInput) -> None: + super()._on_scene_start_impl(optimizer_input) + + # assert scene batch size 1 + context = optimizer_input.context + assert ( + context["extrinsics"].shape[0] == context["intrinsics"].shape[0] == 1 + ), "scene batch size > 1 not supported yet..." + + # instantiate Adam optimizers for each parameter type + nr_gaussians = optimizer_input.prev_output.gaussians.means.shape[1] + device = optimizer_input.prev_output.gaussians.means.device + smoother_cls = partial(AdamInputSmoothing, beta1=self.cfg.betas[0], beta2=self.cfg.betas[1], eps=self.cfg.eps, + device=device) + means_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.means.shape[1:]) + scales_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.scales.shape[1:]) + rotations_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.rotations.shape[1:]) + opacities_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.opacities.shape[1:]) + sh0s_smoother = smoother_cls(shape=optimizer_input.prev_output.gaussians.harmonics[..., :, :1].shape[1:]) + + init_gaussians = optimizer_input.prev_output.gaussians + if init_gaussians.harmonics.shape[-1] > 1: + shNs_smoother = smoother_cls(shape=(init_gaussians.harmonics[..., :, 1:]).shape[1:]) + else: + shNs_smoother = None + + self.smoothers = { + "means": means_smoother, + "scales": scales_smoother, + "rotations": rotations_smoother, + "opacities": opacities_smoother, + "sh0s": sh0s_smoother, + "shNs": shNs_smoother, + } + + # get scene extent + scene_scale = optimizer_input.context["scene_scale"] + if scene_scale is None: + scene_scale = torch.ones(1, 1, device=device) + scene_scale = scene_scale[0].item() + + # initialize learning rate scheduler for means + self.means_lr_scheduler = get_expon_lr_func( + lr_init=self.cfg.means_lr_init * scene_scale, + lr_final=self.cfg.means_lr_final * scene_scale, + lr_delay_mult=self.cfg.means_lr_delay_mult, + max_steps=self.cfg.means_lr_max_steps + ) + + def on_scene_end(self) -> None: + super().on_scene_end() + self.smoothers = None + self.means_lr_scheduler = None + self._meta_bufs.clear() + + def _forward_impl( + self, + i: int, + optimizer_input: OptimizerInput, + optimizer_output: OptimizerOutput, + full_context: BatchedViews, + full_target: BatchedViews, + **kwargs + ) -> OptimizerOutput: + + # Timing + self.iter_start.record() + + # Unpack + iter_context: BatchedViews = optimizer_input.context + target: BatchedViews = optimizer_input.target + renderer: Decoder = optimizer_input.renderer + b, v, _, h, w = iter_context["image"].shape + assert b == 1, "Batch size > 1 not supported for post-processing" + + # Log number of gaussians + self.nr_gaussians_log.append( + optimizer_input.prev_output.gaussians.means.shape[1] + ) + + # One optimization step + res = self.apply_one_update_step(i, optimizer_input, optimizer_output, sh_degree=kwargs.get("sh_degree", None)) + gaussians: Gaussians = res[0] + meta_for_adc: dict = res[1] + updates: dict[str, Tensor] = res[2] + grads_raw: dict[str, Tensor] = res[3] + normalized_grads: dict[str, Tensor] = res[4] + learning_rates: dict[str, float] = res[5] + + # Densification and Pruning + if self.cfg.any_adc: + # Apply ADC + self.apply_adc( + i=i, v=v, h=h, w=w, + adc_state=optimizer_input.prev_output.state.adc_state, + gaussians=gaussians, + meta=meta_for_adc, + object_dict_to_adjust=self.smoothers + ) + # ADC changes N → cached buffers are invalid; re-make tensors as fresh leaves. + # torch.cat (used by add_new/relocate) produces a non-leaf even with requires_grad=True, + # so .grad is never populated by backward(). detach() cuts the grad_fn first. + buf_nr_gaussians = self._meta_bufs['N'] + actual_nr_gaussians = gaussians.means.shape[1] + if buf_nr_gaussians != actual_nr_gaussians: + self._meta_bufs.clear() + # TODO Naama: need to think if the detach is necessary (was added during mcmc implementation) + gaussians.means = gaussians.means.detach().requires_grad_(True) + gaussians.scales = gaussians.scales.detach().requires_grad_(True) + gaussians.rotations_unnorm = gaussians.rotations_unnorm.detach().requires_grad_(True) + gaussians.opacities = gaussians.opacities.detach().requires_grad_(True) + gaussians.harmonics = gaussians.harmonics.detach().requires_grad_(True) + + # Timing + self._record_iter_timing() + + # TODO Naama: we can log stats with save_every, but need to change stuff later. + # Log stats — guard with save_every + if grads_raw is not None: # and self.save_every(i + 1, tag="info"): + G = grads_raw["means"].shape[0] + nonzero_grads = [(g.reshape(G, -1) != 0).any(dim=-1) for g in grads_raw.values() if g is not None] + nonzero_grads = torch.stack(nonzero_grads) # [num_params, G] + nonzero_grads = nonzero_grads.any(dim=0) # [G] + self.nr_nonzero_grad_log.append(nonzero_grads.sum().item()) + + # Save updated gaussians (for next iteration) + optimizer_input.prev_output.gaussians = gaussians + + # Info + if self.save_every(i + 1, tag="info"): + + # save gaussians + optimizer_output.gaussian_list.append(gaussians, detach_and_cpu=True, save_to_disk=False, no_cache=False) + + # Save delta stats + assert optimizer_output.info is not None + + # log deltas + if "deltas" not in optimizer_output.info: + optimizer_output.info["deltas"] = [] + optimizer_output.info["deltas"].append({k: v.cpu() for k, v in updates.items() if v is not None}) + + # log gradients + if "grads" not in optimizer_output.info: + optimizer_output.info["grads"] = [] + optimizer_output.info["grads"].append({k: v.cpu() for k, v in grads_raw.items() if v is not None}) + + # log normalized gradients + if "normalized_grads" not in optimizer_output.info: + optimizer_output.info["normalized_grads"] = [] + optimizer_output.info["normalized_grads"].append( + {k: v.cpu() for k, v in normalized_grads.items() if v is not None}) + + # log learning rates + if "learning_rates" not in optimizer_output.info: + optimizer_output.info["learning_rates"] = [] + optimizer_output.info["learning_rates"].append(learning_rates) + + # Check if output_path in kwargs + output_path = kwargs.get("output_path", None) + scene_name = kwargs.get("scene_name", None) + + # Plot stats + # if self.cfg.any_adc: + # self.plot_info(i, output_path=output_path, scene_name=scene_name) + + # Post-update context + target renders + self._save_post_update_renders( + i, optimizer_input, optimizer_output, gaussians, + full_context, full_target, + ) + + # Optimizer output is being changed in place, but for clarity we return it + return optimizer_output + + def apply_one_update_step( + self, i, optimizer_input: OptimizerInput, optimizer_output: OptimizerOutput, sh_degree: int | None = None + ) -> tuple[Gaussians, dict | None, dict, dict[str, Tensor], dict[str, Tensor], dict[str, float]]: + + iter_context = optimizer_input.context + b, v, _, h, w = iter_context["image"].shape + renderer = optimizer_input.renderer + gaussians = optimizer_input.prev_output.gaussians + + # if first iteration + if i == 0: + # assert gaussians stores activated values + assert gaussians.stores_activated, "Gaussians must store activated values." + # deactivate values in-place (avoids allocating new tensors) + gaussians.scales.log_() # [B, N, 3] + gaussians.opacities.logit_() + gaussians.stores_activated = False + # enable requires_grad once — .grad buffers persist across steps, + # so backward() reuses them instead of allocating new tensors each call + gaussians.means.requires_grad_(True) + gaussians.scales.requires_grad_(True) + gaussians.rotations_unnorm.requires_grad_(True) + gaussians.opacities.requires_grad_(True) + gaussians.harmonics.requires_grad_(True) + else: + # assert gaussians does not store activated values + assert not gaussians.stores_activated, "Gaussians must not store activated values." + + # learning rates + # TODO Naama: use current cfg field lr_scheduler, which also defines the lr per param + assert self.means_lr_scheduler is not None, "means_lr_scheduler is not initialized" + means_lr = self.means_lr_scheduler(i) * self.cfg.base_lr + scales_lr = self.cfg.scales_lr * self.cfg.base_lr + rotations_lr = self.cfg.rotations_lr * self.cfg.base_lr + opacities_lr = self.cfg.opacities_lr * self.cfg.base_lr + sh0s_lr = self.cfg.sh0s_lr * self.cfg.base_lr + shNs_lr = self.cfg.shNs_lr * self.cfg.base_lr + + # scale learning rates by number of views in the batch + # means_lr *= v + # scales_lr *= v + # rotations_lr *= v + # opacities_lr *= v + # sh0s_lr *= v + # shNs_lr *= v + + assert ( + iter_context["extrinsics"].shape[0] == iter_context["extrinsics"].shape[0] == 1 + ), "scene batch size > 1 not supported for yet..." + + # unpack gaussians + means = gaussians.means # [B, N, 3] + rotations_unnorm = gaussians.rotations_unnorm # [B, N, 4] + scales_raw = gaussians.scales # [B, N, 3] + opacities_raw = gaussians.opacities # [B, N] + shs = gaussians.harmonics # [B, N, 3, sh_d] + + self.decoder_event_start.record() + loss, grads_raw, meta_for_adc = calc_input_gradients( + iter_context, + means, + scales_raw, + rotations_unnorm, + opacities_raw, + shs, + renderer, + need_2d_grads=self.cfg.need_2d_grads, + chunk_size=self.cfg.input_gradients_chunk_size, + any_adc=self.cfg.any_adc, + sh_degree=sh_degree, + meta_bufs=self._meta_bufs, + opacity_reg_lambda=self.cfg.opacity_reg_lambda, + ) + self.decoder_event_end.record() + + # get updates from adam optimizer + grads_raw = squeeze_grad_dict(grads_raw) + assert self.smoothers is not None, "Smoothers not initialized" + grads_adam = smooth_grads(grads_raw, self.smoothers) + + # update the gaussians parameters + # Batch delta computation for contiguous params with _foreach_mul to reduce kernel launches. + # no_refine flags are handled by excluding the param from the batch (delta stays None). + _grad_lr_pairs = [ + (grads_adam["means"], -means_lr, self.cfg.no_refine_mean), + (grads_adam["scales"], -scales_lr, self.cfg.no_refine_scale), + (grads_adam["rotations"], -rotations_lr, self.cfg.no_refine_rotation), + (grads_adam["opacities"], -opacities_lr, self.cfg.no_refine_opacity), + ] + _active_grads = [g for g, lr, skip in _grad_lr_pairs if not skip] + _active_lrs = [lr for g, lr, skip in _grad_lr_pairs if not skip] + _active_deltas = torch._foreach_mul(_active_grads, _active_lrs) if _active_grads else [] + + _delta_iter = iter(_active_deltas) + delta_means = next(_delta_iter) if not self.cfg.no_refine_mean else None + delta_scales_raw = next(_delta_iter) if not self.cfg.no_refine_scale else None + delta_rotations_unnorm = next(_delta_iter) if not self.cfg.no_refine_rotation else None + delta_opacities_raw = next(_delta_iter) if not self.cfg.no_refine_opacity else None + + # SH deltas stay separate (non-contiguous slice views) + delta_sh0s = None if self.cfg.no_refine_sh0 else -sh0s_lr * grads_adam["sh0s"] + delta_shNs = None + if grads_adam["shNs"] is not None and not self.cfg.no_refine_shN: + delta_shNs = -shNs_lr * grads_adam["shNs"] + + # step — batch contiguous params with _foreach_add_ to reduce kernel launches; + # SH slice views are non-contiguous so they stay separate + _params = [means, scales_raw, rotations_unnorm, opacities_raw] + _deltas = [delta_means, delta_scales_raw, delta_rotations_unnorm, delta_opacities_raw] + _active = [(p, d) for p, d in zip(_params, _deltas) if d is not None] + if _active: + torch._foreach_add_([p for p, d in _active], [d for p, d in _active]) + self.safe_inplace_update(delta_sh0s, shs[..., 0:1]) + self.safe_inplace_update(delta_shNs, shs[..., 1:]) + + # assign (means/scales/rotations/harmonics are the same objects; in-place ops above + # already updated their storage. opacities_raw is a view — do NOT reassign + # gaussians.opacities here, as that would replace the persistent leaf with a non-leaf + # view and break retain_grad() on subsequent steps.) + gaussians.means = means + gaussians.scales = scales_raw + gaussians.rotations_unnorm = rotations_unnorm + gaussians.harmonics = shs + + # group updates + updates = { + "means": delta_means, + "scales": delta_scales_raw, + "rotations": delta_rotations_unnorm, + "opacities": delta_opacities_raw, + "sh0s": delta_sh0s, + "shNs": delta_shNs, + } + + learning_rates = { + "means": means_lr, + "scales": scales_lr, + "rotations": rotations_lr, + "opacities": opacities_lr, + "sh0s": sh0s_lr, + "shNs": shNs_lr, + } + + return gaussians, meta_for_adc, updates, grads_raw, grads_adam, learning_rates + + @staticmethod + def safe_inplace_update(delta_means: Tensor | None, means: Tensor): + if delta_means is not None: + means += delta_means diff --git a/optgs/scene_trainer/optimizer/optimizer_knn_based.py b/optgs/scene_trainer/optimizer/optimizer_knn_based.py new file mode 100644 index 0000000000000000000000000000000000000000..974235fb4cd4bcf830c2559a13a7401bd455d065 --- /dev/null +++ b/optgs/scene_trainer/optimizer/optimizer_knn_based.py @@ -0,0 +1,2627 @@ +import math +import random +from dataclasses import dataclass +from typing import Literal, Optional, Any + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms as T +from einops import rearrange +from torch import nn, Tensor + +from optgs.dataset.data_types import BatchedExample, DataShim +from optgs.dataset.data_types import BatchedViews +from optgs.dataset.shims.patch_shim import apply_patch_shim +from optgs.geometry.projection import project, sample_image_grid +from optgs.misc.general_utils import SkipBatchException +from optgs.misc.io import FrequencyScheduler +from optgs.model.decoder.decoder import Decoder +from optgs.model.encoder.layer import ResNetFeatureWarpper +from optgs.model.types import Gaussians +from optgs.scene_trainer.common.gaussian_adapter import build_covariance +from optgs.scene_trainer.initializer import InitializerCfg, InitializerColmapCfg, InitializerEdgsCfg, \ + InitializerRandomCfg, InitializerPointcloudCfg +from optgs.scene_trainer.initializer import InitializerPlyCfg +from optgs.scene_trainer.initializer.initializer_resplat import ResplatInitializerCfg +from optgs.scene_trainer.optimizer.optimizer import OptimizerInput, LearnedOptimizer, OptimizerOutput, OptimizerState, \ + OptimizerPreviousOutput, OptimizerCfg +from optgs.scene_trainer.optimizer.optimizer_utils import Number3DGSCfg, Bool3DGSCfg +from optgs.scene_trainer.optimizer.optimizer_utils import unpack_gaussians, \ + get_visibility_contribution_from_gaussian_obj + +try: + from optgs.model.encoder.point_transformer.layer import (PlainPointTransformer, SubsampleBlock, PointLinearWrapper, + MultiScalePointTransformer, + MultViewLowresAttn) +except: + pass + +try: + from simple_knn._C import distCUDA2 +except: + pass + +from optgs.scene_trainer.optimizer.layer import CustomGroupNorm, AdamInputSmoothing, SlicedG3RNorm +from optgs.scene_trainer.initializer.initializer import InitializerOutput +from optgs.scene_trainer.optimizer.time_embed import get_embedder, TimeEncodingWrapper + +from optgs.loss.loss_depth_smooth import get_smooth_loss +from optgs.scene_trainer.optimizer.optimizer_utils import ( + inner_loss_for_input_gradients, + chunk_index_iter, + split_grads, + get_gaussian_param_slices, + get_gaussian_param_sizes, + pack_gaussians, +) + +_IMAGENET_NORM = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + +@dataclass +class KnnBasedOptimizerCfg(OptimizerCfg): + name: Literal["knn_based", "resplat_v1", "resplat_v2", "clogs", "l2s"] # TODO (release) remove clogs + # iterative refine + no_render_error: bool + input_error_shallow_resnet_feature: bool + input_error_resnet_feature_layers: int + refine_sh_only: bool + num_basic_refine_blocks: int + num_refine_blocks: int + concat_init_state: bool # always concat init state during updates + replace_init_state: bool # always use the init state during updates + state_channels: int + refine_block_rmsnorm: bool + refine_block_layernorm: bool + pt_qk_norm: bool + norm_pt_block: bool + refine_gaussian_multiple: int # predict more gaussian residuals based on the previous gaussian center + refine_residual_init_state: bool # add residual connection in the prediction head to the inital state + clamp_refine_max_scale: float + clamp_min_scale: float | int + clamp_min_raw_scales: float | int + clamp_max_raw_scales: float | int + clamp_min_raw_opacities: float | int + clamp_max_raw_opacities: float | int + clamp_min_sh0: float | int + clamp_max_sh0: float | int + clamp_min_shs: float | int + clamp_max_shs: float | int + clamp_shs_soft: bool + + gaussian_head_multiple: int # use multiple non-weight sharing heads to predict multiple gaussians + gradient_update_scale: float | int + input_gradient_with_ssim_loss: bool + update_attn_proj_channels: int | None + update_no_knn_attn: bool + update_no_tran_block_norm: bool + update_tran_block_act: str | None + multi_gaussian_scale_smaller: bool + refine_condition_pt_feature: bool + reinit_gaussian_when_refine_multiple: bool + refine_same_num_points: bool # when init_gaussian_multiple > 1, refine directly works on it instead of subsampling points + + refine_knn_samples: int + refine_multi_scale_pt: bool + + # KNN + use_fused_attn: bool + prune_invisible_gaussians: bool + knn_idx_update_every: int + + # point transformer + pt_heads: int + + # inputs + input_alpha: bool + input_depth: bool + input_depth_smooth_error: bool + + # input error + input_error: bool # render error as input to the refine head + input_error_rgb_no_shuffle: bool # sample single pixel instead of pixel unshuffling + input_error_add_rgb_feature: bool + + # resnet + input_error_resnet_feature: bool + input_error_cache_resnet_feature: bool + input_error_no_freeze_resnet_feature: bool + + # number of views for render error + input_error_num_views: int + input_error_additional_cross_attn: bool + input_error_num_intermediate_views: int + + # render error with remaining context views + input_error_remain_context: bool + input_error_merge_remain_context: bool + input_error_warp_remain_context: bool + input_error_random_num_remain_context: bool + input_error_num_remain_context_test: int + + # render error mv attn + input_error_mv_attn: bool + input_error_mv_attn_blocks: int + + # refine global attention + refine_with_mv_attn: bool + refine_with_mv_attn_lowres: bool + refine_no_mv_attn: bool # remove only the attn + mv_attn_conv_with_norm: bool # unet-attn conv with norm + refine_mv_shuffle_attn: bool # use pixel shuffle to save computation instead of unet + refine_mv_attn_with_pos_enc: bool + refine_shuffle_attn_no_norm: bool + refine_mv_unimatch_attn: bool + + # input gradients + input_gradient: bool + input_gradient_log: bool + input_gradient_log_clip_deltas: float | int + input_gradient_scale: float | int + input_gradient_same_loss: bool # use the same loss as the gaussian update + input_gradient_loss_reduction: str + scale_residual_grads: bool + + # sliding window + window_local_refine: bool # refine each local window separately and then combine all windows + window_global_refine: bool # refine all windows together + window_local_global_refine: bool # first refine each window seprately, and then refine all windows together + + # sliding window update instead of update all gaussians together + update_window_size: int + local_gaussian_render: bool + + # time encoding + use_time_encoding: bool + time_encoding_max_steps: int + + train_global_update_only: bool + + # random size refine + # update more for low resolution, less for high + random_update_with_size: bool + + # amp + use_amp: bool + pt_head_amp: bool + pt_update_amp: bool + + use_checkpointing: bool + recurrent_use_checkpointing: bool + + # Debugging + debug_refine_update_module: bool + + # Normalizing input + input_gradient_normalize: bool + input_gradient_normalize_type: str + input_normalize_state: bool + input_normalize_gaussians: bool + + # State scaling + predict_state_scale: bool + predict_state_scale_norm: bool # whether to normalize the state before scaling + + # Use optimizer without condition features + init_state_wo_features: bool + init_state_type: Literal["random", "constant"] + init_state_scale: float | int + + opt_scales_before_act: bool # optimize scale before activation (raw -> exp -> scale -> log -> raw) + + # Preprocessing the init gaussians + scale_initial_opacities: float | int + + # Experimental + experimental_run: bool + experimental_update: Bool3DGSCfg + experimental_use_grads: bool + experimental_use_norm_grads: Bool3DGSCfg + experimental_lr: Number3DGSCfg + # Deactivate gaussians + local_prune_zero_radii: bool + local_prune_low_weights: bool + local_prune_low_weights_thresh: float | int + update_only_nonzero_grad: bool + + # update learn residual state + residual_state: bool + + # Update head + update_head_layer_num: int + update_head_concat_img: bool + update_head_act: str | None # update_head activation to predict the deltas + update_head_final_act: str | None # final activation in the update_head + update_head_hidden_dim_matches: str # rebuttal or submission version + + update_head_scale_mag: bool # predict deltas as scale * 0.01 * jnp.exp(mag * 0.01) + update_head_scalar_scale: bool # predict deltas as scalar * delta / norm(delta) + update_head_scalar_scale_act: str # activation for the scalar scale output + + # Per-parameter-group update head (Feature A) + update_head_per_param_heads: bool # separate heads per param group, each with own normalize+scale + update_head_per_param_hidden_dim: int # hidden dim for per-param heads (SH head gets 2x) + # Per-parameter scalar scales (Feature B) — requires update_head_scalar_scale=true + update_head_per_param_scales: bool # per-group scalar scales instead of one global scalar + + # Config from initializer + sh_d: int | None + init_gaussian_param_num: int | None = None + init_sh_d: int | None = None + # Fow initialization from feed forward, gaussians are aligned with pixels. + init_gaussian_multiple: int | None = None + latent_downsample: int | None = None + + delta_adam_combine_step: int = 0 # combine deltas and adam updates + + def update(self, initializer_cfg: InitializerCfg): + """ Update the optimizer config based on the initializer config""" + + # General settings + self.init_gaussian_param_num = initializer_cfg.get_gaussian_param_num() + self.init_sh_d = initializer_cfg.get_sh_d() + if self.sh_d is None: + # get sh_d from initializer if not set + self.sh_d = initializer_cfg.get_sh_d() + + # Settings specific to DepthSplat initializer + if isinstance(initializer_cfg, ResplatInitializerCfg): + self.latent_downsample = initializer_cfg.latent_downsample + self.init_gaussian_multiple = initializer_cfg.init_gaussian_multiple + + # update proj channels + if self.refine_condition_pt_feature: + self.condition_channels = initializer_cfg.gaussian_regressor_channels + else: + self.condition_channels = initializer_cfg.get_pt_in_channels() + # Settings specific to Colmap initializer + elif isinstance(initializer_cfg, + (InitializerPlyCfg, InitializerColmapCfg, InitializerEdgsCfg, InitializerRandomCfg, + InitializerPointcloudCfg)): + # Since pixels and gaussians are not alligned, we can not use pixel attributes + assert not self.input_error, "The error calculation assumes per pixel gaussians" + assert not self.update_head_concat_img + assert not self.input_alpha + assert not self.local_gaussian_render, "The local rendering assumes per view gaussians" + + assert self.init_state_wo_features, "Colmap initializer does not have point features, init_state_wo_features must be set to True" + + self.init_gaussian_multiple = 1 + self.latent_downsample = 1 + else: + raise ValueError(f"Unsupported initializer config type: {type(initializer_cfg)}") + + +class KnnBasedOptimizerState: + # TODO Naama: OptimizerState class already exists + def __init__(self, state: torch.Tensor): + self.state = state + + def clone(self, clone_mask: torch.Tensor, zero_t: bool) -> None: + cloned_state = self.state[clone_mask] + if zero_t: + cloned_state = torch.zeros_like(cloned_state) + self.state = torch.cat([self.state, cloned_state], dim=0) + + def split(self, split_mask, num_splits: int, zero_t: bool) -> None: + states_to_split = self.state[split_mask] + split_states = states_to_split.chunk(num_splits, dim=0) + new_states = [] + for i in range(num_splits): + if zero_t: + new_states.append(torch.zeros_like(split_states[i])) + else: + new_states.append(split_states[i]) + self.state = torch.cat([self.state, *new_states], dim=0) + + def replace(self, from_indices: torch.Tensor, dest_indices: torch.Tensor, zero_t: bool) -> None: + if zero_t: + self.state[dest_indices] = 0.0 + else: + self.state[dest_indices] = self.state[from_indices] + + def prune(self, prune_mask: torch.Tensor) -> None: + self.state = self.state[~prune_mask] + + def add(self, num_new: int) -> None: + if num_new <= 0: + return + device = self.state.device + dtype = self.state.dtype + input_dim = self.state.shape[1:] + self.state = torch.cat([self.state, torch.zeros((num_new, *input_dim), device=device, dtype=dtype)], dim=0) + + def extend(self, num_new): + self.add(num_new) + + +class Abs(nn.Module): + def forward(self, x): + return torch.abs(x) + + +def get_activation_cls(activation: Optional[str] = None): + if activation in ['none', None, 'identity']: + return nn.Identity + elif activation == 'tanh': + return nn.Tanh + elif activation == "gelu": + return nn.GELU + elif activation == 'sigmoid': + return nn.Sigmoid + elif activation == 'relu': + return nn.ReLU + elif activation == "softplus": + return nn.Softplus + elif activation == "abs": + return Abs + else: + raise ValueError(f"Unsupported activation: {activation}") + + +class KnnBasedOptimizer(LearnedOptimizer[KnnBasedOptimizerCfg]): + OPTIMIZER_NAME = "knn_based" + OPTIMIZER_NAME_ALIASES: tuple[str, ...] = () + + def __init__(self, cfg: KnnBasedOptimizerCfg, save_every: Optional[FrequencyScheduler] = None) -> None: + valid = {self.OPTIMIZER_NAME, *self.OPTIMIZER_NAME_ALIASES} + assert cfg.name in valid, f"Expected optimizer name {valid}, got {cfg.name}" + + super().__init__(cfg, save_every) + + if self.cfg.residual_state: + assert not self.cfg.refine_residual_init_state + + # State channel + self.state_channels = self.cfg.state_channels + + # time embedder + if self.cfg.use_time_encoding: + self.time_encoder_fn, self.time_embedding_dim = get_embedder(multires=6) + else: + self.time_encoder_fn = None + self.time_embedding_dim = 0 + + # update_proj + if not self.cfg.init_state_wo_features: + self.update_proj = nn.Conv2d(self.cfg.condition_channels, self.state_channels, 1) + + channels, in_channels, update_gaussian_param_num, out_channels, error_features_channels = ( + self.define_update_channels(self.cfg.init_gaussian_param_num)) + self.error_features_channels = error_features_channels + self.gaussian_param_num = out_channels + + if self.cfg.input_error: + + self.update_feature = self.get_input_error_feature_extractor() + if self.cfg.input_error_add_rgb_feature: + if self.cfg.init_gaussian_multiple == 4: # re10k + self.update_rgb_error_proj = nn.Sequential( + nn.Linear(3, error_features_channels), + nn.LayerNorm(error_features_channels) + ) + else: + self.update_rgb_error_proj = nn.Sequential( + nn.Linear(3 * self.cfg.latent_downsample ** 2, error_features_channels), + nn.LayerNorm(error_features_channels) + ) + self.update_input_norm = self.get_update_input_norm(in_channels) + self.update_module = self.get_update_module(channels, in_channels) + + # predict multiple gaussians + out_channels = out_channels * self.cfg.refine_gaussian_multiple + + if not self.cfg.refine_same_num_points: + out_channels = out_channels * self.cfg.init_gaussian_multiple + + # make sure the input size of the gaussian head is updated accordingly + if self.cfg.use_time_encoding: + channels += self.time_embedding_dim + + # Compute per-param group dims (needed by per_param_heads and per_param_scales) + if self.cfg.update_head_per_param_heads or self.cfg.update_head_per_param_scales: + self._per_param_group_dims = self._compute_per_param_group_dims(out_channels) + + # Scaling state for update head + if self.cfg.predict_state_scale: + self.state_scale_head = self.get_state_scale_head(in_channels) + + self.update_head = self.get_update_head(in_channels, channels, out_channels) + + # multiple gaussian heads to predict multiple gaussians + if self.cfg.gaussian_head_multiple > 1: + self.update_head_list = self.get_update_head_list(channels, out_channels) + + # Define error calculation + # add global attention to the render error + if self.cfg.input_error and self.cfg.input_error_mv_attn: + assert self.cfg.input_error_resnet_feature + self.update_error_attn = nn.ModuleList([ + MultViewLowresAttn(error_features_channels) + for _ in range(self.cfg.input_error_mv_attn_blocks) + ]) + + self.param_slices = get_gaussian_param_slices(self.cfg.sh_d) + + def _reset_knn_caches(self) -> None: + """Invalidate cached KNN indices on all point-transformer sub-modules. + + Must be called whenever the number of Gaussians changes (e.g. after add_new) + so the next forward recomputes KNN from scratch instead of using stale indices + that index out-of-bounds into the grown point cloud. + """ + for module in self.modules(): + if hasattr(module, "cache_knn_idx"): + module.cache_knn_idx = None + + @property + def adc_object_dict_to_adjust(self): + if self.cfg.any_adc: + object_dict: dict[str, Any] = {"depthsplat_state": None} + # For ADC + if self.cfg.input_gradient_normalize and self.cfg.input_gradient_normalize_type == "adam": + object_dict.update(self.update_input_norm.subgroups_view(self.param_slices)) + else: + return None + + return object_dict + + def _compute_per_param_group_dims(self, out_channels): + """Compute per-parameter-group output dimensions from total out_channels. + + Returns a dict {group_name: dim} in the same order as split_delta_gaussians. + Accounts for no_refine_rotation, no_refine_mean, refine_sh_only, and multipliers. + """ + + # TODO Naama: allow combination of no_refine_* + p = get_gaussian_param_sizes(self.cfg.sh_d) + + all_params = [ + ("means", "means"), + ("scales", "scales"), + ("rotations", "quats"), + ("opacities", "opacities"), + ("shs", "shs"), + ] + + if self.cfg.refine_sh_only: + excluded = {"means", "scales", "rotations", "opacities"} + elif self.cfg.no_refine_rotation: + excluded = {"rotations"} + elif self.cfg.no_refine_mean: + excluded = {"means"} + else: + excluded = set() + + multiplier = self.cfg.refine_gaussian_multiple + if not self.cfg.refine_same_num_points: + multiplier *= self.cfg.init_gaussian_multiple + + group_dims = {name: p[key] * multiplier for name, key in all_params if name not in excluded} + + assert sum(group_dims.values()) == out_channels, ( + f"Per-param group dims {dict(group_dims)} sum={sum(group_dims.values())} != out_channels={out_channels}" + ) + return group_dims + + def _build_per_param_heads(self, channels, out_channels): + """Build per-parameter-group heads (Feature A). + + Each head: Linear(channels, hidden) -> act -> Linear(hidden, dim+1) + The +1 is a per-group scalar scale. Each head independently normalizes + scales. + """ + act_cls = get_activation_cls(self.cfg.update_head_act) + hidden_dim = self.cfg.update_head_per_param_hidden_dim + + # Set up scale activation (shared across all per-param heads) + scale_act_name = self.cfg.update_head_scalar_scale_act + init_bias_map = {'softplus': -1, 'relu': 1e-8, 'abs': 1e-8} + if scale_act_name not in init_bias_map: + raise ValueError(f"Unsupported scalar_scale_act: {scale_act_name}") + act_class = get_activation_cls(scale_act_name) + self.scale_act = act_class(beta=1) if scale_act_name == 'softplus' else act_class() + + heads = nn.ModuleDict() + for name, dim in self._per_param_group_dims.items(): + # SH head gets 2x hidden dim (more outputs to predict) + h = hidden_dim * 2 if name == "shs" else hidden_dim + + layers = [nn.Linear(channels, h), act_cls()] + for _ in range(self.cfg.update_head_layer_num - 2): + layers += [nn.Linear(h, h), act_cls()] + layers.append(nn.Linear(h, dim + 1)) # +1 for scalar scale + + head = nn.Sequential(*layers) + + # Zero-init last layer (deltas start at 0) + nn.init.zeros_(head[-1].weight) + nn.init.zeros_(head[-1].bias) + # Init scale bias + nn.init.constant_(head[-1].bias[-1], init_bias_map[scale_act_name]) + + heads[name] = head + + return heads + + def get_update_head(self, in_channels, channels, out_channels): + update_head_activation_cls = get_activation_cls(self.cfg.update_head_act) + final_head_activation_cls = get_activation_cls(self.cfg.update_head_final_act) + + # skip connection to the image color + if self.cfg.update_head_concat_img: + channels += 3 * (self.cfg.latent_downsample ** 2) + + # Feature A: per-parameter-group heads (early return — builds ModuleDict instead of Sequential) + if self.cfg.update_head_per_param_heads: + assert not self.cfg.update_head_scale_mag, "update_head_scale_mag not supported with per_param_heads" + assert not self.cfg.update_head_per_param_scales, "per_param_heads already includes per-group scales" + return self._build_per_param_heads(channels, out_channels) + + # predict delta = scale * 0.01 * jnp.exp(mag * 0.01) + if self.cfg.update_head_scale_mag: + out_channels = out_channels * 2 + + if self.cfg.update_head_scalar_scale: + if self.cfg.update_head_per_param_scales: + # Feature B: one scalar scale per parameter group + out_channels = out_channels + len(self._per_param_group_dims) + else: + out_channels = out_channels + 1 + + # Determine hidden layer size + # TODO: update_head_hidden_dim_source should be "output" (out_channels). + # Using "input" currently as default to reproduce rebuttal results. + if self.cfg.update_head_hidden_dim_matches == "input": + hidden_dim = channels # rebuttal version + else: + hidden_dim = out_channels # submitted version + + # Build update head + layers_list = [ + nn.Linear(channels, hidden_dim), + update_head_activation_cls() + ] + for i in range(self.cfg.update_head_layer_num - 2): + layers_list += [ + nn.Linear(hidden_dim, hidden_dim), + update_head_activation_cls(), + ] + + layers_list += [ + nn.Linear(hidden_dim, out_channels), + final_head_activation_cls() + ] + update_head = nn.Sequential(*layers_list) + + # init the delta as 0 + nn.init.zeros_(update_head[-2].weight) + if final_head_activation_cls == torch.nn.Sigmoid: + desired_init_delta = 0.005 + bias = math.log(desired_init_delta / (1 - desired_init_delta)) # ~= -4.6 + nn.init.constant_(update_head[-2].bias, bias) + else: + nn.init.zeros_(update_head[-2].bias) + + # Scalar scale output + if self.cfg.update_head_scalar_scale: + # Set the initial scale to very low number, to get the gradients flow + init_bias_map = { + 'softplus': -1, + 'relu': 1e-8, + 'abs': 1e-8, + } + + act_name = self.cfg.update_head_scalar_scale_act + if act_name not in init_bias_map: + raise ValueError(f"Unsupported scalar_scale_out_act: {act_name}") + + # Initialize bias for scale output(s) + if self.cfg.update_head_per_param_scales: + num_groups = len(self._per_param_group_dims) + for i in range(num_groups): + nn.init.constant_(update_head[-2].bias[-(num_groups - i)], init_bias_map[act_name]) + else: + nn.init.constant_(update_head[-2].bias[-1], init_bias_map[act_name]) + + # Create activation + act_class = get_activation_cls(act_name) + self.scale_act = act_class(beta=1) if act_name == 'softplus' else act_class() + + return update_head + + def get_update_head_list(self, channels, out_channels): + update_head_activation = get_activation_cls(self.cfg.update_head_act) + final_head_activation = get_activation_cls(self.cfg.final_head_act) + update_head_list = nn.ModuleList() + for i in range(self.cfg.gaussian_head_multiple - 1): + update_head_list.append( + nn.Sequential( + nn.Linear(channels, channels), + update_head_activation(), + nn.Linear(channels, out_channels), + final_head_activation() + ) + ) + + # init the delta as 0 + nn.init.zeros_(update_head_list[i][-2].weight) + nn.init.zeros_(update_head_list[i][-2].bias) + + return update_head_list + + def get_update_input_norm(self, in_channels): + if self.cfg.input_gradient_normalize: + assert self.cfg.input_gradient, "for now we only normalize when using gradient as input" + if self.cfg.input_gradient_normalize_type == 'layer': + return nn.LayerNorm(in_channels) + elif self.cfg.input_gradient_normalize_type == 'group': + return CustomGroupNorm([self.gaussian_param_num, self.state_channels, self.gaussian_param_num]) + elif self.cfg.input_gradient_normalize_type == 'batch': + return nn.BatchNorm1d(in_channels, affine=False) + elif self.cfg.input_gradient_normalize_type == 'g3r': + return SlicedG3RNorm(in_channels, slice(-self.gaussian_param_num, None)) + elif self.cfg.input_gradient_normalize_type == 'adam': + assert not self.cfg.input_gradient_log and self.cfg.input_gradient_scale == 1 + return AdamInputSmoothing(input_slice=slice(-self.gaussian_param_num, None)) + else: + raise ValueError(f"normalization type not supported {self.cfg.input_gradient_normalize_type}") + else: + return nn.Identity() + + def get_update_module(self, channels, in_channels): + if not self.cfg.debug_refine_update_module: + return None + + if self.cfg.refine_multi_scale_pt: + update_module = nn.Sequential( + PointLinearWrapper(in_channels, channels), + MultiScalePointTransformer(channels, + self.cfg.refine_knn_samples, + subsample_method=self.cfg.subsample_method, + attn_proj_channels=self.cfg.update_attn_proj_channels, + ) + ) + else: + update_module = nn.Sequential( + PointLinearWrapper(in_channels, channels), + PlainPointTransformer(channels, self.cfg.refine_knn_samples, + num_blocks=self.cfg.num_basic_refine_blocks, + qk_norm=self.cfg.pt_qk_norm, + norm_pt_block=self.cfg.norm_pt_block, + num_heads=self.cfg.pt_heads, + no_rpe=True, + no_attn=self.cfg.update_no_knn_attn, + no_norm=self.cfg.update_no_tran_block_norm, + act=self.cfg.update_tran_block_act, + attn_proj_channels=self.cfg.update_attn_proj_channels, + with_mv_attn=self.cfg.refine_with_mv_attn, + with_mv_attn_lowres=self.cfg.refine_with_mv_attn_lowres, + no_mv_attn=self.cfg.refine_no_mv_attn, + conv_with_norm=self.cfg.mv_attn_conv_with_norm, + mv_shuffle_attn=self.cfg.refine_mv_shuffle_attn, + with_pos_enc=self.cfg.refine_mv_attn_with_pos_enc, + shuffle_attn_no_norm=self.cfg.refine_shuffle_attn_no_norm, + mv_unimatch_attn=self.cfg.refine_mv_unimatch_attn, + use_checkpointing=self.cfg.use_checkpointing, + use_fused_attn=self.cfg.use_fused_attn, + knn_idx_update_every=self.cfg.knn_idx_update_every + ) + ) + + # Init normalization layers + if self.cfg.input_normalize_state: + for block in update_module[1].blocks: + nn.init.zeros_(block.norm1.bias) + nn.init.zeros_(block.norm2.bias) + nn.init.ones_(block.norm1.weight) + nn.init.ones_(block.norm2.weight) + + return update_module + + def get_state_scale_head(self, in_channels): + state_scale_head = nn.Sequential( + nn.Linear(in_channels, in_channels // 2), + nn.ReLU(), + nn.Linear(in_channels // 2, 1), + nn.ReLU() + ) + + # Init the scale to 1 + # nn.init.zeros_(state_scale_head[-2].weight) + nn.init.ones_(state_scale_head[-2].bias) + + return state_scale_head + + def define_update_channels(self, init_gaussian_param_num): + if self.cfg.init_gaussian_multiple > 1: + gaussian_param_num = init_gaussian_param_num // self.cfg.init_gaussian_multiple + else: + gaussian_param_num = init_gaussian_param_num + + # no pixel offset + gaussian_param_num -= 2 + + # update position + gaussian_param_num += 3 + + # SHs + if self.cfg.sh_d != self.cfg.init_sh_d: + gaussian_param_num += 3 * (self.cfg.sh_d - self.cfg.init_sh_d) + + # Get error channels + if self.cfg.input_error: + error_channels, error_feature_channels = self.define_error_channels() + else: + error_channels, error_feature_channels = 0, 0 + + # Get gradient channels + if self.cfg.input_gradient: + gradient_channels = gaussian_param_num * self.cfg.init_gaussian_multiple + else: + gradient_channels = 0 + + # final input channels + input_signal_channels = gradient_channels + error_channels + + if self.cfg.refine_same_num_points: + in_channels = (gaussian_param_num + + self.state_channels + + input_signal_channels) + else: + in_channels = (gaussian_param_num * self.cfg.init_gaussian_multiple + + self.state_channels + + input_signal_channels) + + if self.cfg.concat_init_state: + in_channels += self.state_channels + + out_channels = gaussian_param_num + if self.cfg.no_refine_mean: + out_channels -= 3 + channels = self.state_channels + if self.cfg.input_alpha: + # pixel shuffle the alpha channel to the latent resolution + in_channels += self.cfg.latent_downsample ** 2 # alpha + if self.cfg.input_depth or self.cfg.input_depth_smooth_error: + # pixel shuffle the depth channel to the latent resolution + in_channels += self.cfg.latent_downsample ** 2 # depth + return channels, in_channels, gaussian_param_num, out_channels, error_feature_channels + + def define_error_channels(self): + if self.cfg.no_render_error: + error_channels = 0 + else: + if self.cfg.input_error_rgb_no_shuffle: + error_channels = 3 + else: + error_channels = 3 * self.cfg.latent_downsample ** 2 + + if self.cfg.input_error_resnet_feature: + # 3 scales: 1/2, 1/4, 1/8, channels: 64, 64, 128 + if self.cfg.input_error_resnet_feature_layers in (18, 34): + error_feature_channels = 64 + 64 if self.cfg.input_error_shallow_resnet_feature else 64 + 64 + 128 + elif self.cfg.input_error_resnet_feature_layers == 50: + error_feature_channels = 64 + 256 + 512 + else: + raise NotImplementedError + error_channels = error_feature_channels + else: + error_feature_channels = 256 + + return error_channels, error_feature_channels + + def optimizer_preprocessing(self, optimizer_input: OptimizerInput, from_init: bool) -> None: + if self.cfg.input_error_remain_context or self.cfg.input_error_merge_remain_context: + assert self.cfg.input_error_cache_resnet_feature + + # Image dimensions + context = optimizer_input.context + b, v, _, h, w = context["image"].shape + + # Prepare Gaussians + if from_init: + # Scale initial opacities (in normal scale) + # TODO Naama: add option to reset opacities and randomly reset/scale opacities of intermidiate updates + opacities = optimizer_input.prev_output.gaussians.opacities # post activation, in [0, 1] + scaled_opacities = opacities * self.cfg.scale_initial_opacities # default to 1.0 + optimizer_input.prev_output.gaussians.opacities = scaled_opacities + + # Process shs + shs = optimizer_input.prev_output.gaussians.harmonics # [B, N, 3, init_sh_d] + init_sh_d = shs.shape[-1] + if init_sh_d != self.cfg.sh_d: + if init_sh_d > self.cfg.sh_d: + shs = shs[:, :, :, :self.cfg.sh_d] # truncate [B, N, 3, sh_d] + else: + pad = self.cfg.sh_d - init_sh_d + shs = F.pad(shs, (0, pad), "constant", 0) + optimizer_input.prev_output.gaussians.harmonics = shs + + # Right now, this does not do anything, since we do not use windows + local_window_update, test_window_size, window_end, window_start = self.get_window_size(v) + optimizer_input.additional_info = local_window_update, test_window_size, window_end, window_start + self.update_gaussians_for_window(v, h, w, optimizer_input) + + # Prepare state + # Gaussians dimensions + n = optimizer_input.prev_output.gaussians.means.shape[1] + vector_state = self.get_vector_state(b, v, n, optimizer_input, from_init) + + if from_init: + # Set everything so that the optimizer isn't aware whether it's a new scene + # Convert InitializerOutput to OptimizerPreviousOutput + optimizer_input.prev_output = OptimizerPreviousOutput(gaussians=optimizer_input.prev_output.gaussians, + state=OptimizerState()) + optimizer_input.prev_output.state.state = vector_state + # init_state captures the scene-start state used by some experiments; + # only set it on a fresh scene so replay-buffer resumes preserve the original value. + if from_init: + optimizer_input.prev_output.state.init_state = vector_state + + def update_gaussians_for_window(self, v, h, w, optimizer_input): + # Get window parameters and set gaussians accordingly + local_window_update, test_window_size, window_end, window_start = optimizer_input.additional_info + + if local_window_update and self.cfg.local_gaussian_render: + init_gaussians = optimizer_input.prev_output.gaussians + # select subset of gaussians + init_gaussians_subset = select_gaussian_subset(init_gaussians, window_start, window_end, + v=v, + h=h // self.cfg.latent_downsample, + w=w // self.cfg.latent_downsample, + ) + optimizer_input.prev_output.gaussians = init_gaussians_subset + + def _forward_impl( + self, + i: int, + optimizer_input: OptimizerInput, + optimizer_output: OptimizerOutput, + full_context: BatchedViews, + full_target: BatchedViews, + **kwargs + ) -> OptimizerOutput: + + # Timing + self.iter_start.record() + + # Unpack + iter_context: BatchedViews = optimizer_input.context + target: BatchedViews = optimizer_input.target + renderer: Decoder = optimizer_input.renderer + b, v, _, h, w = iter_context["image"].shape + assert b == 1, "Batch size > 1 not supported for post-processing" + + # Log number of gaussians + self.nr_gaussians_log.append( + optimizer_input.prev_output.gaussians.means.shape[1] + ) + + # One optimization step + res = self.apply_one_update_step( + i, optimizer_input, optimizer_output + ) + updated_gaussians: Gaussians = res[0] + state: Tensor = res[1] + meta_for_adc: dict = res[2] + updates: dict[str, Tensor] = res[3] + grads_raw: Tensor | None = res[4] + normalized_grads: Tensor | None = res[5] + scaled_state: Tensor | None = res[6] + gaussians_sel: Tensor | None = res[7] + + # Timing + self._record_iter_timing() + + # Log stats + if grads_raw is not None: + grads = grads_raw # [B, G, D] + nonzero_grads = (grads != 0).any(-1) # [B, G] + # Filter out strictly zero gradients for logging + grads = grads[nonzero_grads].unsqueeze(0) # [1, N_nonzero, D] + assert nonzero_grads.shape[0] == 1 + self.nr_nonzero_grad_log.append(nonzero_grads[0].sum().item()) + + # Local ADC + # if optimizer_output.t == 500: + # weight_vis_contribution, _ = get_visibility_contribution_from_gaussian_obj(iter_context, updated_gaussians) # [N] + # prune_mask = weight_vis_contribution < 5 + # print(f"Pruning {torch.sum(prune_mask)} gaussians out of {prune_mask.shape[0]} at iteration {i}") + # updated_gaussians = updated_gaussians[:, ~prune_mask] + # state = state[~prune_mask] + # if self.cfg.normalize_update_input and self.cfg.normalize_update_input_type == "adam": + # if not self.update_input_norm.is_reset(): + # self.update_input_norm.prune(prune_mask) + + # Densification and Pruning + if self.cfg.any_adc: + + n_before_adc = updated_gaussians.means.shape[1] + + # Prepare objects to adjust during ADC + object_dict = self.adc_object_dict_to_adjust + object_dict["depthsplat_state"] = KnnBasedOptimizerState(state) + object_dict["depthsplat_init_state"] = KnnBasedOptimizerState(optimizer_input.prev_output.state.init_state) + + # Apply ADC + self.apply_adc( + i=i, v=v, h=h, w=w, adc_state=optimizer_input.prev_output.state.adc_state, + gaussians=updated_gaussians, meta=meta_for_adc, object_dict_to_adjust=object_dict + ) + + # Update state after ADC + state = object_dict["depthsplat_state"].state + optimizer_input.prev_output.state.init_state = object_dict["depthsplat_init_state"].state + + del object_dict["depthsplat_state"] + if self.cfg.input_gradient_normalize and self.cfg.input_gradient_normalize_type == "adam": + self.update_input_norm.aggregate_from_subgroups(object_dict, self.param_slices) + + # If N changed (add_new grew the population), stale KNN caches in the + # point transformer modules would index out-of-bounds on the next forward + # pass → CUDA illegal memory access. Reset them so they are recomputed. + if updated_gaussians.means.shape[1] != n_before_adc: + self._reset_knn_caches() + + # Save updated gaussians and state + optimizer_input.prev_output.gaussians = updated_gaussians + optimizer_input.prev_output.state.state = state + + if self.cfg.input_gradient_normalize_type == "adam": + optimizer_input.prev_output.state.adam_state = self.update_input_norm.get_state() + + if self.training: + optimizer_output.gaussian_list.append(updated_gaussians) + + # Info + if not self.training and self.save_every(i + 1, tag="info"): + # TODO Naama: review and refactor + + # save guassians + optimizer_output.gaussian_list.append(updated_gaussians, detach_and_cpu=True, save_to_disk=False) + + # Save delta stats + assert optimizer_output.info is not None + + # log updates + + # unpack shs + shs = updates.pop("shs") # [1, N, 3*sh_d] + assert shs.shape[0] == 1, "Batch size > 1 not supported" + shs = shs.squeeze(0) # [N, 3*sh_d] + shs = rearrange(shs, "n (c x) -> n c x", c=3, x=self.cfg.sh_d) # [N, 3, sh_d] + updates["sh0s"] = shs[..., 0:1] + if self.cfg.sh_d > 1: + updates["shNs"] = shs[..., 1:] + else: + updates["shNs"] = None + + # log deltas + if "deltas" not in optimizer_output.info: + optimizer_output.info["deltas"] = [] + optimizer_output.info["deltas"].append( + {k: v.squeeze(0).cpu() if v is not None else None for k, v in updates.items()}) + + # Split each vector grad into gaussians components + if grads_raw is not None: + if gaussians_sel is not None: + # Restore the zero gradients for tracking + b, g_valid, d = grads_raw.shape + g = state.shape[0] + grads_raw_full = torch.zeros((b, g, d)) + normalized_grads_full = torch.zeros((b, g, d)) + grads_raw_full[:, gaussians_sel, :] = grads_raw.cpu() + normalized_grads_full[:, gaussians_sel, :] = normalized_grads.cpu() + grads_raw = grads_raw_full + normalized_grads = normalized_grads_full + + grads_raw: dict[str, Tensor] = split_grads(grads_raw.cpu(), self.cfg) + + # Split each vector normalized_grads into gaussians components + if normalized_grads is not None: + normalized_grads: dict[str, Tensor] = split_grads(normalized_grads.cpu(), self.cfg) + + assert grads_raw["means"].shape == normalized_grads["means"].shape, \ + f"Shape mismatch between grads and normalized_grads: {grads_raw['means'].shape} vs {normalized_grads['means'].shape}" + + # log states + if scaled_state is not None: + if "states_norms" not in optimizer_output.info: + optimizer_output.info["states_norms"] = [] + state_norm = torch.norm(scaled_state, dim=-1) # [B, N] + optimizer_output.info["states_norms"].append(state_norm.cpu()) + + # log gradients + if "grads" not in optimizer_output.info: + optimizer_output.info["grads"] = [] + optimizer_output.info["grads"].append(grads_raw) + + # log normalized gradients + if "normalized_grads" not in optimizer_output.info: + optimizer_output.info["normalized_grads"] = [] + optimizer_output.info["normalized_grads"].append(normalized_grads) + + # Check if output_path in kwargs + output_path = kwargs.get("output_path", None) + scene_name = kwargs.get("scene_name", None) + + if self.cfg.any_adc: + pass + # Plot stats + # self.plot_info(i, output_path=output_path, scene_name=scene_name) + + # Post-update context + target renders + self._save_post_update_renders( + i, optimizer_input, optimizer_output, updated_gaussians, + full_context, full_target, + ) + + # Optimizer output is being changed in place, but for clarity we return it + return optimizer_output + + def apply_one_update_step( + self, + i, + optimizer_input: OptimizerInput, + optimizer_output: OptimizerOutput + ) -> tuple[Gaussians, Tensor, dict, dict[str, Tensor], Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + # Unpacking + context = optimizer_input.context + target = optimizer_input.target + renderer = optimizer_input.renderer + debug_dict = optimizer_input.debug_dict + num_refine = optimizer_input.num_refine + gaussians = optimizer_input.prev_output.gaussians # Gaussian object of [B, N, C] + state = optimizer_input.prev_output.state.state # [N, C] + init_state = optimizer_input.prev_output.state.init_state # [N, C] + local_window_update, test_window_size, window_end, window_start = optimizer_input.additional_info + # Get input signal for the optimizer model (erros/gradients) + self.decoder_event_start.record() + input_signal, gaussian_grads_raw, gaussian_grads, grad_sign, context_render_output, means2d_grads = ( + self.prepare_input_signal(context, i, gaussians, local_window_update, renderer, window_end, + window_start, num_refine) + ) + self.decoder_event_end.record() + + # Preparing meta for ADC + if means2d_grads is not None: + means2d_grads = means2d_grads.detach() # [B, V, N, 2] + meta_for_adc = { + "visibility_filter": context_render_output.visibility_filter.detach(), # [B, V, N] + "radii": context_render_output.radii.detach(), # [B, V, N, 1] + "means_2d_grads": means2d_grads, # [B, V, N, 2] + } + + # Handle zero gradient gaussians + # We either prune them, or exclude them from the input/output update + if self.cfg.update_only_nonzero_grad and gaussian_grads is not None: + gaussian_grads, gaussian_grads_raw, gaussians, grad_sign, init_state, input_signal, state = ( + self.handle_zero_grad_gaussians( + context, + context_render_output, + gaussian_grads, + gaussian_grads_raw, + gaussians, + grad_sign, + init_state, + input_signal, + means2d_grads, + meta_for_adc, + optimizer_input, + state) + ) + + # For training, if the number of active gaussians is too high, skip this batch + # TODO Naama: maybe sampling? + active_gaussians_num = state.shape[0] + if self.training: + if active_gaussians_num > 100_000: + print(f"Skipping batch at iteration {i} with {active_gaussians_num} active gaussians.") + raise SkipBatchException() + if active_gaussians_num < self.cfg.refine_knn_samples: + print( + f"Skipping batch at iteration {i} with only {active_gaussians_num} active gaussians (need >= {self.cfg.refine_knn_samples}).") + raise SkipBatchException() + + # Training only: save the rendering of initialization for logging + # Will not be used for loss calculation + # TODO Naama: this cause to many confusion. Pull it out of this function + if self.training and i == 0: + # Append context images initialization + assert context_render_output is not None + optimizer_output.context_render_list.append(context_render_output, detach_and_cpu=False) + + # render target images initialization + target_render_output = renderer.forward_batch_subset(gaussians, target) + optimizer_output.target_render_list.append(target_render_output, detach_and_cpu=False) + + # Unpack Gaussians + means, scales, rotations_unnorm, opacities_raw, shs = unpack_gaussians( + gaussians, + scales_log=self.cfg.opt_scales_before_act, + opacities_logit=True, + opacities_unsqueeze=True, + detach=True, # stop gradient of last predictions + scales_lims=(self.cfg.clamp_min_scale, self.cfg.clamp_refine_max_scale), + raw_opacities_lims=(self.cfg.clamp_min_raw_opacities, self.cfg.clamp_max_raw_opacities) + ) + + gaussians_concat = pack_gaussians(means, scales, rotations_unnorm, opacities_raw, shs) # [B, N, C] + + b, v, c, h, w = context["image"].shape + latent_h = h // self.cfg.latent_downsample + latent_w = w // self.cfg.latent_downsample + # Debugging reprojection error + if debug_dict is not None and (not self.training and self.save_every(i, tag="debug")): + if "reprojection_error" in debug_dict: + self.debug_reprojection_error(means, debug_dict, context, i, latent_h, latent_w) + + # prepare pt input + point_cloud, tmp_batch_size = self.get_point_cloud(latent_h, latent_w, local_window_update, means, + test_window_size, v) + # Create offset directly on device to avoid CPU-GPU transfer + offset = torch.arange(1, b + 1, device=state.device, dtype=torch.long) * tmp_batch_size + + # reshape + tmp_gaussian = self.reshape_gaussians_to_nc(latent_h, latent_w, gaussians_concat, v) # [B, N, C] --> [BN, C] + # add global attention to exchange info across views + if self.cfg.input_error_mv_attn: + input_signal = self.apply_global_attn(b, h, input_signal, latent_h, + latent_w, local_window_update, test_window_size, v, w) + + tmp_input_signal = input_signal.reshape(-1, + input_signal.shape[-1]) # [B, N, C] --> [BN, C] - faster than rearrange + tmp_input_signal = self.append_to_input_signal(b, context, context_render_output, tmp_input_signal, v) + + # Normalize state before input it to the update module + if self.cfg.input_normalize_state: + state_norm = state.norm(dim=1, keepdim=True) / math.sqrt(state.shape[-1]) # [BG, 1] + state = state / (state_norm + 1e-8) # [BG, C] + + normalized_input_signal = self.update_input_norm(tmp_input_signal) + + if self.cfg.input_normalize_gaussians: + tmp_gaussian_mean = tmp_gaussian.mean() + tmp_gaussian_std = tmp_gaussian.std() + tmp_gaussian = (tmp_gaussian - tmp_gaussian_mean) / (tmp_gaussian_std + 1e-8) + + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): + point_cloud, tmp_gaussian, state, update_input = self.prepare_update_input(b, i, init_state, + normalized_input_signal, + latent_h, + latent_w, + local_window_update, + point_cloud, + tmp_gaussian, + # gradients/errors + additional pixel related quantities + state, v, window_end, + window_start) + + # if self.cfg.refine_with_mv_attn: + # state = concat + # for i in range(len(self.update_module)): + # print(i, len(self.update_module), self.update_module[i]) + # state = self.update_module[i]([point_cloud, state, offset]) # [N, C] + # else: + updated_state = self.apply_update_module(b, latent_h, latent_w, offset, + point_cloud, update_input, v, state, i) + + # Hard coded extract normalized gradients + if self.cfg.input_gradient and self.cfg.input_gradient_normalize: + normalized_grads = normalized_input_signal + else: + normalized_grads = None + + # Recover the state norm + if self.cfg.input_normalize_state: + # state = state * state_std + state_mean + updated_state = updated_state * state_norm + + # Predict a scale for the updtaed scale for the MLP deltas prediction + # The updated state for the next stage remains the same + if self.cfg.predict_state_scale: + state_scale = self.state_scale_head(update_input.detach()) + if self.cfg.predict_state_scale_norm: + # Normalize the state vector + state_scale = state_scale / (state_scale.norm(p=2, dim=1, keepdim=True) + 1e-8) + else: + state_scale = torch.tensor([1], device=state.device, dtype=state.dtype) + updated_state_scaled = state_scale * updated_state + + # optionally append time encodiing to normalize input + with TimeEncodingWrapper(self.cfg.use_time_encoding, + self.time_encoder_fn, + optimizer_output.t, + self.cfg.time_encoding_max_steps, + updated_state_scaled) as embedded_state: + if self.cfg.use_time_encoding: + assert not self.cfg.concat_init_state + assert not self.cfg.replace_init_state + + # delta gaussian head + delta_gaussians = self.apply_delta_gaussian_head(b, context, init_state, embedded_state, v) + + visibility_scale = None # disable for now + + delta_means, delta_opacities, delta_rotations, delta_scales, delta_shs, init_repeat, delta_gaussians = ( + self.postprocess_deltas(b, delta_gaussians, gaussian_grads, gaussians_concat, grad_sign, latent_h, latent_w, + local_window_update, normalized_grads, state, test_window_size, v, window_end, + window_start, optimizer_output.t, optimizer_output.T, visibility_scale) + ) + + means, opacities_raw, rotations_unnorm, scales, shs = self.repeat_gaussians(means, opacities_raw, + rotations_unnorm, scales, shs) + + covariances, means, scales, rotations, rotations_unnorm, opacities_raw, shs = self.update_gaussians_params( + delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs, + means, scales, rotations_unnorm, opacities_raw, shs, init_repeat) + + # Recover the state in non valid gaussians (and grad for logging) + if gaussians.sel is not None: + sel = gaussians.sel # [B, G] + full_state = optimizer_input.prev_output.state.state + + # Convert full state to the dtype of state + full_state = full_state.to(state.dtype) + # Use non-in-place index_put to avoid in-place modification of tensors + # in the autograd computation graph (fixes version mismatch errors with stability loss) + updated_state = full_state.index_put((sel,), updated_state) + else: + sel = None + + # update gaussians (only where mask is True) + # Use view instead of rearrange for speed + shs_reshaped = shs.view(shs.shape[0], shs.shape[1], 3, -1) + gaussians = gaussians.update_object_by_curr_mask( + means=means, + covariances=covariances, + harmonics=shs_reshaped, + opacities=opacities_raw.squeeze(-1).sigmoid(), + scales=scales, + rotations=rotations, + rotations_unnorm=rotations_unnorm, + sel=None, + deltas=delta_gaussians if self.training else None, + gradients=gaussian_grads_raw if self.training else None, + norm_gradients=normalized_grads.unsqueeze(0) if normalized_grads is not None and self.training else None + ) + + updates = { + "means": delta_means.detach(), + "scales": delta_scales.detach(), + "rotations": delta_rotations.detach(), + "opacities": delta_opacities.detach(), + "shs": delta_shs.detach() + } + + grads_raw = gaussian_grads.detach() if gaussian_grads is not None else None + grads_adam = normalized_grads.detach() if normalized_grads is not None else None + + return gaussians, updated_state, meta_for_adc, updates, grads_raw, grads_adam, updated_state_scaled, sel + + def postprocess_deltas(self, b, delta_gaussians, gaussian_grads, gaussians_concat, grad_sign, latent_h, latent_w, + local_window_update, normalized_grads, state, test_window_size, v, window_end, window_start, + t, T, visibility_scale): + # Updates for gradient input (scale, log scale, ) + delta_gaussians_raw = delta_gaussians + delta_gaussians = self.update_delta_for_gradients_input(delta_gaussians_raw, grad_sign, normalized_grads, + visibility_scale) + + # Rearrange back to [B, N, C] + delta_gaussians, delta_gaussians_raw = self.rearrange_delta_gaussians(b, delta_gaussians, + delta_gaussians_raw, latent_h, + latent_w, local_window_update, + gaussians_concat, + test_window_size, v, window_end, + window_start) + + # TODO Naama: shouldn't it be before rearranging? + # multiple gaussian heads to predict multiple gaussians + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): + if self.cfg.gaussian_head_multiple > 1: + num_additional_heads = self.cfg.gaussian_head_multiple - 1 + delta_gaussian_list = [delta_gaussians] # list of [B, N, C] + for i in range(num_additional_heads): + curr_delta = self.update_head_list[i](state) + curr_delta = rearrange(curr_delta, "(b n) c -> b n c", b=b) + delta_gaussian_list.append(curr_delta) + delta_gaussians = torch.cat(delta_gaussian_list, dim=1) # [B, K*N, C] + + # Experimental overide deltas + if self.cfg.experimental_run: + self.experimental_update_deltas(delta_gaussians, gaussian_grads, normalized_grads) + + # Split + delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs, init_repeat = ( + self.split_delta_gaussians(delta_gaussians) + ) + + # Apply lr + delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs = self.scale_deltas_with_lr( + t, delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs + ) + + # Linear combination with adam normalized gradients + if self.cfg.delta_adam_combine_step > 0 and normalized_grads is not None: + assert t <= T + if t > self.cfg.delta_adam_combine_step: + alpha = 0.0 + beta = 1 - ((t - self.cfg.delta_adam_combine_step) / (T - self.cfg.delta_adam_combine_step)) ** alpha + # Linear combination with adam normalized gradients + # Use the inverse of the normalized gradients + # TODO Naama: hard coded lr + # means + delta_means = beta * delta_means + (1 - beta) * -normalized_grads[ + ..., self.param_slices["means"]] * 1.6e-4 + # scales + delta_scales = beta * delta_scales + (1 - beta) * -normalized_grads[ + ..., self.param_slices["scales"]] * 5e-3 + # rotations + delta_rotations = beta * delta_rotations + (1 - beta) * -normalized_grads[ + ..., self.param_slices["quats"]] * 1e-3 + # opacities + delta_opacities = beta * delta_opacities + (1 - beta) * -normalized_grads[ + ..., self.param_slices["opacities"]] * 5e-2 + # sh0 - use view instead of rearrange for speed + delta_shs_bgdc = delta_shs.view(delta_shs.shape[0], delta_shs.shape[1], 3, -1) # [b, g, 3, c] + delta_sh0 = delta_shs_bgdc[..., 0] # [b, g, 3] + delta_shN = delta_shs_bgdc[..., 1:] # [b, g, 3, d-1] + delta_shN = delta_shN.flatten(-2) # [b, g, 3*(d-1)] - faster than rearrange + + new_delta_sh0 = beta * delta_sh0 + (1 - beta) * -normalized_grads[ + ..., self.param_slices["sh0"]] * 2.5e-3 + new_delta_shN = beta * delta_shN + (1 - beta) * -normalized_grads[ + ..., self.param_slices["shN"]] * 1.25e-4 + new_delta_shN = new_delta_shN.view(new_delta_shN.shape[0], new_delta_shN.shape[1], 3, + -1) # [b, g, 3, d-1] + delta_shs[..., ::self.cfg.sh_d] = new_delta_sh0 + # shN + for i in range(1, self.cfg.sh_d): + delta_shs[..., i::self.cfg.sh_d] = new_delta_shN[..., i - 1] + + return delta_means, delta_opacities, delta_rotations, delta_scales, delta_shs, init_repeat, delta_gaussians + + def handle_zero_grad_gaussians(self, context, context_render_output, gaussian_grads, gaussian_grads_raw, gaussians, + grad_sign, init_state, input_signal, means2d_grads, meta_for_adc, optimizer_input, + state): + # Compute a mask for gaussian that did not contribute to any pixel of context views + # Their gradients are strictly zero. + # We don't want to prune them, as they might be relevant in other views (in dense views). + if self.cfg.prune_invisible_gaussians: + gaussian_grads, gaussians, grad_sign, input_signal, state = self.prune_invisible_gaussians(context, + context_render_output, + gaussian_grads, + gaussian_grads_raw, + gaussians, + grad_sign, + input_signal, + means2d_grads, + meta_for_adc, + optimizer_input, + state) + else: + assert not self.cfg.local_prune_zero_radii + assert not self.cfg.local_prune_low_weights + assert gaussian_grads.shape[0] == 1, "Batch size > 1 not supported with mask" + + # radii_mask = (context_render_output.radii != 0).all(1).all(-1) # [B, G] + # valid_mask = valid_mask & radii_mask # only consider gaussians with non-zero radius as valid + + # radii = context_render_output.radii # [B, V, G, 2] + # + # # XOR on radii last dimension to find gaussians that have zero radius in only one dimension + # assert ((radii[..., 0] == 0) ^ (radii[..., 1] == 0)).sum() == 0 # [B, V, G] + # + # # Check that all zero radius gaussians are in the zero gradient mask (but not necessarily the opposite) + # zero_radius_mask = (radii == 0).any(1).any(-1) # [B, G] + # zero_grad_mask = ~valid_mask # [B, G] + # zero_radius_cnt = zero_radius_mask.sum() + # zero_grad_of_zero_radii_cnt = zero_grad_mask[zero_radius_mask].sum() + # assert zero_grad_of_zero_radii_cnt == zero_radius_cnt, (f"All zero radius gaussians should have zero " + # f"gradients. Found {zero_radius_cnt} zero radius gaussians, but only {zero_grad_of_zero_radii_cnt} of " + # f"them have zero gradients.") + # print(f"Found {zero_grad_of_zero_radii_cnt} / {zero_radius_cnt} zero radius gaussians with zero gradients.") + + # Contribution of zero gradient gaussians + # gaussian_grads_zero_radii = gaussian_grads[zero_radius_mask] # [G_zero_radius, D] + # assert gaussian_grads_zero_radii.abs().sum() == 0, "Gaussians with zero radius should have zero gradients." + + # radii of zero gradient gaussians + # radii_zero_grad = radii[:, :, zero_grad_mask[0]] # [G_zero_grad, V, 2] + # zero_grad_radii_cont = radii_zero_grad.sum() + + # Compute [G] mask without materializing [B,G,D] bool + # any() on floats treats nonzero as True + valid_g = gaussian_grads[0].any(dim=-1) # [G] bool + sel = None + + # if everything is valid, skip all slicing work + if not valid_g.all(): + sel = valid_g.nonzero(as_tuple=True)[0] # [G_valid] + + input_signal = input_signal[:, sel, :] # [B, G_valid, C] + + gaussian_grads = gaussian_grads[:, sel, :] # [B, G_valid, D] + if gaussian_grads_raw is not None: + gaussian_grads_raw = gaussian_grads_raw[:, sel, :] + if grad_sign is not None: + grad_sign = grad_sign[:, sel, :] + + state = state[sel, :] # [G_valid, C] + init_state = init_state[sel, :] # [G_valid, C] + + valid_mask = valid_g.unsqueeze(0) # [1, G] + gaussians.sel = sel + + if self.cfg.input_gradient_normalize_type == "adam": + self.update_input_norm.sel = sel + return gaussian_grads, gaussian_grads_raw, gaussians, grad_sign, init_state, input_signal, state + + def prune_invisible_gaussians(self, context, context_render_output, gaussian_grads, gaussian_grads_raw, gaussians, + grad_sign, input_signal, means2d_grads, meta_for_adc, optimizer_input, state): + # Get visible gaussians mask, based on the last rendering + with torch.no_grad(): + visible_mask = self.get_visible_gaussian_mask(gaussian_grads, gaussians, + context_render_output.visibility_filter, context) # [B, N, 1] + if visible_mask is None: + return gaussian_grads, gaussians, grad_sign, input_signal, state + assert visible_mask.shape[0] == 1 + visible_mask = visible_mask[0, :, 0] # [N], squeeze batch and last dim + # Apply mask + gaussians = gaussians[:, visible_mask] + state = state[visible_mask] + input_signal = input_signal[:, visible_mask] # [B, N, C] + if gaussian_grads is not None: + gaussian_grads = gaussian_grads[:, visible_mask] # [B, N, C] + if gaussian_grads_raw is not None: + gaussian_grads_raw = gaussian_grads_raw[:, visible_mask] # [B, N, C] + if grad_sign is not None: + grad_sign = grad_sign[:, visible_mask] # [B, N, C] + meta_for_adc["visibility_filter"] = context_render_output.visibility_filter[:, :, visible_mask] + meta_for_adc["radii"] = context_render_output.radii[:, :, visible_mask] + if means2d_grads is not None: + meta_for_adc["means_2d_grads"] = means2d_grads[:, :, visible_mask] + if self.cfg.input_gradient_normalize and self.cfg.input_gradient_normalize_type == "adam": + if not self.update_input_norm.is_reset(): + prune_mask = ~visible_mask + self.update_input_norm.prune(prune_mask) # the prune fn will invert the mask again + if self.cfg.any_adc: + optimizer_input.prev_output.state.adc_state.external_pruning(visible_mask) + return gaussian_grads, gaussians, grad_sign, input_signal, state + + def deactivate_updates(self, subset, gaussians, radii_vis_mask, deltas, gaussian_grads): + """ Deactivate updates for gaussians that are not visible in any view """ + visible_mask = self.get_visible_gaussian_mask(gaussian_grads, gaussians, radii_vis_mask, subset) + deltas = deltas * visible_mask # [B, N, C] + return deltas + + def get_visible_gaussian_mask(self, gaussian_grads, gaussians, radii_vis_mask, subset): + """ + Get mask for gaussians that are visible in at least one view. + + We calculate two criteria: + 1. Whether the projected 2d radius is visible in at least one view. + 2. Whether the gaussian has a non-zero weight contribution to the rendering. + + If neither pruning criterion is enabled, returns None. + + Args: + gaussian_grads: [B, N, C] or None + gaussians: Gaussians object + radii_vis_mask: [B, V, N], bool + subset: dict, context or target + """ + # If no pruning criteria are active, return None + if not (self.cfg.local_prune_zero_radii or self.cfg.local_prune_low_weights): + return None + + b, v, n = radii_vis_mask.shape + + # Criterion 1: Projected radius visibility + if self.cfg.local_prune_zero_radii: + radii_vis_mask = radii_vis_mask.any(dim=1).unsqueeze(-1) # [B, N, 1] + else: + radii_vis_mask = torch.ones((b, n, 1), dtype=torch.bool, device=radii_vis_mask.device) + + # Criterion 2: Weight contribution visibility + if self.cfg.local_prune_low_weights: + threshold = self.cfg.local_prune_low_weights_thresh + weight_vis_contribution, _ = get_visibility_contribution_from_gaussian_obj(subset, gaussians) # [N] + weight_cont_mask = (weight_vis_contribution > threshold).view(1, -1, 1) + else: + weight_cont_mask = torch.ones((b, n, 1), dtype=torch.bool, device=radii_vis_mask.device) + + visible_mask = radii_vis_mask & weight_cont_mask # [B, N, 1] + return visible_mask + + def experimental_inplace_update_delta(self, deltas, grads, normalized_grads, cfg_attr): + # Slicing of the gradients vector per parameter + param_num = grads.shape[-1] + assert param_num == 11 + self.cfg.sh_d * 3 + param_slices = self.param_slices + + update = getattr(self.cfg.experimental_update, cfg_attr) + if update: + # Update this parameter + use_norm_grad = getattr(self.cfg.experimental_use_norm_grads, cfg_attr) + use_grad = self.cfg.experimental_use_grads and not use_norm_grad + use_resplat = not use_grad and not use_norm_grad + assert not (use_grad and use_norm_grad) + if use_grad: + # Use the inverse of the gradients + # TODO Naama: hard coded learning rate for SGD + deltas[..., param_slices[cfg_attr]] = -(grads[..., param_slices[cfg_attr]]).to(deltas.dtype) * 30 + elif use_norm_grad: + # Use the inverse of the normalized gradients + updated_delta = -normalized_grads[..., param_slices[cfg_attr]] * getattr(self.cfg.experimental_lr, + cfg_attr) + deltas[..., param_slices[cfg_attr]] = updated_delta.to(deltas.dtype) + else: + # Use the network prediction (already negated before) + pass + else: + # Do not update this parameter + deltas[..., param_slices[cfg_attr]] = 0 + + def experimental_update_deltas(self, deltas, grads, normalized_grads): + # Verify that at least one parameter is actually using norm_grads or grads override + any_norm_grad = any( + getattr(self.cfg.experimental_use_norm_grads, p) for p in self.cfg.experimental_update.param_names) + any_grad = self.cfg.experimental_use_grads + any_override = any_norm_grad or any_grad + assert any_override, ( + "experimental_run=true but no parameter has use_norm_grads or use_grads enabled. " + "Check that experimental_use_norm_grads._base=true (it gates all other fields via property)." + ) + if any_norm_grad: + assert normalized_grads is not None, ( + "experimental_use_norm_grads is enabled but normalized_grads is None. " + "Ensure input_gradient=true and input_gradient_normalize=true." + ) + + for p in self.cfg.experimental_update.param_names: + self.experimental_inplace_update_delta(deltas, grads, normalized_grads, p) + + def scale_deltas_with_lr(self, t, delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs): + # Scale deltas with learning rates + delta_means = delta_means * self.scheduler.get_lr(t, "means") + delta_scales = delta_scales * self.scheduler.get_lr(t, "scales") + if delta_rotations is not None: + delta_rotations = delta_rotations * self.scheduler.get_lr(t, "rotations") + delta_opacities = delta_opacities * self.scheduler.get_lr(t, "opacities") + + # Use view instead of rearrange for speed + delta_shs = delta_shs.view(delta_shs.shape[0], delta_shs.shape[1], 3, -1) # [b, g, 3, c] + delta_sh0 = delta_shs[..., 0] # [B, N, C] + delta_shN = delta_shs[..., 1:] + delta_sh0 = delta_sh0 * self.scheduler.get_lr(t, "sh0") + delta_shN = delta_shN * self.scheduler.get_lr(t, "shN") + delta_shs = torch.cat((delta_sh0.unsqueeze(-1), delta_shN), dim=-1) + delta_shs = delta_shs.flatten(-2) # [b, g, d*c] - faster than rearrange + return delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs + + def append_to_input_signal(self, b, context, context_render, tmp_input_signal, v): + if self.cfg.input_alpha: + render_alpha = rearrange(context_render.accumulated_alpha, "b v h w -> (b v) () h w") + render_alpha = F.pixel_unshuffle(render_alpha, downscale_factor=self.cfg.latent_downsample) + render_alpha = rearrange(render_alpha, "(b v) c h w -> (b v h w) c", b=b, v=v) + tmp_input_signal = torch.cat((tmp_input_signal, render_alpha), dim=-1) + if self.cfg.input_depth: + render_depth = rearrange(context_render.depth, "b v h w -> (b v) () h w") + render_depth = F.pixel_unshuffle(render_depth, downscale_factor=self.cfg.latent_downsample) + render_depth = rearrange(render_depth, "(b v) c h w -> (b v h w) c", b=b, v=v) + tmp_input_signal = torch.cat((tmp_input_signal, render_depth), dim=-1) + if self.cfg.input_depth_smooth_error: + disp = 1. / context_render.depth.clamp(min=1e-3, max=1000.) # [B, V, H, W] + disp = rearrange(disp, "b v h w -> (b v) () h w") + + mean_disp = disp.mean(2, True).mean(3, True) + norm_disp = disp / (mean_disp + 1e-7) + + tmp_imgs = rearrange(context["image"], "b v c h w -> (b v) c h w") + + depth_smooth_error = get_smooth_loss(norm_disp, tmp_imgs, no_mean=True) + + depth_smooth_error = F.pixel_unshuffle(depth_smooth_error, downscale_factor=self.cfg.latent_downsample) + depth_smooth_error = rearrange(depth_smooth_error, "(b v) c h w -> (b v h w) c", b=b, v=v) + tmp_input_signal = torch.cat((tmp_input_signal, depth_smooth_error), dim=-1) + return tmp_input_signal + + def repeat_gaussians(self, prev_means, prev_opacities_raw, prev_rotations_unnorm, prev_scales, prev_shs): + if self.cfg.gaussian_head_multiple > 1: + # predict multiple gaussians for each point + prev_means = prev_means.repeat(1, self.cfg.gaussian_head_multiple, 1) + prev_scales = prev_scales.repeat(1, self.cfg.gaussian_head_multiple, 1) + prev_rotations_unnorm = prev_rotations_unnorm.repeat(1, self.cfg.gaussian_head_multiple, 1) + prev_opacities_raw = prev_opacities_raw.repeat(1, self.cfg.gaussian_head_multiple, + 1) / self.cfg.gaussian_head_multiple # smaller opacities, important + prev_shs = prev_shs.repeat(1, self.cfg.gaussian_head_multiple, 1) + # NOTE: only repeat at the first iteration + refine_repeat = self.cfg.refine_gaussian_multiple + if refine_repeat > 1: + # predict multiple gaussians for each point + prev_means = prev_means.repeat(1, refine_repeat, 1) + prev_scales = prev_scales.repeat(1, refine_repeat, 1) + prev_rotations_unnorm = prev_rotations_unnorm.repeat(1, refine_repeat, 1) + prev_opacities_raw = prev_opacities_raw.repeat(1, refine_repeat, 1) # smaller opacities, important + prev_shs = prev_shs.repeat(1, refine_repeat, 1) + return prev_means, prev_opacities_raw, prev_rotations_unnorm, prev_scales, prev_shs + + def split_delta_gaussians(self, delta_gaussians): + delta_rotations = None + + if self.cfg.init_gaussian_multiple > 1 and not self.cfg.refine_same_num_points: + init_repeat = self.cfg.init_gaussian_multiple + else: + init_repeat = 1 + p = get_gaussian_param_sizes(self.cfg.sh_d) + if self.cfg.refine_sh_only: + delta_shs = delta_gaussians + delta_means = delta_scales = delta_opacities = 0. + elif self.cfg.no_refine_rotation: + delta_means, delta_scales, delta_opacities, delta_shs = delta_gaussians.split( + (p["means"] * init_repeat, p["scales"] * init_repeat, p["opacities"] * init_repeat, + p["shs"] * init_repeat), dim=-1 + ) + elif self.cfg.no_refine_mean: + delta_scales, delta_rotations, delta_opacities, delta_shs = delta_gaussians.split( + (p["scales"] * init_repeat, p["quats"] * init_repeat, p["opacities"] * init_repeat, + p["shs"] * init_repeat), dim=-1 + ) + delta_means = torch.zeros_like(delta_scales) + else: + delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs = delta_gaussians.split( + (p["means"] * init_repeat, p["scales"] * init_repeat, p["quats"] * init_repeat, + p["opacities"] * init_repeat, p["shs"] * init_repeat), dim=-1 + ) + if ( + self.cfg.refine_gaussian_multiple > 1 or self.cfg.init_gaussian_multiple > 1) and not self.cfg.refine_same_num_points: + delta_means = rearrange(delta_means, "b n (c k) -> b (n k) c", k=init_repeat) + delta_scales = rearrange(delta_scales, "b n (c k) -> b (n k) c", k=init_repeat) + delta_rotations = rearrange(delta_rotations, "b n (c k) -> b (n k) c", k=init_repeat) + delta_opacities = rearrange(delta_opacities, "b n (c k) -> b (n k) c", k=init_repeat) + delta_shs = rearrange(delta_shs, "b n (c k) -> b (n k) c", k=init_repeat) + return delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs, init_repeat + + def rearrange_delta_gaussians(self, b, delta_gaussians, delta_gaussians_raw, latent_h, latent_w, + local_window_update, prev_gaussians_concat, test_window_size, v, window_end, + window_start): + # [BV, C] + # update gaussian parameters + delta_gaussians = rearrange(delta_gaussians, "(b n) c -> b n c", b=b) + delta_gaussians_raw = rearrange(delta_gaussians_raw, "(b n) c -> b n c", b=b) + if local_window_update and not self.cfg.local_gaussian_render: + # zero padding for non-updated gaussians + # curr_v = self.cfg.update_window_size if self.training else test_window_size + curr_v = test_window_size + tmp_delta = rearrange(delta_gaussians, "b (v h w) c -> b v h w c", b=b, v=curr_v, h=latent_h, + w=latent_w) + + all_delta = [] + # padding + if window_start > 0: + tmp_size = rearrange(prev_gaussians_concat, "b (v h w) c -> b v h w c", b=b, v=v, h=latent_h, + w=latent_w) + pad_left = torch.zeros_like(tmp_size[:, :window_start, :, :, :], requires_grad=False) + all_delta.append(pad_left) + + all_delta.append(tmp_delta) + + if window_end < v: + tmp_size = rearrange(prev_gaussians_concat, "b (v h w) c -> b v h w c", b=b, v=v, h=latent_h, + w=latent_w) + pad_right = torch.zeros_like(tmp_size[:, window_end:, :, :, :], requires_grad=False) + all_delta.append(pad_right) + + tmp_delta = torch.cat(all_delta, dim=1) # [B, V, H, W, C] + delta_gaussians = rearrange(tmp_delta, "b v h w c -> b (v h w) c") # [B, N, C] + return delta_gaussians, delta_gaussians_raw + + def update_gaussians_params(self, delta_means, delta_scales, delta_rotations, delta_opacities, delta_shs, + means, scales, rotations_unnorm, opacities_raw, shs, + repeat): + means = self.update_means(delta_means, means) + + # clamp the scale + scales = self.update_scales(delta_scales, scales, repeat) + if self.cfg.opt_scales_before_act: + scales = scales.exp() + + if not self.cfg.no_refine_rotation: + rotations, rotations_unnorm = self.update_rotations(delta_rotations, rotations_unnorm) + else: + rotations = F.normalize(rotations_unnorm, dim=-1) + + # compute covariance + covariances = build_covariance(scales, rotations) # ([1, VHW, 3, 3]) + + opacities_raw = self.update_opacities(delta_opacities, opacities_raw, repeat) + shs = self.update_shs(delta_shs, shs) + return covariances, means, scales, rotations, rotations_unnorm, opacities_raw, shs + + def update_shs(self, delta_shs, prev_shs): + shs = prev_shs + delta_shs # [B, N, 3*sh_d] + + if self.cfg.clamp_shs_soft: + assert self.cfg.clamp_min_shs == -self.cfg.clamp_max_shs, "For soft clamp, min and max should be symmetric around 0" + shs = torch.tanh(shs / self.cfg.clamp_max_shs) * self.cfg.clamp_max_shs + else: + shs = shs.clamp(min=self.cfg.clamp_min_shs, max=self.cfg.clamp_max_shs) + + return shs + + def update_opacities(self, delta_opacities, prev_opacities_raw, repeat): + # update init opacities when predicting multiple gaussians + if repeat > 1 and not self.cfg.multi_gaussian_scale_smaller and (self.cfg.init_gaussian_multiple == 1): + # Given y = sigmoid(x), to get new x' such that sigmoid(x') = y / K: + # The formula is: x' = x + log((1 - y) / (K - y)) + # This adjusts x so that the sigmoid output is scaled down by a factor of K + tmp_sigmoid = prev_opacities_raw.sigmoid() + prev_opacities_raw = prev_opacities_raw + torch.log( + (1 - tmp_sigmoid) / (repeat - tmp_sigmoid)) + delta_opacities + else: + prev_opacities_raw = prev_opacities_raw + delta_opacities + # prev_opacities_raw = prev_opacities_raw.clamp(min=-5, max=5) + return prev_opacities_raw + + @staticmethod + def update_rotations(delta_rotations, prev_rotations_unnorm): + assert delta_rotations is not None + prev_rotations_unnorm = prev_rotations_unnorm + delta_rotations + # normazlie + prev_rotations = prev_rotations_unnorm / (prev_rotations_unnorm.norm(dim=-1, keepdim=True) + 1e-8) + return prev_rotations, prev_rotations_unnorm + + def update_scales(self, delta_scales, prev_scales, repeat): + if repeat > 1 and self.cfg.multi_gaussian_scale_smaller: + # smaller initial scales + new_scales = (prev_scales / repeat + delta_scales).clamp(min=self.cfg.gaussian_adapter.clamp_min_scale) + else: + new_scales = (prev_scales + delta_scales) + + if self.cfg.opt_scales_before_act: + min_scale = self.cfg.clamp_min_raw_scales + max_scale = self.cfg.clamp_max_raw_scales + else: + min_scale = self.cfg.clamp_min_scale + max_scale = self.cfg.clamp_refine_max_scale + + new_scales = new_scales.clamp(min=min_scale) + new_scales = new_scales.clamp(max=max_scale) + + return new_scales + + @staticmethod + def update_means(delta_means, prev_means): + prev_means = (prev_means + delta_means) + return prev_means + + def _on_scene_start_impl(self, optimizer_input: OptimizerInput) -> None: + # Reset the state + if isinstance(optimizer_input.prev_output, InitializerOutput): # New scene + from_init = True + # Reset the optimizer state for a new scene + # We cannot just use super().on_scene_start() because we need to process the InitializerOutput in case it + # contain conditioning features + self.reset_logs() + + if self.cfg.input_gradient_normalize_type == "adam": + self.update_input_norm.reset() + nr_gaussians = rearrange(optimizer_input.prev_output.gaussians.means, "b n c -> (b n) c").shape[0] + param_num = self.gaussian_param_num + self.update_input_norm.initialize(shape=(nr_gaussians, param_num), + device=optimizer_input.prev_output.gaussians.means.device) + + # make sure xyz are contiguous + optimizer_input.prev_output.gaussians.means = optimizer_input.prev_output.gaussians.means.contiguous() + elif isinstance(optimizer_input.prev_output, OptimizerPreviousOutput): + from_init = False + if self.cfg.input_gradient_normalize_type == "adam": + # Continuing previous optimization from replay buffer + self.update_input_norm.update_state(optimizer_input.prev_output.state.adam_state) + + # TODO Naama: logs are not handled right now for continuing from replay buffer + self.reset_logs() + else: + raise ValueError(f"Unknown prev_output type {type(optimizer_input.prev_output)}") + + # Preparing the input for a new scene (will handle both new scene and continuing from replay buffer) + # Will convert init_output to prev_output internally + self.optimizer_preprocessing(optimizer_input, from_init=from_init) + + # initialize adc state, after converting to prev_output + if from_init and self.cfg.any_adc: + self.initialize_adc_state(self.cfg, optimizer_input) + + def reshape_gaussians_to_nc(self, latent_h, latent_w, prev_gaussians_concat, v): + if self.cfg.init_gaussian_multiple == 4 and not self.cfg.refine_same_num_points: + # gaussians are with more points, reshape + tmp_gaussian = rearrange(prev_gaussians_concat, "b (v h x w y) c -> (b v h w) (c x y)", + v=v, h=latent_h, w=latent_w, x=2, y=2) + elif self.cfg.init_gaussian_multiple == 16 and not self.cfg.refine_same_num_points: + tmp_gaussian = rearrange(prev_gaussians_concat, "b (v h x w y) c -> (b v h w) (c x y)", + v=v, h=latent_h, w=latent_w, x=4, y=4) + else: + tmp_gaussian = rearrange(prev_gaussians_concat, "b n c -> (b n) c") + return tmp_gaussian + + def get_point_cloud(self, latent_h, latent_w, local_window_update, prev_means, test_window_size, v): + # TODO: when the initial model predicts multiple gaussians, the number of points also increases + if self.cfg.init_gaussian_multiple == 4 and not self.cfg.refine_same_num_points: + point_cloud = rearrange(prev_means, "b (v h w) c -> b v h w c", + v=v, h=latent_h * 2, w=latent_w * 2, + ) + tmp_batch_size = v * latent_h * latent_w + # simply use uniform grid subsample of point cloud to reduce points + point_cloud = point_cloud[:, :, ::2, ::2] + point_cloud = rearrange(point_cloud, "b v h w c -> (b v h w) c") + elif self.cfg.init_gaussian_multiple == 16 and not self.cfg.refine_same_num_points: + point_cloud = rearrange(prev_means, "b (v h w) c -> b v h w c", + v=v, h=latent_h * 4, w=latent_w * 4, + ) + tmp_batch_size = v * latent_h * latent_w + # simply use uniform grid subsample of point cloud to reduce points + point_cloud = point_cloud[:, :, ::4, ::4] + point_cloud = rearrange(point_cloud, "b v h w c -> (b v h w) c") + else: + point_cloud = rearrange(prev_means, "b n c -> (b n) c") + if local_window_update: + tmp_batch_size = test_window_size * latent_h * latent_w + else: + tmp_batch_size = prev_means.shape[1] + return point_cloud, tmp_batch_size + + def get_vector_state(self, b, v, n, optimizer_input, from_init): + if from_init: + # Starting a new scene directly from the initializer + # State should not be provided + # Create initial state + # optimizer_input.prev_output is of type InitializerOutput + if optimizer_input.prev_output.features is None or self.cfg.init_state_wo_features: + # Creating state without initializer features + assert self.cfg.init_state_wo_features + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): + dtype = torch.get_autocast_dtype('cuda') + if self.cfg.init_state_type == "constant": + state = torch.ones((b, n, self.cfg.state_channels), device=self.device, dtype=dtype) + elif self.cfg.init_state_type == "random": + state = torch.randn((b, n, self.cfg.state_channels), device=self.device, dtype=dtype) + else: + raise ValueError(f"Unknown init_state_type {self.cfg.init_state_type}") + state = state * self.cfg.init_state_scale + else: + # Calculating state from initializer features + state = self.get_state_from_condition_features(b, optimizer_input.prev_output.features, + v) # [B, N, C] + + else: + # Restarting optimizing a scene from a replay buffer + state = optimizer_input.prev_output.state.state + # TODO Naama: need to understand why rearrange here, perhaps something with pruning + state = rearrange(state, "(b n) c -> b n c", b=b) + + # combine gaussians of all scnes in the batch [B*N, C] + state = rearrange(state, "b n c -> (b n) c") # [B*N, C] + + # Do something with window size + _, _, _, h, w = optimizer_input.context["image"].shape # [B, V, C, H, W] + local_window_update, test_window_size, window_end, window_start = optimizer_input.additional_info + # select initial state + if local_window_update and self.cfg.local_gaussian_render: + state = rearrange(state, "(b v h w) c -> b v h w c", b=b, v=v, + h=h // self.cfg.latent_downsample, + w=w // self.cfg.latent_downsample) + state = state[:, window_start:window_end, :, :, :] + state = rearrange(state, "b v h w c -> (b v h w) c") + + return state + + @staticmethod + def _align_features(features, latent_h: int, latent_w: int) -> list: + """Resize each feature map to (latent_h, latent_w) if needed and return as a list.""" + out = [] + vals = features.values() if isinstance(features, dict) else features + for feat in vals: + if feat.shape[-2:] != (latent_h, latent_w): + feat = F.interpolate(feat, size=(latent_h, latent_w), mode='bilinear', align_corners=True) + out.append(feat) + return out + + def _get_latent_size(self, h: int, w: int) -> tuple[int, int]: + """Compute latent (H, W) from image (H, W), accounting for init_gaussian_multiple upsampling.""" + latent_h = h // self.cfg.latent_downsample + latent_w = w // self.cfg.latent_downsample + if self.cfg.init_gaussian_multiple == 4 and self.cfg.refine_same_num_points: + latent_h *= 2 + latent_w *= 2 + elif self.cfg.init_gaussian_multiple == 16 and self.cfg.refine_same_num_points: + latent_h *= 4 + latent_w *= 4 + return latent_h, latent_w + + def render_input_views_for_error_calc(self, context, + local_window_update, + prev_gaussians, + renderer, + window_end, + window_start, + num_refine, + i): + _, _, _, h, w = context["image"].shape # [B, V, C, H, W] + + render_res = (h, w) + + # Default rendering parameters + input_info = context + start = None + end = None + cfg = self.cfg + + # Use only first N views + if cfg.input_error_num_views > 0: + end = cfg.input_error_num_views + + # Local window update logic + elif local_window_update: + if i >= num_refine - 1: + return None # Skip rendering on the last iteration + start = window_start + end = window_end + + # Final unified rendering call + return renderer.forward_batch_subset( + prev_gaussians, + input_info, + render_res, + start=start, + end=end, + return_radii=False + ) + + def get_state_from_condition_features(self, b, condition_features, v): + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, dtype=torch.bfloat16): + if not self.cfg.pt_update_amp and condition_features.dtype == torch.bfloat16: + condition_features = condition_features.float() + state = self.update_proj(condition_features.detach()) # [B, C, H, W] + if self.cfg.init_gaussian_multiple == 4 and self.cfg.refine_same_num_points: + state = F.interpolate(state, scale_factor=2, mode='bilinear', align_corners=True) + elif self.cfg.init_gaussian_multiple == 16 and self.cfg.refine_same_num_points: + state = F.interpolate(state, scale_factor=4, mode='bilinear', align_corners=True) + else: + pass + # Convert to vector of Gaussians per batch [B, N, C] + state = rearrange(state, "(b v) c h w -> b (v h w) c", b=b, v=v) # N = v * h * w + return state + + def get_window_size(self, v): + test_window_size = None + if self.cfg.update_window_size > 0: + + local_window_update = True + # if self.training: + # window_start = random.randint(0, v - self.cfg.update_window_size) + # window_end = window_start + self.cfg.update_window_size + # else: + # fixed window at test time, uniformly move from left to right + # TODO: loop closure, connect left and right + if self.training: + test_window_size = self.cfg.update_window_size + window_start = random.randint(0, test_window_size) + window_end = window_start + test_window_size + + if window_end == v: + # restart + window_start = random.randint(0, test_window_size) + window_end = window_start + test_window_size + else: + # at least do a full pass of all input views + # test_window_size = int(np.ceil(v / self.cfg.num_refine)) + test_window_size = self.cfg.update_window_size + window_start = 0 + window_end = window_start + test_window_size + + else: + local_window_update = False + window_start = 0 + window_end = v + return local_window_update, test_window_size, window_end, window_start + + def prepare_update_input(self, b, i, init_state, input_signal, latent_h, latent_w, local_window_update, point_cloud, + tmp_gaussian, state, v, window_end, window_start): + if self.cfg.replace_init_state: + state = init_state + + if self.cfg.no_render_error: + update_input = torch.cat((tmp_gaussian, state), dim=-1) + else: + if local_window_update and not self.cfg.local_gaussian_render: + # select local window + tmp_gaussian = rearrange(tmp_gaussian, "(b v h w) c -> b v h w c", b=b, v=v, h=latent_h, + w=latent_w) + tmp_gaussian = tmp_gaussian[:, window_start:window_end, :, :, :] + tmp_gaussian = rearrange(tmp_gaussian, "b v h w c -> (b v h w) c") + + if i == 0: + state = rearrange(state, "(b v h w) c -> b v h w c", b=b, v=v, h=latent_h, + w=latent_w) + state = state[:, window_start:window_end, :, :, :] + state = rearrange(state, "b v h w c -> (b v h w) c") + + # local point cloud + point_cloud = rearrange(point_cloud, "(b v h w) c -> b v h w c", b=b, v=v, h=latent_h, + w=latent_w) + point_cloud = point_cloud[:, window_start:window_end, :, :, :] + point_cloud = rearrange(point_cloud, "b v h w c -> (b v h w) c") + + update_input = torch.cat((tmp_gaussian, state, input_signal), dim=-1) + if self.cfg.concat_init_state: + update_input = torch.cat((update_input, init_state), dim=-1) + return point_cloud, tmp_gaussian, state, update_input + + def apply_update_module(self, b, latent_h, latent_w, offset, point_cloud, update_input, v, state, iter): + + def recurrent_chunk(update_input, point_cloud, offset): + pxo = self.update_module[0]([point_cloud, update_input, offset]) + state = self.update_module[1](pxo, iter=iter, b=b, v=v, h=latent_h, w=latent_w) + return state + + if self.cfg.use_checkpointing or self.cfg.recurrent_use_checkpointing: + new_state = torch.utils.checkpoint.checkpoint( + recurrent_chunk, + update_input, point_cloud, offset, + use_reentrant=False, + ) + else: + new_state = recurrent_chunk(update_input, point_cloud, offset) + + if self.cfg.residual_state: + new_state = new_state + state + return new_state + + def apply_delta_gaussian_head(self, b, context, init_state, state, v): + if self.cfg.update_head_concat_img: + # pixel unshuffle image + img_unshuffle = rearrange(context["image"], "b v c h w -> (b v) c h w") + img_unshuffle = F.pixel_unshuffle(img_unshuffle, downscale_factor=self.cfg.latent_downsample) + img_unshuffle = rearrange(img_unshuffle, "(b v) c h w -> (b v h w) c", b=b, v=v) + head_input = torch.cat((state, img_unshuffle), dim=-1) + + else: + if self.cfg.refine_residual_init_state: + head_input = state + init_state + else: + head_input = state + + if self.cfg.update_head_per_param_heads: + delta_gaussians = self._apply_per_param_heads(head_input) + else: + delta_gaussians = self.update_head(head_input) + + return delta_gaussians + + def _apply_per_param_heads(self, head_input): + """Run per-parameter-group heads and concatenate results. + + Each head outputs [N, dim+1] where the last dim is the scalar scale. + Per-group normalize + scale is applied independently. + """ + deltas = [] + for name, dim in self._per_param_group_dims.items(): + raw = self.update_head[name](head_input) # [N, dim+1] + scale = self.scale_act(raw[:, -1:]) # [N, 1] + delta = raw[:, :-1] # [N, dim] + if dim > 1: + delta = delta / (delta.norm(p=2, dim=-1, keepdim=True) + 1e-8) * scale + else: + # 1-d (e.g. opacities): no direction to normalize, just scale magnitude + delta = delta * scale + deltas.append(delta) + return torch.cat(deltas, dim=-1) + + def apply_global_attn(self, b, h, input_signal, latent_h, latent_w, + local_window_update, test_window_size, v, w): + # TODO Naama: do we need local_window? + assert self.cfg.input_error_resnet_feature + assert self.cfg.input_error + + if self.cfg.input_gradient and self.cfg.input_error: + input_render_error = input_signal[..., :self.error_features_channels] + else: + input_render_error = input_signal + + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.use_amp, dtype=torch.bfloat16): + for blk in self.update_error_attn: + if self.cfg.refine_same_num_points: + # no downsample, for re10k 256 + input_render_error = blk(input_render_error, v=v, h=h, w=w) + else: + curr_v = test_window_size if local_window_update else v + input_render_error = blk(input_render_error, v=curr_v, h=latent_h, w=latent_w) + + if self.cfg.input_gradient and self.cfg.input_error: + input_signal[..., :self.error_features_channels] = input_render_error + else: + input_signal = input_render_error + + return input_signal + + def prepare_input_signal(self, context, i, gaussians, + local_window_update, renderer, + window_end, window_start, num_refine): + # TODO Naama: review + # make sure at least one of the following is True + assert self.cfg.input_gradient or self.cfg.input_error + input_view_features = None + input_signal = None + input_render_error = None + context_render_output = None + gaussian_grads_raw = None + gaussian_grads = None + grad_sign = None + means2d_grads = None + + # calculate input gradients + if self.cfg.input_gradient: + gaussian_grads_raw, gaussian_grads, grad_sign, context_render_output, means2d_grads = ( + self._calc_input_gradients(context, gaussians, renderer) + ) + + input_signal = gaussian_grads_raw + + # When using gradients, context_render_output cannot be used for the meta-training, + # because there was already one backward pass. + # So we render again if in training. + if context_render_output is None or self.training: + context_render_output = self.render_input_views_for_error_calc(context, local_window_update, + gaussians, renderer, window_end, + window_start, num_refine, i) + + # calculate input rendering errors + if self.cfg.input_error: + if means2d_grads is None and self.cfg.need_2d_grads: + raise NotImplementedError("Calculating 2dgrad for ADC is not implemented for error input alone") + input_render_error = self._calc_input_errors(context, i, context_render_output, + input_view_features, + local_window_update, + gaussians.means.detach(), + window_end, + window_start) + input_signal = input_render_error + + if self.cfg.input_gradient and self.cfg.input_error: + # Concatenate both gradients and errors + input_signal = torch.cat((input_render_error, gaussian_grads), dim=-1) + + return input_signal, gaussian_grads_raw, gaussian_grads, grad_sign, context_render_output, means2d_grads + + def get_data_shim(self) -> DataShim: + def data_shim(batch: BatchedExample) -> BatchedExample: + batch = apply_patch_shim( + batch, + patch_size=self.cfg.shim_patch_size + * self.cfg.downscale_factor, + ) + + return batch + + return data_shim + + @property + def sampler(self): + return None + + def debug_reprojection_error(self, means, debug_dict, context, i, latent_h, latent_w): + # Prepare means (remove singleton dim) + means = rearrange(means, "b (v h w) c -> b v (h w) c", h=latent_h, w=latent_w) # [B, V, H*W, 3] + + # Expand extrinsics/intrinsics for broadcast + extrinsics = context["extrinsics"].unsqueeze(2) # [B, V, 1, 4, 4] + intrinsics = context["intrinsics"].unsqueeze(2) # [B, V, 1, 3, 3] + + # Project + xy_ray_reconstructed, in_front = project(means, extrinsics, intrinsics) # [B, V, H*W, 2], [B, V, H*W] + + xy_ray, _ = sample_image_grid((latent_h, latent_w), xy_ray_reconstructed.device) # [B, V, H*W, 1, 2] + xy_ray = rearrange(xy_ray, "h w xy -> (h w) () xy") + + xy_ray = xy_ray.squeeze(-2) # [B, V, H*W, 2] + + xy_ray_unnorm = xy_ray.clone() + xy_ray_unnorm[..., 0] *= latent_w + xy_ray_unnorm[..., 1] *= latent_h + + xy_ray_reconstructed_unnorm = xy_ray_reconstructed.clone() + xy_ray_reconstructed_unnorm[..., 0] *= latent_w + xy_ray_reconstructed_unnorm[..., 1] *= latent_h + + reprojection_error = (xy_ray_unnorm - xy_ray_reconstructed_unnorm).abs() + + if debug_dict["reprojection_error"] is None: + # First iteration, first scene + debug_dict["reprojection_error"] = [[]] + elif i == 0: + # New iteration, new scene + debug_dict["reprojection_error"].append([]) + + debug_dict["reprojection_error"][-1].append(reprojection_error.detach().cpu()) + + # import matplotlib.pyplot as plt + # plt.figure(figsize=(12, 6)) + # plt.hist(reprojection_error.flatten().detach().cpu(), bins=100, range=(0, 10)) + # plt.title(f"Reprojection Error - step {i}") + # plt.xlabel("Error (pixels)") + # plt.ylabel("Frequency") + # plt.show() + + def _calc_input_errors(self, context, i, input_render, input_view_features, + local_window_update, prev_means, + window_end, window_start): + b, v, _, h, w = context["image"].shape + # Detach the last rendered object + input_rgb = input_render.color.detach() + # compute input view rendering error + if self.cfg.input_error_resnet_feature: + input0 = rearrange(input_rgb, "b v c h w -> (b v) c h w") + if self.cfg.input_error_num_views > 0: + gt_input = context["image"][:, :self.cfg.input_error_num_views, :, :, :] + elif local_window_update: + gt_input = context["image"][:, window_start:window_end, :, :, :] + else: + gt_input = context["image"] + input1 = rearrange(gt_input, "b v c h w -> (b v) c h w") + + transform = _IMAGENET_NORM + + if input_view_features is None: + assert i == 0 + # first time: extract all features + concat = torch.cat((input0, input1), dim=0) + + input_tensor = transform(concat) + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, + dtype=torch.bfloat16): + # Extract features + with torch.no_grad(): + features = self.update_feature(input_tensor) + + # align to the latent resolution + latent_h, latent_w = self._get_latent_size(h, w) + + all_features = torch.cat(self._align_features(features, latent_h, latent_w), dim=1) + + render_view_features = all_features[:input0.shape[0]] + input_view_features = all_features[input0.shape[0]:] + + else: + # only extract render view features + with torch.amp.autocast(device_type='cuda', enabled=self.cfg.pt_update_amp, + dtype=torch.bfloat16): + # Extract features + with torch.no_grad(): + features = self.update_feature(transform(input0)) + + # align to the latent resolution + latent_h, latent_w = self._get_latent_size(h, w) + + render_view_features = torch.cat(self._align_features(features, latent_h, latent_w), dim=1) + + corr = render_view_features - input_view_features + + if self.cfg.input_error_num_views > 0: + # pad to V views + curr_v = self.cfg.input_error_num_views + indices = torch.arange(v) * curr_v // v + corr = rearrange(corr, "(b v) c h w -> b v c h w", b=b) + corr = corr[torch.arange(b).unsqueeze(1), indices, :, :, :] + input_render_error = rearrange(corr, "b v c h w -> b (v h w) c") + else: + input_render_error = rearrange(corr, "(b v) c h w -> b (v h w) c", b=b) + + else: + input_render_error = (input_render.color - context["image"]).abs() # [B, V, 3, H, W] + input_render_error = rearrange(input_render_error, "b v c h w -> (b v) c h w") + + if self.cfg.input_error_rgb_no_shuffle: + # bilinear + input_render_error = F.interpolate(input_render_error, + scale_factor=1. / self.cfg.latent_downsample, + mode='bilinear', align_corners=True) + else: + # pixel unshuffle + # TODO: when fps is used, how to reshape the render error to make sure its somehow pixel aligned to the gaussians + input_render_error = F.pixel_unshuffle(input_render_error, + downscale_factor=self.cfg.latent_downsample) + + input_render_error = rearrange(input_render_error, "(b v) c h w -> b (v h w) c", b=b, + v=v) # [B, N, C] + + # include both feature error and image error + if self.cfg.input_error_add_rgb_feature: + rgb_render_error = input_render.color - context["image"] + rgb_render_error = rearrange(rgb_render_error, "b v c h w -> (b v) c h w") + + if self.cfg.input_error_rgb_no_shuffle: + # bilinear + rgb_render_error = F.interpolate(rgb_render_error, scale_factor=1. / self.cfg.latent_downsample, + mode='bilinear', align_corners=True) + else: + # pixel unshuffle + # TODO: when fps is used, how to reshape the render error to make sure its somehow pixel aligned to the gaussians + rgb_render_error = F.pixel_unshuffle(rgb_render_error, + downscale_factor=self.cfg.latent_downsample) + + rgb_render_error = rearrange(rgb_render_error, "(b v) c h w -> b (v h w) c", b=b, v=v) # [B, N, C] + + rgb_render_error = self.update_rgb_error_proj(rgb_render_error) + input_render_error = input_render_error + rgb_render_error + + return input_render_error + + def get_input_error_feature_extractor(self): + update_feature = None + # resnet feature + if self.cfg.input_error_resnet_feature: + update_feature = ResNetFeatureWarpper( + shallow_resnet_feature=self.cfg.input_error_shallow_resnet_feature) + + if self.cfg.input_error_no_freeze_resnet_feature: + # remove unused layers + # NOTE: layer 3 is also not used + update_feature.layer3 = nn.Identity() + update_feature.train() + for params in update_feature.parameters(): + params.requires_grad = True + else: + update_feature.eval() + + for params in update_feature.parameters(): + params.requires_grad = False + + return update_feature + + def update_delta_for_gradients_input(self, delta_gaussians, grad_sign, normalized_grad, + visibility_scale: Tensor | None = None): + if self.cfg.input_gradient: + delta_gaussians = delta_gaussians / self.cfg.input_gradient_scale + if self.cfg.input_gradient_log: + grad_sign = rearrange(grad_sign, "b n c -> (b n) c") + # recover log scale for applying the deltas. + # For loss calculation the delta should still be in log scale + + delta_gaussians = grad_sign * (delta_gaussians.exp() - 1e-8) + + if self.cfg.input_gradient_log_clip_deltas > 0: + # clip the delta to avoid too large updates + clip_value = self.cfg.input_gradient_log_clip_deltas + clip_mask = delta_gaussians.abs() > clip_value + delta_gaussians[clip_mask] = delta_gaussians[clip_mask].sign() * clip_value + + # TODO Naama: move these two, as they are not related to gradients + if self.cfg.update_head_scale_mag: + out_channels = delta_gaussians.shape[-1] + param_num = out_channels / 2 + assert param_num.is_integer() + param_num = int(param_num) + scale = delta_gaussians[:, :param_num] + mag = delta_gaussians[:, param_num:] + delta_gaussians = scale * 0.01 * torch.exp(mag * 0.01) + + if self.cfg.update_head_scalar_scale: + if self.cfg.update_head_per_param_heads: + # Already handled in _apply_per_param_heads — nothing to do here + pass + elif self.cfg.update_head_per_param_scales: + # Feature B: per-group scalar scales + num_groups = len(self._per_param_group_dims) + scales = delta_gaussians[:, -num_groups:] # [G, num_groups] + scales = self.scale_act(scales) + deltas = delta_gaussians[:, :-num_groups] # [G, D] + + normalized_deltas = [] + offset = 0 + for i, (name, dim) in enumerate(self._per_param_group_dims.items()): + group_delta = deltas[:, offset:offset + dim] # [G, dim] + group_scale = scales[:, i:i + 1] # [G, 1] + if dim > 1: + group_delta = group_delta / (group_delta.norm(p=2, dim=-1, keepdim=True) + 1e-8) + group_delta = group_delta * group_scale + normalized_deltas.append(group_delta) + offset += dim + + delta_gaussians = torch.cat(normalized_deltas, dim=-1) + else: + # Original global scalar scale + scale = delta_gaussians[:, -1:] # [G, 1] + scale = self.scale_act(scale) # make sure scale is positive + deltas_unnorm = delta_gaussians[:, :-1] # [G, D] + deltas_norm = deltas_unnorm / (deltas_unnorm.norm(p=2, dim=1, keepdim=True) + 1e-8) # [G, D] + delta_gaussians = deltas_norm * scale + + if visibility_scale is not None: + delta_gaussians = delta_gaussians * visibility_scale + + if self.cfg.scale_residual_grads: + delta_gaussians = delta_gaussians * normalized_grad * self.cfg.gradient_update_scale # 1.0 + + # To match the default behavior of SGD, Adam, and other optimizers, deltas are negated. + # SGD applies the gradients as `x = x - lr * grad`, while resaplt applies them as `x = x + lr * deltas`. + delta_gaussians = -delta_gaussians + + return delta_gaussians + + def _calc_input_gradients(self, context, gaussians, renderer): + assert not self.cfg.input_gradient_same_loss, "input_gradient_same_loss is not implemented" + _, v, _, h, w = context["image"].shape + + with torch.enable_grad(): + + # Unpack gaussians + means, scales, rotations_unnorm, opacities_raw, shs = unpack_gaussians( + gaussians, + scales_log=self.cfg.opt_scales_before_act, + opacities_logit=True, + opacities_unsqueeze=True, + detach=True, + clone=False, + requires_grad=True, + scales_lims=(self.cfg.clamp_min_scale, self.cfg.clamp_refine_max_scale), + raw_opacities_lims=(self.cfg.clamp_min_raw_opacities, self.cfg.clamp_max_raw_opacities) + ) + + # Create temporary Gaussians with same values but requires_grad=True + grad_batch_size = self.cfg.input_gradients_chunk_size + if grad_batch_size == -1: + grad_batch_size = v + gaussian_grads = 0 + means2d_grads_chunks = [] + nr_chunks = math.ceil(v / grad_batch_size) + + # Pre-compute shapes and config flags outside the loop + shs_shape = (shs.shape[0], shs.shape[1], 3, -1) + opt_scales_before_act = self.cfg.opt_scales_before_act + # Pre-compute normalized rotations once (not in gradient inputs, so no grad needed) + with torch.no_grad(): + rotations = rotations_unnorm / (rotations_unnorm.norm(dim=-1, keepdim=True) + 1e-8) + + for chunk_idx, start, stop in chunk_index_iter(v, grad_batch_size): + # zero grads + + means = means.detach().requires_grad_(True) + scales = scales.detach().requires_grad_(True) + rotations_unnorm = rotations_unnorm.detach().requires_grad_(True) + opacities_raw = opacities_raw.detach().requires_grad_(True) + shs = shs.detach().requires_grad_(True) + + # Apply activation to scales if needed (before calculating covariance) + scales_act = scales.exp() if opt_scales_before_act else scales + + tmp_gaussians = Gaussians( + means=means, + covariances=None, + harmonics=shs.view(shs_shape), + opacities=torch.sigmoid(opacities_raw.squeeze(-1)), + scales=scales_act, + rotations=rotations, + rotations_unnorm=rotations_unnorm, + ) + + # render input views, calculate inner loss and backprop to get gradients + context_render_output = renderer.forward_batch_subset( + tmp_gaussians, + context, + start=start, + end=stop, + image_shape=(h, w), + ) + + inputs = [means, scales, rotations_unnorm, opacities_raw, shs] + + if self.cfg.need_2d_grads: + assert context_render_output.means2d is not None, "output_renderer.means2d is None" + means2d = context_render_output.means2d # [B, V, N, 2] + # means2d.retain_grad() # retain grad for means2d grads computation + inputs.append(means2d) + + inner_loss = inner_loss_for_input_gradients( + context["image"][:, start:stop], + context_render_output, + reduction=self.cfg.input_gradient_loss_reduction, + with_ssim=self.cfg.input_gradient_with_ssim_loss, + ) + if self.cfg.opacity_reg_lambda > 0.0: + inner_loss = inner_loss + self.cfg.opacity_reg_lambda * torch.sigmoid(opacities_raw).mean() + grads = torch.autograd.grad(outputs=inner_loss, + inputs=inputs, + create_graph=False, + retain_graph=False, + ) + + gaussian_grads = gaussian_grads + torch.cat(grads[:5], dim=-1) # [B, G, D] + assert not torch.isnan(gaussian_grads).any(), "NaN detected in gaussian_grads" + if self.cfg.need_2d_grads: + means2d_grads_chunks.append(grads[5]) # [B, V_chunk, N, 2] + + gaussian_grads = gaussian_grads / nr_chunks + + if self.cfg.need_2d_grads: + means2d_grads = torch.cat(means2d_grads_chunks, dim=1) # [B, V, N, 2] + if self.cfg.input_gradient_loss_reduction == "mean_pixels_sum_views": + means2d_grads = means2d_grads / v + else: + means2d_grads = None + + gaussian_grads_raw = gaussian_grads * self.cfg.input_gradient_scale + if self.cfg.input_gradient_log: + # log gradients + grads_sign = gaussian_grads.sign() + gaussian_grads_raw = (gaussian_grads_raw.abs() + 1e-8).log() + else: + grads_sign = None + + # Detach gradients to avoid gradient flow through the input + gaussian_grads = gaussian_grads.detach() + gaussian_grads_raw = gaussian_grads_raw.detach() + if grads_sign is not None: + grads_sign = grads_sign.detach() + + # Returning also the render output, but it can only be used for visualization, + # as we already backpropogate gradients through it + return gaussian_grads_raw, gaussian_grads, grads_sign, context_render_output, means2d_grads + + +def select_gaussian_subset(gaussians, window_start, window_end, v, h, w): + """Select a subset of gaussians based on view window. Optimized to avoid rearrange overhead.""" + b = gaussians.means.shape[0] + hw = h * w + window_v = window_end - window_start + new_n = window_v * hw + + # Helper to slice view dimension efficiently using view+slice+reshape instead of rearrange + def slice_tensor(t, extra_dims): + # t shape: [b, v*h*w, *extra_dims] -> [b, window_v*h*w, *extra_dims] + shape = (b, v, hw) + extra_dims + new_shape = (b, new_n) + extra_dims + return t.view(shape)[:, window_start:window_end, :].reshape(new_shape) + + means = slice_tensor(gaussians.means, (3,)) + covariances = slice_tensor(gaussians.covariances, (3, 3)) if gaussians.covariances is not None else None + shs = slice_tensor(gaussians.harmonics, gaussians.harmonics.shape[2:]) + opacities = slice_tensor(gaussians.opacities.unsqueeze(-1), ()).squeeze(-1) + scales = slice_tensor(gaussians.scales, (3,)) + rotations = slice_tensor(gaussians.rotations, (4,)) if gaussians.rotations is not None else None + rotations_unnorm = slice_tensor(gaussians.rotations_unnorm, (4,)) + + return Gaussians( + means=means, + covariances=covariances, + harmonics=shs, + opacities=opacities, + scales=scales, + rotations=rotations, + rotations_unnorm=rotations_unnorm, + ) + + +def replace_window(original, window, window_start, window_end, dim=1): + slices = [] + if window_start > 0: + # TODO: detach or not + # slices.append(original[:, :window_start].detach()) + slices.append(original[:, :window_start]) + slices.append(window) + if window_end < original.shape[dim]: + # TODO: detach or not + # slices.append(original[:, window_end:].detach()) + slices.append(original[:, window_end:]) + return torch.cat(slices, dim=dim) + + +def freeze_batchnorm_layers(model): + import torch.nn as nn + for module in model.modules(): + if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or isinstance(module, + nn.BatchNorm3d): + module.eval() # Set to evaluation mode + for param in module.parameters(): + param.requires_grad = False # Freeze parameters diff --git a/optgs/scene_trainer/optimizer/optimizer_learn2splat.py b/optgs/scene_trainer/optimizer/optimizer_learn2splat.py new file mode 100644 index 0000000000000000000000000000000000000000..9812d3b07d9ebafcd90690a96be71f2806609bd0 --- /dev/null +++ b/optgs/scene_trainer/optimizer/optimizer_learn2splat.py @@ -0,0 +1,6 @@ +from optgs.scene_trainer.optimizer.optimizer_knn_based import KnnBasedOptimizer + + +class Learn2SplatOptimizer(KnnBasedOptimizer): + OPTIMIZER_NAME = "l2s" + OPTIMIZER_NAME_ALIASES = ("clogs",) # TODO (release): remove aliases diff --git a/optgs/scene_trainer/optimizer/optimizer_resplat.py b/optgs/scene_trainer/optimizer/optimizer_resplat.py new file mode 100644 index 0000000000000000000000000000000000000000..1439c605eae63fabcc9503bdea13187cef6c0ade --- /dev/null +++ b/optgs/scene_trainer/optimizer/optimizer_resplat.py @@ -0,0 +1,10 @@ +from typing import Literal + +from optgs.scene_trainer.optimizer import KnnBasedOptimizer, KnnBasedOptimizerCfg + + +class ResplatOptimizerV1(KnnBasedOptimizer): + OPTIMIZER_NAME = "resplat_v1" + +class ResplatOptimizerV2(KnnBasedOptimizer): + OPTIMIZER_NAME = "resplat_v2" diff --git a/optgs/scene_trainer/optimizer/optimizer_utils.py b/optgs/scene_trainer/optimizer/optimizer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a4d04e72b00836b64c11658175b619037cbdf3f --- /dev/null +++ b/optgs/scene_trainer/optimizer/optimizer_utils.py @@ -0,0 +1,850 @@ +import warnings +from dataclasses import dataclass +import torch.nn.functional as F +import math +from typing import List, Tuple, Iterator +from typing import Literal, Generic, TypeVar +from fused_ssim import allowed_padding, FusedSSIMMap +import torch +from gsplat import fully_fused_projection, isect_tiles, isect_offset_encode, rasterize_to_indices_in_range +from nerfacc import accumulate_along_rays, render_weight_from_alpha +from torch import Tensor +from optgs.model.decoder.decoder import Decoder, DecoderOutput +from optgs.model.types import Gaussians +from einops import rearrange +from tqdm import tqdm +import gc +import torch.autograd.profiler as profiler +from optgs.misc.memory_profiler import profile_gpu_memory, report_gpu_tensors + +T = TypeVar("T") +GPU_MEM_PROFILING = False # set to True to enable GPU memory profiling + + +def split_grads(grads_tensor, cfg): + + assert isinstance(grads_tensor, Tensor), "grads_tensor is not a Tensor" + + # handle case where grads_tensor has batch dimension + if grads_tensor.ndim == 3: + assert grads_tensor.shape[0] == 1, "Batch size > 1 not supported for grads_tensor with ndim 3" + grads_tensor = grads_tensor.squeeze(0) # [N, D] + + # Split the last dimension + means, scales, rotations, opacities, shs = torch.split( + grads_tensor, (3, 3, 4, 1, 3 * cfg.sh_d), dim=-1 + ) + + shs = rearrange(shs, "n (c x) -> n c x", c=3, x=cfg.sh_d) # [N, 3, sh_d] + sh0s = shs[..., 0:1] + if cfg.sh_d > 1: + shNs = shs[..., 1:] + else: + shNs = None + grads: dict = { + "means": means, + "scales": scales, + "rotations": rotations, + "opacities": opacities, + "sh0s": sh0s, + "shNs": shNs, + } + return grads + + +def inner_loss_for_input_gradients( + gt_images, + output_renderer: DecoderOutput, + reduction: str = "mean", + with_ssim: bool = True, +) -> Tensor: + # compute scalar loss + # assume batch size 1 + assert gt_images.shape[0] == 1 + assert gt_images.shape == output_renderer.color.shape + + l1_loss = (output_renderer.color - gt_images).abs() + if reduction == "mean": + l1_loss = l1_loss.mean() + elif reduction == "sum": + l1_loss = l1_loss.sum() + elif reduction == "mean_pixels_sum_views": + l1_loss = l1_loss.mean(dim=(-1, -2, -3)).sum(dim=-1).mean() + else: + raise ValueError(f"Unknown reduction: {reduction!r}") + + if not with_ssim: + return l1_loss + + gt_images_for_ssim = gt_images.clone() if gt_images.is_inference() else gt_images + ssim_loss = fused_ssim_with_reduction( + rearrange(output_renderer.color, "b v c h w -> (b v) c h w"), + rearrange(gt_images_for_ssim, "b v c h w -> (b v) c h w"), + padding="valid", + reduction=reduction, + loss=True, # returns mean(1 - ssim), i.e. the SSIM loss + ) + return 0.8 * l1_loss + 0.2 * ssim_loss + + +def squeeze_grad_dict(grad_dict): + for k, v in grad_dict.items(): + if v is not None: + grad_dict[k] = v.squeeze(0) + return grad_dict + + +def smooth_grads(grads: dict, smoothers: dict) -> dict: + smoothed_grads = {} + for k, v in grads.items(): + if k not in smoothers: + continue + else: + if v is not None: + smoothed_grads[k] = smoothers[k](v) + else: + smoothed_grads[k] = None + return smoothed_grads + +def chunk_ranges(v: int, chunk_size: int) -> List[Tuple[int, int]]: + """ + Return a list of (start, stop) index ranges that partition [0, v). + Last chunk may be smaller if v % chunk_size != 0. + Example: chunk_ranges(10, 4) -> [(0,4),(4,8),(8,10)] + """ + if chunk_size <= 0: + raise ValueError("chunk_size must be > 0") + ranges = [] + start = 0 + while start < v: + stop = min(start + chunk_size, v) + ranges.append((start, stop)) + start = stop + return ranges + +def chunk_slices(v: int, chunk_size: int, dim: int = 1) -> List[slice]: + """ + Return a list of slice objects that slice along axis `dim`. + Use like: tensor[(slice(None), slice_start_stop, ...)] — easier: use helper below. + NOTE: slice objects don't encode the axis; they only give start/stop; see usage. + """ + return [slice(s, e) for s, e in chunk_ranges(v, chunk_size)] + +def chunk_index_iter(v: int, chunk_size: int) -> Iterator[Tuple[int,int,int]]: + """ + Iterate chunk info as (chunk_idx, start, stop) for convenience. + """ + for idx, (s, e) in enumerate(chunk_ranges(v, chunk_size)): + yield idx, s, e + + + +def fused_ssim_with_reduction(img1, img2, padding="same", train=True, reduction="mean", loss=False): + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + assert padding in allowed_padding + + img1 = img1.contiguous() + ssim_map = FusedSSIMMap.apply(C1, C2, img1, img2, padding, train) # [v c h w] + + if loss: + ssim_map = 1 - ssim_map + + if reduction == "mean": + return ssim_map.mean() + elif reduction == "sum": + return ssim_map.sum() + elif reduction == "mean_pixels_sum_views": + # Mean over spatial (h, w) and channel (c) dims, then sum over views (v) + return ssim_map.mean(dim=(-1, -2, -3)).sum(dim=-1) + else: + raise ValueError(f"Unsupported reduction: {reduction}") + + +def calc_input_gradients( + iter_context, + prev_means, + prev_scales_raw, + prev_rotations_unnorm, + prev_opacities_raw, # [B, N] — may be a non-leaf view of gaussians.opacities + prev_shs, # [B, N, 3, sh_d] + renderer: Decoder, + need_2d_grads: bool, + chunk_size: int | None, + any_adc: bool = True, + sh_degree: int | None = None, + meta_bufs: dict | None = None, # mutable dict populated/reused across calls for radii & visibility + loss_reduction: str = "mean", + loss_with_ssim: bool = True, + opacity_reg_lambda: float = 0.0, # L1 opacity regularization weight (3DGS-MCMC) +) -> tuple[Tensor, dict[str, Tensor], dict[str, Tensor | None] | None]: + + b, v, _, h, w = iter_context["image"].shape + assert b == 1, "Batch size > 1 not supported for post-processing" + + if chunk_size == -1: + chunk_size = v + nr_chunks = math.ceil(v / chunk_size) + N = prev_means.shape[1] + device = prev_means.device + + # --- Grad setup --- + # Gradients are obtained functionally via torch.autograd.grad below, so .grad buffers + # are never read or written. Enable requires_grad on the leaf params as a fallback if + # the caller did not already set it up. Order matters: it defines the autograd.grad + # input order and therefore the order of the returned per-param gradients. + _leaf_params = [prev_means, prev_scales_raw, prev_rotations_unnorm, prev_opacities_raw, prev_shs] + for t in _leaf_params: + if not t.requires_grad: + t.requires_grad_(True) + + # --- Allocate or reuse radii / visibility buffers (only needed when any_adc) --- + bufs_valid = ( + any_adc + and meta_bufs is not None + and meta_bufs.get("N") == N + and meta_bufs.get("v") == v + ) + if bufs_valid: + radii_all = meta_bufs["radii"] + visibility_all = meta_bufs["visibility"] + means2d_grads_all = meta_bufs.get("means2d_grads") + if need_2d_grads and means2d_grads_all is None: + means2d_grads_all = torch.empty((b, v, N, 2), dtype=torch.float32, device=device) + meta_bufs["means2d_grads"] = means2d_grads_all + elif any_adc: + radii_all = torch.empty((b, v, N, 2), dtype=torch.float32, device=device) + visibility_all = torch.empty((b, v, N), dtype=torch.bool, device=device) + means2d_grads_all = ( + torch.empty((b, v, N, 2), dtype=torch.float32, device=device) + if need_2d_grads else None + ) + if meta_bufs is not None: + meta_bufs.update({"N": N, "v": v, "radii": radii_all, + "visibility": visibility_all, "means2d_grads": means2d_grads_all}) + else: + radii_all = visibility_all = means2d_grads_all = None + + # --- Forward + autograd.grad loop --- + # Per-chunk gradients for the leaf params are summed here, then averaged below. + accumulated_grads: list[Tensor] | None = None + with torch.enable_grad(): + assert not torch.is_inference_mode_enabled() + + for chunk_idx, start, stop in tqdm(chunk_index_iter(v, chunk_size), disable=nr_chunks <= 1, + desc="Computing input gradients in chunks"): + image_chunk = iter_context["image"][:, start:stop] + extrinsics_chunk = iter_context["extrinsics"][:, start:stop] + intrinsics_chunk = iter_context["intrinsics"][:, start:stop] + near_chunk = iter_context["near"][:, start:stop] + far_chunk = iter_context["far"][:, start:stop] + + prev_opacities = torch.sigmoid(prev_opacities_raw) + prev_scales = torch.exp(prev_scales_raw) + prev_rotations = F.normalize(prev_rotations_unnorm, dim=-1) + + if sh_degree is not None: + prev_shs_for_render = prev_shs[..., :(sh_degree + 1) ** 2] + else: + prev_shs_for_render = prev_shs + + tmp_gaussians = Gaussians( + means=prev_means, + covariances=None, + harmonics=prev_shs_for_render, + opacities=prev_opacities, + scales=prev_scales, + rotations=prev_rotations, + rotations_unnorm=prev_rotations_unnorm, + stores_activated=True, + ) + + if GPU_MEM_PROFILING: + output_renderer: DecoderOutput = profile_gpu_memory( + fn=renderer.forward, gaussians=tmp_gaussians, + extrinsics=extrinsics_chunk, intrinsics=intrinsics_chunk, + near=near_chunk, far=far_chunk, image_shape=(h, w)) + else: + output_renderer: DecoderOutput = renderer.forward( + gaussians=tmp_gaussians, + extrinsics=extrinsics_chunk, intrinsics=intrinsics_chunk, + near=near_chunk, far=far_chunk, image_shape=(h, w)) + + loss = inner_loss_for_input_gradients(image_chunk, output_renderer, + reduction=loss_reduction, with_ssim=loss_with_ssim) + + # L1 opacity regularization (3DGS-MCMC) folded into the differentiated loss. + grad_loss = loss + if opacity_reg_lambda > 0.0: + grad_loss = loss + opacity_reg_lambda * torch.sigmoid(prev_opacities_raw).mean() + + grad_inputs = list(_leaf_params) + if need_2d_grads: + assert output_renderer.means2d is not None + grad_inputs.append(output_renderer.means2d) + + chunk_grads = torch.autograd.grad(grad_loss, grad_inputs, + create_graph=False, retain_graph=False) + + param_grads = [g.detach() for g in chunk_grads[:5]] + if accumulated_grads is None: + accumulated_grads = param_grads + else: + accumulated_grads = [a + g for a, g in zip(accumulated_grads, param_grads)] + + # store per-chunk meta + if any_adc: + radii_all[:, start:stop] = output_renderer.radii + visibility_all[:, start:stop] = output_renderer.visibility_filter + if need_2d_grads: + means2d_grads_all[:, start:stop] = chunk_grads[5].detach() + + # --- Average grads for multi-chunk --- + if nr_chunks > 1: + inv = 1.0 / nr_chunks + accumulated_grads = [g * inv for g in accumulated_grads] + + means_grads, scales_raw_grads, rotations_unnorm_grads, opacities_raw_grads, harmonics_grads = accumulated_grads + + sh0s_grads = harmonics_grads[..., 0:1] + shNs_grads = harmonics_grads[..., 1:] if harmonics_grads.shape[-1] > 1 else None + + grads = { + "means": means_grads, + "scales": scales_raw_grads, + "rotations": rotations_unnorm_grads, + "opacities": opacities_raw_grads, + "sh0s": sh0s_grads, + "shNs": shNs_grads, + } + + meta_for_adc = { + "visibility_filter": visibility_all, + "radii": radii_all, + "means_2d_grads": means2d_grads_all if need_2d_grads else None, + } if any_adc else None + + return loss, grads, meta_for_adc + + +def unpack_gaussians( + gaussians: Gaussians, + scales_log: bool, + opacities_logit: bool, + opacities_unsqueeze: bool, + detach: bool = True, + clone: bool = False, + requires_grad: bool = False, + scales_lims: tuple | None = None, # post activation (1e-6, 3) + raw_opacities_lims: tuple | None = None, # pre activation (-7, 7) +): + """ Unpack Gaussian parameters and invert opacities and scales. + + # TODO Naama: fix this + Clamp values for scales are in post-activation space, i.e., after exponentiation. + Clamp values for opacities are in pre-activation space, i.e., before sigmoid + + """ + + # Means + means = gaussians.means # [B, N, 3] + + # Scales + scales = gaussians.scales # [B, N, 3] + if scales_lims is not None: + scales = torch.clamp(scales, scales_lims[0], scales_lims[1]) + # if self.cfg.opt_scales_before_act: + if scales_log: + # Invert also scales + scales = torch.log(scales + 1e-8) + + # Quaternions + # use unnormalized rotations since we are going to refine the unnormed rotations + rotations_unnorm = gaussians.rotations_unnorm # [B, N, 4] + + # Opacities + # before sigmoid, eps is necessary, otherwise might be nan + if opacities_logit: + opacities_raw = torch.logit(gaussians.opacities, eps=1e-7) # [B, N] + if raw_opacities_lims is not None: + opacities_raw = torch.clamp(opacities_raw, raw_opacities_lims[0], raw_opacities_lims[1]) + else: + opacities_raw = gaussians.opacities # [B, N] + + if opacities_unsqueeze: + opacities_raw = opacities_raw.unsqueeze(-1) # [B, N, 1] + + # SHs - use flatten instead of rearrange for speed + shs = gaussians.harmonics # [B, N, 3, 9] + shs = shs.flatten(-2) # [B, N, C] - faster than rearrange + + if gaussians.sel is not None: + # TODO Naama: move method to Gaussians class + sel = gaussians.sel # [B, N] + means = means[:, sel] + opacities_raw = opacities_raw[:, sel] + rotations_unnorm = rotations_unnorm[:, sel] + scales = scales[:, sel] + shs = shs[:, sel] + + if detach: + means = means.detach() + opacities_raw = opacities_raw.detach() + rotations_unnorm = rotations_unnorm.detach() + scales = scales.detach() + shs = shs.detach() + + if clone: + means = means.clone() + opacities_raw = opacities_raw.clone() + rotations_unnorm = rotations_unnorm.clone() + scales = scales.clone() + shs = shs.clone() + + if requires_grad: + means.requires_grad_(True) + opacities_raw.requires_grad_(True) + rotations_unnorm.requires_grad_(True) + scales.requires_grad_(True) + shs.requires_grad_(True) + + # # predicting multiple gaussians per point, init new gaussians by copying with scaled opacities + # if self.cfg.reinit_gaussian_when_refine_multiple and self.cfg.refine_gaussian_multiple > 1: + # raise NotImplementedError + # # This should only be called at the first iteration + # # TODO Naama: might be bug if we use replay buffer + # repeat = self.cfg.refine_gaussian_multiple + # prev_means = prev_means.repeat(1, repeat, 1) + # prev_scales = prev_scales.repeat(1, repeat, 1) + # prev_rotations_unnorm = prev_rotations_unnorm.repeat(1, repeat, 1) + # + # # scale down opacities + # prev_opacities_raw = prev_opacities_raw.repeat(1, repeat, 1) # smaller opacities, important + # # Given y = sigmoid(x), to get new x' such that sigmoid(x') = y / K: + # # The formula is: x' = x + log((1 - y) / (K - y)) + # # This adjusts x so that the sigmoid output is scaled down by a factor of K + # tmp_sigmoid = prev_opacities_raw.sigmoid() + # # print(tmp_sigmoid.mean().item()) + # prev_opacities_raw = prev_opacities_raw + torch.log((1 - tmp_sigmoid) / (repeat - tmp_sigmoid)) + # + # prev_shs = prev_shs.repeat(1, repeat, 1) + # + # # TODO: this part not ready + + return means, scales, rotations_unnorm, opacities_raw, shs + + +def get_gaussian_param_slices(sh_d: int) -> dict: + """Return index slices for each Gaussian parameter group in the packed vector. + + Layout (must match pack_gaussians): + [means(3) | scales(3) | quats(4) | opacities(1) | shs(3*sh_d)] + """ + sh_end = 11 + 3 * sh_d + return { + "means": slice(0, 3), + "scales": slice(3, 6), + "quats": slice(6, 10), + "opacities": slice(10, 11), + "sh0": slice(11, sh_end, sh_d), + "shN": [i for i in range(11, sh_end) if (i - 11) % sh_d != 0], + } + + +def get_gaussian_param_sizes(sh_d: int) -> dict: + """Return the element count for each Gaussian parameter group. + + Layout matches pack_gaussians / get_gaussian_param_slices: + [means(3) | scales(3) | quats(4) | opacities(1) | shs(3*sh_d)] + """ + return { + "means": 3, + "scales": 3, + "quats": 4, + "opacities": 1, + "shs": 3 * sh_d, + } + + +def pack_gaussians( + means: Tensor, + scales: Tensor, + rotations_unnorm: Tensor, + opacities_raw: Tensor, + shs: Tensor, +) -> Tensor: + """Concatenate unpacked Gaussian parameters into a single [B, N, C] vector. + + Layout (must match get_gaussian_param_slices): + [means(3) | scales(3) | quats(4) | opacities(1) | shs(3*sh_d)] + """ + return torch.cat((means, scales, rotations_unnorm, opacities_raw, shs), dim=-1) + + +def get_visibility_contribution_from_gaussian_obj(views_info, gaussians, image_shape=None, render_image=False) -> tuple[Tensor, dict]: + """ + Args: + views_info: dict containing: + "extrinsics": Tensor of shape [B, V, 4, 4] + "intrinsics": Tensor of shape [B, V, 3, 3] + "image": Tensor of shape [B, V, C, H, W] (Optional, only for shape reference) + "near": Tensor of shape [B, 1] + "far": Tensor of shape [B, 1] + gaussians: Gaussian object containing: + .means: Tensor of shape [B, N, 3] + .rotations_unnorm: Tensor of shape [B, N, 4] + .scales: Tensor of shape [B, N, 3] + .opacities: Tensor of shape [B, N] + image_shape: Optional tuple (width, height). If None, use the shape from views_info["image"]. + Returns a (N,) shaped tensor whose entry k is the visibility contribution of the k-th Gaussian. + out[k] = sum_{c,i,j}^{C, H, W} w_{k,c,i,j} + ️ + """ + # Context can be either context or target + # TODO Naama: check visibility for both context and target views + b = gaussians.means.shape[0] + assert b == 1 + # Data preparation + means = gaussians.means[0] # [N, 3] + + # Not sure why, the rendering uses it and says the rastereization will normalize + quats = gaussians.rotations_unnorm[0] + quats = quats[:, [3, 0, 1, 2]] # [N, 4] # xyzw to wxyz + + scales = gaussians.scales[0] # [N, 3] + + opacities = gaussians.opacities[0] # [N] + + viewmats = views_info["extrinsics"][0] # [V, 4, 4] + viewmats = viewmats.inverse() + + Ks = views_info["intrinsics"][0].clone() # [V, 3, 3] + if image_shape is not None: + width, height = image_shape + else: + width = views_info["image"].shape[-1] + height = views_info["image"].shape[-2] + Ks[:, 0] *= width + Ks[:, 1] *= height + + near = views_info["near"][0, 0].item() + far = views_info["far"][0, 0].item() + + with torch.no_grad(): + weight_vis_contribution, info = get_gaussians_visibility_contribution( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + viewmats=viewmats, + Ks=Ks, + width=width, + height=height, + near_plane=near, + far_plane=far, + eps2d=0.1, + rasterize_mode="antialiased", + ) + + return weight_vis_contribution, info + + +def get_gaussians_visibility_contribution( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + opacities: Tensor, # [N] + viewmats: Tensor, # [V, 4, 4] + Ks: Tensor, # [V, 3, 3] + width: int, + height: int, + # set these as in your render function + near_plane: float = 0.01, + far_plane: float = 1e10, + eps2d: float = 0.3, + tile_size: int = 16, + rasterize_mode: Literal["classic", "antialiased"] = "antialiased", + batch_per_iter: int = 100, +) -> tuple[Tensor, dict]: + """ + Returns a (N,) shaped tensor whose entry k is the visibility contribution of the k-th Gaussian. + out[k] = sum_{c,i,j}^{C, H, W} w_{k,c,i,j} + """ + N = means.shape[0] + V = viewmats.shape[0] + assert means.shape == (N, 3), means.shape + assert quats.shape == (N, 4), quats.shape + assert scales.shape == (N, 3), scales.shape + assert opacities.shape == (N,), opacities.shape + assert viewmats.shape == (V, 4, 4), viewmats.shape + assert Ks.shape == (V, 3, 3), Ks.shape + + # Project Gaussians to 2D. + # The results are with shape [V, N, ...]. Only the elements with radii > 0 are valid. + radii, means2d, depths, conics, compensations = fully_fused_projection( + means=means, + covars=None, + quats=quats, + scales=scales, + viewmats=viewmats, + Ks=Ks, + width=width, + height=height, + eps2d=eps2d, + near_plane=near_plane, + far_plane=far_plane, + calc_compensations=(rasterize_mode == "antialiased"), + ) + + # import matplotlib.pyplot as plt + # view_id = 0 # choose a view to inspect + # image = torch.ones((3, height, width)) # [3, H, W] + # image = image.permute(1, 2, 0) + # image = (image * 255).clamp(0, 255).byte().cpu().detach().numpy() + # + # # Get 2D projected points and depth + # x = means2d[view_id, :, 0].cpu().detach().numpy() + # y = means2d[view_id, :, 1].cpu().detach().numpy() + # + # # Optional: mask out invalid points (e.g., outside image or radius == 0) + # H, W = image.shape[:2] + # valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H) + # + # # Plot + # plt.figure(figsize=(10, 10)) + # plt.imshow(image) # Background image + # plt.scatter(x[valid_mask], y[valid_mask], c=means[:, -1][valid_mask].cpu().detach().numpy(), cmap='viridis', s=2) + # # plt.gca().invert_yaxis() # Optional: for image coordinate convention + # plt.title("Overlay: Projected Gaussians (colored by depth)") + # plt.colorbar(label="Depth") + # plt.show() + + + opacities = opacities.repeat(V, 1) # [V, N] + + if compensations is not None: + opacities = opacities * compensations + + # Identify intersecting tiles + tile_width = math.ceil(width / float(tile_size)) + tile_height = math.ceil(height / float(tile_size)) + tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( + means2d, + radii, + depths, + tile_size, + tile_width, + tile_height, + packed=False, + n_images=V, + image_ids=None, + gaussian_ids=None, + ) + isect_offsets = isect_offset_encode(isect_ids, V, tile_width, tile_height) + + vis_contributions_sum, render_alphas, gaussian_weights_per_view = _gaussians_vis_contribution( + means2d, + conics, + opacities, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + batch_per_iter=batch_per_iter, + ) # (N,) + + return vis_contributions_sum, {"alphas": render_alphas, + "radii": radii, + "means2d": means2d, + "conics": conics, + "depths": depths, + "weights_per_view": gaussian_weights_per_view} # (N,) + + +def _gaussians_vis_contribution( + means2d: Tensor, # [V, N, 2] + conics: Tensor, # [V, N, 3] + opacities: Tensor, # [V, N] + image_width: int, + image_height: int, + tile_size: int, + isect_offsets: Tensor, # [V, tile_height, tile_width] + flatten_ids: Tensor, # [n_isects] + batch_per_iter: int = 100, +): + V, N = means2d.shape[:2] + n_isects = len(flatten_ids) + device = means2d.device + + render_alphas = torch.zeros((V, image_height, image_width, 1), device=device) + gaussian_weights = torch.zeros(N, dtype=opacities.dtype, device=device) + gaussian_weights_per_view = torch.zeros((V, N), dtype=opacities.dtype, device=device) + + # Split Gaussians into batches and iteratively accumulate the renderings + block_size = tile_size * tile_size + isect_offsets_fl = torch.cat( + [isect_offsets.flatten(), torch.tensor([n_isects], device=device)] + ) + max_range = (isect_offsets_fl[1:] - isect_offsets_fl[:-1]).max().item() + num_batches = (max_range + block_size - 1) // block_size + total_pixels = V * image_height * image_width + + # Pre-allocate accumulator reused across loop iterations to avoid per-step allocation + out = torch.zeros(N, dtype=opacities.dtype, device=device) + + # Loop over batches of Gaussians + for step in range(0, num_batches, batch_per_iter): + # Current transmittance + transmittances = 1.0 - render_alphas[..., 0] + + gs_ids, image_ids, indices, pixel_ids, weights = get_m_intersection_weights(batch_per_iter, conics, flatten_ids, + image_height, image_width, + isect_offsets, means2d, opacities, + step, tile_size, total_pixels, + transmittances) + + # Sum weights over gaussian indices (reuse pre-allocated buffer) + out.zero_() + out.index_add_(0, gs_ids, weights) # (N,) + gaussian_weights_per_view[image_ids, gs_ids] += weights + + # Add to the global sum + gaussian_weights += out + + # Accumulate alpha along rays + alphas = accumulate_along_rays( + weights, None, ray_indices=indices, n_rays=total_pixels + ) + alphas = alphas.reshape(V, image_height, image_width, 1) + + render_alphas.add_(alphas * transmittances[..., None]) + + return gaussian_weights, render_alphas, gaussian_weights_per_view + + +def get_m_intersection_weights(range_size, conics, flatten_ids, image_height, image_width, isect_offsets, means2d, + opacities, step, tile_size, total_pixels, transmittances): + # Find the M intersections between pixels and gaussians. + # Each intersection corresponds to a tuple (gs_id, pixel_id, camera_id) + gs_ids, pixel_ids, image_ids = rasterize_to_indices_in_range( + step, + step + range_size, + transmittances, + means2d, + conics, + opacities, + image_width, + image_height, + tile_size, + isect_offsets, + flatten_ids, + ) # [M], [M] + # if len(gs_ids) == 0: + # break + # Compute gaussian-pixel alpha values (reduced opacity due to gaussian intensity in 2D) -> (M,) + pixel_ids_x = pixel_ids % image_width + pixel_ids_y = pixel_ids // image_width + pixel_coords = torch.stack([pixel_ids_x, pixel_ids_y], dim=-1) + 0.5 # [M, 2] + deltas = pixel_coords - means2d[image_ids, gs_ids] # [M, 2] + c = conics[image_ids, gs_ids] # [M, 3] + sigmas = ( + 0.5 * (c[:, 0] * deltas[:, 0] ** 2 + c[:, 2] * deltas[:, 1] ** 2) + + c[:, 1] * deltas[:, 0] * deltas[:, 1] + ) # [M] + alphas = opacities[image_ids, gs_ids] * torch.exp(-sigmas) + # alphas = torch.clamp_max( + # opacities[image_ids, gs_ids] * torch.exp(-sigmas), 0.999 + # ) + if (alphas > 1).any(): + warnings.warn(f"Not all alphas <= 1, max alpha: {alphas.max().item()}") + # indices of the samples with shape (all_samples,) + indices = image_ids * image_height * image_width + pixel_ids # (M,) + # `weights` is a flattened tensor with shape (all_samples,) + weights, _ = render_weight_from_alpha( + alphas, ray_indices=indices, n_rays=total_pixels + ) # (M,) + return gs_ids, image_ids, indices, pixel_ids, weights + + +@dataclass +class Base3DGSAttributeCfg(Generic[T]): + _base: T + _means: T + _scales: T + _opacities: T + _quats: T + _sh0: T + _shN: T + + @property + def base(self) -> T: + return self._base + + @property + def means(self) -> T: + return self.base * self._means + + @property + def scales(self) -> T: + return self.base * self._scales + + @property + def opacities(self) -> T: + return self.base * self._opacities + + @property + def quats(self) -> T: + return self.base * self._quats + + @property + def rotations(self) -> T: + return self.quats + + @property + def sh0(self) -> T: + return self.base * self._sh0 + + @property + def shN(self) -> T: + return self.base * self._shN + + @property + def param_names(self) -> list[str]: + return ['means', 'scales', 'quats', 'opacities', 'sh0', 'shN'] + + def dict(self): + return {name: getattr(self, name) for name in self.param_names} + + +@dataclass +class Bool3DGSCfg(Base3DGSAttributeCfg[bool]): + # Config loading via dacite doesn't seem to support generic type, so need to write types explicitly + _base: bool + _means: bool + _scales: bool + _opacities: bool + _quats: bool + _sh0: bool + _shN: bool + + def all_true(self): + # return all attributes that are True + return all([getattr(self, attr) for attr in self.param_names]) + + def __str__(self): + if self.all_true: + return "all" + else: + return "_".join([f"{attr}" for attr in self.param_names if getattr(self, attr)]) + +class Number3DGSCfg(Base3DGSAttributeCfg[float | int]): + # Config loading via dacite doesn't seem to support generic type, so need to write types explicitly + _base: float | int + _means: float | int + _scales: float | int + _opacities: float | int + _quats: float | int + _sh0: float | int + _shN: float | int diff --git a/optgs/scene_trainer/optimizer/time_embed.py b/optgs/scene_trainer/optimizer/time_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..51cce544b0e448cd31e07308ecb8cdc7847cb0c0 --- /dev/null +++ b/optgs/scene_trainer/optimizer/time_embed.py @@ -0,0 +1,94 @@ +# From: https://github.com/ingra14m/Deformable-3D-Gaussians/blob/main/utils/time_utils.py + +import torch +import torch.nn as nn + + +def get_embedder(multires): + embed_kwargs = { + 'include_input': True, + 'input_dims': 1, # time steps are 1D + 'max_freq_log2': multires - 1, + 'num_freqs': multires, + 'log_sampling': True, + 'periodic_fns': [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + embed = lambda x, eo=embedder_obj: eo.embed(x) + return embed, embedder_obj.out_dim + + +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + + +class TimeEncodingWrapper: + def __init__(self, use_time_encoding, time_encoder_fn, t, T, state): + self.use_time_encoding = use_time_encoding + self.T = T + self.time_encoder_fn = time_encoder_fn + self.state = state + self.t = t + + def __enter__(self): + # We are modifying the state only inside the context manager + state = self.state + if self.use_time_encoding: + assert self.time_encoder_fn is not None, "Time encoder function must be defined." + + rel_step = torch.tensor([self.t / self.T], device=state.device) + + time_encoding = self.time_encoder_fn(rel_step) # [embedding_dim] + time_encoding = time_encoding.unsqueeze(0).repeat(state.shape[0], 1) # [N, embedding_dim] + + # Concatenate encoding to state + state = torch.cat([state, time_encoding], dim=-1) # [N, c+embedding_dim] + + return state # returns the modified state + + def __exit__(self, exc_type, exc_val, exc_tb): + # Do nothing, the original state is preserved outside the context manager + # Return False to propagate exceptions, if any + return False + + +if __name__ == "__main__": + # Example usage + embed_fn, output_dim = get_embedder(multires=6) + print(f"Output embedding dimension: {output_dim}") + steps = torch.randn(10, 1) # Example input (steps normalized between 0 and 1) + print(f"Input shape: {steps.shape}") + print("steps[0:2]:", steps[0:2]) + embedded_x = embed_fn(steps) + print(f"Embedded shape: {embedded_x.shape}") + print("embedded_x[0:2]:", embedded_x[0:2]) diff --git a/optgs/scene_trainer/postprocess/__init__.py b/optgs/scene_trainer/postprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/scene_trainer/postprocessing.py b/optgs/scene_trainer/postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..99c78755a13db00bf2c6ac84265767fb18bdc6fd --- /dev/null +++ b/optgs/scene_trainer/postprocessing.py @@ -0,0 +1,632 @@ +from dataclasses import dataclass, field +from typing import List +import tqdm as tqdm +import numpy as np +import torch +from torch import Tensor +import math +from pytorch_optimizer import load_optimizer +from torch.optim.lr_scheduler import LambdaLR +import torch.nn.functional as F +from einops import rearrange +from optgs.evaluation.metrics import compute_rgb_metrics +from optgs.misc.io import FrequencyScheduler +from optgs.scene_trainer.gaussian_module import GaussiansModule, gaussians2module, module2gaussians +from optgs.model.types import Gaussians +from optgs.scene_trainer.optimizer.optimizer import OptimizerOutput +from optgs.scene_trainer.optimizer.optimizer_utils import Number3DGSCfg +from optgs.misc.detaching_cpu_list import DetachingCPUList +from optgs.dataset.camera_datasets.camera import get_scene_scale +from optgs.misc.general_utils import get_expon_lr_func +from fused_ssim import fused_ssim +from optgs.model.decoder.decoder import Decoder, DecoderOutput + + +@dataclass +class PostProcessADCCfg: + """ADC (Adaptive Density Control) config for postprocessing. + Defaults match vanilla 3DGS (config/scene_trainer/scene_optimizer/refiner/default.yaml). + """ + do_densify: bool = True + do_prune: bool = True + do_opacity_reset: bool = True + + # Scheduling + pause_refine_after_reset: int = 0 + refine_every: int = 100 + reset_every: int = 3000 + refine_start_iter: int = 500 + refine_stop_iter: int = 15000 + refine_scale2d_stop_iter: int = 0 + + # Densification thresholds + grow_grad2d: float = 0.0002 + grow_scale3d: float = 0.01 # aka percent_dense + grow_scale2d: float = 0.05 + + # Pruning thresholds + prune_scale3d: float = 0.1 + prune_scale2d: float = 0.15 + min_opacity: float = 0.005 + + revised_opacity: bool = False + + +@dataclass +class PostProcessCfg: + name: str + steps: int + compute_metrics_every: int + lr_data: Number3DGSCfg + scheduler: str | None + scheduler_warm_up_ratio: float + + # SGD-specific + momentum: float = 0.0 + nesterov: bool = False + + # Adam-specific + betas: List[float] | None = None + eps: float = 1e-8 + amsgrad: bool = False + + # Shared + weight_decay: float = 0.0 + + # LR scheduling: steps already done by scene trainer (offsets the schedule) + prior_steps: int = 0 + + # Means LR scheduling (defaults match vanilla optimizer behavior) + means_lr_final_ratio: float = 0.0625 # ratio of final/initial means LR (vanilla: 1e-5 / 1.6e-4) + means_lr_delay_mult: float = 0.01 # ramp-up delay multiplier (vanilla default: 0.01) + means_lr_scale_by_scene_extent: bool = True # scale means LR by scene extent (vanilla default) + + # View chunking for gradient accumulation + chunk_size: int = -1 # -1 = all views at once + + # ADC (Adaptive Density Control) + adc: PostProcessADCCfg | None = None + + @property + def is_active(self) -> bool: + return self.name != "none" and self.steps > 0 + + def get_dir_name(self, with_name=True): + dir_str = self._get_dir_name() + return f"{self.name}_{dir_str}" if with_name else dir_str + + def _get_dir_name(self): + if self.name == "sgd": + return f"lr{self.lr_data.base}_mom{self.momentum}" + elif self.name == "adam": + return f"lr{self.lr_data.base}_betas{'-'.join(map(str, self.betas or []))}_eps{self.eps}" + return "" + + +def _module_to_deactivated_gaussians(gm: GaussiansModule) -> Gaussians: + """Convert GaussiansModule to Gaussians with deactivated (raw) values for ADC.""" + return Gaussians( + means=gm.means.detach().unsqueeze(0), + scales=gm.scales_raw.detach().unsqueeze(0), # log space + opacities=gm.opacities_raw.detach().unsqueeze(0), # logit space + rotations=gm.rotations.detach().unsqueeze(0), + rotations_unnorm=gm.rotations_unnorm.detach().unsqueeze(0), + harmonics=gm.harmonics.detach().unsqueeze(0), + stores_activated=False, + ) + + +def _deactivated_gaussians_to_module(gaussians: Gaussians, device: torch.device) -> GaussiansModule: + """Convert deactivated Gaussians back to GaussiansModule.""" + assert not gaussians.stores_activated + return GaussiansModule( + means=gaussians.means[0].to(device), + harmonics=gaussians.harmonics[0].to(device), + opacities=torch.sigmoid(gaussians.opacities[0]).to(device), + scales=torch.exp(gaussians.scales[0]).to(device), + rotations_unnorm=gaussians.rotations_unnorm[0].to(device), + ) + + +class PostProcessing3DGS: + + def __init__(self, cfg: PostProcessCfg, save_every: FrequencyScheduler): + self.cfg = cfg + self.save_every = save_every + + # Timing + self.iter_start = torch.cuda.Event(enable_timing=True) + self.iter_end = torch.cuda.Event(enable_timing=True) + + self.reset_logs() + + def reset_logs(self): + self.radii_max_log = [] + self.grads_max_log = [] + self.nr_cloned_log = [] + self.nr_splitted_log = [] + self.nr_pruned_log = [] + self.nr_gaussians_log = [] + self.nr_nonzero_grad_log = [] + self.iter_time_log = [] + + def _calc_loss( + self, context, output_renderer: DecoderOutput + ) -> Tensor: + # compute scalar loss + # assume batch size 1 + assert context["image"].shape[0] == 1 + assert context["image"].shape == output_renderer.color.shape + l1_render_error = (output_renderer.color - context["image"]).abs().mean() + + ssim_score = fused_ssim( + rearrange(output_renderer.color, "b v c h w -> (b v) c h w"), + rearrange(context["image"], "b v c h w -> (b v) c h w"), + padding="valid" + ) + loss = 0.8 * l1_render_error + 0.2 * (1 - ssim_score) + + return loss + + def _chunked_forward_backward(self, gaussian_module, iter_context, decoder, render_res, adc_state): + """Render views in chunks, accumulate gradients, and collect ADC metadata. + + Matches the gradient accumulation approach of calc_input_gradients in the vanilla optimizer: + each chunk computes a mean loss, gradients accumulate, then are averaged by nr_chunks. + """ + v = iter_context["image"].shape[1] + chunk_size = self.cfg.chunk_size if self.cfg.chunk_size > 0 else v + nr_chunks = math.ceil(v / chunk_size) + + # Accumulate means2d grads and radii for ADC across chunks + need_adc = adc_state is not None + h, w = render_res + if need_adc: + N = gaussian_module.means.shape[0] + means2d_grads_all = torch.zeros((1, v, N, 2), device=gaussian_module.means.device) + radii_all = torch.zeros((1, v, N, 2), device=gaussian_module.means.device) + visibility_all = torch.zeros((1, v, N), dtype=torch.bool, device=gaussian_module.means.device) + + for chunk_start in range(0, v, chunk_size): + chunk_end = min(chunk_start + chunk_size, v) + + # Slice views for this chunk + chunk_context = { + "image": iter_context["image"][:, chunk_start:chunk_end], + "extrinsics": iter_context["extrinsics"][:, chunk_start:chunk_end], + "intrinsics": iter_context["intrinsics"][:, chunk_start:chunk_end], + "near": iter_context["near"][:, chunk_start:chunk_end], + "far": iter_context["far"][:, chunk_start:chunk_end], + } + + # Render + chunk_output = decoder.forward_batch_subset(gaussian_module, chunk_context, render_res) + + # Retain means2d grad for ADC + if need_adc and chunk_output.means2d is not None: + chunk_output.means2d.retain_grad() + + # Loss and backward (gradients accumulate across chunks) + chunk_loss = self._calc_loss(chunk_context, chunk_output) + chunk_loss.backward() + + # Collect ADC metadata from this chunk + if need_adc: + if chunk_output.radii is not None: + radii_all[:, chunk_start:chunk_end] = chunk_output.radii.detach() + if chunk_output.visibility_filter is not None: + visibility_all[:, chunk_start:chunk_end] = chunk_output.visibility_filter.detach() + if chunk_output.means2d is not None and chunk_output.means2d.grad is not None: + means2d_grads_all[:, chunk_start:chunk_end] = chunk_output.means2d.grad.detach() + + # Average gradients across chunks (matches vanilla behavior) + if nr_chunks > 1: + for param in gaussian_module.parameters(): + if param.grad is not None: + param.grad /= nr_chunks + + # Return ADC metadata + if need_adc: + return { + "radii": radii_all, + "visibility_filter": visibility_all, + "means_2d_grads": means2d_grads_all, + } + return None + + def _apply_adc(self, step, gaussian_module, adc_state, device): + """Apply ADC (clone/split/prune/opacity reset) using the same logic as vanilla 3DGS. + + Returns (gaussian_module, optimizer_needs_rebuild). + """ + from optgs.scene_trainer.adc.vanilla import cloning, splitting, prune, reset_adc_state + + adc_cfg = self.cfg.adc + changed = False + nr_cloned, nr_splitted, nr_pruned = 0, 0, 0 + + # Convert to deactivated Gaussians for ADC (ADC functions expect Gaussians, not GaussiansModule) + gaussians = _module_to_deactivated_gaussians(gaussian_module) + + if step < adc_cfg.refine_stop_iter: + grads = adc_state.grad2d_norm_accum / adc_state.denom.clamp_min(1.0) + scene_extent = adc_state.scene_extent + + if ( + step >= adc_cfg.refine_start_iter + and step % adc_cfg.refine_every == 0 + and step % adc_cfg.reset_every >= adc_cfg.pause_refine_after_reset + ): + if adc_cfg.do_densify: + scales = torch.exp(gaussians.scales.squeeze(0)) # activate + is_grad_high = grads > adc_cfg.grow_grad2d + is_small = scales.max(dim=-1).values <= adc_cfg.grow_scale3d * scene_extent + + clone_mask = is_grad_high & is_small + split_mask = is_grad_high & ~is_small + + if step < adc_cfg.refine_scale2d_stop_iter: + split_mask |= adc_state.radii2d > adc_cfg.grow_scale2d + + # Clone + cloning(gaussians, adc_state, clone_mask) + nr_cloned = int(clone_mask.sum().item()) + + # Extend split_mask for newly cloned points (they should not be split) + split_mask = torch.cat([ + split_mask, + torch.zeros(nr_cloned, dtype=torch.bool, device=split_mask.device), + ]) + + # Split + splitting(gaussians, adc_state, split_mask, N=2, + revised_opacity=adc_cfg.revised_opacity) + nr_splitted = int(split_mask.sum().item()) + + changed = True + + if adc_cfg.do_prune: + opacities = torch.sigmoid(gaussians.opacities.squeeze(0)) # activate + scales = torch.exp(gaussians.scales.squeeze(0)) # activate + + prune_mask = opacities < adc_cfg.min_opacity + if step > adc_cfg.reset_every: + is_too_big = scales.max(dim=-1).values > adc_cfg.prune_scale3d * scene_extent + if step < adc_cfg.refine_scale2d_stop_iter: + is_too_big |= adc_state.radii2d > adc_cfg.prune_scale2d + prune_mask = prune_mask | is_too_big + + prune(gaussians, adc_state, prune_mask) + nr_pruned = int(prune_mask.sum().item()) + changed = True + + reset_adc_state(adc_state) + print( + f"ADC @ iter {step}: cloned {nr_cloned}, split {nr_splitted}, " + f"pruned {nr_pruned}, total {gaussians.means.shape[1]}" + ) + + # Opacity reset + if adc_cfg.do_opacity_reset: + if step % adc_cfg.reset_every == 0 and step > 0: + opacities = torch.sigmoid(gaussians.opacities) # activate + value = adc_cfg.min_opacity * 2.0 + new_opacities = torch.min(opacities, torch.ones_like(opacities) * value) + gaussians.opacities = torch.logit(new_opacities) # deactivate back + changed = True + print(f"Opacity reset @ iter {step}") + + self.nr_cloned_log.append(nr_cloned) + self.nr_splitted_log.append(nr_splitted) + self.nr_pruned_log.append(nr_pruned) + + if changed: + # Rebuild GaussiansModule from modified Gaussians + gaussian_module = _deactivated_gaussians_to_module(gaussians, device) + + return gaussian_module, changed + + @torch.no_grad() + def apply( + self, + batch, + gaussians: Gaussians, + decoder, + metrics=["psnr", "ssim"], + iter_batch_size: int = -1, + batchify_fn=None, + visualization_dump=None + ) -> OptimizerOutput | None: + + target_render_list = DetachingCPUList() + context_render_list = DetachingCPUList() + + if self.cfg.steps == 0: + return None + + # [Improvement 1] Calculate scene_scale from both context + target (matches vanilla optimizer) + camtoworlds_context = batch['context']['extrinsics'][0].cpu().numpy() # [Vc, 4, 4] + camtoworlds_target = batch['target']['extrinsics'][0].cpu().numpy() # [Vt, 4, 4] + camtoworlds = np.concatenate([camtoworlds_context, camtoworlds_target], axis=0) + scene_scale = get_scene_scale(camtoworlds) + print("scene_scale:", scene_scale) + + device = batch['context']['image'].device + + # convert Gaussians to GaussiansModule + gaussian_module = gaussians2module(gaussians, device=device) + + optimizer = self.get_optimizer(gaussian_module, scene_scale) + scheduler = self.get_scheduler(optimizer, scene_scale=scene_scale, prior_steps=self.cfg.prior_steps) + + # print all optimizer param groups + for i, param_group in enumerate(optimizer.param_groups): + print(f"Param group {i}: lr={param_group['lr']}, weight_decay={param_group.get('weight_decay', 0.0)}, requires_grad={param_group['params'][0].requires_grad}") + + assert batch["context"]["extrinsics"].shape[0] == batch["context"]["extrinsics"].shape[0] == 1, \ + "Batch size > 1 not supported for post-processing" + + nr_context_views, _, h, w = batch["context"]["image"][0].shape + + # controlling number of context views seen at each iteration (for rendering chunk size) + _iter_batch_size = iter_batch_size if iter_batch_size > 0 else nr_context_views + print("using iter_batch_size =", _iter_batch_size) + + render_res = (h, w) + + # [Improvement 3] Initialize ADC state if configured + adc_state = None + if self.cfg.adc is not None: + from optgs.scene_trainer.adc.vanilla import VanillaStrategyState + nr_points = gaussian_module.means.shape[0] + adc_state = VanillaStrategyState.initialize( + nr_points=nr_points, + device=device, + scene_extent=scene_scale, + ) + print(f"Initialized ADC state with {nr_points} points") + + # render before first step + context_render_output = decoder.forward_batch_subset(gaussian_module, batch["context"], render_res, iter_batch_size=_iter_batch_size) + context_render_list.append(context_render_output, detach_and_cpu=True) # initial rendering + + target_render_output = decoder.forward_batch_subset(gaussian_module, batch["target"], render_res, iter_batch_size=_iter_batch_size) + target_render_list.append(target_render_output, detach_and_cpu=True) # initial rendering + + # Reset viewpoint stack for fresh sampling in postprocessing + batch["context"].viewpoint_stack = None + + pbar = tqdm.tqdm(range(self.cfg.steps), desc=f"PP {self.cfg.name}", ncols=120) + pbar_postfix = {} + for i in pbar: + + self.iter_start.record() + + with torch.enable_grad(): + + # Log number of gaussians + self.nr_gaussians_log.append(gaussian_module.means.shape[0]) + + # reset gradients + optimizer.zero_grad() + + # Sample context views using the same strategy as the optimizer + iter_context, _ = batchify_fn(batch, "context") + + # [Improvement 4] Render in chunks, accumulate gradients, collect ADC metadata + meta_for_adc = self._chunked_forward_backward( + gaussian_module, iter_context, decoder, render_res, adc_state + ) + + # step + optimizer.step() + + # update scheduler + if scheduler is not None: + scheduler.step() + + # [Improvement 3] ADC: update state and apply densification/pruning + if adc_state is not None and meta_for_adc is not None: + from optgs.scene_trainer.adc.vanilla import update_vanilla_strategy_state + + v_chunk = iter_context["image"].shape[1] + update_vanilla_strategy_state( + adc_state, + radii_2d=meta_for_adc["radii"], + means2d_grads=meta_for_adc["means_2d_grads"], + visibility_mask=meta_for_adc["visibility_filter"], + v=v_chunk, + w=w, + h=h, + ) + + gaussian_module, adc_changed = self._apply_adc(i, gaussian_module, adc_state, device) + if adc_changed: + # Rebuild optimizer and scheduler after ADC changed Gaussian count + optimizer = self.get_optimizer(gaussian_module, scene_scale) + scheduler = self.get_scheduler( + optimizer, scene_scale=scene_scale, prior_steps=self.cfg.prior_steps + ) + # Fast-forward scheduler to current step + for _ in range(i + 1): + scheduler.step() if scheduler is not None else None + + # Timing + self.iter_end.record() + torch.cuda.synchronize() + + elapsed_time = self.iter_start.elapsed_time(self.iter_end) + self.iter_time_log.append(elapsed_time) + + if self.save_every(i + 1, tag="context"): + with torch.no_grad(): + context_render_output = decoder.forward_context(gaussian_module, batch, (h, w)) + context_render_list.append(context_render_output, detach_and_cpu=True) + context_rgb = context_render_output.color[0] # [Vc, 3, Hc, Wc] + ctx_scores: dict = compute_rgb_metrics( + rgb=context_rgb, + rgb_gt=batch["context"]["image"][0], + metrics=metrics, + iter_batch_size=iter_batch_size if "lpips" in metrics else -1 + ) + for k, v in ctx_scores.items(): + pbar_postfix[f"ctx_{k}"] = f"{v.item():.2f}" + + if self.save_every(i + 1, tag="target"): + with torch.no_grad(): + target_render_output = decoder.forward_target(gaussian_module, batch, (h, w)) + target_render_list.append(target_render_output, detach_and_cpu=True) + target_rgb = target_render_output.color[0] # [Vt, 3, Ht, Wt] + tgt_scores: dict = compute_rgb_metrics( + rgb=target_rgb, + rgb_gt=batch["target"]["image"][0], + metrics=metrics, + iter_batch_size=iter_batch_size if "lpips" in metrics else -1 + ) + for k, v in tgt_scores.items(): + pbar_postfix[f"tgt_{k}"] = f"{v.item():.2f}" + + pbar_postfix["gs"] = gaussian_module.means.shape[0] + pbar.set_postfix(pbar_postfix) + + if visualization_dump is not None and "grads" in visualization_dump: + self.debug_grads(gaussian_module, visualization_dump, i) + + # convert back to Gaussians + + postprocessed_gaussians = module2gaussians(gaussian_module) + postprocessed_gaussians_list = DetachingCPUList() + postprocessed_gaussians_list.append(postprocessed_gaussians, detach_and_cpu=True) + output = OptimizerOutput( + target_render_list=target_render_list, + context_render_list=context_render_list, + gaussian_list=postprocessed_gaussians_list, + info = {} + ) + + return output + + def debug_grads(self, gaussians: GaussiansModule, debug_dict, step): + if debug_dict["grads"] is None: + # First iteration, first scene + debug_dict["grads"] = [[]] + elif step == 0: + # New iteration, new scene + debug_dict["grads"].append([]) + + grads = [param.grad for name, param in gaussians.named_parameters() if param.grad is not None] + gaussian_num = gaussians.means.shape[0] + grads = [g.view(gaussian_num, -1) for g in grads] + grads = [g.detach().cpu() for g in grads] + grads = torch.cat(grads, dim=-1) # [num_gaussians, total_param_dim] + + debug_dict["grads"][-1].append(grads) + + def get_optimizer(self, gaussians: GaussiansModule, scene_scale: float): + + # TODO Naama: support different batch sizes + batch_size: int = 1 + + # Build params list (name, parameter, lr) + named_parameters = dict(gaussians.named_parameters()) + params = [] + for key in named_parameters.keys(): + lr_data_attr = key + lr_data_attr = lr_data_attr.replace("_raw", "") + lr_data_attr = lr_data_attr.replace("_unnorm", "") + params.append((key, named_parameters[key], getattr(self.cfg.lr_data, lr_data_attr))) + + world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + print(f"World size: {world_size}") + + BS = batch_size * world_size + # Build parameter groups for a single optimizer + param_groups = [ + { + "name": name, + "params": param, + "lr": lr * math.sqrt(BS), # individual learning rate + } + for name, param, lr in params + ] + + # Get other optimizer parameters + opt_params = self.extract_opt_params() + + # Manipulate opt_params with BS if needed + if "weight_decay" in opt_params: + opt_params["weight_decay"] *= BS + if "eps" in opt_params: + opt_params["eps"] /= math.sqrt(BS) + if "betas" in opt_params: + beta1, beta2 = opt_params["betas"] + opt_params["betas"] = (1 - BS * (1 - beta1), 1 - BS * (1 - beta2)) + + # Instantiate a single optimizer with all parameter groups + optimizer_class = load_optimizer(self.cfg.name) + optimizer = optimizer_class( + param_groups, + **opt_params + ) + + # Print out info for debugging + print("Optimizer with parameter groups:") + for i, group in enumerate(optimizer.param_groups): + print( + f"Group {i} ({group.get('name', 'unnamed')}): " + f"lr={group['lr']} params={len(group['params'])}" + ) + + return optimizer + + + _OPT_PARAMS = { + "sgd": ("momentum", "weight_decay", "nesterov"), + "adam": ("betas", "eps", "weight_decay", "amsgrad"), + } + + def extract_opt_params(self): + allowed = self._OPT_PARAMS.get(self.cfg.name, ()) + return {k: getattr(self.cfg, k) for k in allowed if getattr(self.cfg, k, None) is not None} + + def get_scheduler(self, optimizer, scene_scale: float = 1.0, prior_steps: int = 0): + if self.cfg.scheduler is None: + return None + + total_steps = prior_steps + self.cfg.steps + + if self.cfg.scheduler == "exponential": + print(f"Using exponential LR scheduler (total_steps={total_steps}, prior_steps={prior_steps})") + + # [Improvement 2] Per-param-group scheduling: + # - Means: exponential decay optionally scaled by scene_extent (matching vanilla optimizer) + # - Other params: constant LR + lambdas = [] + for group in optimizer.param_groups: + if group["name"] == "means" and self.cfg.means_lr_scale_by_scene_extent: + # Vanilla-style means LR: exponential decay with scene_extent scaling + base_lr = group["lr"] # initial means LR from param group + means_lr_func = get_expon_lr_func( + lr_init=base_lr * scene_scale, + lr_final=base_lr * scene_scale * self.cfg.means_lr_final_ratio, + lr_delay_mult=self.cfg.means_lr_delay_mult, + max_steps=total_steps, + ) + # LambdaLR computes: effective_lr = base_lr * lambda(step) + # We want: effective_lr = means_lr_func(step) + # So: lambda(step) = means_lr_func(step) / base_lr + _base_lr = base_lr # capture for closure + _func = means_lr_func + lambdas.append(lambda step, f=_func, b=_base_lr: f(step) / b) + else: + # Constant LR for all other param groups + lambdas.append(lambda step: 1.0) + + scheduler = LambdaLR(optimizer, lr_lambda=lambdas) + # Fast-forward to prior_steps so LR continues from where scene trainer left off + for _ in range(prior_steps): + scheduler.step() + return scheduler + + else: + raise ValueError(f"Unknown scheduler: {self.cfg.scheduler}") diff --git a/optgs/scene_trainer/scene_trainer.py b/optgs/scene_trainer/scene_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..83821a01b5a8104ad274d376e6519c5fc4cf2703 --- /dev/null +++ b/optgs/scene_trainer/scene_trainer.py @@ -0,0 +1,641 @@ +import math +import random +from pathlib import Path +from typing import Optional, Mapping, Any + +import torch +from einops import rearrange +from lightning_fabric.utilities import move_data_to_device +from torch import Tensor, nn +from tqdm import tqdm + +from optgs.dataset import DatasetCfg +from optgs.dataset.data_types import BatchedExample +from optgs.dataset.view_sampler.view_sampler_bounded_v2 import farthest_point_sample +from optgs.loss.loss_monodepth import get_monodepth_model +from optgs.misc.benchmarker import Benchmarker +from optgs.misc.io import FrequencyScheduler +from optgs.misc.step_tracker import StepTracker +from optgs.model.decoder.decoder import Decoder +from optgs.model.types import Gaussians +from optgs.paths import DEBUG +from optgs.scene_trainer.initializer import get_scene_initializer +from optgs.scene_trainer.initializer.initializer import InitializerOutput +from optgs.scene_trainer.optimizer import get_scene_optimizer +from optgs.scene_trainer.optimizer.optimizer import OptimizerInput, Optimizer, OptimizerOutput, OptimizerPreviousOutput +from optgs.scene_trainer.postprocessing import PostProcessing3DGS +from optgs.scene_trainer.scene_trainer_cfg import SceneTrainerCfg, TestCfg, TrainCfg +from optgs.scripts.dev.debugging_optimizer import debugging_convergence, debugging_invisible_gaussians + + +class SceneTrainer(nn.Module): + test_cfg: TestCfg + train_cfg: TrainCfg + scene_trainer_cfg: SceneTrainerCfg + decoder: Decoder + step_tracker: StepTracker | None + eval_data_cfg: Optional[DatasetCfg | None] + + def __init__( + self, + test_cfg: TestCfg, + train_cfg: TrainCfg, + scene_trainer_cfg: SceneTrainerCfg, + decoder: Decoder, + step_tracker: StepTracker | None, + eval_data_cfg: Optional[DatasetCfg | None] = None, + ) -> None: + super().__init__() + self.test_cfg = test_cfg + self.train_cfg = train_cfg + self.step_tracker = step_tracker + self.eval_data_cfg = eval_data_cfg + self.scene_trainer_cfg = scene_trainer_cfg + + # Set up the model + self.initializer = get_scene_initializer(scene_trainer_cfg.scene_initializer) + + # Scene trainer performs updates + if self.scene_trainer_cfg.num_update_steps > 0: + optimizer_save_every = FrequencyScheduler( + frequencies=self.test_cfg.save_every_freq, + steps=self.test_cfg.save_every_steps, + iters=self.test_cfg.save_at_iters, + last_step=self.scene_trainer_cfg.num_update_steps, + enable_context=self.test_cfg.eval_context_views, + ) + self.optimizer: Optimizer | None = get_scene_optimizer(scene_trainer_cfg.scene_optimizer) + self.optimizer.save_every = optimizer_save_every + else: + self.optimizer = None + + self.decoder = decoder + + self.benchmarker = Benchmarker() + + if self.train_cfg.monodepth_loss_weight > 0: + self.pretrained_monodepth = get_monodepth_model() + + if self.test_cfg.postprocessing is not None and self.test_cfg.postprocessing.is_active: + self.postprocess_save_every = FrequencyScheduler( + frequencies=self.test_cfg.save_every_freq, + steps=self.test_cfg.save_every_steps, + iters=self.test_cfg.save_at_iters, + last_step=self.test_cfg.postprocessing.steps, + enable_context=self.test_cfg.eval_context_views, + ) + self.postprocess = PostProcessing3DGS( + cfg=self.test_cfg.postprocessing, + save_every=self.postprocess_save_every + ) + else: + self.postprocess = None + + @property + def device(self): + # Use try/except to catch StopIteration explicitly rather than letting it + # propagate, which silently terminates PL's generator-based test loop. + try: + return next(self.parameters()).device + except StopIteration: + pass + try: + return next(self.buffers()).device + except StopIteration: + pass + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ): + """Load weights into initializer and optimizer, skipping non-learned strategies.""" + # Remove scene_trainer prefix from state dict keys if it exists + state_dict = {k.replace("scene_trainer.", ""): v for k, v in state_dict.items()} + + prefixes = {s.split(".")[0] for s in state_dict.keys()} + assert all([p in ["initializer", "optimizer"] for p in + prefixes]), f"State dict keys must start with 'initializer.' or 'optimizer.', got {prefixes}" + + if self.initializer.strategy == "learned": + initializer_state_dict = {k[len("initializer."):]: v for k, v in state_dict.items() if + k.startswith("initializer.")} + self.initializer.load_state_dict(initializer_state_dict, strict=strict) + + if self.optimizer is not None: + if self.optimizer.strategy == "learned": + optimizer_state_dict = {k[len("optimizer."):]: v for k, v in state_dict.items() if + k.startswith("optimizer.")} + self.optimizer.load_state_dict(optimizer_state_dict, strict=strict) + + def get_optimized_gaussians( + self, + batch: BatchedExample, + prev_output: InitializerOutput | OptimizerPreviousOutput, + curr_iter=0, + debug_dict=None, + num_update_steps=None, + disable_tqdm=False, + **kwargs, + ) -> OptimizerOutput: + """ + Optimize the Gaussians for a single scene in the batch. + Can be used for both training and testing. + Can handle both new scenes and continuing from the replay buffer. + + Args: + batch: BatchedExample, the input batch containing context and target views. + prev_output: InitializerOutput | OptimizerPreviousOutput. + If we optimize a new scene it will be of type InitializerOutput, the output from the initializer + containing initial Gaussians and optional features. + In this case, on_scene_start of the optimizer should transform the InitializerOutput to + OptimizerPreviousOutput. + If we resample from the replay buffer it will be of type OptimizerPreviousOutput, the output of each + intermidiate update step of the optimizer, which contain gaussians and optional state. + curr_iter: int, the current iteration. + Should be 0 when starting a new scene. + debug_dict: Optional[dict], a dictionary to store debug information. + num_update_steps: Optional[int], number of update steps to perform. If None, use default from config. + Returns: + OptimizerOutput: The output from the optimizer containing the optimized Gaussians and renderings for + intermediate optimization steps. + """ + + assert self.optimizer is not None, "Optimizer is not initialized." + + if num_update_steps is None: + num_update_steps = self.get_num_update_steps() + + optimizer_input = OptimizerInput( + context=batch["context"], # this is full context, not iter batch + target=batch["target"], + prev_output=prev_output, + renderer=self.decoder, + num_refine=num_update_steps, + iter_batch_size=self.scene_trainer_cfg.iter_batch_size, # For rendering in batches + debug_dict=debug_dict, + ) + + # Handles both new scenes (InitializerOutput) and replay buffer continuations (OptimizerPreviousOutput). + self.optimizer.validate_input(optimizer_input) + self.optimizer.scene_start_event_start.record() + self.optimizer.on_scene_start(optimizer_input) + self.optimizer.scene_start_event_end.record() + assert isinstance(optimizer_input.prev_output, OptimizerPreviousOutput), \ + f"Should be OptimizerPreviousOutput after on_scene_start, got {type(optimizer_input.prev_output)}" + + # Initialize empty output to store intermediate and final results + optimizer_output: OptimizerOutput = OptimizerOutput.empty(t=curr_iter) + optimizer_output.T = num_update_steps + + # Insert the initialization into position 0 of the output lists so downstream + # consumers (evaluation, plotting, replay buffer) can treat init uniformly with + # the optimizer steps. No-op when there's no init render to attach (train path + # before init renders are wired through; replay buffer continuations). + if isinstance(prev_output, InitializerOutput): + self._insert_init_into_output(optimizer_output, prev_output) + + # SH degree scheduling (inspired by gsplat simple_trainer): + # sh_degree_to_use = min(step // sh_degree_interval, max_sh_degree) + sh_degree_interval = self.scene_trainer_cfg.sh_degree_interval + if sh_degree_interval > 0: + max_sh_degree = int(math.sqrt(optimizer_input.prev_output.gaussians.harmonics.shape[-1])) - 1 + + # Loop over update steps + for step in tqdm(range(num_update_steps), + disable=(self.training or num_update_steps < 20 or DEBUG) or disable_tqdm, + total=num_update_steps): + + # Sample minibatch of context/target views and move to device + optimizer_input.context, batch_idx = self.batchify_views(batch, "context", self.device) + if batch_idx is not None: + optimizer_output.context_index_list.append(batch_idx) + optimizer_input.target, batch_idx = self.batchify_views(batch, "target", self.device) + if batch_idx is not None: + optimizer_output.target_index_list.append(batch_idx) + + # Build per-step kwargs, adding SH degree if scheduler is active + step_kwargs = dict(kwargs) + if sh_degree_interval > 0: + step_kwargs["sh_degree"] = min(step // sh_degree_interval, max_sh_degree) + + # Single optimization step + # Optimizer output is updated in place, but we return it for clarity + optimizer_output = self.optimizer( + step, + optimizer_input, + optimizer_output, + full_context=batch["context"], + full_target=batch["target"], + **step_kwargs + ) + optimizer_output.t += 1 + + # Sync GPU before reading scene_start elapsed time (events were recorded before the loop). + torch.cuda.synchronize() + self.optimizer.scene_start_ms = self.optimizer.scene_start_event_start.elapsed_time( + self.optimizer.scene_start_event_end + ) + + self.optimizer.on_scene_end() + + # Extract the last output (for replay buffer) + optimizer_output.last_prev_output = optimizer_input.prev_output + + return optimizer_output + + def batchify_views(self, scene_batch, input_str, device, batch_size=None): + """ + Sample a subset of views from the batch for the current optimization step. + + Args: + scene_batch: Full batch containing context/target views + input_str: "context" or "target" + device: Target device for the subset + batch_size: Override batch size. If None, uses config-based batch size. + + Returns: + Tuple of (subset_batch, indices) where indices is None if no subsampling + """ + scene_batch_split = scene_batch[input_str] + + # Determine batch size (may be randomized during training) + if batch_size is None: + batch_size = self._get_batch_size() + v_all = scene_batch_split["image"].shape[1] + if batch_size <= 0 or batch_size >= v_all: + return scene_batch_split, None + + strategy = self.scene_trainer_cfg.opt_batch_strategy + views_idxs = self._sample_indices(scene_batch_split, batch_size, strategy) # [scene_batch, views_batch] + views_batch = scene_batch_split.batchify_views(views_idxs) + views_batch = move_data_to_device(views_batch, device) + return views_batch, views_idxs + + def _get_batch_size(self) -> int: + """Determine the batch size, potentially randomized during training.""" + batch_size = self.scene_trainer_cfg.opt_batch_size + + # Randomize batch size if configured (training or promoting buffer) + if self.scene_trainer_cfg.opt_batch_size_max > 0: + if self.training or self.promoting_buffer_sample: + batch_size = random.randint( + self.scene_trainer_cfg.opt_batch_size_min, + self.scene_trainer_cfg.opt_batch_size_max + ) + return batch_size + + def _sample_indices(self, batch_split, views_batch_size: int, strategy: str) -> torch.Tensor: + """Sample a minibatch of view indices using the configured strategy. + Uses viewpoint_stack to cycle through all views before reshuffling.""" + + # Initialize or reset viewpoint stack for new epoch + batch_split.reset_viewpoint_stack_if_needed(strategy, views_batch_size) + viewpoint_stack = batch_split.viewpoint_stack # [B, V] + scene_batch, v = viewpoint_stack.shape + + views_batch_size = min(views_batch_size, v) + + if strategy in ["random", "sequential"]: + # Take views from the front of the stack (shuffled if random) + batch_idxs = viewpoint_stack[:, :views_batch_size] + idx_to_remove = batch_idxs + + elif strategy == "neighbors": + # Use first view in stack as center, select its neighbors + extrinsics = batch_split["extrinsics"] + if extrinsics.ndim == 4: # [B, V, 4, 4] + assert extrinsics.shape[0] == 1, "Batch size must be 1 for neighbor sampling" + extrinsics = extrinsics[0] + + center_idx = viewpoint_stack[0, 0] + batch_idxs = self._get_neighbor_indices(extrinsics, center_idx, views_batch_size) + idx_to_remove = torch.tensor([[center_idx]]) # Only remove center from stack + elif strategy == "fps": + # FPS on camera positions of the remaining views in the stack + extrinsics = batch_split["extrinsics"] # [B, V_total, 4, 4] + B = extrinsics.shape[0] + batch_arange = torch.arange(B, device=self.device)[:, None] + stack_positions = extrinsics[batch_arange, viewpoint_stack][:, :, :3, 3] # [B, V_stack, 3] + fps_local_idxs = farthest_point_sample(stack_positions, views_batch_size, + first_idx_strategy="random") # [B, K] + batch_idxs = viewpoint_stack[batch_arange, fps_local_idxs] # [B, K] + idx_to_remove = batch_idxs + else: + raise ValueError(f"Unknown opt_batch_strategy: {strategy}") + + # Remove used indices from the stack, preserving order between the views separately for each batch. + remove_mask = (viewpoint_stack.unsqueeze(-1) == idx_to_remove.unsqueeze(1)).any(-1) # [B, V] + batch_split.viewpoint_stack = viewpoint_stack[~remove_mask].view(scene_batch, -1) # [B, V_used] + + return batch_idxs + + def _get_neighbor_indices(self, extrinsics, center_idx, batch_size: int) -> torch.Tensor: + """Get indices of nearest neighbor views based on camera pose distance.""" + combined_metric = self.calc_extrinsics_dist(center_idx, extrinsics) + return torch.argsort(combined_metric)[:batch_size].unsqueeze(0) # [1, K] + + @staticmethod + def calc_extrinsics_dist(center_idx, extrinsics): + """Combined position + rotation distance from a center view to all views. Returns [V].""" + rotations = extrinsics[:, :3, :3] # [V, 3, 3] + # Calculate camera center as -R^T * t + translation = extrinsics[:, :3, [3]] # [V, 3, 1] + poses = -rotations.transpose(1, 2) @ translation # [V, 3, 1] + center_pose = poses[center_idx] # [3, 1] + # Calculate Euclidean distances to the center view + dists = torch.norm(poses - center_pose.unsqueeze(0), dim=1)[0] # [V] + # Calculate angular differences to the center view + center_rot = extrinsics[center_idx, :3, :3] # [3, 3] + # Compute rotation difference + rot_diffs = torch.matmul(rotations, center_rot.transpose(0, 1)) # [V, 3, 3] + # Compute angles from rotation matrices + cos_angles = (rot_diffs[:, 0, 0] + rot_diffs[:, 1, 1] + rot_diffs[:, 2, 2] - 1) / 2 # [V] + cos_angles = torch.clamp(cos_angles, -1.0, 1.0) # Numerical stability + angles = torch.acos(cos_angles) # [V] + # Combine distance and angle into a single metric + combined_metric = dists + angles # [V] + return combined_metric + + def get_num_update_steps(self) -> int: + """Return number of optimizer steps, randomly sampled during training if train_max_refine is set.""" + if self.training and self.scene_trainer_cfg.train_max_refine > 0: + num_updates = random.randint( + self.scene_trainer_cfg.train_min_refine, + self.scene_trainer_cfg.train_max_refine + ) + else: + num_updates = self.scene_trainer_cfg.num_update_steps + return num_updates + + def get_init_gaussians(self, batch, is_training: bool, **kwargs) -> InitializerOutput: + """Run the initializer to produce Gaussians from context views, with optional sliding window. + + Gradients are disabled when not training so the init model is frozen during refine-only runs. + """ + window_size = self.train_cfg.train_window_size if is_training else self.test_cfg.inference_window_size + with torch.set_grad_enabled(is_training): + if window_size is not None: + initializer_output = self.init_gaussians_with_window(batch, window_size, **kwargs) + else: + # In some cases we might want to pass the target as well + # (e.g., to manipulate the poses in colmap dataset) + initializer_output = self.initializer(batch["context"], scene=batch["scene"], + target=batch["target"], device=self.device, **kwargs) + return initializer_output + + def init_gaussians_with_window(self, batch, window, **kwargs) -> InitializerOutput: + """Run the initializer in a sliding window over views, then combine the per-window Gaussians.""" + assert self.initializer.cfg.per_view, "Sliding window initialization only supports per-pixel initialization." + b, v, _, h, w = batch["context"]["image"].shape + assert window > 0 + + window_indices = sliding_window_indices(v, window, 0) + all_gaussians = [] + all_states = [] + all_pred_depths = [] + for indices in window_indices: + + start, end = indices + view_indices = torch.arange(start, end, device=batch["context"]["image"].device).unsqueeze(0).expand(b, -1) + curr_window_input = batch["context"].batchify_views(view_indices) + + initializer_output = self.initializer(curr_window_input, **kwargs) + + curr_gaussians = initializer_output.gaussians # Gaussians object with tensors shape [B, G, D1, ...] + curr_features = initializer_output.features # [BV, C, H, W] + + all_gaussians.append(curr_gaussians) + all_states.append(curr_features) + if initializer_output.depths is not None: + all_pred_depths.append(initializer_output.depths) + + # merge all gaussians + def combine_gaussians_attribute(attr_name): + all_attr = [getattr(g, attr_name) for g in all_gaussians[:-1]] + last_g = all_gaussians[-1] + last_g_attr = getattr(last_g, attr_name) + # handle the overlapping in the last window + if v % window != 0: + x = v % window + b, vhw, *d = last_g_attr.shape + if self.initializer.cfg.per_pixel: + # per-pixel initialization + h_gaussians = h // self.initializer.cfg.latent_downsample + w_gaussians = w // self.initializer.cfg.latent_downsample + else: + raise NotImplementedError + last_g_attr = last_g_attr.view(b, window, h_gaussians, w_gaussians, *d) # [B, V, H, W, ...] + last_g_attr = last_g_attr[:, -x:, ...] # [B, x, H, W, ...] + last_g_attr = last_g_attr.view(b, -1, *d) # [B, x*H*W, ...] + all_attr.append(last_g_attr) + return torch.cat(all_attr, dim=1) + + gaussians = Gaussians( + means=combine_gaussians_attribute('means'), + covariances=combine_gaussians_attribute('covariances'), + harmonics=combine_gaussians_attribute('harmonics'), + opacities=combine_gaussians_attribute('opacities'), + scales=combine_gaussians_attribute('scales'), + rotations=combine_gaussians_attribute('rotations'), + rotations_unnorm=combine_gaussians_attribute('rotations_unnorm'), + ) + + # Collect condition features for the optimizer (only needed if optimizer is active) + if self.scene_trainer_cfg.num_update_steps > 0: + out = [] + is_ori_feature = True # set by first window; True = [BV,C,H,W], False = [BVHW,C] + for i in range(len(all_states)): + # Assuming no overlap between windows + curr = all_states[i] + if curr.dim() == 4: + # [BV, C, H, W] + curr = rearrange(curr, "(b v) c h w -> b v c h w", b=b) + is_ori_feature = True + elif curr.dim() == 2: + # [BVHW, C] + curr = rearrange(curr, "(b v h w) c -> b v h w c", b=b, + h=h // self.initializer.cfg.latent_downsample, + w=w // self.initializer.cfg.latent_downsample, + ) + is_ori_feature = False + else: + raise NotImplementedError + + # Only need to handle the overlaping in the last window + if i == len(all_states) - 1 and v % window != 0: + # last window with overlap + x = v % window + curr = curr[:, -x:, ...] + out.append(curr) + + # concat + if is_ori_feature: + concat = torch.cat(out, dim=1) # [B, V*K, C, H, W] + concat = rearrange(concat, "b v c h w -> (b v) c h w") + else: + concat = torch.cat(out, dim=1) # [B, V*K, H, W, C] + concat = rearrange(concat, "b v h w c -> (b v) c h w") + + condition_features = concat + else: + condition_features = None + + return InitializerOutput(gaussians=gaussians, + features=condition_features, + depths=all_pred_depths) + + def debugging(self, optimizer_output, output_path: Path, scene_name: str): # TODO (release): remove in public code + + # Debugging reprojection errors + # if 'reprojection_error' in visualization_dump: + # self.debugging_reprojection_error(visualization_dump) + + assert "deltas" in optimizer_output.info, "Deltas not found in optimizer output info." + assert "grads" in optimizer_output.info, "Grads not found in optimizer output info." + assert "normalized_grads" in optimizer_output.info, "Normalized grads not found in optimizer output info." + # assert "learning_rates" in optimizer_output.info, "Learning rates not found in optimizer output info." + + # Unpack Optimizer output + deltas_list: list[dict[str, Tensor]] = optimizer_output.info["deltas"] + + grads_raw_list: list[dict[str, Tensor]] = optimizer_output.info["grads"] + normalized_grads_list: list[dict[str, Tensor]] = optimizer_output.info["normalized_grads"] + + # Get PSNR list + module_name = self.optimizer.__class__.__name__.lower() + psnr_list = self.test_step_outputs_target[f"{module_name}_psnr"][0] # list of psnr for target views per scene + + # Get iterations list + iterations_list = self.optimizer.save_every.get_iterations(len(psnr_list)) + + means2d_list = [render.means2d for render in optimizer_output.target_render_list] + radii_list = [render.radii for render in optimizer_output.target_render_list] + + debugging_invisible_gaussians( + optimizer_output.gaussian_list, + grads_raw_list, + normalized_grads_list, + means2d_list, + radii_list, + psnr_list, + iterations_list, + output_path / module_name, + scene_name + ) + + # Remove init. + psnr_list = psnr_list[1:] + iterations_list = iterations_list[1:] + + if "states_norms" in optimizer_output.info: + states_norms_list: list[Tensor] = optimizer_output.info["states_norms"] + debugging_convergence( + deltas_list, + states_norms_list, + grads_raw_list, + normalized_grads_list, + psnr_list, + iterations_list, + output_path / module_name, + scene_name + ) + + + def _insert_init_into_output( + self, + optimizer_output: OptimizerOutput, + init_output: InitializerOutput, + ) -> None: + """Prepend the init render/gaussians at position 0 of the optimizer_output lists. + + No-op if no init render exists yet (train path before init renders are wired + through). Otherwise inserts gaussians plus whichever of context/target renders + are populated. detach_and_cpu mirrors the per-step append policy. + """ + if init_output.context_render is None and init_output.target_render is None: + return + + optimizer_output.gaussian_list.insert(0, init_output.gaussians) + if init_output.context_render is not None: + optimizer_output.context_render_list.insert( + 0, init_output.context_render, detach_and_cpu=not self.training + ) + if init_output.target_render is not None: + optimizer_output.target_render_list.insert( + 0, init_output.target_render, detach_and_cpu=not self.training + ) + + def init_gaussians_and_render( + self, batch, visualization_dump, + render_context: bool, render_target: bool, grad_enabled: bool, + **kwargs, + ) -> InitializerOutput: + """Run the initializer and optionally render its output to context/target views. + + Used in both training (grad_enabled=True, outputs stay on GPU so the init-loss term + can backward through them) and test (grad_enabled=False, outputs moved to CPU to save + memory since they're only consumed for evaluation/saving). + """ + + # run initializer + with self.benchmarker.time("initializer"): + init_output: InitializerOutput = self.get_init_gaussians(batch, is_training=grad_enabled, **kwargs) + + # to_cpu freezes the render off the GPU for evaluation/saving; with grads enabled we + # keep it on GPU so the init-loss term can backward through it. + to_cpu = not grad_enabled + + with torch.set_grad_enabled(grad_enabled): + for input_str, should_render in ( + ("context", render_context), + ("target", render_target), + ): + attr = f"{input_str}_render" + if not should_render or getattr(init_output, attr) is not None: + continue + views = batch[input_str] + h, w = views["image"].shape[-2:] + rendered = self.decoder.forward_batch( + init_output.gaussians.to(batch["target"]["image"].device), + batch, (h, w), + input_str=input_str, + to_cpu=to_cpu, + iter_batch_size=self.scene_trainer_cfg.iter_batch_size, + ) + # TODO Naama: should we make it a render list as in OptimizerOutput and then the flow will be more unified? + setattr(init_output, attr, rendered) + + return init_output + + def test_postprocess_gaussians(self, batch, gaussians, visualization_dump) -> OptimizerOutput | None: + """Run optional post-processing on the final Gaussians. Returns None if disabled.""" + postprocess_output = None + if self.postprocess is not None: + postprocess_output = self.postprocess.apply( + batch, + gaussians=gaussians, + decoder=self.decoder, + visualization_dump=visualization_dump, + iter_batch_size=self.scene_trainer_cfg.iter_batch_size, + batchify_fn=lambda b, input_str: self.batchify_views( + b, input_str, self.device, + ), + ) + + return postprocess_output + + +def sliding_window_indices(N, x, y): + """Return [start, end] pairs for a sliding window of size x with overlap y over N views. + The last window is always [N-x, N] to cover any remainder.""" + indices = [] + start = 0 + while start + x < N: # Ensure the last window is not processed here + end = min(start + x, N) + indices.append([start, end]) + start += (x - y) # Move the start by the window size minus overlap + + # Append the last window [N-x, N] + indices.append([N - x, N]) + + return indices diff --git a/optgs/scene_trainer/scene_trainer_cfg.py b/optgs/scene_trainer/scene_trainer_cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..f333152a3fe795fb17c51eba8edc5d3fc4b6c772 --- /dev/null +++ b/optgs/scene_trainer/scene_trainer_cfg.py @@ -0,0 +1,157 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +from optgs.meta_trainer.replay_buffer import ReplayBufferCfg +from .initializer import InitializerCfg +from .optimizer import SceneOptimizerCfg +from .postprocessing import PostProcessCfg +from ..model.decoder import DecoderCfg +from ..model.decoder.decoder import DepthRenderingMode + + +@dataclass +class SceneTrainerCfg: + scene_initializer: InitializerCfg + scene_optimizer: SceneOptimizerCfg | None + decoder: DecoderCfg + use_fsdp: bool + train_scene_init: bool + train_scene_opt: bool + num_update_steps: int + iter_batch_size: int # if -1, use full batch + + train_min_refine: int + train_max_refine: int + + opt_batch_size: int # if -1, use full batch + opt_batch_size_min: int # if > 0, use random sub-batch + opt_batch_size_max: int # if > 0, use random sub-batch + opt_batch_strategy: Literal["random", "sequential", "neighbors", "fps"] # strategy for sub-batch + sh_degree_interval: int # 0 = disabled; N = steps between SH degree increments (like gsplat simple_trainer) + + def __post_init__(self): + if self.scene_optimizer is not None: + self.scene_optimizer.update(self.scene_initializer) + + +# TODO Naama, probably need to move into meta_trainer cfg file + +@dataclass +class MetaOptimizerCfg: + lr: float + warm_up_steps: int + lr_monodepth: float + lr_depth: float + weight_decay: float + warm_up_ratio: float + adamw_8bit: bool + + +@dataclass +class TestCfg: + output_path: Path | None + compute_scores: bool + compute_scores_metrics: list[str] | None + metrics_batch_size: int + eval_initialization: bool + save_render_image: bool + save_render_image_last_only: bool + save_gt_image: bool + save_render_depth: bool + save_gt_depth: bool + save_error_image: bool + save_video: bool + save_video_fixed_view: bool + save_video_fixed_view_index: int + save_video_fixed_view_duplicate: int + save_video_fixed_iteration: bool + save_video_fixed_iteration_indices: list | None + save_video_fixed_iteration_render_fixed_view: bool + save_video_combined: bool + save_video_combined_iterations: list | None + save_video_combined_fixed_iteration_length: int + eval_time_skip_steps: int + save_gaussian: bool + save_poses: bool + save_cameras_json: bool + save_cameras_npz: bool + save_point_cloud: bool + render_chunk_size: int | None + stablize_camera: bool + stab_camera_kernel: int + eval_context_views: bool + inference_window_size: int | None + profile_model: bool + save_colmap_train_test_views: bool + ori_colmap_data_path: str | None + postprocessing: PostProcessCfg | None + save_at_iters: list[int] | None + save_every_freq: list[int] | None + save_every_steps: list[int] | None + skip_if_outputs_exist: bool + scenes_filter: list[str] | None + + experimental_add_noise_to_images: bool + experimental_add_noise_to_images_std: float | int | None + + +# TODO Naama split into scene and meta trainer cfgs +@dataclass +class TrainCfg: + depth_mode: DepthRenderingMode | None + extended_visualization: bool + print_log_every_n_steps: int + eval_model_every_n_val: int + eval_data_length: int + eval_deterministic: bool + eval_time_skip_steps: int + eval_save_model: bool + l1_loss: bool + intermediate_loss_weight: float + no_viz_video: bool + eval_depth: bool + train_ignore_large_loss: float + no_log_projections: bool + + depth_loss_weight: float + log_depth_loss: bool + depth_smooth_loss_weight: float + depth_teacher_loss_weight: float + viz_depth_teacher: bool + depth_smooth_loss_nonorm: bool + depth_smooth_loss_weight_nvs: float # for novel views + monodepth_loss_weight: float # for monocular depth loss + + eval_render_depth: bool + render_depth_loss_weight: float + viz_render_depth: bool + viz_depth_separate: bool + + use_gt_depth_range: bool + depth_range_from_disparity: bool + max_disparity: float + min_disparity: float + + no_log_video: bool + + # when doing refinement, supervise input view or not since we also render input views + loss_on_target_views: bool + loss_on_target_views_num: int + loss_on_input_views: bool + loss_on_input_views_num: int + + # half res lpips loss to save memory + half_res_lpips_loss: bool + + # local window training + train_window_size: int | None + + # Replay buffer + use_replay_buffer: bool + replay_buffer_cfg: ReplayBufferCfg | None + + # L2 weight decay regularization on Gaussian properties (meta-loss) + scale_l2_loss_weight: float + sh_l2_loss_weight: float + opacity_l2_loss_weight: float diff --git a/optgs/scripts/__init__.py b/optgs/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/scripts/analyze_metrics.sh b/optgs/scripts/analyze_metrics.sh new file mode 100644 index 0000000000000000000000000000000000000000..21f1710ed2339de04f48bee507c00035eb8218df --- /dev/null +++ b/optgs/scripts/analyze_metrics.sh @@ -0,0 +1,133 @@ +#!/usr/bin/env bash +# Summarize per-scene metrics under //target_.json +# and (optionally) the PSNR improvement over an init-only baseline. +# +# Usage: +# analyze_metrics.sh [init_only_metrics_dir] +# +# Examples: +# analyze_metrics.sh \ +# results/mipnerf360/8_all_9_scenes/3dgslm_ply_init/vanilla/100_match_3dgslm_metrics/vanillaoptimizer/metrics +# +# analyze_metrics.sh \ +# results/mipnerf360_sfm/.../clogsoptimizer/metrics \ +# results/mipnerf360_sfm/8_all_9_scenes/3dgslm_ply_init/init_only/100_match_3dgslm_metrics/initializerply/metrics + +set -euo pipefail + +if [[ $# -lt 1 || $# -gt 2 ]]; then + echo "Usage: $0 [init_only_metrics_dir]" >&2 + exit 1 +fi + +OPT_DIR="$1" +INIT_DIR="${2:-}" + +python3 - "$OPT_DIR" "$INIT_DIR" <<'PY' +import sys, os, glob, json + +opt_dir = sys.argv[1] +init_dir = sys.argv[2] or None + +def find_target_json(metrics_dir): + # Expect //target_.json + files = sorted(glob.glob(os.path.join(metrics_dir, "*", "target_*.json"))) + if not files: + sys.exit(f"No target_*.json found under {metrics_dir}") + module = os.path.basename(files[0])[len("target_"):-len(".json")] + for f in files: + if not os.path.basename(f) == f"target_{module}.json": + sys.exit(f"Mixed module names under {metrics_dir}: {f}") + return files, module + +def find_any(obj, key): + if isinstance(obj, dict): + if key in obj: + yield obj[key] + for v in obj.values(): + yield from find_any(v, key) + elif isinstance(obj, list): + for x in obj: + yield from find_any(x, key) + +opt_files, opt_mod = find_target_json(opt_dir) +opt_psnr, opt_vram, opt_time_s, opt_iters, opt_gauss = {}, {}, {}, {}, {} +for f in opt_files: + scene = os.path.basename(os.path.dirname(f)) + with open(f) as fh: + d = json.load(fh) + psnr_list = d.get(f"{opt_mod}_psnr") + time_list = d.get(f"{opt_mod}_time") + iters_list = d.get(f"{opt_mod}_iterations") + gauss_list = d.get(f"{opt_mod}_gaussians") + vram = list(find_any(d, "peak_vram_mb")) + opt_psnr[scene] = psnr_list[-1] if psnr_list else None + opt_time_s[scene] = time_list[-1]/1000.0 if time_list else None + opt_iters[scene] = iters_list[-1] if iters_list else None + opt_gauss[scene] = gauss_list[-1] if gauss_list else None + opt_vram[scene] = vram[0] if vram else None + +init_psnr = {} +init_mod = None +if init_dir: + init_files, init_mod = find_target_json(init_dir) + for f in init_files: + scene = os.path.basename(os.path.dirname(f)) + with open(f) as fh: + d = json.load(fh) + psnr_list = d.get(f"{init_mod}_psnr") + init_psnr[scene] = psnr_list[-1] if psnr_list else None + +print(f"Optimizer module: {opt_mod} (dir: {opt_dir})") +if init_mod: + print(f"Init module : {init_mod} (dir: {init_dir})") +print() + +scenes = sorted(opt_psnr) +header = f"{'scene':<10}" +if init_dir: + header += f" {'init':>8} {'last':>8} {'diff':>8}" +else: + header += f" {'last_psnr':>10}" +header += f" {'vram_MB':>10} {'time_s':>9} {'iters':>7} {'gauss':>10}" +print(header) + +def mean(xs): + xs = [x for x in xs if x is not None] + return sum(xs)/len(xs) if xs else None + +diffs = [] +for s in scenes: + row = f"{s:<10}" + if init_dir: + ip = init_psnr.get(s) + op = opt_psnr.get(s) + diff = (op - ip) if (ip is not None and op is not None) else None + if diff is not None: + diffs.append(diff) + row += f" {('--' if ip is None else f'{ip:8.3f}')}" + row += f" {('--' if op is None else f'{op:8.3f}')}" + row += f" {('--' if diff is None else f'{diff:+8.3f}')}" + else: + op = opt_psnr.get(s) + row += f" {('--' if op is None else f'{op:10.3f}')}" + v = opt_vram.get(s); row += f" {('--' if v is None else f'{v:10.2f}')}" + t = opt_time_s.get(s); row += f" {('--' if t is None else f'{t:9.3f}')}" + it = opt_iters.get(s); row += f" {('--' if it is None else f'{it:7d}')}" + g = opt_gauss.get(s); row += f" {('--' if g is None else f'{g:10d}')}" + print(row) + +print() +def fmt(x, spec, suffix=""): + return "--" if x is None else f"{x:{spec}}{suffix}" + +print(f"n scenes : {len(scenes)}") +print(f"avg last_psnr : {fmt(mean(opt_psnr.values()), '.3f')}") +if init_dir: + print(f"avg init_psnr : {fmt(mean(init_psnr.values()), '.3f')}") + print(f"avg improvement : {fmt(mean(diffs) if diffs else None, '+.3f', ' dB')}") +print(f"avg peak_vram : {fmt(mean(opt_vram.values()), '.2f', ' MB')}") +print(f"avg time : {fmt(mean(opt_time_s.values()), '.3f', ' s')}") +print(f"avg iters : {fmt(mean(opt_iters.values()), '.1f')}") +print(f"avg gaussians : {fmt(mean(opt_gauss.values()), '.0f')}") +PY diff --git a/optgs/scripts/convert_dl3dv_test.py b/optgs/scripts/convert_dl3dv_test.py new file mode 100644 index 0000000000000000000000000000000000000000..54071abbb2310eae250f5591302d118f67b41ded --- /dev/null +++ b/optgs/scripts/convert_dl3dv_test.py @@ -0,0 +1,143 @@ +import argparse +import os +from glob import glob +from pathlib import Path + +import torch +from tqdm import tqdm + +from optgs.scripts.convert_dl3dv_utils import is_image_shape_matched, Example, get_size, load_images, load_metadata + +parser = argparse.ArgumentParser() +parser.add_argument("--input_dir", type=str, help="original dataset directory") +parser.add_argument("--output_dir", type=str, help="processed dataset directory") +parser.add_argument( + "--img_subdir", + type=str, + default="images_8", + help="image directory name", + choices=[ + "images_4", + "images_8", + ], +) +parser.add_argument("--n_test", type=int, default=8, help="test skip") +parser.add_argument("--which_stage", type=str, default=None, help="dataset directory") +parser.add_argument("--detect_overlap", action="store_true") + +args = parser.parse_args() + +INPUT_DIR = Path(args.input_dir) +OUTPUT_DIR = Path(args.output_dir) + +# Target 200 MB per chunk. +TARGET_BYTES_PER_CHUNK = int(1e8) + + +def legal_check_for_all_scenes(root_dir, target_shape, img_dir): + valid_folders = [] + sub_folders = sorted(glob(os.path.join(root_dir, "*/nerfstudio"))) + # NOTE that + # 07d9f9724ca854fae07cb4c57d7ea22bf667d5decd4058f547728922f909956b + # images_4 folder has resolution 270x480 + + for sub_folder in tqdm(sub_folders, desc="checking scenes..."): + # img_dir = os.path.join(sub_folder, "images_8") # 270x480 + img_dir = os.path.join(sub_folder, img_dir) # 540x960 + if not is_image_shape_matched(Path(img_dir), target_shape): + print(f"image shape does not match for {sub_folder}") + continue + pose_file = os.path.join(sub_folder, "transforms.json") + if not os.path.isfile(pose_file): + print(f"cannot find pose file for {sub_folder}") + continue + + valid_folders.append(sub_folder) + + return valid_folders + + +if __name__ == "__main__": + if "images_8" in args.img_subdir: + target_shape = (270, 480) # (h, w) + elif "images_4" in args.img_subdir: + target_shape = (540, 960) + else: + raise ValueError + + print("checking all scenes...") + valid_scenes = legal_check_for_all_scenes(INPUT_DIR, target_shape, args.img_subdir) + print("valid scenes:", len(valid_scenes)) + + # assert False + + for stage in ["test"]: + + error_logs = [] + image_dirs = valid_scenes + + chunk_size = 0 + chunk_index = 0 + chunk: list[Example] = [] + + + def save_chunk(): + global chunk_size + global chunk_index + global chunk + + chunk_key = f"{chunk_index:0>6}" + dir = OUTPUT_DIR / stage + dir.mkdir(exist_ok=True, parents=True) + torch.save(chunk, dir / f"{chunk_key}.torch") + + # Reset the chunk. + chunk_size = 0 + chunk_index += 1 + chunk = [] + + + for image_dir in tqdm(image_dirs, desc=f"Processing {stage}"): + key = os.path.basename(os.path.dirname(image_dir.strip("/"))) + + image_dir = Path(image_dir) / args.img_subdir # 270x480 + # image_dir = Path(image_dir) / 'images_4' # 540x960 + + num_bytes = get_size(image_dir) + + # Read images and metadata. + try: + images = load_images(image_dir) + except: + print("image loading error") + continue + meta_path = image_dir.parent / "transforms.json" + if not meta_path.is_file(): + error_msg = f"---------> [ERROR] no meta file in {key}, skip." + print(error_msg) + error_logs.append(error_msg) + continue + example = load_metadata(meta_path) + + # Merge the images into the example. + try: + example["images"] = [ + images[timestamp.item()] for timestamp in example["timestamps"] + ] + except: + error_msg = f"---------> [ERROR] Some images missing in {key}, skip." + print(error_msg) + error_logs.append(error_msg) + continue + + # Add the key to the example. + example["key"] = key + + chunk.append(example) + chunk_size += num_bytes + + if chunk_size >= TARGET_BYTES_PER_CHUNK: + save_chunk() + + if chunk_size > 0: + save_chunk() diff --git a/optgs/scripts/convert_dl3dv_train.py b/optgs/scripts/convert_dl3dv_train.py new file mode 100644 index 0000000000000000000000000000000000000000..82cef440c5de1ea4a9cbe672b21b7b546e015f8a --- /dev/null +++ b/optgs/scripts/convert_dl3dv_train.py @@ -0,0 +1,149 @@ +import argparse +import json +import os +from glob import glob +from pathlib import Path + +import torch +from tqdm import tqdm + +from optgs.scripts.convert_dl3dv_utils import Example, get_size, load_images, load_metadata, is_image_shape_matched + +parser = argparse.ArgumentParser() +parser.add_argument("--input_dir", type=str, help="original dataset directory") +parser.add_argument("--output_dir", type=str, help="processed dataset directory") +parser.add_argument( + "--img_subdir", + type=str, + default="images_8", + help="image directory name", + choices=[ + "images_4", + "images_8", + ], +) +parser.add_argument("--n_test", type=int, default=10, help="test skip") +parser.add_argument("--which_stage", type=str, default=None, help="dataset directory") +parser.add_argument("--detect_overlap", action="store_true") + +args = parser.parse_args() + +INPUT_DIR = Path(args.input_dir) +OUTPUT_DIR = Path(args.output_dir) + +# Target 200 MB per chunk. +TARGET_BYTES_PER_CHUNK = int(2e8) + + +def legal_check_for_all_scenes(root_dir, target_shape): + valid_folders = [] + sub_folders = sorted(glob(os.path.join(root_dir, "*/*"))) + for sub_folder in tqdm(sub_folders, desc="checking scenes..."): + # img_dir = os.path.join(sub_folder, 'images_8') + img_dir = os.path.join(sub_folder, "images_4") + if not is_image_shape_matched(Path(img_dir), target_shape): + print(f"image shape does not match for {sub_folder}") + continue + pose_file = os.path.join(sub_folder, "transforms.json") + if not os.path.isfile(pose_file): + print(f"cannot find pose file for {sub_folder}") + continue + + valid_folders.append(sub_folder) + + return valid_folders + + +if __name__ == "__main__": + if "images_8" in args.img_subdir: + target_shape = (270, 480) # (h, w) + elif "images_4" in args.img_subdir: + target_shape = (540, 960) + else: + raise ValueError + + print("checking all scenes...") + valid_scenes = legal_check_for_all_scenes(INPUT_DIR, target_shape) + print("valid scenes:", len(valid_scenes)) + + # test scenes + test_scenes = "your_test_set_index.json" + with open(test_scenes, "r") as f: + overlap_scenes = json.load(f) + + assert len(overlap_scenes) == 140, "test scenes should contain 140 scenes" + + for stage in ["train"]: + + error_logs = [] + image_dirs = valid_scenes + + chunk_size = 0 + chunk_index = 0 + chunk: list[Example] = [] + + + def save_chunk(): + global chunk_size + global chunk_index + global chunk + + chunk_key = f"{chunk_index:0>6}" + dir = OUTPUT_DIR / stage + dir.mkdir(exist_ok=True, parents=True) + torch.save(chunk, dir / f"{chunk_key}.torch") + + # Reset the chunk. + chunk_size = 0 + chunk_index += 1 + chunk = [] + + + for image_dir in tqdm(image_dirs, desc=f"Processing {stage}"): + key = os.path.basename(image_dir.strip("/")) + # skip test scenes + if key in overlap_scenes: + print(f"scene {key} in benchmark, skip.") + continue + + image_dir = Path(image_dir) / "images_8" # 270x480 + # image_dir = Path(image_dir) / 'images_4' # 540x960 + + num_bytes = get_size(image_dir) + + # Read images and metadata. + try: + images = load_images(image_dir) + except: + print("image loading error") + continue + meta_path = image_dir.parent / "transforms.json" + if not meta_path.is_file(): + error_msg = f"---------> [ERROR] no meta file in {key}, skip." + print(error_msg) + error_logs.append(error_msg) + continue + example = load_metadata(meta_path) + + # Merge the images into the example. + try: + example["images"] = [ + images[timestamp.item()] for timestamp in example["timestamps"] + ] + except: + error_msg = f"---------> [ERROR] Some images missing in {key}, skip." + print(error_msg) + error_logs.append(error_msg) + continue + + # Add the key to the example. + example["key"] = "dl3dv_" + key + + chunk.append(example) + chunk_size += num_bytes + + if chunk_size >= TARGET_BYTES_PER_CHUNK: + save_chunk() + + if chunk_size > 0: + save_chunk() diff --git a/optgs/scripts/convert_dl3dv_utils.py b/optgs/scripts/convert_dl3dv_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7186fbf7b1f807ee1345535787f9bfe622ddb58c --- /dev/null +++ b/optgs/scripts/convert_dl3dv_utils.py @@ -0,0 +1,122 @@ +import json +import os +import subprocess +from glob import glob +from pathlib import Path +from typing import TypedDict + +import numpy as np +import torch +from PIL.Image import Image +from jaxtyping import UInt8, Int, Float +from torch import Tensor + + +# def get_example_keys(stage: Literal["test", "train"]) -> list[str]: +# image_keys = set( +# example.name +# for example in tqdm(list((INPUT_DIR / stage).iterdir()), desc="Indexing scenes") +# if example.is_dir() and not example.name.startswith(".") +# ) +# # keys = image_keys & metadata_keys +# keys = image_keys +# # print(keys) +# print(f"Found {len(keys)} keys.") +# return sorted(list(keys)) + + +def get_size(path: Path) -> int: + """Get file or folder size in bytes.""" + return int(subprocess.check_output(["du", "-b", path]).split()[0].decode("utf-8")) + + +def load_raw(path: Path) -> UInt8[Tensor, " length"]: + return torch.tensor(np.memmap(path, dtype="uint8", mode="r")) + + +def load_images(example_path: Path) -> dict[int, UInt8[Tensor, "..."]]: + """Load JPG images as raw bytes (do not decode).""" + + return { + int(path.stem.split("_")[-1]): load_raw(path) + for path in example_path.iterdir() + if path.suffix.lower() not in [".npz"] + } + + +class Metadata(TypedDict): + url: str + timestamps: Int[Tensor, " camera"] + cameras: Float[Tensor, "camera entry"] + + +class Example(Metadata): + key: str + images: list[UInt8[Tensor, "..."]] + + +def load_metadata(example_path: Path) -> Metadata: + blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + url = str(example_path).split("/")[-3] + with open(example_path, "r") as f: + meta_data = json.load(f) + + store_h, store_w = meta_data["h"], meta_data["w"] + fx, fy, cx, cy = ( + meta_data["fl_x"], + meta_data["fl_y"], + meta_data["cx"], + meta_data["cy"], + ) + saved_fx = float(fx) / float(store_w) + saved_fy = float(fy) / float(store_h) + saved_cx = float(cx) / float(store_w) + saved_cy = float(cy) / float(store_h) + + timestamps = [] + cameras = [] + opencv_c2ws = [] # will be used to calculate camera distance + + for frame in meta_data["frames"]: + timestamps.append( + int(os.path.basename(frame["file_path"]).split(".")[0].split("_")[-1]) + ) + camera = [saved_fx, saved_fy, saved_cx, saved_cy, 0.0, 0.0] + # transform_matrix is in blender c2w, while we need to store opencv w2c matrix here + opencv_c2w = np.array(frame["transform_matrix"]) @ blender2opencv + opencv_c2ws.append(opencv_c2w) + camera.extend(np.linalg.inv(opencv_c2w)[:3].flatten().tolist()) + cameras.append(np.array(camera)) + + # timestamp should be the one that match the above images keys, use for indexing + timestamps = torch.tensor(timestamps, dtype=torch.int64) + cameras = torch.tensor(np.stack(cameras), dtype=torch.float32) + + return {"url": url, "timestamps": timestamps, "cameras": cameras} + + +def partition_train_test_splits(root_dir, n_test=10): + sub_folders = sorted(glob(os.path.join(root_dir, "*/"))) + test_list = sub_folders[::n_test] + train_list = [x for x in sub_folders if x not in test_list] + out_dict = {"train": train_list, "test": test_list} + return out_dict + + +def is_image_shape_matched(image_dir, target_shape): + image_path = sorted(glob(str(image_dir / "*"))) + if len(image_path) == 0: + return False + + image_path = image_path[0] + try: + im = Image.open(image_path) + except: + return False + w, h = im.size + if (h, w) == target_shape: + return True + else: + return False diff --git a/optgs/scripts/dev/benchmark_colmap_loading.py b/optgs/scripts/dev/benchmark_colmap_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb9a96dce264d95f2e982c2bc8c57559adf6897 --- /dev/null +++ b/optgs/scripts/dev/benchmark_colmap_loading.py @@ -0,0 +1,201 @@ +"""Benchmark COLMAP binary parsing vs .npz cache loading. + +Usage: + python src/scripts/benchmark_colmap_loading.py --root [--scenes N] [--repeats R] [--normalize] + +For each sampled scene the script measures: + - Parser (raw .bin) : full SceneManager + pose processing time + - npz (cached) : np.load() time after the cache has been written + +Results are printed as a table and summary statistics. +""" +import argparse +import sys +import time +import os +import tempfile +from pathlib import Path + +import numpy as np + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from optgs.dataset.colmap.utils import Parser + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def npz_path(scene_dir: Path, normalize: bool) -> Path: + suffix = "_norm" if normalize else "" + return scene_dir / f"colmap_points_cache{suffix}.npz" + + +def find_scene_dirs(root: Path) -> list[Path]: + scenes = [] + for child in sorted(root.iterdir()): + if not child.is_dir(): + continue + sparse = child / "sparse" / "0" + if not sparse.exists(): + sparse = child / "sparse" + if sparse.exists(): + scenes.append(child) + return scenes + + +def time_bin(scene_dir: Path, normalize: bool) -> float: + """Time a full Parser (raw COLMAP binary) parse.""" + t0 = time.perf_counter() + parser = Parser( + data_dir=str(scene_dir), + factor=1, + normalize=normalize, + load_images=False, + dl3dv_settings=False, + verbose=False, + ) + _ = parser.points, parser.points_rgb, parser.camtoworlds + return time.perf_counter() - t0 + + +def ensure_npz(scene_dir: Path, normalize: bool) -> Path: + """Write the .npz cache if it doesn't exist (or is corrupt), return its path.""" + p = npz_path(scene_dir, normalize) + + # Delete corrupt/empty files before attempting to create. + if p.exists(): + try: + data = np.load(p) + _ = data["points"], data["points_rgb"], data["camtoworlds"] + return p # healthy, nothing to do + except Exception: + print(f" WARNING: corrupt cache found at {p}, deleting and re-creating…") + p.unlink(missing_ok=True) + + print(f" Creating .npz cache for {scene_dir.name}…", end="", flush=True) + parser = Parser( + data_dir=str(scene_dir), + factor=1, + normalize=normalize, + load_images=False, + dl3dv_settings=False, + verbose=False, + ) + # NOTE: np.savez_compressed auto-appends ".npz" if the path doesn't end + # with it — so the temp file must already carry the .npz suffix, otherwise + # savez writes to ".npz" while tmp_path points to the empty "". + tmp_fd, tmp_path = tempfile.mkstemp(dir=scene_dir, suffix=".npz") + os.close(tmp_fd) + try: + np.savez_compressed( + tmp_path, + points=parser.points, + points_rgb=parser.points_rgb, + camtoworlds=parser.camtoworlds, + ) + # Verify it's readable before promoting to the final path. + data = np.load(tmp_path, allow_pickle=False) + _ = data["points"], data["points_rgb"], data["camtoworlds"] + os.replace(tmp_path, p) + print(" done.") + except Exception: + print(f" ERROR creating .npz cache for {scene_dir.name}.", file=sys.stderr) + if os.path.exists(tmp_path): + os.unlink(tmp_path) + raise + return p + + +def time_npz(scene_dir: Path, normalize: bool) -> float: + """Time loading from the .npz cache.""" + p = npz_path(scene_dir, normalize) + t0 = time.perf_counter() + data = np.load(p) + _ = data["points"], data["points_rgb"], data["camtoworlds"] + return time.perf_counter() - t0 + + +# ── main ───────────────────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--root", required=True, type=Path, + help="Root directory containing one sub-dir per scene.") + parser.add_argument("--scenes", type=int, default=10, + help="Number of scenes to benchmark (default: 10).") + parser.add_argument("--repeats", type=int, default=3, + help="Repeat each timing N times and take the median (default: 3).") + parser.add_argument("--normalize", action="store_true", + help="Use normalize=True (matches normalize_world_space: true in config).") + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + root: Path = args.root.resolve() + if not root.exists(): + print(f"Root directory does not exist: {root}", file=sys.stderr) + sys.exit(1) + + all_scenes = find_scene_dirs(root) + if not all_scenes: + print(f"No COLMAP scenes found under {root}", file=sys.stderr) + sys.exit(1) + + rng = np.random.default_rng(args.seed) + n = min(args.scenes, len(all_scenes)) + scenes = [all_scenes[i] for i in rng.choice(len(all_scenes), size=n, replace=False)] + + print(f"Benchmarking {n} scenes (repeats={args.repeats}, normalize={args.normalize})\n") + + # Pre-create all .npz caches so the write cost doesn't pollute the timing. + print("Pre-creating .npz caches (if missing)…") + good_scenes = [] + for s in scenes: + try: + ensure_npz(s, args.normalize) + good_scenes.append(s) + except Exception as e: + print(f" SKIP {s.name}: {e}", file=sys.stderr) + print(f"Done. {len(good_scenes)}/{len(scenes)} scenes OK.\n") + + if not good_scenes: + print("No valid scenes to benchmark.", file=sys.stderr) + sys.exit(1) + + col_w = max(len(s.name) for s in good_scenes) + header = f"{'Scene':<{col_w}} {'bin (s)':>10} {'npz (s)':>10} {'speedup':>10}" + print(header) + print("-" * len(header)) + + bin_times, npz_times, speedups = [], [], [] + + for scene in good_scenes: + print(f" timing {scene.name}…", end="", flush=True) + try: + b = np.median([time_bin(scene, args.normalize) for _ in range(args.repeats)]) + z = np.median([time_npz(scene, args.normalize) for _ in range(args.repeats)]) + except Exception as e: + print(f"\r SKIP {scene.name}: {e}", file=sys.stderr) + continue + sp = b / z if z > 0 else float("inf") + + bin_times.append(b) + npz_times.append(z) + speedups.append(sp) + + print(f"\r{scene.name:<{col_w}} {b:>10.3f} {z:>10.4f} {sp:>9.1f}x") + + print("-" * len(header)) + if not bin_times: + print("No scenes were successfully benchmarked.") + sys.exit(1) + print(f"{'MEAN':<{col_w}} {np.mean(bin_times):>10.3f} {np.mean(npz_times):>10.4f} {np.mean(speedups):>9.1f}x") + print(f"{'MEDIAN':<{col_w}} {np.median(bin_times):>10.3f} {np.median(npz_times):>10.4f} {np.median(speedups):>9.1f}x") + print(f"{'MAX':<{col_w}} {np.max(bin_times):>10.3f} {np.max(npz_times):>10.4f} {np.max(speedups):>9.1f}x") + + +if __name__ == "__main__": + main() + diff --git a/optgs/scripts/dev/debug_dataset.py b/optgs/scripts/dev/debug_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f227b3e87dd9ae3f5b6e4ef66e901b66e47ac4e4 --- /dev/null +++ b/optgs/scripts/dev/debug_dataset.py @@ -0,0 +1,94 @@ +import os +import sys +import warnings +from pathlib import Path + +import hydra +import torch +from jaxtyping import install_import_hook +from omegaconf import DictConfig +import matplotlib.pyplot as plt + +from optgs.misc.io import cyan + +# Configure beartype and jaxtyping. +with install_import_hook( + ("optgs",), + ("beartype", "beartype"), +): + from optgs.config import setup_cfg + from optgs.dataset.data_module import DataModule + from optgs.misc.step_tracker import StepTracker + +# print torch device info +print(cyan(f"Torch version: {torch.__version__}")) +if torch.cuda.is_available(): + print(cyan(f"CUDA is available. Number of devices: {torch.cuda.device_count()}")) + for i in range(torch.cuda.device_count()): + print(cyan(f"Device {i}: {torch.cuda.get_device_name(i)}")) +else: + print(cyan("CUDA is not available.")) + # raise ValueError("CUDA is required to run this code.") + + +@hydra.main( + version_base=None, + config_path="config", + config_name="main", +) +def train(cfg_dict: DictConfig): + # Set up configuration. + cfg, cfg_dict, eval_cfg = setup_cfg(cfg_dict) + + # This allows the current step to be shared with the data loader processes. + step_tracker = StepTracker() + + data_module = DataModule( + cfg.dataset, + cfg.data_loader, + step_tracker, + ) + + if cfg.mode == "train": + print("train:", len(data_module.train_dataloader())) + print("val:", len(data_module.val_dataloader())) + print("test:", len(data_module.test_dataloader())) + else: + print("test:", len(data_module.test_dataloader())) + + # DEBUGGING: loop over all data once to catch errors early + for batch_idx, batch in enumerate(data_module.test_dataloader()): + extrinsics = batch["context"]["extrinsics"] + pose_norm = extrinsics.view(extrinsics.shape[0], -1).norm(dim=1) + if pose_norm > 1e3: + print(f"Batch {batch_idx}: pose norm {pose_norm.item():.4f} {extrinsics} {batch['scene']} {batch['context']['index']}") + + image = batch["context"]["image"][0, 0].permute(1, 2, 0).cpu().numpy() + + plt.figure() + plt.imshow(image) + plt.title(f"Batch {batch_idx}\n{batch['scene'][0]}") + plt.show() + + print(cyan("DEBUG: Completed one full pass through the data loaders without errors. Exiting now.")) + sys.exit(0) + + +if __name__ == "__main__": + warnings.filterwarnings("ignore") + torch.set_float32_matmul_precision('high') + + if not torch.cuda.is_available(): + print("") + print(cyan("=" * 80)) + print(cyan("CUDA is not available, running on CPU.")) + print(cyan("=" * 80)) + print("") + + # Print the hostname and current working directory. + print(cyan("=" * 80)) + print(cyan(f"Starting training on {os.uname().nodename}, slurm job id: {os.environ.get('SLURM_JOB_ID', 'N/A')}")) + print(cyan(f"Current working directory: {Path.cwd()}")) + print(cyan("=" * 80)) + + train() diff --git a/optgs/scripts/dev/debug_sh.py b/optgs/scripts/dev/debug_sh.py new file mode 100644 index 0000000000000000000000000000000000000000..51ff9517105e7549a72d6ca3ad7902dda6526e3a --- /dev/null +++ b/optgs/scripts/dev/debug_sh.py @@ -0,0 +1,52 @@ +import gsplat +import torch +from gsplat import spherical_harmonics + +if __name__ == '__main__': + print(f"PyTorch version: {torch.__version__}") + print(f"CUDA version: {torch.version.cuda}") + print(f"cuDNN version: {torch.backends.cudnn.version()}") + print(f"gsplat version: {gsplat.__version__}") + + b = 1 + v = 2 + g = 10 + d = 1 + sh_degree_to_use = 0 + + # Directions [b, v, g, 3] + dirs = torch.tensor([[[[0.0, 0.0, 1.0]] * g] * v] * b) # [b, v, g, 3] + dirs = dirs.to(dtype=torch.float32, device="cuda") + + # SHs [b, v, g, d, 3] + shs = torch.ones(b, v, g, d, 3) * 0.1 # [b, v, g, d, 3] + shs = shs.to(dtype=torch.float32, device="cuda") + + # Masks (optional) [b, v, g] + masks = torch.rand(b, v, g) > 0.5 # Random boolean mask + masks = masks.to(device="cuda") + + print("======================== With Mask ========================") + for i in range(5): + dirs_copy = dirs.clone() + shs_copy = shs.clone() + masks_copy = masks.clone() + + colors = spherical_harmonics( + sh_degree_to_use, dirs_copy, shs_copy, masks=masks_copy + ) # [..., C, N, 3] + + print( + f"Iteration {i}: colors max {colors.max().item():.4f}, min {colors.min().item():.4f}, mean {colors.mean().item():.4f}") + + print("======================== Without Mask ========================") + for i in range(5): + dirs_copy = dirs.clone() + shs_copy = shs.clone() + + colors = spherical_harmonics( + sh_degree_to_use, dirs_copy, shs_copy + ) # [..., C, N, 3] + + print( + f"Iteration {i}: colors max {colors.max().item():.4f}, min {colors.min().item():.4f}, mean {colors.mean().item():.4f}") \ No newline at end of file diff --git a/optgs/scripts/dev/debug_stability_loss.py b/optgs/scripts/dev/debug_stability_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..13140f36c4801f3cd979e93dac97af6e5b06539e --- /dev/null +++ b/optgs/scripts/dev/debug_stability_loss.py @@ -0,0 +1,245 @@ +""" +Debug script mimicking the learned optimizer training loop with stability loss. +Simulates the meta-training loop with inner iterations and the stability loss. +""" +import torch +import torch.nn as nn +from dataclasses import dataclass +from typing import List, Optional + + +# ───────────────────────────────────────────── +# Minimal stubs mirroring your real classes +# ───────────────────────────────────────────── + +@dataclass +class RenderOutput: + color: torch.Tensor # [B, V, C, H, W] + + +@dataclass +class OptimizerOutput: + context_render_list: List[RenderOutput] + target_render_list: List[RenderOutput] + context_index_list: List[Optional[torch.Tensor]] # list of [B, V] or empty + target_index_list: List[Optional[torch.Tensor]] + + def get_render_list(self, input_str: str) -> List[RenderOutput]: + return self.context_render_list if input_str == "context" else self.target_render_list + + def get_index_list(self, input_str: str) -> List[torch.Tensor]: + lst = self.context_index_list if input_str == "context" else self.target_index_list + return [x for x in lst if x is not None] + + +# ───────────────────────────────────────────── +# Tiny "optimizer network" that produces renders +# across inner iterations — all connected in graph +# ───────────────────────────────────────────── + +class TinyOptimizerNet(nn.Module): + """ + Simulates a learned optimizer that refines a rendering across I inner steps. + Each step: render = prev_render + delta(prev_render, params) + This creates a graph that chains across iterations, just like your real model. + """ + def __init__(self, hidden=16): + super().__init__() + self.refine = nn.Sequential( + nn.Conv2d(3, hidden, 1), + nn.ReLU(), + nn.Conv2d(hidden, 3, 1), + ) + + def forward(self, init_render: torch.Tensor, num_inner: int): + """ + init_render: [B, V, C, H, W] + Returns list of renders of length num_inner+1 (init + refined) + """ + renders = [init_render.detach()] # init is detached, like in your code + curr = init_render.detach() + B, V, C, H, W = curr.shape + for _ in range(num_inner): + flat = curr.view(B * V, C, H, W) + delta = self.refine(flat) + curr = curr + delta.view(B, V, C, H, W) # connected graph across iters + renders.append(curr) + return renders + + +# ───────────────────────────────────────────── +# Stability loss (copied from your code) +# ───────────────────────────────────────────── + +class LossStability(nn.Module): + def forward(self, optimizer_output: OptimizerOutput, batch: dict) -> torch.Tensor: + total_loss = 0 + for input_str in ["context", "target"]: + render_list = optimizer_output.get_render_list(input_str) + index_list = optimizer_output.get_index_list(input_str) + + predictions = [render.color for render in render_list] + predictions = torch.stack(predictions, dim=0) # [I, B, V, C, H, W] + gt = batch[input_str]["image"] # [B, V_all, C, H, W] + + if len(index_list) == 0: + loss = torch.abs(predictions - gt).mean(dim=[3, 4, 5]) # [I, B, V] + change_in_loss = loss[1:] - loss[:-1].detach() # [I-1, B, V] + change_in_loss = torch.relu(change_in_loss) + total_loss = total_loss + change_in_loss.sum() + print(f" Stability loss ({input_str}): {total_loss.item():.6f}") + continue + + # With index lists + index_list_padded = [index_list[0]] + index_list # I tensors + index_list_t = torch.stack(index_list_padded, dim=0) # [I, B, V] + + b = gt.shape[0] + device = gt.device + batch_idx = torch.arange(b, device=device)[None, :, None] + gt_indexed = gt[batch_idx, index_list_t] # [I, B, V, C, H, W] + + loss = torch.abs(predictions - gt_indexed).mean(dim=[3, 4, 5]) # [I, B, V] + + I, B, V_all = predictions.shape[0], gt.shape[0], gt.shape[1] + loss_full = torch.zeros(I, B, V_all, device=device).scatter(2, index_list_t, loss) + + iter_idx = torch.arange(I, device=device).view(-1, 1, 1) + visited = loss_full > 0 + visit_ids = torch.where(visited, iter_idx, torch.full_like(iter_idx, -1)) + last_visit = torch.cummax(visit_ids, dim=0).values + prev_visit = torch.roll(last_visit, shifts=1, dims=0) + prev_visit[0] = -1 + safe_prev = prev_visit.clamp(min=0) + prev_loss = loss_full.gather(0, safe_prev).detach() + has_prev = prev_visit >= 0 + change_in_loss = torch.relu(loss_full - prev_loss) + change_in_loss = change_in_loss * has_prev.float().detach() + total_loss = total_loss + change_in_loss.sum() + print(f" Stability loss ({input_str}): {change_in_loss.sum().item():.6f}") + + return total_loss + + +# ───────────────────────────────────────────── +# Other losses (L1 and a fake LPIPS) +# ───────────────────────────────────────────── + +def compute_l1_loss(render_color, gt): + return torch.abs(render_color - gt).mean() + +def compute_lpips_loss(render_color, gt): + # Fake LPIPS: just MSE on downsampled version + return ((render_color - gt) ** 2).mean() + +def compute_meta_losses(optimizer_output, batch, num_inner): + """Mimics your _calc_opt_loss loop (without stability).""" + opt_loss = 0 + for i in range(num_inner): + pred = optimizer_output.context_render_list[i + 1].color + gt = batch["context"]["image"] + opt_loss = opt_loss + compute_l1_loss(pred, gt) + opt_loss = opt_loss + 0.1 * compute_lpips_loss(pred, gt) + + pred = optimizer_output.target_render_list[i + 1].color + gt = batch["target"]["image"] + opt_loss = opt_loss + compute_l1_loss(pred, gt) + opt_loss = opt_loss + 0.1 * compute_lpips_loss(pred, gt) + return opt_loss + + +# ───────────────────────────────────────────── +# Main training loop +# ───────────────────────────────────────────── + +def make_batch(B=2, V=3, C=3, H=8, W=8, device="cpu"): + return { + "context": {"image": torch.rand(B, V, C, H, W, device=device)}, + "target": {"image": torch.rand(B, V, C, H, W, device=device)}, + } + + +def run_meta_iteration(net, batch, stability_loss_fn, num_inner=5, use_index_list=False): + B, V, C, H, W = batch["context"]["image"].shape + device = batch["context"]["image"].device + + # Simulate init render (detached from network, like 3DGS init) + init_context = torch.rand(B, V, C, H, W, requires_grad=False, device=device) + init_target = torch.rand(B, V, C, H, W, requires_grad=False, device=device) + + context_renders = net(init_context, num_inner) + target_renders = net(init_target, num_inner) + + context_render_list = [RenderOutput(color=r) for r in context_renders] + target_render_list = [RenderOutput(color=r) for r in target_renders] + + if use_index_list: + # Simulate partial view sampling: pick V//2 views each inner iter + num_views = max(1, V // 2) + context_index_list = [ + torch.randint(0, V, (B, num_views), device=device) + for _ in range(num_inner) + ] + target_index_list = [ + torch.randint(0, V, (B, num_views), device=device) + for _ in range(num_inner) + ] + else: + context_index_list = [None] * num_inner + target_index_list = [None] * num_inner + + optimizer_output = OptimizerOutput( + context_render_list=context_render_list, + target_render_list=target_render_list, + context_index_list=context_index_list, + target_index_list=target_index_list, + ) + + # ── Meta losses (L1 + LPIPS across inner iters) ── + meta_loss = compute_meta_losses(optimizer_output, batch, num_inner) + + # ── Stability loss ── + stab_loss = stability_loss_fn(optimizer_output, batch) + + total_loss = meta_loss + 0.01 * stab_loss + + return total_loss + + +def main(): + torch.manual_seed(42) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + net = TinyOptimizerNet(hidden=16).to(device) + optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) + stability_loss_fn = LossStability() + + NUM_META_STEPS = 5 + NUM_INNER = 5 # inner iterations (like your 6, minus init) + + for mode in ["no_index_list", "with_index_list"]: + print(f"\n{'='*50}") + print(f"Mode: {mode}") + print(f"{'='*50}") + use_index = (mode == "with_index_list") + + for step in range(NUM_META_STEPS): + batch = make_batch(device=device) + optimizer.zero_grad() + + try: + total_loss = run_meta_iteration( + net, batch, stability_loss_fn, + num_inner=NUM_INNER, + use_index_list=use_index + ) + total_loss.backward() + optimizer.step() + print(f" Step {step+1}: total_loss={total_loss.item():.6f} ✓") + except RuntimeError as e: + print(f" Step {step+1}: ERROR - {e}") + break + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/optgs/scripts/dev/debugging_optimizer.py b/optgs/scripts/dev/debugging_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..10b79f5737addc62dbafd278b40a2423b6a605e3 --- /dev/null +++ b/optgs/scripts/dev/debugging_optimizer.py @@ -0,0 +1,1186 @@ +import os + +import matplotlib.animation as animation +import matplotlib.pyplot as plt + +from optgs.misc.io import CustomPath +from optgs.model.types import Gaussians + +plt.rcParams.update({'font.size': 18, + # line widths + 'lines.linewidth': 6, + }) + +import matplotlib.gridspec as gridspec +import subprocess +from torch import Tensor + + +def calc_hist(values, bins=100, density=True): + """Utility: return (x, y) for a histogram.""" + v = values.detach().cpu().numpy().flatten() + y, x = np.histogram(v, bins=bins, density=density) + x = 0.5 * (x[:-1] + x[1:]) + return x, y + + +def plot_gaussians_params_histograms( + data_groups: dict[str, list[Tensor]], + psnrs, + iters, + out_path=CustomPath("dashboard.mp4"), + max_frames=None, + last_k_hist=5, + save_last_time_only=False, + save_video=False +): + """ + Create a dashboard video visualizing parameter distributions and PSNR over iterations. + Shows histograms of the last K iterations with color fading for comparison. + """ + + # if save video is true, save_last_time_only must be false + assert not (save_video and save_last_time_only), "Cannot save video when save_last_time_only is True." + + # ---- Prepare parameter names ---- + sh_d = data_groups["shs"][0].shape[-1] // 3 + param_axis_names = { + "opacities": [""], + "means": ["x", "y", "z"], + "scales": ["x", "y", "z"], + "quats": ["x", "y", "z", "w"], + "shs": [f"r{i}" for i in range(sh_d)] + + [f"g{i}" for i in range(sh_d)] + + [f"b{i}" for i in range(sh_d)], + } + + # check shape of shs in first iteration + # if data_groups["shs"][0].dim() == 3: + # g_shape = data_groups["shs"][0].shape + # reshaped_shs = [] + # for iter_params in data_groups["shs"]: + # reshaped_shs.append(iter_params.reshape(-1, g_shape[1] * g_shape[2])) + # data_groups["shs"] = reshaped_shs + + # ---- Frame control ---- + T = len(iters) + if max_frames is not None: + T = min(T, max_frames) + + # ---- Prepare figure layout ---- + total_dims = sum(g[0].shape[-1] for g in data_groups.values()) + ncols = 5 + nrows = int(np.ceil(total_dims / ncols)) + + fig = plt.figure(figsize=(5 * ncols, 3.5 * (nrows + 1))) + gs = gridspec.GridSpec(nrows + 1, ncols, height_ratios=[1] * nrows + [0.5]) + axes = [fig.add_subplot(gs[i // ncols, i % ncols]) for i in range(nrows * ncols)] + ax_psnr = fig.add_subplot(gs[-1, :]) + + # ---- Precompute histograms and limits ---- + print("🔍 Precomputing histograms and axis limits...") + subplot_map = [] + i = 0 + for key, iters_params in data_groups.items(): + + D = iters_params.shape[-1] + + coord_names = [f"{key} {param_axis_names[key][d]}" for d in range(D)] + + for d in range(D): + # g_at_t = [iters_params[t] for t in range(T)] + all_hist_data = [calc_hist(iters_params[t][..., d], density=True) for t in range(T)] + all_x, all_y = zip(*all_hist_data) + xmin = min(x.min() for x in all_x) + xmax = max(x.max() for x in all_x) + # Center the x-axis around 0 + x_max_abs = max(abs(xmin), abs(xmax)) + xmin, xmax = -x_max_abs, x_max_abs + ymin = 0.0 + ymax = max(y.max() for y in all_y) * 1.1 + subplot_map.append((key, d, axes[i], coord_names[d], all_x, all_y, xmin, xmax, ymin, ymax)) + i += 1 + + # Hide unused subplots + total_used_subplots = len(subplot_map) + for j in range(total_used_subplots, len(axes)): + axes[j].set_visible(False) + + # ---- Output folders ---- + out_dir = out_path.parent + inter_dir = out_dir / "gaussians_histograms" + inter_dir.mkdir(parents=True, exist_ok=True) + + print(f"📸 Generating histograms frames in {inter_dir:link}") + + # ---- Frame generation loop ---- + for frame_idx in range(T): + + if save_last_time_only and frame_idx < T - 1: + continue + + fig.suptitle(f"Iteration {iters[frame_idx]} — PSNR: {psnrs[frame_idx]:.2f}", fontsize=18) + + for key, d, ax, name, all_x, all_y, xmin, xmax, ymin, ymax in subplot_map: + ax.clear() + + # Plot last_k_hist iterations with progressive color fading + k = min(last_k_hist, frame_idx + 1) + idxs = list(range(frame_idx - k + 1, frame_idx + 1)) + for rel_i, hist_idx in enumerate(idxs): + color = plt.cm.viridis(rel_i / max(1, k - 1)) # gradient color + label = f"Iter {iters[hist_idx]}" + ax.plot(all_x[hist_idx], all_y[hist_idx], color=color, alpha=0.9, lw=6, label=label) + + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) + ax.set_title(name) + ax.legend(frameon=False, loc="upper right", fontsize=7) + ax.grid(True, linestyle='--', alpha=0.5) + # Add vertical line at x=0 to show center + ax.axvline(0, color='black', linewidth=1, linestyle=':', alpha=0.7) + + # ---- PSNR subplot ---- + ax_psnr.clear() + ax_psnr.plot(iters[:frame_idx + 1], psnrs[:frame_idx + 1], color="#ffbc42", linewidth=8) + ax_psnr.scatter(iters[frame_idx], psnrs[frame_idx], color="#ffbc42", s=60, zorder=3, linewidth=8) + ax_psnr.set_xlim(min(iters), max(iters)) + ax_psnr.set_ylim(max(psnrs) * 0.7, max(psnrs) * 1.1) + ax_psnr.set_title("PSNR Progress") + ax_psnr.set_xlabel("Iteration") + ax_psnr.set_ylabel("PSNR") + + plt.tight_layout(rect=[0, 0, 1, 0.97]) + + frame_path = inter_dir / f"hist_{frame_idx:05d}.png" + fig.savefig(frame_path, dpi=400) + + plt.close(fig) + + if not save_video: + print(f"✅ Saved dashboard frames to {inter_dir} ({T} frames total)") + return + + # ---- Combine with ffmpeg ---- + total_duration_sec = 20.0 + fps = T / total_duration_sec + + cmd = [ + "ffmpeg", "-y", "-framerate", f"{fps}", + "-i", str(inter_dir / "hist_%05d.png"), + "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", + "-c:v", "libx264", "-pix_fmt", "yuv420p", + "-crf", "18", str(out_path) + + ] + print("🎞️ Running FFmpeg to create video...") + subprocess.run(cmd, check=True) + + print(f"✅ Saved dashboard video to {out_path} ({total_duration_sec:.1f}s total)") + + +def make_gaussians_dashboard_video_with_ani(data_groups, psnrs, iters, out_path=CustomPath("dashboard.mp4"), + max_frames=None, scene=0): + """ + Create a dashboard video visualizing parameter distributions and PSNR over iterations. + Args: + data_groups (dict): Dictionary containing parameter groups as keys and list of tensors as values. + Each list should have T entry of shape (N, D) where T is time + psnrs (list): List of PSNR values over iterations. + iters (list): List of iteration numbers corresponding to the PSNR values. + out_path (CustomPath): Path to save the output video. + max_frames (int, optional): Maximum number of frames to include in the video. If None, include all frames. + """ + # Groups to visualize + + # Axis names for each parameter group + # Calc sh axis names + sh_d = data_groups["shs"][0][0].shape[-1] // 3 + param_axis_names = { + "opacities": [""], + "means": ["x", "y", "z"], + "scales": ["x", "y", "z"], + "quats": ["x", "y", "z", "w"], + "shs": ["r" + str(i) for i in range(sh_d)] + [f"g{i}" for i in range(sh_d)] + [f"b{i}" for i in range(sh_d)], + } + + T = list(data_groups.values())[0].shape[0] + if max_frames is not None: + T = min(T, max_frames) + if iters is None: + iters = list(range(T)) + + # Count total subplots + total_dims = sum(g.shape[-1] for g in data_groups.values()) + ncols = 4 + nrows = int(np.ceil(total_dims / ncols)) + + # Use GridSpec to reserve one bottom row for PSNR plot + fig = plt.figure(figsize=(5 * ncols, 3.5 * (nrows + 1))) + gs = gridspec.GridSpec(nrows + 1, ncols, height_ratios=[1] * nrows + [0.5]) + axes = [fig.add_subplot(gs[i // ncols, i % ncols]) for i in range(nrows * ncols)] + ax_psnr = fig.add_subplot(gs[-1, :]) + + # Precompute all histograms and axis limits + subplot_map = [] + i = 0 + for key, g in data_groups.items(): + D = g.shape[-1] + coord_names = [f"{key} {param_axis_names[key][i]}" for i in range(D)] + for d in range(D): + all_hist_data = [calc_hist(g[t, scene, :, d], density=True) for t in range(T)] + all_x, all_y = zip(*all_hist_data) + # Find all time min/max for consistent axis limits + xmin = min(x.min() for x in all_x) + xmax = max(x.max() for x in all_x) + ymin = 0.0 + ymax = max(y.max() for y in all_y) * 1.1 + subplot_map.append((key, d, axes[i], coord_names[d], all_x, all_y, xmin, xmax, ymin, ymax)) + i += 1 + + # Animation update + def update(frame_idx): + fig.suptitle(f"Iteration {iters[frame_idx]} — PSNR: {psnrs[frame_idx]:.2f}", fontsize=18) + + for key, d, ax, name, all_x, all_y, xmin, xmax, ymin, ymax in subplot_map: + ax.clear() + ax.plot(all_x[frame_idx], all_y[frame_idx], color="#17becf", label=r"Resplat $\Delta$") + + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) + ax.set_title(name) + ax.legend(frameon=False, loc="upper left") + + # PSNR curve subplot + ax_psnr.clear() + ax_psnr.plot(iters[:frame_idx + 1], psnrs[:frame_idx + 1], color="#ffbc42") + ax_psnr.scatter(iters[frame_idx], psnrs[frame_idx], color="#ffbc42", s=60, zorder=3) + ax_psnr.set_xlim(min(iters), max(iters)) + ax_psnr.set_ylim(min(psnrs) * 0.98, max(psnrs) * 1.02) + ax_psnr.set_title("PSNR Progress") + ax_psnr.set_xlabel("Iteration") + ax_psnr.set_ylabel("PSNR") + + plt.tight_layout(rect=[0, 0, 1, 0.97]) + return axes + [ax_psnr] + + # Create video + Path(out_path).parent.mkdir(parents=True, exist_ok=True) + + total_duration_sec = 20.0 # desired total duration + interval_ms = total_duration_sec * 1000 / T # milliseconds per frame + + try: + ani = animation.FuncAnimation(fig, update, frames=T, interval=interval_ms, blit=False) + ani.save(out_path, writer="ffmpeg", dpi=300) + except FileNotFoundError: + print("⚠️ FFmpeg not found. Saving as GIF instead.") + ani = animation.FuncAnimation(fig, update, frames=T, interval=interval_ms, blit=False) + ani.save(out_path.replace(".mp4", ".gif"), writer="pillow", dpi=300) + + plt.close(fig) + print(f"✅ Saved dashboard video to {out_path} ({total_duration_sec:.1f}s total)") + + plt.close(fig) + print(f"✅ Saved dashboard video to {out_path}") + + +# def make_dashboard_video(info, psnrs, iters, vanilla_lr, out_path="dashboard.mp4", max_frames=None): +# # Groups to visualize +# param_groups = ["opacities", "means", "scales", "rotations", "shs"] +# +# # Axis names for each parameter group +# # Calc sh axis names +# sh_d = info["delta_shs"][0][0].shape[-1] // 3 +# param_axis_names = { +# "opacities": [""], +# "means": ["x", "y", "z"], +# "scales": ["x", "y", "z"], +# "rotations": ["x", "y", "z", "w"], +# "shs": ["r" + str(i) for i in range(sh_d)] + [f"g{i}" for i in range(sh_d)] + [f"b{i}" for i in range(sh_d)], +# } +# +# # Extract and stack tensors +# data = {} +# for key in param_groups: +# delta_data = torch.stack(info[f"delta_{key}"], dim=0) # (T, B, N, D) +# norm_grads_data = torch.stack(info[f"normalized_grad_{key}"], dim=0) # (T, N, D) +# data[key] = (delta_data, norm_grads_data) +# +# T = list(data.values())[0][0].shape[0] +# if max_frames is not None: +# T = min(T, max_frames) +# if iters is None: +# iters = list(range(T)) +# scene = 0 +# +# # Compute axis limits for each param/dim +# axis_limits = {} +# for key, (delta_data, norm_grads_data) in data.items(): +# D = delta_data.shape[-1] +# axis_limits[key] = [] +# for d in range(D): +# delta_all = delta_data[:, scene, :, d].float().flatten().cpu().numpy() +# grad_all = norm_grads_data[:, :, d].float().flatten().cpu().numpy() * vanilla_lr[key] +# vmin = min(delta_all.min(), grad_all.min()) +# vmax = max(delta_all.max(), grad_all.max()) +# +# # Compute max y-density across all frames +# y_max = 0.0 +# for t in range(T): +# delta = delta_data[t, scene, :, d].float().cpu().numpy() +# grad = norm_grads_data[t, :, d].float().cpu().numpy() * vanilla_lr[key] +# _, y1 = calc_hist(delta, density=True) +# _, y2 = calc_hist(grad, density=True) +# y_max = max(y_max, y1.max(), y2.max()) +# +# axis_limits[key].append((vmin, vmax, 0.0, y_max * 0.1)) # add small headroom +# +# # Count total subplots +# total_dims = sum(delta.shape[-1] for delta, _ in data.values()) +# ncols = 4 +# nrows = int(np.ceil(total_dims / ncols)) +# +# # Use GridSpec to reserve one bottom row for PSNR plot +# fig = plt.figure(figsize=(5 * ncols, 3.5 * (nrows + 1))) +# gs = gridspec.GridSpec(nrows + 1, ncols, height_ratios=[1] * nrows + [0.5]) +# axes = [fig.add_subplot(gs[i // ncols, i % ncols]) for i in range(nrows * ncols)] +# ax_psnr = fig.add_subplot(gs[-1, :]) +# +# subplot_map = [] +# i = 0 +# for key, (delta_data, norm_grads_data) in data.items(): +# D = delta_data.shape[-1] +# coord_names = [f"{key} {param_axis_names[key][i]}" for i in range(D)] +# for d in range(D): +# subplot_map.append((key, d, axes[i], coord_names[d])) +# i += 1 +# +# # Animation update +# def update(frame_idx): +# fig.suptitle(f"Iteration {iters[frame_idx]} — PSNR: {psnrs[frame_idx]:.2f}", fontsize=18) +# +# for key, d, ax, name in subplot_map: +# ax.clear() +# delta_data, grads_data = data[key] +# delta = delta_data[frame_idx, scene, :, d].float().cpu().numpy() +# grad = grads_data[frame_idx, :, d].float().cpu().numpy() * vanilla_lr[key] +# vmin, vmax, ymin, ymax = axis_limits[key][d] +# +# # Δ histogram +# x1, y1 = calc_hist(delta, density=True) +# ax.plot(x1, y1, color="#17becf", label=r"Resplat $\Delta$") +# +# # grad histogram +# x2, y2 = calc_hist(grad, density=True) +# ax.plot(x2, y2, color="#e377c2", ls="--", label=r"Adam $\Delta$") +# +# ax.set_xlim(vmin, vmax) +# ax.set_ylim(ymin, ymax) +# ax.set_title(name) +# ax.legend(frameon=False, loc="upper left") +# +# # PSNR curve subplot +# ax_psnr.clear() +# ax_psnr.plot(iters[:frame_idx + 1], psnrs[:frame_idx + 1], color="#ffbc42") +# ax_psnr.scatter(iters[frame_idx], psnrs[frame_idx], color="#ffbc42", s=60, zorder=3) +# ax_psnr.set_xlim(min(iters), max(iters)) +# ax_psnr.set_ylim(min(psnrs) * 0.98, max(psnrs) * 1.02) +# ax_psnr.set_title("PSNR Progress") +# ax_psnr.set_xlabel("Iteration") +# ax_psnr.set_ylabel("PSNR") +# +# plt.tight_layout(rect=[0, 0, 1, 0.97]) +# return axes + [ax_psnr] +# +# # Create video +# Path(out_path).parent.mkdir(parents=True, exist_ok=True) +# +# total_duration_sec = 20.0 # desired total duration +# interval_ms = total_duration_sec * 1000 / T # milliseconds per frame +# +# try: +# ani = animation.FuncAnimation(fig, update, frames=T, interval=interval_ms, blit=False) +# ani.save(out_path, writer="ffmpeg", dpi=300) +# except FileNotFoundError: +# print("⚠️ FFmpeg not found. Saving as GIF instead.") +# ani = animation.FuncAnimation(fig, update, frames=T, interval=interval_ms, blit=False) +# ani.save(out_path.replace(".mp4", ".gif"), writer="pillow", dpi=300) +# +# plt.close(fig) +# print(f"✅ Saved dashboard video to {out_path} ({total_duration_sec:.1f}s total)") +# +# plt.close(fig) +# print(f"✅ Saved dashboard video to {out_path}") + + +import numpy as np +import torch +from torch import Tensor +from pathlib import Path +import matplotlib.pyplot as plt + + +def calc_hist(data, max_percentile=99.9, min_percentile=0.1, density=False): + max_val = np.percentile(data, max_percentile) + min_val = np.percentile(data, min_percentile) + curr_data = data.clip(min_val, max_val) + counts, bin_edges = np.histogram(curr_data, bins=100, range=(min_val, max_val), density=density) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + + return bin_centers, counts + + +def debugging_convergence( + deltas_list: list[dict[str, Tensor]], + states_norms_list: list[Tensor], + grads_raw_list: list[dict[str, Tensor]], + normalized_grads_list: list[dict[str, Tensor]], + psnr_list: list[float], + iterations_list: list[int], + output_path: Path, + scene_name: str +): + print("📈 Generating convergence plots...") + assert len(iterations_list) > 0, "Iterations list cannot be empty." + assert len(psnr_list) == len(iterations_list), "PSNR list length must match iterations list length." + + iters = iterations_list + psnrs = psnr_list + states_norms = [] + for state_norms in states_norms_list: + states_norms.append(state_norms.mean().item()) + + deltas_abs_means = [] + for deltas in deltas_list: + total_mean = 0.0 + count = 0 + for key, delta in deltas.items(): + total_mean += delta.abs().mean().item() + count += 1 + deltas_abs_means.append(total_mean / count) + + grads_raw_abs_means = [] + for grads in grads_raw_list: + total_mean = 0.0 + count = 0 + for key, grad in grads.items(): + total_mean += grad.abs().mean().item() + count += 1 + grads_raw_abs_means.append(total_mean / count) + + normalized_grads_abs_means = [] + for normalized_grads in normalized_grads_list: + total_mean = 0.0 + count = 0 + for key, grad in normalized_grads.items(): + total_mean += grad.abs().mean().item() + count += 1 + normalized_grads_abs_means.append(total_mean / count) + + # set rc once (inside context to avoid global mutation) + rc = { + 'axes.titlesize': 17, + 'axes.labelsize': 15, + 'xtick.labelsize': 15, + 'ytick.labelsize': 15, + 'legend.fontsize': 11 + } + + with plt.rc_context(rc): + # plot all quantities in one figure with 4 subplots + fig, axs = plt.subplots(5, 1, figsize=(10, 15)) + # PSNR + axs[0].plot(iters, psnrs, marker='o', color='blue') + axs[0].set_title('PSNR over Iterations') + axs[0].set_xlabel('Iteration') + axs[0].set_ylabel('PSNR') + axs[0].grid(True, alpha=0.3) + # State norm + axs[1].plot(iters, states_norms, marker='o', color='orange') + axs[1].set_title('State Norm over Iterations') + axs[1].set_xlabel('Iteration') + axs[1].set_ylabel('State Norm') + axs[1].grid(True, alpha=0.3) + # Delta abs mean + axs[2].plot(iters, deltas_abs_means, marker='o', color='green') + axs[2].set_title('Mean Absolute Delta over Iterations') + axs[2].set_xlabel('Iteration') + axs[2].set_ylabel('Mean Absolute Delta') + axs[2].grid(True, alpha=0.3) + # Gradient abs mean + axs[3].plot(iters, grads_raw_abs_means, marker='o', color='red', label='Raw Grads') + axs[3].set_title('Mean Absolute Gradient over Iterations') + axs[3].set_xlabel('Iteration') + axs[3].set_ylabel('Mean Absolute Gradient') + axs[3].grid(True, alpha=0.3) + # Normalized Gradient abs mean + axs[4].plot(iters, normalized_grads_abs_means, marker='o', color='purple', label='Normalized Grads') + axs[4].set_title('Mean Absolute Normalized Gradient over Iterations') + axs[4].set_xlabel('Iteration') + axs[4].set_ylabel('Mean Absolute Normalized Gradient') + axs[4].grid(True, alpha=0.3) + plt.tight_layout() + (output_path / "plots" / scene_name).mkdir(parents=True, exist_ok=True) + plt.savefig(output_path / "plots" / scene_name / "convergence_plot.png", dpi=300) + plt.close() + + +def debugging_deltas( + deltas_list: list[dict[str, Tensor]], + grads_list: list[dict[str, Tensor]], + normalized_grads_list: list[dict[str, Tensor]], + learning_rates: list[dict[str, float]], + psnr_list: list[float], + iterations_list: list[int], + output_path: Path, + scene_name: str +): + assert len(iterations_list) > 0, "Iterations list cannot be empty." + assert len(psnr_list) == len(iterations_list), "PSNR list length must match iterations list length." + + # Remove init. + psnr_list = psnr_list[1:] + iterations_list = iterations_list[1:] + + assert len(deltas_list) == len(iterations_list), "Deltas list length must match iterations list length." + assert len(grads_list) == len(iterations_list), "Grads list length must match iterations list length." + assert len(normalized_grads_list) == len( + iterations_list), "Normalized grads list length must match iterations list length." + if len(learning_rates) > 0: + assert len(learning_rates) == len( + iterations_list), "Learning rates list length must match iterations list length." + iters = iterations_list + psnrs = psnr_list + # max_iter = max(iters) if len(iters) > 0 else 1 + nr_iters = len(iters) + + # set rc once (inside context to avoid global mutation) + rc = { + 'axes.titlesize': 17, + 'axes.labelsize': 15, + 'xtick.labelsize': 15, + 'ytick.labelsize': 15, + 'legend.fontsize': 11 + } + + # Plot delta histograms + for key in ["opacities", "means", "scales", "rotations"]: + + # TODO: log "sh0s", "shNs" + + # here N can change between iterations, C changes based on parameter type + delta_data = [deltas[key] for deltas in deltas_list] # list of [N, C] + grads_data = [grads[key] for grads in grads_list] # list of [N, C] + normalized_grads_data = [normalized_grads[key] for normalized_grads in normalized_grads_list] # list of [N, C] + # if len(learning_rates) == 0: + # lr_data = [1.0] * len(delta_data) # list of floats + # else: + # lr_data = [lrs[key] for lrs in learning_rates] # list of floats + + # Plot histogram of delta means for each step for each coordinate + D = delta_data[0].shape[-1] + + rows = 3 # delta, grad, normalized_grad + + with plt.rc_context(rc): + plt.figure(figsize=(10 * D, 8 * rows)) + + if D in [3, 4]: + coord_names = ['X', 'Y', 'Z', 'W'][:D] + elif D == 1: + coord_names = [""] + else: + coord_names = [f"Dim {i}" for i in range(D)] + + for r, kind in enumerate(["delta", "grad", "grad_norm"]): + for d in range(D): + ax = plt.subplot(rows, D, r * D + d + 1) + + for i, t in enumerate(iters): + color_frac = float(i) / float(nr_iters) + + # Select the correct dataset + if kind == "delta": + curr = delta_data[i][:, d].float().cpu().numpy() + cmap = plt.cm.viridis + elif kind == "grad": + curr = grads_data[i][:, d].float().cpu().numpy() + cmap = plt.cm.cividis + else: # grad_norm + curr = normalized_grads_data[i][:, d].float().cpu().numpy() + cmap = plt.cm.plasma + + # Compute histogram as normalized density + bin_centers, counts = calc_hist(curr) + max_counts = counts.max() + if max_counts > 0: + counts = counts / max_counts # normalize peak=1 + + label = f"step: {t}, psnr: {psnrs[i]}" + ax.plot(bin_centers, counts, label=label, + color=cmap(color_frac), linewidth=2) + + xlim = (-np.max(np.abs(bin_centers)), np.max(np.abs(bin_centers))) + ax.set_xlim(xlim) # Center around 0 + ax.axvline(0, color='black', linewidth=1, linestyle=':') # vertical center line + + if r == rows - 1: + ax.set_xlabel(f"{coord_names[d]}") + if d == 0: + ax.set_ylabel("Density") + + # Titles + ax.set_title(f"{kind.replace('_', ' ').title()} {key.replace('_', ' ').title()} {coord_names[d]}") + + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3) + + plt.suptitle(f"{key.replace('_', ' ').title()} histograms (centered & normalized)", fontsize=18) + plt.tight_layout(rect=[0, 0, 1, 0.97]) + + # save figure + save_dir = os.path.join(output_path, "plots", scene_name) + os.makedirs(save_dir, exist_ok=True) + save_path = os.path.join(save_dir, f"{key}_deltas_histogram.png") + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"Saved delta histogram plot to {save_path}") + + # plt.figure(figsize=(10 * D, 10)) + # # Adjust font size + # plt.rcParams.update({ + # 'axes.titlesize': 17, + # 'axes.labelsize': 15, + # 'xtick.labelsize': 15, + # 'ytick.labelsize': 15, + # 'legend.fontsize': 11 # Smaller + # }) + # if D in [3, 4]: + # coord_names = ['X', 'Y', 'Z', 'W'] + # elif D == 1: + # coord_names = [""] + # else: + # coord_names = [f"Dim {i}" for i in range(D)] + + # for d in range(D): + # plt.subplot(1, D, d + 1) + # for i, t in enumerate(iters): + + # # Plot histogram of delta + + # color = plt.cm.viridis(t / iters[-1]) + # scene = 0 + # curr_delta = delta_data[i, scene, :, d].float().cpu().numpy() + # bin_centers, counts = calc_hist(curr_delta) + # plt.plot(bin_centers, counts, label=fr"{psnrs[i]} $\Delta$ step {t}", color=color, linewidth=3) + + # # Plot histogram of normalized grad + + # color = plt.cm.plasma(t / iters[-1]) + # curr_norm_grad = normalized_grads_data[i, :, d].float().cpu().numpy() + # bin_centers, counts = calc_hist(curr_norm_grad) + # plt.plot(bin_centers, counts, label=fr"{psnrs[i]} $g_t$ normalized step {t}", color=color, + # linewidth=3, + # linestyle='--') + + # plt.xlabel(f"Delta {coord_names[d]}") + # plt.ylabel("Count") + # plt.title(f"{name} {coord_names[d]} histogram") + + # # Arange irst delta handles and then normalized grad handles + # handles, labels = plt.gca().get_legend_handles_labels() + # delta_handles = [h for h, l in zip(handles, labels) if "Delta" in l] + # norm_grad_handles = [h for h, l in zip(handles, labels) if "g_t" in l] + # handles = delta_handles + norm_grad_handles + # labels = [l for l in labels if "Delta" in l] + [l for l in labels if "g_t" in l] + # plt.legend(handles, labels) + # plt.tight_layout() + + # # save figure + # save_path = output_path / "plots" / scene_name + # os.makedirs(save_path, exist_ok=True) + + # save_path = save_path / f"{key}_deltas_histogram.png" + # plt.savefig(save_path, dpi=300, bbox_inches='tight') + # plt.close() + # print(f"Saved delta histogram plot to {save_path}") + + # # Plot delta cumsum + # for key in ["delta_opacities"]: + + # name = key.replace("_", " ").title() + # delta_data = deltas[key] # list of [B, N, 3] + # delta_data = torch.stack(delta_data, dim=0) # (steps, B, N, d) + # delta_cumsum = delta_data.cumsum(dim=0) # (steps, B, N, d) + + # # Plot cumsum of delta for randomly sampled 10 gaussians + + # D = delta_data.shape[-1] + # plt.figure(figsize=(10 * D, 10)) + # # Adjust font size + # plt.rcParams.update({ + # 'axes.titlesize': 17, + # 'axes.labelsize': 15, + # 'xtick.labelsize': 15, + # 'ytick.labelsize': 15, + # 'legend.fontsize': 11 # Smaller + # }) + + # indices = np.random.choice(delta_data.shape[2], size=20, replace=False, ) + # scene = 0 + # # get indices of the maximum cumsum at the last step + # indices = torch.argsort(delta_cumsum[-1, scene].abs().sum(dim=-1), descending=True)[:20].cpu().numpy() + # for d in range(D): + # plt.subplot(1, D, d + 1) + # for idx in indices: + # curr_delta = delta_cumsum[:, scene, idx, d].float().cpu().numpy() + # plt.plot(iters, curr_delta, label=f"Gaussian {idx}", linewidth=2) + # # plt.plot(iters, psnrs[1:], 'k--', label="PSNR", linewidth=4) + + # plt.xlabel("Iteration") + # plt.ylabel(f"Accumulative of delta {name}") + # plt.title(f"Accumulative of delta {name} for 10 maximum gaussians") + # plt.grid(True) + # plt.legend() + + # plt.tight_layout() + # # plt.show() + + # raise NotImplementedError("Plot saving not implemented yet.") + + +# def debugging_reprojection_error(visualization_dump): +# reprojection_error = visualization_dump['reprojection_error'] # list of list of (B, V, H*W, 2) +# # Convert list of list to tensor +# reprojection_error = [torch.stack(scene_errors, dim=0) for scene_errors in +# reprojection_error] # list of (iterations, B, V, H*W, 2) +# reprojection_error = torch.stack(reprojection_error, dim=0) # (scenes, iterations, B, V, H*W, 2) +# reprojection_error = torch.permute(reprojection_error, +# (1, 0, 2, 3, 4, 5)) # [iterations, scenes, B, V, H*W, 2] +# max_val = 3 +# reprojection_error = reprojection_error.clamp(0, max_val) +# iterations = self.optimizer.save_every.get_iterations(len(reprojection_error)) +# target_psnrs = self.test_step_outputs_target["psnr"] # list of psnr for target views per scene +# target_psnrs = torch.Tensor(target_psnrs) # [scenes, iterations] +# target_psnrs = target_psnrs.mean(0) # [iterations] + +# # Plot histograms of reprojection error through out the iterations + +# out_dir = self.test_cfg.output_path / "debugging" +# out_dir.mkdir(parents=True, exist_ok=True) +# plt.figure(figsize=(6, 5)) +# for i, t in enumerate(iterations): +# error = reprojection_error[i] +# error_hist = error.view(-1).cpu().numpy() + +# counts, bin_edges = np.histogram(error_hist, bins=100, range=(0, max_val), density=False) +# bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) +# plt.plot(bin_centers, counts, +# label=f"Iter {t}, psnr {target_psnrs[i]:.2f}", +# color=plt.cm.viridis(i / len(iterations)), +# linewidth=4, +# ) + +# # plt.hist(error_hist, bins=100, range=(0, max_val), label=f"Iter {t}, psnr {target_psnrs[i]:.2f}", +# # histtype='step', +# # color=plt.cm.viridis(i / len(iterations)), +# # linewidth=4, ) + +# # put the legend outside the plot to the right +# plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left') +# plt.xlabel("Reprojection error (pixels)") +# plt.ylabel("Count") +# plt.title("Reprojection error throughout test iterations") +# plt.tight_layout() +# # plt.show() +# raise NotImplementedError("Saving reprojection error plots is not implemented yet.") + +def debugging_gaussians(gaussian_list: list[Gaussians], psnr_list: list[float], iter_list: list[int], output_path: Path, + scene_name: str): + assert len(gaussian_list) > 0, "Gaussian list cannot be empty." + assert len(gaussian_list) == len(iter_list), "Gaussian list length must match iterations list length." + assert len(psnr_list) == len(iter_list), "PSNR list length must match iterations list length." + + if gaussian_list[0].stores_activated: + # need to invert the transformations + scales_fn = torch.log + opacities_fn = torch.logit + else: + # keep as is + scales_fn = lambda x: x + opacities_fn = lambda x: x + + # Extract gaussian attributes + data_groups = { + "opacities": [opacities_fn(g.opacities).squeeze(0).detach().cpu().unsqueeze(-1) for g in gaussian_list], + "scales": [scales_fn(g.scales).squeeze(0).detach().cpu() for g in gaussian_list], + "quats": [g.rotations.squeeze(0).detach().cpu() for g in gaussian_list], + "means": [g.means.squeeze(0).detach().cpu() for g in gaussian_list], + "shs": [g.harmonics.squeeze(0).detach().cpu() for g in gaussian_list]} + + plot_gaussians_params_histograms( + data_groups=data_groups, + psnrs=psnr_list, + iters=iter_list, + out_path=output_path / f"plots/{scene_name}/params.mp4" + ) + + +# def debugging_grads(visualization_dump): + +# # From post processing +# gt = visualization_dump["grads"] # list of list of list (Scenes, Steps, N, dim) +# # Convert list of list to tensor +# gt = [torch.stack(scene_grads, dim=0) for scene_grads in gt] # list of (steps, N, dim) +# gt = torch.stack(gt, dim=0) # (scenes, steps, N, dim) + +# gt2 = gt ** 2 + +# beta1 = 0.9 +# beta2 = 0.999 +# eps = 1e-8 + +# # Calculate the moving averages of adam +# mt = torch.zeros_like(gt) +# vt = torch.zeros_like(gt) +# mt2 = torch.zeros_like(gt2) +# vt2 = torch.zeros_like(gt2) +# mt_hat = torch.zeros_like(gt) +# vt_hat = torch.zeros_like(gt) +# for t in range(gt.shape[1]): +# mt[:, t] = beta1 * mt[:, t - 1] + (1 - beta1) * gt[:, t] if t > 0 else (1 - beta1) * gt[:, t] +# vt[:, t] = beta2 * vt[:, t - 1] + (1 - beta2) * gt[:, t] ** 2 if t > 0 else (1 - beta2) * gt[:, t] ** 2 +# mt2[:, t] = beta1 * mt2[:, t - 1] + (1 - beta1) * gt2[:, t] if t > 0 else (1 - beta1) * gt2[:, t] +# vt2[:, t] = beta2 * vt2[:, t - 1] + (1 - beta2) * gt2[:, t] ** 2 if t > 0 else (1 - beta2) * gt2[:, t] ** 2 +# mt_hat[:, t] = mt[:, t] / (1 - beta1 ** (t + 1)) +# vt_hat[:, t] = vt[:, t] / (1 - beta2 ** (t + 1)) + +# denom = torch.sqrt(vt_hat) + eps +# delta = mt_hat / denom + +# # Plot histograms of gt, gt^2, mt_hat, vt_hat, delta + +# # Adjust font size +# plt.rcParams.update({ +# 'axes.titlesize': 17, +# 'axes.labelsize': 15, +# 'xtick.labelsize': 15, +# 'ytick.labelsize': 15, +# 'legend.fontsize': 9 # Smaller +# }) +# d = 0 # means x +# d = 2 # means z +# scene = 0 +# plt.figure(figsize=(20, 15)) +# names = [r"$g_t$", r"$g_t^2$", r"$\hat{m}_t$", r"$\hat{v}_t$", r"$\sqrt{\hat{v}_t} + \epsilon$", r"$\Delta$"] +# data_list = [gt, gt2, mt_hat, vt_hat, denom, delta] +# for i, (name, data) in enumerate(zip(names, data_list)): +# plt.subplot(2, 3, i + 1) + +# T = data.shape[1] +# for t in range(T): +# data_t = data[scene, t, :, d] +# max_val = np.percentile(data_t, 99.9) +# min_val = np.percentile(data_t, 0.1) +# data_t = data_t.clamp(min_val, max_val).cpu().numpy() +# counts, bin_edges = np.histogram(data_t, bins=100, range=(min_val, max_val), density=False) +# bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) +# # plot color with virdis colormap +# color = plt.cm.viridis(t / T) +# plt.plot(bin_centers, counts, label=fr"step {t}", color=color, linewidth=3) +# # plt.xlim((min_val, max_val)) +# plt.xlabel(name) +# plt.ylabel("Count") +# plt.title(f"{name}") +# plt.legend() +# plt.suptitle(f"Histograms of Adam statistics for gradient element {d}") +# plt.tight_layout() +# # plt.show() +# raise NotImplementedError("Plot saving not implemented yet.") + +# # Compare gt to mt +# gt_mt_diff = gt - mt_hat +# plt.figure(figsize=(6, 5)) +# T = gt_mt_diff.shape[1] +# for t in range(1, T): +# data = gt_mt_diff[scene, t, :, d] +# max_val = np.percentile(data, 99.9) +# min_val = np.percentile(data, 0.1) +# data = data.clamp(min=min_val, max=max_val) +# counts, bin_edges = np.histogram(data, bins=100, range=(min_val, max_val), density=False) +# bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) +# # plot color with virdis colormap +# color = plt.cm.viridis(t / T) +# plt.plot(bin_centers, counts, label=fr"step {t}", color=color, linewidth=3) +# plt.xlabel(r"$g_t - \hat{m}_t$") +# plt.ylabel("Count") +# plt.title(r"Histogram of $g_t - \hat{m}_t$") +# plt.legend() +# plt.tight_layout() +# plt.grid(True) +# # plt.show() +# raise NotImplementedError("Plot saving not implemented yet.") + +# # Compaer |gt| to sqrt(vt) + eps +# gt_abs_ratio = gt.abs() / denom +# plt.figure(figsize=(6, 5)) +# T = gt_abs_ratio.shape[1] +# for t in range(1, T): +# data = gt_abs_ratio[scene, t, :, d] +# max_val = np.percentile(data, 99.9) +# min_val = np.percentile(data, 0.1) +# data = data.clamp(min=min_val, max=max_val) +# counts, bin_edges = np.histogram(data, bins=100, range=(min_val, max_val), density=False) +# bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) +# # plot color with virdis colormap +# color = plt.cm.viridis(t / T) +# plt.plot(bin_centers, counts, label=fr"step {t}", color=color, linewidth=3) +# plt.xlabel(r"$|g_t| / (\sqrt{\hat{v}_t} + \epsilon)$") +# plt.ylabel("Count") +# plt.title(r"Histogram of $|g_t| / (\sqrt{\hat{v}_t} + \epsilon)$") +# plt.legend() +# plt.tight_layout() +# plt.grid(True) +# # plt.show() +# raise NotImplementedError("Plot saving not implemented yet.") + +# # Compare gt to delta +# delta_ratio = delta / gt +# plt.figure(figsize=(10, 5)) +# T = delta_ratio.shape[1] +# for t in range(1, T): +# data = delta_ratio[scene, t, :, d] +# max_val = np.percentile(data, 99.9) +# min_val = np.percentile(data, 0.1) +# data = data.clamp(min=min_val, max=max_val) +# counts, bin_edges = np.histogram(data, bins=100, range=(min_val, max_val), density=False) +# bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) +# # plot color with virdis colormap +# color = plt.cm.viridis(t / T) +# plt.plot(bin_centers, counts, label=fr"step {t}", color=color, linewidth=3) +# plt.xlabel(r"$g_t / \Delta$") +# plt.ylabel("Count") +# plt.title(r"Histogram of $g_t / \Delta$") +# plt.legend() +# plt.tight_layout() +# plt.grid(True) +# # plt.show() +# raise NotImplementedError("Plot saving not implemented yet.") + +# # Plot gaussian postion in 2d +# i = 10000 # gaussian index +# scene = 0 +# grads_xy = gt[..., :2] # (scenes, steps, N, 2) +# deltas_xy = gt[..., 2:] +# gt_xy_pos = grads_xy.cumsum(dim=2) # cumulative sum to get positions +# deltas_xy_pos = deltas_xy.cumsum(dim=2) # cumulative sum to get positions + +# # Plot different gaussian position +# plt.figure(figsize=(6, 6)) +# gaussian_pos = gt_xy_pos[scene, :, i, :] # (steps, 2) +# # Plot with color gradient from blue to red +# plt.scatter(gaussian_pos[:, 0].cpu(), gaussian_pos[:, 1].cpu(), c=np.linspace(0, 1, len(gaussian_pos)), +# cmap='viridis') +# # plt.plot(gaussian_pos[:, 0].cpu(), gaussian_pos[:, 1].cpu(), marker='o', colors=) +# plt.scatter(gaussian_pos[0, 0].cpu(), gaussian_pos[0, 1].cpu(), color='green', label='Start', s=100) +# plt.scatter(gaussian_pos[-1, 0].cpu(), gaussian_pos[-1, 1].cpu(), color='red', label='End', s=100) +# plt.title(f"Gaussian {i} position through steps (from green to red)") +# plt.xlabel("X") +# plt.ylabel("Y") +# plt.axis('equal') +# plt.grid(True) +# plt.legend() +# # plt.show() +# raise NotImplementedError("Plot saving not implemented yet.") + + +def debugging_invisible_gaussians( + gaussian_list, + grads_raw_list, + normalized_grads_list, + means2d_list, + radii_list, + psnr_list, + iterations_list, + output_path, + scene_name +): + def concat_grads(grads_list): + grads_per_params = [] + G = grads_list[0][list(grads_list[0].keys())[0]].shape[0] # number of gaussians + for key in grads_list[0].keys(): + grads_val = [grads[key].reshape(G, -1) for grads in grads_list] + + grads_per_params.append(torch.stack(grads_val, dim=0)) # (T, G, D) + + grads_mat = torch.cat(grads_per_params, dim=-1) # (T, G, D) + return grads_mat, grads_per_params + + # === Prepare data === + grads_mat, grads_per_params = concat_grads(grads_raw_list) # (T, G, D) + norm_grads_mat, norm_grads_per_params = concat_grads(normalized_grads_list) # (T, G, D) + scales_grads = grads_per_params[1] # (T, G, scale_dim) + opacities_grads = grads_per_params[3] # (T, G, opacity_dim) + scales_norm_grads = norm_grads_per_params[1] # (T, G, scale_dim) + opacities_norm_grads = norm_grads_per_params[3] # (T, G, opacity_dim) + means2d = torch.cat(means2d_list, dim=0).cpu()[1:] # (T, V, G, 2) + radii_list = torch.cat(means2d_list, dim=0).cpu()[1:] + + T, G, D = grads_mat.shape + iterations_list = iterations_list[1:] # remove init. + + # === Convert Gaussian params to tensor === + def extract_params(gaussians: list[Gaussians], grads): + params = [] + for k in grads[0].keys(): + if k in ["shNs", "sh0s"]: + continue + params.append(torch.stack([getattr(g, k)[0].detach().cpu() for g in gaussians])) + params.append(torch.stack([g.harmonics[0].detach().cpu() for g in gaussians])) + params = [p[1:] for p in params] # remove init., each (T, G, dim) + gaussians_mat = torch.cat([p.reshape(T, G, -1) for p in params], dim=-1) # (T, G, D) + return params, gaussians_mat + + params_mat, gaussians_mat = extract_params(gaussian_list, grads_raw_list) # (T, G, D) + means = params_mat[0] + scales = params_mat[1] + rotations = params_mat[2] + opacities = params_mat[3] + harmonics = params_mat[4] + + # === Compute zero / partial grad masks === + zero_grad_mask = (grads_mat == 0) # (T, G, D) + zero_grad_cnt = (zero_grad_mask).sum(dim=-1) # (T, G) + is_zero = (zero_grad_mask).all(dim=-1) # (T, G) + is_nonzero = (~zero_grad_mask).all(dim=-1) # (T, G) + is_partial = ~(is_zero | is_nonzero) # (T, G) + validation = is_zero.float() + is_nonzero.float() + is_partial.float() + assert (validation == 1).all(), "Gradient classification error: some Gaussians are not classified properly." + + # 0 = zero, 1 = partial, 2 = nonzero + state = torch.zeros_like(is_zero, dtype=torch.int8) + state[is_partial] = 1 + state[is_nonzero] = 2 + + # === Compute change in zero grad masks === + transition = state[1:] - state[:-1] # (T-1, G) + transition_per_gaussian = (transition != 0).sum(dim=0) # (G,) + + # === Compute counts === + zero_cnt = is_zero.sum(dim=1).cpu().numpy() # (T,) + partial_cnt = is_partial.sum(dim=1).cpu().numpy() # (T,) + + # === Compute change in zero grad masks === + zero_to_partial = ((state[:-1] == 0) & (state[1:] == 1)).sum(dim=1) + zero_to_nonzero = ((state[:-1] == 0) & (state[1:] == 2)).sum(dim=1) + partial_to_nonzero = ((state[:-1] == 1) & (state[1:] == 2)).sum(dim=1) + partial_to_zero = ((state[:-1] == 1) & (state[1:] == 0)).sum(dim=1) + nonzero_to_zero = ((state[:-1] == 2) & (state[1:] == 0)).sum(dim=1) + nonzero_to_partial = ((state[:-1] == 2) & (state[1:] == 1)).sum(dim=1) + + # Stay as is + zero_to_zero = ((state[:-1] == 0) & (state[1:] == 0)).sum(dim=1) + partial_to_partial = ((state[:-1] == 1) & (state[1:] == 1)).sum(dim=1) + nonzero_to_nonzero = ((state[:-1] == 2) & (state[1:] == 2)).sum(dim=1) + + total = (zero_to_nonzero + zero_to_partial + partial_to_nonzero + partial_to_zero + nonzero_to_zero + + nonzero_to_partial + zero_to_zero + partial_to_partial + nonzero_to_nonzero) + assert (total == G).all(), "Transition counts do not sum up to total number" + + # === Gaussian indices === + n_vis = 30 + # random_mask = ((state[:-1] == 0) & (state[1:] == 0)) + # random_indices = torch.where(random_mask) + # random_indices = random_indices[1].unique() + # random_indices = random_indices[torch.randperm(len(random_indices))[:n_vis]] + + # Extract indices of the largest scale gaussians + top_scales = torch.topk(scales[-1, ..., 0], k=n_vis, largest=True).indices + random_indices = top_scales + + # === Compute mean param/grad time series === + # Zero-grad & partial-grad subsets are time-varying masks. + grad_norms = grads_mat.norm(dim=-1) # (T, G) + + # === Create figure === + fig, axes = plt.subplots(10, 1, figsize=(12, 18), sharex=True) + fig.suptitle(f"Debugging Invisible Gaussians — {scene_name}", fontsize=16) + + # 1️⃣ Zero-grad count + i = 0 + axes[i].plot(iterations_list, zero_cnt, label="Zero Grad Gaussians") + axes[i].plot(iterations_list, partial_cnt, label="Partial Grad Gaussians") + axes[i].set_ylabel("Count") + axes[i].set_title("Zero vs Partial Grad Gaussians Count") + axes[i].legend() + + # Change of classification counts + i += 1 + axes[i].plot(iterations_list[1:], zero_to_partial.cpu(), label='Zero → Partial') + axes[i].plot(iterations_list[1:], zero_to_nonzero.cpu(), label='Zero → Nonzero') + axes[i].plot(iterations_list[1:], partial_to_nonzero.cpu(), label='Partial → Nonzero') + axes[i].plot(iterations_list[1:], partial_to_zero.cpu(), label='Partial → Zero') + axes[i].plot(iterations_list[1:], nonzero_to_zero.cpu(), label='Nonzero → Zero') + axes[i].plot(iterations_list[1:], nonzero_to_partial.cpu(), label='Nonzero → Partial') + axes[i].set_ylabel("Count") + axes[i].set_title("Transition Grad Gaussians Count") + axes[i].legend() + + # 2️⃣ Random zero grad cnt + i += 1 + axes[i].plot(iterations_list, zero_grad_cnt[:, random_indices]) + axes[i].set_title("Gaussians zero grad count") + axes[i].set_ylabel("Zero grad cnt") + + # 3️⃣ Random Gradient magnitudes + i += 1 + axes[i].plot(iterations_list, grad_norms[:, random_indices]) + axes[i].set_title("Gaussians gradient magnitude") + axes[i].set_ylabel("Gradient norm") + + # 4️⃣ Random scales + i += 1 + axes[i].plot(iterations_list, scales[:, random_indices].mean(-1)) + axes[i].set_title("Gaussians scales") + axes[i].set_ylabel("Scales") + + # 4️⃣ Random opacities + i += 1 + axes[i].plot(iterations_list, opacities[:, random_indices]) + axes[i].set_title("Gaussians opacities") + axes[i].set_ylabel("Opacities") + + # 4️⃣ Random means + i += 1 + axes[i].plot(iterations_list, scales_norm_grads[:, random_indices, 0]) + axes[i].set_title("Gaussians scales X adam grad") + axes[i].set_ylabel("Scales X grad") + + i += 1 + axes[i].plot(iterations_list, opacities_grads[:, random_indices, 0]) + axes[i].set_title("Gaussians opacities adam grad") + axes[i].set_ylabel("Opacities X grad") + + i += 1 + axes[i].plot(iterations_list, means2d[:, 0, random_indices, 0]) + axes[i].set_title("Gaussians means 2D X") + axes[i].set_ylabel("Means 2D X") + axes[i].set_xlabel("Iteration") + + i += 1 + axes[i].plot(iterations_list, radii_list[:, :, random_indices, 0].sum(1)) + axes[i].set_title("Gaussians radii 2D X") + axes[i].set_ylabel("Radii 2D X") + axes[i].set_xlabel("Iteration") + + plt.tight_layout(rect=[0, 0, 1, 0.96]) + + # === Save figure === + # os.makedirs(output_path, exist_ok=True) + # fig_path = os.path.join(output_path, f"{scene_name}_debug_invisible_gaussians_over_time.png") + # plt.savefig(fig_path) + # plt.close(fig) + plt.show() + + print(f"✅ Saved time-evolution debug plot → {fig_path}") diff --git a/optgs/scripts/diff_renders.py b/optgs/scripts/diff_renders.py new file mode 100644 index 0000000000000000000000000000000000000000..4b9e33280ef9be873ef9892d2509f0cd7dcfe640 --- /dev/null +++ b/optgs/scripts/diff_renders.py @@ -0,0 +1,147 @@ +"""Compute per-image rendering diffs between two output directories. + +Pairs PNGs by relative path under each root (e.g. .../initializerply/images//color_target/*.png) +and reports max|diff|, mean|diff|, PSNR. Useful for comparing the gsplat and inria decoders +on the same init. + +Usage: + python -m optgs.scripts.diff_renders [--subdir initializerply/images] \ + [--save-diff ] [--top-k 10] +""" +import argparse +import json +import sys +from pathlib import Path + +import numpy as np +from PIL import Image + + +def collect_pngs(root: Path, subdir: str) -> dict[str, Path]: + base = root / subdir if subdir else root + if not base.exists(): + sys.exit(f"Missing path: {base}") + return {str(p.relative_to(base)): p for p in base.rglob("*.png")} + + +def diff_pair(a_path: Path, b_path: Path) -> dict: + a = np.asarray(Image.open(a_path).convert("RGB"), dtype=np.float32) / 255.0 + b = np.asarray(Image.open(b_path).convert("RGB"), dtype=np.float32) / 255.0 + if a.shape != b.shape: + return {"shape_a": a.shape, "shape_b": b.shape, "skipped": True} + d = np.abs(a - b) + mse = float((d ** 2).mean()) + psnr = float(20 * np.log10(1.0) - 10 * np.log10(mse + 1e-12)) + return { + "max_abs": float(d.max()), + "mean_abs": float(d.mean()), + "mse": mse, + "psnr": psnr, + "shape": list(a.shape), + } + + +def save_diff_image(a_path: Path, b_path: Path, out_path: Path, scale: float = 5.0) -> None: + a = np.asarray(Image.open(a_path).convert("RGB"), dtype=np.float32) / 255.0 + b = np.asarray(Image.open(b_path).convert("RGB"), dtype=np.float32) / 255.0 + if a.shape != b.shape: + return + d = np.clip(np.abs(a - b) * scale, 0, 1) + out_path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray((d * 255).astype(np.uint8)).save(out_path) + + +def main() -> None: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("root_a", type=Path) + p.add_argument("root_b", type=Path) + p.add_argument("--subdir", default="", + help="Restrict comparison to this subpath under each root (e.g. 'initializerply/images').") + p.add_argument("--save-diff", type=Path, default=None, + help="Directory to save scaled |a-b| images, mirroring the relative path.") + p.add_argument("--diff-scale", type=float, default=5.0) + p.add_argument("--top-k", type=int, default=10, help="Show K worst pairs by max|diff|.") + p.add_argument("--json", type=Path, default=None, help="Optional path to dump per-pair stats as JSON.") + p.add_argument("--pair", choices=["name", "sorted"], default="name", + help="'name': pair by matching relative path. 'sorted': pair by index after sorting " + "each root's PNGs (use when filename schemes differ but order corresponds).") + args = p.parse_args() + + pngs_a = collect_pngs(args.root_a, args.subdir) + pngs_b = collect_pngs(args.root_b, args.subdir) + + print(f"root_a: {args.root_a / args.subdir if args.subdir else args.root_a}") + print(f"root_b: {args.root_b / args.subdir if args.subdir else args.root_b}") + + if args.pair == "name": + common = sorted(set(pngs_a) & set(pngs_b)) + only_a = sorted(set(pngs_a) - set(pngs_b)) + only_b = sorted(set(pngs_b) - set(pngs_a)) + print(f"pair=name; common: {len(common)}; only_a: {len(only_a)}; only_b: {len(only_b)}") + if not common: + sys.exit("No common PNGs to diff. Try --pair sorted if filenames differ but order matches.") + pairs = [(rel, pngs_a[rel], pngs_b[rel]) for rel in common] + else: + sa = sorted(pngs_a.items()) + sb = sorted(pngs_b.items()) + if len(sa) != len(sb): + sys.exit(f"pair=sorted: counts differ (root_a={len(sa)}, root_b={len(sb)}); can't pair by index.") + print(f"pair=sorted; {len(sa)} pairs") + only_a = only_b = [] + pairs = [(f"{ra}|{rb}", pa, pb) for (ra, pa), (rb, pb) in zip(sa, sb)] + + results = [] + skipped = [] + for rel, a_path, b_path in pairs: + stats = diff_pair(a_path, b_path) + if stats.get("skipped"): + skipped.append((rel, stats)) + continue + results.append((rel, stats)) + if args.save_diff is not None: + save_diff_image(a_path, b_path, args.save_diff / rel.replace("|", "_VS_"), scale=args.diff_scale) + + if skipped: + print(f"\nShape-mismatch pairs ({len(skipped)}):") + for rel, s in skipped[:10]: + print(f" {rel}: {s['shape_a']} vs {s['shape_b']}") + + if not results: + sys.exit("All pairs had mismatched shapes.") + + max_abs = np.array([s["max_abs"] for _, s in results]) + mean_abs = np.array([s["mean_abs"] for _, s in results]) + psnr = np.array([s["psnr"] for _, s in results]) + + print(f"\nPer-pair stats ({len(results)} pairs):") + print(f" max|diff| — min: {max_abs.min():.4e} median: {np.median(max_abs):.4e} max: {max_abs.max():.4e}") + print(f" mean|diff| — min: {mean_abs.min():.4e} median: {np.median(mean_abs):.4e} max: {mean_abs.max():.4e}") + print(f" PSNR(dB) — min: {psnr.min():.2f} median: {np.median(psnr):.2f} max: {psnr.max():.2f}") + + results.sort(key=lambda r: -r[1]["max_abs"]) + print(f"\nWorst {min(args.top_k, len(results))} pairs by max|diff|:") + for rel, s in results[: args.top_k]: + print(f" max={s['max_abs']:.4e} mean={s['mean_abs']:.4e} psnr={s['psnr']:.2f}dB {rel}") + + if args.json is not None: + args.json.parent.mkdir(parents=True, exist_ok=True) + with open(args.json, "w") as f: + json.dump( + { + "root_a": str(args.root_a), + "root_b": str(args.root_b), + "subdir": args.subdir, + "common_count": len(common), + "only_a": only_a, + "only_b": only_b, + "pairs": [{"rel": r, **s} for r, s in results], + "skipped": [{"rel": r, **s} for r, s in skipped], + }, + f, + indent=2, + ) + print(f"\nWrote per-pair stats to {args.json}") + + +if __name__ == "__main__": + main() diff --git a/optgs/scripts/dl3dv_benchmark_hf_download.py b/optgs/scripts/dl3dv_benchmark_hf_download.py new file mode 100644 index 0000000000000000000000000000000000000000..865f18ba96f53d208fa7f70a9a759148b01b3022 --- /dev/null +++ b/optgs/scripts/dl3dv_benchmark_hf_download.py @@ -0,0 +1,220 @@ +""" This script is used to download the DL3DV benchmark from the huggingface repo. + + The benchmark is composed of 140 different scenes covering different scene complexities (reflection, transparency, indoor/outdoor, etc.) + + The whole benchmark is very large: 2.1 TB. So we provide this script to download the subset of the dataset based on common needs. + + + - [x] Full benchmark downloading + Full download can directly be done by git clone (w. lfs installed). + + - [x] scene downloading based on scene hash code + + Option: + - [x] images_4 (960 x 540 resolution) level dataset (approx 50G) + +""" + +import argparse +import os +import pickle +import shutil +import traceback +from os.path import join + +import pandas as pd +from huggingface_hub import HfApi +from tqdm import tqdm + +from optgs.misc.io import CustomPath + +api = HfApi() +repo_root = 'DL3DV/DL3DV-10K-Benchmark' + + +def hf_download_path(repo_path: str, odir: str, max_try: int = 5): + """ hf api is not reliable, retry when failed with max tries + + :param repo_path: The path of the repo to download + :param odir: output path + """ + rel_path = os.path.relpath(repo_path, repo_root) + + counter = 0 + while True: + if counter >= max_try: + print("ERROR: Download {} failed.".format(repo_path)) + return False + + try: + api.hf_hub_download(repo_id=repo_root, filename=rel_path, repo_type='dataset', local_dir=odir, + cache_dir=join(odir, '.cache')) + return True + + except BaseException as e: + traceback.print_exc() + counter += 1 + print(f'Retry {counter}') + + +def clean_huggingface_cache(cache_dir: str): + """ Huggingface cache may take too much space, we clean the cache to save space if necessary + + :param cache_dir: the current cache directory + """ + # Current huggingface hub does not provide good practice to clean the space. + # We mannually clean the cache directory if necessary. + try: + shutil.rmtree(join(cache_dir, 'datasets--DL3DV--DL3DV-10K-Benchmark')) + except Exception as e: + pass + + +def download_by_hash(filepath_dict: dict, odir: str, hash: str, only_level4: bool, only_sfm: bool): + """ Given a hash, download the relevant data from the huggingface repo + + :param filepath_dict: the cache dict that stores all the file relative paths + :param odir: the download directory + :param hash: the hash code for the scene + :param only_level4: the images_4 resolution level, if true, only the images_4 resolution level will be downloaded + """ + all_files = filepath_dict[hash] + download_files = [join(repo_root, f) for f in all_files] + + if only_level4: # only download images_4 level data + download_files = [] + for f in all_files: + subdirname = os.path.basename(os.path.dirname(f)) + + if 'images' in f and subdirname != 'images_4' or 'input' in f: + continue + + download_files.append(join(repo_root, f)) + + if only_sfm: # only download nerfstudio colmap data + download_files = list(filter(lambda x: + 'nerfstudio' in x and + ('.json' in x or '.bin' in x), + all_files)) + download_files = [join(repo_root, f) for f in download_files] + + for f in download_files: + if hf_download_path(f, odir) == False: + return False + + if only_sfm: + # Move files to the scene root directory + # /nerfstudio/transforms.json --> /transforms.json + # /nerfstudio/colmap/sparse --> /sparse + + # transforms.json + src_transforms_path = CustomPath(odir) / hash / 'nerfstudio' / 'transforms.json' + dst_transforms_path = CustomPath(odir) / hash / 'transforms.json' + shutil.move(src_transforms_path, dst_transforms_path) + + # sparse + src_sparse_path = CustomPath(odir) / hash / 'nerfstudio' / 'colmap' / 'sparse' + dst_sparse_path = CustomPath(odir) / hash / 'sparse' + try: + shutil.move(src_sparse_path, dst_sparse_path) + except Exception as e: + print(f'Warning: {hash} sparse already exists. Overwriting.') + shutil.rmtree(dst_sparse_path) + shutil.move(src_sparse_path, dst_sparse_path) + + # remove empty nerfstudio directory + nerfstudio_dir = CustomPath(odir) / hash / 'nerfstudio' + # check if the colmap directory is empty + if len(list(nerfstudio_dir.iterdir())) == 1 and len(list((nerfstudio_dir / 'colmap').iterdir())) == 0: + shutil.rmtree(nerfstudio_dir) + + return True + + +def download_benchmark(args): + """ Download the benchmark based on the user inputs. + + 1. download the benchmark-meta.csv + 2. based on the args, download the specific subset + a. full benchmark + b. full benchmark in images_4 resolution level + c. full benchmark only with nerfstudio colmaps (w.o. gaussian splatting colmaps) + d. specific scene based on the index in [0, 140) + + :param args: argparse args. Used to decide the subset. + :return: download success or not + """ + output_dir = args.odir + subset_opt = args.subset + level4_opt = args.only_level4 + hash_name = args.hash + is_clean_cache = args.clean_cache + only_sfm = args.only_sfm + + # import pdb; pdb.set_trace() + os.makedirs(output_dir, exist_ok=True) + + # STEP 1: download the benchmark-meta.csv and .cache/filelist.bin + meta_repo_path = join(repo_root, 'benchmark-meta.csv') + cache_file_path = join(repo_root, '.cache/filelist.bin') + if hf_download_path(meta_repo_path, output_dir) == False: + print('ERROR: Download benchmark-meta.csv failed.') + return False + + if hf_download_path(cache_file_path, output_dir) == False: + print('ERROR: Download .cache/filelist.bin failed.') + return False + + # STEP 2: download the specific subset + df = pd.read_csv(join(output_dir, 'benchmark-meta.csv')) + filepath_dict = pickle.load(open(join(output_dir, '.cache/filelist.bin'), 'rb')) + hashlist = df['hash'].tolist() + download_list = hashlist + + # sanity check + if subset_opt == 'hash': + if hash_name not in hashlist: + print(f'ERROR: hash {hash_name} not in the benchmark-meta.csv') + return False + + # if subset is hash, only download the specific hash + download_list = [hash_name] + + # download the dataset + for cur_hash in tqdm(download_list): + if download_by_hash(filepath_dict, output_dir, cur_hash, level4_opt, only_sfm) == False: + return False + + if is_clean_cache: + clean_huggingface_cache(join(output_dir, '.cache')) + + return True + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--odir', type=str, help='output directory', default='DL3DV-10K-Benchmark') + parser.add_argument('--subset', choices=['full', 'hash'], help='The subset of the benchmark to download', + required=True) + parser.add_argument('--only_level4', action='store_true', + help='If set, only the images_4 resolution level will be downloaded to save space') + parser.add_argument('--only_sfm', action='store_true', + help='If set, only the nerfstudio colmap data will be downloaded to save space') + parser.add_argument('--clean_cache', action='store_true', + help='If set, will clean the huggingface cache to save space') + parser.add_argument('--hash', type=str, help='If set subset=hash, this is the hash code of the scene to download', + default='') + params = parser.parse_args() + + # Check huggingface login + try: + user = api.whoami() + print(f'Logged in as {user["name"]}') + except Exception as e: + print('ERROR: Huggingface login failed. Please check your internet connection and huggingface token.') + exit(1) + + if download_benchmark(params): + print('Download Done. Refer to', params.odir) + else: + print(f'Download to {params.odir} Failed. See error messsage.') diff --git a/optgs/scripts/dl3dv_colmap_test_set_intersection.py b/optgs/scripts/dl3dv_colmap_test_set_intersection.py new file mode 100644 index 0000000000000000000000000000000000000000..17e6041f7c0b6c8fe1db97ce47cee2d54bcf82d9 --- /dev/null +++ b/optgs/scripts/dl3dv_colmap_test_set_intersection.py @@ -0,0 +1,20 @@ +import pathlib + +if __name__ == '__main__': + colmap_dir = pathlib.Path("datasets/dl3dv-colmap-cache/1K") + testset_scene_example_dir = pathlib.Path("results/dl3dv/8_views_140_scenes/best_adam/2000/optimizervanilla/metrics") + + available_scenes = list(colmap_dir.iterdir()) + testset_scenes = list(testset_scene_example_dir.iterdir()) + + available_scene_names = set([scene.name for scene in available_scenes]) + testset_scene_names = set([scene.name for scene in testset_scenes]) + + intersection = available_scene_names.intersection(testset_scene_names) + + print(f"Number of available scenes: {len(available_scene_names)}") + print(f"Number of testset scenes: {len(testset_scene_names)}") + print(f"Number of scenes in intersection: {len(intersection)}") + print("Scenes in intersection:") + for scene_name in sorted(intersection): + print(f"- {scene_name}") \ No newline at end of file diff --git a/optgs/scripts/dl3dv_hf_download.py b/optgs/scripts/dl3dv_hf_download.py new file mode 100644 index 0000000000000000000000000000000000000000..b59c3a905b3cf010146187c3ee6b148879e4c188 --- /dev/null +++ b/optgs/scripts/dl3dv_hf_download.py @@ -0,0 +1,421 @@ +""" This script is used to download the DL3DV-10 dataset for all resolution levels from the huggingface repo. + As the whole dataset is too large for most users, we provide this script so that you can download the dataset efficiently based on your needs. + We provide several options to download the dataset (image frames with poses): + - [X] Resolution level: 4K, 2K, 960P, 480P + - [X] Subset of the 10K, e.g. 1K(0~1K), 2K(1K~2K), 3K(2K~3K), etc + - [X] specific hash + - [X] file_type: raw video | images+poses | colmap cache + + Notes: + - file_type + resolution will decide which dataset repo to download the files + - subset will decide which subdir will be used + - if hash is set, only the specific hash will be downloaded + + example usage: python dl3dv_hf_download.py --odir ../../datasets/dl3dv-colmap-sfm --file_type colmap_sfm --hash e2cedefea8a0ed2d0ffbd5bdc08acbe7e1f85c96f72f7b790e9dfe1c98963047 --clean_cache --subset 1K --resolution 480P + + +""" + +import os +import pathlib +from os.path import join +import pandas as pd +from tqdm import tqdm +from huggingface_hub import HfApi +import argparse +import traceback +import shutil +import urllib.request +import zipfile +from huggingface_hub import HfFileSystem + +api = HfApi() +resolution2repo = { + '480P': 'DL3DV/DL3DV-ALL-480P', + '960P': 'DL3DV/DL3DV-ALL-960P', + '2K': 'DL3DV/DL3DV-ALL-2K', + '4K': 'DL3DV/DL3DV-ALL-4K' +} + + +def verify_access(repo: str): + """ This function can be used to verify if the user has access to the repo. + + :param repo: the repo name + :return: True if the user has access, False otherwise + """ + fs = HfFileSystem() + try: + fs.ls(f'datasets/{repo}') + return True + except BaseException as e: + return False + + +def hf_download_path(repo: str, rel_path: str, odir: str, max_try: int = 5): + """ hf api is not reliable, retry when failed with max tries + + :param repo: The huggingface dataset repo + :param rel_path: The relative path in the repo + :param odir: output path + :param max_try: As the downloading is not a reliable process, we will retry for max_try times + """ + counter = 0 + while True: + if counter >= max_try: + print(f"ERROR: Download {repo}/{rel_path} failed.") + return False + try: + api.hf_hub_download(repo_id=repo, + filename=rel_path, + repo_type='dataset', + local_dir=odir, + cache_dir=join(odir, '.cache')) + return True + + except KeyboardInterrupt: + print('Keyboard Interrupt. Exit.') + exit() + except BaseException as e: + traceback.print_exc() + counter += 1 + + +def download_from_url(url: str, ofile: str): + """ Download a file from the url to ofile + + :param url: The url link + :param ofile: The output path + :return: True if download success, False otherwise + """ + try: + # Use urllib.request.urlretrieve to download the file from `url` and save it locally at `local_file_path` + urllib.request.urlretrieve(url, ofile) + return True + except Exception as e: + print(f"An error occurred while downloading the file: {e}") + return False + + +def clean_huggingface_cache(output_dir: str, repo: str): + """ Huggingface cache may take too much space, we clean the cache to save space if necessary + + Current huggingface hub does not provide good practice to clean the space. + We mannually clean the cache directory if necessary. + + :param output_dir: the current output directory + :param output_dir: the huggingface repo + """ + repo_cache_dir = repo.replace('/', '--') + # cur_cache_dir = join(output_dir, '.cache', f'datasets--{repo_cache_dir}') + cur_cache_dir = join(output_dir, '.cache') + + if os.path.exists(cur_cache_dir): + shutil.rmtree(cur_cache_dir) + + +def get_download_list(subset_opt: str, hash_name: str, reso_opt: str, file_type: str, output_dir: str): + """ Get the download list based on the subset and hash name + + 1. Get the meta file + 2. Select the subset. Based on reso_opt, get the downloading list prepared. + 3. Return the download list. + + :param subset_opt: Subset of the 10K, e.g. 1K(0~1K), 2K(1K~2K), 3K(2K~3K), etc + :param hash_name: If provided a non-empty string, ignore the subset_opt and only download the specific hash + :param reso_opt: The resolution to download. + :param file_type: The file type to download: video | images+poses | colmap_cache + :param output_dir: The output directory. + """ + + def to_download_item(hash_name, reso, batch, file_type): + if file_type == 'images+poses': + repo = resolution2repo[reso] + rel_path = f'{batch}/{hash_name}.zip' + elif file_type == 'video': + repo = 'DL3DV/DL3DV-ALL-video' + rel_path = f'{batch}/{hash_name}/video.mp4' + elif file_type in ['colmap_cache', 'colmap_sfm']: + repo = 'DL3DV/DL3DV-ALL-ColmapCache' + rel_path = f'{batch}/{hash_name}.zip' + else: + raise ValueError('Unknown file_type option.') + + # return f'{repo}/{batch}/{hash_name}' + return {'repo': repo, 'rel_path': rel_path} + + ret = [] + + meta_link = 'https://raw.githubusercontent.com/DL3DV-10K/Dataset/main/cache/DL3DV-valid.csv' + cache_folder = join(output_dir, '.cache') + meta_file = join(cache_folder, 'DL3DV-valid.csv') + os.makedirs(cache_folder, exist_ok=True) + if not os.path.exists(meta_file): + assert download_from_url(meta_link, meta_file), 'Download meta file failed.' + + df = pd.read_csv(meta_file) + + # if hash is set, ignore the subset_opt + if hash_name != '': + assert hash_name in df['hash'].values, f'Hash {hash_name} not found in the meta file.' + + batch = df[df['hash'] == hash_name]['batch'].values[0] + link = to_download_item(hash_name, reso_opt, batch, file_type) + ret = [link] + return ret + + # if hash not set, we download the whole subset + subdf = df[df['batch'] == subset_opt] + for i, r in subdf.iterrows(): + hash_name = r['hash'] + ret.append(to_download_item(hash_name, reso_opt, subset_opt, file_type)) + + return ret + + +SFM_BIN_FILES = {"cameras.bin", "images.bin", "points3D.bin"} + + +def sfm_cleanup_scene(scene_dir: pathlib.Path): + """ + Keep only COLMAP sparse SfM files: + cameras.bin, images.bin, points3D.bin + Delete everything else. + """ + print(f"Cleaning up SfM scene at {scene_dir.resolve()}") + scene_dir = scene_dir.resolve() + + if not scene_dir.exists(): + print(f"[WARN] {scene_dir} does not exist") + return + + # First pass: delete unwanted files + for path in scene_dir.rglob("*"): + if path.is_file(): + # keep sparse/[0-9]+/{cameras,images,points3D}.bin and transforms.json + is_bin_file = (path.name in SFM_BIN_FILES and + path.parent.name.isdigit() and + path.parent.parent.name == "sparse") + is_transforms_file = (path.name == "transforms.json" and path.parent == scene_dir) + if is_bin_file or is_transforms_file: + continue + + path.unlink() + + # Second pass: remove empty directories bottom-up + for path in sorted(scene_dir.rglob("*"), reverse=True): + if path.is_dir() and not any(path.iterdir()): + path.rmdir() + + # Third pass: rearrange files to fit clogs training + # move //colmap/sparse/* to //* + # and remove empty dirs + subset_dir = scene_dir.parent + dataset_dir = scene_dir.parent.parent + colmap_dir = scene_dir / "colmap" + sparse_dir = colmap_dir / "sparse" + if sparse_dir.exists(): + # move the sparse dir to a scene_dir inside dataset dir + target_sparse_dir = dataset_dir / scene_dir.name / "sparse" + target_sparse_dir.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(sparse_dir), str(target_sparse_dir)) + + # remove empty dirs + if not any(colmap_dir.iterdir()): + colmap_dir.rmdir() + if not any(scene_dir.iterdir()): + scene_dir.rmdir() + + +def validate_sfm_structure(scene_dir: pathlib.Path, unsucc_count: int): + """ + Validate the SfM cleanup by checking if the scene dir only contains the sparse/0/cameras.bin, images.bin, points3D.bin files and transforms.json + """ + scene_dir = scene_dir.resolve() + + if not scene_dir.exists(): + print(f"[WARN: {unsucc_count}] {scene_dir} does not exist") + return False + + # Check if transforms.json exists in the root of the scene dir + # transforms_file = scene_dir / "transforms.json" + # TODO Naama: skipping transforms.json, but will need to redownload + # if not transforms_file.is_file(): + # print(f"[ERROR] transforms.json is missing in {scene_dir}") + # return False + + # Check if sparse/0/cameras.bin, images.bin, points3D.bin exist + sparse_0_dir = scene_dir / "sparse" / "0" + for bin_file in SFM_BIN_FILES: + if not (sparse_0_dir / bin_file).is_file(): + print(f"[ERROR: {unsucc_count}] {bin_file} is missing in {sparse_0_dir}") + return False + + # Check if there are any other files or directories in the scene dir + for path in scene_dir.rglob("*"): + if path.is_file(): + is_bin_file = (path.name in SFM_BIN_FILES and + path.parent.name == "0" and + path.parent.parent.name == "sparse") + is_transforms_file = (path.name == "transforms.json" and path.parent == scene_dir) + is_image_file = (path.suffix in ['.jpg', '.png'] and path.parent.name.startswith("images") and path.parent.parent == scene_dir) + if not (is_bin_file or is_transforms_file or is_image_file): + # print(f"*********** [WARN: {unsucc_count}] Unexpected file {path} found in {scene_dir}") + # remove the unexpected file + # path.unlink() + pass + elif path.is_dir(): + # if there is any dir other than sparse/0, it's unexpected + is_sparse_dir = (path.name == "sparse" and path.parent == scene_dir) + is_sparse_0_dir = (path.name == "0" and path.parent.name == "sparse") + # For test scenes we might have images* dir + is_images_dir = ("images" in path.name and path.parent == scene_dir) + if not (is_sparse_0_dir or is_sparse_dir or is_images_dir): + # print(f"*********** [WARN: {unsucc_count}] Unexpected directory {path} found in {scene_dir}") + # remove the unexpected dir + # shutil.rmtree(path) + pass + return True + + +def download(download_list: list, output_dir: str, is_clean_cache: bool, only_sfm: bool = False): + """ Download the dataset based on the download_list and user options. + + :param download_list: the list of files to download, [{'repo', 'rel_path'}] + :param output_dir: the output directory + :param reso_opt: the resolution option + :param is_clean_cache: if set, will clean the huggingface cache to save space + :param only_sfm: if set, only download the colmap sfm files (remove all other files) + """ + succ_count = 0 + unsucc_count = 0 + + for item in tqdm(download_list, desc='Downloading'): + repo = item['repo'] + rel_path = item['rel_path'] + + output_path = os.path.join(output_dir, rel_path) + output_path = output_path.replace('.zip', '') + # skip if already exists locally + # scene dir can be moved from root/subset/hash/ to root/hash/ after sfm_cleanup, so we need to check both paths + output_path_without_subset = pathlib.Path(output_path).parent.parent / pathlib.Path(output_path).name + # if os.path.exists(output_path): + # print(f"File {output_path} already exists, skip downloading.") + # succ_count += 1 + # continue + if output_path_without_subset.exists(): + # For sfm, verify donwload based on the hash dir after cleanup, which is moved to root/hash/ + if only_sfm: + if validate_sfm_structure(output_path_without_subset, unsucc_count): + succ_count += 1 + continue + else: + succ_count += 1 + continue + unsucc_count += 1 + succ = hf_download_path(repo, rel_path, output_dir) + + if succ: + succ_count += 1 + if is_clean_cache: + clean_huggingface_cache(output_dir, repo) + + # unzip the file + if rel_path.endswith('.zip'): + zip_file = join(output_dir, rel_path) + hash_name = os.path.splitext(os.path.basename(rel_path))[0] + subset_name = os.path.dirname(rel_path) + target_dir = join(output_dir, subset_name, hash_name) + + # Ensure target directory exists + os.makedirs(target_dir, exist_ok=True) + + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + # Get list of files in the zip + zip_contents = zip_ref.namelist() + + # Check if all files are under a single directory that matches the hash + common_prefix = None + if zip_contents: + # Check if there's a common directory prefix + first_path = zip_contents[0] + if '/' in first_path: + potential_prefix = first_path.split('/')[0] + '/' + if all(path.startswith(potential_prefix) for path in zip_contents if + not path.endswith('/')): + common_prefix = potential_prefix.rstrip('/') + + # Extract files + if common_prefix == hash_name: + # Files are already under hash directory, extract normally + zip_ref.extractall(join(output_dir, subset_name)) + else: + # Extract directly to target hash directory + zip_ref.extractall(target_dir) + + if only_sfm: + scene_dir = pathlib.Path(target_dir) + sfm_cleanup_scene(scene_dir) + os.remove(zip_file) + + + else: + print(f'Download {rel_path} failed') + + print(f'Summary: {succ_count}/{len(download_list)} files downloaded successfully') + return succ_count == len(download_list) + + +def download_dataset(args): + """ Download the dataset based on the user inputs. + + :param args: argparse args. Used to decide the subset. + :return: download success or not + """ + output_dir = args.odir + subset_opt = args.subset + reso_opt = args.resolution + hash_name = args.hash + file_type = args.file_type + is_clean_cache = args.clean_cache + + os.makedirs(output_dir, exist_ok=True) + + download_list = get_download_list(subset_opt, hash_name, reso_opt, file_type, output_dir) + return download(download_list, output_dir, is_clean_cache, only_sfm=file_type == 'colmap_sfm') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--odir', type=str, help='output directory', required=True) + parser.add_argument('--subset', choices=['1K', '2K', '3K', '4K', '5K', '6K', '7K', '8K', '9K', '10K', '11K'], + help='The subset of the benchmark to download', required=True) + parser.add_argument('--resolution', choices=['4K', '2K', '960P', '480P'], help='The resolution to donwnload', + default='480P') + parser.add_argument('--file_type', choices=['images+poses', 'video', 'colmap_cache', 'colmap_sfm'], + help='The file type to download', required=True, default='images+poses') + parser.add_argument('--hash', type=str, help='If set subset=hash, this is the hash code of the scene to download', + default='') + parser.add_argument('--clean_cache', action='store_true', + help='If set, will clean the huggingface cache to save space') + params = parser.parse_args() + + assert params.file_type in ['images+poses', 'video', 'colmap_cache', 'colmap_sfm', + 'only_sfm'], 'Check the file_type input.' + + if params.file_type == 'images+poses': + repo = resolution2repo[params.resolution] + elif params.file_type == 'video': + repo = 'DL3DV/DL3DV-ALL-video' + elif params.file_type in ['colmap_cache', 'colmap_sfm']: + repo = 'DL3DV/DL3DV-ALL-ColmapCache' + + if not verify_access(repo): + print( + f'You have not grant the access yet. Go to relevant huggingface repo (https://huggingface.co/datasets/{repo}) and apply for the access.') + exit(1) + + if download_dataset(params): + print('Download Done. Refer to', params.odir) + else: + print(f'Download to {params.odir} failed. See error messsage.') diff --git a/optgs/scripts/dl3dv_verify_colmap_poses.py b/optgs/scripts/dl3dv_verify_colmap_poses.py new file mode 100644 index 0000000000000000000000000000000000000000..ef819e324c1130ec1bbad14355efd376f163eec9 --- /dev/null +++ b/optgs/scripts/dl3dv_verify_colmap_poses.py @@ -0,0 +1,84 @@ +import pathlib + +import torch + +from optgs.dataset.dataset_colmap import Parser +import json + +from optgs.scripts.convert_dl3dv_utils import load_metadata +from einops import rearrange, repeat + +if __name__ == '__main__': + scene = "14eb48a50e37df548894ab6d8cd628a21dae14bbe6c462e894616fc5962e6c49" + colmap_cache_dir = pathlib.Path("datasets/dl3dv-colmap-cache/1K") + colmap_benchmark_dir = pathlib.Path("datasets/dl3dv-benchmark") + chunk_path = pathlib.Path("datasets/dl3dv-480p-chunks/test/000004.torch") + + # Extract points and cameras from colmap cache + parser_colmap_cache = Parser( + data_dir=str(colmap_cache_dir / scene), + factor=1, # not used for point cloud + normalize=False, # not used for point cloud + load_images=False, # not used for point cloud + dl3dv_settings=None + ) + c2w_colmap_cache = torch.from_numpy(parser_colmap_cache.camtoworlds) + points_xyz_colmap_cache = torch.from_numpy(parser_colmap_cache.points) + + # Load colmap cache transform + with open(colmap_cache_dir / scene / "transforms.json", 'r') as f: + transform_colmap_cache_data = json.load(f) + transforms_colmap_c2ws = [] + for frame in transform_colmap_cache_data['frames']: + c2w = torch.tensor(frame['transform_matrix'], dtype=c2w_colmap_cache.dtype) + transforms_colmap_c2ws.append(c2w) + transforms_colmap_c2ws = torch.stack(transforms_colmap_c2ws, dim=0) + + # Extract points and cameras from colmap benchmark + # images.bin is missing, so we do not have the poses of colmap from the benchmark + parser_benchmark = Parser( + data_dir=str(colmap_benchmark_dir / scene / "nerfstudio" / "colmap"), + # The sparse dir is not in the same hyrarchy of the images, for debugging we need to copy the spase dir one step up + factor=1, # not used for point cloud + normalize=False, # not used for point cloud + load_images=False, # not used for point cloud + dl3dv_settings=None + ) + c2w_benchmark = torch.from_numpy(parser_benchmark.camtoworlds) + points_xyz_benchmark = torch.from_numpy(parser_benchmark.points) + + # Load transforms.json from nerfstudio format + with open(colmap_benchmark_dir / scene / "nerfstudio" / "transforms.json", 'r') as f: + transforms_benchmark_data = json.load(f) + transforms_benchmark_c2ws = [] + for frame in transforms_benchmark_data['frames']: + c2w = torch.tensor(frame['transform_matrix'], dtype=c2w_colmap_cache.dtype) + transforms_benchmark_c2ws.append(c2w) + transforms_benchmark_c2ws = torch.stack(transforms_benchmark_c2ws, dim=0) + applied_transform = torch.tensor(transforms_benchmark_data["applied_transform"], + dtype=c2w_colmap_cache.dtype) # [3, 4] + + # Loading chunk example cameras + chunk = torch.load(chunk_path) + chunk = chunk[0] + assert chunk["url"] == scene + cameras = chunk["cameras"] + w2c = repeat(torch.eye(4, dtype=c2w_colmap_cache.dtype), + "h w -> b h w", b=len(cameras)).clone() + w2c[:, :3] = rearrange(cameras[:, 6:], "b (h w) -> b h w", h=3, w=4) + c2w_chunk = w2c.inverse() + + blender2opencv = torch.tensor( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], + dtype=c2w_colmap_cache.dtype, + device=c2w_colmap_cache.device + ) + c2w_chunk_transformed = c2w_chunk @ blender2opencv + + # Compare c2w_chunk_transformed with transforms_colmap_cache_c2ws + diff = c2w_chunk_transformed - transforms_benchmark_c2ws + max_diff = diff.abs().max() + print(f"Max difference between chunk poses and benchmark colmap poses: {max_diff.item()}") + assert max_diff < 1e-4, "Chunk camera poses do not match benchmark colmap poses after transformation." + + diff --git a/optgs/scripts/generate_dl3dv_evaluation_index.py b/optgs/scripts/generate_dl3dv_evaluation_index.py new file mode 100644 index 0000000000000000000000000000000000000000..dcfcee5b4099cf387020cb65c373aadc14bf1b64 --- /dev/null +++ b/optgs/scripts/generate_dl3dv_evaluation_index.py @@ -0,0 +1,178 @@ +import argparse +import json +import os +from collections import OrderedDict +from glob import glob + +import numpy as np +import torch +from einops import rearrange, repeat +from jaxtyping import Float +from jaxtyping import install_import_hook +from torch import Tensor +from tqdm import tqdm + +# Configure beartype and jaxtyping. +with install_import_hook( + ("optgs",), + ("beartype", "beartype"), +): + from optgs.dataset.view_sampler.view_sampler_bounded_v2 import farthest_point_sample + from optgs.paths import asset_path + + +def convert_poses( + poses: Float[Tensor, "batch 18"], +) -> tuple[ + Float[Tensor, "batch 4 4"], # extrinsics + Float[Tensor, "batch 3 3"], # intrinsics +]: + b, _ = poses.shape + + # Convert the intrinsics to a 3x3 normalized K matrix. + intrinsics = torch.eye(3, dtype=torch.float32) + intrinsics = repeat(intrinsics, "h w -> b h w", b=b).clone() + fx, fy, cx, cy = poses[:, :4].T + intrinsics[:, 0, 0] = fx + intrinsics[:, 1, 1] = fy + intrinsics[:, 0, 2] = cx + intrinsics[:, 1, 2] = cy + + # Convert the extrinsics to a 4x4 OpenCV-style W2C matrix. + w2c = repeat(torch.eye(4, dtype=torch.float32), "h w -> b h w", b=b).clone() + w2c[:, :3] = rearrange(poses[:, 6:], "b (h w) -> b h w", h=3, w=4) + return w2c.inverse(), intrinsics + + +def partition_list(lst, n_bins): + if n_bins <= 0: + raise ValueError("Number of bins must be greater than 0") + if len(lst) < n_bins: + raise ValueError("Number of bins cannot exceed the length of the list") + + bin_size = len(lst) // n_bins + borders = [lst[0]] # First border is always the first index + for i in range(1, n_bins): + border_index = min( + i * bin_size, len(lst) - 1 + ) # Ensure last bin doesn't exceed list length + borders.append(lst[border_index]) + borders.append(lst[-1]) # Last border is always the last index + return borders + + +def find_train_and_test_index(chunk_path, scene_name=None, num_context_views=5, + num_target_skip=1, num_target_views=28, + start_frame=None, + frame_distance=None, + render_video=False, + uniform_sample=False, + ): + chunk = torch.load(chunk_path) + out_dict = OrderedDict() + for example in chunk: + cur_scene_name = example["key"] + if scene_name is not None and cur_scene_name != scene_name: + continue + + extrinsics, intrinsics = convert_poses(example["cameras"]) + + # bounded evaluation to make the task easier + if start_frame is not None: + assert frame_distance is not None + end_frame = start_frame + frame_distance + + extrinsics = extrinsics[start_frame:end_frame] + + n_views = extrinsics.shape[0] + + if uniform_sample: + index_context = [int(x) for x in np.linspace(0, n_views, num_context_views, dtype=int)] + else: + index_context = sorted(farthest_point_sample( + extrinsics[:, :3, -1].unsqueeze(0), num_context_views + ).squeeze(0).tolist()) + + if render_video: + assert start_frame is not None + assert frame_distance is not None + index_target = list(range(start_frame, end_frame)) + else: + index_target_all = [x for x in range(n_views) if x not in index_context] + + if uniform_sample: + index_target_select = [(index_context[i] + index_context[i + 1]) // 2 for i in + range(len(index_context) - 1)] + else: + index_target_select = partition_list(index_target_all, num_target_views) + + if start_frame is not None: + # the original index in the full sequence + index_context = [idx + start_frame for idx in index_context] + index_target_select = [idx + start_frame for idx in index_target_select] + + assert ( + len(index_target_select) >= num_target_views + ), f"double check {cur_scene_name} at {chunk_path}: target len: {len(index_target_select)} from {len(index_target_all)}" + index_target = index_target_select[:num_target_views] + + out_dict[cur_scene_name] = {"context": index_context, "target": index_target} + + return out_dict + + +def generate_index_file(args): + n_ctx = args.num_context_views + if args.uniform_sample: + args.num_target_views = n_ctx - 1 + n_tgt = args.num_target_views + + out_dir = str(asset_path("dl3dv_evaluation")) + os.makedirs(out_dir, exist_ok=True) + data_dir = "datasets/dl3dv/test" + chunk_paths = sorted(glob(os.path.join(data_dir, "*.torch"))) + out_dict_all = OrderedDict() + for chunk_path in tqdm(chunk_paths): + out_dict = find_train_and_test_index( + chunk_path, scene_name=None, num_context_views=n_ctx, + num_target_views=n_tgt, + start_frame=args.start_frame, + frame_distance=args.frame_distance, + render_video=args.render_video, + uniform_sample=args.uniform_sample, + ) + out_dict_all.update(out_dict) + + if args.start_frame is not None: + if args.render_video: + save_file = f"dl3dv_start_{args.start_frame}_distance_{args.frame_distance}_ctx_{n_ctx}v_video.json" + else: + save_file = f"dl3dv_start_{args.start_frame}_distance_{args.frame_distance}_ctx_{n_ctx}v_tgt_{n_tgt}v.json" + else: + save_file = f"dl3dv_ctx_{n_ctx}v_tgt_{n_tgt}v.json" + + if args.uniform_sample: + save_file = save_file[:-5] + '_uniform.json' + + out_path = os.path.join(out_dir, save_file) + + with open(out_path, "w") as f: + json.dump(out_dict_all, f) + + print("Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num_target_views", type=int, default=28, help="test skip") + parser.add_argument("--num_context_views", type=int, default=5, help="test skip") + parser.add_argument('--render_video', action='store_true') + + # bounded evaluation to make the task easier + parser.add_argument('--start_frame', default=None, type=int) + parser.add_argument('--frame_distance', default=None, type=int) + parser.add_argument('--uniform_sample', action='store_true') + + args = parser.parse_args() + + generate_index_file(args) diff --git a/optgs/scripts/generate_dl3dv_index.py b/optgs/scripts/generate_dl3dv_index.py new file mode 100644 index 0000000000000000000000000000000000000000..655ae0f5a29fa5dc8d4a67dba0451165ba19e27d --- /dev/null +++ b/optgs/scripts/generate_dl3dv_index.py @@ -0,0 +1,23 @@ +import json +from pathlib import Path + +import torch +from tqdm import tqdm + +DATASET_PATH = Path("/capstor/store/cscs/swissai/a03/hxu/datasets/dl3dv_2kres") + +if __name__ == "__main__": + # "train" or "test" + for stage in ["test"]: + stage = DATASET_PATH / stage + + index = {} + for chunk_path in tqdm( + sorted(list(stage.iterdir())), desc=f"Indexing {stage.name}" + ): + if chunk_path.suffix == ".torch": + chunk = torch.load(chunk_path) + for example in chunk: + index[example["key"]] = str(chunk_path.relative_to(stage)) + with (stage / "index.json").open("w") as f: + json.dump(index, f) diff --git a/optgs/scripts/preextract_colmap_npz.py b/optgs/scripts/preextract_colmap_npz.py new file mode 100644 index 0000000000000000000000000000000000000000..41df909a0e3ddc8454133580398e2fb80e320690 --- /dev/null +++ b/optgs/scripts/preextract_colmap_npz.py @@ -0,0 +1,115 @@ +"""Pre-extract COLMAP point clouds for all scenes and save as .npz files. + +Run once before training to avoid slow COLMAP parsing at every iteration: + + python scripts/preextract_colmap_npz.py --root [--normalize] [--workers 8] + +Each scene directory is expected to contain a `sparse/0/` (or `sparse/`) +sub-directory with the standard COLMAP binary model files. + +For every scene a file `colmap_points_cache.npz` (or +`colmap_points_cache_norm.npz` when --normalize is used) is written next to +the scene directory. The InitializerColmap class will pick these files up +automatically and skip the full SceneManager parse. +""" + +import argparse +import concurrent.futures +import os +import sys +import traceback +from pathlib import Path + +import numpy as np + +# Make sure the project root is on sys.path so that src.* imports work. +PROJECT_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from optgs.dataset.colmap.utils import Parser + + +def _npz_path(scene_dir: Path, normalize: bool) -> Path: + suffix = "_norm" if normalize else "" + return scene_dir / f"colmap_points_cache{suffix}.npz" + + +def process_scene(scene_dir: Path, normalize: bool, overwrite: bool) -> str: + npz = _npz_path(scene_dir, normalize) + if npz.exists() and not overwrite: + return f"SKIP {scene_dir.name}" + try: + parser = Parser( + data_dir=str(scene_dir), + factor=1, + normalize=normalize, + load_images=False, + dl3dv_settings=False, + verbose=False, + ) + np.savez_compressed( + npz, + points=parser.points, + points_rgb=parser.points_rgb, + camtoworlds=parser.camtoworlds, + ) + return f"OK {scene_dir.name} ({parser.points.shape[0]} pts)" + except Exception as e: + return f"ERROR {scene_dir.name}: {e}\n{traceback.format_exc()}" + + +def find_scene_dirs(root: Path) -> list[Path]: + """Return all direct children of root that look like a COLMAP scene.""" + scenes = [] + for child in sorted(root.iterdir()): + if not child.is_dir(): + continue + sparse = child / "sparse" / "0" + if not sparse.exists(): + sparse = child / "sparse" + if sparse.exists(): + scenes.append(child) + return scenes + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--root", required=True, type=Path, help="Root directory containing one sub-dir per scene.") + parser.add_argument("--normalize", action="store_true", help="Apply world-space normalisation (matches normalize_world_space: true in config).") + parser.add_argument("--overwrite", action="store_true", help="Re-extract even if .npz already exists.") + parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers (default: 4).") + args = parser.parse_args() + + root: Path = args.root.resolve() + if not root.exists(): + print(f"Root directory does not exist: {root}", file=sys.stderr) + sys.exit(1) + + scenes = find_scene_dirs(root) + if not scenes: + print(f"No COLMAP scene directories found under {root}", file=sys.stderr) + sys.exit(1) + + print(f"Found {len(scenes)} scenes under {root}") + print(f"normalize={args.normalize} overwrite={args.overwrite} workers={args.workers}\n") + + ok = skip = error = 0 + with concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) as pool: + futures = {pool.submit(process_scene, s, args.normalize, args.overwrite): s for s in scenes} + for i, fut in enumerate(concurrent.futures.as_completed(futures), 1): + msg = fut.result() + prefix = msg[:5].strip() + if prefix == "OK": + ok += 1 + elif prefix == "SKIP": + skip += 1 + else: + error += 1 + print(f"[{i}/{len(scenes)}] {msg}") + + print(f"\nDone. OK={ok} skipped={skip} errors={error}") + + +if __name__ == "__main__": + main() + diff --git a/optgs/scripts/verify_dl3dv_colmap_download.py b/optgs/scripts/verify_dl3dv_colmap_download.py new file mode 100644 index 0000000000000000000000000000000000000000..0202d4dca9b2fc1932cfd6687a834c55ae880922 --- /dev/null +++ b/optgs/scripts/verify_dl3dv_colmap_download.py @@ -0,0 +1,116 @@ +""" +This script verifies that the scenes in the chunk files match the scenes in the colmap directory. +""" +import json + +import torch +from tqdm import tqdm + +from optgs.misc.io import CustomPath +from optgs.scripts.dl3dv_hf_download import validate_sfm_structure + +if __name__ == '__main__': + chunk_dir = CustomPath("datasets/dl3dv-480p-chunks/train") + colmap_dir = CustomPath("datasets/dl3dv-colmap-sfm") + + assert chunk_dir.is_dir(), f"Chunk directory {chunk_dir:link}" + assert colmap_dir.is_dir(), f"Colmap directory {colmap_dir:link}" + + # First check if we have already saved the chunk scene names to a text file + chunk_scene_names_file = chunk_dir / "dl3dv_chunk_scenes.txt" + if chunk_scene_names_file.is_file(): + with chunk_scene_names_file.open("r") as f: + chunk_scene_names = set(line.strip() for line in f) + print(f"Loaded {len(chunk_scene_names)} scene names from {chunk_scene_names_file}") + else: + # Collect scene names from chunk files + chunk_scene_names = set() + for i, chunk_path in tqdm(enumerate(chunk_dir.glob("*.torch"))): + chunk = torch.load(chunk_path) + for scene in chunk: + scene_name = scene["key"] + scene_name = scene_name.replace("dl3dv_", "") + chunk_scene_names.add(scene_name) + if (i + 1) % 10 == 0: + print(f"Processed {i + 1} chunk files, collected {len(chunk_scene_names)} unique scene names so far...") + + print(f"Scenes in chunk files: {len(chunk_scene_names)}") + # Save chunk scene names to a text file for reuse + with open(chunk_scene_names_file, "w") as f: + for scene_name in sorted(chunk_scene_names): + f.write(f"{scene_name}\n") + + # Collect scene names from colmap directory + colmap_scene_names = set() + unsucc_count = 0 + for scene in colmap_dir.iterdir(): + # Verify dir structure: should be + # scene_name/ + # - transforms.json (for now, we don't have this) + # - sparse/ + # - 0/ + # - cameras.bin + # - images.bin + # - points3D.bin + if not validate_sfm_structure(scene, unsucc_count=unsucc_count): + unsucc_count += 1 + continue + # if not scene.is_dir(): + # print(f"Warning: {scene:link} is not a directory, skipping...") + # continue + # + # if not (scene / "sparse").is_dir(): + # print(f"Warning: {scene:link} does not contain a 'sparse' directory, skipping...") + # continue + # + # if not (scene / "sparse" / "0").is_dir(): + # print(f"Warning: {scene:link} does not contain a 'sparse/0' directory, skipping...") + # continue + # for file in ["cameras.bin", "images.bin", "points3D.bin"]: + # if not (scene / "sparse" / "0" / file).is_file(): + # print(f"Warning: {scene:link} does not contain a 'sparse/0/{file}' file, skipping...") + # continue + + colmap_scene_names.add(scene.name) + + # Compare the two sets + in_chunk_not_colmap = chunk_scene_names - colmap_scene_names + in_colmap_not_chunk = colmap_scene_names - chunk_scene_names + + print(f"Scenes in chunk but not in colmap: {len(in_chunk_not_colmap)}") + for scene_name in sorted(in_chunk_not_colmap): + print(f"- {scene_name}") + + print(f"\nScenes in colmap but not in chunk: {len(in_colmap_not_chunk)}") + # for scene_name in sorted(in_colmap_not_chunk): + # print(f"- {scene_name}") + + # Generate index_colmap.json + target_train_path = CustomPath("datasets/dl3dv-480p-chunks/train/index_colmap.json") + target_test_path = CustomPath("datasets/dl3dv-480p-chunks/test/index_colmap.json") + + full_train_index_path = CustomPath("datasets/dl3dv-480p-chunks/train/index.json") + full_test_index_path = CustomPath("datasets/dl3dv-480p-chunks/test/index.json") + + # Load the full index files + with open(full_train_index_path, "r") as f: + full_train_index = json.load(f) # with "dl3dv_" prefix in scene names + with open(full_test_index_path, "r") as f: + full_test_index = json.load(f) # without "dl3dv_" prefix in scene names + + # Filter the full index to only include scenes that has colmap data + filtered_train_index = {scene_name: data for scene_name, data in full_train_index.items() if + scene_name.replace("dl3dv_", "") in colmap_scene_names} + filtered_test_index = {scene_name: data for scene_name, data in full_test_index.items() if + scene_name in colmap_scene_names} + + # Save the filtered index files + target_train_path.parent.mkdir(parents=True, exist_ok=True) + target_test_path.parent.mkdir(parents=True, exist_ok=True) + with target_train_path.open("w") as f: + json.dump(filtered_train_index, f, indent=4) + with target_test_path.open("w") as f: + json.dump(filtered_test_index, f, indent=4) + + print(f"Saved filtered train index with {len(filtered_train_index)} scenes to {target_train_path.resolve()}") + print(f"Saved filtered test index with {len(filtered_test_index)} scenes to {target_test_path.resolve()}") diff --git a/optgs/visualization/__init__.py b/optgs/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/visualization/annotation.py b/optgs/visualization/annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..e147e4d686e94fd827c9361d75629c3f278751cb --- /dev/null +++ b/optgs/visualization/annotation.py @@ -0,0 +1,50 @@ +from pathlib import Path +from string import ascii_letters, digits, punctuation + +import numpy as np +import torch +from einops import rearrange +from jaxtyping import Float +from PIL import Image, ImageDraw, ImageFont +from torch import Tensor + +from ..paths import asset_path +from .layout import vcat + +EXPECTED_CHARACTERS = digits + punctuation + ascii_letters + + +def draw_label( + text: str, + font: Path, + font_size: int, + device: torch.device = torch.device("cpu"), +) -> Float[Tensor, "3 height width"]: + """Draw a black label on a white background with no border.""" + try: + font = ImageFont.truetype(str(asset_path(font)), font_size) + except OSError: + font = ImageFont.load_default() + left, _, right, _ = font.getbbox(text) + width = right - left + _, top, _, bottom = font.getbbox(EXPECTED_CHARACTERS) + height = bottom - top + image = Image.new("RGB", (width, height), color="white") + draw = ImageDraw.Draw(image) + draw.text((0, 0), text, font=font, fill="black") + image = torch.tensor(np.array(image) / 255, dtype=torch.float32, device=device) + return rearrange(image, "h w c -> c h w") + + +def add_label( + image: Float[Tensor, "3 width height"], + label: str, + font: Path = Path("assets/Inter-Regular.otf"), + font_size: int = 24, +) -> Float[Tensor, "3 width_with_label height_with_label"]: + return vcat( + draw_label(label, font, font_size, image.device), + image, + align="left", + gap=4, + ) diff --git a/optgs/visualization/camera_trajectory/__init__.py b/optgs/visualization/camera_trajectory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/visualization/camera_trajectory/interpolation.py b/optgs/visualization/camera_trajectory/interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..7c5d1781b18d5d31ec1862b27cb44d4c73045005 --- /dev/null +++ b/optgs/visualization/camera_trajectory/interpolation.py @@ -0,0 +1,255 @@ +import torch +from einops import einsum, rearrange, reduce +from jaxtyping import Float +from scipy.spatial.transform import Rotation as R +from torch import Tensor + + +def interpolate_intrinsics( + initial: Float[Tensor, "*#batch 3 3"], + final: Float[Tensor, "*#batch 3 3"], + t: Float[Tensor, " time_step"], +) -> Float[Tensor, "*batch time_step 3 3"]: + initial = rearrange(initial, "... i j -> ... () i j") + final = rearrange(final, "... i j -> ... () i j") + t = rearrange(t, "t -> t () ()") + return initial + (final - initial) * t + + +def intersect_rays( + a_origins: Float[Tensor, "*#batch dim"], + a_directions: Float[Tensor, "*#batch dim"], + b_origins: Float[Tensor, "*#batch dim"], + b_directions: Float[Tensor, "*#batch dim"], +) -> Float[Tensor, "*batch dim"]: + """Compute the least-squares intersection of rays. Uses the math from here: + https://math.stackexchange.com/a/1762491/286022 + """ + + # Broadcast and stack the tensors. + a_origins, a_directions, b_origins, b_directions = torch.broadcast_tensors( + a_origins, a_directions, b_origins, b_directions + ) + origins = torch.stack((a_origins, b_origins), dim=-2) + directions = torch.stack((a_directions, b_directions), dim=-2) + + # Compute n_i * n_i^T - eye(3) from the equation. + n = einsum(directions, directions, "... n i, ... n j -> ... n i j") + n = n - torch.eye(3, dtype=origins.dtype, device=origins.device) + + # Compute the left-hand side of the equation. + lhs = reduce(n, "... n i j -> ... i j", "sum") + + # Compute the right-hand side of the equation. + rhs = einsum(n, origins, "... n i j, ... n j -> ... n i") + rhs = reduce(rhs, "... n i -> ... i", "sum") + + # Left-matrix-multiply both sides by the inverse of lhs to find p. + return torch.linalg.lstsq(lhs, rhs).solution + + +def normalize(a: Float[Tensor, "*#batch dim"]) -> Float[Tensor, "*#batch dim"]: + return a / a.norm(dim=-1, keepdim=True) + + +def generate_coordinate_frame( + y: Float[Tensor, "*#batch 3"], + z: Float[Tensor, "*#batch 3"], +) -> Float[Tensor, "*batch 3 3"]: + """Generate a coordinate frame given perpendicular, unit-length Y and Z vectors.""" + y, z = torch.broadcast_tensors(y, z) + return torch.stack([y.cross(z), y, z], dim=-1) + + +def generate_rotation_coordinate_frame( + a: Float[Tensor, "*#batch 3"], + b: Float[Tensor, "*#batch 3"], + eps: float = 1e-4, +) -> Float[Tensor, "*batch 3 3"]: + """Generate a coordinate frame where the Y direction is normal to the plane defined + by unit vectors a and b. The other axes are arbitrary.""" + device = a.device + + # Replace every entry in b that's parallel to the corresponding entry in a with an + # arbitrary vector. + b = b.detach().clone() + parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps + b[parallel] = torch.tensor([0, 0, 1], dtype=b.dtype, device=device) + parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps + b[parallel] = torch.tensor([0, 1, 0], dtype=b.dtype, device=device) + + # Generate the coordinate frame. The initial cross product defines the plane. + return generate_coordinate_frame(normalize(a.cross(b)), a) + + +def matrix_to_euler( + rotations: Float[Tensor, "*batch 3 3"], + pattern: str, +) -> Float[Tensor, "*batch 3"]: + *batch, _, _ = rotations.shape + rotations = rotations.reshape(-1, 3, 3) + angles_np = R.from_matrix(rotations.detach().cpu().numpy()).as_euler(pattern) + rotations = torch.tensor(angles_np, dtype=rotations.dtype, device=rotations.device) + return rotations.reshape(*batch, 3) + + +def euler_to_matrix( + rotations: Float[Tensor, "*batch 3"], + pattern: str, +) -> Float[Tensor, "*batch 3 3"]: + *batch, _ = rotations.shape + rotations = rotations.reshape(-1, 3) + matrix_np = R.from_euler(pattern, rotations.detach().cpu().numpy()).as_matrix() + rotations = torch.tensor(matrix_np, dtype=rotations.dtype, device=rotations.device) + return rotations.reshape(*batch, 3, 3) + + +def extrinsics_to_pivot_parameters( + extrinsics: Float[Tensor, "*#batch 4 4"], + pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"], + pivot_point: Float[Tensor, "*#batch 3"], +) -> Float[Tensor, "*batch 5"]: + """Convert the extrinsics to a representation with 5 degrees of freedom: + 1. Distance from pivot point in the "X" (look cross pivot axis) direction. + 2. Distance from pivot point in the "Y" (pivot axis) direction. + 3. Distance from pivot point in the Z (look) direction + 4. Angle in plane + 5. Twist (rotation not in plane) + """ + + # The pivot coordinate frame's Z axis is normal to the plane. + pivot_axis = pivot_coordinate_frame[..., :, 1] + + # Compute the translation elements of the pivot parametrization. + translation_frame = generate_coordinate_frame(pivot_axis, extrinsics[..., :3, 2]) + origin = extrinsics[..., :3, 3] + delta = pivot_point - origin + translation = einsum(translation_frame, delta, "... i j, ... i -> ... j") + + # Add the rotation elements of the pivot parametrization. + inverted = pivot_coordinate_frame.inverse() @ extrinsics[..., :3, :3] + y, _, z = matrix_to_euler(inverted, "YXZ").unbind(dim=-1) + + return torch.cat([translation, y[..., None], z[..., None]], dim=-1) + + +def pivot_parameters_to_extrinsics( + parameters: Float[Tensor, "*#batch 5"], + pivot_coordinate_frame: Float[Tensor, "*#batch 3 3"], + pivot_point: Float[Tensor, "*#batch 3"], +) -> Float[Tensor, "*batch 4 4"]: + translation, y, z = parameters.split((3, 1, 1), dim=-1) + + euler = torch.cat((y, torch.zeros_like(y), z), dim=-1) + rotation = pivot_coordinate_frame @ euler_to_matrix(euler, "YXZ") + + # The pivot coordinate frame's Z axis is normal to the plane. + pivot_axis = pivot_coordinate_frame[..., :, 1] + + translation_frame = generate_coordinate_frame(pivot_axis, rotation[..., :3, 2]) + delta = einsum(translation_frame, translation, "... i j, ... j -> ... i") + origin = pivot_point - delta + + *batch, _ = origin.shape + extrinsics = torch.eye(4, dtype=parameters.dtype, device=parameters.device) + extrinsics = extrinsics.broadcast_to((*batch, 4, 4)).clone() + extrinsics[..., 3, 3] = 1 + extrinsics[..., :3, :3] = rotation + extrinsics[..., :3, 3] = origin + return extrinsics + + +def interpolate_circular( + a: Float[Tensor, "*#batch"], + b: Float[Tensor, "*#batch"], + t: Float[Tensor, "*#batch"], +) -> Float[Tensor, " *batch"]: + a, b, t = torch.broadcast_tensors(a, b, t) + + tau = 2 * torch.pi + a = a % tau + b = b % tau + + # Consider piecewise edge cases. + d = (b - a).abs() + a_left = a - tau + d_left = (b - a_left).abs() + a_right = a + tau + d_right = (b - a_right).abs() + use_d = (d < d_left) & (d < d_right) + use_d_left = (d_left < d_right) & (~use_d) + use_d_right = (~use_d) & (~use_d_left) + + result = a + (b - a) * t + result[use_d_left] = (a_left + (b - a_left) * t)[use_d_left] + result[use_d_right] = (a_right + (b - a_right) * t)[use_d_right] + + return result + + +def interpolate_pivot_parameters( + initial: Float[Tensor, "*#batch 5"], + final: Float[Tensor, "*#batch 5"], + t: Float[Tensor, " time_step"], +) -> Float[Tensor, "*batch time_step 5"]: + initial = rearrange(initial, "... d -> ... () d") + final = rearrange(final, "... d -> ... () d") + t = rearrange(t, "t -> t ()") + ti, ri = initial.split((3, 2), dim=-1) + tf, rf = final.split((3, 2), dim=-1) + + t_lerp = ti + (tf - ti) * t + r_lerp = interpolate_circular(ri, rf, t) + + return torch.cat((t_lerp, r_lerp), dim=-1) + + +@torch.no_grad() +def interpolate_extrinsics( + initial: Float[Tensor, "*#batch 4 4"], + final: Float[Tensor, "*#batch 4 4"], + t: Float[Tensor, " time_step"], + eps: float = 1e-4, +) -> Float[Tensor, "*batch time_step 4 4"]: + """Interpolate extrinsics by rotating around their "focus point," which is the + least-squares intersection between the look vectors of the initial and final + extrinsics. + """ + + initial = initial.type(torch.float64) + final = final.type(torch.float64) + t = t.type(torch.float64) + + # Based on the dot product between the look vectors, pick from one of two cases: + # 1. Look vectors are parallel: interpolate about their origins' midpoint. + # 3. Look vectors aren't parallel: interpolate about their focus point. + initial_look = initial[..., :3, 2] + final_look = final[..., :3, 2] + dot_products = einsum(initial_look, final_look, "... i, ... i -> ...") + parallel_mask = (dot_products.abs() - 1).abs() < eps + + # Pick focus points. + initial_origin = initial[..., :3, 3] + final_origin = final[..., :3, 3] + pivot_point = 0.5 * (initial_origin + final_origin) + pivot_point[~parallel_mask] = intersect_rays( + initial_origin[~parallel_mask], + initial_look[~parallel_mask], + final_origin[~parallel_mask], + final_look[~parallel_mask], + ) + + # Convert to pivot parameters. + pivot_frame = generate_rotation_coordinate_frame(initial_look, final_look, eps=eps) + initial_params = extrinsics_to_pivot_parameters(initial, pivot_frame, pivot_point) + final_params = extrinsics_to_pivot_parameters(final, pivot_frame, pivot_point) + + # Interpolate the pivot parameters. + interpolated_params = interpolate_pivot_parameters(initial_params, final_params, t) + + # Convert back. + return pivot_parameters_to_extrinsics( + interpolated_params.type(torch.float32), + rearrange(pivot_frame, "... i j -> ... () i j").type(torch.float32), + rearrange(pivot_point, "... xyz -> ... () xyz").type(torch.float32), + ) diff --git a/optgs/visualization/camera_trajectory/spin.py b/optgs/visualization/camera_trajectory/spin.py new file mode 100644 index 0000000000000000000000000000000000000000..eadddcbc9b33075e902d610a3fc40ef0ba7bde2e --- /dev/null +++ b/optgs/visualization/camera_trajectory/spin.py @@ -0,0 +1,37 @@ +import numpy as np +import torch +from einops import repeat +from jaxtyping import Float +from scipy.spatial.transform import Rotation as R +from torch import Tensor + + +def generate_spin( + num_frames: int, + device: torch.device, + elevation: float, + radius: float, +) -> Float[Tensor, "frame 4 4"]: + # Translate back along the camera's look vector. + tf_translation = torch.eye(4, dtype=torch.float32, device=device) + tf_translation[:2] *= -1 + tf_translation[2, 3] = -radius + + # Generate the transformation for the azimuth. + phi = 2 * np.pi * (np.arange(num_frames) / num_frames) + rotation_vectors = np.stack([np.zeros_like(phi), phi, np.zeros_like(phi)], axis=-1) + + azimuth = R.from_rotvec(rotation_vectors).as_matrix() + azimuth = torch.tensor(azimuth, dtype=torch.float32, device=device) + tf_azimuth = torch.eye(4, dtype=torch.float32, device=device) + tf_azimuth = repeat(tf_azimuth, "i j -> b i j", b=num_frames).clone() + tf_azimuth[:, :3, :3] = azimuth + + # Generate the transformation for the elevation. + deg_elevation = np.deg2rad(elevation) + elevation = R.from_rotvec(np.array([deg_elevation, 0, 0], dtype=np.float32)) + elevation = torch.tensor(elevation.as_matrix()) + tf_elevation = torch.eye(4, dtype=torch.float32, device=device) + tf_elevation[:3, :3] = elevation + + return tf_azimuth @ tf_elevation @ tf_translation diff --git a/optgs/visualization/camera_trajectory/wobble.py b/optgs/visualization/camera_trajectory/wobble.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd71c8f8e3d808561894f993f6cd4988e469b39 --- /dev/null +++ b/optgs/visualization/camera_trajectory/wobble.py @@ -0,0 +1,32 @@ +import torch +from einops import rearrange +from jaxtyping import Float +from torch import Tensor + + +@torch.no_grad() +def generate_wobble_transformation( + radius: Float[Tensor, "*#batch"], + t: Float[Tensor, " time_step"], + num_rotations: int = 1, + scale_radius_with_t: bool = True, +) -> Float[Tensor, "*batch time_step 4 4"]: + # Generate a translation in the image plane. + tf = torch.eye(4, dtype=torch.float32, device=t.device) + tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() + radius = radius[..., None] + if scale_radius_with_t: + radius = radius * t + tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius + tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius + return tf + + +@torch.no_grad() +def generate_wobble( + extrinsics: Float[Tensor, "*#batch 4 4"], + radius: Float[Tensor, "*#batch"], + t: Float[Tensor, " time_step"], +) -> Float[Tensor, "*batch time_step 4 4"]: + tf = generate_wobble_transformation(radius, t) + return rearrange(extrinsics, "... i j -> ... () i j") @ tf diff --git a/optgs/visualization/color_map.py b/optgs/visualization/color_map.py new file mode 100644 index 0000000000000000000000000000000000000000..2926fffd896610ebbc1b5786170dcba4cbdb1a73 --- /dev/null +++ b/optgs/visualization/color_map.py @@ -0,0 +1,48 @@ +import torch +from colorspacious import cspace_convert +from einops import rearrange +from jaxtyping import Float +from matplotlib import cm +from torch import Tensor + + +def apply_color_map( + x: Float[Tensor, " *batch"], + color_map: str = "inferno", +) -> Float[Tensor, "*batch 3"]: + cmap = cm.get_cmap(color_map) + + # Convert to NumPy so that Matplotlib color maps can be used. + mapped = cmap(x.detach().clip(min=0, max=1).cpu().numpy())[..., :3] + + # Convert back to the original format. + return torch.tensor(mapped, device=x.device, dtype=torch.float32) + + +def apply_color_map_to_image( + image: Float[Tensor, "*batch height width"], + color_map: str = "inferno", +) -> Float[Tensor, "*batch 3 height with"]: + image = apply_color_map(image, color_map) + return rearrange(image, "... h w c -> ... c h w") + + +def apply_color_map_2d( + x: Float[Tensor, "*#batch"], + y: Float[Tensor, "*#batch"], +) -> Float[Tensor, "*batch 3"]: + red = cspace_convert((189, 0, 0), "sRGB255", "CIELab") + blue = cspace_convert((0, 45, 255), "sRGB255", "CIELab") + white = cspace_convert((255, 255, 255), "sRGB255", "CIELab") + x_np = x.detach().clip(min=0, max=1).cpu().numpy()[..., None] + y_np = y.detach().clip(min=0, max=1).cpu().numpy()[..., None] + + # Interpolate between red and blue on the x axis. + interpolated = x_np * red + (1 - x_np) * blue + + # Interpolate between color and white on the y axis. + interpolated = y_np * interpolated + (1 - y_np) * white + + # Convert to RGB. + rgb = cspace_convert(interpolated, "CIELab", "sRGB1") + return torch.tensor(rgb, device=x.device, dtype=torch.float32).clip(min=0, max=1) diff --git a/optgs/visualization/colors.py b/optgs/visualization/colors.py new file mode 100644 index 0000000000000000000000000000000000000000..ef82964e0bf71f147d41d95a6e389b1f22ba57b5 --- /dev/null +++ b/optgs/visualization/colors.py @@ -0,0 +1,32 @@ +from PIL import ImageColor + +# https://sashamaps.net/docs/resources/20-colors/ +DISTINCT_COLORS = [ + "#e6194b", + "#3cb44b", + "#ffe119", + "#4363d8", + "#f58231", + "#911eb4", + "#46f0f0", + "#f032e6", + "#bcf60c", + "#fabebe", + "#008080", + "#e6beff", + "#9a6324", + "#fffac8", + "#800000", + "#aaffc3", + "#808000", + "#ffd8b1", + "#000075", + "#808080", + "#ffffff", + "#000000", +] + + +def get_distinct_color(index: int) -> tuple[float, float, float]: + hex = DISTINCT_COLORS[index % len(DISTINCT_COLORS)] + return tuple(x / 255 for x in ImageColor.getcolor(hex, "RGB")) diff --git a/optgs/visualization/drawing/__init__.py b/optgs/visualization/drawing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/visualization/drawing/cameras.py b/optgs/visualization/drawing/cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..6a60beb53a1a24931cd4b822e5ebe65b6425a6d6 --- /dev/null +++ b/optgs/visualization/drawing/cameras.py @@ -0,0 +1,195 @@ +from typing import Optional + +import torch +from einops import einsum, rearrange, repeat +from jaxtyping import Float +from torch import Tensor + +from ...geometry.projection import unproject +from ..annotation import add_label +from .lines import draw_lines +from .types import Scalar, sanitize_scalar + + +def draw_cameras( + resolution: int, + extrinsics: Float[Tensor, "batch 4 4"], + intrinsics: Float[Tensor, "batch 3 3"], + color: Float[Tensor, "batch 3"], + near: Optional[Scalar] = None, + far: Optional[Scalar] = None, + margin: float = 0.1, # relative to AABB + frustum_scale: float = 0.05, # relative to image resolution +) -> Float[Tensor, "3 3 height width"]: + device = extrinsics.device + + # Compute scene bounds. + minima, maxima = compute_aabb(extrinsics, intrinsics, near, far) + scene_minima, scene_maxima = compute_equal_aabb_with_margin( + minima, maxima, margin=margin + ) + span = (scene_maxima - scene_minima).max() + + # Compute frustum locations. + corner_depth = (span * frustum_scale)[None] + frustum_corners = unproject_frustum_corners(extrinsics, intrinsics, corner_depth) + if near is not None: + near_corners = unproject_frustum_corners(extrinsics, intrinsics, near) + if far is not None: + far_corners = unproject_frustum_corners(extrinsics, intrinsics, far) + + # Project the cameras onto each axis-aligned plane. + projections = [] + for projected_axis in range(3): + image = torch.zeros( + (3, resolution, resolution), + dtype=torch.float32, + device=device, + ) + image_x_axis = (projected_axis + 1) % 3 + image_y_axis = (projected_axis + 2) % 3 + + def project(points: Float[Tensor, "*batch 3"]) -> Float[Tensor, "*batch 2"]: + x = points[..., image_x_axis] + y = points[..., image_y_axis] + return torch.stack([x, y], dim=-1) + + x_range, y_range = torch.stack( + (project(scene_minima), project(scene_maxima)), dim=-1 + ) + + # Draw near and far planes. + if near is not None: + projected_near_corners = project(near_corners) + image = draw_lines( + image, + rearrange(projected_near_corners, "b p xy -> (b p) xy"), + rearrange(projected_near_corners.roll(1, 1), "b p xy -> (b p) xy"), + color=0.25, + width=2, + x_range=x_range, + y_range=y_range, + ) + if far is not None: + projected_far_corners = project(far_corners) + image = draw_lines( + image, + rearrange(projected_far_corners, "b p xy -> (b p) xy"), + rearrange(projected_far_corners.roll(1, 1), "b p xy -> (b p) xy"), + color=0.25, + width=2, + x_range=x_range, + y_range=y_range, + ) + if near is not None and far is not None: + image = draw_lines( + image, + rearrange(projected_near_corners, "b p xy -> (b p) xy"), + rearrange(projected_far_corners, "b p xy -> (b p) xy"), + color=0.25, + width=2, + x_range=x_range, + y_range=y_range, + ) + + # Draw the camera frustums themselves. + projected_origins = project(extrinsics[:, :3, 3]) + projected_frustum_corners = project(frustum_corners) + start = [ + repeat(projected_origins, "b xy -> (b p) xy", p=4), + rearrange(projected_frustum_corners.roll(1, 1), "b p xy -> (b p) xy"), + ] + start = rearrange(torch.cat(start, dim=0), "(r b p) xy -> (b r p) xy", r=2, p=4) + image = draw_lines( + image, + start, + repeat(projected_frustum_corners, "b p xy -> (b r p) xy", r=2), + color=repeat(color, "b c -> (b r p) c", r=2, p=4), + width=2, + x_range=x_range, + y_range=y_range, + ) + + x_name = "XYZ"[image_x_axis] + y_name = "XYZ"[image_y_axis] + image = add_label(image, f"{x_name}{y_name} Projection") + + # TODO: Draw axis indicators. + projections.append(image) + + return torch.stack(projections) + + +def compute_aabb( + extrinsics: Float[Tensor, "batch 4 4"], + intrinsics: Float[Tensor, "batch 3 3"], + near: Optional[Scalar] = None, + far: Optional[Scalar] = None, +) -> tuple[ + Float[Tensor, "3"], # minima of the scene + Float[Tensor, "3"], # maxima of the scene +]: + """Compute an axis-aligned bounding box for the camera frustums.""" + + device = extrinsics.device + + # These points are included in the AABB. + points = [extrinsics[:, :3, 3]] + + if near is not None: + near = sanitize_scalar(near, device) + corners = unproject_frustum_corners(extrinsics, intrinsics, near) + points.append(rearrange(corners, "b p xyz -> (b p) xyz")) + + if far is not None: + far = sanitize_scalar(far, device) + corners = unproject_frustum_corners(extrinsics, intrinsics, far) + points.append(rearrange(corners, "b p xyz -> (b p) xyz")) + + points = torch.cat(points, dim=0) + return points.min(dim=0).values, points.max(dim=0).values + + +def compute_equal_aabb_with_margin( + minima: Float[Tensor, "*#batch 3"], + maxima: Float[Tensor, "*#batch 3"], + margin: float = 0.1, +) -> tuple[ + Float[Tensor, "*batch 3"], # minima of the scene + Float[Tensor, "*batch 3"], # maxima of the scene +]: + midpoint = (maxima + minima) * 0.5 + span = (maxima - minima).max() * (1 + margin) + scene_minima = midpoint - 0.5 * span + scene_maxima = midpoint + 0.5 * span + return scene_minima, scene_maxima + + +def unproject_frustum_corners( + extrinsics: Float[Tensor, "batch 4 4"], + intrinsics: Float[Tensor, "batch 3 3"], + depth: Float[Tensor, "#batch"], +) -> Float[Tensor, "batch 4 3"]: + device = extrinsics.device + + # Get coordinates for the corners. Following them in a circle makes a rectangle. + xy = torch.linspace(0, 1, 2, device=device) + xy = torch.stack(torch.meshgrid(xy, xy, indexing="xy"), dim=-1) + xy = rearrange(xy, "i j xy -> (i j) xy") + xy = xy[torch.tensor([0, 1, 3, 2], device=device)] + + # Get ray directions in camera space. + directions = unproject( + xy, + torch.ones(1, dtype=torch.float32, device=device), + rearrange(intrinsics, "b i j -> b () i j"), + ) + + # Divide by the z coordinate so that multiplying by depth will produce orthographic + # depth (z depth) as opposed to Euclidean depth (distance from the camera). + directions = directions / directions[..., -1:] + directions = einsum(extrinsics[..., :3, :3], directions, "b i j, b r j -> b r i") + + origins = rearrange(extrinsics[:, :3, 3], "b xyz -> b () xyz") + depth = rearrange(depth, "b -> b () ()") + return origins + depth * directions diff --git a/optgs/visualization/drawing/coordinate_conversion.py b/optgs/visualization/drawing/coordinate_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..42ac5aca7cf903b61e6e9f9018c8aa1811300fa3 --- /dev/null +++ b/optgs/visualization/drawing/coordinate_conversion.py @@ -0,0 +1,44 @@ +from typing import Optional, Protocol, runtime_checkable + +import torch +from jaxtyping import Float +from torch import Tensor + +from .types import Pair, sanitize_pair + + +@runtime_checkable +class ConversionFunction(Protocol): + def __call__( + self, + xy: Float[Tensor, "*batch 2"], + ) -> Float[Tensor, "*batch 2"]: + pass + + +def generate_conversions( + shape: tuple[int, int], + device: torch.device, + x_range: Optional[Pair] = None, + y_range: Optional[Pair] = None, +) -> tuple[ + ConversionFunction, # conversion from world coordinates to pixel coordinates + ConversionFunction, # conversion from pixel coordinates to world coordinates +]: + h, w = shape + x_range = sanitize_pair((0, w) if x_range is None else x_range, device) + y_range = sanitize_pair((0, h) if y_range is None else y_range, device) + minima, maxima = torch.stack((x_range, y_range), dim=-1) + wh = torch.tensor((w, h), dtype=torch.float32, device=device) + + def convert_world_to_pixel( + xy: Float[Tensor, "*batch 2"], + ) -> Float[Tensor, "*batch 2"]: + return (xy - minima) / (maxima - minima) * wh + + def convert_pixel_to_world( + xy: Float[Tensor, "*batch 2"], + ) -> Float[Tensor, "*batch 2"]: + return xy / wh * (maxima - minima) + minima + + return convert_world_to_pixel, convert_pixel_to_world diff --git a/optgs/visualization/drawing/lines.py b/optgs/visualization/drawing/lines.py new file mode 100644 index 0000000000000000000000000000000000000000..85ce39825f456f5511049b2b1237836ded600416 --- /dev/null +++ b/optgs/visualization/drawing/lines.py @@ -0,0 +1,83 @@ +from typing import Literal, Optional + +import torch +from einops import einsum, repeat +from jaxtyping import Float +from torch import Tensor + +from .coordinate_conversion import generate_conversions +from .rendering import render_over_image +from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector + + +def draw_lines( + image: Float[Tensor, "3 height width"], + start: Vector, + end: Vector, + color: Vector, + width: Scalar, + cap: Literal["butt", "round", "square"] = "round", + num_msaa_passes: int = 1, + x_range: Optional[Pair] = None, + y_range: Optional[Pair] = None, +) -> Float[Tensor, "3 height width"]: + device = image.device + start = sanitize_vector(start, 2, device) + end = sanitize_vector(end, 2, device) + color = sanitize_vector(color, 3, device) + width = sanitize_scalar(width, device) + (num_lines,) = torch.broadcast_shapes( + start.shape[0], + end.shape[0], + color.shape[0], + width.shape, + ) + + # Convert world-space points to pixel space. + _, h, w = image.shape + world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) + start = world_to_pixel(start) + end = world_to_pixel(end) + + def color_function( + xy: Float[Tensor, "point 2"], + ) -> Float[Tensor, "point 4"]: + # Define a vector between the start and end points. + delta = end - start + delta_norm = delta.norm(dim=-1, keepdim=True) + u_delta = delta / delta_norm + + # Define a vector between each sample and the start point. + indicator = xy - start[:, None] + + # Determine whether each sample is inside the line in the parallel direction. + extra = 0.5 * width[:, None] if cap == "square" else 0 + parallel = einsum(u_delta, indicator, "l xy, l s xy -> l s") + parallel_inside_line = (parallel <= delta_norm + extra) & (parallel > -extra) + + # Determine whether each sample is inside the line perpendicularly. + perpendicular = indicator - parallel[..., None] * u_delta[:, None] + perpendicular_inside_line = perpendicular.norm(dim=-1) < 0.5 * width[:, None] + + inside_line = parallel_inside_line & perpendicular_inside_line + + # Compute round caps. + if cap == "round": + near_start = indicator.norm(dim=-1) < 0.5 * width[:, None] + inside_line |= near_start + end_indicator = indicator = xy - end[:, None] + near_end = end_indicator.norm(dim=-1) < 0.5 * width[:, None] + inside_line |= near_end + + # Determine the sample's color. + selectable_color = color.broadcast_to((num_lines, 3)) + arrangement = inside_line * torch.arange(num_lines, device=device)[:, None] + top_color = selectable_color.gather( + dim=0, + index=repeat(arrangement.argmax(dim=0), "s -> s c", c=3), + ) + rgba = torch.cat((top_color, inside_line.any(dim=0).float()[:, None]), dim=-1) + + return rgba + + return render_over_image(image, color_function, device, num_passes=num_msaa_passes) diff --git a/optgs/visualization/drawing/points.py b/optgs/visualization/drawing/points.py new file mode 100644 index 0000000000000000000000000000000000000000..671db100d34cd9121cb2778dcdb7252ec915bb2d --- /dev/null +++ b/optgs/visualization/drawing/points.py @@ -0,0 +1,59 @@ +from typing import Optional + +import torch +from einops import repeat +from jaxtyping import Float +from torch import Tensor + +from .coordinate_conversion import generate_conversions +from .rendering import render_over_image +from .types import Pair, Scalar, Vector, sanitize_scalar, sanitize_vector + + +def draw_points( + image: Float[Tensor, "3 height width"], + points: Vector, + color: Vector = [1, 1, 1], + radius: Scalar = 1, + inner_radius: Scalar = 0, + num_msaa_passes: int = 1, + x_range: Optional[Pair] = None, + y_range: Optional[Pair] = None, +) -> Float[Tensor, "3 height width"]: + device = image.device + points = sanitize_vector(points, 2, device) + color = sanitize_vector(color, 3, device) + radius = sanitize_scalar(radius, device) + inner_radius = sanitize_scalar(inner_radius, device) + (num_points,) = torch.broadcast_shapes( + points.shape[0], + color.shape[0], + radius.shape, + inner_radius.shape, + ) + + # Convert world-space points to pixel space. + _, h, w = image.shape + world_to_pixel, _ = generate_conversions((h, w), device, x_range, y_range) + points = world_to_pixel(points) + + def color_function( + xy: Float[Tensor, "point 2"], + ) -> Float[Tensor, "point 4"]: + # Define a vector between the start and end points. + delta = xy[:, None] - points[None] + delta_norm = delta.norm(dim=-1) + mask = (delta_norm >= inner_radius[None]) & (delta_norm <= radius[None]) + + # Determine the sample's color. + selectable_color = color.broadcast_to((num_points, 3)) + arrangement = mask * torch.arange(num_points, device=device) + top_color = selectable_color.gather( + dim=0, + index=repeat(arrangement.argmax(dim=1), "s -> s c", c=3), + ) + rgba = torch.cat((top_color, mask.any(dim=1).float()[:, None]), dim=-1) + + return rgba + + return render_over_image(image, color_function, device, num_passes=num_msaa_passes) diff --git a/optgs/visualization/drawing/rendering.py b/optgs/visualization/drawing/rendering.py new file mode 100644 index 0000000000000000000000000000000000000000..65842a8ec5b9c6b109d19ef9509bc04ca0437ea7 --- /dev/null +++ b/optgs/visualization/drawing/rendering.py @@ -0,0 +1,152 @@ +from typing import Protocol, runtime_checkable + +import torch +from einops import rearrange, reduce +from jaxtyping import Bool, Float +from torch import Tensor + + +@runtime_checkable +class ColorFunction(Protocol): + def __call__( + self, + xy: Float[Tensor, "point 2"], + ) -> Float[Tensor, "point 4"]: # RGBA color + pass + + +def generate_sample_grid( + shape: tuple[int, int], + device: torch.device, +) -> Float[Tensor, "height width 2"]: + h, w = shape + x = torch.arange(w, device=device) + 0.5 + y = torch.arange(h, device=device) + 0.5 + x, y = torch.meshgrid(x, y, indexing="xy") + return torch.stack([x, y], dim=-1) + + +def detect_msaa_pixels( + image: Float[Tensor, "batch 4 height width"], +) -> Bool[Tensor, "batch height width"]: + b, _, h, w = image.shape + + mask = torch.zeros((b, h, w), dtype=torch.bool, device=image.device) + + # Detect horizontal differences. + horizontal = (image[:, :, :, 1:] != image[:, :, :, :-1]).any(dim=1) + mask[:, :, 1:] |= horizontal + mask[:, :, :-1] |= horizontal + + # Detect vertical differences. + vertical = (image[:, :, 1:, :] != image[:, :, :-1, :]).any(dim=1) + mask[:, 1:, :] |= vertical + mask[:, :-1, :] |= vertical + + # Detect diagonal (top left to bottom right) differences. + tlbr = (image[:, :, 1:, 1:] != image[:, :, :-1, :-1]).any(dim=1) + mask[:, 1:, 1:] |= tlbr + mask[:, :-1, :-1] |= tlbr + + # Detect diagonal (top right to bottom left) differences. + trbl = (image[:, :, :-1, 1:] != image[:, :, 1:, :-1]).any(dim=1) + mask[:, :-1, 1:] |= trbl + mask[:, 1:, :-1] |= trbl + + return mask + + +def reduce_straight_alpha( + rgba: Float[Tensor, "batch 4 height width"], +) -> Float[Tensor, "batch 4"]: + color, alpha = rgba.split((3, 1), dim=1) + + # Color becomes a weighted average of color (weighted by alpha). + weighted_color = reduce(color * alpha, "b c h w -> b c", "sum") + alpha_sum = reduce(alpha, "b c h w -> b c", "sum") + color = weighted_color / (alpha_sum + 1e-10) + + # Alpha becomes mean alpha. + alpha = reduce(alpha, "b c h w -> b c", "mean") + + return torch.cat((color, alpha), dim=-1) + + +@torch.no_grad() +def run_msaa_pass( + xy: Float[Tensor, "batch height width 2"], + color_function: ColorFunction, + scale: float, + subdivision: int, + remaining_passes: int, + device: torch.device, + batch_size: int = int(2**16), +) -> Float[Tensor, "batch 4 height width"]: # color (RGBA with straight alpha) + # Sample the color function. + b, h, w, _ = xy.shape + color = [ + color_function(batch) + for batch in rearrange(xy, "b h w xy -> (b h w) xy").split(batch_size) + ] + color = torch.cat(color, dim=0) + color = rearrange(color, "(b h w) c -> b c h w", b=b, h=h, w=w) + + # If any MSAA passes remain, subdivide. + if remaining_passes > 0: + mask = detect_msaa_pixels(color) + batch_index, row_index, col_index = torch.where(mask) + xy = xy[batch_index, row_index, col_index] + + offsets = generate_sample_grid((subdivision, subdivision), device) + offsets = (offsets / subdivision - 0.5) * scale + + color_fine = run_msaa_pass( + xy[:, None, None] + offsets, + color_function, + scale / subdivision, + subdivision, + remaining_passes - 1, + device, + batch_size=batch_size, + ) + color[batch_index, :, row_index, col_index] = reduce_straight_alpha(color_fine) + + return color + + +@torch.no_grad() +def render( + shape: tuple[int, int], + color_function: ColorFunction, + device: torch.device, + subdivision: int = 8, + num_passes: int = 2, +) -> Float[Tensor, "4 height width"]: # color (RGBA with straight alpha) + xy = generate_sample_grid(shape, device) + return run_msaa_pass( + xy[None], + color_function, + 1.0, + subdivision, + num_passes, + device, + )[0] + + +def render_over_image( + image: Float[Tensor, "3 height width"], + color_function: ColorFunction, + device: torch.device, + subdivision: int = 8, + num_passes: int = 1, +) -> Float[Tensor, "3 height width"]: + _, h, w = image.shape + overlay = render( + (h, w), + color_function, + device, + subdivision=subdivision, + num_passes=num_passes, + ) + color, alpha = overlay.split((3, 1), dim=0) + return image * (1 - alpha) + color * alpha diff --git a/optgs/visualization/drawing/types.py b/optgs/visualization/drawing/types.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e2b0304319e9016d7f538aff35dea308f3e259 --- /dev/null +++ b/optgs/visualization/drawing/types.py @@ -0,0 +1,67 @@ +from typing import Iterable, Union + +import torch +from einops import repeat +from jaxtyping import Float, Shaped +from torch import Tensor + +Real = Union[float, int] + +Vector = Union[ + Real, + Iterable[Real], + Shaped[Tensor, "3"], + Shaped[Tensor, "batch 3"], +] + + +def sanitize_vector( + vector: Vector, + dim: int, + device: torch.device, +) -> Float[Tensor, "*#batch dim"]: + if isinstance(vector, Tensor): + vector = vector.type(torch.float32).to(device) + else: + vector = torch.tensor(vector, dtype=torch.float32, device=device) + while vector.ndim < 2: + vector = vector[None] + if vector.shape[-1] == 1: + vector = repeat(vector, "... () -> ... c", c=dim) + assert vector.shape[-1] == dim + assert vector.ndim == 2 + return vector + + +Scalar = Union[ + Real, + Iterable[Real], + Shaped[Tensor, ""], + Shaped[Tensor, " batch"], +] + + +def sanitize_scalar(scalar: Scalar, device: torch.device) -> Float[Tensor, "*#batch"]: + if isinstance(scalar, Tensor): + scalar = scalar.type(torch.float32).to(device) + else: + scalar = torch.tensor(scalar, dtype=torch.float32, device=device) + while scalar.ndim < 1: + scalar = scalar[None] + assert scalar.ndim == 1 + return scalar + + +Pair = Union[ + Iterable[Real], + Shaped[Tensor, "2"], +] + + +def sanitize_pair(pair: Pair, device: torch.device) -> Float[Tensor, "2"]: + if isinstance(pair, Tensor): + pair = pair.type(torch.float32).to(device) + else: + pair = torch.tensor(pair, dtype=torch.float32, device=device) + assert pair.shape == (2,) + return pair diff --git a/optgs/visualization/export_point_cloud.py b/optgs/visualization/export_point_cloud.py new file mode 100644 index 0000000000000000000000000000000000000000..270882121c795faab902307d39f5780c907b248e --- /dev/null +++ b/optgs/visualization/export_point_cloud.py @@ -0,0 +1,72 @@ +import numpy as np +import os +import torch +try: + import open3d as o3d +except: + pass + + +def export_to_point_cloud(point_cloud, colors, + save_path='point_cloud.ply', + denoise_cloud=False, + denoise_nb_points=10, + denoise_radius=0.03, + ): + # point_cloud: [N, 3] + # colors: [N, 3] + + # point_cloud = point_cloud / 10. + + point_cloud = point_cloud - np.median(point_cloud, axis=0, keepdims=True) + + # Ensure colors are in the range [0, 1] + colors = np.clip(colors, 0, 1) + + # Create an Open3D PointCloud object + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(point_cloud) + pcd.colors = o3d.utility.Vector3dVector(colors) + + save_dir = os.path.dirname(save_path) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + # Export the point cloud to a PLY file + o3d.io.write_point_cloud(save_path, pcd) + + if denoise_cloud: + print("denoise point cloud...") + cl, ind = pcd.remove_radius_outlier(nb_points=denoise_nb_points, radius=denoise_radius) + inlier_cloud = pcd.select_by_index(ind) + o3d.io.write_point_cloud(save_path[:-4] + '_denoise.ply', inlier_cloud) + + +def transform_points(world_points, cam_to_world): + """ + Transforms world 3D points to camera coordinates. + + Args: + world_points (torch.Tensor): Nx3 tensor of 3D points in world coordinates. + cam_to_world (torch.Tensor): 4x4 tensor of camera-to-world extrinsics. + + Returns: + torch.Tensor: Nx3 tensor of 3D points in camera coordinates. + """ + # Convert world points to homogeneous coordinates (Nx4) + N = world_points.shape[0] + ones = torch.ones((N, 1), device=world_points.device) + world_points_h = torch.cat([world_points, ones], dim=1) # Nx4 + + # Compute the inverse of the extrinsics (world-to-camera transformation) + world_to_cam = torch.inverse(cam_to_world) + + # Apply transformation + camera_points_h = (world_to_cam @ world_points_h.T).T # Nx4 + + # Convert back to 3D coordinates (drop the homogeneous coordinate) + camera_points = camera_points_h[:, :3] # Nx3 + + return camera_points + diff --git a/optgs/visualization/layout.py b/optgs/visualization/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..ca44b8a0814f8e601d82aaea9dd737d8881c972c --- /dev/null +++ b/optgs/visualization/layout.py @@ -0,0 +1,228 @@ +"""This file contains useful layout utilities for images. They are: + +- add_border: Add a border to an image. +- cat/hcat/vcat: Join images by arranging them in a line. If the images have different + sizes, they are aligned as specified (start, end, center). Allows you to specify a gap + between images. + +Images are assumed to be float32 tensors with shape (channel, height, width). +""" + +from typing import Any, Generator, Iterable, Literal, Optional, Union + +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor + +Alignment = Literal["start", "center", "end"] +Axis = Literal["horizontal", "vertical"] +Color = Union[ + int, + float, + Iterable[int], + Iterable[float], + Float[Tensor, "#channel"], + Float[Tensor, ""], +] + + +def _sanitize_color(color: Color) -> Float[Tensor, "#channel"]: + # Convert tensor to list (or individual item). + if isinstance(color, torch.Tensor): + color = color.tolist() + + # Turn iterators and individual items into lists. + if isinstance(color, Iterable): + color = list(color) + else: + color = [color] + + return torch.tensor(color, dtype=torch.float32) + + +def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]: + it = iter(iterable) + yield next(it) + for item in it: + yield delimiter + yield item + + +def _get_main_dim(main_axis: Axis) -> int: + return { + "horizontal": 2, + "vertical": 1, + }[main_axis] + + +def _get_cross_dim(main_axis: Axis) -> int: + return { + "horizontal": 1, + "vertical": 2, + }[main_axis] + + +def _compute_offset(base: int, overlay: int, align: Alignment) -> slice: + assert base >= overlay + offset = { + "start": 0, + "center": (base - overlay) // 2, + "end": base - overlay, + }[align] + return slice(offset, offset + overlay) + + +def overlay( + base: Float[Tensor, "channel base_height base_width"], + overlay: Float[Tensor, "channel overlay_height overlay_width"], + main_axis: Axis, + main_axis_alignment: Alignment, + cross_axis_alignment: Alignment, +) -> Float[Tensor, "channel base_height base_width"]: + # The overlay must be smaller than the base. + _, base_height, base_width = base.shape + _, overlay_height, overlay_width = overlay.shape + assert base_height >= overlay_height and base_width >= overlay_width + + # Compute spacing on the main dimension. + main_dim = _get_main_dim(main_axis) + main_slice = _compute_offset( + base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment + ) + + # Compute spacing on the cross dimension. + cross_dim = _get_cross_dim(main_axis) + cross_slice = _compute_offset( + base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment + ) + + # Combine the slices and paste the overlay onto the base accordingly. + selector = [..., None, None] + selector[main_dim] = main_slice + selector[cross_dim] = cross_slice + result = base.clone() + result[selector] = overlay + return result + + +def cat( + main_axis: Axis, + *images: Iterable[Float[Tensor, "channel _ _"]], + align: Alignment = "center", + gap: int = 8, + gap_color: Color = 1, +) -> Float[Tensor, "channel height width"]: + """Arrange images in a line. The interface resembles a CSS div with flexbox.""" + device = images[0].device + gap_color = _sanitize_color(gap_color).to(device) + + # Find the maximum image side length in the cross axis dimension. + cross_dim = _get_cross_dim(main_axis) + cross_axis_length = max(image.shape[cross_dim] for image in images) + + # Pad the images. + padded_images = [] + for image in images: + # Create an empty image with the correct size. + padded_shape = list(image.shape) + padded_shape[cross_dim] = cross_axis_length + base = torch.ones(padded_shape, dtype=torch.float32, device=device) + base = base * gap_color[:, None, None] + padded_images.append(overlay(base, image, main_axis, "start", align)) + + # Intersperse separators if necessary. + if gap > 0: + # Generate a separator. + c, _, _ = images[0].shape + separator_size = [gap, gap] + separator_size[cross_dim - 1] = cross_axis_length + separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device) + separator = separator * gap_color[:, None, None] + + # Intersperse the separator between the images. + padded_images = list(_intersperse(padded_images, separator)) + + return torch.cat(padded_images, dim=_get_main_dim(main_axis)) + + +def hcat( + *images: Iterable[Float[Tensor, "channel _ _"]], + align: Literal["start", "center", "end", "top", "bottom"] = "start", + gap: int = 8, + gap_color: Color = 1, +): + """Shorthand for a horizontal linear concatenation.""" + return cat( + "horizontal", + *images, + align={ + "start": "start", + "center": "center", + "end": "end", + "top": "start", + "bottom": "end", + }[align], + gap=gap, + gap_color=gap_color, + ) + + +def vcat( + *images: Iterable[Float[Tensor, "channel _ _"]], + align: Literal["start", "center", "end", "left", "right"] = "start", + gap: int = 8, + gap_color: Color = 1, +): + """Shorthand for a horizontal linear concatenation.""" + return cat( + "vertical", + *images, + align={ + "start": "start", + "center": "center", + "end": "end", + "left": "start", + "right": "end", + }[align], + gap=gap, + gap_color=gap_color, + ) + + +def add_border( + image: Float[Tensor, "channel height width"], + border: int = 8, + color: Color = 1, +) -> Float[Tensor, "channel new_height new_width"]: + color = _sanitize_color(color).to(image) + c, h, w = image.shape + result = torch.empty( + (c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device + ) + result[:] = color[:, None, None] + result[:, border : h + border, border : w + border] = image + return result + + +def resize( + image: Float[Tensor, "channel height width"], + shape: Optional[tuple[int, int]] = None, + width: Optional[int] = None, + height: Optional[int] = None, +) -> Float[Tensor, "channel new_height new_width"]: + assert (shape is not None) + (width is not None) + (height is not None) == 1 + _, h, w = image.shape + + if width is not None: + shape = (int(h * width / w), width) + elif height is not None: + shape = (height, int(w * height / h)) + + return F.interpolate( + image[None], + shape, + mode="bilinear", + align_corners=False, + antialias="bilinear", + )[0] diff --git a/optgs/visualization/plots3d/__init__.py b/optgs/visualization/plots3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/optgs/visualization/plots3d/matplotlib.py b/optgs/visualization/plots3d/matplotlib.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2a8805e2bfb706e7aeb62c26f3c70bed3f1e84 --- /dev/null +++ b/optgs/visualization/plots3d/matplotlib.py @@ -0,0 +1,945 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from typing import Optional, Literal, Tuple, List, Union +from pathlib import Path +from itertools import product, combinations +from optgs.visualization.plots3d.utils import PointCloud, Camera +from optgs.dataset.camera_datasets.camera import get_scene_scale + + +TRANSPARENT = False +BBOX_INCHES = "tight" # "tight" or "auto" +PAD_INCHES = 0.1 +DPI = 100 +COLORBAR_FRACTION = 0.04625 +LARGE_SCALE_MULTIPLIER = 0.05 +SCALE_MULTIPLIER = 0.05 +RAY_LENGTH_MULTIPLIER = 1.5 + + +def get_scale(scene_radius: float) -> float: + scale = SCALE_MULTIPLIER + if scene_radius <= 1.0: + return scale + else: + return scale + (scene_radius * LARGE_SCALE_MULTIPLIER) + + +def _draw_3d_init( + ax: plt.Axes, + scene_radius: float = 1.0, + elevation_deg: float = 60.0, + azimuth_deg: float = 30.0, + up: Literal["z", "y"] = "z", +): + if scene_radius < 1.0: + lim = 1.0 + else: + lim = scene_radius + ax.set_xlim([-lim, lim]) + ax.set_ylim([-lim, lim]) + ax.set_zlim([max(-1, -lim), lim]) + + ax.set_xlabel("X") + ax.set_ylabel("Y") if up == "z" else ax.set_ylabel("Z") + ax.set_zlabel("Z") if up == "z" else ax.set_zlabel("Y") + + # axis equal + ax.set_aspect("equal") + ax.view_init(elevation_deg, azimuth_deg) + + +def _draw_rays( + ax: plt.Axes, + rays_o: np.ndarray, + rays_d: np.ndarray, + t_near: Optional[np.ndarray] = None, + t_far: Optional[np.ndarray] = None, + rgbs: Optional[np.ndarray] = None, + masks: Optional[np.ndarray] = None, + max_nr_rays: Optional[int] = None, + up: Literal["z", "y"] = "z", + scene_radius: float = 1.0, +): + if rays_o is None or rays_d is None: + return + + assert ( + rays_o.shape[0] == rays_d.shape[0] + ), "ray_o and ray_d must have the same length" + + # subsample + if max_nr_rays is not None: + if max_nr_rays < rays_o.shape[0]: + idx = np.random.permutation(rays_o.shape[0])[:max_nr_rays] + rays_o = rays_o[idx] + rays_d = rays_d[idx] + if rgbs is not None: + rgbs = rgbs[idx] + if masks is not None: + masks = masks[idx] + if t_near is not None: + t_near = t_near[idx] + if t_far is not None: + t_far = t_far[idx] + + ray_lenght = RAY_LENGTH_MULTIPLIER * scene_radius + + # draw rays + for i, (ray_o, ray_d) in enumerate(zip(rays_o, rays_d)): + start_point = ray_o + end_point = ray_o + ray_d * ray_lenght + if rgbs is not None: + color = rgbs[i] + # check if color is in [0, 255] + if np.max(color) > 1.0: + color = color / 255.0 + else: + color = "blue" + alpha = 0.75 + if masks is not None: + mask = masks[i] + if mask < 0.5: + alpha = 0.5 + # plot line segment + ax.plot( + [start_point[0], end_point[0]], + ( + [start_point[1], end_point[1]] + if up == "z" + else [start_point[2], end_point[2]] + ), + ( + [start_point[2], end_point[2]] + if up == "z" + else [start_point[1], end_point[1]] + ), + color=color, + alpha=0.3 * float(alpha), + ) + + # draw t_near, t_far points + _draw_near_far_points( + ax=ax, + rays_o=rays_o, + rays_d=rays_d, + t_near=t_near, + t_far=t_far, + up=up, + scene_radius=scene_radius, + ) + + +def _draw_point_cloud( + ax: plt.Axes, + point_cloud: PointCloud, + alpha: Optional[float] = None, + max_nr_points: Optional[int] = None, + up: Literal["z", "y"] = "z", + scene_radius: float = 1.0, +): + if point_cloud is None: + return + + scale = get_scale(scene_radius) + + points_3d = point_cloud.points_3d + points_rgb = point_cloud.points_rgb # could be None + + # subsample + if max_nr_points is not None and max_nr_points < point_cloud.points_3d.shape[0]: + # random subsample + idx = np.random.permutation(points_3d.shape[0])[:max_nr_points] + else: + # keep all points + idx = np.arange(points_3d.shape[0]) + + points_3d = points_3d[idx] + if points_rgb is not None: + points_rgb = points_rgb[idx] + + colors = point_cloud.color + if colors is None: + colors = "black" + + # prioritize points_rgb over color + if points_rgb is not None: + colors = points_rgb / 255.0 + + size = point_cloud.size + if size is None: + size = 10.0 + size = max(5.0, size * scale) + + marker = point_cloud.marker + if marker is None: + marker = "o" + + label = point_cloud.label + # if None, keep it None + + if alpha is None: + alpha = 0.5 + + # draw points + if up == "z": + ax.scatter( + points_3d[:, 0], + points_3d[:, 1], + points_3d[:, 2], + s=size, + color=colors, + alpha=alpha, + marker=marker, + label=label, + ) + else: # up = "y" + ax.scatter( + points_3d[:, 0], + points_3d[:, 2], + points_3d[:, 1], + s=size, + color=colors, + alpha=alpha, + marker=marker, + label=label, + ) + + if label is not None: + ax.legend() + + +def _draw_frame( + ax: plt.Axes, + pose: np.ndarray, + idx: int = 0, + up: Literal["z", "y"] = "z", + scene_radius: float = 1.0, +): + if pose is None: + return + + scale = get_scale(scene_radius) + + # get axis directions (normalized) + x_dir = pose[:3, 0] + x_dir /= np.linalg.norm(x_dir) + y_dir = pose[:3, 1] + y_dir /= np.linalg.norm(y_dir) + z_dir = pose[:3, 2] + z_dir /= np.linalg.norm(z_dir) + + # frame center + pos = pose[:3, 3] + + # draw bb frame + ax.quiver( + pos[0], # x + pos[1] if up == "z" else pos[2], # y + pos[2] if up == "z" else pos[1], # z + x_dir[0], + x_dir[1] if up == "z" else x_dir[2], + x_dir[2] if up == "z" else x_dir[1], + length=scale, + color="r", + ) + ax.quiver( + pos[0], # x + pos[1] if up == "z" else pos[2], # y + pos[2] if up == "z" else pos[1], # z + y_dir[0], + y_dir[1] if up == "z" else y_dir[2], + y_dir[2] if up == "z" else y_dir[1], + length=scale, + color="g", + ) + ax.quiver( + pos[0], # x + pos[1] if up == "z" else pos[2], # y + pos[2] if up == "z" else pos[1], # z + z_dir[0], + z_dir[1] if up == "z" else z_dir[2], + z_dir[2] if up == "z" else z_dir[1], + length=scale, + color="b", + ) + eps = 0.2 * scale + ax.text( + pos[0] + eps, # x + pos[1] + eps if up == "z" else pos[2] + eps, # y + pos[2] + eps if up == "z" else pos[1] + eps, # z + str(idx), + ) + + +def _draw_cartesian_axis( + ax: plt.Axes, up: Literal["z", "y"] = "z", scene_radius: float = 1.0 +): + _draw_frame(ax=ax, pose=np.eye(4), idx="w", up=up, scene_radius=scene_radius) + + +def _draw_image_plane( + ax: plt.Axes, camera: Camera, up: Literal["z", "y"] = "z", scene_radius: float = 1.0 +): + if camera is None: + return + + scale = get_scale(scene_radius) + + # get image plane corner points in 3D + # from screen coordinates + corner_points_2d_screen = np.array( + [[0, 0], [camera.width, 0], [0, camera.height], [camera.width, camera.height]] + ) + + _, corner_points_d, _ = camera.get_rays( + points_2d_screen=torch.from_numpy(corner_points_2d_screen).float() + ) # torch.Tensor + corner_points_d = corner_points_d.cpu().numpy() + + camera_center = camera.get_center() + corner_points_3d_world = camera_center + corner_points_d * scale + + for i, j in combinations(range(4), 2): + if up == "z": + ax.plot3D( + *zip(corner_points_3d_world[i], corner_points_3d_world[j]), + color="black", + linewidth=1.0, + alpha=0.5, + ) + else: + ax.plot3D( + *zip( + corner_points_3d_world[:, [0, 2, 1]][i], + corner_points_3d_world[:, [0, 2, 1]][j], + ), + color="black", + linewidth=1.0, + alpha=0.5, + ) + + +def _draw_frustum( + ax: plt.Axes, camera: Camera, up: Literal["z", "y"] = "z", scene_radius: float = 1.0 +): + if camera is None: + return + + # get image plane corner points in 3D + # from screen coordinates + image_plane_vertices_2d = np.array( + [[0, 0], [camera.width, 0], [0, camera.height], [camera.width, camera.height]] + ) + + rays_o, rays_d, _ = camera.get_rays( + points_2d_screen=torch.from_numpy(image_plane_vertices_2d).float() + ) # torch.Tensor + rays_o = rays_o.cpu().numpy() + rays_d = rays_d.cpu().numpy() + + _draw_rays( + ax=ax, + rays_o=rays_o, + rays_d=rays_d, + rgbs=np.zeros((rays_o.shape[0], 3)), + masks=np.ones((rays_o.shape[0], 1)), + up=up, + scene_radius=scene_radius, + ) + + +def _draw_camera_frame( + ax: plt.Axes, + pose: np.ndarray, + label: str = "c", + up: Literal["z", "y"] = "z", + scene_radius: float = 1.0, +): + if pose is None: + return + + scale = get_scale(scene_radius) + + # get axis directions (normalized) + x_dir = pose[:3, 0] + x_dir /= np.linalg.norm(x_dir) + y_dir = pose[:3, 1] + y_dir /= np.linalg.norm(y_dir) + z_dir = pose[:3, 2] + z_dir /= np.linalg.norm(z_dir) + # frame center + pos = pose[:3, 3] + + # draw camera frame + ax.quiver( + pos[0], # x + pos[1] if up == "z" else pos[2], # y + pos[2] if up == "z" else pos[1], # z + x_dir[0], + x_dir[1] if up == "z" else x_dir[2], + x_dir[2] if up == "z" else x_dir[1], + length=scale, + color="r", + ) + ax.quiver( + pos[0], # x + pos[1] if up == "z" else pos[2], # y + pos[2] if up == "z" else pos[1], # z + y_dir[0], + y_dir[1] if up == "z" else y_dir[2], + y_dir[2] if up == "z" else y_dir[1], + length=scale, + color="g", + ) + ax.quiver( + pos[0], # x + pos[1] if up == "z" else pos[2], # y + pos[2] if up == "z" else pos[1], # z + z_dir[0], + z_dir[1] if up == "z" else z_dir[2], + z_dir[2] if up == "z" else z_dir[1], + length=scale, + color="b", + ) + ax.text( + pos[0], # x + pos[1] if up == "z" else pos[2], # y + pos[2] if up == "z" else pos[1], # z + label, + ) + + +def _draw_point_clouds( + ax: plt.Axes, + point_clouds: List[PointCloud] = None, + max_nr_points: Optional[int] = None, + up: Literal["z", "y"] = "z", + scene_radius: float = 1.0, +): + if point_clouds is None: + return + + if not isinstance(point_clouds, list): + raise ValueError("point_clouds must be a list of PointClouds") + + # if pc are given + if len(point_clouds) > 0: + + # split max_nr_points among point clouds + if max_nr_points is not None: + max_nr_points_per_pc = max_nr_points // len(point_clouds) + if max_nr_points_per_pc == 0: + max_nr_points_per_pc = 1 + else: + max_nr_points_per_pc = None + + # plot point clouds + for i, pc in enumerate(point_clouds): + _draw_point_cloud( + ax=ax, + point_cloud=pc, + max_nr_points=max_nr_points_per_pc, + up=up, + scene_radius=scene_radius, + ) + + +def _draw_cameras( + ax: plt.Axes, + cameras: List[Camera] = None, + nr_rays: int = 0, + draw_every_n_cameras: int = 1, + up: Literal["z", "y"] = "z", + scene_radius: float = 1.0, + draw_image_planes=True, + draw_cameras_frustums=True, +): + if cameras is None: + return + + if not isinstance(cameras, list): + raise ValueError("cameras must be a list of Cameras") + + if len(cameras) > 0: + nr_cameras = len(cameras) // draw_every_n_cameras + nr_rays_per_camera = nr_rays // nr_cameras + + # draw camera frames + for i, camera in enumerate(cameras): + + if i % draw_every_n_cameras == 0: + + pose = camera.get_pose() + label = camera.label + _draw_camera_frame( + ax=ax, + pose=pose, + label=label, + up=up, + scene_radius=scene_radius, + ) + if draw_image_planes: + _draw_image_plane( + ax=ax, camera=camera, up=up, scene_radius=scene_radius + ) + if draw_cameras_frustums: + _draw_frustum( + ax=ax, camera=camera, up=up, scene_radius=scene_radius + ) + if nr_rays_per_camera > 0: + _draw_camera_rays( + ax=ax, + camera=camera, + nr_rays=nr_rays_per_camera, + up=up, + scene_radius=scene_radius, + ) + + else: + # skip camera + pass + + +def plot_3d( + cameras: List[Camera] = None, + point_clouds: List[PointCloud] = None, + nr_rays: int = 0, + draw_every_n_cameras: int = 1, + max_nr_points: int = 1000, + azimuth_deg: float = 60.0, + elevation_deg: float = 30.0, + scene_radius: Optional[float] = None, + up: Literal["z", "y"] = "z", + draw_origin: bool = True, + draw_image_planes: bool = True, + draw_cameras_frustums: bool = True, + figsize: Tuple[int, int] = (15, 15), + title: Optional[str] = None, + show: bool = True, + save_path: Optional[Path] = None, # if set, saves the figure to the given path +) -> None: + """ + Returns: + None + """ + + if not (up == "z" or up == "y"): + raise ValueError("up must be either 'y' or 'z'") + + # + if scene_radius is None: + if cameras is not None and len(cameras) > 0: + camtoworlds = [camera.get_pose() for camera in cameras] # list of (4, 4) + # stack to numpy array + camtoworlds = np.stack(camtoworlds, axis=0) # (N, 4, 4) + scene_radius = get_scene_scale(camtoworlds) + else: + scene_radius = 1.0 + + # init figure + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111, projection="3d") + if title is not None: + ax.set_title(title) + _draw_3d_init( + ax=ax, + scene_radius=scene_radius, + up=up, + elevation_deg=elevation_deg, + azimuth_deg=azimuth_deg, + ) + + if draw_origin: + _draw_cartesian_axis(ax=ax, up=up, scene_radius=scene_radius) + + # draw points + _draw_point_clouds( + ax=ax, + point_clouds=point_clouds, + # points_3d=points_3d, + # points_3d_colors=points_3d_colors, + # points_3d_labels=points_3d_labels, + # points_3d_sizes=points_3d_sizes, + # points_3d_markers=points_3d_markers, + max_nr_points=max_nr_points, + up=up, + scene_radius=scene_radius, + ) + + # draw camera frames + _draw_cameras( + ax=ax, + cameras=cameras, + nr_rays=nr_rays, + draw_every_n_cameras=draw_every_n_cameras, + up=up, + scene_radius=scene_radius, + draw_image_planes=draw_image_planes, + draw_cameras_frustums=draw_cameras_frustums, + ) + + if save_path is not None: + plt.savefig( + save_path, + transparent=TRANSPARENT, + bbox_inches=BBOX_INCHES, + pad_inches=PAD_INCHES, + dpi=DPI, + ) + print(f"saved figure to {save_path}") + + if show: + plt.show() + + plt.close() + + +def _draw_camera_rays( + ax: plt.Axes, + camera, + nr_rays, + frame_idx=0, + up: Literal["z", "y"] = "z", + scene_radius: float = 1.0, +): + rays_o, rays_d, points_2d_screen = camera.get_rays() # torch.Tensor + rays_o = rays_o.cpu().numpy() + rays_d = rays_d.cpu().numpy() + + # color rays with their uv coordinates + xy = points_2d_screen # [:, [1, 0]] + z = np.zeros((xy.shape[0], 1)) + rgbs = np.concatenate([xy, z], axis=1) + rgbs[:, 0] /= np.max(rgbs[:, 0]) + rgbs[:, 1] /= np.max(rgbs[:, 1]) + + # set to ones + masks = np.ones((camera.height, camera.width, 1)).reshape(-1, 1) * 0.5 + + # draw rays + _draw_rays( + ax=ax, + rays_o=rays_o, + rays_d=rays_d, + rgbs=rgbs, + masks=masks, + max_nr_rays=nr_rays, + up=up, + scene_radius=scene_radius, + ) + + +def _draw_near_far_points( + ax: plt.Axes, + rays_o: np.ndarray, + rays_d: np.ndarray, + t_near: float, + t_far: float, + up: Literal["z", "y"] = "z", + scene_radius: float = 1.0, +): + if rays_o is None or rays_d is None: + return + if t_near is None or t_far is None: + return + + assert ( + rays_o.shape[0] == rays_d.shape[0] + ), "ray_o and ray_d must have the same length" + assert ( + t_near.shape[0] == t_far.shape[0] + ), "t_near and t_far must have the same length" + assert ( + rays_o.shape[0] == t_near.shape[0] + ), "ray_o and t_near must have the same length" + + # unsqueeze t_near, t_far if needed + if t_near.ndim == 1: + t_near = t_near[:, np.newaxis] + if t_far.ndim == 1: + t_far = t_far[:, np.newaxis] + + # draw t_near, t_far points + p_near = rays_o + rays_d * t_near + p_far = rays_o + rays_d * t_far + + # unsqueeze p_near, p_far if needed + if p_near.ndim == 1: + p_near = p_near[np.newaxis, :] + if p_far.ndim == 1: + p_far = p_far[np.newaxis, :] + + p_boundaries = np.concatenate( + [p_near[:, np.newaxis, :], p_far[:, np.newaxis, :]], axis=1 + ) + + pc = PointCloud( + points_3d=p_boundaries.reshape(-1, 3), size=200, color="black", marker="x" + ) + + for i in range(p_boundaries.shape[0]): + # draw t_near, t_far points + _draw_point_cloud( + ax=ax, + point_cloud=pc, + up=up, + scene_radius=scene_radius, + ) + + +def plot_current_batch( + cameras: List[Camera], + cameras_idx: np.ndarray, + rays_o: np.ndarray, + rays_d: np.ndarray, + rgbs: Optional[np.ndarray] = None, + masks: Optional[np.ndarray] = None, + azimuth_deg: float = 60.0, + elevation_deg: float = 30.0, + scene_radius: float = 1.0, + up: Literal["z", "y"] = "z", + draw_origin: bool = True, + draw_image_planes: bool = True, + figsize: Tuple[int, int] = (15, 15), + title: Optional[str] = None, + show: bool = True, + save_path: Optional[Path] = None, # if set, saves the figure to the given path +) -> None: + """ + Returns: + None + """ + + if not (up == "z" or up == "y"): + raise ValueError("up must be either 'y' or 'z'") + + if rgbs is None: + # if rgb is not given, color rays blue + rgbs = np.zeros((rays_o.shape[0], 3)) + rgbs[:, 2] = 1.0 + + if masks is None: + # if mask is not given, set to 0.5 + masks = np.ones((rays_o.shape[0], 1)) * 0.5 + + # get unique camera idxs + unique_cameras_idx = np.unique(cameras_idx, axis=0) + + # init figure + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111, projection="3d") + if title is not None: + ax.set_title(title) + _draw_3d_init( + ax=ax, + scene_radius=scene_radius, + up=up, + elevation_deg=elevation_deg, + azimuth_deg=azimuth_deg, + ) + + if draw_origin: + _draw_cartesian_axis(ax=ax, up=up, scene_radius=scene_radius) + + # iterate over all unique cameras in batch + for idx in unique_cameras_idx: + camera = cameras[idx] + pose = camera.get_pose() + label = camera.label + _draw_camera_frame( + ax=ax, pose=pose, label=label, up=up, scene_radius=scene_radius + ) + if draw_image_planes: + _draw_image_plane(ax=ax, camera=camera, up=up, scene_radius=scene_radius) + + # draw rays + _draw_rays( + ax=ax, + rays_o=rays_o, + rays_d=rays_d, + rgbs=rgbs, + masks=masks, + max_nr_rays=None, + up=up, + scene_radius=scene_radius, + ) + + if save_path is not None: + plt.savefig( + save_path, + transparent=TRANSPARENT, + bbox_inches=BBOX_INCHES, + pad_inches=PAD_INCHES, + dpi=DPI, + ) + print(f"saved figure to {save_path}") + + if show: + plt.show() + + plt.close() + + +def plot_rays_samples( + rays_o: np.ndarray, + rays_d: np.ndarray, + t_near: Optional[np.ndarray] = None, + t_far: Optional[np.ndarray] = None, + nr_rays: int = 32, + point_clouds: List[PointCloud] = None, + camera: Camera = None, + azimuth_deg: float = 60.0, + elevation_deg: float = 30.0, + scene_radius: float = 1.0, + up: Literal["z", "y"] = "z", + draw_origin: bool = True, + figsize: Tuple[int, int] = (15, 15), + title: Optional[str] = None, + show: bool = True, + save_path: Optional[Path] = None, # if set, saves the figure to the given path +) -> None: + """ + Returns: + None + """ + + if not (up == "z" or up == "y"): + raise ValueError("up must be either 'y' or 'z'") + + # init figure + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(111, projection="3d") + if title is not None: + ax.set_title(title) + _draw_3d_init( + ax=ax, + scene_radius=scene_radius, + up=up, + elevation_deg=elevation_deg, + azimuth_deg=azimuth_deg, + ) + + if draw_origin: + _draw_cartesian_axis(ax=ax, up=up, scene_radius=scene_radius) + + # draw points + _draw_point_clouds( + ax=ax, + point_clouds=point_clouds, + # points_3d=points_samples, + # points_3d_colors=points_samples_colors, + # points_3d_labels=points_samples_labels, + # points_3d_sizes=points_samples_sizes, + up=up, + scene_radius=scene_radius, + ) + + # draw rays + _draw_rays( + ax=ax, + rays_o=rays_o, + rays_d=rays_d, + t_near=t_near, + t_far=t_far, + max_nr_rays=nr_rays, + up=up, + scene_radius=scene_radius, + ) + + # draw camera + if camera is not None: + _draw_cameras( + ax=ax, + cameras=[camera], + up=up, + scene_radius=scene_radius, + draw_image_planes=True, + draw_cameras_frustums=True, + ) + + # Get current axes and check if there are any labels + handles, labels = plt.gca().get_legend_handles_labels() + + # Only display legend if there are labels + if labels: + plt.legend() + + if save_path is not None: + plt.savefig( + save_path, + transparent=TRANSPARENT, + bbox_inches=BBOX_INCHES, + pad_inches=PAD_INCHES, + dpi=DPI, + ) + print(f"saved figure to {save_path}") + + if show: + plt.show() + + plt.close() + + +def plot_image( + image: np.ndarray, # (W, H) + title: Optional[str] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + draw_colorbar: bool = False, + cmap: str = "viridis", + figsize: Tuple[int, int] = (15, 15), + show: bool = True, + save_path: Optional[str] = None, +): + """Plots an image. + + Args: + image (np.ndarray): (W, H) or (W, H, 1) or (W, H, 3) or (W, H, 4):. + title (str, optional): Defaults to None. + """ + + # init figure + plt.figure(figsize=figsize) + + if image.ndim == 2: + image = np.expand_dims(image, axis=-1) + # transpose to (H, W, C) + image = np.transpose(image, (1, 0, 2)) + + plt.imshow(image, cmap=cmap) + + # Calculate (height_of_image / width_of_image) + im_ratio = image.shape[0] / image.shape[1] + + if xlabel is not None: + plt.xlabel(xlabel) + else: + plt.xlabel("W") + + if ylabel is not None: + plt.ylabel(ylabel) + else: + plt.ylabel("H") + + if title is not None: + plt.title(title) + + if draw_colorbar: + plt.colorbar(fraction=COLORBAR_FRACTION * im_ratio) + + if save_path is not None: + plt.savefig( + save_path, + transparent=TRANSPARENT, + bbox_inches=BBOX_INCHES, + pad_inches=PAD_INCHES, + dpi=DPI, + ) + print(f"saved figure to {save_path}") + + if show: + plt.show() + + plt.close() diff --git a/optgs/visualization/plots3d/utils.py b/optgs/visualization/plots3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af64aea84d468c103a1ad2d13178713755e8fd9c --- /dev/null +++ b/optgs/visualization/plots3d/utils.py @@ -0,0 +1,554 @@ +import numpy as np +import torch +from torch import Tensor +from typing import Optional, Tuple, Union + + +class Camera: + + def __init__( + self, + intrinsic: np.ndarray, + extrinsic: np.ndarray, # c2w + width: int, + height: int, + color: Optional[str] = None, + label: Optional[str] = None, + alpha: Optional[float] = None, + line_width: Optional[float] = None, + ): + self.intrinsic = intrinsic + self.extrinsic = extrinsic + self.width = width + self.height = height + + # plotting attributes + self.color = color + self.label = label + self.alpha = alpha + self.line_width = line_width + + def get_intrinsics_inv(self) -> np.ndarray: + """Get inverse of intrinsic matrix.""" + # check if matrix is invertible + # if np.linalg.matrix_rank(self.intrinsic) < 3: + # print(self.intrinsic) + # raise ValueError("Intrinsic matrix is not invertible.") + return np.linalg.inv(self.intrinsic) + + def get_rays( + self, + points_2d_screen: Optional[Tensor] = None, + nr_rays_per_pixel: int = 1, + jitter_pixels: bool = False, + device: str = "cpu", + ) -> Tuple[Tensor, Tensor, Tensor]: + """Get rays from 2D screen points. + + Args: + points_2d_screen (Tensor): (N, 2) tensor of 2D screen points. + """ + + """returns image rays origins and directions + for 2d points on the image plane. + If points are not provided, they are sampled + from the image plane for every pixel. + + Args: + points_2d_screen (torch.Tensor, float or int, optional): (N, 2) + Values in [0, W-1], [0, H-1]. + Default is None. + device (str, optional): device to store tensors. Defaults to "cpu". + jitter_pixels (bool, optional): Whether to jitter pixels. + Only used if points_2d_screen is None. + Defaults to False. + Returns: + rays_o (torch.Tensor): rays origins (N, 3) + rays_d (torch.Tensor): rays directions (N, 3) + points_2d_screen (torch.Tensor, float): (N, 2) screen space sampling coordinates + """ + + # sample points if not provided + if points_2d_screen is None: + + assert nr_rays_per_pixel > 0, "nr_rays_per_pixel must be > 0" + assert nr_rays_per_pixel == 1 or ( + nr_rays_per_pixel > 1 and jitter_pixels is True + ), "jitter_pixels must be True if nr_rays_per_pixel > 1" + + pixels = get_pixels(self.height, self.width, device=device) # (W, H, 2) + # reshape pixels to (N, 2) repeat pixels nr_rays_per_pixel times + pixels = pixels.reshape(-1, 2) # (N, 2) + pixels = pixels.repeat_interleave(nr_rays_per_pixel, dim=0) + # get points in screen space + points_2d_screen = pixels_to_points_2d_screen( + pixels, jitter_pixels + ) # (N, 2) + + c2w = torch.from_numpy(self.get_pose()).float().to(device) + intrinsics_inv = torch.from_numpy(self.get_intrinsics_inv()).float().to(device) + + rays_o, rays_d = get_rays_per_points_2d_screen( + c2w, intrinsics_inv, points_2d_screen + ) + + return rays_o, rays_d, points_2d_screen + + def get_center(self) -> np.ndarray: + """Get camera center in world coordinates.""" + return self.extrinsic[:3, 3] + + def get_pose(self) -> np.ndarray: + """Get camera pose (extrinsic matrix).""" + return self.extrinsic + + +class PointCloud: + + def __init__( + self, + points_3d: np.ndarray, + points_rgb: Optional[np.ndarray] = None, # (N, 3) or (3,) + color: Optional[str] = None, + label: Optional[str] = None, + size: Optional[float] = None, + marker: Optional[str] = None, + ): + self.points_3d = points_3d + self.points_rgb = points_rgb + + if self.points_rgb is not None: + # check if dimensions are correct + if self.points_rgb.ndim == 2: + # first dimension must be the same as points_3d + if self.points_rgb.shape[0] != self.points_3d.shape[0]: + raise ValueError( + f"Points RGB must have the same number of points as points 3D, got {self.points_rgb.shape[0]} and {self.points_3d.shape[0]}" + ) + # second dimension must be 3 + if self.points_rgb.shape[1] != 3: + raise ValueError( + f"Points RGB must have shape (N, 3), got {self.points_rgb.shape}" + ) + elif self.points_rgb.ndim == 1: + # first dimension must be 3 + if self.points_rgb.shape[0] != 3: + raise ValueError( + f"Points RGB must have shape (3,), got {self.points_rgb.shape}" + ) + else: + raise ValueError( + f"Points RGB must have shape (N, 3) or (3,), got {self.points_rgb.shape}" + ) + + # plotting attributes + self.color = color + self.label = label + self.size = size + self.marker = marker + + def downsample(self, nr_points: int): + if nr_points >= self.points_3d.shape[0]: + # do nothing + return + + idxs = np.random.choice(self.points_3d.shape[0], nr_points, replace=False) + self.points_3d = self.points_3d[idxs] + + if self.points_rgb is not None: + self.points_rgb = self.points_rgb[idxs] + + def mask(self, mask: np.ndarray): + self.points_3d = self.points_3d[mask] + + if self.points_rgb is not None: + self.points_rgb = self.points_rgb[mask] + + def shape(self): + return self.points_3d.shape + + def __str__(self) -> str: + return f"PointCloud with {self.points_3d.shape[0]} points" + + def transform(self, transformation: np.ndarray): + self.points_3d = apply_transformation_3d(self.points_3d, transformation) + + +def get_mask_points_in_image_range( + points_2d_screen: Union[np.ndarray, torch.Tensor], width: int, height: int +) -> Union[np.ndarray, torch.Tensor]: + """Filter out points that are outside the image.""" + mask = (points_2d_screen[:, 0] >= 0) & (points_2d_screen[:, 0] < width) + mask &= (points_2d_screen[:, 1] >= 0) & (points_2d_screen[:, 1] < height) + return mask + + +def apply_transformation_3d( + points_3d: Union[np.ndarray, torch.Tensor], + transform: Union[np.ndarray, torch.Tensor], +) -> Union[np.ndarray, torch.Tensor]: + """ + Applies a 3D affine transformation to a set of points. + + Args: + points_3d (numpy.ndarray or torch.Tensor): A (N, 3) array of 3D points. + transform (numpy.ndarray or torch.Tensor): A (4, 4) affine transformation matrix + or (N, 4, 4) for per-point transformations. + + Returns: + numpy.ndarray or torch.Tensor: A (N, 3) array of transformed 3D points. + + Raises: + ValueError: If the shapes of `points_3d` or `transform` are invalid. + TypeError: If the input types are inconsistent (mixing NumPy and PyTorch). + """ + # Check dimensionality of points_3d + if points_3d.ndim != 2 or points_3d.shape[1] != 3: + raise ValueError("`points_3d` must be a 2D array of shape (N, 3).") + + # Check dimensionality of transform + if transform.ndim == 2 and transform.shape == (4, 4): + batched_transform = False + elif transform.ndim == 3 and transform.shape[1:] == (4, 4): + batched_transform = True + else: + raise ValueError("`transform` must be of shape (4, 4) or (N, 4, 4).") + + # Ensure consistent types between inputs + if isinstance(points_3d, np.ndarray) and not isinstance(transform, np.ndarray): + raise TypeError("Both inputs must be of the same type (NumPy or PyTorch).") + if isinstance(points_3d, torch.Tensor) and not isinstance(transform, torch.Tensor): + raise TypeError("Both inputs must be of the same type (NumPy or PyTorch).") + + # Convert points_3d to homogeneous coordinates + points_homogeneous = euclidean_to_homogeneous(points_3d) + + # Apply transformation + if isinstance(points_3d, np.ndarray): + if batched_transform: + transformed_points = np.einsum("nij,nj->ni", transform, points_homogeneous) + else: + transformed_points = points_homogeneous @ transform.T + return transformed_points[:, :3] + elif isinstance(points_3d, torch.Tensor): + if batched_transform: + transformed_points = torch.einsum( + "nij,nj->ni", transform, points_homogeneous + ) + else: + transformed_points = points_homogeneous @ transform.T + return transformed_points[:, :3] + + +def euclidean_to_homogeneous( + points: Union[np.ndarray, torch.Tensor], +) -> Union[np.ndarray, torch.Tensor]: + """ + Converts Euclidean coordinates to homogeneous coordinates by appending a column of ones. + + Args: + points (np.ndarray or torch.Tensor): A 2D array of shape (N, C) representing Euclidean points. + + Returns: + np.ndarray or torch.Tensor: A 2D array of shape (N, C+1) in homogeneous coordinates. + + Raises: + TypeError: If `points` is not a NumPy array or PyTorch tensor. + ValueError: If `points` is not a 2D array. + """ + # Check if input is a 2D array + if points.ndim != 2: + raise ValueError("`points` must be a 2D array of shape (N, C).") + + if isinstance(points, np.ndarray): + ones = np.ones((points.shape[0], 1)) + return np.hstack((points, ones)) + elif isinstance(points, torch.Tensor): + ones = torch.ones( + (points.shape[0], 1), dtype=points.dtype, device=points.device + ) + return torch.cat((points, ones), dim=1) + else: + raise TypeError("`points` must be either a numpy.ndarray or torch.Tensor.") + + +def get_pixels(height: int, width: int, device: str = "cpu") -> torch.Tensor: + """returns all image pixels coords + Args: + height (int): frame height + width (int): frame width + device (str, optional): Defaults to "cpu". + Returns: + pixels (torch.Tensor): dtype int32, shape (W, H, 2), values in [0, W-1], [0, H-1] + """ + + pixels_x, pixels_y = torch.meshgrid( + torch.arange(width, device=device), + torch.arange(height, device=device), + indexing="ij", + ) + pixels = torch.stack([pixels_x, pixels_y], dim=-1).type(torch.int32) + + return pixels + + +def get_random_pixels( + height: int, width: int, nr_pixels: int, device: str = "cpu" +) -> torch.Tensor: + """given a number or pixels, return random pixels + Args: + height (int): frame height + width (int): frame width + nr_pixels (int): number of pixels to sample + device (str, optional): Defaults to "cpu". + Returns: + pixels (torch.Tensor, int): (N, 2) with values in [0, W-1], [0, H-1] + """ + # sample nr_pixels random pixels + pixels = torch.rand(nr_pixels, 2, device=device) + pixels[:, 0] *= width + pixels[:, 1] *= height + pixels = pixels.type(torch.int32) + return pixels + + +def get_pixels_centers(pixels: torch.Tensor) -> torch.Tensor: + """return the center of each pixel + Args: + pixels (torch.Tensor): (N, 2) list of pixels + Returns: + pixels_centers (torch.Tensor): (N, 2) list of pixels centers + """ + + points_2d_screen = pixels.float() # cast to float32 + points_2d_screen = points_2d_screen + 0.5 # pixels centers + + return points_2d_screen + + +def pixels_to_points_2d_screen(pixels: torch.Tensor, jitter_pixels: bool = False): + """convert pixels to 2d points on the image plane + + Args: + pixels (torch.Tensor): (W, H, 2) or (N, 2) list of pixels + jitter_pixels (bool): whether to jitter pixels + Returns: + points_2d_screen (torch.Tensor): (N, 2) list of pixels centers (in screen space) + """ + assert pixels.dtype == torch.int32, "pixels must be int32" + + # get pixels as 3d points on a plane at z=-1 (in camera space) + points_2d_screen = get_pixels_centers(pixels) + points_2d_screen = points_2d_screen.reshape(-1, 2) + if jitter_pixels: + points_2d_screen = jitter_points(points_2d_screen) + + return points_2d_screen # (N, 2) + + +def jitter_points(points: torch.Tensor) -> torch.Tensor: + """apply noise to points + + Args: + points (torch.Tensor): (..., 2) list of pixels centers (in screen space) + Returns: + jittered_pixels (torch.Tensor): (..., 2) list of pixels + """ + + assert points.dtype == torch.float32, "points must be float32" + + # # sample offsets from gaussian distribution + # std = 0.16 + # offsets = torch.normal( + # mean=0.0, std=std, size=jittered_points.shape, device=points.device + # ) + # clamp offsets to [-0.5 + eps, 0.5 - eps] + + # uniformlu sampled offsets + offsets = torch.rand_like(points, device=points.device) + offsets -= 0.5 # [-0.5, 0.5] + eps = 1e-6 + offsets = torch.clamp(offsets, -0.5 + eps, 0.5 - eps) + return points + offsets + + +def get_rays_per_points_2d_screen( + c2w: torch.Tensor, intrinsics_inv: torch.Tensor, points_2d_screen: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """given a list of pixels, return rays origins and directions + + Args: + c2w (torch.Tensor): (N, 4, 4) or (4, 4) + intrinsics_inv (torch.Tensor): (N, 3, 3) or (3, 3) + points_2d_screen (torch.Tensor, float): (N, 2) with values in [0, W-1], [0, H-1] + + Returns: + rays_o (torch.Tensor): (N, 3) + rays_d (torch.Tensor): (N, 3) + """ + + # check input shapes + if c2w.ndim == 2: + c2w = c2w.unsqueeze(0) + elif c2w.ndim == 3: + pass + else: + raise ValueError(f"c2w: {c2w.shape} must be (4, 4) or (N, 4, 4)") + + if c2w.shape[1:] != (4, 4): + raise ValueError(f"c2w: {c2w.shape} must be (4, 4) or (N, 4, 4)") + + if intrinsics_inv.ndim == 2: + intrinsics_inv = intrinsics_inv.unsqueeze(0) + elif intrinsics_inv.ndim == 3: + pass + else: + raise ValueError( + f"intrinsics_inv: {intrinsics_inv} must be (N, 3, 3) or (3, 3)" + ) + + if intrinsics_inv.shape[1:] != (3, 3): + raise ValueError( + f"intrinsics_inv: {intrinsics_inv} must be (N, 3, 3) or (3, 3)" + ) + + if points_2d_screen.ndim != 2 or points_2d_screen.shape[1] != 2: + raise ValueError(f"points_2d_screen: {points_2d_screen.shape} must be (N, 2)") + if c2w.shape[0] != points_2d_screen.shape[0] and c2w.shape[0] != 1: + raise ValueError( + f"input shapes do not match: c2w: {c2w.shape} and points_2d_screen: {points_2d_screen.shape}" + ) + if ( + intrinsics_inv.shape[0] != points_2d_screen.shape[0] + and intrinsics_inv.shape[0] != 1 + ): + raise ValueError( + f"input shapes do not match: intrinsics_inv: {intrinsics_inv.shape} and points_2d_screen: {points_2d_screen.shape}" + ) + + # ray origin are the cameras centers + if c2w.shape[0] == points_2d_screen.shape[0]: + rays_o = c2w[:, :3, -1] + else: + rays_o = c2w[0, :3, -1].repeat(points_2d_screen.shape[0], 1) + + # unproject points to 3d camera space + points_3d_camera = local_inv_perspective_projection( + intrinsics_inv, + points_2d_screen, + ) # (N, 3) + # points_3d_unprojected have all z=1 + + # rotate points with c2w rotation + rot = c2w[:, :3, :3] + points_3d_rotated = apply_rotation_3d(points_3d_camera, rot) # (N, 3) + + # normalize rays + rays_d = torch.nn.functional.normalize(points_3d_rotated, dim=-1) # (N, 3) + + return rays_o, rays_d + + +def local_inv_perspective_projection( + intrinsics_inv: Union[np.ndarray, torch.Tensor], + points_2d_screen: Union[np.ndarray, torch.Tensor], +) -> Union[np.ndarray, torch.Tensor]: + """ + Apply inverse perspective projection to 2D screen points. + + Args: + intrinsics_inv (np.ndarray or torch.Tensor): Inverse of camera intrinsic matrix of shape (N, 3, 3) or (3, 3). + points_2d_screen (np.ndarray or torch.Tensor): 2D points in screen coordinates of shape (N, 2). + + Returns: + np.ndarray or torch.Tensor: Unprojected 3D points of shape (N, 3). + + Raises: + ValueError: If inputs have invalid shapes or types. + """ + + # check input shapes + if intrinsics_inv.ndim == 2: + intrinsics_inv = intrinsics_inv[None, ...] # Add batch dimension + elif intrinsics_inv.ndim == 3: + pass + else: + raise ValueError( + f"intrinsics_inv: {intrinsics_inv.shape} must have shape (N, 3, 3) or (3, 3)." + ) + + if intrinsics_inv.shape[1:] != (3, 3): + raise ValueError( + f"intrinsics_inv: {intrinsics_inv.shape} must have shape (N, 3, 3) or (3, 3)." + ) + + if ( + intrinsics_inv.shape[0] != points_2d_screen.shape[0] + and intrinsics_inv.shape[0] != 1 + ): + raise ValueError( + f"input shapes do not match: intrinsics_inv: {intrinsics_inv.shape} and points_2d_screen: {points_2d_screen.shape}." + ) + + if points_2d_screen.ndim == 2 and points_2d_screen.shape[-1] != 2: + raise ValueError("`points_2d_screen` must have shape (N, 2).") + + augmented_points_2d_screen = euclidean_to_homogeneous(points_2d_screen) # (N, 3) + augmented_points_2d_screen = augmented_points_2d_screen[..., None] # (N, 3, 1) + augmented_points_3d_camera = ( + intrinsics_inv @ augmented_points_2d_screen + ) # (N, 3, 3) @ (N, 3, 1) + # reshape to (N, 3) + augmented_points_3d_camera = augmented_points_3d_camera.squeeze(-1) # (N, 3) + + return augmented_points_3d_camera + + +def apply_rotation_3d( + points_3d: Union[np.ndarray, torch.Tensor], rot: Union[np.ndarray, torch.Tensor] +) -> Union[np.ndarray, torch.Tensor]: + """ + Applies a 3D rotation to a set of points. + + Args: + points_3d (numpy.ndarray or torch.Tensor): A (N, 3) array of 3D points. + rot (numpy.ndarray or torch.Tensor): A (3, 3) rotation matrix or a batch (N, 3, 3) of rotation matrices. + + Returns: + numpy.ndarray or torch.Tensor: A (N, 3) array of rotated 3D points. + + Raises: + ValueError: If the shapes of `points_3d` or `rot` are invalid. + TypeError: If the input types are inconsistent (mixing NumPy and PyTorch). + """ + # Validate points_3d shape + if points_3d.ndim != 2 or points_3d.shape[1] != 3: + raise ValueError("`points_3d` must be a 2D array of shape (N, 3).") + + # Validate rotation matrix shape + if rot.ndim == 2 and rot.shape == (3, 3): + batched_rotation = False + elif rot.ndim == 3 and rot.shape[1:] == (3, 3): + batched_rotation = True + else: + raise ValueError("`rot` must be of shape (3, 3) or (N, 3, 3).") + + # Ensure consistent types between inputs + if isinstance(points_3d, np.ndarray) and not isinstance(rot, np.ndarray): + raise TypeError("Both inputs must be of the same type (NumPy or PyTorch).") + if isinstance(points_3d, torch.Tensor) and not isinstance(rot, torch.Tensor): + raise TypeError("Both inputs must be of the same type (NumPy or PyTorch).") + + # Apply rotation + if isinstance(points_3d, np.ndarray): + if batched_rotation: + rotated_points = np.einsum("nij,nj->ni", rot, points_3d) + else: + rotated_points = points_3d @ rot.T + return rotated_points + elif isinstance(points_3d, torch.Tensor): + if batched_rotation: + rotated_points = torch.einsum("nij,nj->ni", rot, points_3d) + else: + rotated_points = points_3d @ rot.T + return rotated_points diff --git a/optgs/visualization/validation_in_3d.py b/optgs/visualization/validation_in_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..c49b372adcde92a1895ef39acb09f29091c548a6 --- /dev/null +++ b/optgs/visualization/validation_in_3d.py @@ -0,0 +1,118 @@ +import torch +from jaxtyping import Float, Shaped +from torch import Tensor + +try: + from ..model.decoder.cuda_splatting import render_cuda_orthographic +except: + pass +from ..model.types import Gaussians +from ..visualization.annotation import add_label +from ..visualization.drawing.cameras import draw_cameras +from .drawing.cameras import compute_equal_aabb_with_margin + + +def pad(images: list[Shaped[Tensor, "..."]]) -> list[Shaped[Tensor, "..."]]: + shapes = torch.stack([torch.tensor(x.shape) for x in images]) + padded_shape = shapes.max(dim=0)[0] + results = [ + torch.ones(padded_shape.tolist(), dtype=x.dtype, device=x.device) + for x in images + ] + for image, result in zip(images, results): + slices = [slice(0, x) for x in image.shape] + result[slices] = image[slices] + return results + + +def render_projections( + gaussians: Gaussians, + resolution: int, + margin: float = 0.1, + draw_label: bool = True, + extra_label: str = "", +) -> Float[Tensor, "batch 3 3 height width"]: + device = gaussians.means.device + b, _, _ = gaussians.means.shape + + # Compute the minima and maxima of the scene. + minima = gaussians.means.min(dim=1).values + maxima = gaussians.means.max(dim=1).values + scene_minima, scene_maxima = compute_equal_aabb_with_margin( + minima, maxima, margin=margin + ) + + projections = [] + for look_axis in range(3): + right_axis = (look_axis + 1) % 3 + down_axis = (look_axis + 2) % 3 + + # Define the extrinsics for rendering. + extrinsics = torch.zeros((b, 4, 4), dtype=torch.float32, device=device) + extrinsics[:, right_axis, 0] = 1 + extrinsics[:, down_axis, 1] = 1 + extrinsics[:, look_axis, 2] = 1 + extrinsics[:, right_axis, 3] = 0.5 * ( + scene_minima[:, right_axis] + scene_maxima[:, right_axis] + ) + extrinsics[:, down_axis, 3] = 0.5 * ( + scene_minima[:, down_axis] + scene_maxima[:, down_axis] + ) + extrinsics[:, look_axis, 3] = scene_minima[:, look_axis] + extrinsics[:, 3, 3] = 1 + + # Define the intrinsics for rendering. + extents = scene_maxima - scene_minima + far = extents[:, look_axis] + near = torch.zeros_like(far) + width = extents[:, right_axis] + height = extents[:, down_axis] + + projection = render_cuda_orthographic( + extrinsics, + width, + height, + near, + far, + (resolution, resolution), + torch.zeros((b, 3), dtype=torch.float32, device=device), + gaussians.means, + gaussians.covariances, + gaussians.harmonics, + gaussians.opacities, + fov_degrees=10.0, + ) + if draw_label: + right_axis_name = "XYZ"[right_axis] + down_axis_name = "XYZ"[down_axis] + label = f"{right_axis_name}{down_axis_name} Projection {extra_label}" + projection = torch.stack([add_label(x, label) for x in projection]) + + projections.append(projection) + + return torch.stack(pad(projections), dim=1) + + +def render_cameras(batch: dict, resolution: int) -> Float[Tensor, "3 3 height width"]: + # Define colors for context and target views. + num_context_views = batch["context"]["extrinsics"].shape[1] + num_target_views = batch["target"]["extrinsics"].shape[1] + color = torch.ones( + (num_target_views + num_context_views, 3), + dtype=torch.float32, + device=batch["target"]["extrinsics"].device, + ) + color[num_context_views:, 1:] = 0 + + return draw_cameras( + resolution, + torch.cat( + (batch["context"]["extrinsics"][0], batch["target"]["extrinsics"][0]) + ), + torch.cat( + (batch["context"]["intrinsics"][0], batch["target"]["intrinsics"][0]) + ), + color, + torch.cat((batch["context"]["near"][0], batch["target"]["near"][0])), + torch.cat((batch["context"]["far"][0], batch["target"]["far"][0])), + ) diff --git a/optgs/visualization/vis_depth.py b/optgs/visualization/vis_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..30f7d22fc9bb6e440b9755f5600831991a4b465c --- /dev/null +++ b/optgs/visualization/vis_depth.py @@ -0,0 +1,46 @@ +import torch +import torch.utils.data +import numpy as np +import torchvision.utils as vutils +import cv2 +from matplotlib.cm import get_cmap +import matplotlib as mpl +import matplotlib.cm as cm + + +# https://github.com/autonomousvision/unimatch/blob/master/utils/visualization.py + + +def vis_disparity(disp): + disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 + disp_vis = disp_vis.astype("uint8") + disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) + + return disp_vis + + +def viz_depth_tensor(disp, return_numpy=False, colormap="plasma", as_uint8=True, vmin=None, vmax=None): + # visualize inverse depth + assert isinstance(disp, torch.Tensor) + device = disp.device + + disp = disp.cpu().numpy() + if vmin is None: + vmin = disp.min() + if vmax is None: + vmax = disp.max() + normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) + colormapped_im = mapper.to_rgba(disp)[:, :, :3] + if as_uint8: + colormapped_im = (colormapped_im * 255).astype(np.uint8) + else: + colormapped_im = (colormapped_im).astype(np.float32) + + if return_numpy: + return colormapped_im + + viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W] + viz = viz.to(device) + + return viz diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..5470b1476042db92c1861524018848bdb43266cb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,65 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "optgs" +version = "0.1.0" +description = "Learn2Splat: a meta-learned optimizer for 3D Gaussian Splatting." +readme = "README.md" +requires-python = ">=3.10" +license = { file = "LICENSE" } +authors = [ + { name = "Naama Pearl" }, + { name = "Stefano Esposito" }, + { name = "Haofei Xu" }, + { name = "Amit Peleg" }, + { name = "Patricia Gschoßmann" }, + { name = "Lorenzo Porzi" }, + { name = "Peter Kontschieder" }, + { name = "Gerard Pons-Moll" }, + { name = "Andreas Geiger" }, +] +keywords = ["gaussian-splatting", "3d-reconstruction", "novel-view-synthesis", "depth-estimation"] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: POSIX :: Linux", + "Topic :: Scientific/Engineering :: Image Processing", +] +# Runtime deps are sourced from requirements.txt (single source of truth, +# also used by setup.sh) — see [tool.setuptools.dynamic] below. +# +# NOT covered here (cannot be expressed as plain PyPI deps — install via +# setup.sh or the documented steps): +# - torch / torchvision / torchaudio (CUDA-specific wheel index) +# - gsplat, nerfacc (built from git against installed torch) +# - submodules/* CUDA extensions (simple-knn, fused-ssim, pointops, +# fused_knn_attn, pycolmap) +dynamic = ["dependencies"] + +[project.urls] +Homepage = "https://naamapearl.github.io/learn2splat/" +Repository = "https://github.com/autonomousvision/learn2splat" + +[project.scripts] +optgs = "optgs.main:main" + +[tool.setuptools.dynamic] +dependencies = { file = ["requirements.txt"] } + +[tool.setuptools.packages.find] +# Flat layout: only the `optgs` package ships. Excludes tests/, visualization/, +# scripts/, baselines/, submodules/, mlcloud_scripts/ at the repo root, and the +# vendored non-package CUDA dirs inside optgs/ (no __init__.py / hyphenated). +include = ["optgs*"] +namespaces = false + +[tool.setuptools.package-data] +# Hydra configs live inside the package and must ship in the wheel so the +# `optgs` CLI works after `pip install` from any directory. +# (Eval-index assets are intentionally NOT bundled — see docs / OPTGS_ASSETS.) +optgs = ["config/**/*.yaml", "config/**/*.yml"] + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b19d042603e2fcfd0f928df6130da272ac1b1b19 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,40 @@ +beartype==0.18.5 +colorama==0.4.6 +colorspacious==1.1.2 +dacite==1.8.1 +e3nn==0.5.1 +einops==0.8.1 +hydra-core +jaxtyping==0.3.2 +lpips==0.1.4 +matplotlib==3.10.1 +moviepy +numpy==1.26.4 +opencv_python==4.11.0.86 +Pillow==11.1.0 +plyfile==1.1 +pytorch_lightning==2.4.0 +viser==1.0.29 +rich==14.3.4 +scikit-image==0.24.0 +sk-video==1.1.10 +tabulate==0.9.0 +tqdm==4.67.1 +wandb==0.17.7 +decorator==5.2.1 +ipython +pytorch-optimizer +pytest +natsort +typeguard +scikit-learn +pandas +albumentations +h5py +huggingface_hub +kornia +loguru +timm +poselib +tyro +piexif \ No newline at end of file diff --git a/submodules/fused-ssim/.gitignore b/submodules/fused-ssim/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..89296b3454152a317bb190a760480e850fa2d47a --- /dev/null +++ b/submodules/fused-ssim/.gitignore @@ -0,0 +1,3 @@ +build/ +fused_ssim.egg-info/ +fused_ssim_cuda.cpython-310-x86_64-linux-gnu.so diff --git a/submodules/fused-ssim/LICENSE b/submodules/fused-ssim/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..541f73944912fde14ffb971f22ba516042b95ab7 --- /dev/null +++ b/submodules/fused-ssim/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Rahul Goel + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/submodules/fused-ssim/README.md b/submodules/fused-ssim/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b2cd9676ef94dbd0f6488bd08b71fc6b5a159310 --- /dev/null +++ b/submodules/fused-ssim/README.md @@ -0,0 +1,72 @@ +# Fully Fused Differentiable SSIM + +This repository contains an efficient fully-fused implementation of [SSIM](https://en.wikipedia.org/wiki/Structural_similarity_index_measure) which is differentiable in nature. There are several factors that contribute to an efficient implementation: +- Convolutions in SSIM are spatially localized leading to fully-fused implementation without touching global memory for intermediate steps. +- Backpropagation through Gaussian Convolution is simply another Gaussian Convolution itself. +- Gaussian Convolutions are separable leading to reduced computation. +- Gaussians are symmetric in nature leading to fewer computations. +- Single convolution pass for multiple statistics. + +As per the original SSIM paper, this implementation uses `11x11` sized convolution kernel. The weights for it have been hardcoded and this is another reason for it's speed. This implementation currently only supports **2D images** but with **variable number of channels** and **batch size**. + +## PyTorch Installation Instructions +- You must have CUDA and PyTorch+CUDA installed in you Python 3.X environment. This project has currently been tested with: + - PyTorch `2.3.1+cu118` and CUDA `11.8` on Ubuntu 24.04 LTS. + - PyTorch `2.4.1+cu124` and CUDA `12.4` on Ubuntu 24.04 LTS. + - PyTorch `2.5.1+cu124` and CUDA `12.6` on Windows 11. +- Run `pip install git+https://github.com/rahul-goel/fused-ssim/` or clone the repository and run `pip install .` from the root of this project. +- setup.py should detect your GPU architecture automatically. If you want to see the output, run `pip install git+https://github.com/rahul-goel/fused-ssim/ -v` or clone the repository and run `pip install . -v` from the root of this project. +- If the previous command does not work, run `python setup.py install` from the root of this project. + +## Usage +```python +import torch +from fused_ssim import fused_ssim + +# predicted_image, gt_image: [BS, CH, H, W] +# predicted_image is differentiable +gt_image = torch.rand(2, 3, 1080, 1920) +predicted_image = torch.nn.Parameter(torch.rand_like(gt_image)) +ssim_value = fused_ssim(predicted_image, gt_image) +``` + +By default, `same` padding is used. To use `valid` padding which is the kind of padding used by [pytorch-mssim](https://github.com/VainF/pytorch-msssim): +```python +ssim_value = fused_ssim(predicted_image, gt_image, padding="valid") +``` + +If you don't want to train and use this only for inference, use the following for even faster speed: +```python +with torch.no_grad(): + ssim_value = fused_ssim(predicted_image, gt_image, train=False) +``` + +## Constraints +- Currently, only one of the images is allowed to be differentiable i.e. only the first image can be `nn.Parameter`. +- Limited to 2D images. +- Images must be normalized to range `[0, 1]`. +- Standard `11x11` convolutions supported. + +## Performance +This implementation is 5-8x faster than the previous fastest (to the best of my knowledge) differentiable SSIM implementation [pytorch-mssim](https://github.com/VainF/pytorch-msssim). + + + +## BibTeX +If you leverage fused SSIM for your research work, please cite our main paper: +``` +@inproceedings{taming3dgs, + author = {Mallick, Saswat Subhajyoti and Goel, Rahul and Kerbl, Bernhard and Steinberger, Markus and Carrasco, Francisco Vicente and De La Torre, Fernando}, + title = {Taming 3DGS: High-Quality Radiance Fields with Limited Resources}, + year = {2024}, + url = {https://doi.org/10.1145/3680528.3687694}, + doi = {10.1145/3680528.3687694}, + booktitle = {SIGGRAPH Asia 2024 Conference Papers}, + series = {SA '24} +} +``` + +## Acknowledgements +Thanks to [Bernhard](https://snosixtyboo.github.io) for the idea. +Thanks to [Janusch](https://github.com/MrNeRF) for further optimizations. +Thanks to [Florian](https://fhahlbohm.github.io/) and [Ishaan](https://ishaanshah.xyz) for testing. diff --git a/submodules/fused-ssim/ext.cpp b/submodules/fused-ssim/ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3eeece12ad64d46e79c810ba7569325f96e35573 --- /dev/null +++ b/submodules/fused-ssim/ext.cpp @@ -0,0 +1,7 @@ +#include +#include "ssim.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fusedssim", &fusedssim); + m.def("fusedssim_backward", &fusedssim_backward); +} diff --git a/submodules/fused-ssim/fused_ssim/__init__.py b/submodules/fused-ssim/fused_ssim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..776fed79e4ef93fb40ff768df7606ab2b0efa192 --- /dev/null +++ b/submodules/fused-ssim/fused_ssim/__init__.py @@ -0,0 +1,42 @@ +from typing import NamedTuple +import torch.nn as nn +import torch +from fused_ssim_cuda import fusedssim, fusedssim_backward + +allowed_padding = ["same", "valid"] + +class FusedSSIMMap(torch.autograd.Function): + @staticmethod + def forward(ctx, C1, C2, img1, img2, padding="same", train=True): + ssim_map, dm_dmu1, dm_dsigma1_sq, dm_dsigma12 = fusedssim(C1, C2, img1, img2, train) + + if padding == "valid": + ssim_map = ssim_map[:, :, 5:-5, 5:-5] + + ctx.save_for_backward(img1.detach(), img2, dm_dmu1, dm_dsigma1_sq, dm_dsigma12) + ctx.C1 = C1 + ctx.C2 = C2 + ctx.padding = padding + + return ssim_map + + @staticmethod + def backward(ctx, opt_grad): + img1, img2, dm_dmu1, dm_dsigma1_sq, dm_dsigma12 = ctx.saved_tensors + C1, C2, padding = ctx.C1, ctx.C2, ctx.padding + dL_dmap = opt_grad + if padding == "valid": + dL_dmap = torch.zeros_like(img1) + dL_dmap[:, :, 5:-5, 5:-5] = opt_grad + grad = fusedssim_backward(C1, C2, img1, img2, dL_dmap, dm_dmu1, dm_dsigma1_sq, dm_dsigma12) + return None, None, grad, None, None, None + +def fused_ssim(img1, img2, padding="same", train=True): + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + assert padding in allowed_padding + + img1 = img1.contiguous() + map = FusedSSIMMap.apply(C1, C2, img1, img2, padding, train) + return map.mean() diff --git a/submodules/fused-ssim/images/albert.jpg b/submodules/fused-ssim/images/albert.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e07952182bbd42b2059cf12066159ff9fb579f03 --- /dev/null +++ b/submodules/fused-ssim/images/albert.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:855317ee9e69adf8fe7ae6428bf286515baa4afc77fe31a260564998ee8571f8 +size 2309396 diff --git a/submodules/fused-ssim/images/inference_time.png b/submodules/fused-ssim/images/inference_time.png new file mode 100644 index 0000000000000000000000000000000000000000..a7d85cccd2910eb0e4c7ea278617e207f674c99a --- /dev/null +++ b/submodules/fused-ssim/images/inference_time.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ec2a0f9ef49a2289536c6650f17542a6e73abd295810671303e488731813a193 +size 139569 diff --git a/submodules/fused-ssim/images/inference_time_4090.png b/submodules/fused-ssim/images/inference_time_4090.png new file mode 100644 index 0000000000000000000000000000000000000000..7ad6fea7531dda447836044877236992df4c8741 --- /dev/null +++ b/submodules/fused-ssim/images/inference_time_4090.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ea5b004315116b01ff6bf5c23e41d8798c802b50ead6efe451d495ee8f420a36 +size 130548 diff --git a/submodules/fused-ssim/images/predicted.jpg b/submodules/fused-ssim/images/predicted.jpg new file mode 100644 index 0000000000000000000000000000000000000000..20393cc046628fd63eb083a97af1e637c557b6e0 --- /dev/null +++ b/submodules/fused-ssim/images/predicted.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:863c0be5c47667097a98f5af767255f2d571f6a1fe6acd20b0ff9e8902472486 +size 1275874 diff --git a/submodules/fused-ssim/images/training_time.png b/submodules/fused-ssim/images/training_time.png new file mode 100644 index 0000000000000000000000000000000000000000..61edf9fb654e71fa416e2c3ebce08b0b6d54280e --- /dev/null +++ b/submodules/fused-ssim/images/training_time.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ac0c555f98187c0691a1bce759cd4b718234da39617382cdc17d2935381a6fa3 +size 134463 diff --git a/submodules/fused-ssim/images/training_time_4090.png b/submodules/fused-ssim/images/training_time_4090.png new file mode 100644 index 0000000000000000000000000000000000000000..c75a18cebb7d9fa88dc6d00135114a34e4160a73 --- /dev/null +++ b/submodules/fused-ssim/images/training_time_4090.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:940a91c00126dc06d84a504e49ceb77b8bc6bf98285f3ece0058c91a75eee678 +size 139914 diff --git a/submodules/fused-ssim/setup.py b/submodules/fused-ssim/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..47d2689243b49a9005e061edd3a95e30170d279b --- /dev/null +++ b/submodules/fused-ssim/setup.py @@ -0,0 +1,82 @@ +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension +import torch +import sys +import os + +# Force unbuffered output +os.environ['PYTHONUNBUFFERED'] = '1' +sys.stderr.reconfigure(line_buffering=True) + + +# Default fallback architectures +fallback_archs = [ + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", +] + +nvcc_args = [ + "-O3", + "--maxrregcount=32", + "--use_fast_math", +] + +detected_arch = None + +if torch.cuda.is_available(): + try: + device = torch.cuda.current_device() + compute_capability = torch.cuda.get_device_capability(device) + arch = f"sm_{compute_capability[0]}{compute_capability[1]}" + + # Print to multiple outputs + arch_msg = f"Detected GPU architecture: {arch}" + print(arch_msg) + print(arch_msg, file=sys.stderr, flush=True) + + nvcc_args.append(f"-arch={arch}") + detected_arch = arch + except Exception as e: + error_msg = f"Failed to detect GPU architecture: {e}. Falling back to multiple architectures." + print(error_msg) + print(error_msg, file=sys.stderr, flush=True) + nvcc_args.extend(fallback_archs) +else: + cuda_msg = "CUDA not available. Falling back to multiple architectures." + print(cuda_msg) + print(cuda_msg, file=sys.stderr, flush=True) + nvcc_args.extend(fallback_archs) + +# Create a custom class that prints the architecture information +class CustomBuildExtension(BuildExtension): + def build_extensions(self): + arch_info = f"Building with GPU architecture: {detected_arch if detected_arch else 'multiple architectures'}" + print("\n" + "="*50) + print(arch_info) + print("="*50 + "\n") + super().build_extensions() + +setup( + name="fused_ssim", + packages=['fused_ssim'], + ext_modules=[ + CUDAExtension( + name="fused_ssim_cuda", + sources=[ + "ssim.cu", + "ext.cpp"], + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": nvcc_args + } + ) + ], + cmdclass={ + 'build_ext': CustomBuildExtension + } +) + +# Print again at the end of setup.py execution +final_msg = f"Setup completed. NVCC args: {nvcc_args}" +print(final_msg) \ No newline at end of file diff --git a/submodules/fused-ssim/ssim.cu b/submodules/fused-ssim/ssim.cu new file mode 100644 index 0000000000000000000000000000000000000000..2df752d7fcedc008f34de3ed185a6dc507fa4476 --- /dev/null +++ b/submodules/fused-ssim/ssim.cu @@ -0,0 +1,517 @@ +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + +// ------------------------------------------ +// Constant Memory for Gaussian Coefficients +// ------------------------------------------ +__constant__ float cGauss[11] = { + 0.001028380123898387f, + 0.0075987582094967365f, + 0.036000773310661316f, + 0.10936068743467331f, + 0.21300552785396576f, + 0.26601171493530273f, + 0.21300552785396576f, + 0.10936068743467331f, + 0.036000773310661316f, + 0.0075987582094967365f, + 0.001028380123898387f +}; + +// ------------------------------------------ +// Block and Shared Memory Dimensions +// ------------------------------------------ +#define BLOCK_X 16 +#define BLOCK_Y 16 +#define HALO 5 + +#define SHARED_X (BLOCK_X + 2 * HALO) +#define SHARED_Y (BLOCK_Y + 2 * HALO) + +// For partial results after horizontal pass +#define CONV_X BLOCK_X +#define CONV_Y SHARED_Y + +// ------------------------------------------ +// Utility: Safe pixel fetch w/ zero padding +// ------------------------------------------ +__device__ __forceinline__ float get_pix_value( + const float* img, + int b, int c, int y, int x, + int CH, int H, int W +) { + if (x < 0 || x >= W || y < 0 || y >= H) { + return 0.0f; + } + return img[b * CH * H * W + c * H * W + y * W + x]; +} + +// ------------------------------------------ +// Forward Kernel: Fused SSIM +// - Two-pass convolution to get mu1, mu2, +// sigma1_sq, sigma2_sq, sigma12, etc. +// - Writes final SSIM map to ssim_map +// - Optionally writes partial derivatives +// to dm_dmu1, dm_dsigma1_sq, dm_dsigma12 +// ------------------------------------------ +__global__ void fusedssimCUDA( + int H, + int W, + int CH, + float C1, + float C2, + const float* __restrict__ img1, + const float* __restrict__ img2, + float* __restrict__ ssim_map, + float* __restrict__ dm_dmu1, + float* __restrict__ dm_dsigma1_sq, + float* __restrict__ dm_dsigma12 +) { + auto block = cg::this_thread_block(); + const int bIdx = block.group_index().z; // batch index + const int pix_y = block.group_index().y * BLOCK_Y + block.thread_index().y; + const int pix_x = block.group_index().x * BLOCK_X + block.thread_index().x; + const int pix_id = pix_y * W + pix_x; + const int num_pix = H * W; + + // Shared memory for the tile (img1, img2) + __shared__ float sTile[SHARED_Y][SHARED_X][2]; + // After horizontal pass, store partial sums here + // xconv[y][x] -> (sumX, sumX^2, sumY, sumY^2, sumXY) + __shared__ float xconv[CONV_Y][CONV_X][5]; + + // Each block processes B x C sub-batches. We loop over channels: + for (int c = 0; c < CH; ++c) { + // ------------------------------------------------------------ + // 1) Load (img1, img2) tile + halo into shared memory + // ------------------------------------------------------------ + { + const int tileSize = SHARED_Y * SHARED_X; + const int threads = BLOCK_X * BLOCK_Y; + const int steps = (tileSize + threads - 1) / threads; + + const int tileStartY = block.group_index().y * BLOCK_Y; + const int tileStartX = block.group_index().x * BLOCK_X; + + for (int s = 0; s < steps; ++s) { + int tid = s * threads + block.thread_rank(); + if (tid < tileSize) { + int local_y = tid / SHARED_X; + int local_x = tid % SHARED_X; + int gy = tileStartY + local_y - HALO; + int gx = tileStartX + local_x - HALO; + + float X = get_pix_value(img1, bIdx, c, gy, gx, CH, H, W); + float Y = get_pix_value(img2, bIdx, c, gy, gx, CH, H, W); + + sTile[local_y][local_x][0] = X; + sTile[local_y][local_x][1] = Y; + } + } + } + block.sync(); + + // ------------------------------------------------------------ + // 2) Horizontal convolution (11x1) in shared memory + // We'll accumulate symmetrical pairs around center. + // ------------------------------------------------------------ + { + int ly = threadIdx.y; + int lx = threadIdx.x + HALO; // skip left halo + + float sumX = 0.f; + float sumX2 = 0.f; + float sumY = 0.f; + float sumY2 = 0.f; + float sumXY = 0.f; + + // #pragma unroll for those 5 pairs +#pragma unroll + for (int d = 1; d <= HALO; ++d) { + float w = cGauss[HALO - d]; + float Xleft = sTile[ly][lx - d][0]; + float Yleft = sTile[ly][lx - d][1]; + float Xright = sTile[ly][lx + d][0]; + float Yright = sTile[ly][lx + d][1]; + + sumX += (Xleft + Xright) * w; + sumX2 += ((Xleft * Xleft) + (Xright * Xright)) * w; + sumY += (Yleft + Yright) * w; + sumY2 += ((Yleft * Yleft) + (Yright * Yright)) * w; + sumXY += ((Xleft * Yleft) + (Xright * Yright)) * w; + } + // center + { + float centerX = sTile[ly][lx][0]; + float centerY = sTile[ly][lx][1]; + float wc = cGauss[HALO]; + sumX += centerX * wc; + sumX2 += (centerX * centerX) * wc; + sumY += centerY * wc; + sumY2 += (centerY * centerY) * wc; + sumXY += (centerX * centerY) * wc; + } + + // Write out partial sums + xconv[ly][threadIdx.x][0] = sumX; + xconv[ly][threadIdx.x][1] = sumX2; + xconv[ly][threadIdx.x][2] = sumY; + xconv[ly][threadIdx.x][3] = sumY2; + xconv[ly][threadIdx.x][4] = sumXY; + + // Possibly handle second row in same warp + int ly2 = ly + BLOCK_Y; + if (ly2 < CONV_Y) { + sumX = 0.f; sumX2 = 0.f; + sumY = 0.f; sumY2 = 0.f; + sumXY = 0.f; + +#pragma unroll + for (int d = 1; d <= HALO; ++d) { + float w = cGauss[HALO - d]; + float Xleft = sTile[ly2][lx - d][0]; + float Yleft = sTile[ly2][lx - d][1]; + float Xright = sTile[ly2][lx + d][0]; + float Yright = sTile[ly2][lx + d][1]; + + sumX += (Xleft + Xright) * w; + sumX2 += ((Xleft * Xleft) + (Xright * Xright)) * w; + sumY += (Yleft + Yright) * w; + sumY2 += ((Yleft * Yleft) + (Yright * Yright)) * w; + sumXY += ((Xleft * Yleft) + (Xright * Yright)) * w; + } + // center + { + float cx = sTile[ly2][lx][0]; + float cy = sTile[ly2][lx][1]; + float wc = cGauss[HALO]; + sumX += cx * wc; + sumX2 += (cx * cx) * wc; + sumY += cy * wc; + sumY2 += (cy * cy) * wc; + sumXY += (cx * cy) * wc; + } + xconv[ly2][threadIdx.x][0] = sumX; + xconv[ly2][threadIdx.x][1] = sumX2; + xconv[ly2][threadIdx.x][2] = sumY; + xconv[ly2][threadIdx.x][3] = sumY2; + xconv[ly2][threadIdx.x][4] = sumXY; + } + } + block.sync(); + + // ------------------------------------------------------------ + // 3) Vertical convolution (1x11) + final SSIM + // ------------------------------------------------------------ + { + int ly = threadIdx.y + HALO; + int lx = threadIdx.x; + + float out0 = 0.f, out1 = 0.f, out2 = 0.f, out3 = 0.f, out4 = 0.f; + +#pragma unroll + for (int d = 1; d <= HALO; ++d) { + float w = cGauss[HALO - d]; + float* top = xconv[ly - d][lx]; + float* bot = xconv[ly + d][lx]; + + out0 += (top[0] + bot[0]) * w; + out1 += (top[1] + bot[1]) * w; + out2 += (top[2] + bot[2]) * w; + out3 += (top[3] + bot[3]) * w; + out4 += (top[4] + bot[4]) * w; + } + // center + { + float wC = cGauss[HALO]; + float* ctr = xconv[ly][lx]; + out0 += ctr[0] * wC; + out1 += ctr[1] * wC; + out2 += ctr[2] * wC; + out3 += ctr[3] * wC; + out4 += ctr[4] * wC; + } + + if (pix_x < W && pix_y < H) { + float mu1 = out0; + float mu2 = out2; + float mu1_sq = mu1 * mu1; + float mu2_sq = mu2 * mu2; + + float sigma1_sq = out1 - mu1_sq; + float sigma2_sq = out3 - mu2_sq; + float sigma12 = out4 - mu1 * mu2; + + float A = mu1_sq + mu2_sq + C1; + float B = sigma1_sq + sigma2_sq + C2; + float C_ = 2.f * mu1 * mu2 + C1; + float D_ = 2.f * sigma12 + C2; + + float val = (C_ * D_) / (A * B); + + int global_idx = bIdx * CH * num_pix + c * num_pix + pix_id; + ssim_map[global_idx] = val; + + if (dm_dmu1) { + // partial derivatives + float d_m_dmu1 = ( + (mu2 * 2.f * D_) / (A * B) + - (mu2 * 2.f * C_) / (A * B) + - (mu1 * 2.f * C_ * D_) / (A * A * B) + + (mu1 * 2.f * C_ * D_) / (A * B * B) + ); + float d_m_dsigma1_sq = (-C_ * D_) / (A * B * B); + float d_m_dsigma12 = (2.f * C_) / (A * B); + + dm_dmu1[global_idx] = d_m_dmu1; + dm_dsigma1_sq[global_idx] = d_m_dsigma1_sq; + dm_dsigma12[global_idx] = d_m_dsigma12; + } + } + } + } +} + +// ------------------------------------------ +// Backward Kernel: Apply chain rule to get +// dL/d(img1) from partial derivatives +// (dm_dmu1, dm_dsigma1_sq, dm_dsigma12) +// and dL/dmap (the gradient from above). +// ------------------------------------------ +__global__ void fusedssim_backwardCUDA( + int H, + int W, + int CH, + float C1, + float C2, + const float* __restrict__ img1, + const float* __restrict__ img2, + const float* __restrict__ dL_dmap, + float* __restrict__ dL_dimg1, + const float* __restrict__ dm_dmu1, + const float* __restrict__ dm_dsigma1_sq, + const float* __restrict__ dm_dsigma12 +) { + auto block = cg::this_thread_block(); + + const int pix_y = block.group_index().y * BLOCK_Y + block.thread_index().y; + const int pix_x = block.group_index().x * BLOCK_X + block.thread_index().x; + const int pix_id = pix_y * W + pix_x; + const int num_pix = H * W; + const int bIdx = block.group_index().z; + + // Shared memory for the fused data: + // [0]: dm_dmu1*dL, [1]: dm_dsigma1_sq*dL, [2]: dm_dsigma12*dL + __shared__ float sData[3][SHARED_Y][SHARED_X]; + __shared__ float sScratch[CONV_Y][CONV_X][3]; + + for (int c = 0; c < CH; ++c) { + float p1 = 0.f, p2 = 0.f; + if (pix_x < W && pix_y < H) { + p1 = get_pix_value(img1, bIdx, c, pix_y, pix_x, CH, H, W); + p2 = get_pix_value(img2, bIdx, c, pix_y, pix_x, CH, H, W); + } + + // (1) Load + fuse multiplication + { + const int start_y = block.group_index().y * BLOCK_Y; + const int start_x = block.group_index().x * BLOCK_X; + + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int warp_id = tid / 32; + int lane_id = tid % 32; + int totalThreads = BLOCK_X * BLOCK_Y; + int num_warps = (totalThreads + 31) / 32; + + for (int row = warp_id; row < SHARED_Y; row += num_warps) { + int gy = start_y + row - HALO; + for (int col = lane_id; col < SHARED_X; col += 32) { + int gx = start_x + col - HALO; + + float chain = get_pix_value(dL_dmap, bIdx, c, gy, gx, CH, H, W); + float vmu = get_pix_value(dm_dmu1, bIdx, c, gy, gx, CH, H, W); + float vs1 = get_pix_value(dm_dsigma1_sq,bIdx, c, gy, gx, CH, H, W); + float vs12 = get_pix_value(dm_dsigma12, bIdx, c, gy, gx, CH, H, W); + + sData[0][row][col] = vmu * chain; + sData[1][row][col] = vs1 * chain; + sData[2][row][col] = vs12 * chain; + } + } + } + block.sync(); + + // (2) Horizontal pass + { + int ly = threadIdx.y; + int lx = threadIdx.x + HALO; + + for (int pass = 0; pass < 2; ++pass) { + int yy = ly + pass * BLOCK_Y; + if (yy < CONV_Y) { + float accum0 = 0.f, accum1 = 0.f, accum2 = 0.f; + +#pragma unroll + for (int d = 1; d <= HALO; ++d) { + float w = cGauss[HALO - d]; + float left0 = sData[0][yy][lx - d]; + float left1 = sData[1][yy][lx - d]; + float left2 = sData[2][yy][lx - d]; + + float right0 = sData[0][yy][lx + d]; + float right1 = sData[1][yy][lx + d]; + float right2 = sData[2][yy][lx + d]; + + accum0 += (left0 + right0) * w; + accum1 += (left1 + right1) * w; + accum2 += (left2 + right2) * w; + } + // center + { + float wc = cGauss[HALO]; + float c0 = sData[0][yy][lx]; + float c1 = sData[1][yy][lx]; + float c2 = sData[2][yy][lx]; + accum0 += c0 * wc; + accum1 += c1 * wc; + accum2 += c2 * wc; + } + + sScratch[yy][threadIdx.x][0] = accum0; + sScratch[yy][threadIdx.x][1] = accum1; + sScratch[yy][threadIdx.x][2] = accum2; + } + } + } + block.sync(); + + // (3) Vertical pass -> finalize dL/d(img1) + if (pix_x < W && pix_y < H) { + int ly = threadIdx.y + HALO; + int lx = threadIdx.x; + + float sum0 = 0.f, sum1 = 0.f, sum2 = 0.f; + +#pragma unroll + for (int d = 1; d <= HALO; ++d) { + float w = cGauss[HALO - d]; + float* top = sScratch[ly - d][lx]; + float* bot = sScratch[ly + d][lx]; + + sum0 += (top[0] + bot[0]) * w; + sum1 += (top[1] + bot[1]) * w; + sum2 += (top[2] + bot[2]) * w; + } + // center + { + float wc = cGauss[HALO]; + float* ctr = sScratch[ly][lx]; + sum0 += ctr[0] * wc; + sum1 += ctr[1] * wc; + sum2 += ctr[2] * wc; + } + + // final accumulation + float dL_dpix = sum0 + (2.f * p1) * sum1 + (p2) * sum2; + + int out_idx = bIdx * CH * num_pix + c * num_pix + pix_id; + dL_dimg1[out_idx] = dL_dpix; + } + block.sync(); + } +} + +// ------------------------------------------ +// PyTorch Interface (Forward) +// Returns (ssim_map, dm_dmu1, dm_dsigma1_sq, dm_dsigma12). +// If train=false, derivative Tensors are empty. +// ------------------------------------------ +std::tuple +fusedssim( + float C1, + float C2, + torch::Tensor &img1, + torch::Tensor &img2, + bool train +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(img1)); + int B = img1.size(0); + int CH = img1.size(1); + int H = img1.size(2); + int W = img1.size(3); + + // Launch config + dim3 grid((W + BLOCK_X - 1) / BLOCK_X, + (H + BLOCK_Y - 1) / BLOCK_Y, + B); + dim3 block(BLOCK_X, BLOCK_Y); + + // Output SSIM map + auto ssim_map = torch::zeros_like(img1, img1.options()).contiguous(); + + // Optionally allocate derivative Tensors + auto dm_dmu1 = train ? torch::zeros_like(img1) : torch::empty({0}, img1.options()); + auto dm_dsigma1_sq = train ? torch::zeros_like(img1) : torch::empty({0}, img1.options()); + auto dm_dsigma12 = train ? torch::zeros_like(img1) : torch::empty({0}, img1.options()); + + fusedssimCUDA<<>>( + H, W, CH, C1, C2, + img1.contiguous().data_ptr(), + img2.contiguous().data_ptr(), + ssim_map.data_ptr(), + train ? dm_dmu1.data_ptr() : nullptr, + train ? dm_dsigma1_sq.data_ptr() : nullptr, + train ? dm_dsigma12.data_ptr() : nullptr + ); + + return std::make_tuple(ssim_map, dm_dmu1, dm_dsigma1_sq, dm_dsigma12); +} + +// ------------------------------------------ +// PyTorch Interface (Backward) +// Takes the gradient wrt the SSIM map and +// the partial derivatives from forward; +// returns dL/d(img1). +// ------------------------------------------ +torch::Tensor +fusedssim_backward( + float C1, + float C2, + torch::Tensor &img1, + torch::Tensor &img2, + torch::Tensor &dL_dmap, + torch::Tensor &dm_dmu1, + torch::Tensor &dm_dsigma1_sq, + torch::Tensor &dm_dsigma12 +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(img1)); + int B = img1.size(0); + int CH = img1.size(1); + int H = img1.size(2); + int W = img1.size(3); + + auto dL_dimg1 = torch::zeros_like(img1); + + dim3 grid((W + BLOCK_X - 1) / BLOCK_X, + (H + BLOCK_Y - 1) / BLOCK_Y, + B); + dim3 block(BLOCK_X, BLOCK_Y); + + fusedssim_backwardCUDA<<>>( + H, W, CH, C1, C2, + img1.contiguous().data_ptr(), + img2.contiguous().data_ptr(), + dL_dmap.contiguous().data_ptr(), + dL_dimg1.data_ptr(), + dm_dmu1.contiguous().data_ptr(), + dm_dsigma1_sq.contiguous().data_ptr(), + dm_dsigma12.contiguous().data_ptr() + ); + + return dL_dimg1; +} \ No newline at end of file diff --git a/submodules/fused-ssim/ssim.h b/submodules/fused-ssim/ssim.h new file mode 100644 index 0000000000000000000000000000000000000000..adb00543b71403bd5350b0734f1ced0bee8da2fe --- /dev/null +++ b/submodules/fused-ssim/ssim.h @@ -0,0 +1,26 @@ +#pragma once +#include +#include +#include +#include + +std::tuple +fusedssim( + float C1, + float C2, + torch::Tensor &img1, + torch::Tensor &img2, + bool train +); + +torch::Tensor +fusedssim_backward( + float C1, + float C2, + torch::Tensor &img1, + torch::Tensor &img2, + torch::Tensor &dL_dmap, + torch::Tensor &dm_dmu1, + torch::Tensor &dm_dsigma1_sq, + torch::Tensor &dm_dsigma12 +); diff --git a/submodules/fused-ssim/tests/genplot.py b/submodules/fused-ssim/tests/genplot.py new file mode 100644 index 0000000000000000000000000000000000000000..f5b841c6faaa614deefc89735da7261b7e3a9ece --- /dev/null +++ b/submodules/fused-ssim/tests/genplot.py @@ -0,0 +1,100 @@ +import torch +from fused_ssim import fused_ssim +from pytorch_msssim import SSIM +import matplotlib.pyplot as plt +import numpy as np +import time +import os + +plt.style.use('ggplot') +gpu = torch.cuda.get_device_name() + +if __name__ == "__main__": + torch.manual_seed(0) + + B, CH = 5, 1 + dimensions = list(range(50, 1550, 50)) + iterations = 50 + + data = { + "pytorch_mssim": [], + "fused-ssim": [] + } + + pm_ssim = SSIM(data_range=1.0, channel=CH) + + for d in dimensions: + with torch.no_grad(): + img1_og = torch.rand([B, CH, d, d], device="cuda") + img2_og = torch.rand([B, CH, d, d], device="cuda") + + img1_mine_same = torch.nn.Parameter(img1_og.clone()) + img2_mine_same = img2_og.clone() + + img1_pm = torch.nn.Parameter(img1_og.clone()) + img2_pm = img2_og.clone() + + begin = time.time() + for _ in range(iterations): + pm_ssim_val = pm_ssim(img1_pm, img2_pm) + pm_ssim_val.backward() + torch.cuda.synchronize() + end = time.time() + data["pytorch_mssim"].append((end - begin) / iterations * 1000) + + begin = time.time() + for _ in range(iterations): + mine_ssim_val_same = fused_ssim(img1_mine_same, img2_mine_same) + mine_ssim_val_same.backward() + torch.cuda.synchronize() + end = time.time() + data["fused-ssim"].append((end - begin) / iterations * 1000) + + num_pixels = (B * np.array(dimensions) ** 2) / 1e6 + plt.plot(num_pixels, data["pytorch_mssim"], label="pytorch_mssim") + plt.plot(num_pixels, data["fused-ssim"], label="fused-ssim") + plt.legend() + plt.xlabel("Number of pixels (in millions).") + plt.ylabel("Time for one training iteration (ms).") + plt.title(f"Training Benchmark on {gpu}.") + plt.savefig(os.path.join("..", "images", "training_time.png"), dpi=300) + + data = { + "pytorch_mssim": [], + "fused-ssim": [] + } + + plt.clf() + for d in dimensions: + with torch.no_grad(): + img1_og = torch.rand([B, CH, d, d], device="cuda") + img2_og = torch.rand([B, CH, d, d], device="cuda") + + img1_mine_same = torch.nn.Parameter(img1_og.clone()) + img2_mine_same = img2_og.clone() + + img1_pm = torch.nn.Parameter(img1_og.clone()) + img2_pm = img2_og.clone() + + begin = time.time() + for _ in range(iterations): + pm_ssim_val = pm_ssim(img1_pm, img2_pm) + torch.cuda.synchronize() + end = time.time() + data["pytorch_mssim"].append((end - begin) / iterations * 1000) + + begin = time.time() + for _ in range(iterations): + mine_ssim_val_same = fused_ssim(img1_mine_same, img2_mine_same, train=False) + torch.cuda.synchronize() + end = time.time() + data["fused-ssim"].append((end - begin) / iterations * 1000) + + num_pixels = (B * np.array(dimensions) ** 2) / 1e6 + plt.plot(num_pixels, data["pytorch_mssim"], label="pytorch_mssim") + plt.plot(num_pixels, data["fused-ssim"], label="fused-ssim") + plt.legend() + plt.xlabel("Number of pixels (in millions).") + plt.ylabel("Time for one inference iteration (ms).") + plt.title(f"Inference Benchmark on {gpu}.") + plt.savefig(os.path.join("..", "images", "inference_time.png"), dpi=300) diff --git a/submodules/fused-ssim/tests/test.py b/submodules/fused-ssim/tests/test.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd85d5fd36f56ed7d80c085468adaae8cd71cda --- /dev/null +++ b/submodules/fused-ssim/tests/test.py @@ -0,0 +1,157 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.autograd import Variable +from math import exp +from fused_ssim import fused_ssim +from pytorch_msssim import SSIM +import time + +# Reference Implementation is taken from the following: +# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py +# https://github.com/graphdeco-inria/gaussian-splatting/blob/main/utils/loss_utils.py + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +if __name__ == "__main__": + torch.manual_seed(0) + B, CH, H, W = 5, 5, 1080, 1920 + pm_ssim = SSIM(data_range=1.0, channel=CH) + iterations = 100 + + for _ in range(iterations): + with torch.no_grad(): + img1_og = nn.Parameter(torch.rand([B, CH, H, W], device="cuda")) + img2_og = torch.rand([B, CH, H, W], device="cuda") + + img1_mine_same = nn.Parameter(img1_og.clone()) + img2_mine_same = img2_og.clone() + + img1_mine_valid = nn.Parameter(img1_og.clone()) + img2_mine_valid = img2_og.clone() + + img1_pm = nn.Parameter(img1_og.clone()) + img2_pm = img2_og.clone() + + og_ssim_val = ssim(img1_og, img2_og) + mine_ssim_val_same = fused_ssim(img1_mine_same, img2_mine_same) + mine_ssim_val_valid = fused_ssim(img1_mine_valid, img2_mine_valid, "valid") + pm_ssim_val = pm_ssim(img1_pm, img2_pm) + + assert torch.isclose(og_ssim_val, mine_ssim_val_same) + assert torch.isclose(mine_ssim_val_valid, pm_ssim_val) + + og_ssim_val.backward() + mine_ssim_val_same.backward() + mine_ssim_val_valid.backward() + pm_ssim_val.backward() + + assert torch.isclose(img1_og.grad, img1_mine_same.grad).all() + assert torch.isclose(img1_mine_valid.grad, img1_pm.grad).all() + + img1 = nn.Parameter(torch.rand([B, CH, H, W], device="cuda")) + img2 = torch.rand([B, CH, H, W], device="cuda") + + # benchmark og + begin = time.time() + for _ in range(iterations): + og_ssim_val = ssim(img1, img2) + torch.cuda.synchronize() + end = time.time() + og_time_forward = (end - begin) / iterations * 1000 + print("Reference Time (Forward):", og_time_forward, "ms") + + begin = time.time() + for _ in range(iterations): + og_ssim_val = ssim(img1, img2) + og_ssim_val.backward() + torch.cuda.synchronize() + end = time.time() + og_time_backward = (end - begin) / iterations * 1000 - og_time_forward + print("Reference Time (Backward):", og_time_backward, "ms") + + # benchmark pytorch_mssim (pm) + begin = time.time() + for _ in range(iterations): + pm_ssim_val = pm_ssim(img1, img2) + torch.cuda.synchronize() + end = time.time() + pm_time_forward = (end - begin) / iterations * 1000 + print("pytorch_mssim Time (Forward):", pm_time_forward, "ms") + + begin = time.time() + for _ in range(iterations): + pm_ssim_val = pm_ssim(img1, img2) + pm_ssim_val.backward() + torch.cuda.synchronize() + end = time.time() + pm_time_backward = (end - begin) / iterations * 1000 - pm_time_forward + print("pytorch_mssim Time (Backward):", pm_time_backward, "ms") + + + # benchmark mine + begin = time.time() + for _ in range(iterations): + mine_ssim_val = fused_ssim(img1, img2) + torch.cuda.synchronize() + end = time.time() + mine_time_forward = (end - begin) / iterations * 1000 + print("fused-ssim Time (Forward):", mine_time_forward, "ms") + + begin = time.time() + for _ in range(iterations): + mine_ssim_val = fused_ssim(img1, img2) + mine_ssim_val.backward() + torch.cuda.synchronize() + end = time.time() + mine_time_backward = (end - begin) / iterations * 1000 - mine_time_forward + print("fused-ssim Time (Backward):", mine_time_backward, "ms") + + begin = time.time() + for _ in range(iterations): + mine_ssim_val = fused_ssim(img1, img2, train=False) + torch.cuda.synchronize() + end = time.time() + mine_time_infer = (end - begin) / iterations * 1000 + print("fused-ssim Time (Inference):", mine_time_infer, "ms") diff --git a/submodules/fused-ssim/tests/train_image.py b/submodules/fused-ssim/tests/train_image.py new file mode 100644 index 0000000000000000000000000000000000000000..3af09a455c3bff3b2cd9933abe7d6acb1d9a8360 --- /dev/null +++ b/submodules/fused-ssim/tests/train_image.py @@ -0,0 +1,29 @@ +import torch +import numpy as np +import os +from PIL import Image +from fused_ssim import fused_ssim + +gt_image = torch.tensor(np.array(Image.open(os.path.join("..", "images", "albert.jpg"))), dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) / 255.0 +pred_image = torch.nn.Parameter(torch.rand_like(gt_image)) + +with torch.no_grad(): + ssim_value = fused_ssim(pred_image, gt_image, train=False) + print("Starting with SSIM value:", ssim_value) + + +optimizer = torch.optim.Adam([pred_image]) + +while ssim_value < 0.9999: + optimizer.zero_grad() + loss = 1.0 - fused_ssim(pred_image, gt_image) + loss.backward() + optimizer.step() + + with torch.no_grad(): + ssim_value = fused_ssim(pred_image, gt_image, train=False) + print("SSIM value:", ssim_value) + +pred_image = (pred_image * 255.0).squeeze(0).squeeze(0) +to_save = pred_image.detach().cpu().numpy().astype(np.uint8) +Image.fromarray(to_save).save(os.path.join("..", "images", "predicted.jpg")) diff --git a/submodules/fused_knn_attn/__init__.py b/submodules/fused_knn_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..170c201dea03bf81ce8abdcc3b4e83a0cb076c71 --- /dev/null +++ b/submodules/fused_knn_attn/__init__.py @@ -0,0 +1,177 @@ +""" +Fused KNN Gather + Attention kernel. + +Replaces the two-step gather-then-attend pattern in KNNAttention with a single +CUDA kernel that never materializes the [N, K, C] intermediate tensors. + +Usage: + from .fused_knn_attn import fused_knn_attention # autograd Function + + # Same result as the unfused version but faster and less memory + out = fused_knn_attention(q, k, v, knn_idx, scale) +""" + +import torch +from torch.autograd import Function + +# Try to import the compiled CUDA extension; fall back to pure-PyTorch +try: + import fused_knn_attn_cuda as _C + FUSED_KNN_ATTN_CUDA_AVAILABLE = True +except ImportError: + FUSED_KNN_ATTN_CUDA_AVAILABLE = False + + +class FusedKNNAttentionFunction(Function): + """Autograd function for fused KNN gather + scaled dot-product attention. + + Forward: + q, k, v: [N, C] (query, key, value features for all points) + idx: [N, K] (pre-computed KNN indices, int32) + scale: float (attention scale factor, typically head_dim ** -0.5) + + Returns: + out: [N, C] + """ + + @staticmethod + def forward(ctx, q, k, v, idx, scale): + # Ensure contiguous float32 for the CUDA kernel + q = q.contiguous().float() + k = k.contiguous().float() + v = v.contiguous().float() + idx = idx.contiguous().int() + + N, C = q.shape + num_k = idx.shape[1] + + out = torch.empty((N, C), dtype=torch.float32, device=q.device) + attn_weights = torch.empty((N, num_k), dtype=torch.float32, device=q.device) + + _C.fused_knn_attn_forward_cuda( + q, k, v, idx, out, attn_weights, + N, C, num_k, float(scale) + ) + + ctx.save_for_backward(q, k, v, idx, attn_weights) + ctx.scale = scale + ctx.N = N + ctx.C = C + ctx.num_k = num_k + + return out + + @staticmethod + def backward(ctx, grad_out): + q, k, v, idx, attn_weights = ctx.saved_tensors + scale = ctx.scale + N, C, num_k = ctx.N, ctx.C, ctx.num_k + + grad_out = grad_out.contiguous().float() + + grad_q = torch.zeros((N, C), dtype=torch.float32, device=q.device) + grad_k = torch.zeros((N, C), dtype=torch.float32, device=q.device) + grad_v = torch.zeros((N, C), dtype=torch.float32, device=q.device) + + _C.fused_knn_attn_backward_cuda( + grad_out, q, k, v, idx, attn_weights, + grad_q, grad_k, grad_v, + N, C, num_k, float(scale) + ) + + # idx and scale don't need gradients + return grad_q, grad_k, grad_v, None, None + + +class FusedKNNAttentionFunctionPyTorch(Function): + """Pure-PyTorch fallback (same semantics, no CUDA extension required). + + Avoids materializing full [N, K, C] by iterating over K neighbors. + Still faster than the original due to not creating [N, K, C] tensors, + but slower than the CUDA kernel. + """ + + @staticmethod + def forward(ctx, q, k, v, idx, scale): + N, C = q.shape + num_k = idx.shape[1] + + # Compute scores by iterating over neighbors (avoids [N, K, C] tensor) + scores = torch.empty(N, num_k, device=q.device, dtype=q.dtype) + for kk in range(num_k): + neighbor_idx = idx[:, kk].long() # [N] + k_neighbor = k[neighbor_idx] # [N, C] + scores[:, kk] = (q * k_neighbor).sum(dim=-1) * scale + + attn_weights = torch.softmax(scores, dim=-1) # [N, K] + + # Compute output + out = torch.zeros(N, C, device=q.device, dtype=q.dtype) + for kk in range(num_k): + neighbor_idx = idx[:, kk].long() + v_neighbor = v[neighbor_idx] # [N, C] + out += attn_weights[:, kk:kk+1] * v_neighbor + + ctx.save_for_backward(q, k, v, idx, attn_weights) + ctx.scale = scale + return out + + @staticmethod + def backward(ctx, grad_out): + q, k, v, idx, attn_weights = ctx.saved_tensors + scale = ctx.scale + N, C = q.shape + num_k = idx.shape[1] + + # grad_attn[k] = dot(grad_out, V[idx[:, k]]) + grad_attn = torch.empty(N, num_k, device=q.device, dtype=q.dtype) + for kk in range(num_k): + neighbor_idx = idx[:, kk].long() + v_neighbor = v[neighbor_idx] + grad_attn[:, kk] = (grad_out * v_neighbor).sum(dim=-1) + + # Softmax backward: grad_scores = attn * (grad_attn - sum(attn * grad_attn)) + ds = (attn_weights * grad_attn).sum(dim=-1, keepdim=True) # [N, 1] + grad_scores = attn_weights * (grad_attn - ds) # [N, K] + + # grad_Q + grad_q = torch.zeros(N, C, device=q.device, dtype=q.dtype) + for kk in range(num_k): + neighbor_idx = idx[:, kk].long() + k_neighbor = k[neighbor_idx] + grad_q += grad_scores[:, kk:kk+1] * k_neighbor * scale + + # grad_K (scatter add) + grad_k = torch.zeros_like(k) + for kk in range(num_k): + neighbor_idx = idx[:, kk].long() + contrib = grad_scores[:, kk:kk+1] * q * scale # [N, C] + grad_k.index_add_(0, neighbor_idx, contrib) + + # grad_V (scatter add) + grad_v = torch.zeros_like(v) + for kk in range(num_k): + neighbor_idx = idx[:, kk].long() + contrib = attn_weights[:, kk:kk+1] * grad_out # [N, C] + grad_v.index_add_(0, neighbor_idx, contrib) + + return grad_q, grad_k, grad_v, None, None + + +def fused_knn_attention(q, k, v, idx, scale): + """Fused KNN gather + attention. Uses CUDA kernel if available, else PyTorch fallback. + + Args: + q: [N, C] query features + k: [N, C] key features + v: [N, C] value features + idx: [N, K] pre-computed KNN neighbor indices (int32) + scale: attention scale factor + + Returns: + out: [N, C] attention output + """ + if FUSED_KNN_ATTN_CUDA_AVAILABLE and q.is_cuda: + return FusedKNNAttentionFunction.apply(q, k, v, idx, scale) + else: + return FusedKNNAttentionFunctionPyTorch.apply(q, k, v, idx, scale) diff --git a/submodules/fused_knn_attn/csrc/fused_knn_attn.cpp b/submodules/fused_knn_attn/csrc/fused_knn_attn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e34628ba1608fb7453e9d8ddae3df24f7553913b --- /dev/null +++ b/submodules/fused_knn_attn/csrc/fused_knn_attn.cpp @@ -0,0 +1,36 @@ +#include +#include "fused_knn_attn_kernel.h" + +void fused_knn_attn_forward_cuda( + at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor idx, + at::Tensor out, at::Tensor attn_weights, + int N, int C, int num_k, float scale +) { + fused_knn_attn_forward_cuda_launcher( + q.data_ptr(), k.data_ptr(), v.data_ptr(), + idx.data_ptr(), + out.data_ptr(), attn_weights.data_ptr(), + N, C, num_k, scale + ); +} + +void fused_knn_attn_backward_cuda( + at::Tensor grad_out, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor idx, at::Tensor attn_weights, + at::Tensor grad_q, at::Tensor grad_k, at::Tensor grad_v, + int N, int C, int num_k, float scale +) { + fused_knn_attn_backward_cuda_launcher( + grad_out.data_ptr(), q.data_ptr(), + k.data_ptr(), v.data_ptr(), + idx.data_ptr(), attn_weights.data_ptr(), + grad_q.data_ptr(), grad_k.data_ptr(), + grad_v.data_ptr(), + N, C, num_k, scale + ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_knn_attn_forward_cuda", &fused_knn_attn_forward_cuda); + m.def("fused_knn_attn_backward_cuda", &fused_knn_attn_backward_cuda); +} diff --git a/submodules/fused_knn_attn/csrc/fused_knn_attn_kernel.cu b/submodules/fused_knn_attn/csrc/fused_knn_attn_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d7dcb7c9ade0417752df6cfce0ff2e4be54705ac --- /dev/null +++ b/submodules/fused_knn_attn/csrc/fused_knn_attn_kernel.cu @@ -0,0 +1,293 @@ +/* + * Fused KNN Gather + Scaled Dot-Product Attention CUDA Kernel + * + * Fuses the two separate gather operations (for K and V) and the attention + * computation into a single kernel, avoiding materialization of [N, K, C] + * intermediate tensors. + * + * Forward: + * Given Q [N, C], K [N, C], V [N, C], idx [N, num_k]: + * For each query point n: + * 1. Gather K[idx[n, :]] and compute scores = Q[n] . K[neighbor] * scale + * 2. Softmax over scores + * 3. Gather V[idx[n, :]] and compute out = sum_k attn[k] * V[neighbor_k] + * + * Backward: + * Given grad_out [N, C], saved Q, K, V, idx, attn_weights: + * Computes grad_Q [N, C], grad_K [N, C], grad_V [N, C] + */ + +#include +#include +#include +#include + +#define THREADS_PER_BLOCK 256 +#define MAX_K_NEIGHBORS 64 + +// Warp-level reduction +__device__ __forceinline__ float warp_reduce_sum(float val) { + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return val; +} + +// Block-level reduction using shared memory +__device__ float block_reduce_sum(float val, float* shared, int tid, int block_size) { + int lane = tid & 31; + int warp_id = tid >> 5; + + val = warp_reduce_sum(val); + + if (lane == 0) shared[warp_id] = val; + __syncthreads(); + + int num_warps = (block_size + 31) / 32; + val = (tid < num_warps) ? shared[tid] : 0.0f; + if (warp_id == 0) { + val = warp_reduce_sum(val); + } + return val; // result valid in thread 0 +} + +// ============================================================================ +// FORWARD KERNEL +// ============================================================================ +// One block per query point. Threads cooperate across the C dimension. +// Shared memory: scores[K] + attn[K] + reduction_buf[num_warps] + +__global__ void fused_knn_attn_forward_kernel( + const float* __restrict__ q, // [N, C] + const float* __restrict__ k, // [N, C] + const float* __restrict__ v, // [N, C] + const int* __restrict__ idx, // [N, num_k] + float* __restrict__ out, // [N, C] + float* __restrict__ attn_weights, // [N, num_k] saved for backward + const int N, + const int C, + const int num_k, + const float scale +) { + const int n = blockIdx.x; + if (n >= N) return; + + const int tid = threadIdx.x; + const int block_size = blockDim.x; + + // Shared memory: scores[num_k] + attn[num_k] + reduction_buf[ceil(block_size/32)] + extern __shared__ float smem[]; + float* scores = smem; // [num_k] + float* attn = scores + num_k; // [num_k] + float* reduce_buf = attn + num_k; // [ceil(block_size/32)] + + const float* q_n = q + (long long)n * C; + const int* idx_n = idx + (long long)n * num_k; + + // ------- Step 1: Compute attention scores ------- + for (int kk = 0; kk < num_k; kk++) { + int neighbor = idx_n[kk]; + const float* k_neighbor = k + (long long)neighbor * C; + + // Dot product Q[n] . K[neighbor] over C, distributed across threads + float partial = 0.0f; + for (int c = tid; c < C; c += block_size) { + partial += q_n[c] * k_neighbor[c]; + } + + float dot = block_reduce_sum(partial, reduce_buf, tid, block_size); + if (tid == 0) { + scores[kk] = dot * scale; + } + __syncthreads(); + } + + // ------- Step 2: Softmax ------- + if (tid == 0) { + float max_s = -FLT_MAX; + for (int kk = 0; kk < num_k; kk++) { + max_s = fmaxf(max_s, scores[kk]); + } + float sum_exp = 0.0f; + for (int kk = 0; kk < num_k; kk++) { + attn[kk] = expf(scores[kk] - max_s); + sum_exp += attn[kk]; + } + float inv_sum = 1.0f / sum_exp; + for (int kk = 0; kk < num_k; kk++) { + attn[kk] *= inv_sum; + attn_weights[n * num_k + kk] = attn[kk]; + } + } + __syncthreads(); + + // ------- Step 3: Weighted sum of V neighbors ------- + float* out_n = out + (long long)n * C; + for (int c = tid; c < C; c += block_size) { + float val = 0.0f; + for (int kk = 0; kk < num_k; kk++) { + int neighbor = idx_n[kk]; + val += attn[kk] * v[(long long)neighbor * C + c]; + } + out_n[c] = val; + } +} + + +// ============================================================================ +// BACKWARD KERNEL +// ============================================================================ +// One block per query point. Computes grad_Q and scatters grad_K, grad_V +// using atomicAdd. +// +// Equations: +// grad_attn[k] = sum_c grad_out[n,c] * V[idx[n,k], c] +// ds = sum_k attn[k] * grad_attn[k] +// grad_scores[k] = attn[k] * (grad_attn[k] - ds) +// grad_Q[n, c] = sum_k grad_scores[k] * K[idx[n,k], c] * scale +// grad_K[idx[n,k], c] += grad_scores[k] * Q[n, c] * scale (atomicAdd) +// grad_V[idx[n,k], c] += attn[k] * grad_out[n, c] (atomicAdd) + +__global__ void fused_knn_attn_backward_kernel( + const float* __restrict__ grad_out, // [N, C] + const float* __restrict__ q, // [N, C] + const float* __restrict__ k, // [N, C] + const float* __restrict__ v, // [N, C] + const int* __restrict__ idx, // [N, num_k] + const float* __restrict__ attn_weights, // [N, num_k] + float* __restrict__ grad_q, // [N, C] + float* __restrict__ grad_k, // [N, C] + float* __restrict__ grad_v, // [N, C] + const int N, + const int C, + const int num_k, + const float scale +) { + const int n = blockIdx.x; + if (n >= N) return; + + const int tid = threadIdx.x; + const int block_size = blockDim.x; + + // Shared memory: grad_attn[num_k] + attn[num_k] + grad_scores[num_k] + reduce_buf[ceil(block_size/32)] + extern __shared__ float smem[]; + float* s_grad_attn = smem; // [num_k] + float* s_attn = s_grad_attn + num_k; // [num_k] + float* s_grad_scores = s_attn + num_k; // [num_k] + float* reduce_buf = s_grad_scores + num_k; // [ceil(block_size/32)] + + const float* grad_out_n = grad_out + (long long)n * C; + const float* q_n = q + (long long)n * C; + const int* idx_n = idx + (long long)n * num_k; + float* grad_q_n = grad_q + (long long)n * C; + + // Load attn weights into shared memory + if (tid < num_k) { + s_attn[tid] = attn_weights[(long long)n * num_k + tid]; + } + __syncthreads(); + + // ------- Step 1: Compute grad_attn[k] = dot(grad_out[n], V[idx[n,k]]) ------- + for (int kk = 0; kk < num_k; kk++) { + int neighbor = idx_n[kk]; + const float* v_neighbor = v + (long long)neighbor * C; + + float partial = 0.0f; + for (int c = tid; c < C; c += block_size) { + partial += grad_out_n[c] * v_neighbor[c]; + } + float dot = block_reduce_sum(partial, reduce_buf, tid, block_size); + if (tid == 0) { + s_grad_attn[kk] = dot; + } + __syncthreads(); + } + + // ------- Step 2: Softmax backward ------- + // ds = sum_k attn[k] * grad_attn[k] + // grad_scores[k] = attn[k] * (grad_attn[k] - ds) + if (tid == 0) { + float ds = 0.0f; + for (int kk = 0; kk < num_k; kk++) { + ds += s_attn[kk] * s_grad_attn[kk]; + } + for (int kk = 0; kk < num_k; kk++) { + s_grad_scores[kk] = s_attn[kk] * (s_grad_attn[kk] - ds); + } + } + __syncthreads(); + + // ------- Step 3: grad_Q[n, c] = sum_k grad_scores[k] * K[idx[n,k], c] * scale ------- + for (int c = tid; c < C; c += block_size) { + float g = 0.0f; + for (int kk = 0; kk < num_k; kk++) { + int neighbor = idx_n[kk]; + g += s_grad_scores[kk] * k[(long long)neighbor * C + c]; + } + grad_q_n[c] = g * scale; + } + + // ------- Step 4: Scatter grad_K and grad_V using atomicAdd ------- + for (int kk = 0; kk < num_k; kk++) { + int neighbor = idx_n[kk]; + float gs = s_grad_scores[kk] * scale; + float aw = s_attn[kk]; + + for (int c = tid; c < C; c += block_size) { + // grad_K[neighbor, c] += grad_scores[k] * Q[n, c] * scale + atomicAdd(grad_k + (long long)neighbor * C + c, gs * q_n[c]); + // grad_V[neighbor, c] += attn[k] * grad_out[n, c] + atomicAdd(grad_v + (long long)neighbor * C + c, aw * grad_out_n[c]); + } + } +} + + +// ============================================================================ +// C++ Launcher Functions +// ============================================================================ + +void fused_knn_attn_forward_cuda_launcher( + const float* q, const float* k, const float* v, const int* idx, + float* out, float* attn_weights, + int N, int C, int num_k, float scale +) { + int block_size = THREADS_PER_BLOCK; + if (C < block_size) { + // Round up to next power of 2 for efficient reductions + block_size = 1; + while (block_size < C) block_size <<= 1; + if (block_size < 32) block_size = 32; // min warp size + } + + int num_warps = (block_size + 31) / 32; + // smem: scores[num_k] + attn[num_k] + reduce_buf[num_warps] + int smem_size = (2 * num_k + num_warps) * sizeof(float); + + fused_knn_attn_forward_kernel<<>>( + q, k, v, idx, out, attn_weights, N, C, num_k, scale + ); +} + +void fused_knn_attn_backward_cuda_launcher( + const float* grad_out, const float* q, const float* k, const float* v, + const int* idx, const float* attn_weights, + float* grad_q, float* grad_k, float* grad_v, + int N, int C, int num_k, float scale +) { + int block_size = THREADS_PER_BLOCK; + if (C < block_size) { + block_size = 1; + while (block_size < C) block_size <<= 1; + if (block_size < 32) block_size = 32; + } + + int num_warps = (block_size + 31) / 32; + // smem: grad_attn[num_k] + attn[num_k] + grad_scores[num_k] + reduce_buf[num_warps] + int smem_size = (3 * num_k + num_warps) * sizeof(float); + + fused_knn_attn_backward_kernel<<>>( + grad_out, q, k, v, idx, attn_weights, + grad_q, grad_k, grad_v, N, C, num_k, scale + ); +} diff --git a/submodules/fused_knn_attn/csrc/fused_knn_attn_kernel.h b/submodules/fused_knn_attn/csrc/fused_knn_attn_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7600826243e7180e06fbfbb4a01c142c5088d0f6 --- /dev/null +++ b/submodules/fused_knn_attn/csrc/fused_knn_attn_kernel.h @@ -0,0 +1,31 @@ +#ifndef _FUSED_KNN_ATTN_KERNEL_H +#define _FUSED_KNN_ATTN_KERNEL_H + +#include + +// Forward: q [N,C], k [N,C], v [N,C], idx [N,K] -> out [N,C], attn [N,K] +void fused_knn_attn_forward_cuda( + at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor idx, + at::Tensor out, at::Tensor attn_weights, + int N, int C, int num_k, float scale); + +// Backward: grad_out [N,C] -> grad_q [N,C], grad_k [N,C], grad_v [N,C] +void fused_knn_attn_backward_cuda( + at::Tensor grad_out, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor idx, at::Tensor attn_weights, + at::Tensor grad_q, at::Tensor grad_k, at::Tensor grad_v, + int N, int C, int num_k, float scale); + +// CUDA launcher functions (C++ linkage, called from .cpp, defined in .cu) +void fused_knn_attn_forward_cuda_launcher( + const float* q, const float* k, const float* v, const int* idx, + float* out, float* attn_weights, + int N, int C, int num_k, float scale); + +void fused_knn_attn_backward_cuda_launcher( + const float* grad_out, const float* q, const float* k, const float* v, + const int* idx, const float* attn_weights, + float* grad_q, float* grad_k, float* grad_v, + int N, int C, int num_k, float scale); + +#endif diff --git a/submodules/fused_knn_attn/setup.py b/submodules/fused_knn_attn/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a44948078c570de2475aa70c0acce9e884065c --- /dev/null +++ b/submodules/fused_knn_attn/setup.py @@ -0,0 +1,26 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +from distutils.sysconfig import get_config_vars + +(opt,) = get_config_vars("OPT") +os.environ["OPT"] = " ".join( + flag for flag in opt.split() if flag != "-Wstrict-prototypes" +) + +setup( + name="fused_knn_attn_cuda", + version="1.0", + install_requires=["torch"], + ext_modules=[ + CUDAExtension( + name="fused_knn_attn_cuda", + sources=[ + "csrc/fused_knn_attn.cpp", + "csrc/fused_knn_attn_kernel.cu", + ], + extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]}, + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/submodules/pointops/__init__.py b/submodules/pointops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c8f75488366c12e144febe3adccd63b40820cdfa --- /dev/null +++ b/submodules/pointops/__init__.py @@ -0,0 +1 @@ +from .functions import * diff --git a/submodules/pointops/functions/__init__.py b/submodules/pointops/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c05f2f4b74f1ae4343daf9b38b4576d75f13e81 --- /dev/null +++ b/submodules/pointops/functions/__init__.py @@ -0,0 +1,14 @@ +from .query import knn_query, ball_query, random_ball_query +from .sampling import farthest_point_sampling +from .grouping import grouping, grouping2 +from .interpolation import interpolation, interpolation2 +from .subtraction import subtraction +from .aggregation import aggregation +from .attention import attention_relation_step, attention_fusion_step +from .utils import ( + query_and_group, + knn_query_and_group, + ball_query_and_group, + batch2offset, + offset2batch, +) diff --git a/submodules/pointops/functions/aggregation.py b/submodules/pointops/functions/aggregation.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f62444a70d317dfb8df4adc1167bba5dd19ef1 --- /dev/null +++ b/submodules/pointops/functions/aggregation.py @@ -0,0 +1,57 @@ +import torch +from torch.autograd import Function + +from pointops._C import aggregation_forward_cuda, aggregation_backward_cuda + + +class Aggregation(Function): + @staticmethod + def forward(ctx, input, position, weight, idx): + """ + input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample) + output: (n, c) + """ + assert ( + input.is_contiguous() + and position.is_contiguous() + and weight.is_contiguous() + ) + n, nsample, c = position.shape + w_c = weight.shape[-1] + output = torch.cuda.FloatTensor(n, c).zero_() + aggregation_forward_cuda( + n, nsample, c, w_c, input, position, weight, idx, output + ) + ctx.save_for_backward(input, position, weight, idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (n, c) + output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c') + """ + input, position, weight, idx = ctx.saved_tensors + n, nsample, c = position.shape + w_c = weight.shape[-1] + grad_input = torch.cuda.FloatTensor(n, c).zero_() + grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_() + grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_() + aggregation_backward_cuda( + n, + nsample, + c, + w_c, + input, + position, + weight, + idx, + grad_output, + grad_input, + grad_position, + grad_weight, + ) + return grad_input, grad_position, grad_weight, None + + +aggregation = Aggregation.apply diff --git a/submodules/pointops/functions/attention.py b/submodules/pointops/functions/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4e774ff67051d6272f7de3fd751bf3b712431249 --- /dev/null +++ b/submodules/pointops/functions/attention.py @@ -0,0 +1,120 @@ +import torch +from torch.autograd import Function + +from pointops._C import ( + attention_relation_step_forward_cuda, + attention_relation_step_backward_cuda, + attention_fusion_step_forward_cuda, + attention_fusion_step_backward_cuda, +) + + +class AttentionRelationStep(Function): + @staticmethod + def forward(ctx, query, key, weight, index_target, index_refer): + """ + input - query: (n, g, c), key: (n, g, c), weight: (c) 1_c for scatter attention, + index_target: (m), index_refer: (m) + output - relation: (M, g) + """ + + assert ( + query.is_contiguous() + and key.is_contiguous() + and index_target.is_contiguous() + and index_refer.is_contiguous() + and weight.is_contiguous() + ) + + assert index_target.shape[0] == index_refer.shape[0] + + _, g, c = query.shape + m = index_target.shape[0] + output = torch.cuda.FloatTensor(m, g).zero_() + attention_relation_step_forward_cuda( + m, g, c, query, key, weight, index_target.int(), index_refer.int(), output + ) + ctx.save_for_backward(query, key, weight, index_target, index_refer) + return output + + @staticmethod + def backward(ctx, grad_output): + query, key, weight, index_target, index_refer = ctx.saved_tensors + n, g, c = query.shape + m = index_target.shape[0] + grad_query = torch.cuda.FloatTensor(n, g, c).zero_() + grad_key = torch.cuda.FloatTensor(n, g, c).zero_() + grad_weight = torch.cuda.FloatTensor(c).zero_() + attention_relation_step_backward_cuda( + m, + g, + c, + query, + grad_query, + key, + grad_key, + weight, + grad_weight, + index_target.int(), + index_refer.int(), + grad_output, + ) + return grad_query, grad_key, None, None, None + + +class AttentionFusionStep(Function): + @staticmethod + def forward(ctx, weight, value, index_target, index_refer): + """ + input - weight: (m, g), value: (n, g, c) + index_target: (m), index_value: (m) + output - output: (n, g, c) + """ + + assert ( + weight.is_contiguous() + and value.is_contiguous() + and index_target.is_contiguous() + and index_refer.is_contiguous() + and weight.is_contiguous() + ) + + assert index_target.shape[0] == index_refer.shape[0] + + n, g, c = value.shape + m = index_refer.shape[0] + output = torch.cuda.FloatTensor(n, g, c).zero_() + attention_fusion_step_forward_cuda( + m, g, c, weight, value, index_target.int(), index_refer.int(), output + ) + ctx.save_for_backward(weight, value, index_target, index_refer) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_output: (n, g, c) + output: grad_weight: (m, g), grad_value: (n, g, c), none, none + """ + weight, value, index_target, index_refer = ctx.saved_tensors + n, g, c = value.shape + m = index_target.shape[0] + grad_weight = torch.cuda.FloatTensor(m, g).zero_() + grad_value = torch.cuda.FloatTensor(n, g, c).zero_() + attention_fusion_step_backward_cuda( + m, + g, + c, + weight, + grad_weight, + value, + grad_value, + index_target.int(), + index_refer.int(), + grad_output, + ) + return grad_weight, grad_value, None, None + + +attention_relation_step = AttentionRelationStep.apply +attention_fusion_step = AttentionFusionStep.apply diff --git a/submodules/pointops/functions/grouping.py b/submodules/pointops/functions/grouping.py new file mode 100644 index 0000000000000000000000000000000000000000..bd47c4748132a9e128be13ab61283d9ad7bfc2b8 --- /dev/null +++ b/submodules/pointops/functions/grouping.py @@ -0,0 +1,63 @@ +import torch +from torch.autograd import Function + +from pointops._C import grouping_forward_cuda, grouping_backward_cuda + + +class Grouping(Function): + @staticmethod + def forward(ctx, input, idx): + """ + input: input: (n, c), idx : (m, nsample) + output: (m, nsample, c) + """ + assert input.is_contiguous() and idx.is_contiguous() + m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1] + output = torch.cuda.FloatTensor(m, nsample, c) + grouping_forward_cuda(m, nsample, c, input, idx, output) + ctx.n = n + ctx.save_for_backward(idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (m, c, nsample) + output: (n, c), None + """ + n = ctx.n + (idx,) = ctx.saved_tensors + m, nsample, c = grad_output.shape + grad_input = torch.cuda.FloatTensor(n, c).zero_() + grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input) + return grad_input, None + + +def grouping(idx, feat, xyz, new_xyz=None, with_xyz=False): + if new_xyz is None: + new_xyz = xyz + assert xyz.is_contiguous() and feat.is_contiguous() + m, nsample, c = idx.shape[0], idx.shape[1], feat.shape[1] + if idx.min() < 0: + feat = torch.cat([feat, torch.zeros([1, c]).to(feat.device)], dim=0) + grouped_feat = torch.nn.functional.embedding(idx, feat) # [m,num_sample,c] + + if with_xyz: + assert new_xyz.is_contiguous() + if idx.min() < 0: + xyz = torch.cat([xyz, torch.zeros([1, 3]).to(xyz.device)], dim=0) + mask = torch.sign(idx + 1) + grouped_xyz = xyz[idx.view(-1).long(), :].view( + m, nsample, 3 + ) - new_xyz.unsqueeze( + 1 + ) # (m, num_sample, 3) + grouped_xyz = torch.einsum( + "n s c, n s -> n s c", grouped_xyz, mask + ) # (m, num_sample, 3) + return torch.cat((grouped_xyz, grouped_feat), -1) + else: + return grouped_feat + + +grouping2 = Grouping.apply diff --git a/submodules/pointops/functions/interpolation.py b/submodules/pointops/functions/interpolation.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5c861f272f89421fa097505d9882b2c473a060 --- /dev/null +++ b/submodules/pointops/functions/interpolation.py @@ -0,0 +1,59 @@ +import torch +from torch.autograd import Function + +from pointops._C import interpolation_forward_cuda, interpolation_backward_cuda +from .query import knn_query + + +def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3): + """ + input: coords: (m, 3), new_xyz: (n, 3), color: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() + idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) # (n, 3), (n, 3) + dist_recip = 1.0 / (dist + 1e-8) # (n, 3) + norm = torch.sum(dist_recip, dim=1, keepdim=True) + weight = dist_recip / norm # (n, 3) + + new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_() + for i in range(k): + new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) + return new_feat + + +class Interpolation(Function): + @staticmethod + def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3): + """ + input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous() + idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) # (n, k), (n, k) + dist_recip = 1.0 / (dist + 1e-8) # (n, k) + norm = torch.sum(dist_recip, dim=1, keepdim=True) + weight = dist_recip / norm # (n, k) + + n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0] + output = torch.cuda.FloatTensor(n, c).zero_() + interpolation_forward_cuda(n, c, k, input, idx, weight, output) + ctx.m, ctx.k = m, k + ctx.save_for_backward(idx, weight) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) + output: (n, c) + """ + m, k = ctx.m, ctx.k + idx, weight = ctx.saved_tensors + n, c = grad_output.shape + grad_input = torch.cuda.FloatTensor(m, c).zero_() + interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input) + return None, None, grad_input, None, None, None + + +interpolation2 = Interpolation.apply diff --git a/submodules/pointops/functions/query.py b/submodules/pointops/functions/query.py new file mode 100644 index 0000000000000000000000000000000000000000..c1294b6125e00ae1d1dec21ed52a803c164c4810 --- /dev/null +++ b/submodules/pointops/functions/query.py @@ -0,0 +1,113 @@ +import torch +from torch.autograd import Function + +from pointops._C import knn_query_cuda, random_ball_query_cuda, ball_query_cuda + + +class KNNQuery(Function): + @staticmethod + def forward(ctx, nsample, xyz, offset, new_xyz=None, new_offset=None): + """ + input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) + output: idx: (m, nsample) -1 is placeholder, dist2: (m, nsample) + """ + if new_xyz is None or new_offset is None: + new_xyz = xyz + new_offset = offset + assert xyz.is_contiguous() and new_xyz.is_contiguous() + m = new_xyz.shape[0] + idx = torch.cuda.IntTensor(m, nsample).zero_() + dist2 = torch.cuda.FloatTensor(m, nsample).zero_() + knn_query_cuda( + m, nsample, xyz, new_xyz, offset.int(), new_offset.int(), idx, dist2 + ) + return idx, torch.sqrt(dist2) + + +class RandomBallQuery(Function): + """Random Ball Query. + + Find nearby points in spherical space. + """ + + @staticmethod + def forward( + ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None + ): + """ + input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) + output: idx: (m, nsample), dist2: (m, nsample) + """ + if new_xyz is None or new_offset is None: + new_xyz = xyz + new_offset = offset + assert xyz.is_contiguous() and new_xyz.is_contiguous() + assert min_radius < max_radius + + m = new_xyz.shape[0] + order = [] + for k in range(offset.shape[0]): + s_k, e_k = (0, offset[0]) if k == 0 else (offset[k - 1], offset[k]) + order.append( + torch.randperm(e_k - s_k, dtype=torch.int32, device=offset.device) + s_k + ) + order = torch.cat(order, dim=0) + idx = torch.cuda.IntTensor(m, nsample).zero_() + dist2 = torch.cuda.FloatTensor(m, nsample).zero_() + random_ball_query_cuda( + m, + nsample, + min_radius, + max_radius, + order, + xyz, + new_xyz, + offset.int(), + new_offset.int(), + idx, + dist2, + ) + return idx, torch.sqrt(dist2) + + +class BallQuery(Function): + """Ball Query. + + Find nearby points in spherical space. + """ + + @staticmethod + def forward( + ctx, nsample, max_radius, min_radius, xyz, offset, new_xyz=None, new_offset=None + ): + """ + input: coords: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b) + output: idx: (m, nsample), dist2: (m, nsample) + """ + if new_xyz is None or new_offset is None: + new_xyz = xyz + new_offset = offset + assert xyz.is_contiguous() and new_xyz.is_contiguous() + assert min_radius < max_radius + + m = new_xyz.shape[0] + idx = torch.cuda.IntTensor(m, nsample).zero_() + dist2 = torch.cuda.FloatTensor(m, nsample).zero_() + ball_query_cuda( + m, + nsample, + min_radius, + max_radius, + xyz, + new_xyz, + offset.int(), + new_offset.int(), + idx, + dist2, + ) + return idx, torch.sqrt(dist2) + + +knn_query = KNNQuery.apply +ball_query = BallQuery.apply +random_ball_query = RandomBallQuery.apply diff --git a/submodules/pointops/functions/sampling.py b/submodules/pointops/functions/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..9f233d4afe02e43a6a390ca465f7108a01b98541 --- /dev/null +++ b/submodules/pointops/functions/sampling.py @@ -0,0 +1,27 @@ +import torch +from torch.autograd import Function + +from pointops._C import farthest_point_sampling_cuda + + +class FarthestPointSampling(Function): + @staticmethod + def forward(ctx, xyz, offset, new_offset): + """ + input: coords: (n, 3), offset: (b), new_offset: (b) + output: idx: (m) + """ + assert xyz.is_contiguous() + n, b, n_max = xyz.shape[0], offset.shape[0], offset[0] + for i in range(1, b): + n_max = max(offset[i] - offset[i - 1], n_max) + idx = torch.cuda.IntTensor(new_offset[b - 1].item()).zero_() + tmp = torch.cuda.FloatTensor(n).fill_(1e10) + farthest_point_sampling_cuda( + b, n_max, xyz, offset.int(), new_offset.int(), tmp, idx + ) + del tmp + return idx + + +farthest_point_sampling = FarthestPointSampling.apply diff --git a/submodules/pointops/functions/subtraction.py b/submodules/pointops/functions/subtraction.py new file mode 100644 index 0000000000000000000000000000000000000000..bc683ce3d75901777e57886adc077d570230e027 --- /dev/null +++ b/submodules/pointops/functions/subtraction.py @@ -0,0 +1,38 @@ +import torch +from torch.autograd import Function + +from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda + + +class Subtraction(Function): + @staticmethod + def forward(ctx, input1, input2, idx): + """ + input: input1: (n, c), input2: (n, c), idx: (n, nsample) + output: (n, nsample, c) + """ + assert input1.is_contiguous() and input2.is_contiguous() + n, c = input1.shape + nsample = idx.shape[-1] + output = torch.cuda.FloatTensor(n, nsample, c).zero_() + subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) + ctx.save_for_backward(idx) + return output + + @staticmethod + def backward(ctx, grad_output): + """ + input: grad_out: (n, nsample, c) + output: grad_input1: (n, c), grad_input2: (n, c) + """ + (idx,) = ctx.saved_tensors + n, nsample, c = grad_output.shape + grad_input1 = torch.cuda.FloatTensor(n, c).zero_() + grad_input2 = torch.cuda.FloatTensor(n, c).zero_() + subtraction_backward_cuda( + n, nsample, c, idx, grad_output, grad_input1, grad_input2 + ) + return grad_input1, grad_input2, None + + +subtraction = Subtraction.apply diff --git a/submodules/pointops/functions/utils.py b/submodules/pointops/functions/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15e3e328bb012bb684787466f3ec2e97d1317b2b --- /dev/null +++ b/submodules/pointops/functions/utils.py @@ -0,0 +1,121 @@ +import torch +from pointops import knn_query, ball_query, grouping + + +def knn_query_and_group( + feat, + xyz, + offset=None, + new_xyz=None, + new_offset=None, + idx=None, + nsample=None, + with_xyz=False, +): + if idx is None: + assert nsample is not None + idx, _ = knn_query(nsample, xyz, offset, new_xyz, new_offset) + return grouping(idx, feat, xyz, new_xyz, with_xyz), idx + + +def ball_query_and_group( + feat, + xyz, + offset=None, + new_xyz=None, + new_offset=None, + idx=None, + max_radio=None, + min_radio=0, + nsample=None, + with_xyz=False, +): + if idx is None: + assert nsample is not None and offset is not None + assert max_radio is not None and min_radio is not None + idx, _ = ball_query( + nsample, max_radio, min_radio, xyz, offset, new_xyz, new_offset + ) + return grouping(idx, feat, xyz, new_xyz, with_xyz), idx + + +def query_and_group( + nsample, + xyz, + new_xyz, + feat, + idx, + offset, + new_offset, + dilation=0, + with_feat=True, + with_xyz=True, +): + """ + input: coords: (n, 3), new_xyz: (m, 3), color: (n, c), idx: (m, nsample), offset: (b), new_offset: (b) + output: new_feat: (m, nsample, c+3), grouped_idx: (m, nsample) + """ + assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() + if new_xyz is None: + new_xyz = xyz + + if idx is None: + num_samples_total = 1 + (nsample - 1) * (dilation + 1) + # num points in a batch might < num_samples_total => [n1, n2, ..., nk, ns, ns, ns, ...] + idx_no_dilation, _ = knn_query( + num_samples_total, xyz, offset, new_xyz, new_offset + ) # (m, nsample * (d + 1)) + idx = [] + batch_end = offset.tolist() + batch_start = [0] + batch_end[:-1] + new_batch_end = new_offset.tolist() + new_batch_start = [0] + new_batch_end[:-1] + for i in range(offset.shape[0]): + if batch_end[i] - batch_start[i] < num_samples_total: + soft_dilation = (batch_end[i] - batch_start[i] - 1) / (nsample - 1) - 1 + else: + soft_dilation = dilation + idx.append( + idx_no_dilation[ + new_batch_start[i] : new_batch_end[i], + [int((soft_dilation + 1) * i) for i in range(nsample)], + ] + ) + idx = torch.cat(idx, dim=0) + + if not with_feat: + return idx + + n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1] + grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3) + # grouped_xyz = grouping(coords, idx) # (m, nsample, 3) + grouped_xyz -= new_xyz.unsqueeze(1) # (m, nsample, 3) + grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c) + # grouped_feat = grouping(color, idx) # (m, nsample, c) + + if with_xyz: + return torch.cat((grouped_xyz, grouped_feat), -1), idx # (m, nsample, 3+c) + else: + return grouped_feat, idx + + +def offset2batch(offset): + return ( + torch.cat( + [ + ( + torch.tensor([i] * (o - offset[i - 1])) + if i > 0 + else torch.tensor([i] * o) + ) + for i, o in enumerate(offset) + ], + dim=0, + ) + .long() + .to(offset.device) + ) + + +def batch2offset(batch): + return torch.cumsum(batch.bincount(), dim=0).int() diff --git a/submodules/pointops/setup.py b/submodules/pointops/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..0cdf07b6c12bf702b40accbb51fd1825e4050a8b --- /dev/null +++ b/submodules/pointops/setup.py @@ -0,0 +1,33 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +from distutils.sysconfig import get_config_vars + +(opt,) = get_config_vars("OPT") +os.environ["OPT"] = " ".join( + flag for flag in opt.split() if flag != "-Wstrict-prototypes" +) + +src = "src" +sources = [ + os.path.join(root, file) + for root, dirs, files in os.walk(src) + for file in files + if file.endswith(".cpp") or file.endswith(".cu") +] + +setup( + name="pointops", + version="1.0", + install_requires=["torch", "numpy"], + packages=["pointops"], + package_dir={"pointops": "functions"}, + ext_modules=[ + CUDAExtension( + name="pointops._C", + sources=sources, + extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]}, + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/submodules/pointops/src/__init__.py b/submodules/pointops/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/submodules/pointops/src/aggregation/aggregation_cuda.cpp b/submodules/pointops/src/aggregation/aggregation_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..491b6f41660edf9b5ea5656cc88edba8ed807d71 --- /dev/null +++ b/submodules/pointops/src/aggregation/aggregation_cuda.cpp @@ -0,0 +1,28 @@ +#include +#include +#include +#include "aggregation_cuda_kernel.h" + + +void aggregation_forward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) +{ + const float *input = input_tensor.data_ptr(); + const float *position = position_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + aggregation_forward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, output); +} + +void aggregation_backward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input_tensor, at::Tensor grad_position_tensor, at::Tensor grad_weight_tensor) +{ + const float *input = input_tensor.data_ptr(); + const float *position = position_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + const float *grad_output = grad_output_tensor.data_ptr(); + float *grad_input = grad_input_tensor.data_ptr(); + float *grad_position = grad_position_tensor.data_ptr(); + float *grad_weight = grad_weight_tensor.data_ptr(); + aggregation_backward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight); +} diff --git a/submodules/pointops/src/aggregation/aggregation_cuda_kernel.cu b/submodules/pointops/src/aggregation/aggregation_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..8339bb7e2088abffefba02c26b248edafed6cf47 --- /dev/null +++ b/submodules/pointops/src/aggregation/aggregation_cuda_kernel.cu @@ -0,0 +1,53 @@ +#include "../cuda_utils.h" +#include "aggregation_cuda_kernel.h" + + +__global__ void aggregation_forward_cuda_kernel(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output) { + // input: input: (n, c), position: (n, nsample, c), weight: (n, nsample, w_c), idx: (n, nsample), output: (n, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * c) return; + const int c_idx = index % c; + const int n_idx = index / c; + const int w_c_idx = c_idx % w_c; + for (int nsample_idx = 0; nsample_idx < nsample; nsample_idx++) + { + int idx_idx = n_idx * nsample + nsample_idx; + int input_idx = idx[idx_idx] * c + c_idx; + int position_idx = n_idx * nsample * c + nsample_idx * c + c_idx; + int weight_idx = n_idx * nsample * w_c + nsample_idx * w_c + w_c_idx; + output[index] += (input[input_idx] + position[position_idx]) * weight[weight_idx]; + } +} + +__global__ void aggregation_backward_cuda_kernel(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight) { + // input: grad_output: (n, c), output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight: (n, nsample, w_c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * c) return; + const int c_idx = index % c; + const int n_idx = index / c; + const int w_c_idx = c_idx % w_c; + for (int nsample_idx = 0; nsample_idx < nsample; nsample_idx++) + { + int idx_idx = n_idx * nsample + nsample_idx; + int input_idx = idx[idx_idx] * c + c_idx; + int position_idx = n_idx * nsample * c + nsample_idx * c + c_idx; + int weight_idx = n_idx * nsample * w_c + nsample_idx * w_c + w_c_idx; + atomicAdd(grad_input + input_idx, grad_output[index] * weight[weight_idx]); + grad_position[position_idx] = grad_output[index] * weight[weight_idx]; + atomicAdd(grad_weight + weight_idx, grad_output[index] * (input[input_idx] + position[position_idx])); + } +} + +void aggregation_forward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output) { + // input: input: (n, c), position: (n, nsample, c), weight: (n, nsample, w_c), idx: (n, nsample), output: (n, c) + dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + aggregation_forward_cuda_kernel<<>>(n, nsample, c, w_c, input, position, weight, idx, output); +} + +void aggregation_backward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight) { + // input: grad_output: (n, c), output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight: (n, nsample, w_c) + dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + aggregation_backward_cuda_kernel<<>>(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight); +} diff --git a/submodules/pointops/src/aggregation/aggregation_cuda_kernel.h b/submodules/pointops/src/aggregation/aggregation_cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5211a96aa2acbe0d9baf32bddc9ab4be87703072 --- /dev/null +++ b/submodules/pointops/src/aggregation/aggregation_cuda_kernel.h @@ -0,0 +1,20 @@ +#ifndef _AGGREGATION_CUDA_KERNEL +#define _AGGREGATION_CUDA_KERNEL +#include +#include +#include + +void aggregation_forward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); +void aggregation_backward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input_tensor, at::Tensor grad_position_tensor, at::Tensor grad_weight_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void aggregation_forward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, float *output); +void aggregation_backward_cuda_launcher(int n, int nsample, int c, int w_c, const float *input, const float *position, const float *weight, const int *idx, const float *grad_output, float *grad_input, float *grad_position, float *grad_weight); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/submodules/pointops/src/attention/attention_cuda.cpp b/submodules/pointops/src/attention/attention_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..79b90c7ebc3ed85dc389bc4ae3169a086efc5848 --- /dev/null +++ b/submodules/pointops/src/attention/attention_cuda.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include "attention_cuda_kernel.h" + + +void attention_relation_step_forward_cuda(int m, int g, int c, + at::Tensor query_tensor, at::Tensor key_tensor, at::Tensor weight_tensor, + at::Tensor index_target_tensor, at::Tensor index_refer_tensor, + at::Tensor output_tensor) +{ + const float *query = query_tensor.data_ptr(); + const float *key = key_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + const int *index_target = index_target_tensor.data_ptr(); + const int *index_refer = index_refer_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + attention_relation_step_forward_cuda_launcher(m, g, c, query, key, weight, index_target, index_refer, output); +} + +void attention_relation_step_backward_cuda(int m, int g, int c, + at::Tensor query_tensor, at::Tensor grad_query_tensor, + at::Tensor key_tensor, at::Tensor grad_key_tensor, + at::Tensor weight_tensor, at::Tensor grad_weight_tensor, + at::Tensor index_target_tensor, at::Tensor index_refer_tensor, + at::Tensor grad_output_tensor) +{ + const float *query = query_tensor.data_ptr(); + float *grad_query = grad_query_tensor.data_ptr(); + const float *key = key_tensor.data_ptr(); + float *grad_key = grad_key_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + float *grad_weight = grad_weight_tensor.data_ptr(); + const int *index_target = index_target_tensor.data_ptr(); + const int *index_refer = index_refer_tensor.data_ptr(); + const float *grad_output = grad_output_tensor.data_ptr(); + attention_relation_step_backward_cuda_launcher(m, g, c, + query, grad_query, + key, grad_key, + weight, grad_weight, + index_target, index_refer, grad_output); +} + + +void attention_fusion_step_forward_cuda(int m, int g, int c, + at::Tensor weight_tensor, at::Tensor value_tensor, + at::Tensor index_target_tensor, at::Tensor index_refer_tensor, + at::Tensor output_tensor) +{ + const float *weight = weight_tensor.data_ptr(); + const float *value = value_tensor.data_ptr(); + const int *index_target = index_target_tensor.data_ptr(); + const int *index_refer = index_refer_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + attention_fusion_step_forward_cuda_launcher(m, g, c, weight, value, index_target, index_refer, output); +} + + +void attention_fusion_step_backward_cuda(int m, int g, int c, + at::Tensor weight_tensor, at::Tensor grad_weight_tensor, + at::Tensor value_tensor, at::Tensor grad_value_tensor, + at::Tensor index_target_tensor, at::Tensor index_refer_tensor, + at::Tensor grad_output_tensor) +{ + const float *weight = weight_tensor.data_ptr(); + float *grad_weight = grad_weight_tensor.data_ptr(); + const float *value = value_tensor.data_ptr(); + float *grad_value = grad_value_tensor.data_ptr(); + const int *index_target = index_target_tensor.data_ptr(); + const int *index_refer = index_refer_tensor.data_ptr(); + const float *grad_output = grad_output_tensor.data_ptr(); + attention_fusion_step_backward_cuda_launcher(m, g, c, + weight, grad_weight, + value, grad_value, + index_target, index_refer, grad_output); +} diff --git a/submodules/pointops/src/attention/attention_cuda_kernel.cu b/submodules/pointops/src/attention/attention_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..05f4544a4dc4da584ad70eece75265d4845171e7 --- /dev/null +++ b/submodules/pointops/src/attention/attention_cuda_kernel.cu @@ -0,0 +1,149 @@ +#include "../cuda_utils.h" +#include "attention_cuda_kernel.h" + + +/* +Kernels +*/ + +__global__ void attention_relation_step_forward_cuda_kernel(int m, int g, int c, + const float *query, const float *key, const float *weight, + const int *index_target, const int *index_refer, + float *output) +{ + int r_idx = blockIdx.x * blockDim.x + threadIdx.x; + int g_idx = blockIdx.y; + int c_idx = blockIdx.z; + + if (r_idx >= m || g_idx >= g || c_idx >= c) return; + int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; + int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; + + float r = query[q_idx] * key[k_idx] * weight[c_idx]; + atomicAdd(output + r_idx * g + g_idx, r); +} + +__global__ void attention_relation_step_backward_cuda_kernel(int m, int g, int c, + const float *query, float *grad_query, + const float *key, float *grad_key, + const float *weight, float *grad_weight, + const int *index_target, const int *index_refer, + const float *grad_output) +{ + int r_idx = blockIdx.x * blockDim.x + threadIdx.x; + int g_idx = blockIdx.y; + int c_idx = blockIdx.z; + + if (r_idx >= m || g_idx >= g || c_idx >= c) return; + + int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; + int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; + int o_idx = r_idx * g + g_idx; + float grad_r = grad_output[o_idx]; + atomicAdd(grad_query + q_idx, grad_r * key[k_idx] * weight[c_idx]); + atomicAdd(grad_key + k_idx, grad_r * query[q_idx] * weight[c_idx]); + atomicAdd(grad_weight + c_idx, grad_r * key[k_idx] * query[q_idx]); +} + + +__global__ void attention_fusion_step_forward_cuda_kernel(int m, int g, int c, + const float *weight, const float *value, + const int *index_target, const int *index_refer, + float *output) +{ + int r_idx = blockIdx.x * blockDim.x + threadIdx.x; + int g_idx = blockIdx.y; + int c_idx = blockIdx.z; + + if (r_idx >= m || g_idx >= g || c_idx >= c) return; + + int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; + int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; + + float f = weight[r_idx * g + g_idx] * value[v_idx]; + atomicAdd(output + o_idx, f); +} + + +__global__ void attention_fusion_step_backward_cuda_kernel(int m, int g, int c, + const float *weight, float *grad_weight, + const float *value, float *grad_value, + const int *index_target, const int *index_refer, + const float *grad_output) +{ + int r_idx = blockIdx.x * blockDim.x + threadIdx.x; + int g_idx = blockIdx.y; + int c_idx = blockIdx.z; + + if (r_idx >= m || g_idx >= g || c_idx >= c) return; + + int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; + int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; + int w_idx = r_idx * g + g_idx; + float grad = grad_output[o_idx]; + atomicAdd(grad_weight + w_idx, grad * value[v_idx]); + atomicAdd(grad_value + v_idx, grad * weight[w_idx]); +} + +/* +Launchers +*/ + + +void attention_relation_step_forward_cuda_launcher(int m, int g, int c, + const float *query, const float *key, const float *weight, + const int *index_target, const int *index_refer, + float *output) +{ + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); + dim3 threads(THREADS_PER_BLOCK); + attention_relation_step_forward_cuda_kernel<<>>(m, g, c, query, key, weight, + index_target, index_refer, output); +} + +void attention_relation_step_backward_cuda_launcher(int m, int g, int c, + const float *query, float *grad_query, + const float *key, float *grad_key, + const float *weight, float *grad_weight, + const int *index_target, const int *index_refer, + const float *grad_output) +{ + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); + dim3 threads(THREADS_PER_BLOCK); + attention_relation_step_backward_cuda_kernel<<>>(m, g, c, + query, grad_query, + key, grad_key, + weight, grad_weight, + index_target, index_refer, + grad_output); +} + + +void attention_fusion_step_forward_cuda_launcher(int m, int g, int c, + const float *weight, const float *value, + const int *index_target, const int *index_refer, + float *output) +{ + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); + dim3 threads(THREADS_PER_BLOCK); + attention_fusion_step_forward_cuda_kernel<<>>(m, g, c, weight, value, + index_target, index_refer, output); +} + + +void attention_fusion_step_backward_cuda_launcher(int m, int g, int c, + const float *weight, float *grad_weight, + const float *value, float *grad_value, + const int *index_target, const int *index_refer, + const float *grad_output) +{ + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); + dim3 threads(THREADS_PER_BLOCK); + attention_fusion_step_backward_cuda_kernel<<>>(m, g, c, + weight, grad_weight, + value, grad_value, + index_target, index_refer, + grad_output); +} + + diff --git a/submodules/pointops/src/attention/attention_cuda_kernel.h b/submodules/pointops/src/attention/attention_cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..fec965c0415c4cb5c64fd10e441b6a4c6a6c9ae9 --- /dev/null +++ b/submodules/pointops/src/attention/attention_cuda_kernel.h @@ -0,0 +1,54 @@ +#ifndef _ATTENTION_CUDA_KERNEL +#define _ATTENTION_CUDA_KERNEL +#include +#include +#include + +void attention_relation_step_forward_cuda(int m, int g, int c, + at::Tensor query_tensor, at::Tensor key_tensor, at::Tensor weight_tensor, + at::Tensor index_target_tensor, at::Tensor index_refer_tensor, + at::Tensor output_tensor); +void attention_relation_step_backward_cuda(int m, int g, int c, + at::Tensor query_tensor, at::Tensor grad_query_tensor, + at::Tensor key_tensor, at::Tensor grad_key_tensor, + at::Tensor weight_tensor, at::Tensor grad_weight_tensor, + at::Tensor index_target_tensor, at::Tensor index_refer_tensor, + at::Tensor grad_output_tensor); +void attention_fusion_step_forward_cuda(int m, int g, int c, + at::Tensor weight_tensor, at::Tensor value_tensor, + at::Tensor index_target_tensor, at::Tensor index_refer_tensor, + at::Tensor output_tensor); +void attention_fusion_step_backward_cuda(int m, int g, int c, + at::Tensor weight_tensor, at::Tensor grad_weight_tensor, + at::Tensor value_tensor, at::Tensor grad_value_tensor, + at::Tensor index_target_tensor, at::Tensor index_refer_tensor, + at::Tensor grad_output_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void attention_relation_step_forward_cuda_launcher(int m, int g, int c, + const float *query, const float *key, const float *weight, + const int *index_target, const int *index_refer, + float *output); +void attention_relation_step_backward_cuda_launcher(int m, int g, int c, + const float *query, float *grad_query, + const float *key, float *grad_key, + const float *weight, float *grad_weight, + const int *index_target, const int *index_refer, + const float *grad_output); +void attention_fusion_step_forward_cuda_launcher(int m, int g, int c, + const float *weight, const float *value, + const int *index_target, const int *index_refer, + float *output); +void attention_fusion_step_backward_cuda_launcher(int m, int g, int c, + const float *weight, float *grad_weight, + const float *value, float *grad_value, + const int *index_target, const int *index_refer, + const float *grad_output); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/submodules/pointops/src/ball_query/ball_query_cuda.cpp b/submodules/pointops/src/ball_query/ball_query_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..04cd5ff9e8e39c006222d5651f3aae70ce2e35c9 --- /dev/null +++ b/submodules/pointops/src/ball_query/ball_query_cuda.cpp @@ -0,0 +1,20 @@ +#include +#include +#include +#include "ball_query_cuda_kernel.h" + + +void ball_query_cuda(int m, int nsample, + float min_radius, float max_radius, + at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, + at::Tensor offset_tensor, at::Tensor new_offset_tensor, + at::Tensor idx_tensor, at::Tensor dist2_tensor) +{ + const float *xyz = xyz_tensor.data_ptr(); + const float *new_xyz = new_xyz_tensor.data_ptr(); + const int *offset = offset_tensor.data_ptr(); + const int *new_offset = new_offset_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + float *dist2 = dist2_tensor.data_ptr(); + ball_query_cuda_launcher(m, nsample, min_radius, max_radius, xyz, new_xyz, offset, new_offset, idx, dist2); +} diff --git a/submodules/pointops/src/ball_query/ball_query_cuda_kernel.cu b/submodules/pointops/src/ball_query/ball_query_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7b3d95a9835f607798f0d63e2b66ddb3af9032da --- /dev/null +++ b/submodules/pointops/src/ball_query/ball_query_cuda_kernel.cu @@ -0,0 +1,190 @@ +#include "../cuda_utils.h" +#include "ball_query_cuda_kernel.h" + + +namespace ball_query_utils{ + +template +__device__ void swap(DType *x, DType *y) +{ + DType tmp = *x; + *x = *y; + *y = tmp; +} + +__device__ void reheap(float *dist, int *idx, int k) +{ + int root = 0; + int child = root * 2 + 1; + while (child < k) + { + if(child + 1 < k && dist[child+1] > dist[child]) + child++; + if(dist[root] > dist[child]) + return; + swap(&dist[root], &dist[child]); + swap(&idx[root], &idx[child]); + root = child; + child = root * 2 + 1; + } +} + + +__device__ void heap_sort(float *dist, int *idx, int k) +{ + int i; + for (i = k - 1; i > 0; i--) + { + swap(&dist[0], &dist[i]); + swap(&idx[0], &idx[i]); + reheap(dist, idx, i); + } +} + +__device__ int get_bt_idx(int idx, const int *offset) +{ + int i = 0; + while (1) + { + if (idx < offset[i]) + break; + else + i++; + } + return i; +} +} // namespace ball_query_utils + +__global__ void ball_query_cuda_kernel(int m, int nsample, + float min_radius, float max_radius, + const float *__restrict__ xyz, const float *__restrict__ new_xyz, + const int *__restrict__ offset, const int *__restrict__ new_offset, + int *__restrict__ idx, float *__restrict__ dist2) { + // input: xyz (n, 3) new_xyz (m, 3) + // output: idx (m, nsample) dist (m, nsample) + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= m) return; + + new_xyz += pt_idx * 3; + idx += pt_idx * nsample; + dist2 += pt_idx * nsample; + + int bt_idx = ball_query_utils::get_bt_idx(pt_idx, new_offset); + int start; + if (bt_idx == 0) + start = 0; + else + start = offset[bt_idx - 1]; + int end = offset[bt_idx]; + + float max_radius2 = max_radius * max_radius; + float min_radius2 = min_radius * min_radius; + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + + float candi_dist[2048]; + int candi_idx[2048]; + int candi_num = 0; + + for(int i = start; i < end; i++){ + float x = xyz[i * 3 + 0]; + float y = xyz[i * 3 + 1]; + float z = xyz[i * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); + + if (d2 <= 1e-5 || (d2 >= min_radius2 && d2 < max_radius2)){ + // TODO: Check d2 <= 1e-5 + candi_dist[candi_num] = d2; + candi_idx[candi_num] = i; + candi_num += 1; + } + } + ball_query_utils::heap_sort(candi_dist, candi_idx, candi_num); + if(candi_num <= nsample){ + for(int i = 0; i < candi_num; i++){ + idx[i] = candi_idx[i]; + dist2[i] = candi_dist[i]; + } + for(int i = candi_num; i < nsample; i++){ + idx[i] = -1; + dist2[i] = 1e10; + } + } + else{ + float sep = static_cast(candi_num) / nsample; + for(int i = 0; i < nsample; i++) + { + int index = static_cast(sep * i); + idx[i] = candi_idx[index]; + dist2[i] = candi_idx[index]; + } + } +} + +/* Random Sample Mode Ball Query */ + +// __global__ void ball_query_cuda_kernel(int m, int nsample, +// float min_radius, float max_radius, +// const float *__restrict__ xyz, const float *__restrict__ new_xyz, +// const int *__restrict__ offset, const int *__restrict__ new_offset, +// int *__restrict__ idx, float *__restrict__ dist2) { +// // input: xyz (n, 3) new_xyz (m, 3) +// // output: idx (m, nsample) dist (m, nsample) +// int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; +// if (pt_idx >= m) return; +// +// new_xyz += pt_idx * 3; +// idx += pt_idx * nsample; +// dist2 += pt_idx * nsample; +// +// int bt_idx = ball_get_bt_idx(pt_idx, new_offset); +// int start; +// if (bt_idx == 0) +// start = 0; +// else +// start = offset[bt_idx - 1]; +// int end = offset[bt_idx]; +// +// float max_radius2 = max_radius * max_radius; +// float min_radius2 = min_radius * min_radius; +// float new_x = new_xyz[0]; +// float new_y = new_xyz[1]; +// float new_z = new_xyz[2]; +// +// int cnt = 0; +// for(int i = start; i < end; i++){ +// float x = xyz[i * 3 + 0]; +// float y = xyz[i * 3 + 1]; +// float z = xyz[i * 3 + 2]; +// float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); +// +// if (d2 == 0 || (d2 >= min_radius2 && d2 < max_radius2)) { +// if (cnt == 0) { +// for (int l = 0; l < nsample; ++l) { +// idx[l] = i; +// dist2[l] = d2; +// } +// } +// idx[cnt] = i; +// ++cnt; +// if (cnt >= nsample) break; +// } +// } +// } + + +void ball_query_cuda_launcher(int m, int nsample, + float min_radius, float max_radius, + const float *xyz, const float *new_xyz, + const int *offset, const int *new_offset, + int *idx, float *dist2) { + // input: new_xyz: (m, 3), xyz: (n, 3), idx: (m, nsample) + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + ball_query_cuda_kernel<<>>(m, nsample, + min_radius, max_radius, + xyz, new_xyz, + offset, new_offset, + idx, dist2); +} diff --git a/submodules/pointops/src/ball_query/ball_query_cuda_kernel.h b/submodules/pointops/src/ball_query/ball_query_cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..03007a285a3559da85d099681f1316915e1d31b1 --- /dev/null +++ b/submodules/pointops/src/ball_query/ball_query_cuda_kernel.h @@ -0,0 +1,26 @@ +#ifndef _BALL_QUERY_CUDA_KERNEL +#define _BALL_QUERY_CUDA_KERNEL +#include +#include +#include + +void ball_query_cuda(int m, int nsample, + float min_radius, float max_radius, + at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, + at::Tensor offset_tensor, at::Tensor new_offset_tensor, + at::Tensor idx_tensor, at::Tensor dist2_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void ball_query_cuda_launcher(int m, int nsample, + float min_radius, float max_radius, + const float *xyz, const float *new_xyz, + const int *offset, const int *new_offset, + int *idx, float *dist2); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/submodules/pointops/src/cuda_utils.h b/submodules/pointops/src/cuda_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..bbfe7a06bf989056c0bd99e3e64fdbe7d15bb093 --- /dev/null +++ b/submodules/pointops/src/cuda_utils.h @@ -0,0 +1,23 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include +#include + +#define TOTAL_THREADS 1024 +#define THREADS_PER_BLOCK 512 +#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) + +inline int opt_n_threads(int work_size) { + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); +} + +inline dim3 opt_block_config(int x, int y) { + const int x_threads = opt_n_threads(x); + const int y_threads = std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + return block_config; +} + +#endif diff --git a/submodules/pointops/src/grouping/grouping_cuda.cpp b/submodules/pointops/src/grouping/grouping_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f7990adaf43f0a77050eed0d55adad19f256e10 --- /dev/null +++ b/submodules/pointops/src/grouping/grouping_cuda.cpp @@ -0,0 +1,21 @@ +#include +#include +#include +#include "grouping_cuda_kernel.h" + + +void grouping_forward_cuda(int m, int nsample, int c, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) +{ + const float *input = input_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + grouping_forward_cuda_launcher(m, nsample, c, input, idx, output); +} + +void grouping_backward_cuda(int m, int nsample, int c, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor grad_input_tensor) +{ + const float *grad_output = grad_output_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + float *grad_input = grad_input_tensor.data_ptr(); + grouping_backward_cuda_launcher(m, nsample, c, grad_output, idx, grad_input); +} diff --git a/submodules/pointops/src/grouping/grouping_cuda_kernel.cu b/submodules/pointops/src/grouping/grouping_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..58ec0a21a2949f9f82504ccd24597c544c50af40 --- /dev/null +++ b/submodules/pointops/src/grouping/grouping_cuda_kernel.cu @@ -0,0 +1,40 @@ +#include "../cuda_utils.h" +#include "grouping_cuda_kernel.h" + + +__global__ void grouping_forward_cuda_kernel(int m, int nsample, int c, const float *__restrict__ input, const int *__restrict__ idx, float *__restrict__ output) { + // input: input: (n, c), idx: (m, nsample), output: (m, nsample, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= m * nsample * c) return; + const int c_idx = index % c; + const int nsample_idx = (index / c) % nsample; + const int m_idx = index / nsample / c; + const int input_idx = idx[m_idx * nsample + nsample_idx] * c + c_idx; + output[index] = input[input_idx]; +} + +__global__ void grouping_backward_cuda_kernel(int m, int nsample, int c, const float *__restrict__ grad_output, const int *__restrict__ idx, float *__restrict__ grad_input) { + // input: grad_output: (m, nsample, c), idx: (m, nsample), output: grad_input: (n, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= m * nsample * c) return; + const int c_idx = index % c; + const int nsample_idx = (index / c) % nsample; + const int m_idx = index / nsample / c; + const int input_idx = idx[m_idx * nsample + nsample_idx] * c + c_idx; + atomicAdd(grad_input + input_idx, grad_output[index]); +} + +void grouping_forward_cuda_launcher(int m, int nsample, int c, const float *input, const int *idx, float *output) { + // input: input: (n, c), idx: (m, nsample), output: (m, nsample, c) + dim3 blocks(DIVUP(m * nsample * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + grouping_forward_cuda_kernel<<>>(m, nsample, c, input, idx, output); +} + +void grouping_backward_cuda_launcher(int m, int nsample, int c, const float *grad_output, const int *idx, float *grad_input) +{ + // input: grad_output: (m, nsample, c), idx: (m, nsample), output: grad_input: (n, c) + dim3 blocks(DIVUP(m * nsample * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + grouping_backward_cuda_kernel<<>>(m, nsample, c, grad_output, idx, grad_input); +} diff --git a/submodules/pointops/src/grouping/grouping_cuda_kernel.h b/submodules/pointops/src/grouping/grouping_cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..3db4aaa9fad5811d559d47c500e4b00f0165d9b4 --- /dev/null +++ b/submodules/pointops/src/grouping/grouping_cuda_kernel.h @@ -0,0 +1,20 @@ +#ifndef _GROUPING_CUDA_KERNEL +#define _GROUPING_CUDA_KERNEL +#include +#include +#include + +void grouping_forward_cuda(int m, int nsample, int c, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); +void grouping_backward_cuda(int m, int nsample, int c, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor grad_input_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void grouping_forward_cuda_launcher(int m, int nsample, int c, const float *input, const int *idx, float *output); +void grouping_backward_cuda_launcher(int m, int nsample, int c, const float *grad_output, const int *idx, float *grad_input); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/submodules/pointops/src/interpolation/interpolation_cuda.cpp b/submodules/pointops/src/interpolation/interpolation_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f2c1b0078f4b70626705d7b3f5d1d65d37ee6de7 --- /dev/null +++ b/submodules/pointops/src/interpolation/interpolation_cuda.cpp @@ -0,0 +1,23 @@ +#include +#include +#include +#include "interpolation_cuda_kernel.h" + + +void interpolation_forward_cuda(int n, int c, int k, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor output_tensor) +{ + const float *input = input_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + interpolation_forward_cuda_launcher(n, c, k, input, idx, weight, output); +} + +void interpolation_backward_cuda(int n, int c, int k, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_input_tensor) +{ + const float *grad_output = grad_output_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + const float *weight = weight_tensor.data_ptr(); + float *grad_input = grad_input_tensor.data_ptr(); + interpolation_backward_cuda_launcher(n, c, k, grad_output, idx, weight, grad_input); +} diff --git a/submodules/pointops/src/interpolation/interpolation_cuda_kernel.cu b/submodules/pointops/src/interpolation/interpolation_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f560d8c92c6eac865b8c1e1dc27140fe3fcc2250 --- /dev/null +++ b/submodules/pointops/src/interpolation/interpolation_cuda_kernel.cu @@ -0,0 +1,47 @@ +#include "../cuda_utils.h" +#include "interpolation_cuda_kernel.h" + + +__global__ void interpolation_forward_cuda_kernel(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output) +{ + // input: input: (m, c), idx: (n, k), weight: (n, k), output: output (n, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * c) return; + int c_idx = index % c; + int n_idx = index / c; + for (int i = 0; i < k; i++) + { + int idx_idx = n_idx * k + i; + int input_idx = idx[idx_idx] * c + c_idx; + output[index] += input[input_idx] * weight[idx_idx]; + } +} + +__global__ void interpolation_backward_cuda_kernel(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input) +{ + // input: grad_output: (n, c), idx: (n, k), weight: (n, k), output: grad_input (m, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * c) return; + int c_idx = index % c; + int n_idx = index / c; + for (int i = 0; i < k; i++) + { + int idx_idx = n_idx * k + i; + int input_idx = idx[idx_idx] * c + c_idx; + atomicAdd(grad_input + input_idx, grad_output[index] * weight[idx_idx]); + } +} + +void interpolation_forward_cuda_launcher(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output) { + // input: input: (m, c), idx: (n, k), weight: (n, k), output: output (n, c) + dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + interpolation_forward_cuda_kernel<<>>(n, c, k, input, idx, weight, output); +} + +void interpolation_backward_cuda_launcher(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input) { + // input: grad_output: (n, c), idx: (n, k), weight: (n, k), output: grad_input (m, c) + dim3 blocks(DIVUP(n * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + interpolation_backward_cuda_kernel<<>>(n, c, k, grad_output, idx, weight, grad_input); +} diff --git a/submodules/pointops/src/interpolation/interpolation_cuda_kernel.h b/submodules/pointops/src/interpolation/interpolation_cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..309e5dd0a34ccb58807bbf32389ba65e7ee6961b --- /dev/null +++ b/submodules/pointops/src/interpolation/interpolation_cuda_kernel.h @@ -0,0 +1,20 @@ +#ifndef _INTERPOLATION_CUDA_KERNEL +#define _INTERPOLATION_CUDA_KERNEL +#include +#include +#include + +void interpolation_forward_cuda(int n, int c, int k, at::Tensor input_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor output_tensor); +void interpolation_backward_cuda(int n, int c, int k, at::Tensor grad_output_tensor, at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_input_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void interpolation_forward_cuda_launcher(int n, int c, int k, const float *input, const int *idx, const float *weight, float *output); +void interpolation_backward_cuda_launcher(int n, int c, int k, const float *grad_output, const int *idx, const float *weight, float *grad_input); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/submodules/pointops/src/knn_query/knn_query_cuda.cpp b/submodules/pointops/src/knn_query/knn_query_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bbe841ce0352fd234143b3b4978ec001522b31dd --- /dev/null +++ b/submodules/pointops/src/knn_query/knn_query_cuda.cpp @@ -0,0 +1,16 @@ +#include +#include +#include +#include "knn_query_cuda_kernel.h" + + +void knn_query_cuda(int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor) +{ + const float *xyz = xyz_tensor.data_ptr(); + const float *new_xyz = new_xyz_tensor.data_ptr(); + const int *offset = offset_tensor.data_ptr(); + const int *new_offset = new_offset_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + float *dist2 = dist2_tensor.data_ptr(); + knn_query_cuda_launcher(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2); +} diff --git a/submodules/pointops/src/knn_query/knn_query_cuda_kernel.cu b/submodules/pointops/src/knn_query/knn_query_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..297740237eae98cc4e61421bc261755d79b83142 --- /dev/null +++ b/submodules/pointops/src/knn_query/knn_query_cuda_kernel.cu @@ -0,0 +1,112 @@ +#include "../cuda_utils.h" +#include "knn_query_cuda_kernel.h" + + +namespace knn_query_utils{ + +template +__device__ void swap(DType *x, DType *y) +{ + DType tmp = *x; + *x = *y; + *y = tmp; +} + +__device__ void reheap(float *dist, int *idx, int k) +{ + int root = 0; + int child = root * 2 + 1; + while (child < k) + { + if(child + 1 < k && dist[child+1] > dist[child]) + child++; + if(dist[root] > dist[child]) + return; + swap(&dist[root], &dist[child]); + swap(&idx[root], &idx[child]); + root = child; + child = root * 2 + 1; + } +} + + +__device__ void heap_sort(float *dist, int *idx, int k) +{ + int i; + for (i = k - 1; i > 0; i--) + { + swap(&dist[0], &dist[i]); + swap(&idx[0], &idx[i]); + reheap(dist, idx, i); + } +} + + +__device__ int get_bt_idx(int idx, const int *offset) +{ + int i = 0; + while (1) + { + if (idx < offset[i]) + break; + else + i++; + } + return i; +} +} // namespace knn_query_utils + + +__global__ void knn_query_cuda_kernel(int m, int nsample, const float *__restrict__ xyz, const float *__restrict__ new_xyz, const int *__restrict__ offset, const int *__restrict__ new_offset, int *__restrict__ idx, float *__restrict__ dist2) { + // input: xyz (n, 3) new_xyz (m, 3) + // output: idx (m, nsample) dist2 (m, nsample) + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= m) return; + + new_xyz += pt_idx * 3; + idx += pt_idx * nsample; + dist2 += pt_idx * nsample; + + int bt_idx = knn_query_utils::get_bt_idx(pt_idx, new_offset); + int start; + if (bt_idx == 0) + start = 0; + else + start = offset[bt_idx - 1]; + int end = offset[bt_idx]; + + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + + float best_dist[128]; + int best_idx[128]; + for(int i = 0; i < nsample; i++){ + best_dist[i] = 1e10; + best_idx[i] = -1; + } + for(int i = start; i < end; i++){ + float x = xyz[i * 3 + 0]; + float y = xyz[i * 3 + 1]; + float z = xyz[i * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); + if (d2 < best_dist[0]){ + best_dist[0] = d2; + best_idx[0] = i; + knn_query_utils::reheap(best_dist, best_idx, nsample); + } + } + knn_query_utils::heap_sort(best_dist, best_idx, nsample); + for(int i = 0; i < nsample; i++){ + idx[i] = best_idx[i]; + dist2[i] = best_dist[i]; + } +} + + +void knn_query_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2) { + // input: new_xyz: (m, 3), xyz: (n, 3), idx: (m, nsample) + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + knn_query_cuda_kernel<<>>(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2); +} diff --git a/submodules/pointops/src/knn_query/knn_query_cuda_kernel.h b/submodules/pointops/src/knn_query/knn_query_cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c07c1cb46a56b7a37d55e25fb78816e034a8387e --- /dev/null +++ b/submodules/pointops/src/knn_query/knn_query_cuda_kernel.h @@ -0,0 +1,18 @@ +#ifndef _KNN_QUERY_CUDA_KERNEL +#define _KNN_QUERY_CUDA_KERNEL +#include +#include +#include + +void knn_query_cuda(int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void knn_query_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/submodules/pointops/src/pointops_api.cpp b/submodules/pointops/src/pointops_api.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5ca4377607eb181d48d458d700f1df876294a848 --- /dev/null +++ b/submodules/pointops/src/pointops_api.cpp @@ -0,0 +1,32 @@ +#include +#include + +#include "knn_query/knn_query_cuda_kernel.h" +#include "ball_query/ball_query_cuda_kernel.h" +#include "random_ball_query/random_ball_query_cuda_kernel.h" +#include "sampling/sampling_cuda_kernel.h" +#include "grouping/grouping_cuda_kernel.h" +#include "interpolation/interpolation_cuda_kernel.h" +#include "aggregation/aggregation_cuda_kernel.h" +#include "subtraction/subtraction_cuda_kernel.h" +#include "attention/attention_cuda_kernel.h" + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("knn_query_cuda", &knn_query_cuda, "knn_query_cuda"); + m.def("ball_query_cuda", &ball_query_cuda, "ball_query_cuda"); + m.def("random_ball_query_cuda", &random_ball_query_cuda, "random_ball_query_cuda"); + m.def("farthest_point_sampling_cuda", &farthest_point_sampling_cuda, "farthest_point_sampling_cuda"); + m.def("grouping_forward_cuda", &grouping_forward_cuda, "grouping_forward_cuda"); + m.def("grouping_backward_cuda", &grouping_backward_cuda, "grouping_backward_cuda"); + m.def("interpolation_forward_cuda", &interpolation_forward_cuda, "interpolation_forward_cuda"); + m.def("interpolation_backward_cuda", &interpolation_backward_cuda, "interpolation_backward_cuda"); + m.def("subtraction_forward_cuda", &subtraction_forward_cuda, "subtraction_forward_cuda"); + m.def("subtraction_backward_cuda", &subtraction_backward_cuda, "subtraction_backward_cuda"); + m.def("aggregation_forward_cuda", &aggregation_forward_cuda, "aggregation_forward_cuda"); + m.def("aggregation_backward_cuda", &aggregation_backward_cuda, "aggregation_backward_cuda"); + m.def("attention_relation_step_forward_cuda", &attention_relation_step_forward_cuda, "attention_relation_step_forward_cuda"); + m.def("attention_relation_step_backward_cuda", &attention_relation_step_backward_cuda, "attention_relation_step_backward_cuda"); + m.def("attention_fusion_step_forward_cuda", &attention_fusion_step_forward_cuda, "attention_fusion_step_forward_cuda"); + m.def("attention_fusion_step_backward_cuda", &attention_fusion_step_backward_cuda, "attention_fusion_step_backward_cuda"); +} diff --git a/submodules/pointops/src/random_ball_query/random_ball_query_cuda.cpp b/submodules/pointops/src/random_ball_query/random_ball_query_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2618c94b6b19175f044131cebeefe8a23152c47 --- /dev/null +++ b/submodules/pointops/src/random_ball_query/random_ball_query_cuda.cpp @@ -0,0 +1,21 @@ +#include +#include +#include +#include "random_ball_query_cuda_kernel.h" + + +void random_ball_query_cuda(int m, int nsample, + float min_radius, float max_radius, at::Tensor order_tensor, + at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, + at::Tensor offset_tensor, at::Tensor new_offset_tensor, + at::Tensor idx_tensor, at::Tensor dist2_tensor) +{ + const int *order = order_tensor.data_ptr(); + const float *xyz = xyz_tensor.data_ptr(); + const float *new_xyz = new_xyz_tensor.data_ptr(); + const int *offset = offset_tensor.data_ptr(); + const int *new_offset = new_offset_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + float *dist2 = dist2_tensor.data_ptr(); + random_ball_query_cuda_launcher(m, nsample, min_radius, max_radius, order, xyz, new_xyz, offset, new_offset, idx, dist2); +} diff --git a/submodules/pointops/src/random_ball_query/random_ball_query_cuda_kernel.cu b/submodules/pointops/src/random_ball_query/random_ball_query_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..bfafb0f8b731e201783c94144cad9de3e11228ad --- /dev/null +++ b/submodules/pointops/src/random_ball_query/random_ball_query_cuda_kernel.cu @@ -0,0 +1,123 @@ +#include "../cuda_utils.h" +#include "random_ball_query_cuda_kernel.h" + + +namespace random_ball_query_utils{ + +template +__device__ void swap(DType *x, DType *y) +{ + DType tmp = *x; + *x = *y; + *y = tmp; +} + +__device__ void reheap(float *dist, int *idx, int k) +{ + int root = 0; + int child = root * 2 + 1; + while (child < k) + { + if(child + 1 < k && dist[child+1] > dist[child]) + child++; + if(dist[root] > dist[child]) + return; + swap(&dist[root], &dist[child]); + swap(&idx[root], &idx[child]); + root = child; + child = root * 2 + 1; + } +} + + +__device__ void heap_sort(float *dist, int *idx, int k) +{ + int i; + for (i = k - 1; i > 0; i--) + { + swap(&dist[0], &dist[i]); + swap(&idx[0], &idx[i]); + reheap(dist, idx, i); + } +} + +__device__ int get_bt_idx(int idx, const int *offset) +{ + int i = 0; + while (1) + { + if (idx < offset[i]) + break; + else + i++; + } + return i; +} +} // namespace ball_query_utils + +__global__ void random_ball_query_cuda_kernel(int m, int nsample, + float min_radius, float max_radius, const int *__restrict__ order, + const float *__restrict__ xyz, const float *__restrict__ new_xyz, + const int *__restrict__ offset, const int *__restrict__ new_offset, + int *__restrict__ idx, float *__restrict__ dist2) { + // input: xyz (n, 3) new_xyz (m, 3) + // output: idx (m, nsample) dist (m, nsample) + int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (pt_idx >= m) return; + + new_xyz += pt_idx * 3; + idx += pt_idx * nsample; + dist2 += pt_idx * nsample; + + int bt_idx = random_ball_query_utils::get_bt_idx(pt_idx, new_offset); + int start; + if (bt_idx == 0) + start = 0; + else + start = offset[bt_idx - 1]; + int end = offset[bt_idx]; + + float max_radius2 = max_radius * max_radius; + float min_radius2 = min_radius * min_radius; + float new_x = new_xyz[0]; + float new_y = new_xyz[1]; + float new_z = new_xyz[2]; + + int cnt = 0; + + for(int i = start; i < end; i++){ + float x = xyz[order[i] * 3 + 0]; + float y = xyz[order[i] * 3 + 1]; + float z = xyz[order[i] * 3 + 2]; + float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); + + if (d2 <= 1e-5 || (d2 >= min_radius2 && d2 < max_radius2)){ + dist2[cnt] = d2; + idx[cnt] = order[i]; + cnt += 1; + if (cnt >= nsample) break; + } + } + + if (cnt < nsample) { + for (int i = cnt; i < nsample; i++){ + idx[i] = -1; + dist2[i] = 1e10; + } + } +} + +void random_ball_query_cuda_launcher(int m, int nsample, + float min_radius, float max_radius, const int *order, + const float *xyz, const float *new_xyz, + const int *offset, const int *new_offset, + int *idx, float *dist2) { + // input: new_xyz: (m, 3), xyz: (n, 3), idx: (m, nsample) + dim3 blocks(DIVUP(m, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + random_ball_query_cuda_kernel<<>>(m, nsample, + min_radius, max_radius, order, + xyz, new_xyz, + offset, new_offset, + idx, dist2); +} diff --git a/submodules/pointops/src/random_ball_query/random_ball_query_cuda_kernel.h b/submodules/pointops/src/random_ball_query/random_ball_query_cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d3e35be21933d95b50e9c42150067071502bbc1e --- /dev/null +++ b/submodules/pointops/src/random_ball_query/random_ball_query_cuda_kernel.h @@ -0,0 +1,26 @@ +#ifndef _RANDOM_BALL_QUERY_CUDA_KERNEL +#define _RANDOM_BALL_QUERY_CUDA_KERNEL +#include +#include +#include + +void random_ball_query_cuda(int m, int nsample, + float min_radius, float max_radius, at::Tensor order_tensor, + at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, + at::Tensor offset_tensor, at::Tensor new_offset_tensor, + at::Tensor idx_tensor, at::Tensor dist2_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void random_ball_query_cuda_launcher(int m, int nsample, + float min_radius, float max_radius, const int *order, + const float *xyz, const float *new_xyz, + const int *offset, const int *new_offset, + int *idx, float *dist2); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/submodules/pointops/src/sampling/sampling_cuda.cpp b/submodules/pointops/src/sampling/sampling_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7dc8094c3343f874457fd23d1506b25fd006fd0b --- /dev/null +++ b/submodules/pointops/src/sampling/sampling_cuda.cpp @@ -0,0 +1,15 @@ +#include +#include +#include +#include "sampling_cuda_kernel.h" + + +void farthest_point_sampling_cuda(int b, int n, at::Tensor xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor tmp_tensor, at::Tensor idx_tensor) +{ + const float *xyz = xyz_tensor.data_ptr(); + const int *offset = offset_tensor.data_ptr(); + const int *new_offset = new_offset_tensor.data_ptr(); + float *tmp = tmp_tensor.data_ptr(); + int *idx = idx_tensor.data_ptr(); + farthest_point_sampling_cuda_launcher(b, n, xyz, offset, new_offset, tmp, idx); +} diff --git a/submodules/pointops/src/sampling/sampling_cuda_kernel.cu b/submodules/pointops/src/sampling/sampling_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9a8676876672f68cd94913a0500d64813133b387 --- /dev/null +++ b/submodules/pointops/src/sampling/sampling_cuda_kernel.cu @@ -0,0 +1,171 @@ +#include "../cuda_utils.h" +#include "sampling_cuda_kernel.h" + + +__device__ void __update(float *dists, int *dists_i, int idx1, int idx2) { + const float v1 = dists[idx1], v2 = dists[idx2]; + const int i1 = dists_i[idx1], i2 = dists_i[idx2]; + dists[idx1] = max(v1, v2); + dists_i[idx1] = v2 > v1 ? i2 : i1; +} + +// input xyz: (n, 3), tmp: (b, n_max) +// ouput idx (m) +template +__global__ void farthest_point_sampling_cuda_kernel(const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) +{ + __shared__ float dists[block_size]; + __shared__ int dists_i[block_size]; + + int bid = blockIdx.x; + int start_n, end_n, start_m, end_m, old; + if (bid == 0) { + start_n = 0; + end_n = offset[0]; + start_m = 0; + end_m = new_offset[0]; + old = 0; + } + else { + start_n = offset[bid - 1]; + end_n = offset[bid]; + start_m = new_offset[bid - 1]; + end_m = new_offset[bid]; + old = offset[bid - 1]; + } + + const int stride = block_size; + int tid = threadIdx.x; + if (tid == 0) idx[start_m] = start_n; + + __syncthreads(); + for (int j = start_m + 1; j < end_m; j++) + { + int besti = start_n; + float best = -1; + float x1 = xyz[old * 3 + 0]; + float y1 = xyz[old * 3 + 1]; + float z1 = xyz[old * 3 + 2]; + for (int k = start_n + tid; k < end_n; k += stride) + { + float x2 = xyz[k * 3 + 0]; + float y2 = xyz[k * 3 + 1]; + float z2 = xyz[k * 3 + 2]; + float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + float d2 = min(d, tmp[k]); + tmp[k] = d2; + besti = d2 > best ? k : besti; + best = d2 > best ? d2 : best; + } + dists[tid] = best; + dists_i[tid] = besti; + __syncthreads(); + + if (block_size >= 1024) { + if (tid < 512) { + __update(dists, dists_i, tid, tid + 512); + } + __syncthreads(); + } + if (block_size >= 512) { + if (tid < 256) { + __update(dists, dists_i, tid, tid + 256); + } + __syncthreads(); + } + if (block_size >= 256) { + if (tid < 128) { + __update(dists, dists_i, tid, tid + 128); + } + __syncthreads(); + } + if (block_size >= 128) { + if (tid < 64) { + __update(dists, dists_i, tid, tid + 64); + } + __syncthreads(); + } + if (block_size >= 64) { + if (tid < 32) { + __update(dists, dists_i, tid, tid + 32); + } + __syncthreads(); + } + if (block_size >= 32) { + if (tid < 16) { + __update(dists, dists_i, tid, tid + 16); + } + __syncthreads(); + } + if (block_size >= 16) { + if (tid < 8) { + __update(dists, dists_i, tid, tid + 8); + } + __syncthreads(); + } + if (block_size >= 8) { + if (tid < 4) { + __update(dists, dists_i, tid, tid + 4); + } + __syncthreads(); + } + if (block_size >= 4) { + if (tid < 2) { + __update(dists, dists_i, tid, tid + 2); + } + __syncthreads(); + } + if (block_size >= 2) { + if (tid < 1) { + __update(dists, dists_i, tid, tid + 1); + } + __syncthreads(); + } + + old = dists_i[0]; + if (tid == 0) + idx[j] = old; + } +} + +void farthest_point_sampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx) +{ + unsigned int n_threads = opt_n_threads(n); + switch (n_threads) { + case 1024: + farthest_point_sampling_cuda_kernel<1024><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 512: + farthest_point_sampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 256: + farthest_point_sampling_cuda_kernel<256><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 128: + farthest_point_sampling_cuda_kernel<128><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 64: + farthest_point_sampling_cuda_kernel<64><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 32: + farthest_point_sampling_cuda_kernel<32><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 16: + farthest_point_sampling_cuda_kernel<16><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 8: + farthest_point_sampling_cuda_kernel<8><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 4: + farthest_point_sampling_cuda_kernel<4><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 2: + farthest_point_sampling_cuda_kernel<2><<>>(xyz, offset, new_offset, tmp, idx); + break; + case 1: + farthest_point_sampling_cuda_kernel<1><<>>(xyz, offset, new_offset, tmp, idx); + break; + default: + farthest_point_sampling_cuda_kernel<512><<>>(xyz, offset, new_offset, tmp, idx); + } +} diff --git a/submodules/pointops/src/sampling/sampling_cuda_kernel.h b/submodules/pointops/src/sampling/sampling_cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f0e07607394a10b2b70c29f7497589d5edb8aab3 --- /dev/null +++ b/submodules/pointops/src/sampling/sampling_cuda_kernel.h @@ -0,0 +1,18 @@ +#ifndef _SAMPLING_CUDA_KERNEL +#define _SAMPLING_CUDA_KERNEL +#include +#include +#include + +void farthest_point_sampling_cuda(int b, int n, at::Tensor xyz_tensor, at::Tensor offset_tensor, at::Tensor new_offset_tensor, at::Tensor tmp_tensor, at::Tensor idx_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void farthest_point_sampling_cuda_launcher(int b, int n, const float *xyz, const int *offset, const int *new_offset, float *tmp, int *idx); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/submodules/pointops/src/subtraction/subtraction_cuda.cpp b/submodules/pointops/src/subtraction/subtraction_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b343857a1671eafe5199089973e863e2ac5b618c --- /dev/null +++ b/submodules/pointops/src/subtraction/subtraction_cuda.cpp @@ -0,0 +1,23 @@ +#include +#include +#include +#include "subtraction_cuda_kernel.h" + + +void subtraction_forward_cuda(int n, int nsample, int c, at::Tensor input1_tensor, at::Tensor input2_tensor, at::Tensor idx_tensor, at::Tensor output_tensor) +{ + const float *input1 = input1_tensor.data_ptr(); + const float *input2 = input2_tensor.data_ptr(); + const int *idx = idx_tensor.data_ptr(); + float *output = output_tensor.data_ptr(); + subtraction_forward_cuda_launcher(n, nsample, c, input1, input2, idx, output); +} + +void subtraction_backward_cuda(int n, int nsample, int c, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input1_tensor, at::Tensor grad_input2_tensor) +{ + const int *idx = idx_tensor.data_ptr(); + const float *grad_output = grad_output_tensor.data_ptr(); + float *grad_input1 = grad_input1_tensor.data_ptr(); + float *grad_input2 = grad_input2_tensor.data_ptr(); + subtraction_backward_cuda_launcher(n, nsample, c, idx, grad_output, grad_input1, grad_input2); +} diff --git a/submodules/pointops/src/subtraction/subtraction_cuda_kernel.cu b/submodules/pointops/src/subtraction/subtraction_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..9b8d4f752940d580ee2b49f1b2946a8d6386d11a --- /dev/null +++ b/submodules/pointops/src/subtraction/subtraction_cuda_kernel.cu @@ -0,0 +1,44 @@ +#include "../cuda_utils.h" +#include "subtraction_cuda_kernel.h" + + +__global__ void subtraction_forward_cuda_kernel(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output) { + // input: input1: (n, c), input2: (n, c), idx: (n, nsample), output: (n, nsample, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * nsample * c) return; + const int c_idx = index % c; + const int nsample_idx = (index / c) % nsample; + const int n_idx = index / nsample / c; + const int idx_idx = n_idx * nsample + nsample_idx; + const int input1_idx = n_idx * c + c_idx; + const int input2_idx = idx[idx_idx] * c + c_idx; + output[index] = input1[input1_idx] - input2[input2_idx]; +} + +__global__ void subtraction_backward_cuda_kernel(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= n * nsample * c) return; + const int c_idx = index % c; + const int nsample_idx = (index / c) % nsample; + const int n_idx = index / nsample / c; + const int idx_idx = n_idx * nsample + nsample_idx; + const int input1_idx = n_idx * c + c_idx; + const int input2_idx = idx[idx_idx] * c + c_idx; + atomicAdd(grad_input1 + input1_idx, grad_output[index]); + atomicAdd(grad_input2 + input2_idx, -grad_output[index]); +} + +void subtraction_forward_cuda_launcher(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output) { + // input: input1: (n, c), input2: (n, c), idx: (n, nsample), output: (n, nsample, c) + dim3 blocks(DIVUP(n * nsample * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + subtraction_forward_cuda_kernel<<>>(n, nsample, c, input1, input2, idx, output); +} + +void subtraction_backward_cuda_launcher(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2) { + // input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) + dim3 blocks(DIVUP(n * nsample * c, THREADS_PER_BLOCK)); + dim3 threads(THREADS_PER_BLOCK); + subtraction_backward_cuda_kernel<<>>(n, nsample, c, idx, grad_output, grad_input1, grad_input2); +} diff --git a/submodules/pointops/src/subtraction/subtraction_cuda_kernel.h b/submodules/pointops/src/subtraction/subtraction_cuda_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..856133d97bdd3dc58f29c746ff240fc9d489c22e --- /dev/null +++ b/submodules/pointops/src/subtraction/subtraction_cuda_kernel.h @@ -0,0 +1,20 @@ +#ifndef _SUBTRACTION_CUDA_KERNEL +#define _SUBTRACTION_CUDA_KERNEL +#include +#include +#include + +void subtraction_forward_cuda(int n, int nsample, int c, at::Tensor input1_tensor, at::Tensor input2_tensor, at::Tensor idx_tensor, at::Tensor output_tensor); +void subtraction_backward_cuda(int n, int nsample, int c, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input1_tensor, at::Tensor grad_input2_tensor); + +#ifdef __cplusplus +extern "C" { +#endif + +void subtraction_forward_cuda_launcher(int n, int nsample, int c, const float *input1, const float *input2, const int *idx, float *output); +void subtraction_backward_cuda_launcher(int n, int nsample, int c, const int *idx, const float *grad_output, float *grad_input1, float *grad_input2); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/submodules/pycolmap/.gitignore b/submodules/pycolmap/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..99677473f833855d40432ad737142de685716f23 --- /dev/null +++ b/submodules/pycolmap/.gitignore @@ -0,0 +1,4 @@ +*.pyc +*.sw* +*.egg-info +build/ diff --git a/submodules/pycolmap/LICENSE.txt b/submodules/pycolmap/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..5156d3d49e0c312561c59680658b6261f635abe3 --- /dev/null +++ b/submodules/pycolmap/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 True Price, UNC Chapel Hill + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/submodules/pycolmap/README.md b/submodules/pycolmap/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7f6769bb42537467eb155356012441e90b3c8920 --- /dev/null +++ b/submodules/pycolmap/README.md @@ -0,0 +1,18 @@ +# pycolmap +Python interface for COLMAP reconstructions, plus some convenient scripts for +loading/modifying/converting reconstructions. + +This code does not, however, run reconstruction -- it only provides a +convenient interface for handling COLMAP's output. + +## Installation + +The following works with `setuptools >= 62.0.0`. Run `python3 -m pip install +setuptools --upgrade` if necessary. + +``` +git clone https://github.com/rmbrualla/pycolmap.git +cd pycolmap +python3 -m pip install -e . +``` + diff --git a/submodules/pycolmap/pycolmap/__init__.py b/submodules/pycolmap/pycolmap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62e55b8d45b45f2255dc116c367e393f4c27e353 --- /dev/null +++ b/submodules/pycolmap/pycolmap/__init__.py @@ -0,0 +1,5 @@ +from .camera import Camera +from .database import COLMAPDatabase +from .image import Image +from .scene_manager import SceneManager +from .rotation import Quaternion, DualQuaternion diff --git a/submodules/pycolmap/pycolmap/camera.py b/submodules/pycolmap/pycolmap/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..29f2fcb7815a2ea66193520c02d5e0d4bf4b13f0 --- /dev/null +++ b/submodules/pycolmap/pycolmap/camera.py @@ -0,0 +1,273 @@ +# Author: True Price + +import numpy as np + +from scipy.optimize import root + + +#------------------------------------------------------------------------------- +# +# camera distortion functions for arrays of size (..., 2) +# +#------------------------------------------------------------------------------- + +def simple_radial_distortion(camera, x): + return x * (1. + camera.k1 * np.square(x).sum(axis=-1, keepdims=True)) + +def radial_distortion(camera, x): + r_sq = np.square(x).sum(axis=-1, keepdims=True) + return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + +def opencv_distortion(camera, x): + x_sq = np.square(x) + xy = np.prod(x, axis=-1, keepdims=True) + r_sq = x_sq.sum(axis=-1, keepdims=True) + + return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + np.concatenate(( + 2. * camera.p1 * xy + camera.p2 * (r_sq + 2. * x_sq), + camera.p1 * (r_sq + 2. * y_sq) + 2. * camera.p2 * xy), + axis=-1) + + +#------------------------------------------------------------------------------- +# +# Camera +# +#------------------------------------------------------------------------------- + +class Camera: + @staticmethod + def GetNumParams(type_): + if type_ == 0 or type_ == 'SIMPLE_PINHOLE': + return 3 + if type_ == 1 or type_ == 'PINHOLE': + return 4 + if type_ == 2 or type_ == 'SIMPLE_RADIAL': + return 4 + if type_ == 3 or type_ == 'RADIAL': + return 5 + if type_ == 4 or type_ == 'OPENCV': + return 8 + if type_ == 5 or type_ == 'OPENCV_FISHEYE': + return 8 + #if type_ == 6 or type_ == 'FULL_OPENCV': + # return 12 + #if type_ == 7 or type_ == 'FOV': + # return 5 + #if type_ == 8 or type_ == 'SIMPLE_RADIAL_FISHEYE': + # return 4 + #if type_ == 9 or type_ == 'RADIAL_FISHEYE': + # return 5 + #if type_ == 10 or type_ == 'THIN_PRISM_FISHEYE': + # return 12 + + # TODO: not supporting other camera types, currently + raise Exception('Camera type not supported') + + + #--------------------------------------------------------------------------- + + @staticmethod + def GetNameFromType(type_): + if type_ == 0: return 'SIMPLE_PINHOLE' + if type_ == 1: return 'PINHOLE' + if type_ == 2: return 'SIMPLE_RADIAL' + if type_ == 3: return 'RADIAL' + if type_ == 4: return 'OPENCV' + if type_ == 5: return 'OPENCV_FISHEYE' + #if type_ == 6: return 'FULL_OPENCV' + #if type_ == 7: return 'FOV' + #if type_ == 8: return 'SIMPLE_RADIAL_FISHEYE' + #if type_ == 9: return 'RADIAL_FISHEYE' + #if type_ == 10: return 'THIN_PRISM_FISHEYE' + + raise Exception('Camera type not supported') + + + #--------------------------------------------------------------------------- + + def __init__(self, type_, width_, height_, params): + self.width = width_ + self.height = height_ + + if type_ == 0 or type_ == 'SIMPLE_PINHOLE': + self.fx, self.cx, self.cy = params + self.fy = self.fx + self.distortion_func = None + self.camera_type = 0 + + elif type_ == 1 or type_ == 'PINHOLE': + self.fx, self.fy, self.cx, self.cy = params + self.distortion_func = None + self.camera_type = 1 + + elif type_ == 2 or type_ == 'SIMPLE_RADIAL': + self.fx, self.cx, self.cy, self.k1 = params + self.fy = self.fx + self.distortion_func = simple_radial_distortion + self.camera_type = 2 + + elif type_ == 3 or type_ == 'RADIAL': + self.fx, self.cx, self.cy, self.k1, self.k2 = params + self.fy = self.fx + self.distortion_func = radial_distortion + self.camera_type = 3 + + elif type_ == 4 or type_ == 'OPENCV': + self.fx, self.fy, self.cx, self.cy = params[:4] + self.k1, self.k2, self.p1, self.p2 = params[4:] + self.distortion_func = opencv_distortion + self.camera_type = 4 + + elif type_ == 5 or type_ == 'OPENCV_FISHEYE': + self.fx, self.fy, self.cx, self.cy = params[:4] + self.k1, self.k2, self.k3, self.k4 = params[4:] + def fn(camera, x): + raise Exception('Fisheye distortion not supported') + self.distortion_func = fn + self.camera_type = 5 + + else: + raise Exception('Camera type not supported') + + + #--------------------------------------------------------------------------- + + def __str__(self): + s = (self.GetNameFromType(self.camera_type) + + ' {} {} {}'.format(self.width, self.height, self.fx)) + + if self.camera_type in (1, 4): # PINHOLE, OPENCV + s += ' {}'.format(self.fy) + + s += ' {} {}'.format(self.cx, self.cy) + + if self.camera_type == 2: # SIMPLE_RADIAL + s += ' {}'.format(self.k1) + + elif self.camera_type == 3: # RADIAL + s += ' {} {}'.format(self.k1, self.k2) + + elif self.camera_type == 4: # OPENCV + s += ' {} {} {} {}'.format(self.k1, self.k2, self.p1, self.p2) + + elif self.camera_type == 5: # OPENCV_FISHEYE + s += ' {} {} {} {}'.format(self.k1, self.k2, self.k3, self.k4) + + return s + + + #--------------------------------------------------------------------------- + + # return the camera parameters in the same order as the colmap output format + def get_params(self): + if self.camera_type == 0: + return np.array((self.fx, self.cx, self.cy)) + if self.camera_type == 1: + return np.array((self.fx, self.fy, self.cx, self.cy)) + if self.camera_type == 2: + return np.array((self.fx, self.cx, self.cy, self.k1)) + if self.camera_type == 3: + return np.array((self.fx, self.cx, self.cy, self.k1, self.k2)) + if self.camera_type == 4: + return np.array((self.fx, self.fy, self.cx, self.cy, self.k1, + self.k2, self.p1, self.p2)) + if self.camera_type == 5: + return np.array((self.fx, self.fy, self.cx, self.cy, self.k1, self.k2, + self.k3, self.k4)) + + + #--------------------------------------------------------------------------- + + def get_camera_matrix(self): + return np.array( + ((self.fx, 0, self.cx), (0, self.fy, self.cy), (0, 0, 1))) + + def get_inverse_camera_matrix(self): + return np.array( + ((1. / self.fx, 0, -self.cx / self.fx), + (0, 1. / self.fy, -self.cy / self.fy), + (0, 0, 1))) + + @property + def K(self): + return self.get_camera_matrix() + + @property + def K_inv(self): + return self.get_inverse_camera_matrix() + + #--------------------------------------------------------------------------- + + # return the inverse camera matrix + def get_inv_camera_matrix(self): + inv_fx, inv_fy = 1. / self.fx, 1. / self.fy + return np.array(((inv_fx, 0, -inv_fx * self.cx), + (0, inv_fy, -inv_fy * self.cy), + (0, 0, 1))) + + + #--------------------------------------------------------------------------- + + # return an (x, y) pixel coordinate grid for this camera + def get_image_grid(self): + xmin = (0.5 - self.cx) / self.fx + xmax = (self.width - 0.5 - self.cx) / self.fx + ymin = (0.5 - self.cy) / self.fy + ymax = (self.height - 0.5 - self.cy) / self.fy + return np.meshgrid(np.linspace(xmin, xmax, self.width), + np.linspace(ymin, ymax, self.height)) + + + #--------------------------------------------------------------------------- + + # x: array of shape (N,2) or (2,) + # normalized: False if the input points are in pixel coordinates + # denormalize: True if the points should be put back into pixel coordinates + def distort_points(self, x, normalized=True, denormalize=True): + x = np.atleast_2d(x) + + # put the points into normalized camera coordinates + if not normalized: + x -= np.array([[self.cx, self.cy]]) + x /= np.array([[self.fx, self.fy]]) + + # distort, if necessary + if self.distortion_func is not None: + x = self.distortion_func(self, x) + + if denormalize: + x *= np.array([[self.fx, self.fy]]) + x += np.array([[self.cx, self.cy]]) + + return x + + + #--------------------------------------------------------------------------- + + # x: array of shape (N1,N2,...,2), (N,2), or (2,) + # normalized: False if the input points are in pixel coordinates + # denormalize: True if the points should be put back into pixel coordinates + def undistort_points(self, x, normalized=False, denormalize=True): + x = np.atleast_2d(x) + + # put the points into normalized camera coordinates + if not normalized: + x = x - np.array([self.cx, self.cy]) # creates a copy + x /= np.array([self.fx, self.fy]) + + # undistort, if necessary + if self.distortion_func is not None: + def objective(xu): + return (x - self.distortion_func(self, xu.reshape(*x.shape)) + ).ravel() + + xu = root(objective, x).x.reshape(*x.shape) + else: + xu = x + + if denormalize: + xu *= np.array([[self.fx, self.fy]]) + xu += np.array([[self.cx, self.cy]]) + + return xu diff --git a/submodules/pycolmap/pycolmap/database.py b/submodules/pycolmap/pycolmap/database.py new file mode 100644 index 0000000000000000000000000000000000000000..c11948d8ec464c567c1581e6dd588350efa4c7a5 --- /dev/null +++ b/submodules/pycolmap/pycolmap/database.py @@ -0,0 +1,340 @@ +import numpy as np +import os +import sqlite3 + + +#------------------------------------------------------------------------------- +# convert SQLite BLOBs to/from numpy arrays + +def array_to_blob(arr): + return np.getbuffer(arr) + +def blob_to_array(blob, dtype, shape=(-1,)): + return np.frombuffer(blob, dtype).reshape(*shape) + + +#------------------------------------------------------------------------------- +# convert to/from image pair ids + +MAX_IMAGE_ID = 2**31 - 1 + +def get_pair_id(image_id1, image_id2): + if image_id1 > image_id2: + image_id1, image_id2 = image_id2, image_id1 + return image_id1 * MAX_IMAGE_ID + image_id2 + + +def get_image_ids_from_pair_id(pair_id): + image_id2 = pair_id % MAX_IMAGE_ID + return (pair_id - image_id2) / MAX_IMAGE_ID, image_id2 + + +#------------------------------------------------------------------------------- +# create table commands + +CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( + camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + model INTEGER NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + params BLOB, + prior_focal_length INTEGER NOT NULL)""" + +CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( + image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL UNIQUE, + camera_id INTEGER NOT NULL, + prior_qw REAL, + prior_qx REAL, + prior_qy REAL, + prior_qz REAL, + prior_tx REAL, + prior_ty REAL, + prior_tz REAL, + CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647), + FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))""" + +CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + config INTEGER NOT NULL, + F BLOB, + E BLOB, + H BLOB)""" + +CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB)""" + +CREATE_NAME_INDEX = \ + "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" + +CREATE_ALL = "; ".join([CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE, + CREATE_IMAGES_TABLE, CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE, + CREATE_MATCHES_TABLE, CREATE_NAME_INDEX]) + + +#------------------------------------------------------------------------------- +# functional interface for adding objects + +def add_camera(db, model, width, height, params, prior_focal_length=False, + camera_id=None): + # TODO: Parameter count checks + params = np.asarray(params, np.float64) + db.execute("INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", + (camera_id, model, width, height, array_to_blob(params), + prior_focal_length)) + + +def add_descriptors(db, image_id, descriptors): + descriptors = np.ascontiguousarray(descriptors, np.uint8) + db.execute("INSERT INTO descriptors VALUES (?, ?, ?, ?)", + (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) + + +def add_image(db, name, camera_id, prior_q=np.zeros(4), prior_t=np.zeros(3), + image_id=None): + db.execute("INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], + prior_q[3], prior_t[0], prior_t[1], prior_t[2])) + + +# config: defaults to fundamental matrix +def add_inlier_matches(db, image_id1, image_id2, matches, config=2, F=None, + E=None, H=None): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + if F is not None: + F = np.asarray(F, np.float64) + if E is not None: + E = np.asarray(E, np.float64) + if H is not None: + H = np.asarray(H, np.float64) + + pair_id = get_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + db.execute("INSERT INTO inlier_matches VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches), config, F, E, H)) + + +def add_keypoints(db, image_id, keypoints): + assert(len(keypoints.shape) == 2) + assert(keypoints.shape[1] in [2, 4, 6]) + + keypoints = np.asarray(keypoints, np.float32) + db.execute("INSERT INTO keypoints VALUES (?, ?, ?, ?)", + (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) + + +# config: defaults to fundamental matrix +def add_matches(db, image_id1, image_id2, matches): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = get_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + db.execute("INSERT INTO matches VALUES (?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches),)) + + +#------------------------------------------------------------------------------- +# simple functional interface + +class COLMAPDatabase(sqlite3.Connection): + @staticmethod + def connect(database_path): + return sqlite3.connect(database_path, factory=COLMAPDatabase) + + + def __init__(self, *args, **kwargs): + super(COLMAPDatabase, self).__init__(*args, **kwargs) + + self.initialize_tables = lambda: self.executescript(CREATE_ALL) + + self.initialize_cameras = \ + lambda: self.executescript(CREATE_CAMERAS_TABLE) + self.initialize_descriptors = \ + lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) + self.initialize_images = \ + lambda: self.executescript(CREATE_IMAGES_TABLE) + self.initialize_inlier_matches = \ + lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE) + self.initialize_keypoints = \ + lambda: self.executescript(CREATE_KEYPOINTS_TABLE) + self.initialize_matches = \ + lambda: self.executescript(CREATE_MATCHES_TABLE) + + self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) + + + add_camera = add_camera + add_descriptors = add_descriptors + add_image = add_image + add_inlier_matches = add_inlier_matches + add_keypoints = add_keypoints + add_matches = add_matches + + +#------------------------------------------------------------------------------- + +def main(args): + import os + + if os.path.exists(args.database_path): + print("Error: database path already exists -- will not modify it.") + exit() + + db = COLMAPDatabase.connect(args.database_path) + + # + # for convenience, try creating all the tables upfront + # + + db.initialize_tables() + + + # + # create dummy cameras + # + + model1, w1, h1, params1 = 0, 1024, 768, np.array((1024., 512., 384.)) + model2, w2, h2, params2 = 2, 1024, 768, np.array((1024., 512., 384., 0.1)) + + db.add_camera(model1, w1, h1, params1) + db.add_camera(model2, w2, h2, params2) + + + # + # create dummy images + # + + db.add_image("image1.png", 0) + db.add_image("image2.png", 0) + db.add_image("image3.png", 2) + db.add_image("image4.png", 2) + + + # + # create dummy keypoints; note that COLMAP supports 2D keypoints (x, y), + # 4D keypoints (x, y, theta, scale), and 6D affine keypoints + # (x, y, a_11, a_12, a_21, a_22) + # + + N = 1000 + kp1 = np.random.rand(N, 2) * (1024., 768.) + kp2 = np.random.rand(N, 2) * (1024., 768.) + kp3 = np.random.rand(N, 2) * (1024., 768.) + kp4 = np.random.rand(N, 2) * (1024., 768.) + + db.add_keypoints(1, kp1) + db.add_keypoints(2, kp2) + db.add_keypoints(3, kp3) + db.add_keypoints(4, kp4) + + + # + # create dummy matches + # + + M = 50 + m12 = np.random.randint(N, size=(M, 2)) + m23 = np.random.randint(N, size=(M, 2)) + m34 = np.random.randint(N, size=(M, 2)) + + db.add_matches(1, 2, m12) + db.add_matches(2, 3, m23) + db.add_matches(3, 4, m34) + + + # + # check cameras + # + + rows = db.execute("SELECT * FROM cameras") + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float32) + assert model == model1 and width == w1 and height == h1 + assert np.allclose(params, params1) + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float32) + assert model == model2 and width == w2 and height == h2 + assert np.allclose(params, params2) + + + # + # check keypoints + # + + kps = dict( + (image_id, blob_to_array(data, np.float32, (-1, 2))) + for image_id, data in db.execute( + "SELECT image_id, data FROM keypoints")) + + assert np.allclose(kps[1], kp1) + assert np.allclose(kps[2], kp2) + assert np.allclose(kps[3], kp3) + assert np.allclose(kps[4], kp4) + + + # + # check matches + # + + pair_ids = [get_pair_id(*pair) for pair in [(1, 2), (2, 3), (3, 4)]] + + matches = dict( + (get_image_ids_from_pair_id(pair_id), + blob_to_array(data, np.uint32, (-1, 2))) + for pair_id, data in db.execute("SELECT pair_id, data FROM matches")) + + assert np.all(matches[(1, 2)] == m12) + assert np.all(matches[(2, 3)] == m23) + assert np.all(matches[(3, 4)] == m34) + + # + # clean up + # + + db.close() + os.remove(args.database_path) + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--database_path", type=str, default="database.db") + + args = parser.parse_args() + + main(args) diff --git a/submodules/pycolmap/pycolmap/image.py b/submodules/pycolmap/pycolmap/image.py new file mode 100644 index 0000000000000000000000000000000000000000..14efa32b0a91f116cbd7836b6480a60b10371196 --- /dev/null +++ b/submodules/pycolmap/pycolmap/image.py @@ -0,0 +1,35 @@ +# Author: True Price + +import numpy as np + +#------------------------------------------------------------------------------- +# +# Image +# +#------------------------------------------------------------------------------- + +class Image: + def __init__(self, name_, camera_id_, q_, tvec_): + self.name = name_ + self.camera_id = camera_id_ + self.q = q_ + self.tvec = tvec_ + + self.points2D = np.empty((0, 2), dtype=np.float64) + self.point3D_ids = np.empty((0,), dtype=np.uint64) + + #--------------------------------------------------------------------------- + + def R(self): + return self.q.ToR() + + #--------------------------------------------------------------------------- + + def C(self): + return -self.R().T.dot(self.tvec) + + #--------------------------------------------------------------------------- + + @property + def t(self): + return self.tvec diff --git a/submodules/pycolmap/pycolmap/rotation.py b/submodules/pycolmap/pycolmap/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..f0b4e811620e9668e8a44b1fd15e0574a7307f6a --- /dev/null +++ b/submodules/pycolmap/pycolmap/rotation.py @@ -0,0 +1,324 @@ +# Author: True Price + +import numpy as np + +#------------------------------------------------------------------------------- +# +# Axis-Angle Functions +# +#------------------------------------------------------------------------------- + +# returns the cross product matrix representation of a 3-vector v +def cross_prod_matrix(v): + return np.array(((0., -v[2], v[1]), (v[2], 0., -v[0]), (-v[1], v[0], 0.))) + +#------------------------------------------------------------------------------- + +# www.euclideanspace.com/maths/geometry/rotations/conversions/angleToMatrix/ +# if angle is None, assume ||axis|| == angle, in radians +# if angle is not None, assume that axis is a unit vector +def axis_angle_to_rotation_matrix(axis, angle=None): + if angle is None: + angle = np.linalg.norm(axis) + if np.abs(angle) > np.finfo('float').eps: + axis = axis / angle + + cp_axis = cross_prod_matrix(axis) + return np.eye(3) + ( + np.sin(angle) * cp_axis + (1. - np.cos(angle)) * cp_axis.dot(cp_axis)) + +#------------------------------------------------------------------------------- + +# after some deliberation, I've decided the easiest way to do this is to use +# quaternions as an intermediary +def rotation_matrix_to_axis_angle(R): + return Quaternion.FromR(R).ToAxisAngle() + +#------------------------------------------------------------------------------- +# +# Quaternion +# +#------------------------------------------------------------------------------- + +class Quaternion: + # create a quaternion from an existing rotation matrix + # euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/ + @staticmethod + def FromR(R): + trace = np.trace(R) + + if trace > 0: + qw = 0.5 * np.sqrt(1. + trace) + qx = (R[2,1] - R[1,2]) * 0.25 / qw + qy = (R[0,2] - R[2,0]) * 0.25 / qw + qz = (R[1,0] - R[0,1]) * 0.25 / qw + elif R[0,0] > R[1,1] and R[0,0] > R[2,2]: + s = 2. * np.sqrt(1. + R[0,0] - R[1,1] - R[2,2]) + qw = (R[2,1] - R[1,2]) / s + qx = 0.25 * s + qy = (R[0,1] + R[1,0]) / s + qz = (R[0,2] + R[2,0]) / s + elif R[1,1] > R[2,2]: + s = 2. * np.sqrt(1. + R[1,1] - R[0,0] - R[2,2]) + qw = (R[0,2] - R[2,0]) / s + qx = (R[0,1] + R[1,0]) / s + qy = 0.25 * s + qz = (R[1,2] + R[2,1]) / s + else: + s = 2. * np.sqrt(1. + R[2,2] - R[0,0] - R[1,1]) + qw = (R[1,0] - R[0,1]) / s + qx = (R[0,2] + R[2,0]) / s + qy = (R[1,2] + R[2,1]) / s + qz = 0.25 * s + + return Quaternion(np.array((qw, qx, qy, qz))) + + # if angle is None, assume ||axis|| == angle, in radians + # if angle is not None, assume that axis is a unit vector + @staticmethod + def FromAxisAngle(axis, angle=None): + if angle is None: + angle = np.linalg.norm(axis) + if np.abs(angle) > np.finfo('float').eps: + axis = axis / angle + + qw = np.cos(0.5 * angle) + axis = axis * np.sin(0.5 * angle) + + return Quaternion(np.array((qw, axis[0], axis[1], axis[2]))) + + #--------------------------------------------------------------------------- + + def __init__(self, q=np.array((1., 0., 0., 0.))): + if isinstance(q, Quaternion): + self.q = q.q.copy() + else: + q = np.asarray(q) + if q.size == 4: + self.q = q.copy() + elif q.size == 3: # convert from a 3-vector to a quaternion + self.q = np.empty(4) + self.q[0], self.q[1:] = 0., q.ravel() + else: + raise Exception('Input quaternion should be a 3- or 4-vector') + + def __add__(self, other): + return Quaternion(self.q + other.q) + + def __iadd__(self, other): + self.q += other.q + return self + + # conjugation via the ~ operator + def __invert__(self): + return Quaternion( + np.array((self.q[0], -self.q[1], -self.q[2], -self.q[3]))) + + # returns: self.q * other.q if other is a Quaternion; otherwise performs + # scalar multiplication + def __mul__(self, other): + if isinstance(other, Quaternion): # quaternion multiplication + return Quaternion(np.array(( + self.q[0] * other.q[0] - self.q[1] * other.q[1] - + self.q[2] * other.q[2] - self.q[3] * other.q[3], + self.q[0] * other.q[1] + self.q[1] * other.q[0] + + self.q[2] * other.q[3] - self.q[3] * other.q[2], + self.q[0] * other.q[2] - self.q[1] * other.q[3] + + self.q[2] * other.q[0] + self.q[3] * other.q[1], + self.q[0] * other.q[3] + self.q[1] * other.q[2] - + self.q[2] * other.q[1] + self.q[3] * other.q[0]))) + else: # scalar multiplication (assumed) + return Quaternion(other * self.q) + + def __rmul__(self, other): + return self * other + + def __imul__(self, other): + self.q[:] = (self * other).q + return self + + def __irmul__(self, other): + self.q[:] = (self * other).q + return self + + def __neg__(self): + return Quaternion(-self.q) + + def __sub__(self, other): + return Quaternion(self.q - other.q) + + def __isub__(self, other): + self.q -= other.q + return self + + def __str__(self): + return str(self.q) + + def copy(self): + return Quaternion(self) + + def dot(self, other): + return self.q.dot(other.q) + + # assume the quaternion is nonzero! + def inverse(self): + return Quaternion((~self).q / self.q.dot(self.q)) + + def norm(self): + return np.linalg.norm(self.q) + + def normalize(self): + self.q /= np.linalg.norm(self.q) + return self + + # assume x is a Nx3 numpy array or a numpy 3-vector + def rotate_points(self, x): + x = np.atleast_2d(x) + return x.dot(self.ToR().T) + + # convert to a rotation matrix + def ToR(self): + return np.eye(3) + 2 * np.array(( + (-self.q[2] * self.q[2] - self.q[3] * self.q[3], + self.q[1] * self.q[2] - self.q[3] * self.q[0], + self.q[1] * self.q[3] + self.q[2] * self.q[0]), + ( self.q[1] * self.q[2] + self.q[3] * self.q[0], + -self.q[1] * self.q[1] - self.q[3] * self.q[3], + self.q[2] * self.q[3] - self.q[1] * self.q[0]), + ( self.q[1] * self.q[3] - self.q[2] * self.q[0], + self.q[2] * self.q[3] + self.q[1] * self.q[0], + -self.q[1] * self.q[1] - self.q[2] * self.q[2]))) + + # convert to axis-angle representation, with angle encoded by the length + def ToAxisAngle(self): + # recall that for axis-angle representation (a, angle), with "a" unit: + # q = (cos(angle/2), a * sin(angle/2)) + # below, for readability, "theta" actually means half of the angle + + sin_sq_theta = self.q[1:].dot(self.q[1:]) + + # if theta is non-zero, then we can compute a unique rotation + if np.abs(sin_sq_theta) > np.finfo('float').eps: + sin_theta = np.sqrt(sin_sq_theta) + cos_theta = self.q[0] + + # atan2 is more stable, so we use it to compute theta + # note that we multiply by 2 to get the actual angle + angle = 2. * ( + np.arctan2(-sin_theta, -cos_theta) if cos_theta < 0. else + np.arctan2(sin_theta, cos_theta)) + + return self.q[1:] * (angle / sin_theta) + + # otherwise, the result is singular, and we avoid dividing by + # sin(angle/2) = 0 + return np.zeros(3) + + # euclideanspace.com/maths/geometry/rotations/conversions/quaternionToEuler + # this assumes the quaternion is non-zero + # returns yaw, pitch, roll, with application in that order + def ToEulerAngles(self): + qsq = self.q**2 + k = 2. * (self.q[0] * self.q[3] + self.q[1] * self.q[2]) / qsq.sum() + + if (1. - k) < np.finfo('float').eps: # north pole singularity + return 2. * np.arctan2(self.q[1], self.q[0]), 0.5 * np.pi, 0. + if (1. + k) < np.finfo('float').eps: # south pole singularity + return -2. * np.arctan2(self.q[1], self.q[0]), -0.5 * np.pi, 0. + + yaw = np.arctan2(2. * (self.q[0] * self.q[2] - self.q[1] * self.q[3]), + qsq[0] + qsq[1] - qsq[2] - qsq[3]) + pitch = np.arcsin(k) + roll = np.arctan2(2. * (self.q[0] * self.q[1] - self.q[2] * self.q[3]), + qsq[0] - qsq[1] + qsq[2] - qsq[3]) + + return yaw, pitch, roll + +#------------------------------------------------------------------------------- +# +# DualQuaternion +# +#------------------------------------------------------------------------------- + +class DualQuaternion: + # DualQuaternion from an existing rotation + translation + @staticmethod + def FromQT(q, t): + return DualQuaternion(qe=(0.5 * np.asarray(t))) * DualQuaternion(q) + + def __init__(self, q0=np.array((1., 0., 0., 0.)), qe=np.zeros(4)): + self.q0, self.qe = Quaternion(q0), Quaternion(qe) + + def __add__(self, other): + return DualQuaternion(self.q0 + other.q0, self.qe + other.qe) + + def __iadd__(self, other): + self.q0 += other.q0 + self.qe += other.qe + return self + + # conguation via the ~ operator + def __invert__(self): + return DualQuaternion(~self.q0, ~self.qe) + + def __mul__(self, other): + if isinstance(other, DualQuaternion): + return DualQuaternion( + self.q0 * other.q0, + self.q0 * other.qe + self.qe * other.q0) + elif isinstance(other, complex): # multiplication by a dual number + return DualQuaternion( + self.q0 * other.real, + self.q0 * other.imag + self.qe * other.real) + else: # scalar multiplication (assumed) + return DualQuaternion(other * self.q0, other * self.qe) + + def __rmul__(self, other): + return self.__mul__(other) + + def __imul__(self, other): + tmp = self * other + self.q0, self.qe = tmp.q0, tmp.qe + return self + + def __neg__(self): + return DualQuaternion(-self.q0, -self.qe) + + def __sub__(self, other): + return DualQuaternion(self.q0 - other.q0, self.qe - other.qe) + + def __isub__(self, other): + self.q0 -= other.q0 + self.qe -= other.qe + return self + + # q^-1 = q* / ||q||^2 + # assume that q0 is nonzero! + def inverse(self): + normsq = complex(q0.dot(q0), 2. * self.q0.q.dot(self.qe.q)) + inv_len_real = 1. / normsq.real + return ~self * complex( + inv_len_real, -normsq.imag * inv_len_real * inv_len_real) + + # returns a complex representation of the real and imaginary parts of the norm + # assume that q0 is nonzero! + def norm(self): + q0_norm = self.q0.norm() + return complex(q0_norm, self.q0.dot(self.qe) / q0_norm) + + # assume that q0 is nonzero! + def normalize(self): + # current length is ||q0|| + eps * ( / ||q0||) + # writing this as a + eps * b, the inverse is + # 1/||q|| = 1/a - eps * b / a^2 + norm = self.norm() + inv_len_real = 1. / norm.real + self *= complex(inv_len_real, -norm.imag * inv_len_real * inv_len_real) + return self + + # return the translation vector for this dual quaternion + def getT(self): + return 2 * (self.qe * ~self.q0).q[1:] + + def ToQT(self): + return self.q0, self.getT() diff --git a/submodules/pycolmap/pycolmap/scene_manager.py b/submodules/pycolmap/pycolmap/scene_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..b552ccf09a8e6590e179286b885feb480e09ca13 --- /dev/null +++ b/submodules/pycolmap/pycolmap/scene_manager.py @@ -0,0 +1,676 @@ +# Author: True Price + +import array +import numpy as np +import os +import struct + +from collections import OrderedDict +from itertools import combinations + +from .camera import Camera +from .image import Image +from .rotation import Quaternion + +#------------------------------------------------------------------------------- +# +# SceneManager +# +#------------------------------------------------------------------------------- + +class SceneManager: + + # Safe definition that works across NumPy versions + try: + INVALID_POINT3D = np.uint64(-1) + except OverflowError: + # NumPy versions (like 2.3) raise; fall back to the max of uint64 + INVALID_POINT3D = np.uint64(np.iinfo(np.uint64).max) + + def __init__(self, colmap_results_folder, image_path=None): + self.folder = colmap_results_folder + if not self.folder.endswith('/'): + self.folder += '/' + + self.image_path = None + self.load_colmap_project_file(image_path=image_path) + + self.cameras = OrderedDict() + self.images = OrderedDict() + self.name_to_image_id = dict() + + self.last_camera_id = 0 + self.last_image_id = 0 + + # Nx3 array of point3D xyz's + self.points3D = np.zeros((0, 3)) + + # for each element in points3D, stores the id of the point + self.point3D_ids = np.empty(0) + + # point3D_id => index in self.points3D + self.point3D_id_to_point3D_idx = dict() + + # point3D_id => [(image_id, point2D idx in image)] + self.point3D_id_to_images = dict() + + self.point3D_colors = np.zeros((0, 3), dtype=np.uint8) + self.point3D_errors = np.zeros(0) + + #--------------------------------------------------------------------------- + + def load_colmap_project_file(self, project_file=None, image_path=None): + if project_file is None: + project_file = self.folder + 'project.ini' + + self.image_path = image_path + + if self.image_path is None: + try: + with open(project_file, 'r') as f: + for line in iter(f.readline, ''): + if line.startswith('image_path'): + self.image_path = line[11:].strip() + break + except: + pass + + if self.image_path is None: + print('Warning: image_path not found for reconstruction') + elif not self.image_path.endswith('/'): + self.image_path += '/' + + #--------------------------------------------------------------------------- + + def load(self): + self.load_cameras() + self.load_images() + self.load_points3D() + + #--------------------------------------------------------------------------- + + def load_cameras(self, input_file=None): + if input_file is None: + input_file = self.folder + 'cameras.bin' + if os.path.exists(input_file): + self._load_cameras_bin(input_file) + else: + input_file = self.folder + 'cameras.txt' + if os.path.exists(input_file): + self._load_cameras_txt(input_file) + else: + raise IOError('no cameras file found') + + def _load_cameras_bin(self, input_file): + self.cameras = OrderedDict() + + with open(input_file, 'rb') as f: + num_cameras = struct.unpack('L', f.read(8))[0] + + for _ in range(num_cameras): + camera_id, camera_type, w, h = struct.unpack('IiLL', f.read(24)) + num_params = Camera.GetNumParams(camera_type) + params = struct.unpack('d' * num_params, f.read(8 * num_params)) + self.cameras[camera_id] = Camera(camera_type, w, h, params) + self.last_camera_id = max(self.last_camera_id, camera_id) + + def _load_cameras_txt(self, input_file): + self.cameras = OrderedDict() + + with open(input_file, 'r') as f: + for line in iter(lambda: f.readline().strip(), ''): + if not line or line.startswith('#'): + continue + + data = line.split() + camera_id = int(data[0]) + self.cameras[camera_id] = Camera( + data[1], int(data[2]), int(data[3]), map(float, data[4:])) + self.last_camera_id = max(self.last_camera_id, camera_id) + + #--------------------------------------------------------------------------- + + def load_images(self, input_file=None): + if input_file is None: + input_file = self.folder + 'images.bin' + if os.path.exists(input_file): + self._load_images_bin(input_file) + else: + input_file = self.folder + 'images.txt' + if os.path.exists(input_file): + self._load_images_txt(input_file) + else: + raise IOError('no images file found') + + def _load_images_bin(self, input_file): + self.images = OrderedDict() + + with open(input_file, 'rb') as f: + num_images = struct.unpack('L', f.read(8))[0] + image_struct = struct.Struct('7x improvements in 60 image model, 23s -> 3s. + points_array = array.array('d') + points_array.fromfile(f, 3 * num_points2D) + points_elements = np.array(points_array).reshape((num_points2D, 3)) + image.points2D = points_elements[:, :2] + + ids_array = array.array('Q') + ids_array.frombytes(points_elements[:, 2].tobytes()) + image.point3D_ids = np.array(ids_array, dtype=np.uint64).reshape( + (num_points2D,)) + + # automatically remove points without an associated 3D point + #mask = (image.point3D_ids != SceneManager.INVALID_POINT3D) + #image.points2D = image.points2D[mask] + #image.point3D_ids = image.point3D_ids[mask] + + self.images[image_id] = image + self.name_to_image_id[image.name] = image_id + + self.last_image_id = max(self.last_image_id, image_id) + + def _load_images_txt(self, input_file): + self.images = OrderedDict() + + with open(input_file, 'r') as f: + is_camera_description_line = False + + for line in iter(lambda: f.readline().strip(), ''): + if not line or line.startswith('#'): + continue + + is_camera_description_line = not is_camera_description_line + + data = line.split() + + if is_camera_description_line: + image_id = int(data[0]) + image = Image(data[-1], int(data[-2]), + Quaternion(np.array(map(float, data[1:5]))), + np.array(map(float, data[5:8]))) + else: + image.points2D = np.array( + [map(float, data[::3]), map(float, data[1::3])]).T + image.point3D_ids = np.array(map(np.uint64, data[2::3])) + + # automatically remove points without an associated 3D point + #mask = (image.point3D_ids != SceneManager.INVALID_POINT3D) + #image.points2D = image.points2D[mask] + #image.point3D_ids = image.point3D_ids[mask] + + self.images[image_id] = image + self.name_to_image_id[image.name] = image_id + + self.last_image_id = max(self.last_image_id, image_id) + + #--------------------------------------------------------------------------- + + def load_points3D(self, input_file=None): + if input_file is None: + input_file = self.folder + 'points3D.bin' + if os.path.exists(input_file): + self._load_points3D_bin(input_file) + else: + input_file = self.folder + 'points3D.txt' + if os.path.exists(input_file): + self._load_points3D_txt(input_file) + else: + raise IOError('no points3D file found') + + def _load_points3D_bin(self, input_file): + with open(input_file, 'rb') as f: + num_points3D = struct.unpack('L', f.read(8))[0] + + self.points3D = np.empty((num_points3D, 3)) + self.point3D_ids = np.empty(num_points3D, dtype=np.uint64) + self.point3D_colors = np.empty((num_points3D, 3), dtype=np.uint8) + self.point3D_id_to_point3D_idx = dict() + self.point3D_id_to_images = dict() + self.point3D_errors = np.empty(num_points3D) + + data_struct = struct.Struct('>fid, '# Camera list with one line of data per camera:' + print>>fid, '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]' + print>>fid, '# Number of cameras:', len(self.cameras) + + for camera_id, camera in sorted(self.cameras.iteritems()): + print>>fid, camera_id, camera + + #--------------------------------------------------------------------------- + + def save_images(self, output_folder, output_file=None, binary=True): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + if output_file is None: + output_file = 'images.bin' if binary else 'images.txt' + + output_file = os.path.join(output_folder, output_file) + + if binary: + self._save_images_bin(output_file) + else: + self._save_images_txt(output_file) + + def _save_images_bin(self, output_file): + with open(output_file, 'wb') as fid: + fid.write(struct.pack('L', len(self.images))) + + for image_id, image in self.images.iteritems(): + fid.write(struct.pack('I', image_id)) + fid.write(image.q.q.tobytes()) + fid.write(image.tvec.tobytes()) + fid.write(struct.pack('I', image.camera_id)) + fid.write(image.name + '\0') + fid.write(struct.pack('L', len(image.points2D))) + data = np.rec.fromarrays( + (image.points2D[:,0], image.points2D[:,1], image.point3D_ids)) + fid.write(data.tobytes()) + + def _save_images_txt(self, output_file): + with open(output_file, 'w') as fid: + print>>fid, '# Image list with two lines of data per image:' + print>>fid, '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME' + print>>fid, '# POINTS2D[] as (X, Y, POINT3D_ID)' + print>>fid, '# Number of images: {},'.format(len(self.images)), + print>>fid, 'mean observations per image: unknown' + + for image_id, image in self.images.iteritems(): + print>>fid, image_id, + print>>fid, ' '.join(str(qi) for qi in image.q.q), + print>>fid, ' '.join(str(ti) for ti in image.tvec), + print>>fid, image.camera_id, image.name + + data = np.rec.fromarrays( + (image.points2D[:,0], image.points2D[:,1], + image.point3D_ids.astype(np.int64))) + if len(data) > 0: + np.savetxt(fid, data, '%.2f %.2f %d', newline=' ') + fid.seek(-1, os.SEEK_CUR) + fid.write('\n') + + #--------------------------------------------------------------------------- + + def save_points3D(self, output_folder, output_file=None, binary=True): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + if output_file is None: + output_file = 'points3D.bin' if binary else 'points3D.txt' + + output_file = os.path.join(output_folder, output_file) + + if binary: + self._save_points3D_bin(output_file) + else: + self._save_points3D_txt(output_file) + + def _save_points3D_bin(self, output_file): + num_valid_points3D = sum( + 1 for point3D_idx in self.point3D_id_to_point3D_idx.itervalues() + if point3D_idx != SceneManager.INVALID_POINT3D) + + iter_point3D_id_to_point3D_idx = \ + self.point3D_id_to_point3D_idx.iteritems() + + with open(output_file, 'wb') as fid: + fid.write(struct.pack('L', num_valid_points3D)) + + for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx: + if point3D_idx == SceneManager.INVALID_POINT3D: + continue + + fid.write(struct.pack('L', point3D_id)) + fid.write(self.points3D[point3D_idx].tobytes()) + fid.write(self.point3D_colors[point3D_idx].tobytes()) + fid.write(self.point3D_errors[point3D_idx].tobytes()) + fid.write( + struct.pack('L', len(self.point3D_id_to_images[point3D_id]))) + fid.write(self.point3D_id_to_images[point3D_id].tobytes()) + + def _save_points3D_txt(self, output_file): + num_valid_points3D = sum( + 1 for point3D_idx in self.point3D_id_to_point3D_idx.itervalues() + if point3D_idx != SceneManager.INVALID_POINT3D) + + array_to_string = lambda arr: ' '.join(str(x) for x in arr) + + iter_point3D_id_to_point3D_idx = \ + self.point3D_id_to_point3D_idx.iteritems() + + with open(output_file, 'w') as fid: + print>>fid, '# 3D point list with one line of data per point:' + print>>fid, '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as ', + print>>fid, '(IMAGE_ID, POINT2D_IDX)' + print>>fid, '# Number of points: {},'.format(num_valid_points3D), + print>>fid, 'mean track length: unknown' + + for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx: + if point3D_idx == SceneManager.INVALID_POINT3D: + continue + + print>>fid, point3D_id, + print>>fid, array_to_string(self.points3D[point3D_idx]), + print>>fid, array_to_string(self.point3D_colors[point3D_idx]), + print>>fid, self.point3D_errors[point3D_idx], + print>>fid, array_to_string( + self.point3D_id_to_images[point3D_id].flat) + + #--------------------------------------------------------------------------- + + # return the image id associated with a given image file + def get_image_from_name(self, image_name): + image_id = self.name_to_image_id[image_name] + return image_id, self.images[image_id] + + #--------------------------------------------------------------------------- + + def get_camera(self, camera_id): + return self.cameras[camera_id] + + #--------------------------------------------------------------------------- + + def get_points3D(self, image_id, return_points2D=True, return_colors=False): + image = self.images[image_id] + + mask = (image.point3D_ids != SceneManager.INVALID_POINT3D) + + point3D_idxs = np.array([ + self.point3D_id_to_point3D_idx[point3D_id] + for point3D_id in image.point3D_ids[mask]]) + # detect filtered points + filter_mask = (point3D_idxs != SceneManager.INVALID_POINT3D) + point3D_idxs = point3D_idxs[filter_mask] + result = [self.points3D[point3D_idxs,:]] + + if return_points2D: + mask[mask] &= filter_mask + result += [image.points2D[mask]] + if return_colors: + result += [self.point3D_colors[point3D_idxs,:]] + + return result if len(result) > 1 else result[0] + + #--------------------------------------------------------------------------- + + def point3D_valid(self, point3D_id): + return (self.point3D_id_to_point3D_idx[point3D_id] != + SceneManager.INVALID_POINT3D) + + #--------------------------------------------------------------------------- + + def get_filtered_points3D(self, return_colors=False): + point3D_idxs = [ + idx for idx in self.point3D_id_to_point3D_idx.values() + if idx != SceneManager.INVALID_POINT3D] + result = [self.points3D[point3D_idxs,:]] + + if return_colors: + result += [self.point3D_colors[point3D_idxs,:]] + + return result if len(result) > 1 else result[0] + + #--------------------------------------------------------------------------- + + # return 3D points shared by two images + def get_shared_points3D(self, image_id1, image_id2): + point3D_ids = ( + set(self.images[image_id1].point3D_ids) & + set(self.images[image_id2].point3D_ids)) + point3D_ids.discard(SceneManager.INVALID_POINT3D) + + point3D_idxs = np.array([self.point3D_id_to_point3D_idx[point3D_id] + for point3D_id in point3D_ids]) + + return self.points3D[point3D_idxs,:] + + #--------------------------------------------------------------------------- + + # project *all* 3D points into image, return their projection coordinates, + # as well as their 3D positions + def get_viewed_points(self, image_id): + image = self.images[image_id] + + # get unfiltered points + point3D_idxs = set(self.point3D_id_to_point3D_idx.itervalues()) + point3D_idxs.discard(SceneManager.INVALID_POINT3D) + point3D_idxs = list(point3D_idxs) + points3D = self.points3D[point3D_idxs,:] + + # orient points relative to camera + R = image.q.ToR() + points3D = points3D.dot(R.T) + image.tvec[np.newaxis,:] + points3D = points3D[points3D[:,2] > 0,:] # keep points with positive z + + # put points into image coordinates + camera = self.cameras[image.camera_id] + points2D = points3D.dot(camera.get_camera_matrix().T) + points2D = points2D[:,:2] / points2D[:,2][:,np.newaxis] + + # keep points that are within the image + mask = ( + (points2D[:,0] >= 0) & + (points2D[:,1] >= 0) & + (points2D[:,0] < camera.width - 1) & + (points2D[:,1] < camera.height - 1)) + + return points2D[mask,:], points3D[mask,:] + + #--------------------------------------------------------------------------- + + def add_camera(self, camera): + self.last_camera_id += 1 + self.cameras[self.last_camera_id] = camera + return self.last_camera_id + + #--------------------------------------------------------------------------- + + def add_image(self, image): + self.last_image_id += 1 + self.images[self.last_image_id] = image + return self.last_image_id + + #--------------------------------------------------------------------------- + + def delete_images(self, image_list): + # delete specified images + for image_id in image_list: + if image_id in self.images: + del self.images[image_id] + + keep_set = set(self.images.iterkeys()) + + # delete references to specified images, and ignore any points that are + # invalidated + iter_point3D_id_to_point3D_idx = \ + self.point3D_id_to_point3D_idx.iteritems() + + for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx: + if point3D_idx == SceneManager.INVALID_POINT3D: + continue + + mask = np.array([ + image_id in keep_set + for image_id in self.point3D_id_to_images[point3D_id][:,0]]) + if np.any(mask): + self.point3D_id_to_images[point3D_id] = \ + self.point3D_id_to_images[point3D_id][mask] + else: + self.point3D_id_to_point3D_idx[point3D_id] = \ + SceneManager.INVALID_POINT3D + + #--------------------------------------------------------------------------- + + # camera_list: set of cameras whose points we'd like to keep + # min/max triangulation angle: in degrees + def filter_points3D(self, + min_track_len=0, max_error=np.inf, min_tri_angle=0, + max_tri_angle=180, image_set=set()): + + image_set = set(image_set) + + check_triangulation_angles = (min_tri_angle > 0 or max_tri_angle < 180) + if check_triangulation_angles: + max_tri_prod = np.cos(np.radians(min_tri_angle)) + min_tri_prod = np.cos(np.radians(max_tri_angle)) + + iter_point3D_id_to_point3D_idx = \ + self.point3D_id_to_point3D_idx.iteritems() + + image_ids = [] + + for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx: + if point3D_idx == SceneManager.INVALID_POINT3D: + continue + + if image_set or min_track_len > 0: + image_ids = set(self.point3D_id_to_images[point3D_id][:,0]) + + # check if error and min track length are sufficient, or if none of + # the selected cameras see the point + if (len(image_ids) < min_track_len or + self.point3D_errors[point3D_idx] > max_error or + image_set and image_set.isdisjoint(image_ids)): + self.point3D_id_to_point3D_idx[point3D_id] = \ + SceneManager.INVALID_POINT3D + + # find dot product between all camera viewing rays + elif check_triangulation_angles: + xyz = self.points3D[point3D_idx,:] + tvecs = np.array( + [(self.images[image_id].tvec - xyz) + for image_id in image_ids]) + tvecs /= np.linalg.norm(tvecs, axis=-1)[:,np.newaxis] + + cos_theta = np.array( + [u.dot(v) for u,v in combinations(tvecs, 2)]) + + # min_prod = cos(maximum viewing angle), and vice versa + # if maximum viewing angle is too small or too large, + # don't add this point + if (np.min(cos_theta) > max_tri_prod or + np.max(cos_theta) < min_tri_prod): + self.point3D_id_to_point3D_idx[point3D_id] = \ + SceneManager.INVALID_POINT3D + + # apply the filters to the image point3D_ids + for image in self.images.itervalues(): + mask = np.array([ + self.point3D_id_to_point3D_idx.get(point3D_id, 0) \ + == SceneManager.INVALID_POINT3D + for point3D_id in image.point3D_ids]) + image.point3D_ids[mask] = SceneManager.INVALID_POINT3D + + #--------------------------------------------------------------------------- + + # scene graph: {image_id: [image_id: #shared points]} + def build_scene_graph(self): + self.scene_graph = defaultdict(lambda: defaultdict(int)) + point3D_iter = self.point3D_id_to_images.iteritems() + + for i, (point3D_id, images) in enumerate(point3D_iter): + if not self.point3D_valid(point3D_id): + continue + + for image_id1, image_id2 in combinations(images[:,0], 2): + self.scene_graph[image_id1][image_id2] += 1 + self.scene_graph[image_id2][image_id1] += 1 diff --git a/submodules/pycolmap/pyproject.toml b/submodules/pycolmap/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..610d833fc084bc48e114c36f52f2a9e3b99945a6 --- /dev/null +++ b/submodules/pycolmap/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "pycolmap" +version = "0.0.1" +dependencies = [ + "numpy", + "scipy", +] +authors = [ + {name = "Ricardo Martin-Brualla"}, +] +description = "A Python API for reading and writing COLMAP-generated files." + +license = {file = "LICENSE.txt"} +readme = "README.md" + diff --git a/submodules/pycolmap/tools/colmap_to_nvm.py b/submodules/pycolmap/tools/colmap_to_nvm.py new file mode 100644 index 0000000000000000000000000000000000000000..38a4ba8ed4b51523e1525abf57331993301adfda --- /dev/null +++ b/submodules/pycolmap/tools/colmap_to_nvm.py @@ -0,0 +1,69 @@ +import itertools +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import Quaternion, SceneManager + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load() + + with open(args.output_file, "w") as fid: + fid.write("NVM_V3\n \n{:d}\n".format(len(scene_manager.images))) + + image_fmt_str = " {:.3f} " + 7 * "{:.7f} " + for image_id, image in scene_manager.images.iteritems(): + camera = scene_manager.cameras[image.camera_id] + f = 0.5 * (camera.fx + camera.fy) + fid.write(args.image_name_prefix + image.name) + fid.write(image_fmt_str.format( + *((f,) + tuple(image.q.q) + tuple(image.C())))) + if camera.distortion_func is None: + fid.write("0 0\n") + else: + fid.write("{:.7f} 0\n".format(-camera.k1)) + + image_id_to_idx = dict( + (image_id, i) for i, image_id in enumerate(scene_manager.images)) + + fid.write("{:d}\n".format(len(scene_manager.points3D))) + for i, point3D_id in enumerate(scene_manager.point3D_ids): + fid.write( + "{:.7f} {:.7f} {:.7f} ".format(*scene_manager.points3D[i])) + fid.write( + "{:d} {:d} {:d} ".format(*scene_manager.point3D_colors[i])) + keypoints = [ + (image_id_to_idx[image_id], kp_idx) + + tuple(scene_manager.images[image_id].points2D[kp_idx]) + for image_id, kp_idx in + scene_manager.point3D_id_to_images[point3D_id]] + fid.write("{:d}".format(len(keypoints))) + fid.write( + (len(keypoints) * " {:d} {:d} {:.3f} {:.3f}" + "\n").format( + *itertools.chain(*keypoints))) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Save a COLMAP reconstruction in the NVM format " + "(http://ccwu.me/vsfm/doc.html#nvm).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_file") + + parser.add_argument("--image_name_prefix", type=str, default="", + help="prefix image names with this string (e.g., 'images/')") + + args = parser.parse_args() + + main(args) diff --git a/submodules/pycolmap/tools/delete_images.py b/submodules/pycolmap/tools/delete_images.py new file mode 100644 index 0000000000000000000000000000000000000000..f17a84a8a2fa842283c032eeffbce4c5a8de37db --- /dev/null +++ b/submodules/pycolmap/tools/delete_images.py @@ -0,0 +1,36 @@ +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import DualQuaternion, Image, SceneManager + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load() + + image_ids = map(scene_manager.get_image_from_name, + iter(lambda: sys.stdin.readline().strip(), "")) + scene_manager.delete_images(image_ids) + + scene_manager.save(args.output_folder) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Deletes images (filenames read from stdin) from a model.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_folder") + + args = parser.parse_args() + + main(args) diff --git a/submodules/pycolmap/tools/impute_missing_cameras.py b/submodules/pycolmap/tools/impute_missing_cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff8f322b7919f24a9b8652221dd09a0f0b16d5b --- /dev/null +++ b/submodules/pycolmap/tools/impute_missing_cameras.py @@ -0,0 +1,180 @@ +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import DualQuaternion, Image, SceneManager + + +#------------------------------------------------------------------------------- + +image_to_idx = lambda im: int(im.name[:im.name.rfind(".")]) + + +#------------------------------------------------------------------------------- + +def interpolate_linear(images, camera_id, file_format): + if len(images) < 2: + raise ValueError("Need at least two images for linear interpolation!") + + prev_image = images[0] + prev_idx = image_to_idx(prev_image) + prev_dq = DualQuaternion.FromQT(prev_image.q, prev_image.t) + start = prev_idx + + new_images = [] + + for image in images[1:]: + curr_idx = image_to_idx(image) + curr_dq = DualQuaternion.FromQT(image.q, image.t) + T = curr_idx - prev_idx + Tinv = 1. / T + + # like quaternions, dq(x) = -dq(x), so we'll need to pick the one more + # appropriate for interpolation by taking -dq if the dot product of the + # two q-vectors is negative + if prev_dq.q0.dot(curr_dq.q0) < 0: + curr_dq = -curr_dq + + for i in xrange(1, T): + t = i * Tinv + dq = t * prev_dq + (1. - t) * curr_dq + q, t = dq.ToQT() + new_images.append( + Image(file_format.format(prev_idx + i), args.camera_id, q, t)) + + prev_idx = curr_idx + prev_dq = curr_dq + + return new_images + + +#------------------------------------------------------------------------------- + +def interpolate_hermite(images, camera_id, file_format): + if len(images) < 4: + raise ValueError( + "Need at least four images for Hermite spline interpolation!") + + new_images = [] + + # linear blending for the first frames + T0 = image_to_idx(images[0]) + dq0 = DualQuaternion.FromQT(images[0].q, images[0].t) + T1 = image_to_idx(images[1]) + dq1 = DualQuaternion.FromQT(images[1].q, images[1].t) + + if dq0.q0.dot(dq1.q0) < 0: + dq1 = -dq1 + dT = 1. / float(T1 - T0) + for j in xrange(1, T1 - T0): + t = j * dT + dq = ((1. - t) * dq0 + t * dq1).normalize() + new_images.append( + Image(file_format.format(T0 + j), camera_id, *dq.ToQT())) + + T2 = image_to_idx(images[2]) + dq2 = DualQuaternion.FromQT(images[2].q, images[2].t) + if dq1.q0.dot(dq2.q0) < 0: + dq2 = -dq2 + + # Hermite spline interpolation of dual quaternions + # pdfs.semanticscholar.org/05b1/8ede7f46c29c2722fed3376d277a1d286c55.pdf + for i in xrange(1, len(images) - 2): + T3 = image_to_idx(images[i + 2]) + dq3 = DualQuaternion.FromQT(images[i + 2].q, images[i + 2].t) + if dq2.q0.dot(dq3.q0) < 0: + dq3 = -dq3 + + prev_duration = T1 - T0 + current_duration = T2 - T1 + next_duration = T3 - T2 + + # approximate the derivatives at dq1 and dq2 using weighted central + # differences + dt1 = 1. / float(T2 - T0) + dt2 = 1. / float(T3 - T1) + + m1 = (current_duration * dt1) * (dq2 - dq1) + \ + (prev_duration * dt1) * (dq1 - dq0) + m2 = (next_duration * dt2) * (dq3 - dq2) + \ + (current_duration * dt2) * (dq2 - dq1) + + dT = 1. / float(current_duration) + + for j in xrange(1, current_duration): + t = j * dT # 0 to 1 + t2 = t * t # t squared + t3 = t2 * t # t cubed + + # coefficients of the Hermite spline (a=>dq and b=>m) + a1 = 2. * t3 - 3. * t2 + 1. + b1 = t3 - 2. * t2 + t + a2 = -2. * t3 + 3. * t2 + b2 = t3 - t2 + + dq = (a1 * dq1 + b1 * m1 + a2 * dq2 + b2 * m2).normalize() + + new_images.append( + Image(file_format.format(T1 + j), camera_id, *dq.ToQT())) + + T0, T1, T2 = T1, T2, T3 + dq0, dq1, dq2 = dq1, dq2, dq3 + + # linear blending for the last frames + dT = 1. / float(T2 - T1) + for j in xrange(1, T2 - T1): + t = j * dT # 0 to 1 + dq = ((1. - t) * dq1 + t * dq2).normalize() + new_images.append( + Image(file_format.format(T1 + j), camera_id, *dq.ToQT())) + + return new_images + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load() + + images = sorted(scene_manager.images.itervalues(), key=image_to_idx) + + if args.method.lower() == "linear": + new_images = interpolate_linear(images, args.camera_id, args.format) + else: + new_images = interpolate_hermite(images, args.camera_id, args.format) + + map(scene_manager.add_image, new_images) + + scene_manager.save(args.output_folder) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Given a reconstruction with ordered images *with integer " + "filenames* like '000100.png', fill in missing camera positions for " + "intermediate frames.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_folder") + + parser.add_argument("--camera_id", type=int, default=1, + help="camera id to use for the missing images") + + parser.add_argument("--format", type=str, default="{:06d}.png", + help="filename format to use for added images") + + parser.add_argument( + "--method", type=str.lower, choices=("linear", "hermite"), + default="hermite", + help="Pose imputation method") + + args = parser.parse_args() + + main(args) diff --git a/submodules/pycolmap/tools/save_cameras_as_ply.py b/submodules/pycolmap/tools/save_cameras_as_ply.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec89506f61e8fbc12e026853c1a1d663a21d658 --- /dev/null +++ b/submodules/pycolmap/tools/save_cameras_as_ply.py @@ -0,0 +1,92 @@ +import sys +sys.path.append("..") + +import numpy as np +import os + +from pycolmap import SceneManager + + +#------------------------------------------------------------------------------- + +# Saves the cameras as a mesh +# +# inputs: +# - ply_file: output file +# - images: ordered array of pycolmap Image objects +# - color: color string for the camera +# - scale: amount to shrink/grow the camera model +def save_camera_ply(ply_file, images, scale): + points3D = scale * np.array(( + (0., 0., 0.), + (-1., -1., 1.), + (-1., 1., 1.), + (1., -1., 1.), + (1., 1., 1.))) + + faces = np.array(((0, 2, 1), + (0, 4, 2), + (0, 3, 4), + (0, 1, 3), + (1, 2, 4), + (1, 4, 3))) + + r = np.linspace(0, 255, len(images), dtype=np.uint8) + g = 255 - r + b = r - np.linspace(0, 128, len(images), dtype=np.uint8) + color = np.column_stack((r, g, b)) + + with open(ply_file, "w") as fid: + print>>fid, "ply" + print>>fid, "format ascii 1.0" + print>>fid, "element vertex", len(points3D) * len(images) + print>>fid, "property float x" + print>>fid, "property float y" + print>>fid, "property float z" + print>>fid, "property uchar red" + print>>fid, "property uchar green" + print>>fid, "property uchar blue" + print>>fid, "element face", len(faces) * len(images) + print>>fid, "property list uchar int vertex_index" + print>>fid, "end_header" + + for image, c in zip(images, color): + for p3D in (points3D.dot(image.R()) + image.C()): + print>>fid, p3D[0], p3D[1], p3D[2], c[0], c[1], c[2] + + for i in xrange(len(images)): + for f in (faces + len(points3D) * i): + print>>fid, "3 {} {} {}".format(*f) + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load_images() + + images = sorted(scene_manager.images.itervalues(), + key=lambda image: image.name) + + save_camera_ply(args.output_file, images, args.scale) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Saves camera positions to a PLY for easy viewing outside " + "of COLMAP. Currently, camera FoV is not reflected in the output.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_file") + + parser.add_argument("--scale", type=float, default=1., + help="Scaling factor for the camera mesh.") + + args = parser.parse_args() + + main(args) diff --git a/submodules/pycolmap/tools/transform_model.py b/submodules/pycolmap/tools/transform_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b22f1b3699a8ec029715324548f8dc823a4f59 --- /dev/null +++ b/submodules/pycolmap/tools/transform_model.py @@ -0,0 +1,48 @@ +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import Quaternion, SceneManager + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load() + + # expect each line of input corresponds to one row + P = np.array([ + map(float, sys.stdin.readline().strip().split()) for _ in xrange(3)]) + + scene_manager.points3D[:] = scene_manager.points3D.dot(P[:,:3].T) + P[:,3] + + # get rotation without any global scaling (assuming isotropic scaling) + scale = np.cbrt(np.linalg.det(P[:,:3])) + q_old_from_new = ~Quaternion.FromR(P[:,:3] / scale) + + for image in scene_manager.images.itervalues(): + image.q *= q_old_from_new + image.tvec = scale * image.tvec - image.R().dot(P[:,3]) + + scene_manager.save(args.output_folder) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Apply a 3x4 transformation matrix to a COLMAP model and " + "save the result as a new model. Row-major input can be piped in from " + "a file or entered via the command line.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_folder") + + args = parser.parse_args() + + main(args) diff --git a/submodules/pycolmap/tools/write_camera_track_to_bundler.py b/submodules/pycolmap/tools/write_camera_track_to_bundler.py new file mode 100644 index 0000000000000000000000000000000000000000..66fc91ab4ee84eb01e3445fc407c58b9acda5be7 --- /dev/null +++ b/submodules/pycolmap/tools/write_camera_track_to_bundler.py @@ -0,0 +1,60 @@ +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import SceneManager + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load_cameras() + scene_manager.load_images() + + if args.sort: + images = sorted( + scene_manager.images.itervalues(), key=lambda im: im.name) + else: + images = scene_manager.images.values() + + fid = open(args.output_file, "w") + fid_filenames = open(args.output_file + ".list.txt", "w") + + print>>fid, "# Bundle file v0.3" + print>>fid, len(images), 0 + + for image in images: + print>>fid_filenames, image.name + camera = scene_manager.cameras[image.camera_id] + print>>fid, 0.5 * (camera.fx + camera.fy), 0, 0 + R, t = image.R(), image.t + print>>fid, R[0, 0], R[0, 1], R[0, 2] + print>>fid, -R[1, 0], -R[1, 1], -R[1, 2] + print>>fid, -R[2, 0], -R[2, 1], -R[2, 2] + print>>fid, t[0], -t[1], -t[2] + + fid.close() + fid_filenames.close() + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Saves the camera positions in the Bundler format. Note " + "that 3D points are not saved.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_file") + + parser.add_argument("--sort", default=False, action="store_true", + help="sort the images by their filename") + + args = parser.parse_args() + + main(args) diff --git a/submodules/pycolmap/tools/write_depthmap_to_ply.py b/submodules/pycolmap/tools/write_depthmap_to_ply.py new file mode 100644 index 0000000000000000000000000000000000000000..967eef0464aa60444076b8fcedae4a378943b126 --- /dev/null +++ b/submodules/pycolmap/tools/write_depthmap_to_ply.py @@ -0,0 +1,139 @@ +import sys +sys.path.append("..") + +import imageio +import numpy as np +import os + +from plyfile import PlyData, PlyElement +from pycolmap import SceneManager +from scipy.ndimage.interpolation import zoom + + +#------------------------------------------------------------------------------- + +def main(args): + suffix = ".photometric.bin" if args.photometric else ".geometric.bin" + + image_file = os.path.join(args.dense_folder, "images", args.image_filename) + depth_file = os.path.join( + args.dense_folder, args.stereo_folder, "depth_maps", + args.image_filename + suffix) + if args.save_normals: + normals_file = os.path.join( + args.dense_folder, args.stereo_folder, "normal_maps", + args.image_filename + suffix) + + # load camera intrinsics from the COLMAP reconstruction + scene_manager = SceneManager(os.path.join(args.dense_folder, "sparse")) + scene_manager.load_cameras() + scene_manager.load_images() + + image_id, image = scene_manager.get_image_from_name(args.image_filename) + camera = scene_manager.cameras[image.camera_id] + rotation_camera_from_world = image.R() + camera_center = image.C() + + # load image, depth map, and normal map + image = imageio.imread(image_file) + + with open(depth_file, "rb") as fid: + w = int("".join(iter(lambda: fid.read(1), "&"))) + h = int("".join(iter(lambda: fid.read(1), "&"))) + c = int("".join(iter(lambda: fid.read(1), "&"))) + depth_map = np.fromfile(fid, np.float32).reshape(h, w) + if (h, w) != image.shape[:2]: + depth_map = zoom( + depth_map, + (float(image.shape[0]) / h, float(image.shape[1]) / w), + order=0) + + if args.save_normals: + with open(normals_file, "rb") as fid: + w = int("".join(iter(lambda: fid.read(1), "&"))) + h = int("".join(iter(lambda: fid.read(1), "&"))) + c = int("".join(iter(lambda: fid.read(1), "&"))) + normals = np.fromfile( + fid, np.float32).reshape(c, h, w).transpose([1, 2, 0]) + if (h, w) != image.shape[:2]: + normals = zoom( + normals, + (float(image.shape[0]) / h, float(image.shape[1]) / w, 1.), + order=0) + + if args.min_depth is not None: + depth_map[depth_map < args.min_depth] = 0. + if args.max_depth is not None: + depth_map[depth_map > args.max_depth] = 0. + + # create 3D points + #depth_map = np.minimum(depth_map, 100.) + points3D = np.dstack(camera.get_image_grid() + [depth_map]) + points3D[:,:,:2] *= depth_map[:,:,np.newaxis] + + # save + points3D = points3D.astype(np.float32).reshape(-1, 3) + if args.save_normals: + normals = normals.astype(np.float32).reshape(-1, 3) + image = image.reshape(-1, 3) + if image.dtype != np.uint8: + if image.max() <= 1: + image = (image * 255.).astype(np.uint8) + else: + image = image.astype(np.uint8) + + if args.world_space: + points3D = points3D.dot(rotation_camera_from_world) + camera_center + if args.save_normals: + normals = normals.dot(rotation_camera_from_world) + + if args.save_normals: + vertices = np.rec.fromarrays( + tuple(points3D.T) + tuple(normals.T) + tuple(image.T), + names="x,y,z,nx,ny,nz,red,green,blue") + else: + vertices = np.rec.fromarrays( + tuple(points3D.T) + tuple(image.T), names="x,y,z,red,green,blue") + vertices = PlyElement.describe(vertices, "vertex") + PlyData([vertices]).write(args.output_filename) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("dense_folder", type=str) + parser.add_argument("image_filename", type=str) + parser.add_argument("output_filename", type=str) + + parser.add_argument( + "--photometric", default=False, action="store_true", + help="use photometric depthmap instead of geometric") + + parser.add_argument( + "--world_space", default=False, action="store_true", + help="apply the camera->world extrinsic transformation to the result") + + parser.add_argument( + "--save_normals", default=False, action="store_true", + help="load the estimated normal map and save as part of the PLY") + + parser.add_argument( + "--stereo_folder", type=str, default="stereo", + help="folder in the dense workspace containing depth and normal maps") + + parser.add_argument( + "--min_depth", type=float, default=None, + help="set pixels with depth less than this value to zero depth") + + parser.add_argument( + "--max_depth", type=float, default=None, + help="set pixels with depth greater than this value to zero depth") + + args = parser.parse_args() + + main(args) diff --git a/submodules/simple-knn/LICENSE.md b/submodules/simple-knn/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..18445c6d34aedbf1ab9d282223f8f10ce38cd79a --- /dev/null +++ b/submodules/simple-knn/LICENSE.md @@ -0,0 +1,91 @@ +Gaussian-Splatting License +=========================== + +**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. +The *Software* is in the process of being registered with the Agence pour la Protection des +Programmes (APP). + +The *Software* is still being developed by the *Licensor*. + +*Licensor*'s goal is to allow the research community to use, test and evaluate +the *Software*. + +## 1. Definitions + +*Licensee* means any person or entity that uses the *Software* and distributes +its *Work*. + +*Licensor* means the owners of the *Software*, i.e Inria and MPII + +*Software* means the original work of authorship made available under this +License ie gaussian-splatting. + +*Work* means the *Software* and any additions to or derivative works of the +*Software* that are made available under this License. + + +## 2. Purpose +This license is intended to define the rights granted to the *Licensee* by +Licensors under the *Software*. + +## 3. Rights granted + +For the above reasons Licensors have decided to distribute the *Software*. +Licensors grant non-exclusive rights to use the *Software* for research purposes +to research users (both academic and industrial), free of charge, without right +to sublicense.. The *Software* may be used "non-commercially", i.e., for research +and/or evaluation purposes only. + +Subject to the terms and conditions of this License, you are granted a +non-exclusive, royalty-free, license to reproduce, prepare derivative works of, +publicly display, publicly perform and distribute its *Work* and any resulting +derivative works in any form. + +## 4. Limitations + +**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do +so under this License, (b) you include a complete copy of this License with +your distribution, and (c) you retain without modification any copyright, +patent, trademark, or attribution notices that are present in the *Work*. + +**4.2 Derivative Works.** You may specify that additional or different terms apply +to the use, reproduction, and distribution of your derivative works of the *Work* +("Your Terms") only if (a) Your Terms provide that the use limitation in +Section 2 applies to your derivative works, and (b) you identify the specific +derivative works that are subject to Your Terms. Notwithstanding Your Terms, +this License (including the redistribution requirements in Section 3.1) will +continue to apply to the *Work* itself. + +**4.3** Any other use without of prior consent of Licensors is prohibited. Research +users explicitly acknowledge having received from Licensors all information +allowing to appreciate the adequacy between of the *Software* and their needs and +to undertake all necessary precautions for its execution and use. + +**4.4** The *Software* is provided both as a compiled library file and as source +code. In case of using the *Software* for a publication or other results obtained +through the use of the *Software*, users are strongly encouraged to cite the +corresponding publications as explained in the documentation of the *Software*. + +## 5. Disclaimer + +THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES +WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY +UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL +CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES +OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL +USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR +ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE +AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR +IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. + +## 6. Files subject to permissive licenses +The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. + +Title: pytorch-ssim\ +Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\ +Copyright Evan Su, 2017\ +License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT) \ No newline at end of file diff --git a/submodules/simple-knn/ext.cpp b/submodules/simple-knn/ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae6cefe6ce61a38352a88d07b69a8e6cb9de5b31 --- /dev/null +++ b/submodules/simple-knn/ext.cpp @@ -0,0 +1,17 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include +#include "spatial.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("distCUDA2", &distCUDA2); +} diff --git a/submodules/simple-knn/setup.py b/submodules/simple-knn/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..580d2bd8dc190ce642d87501d53de4f6d9d46c64 --- /dev/null +++ b/submodules/simple-knn/setup.py @@ -0,0 +1,35 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from setuptools import setup +from torch.utils.cpp_extension import CUDAExtension, BuildExtension +import os + +cxx_compiler_flags = [] + +if os.name == 'nt': + cxx_compiler_flags.append("/wd4624") + +setup( + name="simple_knn", + ext_modules=[ + CUDAExtension( + name="simple_knn._C", + sources=[ + "spatial.cu", + "simple_knn.cu", + "ext.cpp"], + extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/submodules/simple-knn/simple_knn.cu b/submodules/simple-knn/simple_knn.cu new file mode 100644 index 0000000000000000000000000000000000000000..d7f85592a39f955eafe7b8d1c3a453c1303c752b --- /dev/null +++ b/submodules/simple-knn/simple_knn.cu @@ -0,0 +1,222 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#define BOX_SIZE 1024 + +#include "cuda_runtime.h" +#include "device_launch_parameters.h" +#include "simple_knn.h" +#include +#include +#include +#include +#include +#include +#include +#define __CUDACC__ +#include +#include + +namespace cg = cooperative_groups; + +struct CustomMin +{ + __device__ __forceinline__ + float3 operator()(const float3& a, const float3& b) const { + return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; + } +}; + +struct CustomMax +{ + __device__ __forceinline__ + float3 operator()(const float3& a, const float3& b) const { + return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; + } +}; + +__host__ __device__ uint32_t prepMorton(uint32_t x) +{ + x = (x | (x << 16)) & 0x030000FF; + x = (x | (x << 8)) & 0x0300F00F; + x = (x | (x << 4)) & 0x030C30C3; + x = (x | (x << 2)) & 0x09249249; + return x; +} + +__host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) +{ + uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); + uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); + uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); + + return x | (y << 1) | (z << 2); +} + +__global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) +{ + auto idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + codes[idx] = coord2Morton(points[idx], minn, maxx); +} + +struct MinMax +{ + float3 minn; + float3 maxx; +}; + +__global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) +{ + auto idx = cg::this_grid().thread_rank(); + + MinMax me; + if (idx < P) + { + me.minn = points[indices[idx]]; + me.maxx = points[indices[idx]]; + } + else + { + me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; + me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; + } + + __shared__ MinMax redResult[BOX_SIZE]; + + for (int off = BOX_SIZE / 2; off >= 1; off /= 2) + { + if (threadIdx.x < 2 * off) + redResult[threadIdx.x] = me; + __syncthreads(); + + if (threadIdx.x < off) + { + MinMax other = redResult[threadIdx.x + off]; + me.minn.x = min(me.minn.x, other.minn.x); + me.minn.y = min(me.minn.y, other.minn.y); + me.minn.z = min(me.minn.z, other.minn.z); + me.maxx.x = max(me.maxx.x, other.maxx.x); + me.maxx.y = max(me.maxx.y, other.maxx.y); + me.maxx.z = max(me.maxx.z, other.maxx.z); + } + __syncthreads(); + } + + if (threadIdx.x == 0) + boxes[blockIdx.x] = me; +} + +__device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) +{ + float3 diff = { 0, 0, 0 }; + if (p.x < box.minn.x || p.x > box.maxx.x) + diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); + if (p.y < box.minn.y || p.y > box.maxx.y) + diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); + if (p.z < box.minn.z || p.z > box.maxx.z) + diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); + return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; +} + +template +__device__ void updateKBest(const float3& ref, const float3& point, float* knn) +{ + float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; + float dist = d.x * d.x + d.y * d.y + d.z * d.z; + for (int j = 0; j < K; j++) + { + if (knn[j] > dist) + { + float t = knn[j]; + knn[j] = dist; + dist = t; + } + } +} + +__global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) +{ + int idx = cg::this_grid().thread_rank(); + if (idx >= P) + return; + + float3 point = points[indices[idx]]; + float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; + + for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) + { + if (i == idx) + continue; + updateKBest<3>(point, points[indices[i]], best); + } + + float reject = best[2]; + best[0] = FLT_MAX; + best[1] = FLT_MAX; + best[2] = FLT_MAX; + + for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) + { + MinMax box = boxes[b]; + float dist = distBoxPoint(box, point); + if (dist > reject || dist > best[2]) + continue; + + for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) + { + if (i == idx) + continue; + updateKBest<3>(point, points[indices[i]], best); + } + } + dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; +} + +void SimpleKNN::knn(int P, float3* points, float* meanDists) +{ + float3* result; + cudaMalloc(&result, sizeof(float3)); + size_t temp_storage_bytes; + + float3 init = { 0, 0, 0 }, minn, maxx; + + cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); + thrust::device_vector temp_storage(temp_storage_bytes); + + cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); + cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); + + cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); + cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); + + thrust::device_vector morton(P); + thrust::device_vector morton_sorted(P); + coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); + + thrust::device_vector indices(P); + thrust::sequence(indices.begin(), indices.end()); + thrust::device_vector indices_sorted(P); + + cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); + temp_storage.resize(temp_storage_bytes); + + cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); + + uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; + thrust::device_vector boxes(num_boxes); + boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); + boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); + + cudaFree(result); +} \ No newline at end of file diff --git a/submodules/simple-knn/simple_knn.h b/submodules/simple-knn/simple_knn.h new file mode 100644 index 0000000000000000000000000000000000000000..3fcfdb87c53faaadc1fd820d5deeb1b2b5c21a86 --- /dev/null +++ b/submodules/simple-knn/simple_knn.h @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#ifndef SIMPLEKNN_H_INCLUDED +#define SIMPLEKNN_H_INCLUDED + +class SimpleKNN +{ +public: + static void knn(int P, float3* points, float* meanDists); +}; + +#endif \ No newline at end of file diff --git a/submodules/simple-knn/simple_knn/.gitkeep b/submodules/simple-knn/simple_knn/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/submodules/simple-knn/spatial.cu b/submodules/simple-knn/spatial.cu new file mode 100644 index 0000000000000000000000000000000000000000..1a6a654ba6f8c6a1856a40d14fb7a53c96602bad --- /dev/null +++ b/submodules/simple-knn/spatial.cu @@ -0,0 +1,26 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include "spatial.h" +#include "simple_knn.h" + +torch::Tensor +distCUDA2(const torch::Tensor& points) +{ + const int P = points.size(0); + + auto float_opts = points.options().dtype(torch::kFloat32); + torch::Tensor means = torch::full({P}, 0.0, float_opts); + + SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); + + return means; +} \ No newline at end of file diff --git a/submodules/simple-knn/spatial.h b/submodules/simple-knn/spatial.h new file mode 100644 index 0000000000000000000000000000000000000000..280c953a0321a769e433a43535fd36c251b730f0 --- /dev/null +++ b/submodules/simple-knn/spatial.h @@ -0,0 +1,14 @@ +/* + * Copyright (C) 2023, Inria + * GRAPHDECO research group, https://team.inria.fr/graphdeco + * All rights reserved. + * + * This software is free for non-commercial, research and evaluation use + * under the terms of the LICENSE.md file. + * + * For inquiries contact george.drettakis@inria.fr + */ + +#include + +torch::Tensor distCUDA2(const torch::Tensor& points); \ No newline at end of file