diff --git a/.gitattributes b/.gitattributes index d9dffcfb45bbc505e0f1b9ec34891c79c0f36835..edf4ed9ffbb55204021ba00f4370972a0b393943 100644 --- a/.gitattributes +++ b/.gitattributes @@ -3,3 +3,4 @@ *.safetensors filter=lfs diff=lfs merge=lfs -text *.bin filter=lfs diff=lfs merge=lfs -text examples/omnirooms/*.jpg filter=lfs diff=lfs merge=lfs -text +unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/unisharp/.DS_Store b/unisharp/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..055ffb11016a90800e37af080af58b6e38d0a6c7 Binary files /dev/null and b/unisharp/.DS_Store differ diff --git a/unisharp/__init__.py b/unisharp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c75ef965cced1c45a6a3492f0568c29d124d9cc2 --- /dev/null +++ b/unisharp/__init__.py @@ -0,0 +1 @@ +DEFAULT_MAX_DEPTH_M: float = 100.0 diff --git a/unisharp/cli/__init__.py b/unisharp/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef3954e8d406b1da9a93d1f25cd46d2b42f96c9 --- /dev/null +++ b/unisharp/cli/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import click + +from .train_feature import train_feature_cli + + +@click.group() +def main_cli(): + pass + +main_cli.add_command(train_feature_cli, "train-feature") + diff --git a/unisharp/cli/__main__.py b/unisharp/cli/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0d3a6224dd60eae9a80c302209e094d58141f91 --- /dev/null +++ b/unisharp/cli/__main__.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from unisharp.cli import main_cli + + +def main() -> None: + main_cli() + + +if __name__ == "__main__": + main() + diff --git a/unisharp/cli/__pycache__/__init__.cpython-310.pyc b/unisharp/cli/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f8b4b9d76fb17ae015848a4d1af504f930522dd Binary files /dev/null and b/unisharp/cli/__pycache__/__init__.cpython-310.pyc differ diff --git a/unisharp/cli/__pycache__/__init__.cpython-313.pyc b/unisharp/cli/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab4f2fbcf1c24623999a9d75042787d0b9cd0e2e Binary files /dev/null and b/unisharp/cli/__pycache__/__init__.cpython-313.pyc differ diff --git a/unisharp/cli/__pycache__/mixed_sampler.cpython-313.pyc b/unisharp/cli/__pycache__/mixed_sampler.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5e45628b7eb6d894c3d3119183baea2985c48d8 Binary files /dev/null and b/unisharp/cli/__pycache__/mixed_sampler.cpython-313.pyc differ diff --git a/unisharp/cli/__pycache__/train_feature.cpython-310.pyc b/unisharp/cli/__pycache__/train_feature.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8318530c07046fd5c96c7a29359b1d606f94f49 Binary files /dev/null and b/unisharp/cli/__pycache__/train_feature.cpython-310.pyc differ diff --git a/unisharp/cli/__pycache__/train_feature.cpython-313.pyc b/unisharp/cli/__pycache__/train_feature.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..611e10e6962d4a01fb2282612353d4c355497d7b Binary files /dev/null and b/unisharp/cli/__pycache__/train_feature.cpython-313.pyc differ diff --git a/unisharp/cli/__pycache__/train_utils.cpython-313.pyc b/unisharp/cli/__pycache__/train_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9b1b5d58d79057072ebcab66641d7ac5e00c3aa Binary files /dev/null and b/unisharp/cli/__pycache__/train_utils.cpython-313.pyc differ diff --git a/unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc b/unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a403ef4465911fb0638ebceb6b7897149441dca7 --- /dev/null +++ b/unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2698667885fba54eef04bacbbee4bbf897c0dc3df57e6fe7d10ba185a76d2ed +size 103553 diff --git a/unisharp/cli/mixed_sampler.py b/unisharp/cli/mixed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3fd75221316a4c950d73cc61067ccfaf876a3c --- /dev/null +++ b/unisharp/cli/mixed_sampler.py @@ -0,0 +1,80 @@ + +from __future__ import annotations + +import random +from typing import Any, Iterator + +from torch.utils.data import Dataset, IterableDataset + + +class LazyDataLoaderIterator: + + def __init__(self, dataloader: Any): + self.dataloader = dataloader + self.iterator: Iterator[Any] | None = None + + def __next__(self) -> Any: + if self.iterator is None: + self.iterator = iter(self.dataloader) + return next(self.iterator) + + +class MixedDatasetSampler: + + def __init__( + self, + datasets: dict[str, Dataset | IterableDataset], + weights: dict[str, float], + iterators: dict[str, Iterator[Any]], + seed: int | None = None, + ): + self.datasets = datasets + self.weights = weights + self.iterators = iterators + self._rng = random.Random(seed) + + if len(weights) == 0: + raise ValueError("weights is empty") + for name, w in weights.items(): + if float(w) <= 0.0: + raise ValueError(f"Dataset weight must be > 0, got {name}={float(w)}") + if name not in datasets: + raise ValueError(f"Unknown dataset in weights: {name}") + if name not in iterators: + raise ValueError(f"Missing iterator for dataset: {name}") + + total_weight = float(sum(float(v) for v in weights.values())) + self.probs = {name: float(w) / total_weight for name, w in weights.items()} + self.dataset_names = list(datasets.keys()) + self.prob_list = [self.probs[name] for name in self.dataset_names] + + def sample(self) -> tuple[str, Any]: + dataset_name = self.choose_dataset_name() + batch = self.next_batch(dataset_name) + return dataset_name, batch + + def choose_dataset_name(self, allowed_dataset_names: list[str] | None = None) -> str: + if allowed_dataset_names is None: + names = self.dataset_names + probs = self.prob_list + else: + names = [name for name in self.dataset_names if name in set(allowed_dataset_names)] + if len(names) == 0: + raise ValueError("No allowed dataset names available for sampling.") + probs = [self.probs[name] for name in names] + return self._rng.choices(names, weights=probs, k=1)[0] + + def next_batch(self, dataset_name: str) -> Any: + if dataset_name not in self.iterators: + raise ValueError(f"Unknown dataset iterator: {dataset_name}") + try: + batch = next(self.iterators[dataset_name]) + except StopIteration as exc: + raise StopIteration(f"Dataset {dataset_name} exhausted") from exc + return batch + + def get_sampling_stats(self) -> dict[str, float]: + return { + "probabilities": self.probs.copy(), + "sampling": self.weights.copy(), + } diff --git a/unisharp/cli/train_feature.py b/unisharp/cli/train_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..98b76bfd6364d0eaf5fc95b7e43e5730751045d1 --- /dev/null +++ b/unisharp/cli/train_feature.py @@ -0,0 +1,1410 @@ +from __future__ import annotations + +import csv +import json +import logging +import os +import random +import sys +import time +from dataclasses import fields, is_dataclass, replace +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any + +import click +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from unisharp.datasets.re10k import Re10KDataset, re10k_collate, re10k_passthrough +from unisharp.datasets.wildrgbd import WildRGBDDataset, wildrgbd_collate +from unisharp.datasets.dl3dv import DL3DVDataset +from unisharp.datasets.scannetpp_fisheye import ScannetppFisheyeDataset, scannetpp_fisheye_passthrough +from unisharp.datasets.sim_panorama import SimPanoramaDataset +from unisharp.datasets.panogs import PanOGSDataset, panogs_collate +from unisharp.losses import UnisharpLoss, UnisharpLossWeights +from unisharp.models.unisharp_feature import UnisharpFeatureModel, UnisharpFeatureConfig +from unisharp.utils import logging as logging_utils +from unisharp import DEFAULT_MAX_DEPTH_M +from unisharp.utils.gsplat import GSplatRenderer +from unisharp.utils.io import save_image +from unisharp.utils.rayfit_camera import scale_pinhole_intrinsics +from unisharp.utils.unified_vis import save_pair_visualization + +from .mixed_sampler import LazyDataLoaderIterator, MixedDatasetSampler # type: ignore[import] +from .train_utils import warmup_cosine_lr # type: ignore[import] + +LOGGER = logging.getLogger(__name__) +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def _default_dataset_manifest_file(name: str) -> Path: + parent_path = REPO_ROOT.parent / "dataset_manifests" / name + if parent_path.exists(): + return parent_path + return REPO_ROOT / "dataset_manifests" / name + + +DEFAULT_WILDRGBD_ROOTS_FILE = _default_dataset_manifest_file("wildrgbd_roots.txt") + + +def _multiple_aligned_hw(hw: tuple[int, int], multiple: int) -> tuple[int, int]: + h, w = int(hw[0]), int(hw[1]) + m = int(multiple) + if m <= 1: + return h, w + out_h = max(m, (h // m) * m) + out_w = max(m, (w // m) * m) + return min(out_h, h), min(out_w, w) + + +def _erp_multiple_aligned_hw(hw: tuple[int, int], multiple: int) -> tuple[int, int]: + h, w = int(hw[0]), int(hw[1]) + m = int(multiple) + if m <= 1: + return h, w + max_h_from_h = h // m + max_h_from_w = w // (2 * m) + h_units = min(max_h_from_h, max_h_from_w) + if h_units <= 0: + return h, w + out_h = h_units * m + return out_h, 2 * out_h + + +def _resize_chw_tensor(x: torch.Tensor, dst_hw: tuple[int, int], *, kind: str) -> torch.Tensor: + if not torch.is_tensor(x) or x.ndim < 3: + return x + src_hw = (int(x.shape[-2]), int(x.shape[-1])) + if src_hw == tuple(int(v) for v in dst_hw): + return x + orig_dtype = x.dtype + flat = x.reshape(-1, int(x.shape[-3]), src_hw[0], src_hw[1]).to(dtype=torch.float32) + if kind == "image": + y = F.interpolate(flat, size=dst_hw, mode="bilinear", align_corners=False) + y = y.round().clamp(0.0, 255.0).to(dtype=orig_dtype) if orig_dtype == torch.uint8 else y.to(dtype=orig_dtype) + elif kind == "ray": + y = F.interpolate(flat, size=dst_hw, mode="bilinear", align_corners=False) + y = y / torch.linalg.vector_norm(y, dim=1, keepdim=True).clamp(min=1e-6) + y = y.to(dtype=orig_dtype) + else: + y = F.interpolate(flat, size=dst_hw, mode="nearest").to(dtype=orig_dtype) + return y.reshape(*x.shape[:-2], int(dst_hw[0]), int(dst_hw[1])).contiguous() + + +def _resize_cube_tensor(x: torch.Tensor, dst_hw: tuple[int, int], *, kind: str) -> torch.Tensor: + if not torch.is_tensor(x) or x.ndim < 4: + return x + src_hw = (int(x.shape[-3]), int(x.shape[-2])) + if src_hw == tuple(int(v) for v in dst_hw): + return x + orig_dtype = x.dtype + channels = int(x.shape[-1]) + flat = x.reshape(-1, src_hw[0], src_hw[1], channels).permute(0, 3, 1, 2).to(dtype=torch.float32) + if kind == "image": + y = F.interpolate(flat, size=dst_hw, mode="bilinear", align_corners=False) + y = y.round().clamp(0.0, 255.0).to(dtype=orig_dtype) if orig_dtype == torch.uint8 else y.to(dtype=orig_dtype) + else: + y = F.interpolate(flat, size=dst_hw, mode="nearest").to(dtype=orig_dtype) + y = y.permute(0, 2, 3, 1) + return y.reshape(*x.shape[:-3], int(dst_hw[0]), int(dst_hw[1]), channels).contiguous() + + +def _training_batch_src_hw(batch: Any) -> tuple[int, int] | None: + for name in ("src_rgb_u8", "src_erp_rgb_u8"): + value = getattr(batch, name, None) + if torch.is_tensor(value) and value.ndim >= 3: + return int(value.shape[-2]), int(value.shape[-1]) + return None + + +def _scale_fisheye624_params_any(params: torch.Tensor, *, src_hw: tuple[int, int], dst_hw: tuple[int, int]) -> torch.Tensor: + if tuple(int(x) for x in src_hw) == tuple(int(x) for x in dst_hw): + return params + src_h, src_w = int(src_hw[0]), int(src_hw[1]) + dst_h, dst_w = int(dst_hw[0]), int(dst_hw[1]) + sx = float(dst_w) / float(max(src_w, 1)) + sy = float(dst_h) / float(max(src_h, 1)) + out = params.clone() + out[..., 0] *= sx + out[..., 1] *= sy + out[..., 2] = (out[..., 2] + 0.5) * sx - 0.5 + out[..., 3] = (out[..., 3] + 0.5) * sy - 0.5 + return out + + +def _resize_training_batch_to_multiple(batch: Any, multiple: int) -> Any: + if int(multiple) <= 1 or not is_dataclass(batch): + return batch + src_hw = _training_batch_src_hw(batch) + if src_hw is None: + return batch + + def _view_hw(prefix: str) -> tuple[int, int] | None: + for rgb_name in (f"{prefix}_rgb_u8", f"{prefix}_erp_rgb_u8"): + rgb = getattr(batch, rgb_name, None) + if torch.is_tensor(rgb) and rgb.ndim >= 3: + return int(rgb.shape[-2]), int(rgb.shape[-1]) + return None + + def _aligned_view_hw(prefix: str, hw: tuple[int, int]) -> tuple[int, int]: + is_view_erp = torch.is_tensor(getattr(batch, f"{prefix}_erp_rgb_u8", None)) + return ( + _erp_multiple_aligned_hw(hw, int(multiple)) + if bool(is_view_erp) + else _multiple_aligned_hw(hw, int(multiple)) + ) + + def _field_dst_hw(name: str, value: torch.Tensor) -> tuple[int, int]: + prefix = "tgt" if name.startswith("tgt_") else "src" + view_hw = _view_hw(prefix) + if view_hw is not None: + return _aligned_view_hw(prefix, view_hw) + hw = (int(value.shape[-2]), int(value.shape[-1])) + return _erp_multiple_aligned_hw(hw, int(multiple)) if "_erp_" in name else _multiple_aligned_hw(hw, int(multiple)) + + updates: dict[str, Any] = {} + for field in fields(batch): + name = field.name + value = getattr(batch, name) + if not torch.is_tensor(value): + continue + if name.endswith("_rgb_u8") and value.ndim >= 3: + if "_cube_" in name: + cube_hw = _multiple_aligned_hw((int(value.shape[-3]), int(value.shape[-2])), int(multiple)) + updates[name] = _resize_cube_tensor(value, cube_hw, kind="image") + else: + updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="image") + elif name.endswith("_depth_m") and value.ndim >= 3: + if "_cube_" in name: + cube_hw = _multiple_aligned_hw((int(value.shape[-3]), int(value.shape[-2])), int(multiple)) + updates[name] = _resize_cube_tensor(value, cube_hw, kind="depth") + else: + updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="depth") + elif name.endswith("_valid_mask") and value.ndim >= 3: + updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="depth") + elif name.endswith("_rays") and value.ndim >= 3: + updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="ray") + + for intr_name in ("src_intrinsics", "tgt_intrinsics"): + intr = getattr(batch, intr_name, None) + if torch.is_tensor(intr): + prefix = "tgt" if intr_name.startswith("tgt_") else "src" + view_hw = _view_hw(prefix) + if view_hw is not None: + updates[intr_name] = scale_pinhole_intrinsics( + intr, + src_hw=view_hw, + dst_hw=_aligned_view_hw(prefix, view_hw), + ) + for params_name in ("src_camera_params", "tgt_camera_params"): + params = getattr(batch, params_name, None) + if torch.is_tensor(params): + prefix = "tgt" if params_name.startswith("tgt_") else "src" + view_hw = _view_hw(prefix) + if view_hw is not None: + updates[params_name] = _scale_fisheye624_params_any( + params, + src_hw=view_hw, + dst_hw=_aligned_view_hw(prefix, view_hw), + ) + + return replace(batch, **updates) if updates else batch + + +def _build_optimizer_param_groups( + raw_model: UnisharpFeatureModel, +) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter], list[torch.nn.Parameter]]: + base_params: list[torch.nn.Parameter] = [] + unik3d_encoder_params: list[torch.nn.Parameter] = [] + unik3d_decoder_params: list[torch.nn.Parameter] = [] + for name, param in raw_model.named_parameters(): + if not param.requires_grad: + continue + if name.startswith("feature_extractor.unik3d.pixel_encoder."): + unik3d_encoder_params.append(param) + elif name.startswith("second_layer_depth_head."): + unik3d_decoder_params.append(param) + elif name.startswith("feature_extractor.unik3d."): + unik3d_decoder_params.append(param) + else: + base_params.append(param) + return base_params, unik3d_encoder_params, unik3d_decoder_params + + +def _count_numel(params: list[torch.nn.Parameter]) -> int: + return int(sum(int(p.numel()) for p in params)) + + +def _configure_torchhub_cache() -> Path: + torchhub_dir = REPO_ROOT / "checkpoints" / "torchhub" + torchhub_dir.mkdir(parents=True, exist_ok=True) + os.environ["TORCH_HOME"] = str(torchhub_dir) + torch.hub.set_dir(str(torchhub_dir)) + return torchhub_dir + + +def _ddp_is_enabled() -> bool: + return int(os.environ.get("WORLD_SIZE", "1")) > 1 + + +def _ddp_setup(device: str, ddp_timeout_hours: float = 8.0) -> tuple[torch.device, int, int, bool]: + if not _ddp_is_enabled(): + dev = torch.device(device) + return dev, 0, 1, True + + if device != "cuda": + raise RuntimeError("DDP currently supports CUDA only.") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available.") + + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + torch.cuda.set_device(local_rank) + timeout_hours = max(float(ddp_timeout_hours), 0.25) + if rank == 0: + print( + "[ddp_setup] init_process_group backend=nccl " + f"world_size={world_size} NCCL_NET={os.environ.get('NCCL_NET', '')} " + f"NCCL_IB_DISABLE={os.environ.get('NCCL_IB_DISABLE', '')}", + flush=True, + ) + dist.init_process_group(backend="nccl", timeout=timedelta(hours=timeout_hours)) + if rank == 0: + print("[ddp_setup] init_process_group done", flush=True) + dev = torch.device("cuda", local_rank) + return dev, rank, world_size, (rank == 0) + + +def _ddp_broadcast_path(p: Path, is_main: bool) -> Path: + if not _ddp_is_enabled(): + return p + obj_list: list[str] = [str(p) if is_main else ""] + dist.broadcast_object_list(obj_list, src=0) + return Path(obj_list[0]) + + +def _ddp_broadcast_str(value: str, is_main: bool) -> str: + if not _ddp_is_enabled(): + return value + obj_list: list[str] = [str(value) if is_main else ""] + dist.broadcast_object_list(obj_list, src=0) + return str(obj_list[0]) + + +def _ddp_any_bool(flag: bool, device: torch.device) -> bool: + if not _ddp_is_enabled(): + return bool(flag) + x = torch.tensor(1 if flag else 0, device=device, dtype=torch.int32) + dist.all_reduce(x, op=dist.ReduceOp.MAX) + return bool(int(x.item()) != 0) + + +def _env_flag(name: str, default: bool = False) -> bool: + raw = os.environ.get(name) + if raw is None: + return bool(default) + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def _is_oom_exception(exc: BaseException) -> bool: + if isinstance(exc, torch.cuda.OutOfMemoryError): + return True + msg = str(exc).lower() + oom_markers = ( + "out of memory", + "cuda error: out of memory", + "cublas_status_alloc_failed", + "cudnn_status_alloc_failed", + "defaultcpuallocator", + ) + return any(marker in msg for marker in oom_markers) + + +def _ddp_barrier(device: torch.device) -> None: + if not _ddp_is_enabled(): + return + if device.type == "cuda" and device.index is not None: + dist.barrier(device_ids=[device.index]) + else: + dist.barrier() + + +def _maybe_set_dataset_epoch(dataset: Any, epoch: int) -> None: + set_epoch = getattr(dataset, "set_epoch", None) + if callable(set_epoch): + set_epoch(int(epoch)) + + +def _ddp_mean(x: torch.Tensor) -> torch.Tensor: + if not _ddp_is_enabled(): + return x + y = x.detach().clone() + dist.all_reduce(y, op=dist.ReduceOp.SUM) + y = y / float(dist.get_world_size()) + return y + + +def _save_train_vis( + out_dir: Path, + step: int, + src_gt: torch.Tensor, + src_pred: torch.Tensor, + src_alpha: torch.Tensor, + tgt_gt: torch.Tensor, + tgt_pred: torch.Tensor, + tgt_alpha: torch.Tensor, + src_gt_depth: torch.Tensor | None = None, + tgt_gt_depth: torch.Tensor | None = None, + src_pred_depth: torch.Tensor | None = None, + tgt_pred_depth: torch.Tensor | None = None, + src_unik3d_depth: torch.Tensor | None = None, + tgt_unik3d_depth: torch.Tensor | None = None, + dataset_name: str | None = None, + scene: str | None = None, + src_idx: int | None = None, + tgt_idx: int | None = None, + src_pose_w2c: torch.Tensor | None = None, + tgt_pose_w2c: torch.Tensor | None = None, + src_metric_mask: torch.Tensor | None = None, + tgt_metric_mask: torch.Tensor | None = None, + src_cube_gt_u8: torch.Tensor | None = None, + src_cube_pred_linear: torch.Tensor | None = None, + src_cube_alpha: torch.Tensor | None = None, + tgt_cube_gt_u8: torch.Tensor | None = None, + tgt_cube_pred_linear: torch.Tensor | None = None, + tgt_cube_alpha: torch.Tensor | None = None, +) -> None: + vis_dir = out_dir / "vis" + vis_dir.mkdir(parents=True, exist_ok=True) + LOGGER.info("Saving train visualization: %s", str(vis_dir / f"step_{int(step):07d}.png")) + save_pair_visualization( + vis_dir / f"step_{int(step):07d}.png", + src_gt=src_gt, + src_pred=src_pred, + src_alpha=src_alpha, + tgt_gt=tgt_gt, + tgt_pred=tgt_pred, + tgt_alpha=tgt_alpha, + src_gt_depth=src_gt_depth, + tgt_gt_depth=tgt_gt_depth, + src_pred_depth=src_pred_depth, + tgt_pred_depth=tgt_pred_depth, + src_unik3d_depth=src_unik3d_depth, + tgt_unik3d_depth=tgt_unik3d_depth, + dataset_name=dataset_name, + scene=scene, + step=int(step), + src_idx=src_idx, + tgt_idx=tgt_idx, + src_pose_w2c=src_pose_w2c, + tgt_pose_w2c=tgt_pose_w2c, + src_cube_gt_u8=src_cube_gt_u8, + src_cube_pred_linear=src_cube_pred_linear, + src_cube_alpha=src_cube_alpha, + tgt_cube_gt_u8=tgt_cube_gt_u8, + tgt_cube_pred_linear=tgt_cube_pred_linear, + tgt_cube_alpha=tgt_cube_alpha, + ) + + +def _read_nonempty_lines(path: Path) -> list[str]: + return [line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip()] + + +def _resolve_manifest_file(manifest_dir: Path | None, filename: str) -> Path | None: + if manifest_dir is None: + return None + path = Path(manifest_dir) / filename + return path if path.exists() else None + + +@click.command() +@click.option("--data-root-re10k", type=click.Path(path_type=Path, exists=True), default=None) +@click.option("--data-root-hm3d", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/panogs")) +@click.option("--data-root-sim", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/smx_sim")) +@click.option("--sim-pose-root", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/smx_sim/30cm")) +@click.option("--data-root-wildrgbd", type=click.Path(path_type=Path, exists=True), default=None) +@click.option("--wild-roots-file", type=click.Path(path_type=Path, exists=True, dir_okay=False), default=DEFAULT_WILDRGBD_ROOTS_FILE) +@click.option("--data-root-dl3dv", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/sharp/DL3DV-ALL-960P")) +@click.option("--data-root-dl3dv-depth", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/sharp/DL3DV-ALL-960P_da3_outputs")) +@click.option("--data-root-scanetpp", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/scan")) +@click.option("--dataset-manifest-dir", type=click.Path(path_type=Path, file_okay=False), default=None) +@click.option("--out-root", type=click.Path(path_type=Path, file_okay=False), required=True) +@click.option("--run-name", type=str, default=None) +@click.option("--steps", type=int, default=1000000) +@click.option("--batch-size", type=int, default=2) +@click.option("--num-workers", type=int, default=1) +@click.option("--warmup", type=int, default=75000) +@click.option("--lr0", type=float, default=1.2e-4) +@click.option("--lr1", type=float, default=1.6e-5) +@click.option("--unik3d-lr0", type=float, default=2.5e-5, help="UniK3D decoder/head peak LR.") +@click.option("--unik3d-lr1", type=float, default=2.5e-6, help="UniK3D decoder/head final LR.") +@click.option("--unik3d-encoder-lr0", type=float, default=1.5e-6, help="UniK3D pixel_encoder peak LR.") +@click.option("--unik3d-encoder-lr1", type=float, default=1.5e-7, help="UniK3D pixel_encoder final LR.") +@click.option("--grad-clip-norm", type=float, default=1.0, show_default=True) +@click.option("--max-step-grad-norm", type=float, default=100000.0, show_default=True, help="Skip optimizer step when pre-clip grad norm exceeds this value. 0 disables.") +@click.option("--max-depth-m", type=float, default=DEFAULT_MAX_DEPTH_M, show_default=True) +@click.option("--sim-far-depth-invalid-m", type=float, default=30.0, show_default=True) +@click.option("--sim-far-depth-invalid-max-frac", type=float, default=1.0, show_default=True) +@click.option("--sim-max-long-edge", type=int, default=512, show_default=True, help="Resize SIM ERP frames before cubemap conversion. 0 keeps native resolution.") +@click.option("--train-resize-multiple", type=int, default=256, show_default=True, help="Before model forward, downsize training inputs to the largest H/W divisible by this value. 0 disables.") +@click.option("--pinhole-train-size", type=int, default=0, show_default=True, help="Resize pinhole training datasets to NxN before model forward. 0 keeps dataset native resolution.") +@click.option("--scanetpp-fisheye-far-depth-invalid-m", type=float, default=30.0, show_default=True) +@click.option("--max-index-gap", type=int, default=10) +@click.option("--device", type=str, default="cuda") +@click.option("--render-low-pass-filter-eps", type=float, default=1e-2, show_default=True) +@click.option("--ddp-timeout-hours", type=float, default=8.0) +@click.option("--save-every", type=int, default=5000) +@click.option("--log-every", type=int, default=50) +@click.option("--vis-every", type=int, default=500) +@click.option("--unik3d-backbone", type=click.Choice(["vitb", "vitl"]), default="vitl") +@click.option("--unik3d-resolution-level", type=click.IntRange(0, 9), default=0, show_default=True) +@click.option("--initializer-stride", type=click.IntRange(1, 2), default=1) +@click.option("--initializer-scale-factor", type=float, default=1.5, show_default=True) +@click.option("--lambda-aux-ray", type=float, default=3.0) +@click.option("--lambda-aux-depth-scale", type=float, default=3.0) +@click.option("--lambda-aux-depth2-scale", type=float, default=1.0) +@click.option("--lambda-color", type=float, default=1.0) +@click.option("--lambda-alpha", type=float, default=1.5) +@click.option("--alpha-tail-min", type=float, default=0.99, show_default=True, help="Alpha value below which local tail coverage loss is applied.") +@click.option("--alpha-tail-weight", type=float, default=0.0, show_default=True, help="Extra normalized tail weight for local low-alpha holes.") +@click.option("--lambda-percep", type=float, default=1.0) +@click.option("--lambda-depth", type=float, default=0.5) +@click.option("--lambda-tv", type=float, default=1.0) +@click.option("--lambda-grad", type=float, default=1.0) +@click.option("--lambda-grad-img", type=float, default=0.2) +@click.option("--lambda-edge-rgb", type=float, default=0.0, show_default=True, help="Weight for GT RGB edge-band gradient matching.") +@click.option("--lambda-delta", type=float, default=1.0) +@click.option("--lambda-delta-rho", type=float, default=0.01, show_default=True) +@click.option("--lambda-splat", type=float, default=1.0) +@click.option("--lambda-edge-splat", type=float, default=0.0, show_default=True, help="Weight for stricter projected-sigma penalty on GT depth-edge bands.") +@click.option("--lambda-grid", type=float, default=0.05, show_default=True, help="Weight for Gaussian-grid 2x2 checkerboard residual regularization.") +@click.option("--delta-clip", type=float, default=10.0, show_default=True) +@click.option("--raw-delta-clip", type=float, default=400.0, show_default=True) +@click.option("--raw-delta-rho-clip", type=float, default=5.0, show_default=True) +@click.option("--delta-rho-limit", type=float, default=2.0, show_default=True) +@click.option("--splat-sigma-min", type=float, default=1e-1, show_default=True, help="Minimum projected screen-space variance for L_splat.") +@click.option("--splat-sigma-max", type=float, default=1e2, show_default=True, help="Maximum projected screen-space variance for L_splat.") +@click.option("--edge-splat-sigma-max", type=float, default=2.0, show_default=True, help="Maximum projected variance on depth-edge bands for L_edge_splat.") +@click.option("--depth-edge-log-threshold", type=float, default=0.05, show_default=True, help="Log-depth jump threshold used to build L_edge_splat edge bands.") +@click.option("--depth-edge-dilate-px", type=int, default=2, show_default=True, help="Dilation radius in pixels for L_edge_splat depth-edge bands.") +@click.option("--target-mask-erode-px", type=int, default=0, show_default=True, help="Erode source-visible target masks by this many pixels before target supervision.") +@click.option("--dataset-weight-re10k", type=float, default=1.0) +@click.option("--dataset-weight-hm3d", type=float, default=1.0) +@click.option("--dataset-weight-sim", type=float, default=1.0) +@click.option("--dataset-weight-wildrgbd", type=float, default=1.0) +@click.option("--dataset-weight-dl3dv", type=float, default=1.0) +@click.option("--dataset-weight-scanetpp", type=float, default=0.0) +@click.option( + "--re10k-pseudo-depth-root", + type=click.Path(path_type=Path, file_okay=False), + default=Path("/media/team_data/ML4_team/datasets/nopose/re10k_unik3d_pseudo_depth"), +) +@click.option("--re10k-pseudo-depth-autogen/--no-re10k-pseudo-depth-autogen", default=True) +@click.option("--re10k-pseudo-depth-backbone", type=click.Choice(["vitb", "vitl"]), default="vitl") +@click.option("--re10k-pseudo-depth-device", type=str, default="cpu") +@click.option("--re10k-pseudo-lock-timeout-sec", type=float, default=120.0) +@click.option("--re10k-pseudo-lock-stale-sec", type=float, default=1800.0) +@click.option("--re10k-pseudo-far-depth-invalid-m", type=float, default=30.0) +@click.option("--seed", type=int, default=None) +@click.option("-v", "--verbose", is_flag=True) +def train_feature_cli( + data_root_re10k: Path | None, + data_root_hm3d: Path | None, + data_root_sim: Path | None, + sim_pose_root: Path | None, + data_root_wildrgbd: Path | None, + wild_roots_file: Path, + data_root_dl3dv: Path | None, + data_root_dl3dv_depth: Path | None, + data_root_scanetpp: Path | None, + dataset_manifest_dir: Path | None, + out_root: Path, + run_name: str | None, + steps: int, + batch_size: int, + num_workers: int, + warmup: int, + lr0: float, + lr1: float, + unik3d_lr0: float, + unik3d_lr1: float, + unik3d_encoder_lr0: float, + unik3d_encoder_lr1: float, + grad_clip_norm: float, + max_step_grad_norm: float, + max_depth_m: float, + sim_far_depth_invalid_m: float, + sim_far_depth_invalid_max_frac: float, + sim_max_long_edge: int, + train_resize_multiple: int, + pinhole_train_size: int, + scanetpp_fisheye_far_depth_invalid_m: float, + max_index_gap: int, + device: str, + render_low_pass_filter_eps: float, + ddp_timeout_hours: float, + save_every: int, + log_every: int, + vis_every: int, + unik3d_backbone: str, + unik3d_resolution_level: int, + initializer_stride: int, + initializer_scale_factor: float, + lambda_aux_ray: float, + lambda_aux_depth_scale: float, + lambda_aux_depth2_scale: float, + lambda_color: float, + lambda_alpha: float, + alpha_tail_min: float, + alpha_tail_weight: float, + lambda_percep: float, + lambda_depth: float, + lambda_tv: float, + lambda_grad: float, + lambda_grad_img: float, + lambda_edge_rgb: float, + lambda_delta: float, + lambda_delta_rho: float, + lambda_splat: float, + lambda_edge_splat: float, + lambda_grid: float, + delta_clip: float, + raw_delta_clip: float, + raw_delta_rho_clip: float, + delta_rho_limit: float, + splat_sigma_min: float, + splat_sigma_max: float, + edge_splat_sigma_max: float, + depth_edge_log_threshold: float, + depth_edge_dilate_px: int, + target_mask_erode_px: int, + dataset_weight_re10k: float, + dataset_weight_hm3d: float, + dataset_weight_sim: float, + dataset_weight_wildrgbd: float, + dataset_weight_dl3dv: float, + dataset_weight_scanetpp: float, + re10k_pseudo_depth_root: Path, + re10k_pseudo_depth_autogen: bool, + re10k_pseudo_depth_backbone: str, + re10k_pseudo_depth_device: str, + re10k_pseudo_lock_timeout_sec: float, + re10k_pseudo_lock_stale_sec: float, + re10k_pseudo_far_depth_invalid_m: float, + seed: int | None, + verbose: bool, +) -> None: + detach_init_layer0_distance = True + + log_level = logging.DEBUG if verbose else logging.INFO + logging_utils.configure(log_level) + if float(max_depth_m) <= 0.0: + raise ValueError("--max-depth-m must be positive.") + if float(grad_clip_norm) <= 0.0: + raise ValueError("--grad-clip-norm must be positive.") + if float(max_step_grad_norm) < 0.0: + raise ValueError("--max-step-grad-norm must be non-negative.") + if float(render_low_pass_filter_eps) < 0.0: + raise ValueError("--render-low-pass-filter-eps must be non-negative.") + if not (0.0 <= float(sim_far_depth_invalid_max_frac) <= 1.0): + raise ValueError("--sim-far-depth-invalid-max-frac must be in [0, 1].") + if int(sim_max_long_edge) < 0: + raise ValueError("--sim-max-long-edge must be non-negative.") + if int(train_resize_multiple) < 0: + raise ValueError("--train-resize-multiple must be non-negative.") + if int(pinhole_train_size) < 0: + raise ValueError("--pinhole-train-size must be non-negative.") + if float(scanetpp_fisheye_far_depth_invalid_m) < 0.0: + raise ValueError("--scanetpp-fisheye-far-depth-invalid-m must be non-negative.") + if float(delta_clip) < 0.0: + raise ValueError("--delta-clip must be non-negative.") + if float(raw_delta_clip) < 0.0: + raise ValueError("--raw-delta-clip must be non-negative.") + if float(raw_delta_rho_clip) < 0.0: + raise ValueError("--raw-delta-rho-clip must be non-negative.") + if float(lambda_grid) < 0.0: + raise ValueError("--lambda-grid must be non-negative.") + if float(lambda_edge_rgb) < 0.0: + raise ValueError("--lambda-edge-rgb must be non-negative.") + if float(lambda_edge_splat) < 0.0: + raise ValueError("--lambda-edge-splat must be non-negative.") + if float(edge_splat_sigma_max) < 0.0: + raise ValueError("--edge-splat-sigma-max must be non-negative.") + if float(depth_edge_log_threshold) < 0.0: + raise ValueError("--depth-edge-log-threshold must be non-negative.") + if int(depth_edge_dilate_px) < 0: + raise ValueError("--depth-edge-dilate-px must be non-negative.") + if int(target_mask_erode_px) < 0: + raise ValueError("--target-mask-erode-px must be non-negative.") + if not (0.0 <= float(alpha_tail_min) <= 1.0): + raise ValueError("--alpha-tail-min must be in [0, 1].") + if float(alpha_tail_weight) < 0.0: + raise ValueError("--alpha-tail-weight must be non-negative.") + if float(delta_rho_limit) < 0.0: + raise ValueError("--delta-rho-limit must be non-negative.") + if float(splat_sigma_min) < 0.0: + raise ValueError("--splat-sigma-min must be non-negative.") + if float(splat_sigma_max) <= float(splat_sigma_min): + raise ValueError("--splat-sigma-max must be greater than --splat-sigma-min.") + dev, rank, world_size, is_main = _ddp_setup(device, ddp_timeout_hours=ddp_timeout_hours) + + if seed is not None: + s = int(seed) + random.seed(s + rank) + np.random.seed(s + rank) + torch.manual_seed(s + rank) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(s + rank) + + if is_main and (run_name is None or run_name.strip() == ""): + run_name = f"unified_feature_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + if run_name is None: + run_name = "unified_feature_ddp" + out_dir = _ddp_broadcast_path(Path(out_root) / run_name, is_main=is_main) + logging_utils.configure(log_level) + if not is_main: + logging.getLogger().setLevel(logging.WARNING) + LOGGER.setLevel(logging.WARNING) + _configure_torchhub_cache() + re10k_enabled_for_train = bool(float(dataset_weight_re10k) > 0.0) + hm3d_enabled_for_train = bool(float(dataset_weight_hm3d) > 0.0) + sim_enabled_for_train = bool(float(dataset_weight_sim) > 0.0) + dl3dv_enabled_for_train = bool(float(dataset_weight_dl3dv) > 0.0) + scanetpp_enabled_for_train = bool(float(dataset_weight_scanetpp) > 0.0) + wild_roots = _read_nonempty_lines(wild_roots_file) if wild_roots_file.exists() else [] + re10k_manifest = _resolve_manifest_file(dataset_manifest_dir, "re10k_train_chunks.txt") + hm3d_manifest = _resolve_manifest_file(dataset_manifest_dir, "hm3d_train_scenes.txt") + sim_manifest = _resolve_manifest_file(dataset_manifest_dir, "sim_train_scenes.txt") + wildrgbd_manifest = _resolve_manifest_file(dataset_manifest_dir, "wildrgbd_train_scenes.txt") + dl3dv_manifest = _resolve_manifest_file(dataset_manifest_dir, "dl3dv_train_scenes.txt") + scanetpp_manifest = _resolve_manifest_file(dataset_manifest_dir, "scanetpp_fisheye_train_scenes.txt") + wildrgbd_enabled_for_train = bool( + ((data_root_wildrgbd is not None) or bool(wild_roots)) and (float(dataset_weight_wildrgbd) > 0.0) + ) + if re10k_enabled_for_train and data_root_re10k is None: + raise ValueError("dataset_weight_re10k>0 but --data-root-re10k is not provided.") + if hm3d_enabled_for_train and data_root_hm3d is None: + raise ValueError("dataset_weight_hm3d>0 but --data-root-hm3d is not provided.") + if sim_enabled_for_train and (data_root_sim is None or sim_pose_root is None): + raise ValueError("dataset_weight_sim>0 but --data-root-sim / --sim-pose-root is missing.") + if sim_enabled_for_train and sim_manifest is None: + raise ValueError("dataset_weight_sim>0 but sim_train_scenes.txt is missing from --dataset-manifest-dir.") + if float(dataset_weight_wildrgbd) > 0.0 and (data_root_wildrgbd is None) and (not wild_roots): + raise ValueError("dataset_weight_wildrgbd>0 but neither --data-root-wildrgbd nor --wild-roots-file is provided.") + if dl3dv_enabled_for_train and (data_root_dl3dv is None or data_root_dl3dv_depth is None): + raise ValueError("dataset_weight_dl3dv>0 but --data-root-dl3dv / --data-root-dl3dv-depth is missing.") + if scanetpp_enabled_for_train and data_root_scanetpp is None: + raise ValueError("dataset_weight_scanetpp>0 but --data-root-scanetpp is missing.") + + if is_main: + out_dir.mkdir(parents=True, exist_ok=True) + LOGGER.info( + "Training start: out=%s branch=gt-override scratch_unik3d_pretrained backbone=%s steps=%d batch=%d", + str(out_dir), + str(unik3d_backbone), + int(steps), + int(batch_size), + ) + LOGGER.info( + "Loss weights: color=%.3g alpha=%.3g depth=%.3g percep=%.3g aux_ray=%.3g aux_depth0=%.3g aux_depth1=%.3g", + float(lambda_color), + float(lambda_alpha), + float(lambda_depth), + float(lambda_percep), + float(lambda_aux_ray), + float(lambda_aux_depth_scale), + float(lambda_aux_depth2_scale), + ) + + dataset_seed = int(seed) if seed is not None else 12345 + pinhole_output_h = int(pinhole_train_size) if int(pinhole_train_size) > 0 else None + pinhole_output_w = int(pinhole_train_size) if int(pinhole_train_size) > 0 else None + + re10k_ds = None + if re10k_enabled_for_train: + re10k_ds = Re10KDataset( + root=data_root_re10k, + chunks_file=re10k_manifest, + split="train", + min_frame_gap=1, + max_frame_gap=int(max_index_gap), + pair_max_translation_m=0.5, + pair_min_overlap=0.6, + output_h=pinhole_output_h, + output_w=pinhole_output_w, + shuffle_chunk=True, + shuffle_example=True, + ddp_rank=rank, + ddp_world_size=world_size, + pseudo_depth_root=re10k_pseudo_depth_root, + pseudo_depth_autogen=bool(re10k_pseudo_depth_autogen), + pseudo_depth_backbone=str(re10k_pseudo_depth_backbone), + pseudo_depth_device=str(re10k_pseudo_depth_device), + pseudo_lock_timeout_sec=float(re10k_pseudo_lock_timeout_sec), + pseudo_lock_stale_sec=float(re10k_pseudo_lock_stale_sec), + batch_size_hint=int(batch_size), + depth_max_m=float(max_depth_m), + pseudo_far_depth_invalid_m=float(re10k_pseudo_far_depth_invalid_m), + seed=dataset_seed, + ) + hm3d_train_root = None + if data_root_hm3d is not None: + hm3d_train_root = data_root_hm3d / "train" if (data_root_hm3d / "train").exists() else data_root_hm3d + + hm3d_ds = None + if hm3d_enabled_for_train: + hm3d_ds = PanOGSDataset( + root=hm3d_train_root, + index_manifest_path=hm3d_manifest, + src_tgt_max_index_gap=int(max_index_gap), + use_cubemap_supervision=True, + pair_sampling=True, + pair_max_translation_m=0.5, + pair_min_depth_overlap=0.6, + pair_overlap_face_w=64, + pair_overlap_margin=1.05, + pair_max_tries=48, + depth_max_m=float(max_depth_m), + ) + sim_ds = None + if sim_enabled_for_train: + sim_ds = SimPanoramaDataset( + root=data_root_sim, + pose_root=sim_pose_root, + scene_list_file=sim_manifest, + max_index_gap=int(max_index_gap), + pair_max_translation_m=0.5, + pair_min_depth_overlap=0.6, + pairs_per_chunk=15, + chunk_size=30, + shuffle_scene=True, + ddp_rank=rank, + ddp_world_size=world_size, + depth_max_m=float(max_depth_m), + far_depth_invalid_m=float(sim_far_depth_invalid_m), + far_depth_invalid_max_frac=float(sim_far_depth_invalid_max_frac), + max_long_edge=int(sim_max_long_edge), + seed=dataset_seed, + ) + wildrgbd_ds = None + if wildrgbd_enabled_for_train: + wild_dataset_roots = [Path(p) for p in wild_roots] + if data_root_wildrgbd is not None: + wild_dataset_roots.append(data_root_wildrgbd) + wildrgbd_ds = WildRGBDDataset( + root=None, + scene_list_file=wildrgbd_manifest, + split="scenes", + min_frame_gap=1, + max_frame_gap=int(max_index_gap), + pair_max_translation_m=0.5, + pair_min_overlap=0.6, + output_h=pinhole_output_h, + output_w=pinhole_output_w, + shuffle_scene=True, + shuffle_frame=False, + ddp_rank=rank, + ddp_world_size=world_size, + roots=wild_dataset_roots, + depth_max_m=float(max_depth_m), + seed=dataset_seed, + ) + dl3dv_ds = None + if dl3dv_enabled_for_train: + dl3dv_ds = DL3DVDataset( + root=data_root_dl3dv, + depth_root=data_root_dl3dv_depth, + scene_specs_file=dl3dv_manifest, + min_frame_gap=1, + max_frame_gap=int(max_index_gap), + pair_max_translation_m=0.5, + pair_min_overlap=0.6, + output_h=pinhole_output_h, + output_w=pinhole_output_w, + shuffle_scene=True, + shuffle_frame=False, + ddp_rank=rank, + ddp_world_size=world_size, + batch_size_hint=int(batch_size), + depth_max_m=float(max_depth_m), + seed=dataset_seed, + ) + + scanetpp_ds = None + if scanetpp_enabled_for_train: + scanetpp_ds = ScannetppFisheyeDataset( + root=data_root_scanetpp, + scene_list_file=scanetpp_manifest, + min_frame_gap=1, + max_frame_gap=int(max_index_gap), + pair_max_translation_m=0.5, + shuffle_scene=True, + shuffle_frame=False, + ddp_rank=rank, + ddp_world_size=world_size, + batch_size_hint=int(batch_size), + depth_max_m=float(max_depth_m), + far_depth_invalid_m=float(scanetpp_fisheye_far_depth_invalid_m), + seed=dataset_seed, + ) + + hm3d_sampler = None + if hm3d_ds is not None and _ddp_is_enabled(): + hm3d_sampler = DistributedSampler(hm3d_ds, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) + + re10k_num_workers = int(num_workers) + if re10k_ds is not None and bool(re10k_pseudo_depth_autogen) and re10k_num_workers > 0: + re10k_num_workers = 0 + if is_main: + LOGGER.warning( + "RE10K pseudo-depth auto-generate enabled: force re10k dataloader num_workers=%d (requested=%d).", + int(re10k_num_workers), + int(num_workers), + ) + if re10k_ds is not None and batch_size > 1 and re10k_num_workers > 0: + re10k_num_workers = 0 + if is_main: + LOGGER.warning( + "Dynamic-resolution RE10K batching requires ordered same-resolution samples: force re10k dataloader num_workers=%d (requested=%d).", + int(re10k_num_workers), + int(num_workers), + ) + + highres_pin_memory = os.environ.get("HIGHRES_TRAIN_PIN_MEMORY", "0").strip().lower() in {"1", "true", "yes", "on"} + standard_pin_memory = os.environ.get("TRAIN_PIN_MEMORY", "1").strip().lower() in {"1", "true", "yes", "on"} + try: + train_prefetch_factor = max(1, int(os.environ.get("TRAIN_PREFETCH_FACTOR", "1").strip())) + except Exception: + train_prefetch_factor = 1 + def _loader_worker_kwargs(worker_count: int, *, pin_memory: bool) -> dict[str, Any]: + kwargs: dict[str, Any] = { + "num_workers": int(worker_count), + "pin_memory": bool(pin_memory), + } + if int(worker_count) > 0: + kwargs["prefetch_factor"] = int(train_prefetch_factor) + return kwargs + + re10k_dl = None + if re10k_ds is not None: + re10k_dl = DataLoader( + re10k_ds, + batch_size=None, + **_loader_worker_kwargs(re10k_num_workers, pin_memory=standard_pin_memory), + collate_fn=re10k_passthrough, + ) + + hm3d_dl = None + if hm3d_ds is not None: + hm3d_dl = DataLoader( + hm3d_ds, + batch_size=batch_size, + shuffle=(hm3d_sampler is None), + sampler=hm3d_sampler, + **_loader_worker_kwargs(num_workers, pin_memory=highres_pin_memory), + collate_fn=panogs_collate, + ) + + sim_dl = None + if sim_ds is not None: + sim_dl = DataLoader( + sim_ds, + batch_size=batch_size, + **_loader_worker_kwargs(num_workers, pin_memory=highres_pin_memory), + collate_fn=panogs_collate, + ) + + wildrgbd_dl = None + if wildrgbd_ds is not None: + wildrgbd_dl = DataLoader( + wildrgbd_ds, + batch_size=batch_size, + **_loader_worker_kwargs(num_workers, pin_memory=standard_pin_memory), + collate_fn=wildrgbd_collate, + ) + + dl3dv_dl = None + if dl3dv_ds is not None: + dl3dv_dl = DataLoader( + dl3dv_ds, + batch_size=None, + **_loader_worker_kwargs(num_workers, pin_memory=standard_pin_memory), + collate_fn=re10k_passthrough, + ) + + scanetpp_dl = None + if scanetpp_ds is not None: + scanetpp_dl = DataLoader( + scanetpp_ds, + batch_size=None, + **_loader_worker_kwargs(num_workers, pin_memory=highres_pin_memory), + collate_fn=scannetpp_fisheye_passthrough, + ) + + candidate_datasets: dict[str, Any] = {} + candidate_dataloaders: dict[str, DataLoader] = {} + candidate_weights: dict[str, float] = {} + if re10k_ds is not None and re10k_dl is not None: + candidate_datasets["re10k"] = re10k_ds + candidate_dataloaders["re10k"] = re10k_dl + candidate_weights["re10k"] = float(dataset_weight_re10k) + if hm3d_ds is not None and hm3d_dl is not None: + candidate_datasets["hm3d"] = hm3d_ds + candidate_dataloaders["hm3d"] = hm3d_dl + candidate_weights["hm3d"] = float(dataset_weight_hm3d) + if sim_ds is not None and sim_dl is not None: + candidate_datasets["sim"] = sim_ds + candidate_dataloaders["sim"] = sim_dl + candidate_weights["sim"] = float(dataset_weight_sim) + if wildrgbd_ds is not None and wildrgbd_dl is not None: + candidate_datasets["wildrgbd"] = wildrgbd_ds + candidate_dataloaders["wildrgbd"] = wildrgbd_dl + candidate_weights["wildrgbd"] = float(dataset_weight_wildrgbd) + if dl3dv_ds is not None and dl3dv_dl is not None: + candidate_datasets["dl3dv"] = dl3dv_ds + candidate_dataloaders["dl3dv"] = dl3dv_dl + candidate_weights["dl3dv"] = float(dataset_weight_dl3dv) + if scanetpp_ds is not None and scanetpp_dl is not None: + candidate_datasets["scanetpp_fisheye"] = scanetpp_ds + candidate_dataloaders["scanetpp_fisheye"] = scanetpp_dl + candidate_weights["scanetpp_fisheye"] = float(dataset_weight_scanetpp) + + datasets: dict[str, Any] = {} + dataloaders: dict[str, DataLoader] = {} + sampling: dict[str, float] = {} + for name, w in candidate_weights.items(): + if float(w) > 0.0: + datasets[name] = candidate_datasets[name] + dataloaders[name] = candidate_dataloaders[name] + sampling[name] = float(w) + elif is_main: + LOGGER.warning("Skip dataset in mixed sampler: %s (weight=%.4f <= 0)", name, float(w)) + + if len(datasets) == 0: + raise ValueError("No dataset selected for mixed sampler (all dataset weights <= 0).") + for name, dataset in datasets.items(): + _maybe_set_dataset_epoch(dataset, 0) + iterators = {name: LazyDataLoaderIterator(dl) for name, dl in dataloaders.items()} + sampler_seed = int(seed + rank) if seed is not None else int(12345 + rank) + sampler = MixedDatasetSampler( + datasets=datasets, + weights=sampling, + iterators=iterators, + seed=sampler_seed, + ) + + config = UnisharpFeatureConfig( + unik3d_backbone=unik3d_backbone, + unik3d_resolution_level=int(unik3d_resolution_level), + initializer_stride=int(initializer_stride), + initializer_scale_factor=float(initializer_scale_factor), + detach_init_layer0_distance=bool(detach_init_layer0_distance), + delta_rho_limit=float(delta_rho_limit), + ) + setattr(config, "max_distance_m", float(max_depth_m)) + + model = UnisharpFeatureModel(config).to(dev).train() + + if _ddp_is_enabled(): + model = DDP( + model, + device_ids=[dev.index], + output_device=dev.index, + find_unused_parameters=True, + gradient_as_bucket_view=True, + ) + + raw_model = model.module if isinstance(model, DDP) else model + base_params, unik3d_encoder_params, unik3d_decoder_params = _build_optimizer_param_groups(raw_model) + unik3d_params = unik3d_encoder_params + unik3d_decoder_params + trainable_params = base_params + unik3d_params + if len(trainable_params) == 0: + raise RuntimeError("No trainable parameters found.") + if len(unik3d_params) == 0: + raise RuntimeError( + "No UniK3D parameters were collected for the default unfreeze training path. " + "Please check parameter naming." + ) + depth_head_params = [p for p in raw_model.second_layer_depth_head.parameters() if p.requires_grad] + if len(depth_head_params) == 0: + raise RuntimeError("Depth heads have no trainable parameters; depth branch would not train.") + + opt_groups: list[dict[str, Any]] = [{"params": base_params, "lr": float(lr0), "group_name": "base"}] + if len(unik3d_encoder_params) > 0: + opt_groups.append( + { + "params": unik3d_encoder_params, + "lr": float(unik3d_encoder_lr0), + "group_name": "unik3d_encoder", + } + ) + if len(unik3d_decoder_params) > 0: + opt_groups.append( + { + "params": unik3d_decoder_params, + "lr": float(unik3d_lr0), + "group_name": "unik3d_decoder", + } + ) + opt = torch.optim.Adam(opt_groups) + if is_main: + LOGGER.info( + "Model ready: scratch heads, pretrained UniK3D, trainable_params=%d", + _count_numel(trainable_params), + ) + if dev.type == "cuda": + scaler = torch.amp.GradScaler("cuda", enabled=True) + else: + scaler = torch.amp.GradScaler("cpu", enabled=False) + + renderer = GSplatRenderer( + color_space="sRGB", + background_color="black", + low_pass_filter_eps=float(render_low_pass_filter_eps), + ).to(dev) + + loss_w = UnisharpLossWeights( + lambda_color=float(lambda_color), + lambda_alpha=float(lambda_alpha), + lambda_percep=float(lambda_percep), + lambda_depth=float(lambda_depth), + lambda_tv=float(lambda_tv), + lambda_grad=float(lambda_grad), + lambda_grad_img=float(lambda_grad_img), + lambda_edge_rgb=float(lambda_edge_rgb), + lambda_delta=float(lambda_delta), + lambda_delta_rho=float(lambda_delta_rho), + lambda_splat=float(lambda_splat), + lambda_edge_splat=float(lambda_edge_splat), + lambda_grid=float(lambda_grid), + ) + loss_fn = UnisharpLoss( + weights=loss_w, + delta_clip=float(delta_clip), + raw_delta_clip=float(raw_delta_clip), + raw_delta_rho_clip=float(raw_delta_rho_clip), + alpha_tail_min=float(alpha_tail_min), + alpha_tail_weight=float(alpha_tail_weight), + splat_sigma_min=float(splat_sigma_min), + splat_sigma_max=float(splat_sigma_max), + edge_splat_sigma_max=float(edge_splat_sigma_max), + depth_edge_log_threshold=float(depth_edge_log_threshold), + depth_edge_dilate_px=int(depth_edge_dilate_px), + ).to(dev) + loss_fn.SUPERVISION_MAX_DEPTH_M = float(max_depth_m) + + if is_main: + config_dict = { + "max_depth_m": float(max_depth_m), + "sim_far_depth_invalid_m": float(sim_far_depth_invalid_m), + "sim_far_depth_invalid_max_frac": float(sim_far_depth_invalid_max_frac), + "re10k_pseudo_far_depth_invalid_m": float(re10k_pseudo_far_depth_invalid_m), + "scanetpp_fisheye_far_depth_invalid_m": float(scanetpp_fisheye_far_depth_invalid_m), + "render_low_pass_filter_eps": float(render_low_pass_filter_eps), + } + (out_dir / "config.json").write_text( + json.dumps(config_dict, ensure_ascii=False, indent=2, sort_keys=True) + "\n", + encoding="utf-8", + ) + + loss_csv = out_dir / "losses.csv" + loss_csv_fields = [ + "loss", + "src_loss", + "tgt_loss", + "dataset", + ] + if is_main: + with loss_csv.open("w", newline="") as f: + csv.DictWriter(f, fieldnames=loss_csv_fields).writeheader() + + if is_main: + LOGGER.info("Training loop started.") + + from unisharp.cli.unified_trainer import UnifiedTrainer + + trainer = UnifiedTrainer( + model=model, + renderer=renderer, + loss_fn=loss_fn, + device=dev, + max_depth_m=float(max_depth_m), + sim_far_depth_invalid_m=float(sim_far_depth_invalid_m), + re10k_pseudo_far_depth_invalid_m=float(re10k_pseudo_far_depth_invalid_m), + scanetpp_fisheye_far_depth_invalid_m=float(scanetpp_fisheye_far_depth_invalid_m), + aux_ray_loss_weight=float(lambda_aux_ray), + aux_depth_scale_loss_weight=float(lambda_aux_depth_scale), + aux_depth2_scale_loss_weight=float(lambda_aux_depth2_scale), + target_mask_erode_px=int(target_mask_erode_px), + ) + skip_forward_oom = _env_flag("TRAIN_SKIP_FORWARD_OOM", default=True) + + dataset_epochs: dict[str, int] = {name: 0 for name in dataloaders.keys()} + dataset_samplers: dict[str, DistributedSampler | None] = {"hm3d": hm3d_sampler} + + for step in range(1, steps + 1): + lr = warmup_cosine_lr(step, warmup, steps, lr0, lr1) + lr_unik3d_encoder = warmup_cosine_lr(step, warmup, steps, unik3d_encoder_lr0, unik3d_encoder_lr1) + lr_unik3d_decoder = warmup_cosine_lr(step, warmup, steps, unik3d_lr0, unik3d_lr1) + for g in opt.param_groups: + if g.get("group_name") == "unik3d_encoder": + g["lr"] = lr_unik3d_encoder + elif g.get("group_name") == "unik3d_decoder": + g["lr"] = lr_unik3d_decoder + else: + g["lr"] = lr + + if _ddp_is_enabled(): + batch = None + available_dataset_names = list(dataloaders.keys()) + dataset_name = "" + for _dataset_attempt in range(max(1, len(dataloaders))): + dataset_name = _ddp_broadcast_str( + sampler.choose_dataset_name(available_dataset_names) if is_main else "", + is_main=is_main, + ) + + local_exhausted = False + try: + batch = sampler.next_batch(dataset_name) + except StopIteration: + local_exhausted = True + + exhausted_any = _ddp_any_bool(local_exhausted, device=dev) + if exhausted_any: + dataset_epochs[dataset_name] = dataset_epochs.get(dataset_name, 0) + 1 + ds_sampler = dataset_samplers.get(dataset_name, None) + if ds_sampler is not None: + ds_sampler.set_epoch(dataset_epochs[dataset_name]) + _maybe_set_dataset_epoch(datasets[dataset_name], dataset_epochs[dataset_name]) + iterators[dataset_name] = iter(dataloaders[dataset_name]) + sampler.iterators = iterators + batch = None + + local_exhausted = False + try: + batch = sampler.next_batch(dataset_name) + except StopIteration: + local_exhausted = True + exhausted_any = _ddp_any_bool(local_exhausted, device=dev) + + if not exhausted_any: + break + + batch = None + available_dataset_names = [name for name in available_dataset_names if name != dataset_name] + if len(available_dataset_names) == 0: + break + if batch is None: + raise RuntimeError(f"Failed to fetch synchronized DDP batch for dataset={dataset_name}") + else: + try: + dataset_name, batch = sampler.sample() + except StopIteration as e: + msg = str(e) + exhausted_name = None + if msg.startswith("Dataset ") and msg.endswith(" exhausted"): + exhausted_name = msg[len("Dataset ") : -len(" exhausted")] + if exhausted_name is None or exhausted_name not in dataloaders: + raise + dataset_epochs[exhausted_name] = dataset_epochs.get(exhausted_name, 0) + 1 + ds_sampler = dataset_samplers.get(exhausted_name, None) + if ds_sampler is not None: + ds_sampler.set_epoch(dataset_epochs[exhausted_name]) + _maybe_set_dataset_epoch(datasets[exhausted_name], dataset_epochs[exhausted_name]) + iterators[exhausted_name] = iter(dataloaders[exhausted_name]) + sampler.iterators = iterators + dataset_name, batch = sampler.sample() + + batch = _resize_training_batch_to_multiple(batch, int(train_resize_multiple)) + + opt.zero_grad(set_to_none=True) + + autocast_enabled = dev.type == "cuda" + if autocast_enabled and torch.cuda.is_bf16_supported(): + autocast_dtype = torch.bfloat16 + else: + autocast_dtype = torch.float16 if autocast_enabled else torch.bfloat16 + + need_vis = bool(is_main and vis_every > 0 and (step % vis_every == 0)) + result: dict[str, Any] | None = None + forward_oom_local = False + forward_oom_error = "" + try: + with torch.autocast(device_type=dev.type, enabled=autocast_enabled, dtype=autocast_dtype): + result = trainer.process_batch( + batch, + dataset_name, + step, + need_vis=need_vis, + ) + except Exception as e: + if skip_forward_oom and _is_oom_exception(e): + forward_oom_local = True + forward_oom_error = str(e) + opt.zero_grad(set_to_none=True) + if dev.type == "cuda": + torch.cuda.empty_cache() + else: + raise + + forward_oom_any = _ddp_any_bool(forward_oom_local, device=dev) + if forward_oom_any: + opt.zero_grad(set_to_none=True) + if result is not None: + del result + result = None + if dev.type == "cuda": + torch.cuda.empty_cache() + if is_main: + LOGGER.error( + "Skipping optimizer step=%d because forward OOM occurred on at least one rank | dataset=%s", + int(step), + str(dataset_name), + ) + continue + + if result is None: + raise RuntimeError(f"Forward returned no result for dataset={dataset_name} step={step}") + total_loss = result["total"] + local_nonfinite_loss = not bool(torch.isfinite(total_loss.detach()).item()) + nonfinite_loss_any = _ddp_any_bool(local_nonfinite_loss, device=dev) + if nonfinite_loss_any: + opt.zero_grad(set_to_none=True) + if is_main: + LOGGER.error( + "Skipping optimizer step=%d because loss is non-finite on at least one rank | dataset=%s", + int(step), + str(dataset_name), + ) + continue + + try: + scaler.scale(total_loss).backward() + except Exception as e: + raise + try: + scaler.unscale_(opt) + grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=float(grad_clip_norm)) + except Exception as e: + LOGGER.error("Gradient unscale/clip failed at step=%d: %s", int(step), str(e)) + raise + grad_norm_value = float(grad_norm.detach().to(dtype=torch.float32).cpu().item()) if torch.is_tensor(grad_norm) else float(grad_norm) + local_nonfinite_grad = not np.isfinite(grad_norm_value) + nonfinite_grad_any = _ddp_any_bool(local_nonfinite_grad, device=dev) + if nonfinite_grad_any: + opt.zero_grad(set_to_none=True) + scaler.update() + if is_main: + LOGGER.error( + "Skipping optimizer step=%d because grad norm is non-finite on at least one rank | dataset=%s | local_grad_norm=%s", + int(step), + str(dataset_name), + str(grad_norm_value), + ) + continue + local_huge_grad = bool(float(max_step_grad_norm) > 0.0 and grad_norm_value > float(max_step_grad_norm)) + huge_grad_any = _ddp_any_bool(local_huge_grad, device=dev) + if huge_grad_any: + opt.zero_grad(set_to_none=True) + scaler.update() + if is_main: + LOGGER.error( + "Skipping optimizer step=%d because grad norm exceeded max-step-grad-norm on at least one rank | dataset=%s | local_grad_norm=%.6g | threshold=%.6g", + int(step), + str(dataset_name), + float(grad_norm_value), + float(max_step_grad_norm), + ) + continue + scaler.step(opt) + scaler.update() + + if log_every > 0 and step % log_every == 0: + loss_v = float(_ddp_mean(total_loss.detach()).item()) + src_v = float(_ddp_mean(result["src"].detach()).item()) + tgt_v = float(_ddp_mean(result["tgt"].detach()).item()) + row = { + "loss": loss_v, + "src_loss": src_v, + "tgt_loss": tgt_v, + "dataset": str(dataset_name), + } + if is_main: + LOGGER.info( + "step=%d dataset=%s loss=%.6f src_loss=%.6f tgt_loss=%.6f", + step, + dataset_name, + loss_v, + src_v, + tgt_v, + ) + row_csv = dict(row) + for k in ("loss", "src_loss", "tgt_loss"): + v = float(row_csv.get(k, float("nan"))) + row_csv[k] = "" if not np.isfinite(v) else f"{v:.4f}" + with loss_csv.open("a", newline="") as f: + csv.DictWriter(f, fieldnames=loss_csv_fields).writerow(row_csv) + + if need_vis and result.get("vis_payload"): + vis = result["vis_payload"] + _save_train_vis( + out_dir, + step, + vis["src_gt"], + vis["src_pred"], + vis["src_alpha"], + vis["tgt_gt"], + vis["tgt_pred"], + vis["tgt_alpha"], + src_gt_depth=vis.get("src_gt_depth"), + tgt_gt_depth=vis.get("tgt_gt_depth"), + src_pred_depth=vis.get("src_pred_depth"), + tgt_pred_depth=vis.get("tgt_pred_depth"), + src_unik3d_depth=vis.get("src_unik3d_depth"), + tgt_unik3d_depth=vis.get("tgt_unik3d_depth"), + dataset_name=vis.get("dataset_name"), + scene=vis.get("scene"), + src_idx=vis.get("src_idx"), + tgt_idx=vis.get("tgt_idx"), + src_pose_w2c=vis.get("src_pose_w2c"), + tgt_pose_w2c=vis.get("tgt_pose_w2c"), + src_metric_mask=vis.get("src_metric_mask"), + tgt_metric_mask=vis.get("tgt_metric_mask"), + src_cube_gt_u8=vis.get("src_cube_gt_u8"), + src_cube_pred_linear=vis.get("src_cube_pred_linear"), + src_cube_alpha=vis.get("src_cube_alpha"), + tgt_cube_gt_u8=vis.get("tgt_cube_gt_u8"), + tgt_cube_pred_linear=vis.get("tgt_cube_pred_linear"), + tgt_cube_alpha=vis.get("tgt_cube_alpha"), + ) + + if need_vis: + if "vis" in locals(): + del vis + if dev.type == "cuda": + torch.cuda.empty_cache() + del result + del total_loss + batch = None + + if is_main and (save_every > 0) and (step % save_every == 0): + path = out_dir / f"step_{step:07d}.pt" + raw_model.save_checkpoint(str(path), step, opt) + LOGGER.info("💾 Saved checkpoint: %s", str(path)) + + + if _ddp_is_enabled(): + _ddp_barrier(dev) + dist.destroy_process_group() + + if is_main: + LOGGER.info("✅ Training completed!") diff --git a/unisharp/cli/train_utils.py b/unisharp/cli/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bd677e8ed45115a54312a216d67b7a0b6558e62a --- /dev/null +++ b/unisharp/cli/train_utils.py @@ -0,0 +1,130 @@ + +from __future__ import annotations + +import numpy as np +import torch + + +def quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + w1, x1, y1, z1 = q1.unbind(dim=-1) + w2, x2, y2, z2 = q2.unbind(dim=-1) + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + return torch.stack([w, x, y, z], dim=-1) + + +def rotmat_to_quat_wxyz(Rm: torch.Tensor) -> torch.Tensor: + m00, m01, m02 = Rm[0, 0], Rm[0, 1], Rm[0, 2] + m10, m11, m12 = Rm[1, 0], Rm[1, 1], Rm[1, 2] + m20, m21, m22 = Rm[2, 0], Rm[2, 1], Rm[2, 2] + tr = m00 + m11 + m22 + if tr > 0.0: + s = torch.sqrt(tr + 1.0) * 2.0 + w = 0.25 * s + x = (m21 - m12) / s + y = (m02 - m20) / s + z = (m10 - m01) / s + elif (m00 > m11) and (m00 > m22): + s = torch.sqrt(1.0 + m00 - m11 - m22) * 2.0 + w = (m21 - m12) / s + x = 0.25 * s + y = (m01 + m10) / s + z = (m02 + m20) / s + elif m11 > m22: + s = torch.sqrt(1.0 + m11 - m00 - m22) * 2.0 + w = (m02 - m20) / s + x = (m01 + m10) / s + y = 0.25 * s + z = (m12 + m21) / s + else: + s = torch.sqrt(1.0 + m22 - m00 - m11) * 2.0 + w = (m10 - m01) / s + x = (m02 + m20) / s + y = (m12 + m21) / s + z = 0.25 * s + q = torch.stack([w, x, y, z]) + return q / q.norm().clamp(min=1e-8) + + +def to_k4(k3: torch.Tensor) -> torch.Tensor: + b = k3.shape[0] + out = torch.eye(4, dtype=k3.dtype, device=k3.device).unsqueeze(0).repeat(b, 1, 1) + out[:, :3, :3] = k3 + return out + + +def warmup_cosine_lr(step: int, warmup: int, total: int, lr0: float, lr1: float) -> float: + if step <= warmup: + return lr0 * float(step) / float(max(1, warmup)) + t = (step - warmup) / float(max(1, total - warmup)) + cos = 0.5 * (1 + np.cos(np.pi * t)) + return lr1 + (lr0 - lr1) * cos + + +@torch.no_grad() +def compute_frustum_mask( + depth: torch.Tensor, + tgt_w2c: torch.Tensor, + src_w2c: torch.Tensor, + src_k3: torch.Tensor, + tgt_k3: torch.Tensor, + img_h: int, + img_w: int, + source_img_h: int | None = None, + source_img_w: int | None = None, + depth_min: float = 0.05, + margin: float = 0.05, +) -> torch.Tensor: + dev = depth.device + f32 = torch.float32 + src_h = int(img_h if source_img_h is None else source_img_h) + src_w = int(img_w if source_img_w is None else source_img_w) + + d = depth[0, 0].to(f32) + valid = d > depth_min + + vy, vx = torch.meshgrid( + torch.arange(img_h, device=dev, dtype=f32), + torch.arange(img_w, device=dev, dtype=f32), + indexing="ij", + ) + + fx_t = tgt_k3[0, 0, 0].to(f32) + fy_t = tgt_k3[0, 1, 1].to(f32) + cx_t = tgt_k3[0, 0, 2].to(f32) + cy_t = tgt_k3[0, 1, 2].to(f32) + X_t = (vx - cx_t) / fx_t * d + Y_t = (vy - cy_t) / fy_t * d + Z_t = d + pts_t = torch.stack([X_t, Y_t, Z_t], dim=-1).reshape(-1, 3) + + c2w_t = torch.linalg.inv(tgt_w2c[0].to(f32)) + pts_w = pts_t @ c2w_t[:3, :3].T + c2w_t[:3, 3][None, :] + + w2c_s = src_w2c[0].to(f32) + pts_s = pts_w @ w2c_s[:3, :3].T + w2c_s[:3, 3][None, :] + + Z_s = pts_s[:, 2].clamp(min=1e-4) + fx_s = src_k3[0, 0, 0].to(f32) + fy_s = src_k3[0, 1, 1].to(f32) + cx_s = src_k3[0, 0, 2].to(f32) + cy_s = src_k3[0, 1, 2].to(f32) + u_s = pts_s[:, 0] / Z_s * fx_s + cx_s + v_s = pts_s[:, 1] / Z_s * fy_s + cy_s + + half_w = (src_w - 1) * 0.5 + half_h = (src_h - 1) * 0.5 + x_ndc = (u_s - half_w) / half_w + y_ndc = (v_s - half_h) / half_h + + in_frust = ( + (x_ndc.abs() <= 1.0 + margin) + & (y_ndc.abs() <= 1.0 + margin) + & (pts_s[:, 2] > 0) + ) + + mask = in_frust.reshape(img_h, img_w).float() + mask = mask * valid.float() + return mask[None, None] diff --git a/unisharp/cli/unified_trainer.py b/unisharp/cli/unified_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf8696dbc96f85f381b86bf4a76e5fcd56fc16e3 --- /dev/null +++ b/unisharp/cli/unified_trainer.py @@ -0,0 +1,1966 @@ +from __future__ import annotations + +from dataclasses import dataclass +import os +from typing import Any, Callable + +import torch +import torch.nn.functional as F +from torch import nn + +from unisharp.utils.gsplat import GSplatRenderer +from unisharp.losses import UnisharpLoss +from unisharp.utils.camera_utils import ( + transform_gaussians_to_world, + to_k4, + compute_frustum_mask, +) +from unisharp.utils.fisheye_geer import ( + compute_fisheye624_frustum_mask, + render_gaussians_fisheye624, +) +from unisharp.utils.camera_projection import cubemap_face_cameras, build_extrinsics_w2c, view_frustum_mask_cubemap_union +from unisharp.utils.pano import Cube2Equirec, get_pinhole_intrinsics_4x4 +from unisharp import DEFAULT_MAX_DEPTH_M +from unisharp.utils.pixel_convention import integer_pixel_center_grid + + +@dataclass +class _ModeStrategy: + + batch_size: int + gaussians: Any + make_world_gaussians: Callable[[int, Any], Any] + make_sample: Callable[[int, Any, bool], dict[str, Any]] + collect_all_vis: bool = False + + +class UnifiedTrainer: + + def __init__( + self, + model: nn.Module, + renderer: GSplatRenderer, + loss_fn: UnisharpLoss, + device: torch.device, + enable_tgt_unik3d_vis: bool = True, + max_depth_m: float = DEFAULT_MAX_DEPTH_M, + sim_far_depth_invalid_m: float = 30.0, + re10k_pseudo_far_depth_invalid_m: float = 30.0, + scanetpp_fisheye_far_depth_invalid_m: float = 30.0, + aux_ray_loss_weight: float = 3.0, + aux_depth_scale_loss_weight: float = 3.0, + aux_depth2_scale_loss_weight: float = 1.0, + target_mask_erode_px: int = 0, + ): + self.model = model + self.renderer = renderer + self.loss_fn = loss_fn + self.device = device + self.enable_tgt_unik3d_vis = bool(enable_tgt_unik3d_vis) + self.max_depth_m = float(max_depth_m) + self.sim_far_depth_invalid_m = float(sim_far_depth_invalid_m) + self.re10k_pseudo_far_depth_invalid_m = float(re10k_pseudo_far_depth_invalid_m) + self.scanetpp_fisheye_far_depth_invalid_m = float(scanetpp_fisheye_far_depth_invalid_m) + self.aux_ray_loss_weight = float(aux_ray_loss_weight) + self.aux_depth_scale_loss_weight = float(aux_depth_scale_loss_weight) + self.aux_depth2_scale_loss_weight = float(aux_depth2_scale_loss_weight) + self.target_mask_erode_px = max(int(target_mask_erode_px), 0) + + @staticmethod + def _erode_supervision_mask(mask: torch.Tensor, radius_px: int, *, circular_h: bool = False) -> torch.Tensor: + radius = max(int(radius_px), 0) + if radius <= 0: + return mask + if not torch.is_tensor(mask): + return mask + m = mask.to(dtype=torch.float32).clamp(0.0, 1.0) + if m.ndim == 3: + m = m.unsqueeze(1) + invalid = 1.0 - m + kernel = 2 * radius + 1 + if bool(circular_h): + invalid = F.pad(invalid, (radius, radius, 0, 0), mode="circular") + invalid = F.pad(invalid, (0, 0, radius, radius), mode="constant", value=0.0) + dilated_invalid = F.max_pool2d(invalid, kernel_size=kernel, stride=1) + else: + dilated_invalid = F.max_pool2d(invalid, kernel_size=kernel, stride=1, padding=radius) + return (m * (1.0 - dilated_invalid)).to(device=mask.device, dtype=mask.dtype) + + def _aux_ray_losses( + self, + *, + pred_rays: torch.Tensor | None, + gt_rays: torch.Tensor | None, + mask: torch.Tensor | None, + pred_distance: torch.Tensor | None = None, + pred_distance2: torch.Tensor | None = None, + gt_distance: torch.Tensor | None = None, + gt_distance2: torch.Tensor | None = None, + depth_mask: torch.Tensor | None = None, + depth_mask2: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + out: dict[str, torch.Tensor] = {} + if torch.is_tensor(pred_rays) and torch.is_tensor(gt_rays) and self.aux_ray_loss_weight > 0.0: + out["unik3d_ray"] = self.aux_ray_loss_weight * self._unik3d_polar_ray_loss( + pred_rays, + gt_rays, + mask, + ) + if torch.is_tensor(pred_distance) and torch.is_tensor(gt_distance): + out["unik3d_depth_scale"] = self.aux_depth_scale_loss_weight * self._unik3d_scale_depth_loss( + pred_distance, + gt_distance, + depth_mask if torch.is_tensor(depth_mask) else mask, + ) + depth2_target = gt_distance2 if torch.is_tensor(gt_distance2) else gt_distance + if torch.is_tensor(pred_distance2) and torch.is_tensor(depth2_target): + depth2_mask = depth_mask2 if torch.is_tensor(depth_mask2) else depth_mask + out["unik3d_depth2_scale"] = self.aux_depth2_scale_loss_weight * self._unik3d_scale_depth_loss( + pred_distance2, + depth2_target, + depth2_mask if torch.is_tensor(depth2_mask) else mask, + ) + return out + + DEPTH_SUPERVISION_MAX_M: float = DEFAULT_MAX_DEPTH_M + + def _distance_init_cap_for_dataset(self, dataset_name: str) -> float | None: + name = str(dataset_name).lower() + if name == "re10k" and self.re10k_pseudo_far_depth_invalid_m > 0.0: + return self.re10k_pseudo_far_depth_invalid_m + if name == "sim" and self.sim_far_depth_invalid_m > 0.0: + return self.sim_far_depth_invalid_m + if name in {"scanetpp_fisheye", "scannetpp_fisheye"} and self.scanetpp_fisheye_far_depth_invalid_m > 0.0: + return self.scanetpp_fisheye_far_depth_invalid_m + return None + + @staticmethod + def _unik3d_polar_ray_loss( + pred_rays: torch.Tensor | None, + gt_rays: torch.Tensor | None, + mask: torch.Tensor | None, + ) -> torch.Tensor: + if not torch.is_tensor(pred_rays) or not torch.is_tensor(gt_rays): + device = pred_rays.device if torch.is_tensor(pred_rays) else torch.device("cpu") + return torch.zeros((), device=device, dtype=torch.float32) + pred = pred_rays.to(dtype=torch.float32) + gt = gt_rays.to(device=pred.device, dtype=torch.float32) + if pred.ndim == 3: + pred = pred.unsqueeze(0) + if gt.ndim == 3: + gt = gt.unsqueeze(0) + if tuple(pred.shape) != tuple(gt.shape): + gt = F.interpolate(gt, size=pred.shape[-2:], mode="bilinear", align_corners=False) + gt = gt / torch.norm(gt, dim=1, keepdim=True).clamp(min=1e-5) + pred = pred / torch.norm(pred, dim=1, keepdim=True).clamp(min=1e-5) + gt = gt / torch.norm(gt, dim=1, keepdim=True).clamp(min=1e-5) + + px, py, pz = pred.unbind(dim=1) + gx, gy, gz = gt.unbind(dim=1) + polar_pred = torch.acos(pz.clamp(min=-0.99999, max=0.99999)) + polar_gt = torch.acos(gz.clamp(min=-0.99999, max=0.99999)) + az_pred = torch.atan2(py, px.abs().clamp(min=1e-5) * (2.0 * (px > 0).to(px.dtype) - 1.0)) + az_gt = torch.atan2(gy, gx.abs().clamp(min=1e-5) * (2.0 * (gx > 0).to(gx.dtype) - 1.0)) + polar_error = (polar_pred - polar_gt).abs() + az_delta = az_pred - az_gt + az_error = torch.atan2(torch.sin(az_delta), torch.cos(az_delta)).abs() + quantile_weight = torch.ones_like(polar_error) + quantile_weight[(polar_gt > polar_pred) & (polar_gt > torch.pi / 2)] = 1.4 + quantile_weight[(polar_gt <= polar_pred) & (polar_gt > torch.pi / 2)] = 0.6 + + if torch.is_tensor(mask): + m = mask.to(device=pred.device, dtype=torch.float32) + if m.ndim == 3: + m = m.unsqueeze(1) + if tuple(m.shape[-2:]) != tuple(pred.shape[-2:]): + m = F.interpolate(m, size=pred.shape[-2:], mode="nearest") + m = m[:, 0].clamp(0.0, 1.0) + else: + m = torch.ones_like(polar_error) + denom = m.sum(dim=(-1, -2), keepdim=False).clamp(min=1.0) + mean_polar = (polar_error * quantile_weight * m).sum(dim=(-1, -2)) / denom + mean_azimuth = (az_error * m).sum(dim=(-1, -2)) / denom + mean_error = (3.0 * mean_polar + mean_azimuth) / 4.0 + return torch.sqrt(mean_error + 1e-4).mean() + + @staticmethod + def _unik3d_scale_depth_loss( + pred_distance: torch.Tensor, + gt_distance: torch.Tensor, + mask: torch.Tensor | None, + ) -> torch.Tensor: + pred = UnifiedTrainer._as_b1hw_depth(pred_distance).to(dtype=torch.float32) + gt = UnifiedTrainer._as_b1hw_depth(gt_distance).to(device=pred.device, dtype=torch.float32) + if tuple(gt.shape[-2:]) != tuple(pred.shape[-2:]): + gt = F.interpolate(gt, size=pred.shape[-2:], mode="nearest") + valid = torch.isfinite(pred) & torch.isfinite(gt) & (pred > 0.0) & (gt > 0.0) + if torch.is_tensor(mask): + m = mask.to(device=pred.device) + if m.ndim == 3: + m = m.unsqueeze(1) + if tuple(m.shape[-2:]) != tuple(pred.shape[-2:]): + m = F.interpolate(m.to(dtype=torch.float32), size=pred.shape[-2:], mode="nearest") + valid = valid & (m[:, :1] > 0.5) + err = (gt.clamp(min=1e-4).log() - pred.clamp(min=1e-4).log()).abs() + err = torch.where(valid, err, torch.zeros_like(err)) + denom = valid.to(dtype=err.dtype).sum(dim=(-2, -1)).clamp(min=1.0) + per_image = err.sum(dim=(-2, -1)) / denom + return torch.sqrt(per_image.clamp(min=0.0)).mean() + + def _base_model(self) -> nn.Module: + return self.model.module if hasattr(self.model, "module") else self.model + + def process_batch( + self, + batch: Any, + dataset_name: str, + step: int, + need_vis: bool = False, + ) -> dict[str, Any]: + if hasattr(batch, "src_rgb_u8") and hasattr(batch, "src_intrinsics"): + strategy = self._build_pinhole_strategy( + batch, + step, + need_vis=need_vis, + dataset_name=str(dataset_name), + ) + elif hasattr(batch, "src_rgb_u8") and hasattr(batch, "src_camera_params"): + strategy = self._build_fisheye_strategy( + batch, + step, + need_vis=need_vis, + dataset_name=str(dataset_name), + ) + elif hasattr(batch, "src_erp_rgb_u8") and hasattr(batch, "src_cube_depth_m"): + strategy = self._build_spherical_strategy( + batch, + step, + need_vis=need_vis, + dataset_name=str(dataset_name), + ) + else: + raise ValueError(f"Unknown batch schema for dataset={dataset_name}") + return self._run_strategy_loop( + strategy, + need_vis=need_vis, + ) + + def _run_strategy_loop( + self, + strategy: _ModeStrategy, + need_vis: bool = False, + ) -> dict[str, Any]: + total_loss = torch.zeros((), device=self.device) + src_sum = torch.zeros((), device=self.device) + tgt_sum = torch.zeros((), device=self.device) + src_log_sum: dict[str, torch.Tensor] = {} + tgt_log_sum: dict[str, torch.Tensor] = {} + aux_log_sum: dict[str, torch.Tensor] = {} + vis_payload = None + vis_payloads: list[dict[str, Any]] = [] + + def _accumulate_loss_terms(term_specs: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + merged: dict[str, torch.Tensor] = {} + for spec in term_specs: + term_losses = self._compute_view_loss(**spec) + for k, v in term_losses.items(): + merged[k] = merged.get(k, torch.zeros((), device=self.device)) + v + return merged + + collect_all_vis = bool(getattr(strategy, "collect_all_vis", False)) + for b in range(int(strategy.batch_size)): + g = strategy.gaussians + g_b = type(g)( + mean_vectors=g.mean_vectors[b : b + 1], + singular_values=g.singular_values[b : b + 1], + quaternions=g.quaternions[b : b + 1], + colors=g.colors[b : b + 1], + opacities=g.opacities[b : b + 1], + ) + g_world = strategy.make_world_gaussians(b, g_b) + sample = strategy.make_sample( + b, + g_world, + bool(need_vis and (collect_all_vis or b == 0)), + ) + + if isinstance(sample.get("src_loss_terms", None), list): + src_losses = _accumulate_loss_terms(sample["src_loss_terms"]) + else: + src_losses = self._compute_view_loss( + pred_rgb_linear=sample["src_pred_rgb_linear"], + pred_alpha=sample["src_pred_alpha"], + pred_depth_m=sample["src_pred_depth_m"], + pred_depth2_m=sample.get("src_pred_depth2_m", None), + gt_rgb_u8=sample["src_gt_rgb_u8"], + gt_depth_m=sample["src_gt_depth_m"], + mask=sample["src_mask"], + apply_color=bool(sample.get("src_apply_color", True)), + apply_alpha=bool(sample.get("src_apply_alpha", True)), + apply_depth=bool(sample.get("src_apply_depth", True)), + apply_percep=False, + apply_tv=True, + apply_grad=bool(sample.get("src_apply_grad", True)), + apply_grad_img=bool(sample.get("src_apply_grad_img", True)), + apply_splat=bool(sample.get("src_apply_splat", True)), + grad_img_circular_h=sample.get("src_grad_img_circular_h", None), + gaussian_scales=sample.get("gaussian_scales", None), + gaussian_quaternions=sample.get("gaussian_quaternions", None), + gaussian_angular_cell=sample.get("gaussian_angular_cell", None), + delta_xy=sample.get("delta_xy", None), + delta_rho=sample.get("delta_rho", None), + delta_grid=sample.get("delta_grid", None), + gaussian_mean_vectors=sample.get("gaussian_mean_vectors", None), + gaussian_base_mean_vectors=sample.get("gaussian_base_mean_vectors", None), + gaussian_opacities=sample.get("gaussian_opacities", None), + gauss_grid_shape=sample.get("gauss_grid_shape", None), + projected_scale_factor=sample.get("projected_scale_factor", None), + projection_model=sample.get("projection_model", None), + projection_intrinsics=sample.get("projection_intrinsics", None), + projection_camera_params=sample.get("projection_camera_params", None), + depth_mask=sample.get("src_depth_mask", None), + ) + if isinstance(sample.get("src_extra_loss_terms", None), list): + extra_src_losses = _accumulate_loss_terms(sample["src_extra_loss_terms"]) + for k, v in extra_src_losses.items(): + src_losses[k] = src_losses.get(k, torch.zeros((), device=self.device)) + v + if isinstance(sample.get("tgt_loss_terms", None), list): + tgt_losses = _accumulate_loss_terms(sample["tgt_loss_terms"]) + else: + tgt_losses = self._compute_view_loss( + pred_rgb_linear=sample["tgt_pred_rgb_linear"], + pred_alpha=sample["tgt_pred_alpha"], + pred_depth_m=sample["tgt_pred_depth_m"], + pred_depth2_m=sample.get("tgt_pred_depth2_m", None), + gt_rgb_u8=sample["tgt_gt_rgb_u8"], + gt_depth_m=sample["tgt_gt_depth_m"], + mask=sample["tgt_mask"], + apply_color=bool(sample.get("tgt_apply_color", True)), + apply_alpha=bool(sample.get("tgt_apply_alpha", True)), + apply_depth=bool(sample.get("tgt_apply_depth", True)), + apply_percep=bool(sample.get("tgt_apply_percep", False)), + apply_tv=False, + apply_grad=False, + apply_grad_img=bool(sample.get("tgt_apply_grad_img", True)), + apply_splat=bool(sample.get("tgt_apply_splat", False)), + grad_img_circular_h=sample.get("tgt_grad_img_circular_h", None), + gaussian_scales=None, + gaussian_quaternions=None, + delta_xy=None, + delta_rho=None, + gaussian_mean_vectors=None, + gaussian_base_mean_vectors=None, + gaussian_opacities=None, + gauss_grid_shape=None, + projected_scale_factor=sample.get("projected_scale_factor", None), + projection_model=sample.get("projection_model", None), + projection_intrinsics=sample.get("projection_intrinsics", None), + projection_camera_params=sample.get("projection_camera_params", None), + depth_mask=sample.get("tgt_depth_mask", None), + ) + if isinstance(sample.get("tgt_extra_loss_terms", None), list): + extra_tgt_losses = _accumulate_loss_terms(sample["tgt_extra_loss_terms"]) + for k, v in extra_tgt_losses.items(): + tgt_losses[k] = tgt_losses.get(k, torch.zeros((), device=self.device)) + v + + aux_total = torch.zeros((), device=self.device) + raw_aux = sample.get("aux_losses", None) + if isinstance(raw_aux, dict): + for k, v in raw_aux.items(): + if torch.is_tensor(v): + vv = v.to(device=self.device) + else: + vv = torch.tensor(float(v), device=self.device, dtype=torch.float32) + aux_total = aux_total + vv + aux_log_sum[str(k)] = aux_log_sum.get(str(k), torch.zeros((), device=self.device)) + vv.detach() + + src_sum = src_sum + src_losses["total"] + tgt_sum = tgt_sum + tgt_losses["total"] + total_loss = total_loss + src_losses["total"] + tgt_losses["total"] + aux_total + for k, v in src_losses.items(): + src_log_sum[k] = src_log_sum.get(k, torch.zeros((), device=self.device)) + v.detach() + for k, v in tgt_losses.items(): + tgt_log_sum[k] = tgt_log_sum.get(k, torch.zeros((), device=self.device)) + v.detach() + + if need_vis and isinstance(sample.get("vis_payload", None), dict): + vis_payloads.append(sample["vis_payload"]) + if b == 0: + vis_payload = sample["vis_payload"] + + bs = float(strategy.batch_size) + total_loss = total_loss / bs + src_sum = src_sum / bs + tgt_sum = tgt_sum / bs + loss_breakdown: dict[str, torch.Tensor] = {} + for k, v in src_log_sum.items(): + loss_breakdown[f"src_{k}"] = v / bs + for k, v in tgt_log_sum.items(): + loss_breakdown[f"tgt_{k}"] = v / bs + for k, v in aux_log_sum.items(): + loss_breakdown[f"aux_{k}"] = v / bs + batch_stats = { + "batch_size": int(strategy.batch_size), + "gaussian_count": int(strategy.gaussians.mean_vectors.shape[1]), + } + + return { + "total": total_loss, + "src": src_sum, + "tgt": tgt_sum, + "loss_breakdown": loss_breakdown, + "batch_stats": batch_stats, + "vis_payload": vis_payload, + "vis_payloads": vis_payloads, + } + + @staticmethod + def _first_item(x: Any, default: Any = None) -> Any: + if x is None: + return default + if isinstance(x, (list, tuple)): + return x[0] if len(x) > 0 else default + if torch.is_tensor(x): + if x.numel() == 0: + return default + return x.flatten()[0].item() + return x + + @staticmethod + def _item_at(x: Any, index: int, default: Any = None) -> Any: + if x is None: + return default + if isinstance(x, (list, tuple)): + return x[index] if 0 <= int(index) < len(x) else default + if torch.is_tensor(x): + if x.numel() == 0: + return default + if x.ndim == 0: + return x.item() + if 0 <= int(index) < int(x.shape[0]): + item = x[int(index)] + return item.item() if item.numel() == 1 else item + return default + return x + + @staticmethod + def _finite_quantile(x: torch.Tensor, q: float, default: float = float("nan")) -> torch.Tensor: + vals = x[torch.isfinite(x)] + if int(vals.numel()) <= 0: + return torch.tensor(float(default), device=x.device, dtype=torch.float32) + vals = vals.to(torch.float32).flatten() + if int(vals.numel()) > 262144: + step = max(1, int(vals.numel()) // 262144) + vals = vals[::step] + return torch.quantile(vals, float(q)) + + def _clamp_distance_for_supervision( + self, + depth_m: torch.Tensor | None, + *, + max_depth_m: float | None = None, + clamp_max: bool = True, + ) -> torch.Tensor | None: + if not torch.is_tensor(depth_m): + return None + cap = float(self.max_depth_m if max_depth_m is None else max_depth_m) + out = depth_m.to(dtype=torch.float32) + valid = torch.isfinite(out) & (out > 0.0) + if bool(clamp_max): + sanitized = out.clamp(min=1e-4, max=cap) + else: + sanitized = out.clamp(min=1e-4) + return torch.where(valid, sanitized, torch.zeros_like(out)) + + @staticmethod + def _rendered_depth_valid_for_inv_loss( + depth_m: torch.Tensor, + alpha: torch.Tensor, + *, + alpha_min: float | None = None, + depth_min_m: float = 1e-3, + ) -> torch.Tensor: + depth = depth_m.detach() + valid = torch.isfinite(depth) & (depth > float(depth_min_m)) + if alpha_min is not None: + a = alpha.detach().to(device=depth.device) + valid = valid & (a[:, :1] > float(alpha_min)) + return valid.to(dtype=depth.dtype) + + def _pinhole_z_to_supervision_distance( + self, + z_depth_b1hw: torch.Tensor | None, + k3_b33: torch.Tensor | None, + *, + clamp_max: bool = True, + ) -> torch.Tensor | None: + if not torch.is_tensor(z_depth_b1hw) or not torch.is_tensor(k3_b33): + return None + dist = self._z_depth_to_distance_pinhole(z_depth_b1hw, k3_b33) + return self._clamp_distance_for_supervision(dist, clamp_max=bool(clamp_max)) + + @staticmethod + def _sanitize_positive_depth(depth_m: torch.Tensor | None) -> torch.Tensor | None: + if not torch.is_tensor(depth_m): + return None + out = depth_m.to(dtype=torch.float32) + valid = torch.isfinite(out) & (out > 0.0) + return torch.where(valid, out, torch.zeros_like(out)) + + @staticmethod + def _as_b1hw_depth(depth: torch.Tensor) -> torch.Tensor: + if depth.ndim == 3: + return depth.unsqueeze(1) + if depth.ndim == 4 and depth.shape[1] == 1: + return depth + raise ValueError(f"Expected depth shape (B,H,W) or (B,1,H,W), got {tuple(depth.shape)}") + + @staticmethod + def _as_bchw_rgb_u8(image: torch.Tensor) -> torch.Tensor: + if image.ndim == 3 and image.shape[0] == 3: + return image.unsqueeze(0) + if image.ndim == 4 and image.shape[1] == 3: + return image + raise ValueError(f"Expected image shape (3,H,W) or (B,3,H,W), got {tuple(image.shape)}") + + @staticmethod + def _as_b33_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor: + if intrinsics.ndim == 2 and tuple(intrinsics.shape) == (3, 3): + return intrinsics.unsqueeze(0) + if intrinsics.ndim == 3 and tuple(intrinsics.shape[1:]) == (3, 3): + return intrinsics + raise ValueError( + f"Expected intrinsics shape (3,3) or (B,3,3), got {tuple(intrinsics.shape)}" + ) + + @staticmethod + def _as_b9_camera_params(camera_params: torch.Tensor) -> torch.Tensor: + if camera_params.ndim == 1 and int(camera_params.shape[0]) == 9: + return camera_params.unsqueeze(0) + if camera_params.ndim == 2 and int(camera_params.shape[1]) == 9: + return camera_params + raise ValueError(f"Expected camera_params shape (9,) or (B,9), got {tuple(camera_params.shape)}") + + @staticmethod + def _as_b16_camera_params(camera_params: torch.Tensor) -> torch.Tensor: + if camera_params.ndim == 1 and int(camera_params.shape[0]) == 16: + return camera_params.unsqueeze(0) + if camera_params.ndim == 2 and int(camera_params.shape[1]) == 16: + return camera_params + raise ValueError(f"Expected camera_params shape (16,) or (B,16), got {tuple(camera_params.shape)}") + + @staticmethod + def _as_b44_pose(extrinsics: torch.Tensor) -> torch.Tensor: + if extrinsics.ndim == 2 and tuple(extrinsics.shape) == (4, 4): + return extrinsics.unsqueeze(0) + if extrinsics.ndim == 3 and tuple(extrinsics.shape[1:]) == (4, 4): + return extrinsics + raise ValueError( + f"Expected extrinsics shape (4,4) or (B,4,4), got {tuple(extrinsics.shape)}" + ) + + @staticmethod + def _pick_depth_for_pinhole_frustum_mask( + gt_depth: torch.Tensor | None, + pred_depth: torch.Tensor, + min_valid_px: int = 8, + ) -> torch.Tensor: + if torch.is_tensor(gt_depth): + gt_depth = UnifiedTrainer._as_b1hw_depth(gt_depth) + valid = torch.isfinite(gt_depth) & (gt_depth > 0.0) + if int(valid.sum().item()) >= int(min_valid_px): + return gt_depth + return pred_depth + + @staticmethod + def _pick_depth_for_fisheye_frustum_mask( + gt_depth: torch.Tensor | None, + pred_depth: torch.Tensor, + gt_valid_mask: torch.Tensor | None = None, + min_valid_px: int = 8, + ) -> torch.Tensor: + if torch.is_tensor(gt_depth): + gt_depth = UnifiedTrainer._as_b1hw_depth(gt_depth) + if torch.is_tensor(gt_valid_mask): + gt_valid = gt_depth > 0.0 + gt_valid = gt_valid & (gt_valid_mask > 0.5) + else: + gt_valid = torch.isfinite(gt_depth) & (gt_depth > 0.0) + if int(gt_valid.sum().item()) >= int(min_valid_px): + return gt_depth + return pred_depth + + @staticmethod + def _as_cubemap_depth_hw1(depth: torch.Tensor) -> torch.Tensor: + if depth.ndim != 4: + raise ValueError(f"Expected 4D cubemap depth, got shape={tuple(depth.shape)}") + if depth.shape[-1] == 1: + return depth + if depth.shape[1] == 1: + return depth.permute(0, 2, 3, 1).contiguous() + raise ValueError(f"Unsupported cubemap depth shape={tuple(depth.shape)}") + + def _pick_depth_for_cubemap_frustum_mask( + self, + gt_depth_cube: torch.Tensor | None, + pred_depth_cube: torch.Tensor, + face_w: int, + min_valid_px: int = 8, + ) -> torch.Tensor: + pred_hw1 = self._as_cubemap_depth_hw1(pred_depth_cube) + if torch.is_tensor(gt_depth_cube): + gt_hw1 = self._as_cubemap_depth_hw1(gt_depth_cube) + gt_dist = self._cubemap_z_depth_to_distance(gt_hw1) + gt_hw1 = self._as_cubemap_depth_hw1(gt_dist) + if gt_hw1.shape[1] != int(face_w) or gt_hw1.shape[2] != int(face_w): + gt_hw1 = F.interpolate( + gt_hw1.permute(0, 3, 1, 2), + size=(int(face_w), int(face_w)), + mode="nearest", + ).permute(0, 2, 3, 1).contiguous() + valid = torch.isfinite(gt_hw1[..., 0]) & (gt_hw1[..., 0] > 0.0) + if int(valid.sum().item()) >= int(min_valid_px): + return gt_hw1 + return pred_hw1 + + @staticmethod + def _distance_to_z_depth_pinhole( + distance_b1hw: torch.Tensor, + intrinsics_b33: torch.Tensor, + ) -> torch.Tensor: + distance_b1hw = UnifiedTrainer._as_b1hw_depth(distance_b1hw) + intrinsics_b33 = UnifiedTrainer._as_b33_intrinsics(intrinsics_b33) + b, _, h, w = distance_b1hw.shape + dev = distance_b1hw.device + dtype = distance_b1hw.dtype + uu, vv = integer_pixel_center_grid(h, w, device=dev, dtype=dtype) + uu = uu.unsqueeze(0).expand(b, -1, -1) + vv = vv.unsqueeze(0).expand(b, -1, -1) + fx = intrinsics_b33[:, 0, 0].view(b, 1, 1).to(dtype=dtype, device=dev) + fy = intrinsics_b33[:, 1, 1].view(b, 1, 1).to(dtype=dtype, device=dev) + cx = intrinsics_b33[:, 0, 2].view(b, 1, 1).to(dtype=dtype, device=dev) + cy = intrinsics_b33[:, 1, 2].view(b, 1, 1).to(dtype=dtype, device=dev) + x = (uu - cx) / fx + y = (vv - cy) / fy + ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8) + return distance_b1hw * ray_z.unsqueeze(1) + + @staticmethod + def _z_depth_to_distance_pinhole( + z_depth_b1hw: torch.Tensor, + intrinsics_b33: torch.Tensor, + ) -> torch.Tensor: + z_depth_b1hw = UnifiedTrainer._as_b1hw_depth(z_depth_b1hw) + intrinsics_b33 = UnifiedTrainer._as_b33_intrinsics(intrinsics_b33) + b, _, h, w = z_depth_b1hw.shape + dev = z_depth_b1hw.device + dtype = z_depth_b1hw.dtype + uu, vv = integer_pixel_center_grid(h, w, device=dev, dtype=dtype) + uu = uu.unsqueeze(0).expand(b, -1, -1) + vv = vv.unsqueeze(0).expand(b, -1, -1) + fx = intrinsics_b33[:, 0, 0].view(b, 1, 1).to(dtype=dtype, device=dev) + fy = intrinsics_b33[:, 1, 1].view(b, 1, 1).to(dtype=dtype, device=dev) + cx = intrinsics_b33[:, 0, 2].view(b, 1, 1).to(dtype=dtype, device=dev) + cy = intrinsics_b33[:, 1, 2].view(b, 1, 1).to(dtype=dtype, device=dev) + x = (uu - cx) / fx + y = (vv - cy) / fy + ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8) + return z_depth_b1hw / ray_z.unsqueeze(1).clamp(min=1e-8) + + def _cubemap_z_depth_to_distance( + self, + depth_cube: torch.Tensor, + ) -> torch.Tensor: + if depth_cube.ndim != 4: + raise ValueError(f"Expected 4D cubemap depth, got {tuple(depth_cube.shape)}") + if depth_cube.shape[-1] == 1: + depth_61hw = depth_cube.permute(0, 3, 1, 2).contiguous() + elif depth_cube.shape[1] == 1: + depth_61hw = depth_cube + else: + raise ValueError(f"Unsupported cubemap depth shape={tuple(depth_cube.shape)}") + + _, _, h, w = depth_61hw.shape + intr = get_pinhole_intrinsics_4x4(int(w)).to( + device=depth_61hw.device, + dtype=depth_61hw.dtype, + ) + fx = intr[0, 0] + fy = intr[1, 1] + cx = intr[0, 2] + cy = intr[1, 2] + uu, vv = integer_pixel_center_grid(h, w, device=depth_61hw.device, dtype=depth_61hw.dtype) + x = (uu - cx) / fx + y = (vv - cy) / fy + ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8) + dist = depth_61hw / ray_z.view(1, 1, h, w).clamp(min=1e-8) + valid = torch.isfinite(dist) & (depth_61hw > 0.0) + dist = torch.where(valid, dist.clamp(min=1e-4), torch.zeros_like(dist)) + return dist + + def _collect_regularization_inputs( + self, + out: dict[str, Any], + gaussians: Any, + b: int, + projected_scale_factor: float | None, + ) -> dict[str, Any]: + delta_b = out.get("delta", None) + delta_xy_raw = None + if torch.is_tensor(delta_b): + delta_xy_raw = delta_b[b : b + 1, 0:2] + delta_rho_raw = delta_b[b : b + 1, 2:3] + delta_grid_raw = delta_b[b : b + 1] + else: + delta_rho_raw = None + delta_grid_raw = None + delta_rho_applied_all = out.get("delta_rho_applied", None) + delta_rho_applied = ( + delta_rho_applied_all[b : b + 1] + if torch.is_tensor(delta_rho_applied_all) + else None + ) + scale_factor_applied_all = out.get("scale_factor_applied", None) + scale_factor_applied = ( + scale_factor_applied_all[b : b + 1] + if torch.is_tensor(scale_factor_applied_all) + else None + ) + + scales_b = gaussians.singular_values[b : b + 1] + means_b = gaussians.mean_vectors[b : b + 1] + quats_b = gaussians.quaternions[b : b + 1] + opac_b = gaussians.opacities[b : b + 1] + + base_values = out.get("gaussian_base_values", None) + gauss_grid_shape = None + base_means_b = None + base_scales_b = None + angular_cell_b = None + if base_values is not None and hasattr(base_values, "rays"): + _, _, l, hb, wb = base_values.rays.shape + gauss_grid_shape = (int(l), int(hb), int(wb)) + inv_dist_b = base_values.inv_distance[b : b + 1].clamp(min=1e-6) + base_rays_b = F.normalize(base_values.rays[b : b + 1], dim=1, eps=1e-6) + base_means_grid = base_rays_b / inv_dist_b + base_scales_b = base_values.scales[b : b + 1] + init_output = out.get("initializer_output", None) + global_scale = ( + init_output.global_scale[b : b + 1] + if init_output is not None + and getattr(init_output, "global_scale", None) is not None + else None + ) + if torch.is_tensor(global_scale): + base_means_grid = base_means_grid * global_scale.view(-1, 1, 1, 1, 1) + base_scales_b = base_scales_b * global_scale.view(-1, 1, 1, 1, 1) + base_means_b = base_means_grid.permute(0, 2, 3, 4, 1).flatten(1, 3) + angular_cell = getattr(base_values, "angular_cell", None) + angular_cell_b = angular_cell[b : b + 1] if torch.is_tensor(angular_cell) else None + + return { + "delta_xy_eff": delta_xy_raw, + "delta_rho_raw": delta_rho_raw, + "delta_grid": delta_grid_raw, + "delta_rho_applied": delta_rho_applied, + "scale_factor_applied": scale_factor_applied, + "gaussian_scales": scales_b, + "gaussian_quaternions": quats_b, + "gaussian_angular_cell": angular_cell_b, + "gaussian_mean_vectors": means_b, + "gaussian_base_mean_vectors": base_means_b, + "gaussian_base_scales": base_scales_b, + "gaussian_opacities": opac_b, + "gauss_grid_shape": gauss_grid_shape, + "projected_scale_factor": projected_scale_factor, + } + + def _compute_view_loss( + self, + *, + pred_rgb_linear: torch.Tensor, + pred_alpha: torch.Tensor, + pred_depth_m: torch.Tensor, + pred_depth2_m: torch.Tensor | None, + gt_rgb_u8: torch.Tensor, + gt_depth_m: torch.Tensor, + mask: torch.Tensor, + apply_color: bool, + apply_alpha: bool, + apply_depth: bool, + apply_percep: bool, + apply_tv: bool, + apply_grad: bool, + apply_grad_img: bool, + grad_img_circular_h: bool | None = None, + gaussian_scales: torch.Tensor | None = None, + gaussian_quaternions: torch.Tensor | None = None, + gaussian_angular_cell: torch.Tensor | None = None, + delta_xy: torch.Tensor | None = None, + delta_rho: torch.Tensor | None = None, + delta_grid: torch.Tensor | None = None, + gaussian_mean_vectors: torch.Tensor | None = None, + gaussian_base_mean_vectors: torch.Tensor | None = None, + gaussian_opacities: torch.Tensor | None = None, + gauss_grid_shape: tuple[int, int, int] | None = None, + projected_scale_factor: float | torch.Tensor | None = None, + projection_model: str | None = None, + projection_intrinsics: torch.Tensor | None = None, + projection_camera_params: torch.Tensor | None = None, + loss_scale: float = 1.0, + apply_splat: bool | None = None, + depth_mask: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + losses = self.loss_fn( + pred_rgb_linear=pred_rgb_linear, + pred_alpha=pred_alpha, + pred_depth_m=pred_depth_m, + pred_depth2_m=pred_depth2_m, + gt_rgb_u8=gt_rgb_u8, + gt_depth_m=gt_depth_m, + mask=mask, + depth_mask=depth_mask, + gaussian_scales=gaussian_scales, + gaussian_quaternions=gaussian_quaternions, + gaussian_angular_cell=gaussian_angular_cell, + delta_xy=delta_xy, + delta_rho=delta_rho, + delta_grid=delta_grid, + apply_color=bool(apply_color), + apply_alpha=bool(apply_alpha), + apply_depth=bool(apply_depth), + apply_percep=bool(apply_percep), + apply_tv=bool(apply_tv), + apply_grad=bool(apply_grad), + apply_grad_img=bool(apply_grad_img), + grad_img_circular_h=grad_img_circular_h, + apply_delta=bool(torch.is_tensor(delta_xy) or torch.is_tensor(delta_rho)), + apply_splat=bool(torch.is_tensor(gaussian_scales)) if apply_splat is None else bool(apply_splat), + gaussian_mean_vectors=gaussian_mean_vectors, + gaussian_base_mean_vectors=gaussian_base_mean_vectors, + gaussian_opacities=gaussian_opacities, + gauss_grid_shape=gauss_grid_shape, + projected_scale_factor=projected_scale_factor, + projection_model=projection_model, + projection_intrinsics=projection_intrinsics, + projection_camera_params=projection_camera_params, + ) + scale = float(loss_scale) + if abs(scale - 1.0) > 1e-8: + losses = {k: (v * scale) for k, v in losses.items()} + return losses + + def _build_pinhole_strategy( + self, + batch: Any, + step: int, + need_vis: bool = False, + dataset_name: str = "re10k", + ) -> _ModeStrategy: + src_u8 = self._as_bchw_rgb_u8(batch.src_rgb_u8.to(self.device, non_blocking=True)) + tgt_u8 = self._as_bchw_rgb_u8(batch.tgt_rgb_u8.to(self.device, non_blocking=True)) + src_u8_orig = getattr(batch, "src_rgb_u8_orig", None) + tgt_u8_orig = getattr(batch, "tgt_rgb_u8_orig", None) + src_depth_gt = getattr(batch, "src_depth_m", None) + tgt_depth_gt = getattr(batch, "tgt_depth_m", None) + src_depth_gt_orig = getattr(batch, "src_depth_m_orig", None) + tgt_depth_gt_orig = getattr(batch, "tgt_depth_m_orig", None) + has_depth_gt = torch.is_tensor(src_depth_gt) and torch.is_tensor(tgt_depth_gt) + if has_depth_gt: + src_depth_gt = self._as_b1hw_depth( + src_depth_gt.to(self.device, non_blocking=True).to(torch.float32) + ) + tgt_depth_gt = self._as_b1hw_depth( + tgt_depth_gt.to(self.device, non_blocking=True).to(torch.float32) + ) + has_depth_gt_orig = torch.is_tensor(src_depth_gt_orig) and torch.is_tensor(tgt_depth_gt_orig) + if has_depth_gt_orig: + src_depth_gt_orig = self._as_b1hw_depth( + src_depth_gt_orig.to(self.device, non_blocking=True).to(torch.float32) + ) + tgt_depth_gt_orig = self._as_b1hw_depth( + tgt_depth_gt_orig.to(self.device, non_blocking=True).to(torch.float32) + ) + src_w2c = self._as_b44_pose(batch.src_w2c.to(self.device, non_blocking=True).to(torch.float32)) + tgt_w2c = self._as_b44_pose(batch.tgt_w2c.to(self.device, non_blocking=True).to(torch.float32)) + src_k3 = self._as_b33_intrinsics(batch.src_intrinsics.to(self.device, non_blocking=True).to(torch.float32)) + tgt_k3 = self._as_b33_intrinsics(batch.tgt_intrinsics.to(self.device, non_blocking=True).to(torch.float32)) + src_k3_orig = getattr(batch, "src_intrinsics_orig", None) + tgt_k3_orig = getattr(batch, "tgt_intrinsics_orig", None) + has_orig_vis = ( + torch.is_tensor(src_u8_orig) + and torch.is_tensor(tgt_u8_orig) + and torch.is_tensor(src_k3_orig) + and torch.is_tensor(tgt_k3_orig) + ) + if has_orig_vis: + src_u8_orig = self._as_bchw_rgb_u8(src_u8_orig.to(self.device, non_blocking=True)) + tgt_u8_orig = self._as_bchw_rgb_u8(tgt_u8_orig.to(self.device, non_blocking=True)) + src_k3_orig = self._as_b33_intrinsics( + src_k3_orig.to(self.device, non_blocking=True).to(torch.float32) + ) + tgt_k3_orig = self._as_b33_intrinsics( + tgt_k3_orig.to(self.device, non_blocking=True).to(torch.float32) + ) + src_depth_gt_dist = None + tgt_depth_gt_dist = None + src_unik3d_gt_dist = None + if has_depth_gt: + src_unik3d_gt_dist = self._pinhole_z_to_supervision_distance(src_depth_gt, src_k3) + src_depth_gt_dist = src_unik3d_gt_dist + tgt_depth_gt_dist = self._pinhole_z_to_supervision_distance(tgt_depth_gt, tgt_k3) + + src = src_u8.float().clamp(0, 255) / 255.0 + tgt = tgt_u8.float().clamp(0, 255) / 255.0 + distance_init_cap_m = self._distance_init_cap_for_dataset(dataset_name) + + share_src_forward = bool(getattr(batch, "share_src_forward", False)) and int(src.shape[0]) > 1 + + def _repeat_first_dim(value: Any, batch_size: int) -> Any: + if torch.is_tensor(value): + if value.ndim > 0 and int(value.shape[0]) == 1: + return value.repeat(batch_size, *([1] * (value.ndim - 1))) + return value + if hasattr(value, "_fields"): + return type(value)(*[_repeat_first_dim(getattr(value, field), batch_size) for field in value._fields]) + return value + + if share_src_forward: + out_single = self.model( + image=src[0:1], + image_u8=src_u8[0:1], + camera_intrinsics=src_k3[0:1], + camera_model="pinhole", + depth_gt=(src_depth_gt_dist[0:1] if torch.is_tensor(src_depth_gt_dist) else None), + distance_init_cap_m=distance_init_cap_m, + return_aux=True, + ) + out = {k: _repeat_first_dim(v, int(src.shape[0])) for k, v in out_single.items()} + else: + out = self.model( + image=src, + image_u8=src_u8, + camera_intrinsics=src_k3, + camera_model="pinhole", + depth_gt=src_depth_gt_dist, + distance_init_cap_m=distance_init_cap_m, + return_aux=True, + ) + gaussians = out["gaussians"] + src_render_k3 = src_k3 + tgt_render_k3 = tgt_k3 + src_depth_gt_z_render = src_depth_gt if has_depth_gt else None + tgt_depth_gt_z_render = tgt_depth_gt if has_depth_gt else None + src_depth_gt_render_valid = (torch.isfinite(src_depth_gt) & (src_depth_gt > 0.0)) if has_depth_gt else None + tgt_depth_gt_render_valid = (torch.isfinite(tgt_depth_gt) & (tgt_depth_gt > 0.0)) if has_depth_gt else None + aux_ray_target_all = out.get("unik3d_gt_rays", None) + def make_world_gaussians(b: int, g_b: Any) -> Any: + return g_b + + def make_sample(b: int, g_world: Any, enable_vis: bool) -> dict[str, Any]: + src_h = int(src_u8.shape[-2]) + src_w = int(src_u8.shape[-1]) + tgt_h = int(tgt_u8.shape[-2]) + tgt_w = int(tgt_u8.shape[-1]) + ident = torch.eye(4, dtype=src_w2c.dtype, device=self.device).unsqueeze(0) + rel_tgt_w2c = tgt_w2c[b : b + 1] @ torch.linalg.inv(src_w2c[b : b + 1]) + src_k_render_b = src_render_k3[b : b + 1] + tgt_k_render_b = tgt_render_k3[b : b + 1] + src_out = self.renderer( + g_world, + extrinsics=ident, + intrinsics=to_k4(src_k_render_b), + image_width=src_w, + image_height=src_h, + ) + tgt_out = self.renderer( + g_world, + extrinsics=rel_tgt_w2c, + intrinsics=to_k4(tgt_k_render_b), + image_width=tgt_w, + image_height=tgt_h, + ) + + zeros_src_depth = torch.zeros((1, 1, src_h, src_w), dtype=torch.float32, device=self.device) + zeros_tgt_depth = torch.zeros((1, 1, tgt_h, tgt_w), dtype=torch.float32, device=self.device) + ones_mask = torch.ones_like(zeros_src_depth) + fx_b = float(src_k_render_b[0, 0, 0].item()) + fy_b = float(src_k_render_b[0, 1, 1].item()) + proj_scale_pinhole = 0.5 * (fx_b + fy_b) + reg_inputs = self._collect_regularization_inputs( + out=out, + gaussians=gaussians, + b=b, + projected_scale_factor=proj_scale_pinhole, + ) + + src_depth_for_visibility = None + tgt_gt_depth_for_mask = ( + tgt_depth_gt_z_render[b : b + 1] + if has_depth_gt and torch.is_tensor(tgt_depth_gt_z_render) + else None + ) + if has_depth_gt: + src_depth_for_visibility = ( + src_depth_gt_z_render[b : b + 1] + if torch.is_tensor(src_depth_gt_z_render) + else src_depth_gt[b : b + 1] + ) + + tgt_depth_for_mask = self._pick_depth_for_pinhole_frustum_mask( + gt_depth=tgt_gt_depth_for_mask, + pred_depth=tgt_out.depth, + ) + tgt_frustum_mask = compute_frustum_mask( + depth=tgt_depth_for_mask, + tgt_w2c=tgt_w2c[b : b + 1], + src_w2c=src_w2c[b : b + 1], + src_k3=src_k_render_b, + tgt_k3=tgt_k_render_b, + img_h=tgt_h, + img_w=tgt_w, + source_img_h=src_h, + source_img_w=src_w, + source_depth=src_depth_for_visibility, + ) + tgt_frustum_mask_raw = tgt_frustum_mask + tgt_frustum_mask = self._erode_supervision_mask( + tgt_frustum_mask, + self.target_mask_erode_px, + circular_h=False, + ) + src_depth_pred = self._clamp_distance_for_supervision( + out["distance_layers"][b : b + 1, 0:1], + clamp_max=False, + ) + src_depth2_pred = ( + self._clamp_distance_for_supervision(out["distance_layers"][b : b + 1, 1:2], clamp_max=False) + if out["distance_layers"] is not None and out["distance_layers"].shape[1] > 1 + else None + ) + src_depth2_gt_for_aux = ( + src_unik3d_gt_dist[b : b + 1] + if torch.is_tensor(src_unik3d_gt_dist) + else None + ) + src_depth2_mask_for_aux = src_depth_gt[b : b + 1] > 0.0 if has_depth_gt else None + tgt_depth_pred = self._pinhole_z_to_supervision_distance( + tgt_out.depth, + tgt_k_render_b, + clamp_max=False, + ) + tgt_depth_loss_mask = self._rendered_depth_valid_for_inv_loss(tgt_depth_pred, tgt_out.alpha) + if torch.is_tensor(tgt_depth_gt_render_valid): + tgt_depth_loss_mask = tgt_depth_loss_mask * tgt_depth_gt_render_valid[b : b + 1].to( + device=tgt_depth_loss_mask.device, + dtype=tgt_depth_loss_mask.dtype, + ) + tgt_extra_loss_terms: list[dict[str, Any]] = [] + vis_payload = None + if enable_vis: + vis_src_u8 = src_u8[b : b + 1] + vis_tgt_u8 = tgt_u8[b : b + 1] + vis_src_depth_gt = (src_depth_gt[b : b + 1] if has_depth_gt else None) + vis_tgt_depth_gt = (tgt_depth_gt[b : b + 1] if has_depth_gt else None) + vis_src_out = src_out + vis_tgt_out = tgt_out + if has_orig_vis: + vis_src_u8 = src_u8_orig[b : b + 1] + vis_tgt_u8 = tgt_u8_orig[b : b + 1] + vis_src_depth_gt = (src_depth_gt_orig[b : b + 1] if has_depth_gt_orig else None) + vis_tgt_depth_gt = (tgt_depth_gt_orig[b : b + 1] if has_depth_gt_orig else None) + vis_src_render_k3 = src_k3_orig[b : b + 1] + vis_tgt_render_k3 = tgt_k3_orig[b : b + 1] + vis_src_out = self.renderer( + g_world, + extrinsics=ident, + intrinsics=to_k4(vis_src_render_k3), + image_width=int(vis_src_u8.shape[-1]), + image_height=int(vis_src_u8.shape[-2]), + ) + vis_tgt_out = self.renderer( + g_world, + extrinsics=rel_tgt_w2c, + intrinsics=to_k4(vis_tgt_render_k3), + image_width=int(vis_tgt_u8.shape[-1]), + image_height=int(vis_tgt_u8.shape[-2]), + ) + src_unik3d_depth = None + tgt_unik3d_depth = None + raw_dist = out.get("unik3d_distance", None) + if torch.is_tensor(raw_dist): + try: + conditioning_rays = out.get("unik3d_ray_conditioning_rays", None) + if not torch.is_tensor(conditioning_rays): + conditioning_rays = out.get("unik3d_rays", None) + ray_z = ( + conditioning_rays[b : b + 1, 2:3].detach() + if torch.is_tensor(conditioning_rays) + else None + ) + if torch.is_tensor(ray_z): + if tuple(ray_z.shape[-2:]) != tuple(raw_dist.shape[-2:]): + ray_z = F.interpolate(ray_z, size=raw_dist.shape[-2:], mode="bilinear", align_corners=False) + src_unik3d_depth = raw_dist[b : b + 1, 0:1].detach() * ray_z + else: + src_unik3d_depth = self._distance_to_z_depth_pinhole( + raw_dist[b : b + 1, 0:1].detach(), + src_k_render_b, + ) + except Exception: + src_unik3d_depth = raw_dist[b : b + 1, 0:1].detach() + if self.enable_tgt_unik3d_vis: + try: + with torch.no_grad(): + from unisharp.utils.unik3d_adapter import forward_unik3d_pinhole + + unik_tgt = forward_unik3d_pinhole( + self._base_model().feature_extractor.unik3d, + rgb_u8=tgt_u8[b : b + 1], + intrinsics=tgt_k3[b : b + 1], + normalize=True, + ) + dist_tgt = unik_tgt.get("distance", None) if isinstance(unik_tgt, dict) else None + if torch.is_tensor(dist_tgt): + try: + tgt_unik3d_depth = self._distance_to_z_depth_pinhole( + dist_tgt[:, 0:1].detach(), + tgt_k_render_b, + ) + except Exception: + tgt_unik3d_depth = dist_tgt[:, 0:1].detach() + except Exception: + tgt_unik3d_depth = None + + vis_payload = { + "src_gt": (vis_src_u8.float() / 255.0).detach(), + "src_pred": vis_src_out.color.clamp(0, 1).detach(), + "src_alpha": vis_src_out.alpha.detach(), + "src_gt_depth": (vis_src_depth_gt.detach() if torch.is_tensor(vis_src_depth_gt) else None), + "src_pred_depth": vis_src_out.depth.detach(), + "src_unik3d_depth": src_unik3d_depth, + "tgt_gt": (vis_tgt_u8.float() / 255.0).detach(), + "tgt_pred": vis_tgt_out.color.clamp(0, 1).detach(), + "tgt_alpha": vis_tgt_out.alpha.detach(), + "tgt_gt_depth": (vis_tgt_depth_gt.detach() if torch.is_tensor(vis_tgt_depth_gt) else None), + "tgt_pred_depth": vis_tgt_out.depth.detach(), + "tgt_unik3d_depth": tgt_unik3d_depth, + "dataset_name": str(dataset_name), + "scene": str(self._item_at(getattr(batch, "scene", None), b, "unknown")), + "src_idx": int(self._item_at(getattr(batch, "src_idx", None), b, -1)), + "tgt_idx": int(self._item_at(getattr(batch, "tgt_idx", None), b, -1)), + "src_pose_w2c": src_w2c[b : b + 1].detach(), + "tgt_pose_w2c": tgt_w2c[b : b + 1].detach(), + "tgt_metric_mask_raw": tgt_frustum_mask_raw.detach(), + "tgt_metric_mask": tgt_frustum_mask.detach(), + } + + return { + "src_pred_rgb_linear": src_out.color, + "src_pred_alpha": src_out.alpha, + "src_pred_depth_m": src_depth_pred, + "src_pred_depth2_m": src_depth2_pred, + "src_gt_rgb_u8": src_u8[b : b + 1], + "src_gt_depth_m": (src_depth_gt_dist[b : b + 1] if has_depth_gt and src_depth_gt_dist is not None else zeros_src_depth), + "src_mask": ones_mask, + "src_apply_depth": False, + "src_apply_grad": bool(has_depth_gt), + "src_apply_grad_img": bool(has_depth_gt), + "src_grad_img_circular_h": False, + "tgt_pred_rgb_linear": tgt_out.color, + "tgt_pred_alpha": tgt_out.alpha, + "tgt_pred_depth_m": tgt_depth_pred, + "tgt_gt_rgb_u8": tgt_u8[b : b + 1], + "tgt_gt_depth_m": (tgt_depth_gt_dist[b : b + 1] if has_depth_gt and tgt_depth_gt_dist is not None else zeros_tgt_depth), + "tgt_mask": tgt_frustum_mask, + "tgt_depth_mask": tgt_depth_loss_mask, + "tgt_apply_depth": bool(has_depth_gt), + "tgt_apply_grad_img": bool(has_depth_gt), + "tgt_grad_img_circular_h": False, + "tgt_apply_percep": bool(float(self.loss_fn.w.lambda_percep) > 0.0), + "tgt_extra_loss_terms": tgt_extra_loss_terms, + "aux_losses": self._aux_ray_losses( + pred_rays=( + out.get("unik3d_rays", None)[b : b + 1] + if torch.is_tensor(out.get("unik3d_rays", None)) + else None + ), + gt_rays=( + aux_ray_target_all[b : b + 1] + if torch.is_tensor(aux_ray_target_all) + else None + ), + mask=ones_mask, + pred_distance=( + out["unik3d_distance"][b : b + 1, 0:1] + if torch.is_tensor(out.get("unik3d_distance", None)) + else None + ), + pred_distance2=src_depth2_pred, + gt_distance=( + src_unik3d_gt_dist[b : b + 1] + if torch.is_tensor(src_unik3d_gt_dist) + else None + ), + gt_distance2=src_depth2_gt_for_aux, + depth_mask=(src_depth_gt[b : b + 1] > 0.0 if has_depth_gt else None), + depth_mask2=src_depth2_mask_for_aux, + ), + "gaussian_scales": reg_inputs["gaussian_scales"], + "gaussian_quaternions": reg_inputs["gaussian_quaternions"], + "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"], + "delta_xy": reg_inputs["delta_xy_eff"], + "delta_rho": reg_inputs["delta_rho_raw"], + "delta_grid": reg_inputs["delta_grid"], + "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"], + "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"], + "gaussian_opacities": reg_inputs["gaussian_opacities"], + "gauss_grid_shape": reg_inputs["gauss_grid_shape"], + "projected_scale_factor": reg_inputs["projected_scale_factor"], + "projection_model": "pinhole", + "projection_intrinsics": src_k_render_b, + "vis_payload": vis_payload, + } + + return _ModeStrategy( + batch_size=int(src.shape[0]), + gaussians=gaussians, + make_world_gaussians=make_world_gaussians, + make_sample=make_sample, + collect_all_vis=bool(getattr(batch, "collect_all_vis", False)), + ) + + def _build_fisheye624_strategy( + self, + batch: Any, + step: int, + need_vis: bool = False, + dataset_name: str = "scannetpp_fisheye", + ) -> _ModeStrategy: + del step + src_u8 = self._as_bchw_rgb_u8(batch.src_rgb_u8.to(self.device, non_blocking=True)) + tgt_u8 = self._as_bchw_rgb_u8(batch.tgt_rgb_u8.to(self.device, non_blocking=True)) + src_depth_gt = self._clamp_distance_for_supervision( + self._as_b1hw_depth(batch.src_depth_m.to(self.device, non_blocking=True).to(torch.float32)) + ) + tgt_depth_gt = self._clamp_distance_for_supervision( + self._as_b1hw_depth(batch.tgt_depth_m.to(self.device, non_blocking=True).to(torch.float32)) + ) + src_valid_mask = self._as_b1hw_depth(batch.src_valid_mask.to(self.device, non_blocking=True).to(torch.float32)) + tgt_valid_mask = self._as_b1hw_depth(batch.tgt_valid_mask.to(self.device, non_blocking=True).to(torch.float32)) + src_w2c = self._as_b44_pose(batch.src_w2c.to(self.device, non_blocking=True).to(torch.float32)) + tgt_w2c = self._as_b44_pose(batch.tgt_w2c.to(self.device, non_blocking=True).to(torch.float32)) + src_cam_params = self._as_b16_camera_params( + batch.src_camera_params.to(self.device, non_blocking=True).to(torch.float32) + ) + tgt_cam_params = self._as_b16_camera_params( + batch.tgt_camera_params.to(self.device, non_blocking=True).to(torch.float32) + ) + distance_init_cap_m = self._distance_init_cap_for_dataset(dataset_name) + + out = self.model( + image=src_u8.float().clamp(0, 255) / 255.0, + image_u8=src_u8, + camera_intrinsics=None, + camera_params=src_cam_params, + camera_model="fisheye624", + depth_gt=src_depth_gt, + distance_init_cap_m=distance_init_cap_m, + validity_mask=src_valid_mask, + return_aux=True, + ) + + gaussians = out["gaussians"] + src_render_cam_params = src_cam_params + tgt_render_cam_params = tgt_cam_params + src_render_valid_mask = src_valid_mask + tgt_render_valid_mask = tgt_valid_mask + aux_ray_target_all = out.get("unik3d_gt_rays", None) + + def make_world_gaussians(b: int, g_b: Any) -> Any: + return transform_gaussians_to_world(g_b, src_w2c[b]) + + def make_sample(b: int, g_world: Any, enable_vis: bool) -> dict[str, Any]: + src_h = int(src_u8.shape[-2]) + src_w = int(src_u8.shape[-1]) + tgt_h = int(tgt_u8.shape[-2]) + tgt_w = int(tgt_u8.shape[-1]) + src_render = render_gaussians_fisheye624( + g_world, + extrinsics_w2c=src_w2c[b : b + 1], + camera_params=src_render_cam_params[b : b + 1], + image_h=src_h, + image_w=src_w, + valid_mask=src_render_valid_mask[b : b + 1], + ) + tgt_render = render_gaussians_fisheye624( + g_world, + extrinsics_w2c=tgt_w2c[b : b + 1], + camera_params=tgt_render_cam_params[b : b + 1], + image_h=tgt_h, + image_w=tgt_w, + valid_mask=tgt_render_valid_mask[b : b + 1], + ) + reg_inputs = self._collect_regularization_inputs( + out=out, + gaussians=gaussians, + b=b, + projected_scale_factor=None, + ) + tgt_depth_for_mask = self._pick_depth_for_fisheye_frustum_mask( + gt_depth=tgt_depth_gt[b : b + 1], + pred_depth=tgt_render["depth_distance"], + gt_valid_mask=tgt_valid_mask[b : b + 1], + ) + tgt_frustum_mask = compute_fisheye624_frustum_mask( + depth_distance_m=tgt_depth_for_mask, + tgt_w2c=tgt_w2c[b : b + 1], + src_w2c=src_w2c[b : b + 1], + tgt_camera_params=tgt_render_cam_params[b : b + 1], + src_camera_params=src_render_cam_params[b : b + 1], + src_valid_mask=src_render_valid_mask[b : b + 1] * src_render["valid_mask"], + source_depth_distance_m=src_depth_gt[b : b + 1], + ) + src_mask = src_render_valid_mask[b : b + 1] * src_render["valid_mask"] + src_depth_mask = src_mask + tgt_mask = tgt_render_valid_mask[b : b + 1] * tgt_render["valid_mask"] * tgt_frustum_mask + tgt_mask_raw = tgt_mask + tgt_mask = self._erode_supervision_mask( + tgt_mask, + self.target_mask_erode_px, + circular_h=False, + ) + src_depth_pred = self._clamp_distance_for_supervision( + out["distance_layers"][b : b + 1, 0:1], + clamp_max=False, + ) + src_depth2_pred = ( + self._clamp_distance_for_supervision(out["distance_layers"][b : b + 1, 1:2], clamp_max=False) + if out["distance_layers"] is not None and out["distance_layers"].shape[1] > 1 + else None + ) + tgt_depth_pred = self._clamp_distance_for_supervision(tgt_render["depth_distance"], clamp_max=False) + tgt_depth_loss_mask = self._rendered_depth_valid_for_inv_loss(tgt_depth_pred, tgt_render["alpha"]) + + src_loss_terms = [ + { + "pred_rgb_linear": src_render["color"], + "pred_alpha": src_render["alpha"], + "pred_depth_m": src_render["depth_distance"], + "pred_depth2_m": None, + "gt_rgb_u8": src_u8[b : b + 1], + "gt_depth_m": src_depth_gt[b : b + 1], + "mask": src_mask, + "apply_color": True, + "apply_alpha": True, + "apply_depth": False, + "apply_percep": False, + "apply_tv": False, + "apply_grad": False, + "apply_grad_img": False, + "grad_img_circular_h": False, + "gaussian_scales": None, + "gaussian_quaternions": None, + "gaussian_angular_cell": None, + "delta_xy": None, + "gaussian_mean_vectors": None, + "gaussian_opacities": None, + "gauss_grid_shape": None, + "projected_scale_factor": None, + "apply_splat": False, + "loss_scale": 1.0, + } + ] + src_extra_loss_terms = [ + { + "pred_rgb_linear": torch.zeros((1, 3, src_h, src_w), dtype=torch.float32, device=self.device), + "pred_alpha": torch.zeros((1, 1, src_h, src_w), dtype=torch.float32, device=self.device), + "pred_depth_m": src_depth_pred, + "pred_depth2_m": src_depth2_pred, + "gt_rgb_u8": torch.zeros((1, 3, src_h, src_w), dtype=torch.uint8, device=self.device), + "gt_depth_m": src_depth_gt[b : b + 1], + "mask": src_depth_mask, + "apply_color": False, + "apply_alpha": False, + "apply_depth": False, + "apply_percep": False, + "apply_tv": True, + "apply_grad": True, + "apply_grad_img": True, + "grad_img_circular_h": False, + "gaussian_scales": reg_inputs["gaussian_scales"], + "gaussian_quaternions": reg_inputs["gaussian_quaternions"], + "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"], + "delta_xy": reg_inputs["delta_xy_eff"], + "delta_rho": reg_inputs["delta_rho_raw"], + "delta_grid": reg_inputs["delta_grid"], + "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"], + "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"], + "gaussian_opacities": reg_inputs["gaussian_opacities"], + "gauss_grid_shape": reg_inputs["gauss_grid_shape"], + "projected_scale_factor": None, + "projection_model": "fisheye624", + "projection_camera_params": src_render_cam_params[b : b + 1], + "apply_splat": True, + "loss_scale": 1.0, + } + ] + tgt_extra_loss_terms = [] + + vis_payload = None + if enable_vis: + src_unik3d_depth = out["unik3d_distance"][b : b + 1, 0:1].detach() if torch.is_tensor(out.get("unik3d_distance", None)) else None + tgt_unik3d_depth = None + if ( + tgt_unik3d_depth is None + and self.enable_tgt_unik3d_vis + ): + try: + with torch.no_grad(): + from unisharp.utils.unik3d_adapter import forward_unik3d_fisheye624 + + unik_tgt = forward_unik3d_fisheye624( + self._base_model().feature_extractor.unik3d, + rgb_u8=tgt_u8[b : b + 1], + camera_params=tgt_render_cam_params[b : b + 1], + normalize=True, + validity_mask=tgt_valid_mask[b : b + 1], + ) + dist_tgt = unik_tgt.get("distance", None) if isinstance(unik_tgt, dict) else None + if torch.is_tensor(dist_tgt): + tgt_unik3d_depth = dist_tgt[:, 0:1].detach() + except Exception: + tgt_unik3d_depth = None + vis_payload = { + "src_gt": (src_u8[b : b + 1].float() / 255.0).detach(), + "src_pred": src_render["color"].clamp(0, 1).detach(), + "src_alpha": src_render["alpha"].detach(), + "src_gt_depth": src_depth_gt[b : b + 1].detach(), + "src_pred_depth": src_render["depth_distance"].detach(), + "src_unik3d_depth": src_unik3d_depth, + "src_metric_mask": src_mask.detach(), + "tgt_gt": (tgt_u8[b : b + 1].float() / 255.0).detach(), + "tgt_pred": tgt_render["color"].clamp(0, 1).detach(), + "tgt_alpha": tgt_render["alpha"].detach(), + "tgt_gt_depth": tgt_depth_gt[b : b + 1].detach(), + "tgt_pred_depth": tgt_depth_pred.detach(), + "tgt_unik3d_depth": tgt_unik3d_depth, + "dataset_name": str(dataset_name), + "scene": str(self._first_item(getattr(batch, "scene", None), "unknown")), + "src_idx": int(self._first_item(getattr(batch, "src_idx", None), -1)), + "tgt_idx": int(self._first_item(getattr(batch, "tgt_idx", None), -1)), + "src_pose_w2c": src_w2c[b : b + 1].detach(), + "tgt_pose_w2c": tgt_w2c[b : b + 1].detach(), + "tgt_metric_mask_raw": tgt_mask_raw.detach(), + "tgt_metric_mask": tgt_mask.detach(), + } + + + return { + "src_loss_terms": src_loss_terms, + "src_extra_loss_terms": src_extra_loss_terms, + "tgt_pred_rgb_linear": tgt_render["color"], + "tgt_pred_alpha": tgt_render["alpha"], + "tgt_pred_depth_m": tgt_depth_pred, + "tgt_gt_rgb_u8": tgt_u8[b : b + 1], + "tgt_gt_depth_m": tgt_depth_gt[b : b + 1], + "tgt_mask": tgt_mask, + "tgt_depth_mask": tgt_depth_loss_mask, + "tgt_apply_depth": True, + "tgt_apply_grad_img": True, + "tgt_apply_splat": False, + "tgt_grad_img_circular_h": False, + "tgt_apply_percep": bool(float(self.loss_fn.w.lambda_percep) > 0.0), + "tgt_extra_loss_terms": tgt_extra_loss_terms, + "aux_losses": self._aux_ray_losses( + pred_rays=( + out.get("unik3d_rays", None)[b : b + 1] + if torch.is_tensor(out.get("unik3d_rays", None)) + else None + ), + gt_rays=( + aux_ray_target_all[b : b + 1] + if torch.is_tensor(aux_ray_target_all) + else None + ), + mask=src_render_valid_mask[b : b + 1], + pred_distance=( + out["unik3d_distance"][b : b + 1, 0:1] + if torch.is_tensor(out.get("unik3d_distance", None)) + else None + ), + pred_distance2=None, + gt_distance=src_depth_gt[b : b + 1], + depth_mask=src_valid_mask[b : b + 1], + ), + "gaussian_scales": reg_inputs["gaussian_scales"], + "gaussian_quaternions": reg_inputs["gaussian_quaternions"], + "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"], + "delta_xy": reg_inputs["delta_xy_eff"], + "delta_rho": reg_inputs["delta_rho_raw"], + "delta_grid": reg_inputs["delta_grid"], + "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"], + "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"], + "gaussian_opacities": reg_inputs["gaussian_opacities"], + "gauss_grid_shape": reg_inputs["gauss_grid_shape"], + "projected_scale_factor": reg_inputs["projected_scale_factor"], + "projection_model": "fisheye624", + "projection_camera_params": src_render_cam_params[b : b + 1], + "vis_payload": vis_payload, + } + + return _ModeStrategy( + batch_size=int(src_u8.shape[0]), + gaussians=gaussians, + make_world_gaussians=make_world_gaussians, + make_sample=make_sample, + collect_all_vis=bool(getattr(batch, "collect_all_vis", False)), + ) + + def _build_fisheye_strategy( + self, + batch: Any, + step: int, + need_vis: bool = False, + dataset_name: str = "fisheye", + ) -> _ModeStrategy: + camera_model = str(getattr(batch, "camera_model", "fisheye624")).lower() + if camera_model != "fisheye624": + raise ValueError( + f"Unsupported fisheye camera_model={camera_model!r}; expected 'fisheye624'." + ) + return self._build_fisheye624_strategy( + batch, + step, + need_vis=need_vis, + dataset_name=dataset_name, + ) + + def _build_spherical_strategy( + self, + batch: Any, + step: int, + need_vis: bool = False, + dataset_name: str = "hm3d", + ) -> _ModeStrategy: + src_erp_u8 = batch.src_erp_rgb_u8.to(self.device, non_blocking=True) + tgt_erp_u8 = batch.tgt_erp_rgb_u8.to(self.device, non_blocking=True) + src_erp_depth = self._clamp_distance_for_supervision( + batch.src_erp_depth_m.to(self.device, non_blocking=True) + ) + tgt_erp_depth = self._clamp_distance_for_supervision( + batch.tgt_erp_depth_m.to(self.device, non_blocking=True) + ) + src_cdep = self._sanitize_positive_depth( + batch.src_cube_depth_m.to(self.device, non_blocking=True) + ) + tgt_cdep = self._sanitize_positive_depth( + batch.tgt_cube_depth_m.to(self.device, non_blocking=True) + ) + disable_depth_gt = bool(getattr(batch, "disable_depth_gt", False)) + + src_R = batch.src_R.to(self.device, non_blocking=True) + src_t = batch.src_t.to(self.device, non_blocking=True) + tgt_R = batch.tgt_R.to(self.device, non_blocking=True) + tgt_t = batch.tgt_t.to(self.device, non_blocking=True) + + cur_bs = int(src_erp_u8.shape[0]) + erp_h = int(src_erp_u8.shape[-2]) + erp_w = int(src_erp_u8.shape[-1]) + cube_face_w = int(batch.src_cube_depth_m.shape[2]) if torch.is_tensor(batch.src_cube_depth_m) else max(1, erp_h // 2) + + use_flip_yz = str(dataset_name).lower() not in {"sim", "smx_sim_fisheye"} + pose_convs_per_sample = ["c2w"] * cur_bs + flip_yz_per_sample = [bool(use_flip_yz)] * cur_bs + + extr_src_base = torch.stack( + [build_extrinsics_w2c(src_R[i], src_t[i], pose_convs_per_sample[i]) for i in range(cur_bs)], + dim=0 + ) + extr_tgt_base = torch.stack( + [build_extrinsics_w2c(tgt_R[i], tgt_t[i], pose_convs_per_sample[i]) for i in range(cur_bs)], + dim=0 + ) + + with torch.autocast("cuda", enabled=False): + c2w_src = torch.linalg.inv(extr_src_base.to(torch.float32)) + c2w_tgt = torch.linalg.inv(extr_tgt_base.to(torch.float32)) + + flip_mask = torch.tensor(flip_yz_per_sample, device=c2w_src.device, dtype=torch.bool) + negate_relative_z = False + if bool(flip_mask.any().item()): + flip_mode = os.environ.get("PANO_POSE_FLIP_CONVENTION", "flip_yz_negate_rel_z").strip().lower() + negate_relative_z = flip_mode in { + "flip_yz_negate_rel_z", + "flip_yz_invert_z_translation", + "flip_yz_neg_z", + } + if flip_mode in {"flip_y_only", "y", "y_only"}: + diag = [1.0, -1.0, 1.0, 1.0] + elif flip_mode in {"none", "identity", "no_flip"}: + diag = [1.0, 1.0, 1.0, 1.0] + else: + diag = [1.0, -1.0, -1.0, 1.0] + D = torch.diag(torch.tensor(diag, device=c2w_src.device, dtype=torch.float32)) + c2w_src = c2w_src.clone() + c2w_tgt = c2w_tgt.clone() + c2w_src[flip_mask] = c2w_src[flip_mask] @ D + c2w_tgt[flip_mask] = c2w_tgt[flip_mask] @ D + + ref_inv = torch.linalg.inv(c2w_src.to(torch.float32)) + c2w_src = ref_inv @ c2w_src + c2w_tgt = ref_inv @ c2w_tgt + if negate_relative_z: + c2w_tgt = c2w_tgt.clone() + c2w_tgt[flip_mask, 2, 3] *= -1.0 + + extr_src = torch.linalg.inv(c2w_src).to(dtype=extr_src_base.dtype) + extr_tgt = torch.linalg.inv(c2w_tgt).to(dtype=extr_tgt_base.dtype) + + src_erp = (src_erp_u8.float() / 255.0).clamp(0, 1) + distance_init_cap_m = self._distance_init_cap_for_dataset(dataset_name) + + out = self.model( + image=src_erp, + image_u8=src_erp_u8, + camera_intrinsics=None, + camera_model="spherical", + depth_gt=None if disable_depth_gt else src_erp_depth, + distance_init_cap_m=distance_init_cap_m, + return_aux=True, + ) + gaussians = out["gaussians"] + aux_ray_target_all = out.get("unik3d_gt_rays", None) + + def make_world_gaussians(b: int, g_b: Any) -> Any: + return transform_gaussians_to_world(g_b, extr_src[b]) + + def make_sample(b: int, g_world: Any, enable_vis: bool) -> dict[str, Any]: + src_rgb, src_depth, src_alpha = self._render_cubemap(g_world, extr_src[b], face_w=cube_face_w) + tgt_rgb, tgt_depth, tgt_alpha = self._render_cubemap(g_world, extr_tgt[b], face_w=cube_face_w) + + src_erp_pred = self._cube_to_erp(src_rgb, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w) + tgt_erp_pred = self._cube_to_erp(tgt_rgb, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w) + src_erp_alpha = self._cube_to_erp(src_alpha, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w) + tgt_erp_alpha = self._cube_to_erp(tgt_alpha, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w) + src_depth_dist = self._clamp_distance_for_supervision( + self._cubemap_z_depth_to_distance(src_depth), + clamp_max=False, + ) + tgt_depth_dist = self._clamp_distance_for_supervision( + self._cubemap_z_depth_to_distance(tgt_depth), + clamp_max=False, + ) + src_erp_depth_render = self._cube_to_erp( + src_depth_dist, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w + ).clamp(min=1e-4) + + src_erp_depth_pred = self._clamp_distance_for_supervision( + out["distance_layers"][b : b + 1, 0:1], + clamp_max=False, + ) + src_erp_depth2_pred = ( + self._clamp_distance_for_supervision(out["distance_layers"][b : b + 1, 1:2], clamp_max=False) + if out["distance_layers"] is not None and out["distance_layers"].shape[1] > 1 + else None + ) + tgt_erp_depth_pred = self._cube_to_erp( + tgt_depth_dist, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w + ).clamp(min=1e-4) + tgt_depth_loss_mask = self._rendered_depth_valid_for_inv_loss(tgt_erp_depth_pred, tgt_erp_alpha) + depth_novel = self._pick_depth_for_cubemap_frustum_mask( + gt_depth_cube=None if disable_depth_gt else (tgt_cdep[b : b + 1][0] if torch.is_tensor(tgt_cdep) else None), + pred_depth_cube=tgt_depth_dist, + face_w=cube_face_w, + ) + source_depth_for_visibility = self._pick_depth_for_cubemap_frustum_mask( + gt_depth_cube=None if disable_depth_gt else (src_cdep[b : b + 1][0] if torch.is_tensor(src_cdep) else None), + pred_depth_cube=src_depth_dist, + face_w=cube_face_w, + ) + mask_bool = view_frustum_mask_cubemap_union( + depth_novel=depth_novel, + extr_novel_w2c=extr_tgt[b], + extr_source_w2c=extr_src[b], + face_w=int(cube_face_w), + source_depth=source_depth_for_visibility, + ) + mask_erp = self._cube_to_erp( + mask_bool[:, None].to(torch.float32), equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w + ) + + gt_src_erp_u8 = src_erp_u8[b : b + 1] + gt_tgt_erp_u8 = tgt_erp_u8[b : b + 1] + gt_src_erp_depth = src_erp_depth[b : b + 1] + gt_tgt_erp_depth = tgt_erp_depth[b : b + 1] + gt_src_cube_u8 = batch.src_cube_rgb_u8[b].to(self.device, non_blocking=True).permute(0, 3, 1, 2).contiguous() + gt_tgt_cube_u8 = batch.tgt_cube_rgb_u8[b].to(self.device, non_blocking=True).permute(0, 3, 1, 2).contiguous() + + src_valid = torch.ones_like(gt_src_erp_depth) if disable_depth_gt else (gt_src_erp_depth > 0.0).to(dtype=torch.float32) + tgt_valid = torch.ones_like(gt_tgt_erp_depth) if disable_depth_gt else (gt_tgt_erp_depth > 0.0).to(dtype=torch.float32) + src_mask = torch.ones_like(src_valid) + tgt_mask = (mask_erp.to(dtype=torch.float32) * tgt_valid).clamp(0.0, 1.0) + tgt_mask_raw = tgt_mask + tgt_mask = self._erode_supervision_mask( + tgt_mask, + self.target_mask_erode_px, + circular_h=True, + ) + src_cube_mask = torch.ones_like(src_alpha) + if str(dataset_name).lower() == "hm3d" and (not disable_depth_gt) and torch.is_tensor(src_cdep): + src_cube_valid = (src_cdep[b : b + 1][0, ..., 0] > 0.0).to(dtype=src_alpha.dtype).unsqueeze(1) + if tuple(src_cube_valid.shape[-2:]) != tuple(src_alpha.shape[-2:]): + src_cube_valid = F.interpolate( + src_cube_valid, + size=src_alpha.shape[-2:], + mode="nearest", + ) + src_cube_mask = src_cube_valid.to(device=src_alpha.device, dtype=src_alpha.dtype).clamp(0.0, 1.0) + tgt_cube_valid = (depth_novel[..., 0] > 0.0).to(dtype=torch.float32).unsqueeze(1) + tgt_cube_mask = (mask_bool[:, None].to(dtype=torch.float32) * tgt_cube_valid).clamp(0.0, 1.0) + tgt_cube_mask = self._erode_supervision_mask( + tgt_cube_mask, + self.target_mask_erode_px, + circular_h=False, + ) + + src_cube_depth_zeros = torch.zeros_like(src_alpha) + tgt_cube_depth_zeros = torch.zeros_like(tgt_alpha) + src_erp_rgb_zeros = torch.zeros_like(src_erp_pred) + tgt_erp_rgb_zeros = torch.zeros_like(tgt_erp_pred) + src_erp_u8_zeros = torch.zeros_like(gt_src_erp_u8) + tgt_erp_u8_zeros = torch.zeros_like(gt_tgt_erp_u8) + + erp_proj_scale = 0.5 * ( + float(erp_w) / (2.0 * 3.141592653589793) + + float(erp_h) / 3.141592653589793 + ) + reg_inputs = self._collect_regularization_inputs( + out=out, + gaussians=gaussians, + b=b, + projected_scale_factor=erp_proj_scale, + ) + + vis_payload = None + if enable_vis: + src_unik3d_depth = None + tgt_unik3d_depth = None + raw_dist = out.get("unik3d_distance", None) + if torch.is_tensor(raw_dist): + src_unik3d_depth = raw_dist[b : b + 1, 0:1].detach() + + vis_payload = { + "src_gt": (gt_src_erp_u8.float() / 255.0).detach(), + "src_pred": src_erp_pred.clamp(0, 1).detach(), + "src_alpha": src_erp_alpha.detach(), + "src_gt_depth": None if disable_depth_gt else gt_src_erp_depth.detach(), + "src_pred_depth": src_erp_depth_render.detach(), + "src_unik3d_depth": src_unik3d_depth, + "tgt_gt": (gt_tgt_erp_u8.float() / 255.0).detach(), + "tgt_pred": tgt_erp_pred.clamp(0, 1).detach(), + "tgt_alpha": tgt_erp_alpha.detach(), + "tgt_gt_depth": None if disable_depth_gt else gt_tgt_erp_depth.detach(), + "tgt_pred_depth": tgt_erp_depth_pred.detach(), + "tgt_unik3d_depth": tgt_unik3d_depth, + "dataset_name": str(dataset_name), + "scene": str(self._item_at(getattr(batch, "scene", None), b, "unknown")), + "src_idx": int(self._item_at(getattr(batch, "src_idx", None), b, -1)), + "tgt_idx": int(self._item_at(getattr(batch, "tgt_idx", None), b, -1)), + "src_pose_w2c": extr_src[b : b + 1].detach(), + "tgt_pose_w2c": extr_tgt[b : b + 1].detach(), + "src_cube_gt_u8": ( + batch.src_cube_rgb_u8[b].detach() + if hasattr(batch, "src_cube_rgb_u8") and torch.is_tensor(batch.src_cube_rgb_u8) + else None + ), + "tgt_cube_gt_u8": ( + batch.tgt_cube_rgb_u8[b].detach() + if hasattr(batch, "tgt_cube_rgb_u8") and torch.is_tensor(batch.tgt_cube_rgb_u8) + else None + ), + "src_cube_pred_linear": src_rgb.detach(), + "tgt_cube_pred_linear": tgt_rgb.detach(), + "src_cube_alpha": src_alpha.detach(), + "tgt_cube_alpha": tgt_alpha.detach(), + "tgt_metric_mask_raw": tgt_mask_raw.detach(), + "tgt_metric_mask": tgt_mask.detach(), + } + + tgt_loss_terms = [ + { + "pred_rgb_linear": tgt_rgb, + "pred_alpha": tgt_alpha, + "pred_depth_m": tgt_cube_depth_zeros, + "pred_depth2_m": None, + "gt_rgb_u8": gt_tgt_cube_u8, + "gt_depth_m": tgt_cube_depth_zeros, + "mask": tgt_cube_mask, + "apply_color": True, + "apply_alpha": True, + "apply_depth": False, + "apply_percep": bool(float(self.loss_fn.w.lambda_percep) > 0.0), + "apply_tv": False, + "apply_grad": False, + "apply_grad_img": False, + "grad_img_circular_h": False, + "gaussian_scales": None, + "gaussian_quaternions": None, + "gaussian_angular_cell": None, + "delta_xy": None, + "gaussian_mean_vectors": None, + "gaussian_opacities": None, + "gauss_grid_shape": None, + "projected_scale_factor": None, + }, + { + "pred_rgb_linear": tgt_erp_rgb_zeros, + "pred_alpha": torch.zeros_like(tgt_erp_depth_pred), + "pred_depth_m": tgt_erp_depth_pred, + "pred_depth2_m": None, + "gt_rgb_u8": tgt_erp_u8_zeros, + "gt_depth_m": gt_tgt_erp_depth, + "mask": tgt_mask, + "depth_mask": tgt_depth_loss_mask, + "apply_color": False, + "apply_alpha": False, + "apply_depth": not disable_depth_gt, + "apply_percep": False, + "apply_tv": False, + "apply_grad": False, + "apply_grad_img": not disable_depth_gt, + "grad_img_circular_h": True, + "gaussian_scales": None, + "gaussian_quaternions": None, + "gaussian_angular_cell": None, + "delta_xy": None, + "gaussian_mean_vectors": None, + "gaussian_opacities": None, + "gauss_grid_shape": None, + "projected_scale_factor": reg_inputs["projected_scale_factor"], + }, + ] + return { + "src_loss_terms": [ + { + "pred_rgb_linear": src_rgb, + "pred_alpha": src_alpha, + "pred_depth_m": src_cube_depth_zeros, + "pred_depth2_m": None, + "gt_rgb_u8": gt_src_cube_u8, + "gt_depth_m": src_cube_depth_zeros, + "mask": src_cube_mask, + "apply_color": True, + "apply_alpha": True, + "apply_depth": False, + "apply_percep": False, + "apply_tv": False, + "apply_grad": False, + "apply_grad_img": False, + "grad_img_circular_h": False, + "gaussian_scales": None, + "gaussian_quaternions": None, + "gaussian_angular_cell": None, + "delta_xy": None, + "gaussian_mean_vectors": None, + "gaussian_opacities": None, + "gauss_grid_shape": None, + "projected_scale_factor": None, + }, + { + "pred_rgb_linear": src_erp_rgb_zeros, + "pred_alpha": torch.zeros_like(src_erp_depth_pred), + "pred_depth_m": src_erp_depth_pred, + "pred_depth2_m": src_erp_depth2_pred, + "gt_rgb_u8": src_erp_u8_zeros, + "gt_depth_m": gt_src_erp_depth, + "mask": src_mask, + "apply_color": False, + "apply_alpha": False, + "apply_depth": False, + "apply_percep": False, + "apply_tv": True, + "apply_grad": False, + "apply_grad_img": not disable_depth_gt, + "grad_img_circular_h": True, + "gaussian_scales": reg_inputs["gaussian_scales"], + "gaussian_quaternions": reg_inputs["gaussian_quaternions"], + "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"], + "delta_xy": reg_inputs["delta_xy_eff"], + "delta_rho": reg_inputs["delta_rho_raw"], + "delta_grid": reg_inputs["delta_grid"], + "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"], + "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"], + "gaussian_opacities": reg_inputs["gaussian_opacities"], + "gauss_grid_shape": reg_inputs["gauss_grid_shape"], + "projected_scale_factor": reg_inputs["projected_scale_factor"], + "projection_model": "erp", + }, + ], + "tgt_loss_terms": tgt_loss_terms, + "gaussian_scales": reg_inputs["gaussian_scales"], + "gaussian_quaternions": reg_inputs["gaussian_quaternions"], + "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"], + "delta_xy": reg_inputs["delta_xy_eff"], + "delta_rho": reg_inputs["delta_rho_raw"], + "delta_grid": reg_inputs["delta_grid"], + "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"], + "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"], + "gaussian_opacities": reg_inputs["gaussian_opacities"], + "gauss_grid_shape": reg_inputs["gauss_grid_shape"], + "projected_scale_factor": reg_inputs["projected_scale_factor"], + "projection_model": "erp", + "aux_losses": self._aux_ray_losses( + pred_rays=( + out.get("unik3d_rays", None)[b : b + 1] + if torch.is_tensor(out.get("unik3d_rays", None)) + else None + ), + gt_rays=( + aux_ray_target_all[b : b + 1] + if torch.is_tensor(aux_ray_target_all) + else None + ), + mask=src_valid, + pred_distance=( + out["unik3d_distance"][b : b + 1, 0:1] + if torch.is_tensor(out.get("unik3d_distance", None)) + else None + ), + pred_distance2=src_erp_depth2_pred, + gt_distance=None if disable_depth_gt else gt_src_erp_depth, + depth_mask=src_valid, + ), + "vis_payload": vis_payload, + } + + return _ModeStrategy( + batch_size=int(cur_bs), + gaussians=gaussians, + make_world_gaussians=make_world_gaussians, + make_sample=make_sample, + collect_all_vis=bool(getattr(batch, "collect_all_vis", False)), + ) + + def _render_cubemap( + self, + gaussians: Any, + extr_w2c: torch.Tensor, + face_w: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = gaussians.mean_vectors.device + intr = get_pinhole_intrinsics_4x4(int(face_w)).to(device=device)[None].expand(6, -1, -1) + extr_faces = cubemap_face_cameras(extr_w2c, device=device) + out = self.renderer( + gaussians, + extrinsics=extr_faces, + intrinsics=intr, + image_width=int(face_w), + image_height=int(face_w), + ) + return out.color.contiguous(), out.depth.contiguous(), out.alpha.contiguous() + + def _cube_to_erp(self, cube: torch.Tensor, equ_h: int, equ_w: int, face_w: int) -> torch.Tensor: + cube = cube.permute(1, 0, 2, 3).unsqueeze(0) + c2e = Cube2Equirec(face_w=int(face_w), equ_h=int(equ_h), equ_w=int(equ_w)).to(device=cube.device) + return c2e(cube) diff --git a/unisharp/datasets/__pycache__/dl3dv.cpython-310.pyc b/unisharp/datasets/__pycache__/dl3dv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..730df0e28e17e0c3183f1d8057320cb8d0abbccc Binary files /dev/null and b/unisharp/datasets/__pycache__/dl3dv.cpython-310.pyc differ diff --git a/unisharp/datasets/__pycache__/dl3dv.cpython-313.pyc b/unisharp/datasets/__pycache__/dl3dv.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f86835e27b5bb15184fddfc95b60ed671c68a772 Binary files /dev/null and b/unisharp/datasets/__pycache__/dl3dv.cpython-313.pyc differ diff --git a/unisharp/datasets/__pycache__/pair_sampling.cpython-310.pyc b/unisharp/datasets/__pycache__/pair_sampling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc973cbfa838a6aad20e0541dc60a1ce037588a8 Binary files /dev/null and b/unisharp/datasets/__pycache__/pair_sampling.cpython-310.pyc differ diff --git a/unisharp/datasets/__pycache__/pair_sampling.cpython-313.pyc b/unisharp/datasets/__pycache__/pair_sampling.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0c09a34edd0ded1b64a73651bd1a15ed336fd8e Binary files /dev/null and b/unisharp/datasets/__pycache__/pair_sampling.cpython-313.pyc differ diff --git a/unisharp/datasets/__pycache__/panogs.cpython-310.pyc b/unisharp/datasets/__pycache__/panogs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..847daaad0d664696183cc7c668a97326fe0761bb Binary files /dev/null and b/unisharp/datasets/__pycache__/panogs.cpython-310.pyc differ diff --git a/unisharp/datasets/__pycache__/panogs.cpython-313.pyc b/unisharp/datasets/__pycache__/panogs.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..384e01e0bcf7f6d316e504a04654fed20935ac25 Binary files /dev/null and b/unisharp/datasets/__pycache__/panogs.cpython-313.pyc differ diff --git a/unisharp/datasets/__pycache__/re10k.cpython-310.pyc b/unisharp/datasets/__pycache__/re10k.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6be243d81ddf515d625d4e01068dcbb1d21fb00 Binary files /dev/null and b/unisharp/datasets/__pycache__/re10k.cpython-310.pyc differ diff --git a/unisharp/datasets/__pycache__/re10k.cpython-313.pyc b/unisharp/datasets/__pycache__/re10k.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db367ac8d51902b27d0c6c9718845acbfd90d7be Binary files /dev/null and b/unisharp/datasets/__pycache__/re10k.cpython-313.pyc differ diff --git a/unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-310.pyc b/unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01b8417e8e82f717aecdfa1c4fffa42991eb2f40 Binary files /dev/null and b/unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-310.pyc differ diff --git a/unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-313.pyc b/unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc11ba3316b62057ba98f5f84882a91594ad6e41 Binary files /dev/null and b/unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-313.pyc differ diff --git a/unisharp/datasets/__pycache__/sim_panorama.cpython-310.pyc b/unisharp/datasets/__pycache__/sim_panorama.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6047ad0ef1c6830d24c3e6fcb1c461250b66b25c Binary files /dev/null and b/unisharp/datasets/__pycache__/sim_panorama.cpython-310.pyc differ diff --git a/unisharp/datasets/__pycache__/sim_panorama.cpython-313.pyc b/unisharp/datasets/__pycache__/sim_panorama.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f165fee0cb3b3144371247cf6755926ebc81591 Binary files /dev/null and b/unisharp/datasets/__pycache__/sim_panorama.cpython-313.pyc differ diff --git a/unisharp/datasets/__pycache__/wildrgbd.cpython-310.pyc b/unisharp/datasets/__pycache__/wildrgbd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98eabbfeb61d7b516cce22d6fd664bd4ec25235a Binary files /dev/null and b/unisharp/datasets/__pycache__/wildrgbd.cpython-310.pyc differ diff --git a/unisharp/datasets/__pycache__/wildrgbd.cpython-313.pyc b/unisharp/datasets/__pycache__/wildrgbd.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5565400b66063293afa310c30d8fb07f4b1ecd57 Binary files /dev/null and b/unisharp/datasets/__pycache__/wildrgbd.cpython-313.pyc differ diff --git a/unisharp/datasets/dl3dv.py b/unisharp/datasets/dl3dv.py new file mode 100644 index 0000000000000000000000000000000000000000..20699fc918bfa78e5db1b5499cec00a5ceb5a53e --- /dev/null +++ b/unisharp/datasets/dl3dv.py @@ -0,0 +1,305 @@ + +from __future__ import annotations + +from collections import defaultdict, deque +import json +from pathlib import Path +import random + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch.utils.data import IterableDataset + +from unisharp.datasets.pair_sampling import ( + project_overlap_ratio, + resize_k3_align_corners_false, + resize_rgb_u8_chw_high_quality, + select_targets_for_source, +) +from unisharp.datasets.re10k import Re10KPairSample, re10k_collate +from unisharp import DEFAULT_MAX_DEPTH_M + + +class DL3DVDataset(IterableDataset): + def __init__( + self, + root: Path, + depth_root: Path, + scene_specs_file: Path | None = None, + min_frame_gap: int = 1, + max_frame_gap: int = 32, + pair_max_translation_m: float = 0.5, + pair_min_overlap: float = 0.6, + pair_overlap_sample_h: int = 32, + pair_overlap_sample_w: int = 56, + output_h: int | None = None, + output_w: int | None = None, + shuffle_scene: bool = True, + shuffle_frame: bool = False, + ddp_rank: int = 0, + ddp_world_size: int = 1, + batch_size_hint: int = 1, + depth_max_m: float = DEFAULT_MAX_DEPTH_M, + seed: int = 0, + verify_manifest_paths: bool = False, + ) -> None: + super().__init__() + self.root = Path(root) + self.depth_root = Path(depth_root) + self.min_frame_gap = int(min_frame_gap) + self.max_frame_gap = int(max_frame_gap) + self.pair_max_translation_m = float(pair_max_translation_m) + self.pair_min_overlap = float(pair_min_overlap) + self.pair_overlap_sample_h = int(pair_overlap_sample_h) + self.pair_overlap_sample_w = int(pair_overlap_sample_w) + self.output_h = int(output_h) if output_h is not None else None + self.output_w = int(output_w) if output_w is not None else None + self.shuffle_scene = bool(shuffle_scene) + self.shuffle_frame = bool(shuffle_frame) + self.ddp_rank = int(ddp_rank) + self.ddp_world_size = int(ddp_world_size) + self.batch_size_hint = int(max(1, batch_size_hint)) + self.depth_max_m = float(depth_max_m) + self.seed = int(seed) + self.epoch = 0 + self.verify_manifest_paths = bool(verify_manifest_paths) + self.scene_specs_file = Path(scene_specs_file) if scene_specs_file is not None else None + self.scene_specs = self._load_scene_specs() + if not self.scene_specs: + raise RuntimeError(f"No valid DL3DV scenes found under {self.root}") + + def set_epoch(self, epoch: int) -> None: + self.epoch = int(epoch) + + def _load_scene_specs(self) -> list[tuple[str, Path, Path]]: + if self.scene_specs_file is None: + return self._scan_scenes() + if not self.scene_specs_file.exists(): + raise FileNotFoundError(self.scene_specs_file) + out: list[tuple[str, Path, Path]] = [] + for raw in self.scene_specs_file.read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if not line: + continue + parts = line.split("|") + if len(parts) != 3: + continue + scene_name, scene_dir_raw, depth_dir_raw = parts + scene_dir = Path(scene_dir_raw) + depth_dir = Path(depth_dir_raw) + if (not self.verify_manifest_paths) or (scene_dir.exists() and depth_dir.exists()): + out.append((scene_name, scene_dir, depth_dir)) + return out + + def _scan_scenes(self) -> list[tuple[str, Path, Path]]: + out: list[tuple[str, Path, Path]] = [] + for bucket_dir in sorted([p for p in self.root.iterdir() if p.is_dir()]): + for scene_stub in sorted([p for p in bucket_dir.iterdir() if p.is_dir()]): + inner_dirs = [p for p in scene_stub.iterdir() if p.is_dir()] + scene_dir = inner_dirs[0] if inner_dirs else scene_stub + transforms_path = scene_dir / "transforms.json" + image_dir = scene_dir / "images_4" + depth_dir = self.depth_root / bucket_dir.name / scene_stub.name / "exports" / "mini_npz" / "per_image" + if transforms_path.exists() and image_dir.exists() and depth_dir.exists(): + scene_name = f"{bucket_dir.name}/{scene_stub.name}" + out.append((scene_name, scene_dir, depth_dir)) + return out + + @staticmethod + def _load_rgb_u8(path: Path) -> torch.Tensor: + arr = np.asarray(Image.open(path).convert("RGB"), dtype=np.uint8).copy() + return torch.from_numpy(arr).permute(2, 0, 1).contiguous() + + def _load_depth_m(self, path: Path) -> torch.Tensor: + payload = np.load(path) + depth = payload["depth"].astype(np.float32) + depth[~np.isfinite(depth)] = 0.0 + depth = np.clip(depth, a_min=0.0, a_max=self.depth_max_m) + return torch.from_numpy(depth).unsqueeze(0) + + @staticmethod + def _resize_depth_to_image(depth: torch.Tensor, image_hw: tuple[int, int]) -> torch.Tensor: + target_h, target_w = int(image_hw[0]), int(image_hw[1]) + if depth.shape[-2:] == (target_h, target_w): + return depth + return F.interpolate( + depth.unsqueeze(0), + size=(target_h, target_w), + mode="nearest", + ).squeeze(0) + + @staticmethod + def _frame_id_from_name(name: str) -> int: + stem = Path(name).stem + return int(stem.split("_")[-1]) + + def _load_scene( + self, + scene_name: str, + scene_dir: Path, + depth_dir: Path, + ) -> tuple[list[int], dict[int, Path], dict[int, Path], dict[int, torch.Tensor], dict[int, torch.Tensor], torch.Tensor]: + meta = json.loads((scene_dir / "transforms.json").read_text()) + orig_w = int(meta["w"]) + orig_h = int(meta["h"]) + k = torch.eye(3, dtype=torch.float32) + k[0, 0] = float(meta["fl_x"]) + k[1, 1] = float(meta["fl_y"]) + k[0, 2] = float(meta["cx"]) + k[1, 2] = float(meta["cy"]) + + image_dir = scene_dir / "images_4" + image_paths = {self._frame_id_from_name(p.name): p for p in image_dir.glob("*.png")} + depth_paths = {self._frame_id_from_name(p.name): p for p in depth_dir.glob("*.npz")} + w2c_map: dict[int, torch.Tensor] = {} + intr_map: dict[int, torch.Tensor] = {} + valid_ids: list[int] = [] + + example_img = None + for frame in meta.get("frames", []): + rel_path = str(frame.get("file_path", "")) + frame_name = Path(rel_path).name + frame_id = self._frame_id_from_name(frame_name) + if frame_id not in image_paths or frame_id not in depth_paths: + continue + c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32) + c2w[:3, 1:3] *= -1.0 + if example_img is None: + example_img = self._load_rgb_u8(image_paths[frame_id]) + cur_h, cur_w = int(example_img.shape[1]), int(example_img.shape[2]) + k_cur = k.clone() + if cur_h != orig_h or cur_w != orig_w: + sx = float(cur_w) / float(orig_w) + sy = float(cur_h) / float(orig_h) + k_cur = resize_k3_align_corners_false(k_cur, sx=sx, sy=sy) + w2c_map[frame_id] = torch.linalg.inv(c2w) + intr_map[frame_id] = k_cur + valid_ids.append(frame_id) + valid_ids = sorted(valid_ids) + return valid_ids, image_paths, depth_paths, w2c_map, intr_map, k + + def __iter__(self): + scenes = list(self.scene_specs) + order_rng = random.Random(self.seed + self.epoch) + if self.shuffle_scene: + order_rng.shuffle(scenes) + pending_by_hw: dict[tuple[int, int], deque[Re10KPairSample]] = defaultdict(deque) + worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + worker_id = worker_info.id if worker_info is not None else 0 + total_shards = max(1, self.ddp_world_size * num_workers) + shard_id = self.ddp_rank * num_workers + worker_id + src_unit_index = 0 + + for scene_order_idx, (scene_name, scene_dir, depth_dir) in enumerate(scenes): + try: + valid_ids, image_paths, depth_paths, w2c_map, intr_map, _ = self._load_scene(scene_name, scene_dir, depth_dir) + except Exception: + continue + if len(valid_ids) < 2: + continue + src_order = list(valid_ids) + scene_rng = random.Random(self.seed + self.epoch * 1000003 + scene_order_idx) + if self.shuffle_frame: + scene_rng.shuffle(src_order) + centers = torch.stack([torch.linalg.inv(w2c_map[i])[:3, 3] for i in valid_ids], dim=0) + frame_to_pos = {fid: pos for pos, fid in enumerate(valid_ids)} + + def overlap_avg(src_pos: int, tgt_pos: int) -> float: + src_fid = int(valid_ids[src_pos]) + tgt_fid = int(valid_ids[tgt_pos]) + src_img_path = image_paths[src_fid] + with Image.open(src_img_path) as img: + w = int(img.size[0]) + h = int(img.size[1]) + return float( + 0.5 + * ( + project_overlap_ratio( + src_w2c=w2c_map[src_fid], + tgt_w2c=w2c_map[tgt_fid], + src_k=intr_map[src_fid], + tgt_k=intr_map[tgt_fid], + h=h, + w=w, + sample_h=self.pair_overlap_sample_h, + sample_w=self.pair_overlap_sample_w, + ) + + project_overlap_ratio( + src_w2c=w2c_map[tgt_fid], + tgt_w2c=w2c_map[src_fid], + src_k=intr_map[tgt_fid], + tgt_k=intr_map[src_fid], + h=h, + w=w, + sample_h=self.pair_overlap_sample_h, + sample_w=self.pair_overlap_sample_w, + ) + ) + ) + + for src_idx in src_order: + if src_unit_index % total_shards != shard_id: + src_unit_index += 1 + continue + src_unit_index += 1 + src_pos = int(frame_to_pos[int(src_idx)]) + tgt_pos_list = select_targets_for_source( + src_idx=src_pos, + candidate_indices=list(range(len(valid_ids))), + centers=centers, + min_index_gap=int(self.min_frame_gap), + max_index_gap=int(self.max_frame_gap), + pair_max_translation_m=float(self.pair_max_translation_m), + pair_min_overlap=float(self.pair_min_overlap), + overlap_score_fn=overlap_avg, + ) + if not tgt_pos_list: + continue + tgt_idx = int(valid_ids[scene_rng.choice(tgt_pos_list)]) + try: + src_img = self._load_rgb_u8(image_paths[int(src_idx)]) + tgt_img = self._load_rgb_u8(image_paths[int(tgt_idx)]) + src_depth = self._load_depth_m(depth_paths[int(src_idx)]) + tgt_depth = self._load_depth_m(depth_paths[int(tgt_idx)]) + except Exception: + continue + src_depth = self._resize_depth_to_image(src_depth, (int(src_img.shape[1]), int(src_img.shape[2]))) + tgt_depth = self._resize_depth_to_image(tgt_depth, (int(tgt_img.shape[1]), int(tgt_img.shape[2]))) + src_intr = intr_map[int(src_idx)].clone() + tgt_intr = intr_map[int(tgt_idx)].clone() + if self.output_h is not None and self.output_w is not None: + oh, ow = int(src_img.shape[1]), int(src_img.shape[2]) + if oh != self.output_h or ow != self.output_w: + sx = float(self.output_w) / float(ow) + sy = float(self.output_h) / float(oh) + src_img = resize_rgb_u8_chw_high_quality(src_img, size=(self.output_h, self.output_w)) + tgt_img = resize_rgb_u8_chw_high_quality(tgt_img, size=(self.output_h, self.output_w)) + src_depth = F.interpolate(src_depth[None], size=(self.output_h, self.output_w), mode="nearest")[0] + tgt_depth = F.interpolate(tgt_depth[None], size=(self.output_h, self.output_w), mode="nearest")[0] + src_intr = resize_k3_align_corners_false(src_intr, sx=sx, sy=sy) + tgt_intr = resize_k3_align_corners_false(tgt_intr, sx=sx, sy=sy) + sample = Re10KPairSample( + src_rgb_u8=src_img, + tgt_rgb_u8=tgt_img, + src_w2c=w2c_map[int(src_idx)], + tgt_w2c=w2c_map[int(tgt_idx)], + src_intrinsics=src_intr, + tgt_intrinsics=tgt_intr, + src_idx=int(src_idx), + tgt_idx=int(tgt_idx), + scene=scene_name, + src_depth_m=src_depth, + tgt_depth_m=tgt_depth, + ) + hw_key = (int(sample.src_rgb_u8.shape[1]), int(sample.src_rgb_u8.shape[2])) + bucket = pending_by_hw[hw_key] + bucket.append(sample) + if self.batch_size_hint <= 1: + yield bucket.popleft() + continue + while len(bucket) >= self.batch_size_hint: + packed = [bucket.popleft() for _ in range(self.batch_size_hint)] + yield re10k_collate(packed) diff --git a/unisharp/datasets/pair_sampling.py b/unisharp/datasets/pair_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..439dfef1342972caa483d85929710b33097ed865 --- /dev/null +++ b/unisharp/datasets/pair_sampling.py @@ -0,0 +1,99 @@ + +from __future__ import annotations + +from typing import Callable + +import torch +import torch.nn.functional as F + +from unisharp.utils.pixel_convention import scale_intrinsics_align_corners_false + + +def resize_k3_align_corners_false(k: torch.Tensor, *, sx: float, sy: float) -> torch.Tensor: + return scale_intrinsics_align_corners_false(k, sx=float(sx), sy=float(sy)) + + +def resize_rgb_u8_chw_high_quality(image: torch.Tensor, *, size: tuple[int, int]) -> torch.Tensor: + if not torch.is_tensor(image) or image.ndim != 3: + raise ValueError(f"Expected CHW tensor, got {tuple(image.shape) if torch.is_tensor(image) else type(image)}") + dst_h, dst_w = int(size[0]), int(size[1]) + if tuple(image.shape[-2:]) == (dst_h, dst_w): + return image.contiguous() + resized = F.interpolate( + image.unsqueeze(0).to(torch.float32), + size=(dst_h, dst_w), + mode="bicubic", + align_corners=False, + antialias=True, + ) + return resized[0].round().clamp(0, 255).to(torch.uint8).contiguous() + + +def project_overlap_ratio( + src_w2c: torch.Tensor, + tgt_w2c: torch.Tensor, + src_k: torch.Tensor, + tgt_k: torch.Tensor, + h: int, + w: int, + src_hw: tuple[int, int] | None = None, + tgt_hw: tuple[int, int] | None = None, + sample_h: int = 32, + sample_w: int = 56, + proxy_depth: float = 1.0, +) -> float: + device = src_w2c.device + src_h, src_w = tuple(int(v) for v in (src_hw or (h, w))) + tgt_h, tgt_w = tuple(int(v) for v in (tgt_hw or (h, w))) + ys = torch.linspace(0, src_h - 1, steps=sample_h, device=device) + xs = torch.linspace(0, src_w - 1, steps=sample_w, device=device) + vv, uu = torch.meshgrid(ys, xs, indexing="ij") + u = uu.reshape(-1) + v = vv.reshape(-1) + + fx, fy = src_k[0, 0], src_k[1, 1] + cx, cy = src_k[0, 2], src_k[1, 2] + x = (u - cx) / fx + y = (v - cy) / fy + z = torch.ones_like(x) + rays = torch.stack([x, y, z], dim=-1) + rays = rays / torch.norm(rays, dim=-1, keepdim=True).clamp(min=1e-6) + pts_src = rays * float(proxy_depth) + + src_c2w = torch.linalg.inv(src_w2c) + pts_src_h = torch.cat([pts_src, torch.ones_like(pts_src[:, :1])], dim=-1) + pts_w = (src_c2w @ pts_src_h.T).T + pts_tgt = (tgt_w2c @ pts_w.T).T + xt, yt, zt = pts_tgt[:, 0], pts_tgt[:, 1], pts_tgt[:, 2].clamp(min=1e-6) + ut = tgt_k[0, 0] * (xt / zt) + tgt_k[0, 2] + vt = tgt_k[1, 1] * (yt / zt) + tgt_k[1, 2] + inside = (zt > 0.0) & (ut >= 0.0) & (ut <= float(tgt_w - 1)) & (vt >= 0.0) & (vt <= float(tgt_h - 1)) + return float(inside.float().mean().item()) + + +def select_targets_for_source( + *, + src_idx: int, + candidate_indices: list[int], + centers: torch.Tensor, + min_index_gap: int, + max_index_gap: int, + pair_max_translation_m: float, + pair_min_overlap: float, + overlap_score_fn: Callable[[int, int], float], +) -> list[int]: + src_c = centers[int(src_idx)] + tgt_cands: list[int] = [] + for j in candidate_indices: + j = int(j) + if j == int(src_idx): + continue + gap = abs(int(j) - int(src_idx)) + if gap < int(min_index_gap) or gap > int(max_index_gap): + continue + trans = float(torch.norm(centers[j] - src_c, p=2).item()) + if trans > float(pair_max_translation_m): + continue + if float(overlap_score_fn(int(src_idx), j)) >= float(pair_min_overlap): + tgt_cands.append(j) + return tgt_cands diff --git a/unisharp/datasets/panogs.py b/unisharp/datasets/panogs.py new file mode 100644 index 0000000000000000000000000000000000000000..d228e9afaeadc323d6fb72f84cb3734f61e63a65 --- /dev/null +++ b/unisharp/datasets/panogs.py @@ -0,0 +1,555 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal +from typing import cast +import tarfile + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +from unisharp import DEFAULT_MAX_DEPTH_M + + +MAX_DEPTH_M = DEFAULT_MAX_DEPTH_M + +_PAIR_RECIPE_FIXED: tuple[str, bool] = ("c2w", True) + +_PAIR_CONVENTIONS: tuple[str, ...] = ("c2w",) + + +def _torch_load_any(path: Path) -> object: + try: + return torch.load(path, map_location="cpu", weights_only=False) + except TypeError: + return torch.load(path, map_location="cpu") + except (KeyError, tarfile.ReadError, EOFError, OSError, RuntimeError) as e: + raise RuntimeError(f"torch.load failed (possibly incomplete/corrupted): {path}") from e + + +@dataclass(frozen=True) +class PanOGSSample: + + src_erp_rgb_u8: torch.Tensor + tgt_erp_rgb_u8: torch.Tensor + src_erp_depth_m: torch.Tensor + tgt_erp_depth_m: torch.Tensor + + src_cube_rgb_u8: torch.Tensor + tgt_cube_rgb_u8: torch.Tensor + src_cube_depth_m: torch.Tensor + tgt_cube_depth_m: torch.Tensor + + src_R: torch.Tensor + src_t: torch.Tensor + tgt_R: torch.Tensor + tgt_t: torch.Tensor + + src_idx: int + tgt_idx: int + scene: str + + +def _load_erp_rgb_u8(path: Path) -> torch.Tensor: + img = np.array(Image.open(path)) + if img.ndim != 3 or img.shape[2] != 3: + raise ValueError(f"Expected RGB image at {path}, got shape={img.shape}") + return torch.from_numpy(img.astype(np.uint8)).permute(2, 0, 1).contiguous() + + +def _load_depth_png(path: Path) -> torch.Tensor: + dep = np.array(Image.open(path)) + return torch.from_numpy(dep) + + +def _depth_to_meters(depth: torch.Tensor, max_depth_m: float = DEFAULT_MAX_DEPTH_M) -> torch.Tensor: + depth_f = depth.to(torch.float32) + maxv = float(depth_f.max().item()) if depth_f.numel() else 0.0 + if maxv > 200.0: + depth_f = depth_f / 1000.0 + depth_f[~torch.isfinite(depth_f)] = 0.0 + return depth_f.clamp(min=0.0, max=float(max_depth_m)) + + +class PanOGSDataset(Dataset[PanOGSSample]): + + def __init__( + self, + root: Path, + index_manifest_path: Path | None = None, + src_tgt_max_index_gap: int = 25, + use_cubemap_supervision: bool = True, + pair_sampling: bool = True, + pair_max_translation_m: float = 0.5, + pair_min_depth_overlap: float = 0.6, + pair_overlap_face_w: int = 64, + pair_overlap_margin: float = 1.05, + pair_max_tries: int = 48, + depth_max_m: float = DEFAULT_MAX_DEPTH_M, + ) -> None: + self.root = root + self.src_tgt_max_index_gap = int(src_tgt_max_index_gap) + self.use_cubemap_supervision = use_cubemap_supervision + self.pair_sampling = bool(pair_sampling) + self.pair_max_translation_m = float(pair_max_translation_m) + self.pair_min_depth_overlap = float(pair_min_depth_overlap) + self.pair_overlap_face_w = int(pair_overlap_face_w) + self.pair_overlap_margin = float(pair_overlap_margin) + self.pair_max_tries = int(pair_max_tries) + self.depth_max_m = float(depth_max_m) + self.index_manifest_path = Path(index_manifest_path) if index_manifest_path is not None else None + + self._pair_valid_tgts: dict[tuple[str, int], list[int]] = {} + self._pair_overlap_cache: dict[tuple[str, int, int], float] = {} + + if not root.exists(): + raise FileNotFoundError(root) + + self.scenes = sorted([p for p in root.iterdir() if p.is_dir()]) + if not self.scenes: + raise RuntimeError(f"No scene folders found in {root}") + + self._pose_cache: dict[str, tuple[np.ndarray, np.ndarray]] = {} + self._meta_paths: dict[str, Path] = {} + self._num_frames: dict[str, int] = {} + self._available_frames: dict[str, list[int]] = {} + + if self.index_manifest_path is not None: + if not self.index_manifest_path.exists(): + raise FileNotFoundError(self.index_manifest_path) + valid_scenes: list[Path] = [] + for raw in self.index_manifest_path.read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if not line: + continue + parts = line.split("|") + scene_name = parts[0].strip() + if not scene_name: + continue + scene_dir = root / scene_name + meta_path = scene_dir / "meta.pt" + if not meta_path.exists(): + continue + if len(parts) >= 2: + try: + n_pose = int(parts[1]) + except ValueError: + n_pose = 0 + else: + n_pose = 0 + if n_pose <= 0: + continue + self._meta_paths[scene_name] = meta_path + self._num_frames[scene_name] = n_pose + self._available_frames[scene_name] = list(range(n_pose)) + valid_scenes.append(scene_dir) + self.scenes = valid_scenes + + + if not self._available_frames: + + valid_scenes = [] + for scene_i, scene_dir in enumerate(self.scenes): + meta_path = scene_dir / "meta.pt" + if not meta_path.exists(): + continue + + ex = _torch_load_any(meta_path) + cams = ex.get("cameras", None) + if not isinstance(cams, torch.Tensor): + raise ValueError(f"meta.pt missing 'cameras' tensor in {scene_dir}") + if cams.ndim != 3 or tuple(cams.shape[1:]) != (4, 4): + raise ValueError(f"Bad meta.pt cameras shape {tuple(cams.shape)} in {scene_dir}") + n_pose = int(cams.shape[0]) + + frames = list(range(n_pose)) + + name = scene_dir.name + self._meta_paths[name] = meta_path + self._num_frames[name] = n_pose + self._available_frames[name] = frames + valid_scenes.append(scene_dir) + + self.scenes = valid_scenes + + def _get_pose(self, scene: str) -> tuple[np.ndarray, np.ndarray]: + cached = self._pose_cache.get(scene) + if cached is not None: + return cached + + meta_path = self._meta_paths.get(scene) + if meta_path is None: + raise FileNotFoundError(f"meta.pt not indexed for scene={scene} under {self.root}") + ex = _torch_load_any(meta_path) + cams = ex.get("cameras", None) + if not isinstance(cams, torch.Tensor): + raise ValueError(f"meta.pt missing 'cameras' tensor for scene={scene}") + cams = cams.to(torch.float32) + if cams.ndim != 3 or tuple(cams.shape[1:]) != (4, 4): + raise ValueError(f"Bad meta.pt cameras shape {tuple(cams.shape)} for scene={scene}") + R = cams[:, :3, :3].cpu().numpy() + t = cams[:, :3, 3].cpu().numpy() + out = (R, t) + self._pose_cache[scene] = out + return out + + def __len__(self) -> int: + return len(self._index) + + def _sample_target(self, scene: str, src_idx: int) -> int: + frames = self._available_frames[scene] + if len(frames) <= 1: + return src_idx + effective_gap = self.src_tgt_max_index_gap + candidates = [i for i in frames if i != src_idx and abs(i - src_idx) <= effective_gap] + if not candidates: + return src_idx + j = int(torch.randint(low=0, high=len(candidates), size=(1,)).item()) + return int(candidates[j]) + + def _candidate_targets_by_translation(self, scene: str, src_idx: int) -> list[int]: + frames = self._available_frames[scene] + if len(frames) <= 1: + return [] + R_np, t_np = self._get_pose(scene) + if not (0 <= src_idx < len(t_np) and 0 <= src_idx < len(R_np)): + return [] + th = float(self.pair_max_translation_m) + + def _cam_center_from(R: np.ndarray, t: np.ndarray, conv: str) -> np.ndarray: + if conv in ("c2w", "w2c_t_camcenter"): + return t + if conv == "w2c": + return -(R.transpose(0, 2, 1) @ t[..., None])[..., 0] + if conv == "c2w_t_w2c": + return -(R @ t[..., None])[..., 0] + raise ValueError(conv) + + def _min_dist(idxs: np.ndarray) -> np.ndarray: + R_sub = R_np[idxs].astype(np.float32) + t_sub = t_np[idxs].astype(np.float32) + R_src = R_np[int(src_idx) : int(src_idx) + 1].astype(np.float32) + t_src = t_np[int(src_idx) : int(src_idx) + 1].astype(np.float32) + d_min = None + for conv in _PAIR_CONVENTIONS: + C_src = _cam_center_from(R_src, t_src, conv)[0] + C_sub = _cam_center_from(R_sub, t_sub, conv) + d = np.linalg.norm(C_sub - C_src[None, :], axis=1) + d_min = d if (d_min is None) else np.minimum(d_min, d) + assert d_min is not None + return d_min + + effective_gap = self.src_tgt_max_index_gap + cand0 = np.array([i for i in frames if i != src_idx and abs(i - src_idx) <= effective_gap], dtype=np.int64) + if cand0.size > 0: + d0 = _min_dist(cand0) + ok0 = cand0[d0 < th] + if ok0.size > 0: + return [int(x) for x in ok0.tolist()] + return [] + + def _resize_cube_depth(self, depth: torch.Tensor, face_w: int) -> torch.Tensor: + if depth.ndim != 4 or depth.shape[0] != 6 or depth.shape[-1] != 1: + raise ValueError(f"Expected cube depth shape (6,H,W,1), got {tuple(depth.shape)}") + H = int(depth.shape[1]) + W = int(depth.shape[2]) + if H == face_w and W == face_w: + return depth.to(dtype=torch.float32) + import torch.nn.functional as F + + x = depth.permute(0, 3, 1, 2).to(dtype=torch.float32) + x = F.interpolate(x, size=(face_w, face_w), mode="bilinear", align_corners=False) + return x.permute(0, 2, 3, 1).contiguous() + + @staticmethod + def _cubemap_z_depth_to_distance(depth: torch.Tensor) -> torch.Tensor: + if depth.ndim != 4 or depth.shape[0] != 6 or depth.shape[-1] != 1: + raise ValueError(f"Expected cube depth shape (6,H,W,1), got {tuple(depth.shape)}") + from unisharp.utils.pano import get_pinhole_intrinsics_4x4 + + h = int(depth.shape[1]) + w = int(depth.shape[2]) + if h != w: + raise ValueError(f"Expected square cubemap faces, got {(h, w)}") + depth_61hw = depth.permute(0, 3, 1, 2).to(dtype=torch.float32).contiguous() + intr = get_pinhole_intrinsics_4x4(w).to(device=depth_61hw.device, dtype=depth_61hw.dtype) + ys = torch.arange(h, device=depth_61hw.device, dtype=depth_61hw.dtype) + xs = torch.arange(w, device=depth_61hw.device, dtype=depth_61hw.dtype) + vv, uu = torch.meshgrid(ys, xs, indexing="ij") + x = (uu - intr[0, 2]) / intr[0, 0].clamp(min=1e-8) + y = (vv - intr[1, 2]) / intr[1, 1].clamp(min=1e-8) + ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8) + dist = depth_61hw / ray_z.view(1, 1, h, w).clamp(min=1e-8) + valid = torch.isfinite(dist) & (dist > 0.0) + dist = torch.where(valid, dist, torch.zeros_like(dist)) + return dist.permute(0, 2, 3, 1).contiguous() + + def _pair_depth_overlap_score( + self, + *, + src_R: torch.Tensor, + src_t: torch.Tensor, + tgt_R: torch.Tensor, + tgt_t: torch.Tensor, + src_cube_depth_m: torch.Tensor, + tgt_cube_depth_m: torch.Tensor, + ) -> float: + from unisharp.utils.camera_projection import build_extrinsics_w2c, view_frustum_mask_cubemap_union # noqa: WPS433 + + device = torch.device("cpu") + src_R = src_R.to(device=device, dtype=torch.float32) + src_t = src_t.to(device=device, dtype=torch.float32) + tgt_R = tgt_R.to(device=device, dtype=torch.float32) + tgt_t = tgt_t.to(device=device, dtype=torch.float32) + + face_w = int(self.pair_overlap_face_w) + margin = float(self.pair_overlap_margin) + src_d = self._cubemap_z_depth_to_distance(self._resize_cube_depth(src_cube_depth_m.to(device=device), face_w=face_w)) + tgt_d = self._cubemap_z_depth_to_distance(self._resize_cube_depth(tgt_cube_depth_m.to(device=device), face_w=face_w)) + + def _score_one(recipe: tuple[str, bool]) -> float: + pose_conv, flip_yz = recipe + extr_src = build_extrinsics_w2c(src_R, src_t, pose_conv) + extr_tgt = build_extrinsics_w2c(tgt_R, tgt_t, pose_conv) + + with torch.autocast(device_type="cpu", enabled=False): + c2w_src = torch.linalg.inv(extr_src) + c2w_tgt = torch.linalg.inv(extr_tgt) + if bool(flip_yz): + D = torch.diag(torch.tensor([1.0, -1.0, -1.0, 1.0], dtype=torch.float32, device=device)) + c2w_src = c2w_src @ D + c2w_tgt = c2w_tgt @ D + ref_inv = torch.linalg.inv(c2w_src) + c2w_src = ref_inv @ c2w_src + c2w_tgt = ref_inv @ c2w_tgt + extr_src_n = torch.linalg.inv(c2w_src) + extr_tgt_n = torch.linalg.inv(c2w_tgt) + + m_tgt_in_src = view_frustum_mask_cubemap_union( + depth_novel=tgt_d, + extr_novel_w2c=extr_tgt_n, + extr_source_w2c=extr_src_n, + face_w=face_w, + margin=margin, + ) + m_src_in_tgt = view_frustum_mask_cubemap_union( + depth_novel=src_d, + extr_novel_w2c=extr_src_n, + extr_source_w2c=extr_tgt_n, + face_w=face_w, + margin=margin, + ) + tgt_valid = torch.isfinite(tgt_d[..., 0]) & (tgt_d[..., 0] > 0.0) + src_valid = torch.isfinite(src_d[..., 0]) & (src_d[..., 0] > 0.0) + denom_t = float(tgt_valid.sum().item()) + denom_s = float(src_valid.sum().item()) + if denom_t < 10 or denom_s < 10: + return 0.0 + a = float((m_tgt_in_src & tgt_valid).sum().item()) / denom_t + b = float((m_src_in_tgt & src_valid).sum().item()) / denom_s + return 0.5 * (a + b) + + return _score_one(_PAIR_RECIPE_FIXED) + + def __getitem__(self, idx: int) -> PanOGSSample: + src_erp: torch.Tensor | None = None + tgt_erp: torch.Tensor | None = None + src_dep: torch.Tensor | None = None + tgt_dep: torch.Tensor | None = None + src_cube: torch.Tensor | None = None + tgt_cube: torch.Tensor | None = None + src_cdep: torch.Tensor | None = None + tgt_cdep: torch.Tensor | None = None + last_err: Exception | None = None + + max_outer = 16 + for outer in range(max_outer): + scene, src_idx = self._index[int(idx) % len(self._index)] + scene_dir = self.root / scene + tgt_idx = self._sample_target(scene, src_idx) + + max_retries = 8 + ok = False + for _ in range(max_retries): + try: + if src_erp is None: + src_erp = _load_erp_rgb_u8(scene_dir / "pano" / f"{src_idx:05d}.png") + src_dep = _depth_to_meters( + _load_depth_png(scene_dir / "pano_depth" / f"{src_idx:05d}.png"), + max_depth_m=self.depth_max_m, + ) + if self.use_cubemap_supervision: + src_cube_any = _torch_load_any(scene_dir / "cubemaps" / f"{src_idx:05d}.torch") + src_cdep_any = _torch_load_any(scene_dir / "cubemaps_depth" / f"{src_idx:05d}.torch") + if not all(isinstance(x, torch.Tensor) for x in [src_cube_any, src_cdep_any]): + raise RuntimeError("Bad .torch payload for src (expected Tensor).") + src_cube = cast(torch.Tensor, src_cube_any) + src_cdep = cast(torch.Tensor, src_cdep_any).to(torch.float32).clamp(min=0.0, max=self.depth_max_m) + else: + src_cube = torch.zeros((6, 256, 256, 3), dtype=torch.uint8) + src_cdep = torch.zeros((6, 256, 256, 1), dtype=torch.float32) + + candidates: list[int] = [] + if self.pair_sampling and self.use_cubemap_supervision: + key = (scene, int(src_idx)) + cached = self._pair_valid_tgts.get(key) + if cached: + candidates = list(cached) + else: + candidates = self._candidate_targets_by_translation(scene, int(src_idx)) + if not candidates: + candidates = [int(tgt_idx)] + + tried: set[int] = set() + found = False + max_try = ( + 1 + if (not self.pair_sampling or not self.use_cubemap_supervision) + else max(1, self.pair_max_tries) + ) + for _try in range(max_try): + pool = [ + c + for c in candidates + if int(c) not in tried and int(c) != int(src_idx) + ] + if not pool: + break + j = int(torch.randint(0, len(pool), (1,)).item()) + tgt_idx = int(pool[j]) + tried.add(int(tgt_idx)) + + if self.use_cubemap_supervision: + tgt_cdep_any = _torch_load_any(scene_dir / "cubemaps_depth" / f"{tgt_idx:05d}.torch") + if not isinstance(tgt_cdep_any, torch.Tensor): + raise RuntimeError("Bad .torch payload for tgt depth (expected Tensor).") + tgt_cdep = cast(torch.Tensor, tgt_cdep_any).to(torch.float32).clamp(min=0.0, max=self.depth_max_m) + else: + tgt_cdep = torch.zeros((6, 256, 256, 1), dtype=torch.float32) + + if self.pair_sampling and self.use_cubemap_supervision: + k = (scene, int(src_idx), int(tgt_idx)) + score = self._pair_overlap_cache.get(k) + if score is None: + R_np, t_np = self._get_pose(scene) + src_R = torch.from_numpy(R_np[int(src_idx)]) + src_t = torch.from_numpy(t_np[int(src_idx)]) + tgt_R = torch.from_numpy(R_np[int(tgt_idx)]) + tgt_t = torch.from_numpy(t_np[int(tgt_idx)]) + score = self._pair_depth_overlap_score( + src_R=src_R, + src_t=src_t, + tgt_R=tgt_R, + tgt_t=tgt_t, + src_cube_depth_m=cast(torch.Tensor, src_cdep), + tgt_cube_depth_m=cast(torch.Tensor, tgt_cdep), + ) + self._pair_overlap_cache[k] = float(score) + if float(score) < float(self.pair_min_depth_overlap): + continue + kk = (scene, int(src_idx)) + self._pair_valid_tgts.setdefault(kk, []).append(int(tgt_idx)) + + tgt_erp = _load_erp_rgb_u8(scene_dir / "pano" / f"{tgt_idx:05d}.png") + tgt_dep = _depth_to_meters( + _load_depth_png(scene_dir / "pano_depth" / f"{tgt_idx:05d}.png"), + max_depth_m=self.depth_max_m, + ) + if self.use_cubemap_supervision: + tgt_cube_any = _torch_load_any(scene_dir / "cubemaps" / f"{tgt_idx:05d}.torch") + if not isinstance(tgt_cube_any, torch.Tensor): + raise RuntimeError("Bad .torch payload for tgt RGB cubemap (expected Tensor).") + tgt_cube = cast(torch.Tensor, tgt_cube_any) + else: + tgt_cube = torch.zeros((6, 256, 256, 3), dtype=torch.uint8) + + found = True + break + + if not found: + raise RuntimeError( + f"No valid tgt found for scene={scene} src={src_idx} within constraints " + f"(trans<{self.pair_max_translation_m}m, overlap>{self.pair_min_depth_overlap})." + ) + + ok = True + break + except (FileNotFoundError, RuntimeError, EOFError, KeyError, tarfile.ReadError, OSError) as e: + last_err = e + frames = self._available_frames.get(scene, []) + if not frames: + break + src_idx = int(frames[int(torch.randint(0, len(frames), (1,)).item())]) + tgt_idx = self._sample_target(scene, src_idx) + src_erp = None + src_dep = None + src_cube = None + src_cdep = None + + if ok: + break + + idx = int(idx) + 9973 + outer * 13 + else: + raise RuntimeError(f"PanOGS __getitem__ failed after retries. last_err={last_err}") + + assert src_erp is not None and tgt_erp is not None + assert src_dep is not None and tgt_dep is not None + assert src_cube is not None and tgt_cube is not None + assert src_cdep is not None and tgt_cdep is not None + + src_dep = src_dep.to(torch.float32).unsqueeze(0) + tgt_dep = tgt_dep.to(torch.float32).unsqueeze(0) + + R_np, t_np = self._get_pose(scene) + src_R = torch.from_numpy(R_np[src_idx]) + src_t = torch.from_numpy(t_np[src_idx]) + tgt_R = torch.from_numpy(R_np[tgt_idx]) + tgt_t = torch.from_numpy(t_np[tgt_idx]) + + return PanOGSSample( + src_erp_rgb_u8=src_erp, + tgt_erp_rgb_u8=tgt_erp, + src_erp_depth_m=src_dep, + tgt_erp_depth_m=tgt_dep, + src_cube_rgb_u8=src_cube, + tgt_cube_rgb_u8=tgt_cube, + src_cube_depth_m=src_cdep, + tgt_cube_depth_m=tgt_cdep, + src_R=src_R, + src_t=src_t, + tgt_R=tgt_R, + tgt_t=tgt_t, + src_idx=src_idx, + tgt_idx=tgt_idx, + scene=scene, + ) + + +def panogs_collate(batch: list[PanOGSSample]) -> PanOGSSample: + def stack(xs): + if isinstance(xs[0], torch.Tensor): + return torch.stack(xs, dim=0) + return xs + + return PanOGSSample( + src_erp_rgb_u8=stack([b.src_erp_rgb_u8 for b in batch]), + tgt_erp_rgb_u8=stack([b.tgt_erp_rgb_u8 for b in batch]), + src_erp_depth_m=stack([b.src_erp_depth_m for b in batch]), + tgt_erp_depth_m=stack([b.tgt_erp_depth_m for b in batch]), + src_cube_rgb_u8=stack([b.src_cube_rgb_u8 for b in batch]), + tgt_cube_rgb_u8=stack([b.tgt_cube_rgb_u8 for b in batch]), + src_cube_depth_m=stack([b.src_cube_depth_m for b in batch]), + tgt_cube_depth_m=stack([b.tgt_cube_depth_m for b in batch]), + src_R=stack([b.src_R for b in batch]), + src_t=stack([b.src_t for b in batch]), + tgt_R=stack([b.tgt_R for b in batch]), + tgt_t=stack([b.tgt_t for b in batch]), + src_idx=[b.src_idx for b in batch], # type: ignore[arg-type] + tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type] + scene=[b.scene for b in batch], # type: ignore[arg-type] + ) + diff --git a/unisharp/datasets/re10k.py b/unisharp/datasets/re10k.py new file mode 100644 index 0000000000000000000000000000000000000000..4856712a835e261502fc09528a0efcf87bc0c6a7 --- /dev/null +++ b/unisharp/datasets/re10k.py @@ -0,0 +1,718 @@ +from __future__ import annotations + +from collections import defaultdict, deque +from dataclasses import dataclass +from io import BytesIO +import logging +import os +from pathlib import Path +import random +import time + +import torch +import torchvision.transforms as tf +from PIL import Image +from torch.utils.data import IterableDataset + +from unisharp.datasets.pair_sampling import ( + project_overlap_ratio, + resize_k3_align_corners_false, + resize_rgb_u8_chw_high_quality, + select_targets_for_source, +) +from unisharp import DEFAULT_MAX_DEPTH_M +from unisharp.utils.pixel_convention import normalized_intrinsics_to_integer_pixel_k +from unisharp.utils.unik3d_adapter import infer_unik3d_pinhole, load_unik3d_model + + +LOGGER = logging.getLogger(__name__) + + +def _torch_load_any(path: Path) -> object: + try: + return torch.load(path, map_location="cpu", weights_only=False) + except TypeError: + return torch.load(path, map_location="cpu") + + +def _pack_re10k_batch(batch: list["Re10KPairSample"]) -> "Re10KPairSample": + def stack(xs): + if isinstance(xs[0], torch.Tensor): + ref_shape = tuple(xs[0].shape) + for idx, x in enumerate(xs[1:], start=1): + if tuple(x.shape) != ref_shape: + raise RuntimeError( + "RE10K collate got mixed tensor shapes: " + f"ref={ref_shape} mismatch_idx={idx} got={tuple(x.shape)}" + ) + return torch.stack(xs, dim=0) + return xs + + def stack_optional_depth(xs): + if all(torch.is_tensor(x) for x in xs): + ref_shape = tuple(xs[0].shape) + for idx, x in enumerate(xs[1:], start=1): + if tuple(x.shape) != ref_shape: + raise RuntimeError( + "RE10K collate got mixed depth shapes: " + f"ref={ref_shape} mismatch_idx={idx} got={tuple(x.shape)}" + ) + return torch.stack(xs, dim=0) + return None + + return Re10KPairSample( + src_rgb_u8=stack([b.src_rgb_u8 for b in batch]), + tgt_rgb_u8=stack([b.tgt_rgb_u8 for b in batch]), + src_w2c=stack([b.src_w2c for b in batch]), + tgt_w2c=stack([b.tgt_w2c for b in batch]), + src_intrinsics=stack([b.src_intrinsics for b in batch]), + tgt_intrinsics=stack([b.tgt_intrinsics for b in batch]), + src_idx=[b.src_idx for b in batch], # type: ignore[arg-type] + tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type] + scene=[b.scene for b in batch], # type: ignore[arg-type] + src_depth_m=stack_optional_depth([b.src_depth_m for b in batch]), # type: ignore[arg-type] + tgt_depth_m=stack_optional_depth([b.tgt_depth_m for b in batch]), # type: ignore[arg-type] + ) + + +def re10k_passthrough(batch: "Re10KPairSample") -> "Re10KPairSample": + return batch + + +@dataclass(frozen=True) +class Re10KPairSample: + src_rgb_u8: torch.Tensor + tgt_rgb_u8: torch.Tensor + src_w2c: torch.Tensor + tgt_w2c: torch.Tensor + src_intrinsics: torch.Tensor + tgt_intrinsics: torch.Tensor + src_idx: int + tgt_idx: int + scene: str + src_depth_m: torch.Tensor | None = None + tgt_depth_m: torch.Tensor | None = None + + +class Re10KDataset(IterableDataset): + + def __init__( + self, + root: Path, + chunks_file: Path | None = None, + split: str = "train", + min_frame_gap: int = 1, + max_frame_gap: int = 32, + pair_max_translation_m: float = 0.5, + pair_min_overlap: float = 0.6, + pair_overlap_sample_h: int = 32, + pair_overlap_sample_w: int = 56, + pair_max_tries: int = 32, + output_h: int | None = None, + output_w: int | None = None, + shuffle_chunk: bool = True, + shuffle_example: bool = True, + ddp_rank: int = 0, + ddp_world_size: int = 1, + pseudo_depth_root: Path | None = None, + pseudo_depth_autogen: bool = True, + pseudo_depth_backbone: str = "vitl", + pseudo_depth_device: str = "cpu", + pseudo_lock_timeout_sec: float = 120.0, + pseudo_lock_stale_sec: float = 1800.0, + pseudo_wait_poll_sec: float = 0.25, + batch_size_hint: int = 1, + depth_max_m: float = DEFAULT_MAX_DEPTH_M, + pseudo_far_depth_invalid_m: float = 30.0, + seed: int = 0, + ) -> None: + super().__init__() + self.root = root + self.split = split + self.min_frame_gap = int(min_frame_gap) + self.max_frame_gap = int(max_frame_gap) + self.pair_max_translation_m = float(pair_max_translation_m) + self.pair_min_overlap = float(pair_min_overlap) + self.pair_overlap_sample_h = int(pair_overlap_sample_h) + self.pair_overlap_sample_w = int(pair_overlap_sample_w) + self.pair_max_tries = int(pair_max_tries) + self.output_h = int(output_h) if output_h is not None else None + self.output_w = int(output_w) if output_w is not None else None + self.shuffle_chunk = bool(shuffle_chunk) + self.shuffle_example = bool(shuffle_example) + self.ddp_rank = int(ddp_rank) + self.ddp_world_size = int(ddp_world_size) + self.to_tensor = tf.ToTensor() + self.pseudo_depth_root = Path(pseudo_depth_root) if pseudo_depth_root is not None else None + self.pseudo_depth_autogen = bool(pseudo_depth_autogen) + self.pseudo_depth_backbone = str(pseudo_depth_backbone) + self.pseudo_depth_device = str(pseudo_depth_device) + self.pseudo_lock_timeout_sec = float(max(1.0, pseudo_lock_timeout_sec)) + self.pseudo_lock_stale_sec = float(max(30.0, pseudo_lock_stale_sec)) + self.pseudo_wait_poll_sec = float(max(0.05, pseudo_wait_poll_sec)) + self.batch_size_hint = int(max(1, batch_size_hint)) + self.depth_max_m = float(depth_max_m) + self.pseudo_far_depth_invalid_m = float(pseudo_far_depth_invalid_m) + self._pseudo_model: torch.nn.Module | None = None + self.seed = int(seed) + self.epoch = 0 + + self.chunks_file = Path(chunks_file) if chunks_file is not None else None + split_dir = self.root / self.split + if self.chunks_file is not None: + if not self.chunks_file.exists(): + raise FileNotFoundError(self.chunks_file) + chunks: list[Path] = [] + for raw in self.chunks_file.read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if not line: + continue + p = Path(line) + if not p.is_absolute(): + p = split_dir / p + if p.suffix == ".torch": + chunks.append(p) + self.chunks = sorted(chunks) + else: + if not split_dir.exists(): + raise FileNotFoundError(split_dir) + self.chunks = sorted([p for p in split_dir.iterdir() if p.suffix == ".torch"]) + if not self.chunks: + source = self.chunks_file if self.chunks_file is not None else split_dir + raise RuntimeError(f"No .torch chunks found for {source}") + + def set_epoch(self, epoch: int) -> None: + self.epoch = int(epoch) + + if self.pseudo_depth_root is not None: + (self.pseudo_depth_root / self.split).mkdir(parents=True, exist_ok=True) + + @staticmethod + def _decode_image_u8(image_bytes_tensor: torch.Tensor) -> torch.Tensor: + if image_bytes_tensor.dtype != torch.uint8: + raise ValueError(f"Expected uint8 bytes tensor, got {image_bytes_tensor.dtype}") + image = Image.open(BytesIO(image_bytes_tensor.numpy().tobytes())).convert("RGB") + chw_float = tf.ToTensor()(image) + return (chw_float * 255.0).round().to(torch.uint8) + + @staticmethod + def _convert_pose_row_to_w2c(poses: torch.Tensor) -> torch.Tensor: + t = poses.shape[0] + w2c = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(t, 1, 1) + w2c[:, :3] = poses[:, 6:].reshape(t, 3, 4).to(torch.float32) + return w2c + + @staticmethod + def _convert_intrinsics_to_pixel(poses: torch.Tensor, h: int, w: int) -> torch.Tensor: + t = poses.shape[0] + fx, fy, cx, cy = poses[:, 0], poses[:, 1], poses[:, 2], poses[:, 3] + del t + return normalized_intrinsics_to_integer_pixel_k( + fx, + fy, + cx, + cy, + height=int(h), + width=int(w), + ) + + @staticmethod + def _sanitize_scene(scene: str) -> str: + s = str(scene).strip() + s = s.replace("\\", "__").replace("/", "__") + return s if len(s) > 0 else "unknown_scene" + + def _pseudo_depth_path(self, scene: str, frame_idx: int) -> Path | None: + if self.pseudo_depth_root is None: + return None + scene_key = self._sanitize_scene(scene) + return self.pseudo_depth_root / self.split / scene_key / f"{int(frame_idx):05d}.pt" + + @staticmethod + def _load_pseudo_depth(path: Path) -> tuple[torch.Tensor | None, str]: + if not path.exists(): + return None, "unknown" + try: + payload = _torch_load_any(path) + depth_kind = "distance" + if isinstance(payload, dict): + depth = payload.get("depth_m", None) + depth_kind = str(payload.get("depth_kind", "distance")).strip().lower() + if depth_kind not in ("distance", "zdepth"): + depth_kind = "distance" + else: + depth = payload + if not torch.is_tensor(depth): + return None, "unknown" + if depth.ndim == 3 and depth.shape[0] == 1: + depth = depth[0] + if depth.ndim != 2: + return None, "unknown" + depth = depth.to(torch.float32) + valid = torch.isfinite(depth) & (depth > 0.0) + if int(valid.sum().item()) <= 0: + return None, "unknown" + return depth.unsqueeze(0), depth_kind + except Exception: + return None, "unknown" + + @staticmethod + def _distance_to_z_depth(depth_1hw: torch.Tensor, intrinsics_k3: torch.Tensor) -> torch.Tensor: + if depth_1hw.ndim != 3 or depth_1hw.shape[0] != 1: + raise ValueError(f"Expected depth shape (1,H,W), got {tuple(depth_1hw.shape)}") + d = depth_1hw.to(torch.float32) + h = int(d.shape[-2]) + w = int(d.shape[-1]) + k = intrinsics_k3.to(dtype=torch.float32, device=d.device) + fx = k[0, 0] + fy = k[1, 1] + cx = k[0, 2] + cy = k[1, 2] + ys = torch.arange(h, device=d.device, dtype=torch.float32) + xs = torch.arange(w, device=d.device, dtype=torch.float32) + vv, uu = torch.meshgrid(ys, xs, indexing="ij") + x = (uu - cx) / fx + y = (vv - cy) / fy + ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8) + z = d[0] * ray_z + return z.unsqueeze(0) + + @staticmethod + def _sanitize_pseudo_depth( + depth_1hw: torch.Tensor, + *, + max_depth_m: float = DEFAULT_MAX_DEPTH_M, + far_depth_invalid_m: float = 30.0, + ) -> torch.Tensor: + d = depth_1hw.to(torch.float32) + valid = torch.isfinite(d) & (d > 0.0) + if int(valid.sum().item()) <= 0: + return d + out = d.clone() + if float(far_depth_invalid_m) > 0.0: + valid = valid & (out <= float(far_depth_invalid_m)) + out = torch.where(valid, out, torch.zeros_like(out)) + out[valid] = out[valid].clamp(max=float(max_depth_m)) + return out + + def _get_or_create_pseudo_model(self) -> torch.nn.Module: + if self._pseudo_model is None: + dev = torch.device(self.pseudo_depth_device) + self._pseudo_model = load_unik3d_model( + backbone=self.pseudo_depth_backbone, + pretrained=True, + device=dev, + ) + self._pseudo_model.eval() + LOGGER.info( + "Re10K pseudo-depth model loaded (split=%s, device=%s, backbone=%s)", + self.split, + str(dev), + self.pseudo_depth_backbone, + ) + return self._pseudo_model + + def _save_pseudo_depth_atomic( + self, + path: Path, + depth_2d: torch.Tensor, + scene: str, + frame_idx: int, + ) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.parent / f".tmp_{os.getpid()}_{int(time.time() * 1e6)}_{random.randint(0, 10_000_000)}.pt" + payload = { + "depth_m": depth_2d.to(torch.float16), + "depth_kind": "distance", + "scene": str(scene), + "frame_idx": int(frame_idx), + } + torch.save(payload, tmp) + os.replace(tmp, path) + + def _acquire_lock_or_wait_for_file(self, target: Path) -> tuple[bool, bool]: + lock_dir = Path(str(target) + ".lock") + start = time.time() + while True: + if target.exists(): + return False, True + try: + lock_dir.mkdir(parents=False, exist_ok=False) + meta = lock_dir / "owner.txt" + meta.write_text(f"pid={os.getpid()} time={time.time():.3f}\n", encoding="utf-8") + return True, False + except FileExistsError: + try: + mtime = lock_dir.stat().st_mtime + if (time.time() - float(mtime)) > self.pseudo_lock_stale_sec: + for p in lock_dir.iterdir(): + try: + p.unlink() + except Exception: + pass + lock_dir.rmdir() + continue + except Exception: + pass + if (time.time() - start) >= self.pseudo_lock_timeout_sec: + return False, False + time.sleep(self.pseudo_wait_poll_sec) + except Exception: + return False, False + + def _release_lock(self, target: Path) -> None: + lock_dir = Path(str(target) + ".lock") + if not lock_dir.exists(): + return + try: + for p in lock_dir.iterdir(): + try: + p.unlink() + except Exception: + pass + lock_dir.rmdir() + except Exception: + pass + + def _get_pseudo_depth_for_frame( + self, + *, + scene: str, + frame_idx: int, + rgb_u8: torch.Tensor, + intrinsics_k3: torch.Tensor, + ) -> torch.Tensor | None: + path = self._pseudo_depth_path(scene, frame_idx) + if path is None: + return None + depth, depth_kind = self._load_pseudo_depth(path) + if depth is not None: + if depth_kind != "zdepth": + try: + depth = self._distance_to_z_depth( + self._sanitize_pseudo_depth( + depth, + max_depth_m=self.depth_max_m, + far_depth_invalid_m=self.pseudo_far_depth_invalid_m, + ), + intrinsics_k3=intrinsics_k3, + ) + except Exception: + return None + else: + depth = self._sanitize_pseudo_depth( + depth, + max_depth_m=self.depth_max_m, + far_depth_invalid_m=self.pseudo_far_depth_invalid_m, + ) + return depth + if not self.pseudo_depth_autogen: + return None + + acquired, ready = self._acquire_lock_or_wait_for_file(path) + if ready: + depth, depth_kind = self._load_pseudo_depth(path) + if depth is None: + return None + if depth_kind != "zdepth": + try: + depth = self._distance_to_z_depth( + self._sanitize_pseudo_depth( + depth, + max_depth_m=self.depth_max_m, + far_depth_invalid_m=self.pseudo_far_depth_invalid_m, + ), + intrinsics_k3=intrinsics_k3, + ) + except Exception: + return None + else: + depth = self._sanitize_pseudo_depth( + depth, + max_depth_m=self.depth_max_m, + far_depth_invalid_m=self.pseudo_far_depth_invalid_m, + ) + return depth + if not acquired: + depth, depth_kind = self._load_pseudo_depth(path) + if depth is None: + return None + if depth_kind != "zdepth": + try: + depth = self._distance_to_z_depth( + self._sanitize_pseudo_depth( + depth, + max_depth_m=self.depth_max_m, + far_depth_invalid_m=self.pseudo_far_depth_invalid_m, + ), + intrinsics_k3=intrinsics_k3, + ) + except Exception: + return None + else: + depth = self._sanitize_pseudo_depth( + depth, + max_depth_m=self.depth_max_m, + far_depth_invalid_m=self.pseudo_far_depth_invalid_m, + ) + return depth + + try: + depth, depth_kind = self._load_pseudo_depth(path) + if depth is not None: + if depth_kind != "zdepth": + try: + depth = self._distance_to_z_depth( + self._sanitize_pseudo_depth( + depth, + max_depth_m=self.depth_max_m, + far_depth_invalid_m=self.pseudo_far_depth_invalid_m, + ), + intrinsics_k3=intrinsics_k3, + ) + except Exception: + return None + else: + depth = self._sanitize_pseudo_depth( + depth, + max_depth_m=self.depth_max_m, + far_depth_invalid_m=self.pseudo_far_depth_invalid_m, + ) + return depth + model = self._get_or_create_pseudo_model() + out = infer_unik3d_pinhole( + model, + rgb_u8=rgb_u8.unsqueeze(0), + intrinsics=intrinsics_k3.unsqueeze(0), + ) + dist = out.get("distance", None) if isinstance(out, dict) else None + if not torch.is_tensor(dist) or dist.ndim != 4 or dist.shape[1] != 1: + return None + dist_1hw = self._sanitize_pseudo_depth( + dist[0:1, 0:1].detach().to(torch.float32).cpu()[0], + max_depth_m=self.depth_max_m, + far_depth_invalid_m=self.pseudo_far_depth_invalid_m, + ) + valid = torch.isfinite(dist_1hw) & (dist_1hw > 0.0) + if int(valid.sum().item()) <= 0: + return None + self._save_pseudo_depth_atomic( + path, + depth_2d=dist_1hw[0], + scene=scene, + frame_idx=frame_idx, + ) + return self._distance_to_z_depth(dist_1hw, intrinsics_k3=intrinsics_k3.cpu()) + except Exception as e: + LOGGER.warning( + "Pseudo-depth generation failed scene=%s frame=%d: %s", + str(scene), + int(frame_idx), + str(e), + ) + return None + finally: + self._release_lock(path) + + def _candidate_target_indices( + self, + src_idx: int, + num_frames: int, + w2c_all: torch.Tensor, + intr_all: torch.Tensor, + h: int, + w: int, + ) -> list[int]: + if num_frames < 2: + return [] + centers = torch.linalg.inv(w2c_all)[:, :3, 3].to(torch.float32) + sample_h = int(self.pair_overlap_sample_h) + sample_w = int(self.pair_overlap_sample_w) + return select_targets_for_source( + src_idx=int(src_idx), + candidate_indices=list(range(num_frames)), + centers=centers, + min_index_gap=int(self.min_frame_gap), + max_index_gap=int(self.max_frame_gap), + pair_max_translation_m=float(self.pair_max_translation_m), + pair_min_overlap=float(self.pair_min_overlap), + overlap_score_fn=lambda si, tj: float( + 0.5 + * ( + project_overlap_ratio( + src_w2c=w2c_all[si], + tgt_w2c=w2c_all[tj], + src_k=intr_all[si], + tgt_k=intr_all[tj], + h=h, + w=w, + sample_h=sample_h, + sample_w=sample_w, + ) + + project_overlap_ratio( + src_w2c=w2c_all[tj], + tgt_w2c=w2c_all[si], + src_k=intr_all[tj], + tgt_k=intr_all[si], + h=h, + w=w, + sample_h=sample_h, + sample_w=sample_w, + ) + ) + ), + ) + + def __iter__(self): + chunks = list(self.chunks) + order_rng = random.Random(self.seed + self.epoch) + if self.shuffle_chunk and self.split == "train": + order_rng.shuffle(chunks) + pending_by_hw: dict[tuple[int, int], deque[Re10KPairSample]] = defaultdict(deque) + + worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + worker_id = worker_info.id if worker_info is not None else 0 + total_shards = max(1, self.ddp_world_size * num_workers) + shard_id = self.ddp_rank * num_workers + worker_id + chunks = [chunk for i, chunk in enumerate(chunks) if i % total_shards == shard_id] + + for chunk_order_idx, chunk_path in enumerate(chunks): + chunk = _torch_load_any(chunk_path) + if not isinstance(chunk, list): + continue + examples = list(chunk) + chunk_rng = random.Random(self.seed + self.epoch * 1000003 + chunk_order_idx) + if self.shuffle_example and self.split == "train": + chunk_rng.shuffle(examples) + + for example in examples: + if not isinstance(example, dict): + continue + if "cameras" not in example or "images" not in example: + continue + poses = example["cameras"] + images = example["images"] + scene = str(example.get("key", "unknown")) + if not torch.is_tensor(poses) or not isinstance(images, list): + continue + if poses.ndim != 2 or poses.shape[1] != 18: + continue + if len(images) != int(poses.shape[0]): + continue + + try: + src_probe = self._decode_image_u8(images[0]) + except Exception: + continue + h, w = int(src_probe.shape[1]), int(src_probe.shape[2]) + w2c_all = self._convert_pose_row_to_w2c(poses) + intr_all = self._convert_intrinsics_to_pixel(poses, h=h, w=w) + src_indices = list(range(len(images))) + if self.shuffle_example and self.split == "train": + chunk_rng.shuffle(src_indices) + for src_idx in src_indices: + tgt_candidates = self._candidate_target_indices( + int(src_idx), + len(images), + w2c_all=w2c_all, + intr_all=intr_all, + h=h, + w=w, + ) + if not tgt_candidates: + continue + tgt_idx = chunk_rng.choice(tgt_candidates) + + try: + src_img = self._decode_image_u8(images[src_idx]) + tgt_img = self._decode_image_u8(images[tgt_idx]) + except Exception: + continue + if src_img.shape != tgt_img.shape: + continue + src_intr = intr_all[src_idx].clone() + tgt_intr = intr_all[tgt_idx].clone() + src_depth = self._get_pseudo_depth_for_frame( + scene=scene, + frame_idx=int(src_idx), + rgb_u8=src_img, + intrinsics_k3=intr_all[src_idx].to(torch.float32), + ) + tgt_depth = self._get_pseudo_depth_for_frame( + scene=scene, + frame_idx=int(tgt_idx), + rgb_u8=tgt_img, + intrinsics_k3=intr_all[tgt_idx].to(torch.float32), + ) + if self.pseudo_depth_root is not None and ( + (not torch.is_tensor(src_depth)) or (not torch.is_tensor(tgt_depth)) + ): + continue + if self.output_h is not None and self.output_w is not None: + oh, ow = int(src_img.shape[1]), int(src_img.shape[2]) + if oh > 0 and ow > 0 and (oh != self.output_h or ow != self.output_w): + sx = float(self.output_w) / float(ow) + sy = float(self.output_h) / float(oh) + src_img = resize_rgb_u8_chw_high_quality(src_img, size=(self.output_h, self.output_w)) + tgt_img = resize_rgb_u8_chw_high_quality(tgt_img, size=(self.output_h, self.output_w)) + src_intr = resize_k3_align_corners_false(src_intr, sx=sx, sy=sy) + tgt_intr = resize_k3_align_corners_false(tgt_intr, sx=sx, sy=sy) + if torch.is_tensor(src_depth): + src_depth = ( + torch.nn.functional.interpolate( + src_depth[None], + size=(self.output_h, self.output_w), + mode="bilinear", + align_corners=False, + ) + .squeeze(0) + .to(torch.float32) + ) + if torch.is_tensor(tgt_depth): + tgt_depth = ( + torch.nn.functional.interpolate( + tgt_depth[None], + size=(self.output_h, self.output_w), + mode="bilinear", + align_corners=False, + ) + .squeeze(0) + .to(torch.float32) + ) + + sample = Re10KPairSample( + src_rgb_u8=src_img, + tgt_rgb_u8=tgt_img, + src_w2c=w2c_all[src_idx], + tgt_w2c=w2c_all[tgt_idx], + src_intrinsics=src_intr, + tgt_intrinsics=tgt_intr, + src_idx=int(src_idx), + tgt_idx=int(tgt_idx), + scene=scene, + src_depth_m=src_depth, + tgt_depth_m=tgt_depth, + ) + hw_key = (int(sample.src_rgb_u8.shape[1]), int(sample.src_rgb_u8.shape[2])) + bucket = pending_by_hw[hw_key] + bucket.append(sample) + if self.batch_size_hint <= 1: + yield bucket.popleft() + continue + while len(bucket) >= self.batch_size_hint: + packed = [bucket.popleft() for _ in range(self.batch_size_hint)] + yield _pack_re10k_batch(packed) + + dropped = sum(len(bucket) for bucket in pending_by_hw.values()) + if dropped > 0 and self.split == "train" and self.batch_size_hint > 1: + LOGGER.debug( + "Dropped %d RE10K leftover samples that could not form a same-resolution batch of size %d.", + int(dropped), + int(self.batch_size_hint), + ) + + +def re10k_collate(batch: list[Re10KPairSample]) -> Re10KPairSample: + return _pack_re10k_batch(batch) + diff --git a/unisharp/datasets/scannetpp_fisheye.py b/unisharp/datasets/scannetpp_fisheye.py new file mode 100644 index 0000000000000000000000000000000000000000..d00e8c84309829dee5e4d75453f20a4e98fad31a --- /dev/null +++ b/unisharp/datasets/scannetpp_fisheye.py @@ -0,0 +1,491 @@ + +from __future__ import annotations + +from dataclasses import dataclass +import json +import logging +from pathlib import Path +import random + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import IterableDataset + +from unisharp import DEFAULT_MAX_DEPTH_M + + +LOGGER = logging.getLogger(__name__) +IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"} +DEPTH_DIR_NAMES = ("depth", "depths", "distance", "distances", "depth_maps") +MASK_DIR_NAMES = ("masks", "mask") +DEPTH_MAX_M = DEFAULT_MAX_DEPTH_M + + +def _qvec_to_rotmat(qvec: np.ndarray) -> np.ndarray: + q = np.asarray(qvec, dtype=np.float64) + return np.array( + [ + [1 - 2 * q[2] ** 2 - 2 * q[3] ** 2, 2 * q[1] * q[2] - 2 * q[0] * q[3], 2 * q[3] * q[1] + 2 * q[0] * q[2]], + [2 * q[1] * q[2] + 2 * q[0] * q[3], 1 - 2 * q[1] ** 2 - 2 * q[3] ** 2, 2 * q[2] * q[3] - 2 * q[0] * q[1]], + [2 * q[3] * q[1] - 2 * q[0] * q[2], 2 * q[2] * q[3] + 2 * q[0] * q[1], 1 - 2 * q[1] ** 2 - 2 * q[2] ** 2], + ], + dtype=np.float64, + ) + + +def _read_colmap_w2c(images_txt: Path) -> dict[str, torch.Tensor]: + poses: dict[str, torch.Tensor] = {} + if not images_txt.exists(): + return poses + with images_txt.open("r", encoding="utf-8") as f: + for raw in f: + line = raw.strip() + if not line or line.startswith("#"): + continue + parts = line.split() + if len(parts) < 10: + continue + try: + qvec = np.asarray([float(x) for x in parts[1:5]], dtype=np.float64) + tvec = np.asarray([float(x) for x in parts[5:8]], dtype=np.float64) + image_name = parts[9] + except Exception: + continue + w2c = np.eye(4, dtype=np.float32) + w2c[:3, :3] = _qvec_to_rotmat(qvec).astype(np.float32) + w2c[:3, 3] = tvec.astype(np.float32) + poses[Path(image_name).name] = torch.from_numpy(w2c) + return poses + + +def _opencv_fisheye_to_fisheye624_params(meta: dict[str, object]) -> torch.Tensor: + if str(meta.get("camera_model", "")) != "OPENCV_FISHEYE": + raise RuntimeError(f"Unsupported ScanNet++ camera_model={meta.get('camera_model')!r}; expected OPENCV_FISHEYE.") + return torch.tensor( + [ + float(meta["fl_x"]), + float(meta["fl_y"]), + float(meta["cx"]), + float(meta["cy"]), + float(meta.get("k1", 0.0)), + float(meta.get("k2", 0.0)), + float(meta.get("k3", 0.0)), + float(meta.get("k4", 0.0)), + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + dtype=torch.float32, + ) + + +def _camera_hw_from_meta(meta: dict[str, object]) -> tuple[int, int] | None: + h = meta.get("h", meta.get("height", None)) + w = meta.get("w", meta.get("width", None)) + if h is None or w is None: + return None + try: + h_i, w_i = int(h), int(w) + except Exception: + return None + return (h_i, w_i) if h_i > 0 and w_i > 0 else None + + +def _scale_fisheye624_params( + params: torch.Tensor, + *, + src_hw: tuple[int, int], + dst_hw: tuple[int, int], +) -> torch.Tensor: + if tuple(int(x) for x in src_hw) == tuple(int(x) for x in dst_hw): + return params.clone() + src_h, src_w = int(src_hw[0]), int(src_hw[1]) + dst_h, dst_w = int(dst_hw[0]), int(dst_hw[1]) + sx = float(dst_w) / float(max(src_w, 1)) + sy = float(dst_h) / float(max(src_h, 1)) + out = params.clone() + out[..., 0] *= sx + out[..., 1] *= sy + out[..., 2] = (out[..., 2] + 0.5) * sx - 0.5 + out[..., 3] = (out[..., 3] + 0.5) * sy - 0.5 + return out + + +def _stack_batch(batch: list["ScannetppFisheyePairSample"]) -> "ScannetppFisheyePairSample": + return ScannetppFisheyePairSample( + src_rgb_u8=torch.stack([b.src_rgb_u8 for b in batch], dim=0), + tgt_rgb_u8=torch.stack([b.tgt_rgb_u8 for b in batch], dim=0), + src_depth_m=torch.stack([b.src_depth_m for b in batch], dim=0), + tgt_depth_m=torch.stack([b.tgt_depth_m for b in batch], dim=0), + src_valid_mask=torch.stack([b.src_valid_mask for b in batch], dim=0), + tgt_valid_mask=torch.stack([b.tgt_valid_mask for b in batch], dim=0), + src_w2c=torch.stack([b.src_w2c for b in batch], dim=0), + tgt_w2c=torch.stack([b.tgt_w2c for b in batch], dim=0), + src_camera_params=torch.stack([b.src_camera_params for b in batch], dim=0), + tgt_camera_params=torch.stack([b.tgt_camera_params for b in batch], dim=0), + src_idx=[b.src_idx for b in batch], # type: ignore[arg-type] + tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type] + scene=[b.scene for b in batch], # type: ignore[arg-type] + camera_model="fisheye624", + ) + + +@dataclass(frozen=True) +class ScannetppFisheyePairSample: + src_rgb_u8: torch.Tensor + tgt_rgb_u8: torch.Tensor + src_depth_m: torch.Tensor + tgt_depth_m: torch.Tensor + src_valid_mask: torch.Tensor + tgt_valid_mask: torch.Tensor + src_w2c: torch.Tensor + tgt_w2c: torch.Tensor + src_camera_params: torch.Tensor + tgt_camera_params: torch.Tensor + src_idx: int + tgt_idx: int + scene: str + camera_model: str = "fisheye624" + + +def scannetpp_fisheye_passthrough(batch: ScannetppFisheyePairSample) -> ScannetppFisheyePairSample: + return batch + + +class ScannetppFisheyeDataset(IterableDataset): + + def __init__( + self, + root: Path, + scene_list_file: Path | None = None, + min_frame_gap: int = 1, + max_frame_gap: int = 10, + pair_max_translation_m: float = 0.5, + shuffle_scene: bool = True, + shuffle_frame: bool = True, + skip_bad: bool = True, + ddp_rank: int = 0, + ddp_world_size: int = 1, + batch_size_hint: int = 1, + depth_max_m: float = DEFAULT_MAX_DEPTH_M, + far_depth_invalid_m: float = 30.0, + seed: int = 0, + ) -> None: + super().__init__() + self.root = Path(root) + self.min_frame_gap = int(min_frame_gap) + self.max_frame_gap = int(max_frame_gap) + self.pair_max_translation_m = float(pair_max_translation_m) + self.shuffle_scene = bool(shuffle_scene) + self.shuffle_frame = bool(shuffle_frame) + self.skip_bad = bool(skip_bad) + self.ddp_rank = int(ddp_rank) + self.ddp_world_size = int(ddp_world_size) + self.batch_size_hint = int(max(1, batch_size_hint)) + self.depth_max_m = float(depth_max_m) + self.far_depth_invalid_m = float(far_depth_invalid_m) + self.seed = int(seed) + self.epoch = 0 + self.scene_specs = self._load_scene_specs(scene_list_file) + if not self.scene_specs: + raise RuntimeError(f"No ScanNet++ fisheye scenes found under {self.root}") + + def set_epoch(self, epoch: int) -> None: + self.epoch = int(epoch) + + def _load_scene_specs(self, scene_list_file: Path | None) -> list[tuple[str, Path]]: + specs: list[tuple[str, Path]] = [] + if scene_list_file is not None and Path(scene_list_file).exists(): + for raw in Path(scene_list_file).read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if not line: + continue + parts = line.split("|") + if len(parts) == 1: + scene_dir = Path(parts[0]) + scene_id = scene_dir.name + else: + scene_id = parts[0] + scene_dir = Path(parts[1]) + if not scene_dir.is_absolute(): + scene_dir = self.root / scene_dir + specs.append((scene_id, scene_dir)) + return specs + for transforms in sorted(self.root.glob("*/nerfstudio/transforms.json")): + specs.append((transforms.parent.parent.name, transforms.parent.parent)) + for transforms in sorted(self.root.glob("*/*/nerfstudio/transforms.json")): + specs.append((f"{transforms.parent.parent.parent.name}/{transforms.parent.parent.name}", transforms.parent.parent)) + return specs + + @staticmethod + def _load_rgb(path: Path) -> torch.Tensor: + with Image.open(path) as image: + arr = np.asarray(image.convert("RGB"), dtype=np.uint8).copy() + return torch.from_numpy(arr).permute(2, 0, 1).contiguous() + + @staticmethod + def _load_mask(path: Path, image_hw: tuple[int, int]) -> torch.Tensor | None: + if not path.exists(): + return None + with Image.open(path) as image: + arr = np.asarray(image.convert("L"), dtype=np.uint8).copy() + mask = torch.from_numpy(arr).unsqueeze(0).to(torch.float32) / 255.0 + if tuple(mask.shape[-2:]) != tuple(image_hw): + mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=image_hw, mode="nearest").squeeze(0) + return (mask > 0.5).to(torch.float32) + + def _load_depth_map(self, path: Path) -> tuple[torch.Tensor, str]: + depth_kind = "distance" + if path.suffix.lower() == ".npz": + payload = np.load(path, allow_pickle=False) + for key in ("distance_m", "depth_m", "distance", "depth"): + if key in payload: + arr = payload[key] + if key in {"distance_m", "distance"}: + depth_kind = "distance" + elif "depth_kind" in payload: + depth_kind = str(np.asarray(payload["depth_kind"]).item()).strip().lower() + break + else: + raise RuntimeError(f"Unsupported ScanNet++ depth payload keys at {path}") + else: + arr = np.load(path) + depth = torch.from_numpy(np.asarray(arr, dtype=np.float32).copy()) + if depth.ndim == 3 and depth.shape[0] == 1: + depth = depth[0] + if depth.ndim != 2: + raise RuntimeError(f"Expected 2D fisheye depth at {path}, got shape={tuple(depth.shape)}") + depth = depth.unsqueeze(0) + valid = torch.isfinite(depth) & (depth > 0.0) + if self.far_depth_invalid_m > 0.0: + valid = valid & (depth <= self.far_depth_invalid_m) + depth = torch.where(valid, depth, torch.zeros_like(depth)) + if depth_kind in {"radial", "radius", "dist"}: + depth_kind = "distance" + if depth_kind not in {"distance", "z"}: + raise RuntimeError(f"Unsupported fisheye depth_kind={depth_kind!r} at {path}") + return depth.clamp(min=0.0, max=self.depth_max_m), depth_kind + + @staticmethod + def _fisheye_z_depth_to_distance(z_depth: torch.Tensor, camera_params: torch.Tensor) -> torch.Tensor: + from unisharp.utils.fisheye_geer import build_fisheye624_raymap + + h, w = int(z_depth.shape[-2]), int(z_depth.shape[-1]) + rays = build_fisheye624_raymap( + camera_params.unsqueeze(0), + image_h=h, + image_w=w, + device=z_depth.device, + dtype=torch.float32, + ) + ray_z = rays[:, 2:3].squeeze(0).to(device=z_depth.device, dtype=z_depth.dtype) + valid = torch.isfinite(z_depth) & (z_depth > 0.0) & torch.isfinite(ray_z) & (ray_z > 1e-4) + distance = z_depth / ray_z.clamp(min=1e-4) + return torch.where(valid, distance, torch.zeros_like(z_depth)) + + def _resolve_image_path(self, scene_dir: Path, image_name: str) -> Path | None: + rel = Path(image_name) + candidates = [ + scene_dir / rel, + scene_dir / "images" / rel.name, + scene_dir / "resized_images" / rel.name, + scene_dir / "dslr" / rel, + scene_dir / "dslr" / "images" / rel.name, + scene_dir / "dslr" / "resized_images" / rel.name, + ] + for path in candidates: + if path.exists() and path.suffix in IMAGE_SUFFIXES: + return path + return None + + def _resolve_depth_path(self, scene_dir: Path, image_name: str) -> Path | None: + stem = Path(image_name).stem + names = [stem, Path(image_name).name] + bases = [scene_dir, scene_dir / "dslr"] + for base in bases: + for depth_dir_name in DEPTH_DIR_NAMES: + depth_dir = base / depth_dir_name + for name in names: + for suffix in (".npz", ".npy"): + path = depth_dir / f"{name}{suffix}" + if path.exists(): + return path + return None + + def _resolve_mask_path(self, scene_dir: Path, image_name: str, mask_name: str | None) -> Path | None: + names = [] + if mask_name: + names.append(Path(mask_name).name) + names.append(f"{Path(image_name).stem}.png") + bases = [scene_dir, scene_dir / "dslr"] + for base in bases: + for name in names: + direct = base / name + if direct.exists(): + return direct + for mask_dir_name in MASK_DIR_NAMES: + path = base / mask_dir_name / name + if path.exists(): + return path + return None + + def _load_scene_frames(self, scene_id: str, scene_dir: Path) -> tuple[torch.Tensor, list[dict[str, object]]]: + transforms_path = scene_dir / "nerfstudio" / "transforms.json" + if not transforms_path.exists(): + transforms_path = scene_dir / "dslr" / "nerfstudio" / "transforms.json" + meta = json.loads(transforms_path.read_text(encoding="utf-8")) + camera_params = _opencv_fisheye_to_fisheye624_params(meta) + camera_hw = _camera_hw_from_meta(meta) + w2c_by_name = _read_colmap_w2c(scene_dir / "colmap" / "images.txt") + if not w2c_by_name: + w2c_by_name = _read_colmap_w2c(scene_dir / "dslr" / "colmap" / "images.txt") + + raw_frames = list(meta.get("frames", [])) + list(meta.get("test_frames", [])) + frames: list[dict[str, object]] = [] + for frame in raw_frames: + image_name = Path(str(frame.get("file_path", ""))).name + if not image_name: + continue + if self.skip_bad and bool(frame.get("is_bad", False)): + continue + image_path = self._resolve_image_path(scene_dir, image_name) + depth_path = self._resolve_depth_path(scene_dir, image_name) + if image_path is None or depth_path is None: + continue + w2c = w2c_by_name.get(image_name) + if w2c is None and frame.get("transform_matrix") is not None: + c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32) + w2c = torch.linalg.inv(c2w) + if w2c is None: + continue + center = torch.linalg.inv(w2c)[:3, 3] + frames.append( + { + "image_name": image_name, + "image_path": image_path, + "depth_path": depth_path, + "mask_path": self._resolve_mask_path(scene_dir, image_name, frame.get("mask_path")), + "w2c": w2c.to(torch.float32), + "center": center.to(torch.float32), + "idx": len(frames), + "scene": scene_id, + "camera_hw": _camera_hw_from_meta(frame) or camera_hw, + } + ) + return camera_params, sorted(frames, key=lambda x: str(x["image_name"])) + + def _load_frame_tensor(self, frame: dict[str, object], camera_params: torch.Tensor) -> dict[str, torch.Tensor]: + rgb = self._load_rgb(frame["image_path"]) # type: ignore[arg-type] + rgb_hw = (int(rgb.shape[-2]), int(rgb.shape[-1])) + camera_hw = frame.get("camera_hw", None) + params = camera_params.clone() + if isinstance(camera_hw, tuple): + params = _scale_fisheye624_params(params, src_hw=camera_hw, dst_hw=rgb_hw) + depth, depth_kind = self._load_depth_map(frame["depth_path"]) # type: ignore[arg-type] + if tuple(depth.shape[-2:]) != tuple(rgb.shape[-2:]): + depth = torch.nn.functional.interpolate( + depth.unsqueeze(0), + size=(int(rgb.shape[-2]), int(rgb.shape[-1])), + mode="nearest", + ).squeeze(0) + if depth_kind == "z": + depth = self._fisheye_z_depth_to_distance(depth, params) + valid = (torch.isfinite(depth) & (depth > 0.0)).to(torch.float32) + mask_path = frame.get("mask_path", None) + if isinstance(mask_path, Path): + mask = self._load_mask(mask_path, (int(rgb.shape[-2]), int(rgb.shape[-1]))) + if mask is not None: + valid = valid * mask + else: + valid = valid * (rgb.to(torch.float32).sum(dim=0, keepdim=True) > 1.0).to(torch.float32) + return { + "rgb_u8": rgb, + "depth_m": depth.clamp(min=0.0, max=self.depth_max_m), + "valid_mask": valid, + "camera_params": params, + } + + def _iter_scene_pairs(self, scene_id: str, scene_dir: Path, rng: random.Random): + try: + camera_params, frames = self._load_scene_frames(scene_id, scene_dir) + except Exception as exc: + LOGGER.debug("Skip ScanNet++ scene %s: %s", str(scene_id), str(exc)) + return + if len(frames) < 2: + return + loaded: dict[int, dict[str, torch.Tensor]] = {} + + def get_loaded(pos: int) -> dict[str, torch.Tensor]: + if pos not in loaded: + loaded[pos] = self._load_frame_tensor(frames[pos], camera_params) + return loaded[pos] + + order = list(range(len(frames))) + if self.shuffle_frame: + rng.shuffle(order) + for src_pos in order: + src_item = frames[src_pos] + src_center = src_item["center"] + assert torch.is_tensor(src_center) + candidates: list[int] = [] + for tgt_pos in range(max(0, src_pos - self.max_frame_gap), min(len(frames), src_pos + self.max_frame_gap + 1)): + if tgt_pos == src_pos: + continue + gap = abs(tgt_pos - src_pos) + if gap < self.min_frame_gap: + continue + tgt_center = frames[tgt_pos]["center"] + assert torch.is_tensor(tgt_center) + if float(torch.norm(tgt_center - src_center, p=2).item()) > self.pair_max_translation_m: + continue + candidates.append(tgt_pos) + if not candidates: + continue + tgt_pos = rng.choice(candidates) + try: + src_loaded = get_loaded(src_pos) + tgt_loaded = get_loaded(tgt_pos) + except Exception: + continue + yield ScannetppFisheyePairSample( + src_rgb_u8=src_loaded["rgb_u8"], + tgt_rgb_u8=tgt_loaded["rgb_u8"], + src_depth_m=src_loaded["depth_m"], + tgt_depth_m=tgt_loaded["depth_m"], + src_valid_mask=src_loaded["valid_mask"], + tgt_valid_mask=tgt_loaded["valid_mask"], + src_w2c=src_item["w2c"], # type: ignore[arg-type] + tgt_w2c=frames[tgt_pos]["w2c"], # type: ignore[arg-type] + src_camera_params=src_loaded["camera_params"], + tgt_camera_params=tgt_loaded["camera_params"], + src_idx=int(src_item["idx"]), + tgt_idx=int(frames[tgt_pos]["idx"]), + scene=str(scene_id), + ) + + def __iter__(self): + worker = torch.utils.data.get_worker_info() + worker_id = 0 if worker is None else int(worker.id) + num_workers = 1 if worker is None else int(worker.num_workers) + rng = random.Random(self.seed + 1009 * self.epoch + 97 * self.ddp_rank + 17 * worker_id) + specs = list(self.scene_specs) + if self.shuffle_scene: + rng.shuffle(specs) + specs = specs[self.ddp_rank :: max(self.ddp_world_size, 1)] + specs = specs[worker_id :: num_workers] + pending: dict[tuple[int, int], list[ScannetppFisheyePairSample]] = {} + for scene_id, scene_dir in specs: + for sample in self._iter_scene_pairs(scene_id, scene_dir, rng): + hw = (int(sample.src_rgb_u8.shape[-2]), int(sample.src_rgb_u8.shape[-1])) + bucket = pending.setdefault(hw, []) + bucket.append(sample) + while len(bucket) >= self.batch_size_hint: + packed = bucket[: self.batch_size_hint] + del bucket[: self.batch_size_hint] + yield _stack_batch(packed) diff --git a/unisharp/datasets/sim_panorama.py b/unisharp/datasets/sim_panorama.py new file mode 100644 index 0000000000000000000000000000000000000000..22f032fb48e3622e6f682e6895f160cec823e17f --- /dev/null +++ b/unisharp/datasets/sim_panorama.py @@ -0,0 +1,497 @@ + +from __future__ import annotations + +import csv +from dataclasses import dataclass +import os +from pathlib import Path +import random +import re + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch.utils.data import IterableDataset + +from unisharp.datasets.panogs import PanOGSSample +from unisharp import DEFAULT_MAX_DEPTH_M + +try: + import h5py +except ImportError: + h5py = None + + +_NUM_RE = re.compile(r"(\d+)(?!.*\d)") +_SIM_CACHE_VERSION = 6 + + +def _default_dataset_manifest_dir() -> Path: + repo_root = Path(__file__).resolve().parents[2] + parent_path = repo_root.parent / "dataset_manifests" + if parent_path.exists(): + return parent_path + return repo_root / "dataset_manifests" + + +def _frame_index_from_name(name: str) -> int | None: + match = _NUM_RE.search(Path(name).stem) + if match is None: + return None + return int(match.group(1)) + + +def _sim_csv_xyz_to_training_position(x: float, y: float, z: float) -> torch.Tensor: + return torch.tensor([float(y), -float(z), float(x)], dtype=torch.float32) + + +class _EquirecToCube: + def __init__(self, equ_h: int, equ_w: int, face_w: int) -> None: + self.equ_h = int(equ_h) + self.equ_w = int(equ_w) + self.face_w = int(face_w) + self.grid = self._build_grid() + rng = torch.linspace(-0.5, 0.5, steps=self.face_w, dtype=torch.float32) + xx, yy = torch.meshgrid(rng, -rng, indexing="xy") + self.ray_z = (1.0 / torch.sqrt((2.0 * xx) ** 2 + (2.0 * yy) ** 2 + 1.0)).contiguous() + + def _build_grid(self) -> torch.Tensor: + face_w = self.face_w + rng = torch.linspace(-0.5, 0.5, steps=face_w, dtype=torch.float32) + grid = torch.stack(torch.meshgrid(rng, -rng, indexing="xy"), dim=-1) + xyz = torch.zeros((6, face_w, face_w, 3), dtype=torch.float32) + + xyz[0, :, :, 0] = grid[:, :, 0] + xyz[0, :, :, 1] = grid[:, :, 1] + xyz[0, :, :, 2] = 0.5 + + xyz[1, :, :, 2] = torch.flip(grid[:, :, 0], dims=[1]) + xyz[1, :, :, 1] = torch.flip(grid[:, :, 1], dims=[1]) + xyz[1, :, :, 0] = 0.5 + + xyz[2, :, :, 0] = torch.flip(grid[:, :, 0], dims=[1]) + xyz[2, :, :, 1] = torch.flip(grid[:, :, 1], dims=[1]) + xyz[2, :, :, 2] = -0.5 + + xyz[3, :, :, 2] = grid[:, :, 0] + xyz[3, :, :, 1] = grid[:, :, 1] + xyz[3, :, :, 0] = -0.5 + + xyz[4, :, :, 0] = torch.flip(grid[:, :, 0], dims=[0]) + xyz[4, :, :, 2] = torch.flip(grid[:, :, 1], dims=[0]) + xyz[4, :, :, 1] = 0.5 + + xyz[5, :, :, 0] = grid[:, :, 0] + xyz[5, :, :, 2] = grid[:, :, 1] + xyz[5, :, :, 1] = -0.5 + + xyz = xyz[[4, 2, 3, 0, 1, 5]] + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + lon = torch.atan2(x, z) + c = torch.sqrt(x * x + z * z).clamp(min=1e-8) + lat = torch.atan2(y, c) + grid_x = lon / np.pi + grid_y = (-2.0 * lat / np.pi).clamp(min=-1.0, max=1.0) + return torch.stack([grid_x, grid_y], dim=-1).contiguous() + + def run_depth(self, depth_1hw: torch.Tensor) -> torch.Tensor: + depth = depth_1hw.unsqueeze(0).to(torch.float32) + if tuple(depth.shape[-2:]) != (self.equ_h, self.equ_w): + depth = F.interpolate(depth, size=(self.equ_h, self.equ_w), mode="nearest") + depth_faces = F.grid_sample( + depth.expand(6, -1, -1, -1), + self.grid, + mode="nearest", + padding_mode="border", + align_corners=True, + ) + depth_faces = depth_faces[:, 0] * self.ray_z.to(depth_faces.device, depth_faces.dtype) + return depth_faces.unsqueeze(-1).to(torch.float32).cpu() + + def run(self, rgb_chw: torch.Tensor, depth_1hw: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + rgb = rgb_chw.unsqueeze(0).to(torch.float32) / 255.0 + if tuple(rgb.shape[-2:]) != (self.equ_h, self.equ_w): + rgb = F.interpolate(rgb, size=(self.equ_h, self.equ_w), mode="bilinear", align_corners=True) + rgb_faces = F.grid_sample( + rgb.expand(6, -1, -1, -1), + self.grid, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + cube_rgb = (rgb_faces.permute(0, 2, 3, 1).clamp(0.0, 1.0) * 255.0).round().to(torch.uint8) + cube_depth = self.run_depth(depth_1hw) + return cube_rgb.cpu(), cube_depth + + def run_rgb(self, rgb_chw: torch.Tensor) -> torch.Tensor: + rgb = rgb_chw.unsqueeze(0).to(torch.float32) / 255.0 + if tuple(rgb.shape[-2:]) != (self.equ_h, self.equ_w): + rgb = F.interpolate(rgb, size=(self.equ_h, self.equ_w), mode="bilinear", align_corners=True) + rgb_faces = F.grid_sample( + rgb.expand(6, -1, -1, -1), + self.grid, + mode="bilinear", + padding_mode="border", + align_corners=True, + ) + return (rgb_faces.permute(0, 2, 3, 1).clamp(0.0, 1.0) * 255.0).round().to(torch.uint8).cpu() + + +@dataclass(frozen=True) +class _SimFrame: + frame_idx: int + rgb_path: Path + depth_path: Path + position_xyz: torch.Tensor + + +class SimPanoramaDataset(IterableDataset): + def __init__( + self, + root: Path, + pose_root: Path, + scene_names: list[str] | None = None, + scene_list_file: Path | None = None, + position_scale: float = 0.01, + max_index_gap: int = 10, + pair_max_translation_m: float = 0.5, + pair_min_depth_overlap: float = 0.6, + pair_overlap_margin: float = 1.05, + pairs_per_chunk: int = 15, + chunk_size: int = 30, + shuffle_scene: bool = True, + ddp_rank: int = 0, + ddp_world_size: int = 1, + depth_max_m: float = DEFAULT_MAX_DEPTH_M, + far_depth_invalid_m: float = 30.0, + far_depth_invalid_max_frac: float = 1.0, + max_long_edge: int = 0, + seed: int = 0, + ) -> None: + super().__init__() + self.root = Path(root) + self.pose_root = Path(pose_root) + self.scene_list_file = Path(scene_list_file) if scene_list_file is not None else None + requested_scene_names = [str(name).strip() for name in (scene_names or []) if str(name).strip()] + if self.scene_list_file is not None: + if not self.scene_list_file.exists(): + raise FileNotFoundError(self.scene_list_file) + manifest_scene_names = [ + line.strip() + for line in self.scene_list_file.read_text(encoding="utf-8").splitlines() + if line.strip() + ] + if requested_scene_names: + requested = set(requested_scene_names) + self.scene_names = [name for name in manifest_scene_names if name in requested] + else: + self.scene_names = manifest_scene_names + else: + self.scene_names = requested_scene_names + if not self.scene_names: + raise ValueError("SimPanoramaDataset requires scene_names or scene_list_file.") + self.position_scale = float(position_scale) + self.max_index_gap = int(max_index_gap) + self.pair_max_translation_m = float(pair_max_translation_m) + self.pair_min_depth_overlap = float(pair_min_depth_overlap) + self.pair_overlap_margin = float(pair_overlap_margin) + self.pairs_per_chunk = int(pairs_per_chunk) + self.chunk_size = int(chunk_size) + self.shuffle_scene = bool(shuffle_scene) + self.ddp_rank = int(ddp_rank) + self.ddp_world_size = int(ddp_world_size) + self.seed = int(seed) + self.depth_max_m = float(depth_max_m) + self.far_depth_invalid_m = float(far_depth_invalid_m) + self.far_depth_invalid_max_frac = float(far_depth_invalid_max_frac) + self.max_long_edge = max(int(max_long_edge), 0) + self.epoch = 0 + self.cache_dir = _default_dataset_manifest_dir() / "sim_cache" + self.cache_dir.mkdir(parents=True, exist_ok=True) + self._scene_frames_cache: dict[str, list[_SimFrame]] = {} + + def set_epoch(self, epoch: int) -> None: + self.epoch = int(epoch) + + @staticmethod + def _is_depth_path(path: Path) -> bool: + tokens = [part.lower() for part in path.parts] + name = path.name.lower() + return ("depth" in name) or any("depth" in token for token in tokens) + + @staticmethod + def _is_image_path(path: Path) -> bool: + return path.suffix.lower() in (".png", ".jpg", ".jpeg", ".webp") + + @staticmethod + def _load_rgb(path: Path) -> torch.Tensor: + with Image.open(path) as img: + img = img.convert("RGB") + arr = np.asarray(img, dtype=np.uint8).copy() + return torch.from_numpy(arr).permute(2, 0, 1).contiguous() + + @staticmethod + def _image_hw(path: Path) -> tuple[int, int]: + with Image.open(path) as img: + width, height = img.size + return int(height), int(width) + + def _load_depth(self, path: Path) -> torch.Tensor: + suffix = path.suffix.lower() + if suffix == ".npy": + dep = np.load(path) + elif suffix == ".npz": + payload = np.load(path) + key = "depth" if "depth" in payload.files else payload.files[0] + dep = payload[key] + elif suffix in (".h5", ".hdf5"): + if h5py is None: + raise ImportError("h5py is required to read sim .h5 depth files but is not installed.") + with h5py.File(path, "r") as f: + keys = list(f.keys()) + if not keys: + raise RuntimeError(f"Empty sim depth file: {path}") + dep = f[keys[0]][()] + else: + with Image.open(path) as img: + dep = np.asarray(img) + dep = dep.astype(np.float32) + if dep.ndim == 3: + dep = dep[..., 0] + dep[~np.isfinite(dep)] = 0.0 + if self.far_depth_invalid_m > 0.0: + far = dep > self.far_depth_invalid_m + if 0.0 < float(far.mean()) <= self.far_depth_invalid_max_frac: + dep[far] = 0.0 + dep = np.clip(dep, a_min=0.0, a_max=self.depth_max_m) + return torch.from_numpy(dep).unsqueeze(0) + + def _resize_erp_if_needed(self, rgb: torch.Tensor, depth: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.max_long_edge <= 0: + return rgb, depth + h = int(rgb.shape[-2]) + w = int(rgb.shape[-1]) + long_edge = max(h, w) + if long_edge <= self.max_long_edge: + return rgb, depth + scale = float(self.max_long_edge) / float(long_edge) + new_h = max(2, int(round(float(h) * scale))) + new_w = max(2, int(round(float(w) * scale))) + rgb_f = rgb.unsqueeze(0).to(dtype=torch.float32) + rgb_resized = F.interpolate(rgb_f, size=(new_h, new_w), mode="bilinear", align_corners=False) + rgb_out = rgb_resized[0].round().clamp(0.0, 255.0).to(dtype=torch.uint8).contiguous() + depth_f = depth.unsqueeze(0).to(dtype=torch.float32) + depth_out = F.interpolate(depth_f, size=(new_h, new_w), mode="nearest")[0].contiguous() + return rgb_out, depth_out + + def _pose_csv_for_scene(self, scene_name: str) -> Path: + direct = self.pose_root / f"{scene_name}.csv" + if direct.exists(): + return direct + matches = sorted(self.pose_root.glob(f"*{scene_name}*.csv")) + if matches: + return matches[0] + raise FileNotFoundError(f"No pose csv found for sim scene={scene_name} under {self.pose_root}") + + def _parse_pose_csv(self, csv_path: Path) -> list[tuple[int, torch.Tensor]]: + with csv_path.open("r", encoding="utf-8") as f: + rows = list(csv.DictReader(f)) + if not rows: + raise RuntimeError(f"Empty sim pose csv: {csv_path}") + poses: list[tuple[int, torch.Tensor]] = [] + for row_idx, row in enumerate(rows): + lower = {str(k).strip().lower(): v for k, v in row.items()} + frame_val = None + for key in ("frame", "frame_idx", "idx", "index", "id", "image", "filename", "name"): + if key in lower and str(lower[key]).strip(): + frame_val = _frame_index_from_name(str(lower[key])) + if frame_val is None: + try: + frame_val = int(float(str(lower[key]).strip())) + except Exception: + frame_val = None + break + x = next((lower[k] for k in lower if k in ("x", "tx", "pos_x", "world_x")), None) + y = next((lower[k] for k in lower if k in ("y", "ty", "pos_y", "world_y")), None) + z = next((lower[k] for k in lower if k in ("z", "tz", "pos_z", "world_z")), None) + if x is None or y is None or z is None: + numeric_vals = [] + for val in row.values(): + try: + numeric_vals.append(float(str(val).strip())) + except Exception: + continue + if len(numeric_vals) < 3: + raise ValueError(f"Failed to parse xyz from sim csv row: {row}") + x, y, z = numeric_vals[:3] + pos = _sim_csv_xyz_to_training_position(float(x), float(y), float(z)) * self.position_scale + poses.append((int(frame_val if frame_val is not None else row_idx), pos)) + return poses + + def _scan_scene_frames(self, scene_name: str) -> list[_SimFrame]: + scene_dir = self.root / scene_name + if not scene_dir.exists(): + raise FileNotFoundError(scene_dir) + all_files = [p for p in scene_dir.rglob("*") if p.is_file()] + image_map: dict[int, Path] = {} + depth_map: dict[int, Path] = {} + for path in all_files: + idx = _frame_index_from_name(path.name) + if idx is None: + continue + if self._is_depth_path(path) and path.suffix.lower() in (".png", ".npy", ".npz", ".exr", ".h5", ".hdf5"): + depth_map.setdefault(idx, path) + elif self._is_image_path(path): + image_map.setdefault(idx, path) + pose_entries = self._parse_pose_csv(self._pose_csv_for_scene(scene_name)) + frames: list[_SimFrame] = [] + for frame_idx, pos in pose_entries: + rgb_path = image_map.get(int(frame_idx)) + depth_path = depth_map.get(int(frame_idx)) + if rgb_path is None or depth_path is None: + continue + frames.append(_SimFrame(frame_idx=int(frame_idx), rgb_path=rgb_path, depth_path=depth_path, position_xyz=pos)) + return frames + + @staticmethod + def _atomic_torch_save(path: Path, payload: object) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + f".tmp.{os.getpid()}") + torch.save(payload, tmp_path) + os.replace(tmp_path, path) + + def _scene_index_cache_path(self, scene_name: str) -> Path: + scene_key = scene_name.replace("/", "__") + return self.cache_dir / f"{scene_key}_ps{self.position_scale:g}_frames_v{_SIM_CACHE_VERSION}.pt" + + def _load_or_build_scene_frames(self, scene_name: str) -> list[_SimFrame]: + cached = self._scene_frames_cache.get(scene_name) + if cached is not None: + return cached + cache_path = self._scene_index_cache_path(scene_name) + frames: list[_SimFrame] + if cache_path.exists(): + try: + payload = torch.load(cache_path, map_location="cpu") + frames = [ + _SimFrame( + frame_idx=int(item["frame_idx"]), + rgb_path=Path(str(item["rgb_path"])), + depth_path=Path(str(item["depth_path"])), + position_xyz=torch.tensor(item["position_xyz"], dtype=torch.float32), + ) + for item in payload["frames"] + ] + except Exception: + frames = self._scan_scene_frames(scene_name) + payload = { + "scene": scene_name, + "frames": [ + { + "frame_idx": int(frame.frame_idx), + "rgb_path": str(frame.rgb_path), + "depth_path": str(frame.depth_path), + "position_xyz": frame.position_xyz.tolist(), + } + for frame in frames + ], + } + self._atomic_torch_save(cache_path, payload) + else: + frames = self._scan_scene_frames(scene_name) + payload = { + "scene": scene_name, + "frames": [ + { + "frame_idx": int(frame.frame_idx), + "rgb_path": str(frame.rgb_path), + "depth_path": str(frame.depth_path), + "position_xyz": frame.position_xyz.tolist(), + } + for frame in frames + ], + } + self._atomic_torch_save(cache_path, payload) + self._scene_frames_cache[scene_name] = frames + return frames + + def _random_chunk_pairs(self, chunk: list[_SimFrame], rng: random.Random) -> list[tuple[int, int]]: + if len(chunk) < self.chunk_size: + return [] + indices = list(range(len(chunk))) + rng.shuffle(indices) + max_pairs = min(self.pairs_per_chunk, len(indices) // 2) + return [(indices[2 * i], indices[2 * i + 1]) for i in range(max_pairs)] + + def __iter__(self): + scene_names = list(self.scene_names) + order_rng = random.Random(self.seed + self.epoch) + if self.shuffle_scene: + order_rng.shuffle(scene_names) + worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + worker_id = worker_info.id if worker_info is not None else 0 + total_shards = max(1, self.ddp_world_size * num_workers) + shard_id = self.ddp_rank * num_workers + worker_id + pair_unit_index = 0 + + for scene_order_idx, scene_name in enumerate(scene_names): + try: + frames = self._load_or_build_scene_frames(scene_name) + except Exception: + continue + if len(frames) < self.chunk_size: + continue + for start in range(0, len(frames), self.chunk_size): + chunk = frames[start : start + self.chunk_size] + if len(chunk) < self.chunk_size: + break + try: + equ_h, equ_w = self._image_hw(chunk[0].rgb_path) + if self.max_long_edge > 0 and max(equ_h, equ_w) > self.max_long_edge: + scale = float(self.max_long_edge) / float(max(equ_h, equ_w)) + equ_h = max(2, int(round(float(equ_h) * scale))) + equ_w = max(2, int(round(float(equ_w) * scale))) + face_w = max(1, equ_h // 2) + converter = _EquirecToCube(equ_h=equ_h, equ_w=equ_w, face_w=face_w) + except Exception: + continue + def load_frame(local_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + frame = chunk[local_idx] + rgb = self._load_rgb(frame.rgb_path) + depth = self._load_depth(frame.depth_path) + rgb, depth = self._resize_erp_if_needed(rgb, depth) + cube_rgb, cube_depth = converter.run(rgb, depth) + return rgb, depth, cube_rgb, cube_depth + + chunk_rng = random.Random( + self.seed + self.epoch * 1000003 + scene_order_idx * 1009 + start + ) + pairs = self._random_chunk_pairs(chunk, chunk_rng) + for src_local, tgt_local in pairs: + if pair_unit_index % total_shards != shard_id: + pair_unit_index += 1 + continue + pair_unit_index += 1 + src_rgb, src_depth, src_cube_rgb, src_cube_depth = load_frame(src_local) + tgt_rgb, tgt_depth, tgt_cube_rgb, tgt_cube_depth = load_frame(tgt_local) + yield PanOGSSample( + src_erp_rgb_u8=src_rgb, + tgt_erp_rgb_u8=tgt_rgb, + src_erp_depth_m=src_depth, + tgt_erp_depth_m=tgt_depth, + src_cube_rgb_u8=src_cube_rgb, + tgt_cube_rgb_u8=tgt_cube_rgb, + src_cube_depth_m=src_cube_depth, + tgt_cube_depth_m=tgt_cube_depth, + src_R=torch.eye(3, dtype=torch.float32), + src_t=chunk[src_local].position_xyz.clone(), + tgt_R=torch.eye(3, dtype=torch.float32), + tgt_t=chunk[tgt_local].position_xyz.clone(), + src_idx=int(chunk[src_local].frame_idx), + tgt_idx=int(chunk[tgt_local].frame_idx), + scene=str(scene_name), + ) diff --git a/unisharp/datasets/wildrgbd.py b/unisharp/datasets/wildrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..24f0be83d58c82f421fb114bfb5f2d1a90760432 --- /dev/null +++ b/unisharp/datasets/wildrgbd.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +from dataclasses import dataclass +import json +from pathlib import Path +import random + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torch.utils.data import IterableDataset + +from unisharp.datasets.pair_sampling import project_overlap_ratio, resize_k3_align_corners_false, resize_rgb_u8_chw_high_quality +from unisharp import DEFAULT_MAX_DEPTH_M + + +@dataclass(frozen=True) +class WildRGBDPairSample: + src_rgb_u8: torch.Tensor + tgt_rgb_u8: torch.Tensor + src_depth_m: torch.Tensor + tgt_depth_m: torch.Tensor + src_w2c: torch.Tensor + tgt_w2c: torch.Tensor + src_intrinsics: torch.Tensor + tgt_intrinsics: torch.Tensor + src_idx: int + tgt_idx: int + scene: str + + +class WildRGBDDataset(IterableDataset): + + def __init__( + self, + root: Path | None = None, + scene_list_file: Path | None = None, + split: str = "train", + min_frame_gap: int = 1, + max_frame_gap: int = 32, + pair_max_translation_m: float = 0.5, + pair_min_overlap: float = 0.6, + pair_overlap_sample_h: int = 32, + pair_overlap_sample_w: int = 56, + pair_max_tries: int = 32, + output_h: int | None = None, + output_w: int | None = None, + shuffle_scene: bool = True, + shuffle_frame: bool = True, + ddp_rank: int = 0, + ddp_world_size: int = 1, + roots: list[Path] | None = None, + depth_max_m: float = DEFAULT_MAX_DEPTH_M, + seed: int = 0, + verify_manifest_paths: bool = False, + ) -> None: + super().__init__() + self.root = root + self.split = split + self.min_frame_gap = int(min_frame_gap) + self.max_frame_gap = int(max_frame_gap) + self.pair_max_translation_m = float(pair_max_translation_m) + self.pair_min_overlap = float(pair_min_overlap) + self.pair_overlap_sample_h = int(pair_overlap_sample_h) + self.pair_overlap_sample_w = int(pair_overlap_sample_w) + self.pair_max_tries = int(pair_max_tries) + self.output_h = int(output_h) if output_h is not None else None + self.output_w = int(output_w) if output_w is not None else None + self.shuffle_scene = bool(shuffle_scene) + self.shuffle_frame = bool(shuffle_frame) + self.ddp_rank = int(ddp_rank) + self.ddp_world_size = int(ddp_world_size) + self.depth_max_m = float(depth_max_m) + self.seed = int(seed) + self.epoch = 0 + self.verify_manifest_paths = bool(verify_manifest_paths) + self.roots = [Path(p) for p in roots] if roots is not None else ([Path(root)] if root is not None else []) + if not self.roots: + raise ValueError("WildRGBDDataset requires at least one root path.") + self.scene_dirs: list[Path] = [] + self.scene_list_file = Path(scene_list_file) if scene_list_file is not None else None + if self.scene_list_file is not None: + if not self.scene_list_file.exists(): + raise FileNotFoundError(self.scene_list_file) + for raw in self.scene_list_file.read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if not line: + continue + scene_dir = Path(line) + if (not self.verify_manifest_paths) or scene_dir.is_dir(): + self.scene_dirs.append(scene_dir) + else: + for ds_root in self.roots: + split_dir = ds_root / self.split + if not split_dir.exists(): + raise FileNotFoundError(split_dir) + self.scene_dirs.extend(sorted([p for p in split_dir.iterdir() if p.is_dir()])) + if not self.scene_dirs: + raise RuntimeError("No scene folders found in the configured WildRGBD roots.") + + def set_epoch(self, epoch: int) -> None: + self.epoch = int(epoch) + + @staticmethod + def _load_scene_pose_and_k(scene_dir: Path) -> tuple[np.ndarray, dict[int, np.ndarray], torch.Tensor]: + metadata_path = scene_dir / "metadata" + with metadata_path.open("r", encoding="utf-8") as f: + meta = json.load(f) + k_raw = np.asarray(meta["K"], dtype=np.float32).reshape(3, 3).T + k = torch.from_numpy(k_raw.copy()).to(torch.float32) + + pose_path = scene_dir / "cam_poses.txt" + pose_rows = np.genfromtxt(str(pose_path), dtype=np.float32) + if pose_rows.ndim == 1: + pose_rows = pose_rows[None, :] + if pose_rows.shape[1] < 17: + raise ValueError(f"Bad cam_poses.txt shape={pose_rows.shape} at {pose_path}") + frame_ids = pose_rows[:, 0].astype(np.int64) + c2w = pose_rows[:, 1:17].reshape(-1, 4, 4).astype(np.float32) + w2c = np.linalg.inv(c2w).astype(np.float32) + w2c_map = {int(fid): w2c[i] for i, fid in enumerate(frame_ids.tolist())} + return frame_ids, w2c_map, k + + @staticmethod + def _collect_frame_ids(folder: Path) -> set[int]: + ids: set[int] = set() + if not folder.exists(): + return ids + for p in folder.iterdir(): + if not p.is_file(): + continue + if p.suffix.lower() not in (".png", ".jpg", ".jpeg"): + continue + try: + ids.add(int(p.stem)) + except ValueError: + continue + return ids + + @staticmethod + def _resolve_img_path(folder: Path, idx: int) -> Path: + for ext in (".png", ".jpg", ".jpeg"): + p = folder / f"{idx:05d}{ext}" + if p.exists(): + return p + raise FileNotFoundError(folder / f"{idx:05d}.png") + + @staticmethod + def _load_rgb_u8(path: Path) -> torch.Tensor: + img = Image.open(path).convert("RGB") + arr = np.asarray(img, dtype=np.uint8).copy() + return torch.from_numpy(arr).permute(2, 0, 1).contiguous() + + def _load_depth_m(self, depth_path: Path) -> torch.Tensor: + dep = np.asarray(Image.open(depth_path)) + if dep.ndim != 2: + raise ValueError(f"Expected single-channel depth at {depth_path}, got {dep.shape}") + depth = dep.astype(np.float32) + if float(np.nanmax(depth)) > 200.0: + depth = depth / 1000.0 + depth[~np.isfinite(depth)] = 0.0 + depth = np.clip(depth, a_min=0.0, a_max=self.depth_max_m) + return torch.from_numpy(depth).unsqueeze(0).to(torch.float32) + + @staticmethod + def _scene_name(scene_dir: Path) -> str: + parent = scene_dir.parent.parent.name if scene_dir.parent.name == "scenes" else scene_dir.parent.name + return f"{parent}/{scene_dir.name}" + + def _sample_target_for_src( + self, + src_idx: int, + valid_ids: list[int], + w2c_map: dict[int, np.ndarray], + intr: torch.Tensor, + h: int, + w: int, + rng: random.Random, + ) -> int | None: + src_w2c = torch.from_numpy(w2c_map[int(src_idx)]).to(torch.float32) + src_center = torch.linalg.inv(src_w2c)[:3, 3] + candidates: list[int] = [] + for j in valid_ids: + if int(j) == int(src_idx): + continue + gap = abs(int(j) - int(src_idx)) + if gap < self.min_frame_gap or gap > self.max_frame_gap: + continue + jw2c = torch.from_numpy(w2c_map[int(j)]).to(torch.float32) + jcenter = torch.linalg.inv(jw2c)[:3, 3] + trans = torch.norm(jcenter - src_center, p=2).item() + if trans > self.pair_max_translation_m: + continue + candidates.append(int(j)) + if not candidates: + return None + + rng.shuffle(candidates) + tries = min(self.pair_max_tries, len(candidates)) + src_k = intr.to(torch.float32) + src_w2c_t = src_w2c.to(torch.float32) + for j in candidates[:tries]: + tgt_w2c_t = torch.from_numpy(w2c_map[int(j)]).to(torch.float32) + ov_st = project_overlap_ratio( + src_w2c=src_w2c_t, + tgt_w2c=tgt_w2c_t, + src_k=src_k, + tgt_k=src_k, + h=h, + w=w, + sample_h=self.pair_overlap_sample_h, + sample_w=self.pair_overlap_sample_w, + ) + ov_ts = project_overlap_ratio( + src_w2c=tgt_w2c_t, + tgt_w2c=src_w2c_t, + src_k=src_k, + tgt_k=src_k, + h=h, + w=w, + sample_h=self.pair_overlap_sample_h, + sample_w=self.pair_overlap_sample_w, + ) + if 0.5 * (ov_st + ov_ts) >= self.pair_min_overlap: + return int(j) + return None + + def __iter__(self): + scenes = list(self.scene_dirs) + order_rng = random.Random(self.seed + self.epoch) + if self.shuffle_scene: + order_rng.shuffle(scenes) + + worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + worker_id = worker_info.id if worker_info is not None else 0 + total_shards = max(1, self.ddp_world_size * num_workers) + shard_id = self.ddp_rank * num_workers + worker_id + src_unit_index = 0 + + for scene_order_idx, scene_dir in enumerate(scenes): + try: + pose_ids_np, w2c_map, intr = self._load_scene_pose_and_k(scene_dir) + except Exception: + continue + pose_ids = {int(x) for x in pose_ids_np.tolist()} + rgb_ids = self._collect_frame_ids(scene_dir / "rgb") + dep_ids = self._collect_frame_ids(scene_dir / "depth") + valid_ids = sorted(list(pose_ids & rgb_ids & dep_ids)) + if len(valid_ids) < 2: + continue + + src_order = list(valid_ids) + scene_rng = random.Random(self.seed + self.epoch * 1000003 + scene_order_idx) + if self.shuffle_frame: + scene_rng.shuffle(src_order) + + for src_idx in src_order: + if src_unit_index % total_shards != shard_id: + src_unit_index += 1 + continue + src_unit_index += 1 + try: + rgb_src_path = self._resolve_img_path(scene_dir / "rgb", int(src_idx)) + dep_src_path = self._resolve_img_path(scene_dir / "depth", int(src_idx)) + src_img = self._load_rgb_u8(rgb_src_path) + src_depth = self._load_depth_m(dep_src_path) + except Exception: + continue + + h, w = int(src_img.shape[1]), int(src_img.shape[2]) + tgt_idx = self._sample_target_for_src( + src_idx=int(src_idx), + valid_ids=valid_ids, + w2c_map=w2c_map, + intr=intr, + h=h, + w=w, + rng=scene_rng, + ) + if tgt_idx is None: + continue + + try: + rgb_tgt_path = self._resolve_img_path(scene_dir / "rgb", int(tgt_idx)) + dep_tgt_path = self._resolve_img_path(scene_dir / "depth", int(tgt_idx)) + tgt_img = self._load_rgb_u8(rgb_tgt_path) + tgt_depth = self._load_depth_m(dep_tgt_path) + except Exception: + continue + + if src_img.shape != tgt_img.shape: + continue + + src_intr = intr.clone() + tgt_intr = intr.clone() + if self.output_h is not None and self.output_w is not None: + oh, ow = int(src_img.shape[1]), int(src_img.shape[2]) + if oh > 0 and ow > 0 and (oh != self.output_h or ow != self.output_w): + sx = float(self.output_w) / float(ow) + sy = float(self.output_h) / float(oh) + src_img = resize_rgb_u8_chw_high_quality(src_img, size=(self.output_h, self.output_w)) + tgt_img = resize_rgb_u8_chw_high_quality(tgt_img, size=(self.output_h, self.output_w)) + src_depth = F.interpolate( + src_depth.unsqueeze(0), + size=(self.output_h, self.output_w), + mode="nearest", + ).squeeze(0) + tgt_depth = F.interpolate( + tgt_depth.unsqueeze(0), + size=(self.output_h, self.output_w), + mode="nearest", + ).squeeze(0) + src_intr = resize_k3_align_corners_false(src_intr, sx=sx, sy=sy) + tgt_intr = resize_k3_align_corners_false(tgt_intr, sx=sx, sy=sy) + + yield WildRGBDPairSample( + src_rgb_u8=src_img, + tgt_rgb_u8=tgt_img, + src_depth_m=src_depth, + tgt_depth_m=tgt_depth, + src_w2c=torch.from_numpy(w2c_map[int(src_idx)]).to(torch.float32), + tgt_w2c=torch.from_numpy(w2c_map[int(tgt_idx)]).to(torch.float32), + src_intrinsics=src_intr, + tgt_intrinsics=tgt_intr, + src_idx=int(src_idx), + tgt_idx=int(tgt_idx), + scene=self._scene_name(scene_dir), + ) + + +def wildrgbd_collate(batch: list[WildRGBDPairSample]) -> WildRGBDPairSample: + def stack(xs): + if isinstance(xs[0], torch.Tensor): + return torch.stack(xs, dim=0) + return xs + + return WildRGBDPairSample( + src_rgb_u8=stack([b.src_rgb_u8 for b in batch]), + tgt_rgb_u8=stack([b.tgt_rgb_u8 for b in batch]), + src_depth_m=stack([b.src_depth_m for b in batch]), + tgt_depth_m=stack([b.tgt_depth_m for b in batch]), + src_w2c=stack([b.src_w2c for b in batch]), + tgt_w2c=stack([b.tgt_w2c for b in batch]), + src_intrinsics=stack([b.src_intrinsics for b in batch]), + tgt_intrinsics=stack([b.tgt_intrinsics for b in batch]), + src_idx=[b.src_idx for b in batch], # type: ignore[arg-type] + tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type] + scene=[b.scene for b in batch], # type: ignore[arg-type] + ) + diff --git a/unisharp/losses/__init__.py b/unisharp/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc5e1270c8ad65b156e3f2a80e3802407bd40497 --- /dev/null +++ b/unisharp/losses/__init__.py @@ -0,0 +1,4 @@ +from .unisharp_loss import UnisharpLoss, UnisharpLossWeights + +__all__ = ["UnisharpLoss", "UnisharpLossWeights"] + diff --git a/unisharp/losses/__pycache__/__init__.cpython-310.pyc b/unisharp/losses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e6c82dca21dd78bacbf010ff890f2e75e70f8ff Binary files /dev/null and b/unisharp/losses/__pycache__/__init__.cpython-310.pyc differ diff --git a/unisharp/losses/__pycache__/__init__.cpython-313.pyc b/unisharp/losses/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d985174b75d92d7b08042b6938de474ad7e1746 Binary files /dev/null and b/unisharp/losses/__pycache__/__init__.cpython-313.pyc differ diff --git a/unisharp/losses/__pycache__/unisharp_loss.cpython-310.pyc b/unisharp/losses/__pycache__/unisharp_loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb4886f777fdc143805ab99592944209b59f1a03 Binary files /dev/null and b/unisharp/losses/__pycache__/unisharp_loss.cpython-310.pyc differ diff --git a/unisharp/losses/__pycache__/unisharp_loss.cpython-313.pyc b/unisharp/losses/__pycache__/unisharp_loss.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dc89be3ad29b2830c13c1027fe2cbf3d1a3b97f Binary files /dev/null and b/unisharp/losses/__pycache__/unisharp_loss.cpython-313.pyc differ diff --git a/unisharp/losses/unisharp_loss.py b/unisharp/losses/unisharp_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..d99e5e27d231b06c2db3b86497ac9825894df591 --- /dev/null +++ b/unisharp/losses/unisharp_loss.py @@ -0,0 +1,1120 @@ +from __future__ import annotations + +import torch +from torch import nn +from dataclasses import dataclass +import math +import torch.nn.functional as F + +from unisharp import DEFAULT_MAX_DEPTH_M +from unisharp.utils import linalg + + +def _masked_mean(x: torch.Tensor, m: torch.Tensor) -> torch.Tensor: + if m.dtype != x.dtype: + m = m.to(dtype=x.dtype) + while m.ndim < x.ndim: + m = m.unsqueeze(1) + m_expanded = m.expand_as(x) + return (x * m_expanded).sum() / m_expanded.sum().clamp(min=1.0) + + +def _finite_masked_mean_flat(x: torch.Tensor, valid: torch.Tensor) -> torch.Tensor: + mask = valid.to(device=x.device, dtype=torch.bool) & torch.isfinite(x) + x_safe = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + safe = torch.where(mask, x_safe, torch.zeros_like(x_safe)) + return safe.sum() / mask.to(dtype=x.dtype).sum().clamp(min=1.0) + + +def _finite_abs_mean(x: torch.Tensor) -> torch.Tensor: + mask = torch.isfinite(x) + x_safe = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0) + safe_abs = torch.where(mask, x_safe.abs(), torch.zeros_like(x_safe)) + return safe_abs.sum() / mask.to(dtype=x.dtype).sum().clamp(min=1.0) + + +_ERP_PROJECTION_MODELS = {"erp", "spherical", "equirect", "equirectangular"} +_FISHEYE_PROJECTION_MODELS = {"fisheye624", "opencv_fisheye"} + + +def _tv_l1(img: torch.Tensor) -> torch.Tensor: + zero = torch.zeros((), device=img.device, dtype=img.dtype) + dx = (img[..., :, 1:] - img[..., :, :-1]).abs().mean() if int(img.shape[-1]) > 1 else zero + dy = (img[..., 1:, :] - img[..., :-1, :]).abs().mean() if int(img.shape[-2]) > 1 else zero + return dx + dy + + +def _tv_l1_circular_h(img: torch.Tensor) -> torch.Tensor: + zero = torch.zeros((), device=img.device, dtype=img.dtype) + dx = (torch.roll(img, shifts=-1, dims=-1) - img).abs().mean() if int(img.shape[-1]) > 1 else zero + dy = (img[..., 1:, :] - img[..., :-1, :]).abs().mean() if int(img.shape[-2]) > 1 else zero + return dx + dy + + +def _checkerboard_l1_5d(x: torch.Tensor, *, circular_h: bool) -> torch.Tensor: + if x.ndim != 5: + raise ValueError(f"Expected [B,C,L,H,W], got {tuple(x.shape)}") + if int(x.shape[-2]) < 2 or int(x.shape[-1]) < 2: + return torch.zeros((), device=x.device, dtype=x.dtype) + x = x.to(dtype=torch.float32) + if bool(circular_h): + top = x[..., :-1, :] + bottom = x[..., 1:, :] + response = top - torch.roll(top, shifts=-1, dims=-1) - bottom + torch.roll(bottom, shifts=-1, dims=-1) + else: + response = x[..., :-1, :-1] - x[..., :-1, 1:] - x[..., 1:, :-1] + x[..., 1:, 1:] + return _finite_abs_mean(response) + + +def _delta_grid_checkerboard_loss(delta_grid: torch.Tensor, *, circular_h: bool) -> torch.Tensor: + if delta_grid.ndim != 5 or int(delta_grid.shape[1]) < 14: + raise ValueError(f"Expected delta grid [B,14,L,H,W], got {tuple(delta_grid.shape)}") + delta = delta_grid.to(dtype=torch.float32) + parts = [ + delta[:, 3:6], + 0.1 * delta[:, 10:13], + delta[:, 13:14], + ] + return torch.stack([_checkerboard_l1_5d(part, circular_h=circular_h) for part in parts]).mean() + + +def _avg_pool2d_circular_h(x: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: + if kernel_size <= 1 and stride <= 1: + return x + x = F.pad(x, (kernel_size - 1, 0, 0, 0), mode="circular") + return F.avg_pool2d(x, kernel_size=kernel_size, stride=stride) + + +def _resize_max_side(img: torch.Tensor, max_side: int, *, mode: str = "bilinear") -> torch.Tensor: + if max_side <= 0: + return img + h, w = int(img.shape[-2]), int(img.shape[-1]) + ms = max(h, w) + if ms <= max_side: + return img + scale = float(max_side) / float(ms) + nh = max(1, int(math.floor(h * scale))) + nw = max(1, int(math.floor(w * scale))) + if mode in ("bilinear", "bicubic"): + return F.interpolate(img, size=(nh, nw), mode=mode, align_corners=False) + return F.interpolate(img, size=(nh, nw), mode=mode) + + +def _gram_matrix(fmap: torch.Tensor) -> torch.Tensor: + b, c, h, w = fmap.shape + x = fmap.reshape(b, c, h * w) + return x @ x.transpose(1, 2) + + +class _ResNet50Perceptual(nn.Module): + + def __init__(self) -> None: + super().__init__() + try: + from torchvision.models import resnet50, ResNet50_Weights + + net = resnet50(weights=ResNet50_Weights.DEFAULT) + except Exception: + from torchvision.models import resnet50 + + net = resnet50(pretrained=True) + + net.eval() + net.requires_grad_(False) + + self.conv1 = net.conv1 + self.bn1 = net.bn1 + self.relu = net.relu + self.maxpool = net.maxpool + self.layer1 = net.layer1 + self.layer2 = net.layer2 + self.layer3 = net.layer3 + self.layer4 = net.layer4 + + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + self.register_buffer("_mean", mean, persistent=False) + self.register_buffer("_std", std, persistent=False) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + x = x.clamp(0.0, 1.0) + x = (x - self._mean) / self._std + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + f1 = self.layer1(x) + f2 = self.layer2(f1) + f3 = self.layer3(f2) + f4 = self.layer4(f3) + return [f1, f2, f3, f4] + + +def _to_linear_rgb(img_srgb: torch.Tensor) -> torch.Tensor: + from unisharp.utils.color_space import sRGB2linearRGB + + return sRGB2linearRGB(img_srgb.clamp(0.0, 1.0)) + + +@dataclass +class UnisharpLossWeights: + + lambda_color: float = 1.0 + lambda_alpha: float = 1.5 + lambda_percep: float = 3.0 + lambda_depth: float = 0.5 + lambda_tv: float = 1.0 + lambda_grad: float = 1.0 + lambda_delta: float = 0.0 + lambda_delta_rho: float = 0.0 + lambda_splat: float = 0.0 + lambda_edge_splat: float = 0.0 + lambda_grid: float = 0.0 + lambda_grad_img: float = 0.2 + lambda_edge_rgb: float = 0.0 + + +class UnisharpLoss(nn.Module): + + SUPERVISION_MAX_DEPTH_M: float = DEFAULT_MAX_DEPTH_M + + def __init__( + self, + weights: UnisharpLossWeights | None = None, + *, + grad_sigma: float = 1e-2, + grad_eps: float = 1e-2, + delta_clip: float = 10.0, + raw_delta_clip: float = 400.0, + raw_delta_rho_clip: float = 5.0, + alpha_tail_min: float = 0.99, + alpha_tail_weight: float = 0.0, + splat_sigma_min: float = 1e-1, + splat_sigma_max: float = 1e2, + edge_splat_sigma_max: float = 2.0, + depth_edge_log_threshold: float = 0.05, + depth_edge_dilate_px: int = 2, + percep_max_side: int = 384, + grad_img_scales: int = 4, + grad_img_circular_h: bool = True, + ) -> None: + super().__init__() + self.w = weights or UnisharpLossWeights() + self.grad_sigma = float(grad_sigma) + self.grad_eps = float(grad_eps) + self.delta_clip = float(delta_clip) + self.raw_delta_clip = float(raw_delta_clip) + self.raw_delta_rho_clip = float(raw_delta_rho_clip) + self.alpha_tail_min = float(alpha_tail_min) + self.alpha_tail_weight = float(alpha_tail_weight) + self.splat_sigma_min = float(splat_sigma_min) + self.splat_sigma_max = float(splat_sigma_max) + self.edge_splat_sigma_max = float(edge_splat_sigma_max) + self.depth_edge_log_threshold = float(depth_edge_log_threshold) + self.depth_edge_dilate_px = int(depth_edge_dilate_px) + self.percep_max_side = int(percep_max_side) + self.grad_img_scales = int(grad_img_scales) + self.grad_img_circular_h = bool(grad_img_circular_h) + + sobel_kx = torch.tensor( + [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]] + ).view(1, 1, 3, 3) + sobel_ky = torch.tensor( + [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]] + ).view(1, 1, 3, 3) + self.register_buffer("_sobel_kx", sobel_kx, persistent=False) + self.register_buffer("_sobel_ky", sobel_ky, persistent=False) + + self._percep_net: nn.Module | None = None + if self.w.lambda_percep > 0: + self._percep_net = _ResNet50Perceptual() + + @staticmethod + def _flatten_gaussian_xyz(x: torch.Tensor | None, gauss_grid_shape: tuple[int, int, int] | None = None) -> torch.Tensor | None: + if not torch.is_tensor(x): + return None + if x.ndim == 5: + return x.permute(0, 2, 3, 4, 1).flatten(1, 3) + if x.ndim == 3 and int(x.shape[-1]) == 3: + return x + if x.ndim == 2 and gauss_grid_shape is not None: + return x.unsqueeze(-1) + return None + + @staticmethod + def _flatten_gaussian_quat( + x: torch.Tensor | None, + gauss_grid_shape: tuple[int, int, int] | None = None, + ) -> torch.Tensor | None: + if not torch.is_tensor(x): + return None + if x.ndim == 5 and int(x.shape[1]) == 4: + return x.permute(0, 2, 3, 4, 1).flatten(1, 3) + if x.ndim == 3 and int(x.shape[-1]) == 4: + return x + if x.ndim == 2 and gauss_grid_shape is not None: + return x.unsqueeze(-1) + return None + + @staticmethod + def _flatten_gaussian_scalar( + x: torch.Tensor | None, + gauss_grid_shape: tuple[int, int, int] | None = None, + ) -> torch.Tensor | None: + if not torch.is_tensor(x): + return None + if x.ndim == 5: + return x[:, 0].flatten(1) + if x.ndim == 4: + return x.flatten(1) + if x.ndim == 3 and int(x.shape[-1]) == 1: + return x[..., 0] + if x.ndim == 2: + return x + return None + + @staticmethod + def _central_disparity_gradient(inv_depth: torch.Tensor, *, circular_h: bool) -> torch.Tensor: + if circular_h: + gx = 0.5 * (torch.roll(inv_depth, shifts=-1, dims=-1) - torch.roll(inv_depth, shifts=1, dims=-1)).abs() + else: + padded_x = F.pad(inv_depth, (1, 1, 0, 0), mode="replicate") + gx = 0.5 * (padded_x[..., 2:] - padded_x[..., :-2]).abs() + padded_y = F.pad(inv_depth, (0, 0, 1, 1), mode="replicate") + gy = 0.5 * (padded_y[..., 2:, :] - padded_y[..., :-2, :]).abs() + return torch.sqrt(gx * gx + gy * gy + 1e-12) + + @staticmethod + def _sample_map_at_uv(feat: torch.Tensor, u: torch.Tensor, v: torch.Tensor, valid: torch.Tensor) -> torch.Tensor: + b, _, h, w = feat.shape + valid_bool = valid.to(dtype=torch.bool) & torch.isfinite(u) & torch.isfinite(v) + u_safe = torch.where(valid_bool, u, torch.zeros_like(u)).clamp(0.0, float(max(w - 1, 0))) + v_safe = torch.where(valid_bool, v, torch.zeros_like(v)).clamp(0.0, float(max(h - 1, 0))) + grid_x = (u_safe / max(float(w - 1), 1.0)) * 2.0 - 1.0 + grid_y = (v_safe / max(float(h - 1), 1.0)) * 2.0 - 1.0 + grid = torch.stack([grid_x, grid_y], dim=-1).view(b, -1, 1, 2) + sampled = F.grid_sample(feat, grid, mode="bilinear", padding_mode="zeros", align_corners=True) + return sampled[:, 0, :, 0] * valid_bool.to(dtype=feat.dtype) + + @staticmethod + def _expand_camera_params(camera_params: torch.Tensor, *, batch_size: int, device: torch.device) -> torch.Tensor: + params = camera_params.to(device=device, dtype=torch.float32) + if params.ndim == 1: + params = params.unsqueeze(0) + if int(params.shape[0]) == 1 and int(batch_size) > 1: + params = params.expand(int(batch_size), -1) + return params + + @staticmethod + def _project_fisheye624_points_px_stable( + pts: torch.Tensor, + camera_params: torch.Tensor, + *, + image_h: int, + image_w: int, + finite: torch.Tensor, + require_in_bounds: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + b, n, _ = pts.shape + params = UnisharpLoss._expand_camera_params(camera_params, batch_size=b, device=pts.device) + x, y, z = pts.unbind(dim=-1) + radius = torch.linalg.vector_norm(pts, dim=-1).clamp(min=1e-6) + front = z > (radius * 1e-4).clamp(min=1e-4) + projectable = finite & front + + safe_pts = torch.zeros_like(pts) + safe_pts[..., 2] = 1.0 + pts_proj = torch.where(projectable.unsqueeze(-1), pts, safe_pts) + x, y, z = pts_proj.unbind(dim=-1) + z_safe = z.clamp(min=1e-4) + + ab = torch.stack([x / z_safe, y / z_safe], dim=-1) + r = torch.sqrt((ab * ab).sum(dim=-1, keepdim=True) + 1e-12) + theta = torch.atan(r) + unit_ab = ab / r + + coeffs = params[:, 4:10].reshape(b, 1, 6) + theta_powers = torch.cat([theta.pow(3 + i * 2) for i in range(6)], dim=-1) + theta_distorted = theta + (theta_powers * coeffs).sum(dim=-1, keepdim=True) + uv_dist = theta_distorted * unit_ab + + p0 = params[..., -6].reshape(b, 1) + p1 = params[..., -5].reshape(b, 1) + xr = uv_dist[..., 0] + yr = uv_dist[..., 1] + xr_sq = xr.square() + yr_sq = yr.square() + rd_sq = xr_sq + yr_sq + uv_x = uv_dist[..., 0] + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + uv_y = uv_dist[..., 1] + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + + s0 = params[..., -4].reshape(b, 1) + s1 = params[..., -3].reshape(b, 1) + s2 = params[..., -2].reshape(b, 1) + s3 = params[..., -1].reshape(b, 1) + rd_4 = rd_sq.square() + uv_x = uv_x + s0 * rd_sq + s1 * rd_4 + uv_y = uv_y + s2 * rd_sq + s3 * rd_4 + + if int(params.shape[-1]) == 15: + fx = fy = params[..., 0:1] + cx = params[..., 1:2] + cy = params[..., 2:3] + else: + fx = params[..., 0:1] + fy = params[..., 1:2] + cx = params[..., 2:3] + cy = params[..., 3:4] + u = uv_x * fx + cx + v = uv_y * fy + cy + valid = projectable & torch.isfinite(u) & torch.isfinite(v) + if require_in_bounds: + valid = valid & (u >= 0.0) & (u <= float(image_w - 1)) & (v >= 0.0) & (v <= float(image_h - 1)) + return u, v, valid, radius + + @staticmethod + def _project_points_px( + points: torch.Tensor, + *, + projection_model: str | None, + image_h: int, + image_w: int, + intrinsics: torch.Tensor | None = None, + camera_params: torch.Tensor | None = None, + require_in_bounds: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + pts_raw = points.to(dtype=torch.float32) + finite = torch.isfinite(pts_raw).all(dim=-1) + pts = torch.nan_to_num(pts_raw, nan=0.0, posinf=0.0, neginf=0.0) + b, n, _ = pts.shape + x, y, z = pts.unbind(dim=-1) + model = (projection_model or "pinhole").lower() + + if model in _ERP_PROJECTION_MODELS: + radius_sq_raw = (pts * pts).sum(dim=-1) + direction_valid = finite & (radius_sq_raw > 1e-12) + safe_pts = torch.zeros_like(pts) + safe_pts[..., 2] = 1.0 + pts_erp = torch.where(direction_valid.unsqueeze(-1), pts, safe_pts) + x, y, z = pts_erp.unbind(dim=-1) + radius_sq = (pts_erp * pts_erp).sum(dim=-1) + radius = torch.sqrt(radius_sq + 1e-12) + horizontal_sq = x.square() + z.square() + horizontal = torch.sqrt(horizontal_sq + 1e-12) + pole_angle_eps = max(1e-4, 0.5 * math.pi / float(max(image_h, image_w, 1))) + lon_valid = horizontal > radius * pole_angle_eps + lon_x = torch.where(lon_valid, x, torch.zeros_like(x)) + lon_z = torch.where(lon_valid, z, torch.ones_like(z)) + lon = torch.atan2(lon_x, lon_z) + pitch_down = torch.atan2(y, horizontal) + u = (lon / (2.0 * math.pi) + 0.5) * float(max(image_w, 1)) - 0.5 + v = (0.5 + pitch_down / math.pi) * float(max(image_h, 1)) - 0.5 + valid = direction_valid & lon_valid + valid = ( + valid + & torch.isfinite(u) + & torch.isfinite(v) + & (u >= 0.0) + & (u <= float(image_w - 1)) + & (v >= 0.0) + & (v <= float(image_h - 1)) + ) + return u, v, valid, radius.clamp(min=1e-6) + + if model in _FISHEYE_PROJECTION_MODELS and torch.is_tensor(camera_params): + return UnisharpLoss._project_fisheye624_points_px_stable( + pts, + camera_params, + image_h=image_h, + image_w=image_w, + finite=finite, + require_in_bounds=require_in_bounds, + ) + + valid = finite & (z > 1e-4) + if not torch.is_tensor(intrinsics): + fx = torch.full((b, 1), float(max(image_w, image_h)), device=pts.device, dtype=torch.float32) + fy = fx.clone() + cx = torch.full((b, 1), 0.5 * float(max(image_w - 1, 1)), device=pts.device, dtype=torch.float32) + cy = torch.full((b, 1), 0.5 * float(max(image_h - 1, 1)), device=pts.device, dtype=torch.float32) + else: + k = intrinsics.to(device=pts.device, dtype=torch.float32) + if k.ndim == 2: + k = k.unsqueeze(0) + if int(k.shape[0]) == 1 and b > 1: + k = k.expand(b, -1, -1) + fx = k[:, 0, 0:1] + fy = k[:, 1, 1:2] + cx = k[:, 0, 2:3] + cy = k[:, 1, 2:3] + z_safe = z.clamp(min=1e-4) + u = fx * (x / z_safe) + cx + v = fy * (y / z_safe) + cy + valid = valid & torch.isfinite(u) & torch.isfinite(v) + if require_in_bounds: + valid = valid & (u >= 0.0) & (u <= float(image_w - 1)) & (v >= 0.0) & (v <= float(image_h - 1)) + return u, v, valid, z_safe + + def _projected_sigma_px( + self, + *, + gaussian_scales: torch.Tensor, + gaussian_quaternions: torch.Tensor | None, + gaussian_mean_vectors: torch.Tensor, + valid: torch.Tensor, + projection_model: str | None, + image_h: int, + image_w: int, + intrinsics: torch.Tensor | None = None, + camera_params: torch.Tensor | None = None, + projected_scale_factor: float | torch.Tensor | None = None, + ) -> torch.Tensor: + scales = self._flatten_gaussian_xyz(gaussian_scales) + quats = self._flatten_gaussian_quat(gaussian_quaternions) + means = self._flatten_gaussian_xyz(gaussian_mean_vectors) + if scales is None or means is None: + return torch.zeros_like(valid, dtype=torch.float32) + valid = valid.to(dtype=torch.bool) & torch.isfinite(scales).all(dim=-1) & torch.isfinite(means).all(dim=-1) + scales = torch.nan_to_num(scales.to(dtype=torch.float32), nan=0.0, posinf=0.0, neginf=0.0).abs() + means = torch.nan_to_num(means.to(dtype=torch.float32), nan=0.0, posinf=0.0, neginf=0.0) + model = (projection_model or "pinhole").lower() + if model in _ERP_PROJECTION_MODELS: + radius = torch.norm(means, dim=-1).clamp(min=1e-4) + sigma_u = scales[..., 0] / radius * (float(max(image_w, 1)) / (2.0 * math.pi)) + sigma_v = scales[..., 1] / radius * (float(max(image_h, 1)) / math.pi) + sigma_px = torch.maximum(sigma_u.square(), sigma_v.square()) + valid = valid & torch.isfinite(sigma_px) + sigma_px = torch.nan_to_num(sigma_px, nan=0.0, posinf=0.0, neginf=0.0) + return torch.where(valid, sigma_px, torch.zeros_like(sigma_px)) + + if quats is not None and tuple(quats.shape[:2]) == tuple(means.shape[:2]): + quats = torch.nan_to_num(quats.to(dtype=torch.float32), nan=0.0, posinf=0.0, neginf=0.0) + quat_norm = quats.norm(dim=-1, keepdim=True) + valid = valid & torch.isfinite(quats).all(dim=-1) & (quat_norm.squeeze(-1) > 1e-8) + quats = quats / quat_norm.clamp(min=1e-8) + rotations = linalg.rotation_matrices_from_quaternions(quats) + tangent_scales = scales[..., :2] + tangent_rotations = rotations[..., :, :2] + axis_offsets = (tangent_rotations * tangent_scales[..., None, :]).transpose(-1, -2) + axis_points = means[:, :, None, :] + axis_offsets + u0, v0, valid0, _ = self._project_points_px( + means, + projection_model=projection_model, + image_h=image_h, + image_w=image_w, + intrinsics=intrinsics, + camera_params=camera_params, + require_in_bounds=False, + ) + b, n, axis_count, _ = axis_points.shape + u1, v1, valid1, _ = self._project_points_px( + axis_points.reshape(b, n * axis_count, 3), + projection_model=projection_model, + image_h=image_h, + image_w=image_w, + intrinsics=intrinsics, + camera_params=camera_params, + require_in_bounds=False, + ) + u1 = u1.reshape(b, n, axis_count) + v1 = v1.reshape(b, n, axis_count) + valid1 = valid1.reshape(b, n, axis_count) + du = u1 - u0[..., None] + dv = v1 - v0[..., None] + if (projection_model or "pinhole").lower() in _ERP_PROJECTION_MODELS: + width = float(max(image_w, 1)) + du = torch.remainder(du + 0.5 * width, width) - 0.5 * width + cov_xx = (du * du).sum(dim=-1) + cov_xy = (du * dv).sum(dim=-1) + cov_yy = (dv * dv).sum(dim=-1) + trace = cov_xx + cov_yy + disc = (cov_xx - cov_yy).square() + 4.0 * cov_xy.square() + sigma_px = 0.5 * (trace + (disc.clamp(min=0.0) + 1e-12).sqrt()) + valid = valid & valid0 & valid1.all(dim=-1) & torch.isfinite(sigma_px) + sigma_px = torch.nan_to_num(sigma_px, nan=0.0, posinf=0.0, neginf=0.0) + return torch.where(valid, sigma_px, torch.zeros_like(sigma_px)) + + sigma_screen_3d = scales[..., :2].to(dtype=torch.float32).abs().amax(dim=-1).clamp(min=1e-8) + if model in {"fisheye624", "opencv_fisheye"} and torch.is_tensor(camera_params): + params = camera_params.to(device=means.device, dtype=torch.float32) + if params.ndim == 1: + params = params.unsqueeze(0) + if int(params.shape[0]) == 1 and int(means.shape[0]) > 1: + params = params.expand(int(means.shape[0]), -1) + if int(params.shape[-1]) == 15: + focal = params[:, 0:1].clamp(min=1.0) + else: + focal = 0.5 * (params[:, 0:1] + params[:, 1:2]).clamp(min=1.0) + radius = torch.norm(means, dim=-1).clamp(min=1e-4) + sigma_px = (sigma_screen_3d / radius * focal).square() + elif torch.is_tensor(intrinsics): + k = intrinsics.to(device=means.device, dtype=torch.float32) + if k.ndim == 2: + k = k.unsqueeze(0) + if int(k.shape[0]) == 1 and int(means.shape[0]) > 1: + k = k.expand(int(means.shape[0]), -1, -1) + focal = 0.5 * (k[:, 0, 0:1] + k[:, 1, 1:2]).clamp(min=1.0) + depth = means[..., 2].clamp(min=1e-4) + sigma_px = (sigma_screen_3d / depth * focal).square() + else: + depth = torch.norm(means, dim=-1).clamp(min=1e-4) + sigma_px = sigma_screen_3d / depth + if torch.is_tensor(projected_scale_factor): + sigma_px = sigma_px * projected_scale_factor.to(device=sigma_px.device, dtype=sigma_px.dtype) + elif projected_scale_factor is not None: + sigma_px = sigma_px * float(projected_scale_factor) + sigma_px = sigma_px.square() + valid = valid & torch.isfinite(sigma_px) + sigma_px = torch.nan_to_num(sigma_px, nan=0.0, posinf=0.0, neginf=0.0) + return torch.where(valid, sigma_px, torch.zeros_like(sigma_px)) + + def _depth_edge_band( + self, + depth_m: torch.Tensor, + valid_weight: torch.Tensor, + *, + circular_h: bool, + ) -> torch.Tensor: + depth = depth_m.to(dtype=torch.float32) + if depth.ndim == 3: + depth = depth.unsqueeze(1) + valid = torch.isfinite(depth) & (depth > 0.0) & (valid_weight[:, :1].to(dtype=torch.float32) > 0.5) + log_depth = torch.where(valid, depth.clamp(min=1e-4).log(), torch.zeros_like(depth)) + + if bool(circular_h): + right = torch.roll(log_depth, shifts=-1, dims=-1) + valid_right = valid & torch.roll(valid, shifts=-1, dims=-1) + edge_x = (right - log_depth).abs() > float(self.depth_edge_log_threshold) + edge_x = edge_x & valid_right + else: + edge_x = torch.zeros_like(valid) + edge_x[..., :, :-1] = ( + (log_depth[..., :, 1:] - log_depth[..., :, :-1]).abs() > float(self.depth_edge_log_threshold) + ) & valid[..., :, 1:] & valid[..., :, :-1] + + edge_y = torch.zeros_like(valid) + edge_y[..., :-1, :] = ( + (log_depth[..., 1:, :] - log_depth[..., :-1, :]).abs() > float(self.depth_edge_log_threshold) + ) & valid[..., 1:, :] & valid[..., :-1, :] + edge = (edge_x | edge_y).to(dtype=torch.float32) + + radius = max(int(self.depth_edge_dilate_px), 0) + if radius <= 0: + return edge + kernel = 2 * radius + 1 + if bool(circular_h): + edge = F.pad(edge, (radius, radius, 0, 0), mode="circular") + edge = F.pad(edge, (0, 0, radius, radius), mode="constant", value=0.0) + return F.max_pool2d(edge, kernel_size=kernel, stride=1) + return F.max_pool2d(edge, kernel_size=kernel, stride=1, padding=radius) + + def _ray_cell_sigma( + self, + *, + gaussian_scales: torch.Tensor, + gaussian_mean_vectors: torch.Tensor, + gaussian_angular_cell: torch.Tensor, + gauss_grid_shape: tuple[int, int, int] | None, + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + scales = self._flatten_gaussian_xyz(gaussian_scales, gauss_grid_shape) + means = self._flatten_gaussian_xyz(gaussian_mean_vectors, gauss_grid_shape) + if scales is None or means is None: + return None, None + if not torch.is_tensor(gaussian_angular_cell): + return None, None + cell = gaussian_angular_cell.to(device=scales.device, dtype=torch.float32) + if cell.ndim != 5 or int(cell.shape[1]) != 2: + return None, None + if gauss_grid_shape is None: + return None, None + l, h, w = (int(gauss_grid_shape[0]), int(gauss_grid_shape[1]), int(gauss_grid_shape[2])) + if tuple(cell.shape[-2:]) != (h, w): + return None, None + if int(cell.shape[2]) == 1 and l > 1: + cell = cell.expand(-1, -1, l, -1, -1) + elif int(cell.shape[2]) != l: + return None, None + cell_flat = cell.permute(0, 2, 3, 4, 1).flatten(1, 3) + if int(cell_flat.shape[1]) != int(scales.shape[1]): + return None, None + radius = torch.linalg.norm(means.to(dtype=torch.float32), dim=-1, keepdim=True).clamp(min=1e-4) + tangent = scales[..., :2].to(dtype=torch.float32).abs() + sigma_cells = (tangent / radius / cell_flat.clamp(min=1e-6)).square() + valid = torch.isfinite(sigma_cells).all(dim=-1) & torch.isfinite(radius.squeeze(-1)) + sigma_cells = torch.nan_to_num(sigma_cells, nan=0.0, posinf=0.0, neginf=0.0) + return sigma_cells, valid + + def _dynamic_splat_sigma_limits( + self, + *, + sigma_proj: torch.Tensor, + projection_model: str | None, + image_h: int, + image_w: int, + intrinsics: torch.Tensor | None = None, + camera_params: torch.Tensor | None = None, + projected_scale_factor: float | torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + del projection_model, image_h, image_w, intrinsics, camera_params, projected_scale_factor + return ( + torch.as_tensor(self.splat_sigma_min, device=sigma_proj.device, dtype=sigma_proj.dtype), + torch.as_tensor(self.splat_sigma_max, device=sigma_proj.device, dtype=sigma_proj.dtype), + ) + + def _sanitize_supervision_depth(self, depth_m: torch.Tensor, *, clamp_max: bool = True) -> torch.Tensor: + depth = depth_m.to(torch.float32) + valid = torch.isfinite(depth) & (depth > 0.0) + depth = torch.where(valid, depth, torch.zeros_like(depth)) + if bool(valid.any().item()): + depth = depth.clone() + if bool(clamp_max): + depth[valid] = depth[valid].clamp(min=1e-4, max=float(self.SUPERVISION_MAX_DEPTH_M)) + else: + depth[valid] = depth[valid].clamp(min=1e-4) + return depth + + def _sobel_gradient_loss_erp( + self, + pred_depth_m: torch.Tensor, + gt_depth_m: torch.Tensor, + depth_weight: torch.Tensor, + circular_h: bool | None = None, + ) -> torch.Tensor: + dtype = pred_depth_m.dtype + device = pred_depth_m.device + + kx = self._sobel_kx.to(dtype=dtype, device=device) # type: ignore[attr-defined] + ky = self._sobel_ky.to(dtype=dtype, device=device) # type: ignore[attr-defined] + + log_pred = torch.log(pred_depth_m.clamp(min=1e-4)) + log_gt = torch.log(gt_depth_m.clamp(min=1e-4)) + log_diff = log_pred - log_gt + + mask = depth_weight.to(dtype=dtype).clamp(min=0.0, max=1.0) + valid_mask = (mask > 0.5).to(dtype=dtype) + log_diff = torch.where(valid_mask > 0.5, log_diff, torch.zeros_like(log_diff)) + + total = torch.zeros((), device=device, dtype=dtype) + n_computed = 0 + + use_circular_h = self.grad_img_circular_h if circular_h is None else bool(circular_h) + ones_kernel = torch.ones((1, 1, 3, 3), device=device, dtype=dtype) + for _s in range(self.grad_img_scales): + if min(log_diff.shape[-2:]) < 4: + break + + if use_circular_h: + padded = F.pad(log_diff, (1, 1, 0, 0), mode="circular") + padded = F.pad(padded, (0, 0, 1, 1), mode="reflect") + padded_mask = F.pad(valid_mask, (1, 1, 0, 0), mode="circular") + padded_mask = F.pad(padded_mask, (0, 0, 1, 1), mode="replicate") + else: + padded = F.pad(log_diff, (1, 1, 1, 1), mode="reflect") + padded_mask = F.pad(valid_mask, (1, 1, 1, 1), mode="replicate") + + gx = F.conv2d(padded, kx) + gy = F.conv2d(padded, ky) + grad_mag = torch.sqrt(gx * gx + gy * gy + 1e-8) + + stencil_valid = (F.conv2d(padded_mask, ones_kernel) >= 8.999).to(dtype=dtype) + n_valid = stencil_valid.sum().clamp(min=1.0) + total = total + (grad_mag * stencil_valid).sum() / n_valid + n_computed += 1 + + if _s < self.grad_img_scales - 1: + if use_circular_h: + pooled_mask = _avg_pool2d_circular_h(valid_mask, kernel_size=2, stride=2) + pooled_diff = _avg_pool2d_circular_h(log_diff * valid_mask, kernel_size=2, stride=2) + else: + pooled_mask = F.avg_pool2d(valid_mask, kernel_size=2, stride=2) + pooled_diff = F.avg_pool2d(log_diff * valid_mask, kernel_size=2, stride=2) + log_diff = pooled_diff / pooled_mask.clamp(min=1e-6) + valid_mask = (pooled_mask > 0.999).to(dtype=dtype) + log_diff = torch.where(valid_mask > 0.5, log_diff, torch.zeros_like(log_diff)) + + if n_computed == 0: + return torch.zeros((), device=device, dtype=dtype) + return total / float(n_computed) + + def _sobel_xy_rgb(self, img: torch.Tensor, *, circular_h: bool) -> tuple[torch.Tensor, torch.Tensor]: + channels = int(img.shape[1]) + kx = self._sobel_kx.to(dtype=img.dtype, device=img.device).expand(channels, 1, 3, 3) # type: ignore[attr-defined] + ky = self._sobel_ky.to(dtype=img.dtype, device=img.device).expand(channels, 1, 3, 3) # type: ignore[attr-defined] + if bool(circular_h): + padded = F.pad(img, (1, 1, 0, 0), mode="circular") + padded = F.pad(padded, (0, 0, 1, 1), mode="reflect") + else: + padded = F.pad(img, (1, 1, 1, 1), mode="reflect") + return ( + F.conv2d(padded, kx, groups=channels), + F.conv2d(padded, ky, groups=channels), + ) + + def _edge_rgb_gradient_loss( + self, + pred_rgb_linear: torch.Tensor, + gt_rgb_linear: torch.Tensor, + valid_weight: torch.Tensor, + depth_edge_band: torch.Tensor | None, + *, + circular_h: bool, + ) -> torch.Tensor: + dtype = pred_rgb_linear.dtype + device = pred_rgb_linear.device + pred = pred_rgb_linear.to(dtype=torch.float32) + gt = gt_rgb_linear.to(device=device, dtype=torch.float32) + weight = valid_weight.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)[:, :1] + + pred_gx, pred_gy = self._sobel_xy_rgb(pred, circular_h=circular_h) + gt_gx, gt_gy = self._sobel_xy_rgb(gt, circular_h=circular_h) + gt_mag = torch.sqrt(gt_gx.square() + gt_gy.square() + 1e-8).mean(dim=1, keepdim=True) + + flat = gt_mag.detach().flatten(2) + mean = flat.mean(dim=-1, keepdim=True)[..., None] + std = flat.std(dim=-1, keepdim=True, unbiased=False)[..., None] + rgb_edge = (gt_mag.detach() > (mean + 0.5 * std).clamp(min=0.02)).to(dtype=torch.float32) + + if torch.is_tensor(depth_edge_band): + edge_boost = depth_edge_band.to(device=device, dtype=torch.float32).clamp(0.0, 1.0) + if tuple(edge_boost.shape[-2:]) != tuple(gt_mag.shape[-2:]): + edge_boost = F.interpolate(edge_boost, size=gt_mag.shape[-2:], mode="nearest") + edge_weight = rgb_edge * (1.0 + edge_boost[:, :1]) + else: + edge_weight = rgb_edge + + ones_kernel = torch.ones((1, 1, 3, 3), device=device, dtype=torch.float32) + if bool(circular_h): + padded_weight = F.pad(weight, (1, 1, 0, 0), mode="circular") + padded_weight = F.pad(padded_weight, (0, 0, 1, 1), mode="replicate") + else: + padded_weight = F.pad(weight, (1, 1, 1, 1), mode="replicate") + stencil_valid = (F.conv2d(padded_weight, ones_kernel) >= 8.999).to(dtype=torch.float32) + + diff = (pred_gx - gt_gx).abs() + (pred_gy - gt_gy).abs() + diff = diff.mean(dim=1, keepdim=True) + final_weight = edge_weight * stencil_valid + return (diff * final_weight).sum().to(dtype=dtype) / final_weight.sum().clamp(min=1.0).to(dtype=dtype) + + def forward( + self, + pred_rgb_linear: torch.Tensor, + pred_alpha: torch.Tensor, + pred_depth_m: torch.Tensor, + gt_rgb_u8: torch.Tensor, + gt_depth_m: torch.Tensor, + pred_depth2_m: torch.Tensor | None = None, + mask: torch.Tensor | None = None, + depth_mask: torch.Tensor | None = None, + delta_xy: torch.Tensor | None = None, + delta_rho: torch.Tensor | None = None, + delta_grid: torch.Tensor | None = None, + gaussian_scales: torch.Tensor | None = None, + gaussian_quaternions: torch.Tensor | None = None, + gaussian_mean_vectors: torch.Tensor | None = None, + gaussian_base_mean_vectors: torch.Tensor | None = None, + gaussian_angular_cell: torch.Tensor | None = None, + gaussian_opacities: torch.Tensor | None = None, + gauss_grid_shape: tuple[int, int, int] | None = None, + projected_scale_factor: float | torch.Tensor | None = None, + projection_model: str | None = None, + projection_intrinsics: torch.Tensor | None = None, + projection_camera_params: torch.Tensor | None = None, + apply_color: bool = True, + apply_alpha: bool = True, + apply_depth: bool = True, + apply_percep: bool = False, + apply_tv: bool = True, + apply_grad: bool = True, + apply_delta: bool = True, + apply_splat: bool = True, + apply_grad_img: bool = True, + grad_img_circular_h: bool | None = None, + ) -> dict[str, torch.Tensor]: + losses: dict[str, torch.Tensor] = {} + circular_h = bool(grad_img_circular_h) if grad_img_circular_h is not None else False + + gt_rgb = gt_rgb_u8.to(pred_rgb_linear.device).float() / 255.0 + gt_rgb_linear = _to_linear_rgb(gt_rgb) + pred_depth_m = self._sanitize_supervision_depth(pred_depth_m.to(pred_rgb_linear.device), clamp_max=False) + if pred_depth2_m is not None: + pred_depth2_m = self._sanitize_supervision_depth(pred_depth2_m.to(pred_rgb_linear.device), clamp_max=False) + gt_depth_raw = self._sanitize_supervision_depth(gt_depth_m.to(pred_rgb_linear.device)) + depth_valid = torch.isfinite(gt_depth_raw) & (gt_depth_raw > 0.0) + gt_depth = gt_depth_raw.clamp(min=1e-4) + + if mask is None: + m = torch.ones_like(pred_alpha) + else: + m = mask.to(pred_rgb_linear.device).to(pred_rgb_linear.dtype) + depth_weight = depth_valid.to(dtype=pred_depth_m.dtype) * m[:, :1].to(dtype=pred_depth_m.dtype) + if depth_mask is not None: + depth_weight = depth_weight * depth_mask.to(pred_rgb_linear.device).to(dtype=pred_depth_m.dtype)[:, :1] + pred_rgb_rendered = pred_rgb_linear.clamp(0.0, 1.0) + + if apply_color and self.w.lambda_color > 0: + color_l1 = (pred_rgb_rendered - gt_rgb_linear).abs() + losses["color"] = _masked_mean(color_l1, m) + else: + losses["color"] = torch.zeros((), device=pred_rgb_linear.device) + + if apply_alpha and self.w.lambda_alpha > 0: + a = pred_alpha.clamp(1e-6, 1.0 - 1e-6) + with torch.autocast(device_type=a.device.type, enabled=False): + alpha_bce = F.binary_cross_entropy( + a.to(dtype=torch.float32), + torch.ones_like(a, dtype=torch.float32), + reduction="none", + ) + alpha_loss = _masked_mean(alpha_bce, m) + alpha_tail_min = torch.as_tensor( + self.alpha_tail_min, + device=a.device, + dtype=torch.float32, + ).clamp(min=0.0, max=1.0) + alpha_tail_weight = torch.as_tensor( + max(0.0, self.alpha_tail_weight), + device=a.device, + dtype=torch.float32, + ) + if self.alpha_tail_min > 0.0 and self.alpha_tail_weight > 0.0: + tail = F.relu(alpha_tail_min - a.to(dtype=torch.float32)) + tail = tail / alpha_tail_min.clamp(min=1e-6) + tail_mask = (m[:, :1].to(dtype=torch.bool)) & (tail > 0.0) + alpha_loss = alpha_loss + alpha_tail_weight * _finite_masked_mean_flat(tail, tail_mask) + losses["alpha"] = alpha_loss.to(dtype=pred_rgb_linear.dtype) + else: + losses["alpha"] = torch.zeros((), device=pred_rgb_linear.device) + + if apply_depth and self.w.lambda_depth > 0: + w_depth = depth_weight + inv_pred1 = 1.0 / pred_depth_m.clamp(min=1e-4) + inv_gt = torch.zeros_like(inv_pred1) + inv_gt[depth_valid] = 1.0 / gt_depth[depth_valid] + depth_abs = (inv_pred1 - inv_gt).abs() + losses["depth"] = _masked_mean(depth_abs, w_depth) + else: + losses["depth"] = torch.zeros((), device=pred_rgb_linear.device) + + if apply_tv and self.w.lambda_tv > 0 and (pred_depth2_m is not None): + inv2 = 1.0 / pred_depth2_m.clamp(min=1e-4) + losses["tv"] = _tv_l1_circular_h(inv2) if circular_h else _tv_l1(inv2) + else: + losses["tv"] = torch.zeros((), device=pred_rgb_linear.device) + + image_h, image_w = int(pred_depth_m.shape[-2]), int(pred_depth_m.shape[-1]) + projection_points = self._flatten_gaussian_xyz(gaussian_mean_vectors, gauss_grid_shape) + projected_u = projected_v = None + projected_valid = None + if projection_points is not None: + projected_u, projected_v, projected_valid, _projected_depth = self._project_points_px( + projection_points, + projection_model=projection_model, + image_h=image_h, + image_w=image_w, + intrinsics=projection_intrinsics, + camera_params=projection_camera_params, + ) + + if apply_grad and self.w.lambda_grad > 0: + inv1 = 1.0 / pred_depth_m.clamp(min=1e-4) + op_flat = self._flatten_gaussian_scalar(gaussian_opacities, gauss_grid_shape) + if projected_u is not None and projected_v is not None and projected_valid is not None and op_flat is not None: + grad_map = self._central_disparity_gradient(inv1, circular_h=circular_h) + grad_at_gauss = self._sample_map_at_uv(grad_map, projected_u, projected_v, projected_valid) + penalty = 1.0 - torch.exp( + -(1.0 / max(self.grad_sigma, 1e-8)) * F.relu(grad_at_gauss - self.grad_eps) + ) + weight = projected_valid & torch.isfinite(grad_at_gauss) & torch.isfinite(op_flat) + mask_at_gauss = self._sample_map_at_uv(m[:, :1], projected_u, projected_v, projected_valid) + weight = weight & (mask_at_gauss > 0.5) + grad_value = op_flat.to(dtype=penalty.dtype).clamp(0, 1) * penalty + losses["grad"] = _finite_masked_mean_flat(grad_value, weight) + else: + raise RuntimeError( + "L_grad requires gaussian_mean_vectors, gaussian_opacities, " + "gauss_grid_shape, and projection metadata. The old " + "pred_alpha image-space fallback is disabled for ray-local training." + ) + else: + losses["grad"] = torch.zeros((), device=pred_rgb_linear.device) + + if apply_grad_img and self.w.lambda_grad_img > 0: + losses["grad_img"] = self._sobel_gradient_loss_erp( + pred_depth_m=pred_depth_m, + gt_depth_m=gt_depth, + depth_weight=depth_weight, + circular_h=grad_img_circular_h, + ) + else: + losses["grad_img"] = torch.zeros((), device=pred_rgb_linear.device) + + if apply_color and self.w.lambda_edge_rgb > 0: + depth_edge_for_rgb = self._depth_edge_band(gt_depth, depth_weight, circular_h=circular_h) + losses["edge_rgb"] = self._edge_rgb_gradient_loss( + pred_rgb_linear=pred_rgb_rendered, + gt_rgb_linear=gt_rgb_linear, + valid_weight=m, + depth_edge_band=depth_edge_for_rgb, + circular_h=circular_h, + ) + else: + losses["edge_rgb"] = torch.zeros((), device=pred_rgb_linear.device) + + if apply_delta and self.w.lambda_delta > 0: + if delta_xy is not None: + dx = F.relu(delta_xy[:, 0:1].abs() - self.raw_delta_clip) + dy = F.relu(delta_xy[:, 1:2].abs() - self.raw_delta_clip) + losses["delta"] = (dx + dy).mean() + else: + del gaussian_base_mean_vectors + raise RuntimeError( + "L_delta requires raw delta_xy in ray-local training. The old " + "screen-space pixel displacement fallback is disabled." + ) + else: + losses["delta"] = torch.zeros((), device=pred_rgb_linear.device) + + if apply_delta and self.w.lambda_delta_rho > 0 and delta_rho is not None: + dz = delta_rho.to(device=pred_rgb_linear.device, dtype=pred_rgb_linear.dtype) + finite = torch.isfinite(dz) + dz_safe = torch.nan_to_num(dz, nan=0.0, posinf=0.0, neginf=0.0) + penalty = F.relu(dz_safe.abs() - self.raw_delta_rho_clip) + penalty = torch.where(finite, penalty, torch.zeros_like(penalty)) + losses["delta_rho"] = penalty.sum() / finite.to(dtype=penalty.dtype).sum().clamp(min=1.0) + else: + losses["delta_rho"] = torch.zeros((), device=pred_rgb_linear.device) + + if self.w.lambda_grid > 0 and torch.is_tensor(delta_grid): + losses["grid"] = _delta_grid_checkerboard_loss( + delta_grid.to(device=pred_rgb_linear.device), + circular_h=circular_h, + ).to(dtype=pred_rgb_linear.dtype) + else: + losses["grid"] = torch.zeros((), device=pred_rgb_linear.device) + + if apply_splat and self.w.lambda_splat > 0: + if gaussian_scales is None: + raise RuntimeError("L_splat requires gaussian_scales for projected screen-space variance.") + if gaussian_mean_vectors is None or projected_valid is None: + raise RuntimeError( + "L_splat requires gaussian_mean_vectors and projection metadata " + "to compute projected screen-space variance." + ) + sigma_proj = self._projected_sigma_px( + gaussian_scales=gaussian_scales, + gaussian_quaternions=gaussian_quaternions, + gaussian_mean_vectors=gaussian_mean_vectors, + valid=projected_valid, + projection_model=projection_model, + image_h=image_h, + image_w=image_w, + intrinsics=projection_intrinsics, + camera_params=projection_camera_params, + projected_scale_factor=projected_scale_factor, + ) + valid_splat = projected_valid & torch.isfinite(sigma_proj) + splat_sigma_min = torch.as_tensor( + self.splat_sigma_min, + device=sigma_proj.device, + dtype=sigma_proj.dtype, + ) + splat_sigma_max = torch.as_tensor( + self.splat_sigma_max, + device=sigma_proj.device, + dtype=sigma_proj.dtype, + ) + lower_penalty = F.relu(splat_sigma_min - sigma_proj) + upper_penalty = F.relu(sigma_proj - splat_sigma_max) + splat_penalty = lower_penalty + upper_penalty + losses["splat"] = _finite_masked_mean_flat(splat_penalty, valid_splat) + else: + sigma_proj = None + valid_splat = None + losses["splat"] = torch.zeros((), device=pred_rgb_linear.device) + + if apply_splat and self.w.lambda_edge_splat > 0: + if gaussian_scales is None: + raise RuntimeError("L_edge_splat requires gaussian_scales for projected screen-space variance.") + if gaussian_mean_vectors is None or projected_valid is None: + raise RuntimeError( + "L_edge_splat requires gaussian_mean_vectors and projection metadata " + "to sample source depth-edge bands." + ) + if sigma_proj is None or valid_splat is None: + sigma_proj = self._projected_sigma_px( + gaussian_scales=gaussian_scales, + gaussian_quaternions=gaussian_quaternions, + gaussian_mean_vectors=gaussian_mean_vectors, + valid=projected_valid, + projection_model=projection_model, + image_h=image_h, + image_w=image_w, + intrinsics=projection_intrinsics, + camera_params=projection_camera_params, + projected_scale_factor=projected_scale_factor, + ) + valid_splat = projected_valid & torch.isfinite(sigma_proj) + edge_band = self._depth_edge_band(gt_depth, depth_weight, circular_h=circular_h) + edge_at_gauss = self._sample_map_at_uv(edge_band, projected_u, projected_v, projected_valid) + edge_valid = valid_splat & torch.isfinite(edge_at_gauss) & (edge_at_gauss > 0.5) + edge_sigma_max = torch.as_tensor( + self.edge_splat_sigma_max, + device=sigma_proj.device, + dtype=sigma_proj.dtype, + ) + losses["edge_splat"] = _finite_masked_mean_flat(F.relu(sigma_proj - edge_sigma_max), edge_valid) + else: + losses["edge_splat"] = torch.zeros((), device=pred_rgb_linear.device) + + zero = torch.zeros((), device=pred_rgb_linear.device) + losses["percep_feat"] = zero + losses["percep_gram"] = zero + if apply_percep and self.w.lambda_percep > 0 and (self._percep_net is not None): + from unisharp.utils.color_space import linearRGB2sRGB + + pred_srgb = linearRGB2sRGB(pred_rgb_rendered.to(torch.float32)).clamp(0, 1) + gt_srgb = gt_rgb.clamp(0, 1) + + pred_srgb = _resize_max_side(pred_srgb, self.percep_max_side, mode="bilinear") + gt_srgb = _resize_max_side(gt_srgb, self.percep_max_side, mode="bilinear") + + feats_p = self._percep_net(pred_srgb) + feats_g = self._percep_net(gt_srgb) + loss_feat_total = torch.zeros((), device=pred_rgb_linear.device) + loss_gram_total = torch.zeros((), device=pred_rgb_linear.device) + for fp, fg in zip(feats_p, feats_g): + d, h, w = fp.shape[1], fp.shape[2], fp.shape[3] + lam_gram = 10.0 / float(max(1, d * d)) + lam_feat = 1.0 / float(max(1, d * h * w)) + diff = (fp - fg).pow(2) + loss_feat = (diff.sum(dim=[1, 2, 3]) * lam_feat).mean() + gram_norm = float(max(1, h * w)) + gp = _gram_matrix(fp) / gram_norm + gg = _gram_matrix(fg) / gram_norm + loss_gram = ((gp - gg).pow(2).sum(dim=[1, 2]) * lam_gram).mean() + loss_feat_total = loss_feat_total + loss_feat + loss_gram_total = loss_gram_total + loss_gram + layer_count = float(max(1, len(feats_p))) + losses["percep_feat"] = loss_feat_total / layer_count + losses["percep_gram"] = loss_gram_total / layer_count + losses["percep"] = losses["percep_feat"] + losses["percep_gram"] + else: + losses["percep"] = torch.zeros((), device=pred_rgb_linear.device) + + losses["total"] = ( + self.w.lambda_color * losses["color"] + + self.w.lambda_alpha * losses["alpha"] + + self.w.lambda_percep * losses["percep"] + + self.w.lambda_depth * losses["depth"] + + self.w.lambda_tv * losses["tv"] + + self.w.lambda_grad * losses["grad"] + + self.w.lambda_grad_img * losses["grad_img"] + + self.w.lambda_edge_rgb * losses["edge_rgb"] + + self.w.lambda_delta * losses["delta"] + + self.w.lambda_delta_rho * losses["delta_rho"] + + self.w.lambda_splat * losses["splat"] + + self.w.lambda_edge_splat * losses["edge_splat"] + + self.w.lambda_grid * losses["grid"] + ) + return losses + diff --git a/unisharp/models/__init__.py b/unisharp/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38edc964ae2d49f617e3cf17e4879c37e769cbf3 --- /dev/null +++ b/unisharp/models/__init__.py @@ -0,0 +1,23 @@ + +from __future__ import annotations + +from .feature_gaussian_decoder import ( + FeatureGaussianDecoder, + FeatureGaussianDecoderParams, + ImageFeatures, + create_feature_gaussian_decoder, +) +from .unisharp_params import PanoPredictorParams +from .unisharp_feature import UnisharpFeatureConfig, UnisharpFeatureModel +from .unik3d_feature_extractor import UniK3DFeatureExtractor + +__all__ = [ + "PanoPredictorParams", + "UniK3DFeatureExtractor", + "FeatureGaussianDecoder", + "FeatureGaussianDecoderParams", + "ImageFeatures", + "create_feature_gaussian_decoder", + "UnisharpFeatureConfig", + "UnisharpFeatureModel", +] diff --git a/unisharp/models/__pycache__/__init__.cpython-310.pyc b/unisharp/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb82799f887f7084478be9722cc34b54de977e33 Binary files /dev/null and b/unisharp/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/__init__.cpython-313.pyc b/unisharp/models/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2a66736f2d7eb98af2b114dd2c9e140b6519345 Binary files /dev/null and b/unisharp/models/__pycache__/__init__.cpython-313.pyc differ diff --git a/unisharp/models/__pycache__/blocks.cpython-310.pyc b/unisharp/models/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..672d96def3b923895582727899e83ff35849a69c Binary files /dev/null and b/unisharp/models/__pycache__/blocks.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/blocks.cpython-313.pyc b/unisharp/models/__pycache__/blocks.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d74f84f8b51851068ceee56c1c571201db19f5cb Binary files /dev/null and b/unisharp/models/__pycache__/blocks.cpython-313.pyc differ diff --git a/unisharp/models/__pycache__/decoder.cpython-310.pyc b/unisharp/models/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2a225f6c2c5eed5dbd9d21a1c77cf651a237b22 Binary files /dev/null and b/unisharp/models/__pycache__/decoder.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/decoder.cpython-313.pyc b/unisharp/models/__pycache__/decoder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8d775d68f8b620d51e3aea3f1adb4967b483dd1 Binary files /dev/null and b/unisharp/models/__pycache__/decoder.cpython-313.pyc differ diff --git a/unisharp/models/__pycache__/feature_gaussian_decoder.cpython-310.pyc b/unisharp/models/__pycache__/feature_gaussian_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91538f05736d8db66fa6e9b67d5fcda8deb6f4a7 Binary files /dev/null and b/unisharp/models/__pycache__/feature_gaussian_decoder.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/feature_gaussian_decoder.cpython-313.pyc b/unisharp/models/__pycache__/feature_gaussian_decoder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3327d278e2701a0e47698e2a0bd7f0d42513a9fa Binary files /dev/null and b/unisharp/models/__pycache__/feature_gaussian_decoder.cpython-313.pyc differ diff --git a/unisharp/models/__pycache__/gaussian_composer.cpython-310.pyc b/unisharp/models/__pycache__/gaussian_composer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf1922a8e01d71e49e518e692cb4f2ed011de838 Binary files /dev/null and b/unisharp/models/__pycache__/gaussian_composer.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/gaussian_composer.cpython-313.pyc b/unisharp/models/__pycache__/gaussian_composer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..250baef36d78eb3b944aa6e267f7e705ea20f7f7 Binary files /dev/null and b/unisharp/models/__pycache__/gaussian_composer.cpython-313.pyc differ diff --git a/unisharp/models/__pycache__/gaussian_initializer.cpython-310.pyc b/unisharp/models/__pycache__/gaussian_initializer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daf24bc0082db06790a5c2a9e4144b4172117855 Binary files /dev/null and b/unisharp/models/__pycache__/gaussian_initializer.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/gaussian_initializer.cpython-313.pyc b/unisharp/models/__pycache__/gaussian_initializer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3175135f123d4bade2343dac28dc401e6dd8b628 Binary files /dev/null and b/unisharp/models/__pycache__/gaussian_initializer.cpython-313.pyc differ diff --git a/unisharp/models/__pycache__/heads.cpython-310.pyc b/unisharp/models/__pycache__/heads.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39277e62874cfd6e20c65779e61dbd86aea9788c Binary files /dev/null and b/unisharp/models/__pycache__/heads.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/heads.cpython-313.pyc b/unisharp/models/__pycache__/heads.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6664be39432f2298d401378071b3d16795775ad Binary files /dev/null and b/unisharp/models/__pycache__/heads.cpython-313.pyc differ diff --git a/unisharp/models/__pycache__/unik3d_feature_extractor.cpython-310.pyc b/unisharp/models/__pycache__/unik3d_feature_extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..022e2230ffd5c1ac16db3731fa4b9089a1ee733d Binary files /dev/null and b/unisharp/models/__pycache__/unik3d_feature_extractor.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/unik3d_feature_extractor.cpython-313.pyc b/unisharp/models/__pycache__/unik3d_feature_extractor.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16ff13ace5d3ee289d7c7d3fa0f726eb005f8bf2 Binary files /dev/null and b/unisharp/models/__pycache__/unik3d_feature_extractor.cpython-313.pyc differ diff --git a/unisharp/models/__pycache__/unisharp_feature.cpython-310.pyc b/unisharp/models/__pycache__/unisharp_feature.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..045d5d9f49e7738e14fdec9efe268de475f2518b Binary files /dev/null and b/unisharp/models/__pycache__/unisharp_feature.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/unisharp_feature.cpython-313.pyc b/unisharp/models/__pycache__/unisharp_feature.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e5f00bf718a7f527b01bdad2accfba8ec3088f5 Binary files /dev/null and b/unisharp/models/__pycache__/unisharp_feature.cpython-313.pyc differ diff --git a/unisharp/models/__pycache__/unisharp_params.cpython-310.pyc b/unisharp/models/__pycache__/unisharp_params.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebfe7c9b82d9f729ca3c5ff35995d29853fdc789 Binary files /dev/null and b/unisharp/models/__pycache__/unisharp_params.cpython-310.pyc differ diff --git a/unisharp/models/__pycache__/unisharp_params.cpython-313.pyc b/unisharp/models/__pycache__/unisharp_params.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c877b33d04bf8868542596cd568b53f592464300 Binary files /dev/null and b/unisharp/models/__pycache__/unisharp_params.cpython-313.pyc differ diff --git a/unisharp/models/blocks.py b/unisharp/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..96e729c7d9eb0a68f45dab8de4856c20ebb921fc --- /dev/null +++ b/unisharp/models/blocks.py @@ -0,0 +1,208 @@ + +from __future__ import annotations + +from typing import Literal + +import torch +from torch import nn +from torch.nn import functional as F + +NormLayerName = Literal["noop", "batch_norm", "group_norm", "instance_norm"] +UpsamplingMode = Literal["transposed_conv", "nearest", "bilinear"] + + +class CircularAwareConvTranspose2d(nn.ConvTranspose2d): + + circular_horizontal: bool + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.circular_horizontal = False + + def forward(self, input: torch.Tensor, output_size: list[int] | None = None) -> torch.Tensor: + if not bool(self.circular_horizontal): + return super().forward(input, output_size=output_size) + pad_w = self._padding_w() + if pad_w <= 0: + return super().forward(input, output_size=output_size) + x = F.pad(input, (pad_w, pad_w, 0, 0), mode="circular") + out = super().forward(x, output_size=None) + crop_w = pad_w * self._stride_w() + if crop_w > 0: + out = out[..., crop_w:-crop_w] + return out + + def _padding_w(self) -> int: + if isinstance(self.padding, tuple): + return int(self.padding[1] if len(self.padding) > 1 else self.padding[0]) + return int(self.padding) + + def _stride_w(self) -> int: + if isinstance(self.stride, tuple): + return int(self.stride[1] if len(self.stride) > 1 else self.stride[0]) + return int(self.stride) + + +def norm_layer_2d(num_features: int, norm_type: NormLayerName, num_groups: int = 8) -> nn.Module: + if norm_type == "noop": + return nn.Identity() + elif norm_type == "batch_norm": + return nn.BatchNorm2d(num_features=num_features) + elif norm_type == "group_norm": + return nn.GroupNorm(num_channels=num_features, num_groups=num_groups) + elif norm_type == "instance_norm": + return nn.InstanceNorm2d(num_features=num_features) + else: + raise ValueError(f"Invalid normalization layer type: {norm_type}") + + +def upsampling_layer(upsampling_mode: UpsamplingMode, scale_factor: int, dim_in: int) -> nn.Module: + if upsampling_mode == "transposed_conv": + return CircularAwareConvTranspose2d( + in_channels=dim_in, + out_channels=dim_in, + kernel_size=scale_factor * 2, + stride=scale_factor, + padding=scale_factor // 2, + bias=False, + ) + elif upsampling_mode in ("nearest", "bilinear"): + return nn.Upsample(scale_factor=scale_factor, mode=upsampling_mode) + else: + raise ValueError(f"Invalid upsampling mode {upsampling_mode}.") + + +class ResidualBlock(nn.Module): + + def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None: + super().__init__() + self.residual = residual + self.shortcut = shortcut + + def forward(self, x: torch.Tensor) -> torch.Tensor: + delta_x = self.residual(x) + + if self.shortcut is not None: + x = self.shortcut(x) + + return x + delta_x + + +def residual_block_2d( + dim_in: int, + dim_out: int, + dim_hidden: int | None = None, + actvn: nn.Module | None = None, + norm_type: NormLayerName = "noop", + norm_num_groups: int = 8, + dilation: int = 1, + kernel_size: int = 3, +): + if actvn is None: + actvn = nn.ReLU() + + if dim_hidden is None: + dim_hidden = dim_out // 2 + + padding = (dilation * (kernel_size - 1)) // 2 + + def _create_block(dim_in: int, dim_out: int) -> list[nn.Module]: + layers = [ + norm_layer_2d(dim_in, norm_type, num_groups=norm_num_groups), + actvn, + ] + + layers.append( + nn.Conv2d( + dim_in, + dim_out, + kernel_size=kernel_size, + stride=1, + dilation=dilation, + padding=padding, + ) + ) + return layers + + residual = nn.Sequential( + *_create_block(dim_in, dim_hidden), + *_create_block(dim_hidden, dim_out), + ) + shortcut = None + + if dim_in != dim_out: + shortcut = nn.Conv2d(dim_in, dim_out, 1) + + return ResidualBlock(residual, shortcut) + + +class FeatureFusionBlock2d(nn.Module): + + deconv: nn.Module + + def __init__( + self, + dim_in: int, + dim_out: int | None = None, + upsampling_mode: UpsamplingMode | None = None, + batch_norm: bool = False, + ): + super().__init__() + if dim_out is None: + dim_out = dim_in + self.resnet1 = self._residual_block(dim_in, batch_norm) + self.resnet2 = self._residual_block(dim_in, batch_norm) + + if upsampling_mode is not None: + self.deconv = upsampling_layer(upsampling_mode, scale_factor=2, dim_in=dim_in) + else: + self.deconv = nn.Sequential() + + self.out_conv = nn.Conv2d( + dim_in, + dim_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor: + x = x0 + + if x1 is not None: + res = self.resnet1(x1) + x = self.skip_add.add(x, res) + + x = self.resnet2(x) + x = self.deconv(x) + x = self.out_conv(x) + + return x + + @staticmethod + def _residual_block(num_features: int, batch_norm: bool): + + def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]: + layers = [ + nn.ReLU(False), + nn.Conv2d( + num_features, + num_features, + kernel_size=3, + stride=1, + padding=1, + bias=not batch_norm, + ), + ] + if batch_norm: + layers.append(nn.BatchNorm2d(dim)) + return layers + + residual = nn.Sequential( + *_create_block(dim=num_features, batch_norm=batch_norm), + *_create_block(dim=num_features, batch_norm=batch_norm), + ) + return ResidualBlock(residual) diff --git a/unisharp/models/decoder.py b/unisharp/models/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..65d5f53b92acb92234d345f6b8ed19f1435c7d4b --- /dev/null +++ b/unisharp/models/decoder.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import abc +from typing import Iterable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from unisharp.models.blocks import FeatureFusionBlock2d, UpsamplingMode + + +class BaseDecoder(nn.Module, abc.ABC): + dim_out: int + + @abc.abstractmethod + def forward(self, encodings: list[torch.Tensor]) -> torch.Tensor: + pass + + +class MultiresConvDecoder(BaseDecoder): + def __init__( + self, + dims_encoder: Iterable[int], + dims_decoder: Iterable[int] | int, + grad_checkpointing: bool = False, + upsampling_mode: UpsamplingMode = "transposed_conv", + ): + super().__init__() + self.dims_encoder = list(dims_encoder) + + if isinstance(dims_decoder, int): + self.dims_decoder = [dims_decoder] * len(self.dims_encoder) + else: + self.dims_decoder = list(dims_decoder) + + if len(self.dims_decoder) != len(self.dims_encoder): + raise ValueError("Received dims_encoder and dims_decoder of different sizes.") + + self.dim_out = self.dims_decoder[0] + num_encoders = len(self.dims_encoder) + conv0 = ( + nn.Conv2d(self.dims_encoder[0], self.dims_decoder[0], kernel_size=1, bias=False) + if self.dims_encoder[0] != self.dims_decoder[0] + else nn.Identity() + ) + + convs = [conv0] + for i in range(1, num_encoders): + convs.append( + nn.Conv2d( + self.dims_encoder[i], + self.dims_decoder[i], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + ) + self.convs = nn.ModuleList(convs) + + fusions = [] + for i in range(num_encoders): + fusions.append( + FeatureFusionBlock2d( + dim_in=self.dims_decoder[i], + dim_out=self.dims_decoder[i - 1] if i != 0 else self.dim_out, + upsampling_mode=upsampling_mode if i != 0 else None, + batch_norm=False, + ) + ) + self.fusions = nn.ModuleList(fusions) + self.grad_checkpointing = grad_checkpointing + + @torch.jit.ignore + def set_grad_checkpointing(self, is_enabled=True): + self.grad_checkpointing = is_enabled + + def _checkpoint(self, fn, *args): + if self.grad_checkpointing: + return checkpoint(fn, *args, use_reentrant=False) + return fn(*args) + + def forward(self, encodings: list[torch.Tensor]) -> torch.Tensor: + num_levels = len(encodings) + num_encoders = len(self.dims_encoder) + if num_levels != num_encoders: + raise ValueError( + f"Encoder output levels={num_levels} at runtime " + f"mismatch with expected levels={num_encoders}." + ) + + features = self.convs[-1](encodings[-1]) + features = self._checkpoint(self.fusions[-1], features) + for i in range(num_levels - 2, -1, -1): + features_i = self.convs[i](encodings[i]) + features = self._checkpoint(self.fusions[i], features, features_i) + return features diff --git a/unisharp/models/feature_gaussian_decoder.py b/unisharp/models/feature_gaussian_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a91bb39cbbd209f9f8d298d61a859e542b4a05 --- /dev/null +++ b/unisharp/models/feature_gaussian_decoder.py @@ -0,0 +1,412 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import NamedTuple + +import torch +from torch import nn +from torch.nn import functional as F + +from unisharp.models.blocks import FeatureFusionBlock2d, NormLayerName, residual_block_2d +from unisharp.models.decoder import MultiresConvDecoder + + +LOGGER = logging.getLogger(__name__) + + +class ImageFeatures(NamedTuple): + + texture_features: torch.Tensor + geometry_features: torch.Tensor + + +class CircularAwareConv2d(nn.Conv2d): + + circular_horizontal: bool + + @classmethod + def from_conv2d(cls, conv: nn.Conv2d) -> "CircularAwareConv2d": + out = cls( + in_channels=conv.in_channels, + out_channels=conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + bias=conv.bias is not None, + padding_mode=conv.padding_mode, + device=conv.weight.device, + dtype=conv.weight.dtype, + ) + with torch.no_grad(): + out.weight.copy_(conv.weight) + if conv.bias is not None and out.bias is not None: + out.bias.copy_(conv.bias) + return out + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self.circular_horizontal = False + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if not bool(self.circular_horizontal): + return super().forward(input) + pad_h, pad_w = self._padding_hw() + if pad_h == 0 and pad_w == 0: + return F.conv2d( + input, + self.weight, + self.bias, + self.stride, + 0, + self.dilation, + self.groups, + ) + x = input + if pad_w > 0: + x = F.pad(x, (pad_w, pad_w, 0, 0), mode="circular") + if pad_h > 0: + x = F.pad(x, (0, 0, pad_h, pad_h), mode="constant", value=0.0) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + 0, + self.dilation, + self.groups, + ) + + def _padding_hw(self) -> tuple[int, int]: + if isinstance(self.padding, tuple): + if len(self.padding) == 2: + return int(self.padding[0]), int(self.padding[1]) + return int(self.padding[0]), int(self.padding[0]) + return int(self.padding), int(self.padding) + + +def _convert_conv2d_modules_to_circular_aware(module: nn.Module) -> None: + for name, child in list(module.named_children()): + if isinstance(child, nn.Conv2d) and not isinstance(child, CircularAwareConv2d): + setattr(module, name, CircularAwareConv2d.from_conv2d(child)) + else: + _convert_conv2d_modules_to_circular_aware(child) + + +def _set_circular_horizontal(module: nn.Module, enabled: bool) -> None: + for child in module.modules(): + if hasattr(child, "circular_horizontal"): + child.circular_horizontal = bool(enabled) + + +@dataclass +class FeatureGaussianDecoderParams: + + dims_3d_in: tuple[int, int, int, int] = (128, 256, 512, 512) + dims_3d_out: tuple[int, int, int, int] = (256, 512, 1024, 1024) + + dim_2d_in: int = 1024 + dim_2d_out: int = 256 + + dim_decoder_out: int = 256 + + dim_texture_out: int = 32 + dim_geometry_out: int = 32 + norm_type: NormLayerName = "group_norm" + norm_num_groups: int = 8 + + stride_out: int = 2 + + use_learned_upsampling: bool = False + target_resolution: tuple[int, int] | None = None + + +class Feature2DEncoder(nn.Module): + + def __init__( + self, + dim_in: int = 1024, + dim_out: int = 256, + ): + super().__init__() + + self.process = nn.Sequential( + nn.Conv2d(dim_in, 512, kernel_size=3, padding=1), + nn.GroupNorm(8, 512), + nn.GELU(), + + nn.Conv2d(512, dim_out, kernel_size=3, padding=1), + nn.GroupNorm(8, dim_out), + nn.GELU(), + ) + + self.dim_out = dim_out + + def forward( + self, + x: torch.Tensor, + target_h: int, + target_w: int, + ) -> torch.Tensor: + x = self.process(x) + x = torch.nn.functional.interpolate( + x, + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ) + + return x + + +class Feature3DProjector(nn.Module): + + def __init__( + self, + dims_in: list[int], + dims_out: list[int], + ): + super().__init__() + + if len(dims_in) != len(dims_out): + raise ValueError( + f"dims_in and dims_out must have same length, " + f"got {len(dims_in)} vs {len(dims_out)}" + ) + + self.projectors = nn.ModuleList([ + nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=False) + for dim_in, dim_out in zip(dims_in, dims_out) + ]) + + self.dims_out = dims_out + self.num_levels = len(dims_in) + + def forward(self, pyramid_features: list[torch.Tensor]) -> list[torch.Tensor]: + if len(pyramid_features) != self.num_levels: + raise ValueError( + f"Expected {self.num_levels} pyramid features, got {len(pyramid_features)}" + ) + + return [proj(feat) for proj, feat in zip(self.projectors, pyramid_features)] + + +class LearnedUpsampler(nn.Module): + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.bn1 = nn.GroupNorm(8, in_channels) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.bn2 = nn.GroupNorm(8, out_channels) + + self.skip = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor, target_h: int, target_w: int) -> torch.Tensor: + x_upsampled = torch.nn.functional.interpolate( + x, size=(target_h, target_w), mode="bilinear", align_corners=False + ) + + identity = self.skip(x_upsampled) + + out = self.conv1(x_upsampled) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out = out + identity + out = self.relu(out) + + return out + + +def _create_project_upsample_block(dim_in: int, dim_out: int, upsample_layers: int) -> nn.Module: + blocks: list[nn.Module] = [ + nn.Conv2d( + in_channels=dim_in, + out_channels=dim_out, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + ] + blocks.extend( + nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + CircularAwareConv2d( + in_channels=dim_out, + out_channels=dim_out, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + ) + for _ in range(int(upsample_layers)) + ) + return nn.Sequential(*blocks) + + +class FeatureGaussianDecoder(nn.Module): + + def __init__( + self, + params: FeatureGaussianDecoderParams, + ): + super().__init__() + + self.params = params + self.stride_out = params.stride_out + self.norm_type = params.norm_type + self.norm_num_groups = int(params.norm_num_groups) + if int(self.stride_out) not in (1, 2): + raise ValueError(f"FeatureGaussianDecoder only supports stride_out 1 or 2, got {self.stride_out}") + + self.feature_3d_projector = Feature3DProjector( + dims_in=list(params.dims_3d_in), + dims_out=list(params.dims_3d_out), + ) + + self.decoder = MultiresConvDecoder( + dims_encoder=list(params.dims_3d_out), + dims_decoder=params.dim_decoder_out, + ) + if int(self.stride_out) == 1: + self.upsample = _create_project_upsample_block( + params.dim_decoder_out, + params.dim_decoder_out, + upsample_layers=1, + ) + else: + self.upsample = nn.Identity() + + self.feature_2d_encoder = Feature2DEncoder( + dim_in=params.dim_2d_in, + dim_out=params.dim_2d_out, + ) + + self.fusion = FeatureFusionBlock2d( + params.dim_decoder_out, + params.dim_2d_out, + ) + + self.texture_head = self._create_head( + params.dim_decoder_out, + params.dim_texture_out, + ) + + self.geometry_head = self._create_head( + params.dim_decoder_out, + params.dim_geometry_out, + ) + + if int(params.dim_2d_out) != int(params.dim_decoder_out): + raise ValueError( + "FeatureFusionBlock2d requires 2D skip channels to match decoder channels, " + f"got dim_2d_out={params.dim_2d_out}, dim_decoder_out={params.dim_decoder_out}" + ) + + self.dim_out = params.dim_texture_out + + self.fused_upsampler = None + + _convert_conv2d_modules_to_circular_aware(self) + + def _create_head(self, dim_in: int, dim_out: int) -> nn.Module: + return nn.Sequential( + residual_block_2d( + dim_in=dim_in, + dim_out=dim_in, + dim_hidden=dim_in // 2, + norm_type=self.norm_type, + norm_num_groups=self.norm_num_groups, + ), + residual_block_2d( + dim_in=dim_in, + dim_hidden=dim_in // 2, + dim_out=dim_in, + norm_type=self.norm_type, + norm_num_groups=self.norm_num_groups, + ), + nn.ReLU(), + nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1), + nn.ReLU(), + ) + + def forward( + self, + features_2d: torch.Tensor, + features_3d_pyramid: list[torch.Tensor], + *, + circular_horizontal: bool = False, + target_hw: tuple[int, int] | None = None, + ) -> ImageFeatures: + _set_circular_horizontal(self, bool(circular_horizontal)) + features_3d_sorted = sorted( + features_3d_pyramid, + key=lambda t: int(t.shape[-2]) * int(t.shape[-1]), + reverse=True + ) + + pyramid_projected = self.feature_3d_projector(features_3d_sorted) + + decoder_out = self.decoder(pyramid_projected).contiguous() + decoder_out = self.upsample(decoder_out).contiguous() + if target_hw is not None: + target_h, target_w = int(target_hw[0]), int(target_hw[1]) + if target_h <= 0 or target_w <= 0: + raise ValueError(f"target_hw must be positive, got {target_hw}") + if tuple(decoder_out.shape[-2:]) != (target_h, target_w): + decoder_out = F.interpolate( + decoder_out, + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ).contiguous() + + target_h, target_w = decoder_out.shape[-2:] + features_2d_proj = self.feature_2d_encoder( + features_2d, + target_h=target_h, + target_w=target_w, + ) + + fused = self.fusion(decoder_out, features_2d_proj) + + if target_hw is not None: + target_h, target_w = int(target_hw[0]), int(target_hw[1]) + if tuple(fused.shape[-2:]) != (target_h, target_w): + raise RuntimeError( + "Feature decoder grid must match base Gaussian grid before heads, " + f"got fused={tuple(fused.shape[-2:])} target={(target_h, target_w)}. " + "Only high-channel decoder features may be adapted before fusion." + ) + + texture_features = self.texture_head(fused) + geometry_features = self.geometry_head(fused) + + return ImageFeatures( + texture_features=texture_features, + geometry_features=geometry_features, + ) + + @property + def stride(self) -> int: + return self.stride_out + + +def create_feature_gaussian_decoder( + params: FeatureGaussianDecoderParams | None = None, +) -> FeatureGaussianDecoder: + if params is None: + params = FeatureGaussianDecoderParams() + + return FeatureGaussianDecoder(params) diff --git a/unisharp/models/gaussian_composer.py b/unisharp/models/gaussian_composer.py new file mode 100644 index 0000000000000000000000000000000000000000..38317e496a1036aadfdffb58b2e5687eed8083e5 --- /dev/null +++ b/unisharp/models/gaussian_composer.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import torch +from torch import nn +from torch.nn import functional as F + +from unisharp.utils import math as math_utils +from unisharp.utils.color_space import ColorSpace, sRGB2linearRGB +from unisharp.utils.gaussians import Gaussians3D + +from .gaussian_initializer import PanoGaussianBaseValues, _build_tangent_basis +from .unisharp_params import DeltaFactor + + +def _safe_normalize(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + return x / x.norm(dim=1, keepdim=True).clamp(min=eps) + + +def _infer_circular_horizontal(rays_2d: torch.Tensor) -> bool: + if rays_2d.ndim != 4 or int(rays_2d.shape[-1]) < 4: + return False + r = _safe_normalize(rays_2d.detach().to(dtype=torch.float32)) + seam_dot = (r[..., 0] * r[..., -1]).sum(dim=1).clamp(-1.0, 1.0) + seam_cross = torch.linalg.vector_norm(torch.cross(r[..., 0], r[..., -1], dim=1), dim=1) + seam_angle = torch.atan2(seam_cross, seam_dot) + inner_dot = (r[..., :-1] * r[..., 1:]).sum(dim=1).clamp(-1.0, 1.0) + inner_cross = torch.linalg.vector_norm(torch.cross(r[..., :-1], r[..., 1:], dim=1), dim=1) + inner_angle = torch.atan2(inner_cross, inner_dot) + typical = torch.nanmedian(inner_angle).clamp(min=1e-6) + return bool(torch.isfinite(seam_angle).all() and float(torch.nanmedian(seam_angle).item()) < 4.0 * float(typical.item())) + + +class PanoGaussianComposer(nn.Module): + + def __init__( + self, + delta_factor: DeltaFactor, + min_scale: float, + max_scale: float, + color_activation_type: math_utils.ActivationType, + opacity_activation_type: math_utils.ActivationType, + color_space: ColorSpace, + base_scale_on_predicted_mean: bool = True, + scale_factor: int = 1, + delta_rho_limit: float = 2.0, + ) -> None: + super().__init__() + self.delta_factor = delta_factor + self.min_scale = float(min_scale) + self.max_scale = float(max_scale) + self.color_activation_type = color_activation_type + self.opacity_activation_type = opacity_activation_type + self.color_space = color_space + self.base_scale_on_predicted_mean = bool(base_scale_on_predicted_mean) + self.scale_factor = int(max(1, scale_factor)) + self.delta_rho_limit = float(delta_rho_limit) + + ( + self.ray_scale_min_factor, + self.ray_scale_max_factor, + ) = self._ray_cell_coverage_scale_bounds() + ( + self.ray_radial_min_factor, + self.ray_radial_max_factor, + ) = self._ray_radial_scale_bounds() + ( + self.ray_anisotropy_min_factor, + self.ray_anisotropy_max_factor, + ) = self._ray_tangent_anisotropy_bounds() + ( + self._ray_scale_const_a, + self._ray_scale_const_b, + ) = self._get_scale_activation_constant( + self.ray_scale_max_factor, + self.ray_scale_min_factor, + ) + ( + self._ray_radial_const_a, + self._ray_radial_const_b, + ) = self._get_scale_activation_constant( + self.ray_radial_max_factor, + self.ray_radial_min_factor, + ) + ( + self._ray_anisotropy_const_a, + self._ray_anisotropy_const_b, + ) = self._get_scale_activation_constant( + self.ray_anisotropy_max_factor, + self.ray_anisotropy_min_factor, + ) + + @staticmethod + def _get_scale_activation_constant(max_scale: float, min_scale: float) -> tuple[float, float]: + constant_a = (max_scale - min_scale) / (1 - min_scale) / (max_scale - 1) + constant_b = math_utils.inverse_sigmoid( + torch.tensor((1.0 - min_scale) / (max_scale - min_scale)) + ).item() + return constant_a, constant_b + + @staticmethod + def _ray_cell_coverage_scale_bounds() -> tuple[float, float]: + min_factor = 0.1 + max_factor = 5.0 + return float(min_factor), float(max_factor) + + @staticmethod + def _ray_radial_scale_bounds() -> tuple[float, float]: + return 0.1, 5.0 + + @staticmethod + def _ray_tangent_anisotropy_bounds() -> tuple[float, float]: + max_factor = 4.0 + min_factor = 1.0 / max_factor + return float(min_factor), float(max_factor) + + @staticmethod + def _smooth_scale_delta( + learned_delta_scale: torch.Tensor, + *, + circular_horizontal: bool, + kernel_size: int = 9, + ) -> torch.Tensor: + k = int(kernel_size) + if k <= 1: + return learned_delta_scale + pad = k // 2 + b, c, l, h, w = learned_delta_scale.shape + x = learned_delta_scale.permute(0, 2, 1, 3, 4).reshape(b * l, c, h, w) + if bool(circular_horizontal): + x = F.pad(x, (pad, pad, 0, 0), mode="circular") + x = F.pad(x, (0, 0, pad, pad), mode="replicate") + else: + x = F.pad(x, (pad, pad, pad, pad), mode="replicate") + x = F.avg_pool2d(x, kernel_size=k, stride=1) + return x.reshape(b, l, c, h, w).permute(0, 2, 1, 3, 4).contiguous() + + def apply_delta_rho(self, learned_delta_rho: torch.Tensor) -> torch.Tensor: + raw_d_rho = self.delta_factor.z * learned_delta_rho + limit = float(self.delta_rho_limit) + if limit <= 0.0: + return raw_d_rho + return limit * torch.tanh(raw_d_rho / limit) + + def apply_scale_factor(self, learned_delta_scale: torch.Tensor) -> torch.Tensor: + tangent = (self.ray_scale_max_factor - self.ray_scale_min_factor) * torch.sigmoid( + self._ray_scale_const_a * self.delta_factor.scale * learned_delta_scale[:, 0:1] + + self._ray_scale_const_b + ) + self.ray_scale_min_factor + anisotropy = ( + (self.ray_anisotropy_max_factor - self.ray_anisotropy_min_factor) + * torch.sigmoid( + self._ray_anisotropy_const_a + * self.delta_factor.scale + * learned_delta_scale[:, 1:2] + + self._ray_anisotropy_const_b + ) + + self.ray_anisotropy_min_factor + ) + radial = (self.ray_radial_max_factor - self.ray_radial_min_factor) * torch.sigmoid( + self._ray_radial_const_a * self.delta_factor.scale * learned_delta_scale[:, 2:3] + + self._ray_radial_const_b + ) + self.ray_radial_min_factor + + sqrt_anisotropy = anisotropy.sqrt() + tangent_u = (tangent * sqrt_anisotropy).clamp( + min=self.ray_scale_min_factor, + max=self.ray_scale_max_factor, + ) + tangent_v = (tangent / sqrt_anisotropy.clamp(min=1e-6)).clamp( + min=self.ray_scale_min_factor, + max=self.ray_scale_max_factor, + ) + return torch.cat([tangent_u, tangent_v, radial], dim=1) + + def forward( + self, + delta: torch.Tensor, + base_values: PanoGaussianBaseValues, + global_scale: torch.Tensor | None = None, + flatten_output: bool = True, + ) -> Gaussians3D: + if delta.ndim != 5 or int(delta.shape[1]) != 14: + raise ValueError(f"Expected delta shape [B,14,L,H,W], got {tuple(delta.shape)}") + if int(delta.shape[2]) != int(base_values.rays.shape[2]): + raise ValueError( + "Delta layer count must match base Gaussian layers, " + f"got delta L={int(delta.shape[2])} base L={int(base_values.rays.shape[2])}" + ) + base_h, base_w = int(base_values.rays.shape[-2]), int(base_values.rays.shape[-1]) + delta_h, delta_w = int(delta.shape[-2]), int(delta.shape[-1]) + if (delta_h, delta_w) != (base_h, base_w): + raise ValueError( + "Delta grid must match base Gaussian grid before composition, " + f"got delta={(delta_h, delta_w)} base={(base_h, base_w)}" + ) + + mean = self._forward_mean(base_values, delta[:, 0:3]) + + base_scales = base_values.scales + if self.base_scale_on_predicted_mean: + radius_pred = mean.norm(dim=1, keepdim=True).clamp(min=1e-4) + radius_base_inv = base_values.inv_distance.clamp(min=1e-6) + scale_ratio = (radius_pred * radius_base_inv).clamp(min=1e-6) + base_scales = base_scales * scale_ratio + scale_delta = self._smooth_scale_delta( + delta[:, 3:6], + circular_horizontal=_infer_circular_horizontal(base_values.rays[:, :, 0]), + ) + scales = self._scale_activation(base_scales, scale_delta) + quat_raw = base_values.quaternions + self.delta_factor.quaternion * delta[:, 6:10] + quat_norm = quat_raw.norm(dim=1, keepdim=True) + base_quats = base_values.quaternions / base_values.quaternions.norm(dim=1, keepdim=True).clamp(min=1e-8) + quats = torch.where(quat_norm > 1e-8, quat_raw / quat_norm.clamp(min=1e-8), base_quats) + colors = self._color_activation(base_values.colors, delta[:, 10:13]) + opacities = self._opacity_activation(base_values.opacities, delta[:, 13:14]) + + if flatten_output: + mean = mean.permute(0, 2, 3, 4, 1).flatten(1, 3) + scales = scales.permute(0, 2, 3, 4, 1).flatten(1, 3) + quats = quats.permute(0, 2, 3, 4, 1).flatten(1, 3) + colors = colors.permute(0, 2, 3, 4, 1).flatten(1, 3) + opacities = opacities.squeeze(1).flatten(1) + + if global_scale is not None: + mean = mean * global_scale[:, None, None] + scales = scales * global_scale[:, None, None] + + return Gaussians3D( + mean_vectors=mean, + singular_values=scales, + quaternions=quats, + colors=colors, + opacities=opacities, + ) + + def _forward_mean(self, base: PanoGaussianBaseValues, learned_delta: torch.Tensor) -> torch.Tensor: + rays = _safe_normalize(base.rays) + b, _, l, h, w = rays.shape + rays_2d = rays[:, :, 0] + e1_2d, e2_2d = _build_tangent_basis( + rays_2d, + circular_horizontal=_infer_circular_horizontal(rays_2d), + ) + e1 = e1_2d[:, :, None].expand(b, -1, l, h, w) + e2 = e2_2d[:, :, None].expand(b, -1, l, h, w) + angular_cell = base.angular_cell.to(device=learned_delta.device, dtype=learned_delta.dtype) + if angular_cell.ndim != 5 or int(angular_cell.shape[1]) != 2: + raise ValueError(f"Expected angular_cell shape [B,2,1,H,W], got {tuple(angular_cell.shape)}") + cell_u = angular_cell[:, 0:1] + cell_v = angular_cell[:, 1:2] + + du = self.delta_factor.xy * learned_delta[:, 0:1] * cell_u + dv = self.delta_factor.xy * learned_delta[:, 1:2] * cell_v + d_rho = self.apply_delta_rho(learned_delta[:, 2:3]) + + rho0 = base.inv_distance + rho = F.softplus(math_utils.inverse_softplus(rho0.clamp(min=1e-6)) + d_rho) + r = 1.0 / (rho + 1e-4) + + ray_new = rays + du * e1 + dv * e2 + ray_new = _safe_normalize(ray_new) + + return r * ray_new + + def _scale_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: + scale_factor = self.apply_scale_factor(learned_delta) + return base * scale_factor + + def _color_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: + if self.color_activation_type == "sigmoid": + base = torch.clamp(base, min=0.01, max=0.99) + elif self.color_activation_type in ("exp", "softplus"): + base = torch.clamp(base, min=0.01) + + activation = math_utils.create_activation_pair(self.color_activation_type) + colors: torch.Tensor = activation.forward( + activation.inverse(base) + self.delta_factor.color * learned_delta + ) + if self.color_space == "linearRGB": + colors = sRGB2linearRGB(colors) + return colors + + def _opacity_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor: + activation = math_utils.create_activation_pair(self.opacity_activation_type) + return activation.forward( + activation.inverse(base) + self.delta_factor.opacity * learned_delta + ) + diff --git a/unisharp/models/gaussian_initializer.py b/unisharp/models/gaussian_initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..b90fad4d650e593ffc66bb9364113b4585255e30 --- /dev/null +++ b/unisharp/models/gaussian_initializer.py @@ -0,0 +1,452 @@ +from __future__ import annotations + +from typing import NamedTuple + +import torch +from torch import nn +from torch.nn import functional as F + +from .unisharp_params import PanoInitializerParams + + +class PanoGaussianBaseValues(NamedTuple): + + rays: torch.Tensor + inv_distance: torch.Tensor + angular_cell: torch.Tensor + scales: torch.Tensor + quaternions: torch.Tensor + colors: torch.Tensor + opacities: torch.Tensor + + +class PanoInitializerOutput(NamedTuple): + + gaussian_base_values: PanoGaussianBaseValues + feature_input: torch.Tensor + global_scale: torch.Tensor | None = None + grid_cell_size: torch.Tensor | None = None + + +def _rescale_distance( + distance: torch.Tensor, + dist_min: float = 1.0, + dist_max: float = 1e2, + scale_quantile: float = 0.02, + scale_floor: float = 0.1, +) -> tuple[torch.Tensor, torch.Tensor]: + sample = distance + if int(sample.shape[-2]) > 256 or int(sample.shape[-1]) > 256: + step_h = max(1, int(sample.shape[-2]) // 256) + step_w = max(1, int(sample.shape[-1]) // 256) + sample = sample[..., ::step_h, ::step_w] + + flat = sample.flatten(sample.ndim - 3) + finite_positive = torch.isfinite(flat) & (flat > 0.0) + safe_flat = torch.where(finite_positive, flat, torch.full_like(flat, float(dist_max))) + k = max(1, min(int(safe_flat.shape[-1]), int(round(float(scale_quantile) * float(safe_flat.shape[-1]))))) + robust_min = safe_flat.kthvalue(k, dim=-1).values + robust_min = robust_min.clamp(min=float(scale_floor), max=float(dist_max)) + factor = dist_min / (robust_min + 1e-6) + distance = (distance * factor[..., None, None, None]).clamp(max=dist_max) + return distance, factor + + +def _downsample_avg(x: torch.Tensor, stride: int, circular_horizontal: bool = True) -> torch.Tensor: + if stride == 1: + return x + del circular_horizontal + return F.avg_pool2d(x, kernel_size=stride, stride=stride) + + +def _resize_to_grid( + x: torch.Tensor, + target_hw: tuple[int, int] | None, + *, + mode: str, +) -> torch.Tensor: + if target_hw is None: + return x + target_h, target_w = int(target_hw[0]), int(target_hw[1]) + if target_h <= 0 or target_w <= 0: + raise ValueError(f"target_hw must be positive, got {target_hw}") + if tuple(x.shape[-2:]) == (target_h, target_w): + return x + if mode in {"bilinear", "bicubic"}: + return F.interpolate(x, size=(target_h, target_w), mode=mode, align_corners=False) + return F.interpolate(x, size=(target_h, target_w), mode=mode) + + +def _grid_cell_size_uv( + *, + batch_size: int, + image_h: int, + image_w: int, + stride: float, + circular_horizontal: bool, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + if bool(circular_horizontal): + cell_w = (2.0 * torch.pi * float(stride)) / float(max(int(image_w), 1)) + cell_h = (torch.pi * float(stride)) / float(max(int(image_h), 1)) + else: + cell_w = float(stride) / float(max(int(image_w), 1)) + cell_h = float(stride) / float(max(int(image_h), 1)) + return torch.tensor( + [float(cell_w), float(cell_h)], + device=device, + dtype=dtype, + ).view(1, 2, 1, 1, 1).expand(int(batch_size), -1, -1, -1, -1) + + +def _format_cell_size_override_uv( + value: torch.Tensor, + *, + batch_size: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + cell = value.to(device=device, dtype=dtype) + if cell.ndim == 0: + cell = cell.view(1, 1).expand(int(batch_size), 2) + elif cell.ndim > 1: + cell = cell.reshape(int(cell.shape[0]), -1) + if int(cell.shape[1]) == 1: + cell = cell.expand(-1, 2) + else: + cell = cell[:, :2] + else: + cell = cell.reshape(-1) + if int(cell.numel()) == 1: + cell = cell.view(1, 1).expand(int(batch_size), 2) + elif int(cell.numel()) == 2 and int(batch_size) > 1: + cell = cell.view(1, 2).expand(int(batch_size), 2) + else: + cell = cell.view(int(batch_size), -1) + if int(cell.shape[1]) == 1: + cell = cell.expand(-1, 2) + if int(cell.shape[0]) != int(batch_size): + raise ValueError(f"grid_cell_size_override must have batch size {int(batch_size)}, got {int(cell.shape[0])}") + return cell.clamp(min=1e-6).view(int(batch_size), 2, 1, 1, 1) + + +def _safe_normalize(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: + return x / x.norm(dim=1, keepdim=True).clamp(min=eps) + + +def _angle_between_unit(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + dot = (a * b).sum(dim=1, keepdim=True).clamp(-1.0, 1.0) + cross = torch.linalg.vector_norm(torch.cross(a, b, dim=1), dim=1, keepdim=True) + return torch.atan2(cross, dot).clamp(min=1e-6) + + +def _ray_angular_cell(rays: torch.Tensor, *, circular_horizontal: bool) -> torch.Tensor: + r = _safe_normalize(rays) + b, _, h, w = r.shape + if w > 1: + if bool(circular_horizontal): + angle_right = _angle_between_unit(r, torch.roll(r, shifts=-1, dims=-1)) + angle_left = _angle_between_unit(r, torch.roll(r, shifts=1, dims=-1)) + cell_u = 0.5 * (angle_right + angle_left) + else: + pair_u = _angle_between_unit(r[..., :-1], r[..., 1:]) + right = F.pad(pair_u, (0, 1, 0, 0), mode="replicate") + left = F.pad(pair_u, (1, 0, 0, 0), mode="replicate") + cell_u = 0.5 * (left + right) + else: + cell_u = torch.full((b, 1, h, w), 1e-3, device=r.device, dtype=r.dtype) + if h > 1: + pair_v = _angle_between_unit(r[..., :-1, :], r[..., 1:, :]) + down = F.pad(pair_v, (0, 0, 0, 1), mode="replicate") + up = F.pad(pair_v, (0, 0, 1, 0), mode="replicate") + cell_v = 0.5 * (up + down) + else: + cell_v = torch.full((b, 1, h, w), 1e-3, device=r.device, dtype=r.dtype) + return torch.cat([cell_u, cell_v], dim=1).clamp(min=1e-6, max=0.25).unsqueeze(2) + + +def _smooth_angular_cell( + cell: torch.Tensor, + *, + kernel_size: int = 9, + circular_horizontal: bool, +) -> torch.Tensor: + k = int(kernel_size) + if k <= 1: + return cell + if k % 2 == 0: + k += 1 + pad = k // 2 + b, c, one, h, w = cell.shape + if one != 1: + raise ValueError(f"Expected angular cell shape [B,2,1,H,W], got {tuple(cell.shape)}") + x = cell.reshape(b * c, 1, h, w) + if pad > 0: + if bool(circular_horizontal): + x = F.pad(x, (pad, pad, 0, 0), mode="circular") + else: + x = F.pad(x, (pad, pad, 0, 0), mode="replicate") + x = F.pad(x, (0, 0, pad, pad), mode="replicate") + x = F.avg_pool2d(x, kernel_size=k, stride=1) + return x.reshape(b, c, one, h, w).clamp(min=1e-6, max=0.25) + + +def _fallback_tangent_basis(rays: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + b, _, h, w = rays.shape + r = rays.reshape(b, 3, -1) + up1 = torch.tensor([0.0, 1.0, 0.0], device=r.device, dtype=r.dtype)[:, None] + up2 = torch.tensor([1.0, 0.0, 0.0], device=r.device, dtype=r.dtype)[:, None] + dot = (r * up1).sum(dim=1, keepdim=True).abs() + up = torch.where(dot > 0.9, up2.expand_as(up1), up1.expand_as(up1)) + e_u = _safe_normalize(torch.cross(up.expand_as(r), r, dim=1)) + e_v = _safe_normalize(torch.cross(r, e_u, dim=1)) + return e_u.reshape(b, 3, h, w), e_v.reshape(b, 3, h, w) + + +def _central_ray_difference(rays: torch.Tensor, *, dim: int, circular: bool) -> torch.Tensor: + if int(rays.shape[dim]) <= 1: + return torch.zeros_like(rays) + if bool(circular): + return 0.5 * ( + torch.roll(rays, shifts=-1, dims=dim) - torch.roll(rays, shifts=1, dims=dim) + ) + forward = torch.roll(rays, shifts=-1, dims=dim) + backward = torch.roll(rays, shifts=1, dims=dim) + diff = 0.5 * (forward - backward) + sl_first = [slice(None)] * rays.ndim + sl_first[dim] = 0 + sl_second = [slice(None)] * rays.ndim + sl_second[dim] = 1 + diff[tuple(sl_first)] = rays[tuple(sl_second)] - rays[tuple(sl_first)] + sl_last = [slice(None)] * rays.ndim + sl_last[dim] = -1 + sl_prev = [slice(None)] * rays.ndim + sl_prev[dim] = -2 + diff[tuple(sl_last)] = rays[tuple(sl_last)] - rays[tuple(sl_prev)] + return diff + + +def _build_tangent_basis(rays: torch.Tensor, *, circular_horizontal: bool) -> tuple[torch.Tensor, torch.Tensor]: + r = _safe_normalize(rays) + fallback_u, fallback_v = _fallback_tangent_basis(r) + + du = _central_ray_difference(r, dim=-1, circular=bool(circular_horizontal)) + du = du - (du * r).sum(dim=1, keepdim=True) * r + du_norm = du.norm(dim=1, keepdim=True) + e_u = torch.where(du_norm > 1e-7, du / du_norm.clamp(min=1e-7), fallback_u) + e_u = _safe_normalize(e_u) + + dv = _central_ray_difference(r, dim=-2, circular=False) + dv = dv - (dv * r).sum(dim=1, keepdim=True) * r + dv = dv - (dv * e_u).sum(dim=1, keepdim=True) * e_u + dv_norm = dv.norm(dim=1, keepdim=True) + e_v_fallback = _safe_normalize(torch.cross(r, e_u, dim=1)) + e_v = torch.where(dv_norm > 1e-7, dv / dv_norm.clamp(min=1e-7), e_v_fallback) + handed = (torch.cross(e_u, e_v, dim=1) * r).sum(dim=1, keepdim=True) + e_v = torch.where(handed < 0.0, -e_v, e_v) + e_v = torch.where(torch.isfinite(e_v).all(dim=1, keepdim=True), e_v, fallback_v) + return e_u, _safe_normalize(e_v) + + +def _rotmat_to_quat_wxyz(rot: torch.Tensor) -> torch.Tensor: + m00 = rot[:, 0, 0] + m01 = rot[:, 0, 1] + m02 = rot[:, 0, 2] + m10 = rot[:, 1, 0] + m11 = rot[:, 1, 1] + m12 = rot[:, 1, 2] + m20 = rot[:, 2, 0] + m21 = rot[:, 2, 1] + m22 = rot[:, 2, 2] + qw = 0.5 * torch.sqrt((1.0 + m00 + m11 + m22).clamp(min=1e-8)) + qx = torch.copysign(0.5 * torch.sqrt((1.0 + m00 - m11 - m22).clamp(min=1e-8)), m21 - m12) + qy = torch.copysign(0.5 * torch.sqrt((1.0 - m00 + m11 - m22).clamp(min=1e-8)), m02 - m20) + qz = torch.copysign(0.5 * torch.sqrt((1.0 - m00 - m11 + m22).clamp(min=1e-8)), m10 - m01) + quat = torch.stack([qw, qx, qy, qz], dim=1) + return quat / quat.norm(dim=1, keepdim=True).clamp(min=1e-8) + + +class PanoInitializer(nn.Module): + + def __init__(self, params: PanoInitializerParams) -> None: + super().__init__() + self.params = params + + def forward( + self, + image: torch.Tensor, + rays: torch.Tensor, + distance: torch.Tensor, + angular_cell_rays: torch.Tensor | None = None, + grid_cell_size_override: torch.Tensor | None = None, + force_grid_cell_size_override: bool = False, + target_hw: tuple[int, int] | None = None, + ) -> PanoInitializerOutput: + p = self.params + b, _, h, w = image.shape + device = image.device + + global_scale: torch.Tensor | None = None + if p.normalize_distance: + distance, factor = _rescale_distance(distance) + global_scale = 1.0 / factor + + stride = int(p.stride) + if target_hw is None: + img_ds = _downsample_avg(image, stride, circular_horizontal=bool(p.circular_horizontal)) + rays_ds = _downsample_avg(rays, stride, circular_horizontal=bool(p.circular_horizontal)) + angular_rays_ds = ( + _downsample_avg(angular_cell_rays, stride, circular_horizontal=bool(p.circular_horizontal)) + if torch.is_tensor(angular_cell_rays) + else rays_ds + ) + dist_ds = _downsample_avg( + distance, stride, circular_horizontal=bool(p.circular_horizontal) + ).clamp(min=1e-4) + pool_hw: tuple[int, int] | None = None + fallback_stride_u = fallback_stride_v = float(stride) + else: + target_h, target_w = int(target_hw[0]), int(target_hw[1]) + img_ds = _resize_to_grid(image, (target_h, target_w), mode="area") + rays_ds = _resize_to_grid(rays, (target_h, target_w), mode="bilinear") + angular_rays_ds = ( + _resize_to_grid(angular_cell_rays, (target_h, target_w), mode="bilinear") + if torch.is_tensor(angular_cell_rays) + else rays_ds + ) + dist_ds = _resize_to_grid(distance, (target_h, target_w), mode="area").clamp(min=1e-4) + pool_hw = (target_h, target_w) + fallback_stride_u = float(w) / float(max(target_w, 1)) + fallback_stride_v = float(h) / float(max(target_h, 1)) + + rays_ds_n = _safe_normalize(rays_ds) + angular_rays_ds_n = _safe_normalize(angular_rays_ds) + angular_cell = _ray_angular_cell( + angular_rays_ds_n, + circular_horizontal=bool(p.circular_horizontal), + ).to(dtype=dist_ds.dtype) + if torch.is_tensor(grid_cell_size_override): + cell_size = _format_cell_size_override_uv( + grid_cell_size_override, + batch_size=b, + device=device, + dtype=dist_ds.dtype, + ) + else: + cell_size = _grid_cell_size_uv( + batch_size=b, + image_h=h, + image_w=w, + stride=0.5 * (fallback_stride_u + fallback_stride_v), + circular_horizontal=bool(p.circular_horizontal), + device=device, + dtype=dist_ds.dtype, + ) + cell_size_full = cell_size.expand(-1, -1, 1, int(rays_ds_n.shape[-2]), int(rays_ds_n.shape[-1])) + if bool(force_grid_cell_size_override) and torch.is_tensor(grid_cell_size_override): + angular_cell = cell_size_full + else: + valid_angular = torch.isfinite(angular_cell) & (angular_cell > 1e-6) + angular_cell = torch.where(valid_angular, angular_cell, cell_size_full) + angular_cell = _smooth_angular_cell( + angular_cell, + kernel_size=9, + circular_horizontal=bool(p.circular_horizontal), + ) + angular_cell = angular_cell.clamp(min=1e-6, max=0.25) + + inv = 1.0 / dist_ds.clamp(min=1e-4) + + inv_full = 1.0 / distance.clamp(min=1e-4) + if p.first_layer_depth_option == "surface_min": + inv1_src = inv_full[:, 0:1] + inv1 = ( + F.adaptive_max_pool2d(inv1_src, output_size=pool_hw) + if pool_hw is not None + else F.max_pool2d(inv1_src, kernel_size=stride, stride=stride) + ) + else: + inv1_src = -inv_full[:, 0:1] + inv1 = -( + F.adaptive_max_pool2d(inv1_src, output_size=pool_hw) + if pool_hw is not None + else F.max_pool2d(inv1_src, kernel_size=stride, stride=stride) + ) + + if int(p.num_layers) == 1: + inv_L = inv1[:, :, None] + else: + following = inv_full if inv_full.shape[1] == 1 else inv_full[:, 1:2] + if p.rest_layer_depth_option == "surface_min": + inv2_src = following + inv2 = ( + F.adaptive_max_pool2d(inv2_src, output_size=pool_hw) + if pool_hw is not None + else F.max_pool2d(inv2_src, kernel_size=stride, stride=stride) + ) + else: + inv2_src = -following + inv2 = -( + F.adaptive_max_pool2d(inv2_src, output_size=pool_hw) + if pool_hw is not None + else F.max_pool2d(inv2_src, kernel_size=stride, stride=stride) + ) + inv_L = torch.cat([inv1[:, :, None], inv2[:, :, None]], dim=2) + + L = inv_L.shape[2] + rays_L = rays_ds_n[:, :, None].repeat(1, 1, L, 1, 1) + inv_dist_L = inv_L.clamp(min=1e-6) + dist_L = 1.0 / inv_dist_L.clamp(min=1e-6) + + cell_u = angular_cell[:, 0:1].to(dtype=dist_L.dtype, device=dist_L.device) + cell_v = angular_cell[:, 1:2].to(dtype=dist_L.dtype, device=dist_L.device) + cell_r = 0.5 * (cell_u + cell_v) + scales = torch.cat( + [ + dist_L * cell_u, + dist_L * cell_v, + dist_L * cell_r, + ], + dim=1, + ) * float(p.scale_factor) + + e_u, e_v = _build_tangent_basis( + rays_ds_n, + circular_horizontal=bool(p.circular_horizontal), + ) + rot = torch.stack([e_u, e_v, rays_ds_n], dim=-1) + rot_flat = rot.permute(0, 2, 3, 1, 4).reshape(-1, 3, 3) + quat_flat = _rotmat_to_quat_wxyz(rot_flat) + quaternions = quat_flat.reshape(b, int(rays_ds_n.shape[-2]), int(rays_ds_n.shape[-1]), 4) + quaternions = quaternions.permute(0, 3, 1, 2)[:, :, None].repeat(1, 1, L, 1, 1) + + colors = img_ds[:, :, None].repeat(1, 1, L, 1, 1).clamp(0.0, 1.0) + + opacities = torch.full( + (b, 1, L, 1, 1), + float(p.opacity_init), + device=device, + dtype=torch.float32, + ) + + base = PanoGaussianBaseValues( + rays=rays_L, + inv_distance=inv_dist_L, + angular_cell=angular_cell, + scales=scales, + quaternions=quaternions, + colors=colors, + opacities=opacities, + ) + + inv_dist_1 = inv_dist_L[:, :, 0] + feat = torch.cat([2.0 * img_ds - 1.0, 2.0 * inv_dist_1 - 1.0, rays_ds_n], dim=1) + + return PanoInitializerOutput( + gaussian_base_values=base, + feature_input=feat, + global_scale=global_scale, + grid_cell_size=0.5 * (angular_cell[:, 0:1] + angular_cell[:, 1:2]), + ) + diff --git a/unisharp/models/heads.py b/unisharp/models/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..4b610e2d5a932ea604856ffcd137825889cf841c --- /dev/null +++ b/unisharp/models/heads.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import torch +from torch import nn + +from .feature_gaussian_decoder import ImageFeatures + + +class DirectPredictionHead(nn.Module): + + def __init__(self, feature_dim: int, num_layers: int) -> None: + super().__init__() + self.feature_dim = int(feature_dim) + self.num_layers = int(num_layers) + + self.geometry_prediction_head = nn.Conv2d(self.feature_dim, 10 * self.num_layers, kernel_size=1) + self.texture_prediction_head = nn.Conv2d(self.feature_dim, 4 * self.num_layers, kernel_size=1) + + self.reset_parameters() + + def reset_parameters(self) -> None: + with torch.no_grad(): + self.geometry_prediction_head.weight.zero_() + self.geometry_prediction_head.bias.zero_() + self.texture_prediction_head.weight.zero_() + self.texture_prediction_head.bias.zero_() + + def _copy_conv_weight_prefix(self, dst: nn.Conv2d, src: torch.Tensor) -> bool: + if not isinstance(src, torch.Tensor): + return False + src_4d = src.reshape(src.shape[0], src.shape[1], 1, 1) if src.ndim == 2 else src + if src_4d.ndim != 4 or src_4d.shape[0] != dst.weight.shape[0]: + return False + channels = min(int(src_4d.shape[1]), self.feature_dim, int(dst.weight.shape[1])) + if channels <= 0 or tuple(src_4d.shape[2:]) != (1, 1): + return False + dst.weight[:, :channels].copy_(src_4d[:, :channels].to(device=dst.weight.device, dtype=dst.weight.dtype)) + return True + + def _copy_conv_weight_range( + self, + dst: nn.Conv2d, + src: torch.Tensor, + *, + src_start: int, + dst_start: int, + count: int, + ) -> bool: + if not isinstance(src, torch.Tensor) or int(count) <= 0: + return False + src_4d = src.reshape(src.shape[0], src.shape[1], 1, 1) if src.ndim == 2 else src + if src_4d.ndim != 4 or tuple(src_4d.shape[2:]) != (1, 1): + return False + src_end = int(src_start) + int(count) + dst_end = int(dst_start) + int(count) + if src_start < 0 or dst_start < 0 or src_end > int(src_4d.shape[0]) or dst_end > int(dst.weight.shape[0]): + return False + channels = min(int(src_4d.shape[1]), self.feature_dim, int(dst.weight.shape[1])) + if channels <= 0: + return False + dst.weight[dst_start:dst_end, :channels].copy_( + src_4d[src_start:src_end, :channels].to(device=dst.weight.device, dtype=dst.weight.dtype) + ) + return True + + @staticmethod + def _copy_bias_range( + dst: nn.Parameter, + src: torch.Tensor, + *, + src_start: int, + dst_start: int, + count: int, + ) -> bool: + if not isinstance(src, torch.Tensor) or src.ndim != 1 or int(count) <= 0: + return False + src_end = int(src_start) + int(count) + dst_end = int(dst_start) + int(count) + if src_start < 0 or dst_start < 0 or src_end > int(src.shape[0]) or dst_end > int(dst.shape[0]): + return False + dst[dst_start:dst_end].copy_(src[src_start:src_end].to(device=dst.device, dtype=dst.dtype)) + return True + + def init_from_legacy_direct_state(self, state: dict[str, torch.Tensor]) -> int: + copied = 0 + with torch.no_grad(): + geo_w = state.get( + "geometry_weight", + state.get("geometry_prediction_head.weight", state.get("geo_fc2.weight")), + ) + geo_b = state.get( + "geometry_bias", + state.get("geometry_prediction_head.bias", state.get("geo_fc2.bias")), + ) + tex_w = state.get( + "texture_weight", + state.get("texture_prediction_head.weight", state.get("tex_fc2.weight")), + ) + tex_b = state.get( + "texture_bias", + state.get("texture_prediction_head.bias", state.get("tex_fc2.bias")), + ) + + l = self.num_layers + if isinstance(geo_w, torch.Tensor): + if int(geo_w.shape[0]) == 10 * l and self._copy_conv_weight_prefix(self.geometry_prediction_head, geo_w): + copied += 1 + elif self._copy_conv_weight_range( + self.geometry_prediction_head, + geo_w, + src_start=0, + dst_start=0, + count=3 * l, + ): + copied += 1 + if isinstance(geo_b, torch.Tensor): + if tuple(geo_b.shape) == (10 * l,): + self.geometry_prediction_head.bias.copy_( + geo_b.to(device=self.geometry_prediction_head.bias.device, dtype=self.geometry_prediction_head.bias.dtype) + ) + copied += 1 + elif self._copy_bias_range(self.geometry_prediction_head.bias, geo_b, src_start=0, dst_start=0, count=3 * l): + copied += 1 + if isinstance(tex_w, torch.Tensor): + if int(tex_w.shape[0]) == 4 * l and self._copy_conv_weight_prefix(self.texture_prediction_head, tex_w): + copied += 1 + elif int(tex_w.shape[0]) == 11 * l: + if self._copy_conv_weight_range(self.geometry_prediction_head, tex_w, src_start=0, dst_start=3 * l, count=7 * l): + copied += 1 + if self._copy_conv_weight_range(self.texture_prediction_head, tex_w, src_start=7 * l, dst_start=0, count=4 * l): + copied += 1 + if isinstance(tex_b, torch.Tensor): + if tuple(tex_b.shape) == (4 * l,): + self.texture_prediction_head.bias.copy_( + tex_b.to(device=self.texture_prediction_head.bias.device, dtype=self.texture_prediction_head.bias.dtype) + ) + copied += 1 + elif tuple(tex_b.shape) == (11 * l,): + if self._copy_bias_range(self.geometry_prediction_head.bias, tex_b, src_start=0, dst_start=3 * l, count=7 * l): + copied += 1 + if self._copy_bias_range(self.texture_prediction_head.bias, tex_b, src_start=7 * l, dst_start=0, count=4 * l): + copied += 1 + return copied + + def forward(self, image_features: ImageFeatures) -> torch.Tensor: + delta_geo = self.geometry_prediction_head(image_features.geometry_features) + delta_texture = self.texture_prediction_head(image_features.texture_features) + delta_geo = delta_geo.unflatten(1, (10, self.num_layers)) + delta_texture = delta_texture.unflatten(1, (4, self.num_layers)) + return torch.cat([delta_geo, delta_texture], dim=1) + + +GaussianConditionedPredictionHead = DirectPredictionHead diff --git a/unisharp/models/unik3d_feature_extractor.py b/unisharp/models/unik3d_feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..018452ebc7d1486e13664af428605f4c0a1ea438 --- /dev/null +++ b/unisharp/models/unik3d_feature_extractor.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import logging +from typing import Any + +import torch +from torch import nn + + +LOGGER = logging.getLogger(__name__) + + +def _enable_unik3d_encoder_feature_capture(model: torch.nn.Module) -> None: + try: + encoder = model.pixel_encoder # type: ignore[attr-defined] + except Exception: + return + + if getattr(encoder, "_unisharp_encoder_wrapped", False): + return + + import types + + orig_forward = encoder.forward + + def wrapped_forward(self, *args, **kwargs): # type: ignore[no-untyped-def] + output = orig_forward(*args, **kwargs) + self._unisharp_last_encoder_output = output + return output + + encoder.forward = types.MethodType(wrapped_forward, encoder) # type: ignore[method-assign] + encoder._unisharp_encoder_wrapped = True + + +def extract_unik3d_2d_feature_layers( + encoder_output: Any, + target_h: int, + target_w: int, +) -> list[torch.Tensor]: + def _to_bchw(feat_in: torch.Tensor) -> torch.Tensor: + if feat_in.ndim == 4: + if feat_in.shape[-1] > feat_in.shape[1]: + return feat_in.permute(0, 3, 1, 2).contiguous() + return feat_in + if feat_in.ndim == 3: + B, N, C = feat_in.shape + import math + + ratio = target_h / max(target_w, 1) + pH = max(1, int(math.sqrt(N * ratio) + 0.5)) + pW = max(1, N // pH) + while pH * pW < N: + pW += 1 + return feat_in[:, : pH * pW, :].transpose(1, 2).reshape(B, C, pH, pW) + raise TypeError(f"Unsupported spatial feature ndim={feat_in.ndim}") + + def _resize(feat_in: torch.Tensor) -> torch.Tensor: + if feat_in.shape[-2:] != (target_h, target_w): + feat_in = torch.nn.functional.interpolate( + feat_in, + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ) + return feat_in + + feats: list[torch.Tensor] = [] + + if isinstance(encoder_output, (list, tuple)) and len(encoder_output) == 2: + spatial_or_cls, _ = encoder_output + if isinstance(spatial_or_cls, (list, tuple)) and len(spatial_or_cls) > 0: + spatial_candidates = [x for x in spatial_or_cls if isinstance(x, torch.Tensor)] + if len(spatial_candidates) > 0: + n = len(spatial_candidates) + if n <= 4: + idxs = list(range(n)) + else: + idxs = sorted({n // 4, n // 2, (3 * n) // 4, n - 1}) + feats = [_resize(_to_bchw(spatial_candidates[i])) for i in idxs] + channels = {int(f.shape[1]) for f in feats} + if len(channels) != 1: + raise ValueError( + f"Selected DINO spatial features must share channels, got {sorted(channels)}" + ) + return feats + + if len(feats) == 0: + def _find_spatial(x: Any, depth: int = 0) -> torch.Tensor | None: + if isinstance(x, torch.Tensor): + if x.ndim == 4: + return x + if x.ndim == 3 and x.shape[1] > 2: + return x + return None + if isinstance(x, (list, tuple)) and depth < 3: + for elem in x: + result = _find_spatial(elem, depth + 1) + if result is not None: + return result + return None + + feat = _find_spatial(encoder_output) + else: + feat = None + + if feat is None or not isinstance(feat, torch.Tensor): + raise TypeError( + f"Cannot extract spatial 2D features from encoder_output of type {type(encoder_output)}. " + "Expected (spatial_list, cls_list) from DINOv2 encoder." + ) + return [_resize(_to_bchw(feat))] + + +def extract_unik3d_2d_features( + encoder_output: Any, + target_h: int, + target_w: int, +) -> torch.Tensor: + feats = extract_unik3d_2d_feature_layers(encoder_output, target_h, target_w) + return torch.stack(feats, dim=0).mean(dim=0) + + +class DINOFeatureLayerFusion(nn.Module): + + def __init__(self, dim: int, max_layers: int = 4) -> None: + super().__init__() + self.dim = int(dim) + self.max_layers = int(max_layers) + self.proj = nn.ModuleList( + [nn.Conv2d(self.dim, self.dim, kernel_size=1, bias=False) for _ in range(self.max_layers)] + ) + self.layer_logits = nn.Parameter(torch.zeros(self.max_layers)) + self.reset_parameters() + + def reset_parameters(self) -> None: + for proj in self.proj: + nn.init.zeros_(proj.weight) + eye = torch.eye(self.dim, dtype=proj.weight.dtype).view(self.dim, self.dim, 1, 1) + with torch.no_grad(): + proj.weight.copy_(eye) + + def forward(self, features: list[torch.Tensor]) -> torch.Tensor: + if not features: + raise ValueError("Expected at least one DINO feature layer.") + if len(features) > self.max_layers: + raise ValueError(f"Expected at most {self.max_layers} feature layers, got {len(features)}") + channels = {int(f.shape[1]) for f in features} + if channels != {self.dim}: + raise ValueError(f"Expected DINO feature channels {self.dim}, got {sorted(channels)}") + weights = torch.softmax(self.layer_logits[: len(features)], dim=0) + fused = None + for i, feat in enumerate(features): + projected = self.proj[i](feat) + weighted = projected * weights[i].to(dtype=projected.dtype, device=projected.device) + fused = weighted if fused is None else fused + weighted + if fused is None: + raise RuntimeError("DINO feature fusion produced no output.") + return fused + + +class UniK3DFeatureExtractor(nn.Module): + + def __init__( + self, + unik3d_model: nn.Module, + dino_feature_dim: int = 1024, + ): + super().__init__() + + self.unik3d = unik3d_model + self.dino_layer_fusion = DINOFeatureLayerFusion(dim=int(dino_feature_dim), max_layers=4) + + from unisharp.utils.unik3d_adapter import _enable_unik3d_decoder_feature_capture + _enable_unik3d_decoder_feature_capture(self.unik3d) + _enable_unik3d_encoder_feature_capture(self.unik3d) + try: + self.unik3d.pixel_decoder.radial_module._unisharp_detach_rays_embeddings = True # type: ignore[attr-defined] + except Exception: + pass + self._unisharp_last_unik3d_output: dict[str, torch.Tensor] | None = None + + def train(self, mode: bool = True) -> "UniK3DFeatureExtractor": + super().train(mode) + return self + + def _extract_features_from_output( + self, + output: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + self._unisharp_last_unik3d_output = output + + features_3d_pyramid = output.get("pyramid_features") + if features_3d_pyramid is None: + raise RuntimeError( + "Failed to capture pyramid_features from UniK3D. " + "Ensure _enable_unik3d_decoder_feature_capture() is called." + ) + + highest_res_3d = features_3d_pyramid[-1] + actual_h, actual_w = highest_res_3d.shape[-2:] + + try: + encoder = self.unik3d.pixel_encoder # type: ignore[attr-defined] + encoder_output = getattr(encoder, "_unisharp_last_encoder_output", None) + if encoder_output is None: + raise RuntimeError("Failed to capture encoder output from UniK3D.") + dino_layers = extract_unik3d_2d_feature_layers( + encoder_output, + target_h=actual_h, + target_w=actual_w, + ) + features_2d = self.dino_layer_fusion(dino_layers) + except Exception as e: + LOGGER.error(f"Failed to extract 2D features: {e}") + raise RuntimeError( + "Failed to extract DINO 2D features. The Gaussian decoder expects " + "DINO-channel features here; lowres UniK3D decoder features are not " + "a safe fallback because their channels do not match Feature2DEncoder." + ) from e + + return features_2d, features_3d_pyramid + + def forward( + self, + rgb_u8: torch.Tensor, + target_h: int, + target_w: int, + intrinsics: torch.Tensor | None = None, + camera_params: torch.Tensor | None = None, + camera_model: str | None = None, + hfov: float | None = None, + vfov: float | None = None, + validity_mask: torch.Tensor | None = None, + use_predicted_rays: bool = False, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + from unisharp.utils.unik3d_adapter import ( + build_unik3d_camera_rays, + forward_unik3d_camera_rays, + forward_unik3d_fisheye624, + forward_unik3d_pinhole, + ) + + def _forward() -> dict[str, torch.Tensor]: + if bool(use_predicted_rays): + return forward_unik3d_camera_rays( + self.unik3d, + rgb_u8, + normalize=True, + validity_mask=validity_mask, + ) + if torch.is_tensor(intrinsics): + return forward_unik3d_pinhole( + self.unik3d, + rgb_u8, + intrinsics=intrinsics, + normalize=True, + ) + if torch.is_tensor(camera_params): + return forward_unik3d_fisheye624( + self.unik3d, + rgb_u8, + camera_params=camera_params, + normalize=True, + validity_mask=validity_mask, + ) + _, rays, _, _ = build_unik3d_camera_rays( + rgb_u8, + device=next(self.unik3d.parameters()).device, + camera_model=camera_model, + hfov=float(2.0 * torch.pi) if hfov is None else float(hfov), + vfov=float(torch.pi) if vfov is None else float(vfov), + ) + return forward_unik3d_camera_rays( + self.unik3d, + rgb_u8, + normalize=True, + validity_mask=validity_mask, + rays=rays, + ) + + output = _forward() + return self._extract_features_from_output(output) diff --git a/unisharp/models/unisharp_feature.py b/unisharp/models/unisharp_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..4347bb4dae0292ff53758fa3d8be63b02d3ccfc8 --- /dev/null +++ b/unisharp/models/unisharp_feature.py @@ -0,0 +1,1030 @@ +from __future__ import annotations + +import copy +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from unisharp.models.unik3d_feature_extractor import UniK3DFeatureExtractor +from unisharp.models.gaussian_initializer import PanoInitializer +from unisharp.models.heads import ( + DirectPredictionHead, +) +from unisharp.models.gaussian_composer import PanoGaussianComposer +from unisharp.models.unisharp_params import PanoPredictorParams +from unisharp.models.feature_gaussian_decoder import ImageFeatures +from unisharp import DEFAULT_MAX_DEPTH_M +from unisharp.utils.gaussians import Gaussians3D + + +@dataclass +class UnisharpFeatureConfig: + + unik3d_backbone: str = "vitl" + unik3d_resolution_level: int = 0 + + initializer_stride: int = 1 + initializer_scale_factor: float = 1.5 + max_distance_m: float = DEFAULT_MAX_DEPTH_M + detach_init_layer0_distance: bool = True + delta_rho_limit: float = 2.0 + + +class UniK3DCopiedDepthHead(nn.Module): + + def __init__(self, radial_module: nn.Module, *, out_channels: int = 1) -> None: + super().__init__() + self.out_channels = int(out_channels) + if self.out_channels < 1: + raise ValueError(f"out_channels must be >= 1, got {self.out_channels}") + + self.depth_mlp = copy.deepcopy(radial_module.depth_mlp) # type: ignore[attr-defined] + self.to_depth_lr = copy.deepcopy(radial_module.to_depth_lr) # type: ignore[attr-defined] + self.to_depth_hr = copy.deepcopy(radial_module.to_depth_hr) # type: ignore[attr-defined] + self._set_last_depth_conv_channels(self.out_channels) + + for p in self.parameters(): + p.requires_grad_(True) + + def _set_last_depth_conv_channels(self, out_channels: int) -> None: + if not isinstance(self.to_depth_hr, nn.Sequential) or len(self.to_depth_hr) == 0: + raise TypeError("Expected UniK3D radial_module.to_depth_hr to be a non-empty nn.Sequential.") + last = self.to_depth_hr[-1] + if not isinstance(last, nn.Conv2d): + raise TypeError(f"Expected final UniK3D depth head layer to be Conv2d, got {type(last)!r}.") + if int(last.out_channels) == int(out_channels): + return + + new_last = nn.Conv2d( + in_channels=int(last.in_channels), + out_channels=int(out_channels), + kernel_size=last.kernel_size, + stride=last.stride, + padding=last.padding, + dilation=last.dilation, + groups=int(last.groups), + bias=last.bias is not None, + padding_mode=last.padding_mode, + ) + with torch.no_grad(): + if int(last.out_channels) == 1: + new_last.weight.copy_(last.weight.repeat(int(out_channels), 1, 1, 1)) + if last.bias is not None and new_last.bias is not None: + new_last.bias.copy_(last.bias.repeat(int(out_channels))) + elif int(last.out_channels) >= int(out_channels): + new_last.weight.copy_(last.weight[: int(out_channels)]) + if last.bias is not None and new_last.bias is not None: + new_last.bias.copy_(last.bias[: int(out_channels)]) + else: + repeat = (int(out_channels) + int(last.out_channels) - 1) // int(last.out_channels) + new_last.weight.copy_(last.weight.repeat(repeat, 1, 1, 1)[: int(out_channels)]) + if last.bias is not None and new_last.bias is not None: + new_last.bias.copy_(last.bias.repeat(repeat)[: int(out_channels)]) + self.to_depth_hr[-1] = new_last + + def _radial_out_features(self, features_3d_pyramid: list[torch.Tensor]) -> list[torch.Tensor]: + expected = len(self.depth_mlp) + if len(features_3d_pyramid) == expected + 1: + return list(features_3d_pyramid[1:]) + if len(features_3d_pyramid) == expected: + return list(features_3d_pyramid) + raise RuntimeError( + f"Expected {expected} UniK3D radial out features (or {expected + 1} including init_latents), " + f"got {len(features_3d_pyramid)}." + ) + + def forward( + self, + features_3d_pyramid: list[torch.Tensor], + *, + internal_hw: tuple[int, int], + ) -> torch.Tensor: + out_features = self._radial_out_features(features_3d_pyramid) + h_out, w_out = int(out_features[-1].shape[-2]), int(out_features[-1].shape[-1]) + out_depth_features: torch.Tensor | None = None + for i, (layer, features) in enumerate(zip(self.depth_mlp, out_features, strict=False)): + out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + if i < len(self.depth_mlp) - 1: + continue + if out_depth_features is None: + raise RuntimeError("UniK3D copied depth head received no radial features.") + + out_depth_features = F.interpolate( + out_depth_features, + size=(h_out, w_out), + mode="bilinear", + align_corners=True, + ) + logradius = self.to_depth_lr(out_depth_features) + logradius = F.interpolate( + logradius, + size=(int(internal_hw[0]), int(internal_hw[1])), + mode="bilinear", + align_corners=True, + ) + return self.to_depth_hr(logradius) + + +class UnisharpFeatureModel(nn.Module): + + def __init__(self, config: UnisharpFeatureConfig): + super().__init__() + self.config = config + + from unisharp.utils.unik3d_adapter import load_unik3d_model + unik3d_model = load_unik3d_model( + backbone=config.unik3d_backbone, + pretrained=True, + device="cpu", + ) + if hasattr(unik3d_model, "resolution_level"): + unik3d_model.resolution_level = int(max(0, min(9, config.unik3d_resolution_level))) + + from unisharp.models.feature_gaussian_decoder import FeatureGaussianDecoderParams, create_feature_gaussian_decoder + if config.unik3d_backbone == "vitl": + dino_feature_dim = 1024 + decoder_params = FeatureGaussianDecoderParams( + dims_3d_in=(128, 256, 512, 512), + dims_3d_out=(256, 512, 1024, 1024), + dim_2d_in=1024, + dim_2d_out=256, + dim_decoder_out=256, + dim_texture_out=32, + dim_geometry_out=32, + stride_out=int(max(1, config.initializer_stride)), + ) + else: + dino_feature_dim = 768 + decoder_params = FeatureGaussianDecoderParams( + dims_3d_in=(96, 192, 384, 384), + dims_3d_out=(256, 512, 768, 768), + dim_2d_in=768, + dim_2d_out=256, + dim_decoder_out=256, + dim_texture_out=32, + dim_geometry_out=32, + stride_out=int(max(1, config.initializer_stride)), + ) + + self.feature_extractor = UniK3DFeatureExtractor( + unik3d_model=unik3d_model, + dino_feature_dim=dino_feature_dim, + ) + + self.decoder_params = decoder_params + self.feature_decoder = create_feature_gaussian_decoder(decoder_params) + + params = PanoPredictorParams() + params.initializer.stride = config.initializer_stride + params.initializer.scale_factor = float(config.initializer_scale_factor) + params.num_monodepth_layers = 2 + params.initializer.num_layers = 2 + + self.init_model = PanoInitializer(params.initializer) + self.prediction_head = DirectPredictionHead( + feature_dim=32, + num_layers=params.initializer.num_layers, + ) + decoder_stride = int(getattr(self.feature_decoder, "stride", 1)) + init_stride = int(max(1, config.initializer_stride)) + if decoder_stride != init_stride: + raise ValueError( + "Feature decoder stride must match initializer stride so base/features/head/delta " + "share one Gaussian grid, " + f"got decoder_stride={decoder_stride}, initializer_stride={init_stride}" + ) + self.gaussian_composer = PanoGaussianComposer( + delta_factor=params.delta_factor, + min_scale=params.min_scale, + max_scale=params.max_scale, + color_activation_type=params.color_activation_type, + opacity_activation_type=params.opacity_activation_type, + color_space="linearRGB", + base_scale_on_predicted_mean=params.base_scale_on_predicted_mean, + scale_factor=decoder_stride // init_stride, + delta_rho_limit=float(getattr(config, "delta_rho_limit", 2.0)), + ) + + radial_module = self.feature_extractor.unik3d.pixel_decoder.radial_module # type: ignore[attr-defined] + self.second_layer_depth_head = UniK3DCopiedDepthHead( + radial_module, + out_channels=1, + ) + + self.params = params + + def train(self, mode: bool = True) -> "UnisharpFeatureModel": + super().train(mode) + return self + + def _set_initializer_circular_mode(self, circular_horizontal: bool) -> None: + self.init_model.params.circular_horizontal = bool(circular_horizontal) + + def _initializer_grid_cell_size_override( + self, + *, + camera_intrinsics: torch.Tensor | None, + camera_params: torch.Tensor | None, + image: torch.Tensor, + is_spherical: bool, + ) -> torch.Tensor | None: + if bool(is_spherical): + return None + stride = float(max(1, int(self.init_model.params.stride))) + if torch.is_tensor(camera_intrinsics): + k = camera_intrinsics.to(device=image.device, dtype=torch.float32) + if k.ndim == 2: + k = k.unsqueeze(0) + fx = k[:, 0, 0].clamp(min=1.0) + fy = k[:, 1, 1].clamp(min=1.0) + return torch.stack([stride / fx, stride / fy], dim=1) + if torch.is_tensor(camera_params): + params = camera_params.to(device=image.device, dtype=torch.float32) + if params.ndim == 1: + params = params.unsqueeze(0) + if int(params.shape[-1]) == 15: + fx = fy = params[:, 0].clamp(min=1.0) + else: + fx = params[:, 0].clamp(min=1.0) + fy = params[:, 1].clamp(min=1.0) + return torch.stack([stride / fx, stride / fy], dim=1) + return None + + def _predict_delta( + self, + image_features: ImageFeatures, + ) -> torch.Tensor: + return self.prediction_head(image_features) + + @staticmethod + def _strip_module_prefix(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + out: dict[str, torch.Tensor] = {} + for k, v in state.items(): + if isinstance(k, str) and k.startswith("module."): + out[k[len("module.") :]] = v + else: + out[k] = v + return out + + @staticmethod + def _distance_layers_from_unik3d_logradius( + *, + logradius_layers: torch.Tensor, + internal_rays: torch.Tensor | None, + final_rays: torch.Tensor, + unik3d_output: dict[str, Any], + ) -> torch.Tensor: + radius_layers = torch.exp(logradius_layers.clamp(min=-8.0, max=8.0) + 2.0) + if internal_rays is None: + return F.interpolate( + radius_layers, + size=final_rays.shape[-2:], + mode="bilinear", + align_corners=False, + ).clamp(min=1e-4) + + padded_hw = unik3d_output.get("_unisharp_postprocess_padded_hw", None) + paddings = unik3d_output.get("_unisharp_postprocess_paddings", None) + interpolation_mode = unik3d_output.get("_unisharp_postprocess_interpolation_mode", None) + if padded_hw is None or paddings is None or interpolation_mode is None: + return F.interpolate( + radius_layers, + size=final_rays.shape[-2:], + mode="bilinear", + align_corners=False, + ).clamp(min=1e-4) + + from unisharp.utils.unik3d_adapter import postprocess_unik3d_tensor + + bsz, num_layers, h_int, w_int = radius_layers.shape + radius_post = postprocess_unik3d_tensor( + radius_layers.reshape(bsz * num_layers, 1, h_int, w_int), + padded_hw=tuple(int(x) for x in padded_hw), + paddings=tuple(int(x) for x in paddings), + interpolation_mode=str(interpolation_mode), + ) + h_out, w_out = int(radius_post.shape[-2]), int(radius_post.shape[-1]) + return radius_post.reshape(bsz, num_layers, h_out, w_out).clamp(min=1e-4) + + @staticmethod + def _compute_spherical_distortion_drop_prob_map( + *, + batch_size: int, + h: int, + w: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + yy = (torch.arange(h, device=device, dtype=dtype) + 0.5) / float(max(1, h)) + lat = (yy - 0.5) * torch.pi + row_keep = torch.cos(lat).clamp(min=0.0) + row_keep = row_keep / row_keep.max().clamp(min=1e-6) + row_drop = 1.0 - row_keep + drop_2d = row_drop[:, None].expand(h, w) + return drop_2d.unsqueeze(0).expand(int(batch_size), -1, -1).contiguous() + + @staticmethod + def _sample_dropout_mask( + drop_prob_map: torch.Tensor, + *, + num_layers: int, + ) -> torch.Tensor: + if num_layers <= 0: + raise ValueError(f"num_layers must be positive, got {num_layers}") + b, h, w = drop_prob_map.shape + p = torch.zeros((b, 1, int(num_layers), h, w), device=drop_prob_map.device, dtype=drop_prob_map.dtype) + if int(num_layers) > 1: + p[:, :, 1] = drop_prob_map[:, None] + rnd = torch.rand_like(p) + return (rnd < p).to(dtype=drop_prob_map.dtype) + + @staticmethod + def _apply_dropout_to_gaussians( + gaussians: Gaussians3D, + dropout_mask: torch.Tensor | None, + ) -> Gaussians3D: + if dropout_mask is None: + return gaussians + if dropout_mask.ndim != 5: + raise ValueError(f"Expected dropout mask shape [B,1,L,H,W], got {tuple(dropout_mask.shape)}") + mask_flat = dropout_mask[:, 0].flatten(1).to( + device=gaussians.opacities.device, + dtype=gaussians.opacities.dtype, + ).clamp(0.0, 1.0) + if mask_flat.shape != gaussians.opacities.shape: + raise ValueError( + "Dropout mask must match flattened Gaussian opacity shape, " + f"got {tuple(mask_flat.shape)} vs {tuple(gaussians.opacities.shape)}" + ) + return Gaussians3D( + mean_vectors=gaussians.mean_vectors, + singular_values=gaussians.singular_values, + quaternions=gaussians.quaternions, + colors=gaussians.colors, + opacities=gaussians.opacities * (1.0 - mask_flat), + ) + + @staticmethod + def _apply_flat_opacity_dropout( + gaussians: Gaussians3D, + dropout_mask_flat: torch.Tensor | None, + ) -> Gaussians3D: + if dropout_mask_flat is None: + return gaussians + mask = dropout_mask_flat.to(device=gaussians.opacities.device, dtype=gaussians.opacities.dtype).clamp(0.0, 1.0) + if tuple(mask.shape) != tuple(gaussians.opacities.shape): + raise ValueError( + "Flat dropout mask must match Gaussian opacity shape, " + f"got {tuple(mask.shape)} vs {tuple(gaussians.opacities.shape)}" + ) + return Gaussians3D( + mean_vectors=gaussians.mean_vectors, + singular_values=gaussians.singular_values, + quaternions=gaussians.quaternions, + colors=gaussians.colors, + opacities=gaussians.opacities * (1.0 - mask), + ) + + def _build_spherical_dropout_prob_map( + self, + base_values: Any, + *, + is_spherical: bool, + ) -> torch.Tensor | None: + rays = base_values.rays + if rays.ndim != 5 or rays.shape[2] <= 1: + return None + if not bool(is_spherical): + return None + if not bool(self.training): + return None + with torch.no_grad(): + bsz = int(base_values.rays.shape[0]) + h = int(base_values.rays.shape[-2]) + w = int(base_values.rays.shape[-1]) + drop_prob_map = self._compute_spherical_distortion_drop_prob_map( + batch_size=bsz, + h=h, + w=w, + device=base_values.rays.device, + dtype=base_values.rays.dtype, + ) + return drop_prob_map + + @staticmethod + def _strip_known_prefixes_for_unik3d(state: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + out: dict[str, torch.Tensor] = {} + for k, v in state.items(): + kk = str(k) + if kk.startswith("module."): + kk = kk[len("module.") :] + if kk.startswith("unik3d."): + kk = kk[len("unik3d.") :] + out[kk] = v + return out + + @staticmethod + def _looks_like_unik3d_state_dict(state: dict[str, torch.Tensor]) -> bool: + if len(state) == 0: + return False + probe = UnisharpFeatureModel._strip_known_prefixes_for_unik3d(state) + prefixes = ("pixel_encoder.", "pixel_decoder.", "head.") + return any(any(str(k).startswith(p) for p in prefixes) for k in probe.keys()) + + def _load_unik3d_state_dict(self, state: dict[str, torch.Tensor]) -> tuple[list[str], list[str]]: + state_norm = self._strip_known_prefixes_for_unik3d(state) + incompatible = self.feature_extractor.unik3d.load_state_dict(state_norm, strict=False) + missing = [f"feature_extractor.unik3d.{k}" for k in list(getattr(incompatible, "missing_keys", []))] + unexpected = [f"feature_extractor.unik3d.{k}" for k in list(getattr(incompatible, "unexpected_keys", []))] + return missing, unexpected + + @staticmethod + def _center_embed_or_crop_kernel(src: torch.Tensor, dst_shape: torch.Size) -> torch.Tensor: + out = src.new_zeros(tuple(dst_shape)) + src_h, src_w = int(src.shape[-2]), int(src.shape[-1]) + dst_h, dst_w = int(dst_shape[-2]), int(dst_shape[-1]) + copy_h = min(src_h, dst_h) + copy_w = min(src_w, dst_w) + src_y0 = max(0, (src_h - copy_h) // 2) + src_x0 = max(0, (src_w - copy_w) // 2) + dst_y0 = max(0, (dst_h - copy_h) // 2) + dst_x0 = max(0, (dst_w - copy_w) // 2) + out[..., dst_y0 : dst_y0 + copy_h, dst_x0 : dst_x0 + copy_w] = src[ + ..., src_y0 : src_y0 + copy_h, src_x0 : src_x0 + copy_w + ] + return out + + @classmethod + def _load_module_state_shape_compat( + cls, + module: nn.Module, + state: dict[str, torch.Tensor], + ) -> tuple[list[str], list[str]]: + current = module.state_dict() + filtered: dict[str, torch.Tensor] = {} + mismatched: list[str] = [] + migrated: list[str] = [] + for key, value in state.items(): + dst = current.get(key) + if not torch.is_tensor(value) or not torch.is_tensor(dst): + continue + if tuple(value.shape) == tuple(dst.shape): + filtered[key] = value.to(dtype=dst.dtype) + continue + if ( + str(key).endswith(".deconv.weight") + and value.ndim == 4 + and dst.ndim == 4 + and tuple(value.shape[:2]) == tuple(dst.shape[:2]) + ): + filtered[key] = cls._center_embed_or_crop_kernel(value, dst.shape).to(dtype=dst.dtype) + migrated.append(f"{key}:shape{tuple(value.shape)}->{tuple(dst.shape)}") + continue + mismatched.append(f"{key}:shape{tuple(value.shape)}->{tuple(dst.shape)}") + incompatible = module.load_state_dict(filtered, strict=False) + missing = list(getattr(incompatible, "missing_keys", [])) + unexpected = list(getattr(incompatible, "unexpected_keys", [])) + mismatched + unexpected.extend([f"migrated:{item}" for item in migrated]) + return missing, unexpected + + def _load_prediction_head_compat( + self, + state: dict[str, torch.Tensor], + ) -> tuple[list[str], list[str], int]: + current = self.prediction_head.state_dict() + filtered: dict[str, torch.Tensor] = {} + shape_mismatch: list[str] = [] + for key, value in state.items(): + dst = current.get(key) + if not torch.is_tensor(value) or not torch.is_tensor(dst): + continue + if tuple(value.shape) != tuple(dst.shape): + shape_mismatch.append(str(key)) + continue + filtered[key] = value.to(dtype=dst.dtype) + incompatible = self.prediction_head.load_state_dict(filtered, strict=False) + missing = list(getattr(incompatible, "missing_keys", [])) + unexpected = list(getattr(incompatible, "unexpected_keys", [])) + shape_mismatch + copied_legacy = 0 + if len(state) > 0 and len(missing) > 0: + copied_legacy = self.prediction_head.init_from_legacy_direct_state(state) + if copied_legacy > 0: + missing = [ + k + for k in missing + if k + not in ( + "geometry_prediction_head.weight", + "geometry_prediction_head.bias", + "texture_prediction_head.weight", + "texture_prediction_head.bias", + ) + ] + unexpected = [ + k + for k in unexpected + if k + not in ( + "geometry_weight", + "geometry_bias", + "texture_weight", + "texture_bias", + "geometry_prediction_head.weight", + "geometry_prediction_head.bias", + "texture_prediction_head.weight", + "texture_prediction_head.bias", + "geo_fc2.weight", + "geo_fc2.bias", + "tex_fc2.weight", + "tex_fc2.bias", + ) + ] + return missing, unexpected, copied_legacy + + def _load_depth_head_compat(self, state: dict[str, torch.Tensor]) -> tuple[list[str], list[str]]: + state = dict(state) + current = self.second_layer_depth_head.state_dict() + filtered: dict[str, torch.Tensor] = {} + unexpected: list[str] = [] + for key, value in state.items(): + dst = current.get(key) + if not torch.is_tensor(value) or not torch.is_tensor(dst): + unexpected.append(str(key)) + continue + if tuple(value.shape) != tuple(dst.shape): + unexpected.append(str(key)) + continue + filtered[key] = value.to(dtype=dst.dtype) + incompatible = self.second_layer_depth_head.load_state_dict(filtered, strict=False) + missing = list(getattr(incompatible, "missing_keys", [])) + unexpected.extend(list(getattr(incompatible, "unexpected_keys", []))) + return missing, unexpected + + def forward( + self, + image: torch.Tensor, + image_u8: torch.Tensor | None = None, + camera_intrinsics: torch.Tensor | None = None, + camera_params: torch.Tensor | None = None, + camera_model: str | None = None, + depth_gt: torch.Tensor | None = None, + distance_init_cap_m: float | None = None, + validity_mask: torch.Tensor | None = None, + return_aux: bool = False, + ) -> dict[str, Any] | Any: + _, _, H, W = image.shape + + import numpy as np + camera_model_name_input = str(camera_model or "").strip().lower() + if ( + camera_model_name_input == "" + and camera_intrinsics is None + and camera_params is None + and int(W) == 2 * int(H) + ): + camera_model = "spherical" + if image_u8 is None: + image_u8 = (image * 255.0).round().clamp(0, 255).to(torch.uint8) + features_2d, features_3d_pyramid = self.feature_extractor.forward( + rgb_u8=image_u8, + target_h=H, + target_w=W, + intrinsics=camera_intrinsics, + camera_params=camera_params, + camera_model=camera_model, + hfov=float(2.0 * np.pi), + vfov=float(np.pi), + validity_mask=validity_mask, + use_predicted_rays=False, + ) + unik3d_output = self.feature_extractor._unisharp_last_unik3d_output + + if unik3d_output is None: + raise RuntimeError("Missing cached UniK3D output from feature_extractor forward.") + + rays = unik3d_output["rays"] + distance = unik3d_output["distance"] + try: + from unisharp.utils.unik3d_adapter import build_unik3d_camera_rays + + _, gt_rays, _, _ = build_unik3d_camera_rays( + image_u8, + device=rays.device, + intrinsics=camera_intrinsics, + camera_params=camera_params, + camera_model=camera_model, + hfov=float(2.0 * np.pi), + vfov=float(np.pi), + ) + except Exception as exc: + raise RuntimeError( + "Failed to build calibrated camera rays required by the fixed gt-override geometry path." + ) from exc + geometry_rays = gt_rays.to(device=rays.device, dtype=rays.dtype) + if tuple(geometry_rays.shape[-2:]) != tuple(rays.shape[-2:]): + geometry_rays = F.interpolate(geometry_rays, size=rays.shape[-2:], mode="bilinear", align_corners=False) + geometry_rays = geometry_rays / torch.norm(geometry_rays, dim=1, keepdim=True).clamp(min=1e-5) + + camera_model_name = str(camera_model or "").strip().lower() + render_geometry_rays = geometry_rays.detach() + + is_spherical = bool( + (camera_intrinsics is None) + and (camera_params is None) + and camera_model_name in {"spherical", "erp", "panorama"} + ) + + internal_rays_raw = unik3d_output.get("_unisharp_internal_rays", None) + internal_rays = internal_rays_raw if torch.is_tensor(internal_rays_raw) else None + internal_hw = ( + (int(internal_rays.shape[-2]), int(internal_rays.shape[-1])) + if internal_rays is not None + else (int(image.shape[-2]), int(image.shape[-1])) + ) + logradius_extra = self.second_layer_depth_head( + features_3d_pyramid, + internal_hw=internal_hw, + ) + if int(logradius_extra.shape[1]) != 1: + raise RuntimeError( + f"copied UniK3D depth head output channels ({int(logradius_extra.shape[1])}) must be 1" + ) + extra_distance_layers = self._distance_layers_from_unik3d_logradius( + logradius_layers=logradius_extra, + internal_rays=internal_rays.detach() if torch.is_tensor(internal_rays) else None, + final_rays=render_geometry_rays, + unik3d_output=unik3d_output, + ) + distance_layers = torch.cat([distance.clamp(min=1e-4), extra_distance_layers], dim=1) + distance_ray_align_factor = None + + max_distance_m = float(getattr(self.config, "max_distance_m", DEFAULT_MAX_DEPTH_M)) + finite_cap_m = max_distance_m if max_distance_m > 0.0 else DEFAULT_MAX_DEPTH_M + distance_layers = torch.nan_to_num( + distance_layers, + nan=finite_cap_m, + posinf=finite_cap_m, + neginf=1e-4, + ).clamp(min=1e-4) + + circular_horizontal = bool(is_spherical) + self._set_initializer_circular_mode(circular_horizontal=circular_horizontal) + + init_cap_m = float(distance_init_cap_m) if distance_init_cap_m is not None else max_distance_m + if max_distance_m > 0.0 and init_cap_m > 0.0: + init_cap_m = min(max_distance_m, init_cap_m) + elif max_distance_m > 0.0: + init_cap_m = max_distance_m + distance_layers_for_supervision = distance_layers + init_distance_layers = distance_layers.clamp(min=1e-4) + if init_cap_m > 0.0: + init_distance_layers = init_distance_layers.clamp(max=init_cap_m) + + init_rays = render_geometry_rays + init_layer0_distance = init_distance_layers[:, 0:1] + if bool(getattr(self.config, "detach_init_layer0_distance", True)): + init_layer0_distance = init_layer0_distance.detach() + init_distance_layers = torch.cat( + [ + init_layer0_distance, + init_distance_layers[:, 1:], + ], + dim=1, + ) + + feat_2d = features_2d + feat_3d = features_3d_pyramid + + scale_cell_intrinsics = camera_intrinsics + scale_cell_camera_params = camera_params + scale_cell_rays = None + + init_output = self.init_model( + image=image, + rays=init_rays, + distance=init_distance_layers, + angular_cell_rays=scale_cell_rays, + grid_cell_size_override=self._initializer_grid_cell_size_override( + camera_intrinsics=scale_cell_intrinsics, + camera_params=scale_cell_camera_params, + image=image, + is_spherical=bool(is_spherical), + ), + target_hw=None, + ) + base_values = init_output.gaussian_base_values + + base_hw = (int(base_values.rays.shape[-2]), int(base_values.rays.shape[-1])) + + decoded_features = self.feature_decoder( + feat_2d, + feat_3d, + circular_horizontal=bool(is_spherical), + target_hw=base_hw, + ) + feature_hw = ( + int(decoded_features.texture_features.shape[-2]), + int(decoded_features.texture_features.shape[-1]), + ) + if tuple(decoded_features.geometry_features.shape[-2:]) != feature_hw: + raise RuntimeError( + "Texture and geometry feature grids must match, " + f"got texture={feature_hw} geometry={tuple(decoded_features.geometry_features.shape[-2:])}" + ) + if feature_hw != base_hw: + raise RuntimeError( + "Decoded feature grid must match initializer Gaussian grid, " + f"got features={feature_hw} base={base_hw}" + ) + + dropout_prob_map: torch.Tensor | None = None + dropout_mask: torch.Tensor | None = None + dropout_prob_map = self._build_spherical_dropout_prob_map( + base_values, + is_spherical=bool(is_spherical), + ) + if bool(is_spherical) and (dropout_prob_map is not None): + dropout_mask = self._sample_dropout_mask( + dropout_prob_map, + num_layers=int(base_values.rays.shape[2]), + ) + + image_features = decoded_features + + delta = self._predict_delta(image_features) + + gaussians = self.gaussian_composer( + delta=delta, + base_values=base_values, + global_scale=init_output.global_scale, + flatten_output=True, + ) + gaussians = self._apply_dropout_to_gaussians(gaussians, dropout_mask) + + if not return_aux: + return gaussians + else: + return { + "gaussians": gaussians, + "gaussian_base_values": init_output.gaussian_base_values, + "gaussian_base_values_for_composer": base_values, + "delta": delta, + "delta_rho_applied": self.gaussian_composer.apply_delta_rho(delta[:, 2:3]), + "scale_factor_applied": self.gaussian_composer.apply_scale_factor( + self.gaussian_composer._smooth_scale_delta( + delta[:, 3:6], + circular_horizontal=bool(is_spherical), + ) + ), + "unik3d_rays": rays, + "unik3d_ray_conditioning_rays": unik3d_output.get("ray_conditioning_rays", None), + "unik3d_gt_rays": gt_rays, + "geometry_rays": geometry_rays, + "initializer_geometry_rays": render_geometry_rays, + "detach_init_layer0_distance": bool(getattr(self.config, "detach_init_layer0_distance", True)), + "unik3d_distance": distance, + "distance_ray_align_factor": distance_ray_align_factor, + "distance_layers": distance_layers_for_supervision, + "init_distance_layers": init_distance_layers, + "decoded_image_features": image_features, + "gaussian_dropout_mask": dropout_mask, + "unik3d_features_2d": feat_2d, + "unik3d_features_3d": feat_3d, + "initializer_output": init_output, + "second_layer_dropout_prob_map": dropout_prob_map, + "second_layer_dropout_mask": dropout_mask, + } + + def get_trainable_parameters(self) -> list[torch.nn.Parameter]: + params = [] + + params.extend([p for p in self.feature_decoder.parameters() if p.requires_grad]) + + params.extend([p for p in self.init_model.parameters() if p.requires_grad]) + params.extend([p for p in self.prediction_head.parameters() if p.requires_grad]) + params.extend([p for p in self.gaussian_composer.parameters() if p.requires_grad]) + params.extend([p for p in self.second_layer_depth_head.parameters() if p.requires_grad]) + + for _, p in self.feature_extractor.named_parameters(): + if not p.requires_grad: + continue + params.append(p) + + return params + + def load_from_checkpoint(self, ckpt_path: str, strict: bool = False) -> tuple[list[str], list[str]]: + try: + payload = torch.load(ckpt_path, map_location="cpu", weights_only=False) + except TypeError: + payload = torch.load(ckpt_path, map_location="cpu") + + def _finish(result: tuple[list[str], list[str]]) -> tuple[list[str], list[str]]: + return result + + if isinstance(payload, dict): + if "feature_extractor" in payload: + missing_keys: list[str] = [] + unexpected_keys: list[str] = [] + + def _merge_incompatible(prefix: str, incompatible: Any) -> None: + missing = list(getattr(incompatible, "missing_keys", [])) + unexpected = list(getattr(incompatible, "unexpected_keys", [])) + missing_keys.extend([f"{prefix}.{k}" for k in missing]) + unexpected_keys.extend([f"{prefix}.{k}" for k in unexpected]) + + def _missing_all(prefix: str, module: nn.Module) -> None: + missing_keys.extend([f"{prefix}.{k}" for k in module.state_dict().keys()]) + + def _load_module_compat( + prefix: str, + module: nn.Module, + state: dict[str, torch.Tensor], + *, + resize_spatial_kernels: bool = False, + ) -> None: + target_state = module.state_dict() + filtered: dict[str, torch.Tensor] = {} + for key, value in state.items(): + target = target_state.get(key) + if target is None: + unexpected_keys.append(f"{prefix}.{key}") + continue + if tuple(value.shape) == tuple(target.shape): + filtered[key] = value + continue + can_resize = ( + bool(resize_spatial_kernels) + and isinstance(value, torch.Tensor) + and isinstance(target, torch.Tensor) + and value.ndim == 4 + and target.ndim == 4 + and tuple(value.shape[:2]) == tuple(target.shape[:2]) + ) + if can_resize: + resized = F.interpolate( + value.to(dtype=torch.float32), + size=tuple(int(x) for x in target.shape[-2:]), + mode="bilinear", + align_corners=False, + ).to(dtype=target.dtype) + filtered[key] = resized + else: + unexpected_keys.append(f"{prefix}.{key}") + _merge_incompatible(prefix, module.load_state_dict(filtered, strict=False)) + + if "feature_extractor" in payload and isinstance(payload["feature_extractor"], dict): + state = self._strip_module_prefix(payload["feature_extractor"]) + _merge_incompatible( + "feature_extractor", + self.feature_extractor.load_state_dict(state, strict=False), + ) + else: + _missing_all("feature_extractor", self.feature_extractor) + + if "feature_decoder" in payload and isinstance(payload["feature_decoder"], dict): + state = self._strip_module_prefix(payload["feature_decoder"]) + _load_module_compat( + "feature_decoder", + self.feature_decoder, + state, + resize_spatial_kernels=True, + ) + else: + _missing_all("feature_decoder", self.feature_decoder) + + if "init_model" in payload and isinstance(payload["init_model"], dict): + state = self._strip_module_prefix(payload["init_model"]) + _merge_incompatible( + "init_model", + self.init_model.load_state_dict(state, strict=False), + ) + else: + _missing_all("init_model", self.init_model) + + if "prediction_head" in payload and isinstance(payload["prediction_head"], dict): + state = self._strip_module_prefix(payload["prediction_head"]) + miss, unexp, _ = self._load_prediction_head_compat(state) + missing_keys.extend([f"prediction_head.{k}" for k in miss]) + unexpected_keys.extend([f"prediction_head.{k}" for k in unexp]) + else: + _missing_all("prediction_head", self.prediction_head) + + if "gaussian_composer" in payload and isinstance(payload["gaussian_composer"], dict): + state = self._strip_module_prefix(payload["gaussian_composer"]) + _merge_incompatible( + "gaussian_composer", + self.gaussian_composer.load_state_dict(state, strict=False), + ) + else: + _missing_all("gaussian_composer", self.gaussian_composer) + + if "second_layer_depth_head" in payload and isinstance(payload["second_layer_depth_head"], dict): + state = self._strip_module_prefix(payload["second_layer_depth_head"]) + miss, unexp = self._load_depth_head_compat(state) + missing_keys.extend([f"second_layer_depth_head.{k}" for k in miss]) + unexpected_keys.extend([f"second_layer_depth_head.{k}" for k in unexp]) + else: + _missing_all("second_layer_depth_head", self.second_layer_depth_head) + + expected_top = { + "step", + "feature_extractor", + "feature_decoder", + "init_model", + "prediction_head", + "gaussian_composer", + "second_layer_depth_head", + "config", + "decoder_params", + "optimizer", + "use_feature_only", + "unik3d_backbone", + } + for k in payload.keys(): + if k not in expected_top: + unexpected_keys.append(f"payload.{k}") + if strict and (missing_keys or unexpected_keys): + raise RuntimeError( + "Feature-only checkpoint is incompatible: " + f"missing={missing_keys[:20]} unexpected={unexpected_keys[:20]}" + ) + return _finish((missing_keys, unexpected_keys)) + + elif "model" in payload: + state = payload["model"] + return _finish(self._load_unisharp_checkpoint(state, strict=strict)) + elif "state_dict" in payload and isinstance(payload["state_dict"], dict): + state_dict = self._strip_module_prefix(payload["state_dict"]) + if self._looks_like_unik3d_state_dict(state_dict): + return _finish(self._load_unik3d_state_dict(state_dict)) + return _finish(self.load_state_dict(state_dict, strict=strict)) + else: + raw_state = self._strip_module_prefix(payload) + if self._looks_like_unik3d_state_dict(raw_state): + return _finish(self._load_unik3d_state_dict(raw_state)) + return _finish(self.load_state_dict(payload, strict=strict)) + else: + return _finish(self.load_state_dict(payload, strict=strict)) + + def _load_unisharp_checkpoint(self, state: dict, strict: bool = False) -> tuple[list[str], list[str]]: + state = self._strip_module_prefix(state) + init_state = {k.replace("init_model.", ""): v for k, v in state.items() if k.startswith("init_model.")} + comp_state = {k.replace("gaussian_composer.", ""): v for k, v in state.items() if k.startswith("gaussian_composer.")} + depth_head_state = { + k.replace("second_layer_depth_head.", ""): v + for k, v in state.items() + if k.startswith("second_layer_depth_head.") + } + missing_keys = [] + unexpected_keys = [] + + if init_state: + m, u = self.init_model.load_state_dict(init_state, strict=False) + missing_keys.extend(m) + unexpected_keys.extend(u) + + + if comp_state: + m, u = self.gaussian_composer.load_state_dict(comp_state, strict=False) + missing_keys.extend(m) + unexpected_keys.extend(u) + + if depth_head_state: + m, u = self._load_depth_head_compat(depth_head_state) + missing_keys.extend(m) + unexpected_keys.extend(u) + + known_prefixes = ( + "init_model.", + "prediction_head.", + "gaussian_composer.", + "second_layer_depth_head.", + ) + for k in state.keys(): + if not any(str(k).startswith(pref) for pref in known_prefixes): + unexpected_keys.append(str(k)) + return (missing_keys, unexpected_keys) + + def save_checkpoint(self, path: str, step: int, optimizer: torch.optim.Optimizer | None = None) -> None: + from dataclasses import asdict + + ckpt = { + "step": step, + "feature_extractor": self.feature_extractor.state_dict(), + "feature_decoder": self.feature_decoder.state_dict(), + "init_model": self.init_model.state_dict(), + "prediction_head": self.prediction_head.state_dict(), + "gaussian_composer": self.gaussian_composer.state_dict(), + "second_layer_depth_head": self.second_layer_depth_head.state_dict(), + "config": self.config.__dict__, + "decoder_params": asdict(self.decoder_params), + "use_feature_only": True, + "unik3d_backbone": self.config.unik3d_backbone, + } + if optimizer is not None: + ckpt["optimizer"] = optimizer.state_dict() + + torch.save(ckpt, path) diff --git a/unisharp/models/unisharp_params.py b/unisharp/models/unisharp_params.py new file mode 100644 index 0000000000000000000000000000000000000000..f6331706ec7cc33d3e85f3ea7d84ecffa42ee6d5 --- /dev/null +++ b/unisharp/models/unisharp_params.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import dataclasses +from typing import Literal + +import unisharp.utils.math as math_utils + + +@dataclasses.dataclass +class DeltaFactor: + + xy: float = 0.001 + z: float = 0.001 + color: float = 0.1 + opacity: float = 1.0 + scale: float = 1.0 + quaternion: float = 1.0 + + +@dataclasses.dataclass +class PanoInitializerParams: + + stride: int = 2 + num_layers: int = 2 + scale_factor: float = 1.0 + opacity_init: float = 0.5 + normalize_distance: bool = True + circular_horizontal: bool = True + + first_layer_depth_option: Literal["surface_min", "surface_max"] = "surface_min" + rest_layer_depth_option: Literal["surface_min", "surface_max"] = "surface_min" + +@dataclasses.dataclass +class PanoPredictorParams: + + initializer: PanoInitializerParams = dataclasses.field(default_factory=PanoInitializerParams) + + delta_factor: DeltaFactor = dataclasses.field(default_factory=DeltaFactor) + color_activation_type: math_utils.ActivationType = "sigmoid" + opacity_activation_type: math_utils.ActivationType = "sigmoid" + max_scale: float = 10.0 + min_scale: float = 0.0 + base_scale_on_predicted_mean: bool = True + + unik3d_backbone: str = "vitl" + unik3d_pretrained: bool = True + num_monodepth_layers: int = 2 diff --git a/unisharp/utils/__init__.py b/unisharp/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/unisharp/utils/__pycache__/__init__.cpython-310.pyc b/unisharp/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f8a636e96bb79fbba115e8da37a6b405b275597 Binary files /dev/null and b/unisharp/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/unisharp/utils/__pycache__/__init__.cpython-313.pyc b/unisharp/utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd15add97de66c1fd61c675c1161e2f181ee5f1a Binary files /dev/null and b/unisharp/utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/camera_projection.cpython-313.pyc b/unisharp/utils/__pycache__/camera_projection.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85dd4f116cf5332bd02ab775bb8c627c6723cffe Binary files /dev/null and b/unisharp/utils/__pycache__/camera_projection.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/camera_utils.cpython-313.pyc b/unisharp/utils/__pycache__/camera_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c0bacb7313dff6bb9f591e123fe5f5517c12c16 Binary files /dev/null and b/unisharp/utils/__pycache__/camera_utils.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/color_space.cpython-310.pyc b/unisharp/utils/__pycache__/color_space.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bed3d0a735df66532061a55913e12190d82d1d68 Binary files /dev/null and b/unisharp/utils/__pycache__/color_space.cpython-310.pyc differ diff --git a/unisharp/utils/__pycache__/color_space.cpython-313.pyc b/unisharp/utils/__pycache__/color_space.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7328de4cc51bcef92231efabb55fd2182ef2916d Binary files /dev/null and b/unisharp/utils/__pycache__/color_space.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/fisheye_geer.cpython-313.pyc b/unisharp/utils/__pycache__/fisheye_geer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9393baac3112f151dddc1eb3bb8017e0260a052a Binary files /dev/null and b/unisharp/utils/__pycache__/fisheye_geer.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/gaussians.cpython-310.pyc b/unisharp/utils/__pycache__/gaussians.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a564e35650013219eed7e67f5725ff922287afc Binary files /dev/null and b/unisharp/utils/__pycache__/gaussians.cpython-310.pyc differ diff --git a/unisharp/utils/__pycache__/gaussians.cpython-313.pyc b/unisharp/utils/__pycache__/gaussians.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a07752dbc5049c8491c333687d53e480921c8454 Binary files /dev/null and b/unisharp/utils/__pycache__/gaussians.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/gsplat.cpython-310.pyc b/unisharp/utils/__pycache__/gsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86e0e57b07fef100cf30ad6d46bdf1fd9819d359 Binary files /dev/null and b/unisharp/utils/__pycache__/gsplat.cpython-310.pyc differ diff --git a/unisharp/utils/__pycache__/gsplat.cpython-313.pyc b/unisharp/utils/__pycache__/gsplat.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f9033b66d629e690bc4743251f51d637cc441a7 Binary files /dev/null and b/unisharp/utils/__pycache__/gsplat.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/io.cpython-313.pyc b/unisharp/utils/__pycache__/io.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b20c8946088f7a7ad882e6e2f1982afe0216faef Binary files /dev/null and b/unisharp/utils/__pycache__/io.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/linalg.cpython-310.pyc b/unisharp/utils/__pycache__/linalg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..245d5284ff70db368be05d7067ba980d4c269ec5 Binary files /dev/null and b/unisharp/utils/__pycache__/linalg.cpython-310.pyc differ diff --git a/unisharp/utils/__pycache__/linalg.cpython-313.pyc b/unisharp/utils/__pycache__/linalg.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..159aff28851682ae29c83711bbaae66d934be805 Binary files /dev/null and b/unisharp/utils/__pycache__/linalg.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/logging.cpython-310.pyc b/unisharp/utils/__pycache__/logging.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5fff5571edb2e51ccc521e339a5869f48449102 Binary files /dev/null and b/unisharp/utils/__pycache__/logging.cpython-310.pyc differ diff --git a/unisharp/utils/__pycache__/logging.cpython-313.pyc b/unisharp/utils/__pycache__/logging.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea34194760c7c58da8d5f55807a9a754887cbf3a Binary files /dev/null and b/unisharp/utils/__pycache__/logging.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/math.cpython-310.pyc b/unisharp/utils/__pycache__/math.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7b4a42133f6bad0d8d961a1495e796878801604 Binary files /dev/null and b/unisharp/utils/__pycache__/math.cpython-310.pyc differ diff --git a/unisharp/utils/__pycache__/math.cpython-313.pyc b/unisharp/utils/__pycache__/math.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..052b76029f821c2802e944313a736973b0553b66 Binary files /dev/null and b/unisharp/utils/__pycache__/math.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/pano.cpython-313.pyc b/unisharp/utils/__pycache__/pano.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cd4dfc3c03dda8a20b698cc24bdca8a0eb096b4 Binary files /dev/null and b/unisharp/utils/__pycache__/pano.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/pixel_convention.cpython-310.pyc b/unisharp/utils/__pycache__/pixel_convention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4458965f45520d26602eccca8ca352441c6d512c Binary files /dev/null and b/unisharp/utils/__pycache__/pixel_convention.cpython-310.pyc differ diff --git a/unisharp/utils/__pycache__/pixel_convention.cpython-313.pyc b/unisharp/utils/__pycache__/pixel_convention.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c826c1bdb51212740b05dd6456cdeca90fb6bf0 Binary files /dev/null and b/unisharp/utils/__pycache__/pixel_convention.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/rayfit_camera.cpython-313.pyc b/unisharp/utils/__pycache__/rayfit_camera.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1d509c3abe83d4245777cc5b0f82bae30223da7 Binary files /dev/null and b/unisharp/utils/__pycache__/rayfit_camera.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/unified_vis.cpython-313.pyc b/unisharp/utils/__pycache__/unified_vis.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15ef8565971ee0cb7bbddf39c3bbed3c415853fe Binary files /dev/null and b/unisharp/utils/__pycache__/unified_vis.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/unik3d_adapter.cpython-310.pyc b/unisharp/utils/__pycache__/unik3d_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6373661657b6fda01175742c8938b1fe3546ec49 Binary files /dev/null and b/unisharp/utils/__pycache__/unik3d_adapter.cpython-310.pyc differ diff --git a/unisharp/utils/__pycache__/unik3d_adapter.cpython-313.pyc b/unisharp/utils/__pycache__/unik3d_adapter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..263fc26c8c879212b4d4f8edbf8f8cde02d54ba2 Binary files /dev/null and b/unisharp/utils/__pycache__/unik3d_adapter.cpython-313.pyc differ diff --git a/unisharp/utils/__pycache__/vis.cpython-313.pyc b/unisharp/utils/__pycache__/vis.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27c5eb197cfbc0bcf254e76d44837dcc91e2f63c Binary files /dev/null and b/unisharp/utils/__pycache__/vis.cpython-313.pyc differ diff --git a/unisharp/utils/camera_projection.py b/unisharp/utils/camera_projection.py new file mode 100644 index 0000000000000000000000000000000000000000..5c29e7159742f7ddf2fb8cfd452d5298a35779f5 --- /dev/null +++ b/unisharp/utils/camera_projection.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from typing import Literal + +import torch +import torch.nn.functional as F +from unisharp.utils.pixel_convention import integer_pixel_center_grid + +from .pano import get_cubemap_extrinsics_4x4, get_pinhole_intrinsics_4x4 + + +PoseConvention = Literal["c2w", "w2c", "c2w_t_w2c", "w2c_t_camcenter"] + + +def build_extrinsics_w2c( + R: torch.Tensor, t: torch.Tensor, convention: PoseConvention = "c2w" +) -> torch.Tensor: + R = R.to(torch.float32) + t = t.to(torch.float32) + ext = torch.eye(4, dtype=torch.float32, device=R.device) + if convention == "w2c": + ext[:3, :3] = R + ext[:3, 3] = t + return ext + if convention == "c2w": + ext[:3, :3] = R.T + ext[:3, 3] = -(R.T @ t) + return ext + if convention == "c2w_t_w2c": + ext[:3, :3] = R.T + ext[:3, 3] = t + return ext + if convention == "w2c_t_camcenter": + ext[:3, :3] = R + ext[:3, 3] = -(R @ t) + return ext + raise ValueError(f"Unsupported convention: {convention}") + + +def cubemap_face_cameras(base_extr_w2c: torch.Tensor, device: torch.device) -> torch.Tensor: + face_only = get_cubemap_extrinsics_4x4(device=device, yaw_degrees=0.0) + return face_only @ base_extr_w2c[None] + + +def view_frustum_mask_cubemap_union( + depth_novel: torch.Tensor, + extr_novel_w2c: torch.Tensor, + extr_source_w2c: torch.Tensor, + face_w: int, + margin: float = 1.05, + source_depth: torch.Tensor | None = None, + source_occlusion_tolerance_m: float = 0.0, + source_occlusion_tolerance_ratio: float = 0.10, + source_visibility_radius_px: int = 0, +) -> torch.Tensor: + with torch.autocast(device_type=depth_novel.device.type, enabled=False): + depth_novel = depth_novel.to(dtype=torch.float32) + extr_novel_w2c = extr_novel_w2c.to(dtype=torch.float32) + extr_source_w2c = extr_source_w2c.to(dtype=torch.float32) + + device = depth_novel.device + H = face_w + W = face_w + if tuple(depth_novel.shape[:3]) != (6, H, W): + raise ValueError("depth_novel must be (6,face_w,face_w,1)") + + intr = get_pinhole_intrinsics_4x4(face_w).to(device=device) + fx, fy = intr[0, 0], intr[1, 1] + cx, cy = intr[0, 2], intr[1, 2] + + uu, vv = integer_pixel_center_grid(H, W, device=device, dtype=torch.float32) + x = (uu - cx) / fx + y = (vv - cy) / fy + z = torch.ones_like(x) + rays_cam = torch.stack([x, y, z], dim=-1) + rays_cam = rays_cam / torch.norm(rays_cam, dim=-1, keepdim=True).clamp(min=1e-6) + + extr_faces_novel = cubemap_face_cameras(extr_novel_w2c, device=device) + extr_faces_src = cubemap_face_cameras(extr_source_w2c, device=device) + cam2world_novel = torch.linalg.inv(extr_faces_novel) + + depth = depth_novel[..., 0].to(torch.float32) + depth_valid = torch.isfinite(depth) & (depth > 0.0) + depth = torch.where(depth_valid, depth, torch.zeros_like(depth)) + xyz_cam = rays_cam[None].repeat(6, 1, 1, 1) * depth[..., None] + xyz_cam_h = torch.cat( + [xyz_cam, torch.ones_like(xyz_cam[..., :1])], dim=-1 + ) + xyz_world_h = torch.bmm( + xyz_cam_h.reshape(6, -1, 4), + cam2world_novel.transpose(-1, -2), + ).reshape(6, H, W, 4) + + src_depth_for_min = None + src_depth_valid = None + if torch.is_tensor(source_depth): + src_depth = source_depth.to(device=device, dtype=torch.float32) + if src_depth.ndim != 4: + raise ValueError(f"Expected source_depth shape (6,H,W,1) or (6,1,H,W), got {tuple(src_depth.shape)}") + if int(src_depth.shape[0]) != 6: + raise ValueError(f"Expected source_depth first dimension to be 6, got {tuple(src_depth.shape)}") + if int(src_depth.shape[-1]) == 1: + src_depth_bchw = src_depth.permute(0, 3, 1, 2).contiguous() + elif int(src_depth.shape[1]) == 1: + src_depth_bchw = src_depth.contiguous() + else: + raise ValueError(f"Expected source_depth shape (6,H,W,1) or (6,1,H,W), got {tuple(src_depth.shape)}") + if tuple(src_depth_bchw.shape[-2:]) != (H, W): + src_depth_bchw = F.interpolate(src_depth_bchw, size=(H, W), mode="nearest") + src_depth_valid = torch.isfinite(src_depth_bchw) & (src_depth_bchw > 0.0) + invalid_depth_fill = 1.0e9 + src_depth_for_min = torch.where( + src_depth_valid, + src_depth_bchw, + torch.full_like(src_depth_bchw, invalid_depth_fill), + ) + radius = max(int(source_visibility_radius_px), 0) + if radius > 0: + kernel = 2 * radius + 1 + padded_depth = F.pad(src_depth_for_min, (radius, radius, radius, radius), value=invalid_depth_fill) + src_depth_for_min = -F.max_pool2d(-padded_depth, kernel_size=kernel, stride=1) + src_depth_valid = ( + F.max_pool2d(src_depth_valid.to(dtype=torch.float32), kernel_size=kernel, stride=1, padding=radius) + > 0.0 + ) + + mask_any = torch.zeros((6, H, W), dtype=torch.bool, device=device) + for j in range(6): + ext = extr_faces_src[j] + xyz_src = xyz_world_h @ ext.T + X, Y, Z = xyz_src[..., 0], xyz_src[..., 1], xyz_src[..., 2].clamp(min=1e-6) + x_ndc = X / Z + y_ndc = Y / Z + inside = ( + (x_ndc >= -margin) + & (x_ndc <= margin) + & (y_ndc >= -margin) + & (y_ndc <= margin) + & (Z > 0) + ) + inside = inside & depth_valid + if src_depth_for_min is not None and src_depth_valid is not None: + u = (x_ndc * fx) + cx + v = (y_ndc * fy) + cy + sample_grid = torch.stack( + [ + (u / max(float(W - 1), 1.0)) * 2.0 - 1.0, + (v / max(float(H - 1), 1.0)) * 2.0 - 1.0, + ], + dim=-1, + ) + src_depth_face = src_depth_for_min[j : j + 1].expand(6, -1, -1, -1) + src_valid_face = src_depth_valid[j : j + 1].to(dtype=torch.float32).expand(6, -1, -1, -1) + sampled_src_dist = F.grid_sample( + src_depth_face, + sample_grid, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + )[:, 0] + sampled_src_valid = ( + F.grid_sample( + src_valid_face, + sample_grid, + mode="nearest", + padding_mode="zeros", + align_corners=True, + )[:, 0] + > 0.5 + ) + projected_src_dist = torch.linalg.vector_norm(xyz_src[..., :3], dim=-1) + tolerance = float(source_occlusion_tolerance_m) + float(source_occlusion_tolerance_ratio) * sampled_src_dist.abs() + source_visible = sampled_src_valid & torch.isfinite(sampled_src_dist) & ( + projected_src_dist <= sampled_src_dist + tolerance + ) + inside = inside & source_visible + mask_any |= inside + return mask_any + diff --git a/unisharp/utils/camera_utils.py b/unisharp/utils/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a08aefae48bf260cf1195cba63eeee059d9fd55 --- /dev/null +++ b/unisharp/utils/camera_utils.py @@ -0,0 +1,372 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any + +import torch +import torch.nn.functional as F + + +def reproject_pinhole_z_depth_same_pose( + z_depth: torch.Tensor | None, + src_k3: torch.Tensor | None, + dst_k3: torch.Tensor | None, + *, + dst_hw: tuple[int, int] | None = None, +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if not (torch.is_tensor(z_depth) and torch.is_tensor(src_k3) and torch.is_tensor(dst_k3)): + return None, None + depth = z_depth + if depth.ndim == 3: + depth = depth.unsqueeze(1) + if depth.ndim != 4 or int(depth.shape[1]) != 1: + raise ValueError(f"Expected z_depth shape (B,1,H,W), got {tuple(depth.shape)}") + device = depth.device + dtype = torch.float32 + depth = depth.to(device=device, dtype=dtype) + if src_k3.ndim == 2: + src_k3 = src_k3.unsqueeze(0) + if dst_k3.ndim == 2: + dst_k3 = dst_k3.unsqueeze(0) + src_k = src_k3.to(device=device, dtype=dtype) + dst_k = dst_k3.to(device=device, dtype=dtype) + bsz, _, src_h, src_w = depth.shape + if int(src_k.shape[0]) == 1 and bsz > 1: + src_k = src_k.expand(bsz, -1, -1) + if int(dst_k.shape[0]) == 1 and bsz > 1: + dst_k = dst_k.expand(bsz, -1, -1) + if int(src_k.shape[0]) != bsz or int(dst_k.shape[0]) != bsz: + raise ValueError( + f"Batch mismatch: depth B={bsz}, src_k={tuple(src_k.shape)}, dst_k={tuple(dst_k.shape)}" + ) + dst_h, dst_w = ( + (int(dst_hw[0]), int(dst_hw[1])) + if dst_hw is not None + else (int(src_h), int(src_w)) + ) + + yy, xx = torch.meshgrid( + torch.arange(src_h, device=device, dtype=dtype), + torch.arange(src_w, device=device, dtype=dtype), + indexing="ij", + ) + xx_flat = xx.reshape(-1) + yy_flat = yy.reshape(-1) + out_depth: list[torch.Tensor] = [] + out_valid: list[torch.Tensor] = [] + inf = torch.tensor(float("inf"), device=device, dtype=dtype) + + for b in range(bsz): + z = depth[b, 0].reshape(-1) + valid = torch.isfinite(z) & (z > 0.0) + if not bool(valid.any()): + z_out = torch.zeros((dst_h * dst_w,), device=device, dtype=dtype) + v_out = torch.zeros_like(z_out, dtype=torch.bool) + out_depth.append(z_out.reshape(1, dst_h, dst_w)) + out_valid.append(v_out.reshape(1, dst_h, dst_w)) + continue + + fx_s = src_k[b, 0, 0].clamp(min=1e-6) + fy_s = src_k[b, 1, 1].clamp(min=1e-6) + cx_s = src_k[b, 0, 2] + cy_s = src_k[b, 1, 2] + fx_d = dst_k[b, 0, 0].clamp(min=1e-6) + fy_d = dst_k[b, 1, 1].clamp(min=1e-6) + cx_d = dst_k[b, 0, 2] + cy_d = dst_k[b, 1, 2] + + z_v = z[valid] + x = (xx_flat[valid] - cx_s) * z_v / fx_s + y = (yy_flat[valid] - cy_s) * z_v / fy_s + u = fx_d * (x / z_v.clamp(min=1e-6)) + cx_d + v = fy_d * (y / z_v.clamp(min=1e-6)) + cy_d + + u0 = torch.floor(u) + v0 = torch.floor(v) + lin_parts: list[torch.Tensor] = [] + z_parts: list[torch.Tensor] = [] + for du in (0.0, 1.0): + for dv in (0.0, 1.0): + ui = (u0 + du).to(torch.long) + vi = (v0 + dv).to(torch.long) + in_bounds = ( + torch.isfinite(u) + & torch.isfinite(v) + & (ui >= 0) + & (ui < dst_w) + & (vi >= 0) + & (vi < dst_h) + ) + if bool(in_bounds.any()): + lin_parts.append(vi[in_bounds] * dst_w + ui[in_bounds]) + z_parts.append(z_v[in_bounds]) + + zbuf = torch.full((dst_h * dst_w,), inf, device=device, dtype=dtype) + if lin_parts: + lin = torch.cat(lin_parts, dim=0) + vals = torch.cat(z_parts, dim=0) + if hasattr(zbuf, "scatter_reduce_"): + zbuf.scatter_reduce_(0, lin, vals, reduce="amin", include_self=True) + else: + order = torch.argsort(vals, descending=True) + zbuf[lin[order]] = vals[order] + valid_out = torch.isfinite(zbuf) + zbuf = torch.where(valid_out, zbuf, torch.zeros_like(zbuf)) + out_depth.append(zbuf.reshape(1, dst_h, dst_w)) + out_valid.append(valid_out.reshape(1, dst_h, dst_w)) + + return torch.stack(out_depth, dim=0), torch.stack(out_valid, dim=0) + + +class CameraType(Enum): + PINHOLE = "pinhole" + SPHERICAL = "spherical" + + +def detect_camera_type(camera_intrinsics: torch.Tensor | None) -> CameraType: + return CameraType.SPHERICAL if camera_intrinsics is None else CameraType.PINHOLE + + +def transform_gaussians_to_world( + gaussians: Any, + src_w2c: torch.Tensor, +) -> Any: + c2w = torch.linalg.inv(src_w2c).to(torch.float32) + r = c2w[:3, :3] + t = c2w[:3, 3] + + means_world = gaussians.mean_vectors.to(torch.float32) @ r.T + t[None, None, :] + + q_r = rotmat_to_quat_wxyz(r) + q_world = quat_mul_wxyz( + q_r[None, None, :].expand_as(gaussians.quaternions), + gaussians.quaternions.to(torch.float32) + ) + q_world = q_world / q_world.norm(dim=-1, keepdim=True).clamp(min=1e-8) + + return type(gaussians)( + mean_vectors=means_world.to(gaussians.mean_vectors.dtype), + singular_values=gaussians.singular_values, + quaternions=q_world.to(gaussians.quaternions.dtype), + colors=gaussians.colors, + opacities=gaussians.opacities, + ) + + +def rotmat_to_quat_wxyz(R: torch.Tensor) -> torch.Tensor: + trace = R[0, 0] + R[1, 1] + R[2, 2] + + if trace > 0: + s = 0.5 / torch.sqrt(trace + 1.0) + w = 0.25 / s + x = (R[2, 1] - R[1, 2]) * s + y = (R[0, 2] - R[2, 0]) * s + z = (R[1, 0] - R[0, 1]) * s + elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]: + s = 2.0 * torch.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2]) + w = (R[2, 1] - R[1, 2]) / s + x = 0.25 * s + y = (R[0, 1] + R[1, 0]) / s + z = (R[0, 2] + R[2, 0]) / s + elif R[1, 1] > R[2, 2]: + s = 2.0 * torch.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2]) + w = (R[0, 2] - R[2, 0]) / s + x = (R[0, 1] + R[1, 0]) / s + y = 0.25 * s + z = (R[1, 2] + R[2, 1]) / s + else: + s = 2.0 * torch.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1]) + w = (R[1, 0] - R[0, 1]) / s + x = (R[0, 2] + R[2, 0]) / s + y = (R[1, 2] + R[2, 1]) / s + z = 0.25 * s + + return torch.stack([w, x, y, z], dim=0) + + +def quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + w1, x1, y1, z1 = q1.unbind(dim=-1) + w2, x2, y2, z2 = q2.unbind(dim=-1) + + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + + return torch.stack([w, x, y, z], dim=-1) + + +def to_k4(k3: torch.Tensor) -> torch.Tensor: + if k3.ndim == 2: + k4 = torch.eye(4, dtype=k3.dtype, device=k3.device) + k4[:3, :3] = k3 + return k4 + else: + B = k3.shape[0] + k4 = torch.eye(4, dtype=k3.dtype, device=k3.device)[None].expand(B, -1, -1).contiguous() + k4 = k4.clone() + k4[:, :3, :3] = k3 + return k4 + + +def compute_frustum_mask( + depth: torch.Tensor, + tgt_w2c: torch.Tensor, + src_w2c: torch.Tensor, + src_k3: torch.Tensor, + tgt_k3: torch.Tensor, + img_h: int, + img_w: int, + source_img_h: int | None = None, + source_img_w: int | None = None, + frustum_margin: float = 1.05, + source_depth: torch.Tensor | None = None, + source_occlusion_tolerance_m: float = 0.0, + source_occlusion_tolerance_ratio: float = 0.10, + source_visibility_radius_px: int = 0, +) -> torch.Tensor: + device = depth.device + src_h = int(img_h if source_img_h is None else source_img_h) + src_w = int(img_w if source_img_w is None else source_img_w) + + y_coords, x_coords = torch.meshgrid( + torch.arange(img_h, device=device, dtype=torch.float32), + torch.arange(img_w, device=device, dtype=torch.float32), + indexing="ij", + ) + + fx_t = tgt_k3[0, 0, 0] + fy_t = tgt_k3[0, 1, 1] + cx_t = tgt_k3[0, 0, 2] + cy_t = tgt_k3[0, 1, 2] + + z = depth[0, 0] + x_cam = (x_coords - cx_t) * z / fx_t + y_cam = (y_coords - cy_t) * z / fy_t + + pts_tgt_cam = torch.stack([x_cam, y_cam, z, torch.ones_like(z)], dim=0) + pts_tgt_cam = pts_tgt_cam.reshape(4, -1) + + tgt_c2w = torch.linalg.inv(tgt_w2c[0]).to(torch.float32) + pts_world = tgt_c2w @ pts_tgt_cam + + pts_src_cam = src_w2c[0].to(torch.float32) @ pts_world + + fx_s = src_k3[0, 0, 0] + fy_s = src_k3[0, 1, 1] + cx_s = src_k3[0, 0, 2] + cy_s = src_k3[0, 1, 2] + + x_src = pts_src_cam[0] / pts_src_cam[2].clamp(min=1e-6) + y_src = pts_src_cam[1] / pts_src_cam[2].clamp(min=1e-6) + + u_src = fx_s * x_src + cx_s + v_src = fy_s * y_src + cy_s + + margin = max(float(frustum_margin), 1.0) + margin_x = 0.5 * (margin - 1.0) * float(src_w) + margin_y = 0.5 * (margin - 1.0) * float(src_h) + valid_depth = (torch.isfinite(z) & (z > 0)).reshape(-1) + valid = ( + (u_src >= -margin_x) & (u_src < float(src_w) + margin_x) & + (v_src >= -margin_y) & (v_src < float(src_h) + margin_y) & + valid_depth & + (pts_src_cam[2] > 0) + ) + + if torch.is_tensor(source_depth): + if source_depth.ndim == 3: + source_depth = source_depth.unsqueeze(1) + if source_depth.ndim != 4 or int(source_depth.shape[0]) != 1 or int(source_depth.shape[1]) != 1: + raise ValueError(f"Expected source_depth shape (1,1,H,W), got {tuple(source_depth.shape)}") + + src_depth = source_depth.to(device=device, dtype=torch.float32) + if tuple(src_depth.shape[-2:]) != (src_h, src_w): + src_depth = F.interpolate(src_depth, size=(src_h, src_w), mode="nearest") + + src_depth_valid = torch.isfinite(src_depth) & (src_depth > 0.0) + invalid_depth_fill = 1.0e9 + src_depth_for_min = torch.where( + src_depth_valid, + src_depth, + torch.full_like(src_depth, invalid_depth_fill), + ) + radius = max(int(source_visibility_radius_px), 0) + if radius > 0: + kernel = 2 * radius + 1 + padded_depth = F.pad(src_depth_for_min, (radius, radius, radius, radius), value=invalid_depth_fill) + src_depth_for_min = -F.max_pool2d(-padded_depth, kernel_size=kernel, stride=1) + src_depth_valid = ( + F.max_pool2d(src_depth_valid.to(dtype=torch.float32), kernel_size=kernel, stride=1, padding=radius) + > 0.0 + ) + + u_grid = (u_src.reshape(img_h, img_w) / max(float(src_w - 1), 1.0)) * 2.0 - 1.0 + v_grid = (v_src.reshape(img_h, img_w) / max(float(src_h - 1), 1.0)) * 2.0 - 1.0 + sample_grid = torch.stack([u_grid, v_grid], dim=-1)[None] + sampled_src_z = F.grid_sample( + src_depth_for_min, + sample_grid, + mode="nearest", + padding_mode="zeros", + align_corners=True, + )[0, 0].reshape(-1) + sampled_src_valid = ( + F.grid_sample( + src_depth_valid.to(dtype=torch.float32), + sample_grid, + mode="nearest", + padding_mode="zeros", + align_corners=True, + )[0, 0].reshape(-1) + > 0.5 + ) + z_src_projected = pts_src_cam[2].reshape(-1) + tolerance = float(source_occlusion_tolerance_m) + float(source_occlusion_tolerance_ratio) * sampled_src_z.abs() + source_visible = sampled_src_valid & torch.isfinite(sampled_src_z) & ( + z_src_projected <= sampled_src_z + tolerance + ) + valid = valid & source_visible + + mask = valid.reshape(img_h, img_w).float()[None, None, :, :] + return mask + + +def resize_batch( + batch: dict[str, torch.Tensor], + target_h: int, + target_w: int, + keys_to_resize: list[str] = ["image", "image_u8", "depth"], +) -> dict[str, torch.Tensor]: + for key in keys_to_resize: + if key not in batch: + continue + + tensor = batch[key] + if tensor.shape[-2:] == (target_h, target_w): + continue + + if key.endswith("_u8"): + tensor = F.interpolate( + tensor.float(), + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ).round().clamp(0, 255).to(torch.uint8) + elif "depth" in key: + tensor = F.interpolate( + tensor, + size=(target_h, target_w), + mode="nearest", + ) + else: + tensor = F.interpolate( + tensor, + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ) + + batch[key] = tensor + + return batch diff --git a/unisharp/utils/color_space.py b/unisharp/utils/color_space.py new file mode 100644 index 0000000000000000000000000000000000000000..52af018b9e58505d37fc4048295932a26681bb24 --- /dev/null +++ b/unisharp/utils/color_space.py @@ -0,0 +1,77 @@ + +from __future__ import annotations + +import logging +from typing import Callable +from typing import Literal + +import torch + +LOGGER = logging.getLogger(__name__) + +ColorSpace = Literal["sRGB", "linearRGB"] + + +def robust_where( + condition: torch.Tensor, + input: torch.Tensor, + branch_true_func: Callable[[torch.Tensor], torch.Tensor], + branch_false_func: Callable[[torch.Tensor], torch.Tensor], + branch_true_safe_value: float | None = None, + branch_false_safe_value: float | None = None, +) -> torch.Tensor: + input_1 = input + input_2 = input + if branch_true_safe_value is not None: + input_1 = torch.where(condition, input_1, branch_true_safe_value) + if branch_false_safe_value is not None: + input_2 = torch.where(~condition, input_2, branch_false_safe_value) + return torch.where( + condition, + branch_true_func(input_1), + branch_false_func(input_2), + ) + + +def encode_color_space(color_space: ColorSpace) -> int: + return 0 if color_space == "sRGB" else 1 + + +def decode_color_space(color_space_index: int) -> ColorSpace: + return "sRGB" if color_space_index == 0 else "linearRGB" + + +def sRGB2linearRGB(sRGB: torch.Tensor) -> torch.Tensor: + THRESHOLD = 0.04045 + + def branch_true_func(x): + return x / 12.92 + + def branch_false_func(x): + return ((x + 0.055) / 1.055) ** 2.4 + + return robust_where( + sRGB <= THRESHOLD, + sRGB, + branch_true_func, + branch_false_func, + branch_false_safe_value=THRESHOLD, + ) + + +def linearRGB2sRGB(linearRGB: torch.Tensor) -> torch.Tensor: + THRESHOLD = 0.0031308 + + def branch_true_func(x): + return x * 12.92 + + def branch_false_func(x): + return 1.055 * (x ** (1 / 2.4)) - 0.055 + + return robust_where( + linearRGB <= THRESHOLD, + linearRGB, + branch_true_func, + branch_false_func, + branch_false_safe_value=THRESHOLD, + ) diff --git a/unisharp/utils/fisheye_geer.py b/unisharp/utils/fisheye_geer.py new file mode 100644 index 0000000000000000000000000000000000000000..49c12f834797f80a4cc31d1861826c0cc1d13839 --- /dev/null +++ b/unisharp/utils/fisheye_geer.py @@ -0,0 +1,341 @@ + +from __future__ import annotations + +from functools import lru_cache +import importlib +from pathlib import Path +import sys +from typing import Any + +import torch +import torch.nn.functional as F + +from unisharp.utils.pixel_convention import integer_pixel_center_grid + + + +def _geer_rasterizer_root() -> Path: + return Path(__file__).resolve().parents[2] / "3dgeer" / "submodules" / "geer-rasterizer" + + +def _load_fisheye624_class() -> Any: + repo_root = Path(__file__).resolve().parents[2] + unik3d_root = repo_root / "UniK3D" + if unik3d_root.exists() and str(unik3d_root) not in sys.path: + sys.path.insert(0, str(unik3d_root)) + from unik3d.utils.camera import Fisheye624 # type: ignore + + return Fisheye624 + + +@lru_cache(maxsize=1) +def _load_geer_rasterizer() -> Any: + root = _geer_rasterizer_root() + root_str = str(root) + if root_str not in sys.path: + sys.path.insert(0, root_str) + try: + return importlib.import_module("diff_gaussian_rasterization") + except Exception as exc: + raise ImportError( + "Failed to import 3DGEER rasterizer. Build it first with: " + f"cd '{root_str}' && python setup.py build_ext --inplace" + ) from exc + + +def build_fisheye624_raymap( + camera_params: torch.Tensor, + *, + image_h: int, + image_w: int, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + if camera_params.ndim == 1: + camera_params = camera_params.unsqueeze(0) + params = camera_params.to(device=device, dtype=torch.float32) + Fisheye624 = _load_fisheye624_class() + cam = Fisheye624(params=params) + uu, vv = integer_pixel_center_grid(int(image_h), int(image_w), device=device, dtype=torch.float32) + uv = torch.stack([uu, vv], dim=0).unsqueeze(0).expand(int(params.shape[0]), -1, -1, -1) + rays = cam.unproject(uv).to(dtype=dtype) + return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-6) + + +def build_fisheye624_tangent_arrays( + camera_params: torch.Tensor, + *, + image_h: int, + image_w: int, + valid_mask: torch.Tensor | None = None, + extent_quantile: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dev = camera_params.device + rays = build_fisheye624_raymap( + camera_params, + image_h=image_h, + image_w=image_w, + device=dev, + dtype=torch.float32, + ) + z = rays[:, 2].clamp(min=1e-4) + valid = torch.isfinite(z) & (z > 1e-4) + if torch.is_tensor(valid_mask): + vm = valid_mask[:, 0] if valid_mask.ndim == 4 else valid_mask + valid = valid & (vm.to(device=dev) > 0.5) + tan_x = (rays[:, 0] / z).abs() + tan_y = (rays[:, 1] / z).abs() + valid_x = tan_x[valid & torch.isfinite(tan_x)] + valid_y = tan_y[valid & torch.isfinite(tan_y)] + q = float(extent_quantile) + if int(valid_x.numel()) > 0: + tx_max = float(valid_x.max().item()) if q >= 1.0 else float(torch.quantile(valid_x, q).item()) + else: + tx_max = 1.0 + if int(valid_y.numel()) > 0: + ty_max = float(valid_y.max().item()) if q >= 1.0 else float(torch.quantile(valid_y, q).item()) + else: + ty_max = 1.0 + tan_theta = torch.linspace(-max(tx_max, 1e-3), max(tx_max, 1e-3), steps=int(image_w), device=dev) + tan_phi = torch.linspace(-max(ty_max, 1e-3), max(ty_max, 1e-3), steps=int(image_h), device=dev) + return rays, tan_theta.contiguous(), tan_phi.contiguous() + + +def _depth_from_accumulated_invdepth( + invdepth_accum: torch.Tensor, + alpha: torch.Tensor, + *, + eps: float = 1e-8, + alpha_min: float = 1e-4, +) -> tuple[torch.Tensor, torch.Tensor]: + if invdepth_accum.ndim == 2: + invdepth_accum = invdepth_accum.unsqueeze(0).unsqueeze(0) + elif invdepth_accum.ndim == 3: + invdepth_accum = invdepth_accum.unsqueeze(1) + alpha_1 = alpha[:, :1].to(device=invdepth_accum.device, dtype=invdepth_accum.dtype) + valid = torch.isfinite(invdepth_accum) & torch.isfinite(alpha_1) & (alpha_1 > float(alpha_min)) + invdepth = torch.where(valid, invdepth_accum / alpha_1.clamp(min=eps), torch.zeros_like(invdepth_accum)) + depth = torch.where(invdepth > eps, 1.0 / invdepth.clamp(min=eps), torch.zeros_like(invdepth)) + return depth, invdepth + + +def compute_fisheye624_frustum_mask( + *, + depth_distance_m: torch.Tensor, + tgt_w2c: torch.Tensor, + src_w2c: torch.Tensor, + tgt_camera_params: torch.Tensor, + src_camera_params: torch.Tensor, + src_valid_mask: torch.Tensor | None = None, + source_depth_distance_m: torch.Tensor | None = None, + source_occlusion_tolerance_m: float = 0.0, + source_occlusion_tolerance_ratio: float = 0.10, + source_visibility_radius_px: int = 0, + edge_eps_px: float = 0.501, +) -> torch.Tensor: + if depth_distance_m.ndim != 4 or int(depth_distance_m.shape[0]) != 1: + raise ValueError(f"Expected depth shape (1,1,H,W), got {tuple(depth_distance_m.shape)}") + dev = depth_distance_m.device + dtype = torch.float32 + _, _, h, w = depth_distance_m.shape + rays = build_fisheye624_raymap( + tgt_camera_params.to(device=dev, dtype=dtype), + image_h=int(h), + image_w=int(w), + device=dev, + dtype=dtype, + ) + depth = depth_distance_m.to(dtype=dtype) + valid = torch.isfinite(depth[:, 0]) & (depth[:, 0] > 0.0) + xyz_tgt = rays * depth + xyz_tgt_h = torch.cat([xyz_tgt, torch.ones_like(depth)], dim=1) + xyz_world = torch.einsum("bij,bjhw->bihw", torch.linalg.inv(tgt_w2c.to(dtype=dtype)), xyz_tgt_h) + xyz_src = torch.einsum("bij,bjhw->bihw", src_w2c.to(dtype=dtype), xyz_world)[:, :3] + + Fisheye624 = _load_fisheye624_class() + src_cam = Fisheye624(params=src_camera_params.to(device=dev, dtype=dtype)) + uv_src = src_cam.project(xyz_src).to(dtype=dtype) + proj_mask = getattr(src_cam, "projection_mask", None) + h_src = int(src_valid_mask.shape[-2]) if torch.is_tensor(src_valid_mask) else int(h) + w_src = int(src_valid_mask.shape[-1]) if torch.is_tensor(src_valid_mask) else int(w) + u = uv_src[:, 0] + v = uv_src[:, 1] + valid = valid & torch.isfinite(u) & torch.isfinite(v) + eps = float(edge_eps_px) + valid = valid & (u >= -eps) & (u <= float(w_src - 1) + eps) & (v >= -eps) & (v <= float(h_src - 1) + eps) + if torch.is_tensor(proj_mask): + valid = valid & proj_mask[:, 0].to(device=dev, dtype=torch.bool) + if torch.is_tensor(src_valid_mask): + uv_grid = torch.stack( + [ + (u / max(float(w_src - 1), 1.0)) * 2.0 - 1.0, + (v / max(float(h_src - 1), 1.0)) * 2.0 - 1.0, + ], + dim=-1, + ).clamp(-1.0, 1.0) + src_valid_proj = F.grid_sample( + src_valid_mask.to(device=dev, dtype=dtype), + uv_grid, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + ) + valid = valid & (src_valid_proj[:, 0] > 0.5) + if torch.is_tensor(source_depth_distance_m): + src_depth = source_depth_distance_m.to(device=dev, dtype=dtype) + if src_depth.ndim == 3: + src_depth = src_depth.unsqueeze(1) + if src_depth.ndim != 4 or int(src_depth.shape[0]) != 1 or int(src_depth.shape[1]) != 1: + raise ValueError(f"Expected source_depth_distance_m shape (1,1,H,W), got {tuple(src_depth.shape)}") + if tuple(src_depth.shape[-2:]) != (h_src, w_src): + src_depth = F.interpolate(src_depth, size=(h_src, w_src), mode="nearest") + src_depth_valid = torch.isfinite(src_depth) & (src_depth > 0.0) + invalid_depth_fill = 1.0e9 + src_depth_for_min = torch.where( + src_depth_valid, + src_depth, + torch.full_like(src_depth, invalid_depth_fill), + ) + radius = max(int(source_visibility_radius_px), 0) + if radius > 0: + kernel = 2 * radius + 1 + padded_depth = F.pad(src_depth_for_min, (radius, radius, radius, radius), value=invalid_depth_fill) + src_depth_for_min = -F.max_pool2d(-padded_depth, kernel_size=kernel, stride=1) + src_depth_valid = ( + F.max_pool2d(src_depth_valid.to(dtype=dtype), kernel_size=kernel, stride=1, padding=radius) + > 0.0 + ) + uv_grid = torch.stack( + [ + (u / max(float(w_src - 1), 1.0)) * 2.0 - 1.0, + (v / max(float(h_src - 1), 1.0)) * 2.0 - 1.0, + ], + dim=-1, + ) + sampled_src_dist = F.grid_sample( + src_depth_for_min, + uv_grid, + mode="bilinear", + padding_mode="zeros", + align_corners=True, + )[:, 0] + sampled_src_valid = ( + F.grid_sample( + src_depth_valid.to(dtype=dtype), + uv_grid, + mode="nearest", + padding_mode="zeros", + align_corners=True, + )[:, 0] + > 0.5 + ) + projected_src_dist = torch.linalg.vector_norm(xyz_src, dim=1) + tolerance = float(source_occlusion_tolerance_m) + float(source_occlusion_tolerance_ratio) * sampled_src_dist.abs() + source_visible = sampled_src_valid & torch.isfinite(sampled_src_dist) & ( + projected_src_dist <= sampled_src_dist + tolerance + ) + valid = valid & source_visible + return valid.unsqueeze(1).to(dtype=dtype) + + +def render_gaussians_fisheye624( + gaussians_world: Any, + *, + extrinsics_w2c: torch.Tensor, + camera_params: torch.Tensor, + image_h: int, + image_w: int, + valid_mask: torch.Tensor | None = None, + near_threshold: float = 0.2, + asso_mode: int = 0, +) -> dict[str, torch.Tensor]: + dgr = _load_geer_rasterizer() + dev = extrinsics_w2c.device + dtype = torch.float32 + + means = gaussians_world.mean_vectors[0].to(device=dev, dtype=dtype) + scales = gaussians_world.singular_values[0].to(device=dev, dtype=dtype) + rotations = gaussians_world.quaternions[0].to(device=dev, dtype=dtype) + colors = gaussians_world.colors[0].to(device=dev, dtype=dtype) + opacities = gaussians_world.opacities[0].to(device=dev, dtype=dtype) + if opacities.ndim == 1: + opacities = opacities.unsqueeze(-1) + + viewmatrix = extrinsics_w2c[0].to(device=dev, dtype=dtype).transpose(0, 1).contiguous() + campos = torch.linalg.inv(viewmatrix)[3, :3].contiguous() + params = camera_params[0].to(device=dev, dtype=dtype) if camera_params.ndim == 2 else camera_params.to(device=dev, dtype=dtype) + rays, tan_theta, tan_phi = build_fisheye624_tangent_arrays( + params.unsqueeze(0), + image_h=int(image_h), + image_w=int(image_w), + valid_mask=valid_mask, + ) + raymap = rays[0].permute(1, 2, 0).contiguous() + empty = torch.empty(0, device=dev, dtype=dtype) + bg = torch.zeros(3, device=dev, dtype=dtype) + tan_theta = tan_theta.to(device=dev, dtype=dtype).contiguous() + tan_phi = tan_phi.to(device=dev, dtype=dtype).contiguous() + raster_settings = dgr.GaussianRasterizationSettings( + image_height=int(image_h), + image_width=int(image_w), + tanfovx=float(torch.max(torch.abs(tan_theta)).item()), + tanfovy=float(torch.max(torch.abs(tan_phi)).item()), + bg=bg, + scale_modifier=1.0, + viewmatrix=viewmatrix, + mirror_transformed_tan_theta=tan_theta, + mirror_transformed_tan_phi=tan_phi, + tan_theta=tan_theta, + tan_phi=tan_phi, + focal_x=float(params[0].item()), + focal_y=float(params[1].item()), + principal_x=float(params[2].item()), + principal_y=float(params[3].item()), + distortion_coeffs=params[4:8].contiguous(), + raymap=raymap, + sh_degree=0, + campos=campos, + prefiltered=False, + debug=False, + antialiasing=False, + render_mode=1, + near_threshold=float(near_threshold), + asso_mode=int(asso_mode), + ) + rasterizer = dgr.GaussianRasterizer(raster_settings=raster_settings) + means2d = torch.zeros_like(means) + rendered_rgb, _radii, invdepth, _kernel_times, _ranges = rasterizer( + means3D=means, + means2D=means2d, + opacities=opacities, + colors_precomp=colors, + scales=scales, + rotations=rotations, + ) + alpha_rgb, _radii_a, _invdepth_a, _kernel_times_a, _ranges_a = rasterizer( + means3D=means, + means2D=means2d, + opacities=opacities, + colors_precomp=torch.ones_like(colors), + scales=scales, + rotations=rotations, + ) + alpha = alpha_rgb[:1].unsqueeze(0).clamp(0.0, 1.0) + depth_distance, invdepth = _depth_from_accumulated_invdepth(invdepth, alpha) + z = rays[:, 2].clamp(min=1e-4) + ray_tan_x = (rays[:, 0] / z).abs() + ray_tan_y = (rays[:, 1] / z).abs() + angular_valid = ( + (ray_tan_x <= torch.max(torch.abs(tan_theta)).clamp(min=1e-3)) + & (ray_tan_y <= torch.max(torch.abs(tan_phi)).clamp(min=1e-3)) + ).unsqueeze(1).to(dtype=dtype) + render_valid = angular_valid + if torch.is_tensor(valid_mask): + render_valid = render_valid * valid_mask.to(device=dev, dtype=dtype) + return { + "color": rendered_rgb.unsqueeze(0).clamp(0.0, 1.0), + "alpha": alpha, + "depth_distance": depth_distance, + "invdepth": invdepth, + "valid_mask": render_valid, + } diff --git a/unisharp/utils/gaussians.py b/unisharp/utils/gaussians.py new file mode 100644 index 0000000000000000000000000000000000000000..dec71d5cb615458f84291892d90763ba9659aac3 --- /dev/null +++ b/unisharp/utils/gaussians.py @@ -0,0 +1,391 @@ + + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Literal, NamedTuple, TYPE_CHECKING + +import numpy as np +import torch + +if TYPE_CHECKING: + from plyfile import PlyData + +from unisharp.utils import color_space as cs_utils +from unisharp.utils import linalg + +LOGGER = logging.getLogger(__name__) + + +BackgroundColor = Literal["black", "white", "random_color", "random_pixel"] + + +class Gaussians3D(NamedTuple): + + mean_vectors: torch.Tensor + singular_values: torch.Tensor + quaternions: torch.Tensor + colors: torch.Tensor + opacities: torch.Tensor + + def to(self, device: torch.device) -> Gaussians3D: + return Gaussians3D( + mean_vectors=self.mean_vectors.to(device), + singular_values=self.singular_values.to(device), + quaternions=self.quaternions.to(device), + colors=self.colors.to(device), + opacities=self.opacities.to(device), + ) + + +class SceneMetaData(NamedTuple): + + focal_length_px: float + resolution_px: tuple[int, int] + color_space: cs_utils.ColorSpace + + +def get_unprojection_matrix( + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + image_shape: tuple[int, int], +) -> torch.Tensor: + device = intrinsics.device + image_width, image_height = image_shape + ndc_matrix = torch.tensor( + [ + [2.0 / image_width, 0.0, -1.0, 0.0], + [0.0, 2.0 / image_height, -1.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + device=device, + ) + return torch.linalg.inv(ndc_matrix @ intrinsics @ extrinsics) + + +def unproject_gaussians( + gaussians_ndc: Gaussians3D, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + image_shape: tuple[int, int], +) -> Gaussians3D: + unprojection_matrix = get_unprojection_matrix(extrinsics, intrinsics, image_shape) + gaussians = apply_transform(gaussians_ndc, unprojection_matrix[:3]) + return gaussians + + +def apply_transform(gaussians: Gaussians3D, transform: torch.Tensor) -> Gaussians3D: + transform_linear = transform[..., :3, :3] + transform_offset = transform[..., :3, 3] + + mean_vectors = gaussians.mean_vectors @ transform_linear.T + transform_offset + covariance_matrices = compose_covariance_matrices( + gaussians.quaternions, gaussians.singular_values + ) + covariance_matrices = ( + transform_linear @ covariance_matrices @ transform_linear.transpose(-1, -2) + ) + quaternions, singular_values = decompose_covariance_matrices(covariance_matrices) + + return Gaussians3D( + mean_vectors=mean_vectors, + singular_values=singular_values, + quaternions=quaternions, + colors=gaussians.colors, + opacities=gaussians.opacities, + ) + + +def decompose_covariance_matrices( + covariance_matrices: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + device = covariance_matrices.device + dtype = covariance_matrices.dtype + + covariance_matrices = covariance_matrices.detach().cpu().to(torch.float64) + rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices) + + batch_idx, gaussian_idx = torch.where(torch.linalg.det(rotations) < 0) + num_reflections = len(gaussian_idx) + if num_reflections > 0: + LOGGER.warning( + "Received %d reflection matrices from SVD. Flipping them to rotations.", + num_reflections, + ) + rotations[batch_idx, gaussian_idx, :, -1] *= -1 + quaternions = linalg.quaternions_from_rotation_matrices(rotations) + quaternions = quaternions.to(dtype=dtype, device=device) + singular_values = singular_values_2.sqrt().to(dtype=dtype, device=device) + return quaternions, singular_values + + +def compose_covariance_matrices( + quaternions: torch.Tensor, singular_values: torch.Tensor +) -> torch.Tensor: + device = quaternions.device + rotations = linalg.rotation_matrices_from_quaternions(quaternions) + diagonal_matrix = torch.eye(3, device=device) * singular_values[..., :, None] + return rotations @ diagonal_matrix.square() @ rotations.transpose(-1, -2) + + +def convert_spherical_harmonics_to_rgb(sh0: torch.Tensor) -> torch.Tensor: + coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi)) + return sh0 * coeff_degree0 + 0.5 + + +def convert_rgb_to_spherical_harmonics(rgb: torch.Tensor) -> torch.Tensor: + coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi)) + return (rgb - 0.5) / coeff_degree0 + + +def load_ply(path: Path) -> tuple[Gaussians3D, SceneMetaData]: + from plyfile import PlyData + + plydata = PlyData.read(path) + + vertices = next(filter(lambda x: x.name == "vertex", plydata.elements)) + + properties = ["x", "y", "z"] + properties.extend([f"f_dc_{i}" for i in range(3)]) + properties.extend([f"scale_{i}" for i in range(3)]) + properties.extend([f"rot_{i}" for i in range(3)]) + + for prop in properties: + if prop not in vertices: + raise KeyError(f"Incompatible ply file: property {prop} not found in ply elements.") + mean_vectors = np.stack( + ( + np.asarray(vertices["x"]), + np.asarray(vertices["y"]), + np.asarray(vertices["z"]), + ), + axis=1, + ) + + scale_logits = np.stack( + ( + np.asarray(vertices["scale_0"]), + np.asarray(vertices["scale_1"]), + np.asarray(vertices["scale_2"]), + ), + axis=1, + ) + + quaternions = np.stack( + ( + np.asarray(vertices["rot_0"]), + np.asarray(vertices["rot_1"]), + np.asarray(vertices["rot_2"]), + np.asarray(vertices["rot_3"]), + ), + axis=1, + ) + + spherical_harmonics_deg0 = np.stack( + ( + np.asarray(vertices["f_dc_0"]), + np.asarray(vertices["f_dc_1"]), + np.asarray(vertices["f_dc_2"]), + ), + axis=1, + ) + + colors = convert_spherical_harmonics_to_rgb(spherical_harmonics_deg0) + + opacity_logits = np.asarray(vertices["opacity"])[..., None] + + supplement_elements = [element for element in plydata.elements if element.name != "vertex"] + supplement_data: dict[str, Any] = {} + supplement_keys = ["extrinsic", "intrinsic", "color_space", "image_size"] + + for element in supplement_elements: + for key in supplement_keys: + if key not in supplement_data and key in element: + supplement_data[key] = np.asarray(element[key]) + + if "intrinsic" in supplement_data: + intrinsics_data = supplement_data["intrinsic"] + + if "image_size" not in supplement_data: + if len(intrinsics_data) != 4: + raise ValueError( + "Expect legacy intrinsics with len=4 containing image size, " + f"but received len={len(intrinsics_data)}" + ) + focal_length_px = (intrinsics_data[0], intrinsics_data[1]) + width = int(intrinsics_data[2]) + height = int(intrinsics_data[3]) + + else: + if len(intrinsics_data) != 9: + raise ValueError( + "Expect 9 elements in intrinsics, " f"but received {len(intrinsics_data)}." + ) + intrinsics_matrix = intrinsics_data.reshape((3, 3)) + focal_length_px = (intrinsics_matrix[0, 0], intrinsics_matrix[1, 1]) + + image_size_data = supplement_data["image_size"] + width = image_size_data[0] + height = image_size_data[1] + + else: + focal_length_px = (512, 512) + width = 640 + height = 480 + + extrinsics_data = supplement_data.get("extrinsic", np.eye(4).flatten()) + extrinsics_matrix = np.eye(4) + + if len(extrinsics_data) == 12: + extrinsics_matrix[:3] = extrinsics_data.reshape((3, 4)) + extrinsics_matrix[:3, :3] = extrinsics_matrix[:3, :3].copy().T + elif len(extrinsics_data) == 16: + extrinsics_matrix[:] = extrinsics_data.reshape((4, 4)) + else: + raise ValueError(f"Unrecognized extrinsics matrix shape {len(extrinsics_data)}") + + color_space_index = supplement_data.get("color_space", 1) + color_space = cs_utils.decode_color_space(color_space_index) + colors = torch.from_numpy(colors).view(1, -1, 3).float() + + if color_space == "sRGB": + colors = cs_utils.sRGB2linearRGB(colors.flatten(0, 1)).view(1, -1, 3) + color_space = "linearRGB" + + mean_vectors = torch.from_numpy(mean_vectors).view(1, -1, 3).float() + quaternions = torch.from_numpy(quaternions).view(1, -1, 4).float() + singular_values = torch.exp(torch.from_numpy(scale_logits).view(1, -1, 3)).float() + opacities = torch.sigmoid(torch.from_numpy(opacity_logits).view(1, -1)).float() + + gaussians = Gaussians3D( + mean_vectors=mean_vectors, + quaternions=quaternions, + singular_values=singular_values, + opacities=opacities, + colors=colors, + ) + metadata = SceneMetaData(focal_length_px[0], (width, height), color_space) + return gaussians, metadata + + +@torch.no_grad() +def save_ply( + gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path +) -> "PlyData": + from plyfile import PlyData, PlyElement + + def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor: + return torch.log(tensor / (1.0 - tensor)) + + xyz = gaussians.mean_vectors.flatten(0, 1) + scale_logits = torch.log(gaussians.singular_values).flatten(0, 1) + quaternions = gaussians.quaternions.flatten(0, 1) + + colors = convert_rgb_to_spherical_harmonics( + cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1)) + ) + color_space_index = cs_utils.encode_color_space("sRGB") + + opacity_logits = _inverse_sigmoid(gaussians.opacities).flatten(0, 1).unsqueeze(-1) + + attributes = torch.cat( + ( + xyz, + colors, + opacity_logits, + scale_logits, + quaternions, + ), + dim=1, + ) + + dtype_full = [ + (attribute, "f4") + for attribute in ["x", "y", "z"] + + [f"f_dc_{i}" for i in range(3)] + + ["opacity"] + + [f"scale_{i}" for i in range(3)] + + [f"rot_{i}" for i in range(4)] + ] + + num_gaussians = len(xyz) + elements = np.empty(num_gaussians, dtype=dtype_full) + elements[:] = list(map(tuple, attributes.detach().cpu().numpy())) + vertex_elements = PlyElement.describe(elements, "vertex") + + image_height, image_width = image_shape + + dtype_image_size = [("image_size", "u4")] + image_size_array = np.empty(2, dtype=dtype_image_size) + image_size_array[:] = np.array([image_width, image_height]) + image_size_element = PlyElement.describe(image_size_array, "image_size") + + dtype_intrinsic = [("intrinsic", "f4")] + intrinsic_array = np.empty(9, dtype=dtype_intrinsic) + intrinsic = np.array( + [ + f_px, + 0, + image_width * 0.5, + 0, + f_px, + image_height * 0.5, + 0, + 0, + 1, + ] + ) + intrinsic_array[:] = intrinsic.flatten() + intrinsic_element = PlyElement.describe(intrinsic_array, "intrinsic") + + dtype_extrinsic = [("extrinsic", "f4")] + extrinsic_array = np.empty(16, dtype=dtype_extrinsic) + extrinsic_array[:] = np.eye(4).flatten() + extrinsic_element = PlyElement.describe(extrinsic_array, "extrinsic") + + dtype_frames = [("frame", "i4")] + frame_array = np.empty(2, dtype=dtype_frames) + frame_array[:] = np.array([1, num_gaussians], dtype=np.int32) + frame_element = PlyElement.describe(frame_array, "frame") + + dtype_disparity = [("disparity", "f4")] + disparity_array = np.empty(2, dtype=dtype_disparity) + + radius = torch.linalg.vector_norm(gaussians.mean_vectors[0], dim=-1).clamp(min=1e-6) + disparity = 1.0 / radius + quantiles = ( + torch.quantile(disparity, q=torch.tensor([0.1, 0.9], device=disparity.device)) + .float() + .cpu() + .numpy() + ) + disparity_array[:] = quantiles + disparity_element = PlyElement.describe(disparity_array, "disparity") + + dtype_color_space = [("color_space", "u1")] + color_space_array = np.empty(1, dtype=dtype_color_space) + color_space_array[:] = np.array([color_space_index]).flatten() + color_space_element = PlyElement.describe(color_space_array, "color_space") + + dtype_version = [("version", "u1")] + version_array = np.empty(3, dtype=dtype_version) + version_array[:] = np.array([1, 5, 0], dtype=np.uint8).flatten() + version_element = PlyElement.describe(version_array, "version") + + plydata = PlyData( + [ + vertex_elements, + extrinsic_element, + intrinsic_element, + image_size_element, + frame_element, + disparity_element, + color_space_element, + version_element, + ] + ) + + plydata.write(path) + return plydata diff --git a/unisharp/utils/gsplat.py b/unisharp/utils/gsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..a88555ef49367e73c9c07be0929f31edf8e48936 --- /dev/null +++ b/unisharp/utils/gsplat.py @@ -0,0 +1,165 @@ + +from __future__ import annotations + +import os +from typing import NamedTuple + +import torch +from torch import nn + +from unisharp.utils import color_space as cs_utils +from unisharp.utils.gaussians import BackgroundColor, Gaussians3D + + +if "TORCH_CUDA_ARCH_LIST" not in os.environ and torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability(0) + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" + +import gsplat # noqa: E402 (must happen after env var setup) + + +class RenderingOutputs(NamedTuple): + + color: torch.Tensor + depth: torch.Tensor + alpha: torch.Tensor + + +class GSplatRenderer(nn.Module): + + color_space: cs_utils.ColorSpace + background_color: BackgroundColor + + def __init__( + self, + color_space: cs_utils.ColorSpace = "sRGB", + background_color: BackgroundColor = "black", + low_pass_filter_eps: float = 1e-2, + ) -> None: + super().__init__() + self.color_space = color_space + self.background_color = background_color + self.low_pass_filter_eps = low_pass_filter_eps + + def forward( + self, + gaussians: Gaussians3D, + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + image_width: int, + image_height: int, + ) -> RenderingOutputs: + gaussian_batch_size = len(gaussians.mean_vectors) + camera_batch_size = int(extrinsics.shape[0]) + if int(intrinsics.shape[0]) != camera_batch_size: + raise ValueError( + f"Expected intrinsics batch to match extrinsics batch, got " + f"{tuple(intrinsics.shape)} vs {tuple(extrinsics.shape)}" + ) + if gaussian_batch_size not in (1, camera_batch_size): + raise ValueError( + f"Unsupported batch combination: gaussians={gaussian_batch_size}, cameras={camera_batch_size}. " + "Expected either one Gaussian batch for many cameras or one-to-one batches." + ) + outputs_list: list[RenderingOutputs] = [] + + for ib in range(camera_batch_size): + g_idx = 0 if gaussian_batch_size == 1 else ib + means = gaussians.mean_vectors[g_idx].to(dtype=torch.float32) + quats = gaussians.quaternions[g_idx].to(dtype=torch.float32) + scales = gaussians.singular_values[g_idx].to(dtype=torch.float32) + opacities = gaussians.opacities[g_idx].to(dtype=torch.float32) + colors_in = gaussians.colors[g_idx].to(dtype=torch.float32) + viewmats = extrinsics[ib : ib + 1].to(dtype=torch.float32) + Ks = intrinsics[ib : ib + 1, :3, :3].to(dtype=torch.float32) + + colors, alphas, meta = gsplat.rendering.rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors_in, + viewmats=viewmats, + Ks=Ks, + width=image_width, + height=image_height, + render_mode="RGB+D", + rasterize_mode="classic", + absgrad=False, + packed=False, + eps2d=self.low_pass_filter_eps, + ) + + rendered_color = colors[..., 0:3].permute([0, 3, 1, 2]) + rendered_depth_unnormalized = colors[..., 3:4].permute([0, 3, 1, 2]) + rendered_alpha = alphas.permute([0, 3, 1, 2]) + + rendered_color = self.compose_with_background( + rendered_color, rendered_alpha, self.background_color + ) + + if self.color_space == "sRGB": + pass + elif self.color_space == "linearRGB": + rendered_color = cs_utils.linearRGB2sRGB(rendered_color) + else: + raise ValueError(f"Unsupported ColorSpace type: {self.color_space!r}") + + cov2d = self._conics_to_covars2d(meta["conics"]) + splats_visible_mask = meta["depths"] > 1e-2 + cov2d[~splats_visible_mask][..., 0, 0] = 1 + cov2d[~splats_visible_mask][..., 1, 1] = 1 + cov2d[~splats_visible_mask][..., 0, 1] = 0 + + rendered_depth = rendered_depth_unnormalized / torch.clip(rendered_alpha, min=1e-8) + + outputs = RenderingOutputs( + color=rendered_color, + depth=rendered_depth, + alpha=rendered_alpha, + ) + outputs_list.append(outputs) + + return RenderingOutputs( + color=torch.cat([item.color for item in outputs_list], dim=0).contiguous(), + depth=torch.cat([item.depth for item in outputs_list], dim=0).contiguous(), + alpha=torch.cat([item.alpha for item in outputs_list], dim=0).contiguous(), + ) + + @staticmethod + def compose_with_background( + rendered_rgb: torch.Tensor, + rendered_alpha: torch.Tensor, + background_color: BackgroundColor, + ) -> torch.Tensor: + if background_color == "black": + return rendered_rgb + elif background_color == "white": + return rendered_rgb + (1.0 - rendered_alpha) + elif background_color == "random_color": + return ( + rendered_rgb + + (1.0 - rendered_alpha) + * torch.rand(3, dtype=rendered_rgb.dtype, device=rendered_rgb.device)[ + None, :, None, None + ] + ) + elif background_color == "random_pixel": + return rendered_rgb + (1.0 - rendered_alpha) * torch.rand_like(rendered_rgb) + else: + raise ValueError("Unsupported BackgroundColor type.") + + @staticmethod + def _conics_to_covars2d(conics: torch.Tensor, eps=1e-8) -> torch.Tensor: + a = conics[..., 0] + b = conics[..., 1] + c = conics[..., 2] + det = 1 / (a * c - b**2 + eps) + det = det.clamp(min=eps) + covars2d = torch.zeros(*conics.shape[:-1], 2, 2, device=conics.device) + covars2d[..., 1, 1] = a * det + covars2d[..., 0, 0] = c * det + covars2d[..., 0, 1] = -b * det + covars2d[..., 1, 0] = -b * det + covars2d = torch.nan_to_num(covars2d, nan=0.0, posinf=0.0, neginf=0.0) + return covars2d diff --git a/unisharp/utils/io.py b/unisharp/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..c26148f2cd6da4f2512a5a2b437ef8bbadcd252a --- /dev/null +++ b/unisharp/utils/io.py @@ -0,0 +1,197 @@ + + +from __future__ import annotations + +import io +import logging +from pathlib import Path +from typing import IO, Any, Protocol + +import numpy as np +import torch +from PIL import ExifTags, Image, TiffTags + +try: + import imageio.v2 as iio +except ImportError: + iio = None + +try: + import pillow_heif +except ImportError: + pillow_heif = None + +from .vis import METRIC_DEPTH_MAX_CLAMP_METER, colorize_depth + +LOGGER = logging.getLogger(__name__) + + +Image.MAX_IMAGE_PIXELS = 200000000 + + +def load_rgb( + path: Path, auto_rotate: bool = True, remove_alpha: bool = True +) -> tuple[np.ndarray, list[bytes] | None, float]: + LOGGER.debug(f"Loading image {path} ...") + + if path.suffix.lower() in [".heic"]: + if pillow_heif is None: + raise ImportError("pillow_heif is required to read .heic images but is not installed.") + heif_file = pillow_heif.open_heif(path, convert_hdr_to_8bit=True) + img_pil = heif_file.to_pillow() + else: + img_pil = Image.open(path) + + img_exif = extract_exif(img_pil) + icc_profile = img_pil.info.get("icc_profile", None) + + if auto_rotate: + exif_orientation = img_exif.get("Orientation", 1) + if exif_orientation == 3: + img_pil = img_pil.transpose(Image.ROTATE_180) + elif exif_orientation == 6: + img_pil = img_pil.transpose(Image.ROTATE_270) + elif exif_orientation == 8: + img_pil = img_pil.transpose(Image.ROTATE_90) + elif exif_orientation != 1: + LOGGER.warning(f"Ignoring image orientation {exif_orientation}.") + + f_35mm = img_exif.get("FocalLengthIn35mmFilm", img_exif.get("FocalLenIn35mmFilm", None)) + if f_35mm is None or f_35mm < 1: + f_35mm = img_exif.get("FocalLength", None) + if f_35mm is None: + LOGGER.warn(f"Did not find focallength in exif data of {path} - Setting to 30mm.") + f_35mm = 30.0 + if f_35mm < 10.0: + LOGGER.info("Found focal length below 10mm, assuming it's not for 35mm.") + f_35mm *= 8.4 + + img = np.asarray(img_pil) + if img.ndim < 3 or img.shape[2] == 1: + img = np.dstack((img, img, img)) + + if remove_alpha: + img = img[:, :, :3] + + LOGGER.debug(f"\tHxW: {img.shape[0]}x{img.shape[1]}") + LOGGER.debug(f"\tfocal length @ 35mm film: {f_35mm}mm") + f_px = convert_focallength(img.shape[1], img.shape[0], f_35mm) + LOGGER.debug(f"\tfocal length: {f_px:.2f}px") + + return img, icc_profile, f_px + + +def extract_exif(img_pil: Image.Image) -> dict[str, Any]: + # cf https://pillow.readthedocs.io/en/stable/releasenotes/8.2.0.html#image-getexif-exif-and-gps-ifd # noqa + img_exif = img_pil.getexif().get_ifd(0x8769) + exif_dict = {ExifTags.TAGS[k]: v for k, v in img_exif.items() if k in ExifTags.TAGS} + + # https://pillow.readthedocs.io/en/stable/_modules/PIL/TiffTags.html# # noqa + tiff_tags = img_pil.getexif() + tiff_dict = {TiffTags.TAGS_V2[k].name: v for k, v in tiff_tags.items() if k in TiffTags.TAGS_V2} + return {**exif_dict, **tiff_dict} + + +def convert_focallength(width: float, height: float, f_mm: float = 30) -> float: + return f_mm * np.sqrt(width**2.0 + height**2.0) / np.sqrt(36**2 + 24**2) + + +def save_image( + image: np.ndarray, + output_path: Path, + icc_profile: list[bytes] | None = None, + jpeg_quality: int = 92, +) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + + extensions_to_format = Image.registered_extensions() + try: + format = extensions_to_format[output_path.suffix.lower()] + except KeyError: + raise ValueError(f"Unsupported output format {output_path.suffix}.") + + with output_path.open("wb") as file_handle: + write_image( + image, + file_handle, + format, + icc_profile=icc_profile, + jpeg_quality=jpeg_quality, + ) + + +def write_image( + image: np.ndarray, + output_io: IO[bytes], + format="jpg", + icc_profile: list[bytes] | None = None, + jpeg_quality: int = 92, +): + pil_config = {} + if format == "JPEG": + pil_config["quality"] = jpeg_quality + + image_pil = Image.fromarray(image) + + if format == "TIFF": + bytes_io = io.BytesIO() + image_pil.save(bytes_io, format="TIFF") + bytes_io.seek(0) + output_io.write(bytes_io.read()) + return + + image_pil.save(output_io, format, icc_profile=icc_profile, **pil_config) + + +def get_supported_image_extensions(with_heic: bool = True) -> list[str]: + exts = Image.registered_extensions() + supported_extensions = {ex for ex, f in exts.items() if f in Image.OPEN} + if with_heic: + supported_extensions.add(".heic") + + supported_extensions_upper = {ex.upper() for ex in supported_extensions} + return list(supported_extensions | supported_extensions_upper) + + +class OutputWriter(Protocol): + + def add_frame(self, image: torch.Tensor, depth: torch.Tensor) -> None: + ... + + def close(self) -> None: + ... + + +class VideoWriter(OutputWriter): + + def __init__(self, output_path: Path, fps: float = 30.0, render_depth: bool = True) -> None: + if iio is None: + raise ImportError("imageio is required for VideoWriter but is not installed.") + output_path.parent.mkdir(exist_ok=True, parents=True) + self.output_path = output_path + self.image_writer = iio.get_writer(output_path, fps=fps) + + self.max_depth_estimate = None + self.depth_writer = None + if render_depth: + self.depth_writer = iio.get_writer(output_path.with_suffix(".depth.mp4"), fps=fps) + + def add_frame(self, image: torch.Tensor, depth: torch.Tensor) -> None: + image_np = image.detach().cpu().numpy() + self.image_writer.append_data(image_np) + + if self.depth_writer is not None: + if self.max_depth_estimate is None: + self.max_depth_estimate = depth.max().item() + + colored_depth_pt = colorize_depth( + depth, + min(self.max_depth_estimate, METRIC_DEPTH_MAX_CLAMP_METER), # type: ignore[call-overload] + ) + colored_depth_np = colored_depth_pt.squeeze(0).permute(1, 2, 0).cpu().numpy() + self.depth_writer.append_data(colored_depth_np) + + def close(self): + self.image_writer.close() + if self.depth_writer is not None: + self.depth_writer.close() \ No newline at end of file diff --git a/unisharp/utils/linalg.py b/unisharp/utils/linalg.py new file mode 100644 index 0000000000000000000000000000000000000000..08da1eafe6f9ec5a0b18e17d8b89982f11739ac6 --- /dev/null +++ b/unisharp/utils/linalg.py @@ -0,0 +1,51 @@ + +from __future__ import annotations + +import torch +from scipy.spatial.transform import Rotation + + +def rotation_matrices_from_quaternions(quaternions: torch.Tensor) -> torch.Tensor: + device = quaternions.device + shape = quaternions.shape[:-1] + + quaternions = quaternions / torch.linalg.norm(quaternions, dim=-1, keepdim=True) + real_part = quaternions[..., 0] + vector_part = quaternions[..., 1:] + + vector_cross = get_cross_product_matrix(vector_part) + real_part = real_part[..., None, None] + + matrix_outer = vector_part[..., :, None] * vector_part[..., None, :] + matrix_diag = real_part.square() * eyes(3, shape=shape, device=device) + matrix_cross_1 = 2 * real_part * vector_cross + matrix_cross_2 = vector_cross @ vector_cross + + return matrix_outer + matrix_diag + matrix_cross_1 + matrix_cross_2 + + +def quaternions_from_rotation_matrices(matrices: torch.Tensor) -> torch.Tensor: + if not matrices.shape[-2:] == (3, 3): + raise ValueError(f"matrices have invalid shape {matrices.shape}") + matrices_np = matrices.detach().cpu().numpy() + quaternions_np = Rotation.from_matrix(matrices_np.reshape(-1, 3, 3)).as_quat() + quaternions_np = quaternions_np[:, [3, 0, 1, 2]] + quaternions_np = quaternions_np.reshape(matrices_np.shape[:-2] + (4,)) + return torch.as_tensor(quaternions_np, device=matrices.device, dtype=matrices.dtype) + + +def get_cross_product_matrix(vectors: torch.Tensor) -> torch.Tensor: + if not vectors.shape[-1] == 3: + raise ValueError("Only 3-dimensional vectors are supported") + device = vectors.device + shape = vectors.shape[:-1] + unit_basis = eyes(3, shape=shape, device=device) + return torch.cross(vectors[..., :, None], unit_basis, dim=-2) + + +def eyes( + dim: int, shape: tuple[int, ...], device: torch.device | str | None = None +) -> torch.Tensor: + return torch.eye(dim, device=device).broadcast_to(shape + (dim, dim)).clone() + + diff --git a/unisharp/utils/logging.py b/unisharp/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..774361eeba73e7f9c7aabb1532c2158f4ea0ca13 --- /dev/null +++ b/unisharp/utils/logging.py @@ -0,0 +1,39 @@ + +from __future__ import annotations + +import logging +import sys +from pathlib import Path + + +def configure(log_level: int, log_path: Path | None = None, prefix: str | None = None) -> None: + logger = logging.getLogger(prefix) + + for handler in logger.handlers: + logger.removeHandler(handler) + + for filter in logger.filters: + logger.removeFilter(filter) + + logger.setLevel(log_level) + + formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") + + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(formatter) + logger.addHandler(stdout_handler) + + if log_path is not None: + file_handler = logging.FileHandler(log_path, mode="w") + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + noisy_libs = [ + "PIL", + "PIL.PngImagePlugin", + "urllib3", + "matplotlib", + "imageio", + ] + for name in noisy_libs: + logging.getLogger(name).setLevel(logging.WARNING) diff --git a/unisharp/utils/math.py b/unisharp/utils/math.py new file mode 100644 index 0000000000000000000000000000000000000000..1b8aa02260c17b03401af405f185e0b1c303152d --- /dev/null +++ b/unisharp/utils/math.py @@ -0,0 +1,107 @@ + +from __future__ import annotations + +from typing import Any, Callable, Literal, NamedTuple, Union + +import torch +from torch import autograd + +ActivationType = Literal[ + "linear", + "exp", + "sigmoid", + "softplus", + "relu_with_pushback", + "hard_sigmoid_with_pushback", +] +ActivationFunction = Callable[[torch.Tensor], torch.Tensor] + + +class ActivationPair(NamedTuple): + + forward: ActivationFunction + inverse: ActivationFunction + + +def create_activation_pair(activation_type: ActivationType) -> ActivationPair: + if activation_type == "linear": + return ActivationPair(lambda x: x, lambda x: x) + elif activation_type == "exp": + return ActivationPair(torch.exp, torch.log) + elif activation_type == "sigmoid": + return ActivationPair(torch.sigmoid, inverse_sigmoid) + elif activation_type == "softplus": + return ActivationPair(torch.nn.functional.softplus, inverse_softplus) + elif activation_type == "relu_with_pushback": + return ActivationPair(relu_with_pushback, lambda x: x) + elif activation_type == "hard_sigmoid_with_pushback": + return ActivationPair(hard_sigmoid_with_pushback, lambda x: 6.0 * x - 3.0) + else: + raise ValueError(f"Unsupported activation function: {activation_type}.") + + +def inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor: + return torch.log(tensor / (1.0 - tensor)) + + +def inverse_softplus(tensor: torch.Tensor, eps: float = 1e-06) -> torch.Tensor: + tensor = tensor.clamp_min(eps) + sigmoid = torch.sigmoid(-tensor) + exp = sigmoid / (1.0 - sigmoid) + return tensor + torch.log(-exp + 1.0) + + +class ClampWithPushback(autograd.Function): + + @staticmethod + def forward( + ctx: Any, + tensor: torch.Tensor, + min: float | None, + max: float | None, + pushback: float, + ) -> torch.Tensor: + if min is not None and max is not None and min >= max: + raise ValueError("Only min < max is supported.") + + ctx.save_for_backward(tensor) + ctx.min = min + ctx.max = max + ctx.pushback = pushback + return torch.clamp(tensor, min=min, max=max) + + @staticmethod + def backward( # type: ignore[override] # Deal with buggy torch annotations. + ctx: Any, grad_in: torch.Tensor + ) -> tuple[torch.Tensor, None, None, None]: + grad_out = grad_in.clone() + (tensor,) = ctx.saved_tensors + + if ctx.min is not None: + mask_min = tensor < ctx.min + grad_out[mask_min] = -ctx.pushback + + if ctx.max is not None: + mask_max = tensor > ctx.max + grad_out[mask_max] = ctx.pushback + + return grad_out, None, None, None + + +def clamp_with_pushback( + tensor: torch.Tensor, + min: float | None = None, + max: float | None = None, + pushback: float = 1e-2, +) -> torch.Tensor: + output = ClampWithPushback.apply(tensor, min, max, pushback) + assert isinstance(output, torch.Tensor) + return output + + +def hard_sigmoid_with_pushback(x: torch.Tensor, slope: float = 1.0 / 6.0) -> torch.Tensor: + return clamp_with_pushback(slope * x + 0.5, min=0.0, max=1.0) + + +def relu_with_pushback(x: torch.Tensor) -> torch.Tensor: + return clamp_with_pushback(x, min=0.0) diff --git a/unisharp/utils/metrics.py b/unisharp/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..eac01eb829ce21addcd161c19636407331c8a54a --- /dev/null +++ b/unisharp/utils/metrics.py @@ -0,0 +1,428 @@ + +from __future__ import annotations + +import logging +import os +from pathlib import Path +import re +from typing import Any +from typing import Optional + +import torch +from torch import Tensor +from torch.nn import functional as F + +LOGGER = logging.getLogger(__name__) +METRIC_MASK_CACHE_VERSION = "v3_source_bounds" + + +def _compute_psnr(gt: Tensor, pred: Tensor, eps: float = 1e-8) -> Tensor: + mse = torch.mean((gt - pred) ** 2, dim=(1, 2, 3)).clamp_min(eps) + return -10.0 * torch.log10(mse) + + +def _gaussian_kernel(size: int = 11, sigma: float = 1.5, device: Optional[torch.device] = None) -> Tensor: + coords = torch.arange(size, device=device, dtype=torch.float32) - size // 2 + g = torch.exp(-(coords**2) / (2 * sigma * sigma)) + g = g / g.sum().clamp_min(1e-8) + k2d = torch.outer(g, g) + return k2d + + +def _compute_ssim_map(gt: Tensor, pred: Tensor) -> Tensor: + b, c, _, _ = gt.shape + kernel = _gaussian_kernel(device=gt.device).to(dtype=gt.dtype) + kernel = kernel.view(1, 1, kernel.shape[0], kernel.shape[1]).repeat(c, 1, 1, 1) + + mu_x = F.conv2d(pred, kernel, padding=5, groups=c) + mu_y = F.conv2d(gt, kernel, padding=5, groups=c) + mu_x2 = mu_x * mu_x + mu_y2 = mu_y * mu_y + mu_xy = mu_x * mu_y + + sigma_x2 = F.conv2d(pred * pred, kernel, padding=5, groups=c) - mu_x2 + sigma_y2 = F.conv2d(gt * gt, kernel, padding=5, groups=c) - mu_y2 + sigma_xy = F.conv2d(pred * gt, kernel, padding=5, groups=c) - mu_xy + + c1 = (0.01**2) + c2 = (0.03**2) + ssim_map = ((2 * mu_xy + c1) * (2 * sigma_xy + c2)) / ( + (mu_x2 + mu_y2 + c1) * (sigma_x2 + sigma_y2 + c2) + 1e-8 + ) + return ssim_map.view(b, c, gt.shape[-2], gt.shape[-1]) + + +def _compute_ssim(gt: Tensor, pred: Tensor) -> Tensor: + ssim_map = _compute_ssim_map(gt, pred) + return ssim_map.view(int(ssim_map.shape[0]), -1).mean(dim=1) + + +class _LPIPSLike: + + def __init__(self, device: torch.device): + self.device = device + self.net = None + try: + import lpips # type: ignore + + self.net = lpips.LPIPS(net="alex").to(device).eval() + LOGGER.info("LPIPS backend: lpips/alex") + except Exception: + LOGGER.warning("LPIPS package not available, fallback to normalized L1 proxy.") + + @torch.no_grad() + def __call__(self, gt: Tensor, pred: Tensor) -> Tensor: + if self.net is not None: + gt_n = gt * 2.0 - 1.0 + pred_n = pred * 2.0 - 1.0 + val = self.net(pred_n, gt_n, normalize=False) + return val.view(val.shape[0]) + l1 = torch.mean(torch.abs(gt - pred), dim=(1, 2, 3)) + return l1.clamp_min(0.0) + + +class MetricsCalculator: + + def __init__(self, device: torch.device = None, compute_lpips: bool = True): + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.compute_lpips = bool(compute_lpips) + self.lpips_calculator = _LPIPSLike(self.device) if self.compute_lpips else None + + @torch.no_grad() + def compute_lpips_value(self, gt: Tensor, pred: Tensor) -> Tensor: + if not self.compute_lpips or self.lpips_calculator is None: + return torch.full((int(gt.shape[0]),), float("nan"), device=self.device, dtype=torch.float32) + return self.lpips_calculator(gt, pred) + + @torch.no_grad() + def compute_rgb_metrics( + self, + pred: Tensor, + gt: Tensor, + ) -> dict[str, float]: + pred = pred.to(self.device) + gt = gt.to(self.device) + + return { + "psnr": _compute_psnr(gt, pred).mean().item(), + "ssim": _compute_ssim(gt, pred).mean().item(), + "lpips": self.compute_lpips_value(gt, pred).mean().item(), + } + + +@torch.no_grad() +def compute_masked_rgb_metrics( + pred: Tensor, + gt: Tensor, + mask: Tensor, + metrics_calc: MetricsCalculator, +) -> dict[str, float]: + m = (mask.to(dtype=torch.float32) > 0.5).to(dtype=torch.float32) + valid_px = float(m.sum().item()) + if valid_px < 8.0: + return { + "psnr": float("nan"), + "ssim": float("nan"), + "lpips": float("nan"), + "coverage": float(0.0), + } + + denom = (m.sum() * float(pred.shape[1])).clamp(min=1.0) + mse = (((pred - gt) ** 2) * m).sum() / denom + psnr = float((-10.0 * torch.log10(mse.clamp_min(1e-8))).item()) + + ssim_map = _compute_ssim_map(gt, pred).mean(dim=1, keepdim=True) + ssim = float(((ssim_map * m).sum() / m.sum().clamp(min=1.0)).item()) + + idx = torch.nonzero(m[0, 0] > 0.5, as_tuple=False) + if idx.numel() == 0: + lpips_val = float("nan") + else: + y0 = int(idx[:, 0].min().item()) + y1 = int(idx[:, 0].max().item()) + 1 + x0 = int(idx[:, 1].min().item()) + x1 = int(idx[:, 1].max().item()) + 1 + pred_c = pred[:, :, y0:y1, x0:x1] + gt_c = gt[:, :, y0:y1, x0:x1] + m_c = m[:, :, y0:y1, x0:x1] + pred_blend = pred_c * m_c + gt_c * (1.0 - m_c) + try: + lpips_val = float(metrics_calc.compute_lpips_value(gt_c, pred_blend).mean().item()) + except Exception: + l1 = torch.abs(pred_c - gt_c) + lpips_val = float(((l1 * m_c).sum() / (m_c.sum() * 3.0).clamp(min=1.0)).item()) + + return { + "psnr": psnr, + "ssim": ssim, + "lpips": lpips_val, + "coverage": float((m > 0.5).float().mean().item()), + } + + +def default_metric_mask_cache_dir(repo_root: Path | None = None) -> Path: + env = os.environ.get("VALIDATION_METRIC_MASK_CACHE_DIR", "").strip() + if env: + return Path(env) + if repo_root is None: + repo_root = Path(__file__).resolve().parents[2] + return repo_root / "validation_metric_masks" + + +def _safe_part(value: Any) -> str: + text = str(value) + text = text.replace(os.sep, "__") + return re.sub(r"[^A-Za-z0-9_.-]+", "_", text)[:180] + + +def metric_mask_cache_path( + cache_dir: Path, + *, + dataset: str, + scene: Any, + src_idx: Any, + tgt_idx: Any, + height: int, + width: int, +) -> Path: + name = ( + f"{METRIC_MASK_CACHE_VERSION}__{_safe_part(scene)}__s{_safe_part(src_idx)}" + f"__t{_safe_part(tgt_idx)}__{int(height)}x{int(width)}.pt" + ) + return Path(cache_dir) / str(dataset) / name + + +def _as_batched_depth(depth: Tensor | None) -> Tensor | None: + if not torch.is_tensor(depth): + return None + if depth.ndim == 3: + depth = depth.unsqueeze(1) + if depth.ndim != 4: + return None + if depth.shape[1] != 1: + depth = depth[:, :1] + return depth.to(dtype=torch.float32) + + +def _as_batched_k(k: Tensor | None) -> Tensor | None: + if not torch.is_tensor(k): + return None + if k.ndim == 2: + k = k.unsqueeze(0) + if k.ndim != 3 or tuple(k.shape[-2:]) != (3, 3): + return None + return k.to(dtype=torch.float32) + + +def _as_batched_pose(pose: Tensor | None) -> Tensor | None: + if not torch.is_tensor(pose): + return None + if pose.ndim == 2: + pose = pose.unsqueeze(0) + if pose.ndim != 3 or tuple(pose.shape[-2:]) != (4, 4): + return None + return pose.to(dtype=torch.float32) + + +def _item_at(value: Any, index: int, default: Any = None) -> Any: + if value is None: + return default + if isinstance(value, (list, tuple)): + return value[index] if 0 <= int(index) < len(value) else default + if torch.is_tensor(value): + if value.numel() == 0: + return default + if value.ndim == 0: + return value.item() + if 0 <= int(index) < int(value.shape[0]): + item = value[int(index)] + return item.item() if item.numel() == 1 else item + return default + return value + + +def _batch_get(batch: Any, *names: str) -> Any: + for name in names: + if isinstance(batch, dict) and name in batch: + value = batch[name] + if value is not None: + return value + if hasattr(batch, name): + value = getattr(batch, name) + if value is not None: + return value + return None + + +@torch.no_grad() +def compute_source_frustum_mask( + *, + depth: Tensor, + tgt_w2c: Tensor, + src_w2c: Tensor, + src_k3: Tensor, + tgt_k3: Tensor, + target_hw: tuple[int, int], + source_hw: tuple[int, int] | None = None, +) -> Tensor: + depth = _as_batched_depth(depth) + tgt_w2c = _as_batched_pose(tgt_w2c) + src_w2c = _as_batched_pose(src_w2c) + src_k3 = _as_batched_k(src_k3) + tgt_k3 = _as_batched_k(tgt_k3) + if depth is None or tgt_w2c is None or src_w2c is None or src_k3 is None or tgt_k3 is None: + raise ValueError("Invalid geometry inputs for source frustum mask.") + + device = depth.device + dtype = torch.float32 + target_h, target_w = int(target_hw[0]), int(target_hw[1]) + source_h, source_w = (target_h, target_w) if source_hw is None else (int(source_hw[0]), int(source_hw[1])) + if tuple(depth.shape[-2:]) != (target_h, target_w): + depth = F.interpolate(depth, size=(target_h, target_w), mode="nearest") + + y_coords, x_coords = torch.meshgrid( + torch.arange(target_h, device=device, dtype=dtype), + torch.arange(target_w, device=device, dtype=dtype), + indexing="ij", + ) + + z = depth[0, 0].to(dtype=dtype) + valid_depth = torch.isfinite(z) & (z > 0.0) + fx_t = tgt_k3[0, 0, 0].to(device=device, dtype=dtype).clamp(min=1e-6) + fy_t = tgt_k3[0, 1, 1].to(device=device, dtype=dtype).clamp(min=1e-6) + cx_t = tgt_k3[0, 0, 2].to(device=device, dtype=dtype) + cy_t = tgt_k3[0, 1, 2].to(device=device, dtype=dtype) + + x_cam = (x_coords - cx_t) * z / fx_t + y_cam = (y_coords - cy_t) * z / fy_t + pts_tgt = torch.stack([x_cam, y_cam, z, torch.ones_like(z)], dim=0).reshape(4, -1) + + tgt_c2w = torch.linalg.inv(tgt_w2c[0].to(device=device, dtype=dtype)) + pts_world = tgt_c2w @ pts_tgt + pts_src = src_w2c[0].to(device=device, dtype=dtype) @ pts_world + + fx_s = src_k3[0, 0, 0].to(device=device, dtype=dtype).clamp(min=1e-6) + fy_s = src_k3[0, 1, 1].to(device=device, dtype=dtype).clamp(min=1e-6) + cx_s = src_k3[0, 0, 2].to(device=device, dtype=dtype) + cy_s = src_k3[0, 1, 2].to(device=device, dtype=dtype) + z_src = pts_src[2] + u_src = fx_s * (pts_src[0] / z_src.clamp(min=1e-6)) + cx_s + v_src = fy_s * (pts_src[1] / z_src.clamp(min=1e-6)) + cy_s + + valid = ( + valid_depth.reshape(-1) + & (z_src > 0.0) + & (u_src >= 0.0) + & (u_src < float(source_w)) + & (v_src >= 0.0) + & (v_src < float(source_h)) + ) + return valid.reshape(target_h, target_w).to(dtype=torch.float32)[None, None] + + +def load_cached_metric_mask( + cache_dir: Path | None, + *, + dataset: str, + scene: Any, + src_idx: Any, + tgt_idx: Any, + height: int, + width: int, + device: torch.device, +) -> Tensor | None: + if cache_dir is None: + return None + path = metric_mask_cache_path( + Path(cache_dir), + dataset=dataset, + scene=scene, + src_idx=src_idx, + tgt_idx=tgt_idx, + height=height, + width=width, + ) + if not path.exists(): + return None + payload = torch.load(path, map_location="cpu") + value = payload.get("mask", payload) if isinstance(payload, dict) else payload + if not torch.is_tensor(value): + return None + if value.ndim == 2: + value = value[None, None] + elif value.ndim == 3: + value = value.unsqueeze(1) + if value.ndim != 4: + return None + if tuple(value.shape[-2:]) != (int(height), int(width)): + value = F.interpolate(value.to(dtype=torch.float32), size=(int(height), int(width)), mode="nearest") + return (value.to(device=device, dtype=torch.float32) > 0.5).to(dtype=torch.float32) + + +@torch.no_grad() +def metric_mask_from_pinhole_batch( + batch: Any, + *, + dataset: str, + cache_dir: Path | None = None, + device: torch.device | None = None, +) -> Tensor | None: + tgt_img = _batch_get(batch, "tgt_img", "tgt_rgb_u8") + src_img = _batch_get(batch, "src_img", "src_rgb_u8") + src_w2c = _as_batched_pose(_batch_get(batch, "src_w2c")) + tgt_w2c = _as_batched_pose(_batch_get(batch, "tgt_w2c")) + src_k = _as_batched_k(_batch_get(batch, "src_k", "src_intrinsics")) + tgt_k = _as_batched_k(_batch_get(batch, "tgt_k", "tgt_intrinsics")) + tgt_depth = _as_batched_depth(_batch_get(batch, "tgt_depth", "tgt_depth_m", "tgt_depth_m_orig")) + if not torch.is_tensor(tgt_img) or src_w2c is None or tgt_w2c is None or src_k is None or tgt_k is None: + return None + + if device is None: + device = tgt_img.device if torch.is_tensor(tgt_img) else torch.device("cpu") + target_h, target_w = int(tgt_img.shape[-2]), int(tgt_img.shape[-1]) + source_h, source_w = (target_h, target_w) + if torch.is_tensor(src_img): + source_h, source_w = int(src_img.shape[-2]), int(src_img.shape[-1]) + batch_size = int(tgt_img.shape[0]) + scene_value = _batch_get(batch, "scene") + src_idx_value = _batch_get(batch, "src_idx") + tgt_idx_value = _batch_get(batch, "tgt_idx") + + masks: list[Tensor] = [] + for b in range(batch_size): + scene = _item_at(scene_value, b, "unknown") + src_idx = _item_at(src_idx_value, b, -1) + tgt_idx = _item_at(tgt_idx_value, b, -1) + depth_b = tgt_depth[b : b + 1].to(device=device) if tgt_depth is not None and b < int(tgt_depth.shape[0]) else None + if depth_b is not None: + try: + mask_b = compute_source_frustum_mask( + depth=depth_b, + tgt_w2c=tgt_w2c[min(b, int(tgt_w2c.shape[0]) - 1) : min(b, int(tgt_w2c.shape[0]) - 1) + 1].to(device), + src_w2c=src_w2c[min(b, int(src_w2c.shape[0]) - 1) : min(b, int(src_w2c.shape[0]) - 1) + 1].to(device), + src_k3=src_k[min(b, int(src_k.shape[0]) - 1) : min(b, int(src_k.shape[0]) - 1) + 1].to(device), + tgt_k3=tgt_k[min(b, int(tgt_k.shape[0]) - 1) : min(b, int(tgt_k.shape[0]) - 1) + 1].to(device), + target_hw=(target_h, target_w), + source_hw=(source_h, source_w), + ) + masks.append(mask_b) + continue + except Exception: + pass + cached = load_cached_metric_mask( + cache_dir, + dataset=dataset, + scene=scene, + src_idx=src_idx, + tgt_idx=tgt_idx, + height=target_h, + width=target_w, + device=device, + ) + if cached is None: + return None + masks.append(cached) + if not masks: + return None + return torch.cat(masks, dim=0).to(device=device, dtype=torch.float32) + diff --git a/unisharp/utils/pano.py b/unisharp/utils/pano.py new file mode 100644 index 0000000000000000000000000000000000000000..44d19092378953f577958227ebd523da77e2f696 --- /dev/null +++ b/unisharp/utils/pano.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import math +from typing import Literal + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +CubemapFace = Literal["F", "R", "B", "L", "U", "D"] +_FACE_ORDER: list[CubemapFace] = ["U", "B", "L", "F", "R", "D"] + + +def get_pinhole_intrinsics_4x4(face_w: int, fov_degrees: float = 90.0) -> torch.Tensor: + fov = math.radians(fov_degrees) + f_px = (face_w - 1) / 2.0 / math.tan(fov / 2.0) + cx = (face_w - 1) / 2.0 + cy = (face_w - 1) / 2.0 + intr = torch.tensor( + [ + [f_px, 0.0, cx, 0.0], + [0.0, f_px, cy, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=torch.float32, + ) + return intr + + +def _rotation_world_to_cam(face: CubemapFace) -> torch.Tensor: + if face == "F": + x_cam = torch.tensor([1.0, 0.0, 0.0]) + y_cam = torch.tensor([0.0, 1.0, 0.0]) + z_cam = torch.tensor([0.0, 0.0, 1.0]) + elif face == "R": + x_cam = torch.tensor([0.0, 0.0, -1.0]) + y_cam = torch.tensor([0.0, 1.0, 0.0]) + z_cam = torch.tensor([1.0, 0.0, 0.0]) + elif face == "B": + x_cam = torch.tensor([-1.0, 0.0, 0.0]) + y_cam = torch.tensor([0.0, 1.0, 0.0]) + z_cam = torch.tensor([0.0, 0.0, -1.0]) + elif face == "L": + x_cam = torch.tensor([0.0, 0.0, 1.0]) + y_cam = torch.tensor([0.0, 1.0, 0.0]) + z_cam = torch.tensor([-1.0, 0.0, 0.0]) + elif face == "U": + x_cam = torch.tensor([-1.0, 0.0, 0.0]) + y_cam = torch.tensor([0.0, 0.0, -1.0]) + z_cam = torch.tensor([0.0, -1.0, 0.0]) + elif face == "D": + x_cam = torch.tensor([-1.0, 0.0, 0.0]) + y_cam = torch.tensor([0.0, 0.0, 1.0]) + z_cam = torch.tensor([0.0, 1.0, 0.0]) + else: + raise ValueError(f"Unsupported face: {face}") + + return torch.stack([x_cam, y_cam, z_cam], dim=0) + + +def get_cubemap_extrinsics_4x4( + device: torch.device, + yaw_degrees: float = 0.0, + faces: list[CubemapFace] | None = None, +) -> torch.Tensor: + if faces is None: + faces = _FACE_ORDER + + yaw = math.radians(yaw_degrees) + cy, sy = math.cos(yaw), math.sin(yaw) + r_yaw = torch.tensor( + [ + [cy, 0.0, sy], + [0.0, 1.0, 0.0], + [-sy, 0.0, cy], + ], + dtype=torch.float32, + device=device, + ) + + mats = [] + for face in faces: + r_wc0 = _rotation_world_to_cam(face).to(device=device) + r_wc = r_wc0 @ r_yaw.T + ext = torch.eye(4, dtype=torch.float32, device=device) + ext[:3, :3] = r_wc + mats.append(ext) + return torch.stack(mats, dim=0) + + +class Cube2Equirec(nn.Module): + + def __init__(self, face_w: int, equ_h: int, equ_w: int) -> None: + super().__init__() + self.face_w = face_w + self.equ_h = equ_h + self.equ_w = equ_w + + tp, sample_grid = self._build_sample_grid() + self.register_buffer("tp", tp, persistent=False) + self.register_buffer("sample_grid", sample_grid, persistent=False) + + def _build_sample_grid(self) -> tuple[torch.Tensor, torch.Tensor]: + equ_h, equ_w = self.equ_h, self.equ_w + device = torch.device("cpu") + + tp = np.roll( + np.arange(4).repeat(equ_w // 4)[None, :].repeat(equ_h, 0), 3 * equ_w // 8, 1 + ) + + mask = np.zeros((equ_h, equ_w // 4), np.bool_) + idx = np.linspace(-np.pi, np.pi, equ_w // 4) / 4 + idx = equ_h // 2 - np.round(np.arctan(np.cos(idx)) * equ_h / np.pi).astype(int) + for i, j in enumerate(idx): + mask[:j, i] = 1 + mask = np.roll(np.concatenate([mask] * 4, 1), 3 * equ_w // 8, 1) + tp[mask] = 4 + tp[np.flip(mask, 0)] = 5 + + lon = ((np.linspace(0, equ_w - 1, num=equ_w, dtype=np.float32) + 0.5) / equ_w - 0.5) * 2 * np.pi + lat = -((np.linspace(0, equ_h - 1, num=equ_h, dtype=np.float32) + 0.5) / equ_h - 0.5) * np.pi + lon, lat = np.meshgrid(lon, lat) + coor_u = np.zeros((equ_h, equ_w), dtype=np.float32) + coor_v = np.zeros((equ_h, equ_w), dtype=np.float32) + for i in range(4): + m = tp == i + coor_u[m] = 0.5 * np.tan(lon[m] - np.pi * i / 2) + coor_v[m] = -0.5 * np.tan(lat[m]) / np.cos(lon[m] - np.pi * i / 2) + m = tp == 4 + c = 0.5 * np.tan(np.pi / 2 - lat[m]) + coor_u[m] = c * np.sin(lon[m]) + coor_v[m] = c * np.cos(lon[m]) + m = tp == 5 + c = 0.5 * np.tan(np.pi / 2 - np.abs(lat[m])) + coor_u[m] = c * np.sin(lon[m]) + coor_v[m] = -c * np.cos(lon[m]) + + coor_u = (np.clip(coor_u, -0.5, 0.5)) * 2 + coor_v = (np.clip(coor_v, -0.5, 0.5)) * 2 + + tp_t = torch.from_numpy(tp.astype(np.float32) / 2.5 - 1.0).to(device) + u_t = torch.from_numpy(coor_u).to(device) + v_t = torch.from_numpy(coor_v).to(device) + + sample_grid = torch.stack([u_t, v_t, tp_t], dim=-1).view(1, 1, equ_h, equ_w, 3) + return tp_t, sample_grid + + def forward(self, cube_feat: torch.Tensor) -> torch.Tensor: + bs = cube_feat.shape[0] + cube_feat = cube_feat[:, :, [3, 4, 1, 2, 0, 5], :, :] + cube_feat[:, :, 4:] = torch.flip(cube_feat[:, :, 4:], [3, 4]) + sample_grid = torch.cat([self.sample_grid.to(cube_feat.device)] * bs, dim=0) + equi_feat = F.grid_sample( + cube_feat, + sample_grid, + padding_mode="border", + align_corners=True, + ) + return equi_feat.squeeze(2) + + + diff --git a/unisharp/utils/pixel_convention.py b/unisharp/utils/pixel_convention.py new file mode 100644 index 0000000000000000000000000000000000000000..1648f73d7f652f5067fc87f7cbdcf1f675157042 --- /dev/null +++ b/unisharp/utils/pixel_convention.py @@ -0,0 +1,50 @@ + +from __future__ import annotations + +import torch + + +def integer_pixel_center_grid( + height: int, + width: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + vv, uu = torch.meshgrid( + torch.arange(int(height), device=device, dtype=dtype), + torch.arange(int(width), device=device, dtype=dtype), + indexing="ij", + ) + return uu, vv + + +def scale_intrinsics_align_corners_false( + k3: torch.Tensor, + *, + sx: float, + sy: float, +) -> torch.Tensor: + kk = k3.clone() + kk[..., 0, 0] = kk[..., 0, 0] * float(sx) + kk[..., 1, 1] = kk[..., 1, 1] * float(sy) + kk[..., 0, 2] = (kk[..., 0, 2] + 0.5) * float(sx) - 0.5 + kk[..., 1, 2] = (kk[..., 1, 2] + 0.5) * float(sy) - 0.5 + return kk + + +def normalized_intrinsics_to_integer_pixel_k( + fx_norm: torch.Tensor, + fy_norm: torch.Tensor, + cx_norm: torch.Tensor, + cy_norm: torch.Tensor, + *, + height: int, + width: int, +) -> torch.Tensor: + intr = torch.eye(3, dtype=torch.float32, device=fx_norm.device).unsqueeze(0).repeat(int(fx_norm.shape[0]), 1, 1) + intr[:, 0, 0] = fx_norm.to(torch.float32) * float(width) + intr[:, 1, 1] = fy_norm.to(torch.float32) * float(height) + intr[:, 0, 2] = cx_norm.to(torch.float32) * float(width) - 0.5 + intr[:, 1, 2] = cy_norm.to(torch.float32) * float(height) - 0.5 + return intr diff --git a/unisharp/utils/rayfit_camera.py b/unisharp/utils/rayfit_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc7d1bb7d601101fe495fd1cf4033985e0d53a3 --- /dev/null +++ b/unisharp/utils/rayfit_camera.py @@ -0,0 +1,202 @@ + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from unisharp.utils.pixel_convention import integer_pixel_center_grid, scale_intrinsics_align_corners_false + + +def _pixel_grid( + height: int, + width: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + xx, yy = integer_pixel_center_grid(int(height), int(width), device=device, dtype=dtype) + return xx.reshape(-1), yy.reshape(-1) + + +def _solve_linear_2param(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + a = a.to(dtype=torch.float32) + b = b.to(device=a.device, dtype=a.dtype) + with torch.autocast(device_type=a.device.type, enabled=False): + ata = a.transpose(0, 1) @ a + ridge = torch.eye(2, device=a.device, dtype=a.dtype) * 1e-6 + atb = a.transpose(0, 1) @ b + return torch.linalg.solve(ata + ridge, atb) + + +def _solve_linear(a: torch.Tensor, b: torch.Tensor, *, ridge: float = 1e-6) -> torch.Tensor: + a = a.to(dtype=torch.float32) + b = b.to(device=a.device, dtype=a.dtype) + with torch.autocast(device_type=a.device.type, enabled=False): + ata = a.transpose(0, 1) @ a + reg = torch.eye(int(a.shape[1]), device=a.device, dtype=a.dtype) * float(ridge) + atb = a.transpose(0, 1) @ b + return torch.linalg.solve(ata + reg, atb) + + +def fit_pinhole_intrinsics_from_rays( + rays: torch.Tensor, + *, + min_focal_px: float = 1.0, + max_samples: int = 65536, +) -> torch.Tensor: + if rays.ndim != 4 or int(rays.shape[1]) != 3: + raise ValueError(f"Expected rays shape (B,3,H,W), got {tuple(rays.shape)}") + with torch.no_grad(): + rays_f = rays.detach().to(dtype=torch.float32) + bsz, _, height, width = rays_f.shape + uu, vv = _pixel_grid(height, width, device=rays_f.device, dtype=rays_f.dtype) + stride = max(1, int((height * width + int(max_samples) - 1) // int(max_samples))) + uu_s = uu[::stride] + vv_s = vv[::stride] + out = torch.zeros((bsz, 3, 3), device=rays_f.device, dtype=rays_f.dtype) + out[:, 2, 2] = 1.0 + for b in range(int(bsz)): + rb = rays_f[b].reshape(3, -1)[:, ::stride] + x, y, z = rb.unbind(dim=0) + valid = torch.isfinite(rb).all(dim=0) & (z > 1e-4) + if int(valid.sum().item()) < 16: + fx = fy = float(max(height, width)) + cx = (float(width) - 1.0) * 0.5 + cy = (float(height) - 1.0) * 0.5 + else: + z_valid = z[valid] + xz = x[valid] / z_valid + yz = y[valid] / z_valid + ones = torch.ones_like(xz) + sol_x = _solve_linear_2param(torch.stack([xz, ones], dim=1), uu_s[valid]) + sol_y = _solve_linear_2param(torch.stack([yz, ones], dim=1), vv_s[valid]) + fx = float(torch.clamp(sol_x[0], min=float(min_focal_px)).item()) + fy = float(torch.clamp(sol_y[0], min=float(min_focal_px)).item()) + cx = float(sol_x[1].item()) + cy = float(sol_y[1].item()) + out[b, 0, 0] = fx + out[b, 1, 1] = fy + out[b, 0, 2] = cx + out[b, 1, 2] = cy + return out + + +def scale_pinhole_intrinsics( + intrinsics: torch.Tensor, + *, + src_hw: tuple[int, int], + dst_hw: tuple[int, int], +) -> torch.Tensor: + if tuple(int(x) for x in src_hw) == tuple(int(x) for x in dst_hw): + return intrinsics + src_h, src_w = int(src_hw[0]), int(src_hw[1]) + dst_h, dst_w = int(dst_hw[0]), int(dst_hw[1]) + return scale_intrinsics_align_corners_false( + intrinsics, + sx=float(dst_w) / float(max(src_w, 1)), + sy=float(dst_h) / float(max(src_h, 1)), + ) + + +def fit_fisheye624_params_from_rays( + rays: torch.Tensor, + *, + min_focal_px: float = 1.0, + max_samples: int = 65536, +) -> torch.Tensor: + if rays.ndim != 4 or int(rays.shape[1]) != 3: + raise ValueError(f"Expected rays shape (B,3,H,W), got {tuple(rays.shape)}") + with torch.no_grad(): + rays_f = F.normalize(rays.detach().to(dtype=torch.float32), dim=1, eps=1e-6) + bsz, _, height, width = rays_f.shape + uu, vv = _pixel_grid(height, width, device=rays_f.device, dtype=rays_f.dtype) + stride = max(1, int((height * width + int(max_samples) - 1) // int(max_samples))) + uu_s = uu[::stride] + vv_s = vv[::stride] + out = torch.zeros((bsz, 16), device=rays_f.device, dtype=rays_f.dtype) + for b in range(int(bsz)): + rb = rays_f[b].reshape(3, -1)[:, ::stride] + x, y, z = rb.unbind(dim=0) + xy_norm = torch.sqrt(x.square() + y.square()).clamp(min=1e-8) + theta = torch.atan2(xy_norm, z.clamp(min=1e-8)) + dir_x = x / xy_norm + dir_y = y / xy_norm + xd = theta * dir_x + yd = theta * dir_y + valid = torch.isfinite(rb).all(dim=0) & torch.isfinite(xd) & torch.isfinite(yd) & (z > 1e-4) + if int(valid.sum().item()) < 16: + fx = fy = float(max(height, width)) * 0.5 + cx = (float(width) - 1.0) * 0.5 + cy = (float(height) - 1.0) * 0.5 + coeffs = torch.zeros(4, device=rays_f.device, dtype=rays_f.dtype) + else: + ones = torch.ones_like(xd[valid]) + sol_x = _solve_linear_2param(torch.stack([xd[valid], ones], dim=1), uu_s[valid]) + sol_y = _solve_linear_2param(torch.stack([yd[valid], ones], dim=1), vv_s[valid]) + fx_t = torch.clamp(sol_x[0], min=float(min_focal_px)) + fy_t = torch.clamp(sol_y[0], min=float(min_focal_px)) + cx_t = sol_x[1] + cy_t = sol_y[1] + + x_img = (uu_s[valid] - cx_t) / fx_t + y_img = (vv_s[valid] - cy_t) / fy_t + rho_obs = x_img * dir_x[valid] + y_img * dir_y[valid] + theta_v = theta[valid] + coeff_basis = torch.stack([theta_v.pow(3 + i * 2) for i in range(4)], dim=1) + coeff_valid = torch.isfinite(rho_obs) & torch.isfinite(coeff_basis).all(dim=1) & (theta_v > 1e-6) + if int(coeff_valid.sum().item()) >= 16: + coeffs = _solve_linear( + coeff_basis[coeff_valid], + (rho_obs - theta_v)[coeff_valid], + ridge=1e-4, + ).clamp(min=-10.0, max=10.0) + else: + coeffs = torch.zeros(4, device=rays_f.device, dtype=rays_f.dtype) + + theta_dist = theta + sum(coeffs[i] * theta.pow(3 + i * 2) for i in range(4)) + xd_refit = theta_dist * dir_x + yd_refit = theta_dist * dir_y + refit_valid = valid & torch.isfinite(xd_refit) & torch.isfinite(yd_refit) + if int(refit_valid.sum().item()) >= 16: + ones_refit = torch.ones_like(xd_refit[refit_valid]) + sol_x = _solve_linear_2param( + torch.stack([xd_refit[refit_valid], ones_refit], dim=1), + uu_s[refit_valid], + ) + sol_y = _solve_linear_2param( + torch.stack([yd_refit[refit_valid], ones_refit], dim=1), + vv_s[refit_valid], + ) + fx_t = torch.clamp(sol_x[0], min=float(min_focal_px)) + fy_t = torch.clamp(sol_y[0], min=float(min_focal_px)) + cx_t = sol_x[1] + cy_t = sol_y[1] + fx = float(fx_t.item()) + fy = float(fy_t.item()) + cx = float(cx_t.item()) + cy = float(cy_t.item()) + out[b, 0] = fx + out[b, 1] = fy + out[b, 2] = cx + out[b, 3] = cy + out[b, 4:8] = coeffs + return out + + +def scale_fisheye624_params( + params: torch.Tensor, + *, + src_hw: tuple[int, int], + dst_hw: tuple[int, int], +) -> torch.Tensor: + if tuple(int(x) for x in src_hw) == tuple(int(x) for x in dst_hw): + return params + src_h, src_w = int(src_hw[0]), int(src_hw[1]) + dst_h, dst_w = int(dst_hw[0]), int(dst_hw[1]) + out = params.clone() + sx = float(dst_w) / float(max(src_w, 1)) + sy = float(dst_h) / float(max(src_h, 1)) + out[:, 0] *= sx + out[:, 1] *= sy + out[:, 2] = (out[:, 2] + 0.5) * sx - 0.5 + out[:, 3] = (out[:, 3] + 0.5) * sy - 0.5 + return out diff --git a/unisharp/utils/unified_vis.py b/unisharp/utils/unified_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..94bec1fb26576bef715170389ede23b84cc4ad76 --- /dev/null +++ b/unisharp/utils/unified_vis.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import torch +from PIL import Image, ImageDraw + +from unisharp.utils.color_space import linearRGB2sRGB +from unisharp.utils.io import save_image +from unisharp.utils.vis import colorize_alpha, colorize_scalar_map + + +def _to_u8_hwc(img_chw: torch.Tensor) -> np.ndarray: + if img_chw.dtype == torch.uint8: + return img_chw.permute(1, 2, 0).detach().cpu().numpy() + x = img_chw.detach().to(torch.float32).clamp(0, 1) + return (x * 255.0).round().to(torch.uint8).permute(1, 2, 0).cpu().numpy() + + +def _concat_grid(rows: list[list[np.ndarray]], pad: int = 6, pad_value: int = 0) -> np.ndarray: + row_imgs = [] + for r in rows: + padded = [] + for i, im in enumerate(r): + padded.append(im) + if i != len(r) - 1 and pad > 0: + padded.append(np.full((im.shape[0], pad, 3), pad_value, dtype=np.uint8)) + row_imgs.append(np.concatenate(padded, axis=1)) + padded_rows = [] + for i, im in enumerate(row_imgs): + padded_rows.append(im) + if i != len(row_imgs) - 1 and pad > 0: + padded_rows.append(np.full((pad, im.shape[1], 3), pad_value, dtype=np.uint8)) + return np.concatenate(padded_rows, axis=0) + + +def _pose_to_text(pose_w2c: torch.Tensor | None) -> str: + if pose_w2c is None: + return "None" + p = pose_w2c.detach().to(torch.float32).cpu() + if p.ndim == 3: + p = p[0] + p = p[:3, :4] + rows = [] + for r in range(3): + vals = [f"{float(v):+.3f}" for v in p[r].tolist()] + rows.append("[" + ",".join(vals) + "]") + return " ".join(rows) + + +def _append_text_header(image: np.ndarray, lines: list[str]) -> np.ndarray: + if len(lines) == 0: + return image + h, w = image.shape[:2] + line_h = 16 + header_h = 6 + line_h * len(lines) + canvas = np.zeros((h + header_h, w, 3), dtype=np.uint8) + canvas[header_h:, :, :] = image + pil_img = Image.fromarray(canvas) + draw = ImageDraw.Draw(pil_img) + for i, txt in enumerate(lines): + draw.text((6, 3 + i * line_h), txt, fill=(255, 255, 255)) + return np.asarray(pil_img) + + +def _range_from(depth_list: list[torch.Tensor | None]) -> tuple[float, float]: + vals = [] + for d in depth_list: + if d is None: + continue + valid = d[torch.isfinite(d) & (d > 0.0)] + if valid.numel() > 8: + vals.append(valid) + if len(vals) == 0: + return (0.0, 10.0) + vv = torch.cat(vals, dim=0) + vmin = float(torch.quantile(vv, 0.01).item()) + vmax = float(torch.quantile(vv, 0.99).item()) + vmin = max(0.0, vmin) + vmax = max(vmin + 1e-3, vmax) + return (vmin, vmax) + + +def _depth_u8_or_blank( + depth: torch.Tensor | None, + val_min: float, + val_max: float, + blank: np.ndarray, + *, + mask_invalid_black: bool, +) -> np.ndarray: + if depth is None: + return blank + valid = torch.isfinite(depth) & (depth > 0.0) + if int(valid.sum().item()) < 8: + return blank + valid_vals = depth[valid] + fill_val = float(torch.quantile(valid_vals, 0.5).item()) if valid_vals.numel() > 0 else float(val_min) + depth_clean = torch.where(valid, depth, torch.full_like(depth, fill_val)) + depth_clean = depth_clean.clamp(min=float(val_min), max=float(val_max)) + depth_u8 = _to_u8_hwc(colorize_scalar_map(depth_clean[0, 0], val_min=val_min, val_max=val_max, color_map="turbo")) + if mask_invalid_black: + valid_2d = valid[0, 0].detach().cpu().numpy() + depth_u8[~valid_2d] = 0 + return depth_u8 + + +def _to_face_u8_list(cube_img: torch.Tensor, face_count: int = 6) -> list[np.ndarray]: + x = cube_img + if x.ndim == 5 and x.shape[0] == 1: + x = x[0] + if x.ndim != 4: + return [] + faces = [] + if x.shape[0] == face_count and x.shape[1] == 3: + for i in range(face_count): + faces.append(_to_u8_hwc(x[i])) + elif x.shape[0] == face_count and x.shape[-1] == 3: + for i in range(face_count): + xi = x[i].permute(2, 0, 1).contiguous() + faces.append(_to_u8_hwc(xi)) + return faces + + +def _make_cube_rows( + src_cube_gt_u8: torch.Tensor | None, + src_cube_pred_linear: torch.Tensor | None, + src_cube_alpha: torch.Tensor | None, + tgt_cube_gt_u8: torch.Tensor | None, + tgt_cube_pred_linear: torch.Tensor | None, + tgt_cube_alpha: torch.Tensor | None, +) -> list[list[np.ndarray]] | None: + if src_cube_gt_u8 is None or src_cube_pred_linear is None or src_cube_alpha is None: + return None + if tgt_cube_gt_u8 is None or tgt_cube_pred_linear is None or tgt_cube_alpha is None: + return None + + src_gt_faces = _to_face_u8_list(src_cube_gt_u8) + tgt_gt_faces = _to_face_u8_list(tgt_cube_gt_u8) + if len(src_gt_faces) != 6 or len(tgt_gt_faces) != 6: + return None + + src_pred = linearRGB2sRGB( + (src_cube_pred_linear / src_cube_alpha.clamp(min=1e-4)).clamp(0.0, 1.0) + ).clamp(0.0, 1.0) + tgt_pred = linearRGB2sRGB( + (tgt_cube_pred_linear / tgt_cube_alpha.clamp(min=1e-4)).clamp(0.0, 1.0) + ).clamp(0.0, 1.0) + src_pred_faces = _to_face_u8_list(src_pred) + tgt_pred_faces = _to_face_u8_list(tgt_pred) + if len(src_pred_faces) != 6 or len(tgt_pred_faces) != 6: + return None + + src_gt_f = torch.stack( + [torch.from_numpy(x).permute(2, 0, 1).to(torch.float32) / 255.0 for x in src_gt_faces], + dim=0, + ) + tgt_gt_f = torch.stack( + [torch.from_numpy(x).permute(2, 0, 1).to(torch.float32) / 255.0 for x in tgt_gt_faces], + dim=0, + ) + src_err = (src_pred.detach().cpu() - src_gt_f).abs().mean(dim=1, keepdim=True) + tgt_err = (tgt_pred.detach().cpu() - tgt_gt_f).abs().mean(dim=1, keepdim=True) + vmax = float( + max( + 1e-3, + min(float(torch.quantile(torch.cat([src_err.flatten(), tgt_err.flatten()]), 0.99).item()), 0.5), + ) + ) + src_err_faces = [_to_u8_hwc(colorize_scalar_map(src_err[i, 0], val_min=0.0, val_max=vmax, color_map="turbo")) for i in range(6)] + tgt_err_faces = [_to_u8_hwc(colorize_scalar_map(tgt_err[i, 0], val_min=0.0, val_max=vmax, color_map="turbo")) for i in range(6)] + + return [ + src_gt_faces, + src_pred_faces, + src_err_faces, + tgt_gt_faces, + tgt_pred_faces, + tgt_err_faces, + ] + + +def save_pair_visualization( + out_file: Path, + *, + src_gt: torch.Tensor, + src_pred: torch.Tensor, + src_alpha: torch.Tensor, + tgt_gt: torch.Tensor, + tgt_pred: torch.Tensor, + tgt_alpha: torch.Tensor, + src_gt_depth: torch.Tensor | None = None, + tgt_gt_depth: torch.Tensor | None = None, + src_pred_depth: torch.Tensor | None = None, + tgt_pred_depth: torch.Tensor | None = None, + src_unik3d_depth: torch.Tensor | None = None, + tgt_unik3d_depth: torch.Tensor | None = None, + dataset_name: str | None = None, + scene: str | None = None, + step: int | None = None, + src_idx: int | None = None, + tgt_idx: int | None = None, + src_pose_w2c: torch.Tensor | None = None, + tgt_pose_w2c: torch.Tensor | None = None, + src_cube_gt_u8: torch.Tensor | None = None, + src_cube_pred_linear: torch.Tensor | None = None, + src_cube_alpha: torch.Tensor | None = None, + tgt_cube_gt_u8: torch.Tensor | None = None, + tgt_cube_pred_linear: torch.Tensor | None = None, + tgt_cube_alpha: torch.Tensor | None = None, +) -> None: + out_file.parent.mkdir(parents=True, exist_ok=True) + + src_a = src_alpha.clamp(0.0, 1.0) + tgt_a = tgt_alpha.clamp(0.0, 1.0) + src_vis_lin = (src_pred / src_a.clamp(min=1e-4)).clamp(0.0, 1.0) + tgt_vis_lin = (tgt_pred / tgt_a.clamp(min=1e-4)).clamp(0.0, 1.0) + src_vis = linearRGB2sRGB(src_vis_lin).clamp(0.0, 1.0) + tgt_vis = linearRGB2sRGB(tgt_vis_lin).clamp(0.0, 1.0) + src_err = (src_vis - src_gt).abs().mean(dim=1, keepdim=True) + tgt_err = (tgt_vis - tgt_gt).abs().mean(dim=1, keepdim=True) + vmax = float( + max( + 1e-3, + min(float(torch.quantile(torch.cat([src_err.flatten(), tgt_err.flatten()]), 0.99).item()), 0.5), + ) + ) + src_err_u8 = colorize_scalar_map(src_err[0, 0], val_min=0.0, val_max=vmax, color_map="turbo") + tgt_err_u8 = colorize_scalar_map(tgt_err[0, 0], val_min=0.0, val_max=vmax, color_map="turbo") + src_alpha_u8 = colorize_alpha(src_alpha)[0] + tgt_alpha_u8 = colorize_alpha(tgt_alpha)[0] + + has_gt_depth = (src_gt_depth is not None) and (tgt_gt_depth is not None) + has_render_depth = (src_pred_depth is not None) and (tgt_pred_depth is not None) + has_unik3d_depth = (src_unik3d_depth is not None) or (tgt_unik3d_depth is not None) + + base_hwc = _to_u8_hwc(src_gt[0]) + blank = np.zeros_like(base_hwc) + + if has_gt_depth: + gt_min, gt_max = _range_from([src_gt_depth, tgt_gt_depth]) + render_min, render_max = gt_min, gt_max + unik_min, unik_max = gt_min, gt_max + else: + gt_min, gt_max = (0.0, 10.0) + shared_min, shared_max = _range_from( + [src_pred_depth, tgt_pred_depth, src_unik3d_depth, tgt_unik3d_depth] + ) + render_min, render_max = shared_min, shared_max + unik_min, unik_max = shared_min, shared_max + + src_cols = [_to_u8_hwc(src_gt[0]), _to_u8_hwc(src_vis[0]), _to_u8_hwc(src_err_u8), _to_u8_hwc(src_alpha_u8)] + tgt_cols = [_to_u8_hwc(tgt_gt[0]), _to_u8_hwc(tgt_vis[0]), _to_u8_hwc(tgt_err_u8), _to_u8_hwc(tgt_alpha_u8)] + if has_gt_depth: + src_cols.append(_depth_u8_or_blank(src_gt_depth, gt_min, gt_max, blank, mask_invalid_black=True)) + tgt_cols.append(_depth_u8_or_blank(tgt_gt_depth, gt_min, gt_max, blank, mask_invalid_black=True)) + if has_render_depth: + src_cols.append(_depth_u8_or_blank(src_pred_depth, render_min, render_max, blank, mask_invalid_black=True)) + tgt_cols.append(_depth_u8_or_blank(tgt_pred_depth, render_min, render_max, blank, mask_invalid_black=True)) + if has_unik3d_depth: + src_cols.append(_depth_u8_or_blank(src_unik3d_depth, unik_min, unik_max, blank, mask_invalid_black=False)) + tgt_cols.append(_depth_u8_or_blank(tgt_unik3d_depth, unik_min, unik_max, blank, mask_invalid_black=False)) + + erp_grid = _concat_grid(rows=[src_cols, tgt_cols], pad=6, pad_value=0) + lines = [ + f"dataset={str(dataset_name) if dataset_name is not None else 'unknown'} scene={str(scene) if scene is not None else 'unknown'} step={int(step) if step is not None else -1}", + f"src_idx={int(src_idx) if src_idx is not None else -1} tgt_idx={int(tgt_idx) if tgt_idx is not None else -1}", + f"src_w2c={_pose_to_text(src_pose_w2c)}", + f"tgt_w2c={_pose_to_text(tgt_pose_w2c)}", + ] + grid = _append_text_header(erp_grid, lines) + + cube_rows = _make_cube_rows( + src_cube_gt_u8=src_cube_gt_u8, + src_cube_pred_linear=src_cube_pred_linear, + src_cube_alpha=src_cube_alpha, + tgt_cube_gt_u8=tgt_cube_gt_u8, + tgt_cube_pred_linear=tgt_cube_pred_linear, + tgt_cube_alpha=tgt_cube_alpha, + ) + save_image(grid, out_file) + if cube_rows is not None: + cube_grid = _concat_grid(rows=cube_rows, pad=6, pad_value=0) + cube_lines = lines + ["cubemap_rows=src_gt/src_pred/src_err/tgt_gt/tgt_pred/tgt_err"] + cube_grid = _append_text_header(cube_grid, cube_lines) + cube_file = out_file.with_name(f"{out_file.stem}_cubemap{out_file.suffix}") + save_image(cube_grid, cube_file) + diff --git a/unisharp/utils/unik3d_adapter.py b/unisharp/utils/unik3d_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..50dc19bf5f1676c6a2cff11e94b3dbd2c88ee130 --- /dev/null +++ b/unisharp/utils/unik3d_adapter.py @@ -0,0 +1,594 @@ +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Any + +import torch +from unisharp.utils.pixel_convention import integer_pixel_center_grid + + +def _enable_unik3d_decoder_feature_capture(model: torch.nn.Module) -> None: + try: + rm = model.pixel_decoder.radial_module # type: ignore[attr-defined] + except Exception: + return + if getattr(rm, "_unisharp_process_wrapped", False): + return + + import types + + try: + orig_process = rm.process + except Exception: + return + + def wrapped_process(self, features_list, rays_embeddings): # type: ignore[no-untyped-def] + if bool(getattr(self, "_unisharp_detach_rays_embeddings", False)): + rays_embeddings = rays_embeddings.detach() + out_features, init_latents = orig_process(features_list, rays_embeddings) + self._unisharp_last_out_features = out_features + self._unisharp_last_init_latents = init_latents + return out_features, init_latents + + rm.process = types.MethodType(wrapped_process, rm) # type: ignore[method-assign] + rm._unisharp_process_wrapped = True + + +def _enable_unik3d_force_gt_rays(model: torch.nn.Module) -> None: + try: + pixel_decoder = model.pixel_decoder # type: ignore[attr-defined] + except Exception: + return + if getattr(pixel_decoder, "_unisharp_run_camera_wrapped", False): + return + + import types + from einops import rearrange + + try: + orig_run_camera = pixel_decoder.run_camera + except Exception: + return + + def wrapped_run_camera(self, cls_tokens, original_shapes, rays_gt): # type: ignore[no-untyped-def] + force_gt = bool(getattr(self, "_unisharp_force_gt_rays", False)) and torch.is_tensor(rays_gt) + if force_gt: + old_camera_gt = getattr(self, "camera_gt", None) + try: + self.camera_gt = False + intrinsics, pred_rays = orig_run_camera(cls_tokens, original_shapes, rays_gt) + finally: + if old_camera_gt is not None: + self.camera_gt = old_camera_gt + self._unisharp_last_pred_rays_flat = pred_rays + rays = rearrange(rays_gt, "b c h w -> b (h w) c") + return intrinsics, rays + + intrinsics, rays = orig_run_camera(cls_tokens, original_shapes, rays_gt) + self._unisharp_last_pred_rays_flat = rays + return intrinsics, rays + + pixel_decoder.run_camera = types.MethodType(wrapped_run_camera, pixel_decoder) # type: ignore[method-assign] + pixel_decoder._unisharp_run_camera_wrapped = True + + +def postprocess_unik3d_tensor( + tensor: torch.Tensor, + *, + padded_hw: tuple[int, int], + paddings: tuple[int, int, int, int], + interpolation_mode: str, +) -> torch.Tensor: + _ = _try_import_unik3d() + from unik3d.models.unik3d import _postprocess # type: ignore + + return _postprocess( + tensor, + padded_hw, + paddings=paddings, + interpolation_mode=interpolation_mode, + ) + + +def _erp_rays_panosplatt3r_opencv( + h: int, w: int, device: torch.device, dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + xs = (torch.arange(w, device=device, dtype=dtype) + 0.5) / float(w) + ys = (torch.arange(h, device=device, dtype=dtype) + 0.5) / float(h) + lon = (xs - 0.5) * (2.0 * torch.pi) + lat = -(ys - 0.5) * torch.pi + lat2d, lon2d = torch.meshgrid(lat, lon, indexing="ij") + cos_lat = torch.cos(lat2d) + x = torch.sin(lon2d) * cos_lat + y_up = torch.sin(lat2d) + z = torch.cos(lon2d) * cos_lat + y = -y_up + rays = torch.stack([x, y, z], dim=0) + rays = rays / torch.norm(rays, dim=0, keepdim=True).clamp(min=1e-6) + return rays + + +def _pinhole_rays_opencv( + intrinsics: torch.Tensor, + h: int, + w: int, + *, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + if intrinsics.ndim == 2: + intrinsics = intrinsics.unsqueeze(0) + if intrinsics.ndim != 3 or tuple(intrinsics.shape[1:]) != (3, 3): + raise ValueError(f"Expected intrinsics (B,3,3), got {tuple(intrinsics.shape)}") + B = int(intrinsics.shape[0]) + K = intrinsics.to(device=device, dtype=dtype) + fx = K[:, 0, 0].view(B, 1, 1) + fy = K[:, 1, 1].view(B, 1, 1) + cx = K[:, 0, 2].view(B, 1, 1) + cy = K[:, 1, 2].view(B, 1, 1) + + uu, vv = integer_pixel_center_grid(h, w, device=device, dtype=dtype) + uu = uu[None].expand(B, -1, -1) + vv = vv[None].expand(B, -1, -1) + + x = (uu - cx) / fx + y = (vv - cy) / fy + z = torch.ones_like(x) + rays = torch.stack([x, y, z], dim=1) + rays = rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-6) + return rays + + +def _fisheye624_rays_opencv_integer_centers( + fisheye624_cls: Any, + params: torch.Tensor, + h: int, + w: int, + *, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + if params.ndim == 1: + params = params.unsqueeze(0) + params = params.to(device=device, dtype=torch.float32) + uu, vv = integer_pixel_center_grid(int(h), int(w), device=device, dtype=torch.float32) + uv = torch.stack([uu, vv], dim=0).unsqueeze(0).expand(int(params.shape[0]), -1, -1, -1) + rays = fisheye624_cls(params=params).unproject(uv).to(dtype=dtype) + return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-6) + + +def _fill_invalid_rgb_with_valid_mean( + rgb: torch.Tensor, + validity_mask: torch.Tensor | None, +) -> torch.Tensor: + if validity_mask is None: + return rgb + if rgb.ndim != 4: + raise ValueError(f"Expected rgb shape (B,3,H,W), got {tuple(rgb.shape)}") + mask = validity_mask.to(device=rgb.device, dtype=torch.float32) + if mask.ndim == 3: + mask = mask.unsqueeze(1) + if mask.ndim != 4 or int(mask.shape[1]) != 1: + raise ValueError(f"Expected validity_mask shape (B,1,H,W), got {tuple(mask.shape)}") + if int(mask.shape[0]) == 1 and int(rgb.shape[0]) > 1: + mask = mask.expand(int(rgb.shape[0]), -1, -1, -1) + if tuple(mask.shape[-2:]) != tuple(rgb.shape[-2:]): + import torch.nn.functional as F + + mask = F.interpolate(mask, size=rgb.shape[-2:], mode="nearest") + + rgb_f = rgb.to(dtype=torch.float32) + valid = mask > 0.5 + count = valid.to(dtype=rgb_f.dtype).sum(dim=(2, 3), keepdim=True).clamp(min=1.0) + fill = (rgb_f * valid.to(dtype=rgb_f.dtype)).sum(dim=(2, 3), keepdim=True) / count + return torch.where(valid.expand_as(rgb_f), rgb_f, fill) + + +def build_unik3d_camera_rays( + rgb_u8: torch.Tensor, + *, + device: torch.device, + intrinsics: torch.Tensor | None = None, + camera_params: torch.Tensor | None = None, + camera_model: str | None = None, + hfov: float | None = None, + vfov: float | None = None, +) -> tuple[Any, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + if rgb_u8.ndim == 3: + rgb_u8 = rgb_u8.unsqueeze(0) + if rgb_u8.ndim != 4: + raise ValueError(f"Expected rgb_u8 shape (3,H,W) or (B,3,H,W), got {tuple(rgb_u8.shape)}") + bsz, _, h, w = rgb_u8.shape + + _ = _try_import_unik3d() + from unik3d.utils.camera import BatchCamera, Fisheye624, Pinhole, Spherical # type: ignore + + if intrinsics is not None: + if intrinsics.ndim == 2: + intrinsics = intrinsics.unsqueeze(0) + if intrinsics.shape != (bsz, 3, 3): + raise ValueError( + f"Expected intrinsics shape {(bsz, 3, 3)}, got {tuple(intrinsics.shape)}" + ) + intrinsics_orig = intrinsics.to(device=device, dtype=torch.float32).clone() + cameras = [ + BatchCamera.from_camera(Pinhole(K=intrinsics_orig[i : i + 1].clone())) + for i in range(bsz) + ] + camera = torch.cat(cameras, dim=0).to(device) + rays_k = _pinhole_rays_opencv(intrinsics_orig, h, w, device=device, dtype=torch.float32) + return camera, rays_k, intrinsics_orig, rays_k + + if camera_params is not None: + cam_model = str(camera_model or "").lower() + if cam_model != "fisheye624": + raise ValueError( + f"Unsupported camera_model={camera_model!r}; fisheye training now expects OPENCV_FISHEYE/Fisheye624." + ) + if camera_params.ndim == 1: + camera_params = camera_params.unsqueeze(0) + expected_dim = 16 + if camera_params.shape != (bsz, expected_dim): + raise ValueError( + f"Expected camera_params shape {(bsz, expected_dim)}, got {tuple(camera_params.shape)}" + ) + camera_params = camera_params.to(device=device, dtype=torch.float32).clone() + cameras = [BatchCamera.from_camera(Fisheye624(params=camera_params[i : i + 1].clone())) for i in range(bsz)] + camera = torch.cat(cameras, dim=0).to(device) + rays = _fisheye624_rays_opencv_integer_centers( + Fisheye624, + camera_params, + h, + w, + device=device, + dtype=torch.float32, + ) + return camera, rays, None, rays + + params = torch.tensor( + [ + 1.0, + 1.0, + (w - 1) / 2.0, + (h - 1) / 2.0, + float(w), + float(h), + float(hfov if hfov is not None else 2.0 * torch.pi) / 2.0, + float(vfov if vfov is not None else torch.pi) / 2.0, + ], + dtype=torch.float32, + device=device, + ) + camera = BatchCamera.from_camera(Spherical(params=params)).to(device) + rays = _erp_rays_panosplatt3r_opencv(h, w, device=device, dtype=torch.float32) + rays = rays.unsqueeze(0).expand(bsz, -1, -1, -1).contiguous() + return camera, rays, None, rays + + +def _try_import_unik3d() -> Any: + try: + import unik3d # type: ignore + + return unik3d + except Exception: + repo_root = Path(__file__).resolve() + ml_unisharp_root = repo_root.parents[2] + unik3d_root = ml_unisharp_root / "UniK3D" + if unik3d_root.exists(): + sys.path.insert(0, str(unik3d_root)) + import unik3d # type: ignore + + return unik3d + + +def _get_ml_unisharp_root() -> Path: + repo_root = Path(__file__).resolve() + return repo_root.parents[2] + + +def _setup_unik3d_repo_caches(cache_root: Path | None = None) -> None: + if cache_root is None: + cache_root = _get_ml_unisharp_root() / "UniK3D" / "checkpoints" + hf_cache = cache_root / "huggingface" + torchhub_cache = cache_root / "torchhub" + hf_cache.mkdir(parents=True, exist_ok=True) + torchhub_cache.mkdir(parents=True, exist_ok=True) + + try: + torch.hub.set_dir(str(torchhub_cache)) + except Exception: + pass + + import os + + os.environ["HF_HOME"] = str(hf_cache) + os.environ["HUGGINGFACE_HUB_CACHE"] = str(hf_cache) + os.environ["HF_HUB_CACHE"] = str(hf_cache) + + +def load_unik3d_model( + backbone: str = "vitl", + pretrained: bool = True, + device: torch.device | None = None, + cache_root: Path | None = None, +) -> torch.nn.Module: + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + _ = _try_import_unik3d() + _setup_unik3d_repo_caches(cache_root=cache_root) + try: + from hubconf import UniK3D as UniK3DHub # type: ignore + + model = UniK3DHub(backbone=backbone, pretrained=pretrained, config_variant="eval", cache_root=cache_root) + except Exception as exc: + if os.environ.get("HF_HUB_OFFLINE", "").strip() in {"1", "ON", "YES", "TRUE"}: + raise RuntimeError("UniK3D local/offline load failed; refusing to fall back to HuggingFace Hub.") from exc + from unik3d.models import UniK3D # type: ignore + + model = UniK3D.from_pretrained(f"lpiccinelli/unik3d-{backbone}") + + model.eval() + if not hasattr(model, "resolution_level"): + try: + setattr(model, "resolution_level", 0) + except Exception: + pass + model.to(device) + return model + + +@torch.no_grad() +def infer_unik3d_spherical( + model: torch.nn.Module, + rgb_u8: torch.Tensor, + hfov: float, + vfov: float, +) -> dict[str, torch.Tensor]: + if rgb_u8.ndim != 3: + raise ValueError("Expected rgb_u8 shape (3,H,W).") + + dev = next(model.parameters()).device + with torch.autocast(device_type=dev.type, enabled=False): + out = forward_unik3d_spherical(model, rgb_u8, hfov=hfov, vfov=vfov, normalize=True) + return out + + +def forward_unik3d_camera_rays( + model: torch.nn.Module, + rgb_u8: torch.Tensor, + *, + normalize: bool = True, + validity_mask: torch.Tensor | None = None, + rays: torch.Tensor | None = None, +) -> dict[str, torch.Tensor]: + if rgb_u8.ndim == 3: + rgb_u8 = rgb_u8.unsqueeze(0) + if rgb_u8.ndim != 4: + raise ValueError(f"Expected rgb_u8 shape (3,H,W) or (B,3,H,W), got {tuple(rgb_u8.shape)}") + + _ = _try_import_unik3d() + from unik3d.models.unik3d import ( # type: ignore + IMAGENET_DATASET_MEAN, + IMAGENET_DATASET_STD, + _postprocess, + get_paddings, + get_resize_factor, + ) + import torch.nn.functional as F + import torchvision.transforms.v2.functional as TF + + device = next(model.parameters()).device + bsz, _, h, w = rgb_u8.shape + rgb = _fill_invalid_rgb_with_valid_mean(rgb_u8.to(device), validity_mask) + + ratio_bounds = model.shape_constraints["ratio_bounds"] # type: ignore[attr-defined] + pixels_bounds = [ + model.shape_constraints["pixels_min"], # type: ignore[attr-defined] + model.shape_constraints["pixels_max"], # type: ignore[attr-defined] + ] + if hasattr(model, "resolution_level"): + pixels_range = pixels_bounds[1] - pixels_bounds[0] + interval = pixels_range / 10 + new_lowbound = model.resolution_level * interval + pixels_bounds[0] + new_upbound = (model.resolution_level + 1) * interval + pixels_bounds[0] + pixels_bounds = (new_lowbound, new_upbound) + + paddings, (padded_h, padded_w) = get_paddings((h, w), ratio_bounds) + pad_left, pad_right, pad_top, pad_bottom = paddings + resize_factor, (new_h, new_w) = get_resize_factor((padded_h, padded_w), pixels_bounds) + + if normalize: + rgb_f = TF.normalize( + rgb.float() / 255.0, + mean=IMAGENET_DATASET_MEAN, + std=IMAGENET_DATASET_STD, + ) + else: + rgb_f = rgb.float() / 255.0 + rgb_f = F.pad(rgb_f, (pad_left, pad_right, pad_top, pad_bottom), value=0.0) + rgb_f = F.interpolate(rgb_f, size=(new_h, new_w), mode="bilinear", align_corners=False) + + validity_mask_resized = None + if torch.is_tensor(validity_mask): + validity_mask_resized = validity_mask.to(device=device, dtype=torch.float32) + if validity_mask_resized.ndim == 3: + validity_mask_resized = validity_mask_resized.unsqueeze(1) + if int(validity_mask_resized.shape[0]) == 1 and bsz > 1: + validity_mask_resized = validity_mask_resized.expand(bsz, -1, -1, -1) + if tuple(validity_mask_resized.shape[-2:]) != (h, w): + validity_mask_resized = F.interpolate(validity_mask_resized, size=(h, w), mode="nearest") + validity_mask_resized = F.pad( + validity_mask_resized, + (max(0, pad_left), max(0, pad_right), max(0, pad_top), max(0, pad_bottom)), + value=0.0, + ) + validity_mask_resized = F.interpolate(validity_mask_resized, size=(new_h, new_w), mode="nearest") + + inputs: dict[str, Any] = {"image": rgb_f} + if validity_mask_resized is not None: + inputs["validity_mask"] = validity_mask_resized > 0.5 + rays_resized = None + if torch.is_tensor(rays): + rays_resized = rays.to(device=device, dtype=torch.float32) + if rays_resized.ndim == 3: + rays_resized = rays_resized.unsqueeze(0) + if rays_resized.ndim != 4 or int(rays_resized.shape[1]) != 3: + raise ValueError(f"Expected rays shape (3,H,W) or (B,3,H,W), got {tuple(rays_resized.shape)}") + if int(rays_resized.shape[0]) == 1 and bsz > 1: + rays_resized = rays_resized.expand(bsz, -1, -1, -1) + if tuple(rays_resized.shape[-2:]) != (h, w): + rays_resized = F.interpolate(rays_resized, size=(h, w), mode="bilinear", align_corners=False) + rays_resized = rays_resized / torch.norm(rays_resized, dim=1, keepdim=True).clamp(min=1e-5) + rays_resized = F.pad( + rays_resized, + (max(0, pad_left), max(0, pad_right), max(0, pad_top), max(0, pad_bottom)), + value=0.0, + ) + rays_resized = F.interpolate(rays_resized, size=(new_h, new_w), mode="bilinear", align_corners=False) + rays_resized = rays_resized / torch.norm(rays_resized, dim=1, keepdim=True).clamp(min=1e-5) + inputs["rays"] = rays_resized + + _enable_unik3d_decoder_feature_capture(model) + _enable_unik3d_force_gt_rays(model) + pixel_decoder = getattr(model, "pixel_decoder", None) + old_force = bool(getattr(pixel_decoder, "_unisharp_force_gt_rays", False)) if pixel_decoder is not None else False + if pixel_decoder is not None: + pixel_decoder._unisharp_force_gt_rays = torch.is_tensor(rays_resized) + try: + _, model_outputs = model.encode_decode(inputs, image_metas={}) + finally: + if pixel_decoder is not None: + pixel_decoder._unisharp_force_gt_rays = old_force + + out: dict[str, torch.Tensor] = {} + out["confidence"] = _postprocess( + model_outputs["confidence"], + (padded_h, padded_w), + paddings=paddings, + interpolation_mode=model.interpolation_mode, # type: ignore[attr-defined] + ) + distance = _postprocess( + model_outputs["distance"], + (padded_h, padded_w), + paddings=paddings, + interpolation_mode=model.interpolation_mode, # type: ignore[attr-defined] + ).clamp(min=1e-4) + points = _postprocess( + model_outputs["points"], + (padded_h, padded_w), + paddings=paddings, + interpolation_mode=model.interpolation_mode, # type: ignore[attr-defined] + ) + rays_out = _postprocess( + model_outputs["rays"], + (padded_h, padded_w), + paddings=paddings, + interpolation_mode=model.interpolation_mode, # type: ignore[attr-defined] + ) + pred_rays_out = rays_out + try: + pred_flat = getattr(model.pixel_decoder, "_unisharp_last_pred_rays_flat", None) # type: ignore[attr-defined] + if torch.is_tensor(pred_flat): + pred_internal = pred_flat.reshape(bsz, new_h, new_w, 3).permute(0, 3, 1, 2).contiguous() + pred_rays_out = _postprocess( + pred_internal, + (padded_h, padded_w), + paddings=paddings, + interpolation_mode=model.interpolation_mode, # type: ignore[attr-defined] + ) + except Exception: + pred_rays_out = rays_out + out["points"] = points + out["depth"] = points[:, -1:] + out["distance"] = distance + out["rays"] = pred_rays_out / torch.norm(pred_rays_out, dim=1, keepdim=True).clamp(min=1e-5) + out["ray_conditioning_rays"] = rays_out / torch.norm(rays_out, dim=1, keepdim=True).clamp(min=1e-5) + out["lowres_features"] = model_outputs.get("lowres_features", None) + out["_unisharp_internal_rays"] = model_outputs["rays"] + out["_unisharp_postprocess_padded_hw"] = (int(padded_h), int(padded_w)) + out["_unisharp_postprocess_paddings"] = tuple(int(x) for x in paddings) + out["_unisharp_postprocess_interpolation_mode"] = str(model.interpolation_mode) # type: ignore[attr-defined] + + try: + rm = model.pixel_decoder.radial_module # type: ignore[attr-defined] + init_latents = getattr(rm, "_unisharp_last_init_latents", None) + out_feats = getattr(rm, "_unisharp_last_out_features", None) + if init_latents is not None and out_feats is not None: + out["pyramid_features"] = [init_latents, *list(out_feats)] + except Exception: + pass + return out + + +def forward_unik3d_spherical( + model: torch.nn.Module, + rgb_u8: torch.Tensor, + hfov: float, + vfov: float, + normalize: bool = True, +) -> dict[str, torch.Tensor]: + del hfov, vfov + return forward_unik3d_camera_rays( + model, + rgb_u8, + normalize=normalize, + ) + + +def forward_unik3d_pinhole( + model: torch.nn.Module, + rgb_u8: torch.Tensor, + intrinsics: torch.Tensor, + normalize: bool = True, +) -> dict[str, torch.Tensor]: + device = next(model.parameters()).device + _, rays, _, _ = build_unik3d_camera_rays( + rgb_u8, + device=device, + intrinsics=intrinsics.to(device=device, dtype=torch.float32), + ) + return forward_unik3d_camera_rays( + model, + rgb_u8, + normalize=normalize, + rays=rays, + ) + + +def forward_unik3d_fisheye624( + model: torch.nn.Module, + rgb_u8: torch.Tensor, + camera_params: torch.Tensor, + normalize: bool = True, + validity_mask: torch.Tensor | None = None, +) -> dict[str, torch.Tensor]: + device = next(model.parameters()).device + _, rays, _, _ = build_unik3d_camera_rays( + rgb_u8, + device=device, + camera_params=camera_params.to(device=device, dtype=torch.float32), + camera_model="fisheye624", + ) + return forward_unik3d_camera_rays( + model, + rgb_u8, + normalize=normalize, + validity_mask=validity_mask, + rays=rays, + ) + + +@torch.no_grad() +def infer_unik3d_pinhole( + model: torch.nn.Module, + rgb_u8: torch.Tensor, + intrinsics: torch.Tensor, +) -> dict[str, torch.Tensor]: + dev = next(model.parameters()).device + with torch.autocast(device_type=dev.type, enabled=False): + out = forward_unik3d_pinhole(model, rgb_u8, intrinsics=intrinsics, normalize=True) + return out + diff --git a/unisharp/utils/vis.py b/unisharp/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..6e331236ca209fa46926e5e61b96c60500135c16 --- /dev/null +++ b/unisharp/utils/vis.py @@ -0,0 +1,56 @@ + +from __future__ import annotations + +import numpy as np +import torch +from matplotlib import pyplot as plt + +METRIC_DEPTH_MAX_CLAMP_METER = 50.0 + + +def colorize_depth(depth: torch.Tensor, val_max: float = 10.0) -> torch.Tensor: + depth_channels = depth.shape[-3] + + if depth_channels == 1: + return colorize_scalar_map( + depth.squeeze(-3), val_min=0.0, val_max=val_max, color_map="turbo" + ) + + else: + colored_depths = [] + for c in range(depth_channels): + colored_depths.append( + colorize_scalar_map( + depth[..., c, :, :], val_min=0.0, val_max=val_max, color_map="turbo" + ) + ) + return torch.cat(colored_depths, dim=-1) + + +def colorize_alpha(alpha: torch.Tensor) -> torch.Tensor: + return colorize_scalar_map(alpha.squeeze(-3), val_min=0.0, val_max=1.0, color_map="coolwarm") + + +def colorize_scalar_map( + scalar_map: torch.Tensor, val_min=0.0, val_max=1.0, color_map: str = "jet" +) -> torch.Tensor: + if scalar_map.ndim not in (2, 3, 4): + raise ValueError("Only scalar maps of 2 or 3 or 4 dimensions supported.") + + cmap = plt.get_cmap(color_map) + + scalar_map_np = scalar_map.detach().cpu().float().numpy() + scalar_map_np = (scalar_map_np - val_min) / (val_max - val_min) + scalar_map_np = np.clip(scalar_map_np, a_min=0.0, a_max=1.0) + + color_map_np = cmap(scalar_map_np)[..., :3] + tensor = torch.as_tensor(color_map_np * 255.0, dtype=torch.uint8) + + if tensor.ndim == 3: + return tensor.permute(2, 0, 1) + elif tensor.ndim == 4: + return tensor.permute(0, 3, 1, 2) + elif tensor.ndim == 5: + return tensor.permute(0, 1, 4, 2, 3) + else: + assert False, "Invalid tensor shape encountered." diff --git a/unisharp/validation/__init__.py b/unisharp/validation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/unisharp/validation/io_common.py b/unisharp/validation/io_common.py new file mode 100644 index 0000000000000000000000000000000000000000..feeb5a5aaf76a890d905a5fffc5c2f75de803e7f --- /dev/null +++ b/unisharp/validation/io_common.py @@ -0,0 +1,563 @@ +from __future__ import annotations + +import io +import re +import struct +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +from unisharp.utils.pixel_convention import scale_intrinsics_align_corners_false + + +def resize_k3_align_corners_false(k3: torch.Tensor, *, sx: float, sy: float) -> torch.Tensor: + return scale_intrinsics_align_corners_false(k3, sx=float(sx), sy=float(sy)) + + +def resize_rgb_depth_k_fit_box_no_pad( + rgb_u8_bchw: torch.Tensor, + k3_b33: torch.Tensor, + *, + max_h: int, + max_w: int, + depth_b1hw: torch.Tensor | None = None, + size_divisor: int = 14, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + old_h, old_w = int(rgb_u8_bchw.shape[-2]), int(rgb_u8_bchw.shape[-1]) + if old_h <= int(max_h) and old_w <= int(max_w): + return rgb_u8_bchw, k3_b33, depth_b1hw + max_h = max(1, int(max_h)) + max_w = max(1, int(max_w)) + size_divisor = max(1, int(size_divisor)) + scale_h = float(max_h) / float(old_h) + scale_w = float(max_w) / float(old_w) + if scale_h <= scale_w: + new_h = max(size_divisor, (max_h // size_divisor) * size_divisor) + new_w_raw = int(round(float(old_w) * float(new_h) / float(old_h))) + new_w = max(size_divisor, min(max_w, (new_w_raw // size_divisor) * size_divisor)) + else: + new_w = max(size_divisor, (max_w // size_divisor) * size_divisor) + new_h_raw = int(round(float(old_h) * float(new_w) / float(old_w))) + new_h = max(size_divisor, min(max_h, (new_h_raw // size_divisor) * size_divisor)) + new_h = max(1, min(max_h, new_h)) + new_w = max(1, min(max_w, new_w)) + if new_h == old_h and new_w == old_w: + return rgb_u8_bchw, k3_b33, depth_b1hw + rgb_resized = F.interpolate( + rgb_u8_bchw.float(), + size=(new_h, new_w), + mode="bilinear", + align_corners=False, + ).round().clamp(0, 255).to(torch.uint8) + sx = float(new_w) / float(old_w) + sy = float(new_h) / float(old_h) + k3_out = resize_k3_align_corners_false(k3_b33, sx=sx, sy=sy) + depth_out = None if depth_b1hw is None else F.interpolate(depth_b1hw, size=(new_h, new_w), mode="nearest") + return rgb_resized, k3_out, depth_out + + +def decode_rgb_u8(jpeg_bytes_tensor: torch.Tensor) -> torch.Tensor: + img_bytes = jpeg_bytes_tensor.detach().cpu().numpy().tobytes() + pil = Image.open(io.BytesIO(img_bytes)).convert("RGB") + return torch.from_numpy(np.array(pil)).permute(2, 0, 1).contiguous().to(torch.uint8) + + +def torch_load_any(path: Path) -> Any: + try: + return torch.load(path, map_location="cpu", weights_only=False) + except TypeError: + return torch.load(path, map_location="cpu") + + +def pseudo_depth_safe_key(value: Any) -> str: + text = str(value).strip().replace("\\", "__").replace("/", "__") + text = re.sub(r"[^A-Za-z0-9_.-]+", "_", text) + return text[:180] if text else "unknown" + + +def validation_pseudo_depth_path(root: Path, dataset: str, scene: Any, frame_idx: Any) -> Path: + if isinstance(frame_idx, int): + frame_key = f"{int(frame_idx):05d}" + elif torch.is_tensor(frame_idx) and frame_idx.numel() == 1: + frame_key = f"{int(frame_idx.item()):05d}" + else: + frame_key = pseudo_depth_safe_key(frame_idx) + return Path(root) / str(dataset) / pseudo_depth_safe_key(scene) / f"{frame_key}.npz" + + +def normalize_depth_kind(value: Any, *, default: str = "distance") -> str: + text = str(value).strip().lower().replace("-", "_") + if text in {"zdepth", "z_depth", "z"}: + return "zdepth" + if text in {"distance", "dist", "radial", "ray_distance"}: + return "distance" + return str(default) + + +def distance_to_z_depth_pinhole(distance_1hw: torch.Tensor, intrinsics_k3: torch.Tensor) -> torch.Tensor: + if distance_1hw.ndim == 2: + distance_1hw = distance_1hw.unsqueeze(0) + if distance_1hw.ndim != 3 or int(distance_1hw.shape[0]) != 1: + raise ValueError(f"Expected distance shape (1,H,W), got {tuple(distance_1hw.shape)}") + d = distance_1hw.to(torch.float32) + h, w = int(d.shape[-2]), int(d.shape[-1]) + k = intrinsics_k3.to(dtype=torch.float32, device=d.device) + ys = torch.arange(h, device=d.device, dtype=torch.float32) + xs = torch.arange(w, device=d.device, dtype=torch.float32) + vv, uu = torch.meshgrid(ys, xs, indexing="ij") + x = (uu - k[0, 2]) / k[0, 0].clamp(min=1e-6) + y = (vv - k[1, 2]) / k[1, 1].clamp(min=1e-6) + ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8) + return (d[0] * ray_z).unsqueeze(0) + + +def load_validation_pseudo_depth( + root: Path | None, + dataset: str, + scene: Any, + frame_idx: Any, + intrinsics_k3: torch.Tensor | None = None, +) -> torch.Tensor | None: + if root is None: + return None + path = validation_pseudo_depth_path(Path(root), dataset, scene, frame_idx) + if not path.exists(): + legacy_path = path.with_suffix(".pt") + if not legacy_path.exists(): + return None + path = legacy_path + try: + if path.suffix.lower() == ".npz": + with np.load(path, allow_pickle=False) as payload_np: + if "z_depth_m" in payload_np: + depth = torch.from_numpy(np.array(payload_np["z_depth_m"], copy=True)) + depth_kind = "zdepth" + elif "depth_m" in payload_np: + depth = torch.from_numpy(np.array(payload_np["depth_m"], copy=True)) + depth_kind_raw = payload_np["depth_kind"] if "depth_kind" in payload_np else "distance" + depth_kind = normalize_depth_kind( + depth_kind_raw.tolist() if hasattr(depth_kind_raw, "tolist") else depth_kind_raw, + default="distance", + ) + elif "distance_m" in payload_np: + depth = torch.from_numpy(np.array(payload_np["distance_m"], copy=True)) + depth_kind = "distance" + else: + return None + else: + payload = torch_load_any(path) + if isinstance(payload, dict): + depth = payload.get("z_depth_m", None) + if not torch.is_tensor(depth): + depth = payload.get("depth_m", None) + depth_kind = normalize_depth_kind(payload.get("depth_kind", "distance"), default="distance") + else: + depth_kind = "zdepth" + else: + depth = payload + depth_kind = "distance" + if torch.is_tensor(depth): + pass + elif isinstance(depth, np.ndarray): + depth = torch.from_numpy(depth) + else: + return None + if depth.ndim == 2: + depth = depth.unsqueeze(0) + if depth.ndim == 3 and int(depth.shape[0]) == 1: + depth = depth.to(torch.float32) + else: + return None + if normalize_depth_kind(depth_kind, default="distance") != "zdepth": + if intrinsics_k3 is None: + return None + depth = distance_to_z_depth_pinhole(depth, intrinsics_k3=intrinsics_k3) + valid = torch.isfinite(depth) & (depth > 0.0) + if int(valid.sum().item()) <= 0: + return None + return torch.where(valid, depth, torch.zeros_like(depth)) + except Exception: + return None + + +def load_validation_pseudo_distance( + root: Path | None, + dataset: str, + scene: Any, + frame_idx: Any, +) -> torch.Tensor | None: + if root is None: + return None + path = validation_pseudo_depth_path(Path(root), dataset, scene, frame_idx) + if not path.exists(): + legacy_path = path.with_suffix(".pt") + if not legacy_path.exists(): + return None + path = legacy_path + try: + if path.suffix.lower() == ".npz": + with np.load(path, allow_pickle=False) as payload_np: + if "distance_m" in payload_np: + depth = torch.from_numpy(np.array(payload_np["distance_m"], copy=True)) + elif "depth_m" in payload_np: + depth_kind_raw = payload_np["depth_kind"] if "depth_kind" in payload_np else "distance" + depth_kind = normalize_depth_kind( + depth_kind_raw.tolist() if hasattr(depth_kind_raw, "tolist") else depth_kind_raw, + default="distance", + ) + if depth_kind != "distance": + return None + depth = torch.from_numpy(np.array(payload_np["depth_m"], copy=True)) + else: + return None + else: + payload = torch_load_any(path) + if isinstance(payload, dict): + depth = payload.get("distance_m", None) + if not torch.is_tensor(depth): + depth = payload.get("depth_m", None) + if normalize_depth_kind(payload.get("depth_kind", "distance"), default="distance") != "distance": + return None + else: + depth = payload + if isinstance(depth, np.ndarray): + depth = torch.from_numpy(depth) + if not torch.is_tensor(depth): + return None + if depth.ndim == 2: + depth = depth.unsqueeze(0) + if depth.ndim != 3 or int(depth.shape[0]) != 1: + return None + depth = depth.to(torch.float32) + valid = torch.isfinite(depth) & (depth > 0.0) + if int(valid.sum().item()) <= 0: + return None + return torch.where(valid, depth, torch.zeros_like(depth)) + except Exception: + return None + + +def load_png_rgb_u8(path: Path) -> torch.Tensor: + arr = np.array(Image.open(path)) + if arr.ndim != 3 or arr.shape[2] != 3: + raise ValueError(f"Expected RGB image at {path}, got shape={arr.shape}") + return torch.from_numpy(arr.astype(np.uint8)).permute(2, 0, 1).contiguous() + + +def load_png_depth_m(path: Path) -> torch.Tensor: + dep = torch.from_numpy(np.array(Image.open(path))).to(torch.float32) + if float(dep.max().item()) > 200.0: + dep = dep / 1000.0 + return dep.unsqueeze(0) + + +def resolve_replica_test_root(root: Path) -> Path: + root = Path(root) + + def _looks_like_replica_test_dir(path: Path) -> bool: + if not path.exists() or not path.is_dir(): + return False + for child in path.iterdir(): + if child.is_dir() and (child / "pano").exists() and ( + (child / "rotation.npy").exists() or (child / "meta.pt").exists() + ): + return True + return False + + candidates = [ + root / "replica_dataset" / "test", + root / "test", + root, + ] + for candidate in candidates: + if _looks_like_replica_test_dir(candidate): + return candidate + return candidates[0] if candidates[0].exists() else root + + +def read_manifest_lines(path: Path | None, max_lines: int = 0) -> list[str]: + if path is None: + return [] + path = Path(path) + if not path.exists(): + raise FileNotFoundError(path) + lines = [line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip()] + return lines[: int(max_lines)] if int(max_lines) > 0 else lines + + +def colmap_scene_roots(root: Path) -> list[Path]: + if (root / "images_4").exists() or (root / "images").exists() or (root / "poses_bounds.npy").exists() or (root / "sparse" / "0").exists(): + return [root] + return sorted( + [ + p + for p in root.iterdir() + if p.is_dir() and ((p / "images_4").exists() or (p / "images").exists() or (p / "poses_bounds.npy").exists() or (p / "sparse" / "0").exists()) + ] + ) if root.exists() else [] + + +def colmap_image_dir(scene_root: Path) -> Path: + if (scene_root / "images_4").exists(): + return scene_root / "images_4" + if (scene_root / "images").exists(): + return scene_root / "images" + return scene_root + + +def load_hm3d_pose(scene_dir: Path) -> tuple[np.ndarray, np.ndarray]: + rot_path = scene_dir / "rotation.npy" + trans_path = scene_dir / "translation.npy" + if rot_path.exists() and trans_path.exists(): + return np.load(rot_path).astype(np.float32), np.load(trans_path).astype(np.float32) + meta = torch_load_any(scene_dir / "meta.pt") + if isinstance(meta, dict) and ("R" in meta) and ("t" in meta): + return np.asarray(meta["R"], dtype=np.float32), np.asarray(meta["t"], dtype=np.float32) + cams = meta["cameras"].to(torch.float32) + return cams[:, :3, :3].cpu().numpy(), cams[:, :3, 3].cpu().numpy() + + +def nerf_c2w_to_opencv_c2w(c2w_in: torch.Tensor | np.ndarray) -> torch.Tensor: + c2w = torch.as_tensor(c2w_in, dtype=torch.float32).clone() + if tuple(c2w.shape) != (4, 4): + raise ValueError(f"Expected c2w shape (4,4), got {tuple(c2w.shape)}") + c2w[:3, 1:3] *= -1.0 + return c2w + + +def colmap_pose_scale_from_bounds(root: Path, bd_factor: float = 0.75) -> float: + poses_bounds_path = root / "poses_bounds.npy" + if not poses_bounds_path.exists(): + return 1.0 + poses_bounds = np.load(poses_bounds_path) + if poses_bounds.ndim != 2 or poses_bounds.shape[1] < 17: + return 1.0 + bounds = np.asarray(poses_bounds[:, -2:], dtype=np.float32) + finite = bounds[np.isfinite(bounds)] + positive = finite[finite > 0.0] + if positive.size == 0: + return 1.0 + min_bound = float(positive.min()) + return 1.0 if min_bound <= 0.0 else 1.0 / max(min_bound * float(bd_factor), 1e-6) + + +def qvec2rotmat(qvec: np.ndarray) -> np.ndarray: + q = np.asarray(qvec, dtype=np.float64) + if q.shape != (4,): + raise ValueError(f"Expected qvec shape (4,), got {tuple(q.shape)}") + w, x, y, z = q.tolist() + return np.array( + [ + [1.0 - 2.0 * y * y - 2.0 * z * z, 2.0 * x * y - 2.0 * w * z, 2.0 * x * z + 2.0 * w * y], + [2.0 * x * y + 2.0 * w * z, 1.0 - 2.0 * x * x - 2.0 * z * z, 2.0 * y * z - 2.0 * w * x], + [2.0 * x * z - 2.0 * w * y, 2.0 * y * z + 2.0 * w * x, 1.0 - 2.0 * x * x - 2.0 * y * y], + ], + dtype=np.float32, + ) + + +_COLMAP_CAMERA_MODEL_NUM_PARAMS = { + 0: ("SIMPLE_PINHOLE", 3), + 1: ("PINHOLE", 4), + 2: ("SIMPLE_RADIAL", 4), + 3: ("RADIAL", 5), + 4: ("OPENCV", 8), + 5: ("OPENCV_FISHEYE", 8), + 6: ("FULL_OPENCV", 12), + 7: ("FOV", 5), + 8: ("SIMPLE_RADIAL_FISHEYE", 4), + 9: ("RADIAL_FISHEYE", 5), + 10: ("THIN_PRISM_FISHEYE", 12), +} + + +def _colmap_camera_to_k(camera: dict[str, float | int | str]) -> torch.Tensor: + model = str(camera["model"]) + params = list(camera["params"]) # type: ignore[arg-type] + if model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL", "SIMPLE_RADIAL_FISHEYE", "RADIAL_FISHEYE"): + fx = fy = float(params[0]) + cx = float(params[1]) + cy = float(params[2]) + elif model in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV", "FOV", "THIN_PRISM_FISHEYE"): + fx = float(params[0]) + fy = float(params[1]) + cx = float(params[2]) + cy = float(params[3]) + else: + raise ValueError(f"Unsupported COLMAP camera model: {model}") + k = torch.eye(3, dtype=torch.float32) + k[0, 0] = fx + k[1, 1] = fy + k[0, 2] = cx + k[1, 2] = cy + return k + + +def _read_cameras_txt(cameras_txt: Path) -> dict[int, dict[str, float | int | str | list[float]]]: + cameras: dict[int, dict[str, float | int | str | list[float]]] = {} + for raw in cameras_txt.read_text(encoding="utf-8").splitlines(): + line = raw.strip() + if (not line) or line.startswith("#"): + continue + parts = line.split() + camera_id = int(parts[0]) + cameras[camera_id] = { + "model": str(parts[1]), + "width": int(parts[2]), + "height": int(parts[3]), + "params": [float(x) for x in parts[4:]], + } + return cameras + + +def _read_cameras_bin(cameras_bin: Path) -> dict[int, dict[str, float | int | str | list[float]]]: + cameras: dict[int, dict[str, float | int | str | list[float]]] = {} + data = cameras_bin.read_bytes() + offset = 0 + (num_cameras,) = struct.unpack_from(" dict[str, torch.Tensor | int | str | list[float]]: + w2c = torch.eye(4, dtype=torch.float32) + w2c[:3, :3] = torch.from_numpy(qvec2rotmat(qvec)) + w2c[:3, 3] = torch.from_numpy(tvec.astype(np.float32) * float(pose_scale)) + return { + "image_id": int(image_id), + "camera_id": int(camera_id), + "camera_model": str(camera["model"]), + "camera_params": [float(x) for x in camera["params"]], # type: ignore[index] + "w2c": w2c, + "k": _colmap_camera_to_k(camera), + "width": int(camera["width"]), + "height": int(camera["height"]), + } + + +def _read_images_txt( + images_txt: Path, + cameras: dict[int, dict[str, float | int | str | list[float]]], + pose_scale: float, +) -> dict[str, dict[str, torch.Tensor | int | str | list[float]]]: + entries: dict[str, dict[str, torch.Tensor | int | str | list[float]]] = {} + lines = images_txt.read_text(encoding="utf-8").splitlines() + line_idx = 0 + while line_idx < len(lines): + line = lines[line_idx].strip() + line_idx += 1 + if (not line) or line.startswith("#"): + continue + parts = line.split() + if len(parts) < 10: + continue + qvec = np.array([float(x) for x in parts[1:5]], dtype=np.float64) + tvec = np.array([float(x) for x in parts[5:8]], dtype=np.float32) + camera_id = int(parts[8]) + image_name = str(parts[9]) + entries[image_name] = _colmap_entry_for_image( + image_id=-1, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + camera=cameras[camera_id], + pose_scale=float(pose_scale), + ) + if line_idx < len(lines): + line_idx += 1 + return entries + + +def _read_images_bin( + images_bin: Path, + cameras: dict[int, dict[str, float | int | str | list[float]]], + pose_scale: float, +) -> dict[str, dict[str, torch.Tensor | int | str | list[float]]]: + entries: dict[str, dict[str, torch.Tensor | int | str | list[float]]] = {} + data = images_bin.read_bytes() + offset = 0 + (num_images,) = struct.unpack_from(" dict[str, dict[str, torch.Tensor | int | str | list[float]]] | None: + sparse_dir = root / "sparse" / "0" + cameras_txt = sparse_dir / "cameras.txt" + images_txt = sparse_dir / "images.txt" + cameras_bin = sparse_dir / "cameras.bin" + images_bin = sparse_dir / "images.bin" + if cameras_txt.exists() and images_txt.exists(): + cameras = _read_cameras_txt(cameras_txt) + return _read_images_txt(images_txt, cameras, pose_scale=float(pose_scale)) + if cameras_bin.exists() and images_bin.exists(): + cameras = _read_cameras_bin(cameras_bin) + return _read_images_bin(images_bin, cameras, pose_scale=float(pose_scale)) + return None + + +def load_scaled_colmap_entries(root: Path) -> dict[str, dict[str, torch.Tensor | int | str | list[float]]] | None: + return load_colmap_entries(root, pose_scale=float(colmap_pose_scale_from_bounds(root))) + + +def wild_validation_roots(data_root: Path) -> list[Path]: + if data_root.is_file(): + return [Path(line.strip()) for line in data_root.read_text(encoding="utf-8").splitlines() if line.strip()] + if (data_root / "scenes").is_dir(): + return [data_root] + roots: list[Path] = [] + if data_root.is_dir(): + for child in sorted([p for p in data_root.iterdir() if p.is_dir()]): + if (child / "scenes").is_dir(): + roots.append(child) + elif (child / child.name / "scenes").is_dir(): + roots.append(child / child.name) + if roots: + return roots + return [data_root] diff --git a/unisharp/validation/run_validation.py b/unisharp/validation/run_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..5b7c9e84d0cf707261a60d80cf93167d5f4e0cb2 --- /dev/null +++ b/unisharp/validation/run_validation.py @@ -0,0 +1,2205 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import csv +from dataclasses import dataclass +import json +import logging +import math +import os +import random +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Callable, Iterable, Iterator + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from tqdm import tqdm + +REPO_ROOT = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(REPO_ROOT)) + +from unisharp.datasets.panogs import panogs_collate # noqa: E402 +from unisharp.datasets.scannetpp_fisheye import ScannetppFisheyeDataset # noqa: E402 +from unisharp.datasets.sim_panorama import _EquirecToCube, SimPanoramaDataset # noqa: E402 +from unisharp.datasets.wildrgbd import WildRGBDDataset # noqa: E402 +from unisharp.losses import UnisharpLoss, UnisharpLossWeights # noqa: E402 +from unisharp.models.unisharp_feature import UnisharpFeatureConfig, UnisharpFeatureModel # noqa: E402 +from unisharp.utils.color_space import linearRGB2sRGB # noqa: E402 +from unisharp import DEFAULT_MAX_DEPTH_M # noqa: E402 +from unisharp.utils.io import save_image # noqa: E402 +from unisharp.utils.metrics import ( # noqa: E402 + MetricsCalculator, + compute_masked_rgb_metrics, + default_metric_mask_cache_dir, + metric_mask_from_pinhole_batch, +) +from unisharp.utils.vis import colorize_alpha, colorize_scalar_map # noqa: E402 +from unisharp.validation.io_common import ( # noqa: E402 + decode_rgb_u8 as _decode_rgb_u8, + distance_to_z_depth_pinhole as _distance_to_z_depth_pinhole, + colmap_image_dir as _colmap_image_dir, + colmap_scene_roots as _colmap_scene_roots, + load_colmap_entries as _load_colmap_entries, + load_hm3d_pose as _load_hm3d_pose, + load_scaled_colmap_entries as _load_scaled_colmap_entries, + load_png_depth_m as _load_png_depth_m, + load_png_rgb_u8 as _load_png_rgb_u8, + load_validation_pseudo_depth as _load_validation_pseudo_depth, + load_validation_pseudo_distance as _load_validation_pseudo_distance, + nerf_c2w_to_opencv_c2w as _nerf_c2w_to_opencv_c2w, + normalize_depth_kind as _normalize_depth_kind, + read_manifest_lines as _read_manifest_lines, + resolve_replica_test_root as _resolve_replica_test_root, + resize_k3_align_corners_false as _resize_k3_align_corners_false, + torch_load_any as _torch_load_any, + wild_validation_roots as _wild_validation_roots, +) + + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") +LOGGER = logging.getLogger(__name__) + +ValidationTag = str | list[str] +ValidationItem = tuple[str, Any, ValidationTag, str] + +METRIC_FIELDS = ["psnr", "ssim", "lpips"] + + +def _configure_torchhub_cache() -> Path: + torchhub_dir = REPO_ROOT / "checkpoints" / "torchhub" + torchhub_dir.mkdir(parents=True, exist_ok=True) + os.environ["TORCH_HOME"] = str(torchhub_dir) + torch.hub.set_dir(str(torchhub_dir)) + return torchhub_dir + + +def _training_config_for_checkpoint(checkpoint_path: Path) -> dict[str, Any]: + payload: dict[str, Any] = {} + try: + ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + except TypeError: + ckpt = torch.load(checkpoint_path, map_location="cpu") + except Exception: + ckpt = None + if isinstance(ckpt, dict): + cfg = ckpt.get("config", None) + if isinstance(cfg, dict): + payload.update(cfg) + config_path = Path(checkpoint_path).parent / "config.json" + if not config_path.exists(): + return payload + try: + json_payload = json.loads(config_path.read_text(encoding="utf-8")) + except Exception: + return payload + if isinstance(json_payload, dict): + payload.update(json_payload) + return payload + + +def _fill_arg_from_config(args: argparse.Namespace, attr: str, config: dict[str, Any], key: str, default: float) -> None: + if getattr(args, attr, None) is not None: + return + value = config.get(key, default) + try: + setattr(args, attr, float(value)) + except Exception: + setattr(args, attr, float(default)) + + +def _fill_int_arg_from_config(args: argparse.Namespace, attr: str, config: dict[str, Any], key: str, default: int) -> None: + if getattr(args, attr, None) is not None: + return + value = config.get(key, default) + try: + setattr(args, attr, int(value)) + except Exception: + setattr(args, attr, int(default)) + + +def _apply_training_depth_config_defaults(args: argparse.Namespace) -> None: + config = _training_config_for_checkpoint(Path(args.checkpoint)) + _fill_arg_from_config(args, "max_depth_m", config, "max_depth_m", DEFAULT_MAX_DEPTH_M) + _fill_arg_from_config(args, "sim_far_depth_invalid_m", config, "sim_far_depth_invalid_m", 30.0) + _fill_arg_from_config(args, "sim_far_depth_invalid_max_frac", config, "sim_far_depth_invalid_max_frac", 1.0) + _fill_arg_from_config(args, "re10k_pseudo_far_depth_invalid_m", config, "re10k_pseudo_far_depth_invalid_m", 30.0) + _fill_arg_from_config(args, "scanetpp_fisheye_far_depth_invalid_m", config, "scanetpp_fisheye_far_depth_invalid_m", 30.0) + _fill_arg_from_config(args, "low_pass_filter_eps", config, "render_low_pass_filter_eps", 1e-2) + + +def _append_metrics_row(csv_path: Path, row: dict[str, float]) -> None: + fieldnames = list(METRIC_FIELDS) + row_out = {k: row.get(k, float("nan")) for k in fieldnames} + if csv_path.exists(): + try: + with csv_path.open("r", newline="") as f: + reader = csv.reader(f) + existing_header = next(reader, []) + if existing_header: + if fieldnames != existing_header: + with csv_path.open("r", newline="") as f: + old_rows = list(csv.DictReader(f)) + with csv_path.open("w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for r in old_rows: + writer.writerow({k: r.get(k, float("nan")) for k in fieldnames}) + except Exception: + pass + write_header = not csv_path.exists() or csv_path.stat().st_size == 0 + with csv_path.open("a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + if write_header: + writer.writeheader() + writer.writerow(row_out) + + +def _append_sample_metrics_row(csv_path: Path, group_key: str, tag: str, row: dict[str, float]) -> None: + fieldnames = ["group", "tag", *METRIC_FIELDS] + row_out = {"group": group_key, "tag": tag, **{k: row.get(k, float("nan")) for k in METRIC_FIELDS}} + write_header = not csv_path.exists() or csv_path.stat().st_size == 0 + with csv_path.open("a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + if write_header: + writer.writeheader() + writer.writerow(row_out) + + +def _feature_config_from_checkpoint(checkpoint_path: Path, ckpt: dict[str, Any]) -> UnisharpFeatureConfig: + cfg = UnisharpFeatureConfig() + merged: dict[str, Any] = {} + cfg_payload = ckpt.get("config", {}) + if isinstance(cfg_payload, dict): + merged.update(cfg_payload) + for key in cfg.__dict__.keys(): + if key in ckpt: + merged[key] = ckpt[key] + config_path = Path(checkpoint_path).parent / "config.json" + if config_path.exists(): + try: + sidecar = json.loads(config_path.read_text(encoding="utf-8")) + except Exception: + sidecar = None + if isinstance(sidecar, dict): + merged.update({k: v for k, v in sidecar.items() if k in cfg.__dict__}) + for k in cfg.__dict__.keys(): + if k in merged: + setattr(cfg, k, merged[k]) + return cfg + + +def _load_model(checkpoint_path: Path, device: torch.device) -> tuple[UnisharpFeatureModel, int]: + try: + ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + except TypeError: + ckpt = torch.load(checkpoint_path, map_location="cpu") + if not isinstance(ckpt, dict): + raise ValueError(f"Expected feature-only checkpoint dict, got {type(ckpt)} from {checkpoint_path}") + + cfg = _feature_config_from_checkpoint(checkpoint_path, ckpt) + model = UnisharpFeatureModel(cfg).to(device) + model.load_from_checkpoint(str(checkpoint_path), strict=True) + model.eval() + return model, int(ckpt.get("step", 0)) + + +def _build_trainer(model: UnisharpFeatureModel, device: torch.device, args: argparse.Namespace) -> Any: + from unisharp.cli.unified_trainer import UnifiedTrainer + zero_w = UnisharpLossWeights( + lambda_color=0.0, + lambda_alpha=0.0, + lambda_percep=0.0, + lambda_depth=0.0, + lambda_tv=0.0, + lambda_grad=0.0, + lambda_grad_img=0.0, + lambda_delta=0.0, + lambda_splat=0.0, + ) + loss_fn = UnisharpLoss(zero_w).to(device) + max_depth_m = float(getattr(args, "max_depth_m", getattr(model.config, "max_distance_m", DEFAULT_MAX_DEPTH_M))) + loss_fn.SUPERVISION_MAX_DEPTH_M = max_depth_m + from unisharp.utils.gsplat import GSplatRenderer + + renderer = GSplatRenderer( + color_space="sRGB", + background_color="black", + low_pass_filter_eps=float(getattr(args, "low_pass_filter_eps", 1e-2)), + ).to(device) + return UnifiedTrainer( + model=model, + renderer=renderer, + loss_fn=loss_fn, + device=device, + enable_tgt_unik3d_vis=False, + max_depth_m=max_depth_m, + sim_far_depth_invalid_m=float(getattr(args, "sim_far_depth_invalid_m", 30.0)), + re10k_pseudo_far_depth_invalid_m=float(getattr(args, "re10k_pseudo_far_depth_invalid_m", 30.0)), + scanetpp_fisheye_far_depth_invalid_m=float(getattr(args, "scanetpp_fisheye_far_depth_invalid_m", 30.0)), + ) + + +def _make_pinhole_batch( + *, + src_img: torch.Tensor, + tgt_img: torch.Tensor, + src_w2c: torch.Tensor, + tgt_w2c: torch.Tensor, + src_k: torch.Tensor, + tgt_k: torch.Tensor, + scene: str, + src_idx: int | list[int], + tgt_idx: int | list[int], + src_depth: torch.Tensor | None = None, + tgt_depth: torch.Tensor | None = None, + src_img_orig: torch.Tensor | None = None, + tgt_img_orig: torch.Tensor | None = None, + src_k_orig: torch.Tensor | None = None, + tgt_k_orig: torch.Tensor | None = None, + src_depth_orig: torch.Tensor | None = None, + tgt_depth_orig: torch.Tensor | None = None, +) -> SimpleNamespace: + batch_size = int(tgt_img.shape[0]) if torch.is_tensor(tgt_img) and tgt_img.ndim == 4 else 1 + scene_values = [scene] * batch_size if isinstance(scene, str) else list(scene) + src_idx_values = [int(src_idx)] * batch_size if isinstance(src_idx, int) else [int(x) for x in src_idx] + tgt_idx_values = [int(tgt_idx)] if isinstance(tgt_idx, int) else [int(x) for x in tgt_idx] + return SimpleNamespace( + src_rgb_u8=src_img, + tgt_rgb_u8=tgt_img, + src_depth_m=src_depth, + tgt_depth_m=tgt_depth, + src_rgb_u8_orig=(src_img if src_img_orig is None else src_img_orig), + tgt_rgb_u8_orig=(tgt_img if tgt_img_orig is None else tgt_img_orig), + src_depth_m_orig=(src_depth if src_depth_orig is None else src_depth_orig), + tgt_depth_m_orig=(tgt_depth if tgt_depth_orig is None else tgt_depth_orig), + src_w2c=src_w2c, + tgt_w2c=tgt_w2c, + src_intrinsics=src_k, + tgt_intrinsics=tgt_k, + src_intrinsics_orig=(src_k if src_k_orig is None else src_k_orig), + tgt_intrinsics_orig=(tgt_k if tgt_k_orig is None else tgt_k_orig), + scene=scene_values, + src_idx=torch.tensor(src_idx_values, dtype=torch.long), + tgt_idx=torch.tensor(tgt_idx_values, dtype=torch.long), + share_src_forward=True, + collect_all_vis=True, + ) + + +def _make_scanetpp_fisheye_batch( + *, + scene: str, + src_pos: int, + tgt_positions: list[int], + src_frame: dict[str, Any], + tgt_frames: list[dict[str, Any]], + src_loaded: dict[str, torch.Tensor], + tgt_loaded: list[dict[str, torch.Tensor]], +) -> SimpleNamespace: + n = int(len(tgt_positions)) + src_rgb = src_loaded["rgb_u8"].unsqueeze(0).repeat(n, 1, 1, 1) + src_depth = src_loaded["depth_m"].unsqueeze(0).repeat(n, 1, 1, 1) + src_mask = src_loaded["valid_mask"].unsqueeze(0).repeat(n, 1, 1, 1) + src_w2c = src_frame["w2c"].to(torch.float32).unsqueeze(0).repeat(n, 1, 1) + src_cam = src_loaded["camera_params"].to(torch.float32).unsqueeze(0).repeat(n, 1) + return SimpleNamespace( + src_rgb_u8=src_rgb, + tgt_rgb_u8=torch.stack([item["rgb_u8"] for item in tgt_loaded], dim=0), + src_depth_m=src_depth, + tgt_depth_m=torch.stack([item["depth_m"] for item in tgt_loaded], dim=0), + src_valid_mask=src_mask, + tgt_valid_mask=torch.stack([item["valid_mask"] for item in tgt_loaded], dim=0), + src_w2c=src_w2c, + tgt_w2c=torch.stack([frame["w2c"].to(torch.float32) for frame in tgt_frames], dim=0), + src_camera_params=src_cam, + tgt_camera_params=torch.stack([item["camera_params"].to(torch.float32) for item in tgt_loaded], dim=0), + src_idx=torch.full((n,), int(src_pos), dtype=torch.long), + tgt_idx=torch.tensor([int(x) for x in tgt_positions], dtype=torch.long), + scene=[str(scene)] * n, + camera_model="fisheye624", + share_src_forward=True, + collect_all_vis=True, + ) + + + +@dataclass +class _PinholeTargetAdapter: + idx: int + img: torch.Tensor + w2c: torch.Tensor + k: torch.Tensor + depth: torch.Tensor | None = None + + +@dataclass +class _PinholeGroupAdapter: + scene: str + group_key: str + src_idx: int + src_img: torch.Tensor + src_w2c: torch.Tensor + src_k: torch.Tensor + tgt_indices: list[int] + load_target: Callable[[int], _PinholeTargetAdapter | None] + src_depth: torch.Tensor | None = None + + +def _iter_manifest_parts( + args: argparse.Namespace, + *, + expected_parts: int, +) -> Iterator[tuple[int, list[str]]]: + manifest_in = _read_manifest_lines(getattr(args, "manifest_file", None), max_lines=_manifest_max_groups(args)) + for group_idx, raw in enumerate(manifest_in): + parts = raw.split("|") + if len(parts) == int(expected_parts): + yield group_idx, parts + + +def _yield_pinhole_group_batches( + dataset: str, + adapter: _PinholeGroupAdapter, + args: argparse.Namespace, +) -> Iterator[ValidationItem]: + batch_size = max(1, int(getattr(args, "validation_batch_size", 1))) + pending: list[_PinholeTargetAdapter] = [] + + def _flush(targets: list[_PinholeTargetAdapter]) -> Iterator[ValidationItem]: + if not targets: + return + n = len(targets) + src_img_orig = adapter.src_img.clone() + if n > 1: + src_img_orig = src_img_orig.repeat(n, 1, 1, 1) + tgt_img_orig = torch.cat([t.img.clone() for t in targets], dim=0) + src_k_orig = adapter.src_k.clone() + if n > 1: + src_k_orig = src_k_orig.repeat(n, 1, 1) + tgt_k_orig = torch.cat([t.k.clone() for t in targets], dim=0) + src_w2c = adapter.src_w2c + if n > 1: + src_w2c = src_w2c.repeat(n, 1, 1) + tgt_w2c = torch.cat([t.w2c for t in targets], dim=0) + src_depth_orig = None if adapter.src_depth is None else adapter.src_depth.clone() + if src_depth_orig is not None and n > 1: + src_depth_orig = src_depth_orig.repeat(n, 1, 1, 1) + tgt_depth_values = [t.depth for t in targets] + tgt_depth_orig = None + if all(torch.is_tensor(d) for d in tgt_depth_values): + tgt_depth_orig = torch.cat([d for d in tgt_depth_values if torch.is_tensor(d)], dim=0) + batch = _make_pinhole_batch( + src_img=src_img_orig, + tgt_img=tgt_img_orig, + src_w2c=src_w2c, + tgt_w2c=tgt_w2c, + src_k=src_k_orig, + tgt_k=tgt_k_orig, + scene=adapter.scene, + src_idx=[adapter.src_idx] * n, + tgt_idx=[int(t.idx) for t in targets], + src_depth=src_depth_orig, + tgt_depth=tgt_depth_orig, + src_img_orig=src_img_orig, + tgt_img_orig=tgt_img_orig, + src_k_orig=src_k_orig, + tgt_k_orig=tgt_k_orig, + src_depth_orig=src_depth_orig, + tgt_depth_orig=tgt_depth_orig, + ) + tags = [f"{adapter.group_key}_t{int(t.idx):05d}" for t in targets] + yield (dataset, batch, tags[0] if len(tags) == 1 else tags, adapter.group_key) + + for tgt_idx in adapter.tgt_indices: + tgt = adapter.load_target(int(tgt_idx)) + if tgt is None: + continue + pending.append(tgt) + if len(pending) >= batch_size: + yield from _flush(pending) + pending = [] + if pending: + yield from _flush(pending) + + +def _yield_panogs_group_batches( + dataset: str, + group_key: str, + samples: list[Any], + tags: list[str], + args: argparse.Namespace, +) -> Iterator[ValidationItem]: + batch_size = max(1, int(getattr(args, "validation_batch_size", 1))) + if len(samples) != len(tags): + raise ValueError(f"Expected samples/tags length match, got {len(samples)} vs {len(tags)}") + for start in range(0, len(samples), batch_size): + end = min(len(samples), start + batch_size) + batch_tags = tags[start:end] + batch = panogs_collate(samples[start:end]) + object.__setattr__(batch, "collect_all_vis", True) + yield ( + dataset, + batch, + batch_tags[0] if len(batch_tags) == 1 else batch_tags, + group_key, + ) + + +def _mask_connected_to_border(mask_2d: torch.Tensor) -> torch.Tensor: + h, w = int(mask_2d.shape[0]), int(mask_2d.shape[1]) + if h <= 0 or w <= 0: + return torch.zeros_like(mask_2d, dtype=torch.bool) + border = torch.zeros_like(mask_2d, dtype=torch.bool) + border[0, :] = True + border[-1, :] = True + border[:, 0] = True + border[:, -1] = True + frontier = mask_2d & border + visited = frontier.clone() + kernel = torch.tensor( + [[[[0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [0.0, 1.0, 0.0]]]], + device=mask_2d.device, + dtype=torch.float32, + ) + for _ in range(h + w): + if not bool(frontier.any()): + break + neigh = F.conv2d(frontier[None, None].to(torch.float32), kernel, padding=1)[0, 0] > 0.0 + new_frontier = neigh & mask_2d & (~visited) + visited = visited | new_frontier + frontier = new_frontier + return visited + + +def _compute_metrics_from_vis( + vis: dict[str, Any], + metrics_calc: MetricsCalculator, +) -> dict[str, float]: + tgt_gt = vis["tgt_gt"].detach().to(torch.float32).clamp(0, 1) + tgt_alpha = vis["tgt_alpha"].detach().to(torch.float32).clamp(0.0, 1.0) + tgt_pred = linearRGB2sRGB( + (vis["tgt_pred"].detach().to(torch.float32) / tgt_alpha.clamp(min=1e-4)).clamp(0.0, 1.0) + ).clamp(0, 1) + + geom_mask = vis.get("tgt_metric_mask", None) + if torch.is_tensor(geom_mask): + geom_mask = geom_mask.detach().to(device=tgt_pred.device, dtype=torch.float32) + if geom_mask.ndim == 3: + geom_mask = geom_mask.unsqueeze(1) + if tuple(geom_mask.shape[-2:]) != tuple(tgt_pred.shape[-2:]): + geom_mask = F.interpolate(geom_mask, size=tgt_pred.shape[-2:], mode="nearest") + tgt_geom = compute_masked_rgb_metrics( + pred=tgt_pred, + gt=tgt_gt, + mask=geom_mask, + metrics_calc=metrics_calc, + ) + else: + tgt_geom = metrics_calc.compute_rgb_metrics(tgt_pred, tgt_gt) + + return { + "psnr": float(tgt_geom["psnr"]), + "ssim": float(tgt_geom["ssim"]), + "lpips": float(tgt_geom["lpips"]), + } + + +def _save_vis_from_payload(vis: dict[str, Any], vis_dir: Path, tag: str, step: int) -> None: + from unisharp.utils.unified_vis import save_pair_visualization + + save_pair_visualization( + vis_dir / f"step_{int(step):07d}_{tag}.png", + src_gt=vis["src_gt"], + src_pred=vis["src_pred"], + src_alpha=vis["src_alpha"], + tgt_gt=vis["tgt_gt"], + tgt_pred=vis["tgt_pred"], + tgt_alpha=vis["tgt_alpha"], + src_gt_depth=vis.get("src_gt_depth", None), + tgt_gt_depth=vis.get("tgt_gt_depth", None), + src_pred_depth=vis.get("src_pred_depth", None), + tgt_pred_depth=vis.get("tgt_pred_depth", None), + src_unik3d_depth=vis.get("src_unik3d_depth", None), + tgt_unik3d_depth=vis.get("tgt_unik3d_depth", None), + dataset_name=str(vis.get("dataset_name", "unknown")), + scene=str(vis.get("scene", "unknown")), + step=int(step), + src_idx=int(vis.get("src_idx", -1)), + tgt_idx=int(vis.get("tgt_idx", -1)), + src_pose_w2c=vis.get("src_pose_w2c", None), + tgt_pose_w2c=vis.get("tgt_pose_w2c", None), + src_cube_gt_u8=vis.get("src_cube_gt_u8", None), + src_cube_pred_linear=vis.get("src_cube_pred_linear", None), + src_cube_alpha=vis.get("src_cube_alpha", None), + tgt_cube_gt_u8=vis.get("tgt_cube_gt_u8", None), + tgt_cube_pred_linear=vis.get("tgt_cube_pred_linear", None), + tgt_cube_alpha=vis.get("tgt_cube_alpha", None), + ) + + +def _save_group_pair_pngs(group_dir: Path, group_items: list[dict[str, Any]]) -> None: + visual_items = [item for item in group_items if isinstance(item.get("vis", None), dict)] + if not visual_items: + return + group_dir.mkdir(parents=True, exist_ok=True) + + def _save_mask(mask: torch.Tensor | None, path: Path) -> None: + if not torch.is_tensor(mask): + return + mask_rgb = mask.detach().to(torch.float32).clamp(0.0, 1.0) + if mask_rgb.ndim == 4: + mask_rgb = mask_rgb[0] + if mask_rgb.ndim == 3 and int(mask_rgb.shape[0]) == 1: + mask_rgb = mask_rgb.repeat(3, 1, 1) + if mask_rgb.ndim == 2: + mask_rgb = mask_rgb[None].repeat(3, 1, 1) + if mask_rgb.ndim == 3: + save_image(_to_u8_hwc(mask_rgb), path) + + def _save_alpha(alpha: torch.Tensor | None, path: Path) -> None: + if not torch.is_tensor(alpha): + return + a = alpha.detach().to(torch.float32).clamp(0.0, 1.0) + if a.ndim == 4: + a = a[0] + if a.ndim == 3 and int(a.shape[0]) == 1: + a = a.repeat(3, 1, 1) + if a.ndim == 2: + a = a[None].repeat(3, 1, 1) + if a.ndim == 3: + save_image(_to_u8_hwc(a), path) + + def _noextrap_mask_from_vis(vis: dict[str, Any], which: str) -> torch.Tensor | None: + pred = vis.get(f"{which}_pred", None) + alpha = vis.get(f"{which}_alpha", None) + if not (torch.is_tensor(pred) and torch.is_tensor(alpha)): + return None + pred_pm = linearRGB2sRGB(pred.detach().to(torch.float32).clamp(min=0.0)).clamp(0.0, 1.0) + alpha = alpha.detach().to(torch.float32).clamp(0.0, 1.0) + masks: list[torch.Tensor] = [] + for bi in range(int(pred_pm.shape[0])): + black = pred_pm[bi : bi + 1].max(dim=1, keepdim=True).values <= float(2.0 / 255.0) + low_alpha = alpha[bi : bi + 1] <= float(0.02) + extrap_border = _mask_connected_to_border((black & low_alpha)[0, 0]) + masks.append((~extrap_border)[None, None].to(torch.float32)) + return torch.cat(masks, dim=0) + + src_row = _build_perspective_row(visual_items[0]["vis"], "src") + save_image(src_row[0], group_dir / "src_gt.png") + save_image(src_row[1], group_dir / "src_pred.png") + _save_alpha(visual_items[0]["vis"].get("src_alpha", None), group_dir / "src_alpha.png") + _save_mask(_noextrap_mask_from_vis(visual_items[0]["vis"], "src"), group_dir / "src_noextrap_mask.png") + _save_mask(visual_items[0]["vis"].get("src_metric_mask", None), group_dir / "src_mask.png") + for idx, item in enumerate(visual_items): + tgt_row = _build_perspective_row(item["vis"], "tgt") + save_image(tgt_row[0], group_dir / f"tgt_{idx:03d}_gt.png") + save_image(tgt_row[1], group_dir / f"tgt_{idx:03d}_pred.png") + _save_alpha(item["vis"].get("tgt_alpha", None), group_dir / f"tgt_{idx:03d}_alpha.png") + _save_mask(_noextrap_mask_from_vis(item["vis"], "tgt"), group_dir / f"tgt_{idx:03d}_noextrap_mask.png") + _save_mask(item["vis"].get("tgt_training_mask", None), group_dir / f"tgt_{idx:03d}_training_mask.png") + _save_mask(item["vis"].get("tgt_metric_mask", None), group_dir / f"tgt_{idx:03d}_mask.png") + + +def _aggregate_rows(rows: list[dict[str, float]]) -> dict[str, float]: + agg: dict[str, float] = {} + if not rows: + return agg + keys = sorted(set().union(*[set(r.keys()) for r in rows])) + for k in keys: + arr = np.array([r.get(k, np.nan) for r in rows], dtype=np.float64) + agg[k] = _safe_nanmean(arr) + agg["num_samples"] = float(len(rows)) + return agg + + +def _safe_nanmean(values: Any) -> float: + arr = np.asarray(values, dtype=np.float64) + if arr.size == 0: + return float("nan") + if not np.isfinite(arr).any(): + return float("nan") + return float(np.nanmean(arr)) + + +def _to_u8_hwc(img_chw: torch.Tensor) -> np.ndarray: + if img_chw.dtype == torch.uint8: + return img_chw.permute(1, 2, 0).detach().cpu().numpy() + x = img_chw.detach().to(torch.float32).clamp(0.0, 1.0) + return (x * 255.0).round().to(torch.uint8).permute(1, 2, 0).cpu().numpy() + + +def _concat_grid(rows: list[list[np.ndarray]], pad: int = 6, pad_value: int = 0) -> np.ndarray: + row_imgs: list[np.ndarray] = [] + for r in rows: + padded: list[np.ndarray] = [] + for i, im in enumerate(r): + padded.append(im) + if i != len(r) - 1 and pad > 0: + padded.append(np.full((im.shape[0], pad, 3), pad_value, dtype=np.uint8)) + row_imgs.append(np.concatenate(padded, axis=1)) + merged: list[np.ndarray] = [] + for i, im in enumerate(row_imgs): + merged.append(im) + if i != len(row_imgs) - 1 and pad > 0: + merged.append(np.full((pad, im.shape[1], 3), pad_value, dtype=np.uint8)) + return np.concatenate(merged, axis=0) + + +def _resize_panel_np(panel: np.ndarray, out_h: int, out_w: int) -> np.ndarray: + if panel.shape[0] == out_h and panel.shape[1] == out_w: + return panel + return np.asarray(Image.fromarray(panel).resize((int(out_w), int(out_h)), resample=Image.BILINEAR)) + + +def _normalize_rows_for_grid(rows: list[list[np.ndarray]]) -> list[list[np.ndarray]]: + if not rows or not rows[0]: + return rows + ref_h, ref_w = rows[0][0].shape[:2] + return [[_resize_panel_np(panel, ref_h, ref_w) for panel in row] for row in rows] + + +def _save_gif(frames: list[np.ndarray], out_file: Path, duration_ms: int = 250) -> None: + if not frames: + return + out_file.parent.mkdir(parents=True, exist_ok=True) + pil_frames = [Image.fromarray(frame) for frame in frames] + pil_frames[0].save( + out_file, + save_all=True, + append_images=pil_frames[1:], + duration=int(duration_ms), + loop=0, + disposal=2, + ) + + +def _depth_range(depth: torch.Tensor | None, fallback: tuple[float, float] = (0.0, 10.0)) -> tuple[float, float]: + if not torch.is_tensor(depth): + return fallback + valid = depth[torch.isfinite(depth) & (depth > 0.0)] + if int(valid.numel()) < 8: + return fallback + valid = valid.to(torch.float32).flatten() + if int(valid.numel()) > 262144: + step = max(1, int(valid.numel()) // 262144) + valid = valid[::step] + vmin = float(torch.quantile(valid, 0.01).item()) + vmax = float(torch.quantile(valid, 0.99).item()) + vmin = max(0.0, vmin) + vmax = max(vmin + 1e-3, vmax) + return (vmin, vmax) + + +def _depth_panel(depth: torch.Tensor | None, val_min: float, val_max: float, blank: np.ndarray) -> np.ndarray: + if not torch.is_tensor(depth): + return blank + d = depth.detach().to(torch.float32) + valid = torch.isfinite(d) & (d > 0.0) + if int(valid.sum().item()) < 8: + return blank + valid_vals = d[valid].flatten() + if int(valid_vals.numel()) > 262144: + step = max(1, int(valid_vals.numel()) // 262144) + valid_vals = valid_vals[::step] + fill = float(torch.quantile(valid_vals, 0.5).item()) + d_safe = torch.where(valid, d, torch.full_like(d, fill)).clamp(min=val_min, max=val_max) + panel = colorize_scalar_map(d_safe[0, 0], val_min=val_min, val_max=val_max, color_map="turbo") + out = _to_u8_hwc(panel) + out[~valid[0, 0].detach().cpu().numpy()] = 0 + return out + + +def _build_perspective_row(vis: dict[str, Any], which: str) -> list[np.ndarray]: + if which not in ("src", "tgt"): + raise ValueError(f"which must be src/tgt, got {which}") + gt = vis[f"{which}_gt"].detach().to(torch.float32).clamp(0.0, 1.0) + pred = vis[f"{which}_pred"].detach().to(torch.float32) + alpha = vis[f"{which}_alpha"].detach().to(torch.float32).clamp(0.0, 1.0) + pred_vis = linearRGB2sRGB((pred / alpha.clamp(min=1e-4)).clamp(0.0, 1.0)).clamp(0.0, 1.0) + err = (pred_vis - gt).abs().mean(dim=1, keepdim=True) + err_vals = err.flatten() + if int(err_vals.numel()) > 262144: + step = max(1, int(err_vals.numel()) // 262144) + err_vals = err_vals[::step] + vmax = float(max(1e-3, min(float(torch.quantile(err_vals, 0.99).item()), 0.5))) + err_u8 = _to_u8_hwc(colorize_scalar_map(err[0, 0], val_min=0.0, val_max=vmax, color_map="turbo")) + alpha_u8 = _to_u8_hwc(colorize_alpha(alpha)[0]) + blank = np.zeros_like(_to_u8_hwc(gt[0])) + gt_depth = vis.get(f"{which}_gt_depth", None) + pred_depth = vis.get(f"{which}_pred_depth", None) + dmin, dmax = _depth_range(gt_depth) + if not torch.is_tensor(gt_depth): + dmin, dmax = _depth_range(pred_depth) + ref_d = _depth_panel(gt_depth, dmin, dmax, blank) + pred_d = _depth_panel(pred_depth, dmin, dmax, blank) + return [_to_u8_hwc(gt[0]), _to_u8_hwc(pred_vis[0]), err_u8, alpha_u8, ref_d, pred_d] + + +def _build_perspective_gif_frame(vis: dict[str, Any], which: str) -> np.ndarray: + pred = vis[f"{which}_pred"].detach().to(torch.float32) + alpha = vis[f"{which}_alpha"].detach().to(torch.float32).clamp(0.0, 1.0) + pred_vis = linearRGB2sRGB((pred / alpha.clamp(min=1e-4)).clamp(0.0, 1.0)).clamp(0.0, 1.0) + return _to_u8_hwc(pred_vis[0]) + + +def _cube_faces_u8(cube_img: torch.Tensor, face_count: int = 6) -> list[np.ndarray]: + x = cube_img + if x.ndim == 5 and x.shape[0] == 1: + x = x[0] + if x.ndim != 4: + return [] + faces = [] + if x.shape[0] == face_count and x.shape[1] == 3: + for i in range(face_count): + faces.append(_to_u8_hwc(x[i])) + elif x.shape[0] == face_count and x.shape[-1] == 3: + for i in range(face_count): + faces.append(_to_u8_hwc(x[i].permute(2, 0, 1).contiguous())) + return faces + + +def _build_hm3d_front_gif_frame(vis: dict[str, Any], which: str) -> np.ndarray: + cube_pred = vis.get(f"{which}_cube_pred_linear", None) + cube_alpha = vis.get(f"{which}_cube_alpha", None) + if torch.is_tensor(cube_pred) and torch.is_tensor(cube_alpha): + pred = linearRGB2sRGB( + (cube_pred.detach().to(torch.float32) / cube_alpha.detach().to(torch.float32).clamp(min=1e-4)).clamp(0.0, 1.0) + ).clamp(0.0, 1.0) + faces = _cube_faces_u8(pred) + if len(faces) == 6: + return faces[3] + return _build_perspective_gif_frame(vis, which) + + +def _save_group_gif( + *, + dataset: str, + group_dir: Path, + group_key: str, + step: int, + group_items: list[dict[str, Any]], +) -> None: + visual_items = [item for item in group_items if isinstance(item.get("vis", None), dict)] + if not visual_items: + return + if dataset in {"hm3d", "replica"}: + frames = [_build_hm3d_front_gif_frame(visual_items[0]["vis"], "src")] + frames.extend(_build_hm3d_front_gif_frame(item["vis"], "tgt") for item in visual_items[:10]) + else: + frames = [_build_perspective_gif_frame(visual_items[0]["vis"], "src")] + frames.extend(_build_perspective_gif_frame(item["vis"], "tgt") for item in visual_items[:10]) + _save_gif(frames, group_dir / f"step_{int(step):07d}_{group_key}.gif") + + +def _save_perspective_group_grid( + *, + group_dir: Path, + group_key: str, + step: int, + group_items: list[dict[str, Any]], +) -> None: + visual_items = [item for item in group_items if isinstance(item.get("vis", None), dict)] + if not visual_items: + return + group_dir.mkdir(parents=True, exist_ok=True) + first_vis = visual_items[0]["vis"] + rows: list[list[np.ndarray]] = [_build_perspective_row(first_vis, "src")] + for item in visual_items[:10]: + rows.append(_build_perspective_row(item["vis"], "tgt")) + while len(rows) < 11: + rows.append(list(rows[-1])) + rows = _normalize_rows_for_grid(rows) + grid = _concat_grid(rows=rows, pad=6, pad_value=0) + out_file = group_dir / f"step_{int(step):07d}_{group_key}_erp_11x6.png" + save_image(grid, out_file) + + + +def _manifest_max_groups(args: argparse.Namespace) -> int: + return max(0, int(getattr(args, "manifest_max_groups", 0))) + + +def _validation_pseudo_root(args: argparse.Namespace) -> Path | None: + root = getattr(args, "validation_pseudo_depth_root", None) + return Path(root) if root is not None else Path("/media/team_data/ML4_team/datasets/sharp/validation_unik3d_pseudo_depth") + + +def _re10k_pseudo_scene_key(scene: Any) -> str: + key = str(scene).strip().replace("\\", "__").replace("/", "__") + return key if key else "unknown_scene" + + +def _re10k_training_pseudo_depth_path(args: argparse.Namespace, scene: Any, frame_idx: Any) -> Path | None: + root = getattr(args, "re10k_pseudo_depth_root", None) + if root is None: + return None + root = Path(root) + split = str(getattr(args, "split", "test")) + base = root if root.name == split else root / split + try: + frame_key = f"{int(frame_idx):05d}" + except Exception: + frame_key = str(frame_idx) + return base / _re10k_pseudo_scene_key(scene) / f"{frame_key}.pt" + + +def _load_re10k_training_pseudo_depth( + args: argparse.Namespace, + *, + scene: Any, + frame_idx: Any, + intrinsics_k3: torch.Tensor, +) -> torch.Tensor | None: + path = _re10k_training_pseudo_depth_path(args, scene, frame_idx) + if path is None or not path.exists(): + return None + try: + payload = _torch_load_any(path) + depth_kind = "distance" + if isinstance(payload, dict): + depth = payload.get("z_depth_m", None) + if torch.is_tensor(depth): + depth_kind = "zdepth" + else: + depth = payload.get("distance_m", None) + if torch.is_tensor(depth): + depth_kind = "distance" + else: + depth = payload.get("depth_m", None) + depth_kind = _normalize_depth_kind(payload.get("depth_kind", "distance"), default="distance") + else: + depth = payload + depth_kind = "distance" + if isinstance(depth, np.ndarray): + depth = torch.from_numpy(depth) + if not torch.is_tensor(depth): + return None + if depth.ndim == 2: + depth = depth.unsqueeze(0) + if depth.ndim != 3 or int(depth.shape[0]) != 1: + return None + depth = depth.to(torch.float32) + max_depth_m = float(getattr(args, "max_depth_m", DEFAULT_MAX_DEPTH_M)) + far_invalid_m = float(getattr(args, "re10k_pseudo_far_depth_invalid_m", 30.0)) + valid = torch.isfinite(depth) & (depth > 0.0) + if far_invalid_m > 0.0: + valid = valid & (depth <= far_invalid_m) + depth = torch.where(valid, depth, torch.zeros_like(depth)) + if int(valid.sum().item()) <= 0: + return None + depth[valid] = depth[valid].clamp(max=max_depth_m) + if _normalize_depth_kind(depth_kind, default="distance") != "zdepth": + depth = _distance_to_z_depth_pinhole(depth, intrinsics_k3=intrinsics_k3) + return depth + except Exception: + return None + + +def _load_val_pseudo_depth( + args: argparse.Namespace, + *, + dataset: str, + scene: Any, + frame_idx: Any, + intrinsics_k3: torch.Tensor, +) -> torch.Tensor | None: + if str(dataset) == "re10k": + depth = _load_re10k_training_pseudo_depth( + args, + scene=scene, + frame_idx=frame_idx, + intrinsics_k3=intrinsics_k3, + ) + if torch.is_tensor(depth): + return depth + return _load_validation_pseudo_depth( + _validation_pseudo_root(args), + dataset=dataset, + scene=scene, + frame_idx=frame_idx, + intrinsics_k3=intrinsics_k3, + ) + + +def _load_val_pseudo_depth_b1hw( + args: argparse.Namespace, + *, + dataset: str, + scene: Any, + frame_idx: Any, + intrinsics_k3: torch.Tensor, +) -> torch.Tensor | None: + depth = _load_val_pseudo_depth( + args, + dataset=dataset, + scene=scene, + frame_idx=frame_idx, + intrinsics_k3=intrinsics_k3, + ) + return depth.unsqueeze(0) if torch.is_tensor(depth) else None + + +def _load_val_pseudo_distance_b1hw( + args: argparse.Namespace, + *, + dataset: str, + scene: Any, + frame_idx: Any, +) -> torch.Tensor | None: + depth = _load_validation_pseudo_distance( + _validation_pseudo_root(args), + dataset=dataset, + scene=scene, + frame_idx=frame_idx, + ) + return depth.unsqueeze(0) if torch.is_tensor(depth) else None + + +def _iter_re10k_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + for group_idx, parts in _iter_manifest_parts(args, expected_parts=4): + chunk_path = Path(parts[0]) + scene = str(parts[1]) + src_idx = int(parts[2]) + tgt_indices = [int(x) for x in parts[3].split(",") if x.strip()] + payload = _torch_load_any(chunk_path) + if not isinstance(payload, list): + continue + example = next((ex for ex in payload if isinstance(ex, dict) and str(ex.get("key", chunk_path.stem)) == scene), None) + if not isinstance(example, dict): + continue + poses = example.get("cameras", None) + images = example.get("images", None) + if not torch.is_tensor(poses) or not isinstance(images, list) or poses.ndim != 2 or poses.shape[1] != 18: + continue + if not (0 <= src_idx < len(images)): + continue + src_probe = _decode_rgb_u8(images[0]) + h0, w0 = int(src_probe.shape[1]), int(src_probe.shape[2]) + intr_all = torch.eye(3, dtype=torch.float32).unsqueeze(0).repeat(int(poses.shape[0]), 1, 1) + intr_all[:, 0, 0] = poses[:, 0] * float(w0) + intr_all[:, 1, 1] = poses[:, 1] * float(h0) + intr_all[:, 0, 2] = poses[:, 2] * float(w0) - 0.5 + intr_all[:, 1, 2] = poses[:, 3] * float(h0) - 0.5 + w2c_all = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(int(poses.shape[0]), 1, 1) + w2c_all[:, :3] = poses[:, 6:].reshape(-1, 3, 4).to(torch.float32) + group_key = f"{scene}_g{group_idx:05d}" + adapter = _PinholeGroupAdapter( + scene=scene, + group_key=group_key, + src_idx=int(src_idx), + src_img=_decode_rgb_u8(images[src_idx]).unsqueeze(0), + src_depth=_load_val_pseudo_depth_b1hw( + args, dataset="re10k", scene=scene, frame_idx=src_idx, intrinsics_k3=intr_all[src_idx] + ), + src_w2c=w2c_all[src_idx].unsqueeze(0), + src_k=intr_all[src_idx].unsqueeze(0).clone(), + tgt_indices=tgt_indices, + load_target=lambda tgt_idx, images=images, intr_all=intr_all, w2c_all=w2c_all: None + if not (0 <= int(tgt_idx) < len(images)) + else _PinholeTargetAdapter( + idx=int(tgt_idx), + img=_decode_rgb_u8(images[int(tgt_idx)]).unsqueeze(0), + w2c=w2c_all[int(tgt_idx)].unsqueeze(0), + k=intr_all[int(tgt_idx)].unsqueeze(0).clone(), + depth=_load_val_pseudo_depth_b1hw( + args, dataset="re10k", scene=scene, frame_idx=int(tgt_idx), intrinsics_k3=intr_all[int(tgt_idx)] + ), + ), + ) + yield from _yield_pinhole_group_batches("re10k", adapter, args) + + +def _iter_wildrgbd_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + root_map: dict[str, Path] = {} + for root in _wild_validation_roots(Path(args.data_root)): + scene_parent = root / "scenes" + for scene_dir in sorted([p for p in scene_parent.iterdir() if p.is_dir()]) if scene_parent.exists() else []: + root_map[f"{root.name}/{scene_dir.name}"] = scene_dir + for group_idx, parts in _iter_manifest_parts(args, expected_parts=3): + scene_name = str(parts[0]) + src_idx = int(parts[1]) + tgt_indices = [int(x) for x in parts[2].split(",") if x.strip()] + scene_dir = root_map.get(scene_name) + if scene_dir is None: + continue + pose_ids_np, w2c_map, intr = WildRGBDDataset._load_scene_pose_and_k(scene_dir) + pose_ids = {int(x) for x in pose_ids_np.tolist()} + rgb_ids = WildRGBDDataset._collect_frame_ids(scene_dir / "rgb") + dep_ids = WildRGBDDataset._collect_frame_ids(scene_dir / "depth") + valid_ids = pose_ids & rgb_ids & dep_ids + if int(src_idx) not in valid_ids: + continue + ds_loader = WildRGBDDataset(root=scene_dir.parent.parent, split="scenes", scene_list_file=None) + group_key = f"{scene_name}_g{group_idx:05d}" + adapter = _PinholeGroupAdapter( + scene=scene_name, + group_key=group_key, + src_idx=src_idx, + src_img=WildRGBDDataset._load_rgb_u8(WildRGBDDataset._resolve_img_path(scene_dir / "rgb", src_idx)).unsqueeze(0), + src_depth=ds_loader._load_depth_m(WildRGBDDataset._resolve_img_path(scene_dir / "depth", src_idx)).unsqueeze(0), + src_w2c=torch.from_numpy(w2c_map[src_idx]).to(torch.float32).unsqueeze(0), + src_k=intr.to(torch.float32).unsqueeze(0).clone(), + tgt_indices=tgt_indices, + load_target=lambda tgt_idx, scene_dir=scene_dir, valid_ids=valid_ids, ds_loader=ds_loader, intr=intr, w2c_map=w2c_map: None + if int(tgt_idx) not in valid_ids + else _PinholeTargetAdapter( + idx=int(tgt_idx), + img=WildRGBDDataset._load_rgb_u8( + WildRGBDDataset._resolve_img_path(scene_dir / "rgb", int(tgt_idx)) + ).unsqueeze(0), + depth=ds_loader._load_depth_m( + WildRGBDDataset._resolve_img_path(scene_dir / "depth", int(tgt_idx)) + ).unsqueeze(0), + w2c=torch.from_numpy(w2c_map[int(tgt_idx)]).to(torch.float32).unsqueeze(0), + k=intr.to(torch.float32).unsqueeze(0).clone(), + ), + ) + yield from _yield_pinhole_group_batches("wildrgbd", adapter, args) + + +def _iter_hm3d_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + root = Path(args.data_root) + manifest_in = _read_manifest_lines(getattr(args, "manifest_file", None), max_lines=_manifest_max_groups(args)) + for group_idx, raw in enumerate(manifest_in): + parts = raw.split("|") + if len(parts) != 3: + continue + scene_name = str(parts[0]) + src_idx = int(parts[1]) + tgt_indices = [int(x) for x in parts[2].split(",") if x.strip()] + scene_dir = root / scene_name + pano_dir = scene_dir / "pano" + depth_dir = scene_dir / "pano_depth" + cube_dir = scene_dir / "cubemaps" + cube_depth_dir = scene_dir / "cubemaps_depth" + if not (pano_dir.exists() and depth_dir.exists() and cube_dir.exists() and cube_depth_dir.exists()): + continue + R_np, t_np = _load_hm3d_pose(scene_dir) + group_key = f"{scene_dir.name}_g{group_idx:05d}" + src_rgb = _load_png_rgb_u8(pano_dir / f"{src_idx:05d}.png") + src_dep = _load_png_depth_m(depth_dir / f"{src_idx:05d}.png") + src_cube = _torch_load_any(cube_dir / f"{src_idx:05d}.torch") + src_cdep = _torch_load_any(cube_depth_dir / f"{src_idx:05d}.torch") + src_R = torch.from_numpy(R_np[src_idx]) + src_t = torch.from_numpy(t_np[src_idx]) + batch_size = max(1, int(getattr(args, "validation_batch_size", 1))) + samples: list[Any] = [] + tags: list[str] = [] + for tgt_idx in tgt_indices: + tgt_rgb = _load_png_rgb_u8(pano_dir / f"{tgt_idx:05d}.png") + tgt_dep = _load_png_depth_m(depth_dir / f"{tgt_idx:05d}.png") + tgt_cube = _torch_load_any(cube_dir / f"{tgt_idx:05d}.torch") + tgt_cdep = _torch_load_any(cube_depth_dir / f"{tgt_idx:05d}.torch") + sample = SimpleNamespace( + src_erp_rgb_u8=src_rgb, + tgt_erp_rgb_u8=tgt_rgb, + src_erp_depth_m=src_dep, + tgt_erp_depth_m=tgt_dep, + src_cube_rgb_u8=src_cube, + tgt_cube_rgb_u8=tgt_cube, + src_cube_depth_m=src_cdep, + tgt_cube_depth_m=tgt_cdep, + src_R=src_R, + src_t=src_t, + tgt_R=torch.from_numpy(R_np[tgt_idx]), + tgt_t=torch.from_numpy(t_np[tgt_idx]), + src_idx=src_idx, + tgt_idx=tgt_idx, + scene=scene_dir.name, + ) + samples.append(sample) + tags.append(f"{group_key}_t{tgt_idx:05d}") + if len(samples) >= batch_size: + yield from _yield_panogs_group_batches("hm3d", group_key, samples, tags, args) + samples = [] + tags = [] + yield from _yield_panogs_group_batches("hm3d", group_key, samples, tags, args) + + +def _iter_replica_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + root = _resolve_replica_test_root(Path(args.data_root)) + manifest_in = _read_manifest_lines(getattr(args, "manifest_file", None), max_lines=_manifest_max_groups(args)) + for group_idx, raw in enumerate(manifest_in): + parts = raw.split("|") + if len(parts) != 3: + continue + scene_name = str(parts[0]) + src_idx = int(parts[1]) + tgt_indices = [int(x) for x in parts[2].split(",") if x.strip()] + scene_dir = root / scene_name + pano_dir = scene_dir / "pano" + depth_dir = scene_dir / "pano_depth" + cube_dir = scene_dir / "cubemaps" + cube_depth_dir = scene_dir / "cubemaps_depth" + if not (pano_dir.exists() and depth_dir.exists() and cube_dir.exists() and cube_depth_dir.exists()): + continue + R_np, t_np = _load_hm3d_pose(scene_dir) + group_key = f"replica_{scene_dir.name}_g{group_idx:05d}" + src_rgb = _load_png_rgb_u8(pano_dir / f"{src_idx:05d}.png") + src_dep = _load_png_depth_m(depth_dir / f"{src_idx:05d}.png") + src_cube = _torch_load_any(cube_dir / f"{src_idx:05d}.torch") + src_cdep = _torch_load_any(cube_depth_dir / f"{src_idx:05d}.torch") + src_R = torch.from_numpy(R_np[src_idx]) + src_t = torch.from_numpy(t_np[src_idx]) + batch_size = max(1, int(getattr(args, "validation_batch_size", 1))) + samples: list[Any] = [] + tags: list[str] = [] + for tgt_idx in tgt_indices: + tgt_rgb = _load_png_rgb_u8(pano_dir / f"{tgt_idx:05d}.png") + tgt_dep = _load_png_depth_m(depth_dir / f"{tgt_idx:05d}.png") + tgt_cube = _torch_load_any(cube_dir / f"{tgt_idx:05d}.torch") + tgt_cdep = _torch_load_any(cube_depth_dir / f"{tgt_idx:05d}.torch") + sample = SimpleNamespace( + src_erp_rgb_u8=src_rgb, + tgt_erp_rgb_u8=tgt_rgb, + src_erp_depth_m=src_dep, + tgt_erp_depth_m=tgt_dep, + src_cube_rgb_u8=src_cube, + tgt_cube_rgb_u8=tgt_cube, + src_cube_depth_m=src_cdep, + tgt_cube_depth_m=tgt_cdep, + src_R=src_R, + src_t=src_t, + tgt_R=torch.from_numpy(R_np[tgt_idx]), + tgt_t=torch.from_numpy(t_np[tgt_idx]), + src_idx=src_idx, + tgt_idx=tgt_idx, + scene=scene_dir.name, + ) + samples.append(sample) + tags.append(f"{group_key}_t{tgt_idx:05d}") + if len(samples) >= batch_size: + yield from _yield_panogs_group_batches("replica", group_key, samples, tags, args) + samples = [] + tags = [] + yield from _yield_panogs_group_batches("replica", group_key, samples, tags, args) + + +def _iter_sim_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + root = Path(args.data_root) + pose_root = Path(getattr(args, "sim_pose_root", root / "30cm")) + manifest_in = _read_manifest_lines(getattr(args, "manifest_file", None), max_lines=_manifest_max_groups(args)) + dataset = SimPanoramaDataset( + root=root, + pose_root=pose_root, + scene_names=["AI_vol3_03"], + scene_list_file=None, + max_index_gap=10, + pair_max_translation_m=0.5, + pair_min_depth_overlap=0.0, + chunk_size=30, + shuffle_scene=False, + depth_max_m=float(getattr(args, "max_depth_m", DEFAULT_MAX_DEPTH_M)), + far_depth_invalid_m=float(getattr(args, "sim_far_depth_invalid_m", 30.0)), + far_depth_invalid_max_frac=float(getattr(args, "sim_far_depth_invalid_max_frac", 1.0)), + seed=int(args.seed), + ) + def _load_scene(scene_name: str) -> tuple[dict[int, Any], _EquirecToCube] | None: + try: + frames = dataset._load_or_build_scene_frames(scene_name) + if not frames: + return None + first_rgb = dataset._load_rgb(frames[0].rgb_path) + equ_h, equ_w = int(first_rgb.shape[1]), int(first_rgb.shape[2]) + converter = _EquirecToCube(equ_h=equ_h, equ_w=equ_w, face_w=max(1, equ_h // 2)) + frame_map = {int(frame.frame_idx): frame for frame in frames} + except Exception as exc: + LOGGER.warning("Skip SIM scene=%s: %s", scene_name, str(exc)) + return None + return frame_map, converter + + def _load_frame(frame: Any, converter: _EquirecToCube) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + rgb = dataset._load_rgb(frame.rgb_path) + depth = dataset._load_depth(frame.depth_path) + cube_rgb, cube_depth = converter.run(rgb, depth) + return rgb, depth, cube_rgb, cube_depth + + for group_idx, raw in enumerate(manifest_in): + parts = raw.split("|") + if len(parts) != 3: + continue + scene_name = str(parts[0]) + src_idx = int(parts[1]) + tgt_indices = [int(x) for x in parts[2].split(",") if x.strip()] + loaded_scene = _load_scene(scene_name) + if loaded_scene is None: + continue + frame_map, converter = loaded_scene + src_frame = frame_map.get(src_idx) + if src_frame is None: + continue + try: + src_rgb, src_dep, src_cube, src_cdep = _load_frame(src_frame, converter) + except Exception as exc: + LOGGER.warning("Skip SIM src scene=%s src=%d: %s", scene_name, int(src_idx), str(exc)) + continue + group_key = f"sim_{scene_name}_g{group_idx:05d}" + batch_size = max(1, int(getattr(args, "validation_batch_size", 1))) + samples: list[Any] = [] + tags: list[str] = [] + for tgt_idx in tgt_indices: + tgt_frame = frame_map.get(int(tgt_idx)) + if tgt_frame is None: + continue + try: + tgt_rgb, tgt_dep, tgt_cube, tgt_cdep = _load_frame(tgt_frame, converter) + except Exception: + continue + sample = SimpleNamespace( + src_erp_rgb_u8=src_rgb, + tgt_erp_rgb_u8=tgt_rgb, + src_erp_depth_m=src_dep, + tgt_erp_depth_m=tgt_dep, + src_cube_rgb_u8=src_cube, + tgt_cube_rgb_u8=tgt_cube, + src_cube_depth_m=src_cdep, + tgt_cube_depth_m=tgt_cdep, + src_R=torch.eye(3, dtype=torch.float32), + src_t=src_frame.position_xyz.clone(), + tgt_R=torch.eye(3, dtype=torch.float32), + tgt_t=tgt_frame.position_xyz.clone(), + src_idx=src_idx, + tgt_idx=int(tgt_idx), + scene=scene_name, + ) + samples.append(sample) + tags.append(f"{group_key}_t{int(tgt_idx):05d}") + if len(samples) >= batch_size: + yield from _yield_panogs_group_batches("sim", group_key, samples, tags, args) + samples = [] + tags = [] + yield from _yield_panogs_group_batches("sim", group_key, samples, tags, args) + getattr(dataset, "_scene_frames_cache", {}).pop(scene_name, None) + getattr(dataset, "_scene_pair_cache", {}).pop(scene_name, None) + + +def _iter_scannetpp_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + for group_idx, parts in _iter_manifest_parts(args, expected_parts=4): + tf = Path(parts[0]) + sample_key = str(parts[1]) + src_idx = int(parts[2]) + tgt_indices = [int(x) for x in parts[3].split(",") if x.strip()] + payload = _torch_load_any(tf) + sample_raw = payload[0] if isinstance(payload, list) else payload + if str(sample_raw.get("key", tf.stem)) != sample_key: + continue + cameras = sample_raw["cameras"].to(torch.float32) + images = sample_raw["images"] + if not isinstance(images, list) or int(cameras.shape[0]) < 11: + continue + w2c_all = [] + intr_all = [] + for i in range(int(cameras.shape[0])): + cam = cameras[i] + fx_n, fy_n, cx_n, cy_n, w0, h0 = cam[:6] + w2c = torch.eye(4, dtype=torch.float32) + w2c[:3, :] = cam[6:].reshape(3, 4) + k = torch.eye(3, dtype=torch.float32) + k[0, 0] = fx_n * w0 + k[1, 1] = fy_n * h0 + k[0, 2] = cx_n * w0 + k[1, 2] = cy_n * h0 + w2c_all.append(w2c) + intr_all.append(k) + w2c_t = torch.stack(w2c_all, dim=0) + intr_t = torch.stack(intr_all, dim=0) + group_key = f"{sample_key}_g{group_idx:05d}" + adapter = _PinholeGroupAdapter( + scene=sample_key, + group_key=group_key, + src_idx=int(src_idx), + src_img=_decode_rgb_u8(images[src_idx]).unsqueeze(0), + src_depth=_load_val_pseudo_depth_b1hw( + args, dataset="scannetpp", scene=sample_key, frame_idx=src_idx, intrinsics_k3=intr_t[src_idx] + ), + src_w2c=w2c_t[src_idx].unsqueeze(0), + src_k=intr_t[src_idx].unsqueeze(0), + tgt_indices=tgt_indices, + load_target=lambda tgt_idx, images=images, intr_t=intr_t, w2c_t=w2c_t: _PinholeTargetAdapter( + idx=int(tgt_idx), + img=_decode_rgb_u8(images[int(tgt_idx)]).unsqueeze(0), + w2c=w2c_t[int(tgt_idx)].unsqueeze(0), + k=intr_t[int(tgt_idx)].unsqueeze(0), + depth=_load_val_pseudo_depth_b1hw( + args, + dataset="scannetpp", + scene=sample_key, + frame_idx=int(tgt_idx), + intrinsics_k3=intr_t[int(tgt_idx)], + ), + ), + ) + yield from _yield_pinhole_group_batches("scannetpp", adapter, args) + + +def _iter_scanetpp_fisheye_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + root = Path(args.data_root) + loader = ScannetppFisheyeDataset( + root=root, + scene_list_file=None, + min_frame_gap=1, + max_frame_gap=10, + pair_max_translation_m=float(args.pair_max_translation_m), + shuffle_scene=False, + shuffle_frame=False, + skip_bad=True, + batch_size_hint=1, + depth_max_m=float(getattr(args, "max_depth_m", DEFAULT_MAX_DEPTH_M)), + far_depth_invalid_m=float(getattr(args, "scanetpp_fisheye_far_depth_invalid_m", 30.0)), + seed=int(args.seed), + ) + batch_size = max(1, int(getattr(args, "validation_batch_size", 1))) + for group_idx, parts in _iter_manifest_parts(args, expected_parts=4): + scene_id = str(parts[0]) + scene_dir = Path(parts[1]) + if not scene_dir.is_absolute(): + scene_dir = root / scene_dir + src_pos = int(parts[2]) + tgt_positions = [int(x) for x in parts[3].split(",") if x.strip()] + try: + camera_params, frames = loader._load_scene_frames(scene_id, scene_dir) + except Exception as exc: + LOGGER.warning("Skip ScanNet++ fisheye scene=%s: %s", scene_id, str(exc)) + continue + if not (0 <= src_pos < len(frames)): + continue + try: + src_loaded = loader._load_frame_tensor(frames[src_pos], camera_params) + except Exception as exc: + LOGGER.warning("Skip ScanNet++ fisheye src scene=%s src=%d: %s", scene_id, int(src_pos), str(exc)) + continue + pending_pos: list[int] = [] + pending_frames: list[dict[str, Any]] = [] + pending_loaded: list[dict[str, torch.Tensor]] = [] + + def _flush() -> Iterator[tuple[str, Any, str | list[str], str]]: + if not pending_pos: + return + group_key = f"scanetpp_fisheye_{scene_id}_g{group_idx:05d}" + batch = _make_scanetpp_fisheye_batch( + scene=scene_id, + src_pos=src_pos, + tgt_positions=list(pending_pos), + src_frame=frames[src_pos], + tgt_frames=list(pending_frames), + src_loaded=src_loaded, + tgt_loaded=list(pending_loaded), + ) + tags = [f"{group_key}_t{int(t):05d}" for t in pending_pos] + yield ("scanetpp_fisheye", batch, tags[0] if len(tags) == 1 else tags, group_key) + + for tgt_pos in tgt_positions: + if not (0 <= int(tgt_pos) < len(frames)): + continue + try: + tgt_loaded = loader._load_frame_tensor(frames[int(tgt_pos)], camera_params) + except Exception: + continue + pending_pos.append(int(tgt_pos)) + pending_frames.append(frames[int(tgt_pos)]) + pending_loaded.append(tgt_loaded) + if len(pending_pos) >= batch_size: + yield from _flush() + pending_pos = [] + pending_frames = [] + pending_loaded = [] + if pending_pos: + yield from _flush() + + +def _load_smx_sim_fisheye_scene(scene_dir: Path) -> tuple[dict[str, Any], list[dict[str, Any]]]: + meta = json.loads((scene_dir / "transforms.json").read_text(encoding="utf-8")) + raw_frames = list(meta.get("frames", [])) + frames: list[dict[str, Any]] = [] + for local_idx, frame in enumerate(raw_frames): + rel = Path(str(frame.get("file_path", ""))) + image_path = scene_dir / rel + if not image_path.exists(): + image_path = Path(str(frame.get("source_image", ""))) + source_image = Path(str(frame.get("source_image", ""))) + if not image_path.exists() or frame.get("transform_matrix") is None: + continue + c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32) + frames.append( + { + "image_name": image_path.name, + "image_path": image_path, + "source_image": source_image, + "w2c": torch.linalg.inv(c2w), + "idx": int(frame.get("source_image_index", local_idx)), + "pos": int(local_idx), + "yaw_pitch_roll_deg": list(frame.get("yaw_pitch_roll_deg", [0.0, 0.0, 0.0])), + } + ) + return meta, frames + + +def _smx_sim_fisheye_valid_mask(rgb_u8: torch.Tensor, meta: dict[str, Any]) -> torch.Tensor: + h, w = int(rgb_u8.shape[-2]), int(rgb_u8.shape[-1]) + yy, xx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") + cx = float(meta.get("cx", w * 0.5)) + cy = float(meta.get("cy", h * 0.5)) + radius = float(meta.get("valid_radius_px", min(h, w) * 0.5)) + circle = ((xx.to(torch.float32) - cx) ** 2 + (yy.to(torch.float32) - cy) ** 2) <= radius * radius + nonblack = rgb_u8.to(torch.float32).sum(dim=0) > 1.0 + return (circle & nonblack).to(torch.float32).unsqueeze(0) + + +def _smx_sim_rotation_yaw_pitch_roll(yaw_deg: float, pitch_deg: float, roll_deg: float) -> torch.Tensor: + yaw = math.radians(float(yaw_deg)) + pitch = math.radians(float(pitch_deg)) + roll = math.radians(float(roll_deg)) + cy, sy = math.cos(yaw), math.sin(yaw) + cp, sp = math.cos(pitch), math.sin(pitch) + cr, sr = math.cos(roll), math.sin(roll) + r_yaw = torch.tensor([[cy, 0.0, sy], [0.0, 1.0, 0.0], [-sy, 0.0, cy]], dtype=torch.float32) + r_pitch = torch.tensor([[1.0, 0.0, 0.0], [0.0, cp, sp], [0.0, -sp, cp]], dtype=torch.float32) + r_roll = torch.tensor([[cr, -sr, 0.0], [sr, cr, 0.0], [0.0, 0.0, 1.0]], dtype=torch.float32) + return (r_yaw @ r_pitch @ r_roll).to(torch.float32) + + +def _smx_sim_fisheye_grid( + *, + meta: dict[str, Any], + frame: dict[str, Any], + erp_h: int, + erp_w: int, + fish_h: int, + fish_w: int, +) -> tuple[torch.Tensor, torch.Tensor]: + yy, xx = torch.meshgrid( + torch.arange(fish_h, dtype=torch.float32) + 0.5, + torch.arange(fish_w, dtype=torch.float32) + 0.5, + indexing="ij", + ) + fx = float(meta.get("fl_x", fish_w / max(math.radians(float(meta.get("fov_deg", 130.0))), 1e-6))) + fy = float(meta.get("fl_y", fish_h / max(math.radians(float(meta.get("fov_deg", 130.0))), 1e-6))) + cx = float(meta.get("cx", fish_w * 0.5)) + cy = float(meta.get("cy", fish_h * 0.5)) + fov_rad = math.radians(float(meta.get("fov_deg", 130.0))) + half_fov = 0.5 * fov_rad + + dx = (xx - cx) / max(fx, 1e-6) + dy = (yy - cy) / max(fy, 1e-6) + theta = torch.sqrt(dx * dx + dy * dy) + valid = theta <= float(half_fov) + + scale = torch.zeros_like(theta) + nonzero = theta > 1e-8 + scale[nonzero] = torch.sin(theta[nonzero]) / theta[nonzero] + rays = torch.stack([dx * scale, -dy * scale, torch.cos(theta)], dim=-1) + rays[~nonzero] = torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32) + + ypr = list(frame.get("yaw_pitch_roll_deg", [0.0, 0.0, 0.0])) + rot = _smx_sim_rotation_yaw_pitch_roll(float(ypr[0]), float(ypr[1]), float(ypr[2])) + rays = rays @ rot.T + rays = rays / rays.norm(dim=-1, keepdim=True).clamp_min(1e-8) + + lon = torch.atan2(rays[..., 0], rays[..., 2]) + lat = torch.atan2(rays[..., 1], torch.sqrt(rays[..., 0] ** 2 + rays[..., 2] ** 2)) + map_x = (lon / (2.0 * math.pi) + 0.5) * float(erp_w) - 0.5 + map_y = (0.5 - lat / math.pi) * float(erp_h) - 0.5 + map_x = torch.remainder(map_x + 0.5, float(erp_w)) - 0.5 + x_norm = 2.0 * (map_x + 0.5) / float(erp_w) - 1.0 + y_norm = 2.0 * (map_y + 0.5) / float(erp_h) - 1.0 + grid = torch.stack([x_norm, y_norm], dim=-1).to(torch.float32) + return grid, valid.to(torch.float32).unsqueeze(0) + + +def _smx_sim_project_erp_tensor_to_fisheye( + tensor: torch.Tensor, + grid: torch.Tensor, + valid_mask: torch.Tensor, + *, + mode: str = "bilinear", +) -> torch.Tensor: + x = tensor.detach().to(torch.float32) + if x.ndim == 3: + x = x.unsqueeze(0) + if x.ndim != 4: + raise ValueError(f"Expected 3D/4D tensor for SMX projection, got shape={tuple(x.shape)}") + device = x.device + grid_b = grid.to(device=device, dtype=torch.float32).unsqueeze(0).expand(int(x.shape[0]), -1, -1, -1) + mask_b = valid_mask.to(device=device, dtype=x.dtype) + if mask_b.ndim == 3: + mask_b = mask_b.unsqueeze(0) + if mask_b.ndim != 4: + raise ValueError(f"Expected 3D/4D SMX valid mask, got shape={tuple(mask_b.shape)}") + mask_b = mask_b.expand(int(x.shape[0]), -1, -1, -1) + out = F.grid_sample(x, grid_b, mode=mode, padding_mode="zeros", align_corners=False) + return out * mask_b + + +def _project_smx_sim_fisheye_vis(vis: dict[str, Any], batch: Any, sample_idx: int) -> dict[str, Any]: + if not hasattr(batch, "smx_fisheye_src_grid"): + return vis + i = int(sample_idx) + src_grid = batch.smx_fisheye_src_grid[i] + tgt_grid = batch.smx_fisheye_tgt_grid[i] + src_mask = batch.smx_fisheye_src_valid_mask[i : i + 1] + tgt_mask = batch.smx_fisheye_tgt_valid_mask[i : i + 1] + + def _batch_value(name: str, idx: int, device: torch.device) -> torch.Tensor | None: + value = getattr(batch, name, None) + if not torch.is_tensor(value) or idx >= int(value.shape[0]): + return None + return value[idx : idx + 1].to(device=device) + + device = vis["tgt_pred"].device if torch.is_tensor(vis.get("tgt_pred", None)) else torch.device("cpu") + src_gt = _batch_value("smx_fisheye_src_rgb_u8", i, device) + tgt_gt = _batch_value("smx_fisheye_tgt_rgb_u8", i, device) + if src_gt is not None: + vis["src_gt"] = (src_gt.to(torch.float32) / 255.0).clamp(0.0, 1.0) + if tgt_gt is not None: + vis["tgt_gt"] = (tgt_gt.to(torch.float32) / 255.0).clamp(0.0, 1.0) + + for prefix, grid, mask in (("src", src_grid, src_mask), ("tgt", tgt_grid, tgt_mask)): + pred_key = f"{prefix}_pred" + alpha_key = f"{prefix}_alpha" + if torch.is_tensor(vis.get(pred_key, None)): + vis[pred_key] = _smx_sim_project_erp_tensor_to_fisheye(vis[pred_key], grid, mask, mode="bilinear") + if torch.is_tensor(vis.get(alpha_key, None)): + vis[alpha_key] = _smx_sim_project_erp_tensor_to_fisheye(vis[alpha_key], grid, mask, mode="bilinear") + for depth_key in (f"{prefix}_gt_depth", f"{prefix}_pred_depth", f"{prefix}_unik3d_depth"): + if torch.is_tensor(vis.get(depth_key, None)): + vis[depth_key] = _smx_sim_project_erp_tensor_to_fisheye(vis[depth_key], grid, mask, mode="bilinear") + + vis["src_metric_mask"] = src_mask.to(device=device, dtype=torch.float32) + vis["tgt_metric_mask"] = tgt_mask.to(device=device, dtype=torch.float32) + vis["dataset_name"] = "smx_sim_fisheye" + vis["projection_pipeline"] = "source_fisheye_to_source_pano_infer_target_pano_to_target_fisheye" + return vis + + +def _yield_smx_sim_fisheye_pano_batches( + *, + group_key: str, + samples: list[Any], + tags: list[str], + args: argparse.Namespace, +) -> Iterator[ValidationItem]: + batch_size = max(1, int(getattr(args, "validation_batch_size", 1))) + if len(samples) != len(tags): + raise ValueError(f"Expected samples/tags length match, got {len(samples)} vs {len(tags)}") + for start in range(0, len(samples), batch_size): + end = min(len(samples), start + batch_size) + chunk = samples[start:end] + batch_tags = tags[start:end] + batch = panogs_collate(chunk) + object.__setattr__(batch, "collect_all_vis", True) + object.__setattr__(batch, "disable_depth_gt", True) + for attr in ( + "smx_fisheye_src_rgb_u8", + "smx_fisheye_tgt_rgb_u8", + "smx_fisheye_src_valid_mask", + "smx_fisheye_tgt_valid_mask", + "smx_fisheye_src_grid", + "smx_fisheye_tgt_grid", + ): + object.__setattr__(batch, attr, torch.stack([getattr(s, attr) for s in chunk], dim=0)) + yield ("smx_sim_fisheye", batch, batch_tags[0] if len(batch_tags) == 1 else batch_tags, group_key) + + +def _smx_frame_position_from_w2c(frame: dict[str, Any], meta: dict[str, Any]) -> torch.Tensor: + w2c = frame["w2c"] + if torch.is_tensor(w2c): + w2c_t = w2c.detach().clone().to(torch.float32) + else: + w2c_t = torch.as_tensor(w2c, dtype=torch.float32) + raw_xyz = torch.linalg.inv(w2c_t)[:3, 3].clone() + raw_scale = float(meta.get("position_scale", 1.0)) + if abs(raw_scale) > 1e-8: + raw_xyz = raw_xyz / raw_scale + return torch.stack([raw_xyz[1], -raw_xyz[2], raw_xyz[0]], dim=0).to(torch.float32) * 0.01 + + +def _iter_smx_sim_fisheye_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + root = Path(args.data_root) + batch_size = max(1, int(getattr(args, "validation_batch_size", 1))) + for group_idx, parts in _iter_manifest_parts(args, expected_parts=4): + scene_id = str(parts[0]) + scene_dir = Path(parts[1]) + if not scene_dir.is_absolute(): + scene_dir = root / scene_dir + src_pos = int(parts[2]) + tgt_positions = [int(x) for x in parts[3].split(",") if x.strip()] + try: + meta, frames = _load_smx_sim_fisheye_scene(scene_dir) + except Exception as exc: + LOGGER.warning("Skip SMX SIM fisheye scene=%s: %s", scene_id, str(exc)) + continue + if not (0 <= src_pos < len(frames)): + continue + + def _load_frame(frame: dict[str, Any]) -> dict[str, torch.Tensor]: + source_image = Path(str(frame.get("source_image", ""))) + if not source_image.exists(): + raise FileNotFoundError(source_image) + erp_rgb = _load_png_rgb_u8(source_image) + converter = _EquirecToCube( + equ_h=int(erp_rgb.shape[-2]), + equ_w=int(erp_rgb.shape[-1]), + face_w=max(1, int(erp_rgb.shape[-2]) // 2), + ) + cube_rgb = converter.run_rgb(erp_rgb) + erp_depth = torch.zeros((1, int(erp_rgb.shape[-2]), int(erp_rgb.shape[-1])), dtype=torch.float32) + cube_depth = torch.zeros((6, int(converter.face_w), int(converter.face_w), 1), dtype=torch.float32) + fish_rgb = _load_png_rgb_u8(Path(frame["image_path"])) + grid, valid = _smx_sim_fisheye_grid( + meta=meta, + frame=frame, + erp_h=int(erp_rgb.shape[-2]), + erp_w=int(erp_rgb.shape[-1]), + fish_h=int(fish_rgb.shape[-2]), + fish_w=int(fish_rgb.shape[-1]), + ) + valid = (valid * _smx_sim_fisheye_valid_mask(fish_rgb, meta)).clamp(0.0, 1.0) + return { + "erp_rgb_u8": erp_rgb, + "erp_depth_m": erp_depth, + "cube_rgb_u8": cube_rgb, + "cube_depth_m": cube_depth, + "fish_rgb_u8": fish_rgb, + "fish_valid_mask": valid, + "fish_grid": grid, + } + + try: + src_loaded = _load_frame(frames[src_pos]) + except Exception as exc: + LOGGER.warning("Skip SMX SIM fisheye src scene=%s src=%d: %s", scene_id, int(src_pos), str(exc)) + continue + pending_pos: list[int] = [] + pending_samples: list[Any] = [] + pending_tags: list[str] = [] + + def _flush() -> Iterator[tuple[str, Any, str | list[str], str]]: + if not pending_pos: + return + group_key = f"smx_sim_fisheye_{scene_id}_g{group_idx:05d}" + yield from _yield_smx_sim_fisheye_pano_batches( + group_key=group_key, + samples=list(pending_samples), + tags=list(pending_tags), + args=args, + ) + + for tgt_pos in tgt_positions: + if not (0 <= int(tgt_pos) < len(frames)): + continue + try: + tgt_loaded = _load_frame(frames[int(tgt_pos)]) + except Exception: + continue + group_key = f"smx_sim_fisheye_{scene_id}_g{group_idx:05d}" + sample = SimpleNamespace( + src_erp_rgb_u8=src_loaded["erp_rgb_u8"], + tgt_erp_rgb_u8=tgt_loaded["erp_rgb_u8"], + src_erp_depth_m=src_loaded["erp_depth_m"], + tgt_erp_depth_m=tgt_loaded["erp_depth_m"], + src_cube_rgb_u8=src_loaded["cube_rgb_u8"], + tgt_cube_rgb_u8=tgt_loaded["cube_rgb_u8"], + src_cube_depth_m=src_loaded["cube_depth_m"], + tgt_cube_depth_m=tgt_loaded["cube_depth_m"], + src_R=torch.eye(3, dtype=torch.float32), + src_t=_smx_frame_position_from_w2c(frames[src_pos], meta), + tgt_R=torch.eye(3, dtype=torch.float32), + tgt_t=_smx_frame_position_from_w2c(frames[int(tgt_pos)], meta), + src_idx=int(frames[src_pos].get("idx", src_pos)), + tgt_idx=int(frames[int(tgt_pos)].get("idx", tgt_pos)), + scene=scene_id, + smx_fisheye_src_rgb_u8=src_loaded["fish_rgb_u8"], + smx_fisheye_tgt_rgb_u8=tgt_loaded["fish_rgb_u8"], + smx_fisheye_src_valid_mask=src_loaded["fish_valid_mask"], + smx_fisheye_tgt_valid_mask=tgt_loaded["fish_valid_mask"], + smx_fisheye_src_grid=src_loaded["fish_grid"], + smx_fisheye_tgt_grid=tgt_loaded["fish_grid"], + ) + pending_pos.append(int(tgt_pos)) + pending_samples.append(sample) + pending_tags.append(f"{group_key}_t{int(tgt_pos):05d}") + if len(pending_pos) >= batch_size: + yield from _flush() + pending_pos = [] + pending_samples = [] + pending_tags = [] + if pending_pos: + yield from _flush() + + +def _iter_tat_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + root = Path(args.data_root) + scene_roots = _colmap_scene_roots(root) + scene_root_map = {scene_root.name: scene_root for scene_root in scene_roots} + for group_idx, parts in _iter_manifest_parts(args, expected_parts=3): + scene_name = str(parts[0]) + scene_root = scene_root_map.get(scene_name) + if scene_root is None: + continue + image_dir = _colmap_image_dir(scene_root) + image_paths = sorted([p for p in image_dir.iterdir() if p.suffix.lower() in (".png", ".jpg", ".jpeg")]) + image_map = {p.name: p for p in image_paths} + colmap_entries = _load_scaled_colmap_entries(scene_root) + if not colmap_entries: + continue + image_paths = [p for p in image_paths if p.name in colmap_entries] + image_map = {p.name: p for p in image_paths} + src_name = str(parts[1]) + tgt_names = [x for x in parts[2].split(",") if x.strip()] + if src_name not in image_map: + continue + group_key = f"tat_{scene_name}_g{group_idx:05d}" + src_img = _load_png_rgb_u8(image_map[src_name]).unsqueeze(0) + src_meta = colmap_entries[src_name] + src_k = src_meta["k"].unsqueeze(0).clone() + src_w2c = src_meta["w2c"].unsqueeze(0).clone() + ref_h = int(src_meta["height"]) + ref_w = int(src_meta["width"]) + if (int(src_img.shape[-2]) != ref_h) or (int(src_img.shape[-1]) != ref_w): + sx0 = float(int(src_img.shape[-1])) / float(ref_w) + sy0 = float(int(src_img.shape[-2])) / float(ref_h) + src_k = _resize_k3_align_corners_false(src_k, sx=sx0, sy=sy0) + name_to_idx = {p.name: i for i, p in enumerate(image_paths)} + src_idx = int(name_to_idx[src_name]) + + def _load_tat_target(tgt_idx: int) -> _PinholeTargetAdapter | None: + tgt_name = image_paths[int(tgt_idx)].name + if tgt_name not in image_map or tgt_name not in colmap_entries: + return None + tgt_img = _load_png_rgb_u8(image_map[tgt_name]).unsqueeze(0) + tgt_meta = colmap_entries[tgt_name] + tgt_k = tgt_meta["k"].unsqueeze(0).clone() + if (int(tgt_img.shape[-2]) != int(tgt_meta["height"])) or (int(tgt_img.shape[-1]) != int(tgt_meta["width"])): + sx0 = float(int(tgt_img.shape[-1])) / float(int(tgt_meta["width"])) + sy0 = float(int(tgt_img.shape[-2])) / float(int(tgt_meta["height"])) + tgt_k = _resize_k3_align_corners_false(tgt_k, sx=sx0, sy=sy0) + return _PinholeTargetAdapter( + idx=int(tgt_idx), + img=tgt_img, + w2c=tgt_meta["w2c"].unsqueeze(0).clone(), + k=tgt_k, + depth=_load_val_pseudo_depth_b1hw( + args, dataset="tat", scene=scene_name, frame_idx=int(tgt_idx), intrinsics_k3=tgt_k[0] + ), + ) + + adapter = _PinholeGroupAdapter( + scene=scene_name, + group_key=group_key, + src_idx=src_idx, + src_img=src_img, + src_depth=_load_val_pseudo_depth_b1hw( + args, dataset="tat", scene=scene_name, frame_idx=src_idx, intrinsics_k3=src_k[0] + ), + src_w2c=src_w2c, + src_k=src_k, + tgt_indices=[int(name_to_idx[tgt_name]) for tgt_name in tgt_names if tgt_name in name_to_idx], + load_target=_load_tat_target, + ) + yield from _yield_pinhole_group_batches("tat", adapter, args) + + +def _dl3dv_frame_id_from_name(name: str) -> int: + return int(Path(name).stem.split("_")[-1]) + + +def _load_dl3dv_scene(scene_dir: Path) -> tuple[dict[int, Path], dict[int, torch.Tensor], dict[int, torch.Tensor]] | None: + transforms_path = scene_dir / "transforms.json" + image_dir = scene_dir / "images_4" + if not (transforms_path.exists() and image_dir.exists()): + return None + meta = json.loads(transforms_path.read_text(encoding="utf-8")) + image_paths = {int(_dl3dv_frame_id_from_name(p.name)): p for p in image_dir.glob("*.png")} + if not image_paths: + return None + orig_w = int(meta["w"]) + orig_h = int(meta["h"]) + k = torch.eye(3, dtype=torch.float32) + k[0, 0] = float(meta["fl_x"]) + k[1, 1] = float(meta["fl_y"]) + k[0, 2] = float(meta["cx"]) + k[1, 2] = float(meta["cy"]) + example_path = next(iter(image_paths.values())) + with Image.open(example_path) as img: + cur_w, cur_h = int(img.size[0]), int(img.size[1]) + k_cur = k.clone() + if cur_h != orig_h or cur_w != orig_w: + k_cur = _resize_k3_align_corners_false( + k_cur.unsqueeze(0), + sx=float(cur_w) / float(orig_w), + sy=float(cur_h) / float(orig_h), + )[0] + w2c_map: dict[int, torch.Tensor] = {} + intr_map: dict[int, torch.Tensor] = {} + for frame in meta.get("frames", []): + rel_path = str(frame.get("file_path", "")) + try: + frame_id = int(_dl3dv_frame_id_from_name(Path(rel_path).name)) + except Exception: + continue + if frame_id not in image_paths: + continue + c2w = _nerf_c2w_to_opencv_c2w(frame["transform_matrix"]) + w2c_map[frame_id] = torch.linalg.inv(c2w) + intr_map[frame_id] = k_cur.clone() + return image_paths, w2c_map, intr_map + + +def _resolve_dl3dv_scene_dir(args: argparse.Namespace, scene_name: str, scene_dir_raw: str) -> Path | None: + scene_dir = Path(scene_dir_raw) + if scene_dir.exists(): + return scene_dir + root = Path(args.data_root) + parts = scene_name.split("/", 1) + if len(parts) == 2: + candidate = root / parts[0] / parts[1] / parts[1] + if candidate.exists(): + return candidate + candidate = root / parts[0] / parts[1] + if candidate.exists(): + return candidate + return None + + +def _iter_dl3dv_manifest_items(args: argparse.Namespace) -> Iterator[ValidationItem]: + for group_idx, parts in _iter_manifest_parts(args, expected_parts=4): + scene_name = str(parts[0]) + scene_dir = _resolve_dl3dv_scene_dir(args, scene_name=scene_name, scene_dir_raw=parts[1]) + if scene_dir is None: + continue + src_idx = int(parts[2]) + tgt_indices = [int(x) for x in parts[3].split(",") if x.strip()] + loaded = _load_dl3dv_scene(scene_dir) + if loaded is None: + continue + image_paths, w2c_map, intr_map = loaded + if src_idx not in image_paths or src_idx not in w2c_map: + continue + group_key = f"dl3dv_{scene_name.replace('/', '_')}_g{group_idx:05d}" + adapter = _PinholeGroupAdapter( + scene=scene_name, + group_key=group_key, + src_idx=src_idx, + src_img=_load_png_rgb_u8(image_paths[src_idx]).unsqueeze(0), + src_depth=_load_val_pseudo_depth_b1hw( + args, dataset="dl3dv", scene=scene_name, frame_idx=src_idx, intrinsics_k3=intr_map[src_idx] + ), + src_w2c=w2c_map[src_idx].unsqueeze(0), + src_k=intr_map[src_idx].unsqueeze(0).clone(), + tgt_indices=tgt_indices, + load_target=lambda tgt_idx, image_paths=image_paths, w2c_map=w2c_map, intr_map=intr_map: None + if int(tgt_idx) not in image_paths or int(tgt_idx) not in w2c_map + else _PinholeTargetAdapter( + idx=int(tgt_idx), + img=_load_png_rgb_u8(image_paths[int(tgt_idx)]).unsqueeze(0), + w2c=w2c_map[int(tgt_idx)].unsqueeze(0), + k=intr_map[int(tgt_idx)].unsqueeze(0).clone(), + depth=_load_val_pseudo_depth_b1hw( + args, + dataset="dl3dv", + scene=scene_name, + frame_idx=int(tgt_idx), + intrinsics_k3=intr_map[int(tgt_idx)], + ), + ), + ) + yield from _yield_pinhole_group_batches("dl3dv", adapter, args) + + +def _iter_dataset_items(args: argparse.Namespace) -> Iterable[ValidationItem]: + if getattr(args, "manifest_file", None) is None: + raise ValueError("Validation requires --manifest-file. Build manifests first with scripts/build_validation_manifests.py.") + dataset = str(args.dataset) + if dataset == "re10k": + return _iter_re10k_manifest_items(args) + if dataset == "dl3dv": + return _iter_dl3dv_manifest_items(args) + if dataset == "replica": + return _iter_replica_manifest_items(args) + if dataset == "sim": + return _iter_sim_manifest_items(args) + if dataset == "wildrgbd": + return _iter_wildrgbd_manifest_items(args) + if dataset == "scannetpp": + return _iter_scannetpp_manifest_items(args) + if dataset == "scanetpp_fisheye": + return _iter_scanetpp_fisheye_manifest_items(args) + if dataset == "smx_sim_fisheye": + return _iter_smx_sim_fisheye_manifest_items(args) + if dataset == "tat": + return _iter_tat_manifest_items(args) + return _iter_hm3d_manifest_items(args) + + +def _finalize_validation_group( + *, + dataset: str, + step: int, + vis_dir: Path, + group_key: str, + group_items: list[dict[str, Any]], +) -> dict[str, float] | None: + if not group_items: + return None + group_row = { + "psnr": _safe_nanmean([float(e["row"]["psnr"]) for e in group_items]), + "ssim": _safe_nanmean([float(e["row"]["ssim"]) for e in group_items]), + "lpips": _safe_nanmean([float(e["row"]["lpips"]) for e in group_items]), + } + group_dir = vis_dir / group_key + _save_group_pair_pngs(group_dir, group_items) + _save_perspective_group_grid( + group_dir=group_dir, + group_key=group_key, + step=int(step), + group_items=group_items, + ) + _save_group_gif( + dataset=dataset, + group_dir=group_dir, + group_key=group_key, + step=int(step), + group_items=group_items, + ) + visual_items = [item for item in group_items if isinstance(item.get("vis", None), dict)] + if dataset in {"hm3d", "replica"}: + for j, item in enumerate(visual_items[:10], start=1): + _save_vis_from_payload( + item["vis"], + vis_dir=group_dir, + tag=f"{group_key}_t{j:02d}", + step=int(step), + ) + return group_row + +def run_validation(args: argparse.Namespace) -> None: + random.seed(int(args.seed)) + np.random.seed(int(args.seed)) + torch.manual_seed(int(args.seed)) + + dev = torch.device(args.device) + model, step = _load_model(Path(args.checkpoint), dev) + trainer = _build_trainer(model, dev, args) + metrics_calc = MetricsCalculator(device=dev, compute_lpips=not bool(getattr(args, "fast_metrics", False))) + + dataset = str(args.dataset) + items = _iter_dataset_items(args) + + if getattr(args, "out_dir", None) is not None: + out_dir = Path(args.out_dir) + else: + out_dir = Path(args.checkpoint).parent / f"validation_{dataset}" + out_dir.mkdir(parents=True, exist_ok=True) + vis_dir = out_dir / "vis" + vis_dir.mkdir(parents=True, exist_ok=True) + sample_csv = out_dir / f"validation_sample_metrics_{dataset}.csv" + + + group_rows: list[dict[str, float]] = [] + failure_rows: list[dict[str, Any]] = [] + current_group_key: str | None = None + current_group_items: list[dict[str, Any]] = [] + num_rows = 0 + + LOGGER.info("Validation start: dataset=%s checkpoint=%s", dataset, str(args.checkpoint)) + pbar = tqdm(items, desc=f"validate_{dataset}", leave=False, disable=True) + for i, (dataset_name, batch, tag, group_key) in enumerate(pbar): + if current_group_key is not None and group_key != current_group_key: + group_row = _finalize_validation_group( + dataset=dataset, + step=int(step), + vis_dir=vis_dir, + group_key=current_group_key, + group_items=current_group_items, + ) + if group_row is not None: + group_rows.append(group_row) + num_rows += len(current_group_items) + pbar.set_postfix(groups=len(group_rows), targets=num_rows, refresh=False) + current_group_items = [] + current_group_key = group_key + try: + with torch.no_grad(): + result = trainer.process_batch( + batch, + dataset_name=dataset_name, + step=int(step), + need_vis=True, + ) + vis_payloads = result.get("vis_payloads", None) + if isinstance(vis_payloads, list) and vis_payloads: + vis_list = [v for v in vis_payloads if isinstance(v, dict)] + else: + vis = result.get("vis_payload", None) + vis_list = [vis] if isinstance(vis, dict) else [] + if not vis_list: + continue + if str(dataset_name) == "smx_sim_fisheye": + vis_list = [_project_smx_sim_fisheye_vis(vis, batch, j) for j, vis in enumerate(vis_list)] + tags = tag if isinstance(tag, list) else [tag] + metric_mask = metric_mask_from_pinhole_batch( + batch, + dataset=str(dataset_name), + cache_dir=Path(args.metric_mask_cache_dir) if args.metric_mask_cache_dir is not None else None, + device=dev, + ) + for j, vis in enumerate(vis_list): + if torch.is_tensor(vis.get("tgt_metric_mask", None)): + vis["tgt_training_mask"] = vis["tgt_metric_mask"].detach() + if torch.is_tensor(metric_mask): + vis_b = int(vis["tgt_gt"].shape[0]) if torch.is_tensor(vis.get("tgt_gt", None)) else 1 + if len(vis_list) == 1 and vis_b == int(metric_mask.shape[0]): + vis["tgt_metric_mask"] = metric_mask.detach() + elif j < int(metric_mask.shape[0]): + vis["tgt_metric_mask"] = metric_mask[j : j + 1].detach() + row = _compute_metrics_from_vis( + vis, + metrics_calc=metrics_calc, + ) + item_tag = str(tags[j]) if j < len(tags) else str(tag) + _append_sample_metrics_row(sample_csv, str(group_key), item_tag, row) + item = { + "dataset_name": dataset_name, + "tag": item_tag, + "row": row, + } + if len(current_group_items) < 10: + item["vis"] = vis + current_group_items.append(item) + pbar.set_postfix( + groups=len(group_rows), + targets=num_rows + len(current_group_items), + refresh=False, + ) + except Exception as e: + LOGGER.warning("Skip %s sample idx=%d: %s", dataset, int(i), str(e)) + failure_rows.append( + { + "step": int(step), + "sample_idx": int(i), + "dataset": str(dataset), + "tag": str(tag), + "group_key": str(group_key), + "error": str(e), + } + ) + if "cuda" in str(e).lower(): + raise RuntimeError( + f"CUDA error during {dataset} validation at sample idx={int(i)}; " + "the CUDA context may be corrupted, so this validation round must fail." + ) from e + + if failure_rows: + fail_csv = out_dir / f"validation_failures_{dataset}_step_{int(step):07d}.csv" + with fail_csv.open("w", newline="") as f: + writer = csv.DictWriter( + f, + fieldnames=["step", "sample_idx", "dataset", "tag", "group_key", "error"], + ) + writer.writeheader() + writer.writerows(failure_rows) + + if current_group_key is not None: + group_row = _finalize_validation_group( + dataset=dataset, + step=int(step), + vis_dir=vis_dir, + group_key=current_group_key, + group_items=current_group_items, + ) + if group_row is not None: + group_rows.append(group_row) + num_rows += len(current_group_items) + + if not group_rows: + if dataset in ("scannetpp", "scanetpp_fisheye", "tat"): + LOGGER.warning("No validation samples processed for dataset=%s; skip this round.", dataset) + return + raise RuntimeError(f"No validation samples processed for dataset={dataset}") + + agg = _aggregate_rows(group_rows) + agg["step"] = float(step) + + csv_main = out_dir / f"validation_metrics_{dataset}.csv" + _append_metrics_row(csv_main, agg) + + LOGGER.info( + "Validation done: dataset=%s groups=%d samples=%d psnr=%.3f ssim=%.4f lpips=%.4f", + dataset, + int(len(group_rows)), + int(num_rows), + float(agg.get("psnr", float("nan"))), + float(agg.get("ssim", float("nan"))), + float(agg.get("lpips", float("nan"))), + ) + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Unified UniSharp validation") + p.add_argument("--checkpoint", type=Path, required=True) + p.add_argument( + "--dataset", + type=str, + required=True, + choices=[ + "re10k", + "dl3dv", + "hm3d", + "replica", + "sim", + "wildrgbd", + "scannetpp", + "scanetpp_fisheye", + "smx_sim_fisheye", + "tat", + ], + ) + p.add_argument("--data-root", type=Path, required=True) + p.add_argument("--device", type=str, default="cuda:0") + p.add_argument("--seed", type=int, default=42) + + p.add_argument("--max-index-gap", type=int, default=10) + p.add_argument("--pair-max-translation-m", type=float, default=0.5) + p.add_argument("--pair-min-overlap", type=float, default=0.6) + p.add_argument("--split", type=str, default="test") + p.add_argument("--manifest-file", type=Path, default=None) + p.add_argument("--manifest-max-groups", type=int, default=0) + p.add_argument("--validation-batch-size", type=int, default=10) + p.add_argument("--out-dir", type=Path, default=None) + p.add_argument("--fast-metrics", action="store_true", help="Skip LPIPS during validation; keep PSNR/SSIM/depth metrics.") + p.add_argument("--metric-mask-cache-dir", type=Path, default=default_metric_mask_cache_dir()) + p.add_argument("--max-depth-m", type=float, default=None) + p.add_argument("--sim-far-depth-invalid-m", type=float, default=None) + p.add_argument("--sim-far-depth-invalid-max-frac", type=float, default=None) + p.add_argument("--re10k-pseudo-far-depth-invalid-m", type=float, default=None) + p.add_argument("--scanetpp-fisheye-far-depth-invalid-m", type=float, default=None) + p.add_argument("--low-pass-filter-eps", type=float, default=None) + p.add_argument( + "--validation-pseudo-depth-root", + type=Path, + default=Path("/media/team_data/ML4_team/datasets/sharp/validation_unik3d_pseudo_depth"), + ) + p.add_argument("--sim-pose-root", type=Path, default=Path("/media/team_data/ML4_team/datasets/smx_sim/30cm")) + p.add_argument( + "--re10k-pseudo-depth-root", + type=Path, + default=Path("/media/team_data/ML4_team/datasets/nopose/re10k_unik3d_pseudo_depth/test"), + ) + return p + + +def main() -> None: + _configure_torchhub_cache() + args = build_parser().parse_args() + _apply_training_depth_config_defaults(args) + run_validation(args) + + +if __name__ == "__main__": + main()