Insta360-Research commited on
Commit
8b03647
·
verified ·
1 Parent(s): 5ea10ea

Upload 119 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. unisharp/.DS_Store +0 -0
  3. unisharp/__init__.py +1 -0
  4. unisharp/cli/__init__.py +13 -0
  5. unisharp/cli/__main__.py +12 -0
  6. unisharp/cli/__pycache__/__init__.cpython-310.pyc +0 -0
  7. unisharp/cli/__pycache__/__init__.cpython-313.pyc +0 -0
  8. unisharp/cli/__pycache__/mixed_sampler.cpython-313.pyc +0 -0
  9. unisharp/cli/__pycache__/train_feature.cpython-310.pyc +0 -0
  10. unisharp/cli/__pycache__/train_feature.cpython-313.pyc +0 -0
  11. unisharp/cli/__pycache__/train_utils.cpython-313.pyc +0 -0
  12. unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc +3 -0
  13. unisharp/cli/mixed_sampler.py +80 -0
  14. unisharp/cli/train_feature.py +1410 -0
  15. unisharp/cli/train_utils.py +130 -0
  16. unisharp/cli/unified_trainer.py +1966 -0
  17. unisharp/datasets/__pycache__/dl3dv.cpython-310.pyc +0 -0
  18. unisharp/datasets/__pycache__/dl3dv.cpython-313.pyc +0 -0
  19. unisharp/datasets/__pycache__/pair_sampling.cpython-310.pyc +0 -0
  20. unisharp/datasets/__pycache__/pair_sampling.cpython-313.pyc +0 -0
  21. unisharp/datasets/__pycache__/panogs.cpython-310.pyc +0 -0
  22. unisharp/datasets/__pycache__/panogs.cpython-313.pyc +0 -0
  23. unisharp/datasets/__pycache__/re10k.cpython-310.pyc +0 -0
  24. unisharp/datasets/__pycache__/re10k.cpython-313.pyc +0 -0
  25. unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-310.pyc +0 -0
  26. unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-313.pyc +0 -0
  27. unisharp/datasets/__pycache__/sim_panorama.cpython-310.pyc +0 -0
  28. unisharp/datasets/__pycache__/sim_panorama.cpython-313.pyc +0 -0
  29. unisharp/datasets/__pycache__/wildrgbd.cpython-310.pyc +0 -0
  30. unisharp/datasets/__pycache__/wildrgbd.cpython-313.pyc +0 -0
  31. unisharp/datasets/dl3dv.py +305 -0
  32. unisharp/datasets/pair_sampling.py +99 -0
  33. unisharp/datasets/panogs.py +555 -0
  34. unisharp/datasets/re10k.py +718 -0
  35. unisharp/datasets/scannetpp_fisheye.py +491 -0
  36. unisharp/datasets/sim_panorama.py +497 -0
  37. unisharp/datasets/wildrgbd.py +352 -0
  38. unisharp/losses/__init__.py +4 -0
  39. unisharp/losses/__pycache__/__init__.cpython-310.pyc +0 -0
  40. unisharp/losses/__pycache__/__init__.cpython-313.pyc +0 -0
  41. unisharp/losses/__pycache__/unisharp_loss.cpython-310.pyc +0 -0
  42. unisharp/losses/__pycache__/unisharp_loss.cpython-313.pyc +0 -0
  43. unisharp/losses/unisharp_loss.py +1120 -0
  44. unisharp/models/__init__.py +23 -0
  45. unisharp/models/__pycache__/__init__.cpython-310.pyc +0 -0
  46. unisharp/models/__pycache__/__init__.cpython-313.pyc +0 -0
  47. unisharp/models/__pycache__/blocks.cpython-310.pyc +0 -0
  48. unisharp/models/__pycache__/blocks.cpython-313.pyc +0 -0
  49. unisharp/models/__pycache__/decoder.cpython-310.pyc +0 -0
  50. unisharp/models/__pycache__/decoder.cpython-313.pyc +0 -0
.gitattributes CHANGED
@@ -3,3 +3,4 @@
3
  *.safetensors filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
5
  examples/omnirooms/*.jpg filter=lfs diff=lfs merge=lfs -text
 
 
3
  *.safetensors filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
5
  examples/omnirooms/*.jpg filter=lfs diff=lfs merge=lfs -text
6
+ unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc filter=lfs diff=lfs merge=lfs -text
unisharp/.DS_Store ADDED
Binary file (6.15 kB). View file
 
unisharp/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ DEFAULT_MAX_DEPTH_M: float = 100.0
unisharp/cli/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import click
4
+
5
+ from .train_feature import train_feature_cli
6
+
7
+
8
+ @click.group()
9
+ def main_cli():
10
+ pass
11
+
12
+ main_cli.add_command(train_feature_cli, "train-feature")
13
+
unisharp/cli/__main__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from unisharp.cli import main_cli
4
+
5
+
6
+ def main() -> None:
7
+ main_cli()
8
+
9
+
10
+ if __name__ == "__main__":
11
+ main()
12
+
unisharp/cli/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (480 Bytes). View file
 
unisharp/cli/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (581 Bytes). View file
 
unisharp/cli/__pycache__/mixed_sampler.cpython-313.pyc ADDED
Binary file (5.33 kB). View file
 
unisharp/cli/__pycache__/train_feature.cpython-310.pyc ADDED
Binary file (42 kB). View file
 
unisharp/cli/__pycache__/train_feature.cpython-313.pyc ADDED
Binary file (74 kB). View file
 
unisharp/cli/__pycache__/train_utils.cpython-313.pyc ADDED
Binary file (7.03 kB). View file
 
unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2698667885fba54eef04bacbbee4bbf897c0dc3df57e6fe7d10ba185a76d2ed
3
+ size 103553
unisharp/cli/mixed_sampler.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ import random
5
+ from typing import Any, Iterator
6
+
7
+ from torch.utils.data import Dataset, IterableDataset
8
+
9
+
10
+ class LazyDataLoaderIterator:
11
+
12
+ def __init__(self, dataloader: Any):
13
+ self.dataloader = dataloader
14
+ self.iterator: Iterator[Any] | None = None
15
+
16
+ def __next__(self) -> Any:
17
+ if self.iterator is None:
18
+ self.iterator = iter(self.dataloader)
19
+ return next(self.iterator)
20
+
21
+
22
+ class MixedDatasetSampler:
23
+
24
+ def __init__(
25
+ self,
26
+ datasets: dict[str, Dataset | IterableDataset],
27
+ weights: dict[str, float],
28
+ iterators: dict[str, Iterator[Any]],
29
+ seed: int | None = None,
30
+ ):
31
+ self.datasets = datasets
32
+ self.weights = weights
33
+ self.iterators = iterators
34
+ self._rng = random.Random(seed)
35
+
36
+ if len(weights) == 0:
37
+ raise ValueError("weights is empty")
38
+ for name, w in weights.items():
39
+ if float(w) <= 0.0:
40
+ raise ValueError(f"Dataset weight must be > 0, got {name}={float(w)}")
41
+ if name not in datasets:
42
+ raise ValueError(f"Unknown dataset in weights: {name}")
43
+ if name not in iterators:
44
+ raise ValueError(f"Missing iterator for dataset: {name}")
45
+
46
+ total_weight = float(sum(float(v) for v in weights.values()))
47
+ self.probs = {name: float(w) / total_weight for name, w in weights.items()}
48
+ self.dataset_names = list(datasets.keys())
49
+ self.prob_list = [self.probs[name] for name in self.dataset_names]
50
+
51
+ def sample(self) -> tuple[str, Any]:
52
+ dataset_name = self.choose_dataset_name()
53
+ batch = self.next_batch(dataset_name)
54
+ return dataset_name, batch
55
+
56
+ def choose_dataset_name(self, allowed_dataset_names: list[str] | None = None) -> str:
57
+ if allowed_dataset_names is None:
58
+ names = self.dataset_names
59
+ probs = self.prob_list
60
+ else:
61
+ names = [name for name in self.dataset_names if name in set(allowed_dataset_names)]
62
+ if len(names) == 0:
63
+ raise ValueError("No allowed dataset names available for sampling.")
64
+ probs = [self.probs[name] for name in names]
65
+ return self._rng.choices(names, weights=probs, k=1)[0]
66
+
67
+ def next_batch(self, dataset_name: str) -> Any:
68
+ if dataset_name not in self.iterators:
69
+ raise ValueError(f"Unknown dataset iterator: {dataset_name}")
70
+ try:
71
+ batch = next(self.iterators[dataset_name])
72
+ except StopIteration as exc:
73
+ raise StopIteration(f"Dataset {dataset_name} exhausted") from exc
74
+ return batch
75
+
76
+ def get_sampling_stats(self) -> dict[str, float]:
77
+ return {
78
+ "probabilities": self.probs.copy(),
79
+ "sampling": self.weights.copy(),
80
+ }
unisharp/cli/train_feature.py ADDED
@@ -0,0 +1,1410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import json
5
+ import logging
6
+ import os
7
+ import random
8
+ import sys
9
+ import time
10
+ from dataclasses import fields, is_dataclass, replace
11
+ from datetime import datetime, timedelta
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import click
16
+ import numpy as np
17
+ import torch
18
+ import torch.distributed as dist
19
+ import torch.nn.functional as F
20
+ from torch.nn.parallel import DistributedDataParallel as DDP
21
+ from torch.utils.data import DataLoader
22
+ from torch.utils.data.distributed import DistributedSampler
23
+
24
+ from unisharp.datasets.re10k import Re10KDataset, re10k_collate, re10k_passthrough
25
+ from unisharp.datasets.wildrgbd import WildRGBDDataset, wildrgbd_collate
26
+ from unisharp.datasets.dl3dv import DL3DVDataset
27
+ from unisharp.datasets.scannetpp_fisheye import ScannetppFisheyeDataset, scannetpp_fisheye_passthrough
28
+ from unisharp.datasets.sim_panorama import SimPanoramaDataset
29
+ from unisharp.datasets.panogs import PanOGSDataset, panogs_collate
30
+ from unisharp.losses import UnisharpLoss, UnisharpLossWeights
31
+ from unisharp.models.unisharp_feature import UnisharpFeatureModel, UnisharpFeatureConfig
32
+ from unisharp.utils import logging as logging_utils
33
+ from unisharp import DEFAULT_MAX_DEPTH_M
34
+ from unisharp.utils.gsplat import GSplatRenderer
35
+ from unisharp.utils.io import save_image
36
+ from unisharp.utils.rayfit_camera import scale_pinhole_intrinsics
37
+ from unisharp.utils.unified_vis import save_pair_visualization
38
+
39
+ from .mixed_sampler import LazyDataLoaderIterator, MixedDatasetSampler # type: ignore[import]
40
+ from .train_utils import warmup_cosine_lr # type: ignore[import]
41
+
42
+ LOGGER = logging.getLogger(__name__)
43
+ REPO_ROOT = Path(__file__).resolve().parents[2]
44
+
45
+
46
+ def _default_dataset_manifest_file(name: str) -> Path:
47
+ parent_path = REPO_ROOT.parent / "dataset_manifests" / name
48
+ if parent_path.exists():
49
+ return parent_path
50
+ return REPO_ROOT / "dataset_manifests" / name
51
+
52
+
53
+ DEFAULT_WILDRGBD_ROOTS_FILE = _default_dataset_manifest_file("wildrgbd_roots.txt")
54
+
55
+
56
+ def _multiple_aligned_hw(hw: tuple[int, int], multiple: int) -> tuple[int, int]:
57
+ h, w = int(hw[0]), int(hw[1])
58
+ m = int(multiple)
59
+ if m <= 1:
60
+ return h, w
61
+ out_h = max(m, (h // m) * m)
62
+ out_w = max(m, (w // m) * m)
63
+ return min(out_h, h), min(out_w, w)
64
+
65
+
66
+ def _erp_multiple_aligned_hw(hw: tuple[int, int], multiple: int) -> tuple[int, int]:
67
+ h, w = int(hw[0]), int(hw[1])
68
+ m = int(multiple)
69
+ if m <= 1:
70
+ return h, w
71
+ max_h_from_h = h // m
72
+ max_h_from_w = w // (2 * m)
73
+ h_units = min(max_h_from_h, max_h_from_w)
74
+ if h_units <= 0:
75
+ return h, w
76
+ out_h = h_units * m
77
+ return out_h, 2 * out_h
78
+
79
+
80
+ def _resize_chw_tensor(x: torch.Tensor, dst_hw: tuple[int, int], *, kind: str) -> torch.Tensor:
81
+ if not torch.is_tensor(x) or x.ndim < 3:
82
+ return x
83
+ src_hw = (int(x.shape[-2]), int(x.shape[-1]))
84
+ if src_hw == tuple(int(v) for v in dst_hw):
85
+ return x
86
+ orig_dtype = x.dtype
87
+ flat = x.reshape(-1, int(x.shape[-3]), src_hw[0], src_hw[1]).to(dtype=torch.float32)
88
+ if kind == "image":
89
+ y = F.interpolate(flat, size=dst_hw, mode="bilinear", align_corners=False)
90
+ y = y.round().clamp(0.0, 255.0).to(dtype=orig_dtype) if orig_dtype == torch.uint8 else y.to(dtype=orig_dtype)
91
+ elif kind == "ray":
92
+ y = F.interpolate(flat, size=dst_hw, mode="bilinear", align_corners=False)
93
+ y = y / torch.linalg.vector_norm(y, dim=1, keepdim=True).clamp(min=1e-6)
94
+ y = y.to(dtype=orig_dtype)
95
+ else:
96
+ y = F.interpolate(flat, size=dst_hw, mode="nearest").to(dtype=orig_dtype)
97
+ return y.reshape(*x.shape[:-2], int(dst_hw[0]), int(dst_hw[1])).contiguous()
98
+
99
+
100
+ def _resize_cube_tensor(x: torch.Tensor, dst_hw: tuple[int, int], *, kind: str) -> torch.Tensor:
101
+ if not torch.is_tensor(x) or x.ndim < 4:
102
+ return x
103
+ src_hw = (int(x.shape[-3]), int(x.shape[-2]))
104
+ if src_hw == tuple(int(v) for v in dst_hw):
105
+ return x
106
+ orig_dtype = x.dtype
107
+ channels = int(x.shape[-1])
108
+ flat = x.reshape(-1, src_hw[0], src_hw[1], channels).permute(0, 3, 1, 2).to(dtype=torch.float32)
109
+ if kind == "image":
110
+ y = F.interpolate(flat, size=dst_hw, mode="bilinear", align_corners=False)
111
+ y = y.round().clamp(0.0, 255.0).to(dtype=orig_dtype) if orig_dtype == torch.uint8 else y.to(dtype=orig_dtype)
112
+ else:
113
+ y = F.interpolate(flat, size=dst_hw, mode="nearest").to(dtype=orig_dtype)
114
+ y = y.permute(0, 2, 3, 1)
115
+ return y.reshape(*x.shape[:-3], int(dst_hw[0]), int(dst_hw[1]), channels).contiguous()
116
+
117
+
118
+ def _training_batch_src_hw(batch: Any) -> tuple[int, int] | None:
119
+ for name in ("src_rgb_u8", "src_erp_rgb_u8"):
120
+ value = getattr(batch, name, None)
121
+ if torch.is_tensor(value) and value.ndim >= 3:
122
+ return int(value.shape[-2]), int(value.shape[-1])
123
+ return None
124
+
125
+
126
+ def _scale_fisheye624_params_any(params: torch.Tensor, *, src_hw: tuple[int, int], dst_hw: tuple[int, int]) -> torch.Tensor:
127
+ if tuple(int(x) for x in src_hw) == tuple(int(x) for x in dst_hw):
128
+ return params
129
+ src_h, src_w = int(src_hw[0]), int(src_hw[1])
130
+ dst_h, dst_w = int(dst_hw[0]), int(dst_hw[1])
131
+ sx = float(dst_w) / float(max(src_w, 1))
132
+ sy = float(dst_h) / float(max(src_h, 1))
133
+ out = params.clone()
134
+ out[..., 0] *= sx
135
+ out[..., 1] *= sy
136
+ out[..., 2] = (out[..., 2] + 0.5) * sx - 0.5
137
+ out[..., 3] = (out[..., 3] + 0.5) * sy - 0.5
138
+ return out
139
+
140
+
141
+ def _resize_training_batch_to_multiple(batch: Any, multiple: int) -> Any:
142
+ if int(multiple) <= 1 or not is_dataclass(batch):
143
+ return batch
144
+ src_hw = _training_batch_src_hw(batch)
145
+ if src_hw is None:
146
+ return batch
147
+
148
+ def _view_hw(prefix: str) -> tuple[int, int] | None:
149
+ for rgb_name in (f"{prefix}_rgb_u8", f"{prefix}_erp_rgb_u8"):
150
+ rgb = getattr(batch, rgb_name, None)
151
+ if torch.is_tensor(rgb) and rgb.ndim >= 3:
152
+ return int(rgb.shape[-2]), int(rgb.shape[-1])
153
+ return None
154
+
155
+ def _aligned_view_hw(prefix: str, hw: tuple[int, int]) -> tuple[int, int]:
156
+ is_view_erp = torch.is_tensor(getattr(batch, f"{prefix}_erp_rgb_u8", None))
157
+ return (
158
+ _erp_multiple_aligned_hw(hw, int(multiple))
159
+ if bool(is_view_erp)
160
+ else _multiple_aligned_hw(hw, int(multiple))
161
+ )
162
+
163
+ def _field_dst_hw(name: str, value: torch.Tensor) -> tuple[int, int]:
164
+ prefix = "tgt" if name.startswith("tgt_") else "src"
165
+ view_hw = _view_hw(prefix)
166
+ if view_hw is not None:
167
+ return _aligned_view_hw(prefix, view_hw)
168
+ hw = (int(value.shape[-2]), int(value.shape[-1]))
169
+ return _erp_multiple_aligned_hw(hw, int(multiple)) if "_erp_" in name else _multiple_aligned_hw(hw, int(multiple))
170
+
171
+ updates: dict[str, Any] = {}
172
+ for field in fields(batch):
173
+ name = field.name
174
+ value = getattr(batch, name)
175
+ if not torch.is_tensor(value):
176
+ continue
177
+ if name.endswith("_rgb_u8") and value.ndim >= 3:
178
+ if "_cube_" in name:
179
+ cube_hw = _multiple_aligned_hw((int(value.shape[-3]), int(value.shape[-2])), int(multiple))
180
+ updates[name] = _resize_cube_tensor(value, cube_hw, kind="image")
181
+ else:
182
+ updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="image")
183
+ elif name.endswith("_depth_m") and value.ndim >= 3:
184
+ if "_cube_" in name:
185
+ cube_hw = _multiple_aligned_hw((int(value.shape[-3]), int(value.shape[-2])), int(multiple))
186
+ updates[name] = _resize_cube_tensor(value, cube_hw, kind="depth")
187
+ else:
188
+ updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="depth")
189
+ elif name.endswith("_valid_mask") and value.ndim >= 3:
190
+ updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="depth")
191
+ elif name.endswith("_rays") and value.ndim >= 3:
192
+ updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="ray")
193
+
194
+ for intr_name in ("src_intrinsics", "tgt_intrinsics"):
195
+ intr = getattr(batch, intr_name, None)
196
+ if torch.is_tensor(intr):
197
+ prefix = "tgt" if intr_name.startswith("tgt_") else "src"
198
+ view_hw = _view_hw(prefix)
199
+ if view_hw is not None:
200
+ updates[intr_name] = scale_pinhole_intrinsics(
201
+ intr,
202
+ src_hw=view_hw,
203
+ dst_hw=_aligned_view_hw(prefix, view_hw),
204
+ )
205
+ for params_name in ("src_camera_params", "tgt_camera_params"):
206
+ params = getattr(batch, params_name, None)
207
+ if torch.is_tensor(params):
208
+ prefix = "tgt" if params_name.startswith("tgt_") else "src"
209
+ view_hw = _view_hw(prefix)
210
+ if view_hw is not None:
211
+ updates[params_name] = _scale_fisheye624_params_any(
212
+ params,
213
+ src_hw=view_hw,
214
+ dst_hw=_aligned_view_hw(prefix, view_hw),
215
+ )
216
+
217
+ return replace(batch, **updates) if updates else batch
218
+
219
+
220
+ def _build_optimizer_param_groups(
221
+ raw_model: UnisharpFeatureModel,
222
+ ) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter], list[torch.nn.Parameter]]:
223
+ base_params: list[torch.nn.Parameter] = []
224
+ unik3d_encoder_params: list[torch.nn.Parameter] = []
225
+ unik3d_decoder_params: list[torch.nn.Parameter] = []
226
+ for name, param in raw_model.named_parameters():
227
+ if not param.requires_grad:
228
+ continue
229
+ if name.startswith("feature_extractor.unik3d.pixel_encoder."):
230
+ unik3d_encoder_params.append(param)
231
+ elif name.startswith("second_layer_depth_head."):
232
+ unik3d_decoder_params.append(param)
233
+ elif name.startswith("feature_extractor.unik3d."):
234
+ unik3d_decoder_params.append(param)
235
+ else:
236
+ base_params.append(param)
237
+ return base_params, unik3d_encoder_params, unik3d_decoder_params
238
+
239
+
240
+ def _count_numel(params: list[torch.nn.Parameter]) -> int:
241
+ return int(sum(int(p.numel()) for p in params))
242
+
243
+
244
+ def _configure_torchhub_cache() -> Path:
245
+ torchhub_dir = REPO_ROOT / "checkpoints" / "torchhub"
246
+ torchhub_dir.mkdir(parents=True, exist_ok=True)
247
+ os.environ["TORCH_HOME"] = str(torchhub_dir)
248
+ torch.hub.set_dir(str(torchhub_dir))
249
+ return torchhub_dir
250
+
251
+
252
+ def _ddp_is_enabled() -> bool:
253
+ return int(os.environ.get("WORLD_SIZE", "1")) > 1
254
+
255
+
256
+ def _ddp_setup(device: str, ddp_timeout_hours: float = 8.0) -> tuple[torch.device, int, int, bool]:
257
+ if not _ddp_is_enabled():
258
+ dev = torch.device(device)
259
+ return dev, 0, 1, True
260
+
261
+ if device != "cuda":
262
+ raise RuntimeError("DDP currently supports CUDA only.")
263
+ if not torch.cuda.is_available():
264
+ raise RuntimeError("CUDA not available.")
265
+
266
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
267
+ rank = int(os.environ.get("RANK", "0"))
268
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
269
+ torch.cuda.set_device(local_rank)
270
+ timeout_hours = max(float(ddp_timeout_hours), 0.25)
271
+ if rank == 0:
272
+ print(
273
+ "[ddp_setup] init_process_group backend=nccl "
274
+ f"world_size={world_size} NCCL_NET={os.environ.get('NCCL_NET', '<unset>')} "
275
+ f"NCCL_IB_DISABLE={os.environ.get('NCCL_IB_DISABLE', '<unset>')}",
276
+ flush=True,
277
+ )
278
+ dist.init_process_group(backend="nccl", timeout=timedelta(hours=timeout_hours))
279
+ if rank == 0:
280
+ print("[ddp_setup] init_process_group done", flush=True)
281
+ dev = torch.device("cuda", local_rank)
282
+ return dev, rank, world_size, (rank == 0)
283
+
284
+
285
+ def _ddp_broadcast_path(p: Path, is_main: bool) -> Path:
286
+ if not _ddp_is_enabled():
287
+ return p
288
+ obj_list: list[str] = [str(p) if is_main else ""]
289
+ dist.broadcast_object_list(obj_list, src=0)
290
+ return Path(obj_list[0])
291
+
292
+
293
+ def _ddp_broadcast_str(value: str, is_main: bool) -> str:
294
+ if not _ddp_is_enabled():
295
+ return value
296
+ obj_list: list[str] = [str(value) if is_main else ""]
297
+ dist.broadcast_object_list(obj_list, src=0)
298
+ return str(obj_list[0])
299
+
300
+
301
+ def _ddp_any_bool(flag: bool, device: torch.device) -> bool:
302
+ if not _ddp_is_enabled():
303
+ return bool(flag)
304
+ x = torch.tensor(1 if flag else 0, device=device, dtype=torch.int32)
305
+ dist.all_reduce(x, op=dist.ReduceOp.MAX)
306
+ return bool(int(x.item()) != 0)
307
+
308
+
309
+ def _env_flag(name: str, default: bool = False) -> bool:
310
+ raw = os.environ.get(name)
311
+ if raw is None:
312
+ return bool(default)
313
+ return raw.strip().lower() in {"1", "true", "yes", "on"}
314
+
315
+
316
+ def _is_oom_exception(exc: BaseException) -> bool:
317
+ if isinstance(exc, torch.cuda.OutOfMemoryError):
318
+ return True
319
+ msg = str(exc).lower()
320
+ oom_markers = (
321
+ "out of memory",
322
+ "cuda error: out of memory",
323
+ "cublas_status_alloc_failed",
324
+ "cudnn_status_alloc_failed",
325
+ "defaultcpuallocator",
326
+ )
327
+ return any(marker in msg for marker in oom_markers)
328
+
329
+
330
+ def _ddp_barrier(device: torch.device) -> None:
331
+ if not _ddp_is_enabled():
332
+ return
333
+ if device.type == "cuda" and device.index is not None:
334
+ dist.barrier(device_ids=[device.index])
335
+ else:
336
+ dist.barrier()
337
+
338
+
339
+ def _maybe_set_dataset_epoch(dataset: Any, epoch: int) -> None:
340
+ set_epoch = getattr(dataset, "set_epoch", None)
341
+ if callable(set_epoch):
342
+ set_epoch(int(epoch))
343
+
344
+
345
+ def _ddp_mean(x: torch.Tensor) -> torch.Tensor:
346
+ if not _ddp_is_enabled():
347
+ return x
348
+ y = x.detach().clone()
349
+ dist.all_reduce(y, op=dist.ReduceOp.SUM)
350
+ y = y / float(dist.get_world_size())
351
+ return y
352
+
353
+
354
+ def _save_train_vis(
355
+ out_dir: Path,
356
+ step: int,
357
+ src_gt: torch.Tensor,
358
+ src_pred: torch.Tensor,
359
+ src_alpha: torch.Tensor,
360
+ tgt_gt: torch.Tensor,
361
+ tgt_pred: torch.Tensor,
362
+ tgt_alpha: torch.Tensor,
363
+ src_gt_depth: torch.Tensor | None = None,
364
+ tgt_gt_depth: torch.Tensor | None = None,
365
+ src_pred_depth: torch.Tensor | None = None,
366
+ tgt_pred_depth: torch.Tensor | None = None,
367
+ src_unik3d_depth: torch.Tensor | None = None,
368
+ tgt_unik3d_depth: torch.Tensor | None = None,
369
+ dataset_name: str | None = None,
370
+ scene: str | None = None,
371
+ src_idx: int | None = None,
372
+ tgt_idx: int | None = None,
373
+ src_pose_w2c: torch.Tensor | None = None,
374
+ tgt_pose_w2c: torch.Tensor | None = None,
375
+ src_metric_mask: torch.Tensor | None = None,
376
+ tgt_metric_mask: torch.Tensor | None = None,
377
+ src_cube_gt_u8: torch.Tensor | None = None,
378
+ src_cube_pred_linear: torch.Tensor | None = None,
379
+ src_cube_alpha: torch.Tensor | None = None,
380
+ tgt_cube_gt_u8: torch.Tensor | None = None,
381
+ tgt_cube_pred_linear: torch.Tensor | None = None,
382
+ tgt_cube_alpha: torch.Tensor | None = None,
383
+ ) -> None:
384
+ vis_dir = out_dir / "vis"
385
+ vis_dir.mkdir(parents=True, exist_ok=True)
386
+ LOGGER.info("Saving train visualization: %s", str(vis_dir / f"step_{int(step):07d}.png"))
387
+ save_pair_visualization(
388
+ vis_dir / f"step_{int(step):07d}.png",
389
+ src_gt=src_gt,
390
+ src_pred=src_pred,
391
+ src_alpha=src_alpha,
392
+ tgt_gt=tgt_gt,
393
+ tgt_pred=tgt_pred,
394
+ tgt_alpha=tgt_alpha,
395
+ src_gt_depth=src_gt_depth,
396
+ tgt_gt_depth=tgt_gt_depth,
397
+ src_pred_depth=src_pred_depth,
398
+ tgt_pred_depth=tgt_pred_depth,
399
+ src_unik3d_depth=src_unik3d_depth,
400
+ tgt_unik3d_depth=tgt_unik3d_depth,
401
+ dataset_name=dataset_name,
402
+ scene=scene,
403
+ step=int(step),
404
+ src_idx=src_idx,
405
+ tgt_idx=tgt_idx,
406
+ src_pose_w2c=src_pose_w2c,
407
+ tgt_pose_w2c=tgt_pose_w2c,
408
+ src_cube_gt_u8=src_cube_gt_u8,
409
+ src_cube_pred_linear=src_cube_pred_linear,
410
+ src_cube_alpha=src_cube_alpha,
411
+ tgt_cube_gt_u8=tgt_cube_gt_u8,
412
+ tgt_cube_pred_linear=tgt_cube_pred_linear,
413
+ tgt_cube_alpha=tgt_cube_alpha,
414
+ )
415
+
416
+
417
+ def _read_nonempty_lines(path: Path) -> list[str]:
418
+ return [line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
419
+
420
+
421
+ def _resolve_manifest_file(manifest_dir: Path | None, filename: str) -> Path | None:
422
+ if manifest_dir is None:
423
+ return None
424
+ path = Path(manifest_dir) / filename
425
+ return path if path.exists() else None
426
+
427
+
428
+ @click.command()
429
+ @click.option("--data-root-re10k", type=click.Path(path_type=Path, exists=True), default=None)
430
+ @click.option("--data-root-hm3d", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/panogs"))
431
+ @click.option("--data-root-sim", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/smx_sim"))
432
+ @click.option("--sim-pose-root", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/smx_sim/30cm"))
433
+ @click.option("--data-root-wildrgbd", type=click.Path(path_type=Path, exists=True), default=None)
434
+ @click.option("--wild-roots-file", type=click.Path(path_type=Path, exists=True, dir_okay=False), default=DEFAULT_WILDRGBD_ROOTS_FILE)
435
+ @click.option("--data-root-dl3dv", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/sharp/DL3DV-ALL-960P"))
436
+ @click.option("--data-root-dl3dv-depth", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/sharp/DL3DV-ALL-960P_da3_outputs"))
437
+ @click.option("--data-root-scanetpp", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/scan"))
438
+ @click.option("--dataset-manifest-dir", type=click.Path(path_type=Path, file_okay=False), default=None)
439
+ @click.option("--out-root", type=click.Path(path_type=Path, file_okay=False), required=True)
440
+ @click.option("--run-name", type=str, default=None)
441
+ @click.option("--steps", type=int, default=1000000)
442
+ @click.option("--batch-size", type=int, default=2)
443
+ @click.option("--num-workers", type=int, default=1)
444
+ @click.option("--warmup", type=int, default=75000)
445
+ @click.option("--lr0", type=float, default=1.2e-4)
446
+ @click.option("--lr1", type=float, default=1.6e-5)
447
+ @click.option("--unik3d-lr0", type=float, default=2.5e-5, help="UniK3D decoder/head peak LR.")
448
+ @click.option("--unik3d-lr1", type=float, default=2.5e-6, help="UniK3D decoder/head final LR.")
449
+ @click.option("--unik3d-encoder-lr0", type=float, default=1.5e-6, help="UniK3D pixel_encoder peak LR.")
450
+ @click.option("--unik3d-encoder-lr1", type=float, default=1.5e-7, help="UniK3D pixel_encoder final LR.")
451
+ @click.option("--grad-clip-norm", type=float, default=1.0, show_default=True)
452
+ @click.option("--max-step-grad-norm", type=float, default=100000.0, show_default=True, help="Skip optimizer step when pre-clip grad norm exceeds this value. 0 disables.")
453
+ @click.option("--max-depth-m", type=float, default=DEFAULT_MAX_DEPTH_M, show_default=True)
454
+ @click.option("--sim-far-depth-invalid-m", type=float, default=30.0, show_default=True)
455
+ @click.option("--sim-far-depth-invalid-max-frac", type=float, default=1.0, show_default=True)
456
+ @click.option("--sim-max-long-edge", type=int, default=512, show_default=True, help="Resize SIM ERP frames before cubemap conversion. 0 keeps native resolution.")
457
+ @click.option("--train-resize-multiple", type=int, default=256, show_default=True, help="Before model forward, downsize training inputs to the largest H/W divisible by this value. 0 disables.")
458
+ @click.option("--pinhole-train-size", type=int, default=0, show_default=True, help="Resize pinhole training datasets to NxN before model forward. 0 keeps dataset native resolution.")
459
+ @click.option("--scanetpp-fisheye-far-depth-invalid-m", type=float, default=30.0, show_default=True)
460
+ @click.option("--max-index-gap", type=int, default=10)
461
+ @click.option("--device", type=str, default="cuda")
462
+ @click.option("--render-low-pass-filter-eps", type=float, default=1e-2, show_default=True)
463
+ @click.option("--ddp-timeout-hours", type=float, default=8.0)
464
+ @click.option("--save-every", type=int, default=5000)
465
+ @click.option("--log-every", type=int, default=50)
466
+ @click.option("--vis-every", type=int, default=500)
467
+ @click.option("--unik3d-backbone", type=click.Choice(["vitb", "vitl"]), default="vitl")
468
+ @click.option("--unik3d-resolution-level", type=click.IntRange(0, 9), default=0, show_default=True)
469
+ @click.option("--initializer-stride", type=click.IntRange(1, 2), default=1)
470
+ @click.option("--initializer-scale-factor", type=float, default=1.5, show_default=True)
471
+ @click.option("--lambda-aux-ray", type=float, default=3.0)
472
+ @click.option("--lambda-aux-depth-scale", type=float, default=3.0)
473
+ @click.option("--lambda-aux-depth2-scale", type=float, default=1.0)
474
+ @click.option("--lambda-color", type=float, default=1.0)
475
+ @click.option("--lambda-alpha", type=float, default=1.5)
476
+ @click.option("--alpha-tail-min", type=float, default=0.99, show_default=True, help="Alpha value below which local tail coverage loss is applied.")
477
+ @click.option("--alpha-tail-weight", type=float, default=0.0, show_default=True, help="Extra normalized tail weight for local low-alpha holes.")
478
+ @click.option("--lambda-percep", type=float, default=1.0)
479
+ @click.option("--lambda-depth", type=float, default=0.5)
480
+ @click.option("--lambda-tv", type=float, default=1.0)
481
+ @click.option("--lambda-grad", type=float, default=1.0)
482
+ @click.option("--lambda-grad-img", type=float, default=0.2)
483
+ @click.option("--lambda-edge-rgb", type=float, default=0.0, show_default=True, help="Weight for GT RGB edge-band gradient matching.")
484
+ @click.option("--lambda-delta", type=float, default=1.0)
485
+ @click.option("--lambda-delta-rho", type=float, default=0.01, show_default=True)
486
+ @click.option("--lambda-splat", type=float, default=1.0)
487
+ @click.option("--lambda-edge-splat", type=float, default=0.0, show_default=True, help="Weight for stricter projected-sigma penalty on GT depth-edge bands.")
488
+ @click.option("--lambda-grid", type=float, default=0.05, show_default=True, help="Weight for Gaussian-grid 2x2 checkerboard residual regularization.")
489
+ @click.option("--delta-clip", type=float, default=10.0, show_default=True)
490
+ @click.option("--raw-delta-clip", type=float, default=400.0, show_default=True)
491
+ @click.option("--raw-delta-rho-clip", type=float, default=5.0, show_default=True)
492
+ @click.option("--delta-rho-limit", type=float, default=2.0, show_default=True)
493
+ @click.option("--splat-sigma-min", type=float, default=1e-1, show_default=True, help="Minimum projected screen-space variance for L_splat.")
494
+ @click.option("--splat-sigma-max", type=float, default=1e2, show_default=True, help="Maximum projected screen-space variance for L_splat.")
495
+ @click.option("--edge-splat-sigma-max", type=float, default=2.0, show_default=True, help="Maximum projected variance on depth-edge bands for L_edge_splat.")
496
+ @click.option("--depth-edge-log-threshold", type=float, default=0.05, show_default=True, help="Log-depth jump threshold used to build L_edge_splat edge bands.")
497
+ @click.option("--depth-edge-dilate-px", type=int, default=2, show_default=True, help="Dilation radius in pixels for L_edge_splat depth-edge bands.")
498
+ @click.option("--target-mask-erode-px", type=int, default=0, show_default=True, help="Erode source-visible target masks by this many pixels before target supervision.")
499
+ @click.option("--dataset-weight-re10k", type=float, default=1.0)
500
+ @click.option("--dataset-weight-hm3d", type=float, default=1.0)
501
+ @click.option("--dataset-weight-sim", type=float, default=1.0)
502
+ @click.option("--dataset-weight-wildrgbd", type=float, default=1.0)
503
+ @click.option("--dataset-weight-dl3dv", type=float, default=1.0)
504
+ @click.option("--dataset-weight-scanetpp", type=float, default=0.0)
505
+ @click.option(
506
+ "--re10k-pseudo-depth-root",
507
+ type=click.Path(path_type=Path, file_okay=False),
508
+ default=Path("/media/team_data/ML4_team/datasets/nopose/re10k_unik3d_pseudo_depth"),
509
+ )
510
+ @click.option("--re10k-pseudo-depth-autogen/--no-re10k-pseudo-depth-autogen", default=True)
511
+ @click.option("--re10k-pseudo-depth-backbone", type=click.Choice(["vitb", "vitl"]), default="vitl")
512
+ @click.option("--re10k-pseudo-depth-device", type=str, default="cpu")
513
+ @click.option("--re10k-pseudo-lock-timeout-sec", type=float, default=120.0)
514
+ @click.option("--re10k-pseudo-lock-stale-sec", type=float, default=1800.0)
515
+ @click.option("--re10k-pseudo-far-depth-invalid-m", type=float, default=30.0)
516
+ @click.option("--seed", type=int, default=None)
517
+ @click.option("-v", "--verbose", is_flag=True)
518
+ def train_feature_cli(
519
+ data_root_re10k: Path | None,
520
+ data_root_hm3d: Path | None,
521
+ data_root_sim: Path | None,
522
+ sim_pose_root: Path | None,
523
+ data_root_wildrgbd: Path | None,
524
+ wild_roots_file: Path,
525
+ data_root_dl3dv: Path | None,
526
+ data_root_dl3dv_depth: Path | None,
527
+ data_root_scanetpp: Path | None,
528
+ dataset_manifest_dir: Path | None,
529
+ out_root: Path,
530
+ run_name: str | None,
531
+ steps: int,
532
+ batch_size: int,
533
+ num_workers: int,
534
+ warmup: int,
535
+ lr0: float,
536
+ lr1: float,
537
+ unik3d_lr0: float,
538
+ unik3d_lr1: float,
539
+ unik3d_encoder_lr0: float,
540
+ unik3d_encoder_lr1: float,
541
+ grad_clip_norm: float,
542
+ max_step_grad_norm: float,
543
+ max_depth_m: float,
544
+ sim_far_depth_invalid_m: float,
545
+ sim_far_depth_invalid_max_frac: float,
546
+ sim_max_long_edge: int,
547
+ train_resize_multiple: int,
548
+ pinhole_train_size: int,
549
+ scanetpp_fisheye_far_depth_invalid_m: float,
550
+ max_index_gap: int,
551
+ device: str,
552
+ render_low_pass_filter_eps: float,
553
+ ddp_timeout_hours: float,
554
+ save_every: int,
555
+ log_every: int,
556
+ vis_every: int,
557
+ unik3d_backbone: str,
558
+ unik3d_resolution_level: int,
559
+ initializer_stride: int,
560
+ initializer_scale_factor: float,
561
+ lambda_aux_ray: float,
562
+ lambda_aux_depth_scale: float,
563
+ lambda_aux_depth2_scale: float,
564
+ lambda_color: float,
565
+ lambda_alpha: float,
566
+ alpha_tail_min: float,
567
+ alpha_tail_weight: float,
568
+ lambda_percep: float,
569
+ lambda_depth: float,
570
+ lambda_tv: float,
571
+ lambda_grad: float,
572
+ lambda_grad_img: float,
573
+ lambda_edge_rgb: float,
574
+ lambda_delta: float,
575
+ lambda_delta_rho: float,
576
+ lambda_splat: float,
577
+ lambda_edge_splat: float,
578
+ lambda_grid: float,
579
+ delta_clip: float,
580
+ raw_delta_clip: float,
581
+ raw_delta_rho_clip: float,
582
+ delta_rho_limit: float,
583
+ splat_sigma_min: float,
584
+ splat_sigma_max: float,
585
+ edge_splat_sigma_max: float,
586
+ depth_edge_log_threshold: float,
587
+ depth_edge_dilate_px: int,
588
+ target_mask_erode_px: int,
589
+ dataset_weight_re10k: float,
590
+ dataset_weight_hm3d: float,
591
+ dataset_weight_sim: float,
592
+ dataset_weight_wildrgbd: float,
593
+ dataset_weight_dl3dv: float,
594
+ dataset_weight_scanetpp: float,
595
+ re10k_pseudo_depth_root: Path,
596
+ re10k_pseudo_depth_autogen: bool,
597
+ re10k_pseudo_depth_backbone: str,
598
+ re10k_pseudo_depth_device: str,
599
+ re10k_pseudo_lock_timeout_sec: float,
600
+ re10k_pseudo_lock_stale_sec: float,
601
+ re10k_pseudo_far_depth_invalid_m: float,
602
+ seed: int | None,
603
+ verbose: bool,
604
+ ) -> None:
605
+ detach_init_layer0_distance = True
606
+
607
+ log_level = logging.DEBUG if verbose else logging.INFO
608
+ logging_utils.configure(log_level)
609
+ if float(max_depth_m) <= 0.0:
610
+ raise ValueError("--max-depth-m must be positive.")
611
+ if float(grad_clip_norm) <= 0.0:
612
+ raise ValueError("--grad-clip-norm must be positive.")
613
+ if float(max_step_grad_norm) < 0.0:
614
+ raise ValueError("--max-step-grad-norm must be non-negative.")
615
+ if float(render_low_pass_filter_eps) < 0.0:
616
+ raise ValueError("--render-low-pass-filter-eps must be non-negative.")
617
+ if not (0.0 <= float(sim_far_depth_invalid_max_frac) <= 1.0):
618
+ raise ValueError("--sim-far-depth-invalid-max-frac must be in [0, 1].")
619
+ if int(sim_max_long_edge) < 0:
620
+ raise ValueError("--sim-max-long-edge must be non-negative.")
621
+ if int(train_resize_multiple) < 0:
622
+ raise ValueError("--train-resize-multiple must be non-negative.")
623
+ if int(pinhole_train_size) < 0:
624
+ raise ValueError("--pinhole-train-size must be non-negative.")
625
+ if float(scanetpp_fisheye_far_depth_invalid_m) < 0.0:
626
+ raise ValueError("--scanetpp-fisheye-far-depth-invalid-m must be non-negative.")
627
+ if float(delta_clip) < 0.0:
628
+ raise ValueError("--delta-clip must be non-negative.")
629
+ if float(raw_delta_clip) < 0.0:
630
+ raise ValueError("--raw-delta-clip must be non-negative.")
631
+ if float(raw_delta_rho_clip) < 0.0:
632
+ raise ValueError("--raw-delta-rho-clip must be non-negative.")
633
+ if float(lambda_grid) < 0.0:
634
+ raise ValueError("--lambda-grid must be non-negative.")
635
+ if float(lambda_edge_rgb) < 0.0:
636
+ raise ValueError("--lambda-edge-rgb must be non-negative.")
637
+ if float(lambda_edge_splat) < 0.0:
638
+ raise ValueError("--lambda-edge-splat must be non-negative.")
639
+ if float(edge_splat_sigma_max) < 0.0:
640
+ raise ValueError("--edge-splat-sigma-max must be non-negative.")
641
+ if float(depth_edge_log_threshold) < 0.0:
642
+ raise ValueError("--depth-edge-log-threshold must be non-negative.")
643
+ if int(depth_edge_dilate_px) < 0:
644
+ raise ValueError("--depth-edge-dilate-px must be non-negative.")
645
+ if int(target_mask_erode_px) < 0:
646
+ raise ValueError("--target-mask-erode-px must be non-negative.")
647
+ if not (0.0 <= float(alpha_tail_min) <= 1.0):
648
+ raise ValueError("--alpha-tail-min must be in [0, 1].")
649
+ if float(alpha_tail_weight) < 0.0:
650
+ raise ValueError("--alpha-tail-weight must be non-negative.")
651
+ if float(delta_rho_limit) < 0.0:
652
+ raise ValueError("--delta-rho-limit must be non-negative.")
653
+ if float(splat_sigma_min) < 0.0:
654
+ raise ValueError("--splat-sigma-min must be non-negative.")
655
+ if float(splat_sigma_max) <= float(splat_sigma_min):
656
+ raise ValueError("--splat-sigma-max must be greater than --splat-sigma-min.")
657
+ dev, rank, world_size, is_main = _ddp_setup(device, ddp_timeout_hours=ddp_timeout_hours)
658
+
659
+ if seed is not None:
660
+ s = int(seed)
661
+ random.seed(s + rank)
662
+ np.random.seed(s + rank)
663
+ torch.manual_seed(s + rank)
664
+ if torch.cuda.is_available():
665
+ torch.cuda.manual_seed_all(s + rank)
666
+
667
+ if is_main and (run_name is None or run_name.strip() == ""):
668
+ run_name = f"unified_feature_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
669
+ if run_name is None:
670
+ run_name = "unified_feature_ddp"
671
+ out_dir = _ddp_broadcast_path(Path(out_root) / run_name, is_main=is_main)
672
+ logging_utils.configure(log_level)
673
+ if not is_main:
674
+ logging.getLogger().setLevel(logging.WARNING)
675
+ LOGGER.setLevel(logging.WARNING)
676
+ _configure_torchhub_cache()
677
+ re10k_enabled_for_train = bool(float(dataset_weight_re10k) > 0.0)
678
+ hm3d_enabled_for_train = bool(float(dataset_weight_hm3d) > 0.0)
679
+ sim_enabled_for_train = bool(float(dataset_weight_sim) > 0.0)
680
+ dl3dv_enabled_for_train = bool(float(dataset_weight_dl3dv) > 0.0)
681
+ scanetpp_enabled_for_train = bool(float(dataset_weight_scanetpp) > 0.0)
682
+ wild_roots = _read_nonempty_lines(wild_roots_file) if wild_roots_file.exists() else []
683
+ re10k_manifest = _resolve_manifest_file(dataset_manifest_dir, "re10k_train_chunks.txt")
684
+ hm3d_manifest = _resolve_manifest_file(dataset_manifest_dir, "hm3d_train_scenes.txt")
685
+ sim_manifest = _resolve_manifest_file(dataset_manifest_dir, "sim_train_scenes.txt")
686
+ wildrgbd_manifest = _resolve_manifest_file(dataset_manifest_dir, "wildrgbd_train_scenes.txt")
687
+ dl3dv_manifest = _resolve_manifest_file(dataset_manifest_dir, "dl3dv_train_scenes.txt")
688
+ scanetpp_manifest = _resolve_manifest_file(dataset_manifest_dir, "scanetpp_fisheye_train_scenes.txt")
689
+ wildrgbd_enabled_for_train = bool(
690
+ ((data_root_wildrgbd is not None) or bool(wild_roots)) and (float(dataset_weight_wildrgbd) > 0.0)
691
+ )
692
+ if re10k_enabled_for_train and data_root_re10k is None:
693
+ raise ValueError("dataset_weight_re10k>0 but --data-root-re10k is not provided.")
694
+ if hm3d_enabled_for_train and data_root_hm3d is None:
695
+ raise ValueError("dataset_weight_hm3d>0 but --data-root-hm3d is not provided.")
696
+ if sim_enabled_for_train and (data_root_sim is None or sim_pose_root is None):
697
+ raise ValueError("dataset_weight_sim>0 but --data-root-sim / --sim-pose-root is missing.")
698
+ if sim_enabled_for_train and sim_manifest is None:
699
+ raise ValueError("dataset_weight_sim>0 but sim_train_scenes.txt is missing from --dataset-manifest-dir.")
700
+ if float(dataset_weight_wildrgbd) > 0.0 and (data_root_wildrgbd is None) and (not wild_roots):
701
+ raise ValueError("dataset_weight_wildrgbd>0 but neither --data-root-wildrgbd nor --wild-roots-file is provided.")
702
+ if dl3dv_enabled_for_train and (data_root_dl3dv is None or data_root_dl3dv_depth is None):
703
+ raise ValueError("dataset_weight_dl3dv>0 but --data-root-dl3dv / --data-root-dl3dv-depth is missing.")
704
+ if scanetpp_enabled_for_train and data_root_scanetpp is None:
705
+ raise ValueError("dataset_weight_scanetpp>0 but --data-root-scanetpp is missing.")
706
+
707
+ if is_main:
708
+ out_dir.mkdir(parents=True, exist_ok=True)
709
+ LOGGER.info(
710
+ "Training start: out=%s branch=gt-override scratch_unik3d_pretrained backbone=%s steps=%d batch=%d",
711
+ str(out_dir),
712
+ str(unik3d_backbone),
713
+ int(steps),
714
+ int(batch_size),
715
+ )
716
+ LOGGER.info(
717
+ "Loss weights: color=%.3g alpha=%.3g depth=%.3g percep=%.3g aux_ray=%.3g aux_depth0=%.3g aux_depth1=%.3g",
718
+ float(lambda_color),
719
+ float(lambda_alpha),
720
+ float(lambda_depth),
721
+ float(lambda_percep),
722
+ float(lambda_aux_ray),
723
+ float(lambda_aux_depth_scale),
724
+ float(lambda_aux_depth2_scale),
725
+ )
726
+
727
+ dataset_seed = int(seed) if seed is not None else 12345
728
+ pinhole_output_h = int(pinhole_train_size) if int(pinhole_train_size) > 0 else None
729
+ pinhole_output_w = int(pinhole_train_size) if int(pinhole_train_size) > 0 else None
730
+
731
+ re10k_ds = None
732
+ if re10k_enabled_for_train:
733
+ re10k_ds = Re10KDataset(
734
+ root=data_root_re10k,
735
+ chunks_file=re10k_manifest,
736
+ split="train",
737
+ min_frame_gap=1,
738
+ max_frame_gap=int(max_index_gap),
739
+ pair_max_translation_m=0.5,
740
+ pair_min_overlap=0.6,
741
+ output_h=pinhole_output_h,
742
+ output_w=pinhole_output_w,
743
+ shuffle_chunk=True,
744
+ shuffle_example=True,
745
+ ddp_rank=rank,
746
+ ddp_world_size=world_size,
747
+ pseudo_depth_root=re10k_pseudo_depth_root,
748
+ pseudo_depth_autogen=bool(re10k_pseudo_depth_autogen),
749
+ pseudo_depth_backbone=str(re10k_pseudo_depth_backbone),
750
+ pseudo_depth_device=str(re10k_pseudo_depth_device),
751
+ pseudo_lock_timeout_sec=float(re10k_pseudo_lock_timeout_sec),
752
+ pseudo_lock_stale_sec=float(re10k_pseudo_lock_stale_sec),
753
+ batch_size_hint=int(batch_size),
754
+ depth_max_m=float(max_depth_m),
755
+ pseudo_far_depth_invalid_m=float(re10k_pseudo_far_depth_invalid_m),
756
+ seed=dataset_seed,
757
+ )
758
+ hm3d_train_root = None
759
+ if data_root_hm3d is not None:
760
+ hm3d_train_root = data_root_hm3d / "train" if (data_root_hm3d / "train").exists() else data_root_hm3d
761
+
762
+ hm3d_ds = None
763
+ if hm3d_enabled_for_train:
764
+ hm3d_ds = PanOGSDataset(
765
+ root=hm3d_train_root,
766
+ index_manifest_path=hm3d_manifest,
767
+ src_tgt_max_index_gap=int(max_index_gap),
768
+ use_cubemap_supervision=True,
769
+ pair_sampling=True,
770
+ pair_max_translation_m=0.5,
771
+ pair_min_depth_overlap=0.6,
772
+ pair_overlap_face_w=64,
773
+ pair_overlap_margin=1.05,
774
+ pair_max_tries=48,
775
+ depth_max_m=float(max_depth_m),
776
+ )
777
+ sim_ds = None
778
+ if sim_enabled_for_train:
779
+ sim_ds = SimPanoramaDataset(
780
+ root=data_root_sim,
781
+ pose_root=sim_pose_root,
782
+ scene_list_file=sim_manifest,
783
+ max_index_gap=int(max_index_gap),
784
+ pair_max_translation_m=0.5,
785
+ pair_min_depth_overlap=0.6,
786
+ pairs_per_chunk=15,
787
+ chunk_size=30,
788
+ shuffle_scene=True,
789
+ ddp_rank=rank,
790
+ ddp_world_size=world_size,
791
+ depth_max_m=float(max_depth_m),
792
+ far_depth_invalid_m=float(sim_far_depth_invalid_m),
793
+ far_depth_invalid_max_frac=float(sim_far_depth_invalid_max_frac),
794
+ max_long_edge=int(sim_max_long_edge),
795
+ seed=dataset_seed,
796
+ )
797
+ wildrgbd_ds = None
798
+ if wildrgbd_enabled_for_train:
799
+ wild_dataset_roots = [Path(p) for p in wild_roots]
800
+ if data_root_wildrgbd is not None:
801
+ wild_dataset_roots.append(data_root_wildrgbd)
802
+ wildrgbd_ds = WildRGBDDataset(
803
+ root=None,
804
+ scene_list_file=wildrgbd_manifest,
805
+ split="scenes",
806
+ min_frame_gap=1,
807
+ max_frame_gap=int(max_index_gap),
808
+ pair_max_translation_m=0.5,
809
+ pair_min_overlap=0.6,
810
+ output_h=pinhole_output_h,
811
+ output_w=pinhole_output_w,
812
+ shuffle_scene=True,
813
+ shuffle_frame=False,
814
+ ddp_rank=rank,
815
+ ddp_world_size=world_size,
816
+ roots=wild_dataset_roots,
817
+ depth_max_m=float(max_depth_m),
818
+ seed=dataset_seed,
819
+ )
820
+ dl3dv_ds = None
821
+ if dl3dv_enabled_for_train:
822
+ dl3dv_ds = DL3DVDataset(
823
+ root=data_root_dl3dv,
824
+ depth_root=data_root_dl3dv_depth,
825
+ scene_specs_file=dl3dv_manifest,
826
+ min_frame_gap=1,
827
+ max_frame_gap=int(max_index_gap),
828
+ pair_max_translation_m=0.5,
829
+ pair_min_overlap=0.6,
830
+ output_h=pinhole_output_h,
831
+ output_w=pinhole_output_w,
832
+ shuffle_scene=True,
833
+ shuffle_frame=False,
834
+ ddp_rank=rank,
835
+ ddp_world_size=world_size,
836
+ batch_size_hint=int(batch_size),
837
+ depth_max_m=float(max_depth_m),
838
+ seed=dataset_seed,
839
+ )
840
+
841
+ scanetpp_ds = None
842
+ if scanetpp_enabled_for_train:
843
+ scanetpp_ds = ScannetppFisheyeDataset(
844
+ root=data_root_scanetpp,
845
+ scene_list_file=scanetpp_manifest,
846
+ min_frame_gap=1,
847
+ max_frame_gap=int(max_index_gap),
848
+ pair_max_translation_m=0.5,
849
+ shuffle_scene=True,
850
+ shuffle_frame=False,
851
+ ddp_rank=rank,
852
+ ddp_world_size=world_size,
853
+ batch_size_hint=int(batch_size),
854
+ depth_max_m=float(max_depth_m),
855
+ far_depth_invalid_m=float(scanetpp_fisheye_far_depth_invalid_m),
856
+ seed=dataset_seed,
857
+ )
858
+
859
+ hm3d_sampler = None
860
+ if hm3d_ds is not None and _ddp_is_enabled():
861
+ hm3d_sampler = DistributedSampler(hm3d_ds, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
862
+
863
+ re10k_num_workers = int(num_workers)
864
+ if re10k_ds is not None and bool(re10k_pseudo_depth_autogen) and re10k_num_workers > 0:
865
+ re10k_num_workers = 0
866
+ if is_main:
867
+ LOGGER.warning(
868
+ "RE10K pseudo-depth auto-generate enabled: force re10k dataloader num_workers=%d (requested=%d).",
869
+ int(re10k_num_workers),
870
+ int(num_workers),
871
+ )
872
+ if re10k_ds is not None and batch_size > 1 and re10k_num_workers > 0:
873
+ re10k_num_workers = 0
874
+ if is_main:
875
+ LOGGER.warning(
876
+ "Dynamic-resolution RE10K batching requires ordered same-resolution samples: force re10k dataloader num_workers=%d (requested=%d).",
877
+ int(re10k_num_workers),
878
+ int(num_workers),
879
+ )
880
+
881
+ highres_pin_memory = os.environ.get("HIGHRES_TRAIN_PIN_MEMORY", "0").strip().lower() in {"1", "true", "yes", "on"}
882
+ standard_pin_memory = os.environ.get("TRAIN_PIN_MEMORY", "1").strip().lower() in {"1", "true", "yes", "on"}
883
+ try:
884
+ train_prefetch_factor = max(1, int(os.environ.get("TRAIN_PREFETCH_FACTOR", "1").strip()))
885
+ except Exception:
886
+ train_prefetch_factor = 1
887
+ def _loader_worker_kwargs(worker_count: int, *, pin_memory: bool) -> dict[str, Any]:
888
+ kwargs: dict[str, Any] = {
889
+ "num_workers": int(worker_count),
890
+ "pin_memory": bool(pin_memory),
891
+ }
892
+ if int(worker_count) > 0:
893
+ kwargs["prefetch_factor"] = int(train_prefetch_factor)
894
+ return kwargs
895
+
896
+ re10k_dl = None
897
+ if re10k_ds is not None:
898
+ re10k_dl = DataLoader(
899
+ re10k_ds,
900
+ batch_size=None,
901
+ **_loader_worker_kwargs(re10k_num_workers, pin_memory=standard_pin_memory),
902
+ collate_fn=re10k_passthrough,
903
+ )
904
+
905
+ hm3d_dl = None
906
+ if hm3d_ds is not None:
907
+ hm3d_dl = DataLoader(
908
+ hm3d_ds,
909
+ batch_size=batch_size,
910
+ shuffle=(hm3d_sampler is None),
911
+ sampler=hm3d_sampler,
912
+ **_loader_worker_kwargs(num_workers, pin_memory=highres_pin_memory),
913
+ collate_fn=panogs_collate,
914
+ )
915
+
916
+ sim_dl = None
917
+ if sim_ds is not None:
918
+ sim_dl = DataLoader(
919
+ sim_ds,
920
+ batch_size=batch_size,
921
+ **_loader_worker_kwargs(num_workers, pin_memory=highres_pin_memory),
922
+ collate_fn=panogs_collate,
923
+ )
924
+
925
+ wildrgbd_dl = None
926
+ if wildrgbd_ds is not None:
927
+ wildrgbd_dl = DataLoader(
928
+ wildrgbd_ds,
929
+ batch_size=batch_size,
930
+ **_loader_worker_kwargs(num_workers, pin_memory=standard_pin_memory),
931
+ collate_fn=wildrgbd_collate,
932
+ )
933
+
934
+ dl3dv_dl = None
935
+ if dl3dv_ds is not None:
936
+ dl3dv_dl = DataLoader(
937
+ dl3dv_ds,
938
+ batch_size=None,
939
+ **_loader_worker_kwargs(num_workers, pin_memory=standard_pin_memory),
940
+ collate_fn=re10k_passthrough,
941
+ )
942
+
943
+ scanetpp_dl = None
944
+ if scanetpp_ds is not None:
945
+ scanetpp_dl = DataLoader(
946
+ scanetpp_ds,
947
+ batch_size=None,
948
+ **_loader_worker_kwargs(num_workers, pin_memory=highres_pin_memory),
949
+ collate_fn=scannetpp_fisheye_passthrough,
950
+ )
951
+
952
+ candidate_datasets: dict[str, Any] = {}
953
+ candidate_dataloaders: dict[str, DataLoader] = {}
954
+ candidate_weights: dict[str, float] = {}
955
+ if re10k_ds is not None and re10k_dl is not None:
956
+ candidate_datasets["re10k"] = re10k_ds
957
+ candidate_dataloaders["re10k"] = re10k_dl
958
+ candidate_weights["re10k"] = float(dataset_weight_re10k)
959
+ if hm3d_ds is not None and hm3d_dl is not None:
960
+ candidate_datasets["hm3d"] = hm3d_ds
961
+ candidate_dataloaders["hm3d"] = hm3d_dl
962
+ candidate_weights["hm3d"] = float(dataset_weight_hm3d)
963
+ if sim_ds is not None and sim_dl is not None:
964
+ candidate_datasets["sim"] = sim_ds
965
+ candidate_dataloaders["sim"] = sim_dl
966
+ candidate_weights["sim"] = float(dataset_weight_sim)
967
+ if wildrgbd_ds is not None and wildrgbd_dl is not None:
968
+ candidate_datasets["wildrgbd"] = wildrgbd_ds
969
+ candidate_dataloaders["wildrgbd"] = wildrgbd_dl
970
+ candidate_weights["wildrgbd"] = float(dataset_weight_wildrgbd)
971
+ if dl3dv_ds is not None and dl3dv_dl is not None:
972
+ candidate_datasets["dl3dv"] = dl3dv_ds
973
+ candidate_dataloaders["dl3dv"] = dl3dv_dl
974
+ candidate_weights["dl3dv"] = float(dataset_weight_dl3dv)
975
+ if scanetpp_ds is not None and scanetpp_dl is not None:
976
+ candidate_datasets["scanetpp_fisheye"] = scanetpp_ds
977
+ candidate_dataloaders["scanetpp_fisheye"] = scanetpp_dl
978
+ candidate_weights["scanetpp_fisheye"] = float(dataset_weight_scanetpp)
979
+
980
+ datasets: dict[str, Any] = {}
981
+ dataloaders: dict[str, DataLoader] = {}
982
+ sampling: dict[str, float] = {}
983
+ for name, w in candidate_weights.items():
984
+ if float(w) > 0.0:
985
+ datasets[name] = candidate_datasets[name]
986
+ dataloaders[name] = candidate_dataloaders[name]
987
+ sampling[name] = float(w)
988
+ elif is_main:
989
+ LOGGER.warning("Skip dataset in mixed sampler: %s (weight=%.4f <= 0)", name, float(w))
990
+
991
+ if len(datasets) == 0:
992
+ raise ValueError("No dataset selected for mixed sampler (all dataset weights <= 0).")
993
+ for name, dataset in datasets.items():
994
+ _maybe_set_dataset_epoch(dataset, 0)
995
+ iterators = {name: LazyDataLoaderIterator(dl) for name, dl in dataloaders.items()}
996
+ sampler_seed = int(seed + rank) if seed is not None else int(12345 + rank)
997
+ sampler = MixedDatasetSampler(
998
+ datasets=datasets,
999
+ weights=sampling,
1000
+ iterators=iterators,
1001
+ seed=sampler_seed,
1002
+ )
1003
+
1004
+ config = UnisharpFeatureConfig(
1005
+ unik3d_backbone=unik3d_backbone,
1006
+ unik3d_resolution_level=int(unik3d_resolution_level),
1007
+ initializer_stride=int(initializer_stride),
1008
+ initializer_scale_factor=float(initializer_scale_factor),
1009
+ detach_init_layer0_distance=bool(detach_init_layer0_distance),
1010
+ delta_rho_limit=float(delta_rho_limit),
1011
+ )
1012
+ setattr(config, "max_distance_m", float(max_depth_m))
1013
+
1014
+ model = UnisharpFeatureModel(config).to(dev).train()
1015
+
1016
+ if _ddp_is_enabled():
1017
+ model = DDP(
1018
+ model,
1019
+ device_ids=[dev.index],
1020
+ output_device=dev.index,
1021
+ find_unused_parameters=True,
1022
+ gradient_as_bucket_view=True,
1023
+ )
1024
+
1025
+ raw_model = model.module if isinstance(model, DDP) else model
1026
+ base_params, unik3d_encoder_params, unik3d_decoder_params = _build_optimizer_param_groups(raw_model)
1027
+ unik3d_params = unik3d_encoder_params + unik3d_decoder_params
1028
+ trainable_params = base_params + unik3d_params
1029
+ if len(trainable_params) == 0:
1030
+ raise RuntimeError("No trainable parameters found.")
1031
+ if len(unik3d_params) == 0:
1032
+ raise RuntimeError(
1033
+ "No UniK3D parameters were collected for the default unfreeze training path. "
1034
+ "Please check parameter naming."
1035
+ )
1036
+ depth_head_params = [p for p in raw_model.second_layer_depth_head.parameters() if p.requires_grad]
1037
+ if len(depth_head_params) == 0:
1038
+ raise RuntimeError("Depth heads have no trainable parameters; depth branch would not train.")
1039
+
1040
+ opt_groups: list[dict[str, Any]] = [{"params": base_params, "lr": float(lr0), "group_name": "base"}]
1041
+ if len(unik3d_encoder_params) > 0:
1042
+ opt_groups.append(
1043
+ {
1044
+ "params": unik3d_encoder_params,
1045
+ "lr": float(unik3d_encoder_lr0),
1046
+ "group_name": "unik3d_encoder",
1047
+ }
1048
+ )
1049
+ if len(unik3d_decoder_params) > 0:
1050
+ opt_groups.append(
1051
+ {
1052
+ "params": unik3d_decoder_params,
1053
+ "lr": float(unik3d_lr0),
1054
+ "group_name": "unik3d_decoder",
1055
+ }
1056
+ )
1057
+ opt = torch.optim.Adam(opt_groups)
1058
+ if is_main:
1059
+ LOGGER.info(
1060
+ "Model ready: scratch heads, pretrained UniK3D, trainable_params=%d",
1061
+ _count_numel(trainable_params),
1062
+ )
1063
+ if dev.type == "cuda":
1064
+ scaler = torch.amp.GradScaler("cuda", enabled=True)
1065
+ else:
1066
+ scaler = torch.amp.GradScaler("cpu", enabled=False)
1067
+
1068
+ renderer = GSplatRenderer(
1069
+ color_space="sRGB",
1070
+ background_color="black",
1071
+ low_pass_filter_eps=float(render_low_pass_filter_eps),
1072
+ ).to(dev)
1073
+
1074
+ loss_w = UnisharpLossWeights(
1075
+ lambda_color=float(lambda_color),
1076
+ lambda_alpha=float(lambda_alpha),
1077
+ lambda_percep=float(lambda_percep),
1078
+ lambda_depth=float(lambda_depth),
1079
+ lambda_tv=float(lambda_tv),
1080
+ lambda_grad=float(lambda_grad),
1081
+ lambda_grad_img=float(lambda_grad_img),
1082
+ lambda_edge_rgb=float(lambda_edge_rgb),
1083
+ lambda_delta=float(lambda_delta),
1084
+ lambda_delta_rho=float(lambda_delta_rho),
1085
+ lambda_splat=float(lambda_splat),
1086
+ lambda_edge_splat=float(lambda_edge_splat),
1087
+ lambda_grid=float(lambda_grid),
1088
+ )
1089
+ loss_fn = UnisharpLoss(
1090
+ weights=loss_w,
1091
+ delta_clip=float(delta_clip),
1092
+ raw_delta_clip=float(raw_delta_clip),
1093
+ raw_delta_rho_clip=float(raw_delta_rho_clip),
1094
+ alpha_tail_min=float(alpha_tail_min),
1095
+ alpha_tail_weight=float(alpha_tail_weight),
1096
+ splat_sigma_min=float(splat_sigma_min),
1097
+ splat_sigma_max=float(splat_sigma_max),
1098
+ edge_splat_sigma_max=float(edge_splat_sigma_max),
1099
+ depth_edge_log_threshold=float(depth_edge_log_threshold),
1100
+ depth_edge_dilate_px=int(depth_edge_dilate_px),
1101
+ ).to(dev)
1102
+ loss_fn.SUPERVISION_MAX_DEPTH_M = float(max_depth_m)
1103
+
1104
+ if is_main:
1105
+ config_dict = {
1106
+ "max_depth_m": float(max_depth_m),
1107
+ "sim_far_depth_invalid_m": float(sim_far_depth_invalid_m),
1108
+ "sim_far_depth_invalid_max_frac": float(sim_far_depth_invalid_max_frac),
1109
+ "re10k_pseudo_far_depth_invalid_m": float(re10k_pseudo_far_depth_invalid_m),
1110
+ "scanetpp_fisheye_far_depth_invalid_m": float(scanetpp_fisheye_far_depth_invalid_m),
1111
+ "render_low_pass_filter_eps": float(render_low_pass_filter_eps),
1112
+ }
1113
+ (out_dir / "config.json").write_text(
1114
+ json.dumps(config_dict, ensure_ascii=False, indent=2, sort_keys=True) + "\n",
1115
+ encoding="utf-8",
1116
+ )
1117
+
1118
+ loss_csv = out_dir / "losses.csv"
1119
+ loss_csv_fields = [
1120
+ "loss",
1121
+ "src_loss",
1122
+ "tgt_loss",
1123
+ "dataset",
1124
+ ]
1125
+ if is_main:
1126
+ with loss_csv.open("w", newline="") as f:
1127
+ csv.DictWriter(f, fieldnames=loss_csv_fields).writeheader()
1128
+
1129
+ if is_main:
1130
+ LOGGER.info("Training loop started.")
1131
+
1132
+ from unisharp.cli.unified_trainer import UnifiedTrainer
1133
+
1134
+ trainer = UnifiedTrainer(
1135
+ model=model,
1136
+ renderer=renderer,
1137
+ loss_fn=loss_fn,
1138
+ device=dev,
1139
+ max_depth_m=float(max_depth_m),
1140
+ sim_far_depth_invalid_m=float(sim_far_depth_invalid_m),
1141
+ re10k_pseudo_far_depth_invalid_m=float(re10k_pseudo_far_depth_invalid_m),
1142
+ scanetpp_fisheye_far_depth_invalid_m=float(scanetpp_fisheye_far_depth_invalid_m),
1143
+ aux_ray_loss_weight=float(lambda_aux_ray),
1144
+ aux_depth_scale_loss_weight=float(lambda_aux_depth_scale),
1145
+ aux_depth2_scale_loss_weight=float(lambda_aux_depth2_scale),
1146
+ target_mask_erode_px=int(target_mask_erode_px),
1147
+ )
1148
+ skip_forward_oom = _env_flag("TRAIN_SKIP_FORWARD_OOM", default=True)
1149
+
1150
+ dataset_epochs: dict[str, int] = {name: 0 for name in dataloaders.keys()}
1151
+ dataset_samplers: dict[str, DistributedSampler | None] = {"hm3d": hm3d_sampler}
1152
+
1153
+ for step in range(1, steps + 1):
1154
+ lr = warmup_cosine_lr(step, warmup, steps, lr0, lr1)
1155
+ lr_unik3d_encoder = warmup_cosine_lr(step, warmup, steps, unik3d_encoder_lr0, unik3d_encoder_lr1)
1156
+ lr_unik3d_decoder = warmup_cosine_lr(step, warmup, steps, unik3d_lr0, unik3d_lr1)
1157
+ for g in opt.param_groups:
1158
+ if g.get("group_name") == "unik3d_encoder":
1159
+ g["lr"] = lr_unik3d_encoder
1160
+ elif g.get("group_name") == "unik3d_decoder":
1161
+ g["lr"] = lr_unik3d_decoder
1162
+ else:
1163
+ g["lr"] = lr
1164
+
1165
+ if _ddp_is_enabled():
1166
+ batch = None
1167
+ available_dataset_names = list(dataloaders.keys())
1168
+ dataset_name = ""
1169
+ for _dataset_attempt in range(max(1, len(dataloaders))):
1170
+ dataset_name = _ddp_broadcast_str(
1171
+ sampler.choose_dataset_name(available_dataset_names) if is_main else "",
1172
+ is_main=is_main,
1173
+ )
1174
+
1175
+ local_exhausted = False
1176
+ try:
1177
+ batch = sampler.next_batch(dataset_name)
1178
+ except StopIteration:
1179
+ local_exhausted = True
1180
+
1181
+ exhausted_any = _ddp_any_bool(local_exhausted, device=dev)
1182
+ if exhausted_any:
1183
+ dataset_epochs[dataset_name] = dataset_epochs.get(dataset_name, 0) + 1
1184
+ ds_sampler = dataset_samplers.get(dataset_name, None)
1185
+ if ds_sampler is not None:
1186
+ ds_sampler.set_epoch(dataset_epochs[dataset_name])
1187
+ _maybe_set_dataset_epoch(datasets[dataset_name], dataset_epochs[dataset_name])
1188
+ iterators[dataset_name] = iter(dataloaders[dataset_name])
1189
+ sampler.iterators = iterators
1190
+ batch = None
1191
+
1192
+ local_exhausted = False
1193
+ try:
1194
+ batch = sampler.next_batch(dataset_name)
1195
+ except StopIteration:
1196
+ local_exhausted = True
1197
+ exhausted_any = _ddp_any_bool(local_exhausted, device=dev)
1198
+
1199
+ if not exhausted_any:
1200
+ break
1201
+
1202
+ batch = None
1203
+ available_dataset_names = [name for name in available_dataset_names if name != dataset_name]
1204
+ if len(available_dataset_names) == 0:
1205
+ break
1206
+ if batch is None:
1207
+ raise RuntimeError(f"Failed to fetch synchronized DDP batch for dataset={dataset_name}")
1208
+ else:
1209
+ try:
1210
+ dataset_name, batch = sampler.sample()
1211
+ except StopIteration as e:
1212
+ msg = str(e)
1213
+ exhausted_name = None
1214
+ if msg.startswith("Dataset ") and msg.endswith(" exhausted"):
1215
+ exhausted_name = msg[len("Dataset ") : -len(" exhausted")]
1216
+ if exhausted_name is None or exhausted_name not in dataloaders:
1217
+ raise
1218
+ dataset_epochs[exhausted_name] = dataset_epochs.get(exhausted_name, 0) + 1
1219
+ ds_sampler = dataset_samplers.get(exhausted_name, None)
1220
+ if ds_sampler is not None:
1221
+ ds_sampler.set_epoch(dataset_epochs[exhausted_name])
1222
+ _maybe_set_dataset_epoch(datasets[exhausted_name], dataset_epochs[exhausted_name])
1223
+ iterators[exhausted_name] = iter(dataloaders[exhausted_name])
1224
+ sampler.iterators = iterators
1225
+ dataset_name, batch = sampler.sample()
1226
+
1227
+ batch = _resize_training_batch_to_multiple(batch, int(train_resize_multiple))
1228
+
1229
+ opt.zero_grad(set_to_none=True)
1230
+
1231
+ autocast_enabled = dev.type == "cuda"
1232
+ if autocast_enabled and torch.cuda.is_bf16_supported():
1233
+ autocast_dtype = torch.bfloat16
1234
+ else:
1235
+ autocast_dtype = torch.float16 if autocast_enabled else torch.bfloat16
1236
+
1237
+ need_vis = bool(is_main and vis_every > 0 and (step % vis_every == 0))
1238
+ result: dict[str, Any] | None = None
1239
+ forward_oom_local = False
1240
+ forward_oom_error = ""
1241
+ try:
1242
+ with torch.autocast(device_type=dev.type, enabled=autocast_enabled, dtype=autocast_dtype):
1243
+ result = trainer.process_batch(
1244
+ batch,
1245
+ dataset_name,
1246
+ step,
1247
+ need_vis=need_vis,
1248
+ )
1249
+ except Exception as e:
1250
+ if skip_forward_oom and _is_oom_exception(e):
1251
+ forward_oom_local = True
1252
+ forward_oom_error = str(e)
1253
+ opt.zero_grad(set_to_none=True)
1254
+ if dev.type == "cuda":
1255
+ torch.cuda.empty_cache()
1256
+ else:
1257
+ raise
1258
+
1259
+ forward_oom_any = _ddp_any_bool(forward_oom_local, device=dev)
1260
+ if forward_oom_any:
1261
+ opt.zero_grad(set_to_none=True)
1262
+ if result is not None:
1263
+ del result
1264
+ result = None
1265
+ if dev.type == "cuda":
1266
+ torch.cuda.empty_cache()
1267
+ if is_main:
1268
+ LOGGER.error(
1269
+ "Skipping optimizer step=%d because forward OOM occurred on at least one rank | dataset=%s",
1270
+ int(step),
1271
+ str(dataset_name),
1272
+ )
1273
+ continue
1274
+
1275
+ if result is None:
1276
+ raise RuntimeError(f"Forward returned no result for dataset={dataset_name} step={step}")
1277
+ total_loss = result["total"]
1278
+ local_nonfinite_loss = not bool(torch.isfinite(total_loss.detach()).item())
1279
+ nonfinite_loss_any = _ddp_any_bool(local_nonfinite_loss, device=dev)
1280
+ if nonfinite_loss_any:
1281
+ opt.zero_grad(set_to_none=True)
1282
+ if is_main:
1283
+ LOGGER.error(
1284
+ "Skipping optimizer step=%d because loss is non-finite on at least one rank | dataset=%s",
1285
+ int(step),
1286
+ str(dataset_name),
1287
+ )
1288
+ continue
1289
+
1290
+ try:
1291
+ scaler.scale(total_loss).backward()
1292
+ except Exception as e:
1293
+ raise
1294
+ try:
1295
+ scaler.unscale_(opt)
1296
+ grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=float(grad_clip_norm))
1297
+ except Exception as e:
1298
+ LOGGER.error("Gradient unscale/clip failed at step=%d: %s", int(step), str(e))
1299
+ raise
1300
+ grad_norm_value = float(grad_norm.detach().to(dtype=torch.float32).cpu().item()) if torch.is_tensor(grad_norm) else float(grad_norm)
1301
+ local_nonfinite_grad = not np.isfinite(grad_norm_value)
1302
+ nonfinite_grad_any = _ddp_any_bool(local_nonfinite_grad, device=dev)
1303
+ if nonfinite_grad_any:
1304
+ opt.zero_grad(set_to_none=True)
1305
+ scaler.update()
1306
+ if is_main:
1307
+ LOGGER.error(
1308
+ "Skipping optimizer step=%d because grad norm is non-finite on at least one rank | dataset=%s | local_grad_norm=%s",
1309
+ int(step),
1310
+ str(dataset_name),
1311
+ str(grad_norm_value),
1312
+ )
1313
+ continue
1314
+ local_huge_grad = bool(float(max_step_grad_norm) > 0.0 and grad_norm_value > float(max_step_grad_norm))
1315
+ huge_grad_any = _ddp_any_bool(local_huge_grad, device=dev)
1316
+ if huge_grad_any:
1317
+ opt.zero_grad(set_to_none=True)
1318
+ scaler.update()
1319
+ if is_main:
1320
+ LOGGER.error(
1321
+ "Skipping optimizer step=%d because grad norm exceeded max-step-grad-norm on at least one rank | dataset=%s | local_grad_norm=%.6g | threshold=%.6g",
1322
+ int(step),
1323
+ str(dataset_name),
1324
+ float(grad_norm_value),
1325
+ float(max_step_grad_norm),
1326
+ )
1327
+ continue
1328
+ scaler.step(opt)
1329
+ scaler.update()
1330
+
1331
+ if log_every > 0 and step % log_every == 0:
1332
+ loss_v = float(_ddp_mean(total_loss.detach()).item())
1333
+ src_v = float(_ddp_mean(result["src"].detach()).item())
1334
+ tgt_v = float(_ddp_mean(result["tgt"].detach()).item())
1335
+ row = {
1336
+ "loss": loss_v,
1337
+ "src_loss": src_v,
1338
+ "tgt_loss": tgt_v,
1339
+ "dataset": str(dataset_name),
1340
+ }
1341
+ if is_main:
1342
+ LOGGER.info(
1343
+ "step=%d dataset=%s loss=%.6f src_loss=%.6f tgt_loss=%.6f",
1344
+ step,
1345
+ dataset_name,
1346
+ loss_v,
1347
+ src_v,
1348
+ tgt_v,
1349
+ )
1350
+ row_csv = dict(row)
1351
+ for k in ("loss", "src_loss", "tgt_loss"):
1352
+ v = float(row_csv.get(k, float("nan")))
1353
+ row_csv[k] = "" if not np.isfinite(v) else f"{v:.4f}"
1354
+ with loss_csv.open("a", newline="") as f:
1355
+ csv.DictWriter(f, fieldnames=loss_csv_fields).writerow(row_csv)
1356
+
1357
+ if need_vis and result.get("vis_payload"):
1358
+ vis = result["vis_payload"]
1359
+ _save_train_vis(
1360
+ out_dir,
1361
+ step,
1362
+ vis["src_gt"],
1363
+ vis["src_pred"],
1364
+ vis["src_alpha"],
1365
+ vis["tgt_gt"],
1366
+ vis["tgt_pred"],
1367
+ vis["tgt_alpha"],
1368
+ src_gt_depth=vis.get("src_gt_depth"),
1369
+ tgt_gt_depth=vis.get("tgt_gt_depth"),
1370
+ src_pred_depth=vis.get("src_pred_depth"),
1371
+ tgt_pred_depth=vis.get("tgt_pred_depth"),
1372
+ src_unik3d_depth=vis.get("src_unik3d_depth"),
1373
+ tgt_unik3d_depth=vis.get("tgt_unik3d_depth"),
1374
+ dataset_name=vis.get("dataset_name"),
1375
+ scene=vis.get("scene"),
1376
+ src_idx=vis.get("src_idx"),
1377
+ tgt_idx=vis.get("tgt_idx"),
1378
+ src_pose_w2c=vis.get("src_pose_w2c"),
1379
+ tgt_pose_w2c=vis.get("tgt_pose_w2c"),
1380
+ src_metric_mask=vis.get("src_metric_mask"),
1381
+ tgt_metric_mask=vis.get("tgt_metric_mask"),
1382
+ src_cube_gt_u8=vis.get("src_cube_gt_u8"),
1383
+ src_cube_pred_linear=vis.get("src_cube_pred_linear"),
1384
+ src_cube_alpha=vis.get("src_cube_alpha"),
1385
+ tgt_cube_gt_u8=vis.get("tgt_cube_gt_u8"),
1386
+ tgt_cube_pred_linear=vis.get("tgt_cube_pred_linear"),
1387
+ tgt_cube_alpha=vis.get("tgt_cube_alpha"),
1388
+ )
1389
+
1390
+ if need_vis:
1391
+ if "vis" in locals():
1392
+ del vis
1393
+ if dev.type == "cuda":
1394
+ torch.cuda.empty_cache()
1395
+ del result
1396
+ del total_loss
1397
+ batch = None
1398
+
1399
+ if is_main and (save_every > 0) and (step % save_every == 0):
1400
+ path = out_dir / f"step_{step:07d}.pt"
1401
+ raw_model.save_checkpoint(str(path), step, opt)
1402
+ LOGGER.info("💾 Saved checkpoint: %s", str(path))
1403
+
1404
+
1405
+ if _ddp_is_enabled():
1406
+ _ddp_barrier(dev)
1407
+ dist.destroy_process_group()
1408
+
1409
+ if is_main:
1410
+ LOGGER.info("✅ Training completed!")
unisharp/cli/train_utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
9
+ w1, x1, y1, z1 = q1.unbind(dim=-1)
10
+ w2, x2, y2, z2 = q2.unbind(dim=-1)
11
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
12
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
13
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
14
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
15
+ return torch.stack([w, x, y, z], dim=-1)
16
+
17
+
18
+ def rotmat_to_quat_wxyz(Rm: torch.Tensor) -> torch.Tensor:
19
+ m00, m01, m02 = Rm[0, 0], Rm[0, 1], Rm[0, 2]
20
+ m10, m11, m12 = Rm[1, 0], Rm[1, 1], Rm[1, 2]
21
+ m20, m21, m22 = Rm[2, 0], Rm[2, 1], Rm[2, 2]
22
+ tr = m00 + m11 + m22
23
+ if tr > 0.0:
24
+ s = torch.sqrt(tr + 1.0) * 2.0
25
+ w = 0.25 * s
26
+ x = (m21 - m12) / s
27
+ y = (m02 - m20) / s
28
+ z = (m10 - m01) / s
29
+ elif (m00 > m11) and (m00 > m22):
30
+ s = torch.sqrt(1.0 + m00 - m11 - m22) * 2.0
31
+ w = (m21 - m12) / s
32
+ x = 0.25 * s
33
+ y = (m01 + m10) / s
34
+ z = (m02 + m20) / s
35
+ elif m11 > m22:
36
+ s = torch.sqrt(1.0 + m11 - m00 - m22) * 2.0
37
+ w = (m02 - m20) / s
38
+ x = (m01 + m10) / s
39
+ y = 0.25 * s
40
+ z = (m12 + m21) / s
41
+ else:
42
+ s = torch.sqrt(1.0 + m22 - m00 - m11) * 2.0
43
+ w = (m10 - m01) / s
44
+ x = (m02 + m20) / s
45
+ y = (m12 + m21) / s
46
+ z = 0.25 * s
47
+ q = torch.stack([w, x, y, z])
48
+ return q / q.norm().clamp(min=1e-8)
49
+
50
+
51
+ def to_k4(k3: torch.Tensor) -> torch.Tensor:
52
+ b = k3.shape[0]
53
+ out = torch.eye(4, dtype=k3.dtype, device=k3.device).unsqueeze(0).repeat(b, 1, 1)
54
+ out[:, :3, :3] = k3
55
+ return out
56
+
57
+
58
+ def warmup_cosine_lr(step: int, warmup: int, total: int, lr0: float, lr1: float) -> float:
59
+ if step <= warmup:
60
+ return lr0 * float(step) / float(max(1, warmup))
61
+ t = (step - warmup) / float(max(1, total - warmup))
62
+ cos = 0.5 * (1 + np.cos(np.pi * t))
63
+ return lr1 + (lr0 - lr1) * cos
64
+
65
+
66
+ @torch.no_grad()
67
+ def compute_frustum_mask(
68
+ depth: torch.Tensor,
69
+ tgt_w2c: torch.Tensor,
70
+ src_w2c: torch.Tensor,
71
+ src_k3: torch.Tensor,
72
+ tgt_k3: torch.Tensor,
73
+ img_h: int,
74
+ img_w: int,
75
+ source_img_h: int | None = None,
76
+ source_img_w: int | None = None,
77
+ depth_min: float = 0.05,
78
+ margin: float = 0.05,
79
+ ) -> torch.Tensor:
80
+ dev = depth.device
81
+ f32 = torch.float32
82
+ src_h = int(img_h if source_img_h is None else source_img_h)
83
+ src_w = int(img_w if source_img_w is None else source_img_w)
84
+
85
+ d = depth[0, 0].to(f32)
86
+ valid = d > depth_min
87
+
88
+ vy, vx = torch.meshgrid(
89
+ torch.arange(img_h, device=dev, dtype=f32),
90
+ torch.arange(img_w, device=dev, dtype=f32),
91
+ indexing="ij",
92
+ )
93
+
94
+ fx_t = tgt_k3[0, 0, 0].to(f32)
95
+ fy_t = tgt_k3[0, 1, 1].to(f32)
96
+ cx_t = tgt_k3[0, 0, 2].to(f32)
97
+ cy_t = tgt_k3[0, 1, 2].to(f32)
98
+ X_t = (vx - cx_t) / fx_t * d
99
+ Y_t = (vy - cy_t) / fy_t * d
100
+ Z_t = d
101
+ pts_t = torch.stack([X_t, Y_t, Z_t], dim=-1).reshape(-1, 3)
102
+
103
+ c2w_t = torch.linalg.inv(tgt_w2c[0].to(f32))
104
+ pts_w = pts_t @ c2w_t[:3, :3].T + c2w_t[:3, 3][None, :]
105
+
106
+ w2c_s = src_w2c[0].to(f32)
107
+ pts_s = pts_w @ w2c_s[:3, :3].T + w2c_s[:3, 3][None, :]
108
+
109
+ Z_s = pts_s[:, 2].clamp(min=1e-4)
110
+ fx_s = src_k3[0, 0, 0].to(f32)
111
+ fy_s = src_k3[0, 1, 1].to(f32)
112
+ cx_s = src_k3[0, 0, 2].to(f32)
113
+ cy_s = src_k3[0, 1, 2].to(f32)
114
+ u_s = pts_s[:, 0] / Z_s * fx_s + cx_s
115
+ v_s = pts_s[:, 1] / Z_s * fy_s + cy_s
116
+
117
+ half_w = (src_w - 1) * 0.5
118
+ half_h = (src_h - 1) * 0.5
119
+ x_ndc = (u_s - half_w) / half_w
120
+ y_ndc = (v_s - half_h) / half_h
121
+
122
+ in_frust = (
123
+ (x_ndc.abs() <= 1.0 + margin)
124
+ & (y_ndc.abs() <= 1.0 + margin)
125
+ & (pts_s[:, 2] > 0)
126
+ )
127
+
128
+ mask = in_frust.reshape(img_h, img_w).float()
129
+ mask = mask * valid.float()
130
+ return mask[None, None]
unisharp/cli/unified_trainer.py ADDED
@@ -0,0 +1,1966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import os
5
+ from typing import Any, Callable
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from unisharp.utils.gsplat import GSplatRenderer
12
+ from unisharp.losses import UnisharpLoss
13
+ from unisharp.utils.camera_utils import (
14
+ transform_gaussians_to_world,
15
+ to_k4,
16
+ compute_frustum_mask,
17
+ )
18
+ from unisharp.utils.fisheye_geer import (
19
+ compute_fisheye624_frustum_mask,
20
+ render_gaussians_fisheye624,
21
+ )
22
+ from unisharp.utils.camera_projection import cubemap_face_cameras, build_extrinsics_w2c, view_frustum_mask_cubemap_union
23
+ from unisharp.utils.pano import Cube2Equirec, get_pinhole_intrinsics_4x4
24
+ from unisharp import DEFAULT_MAX_DEPTH_M
25
+ from unisharp.utils.pixel_convention import integer_pixel_center_grid
26
+
27
+
28
+ @dataclass
29
+ class _ModeStrategy:
30
+
31
+ batch_size: int
32
+ gaussians: Any
33
+ make_world_gaussians: Callable[[int, Any], Any]
34
+ make_sample: Callable[[int, Any, bool], dict[str, Any]]
35
+ collect_all_vis: bool = False
36
+
37
+
38
+ class UnifiedTrainer:
39
+
40
+ def __init__(
41
+ self,
42
+ model: nn.Module,
43
+ renderer: GSplatRenderer,
44
+ loss_fn: UnisharpLoss,
45
+ device: torch.device,
46
+ enable_tgt_unik3d_vis: bool = True,
47
+ max_depth_m: float = DEFAULT_MAX_DEPTH_M,
48
+ sim_far_depth_invalid_m: float = 30.0,
49
+ re10k_pseudo_far_depth_invalid_m: float = 30.0,
50
+ scanetpp_fisheye_far_depth_invalid_m: float = 30.0,
51
+ aux_ray_loss_weight: float = 3.0,
52
+ aux_depth_scale_loss_weight: float = 3.0,
53
+ aux_depth2_scale_loss_weight: float = 1.0,
54
+ target_mask_erode_px: int = 0,
55
+ ):
56
+ self.model = model
57
+ self.renderer = renderer
58
+ self.loss_fn = loss_fn
59
+ self.device = device
60
+ self.enable_tgt_unik3d_vis = bool(enable_tgt_unik3d_vis)
61
+ self.max_depth_m = float(max_depth_m)
62
+ self.sim_far_depth_invalid_m = float(sim_far_depth_invalid_m)
63
+ self.re10k_pseudo_far_depth_invalid_m = float(re10k_pseudo_far_depth_invalid_m)
64
+ self.scanetpp_fisheye_far_depth_invalid_m = float(scanetpp_fisheye_far_depth_invalid_m)
65
+ self.aux_ray_loss_weight = float(aux_ray_loss_weight)
66
+ self.aux_depth_scale_loss_weight = float(aux_depth_scale_loss_weight)
67
+ self.aux_depth2_scale_loss_weight = float(aux_depth2_scale_loss_weight)
68
+ self.target_mask_erode_px = max(int(target_mask_erode_px), 0)
69
+
70
+ @staticmethod
71
+ def _erode_supervision_mask(mask: torch.Tensor, radius_px: int, *, circular_h: bool = False) -> torch.Tensor:
72
+ radius = max(int(radius_px), 0)
73
+ if radius <= 0:
74
+ return mask
75
+ if not torch.is_tensor(mask):
76
+ return mask
77
+ m = mask.to(dtype=torch.float32).clamp(0.0, 1.0)
78
+ if m.ndim == 3:
79
+ m = m.unsqueeze(1)
80
+ invalid = 1.0 - m
81
+ kernel = 2 * radius + 1
82
+ if bool(circular_h):
83
+ invalid = F.pad(invalid, (radius, radius, 0, 0), mode="circular")
84
+ invalid = F.pad(invalid, (0, 0, radius, radius), mode="constant", value=0.0)
85
+ dilated_invalid = F.max_pool2d(invalid, kernel_size=kernel, stride=1)
86
+ else:
87
+ dilated_invalid = F.max_pool2d(invalid, kernel_size=kernel, stride=1, padding=radius)
88
+ return (m * (1.0 - dilated_invalid)).to(device=mask.device, dtype=mask.dtype)
89
+
90
+ def _aux_ray_losses(
91
+ self,
92
+ *,
93
+ pred_rays: torch.Tensor | None,
94
+ gt_rays: torch.Tensor | None,
95
+ mask: torch.Tensor | None,
96
+ pred_distance: torch.Tensor | None = None,
97
+ pred_distance2: torch.Tensor | None = None,
98
+ gt_distance: torch.Tensor | None = None,
99
+ gt_distance2: torch.Tensor | None = None,
100
+ depth_mask: torch.Tensor | None = None,
101
+ depth_mask2: torch.Tensor | None = None,
102
+ ) -> dict[str, torch.Tensor]:
103
+ out: dict[str, torch.Tensor] = {}
104
+ if torch.is_tensor(pred_rays) and torch.is_tensor(gt_rays) and self.aux_ray_loss_weight > 0.0:
105
+ out["unik3d_ray"] = self.aux_ray_loss_weight * self._unik3d_polar_ray_loss(
106
+ pred_rays,
107
+ gt_rays,
108
+ mask,
109
+ )
110
+ if torch.is_tensor(pred_distance) and torch.is_tensor(gt_distance):
111
+ out["unik3d_depth_scale"] = self.aux_depth_scale_loss_weight * self._unik3d_scale_depth_loss(
112
+ pred_distance,
113
+ gt_distance,
114
+ depth_mask if torch.is_tensor(depth_mask) else mask,
115
+ )
116
+ depth2_target = gt_distance2 if torch.is_tensor(gt_distance2) else gt_distance
117
+ if torch.is_tensor(pred_distance2) and torch.is_tensor(depth2_target):
118
+ depth2_mask = depth_mask2 if torch.is_tensor(depth_mask2) else depth_mask
119
+ out["unik3d_depth2_scale"] = self.aux_depth2_scale_loss_weight * self._unik3d_scale_depth_loss(
120
+ pred_distance2,
121
+ depth2_target,
122
+ depth2_mask if torch.is_tensor(depth2_mask) else mask,
123
+ )
124
+ return out
125
+
126
+ DEPTH_SUPERVISION_MAX_M: float = DEFAULT_MAX_DEPTH_M
127
+
128
+ def _distance_init_cap_for_dataset(self, dataset_name: str) -> float | None:
129
+ name = str(dataset_name).lower()
130
+ if name == "re10k" and self.re10k_pseudo_far_depth_invalid_m > 0.0:
131
+ return self.re10k_pseudo_far_depth_invalid_m
132
+ if name == "sim" and self.sim_far_depth_invalid_m > 0.0:
133
+ return self.sim_far_depth_invalid_m
134
+ if name in {"scanetpp_fisheye", "scannetpp_fisheye"} and self.scanetpp_fisheye_far_depth_invalid_m > 0.0:
135
+ return self.scanetpp_fisheye_far_depth_invalid_m
136
+ return None
137
+
138
+ @staticmethod
139
+ def _unik3d_polar_ray_loss(
140
+ pred_rays: torch.Tensor | None,
141
+ gt_rays: torch.Tensor | None,
142
+ mask: torch.Tensor | None,
143
+ ) -> torch.Tensor:
144
+ if not torch.is_tensor(pred_rays) or not torch.is_tensor(gt_rays):
145
+ device = pred_rays.device if torch.is_tensor(pred_rays) else torch.device("cpu")
146
+ return torch.zeros((), device=device, dtype=torch.float32)
147
+ pred = pred_rays.to(dtype=torch.float32)
148
+ gt = gt_rays.to(device=pred.device, dtype=torch.float32)
149
+ if pred.ndim == 3:
150
+ pred = pred.unsqueeze(0)
151
+ if gt.ndim == 3:
152
+ gt = gt.unsqueeze(0)
153
+ if tuple(pred.shape) != tuple(gt.shape):
154
+ gt = F.interpolate(gt, size=pred.shape[-2:], mode="bilinear", align_corners=False)
155
+ gt = gt / torch.norm(gt, dim=1, keepdim=True).clamp(min=1e-5)
156
+ pred = pred / torch.norm(pred, dim=1, keepdim=True).clamp(min=1e-5)
157
+ gt = gt / torch.norm(gt, dim=1, keepdim=True).clamp(min=1e-5)
158
+
159
+ px, py, pz = pred.unbind(dim=1)
160
+ gx, gy, gz = gt.unbind(dim=1)
161
+ polar_pred = torch.acos(pz.clamp(min=-0.99999, max=0.99999))
162
+ polar_gt = torch.acos(gz.clamp(min=-0.99999, max=0.99999))
163
+ az_pred = torch.atan2(py, px.abs().clamp(min=1e-5) * (2.0 * (px > 0).to(px.dtype) - 1.0))
164
+ az_gt = torch.atan2(gy, gx.abs().clamp(min=1e-5) * (2.0 * (gx > 0).to(gx.dtype) - 1.0))
165
+ polar_error = (polar_pred - polar_gt).abs()
166
+ az_delta = az_pred - az_gt
167
+ az_error = torch.atan2(torch.sin(az_delta), torch.cos(az_delta)).abs()
168
+ quantile_weight = torch.ones_like(polar_error)
169
+ quantile_weight[(polar_gt > polar_pred) & (polar_gt > torch.pi / 2)] = 1.4
170
+ quantile_weight[(polar_gt <= polar_pred) & (polar_gt > torch.pi / 2)] = 0.6
171
+
172
+ if torch.is_tensor(mask):
173
+ m = mask.to(device=pred.device, dtype=torch.float32)
174
+ if m.ndim == 3:
175
+ m = m.unsqueeze(1)
176
+ if tuple(m.shape[-2:]) != tuple(pred.shape[-2:]):
177
+ m = F.interpolate(m, size=pred.shape[-2:], mode="nearest")
178
+ m = m[:, 0].clamp(0.0, 1.0)
179
+ else:
180
+ m = torch.ones_like(polar_error)
181
+ denom = m.sum(dim=(-1, -2), keepdim=False).clamp(min=1.0)
182
+ mean_polar = (polar_error * quantile_weight * m).sum(dim=(-1, -2)) / denom
183
+ mean_azimuth = (az_error * m).sum(dim=(-1, -2)) / denom
184
+ mean_error = (3.0 * mean_polar + mean_azimuth) / 4.0
185
+ return torch.sqrt(mean_error + 1e-4).mean()
186
+
187
+ @staticmethod
188
+ def _unik3d_scale_depth_loss(
189
+ pred_distance: torch.Tensor,
190
+ gt_distance: torch.Tensor,
191
+ mask: torch.Tensor | None,
192
+ ) -> torch.Tensor:
193
+ pred = UnifiedTrainer._as_b1hw_depth(pred_distance).to(dtype=torch.float32)
194
+ gt = UnifiedTrainer._as_b1hw_depth(gt_distance).to(device=pred.device, dtype=torch.float32)
195
+ if tuple(gt.shape[-2:]) != tuple(pred.shape[-2:]):
196
+ gt = F.interpolate(gt, size=pred.shape[-2:], mode="nearest")
197
+ valid = torch.isfinite(pred) & torch.isfinite(gt) & (pred > 0.0) & (gt > 0.0)
198
+ if torch.is_tensor(mask):
199
+ m = mask.to(device=pred.device)
200
+ if m.ndim == 3:
201
+ m = m.unsqueeze(1)
202
+ if tuple(m.shape[-2:]) != tuple(pred.shape[-2:]):
203
+ m = F.interpolate(m.to(dtype=torch.float32), size=pred.shape[-2:], mode="nearest")
204
+ valid = valid & (m[:, :1] > 0.5)
205
+ err = (gt.clamp(min=1e-4).log() - pred.clamp(min=1e-4).log()).abs()
206
+ err = torch.where(valid, err, torch.zeros_like(err))
207
+ denom = valid.to(dtype=err.dtype).sum(dim=(-2, -1)).clamp(min=1.0)
208
+ per_image = err.sum(dim=(-2, -1)) / denom
209
+ return torch.sqrt(per_image.clamp(min=0.0)).mean()
210
+
211
+ def _base_model(self) -> nn.Module:
212
+ return self.model.module if hasattr(self.model, "module") else self.model
213
+
214
+ def process_batch(
215
+ self,
216
+ batch: Any,
217
+ dataset_name: str,
218
+ step: int,
219
+ need_vis: bool = False,
220
+ ) -> dict[str, Any]:
221
+ if hasattr(batch, "src_rgb_u8") and hasattr(batch, "src_intrinsics"):
222
+ strategy = self._build_pinhole_strategy(
223
+ batch,
224
+ step,
225
+ need_vis=need_vis,
226
+ dataset_name=str(dataset_name),
227
+ )
228
+ elif hasattr(batch, "src_rgb_u8") and hasattr(batch, "src_camera_params"):
229
+ strategy = self._build_fisheye_strategy(
230
+ batch,
231
+ step,
232
+ need_vis=need_vis,
233
+ dataset_name=str(dataset_name),
234
+ )
235
+ elif hasattr(batch, "src_erp_rgb_u8") and hasattr(batch, "src_cube_depth_m"):
236
+ strategy = self._build_spherical_strategy(
237
+ batch,
238
+ step,
239
+ need_vis=need_vis,
240
+ dataset_name=str(dataset_name),
241
+ )
242
+ else:
243
+ raise ValueError(f"Unknown batch schema for dataset={dataset_name}")
244
+ return self._run_strategy_loop(
245
+ strategy,
246
+ need_vis=need_vis,
247
+ )
248
+
249
+ def _run_strategy_loop(
250
+ self,
251
+ strategy: _ModeStrategy,
252
+ need_vis: bool = False,
253
+ ) -> dict[str, Any]:
254
+ total_loss = torch.zeros((), device=self.device)
255
+ src_sum = torch.zeros((), device=self.device)
256
+ tgt_sum = torch.zeros((), device=self.device)
257
+ src_log_sum: dict[str, torch.Tensor] = {}
258
+ tgt_log_sum: dict[str, torch.Tensor] = {}
259
+ aux_log_sum: dict[str, torch.Tensor] = {}
260
+ vis_payload = None
261
+ vis_payloads: list[dict[str, Any]] = []
262
+
263
+ def _accumulate_loss_terms(term_specs: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
264
+ merged: dict[str, torch.Tensor] = {}
265
+ for spec in term_specs:
266
+ term_losses = self._compute_view_loss(**spec)
267
+ for k, v in term_losses.items():
268
+ merged[k] = merged.get(k, torch.zeros((), device=self.device)) + v
269
+ return merged
270
+
271
+ collect_all_vis = bool(getattr(strategy, "collect_all_vis", False))
272
+ for b in range(int(strategy.batch_size)):
273
+ g = strategy.gaussians
274
+ g_b = type(g)(
275
+ mean_vectors=g.mean_vectors[b : b + 1],
276
+ singular_values=g.singular_values[b : b + 1],
277
+ quaternions=g.quaternions[b : b + 1],
278
+ colors=g.colors[b : b + 1],
279
+ opacities=g.opacities[b : b + 1],
280
+ )
281
+ g_world = strategy.make_world_gaussians(b, g_b)
282
+ sample = strategy.make_sample(
283
+ b,
284
+ g_world,
285
+ bool(need_vis and (collect_all_vis or b == 0)),
286
+ )
287
+
288
+ if isinstance(sample.get("src_loss_terms", None), list):
289
+ src_losses = _accumulate_loss_terms(sample["src_loss_terms"])
290
+ else:
291
+ src_losses = self._compute_view_loss(
292
+ pred_rgb_linear=sample["src_pred_rgb_linear"],
293
+ pred_alpha=sample["src_pred_alpha"],
294
+ pred_depth_m=sample["src_pred_depth_m"],
295
+ pred_depth2_m=sample.get("src_pred_depth2_m", None),
296
+ gt_rgb_u8=sample["src_gt_rgb_u8"],
297
+ gt_depth_m=sample["src_gt_depth_m"],
298
+ mask=sample["src_mask"],
299
+ apply_color=bool(sample.get("src_apply_color", True)),
300
+ apply_alpha=bool(sample.get("src_apply_alpha", True)),
301
+ apply_depth=bool(sample.get("src_apply_depth", True)),
302
+ apply_percep=False,
303
+ apply_tv=True,
304
+ apply_grad=bool(sample.get("src_apply_grad", True)),
305
+ apply_grad_img=bool(sample.get("src_apply_grad_img", True)),
306
+ apply_splat=bool(sample.get("src_apply_splat", True)),
307
+ grad_img_circular_h=sample.get("src_grad_img_circular_h", None),
308
+ gaussian_scales=sample.get("gaussian_scales", None),
309
+ gaussian_quaternions=sample.get("gaussian_quaternions", None),
310
+ gaussian_angular_cell=sample.get("gaussian_angular_cell", None),
311
+ delta_xy=sample.get("delta_xy", None),
312
+ delta_rho=sample.get("delta_rho", None),
313
+ delta_grid=sample.get("delta_grid", None),
314
+ gaussian_mean_vectors=sample.get("gaussian_mean_vectors", None),
315
+ gaussian_base_mean_vectors=sample.get("gaussian_base_mean_vectors", None),
316
+ gaussian_opacities=sample.get("gaussian_opacities", None),
317
+ gauss_grid_shape=sample.get("gauss_grid_shape", None),
318
+ projected_scale_factor=sample.get("projected_scale_factor", None),
319
+ projection_model=sample.get("projection_model", None),
320
+ projection_intrinsics=sample.get("projection_intrinsics", None),
321
+ projection_camera_params=sample.get("projection_camera_params", None),
322
+ depth_mask=sample.get("src_depth_mask", None),
323
+ )
324
+ if isinstance(sample.get("src_extra_loss_terms", None), list):
325
+ extra_src_losses = _accumulate_loss_terms(sample["src_extra_loss_terms"])
326
+ for k, v in extra_src_losses.items():
327
+ src_losses[k] = src_losses.get(k, torch.zeros((), device=self.device)) + v
328
+ if isinstance(sample.get("tgt_loss_terms", None), list):
329
+ tgt_losses = _accumulate_loss_terms(sample["tgt_loss_terms"])
330
+ else:
331
+ tgt_losses = self._compute_view_loss(
332
+ pred_rgb_linear=sample["tgt_pred_rgb_linear"],
333
+ pred_alpha=sample["tgt_pred_alpha"],
334
+ pred_depth_m=sample["tgt_pred_depth_m"],
335
+ pred_depth2_m=sample.get("tgt_pred_depth2_m", None),
336
+ gt_rgb_u8=sample["tgt_gt_rgb_u8"],
337
+ gt_depth_m=sample["tgt_gt_depth_m"],
338
+ mask=sample["tgt_mask"],
339
+ apply_color=bool(sample.get("tgt_apply_color", True)),
340
+ apply_alpha=bool(sample.get("tgt_apply_alpha", True)),
341
+ apply_depth=bool(sample.get("tgt_apply_depth", True)),
342
+ apply_percep=bool(sample.get("tgt_apply_percep", False)),
343
+ apply_tv=False,
344
+ apply_grad=False,
345
+ apply_grad_img=bool(sample.get("tgt_apply_grad_img", True)),
346
+ apply_splat=bool(sample.get("tgt_apply_splat", False)),
347
+ grad_img_circular_h=sample.get("tgt_grad_img_circular_h", None),
348
+ gaussian_scales=None,
349
+ gaussian_quaternions=None,
350
+ delta_xy=None,
351
+ delta_rho=None,
352
+ gaussian_mean_vectors=None,
353
+ gaussian_base_mean_vectors=None,
354
+ gaussian_opacities=None,
355
+ gauss_grid_shape=None,
356
+ projected_scale_factor=sample.get("projected_scale_factor", None),
357
+ projection_model=sample.get("projection_model", None),
358
+ projection_intrinsics=sample.get("projection_intrinsics", None),
359
+ projection_camera_params=sample.get("projection_camera_params", None),
360
+ depth_mask=sample.get("tgt_depth_mask", None),
361
+ )
362
+ if isinstance(sample.get("tgt_extra_loss_terms", None), list):
363
+ extra_tgt_losses = _accumulate_loss_terms(sample["tgt_extra_loss_terms"])
364
+ for k, v in extra_tgt_losses.items():
365
+ tgt_losses[k] = tgt_losses.get(k, torch.zeros((), device=self.device)) + v
366
+
367
+ aux_total = torch.zeros((), device=self.device)
368
+ raw_aux = sample.get("aux_losses", None)
369
+ if isinstance(raw_aux, dict):
370
+ for k, v in raw_aux.items():
371
+ if torch.is_tensor(v):
372
+ vv = v.to(device=self.device)
373
+ else:
374
+ vv = torch.tensor(float(v), device=self.device, dtype=torch.float32)
375
+ aux_total = aux_total + vv
376
+ aux_log_sum[str(k)] = aux_log_sum.get(str(k), torch.zeros((), device=self.device)) + vv.detach()
377
+
378
+ src_sum = src_sum + src_losses["total"]
379
+ tgt_sum = tgt_sum + tgt_losses["total"]
380
+ total_loss = total_loss + src_losses["total"] + tgt_losses["total"] + aux_total
381
+ for k, v in src_losses.items():
382
+ src_log_sum[k] = src_log_sum.get(k, torch.zeros((), device=self.device)) + v.detach()
383
+ for k, v in tgt_losses.items():
384
+ tgt_log_sum[k] = tgt_log_sum.get(k, torch.zeros((), device=self.device)) + v.detach()
385
+
386
+ if need_vis and isinstance(sample.get("vis_payload", None), dict):
387
+ vis_payloads.append(sample["vis_payload"])
388
+ if b == 0:
389
+ vis_payload = sample["vis_payload"]
390
+
391
+ bs = float(strategy.batch_size)
392
+ total_loss = total_loss / bs
393
+ src_sum = src_sum / bs
394
+ tgt_sum = tgt_sum / bs
395
+ loss_breakdown: dict[str, torch.Tensor] = {}
396
+ for k, v in src_log_sum.items():
397
+ loss_breakdown[f"src_{k}"] = v / bs
398
+ for k, v in tgt_log_sum.items():
399
+ loss_breakdown[f"tgt_{k}"] = v / bs
400
+ for k, v in aux_log_sum.items():
401
+ loss_breakdown[f"aux_{k}"] = v / bs
402
+ batch_stats = {
403
+ "batch_size": int(strategy.batch_size),
404
+ "gaussian_count": int(strategy.gaussians.mean_vectors.shape[1]),
405
+ }
406
+
407
+ return {
408
+ "total": total_loss,
409
+ "src": src_sum,
410
+ "tgt": tgt_sum,
411
+ "loss_breakdown": loss_breakdown,
412
+ "batch_stats": batch_stats,
413
+ "vis_payload": vis_payload,
414
+ "vis_payloads": vis_payloads,
415
+ }
416
+
417
+ @staticmethod
418
+ def _first_item(x: Any, default: Any = None) -> Any:
419
+ if x is None:
420
+ return default
421
+ if isinstance(x, (list, tuple)):
422
+ return x[0] if len(x) > 0 else default
423
+ if torch.is_tensor(x):
424
+ if x.numel() == 0:
425
+ return default
426
+ return x.flatten()[0].item()
427
+ return x
428
+
429
+ @staticmethod
430
+ def _item_at(x: Any, index: int, default: Any = None) -> Any:
431
+ if x is None:
432
+ return default
433
+ if isinstance(x, (list, tuple)):
434
+ return x[index] if 0 <= int(index) < len(x) else default
435
+ if torch.is_tensor(x):
436
+ if x.numel() == 0:
437
+ return default
438
+ if x.ndim == 0:
439
+ return x.item()
440
+ if 0 <= int(index) < int(x.shape[0]):
441
+ item = x[int(index)]
442
+ return item.item() if item.numel() == 1 else item
443
+ return default
444
+ return x
445
+
446
+ @staticmethod
447
+ def _finite_quantile(x: torch.Tensor, q: float, default: float = float("nan")) -> torch.Tensor:
448
+ vals = x[torch.isfinite(x)]
449
+ if int(vals.numel()) <= 0:
450
+ return torch.tensor(float(default), device=x.device, dtype=torch.float32)
451
+ vals = vals.to(torch.float32).flatten()
452
+ if int(vals.numel()) > 262144:
453
+ step = max(1, int(vals.numel()) // 262144)
454
+ vals = vals[::step]
455
+ return torch.quantile(vals, float(q))
456
+
457
+ def _clamp_distance_for_supervision(
458
+ self,
459
+ depth_m: torch.Tensor | None,
460
+ *,
461
+ max_depth_m: float | None = None,
462
+ clamp_max: bool = True,
463
+ ) -> torch.Tensor | None:
464
+ if not torch.is_tensor(depth_m):
465
+ return None
466
+ cap = float(self.max_depth_m if max_depth_m is None else max_depth_m)
467
+ out = depth_m.to(dtype=torch.float32)
468
+ valid = torch.isfinite(out) & (out > 0.0)
469
+ if bool(clamp_max):
470
+ sanitized = out.clamp(min=1e-4, max=cap)
471
+ else:
472
+ sanitized = out.clamp(min=1e-4)
473
+ return torch.where(valid, sanitized, torch.zeros_like(out))
474
+
475
+ @staticmethod
476
+ def _rendered_depth_valid_for_inv_loss(
477
+ depth_m: torch.Tensor,
478
+ alpha: torch.Tensor,
479
+ *,
480
+ alpha_min: float | None = None,
481
+ depth_min_m: float = 1e-3,
482
+ ) -> torch.Tensor:
483
+ depth = depth_m.detach()
484
+ valid = torch.isfinite(depth) & (depth > float(depth_min_m))
485
+ if alpha_min is not None:
486
+ a = alpha.detach().to(device=depth.device)
487
+ valid = valid & (a[:, :1] > float(alpha_min))
488
+ return valid.to(dtype=depth.dtype)
489
+
490
+ def _pinhole_z_to_supervision_distance(
491
+ self,
492
+ z_depth_b1hw: torch.Tensor | None,
493
+ k3_b33: torch.Tensor | None,
494
+ *,
495
+ clamp_max: bool = True,
496
+ ) -> torch.Tensor | None:
497
+ if not torch.is_tensor(z_depth_b1hw) or not torch.is_tensor(k3_b33):
498
+ return None
499
+ dist = self._z_depth_to_distance_pinhole(z_depth_b1hw, k3_b33)
500
+ return self._clamp_distance_for_supervision(dist, clamp_max=bool(clamp_max))
501
+
502
+ @staticmethod
503
+ def _sanitize_positive_depth(depth_m: torch.Tensor | None) -> torch.Tensor | None:
504
+ if not torch.is_tensor(depth_m):
505
+ return None
506
+ out = depth_m.to(dtype=torch.float32)
507
+ valid = torch.isfinite(out) & (out > 0.0)
508
+ return torch.where(valid, out, torch.zeros_like(out))
509
+
510
+ @staticmethod
511
+ def _as_b1hw_depth(depth: torch.Tensor) -> torch.Tensor:
512
+ if depth.ndim == 3:
513
+ return depth.unsqueeze(1)
514
+ if depth.ndim == 4 and depth.shape[1] == 1:
515
+ return depth
516
+ raise ValueError(f"Expected depth shape (B,H,W) or (B,1,H,W), got {tuple(depth.shape)}")
517
+
518
+ @staticmethod
519
+ def _as_bchw_rgb_u8(image: torch.Tensor) -> torch.Tensor:
520
+ if image.ndim == 3 and image.shape[0] == 3:
521
+ return image.unsqueeze(0)
522
+ if image.ndim == 4 and image.shape[1] == 3:
523
+ return image
524
+ raise ValueError(f"Expected image shape (3,H,W) or (B,3,H,W), got {tuple(image.shape)}")
525
+
526
+ @staticmethod
527
+ def _as_b33_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor:
528
+ if intrinsics.ndim == 2 and tuple(intrinsics.shape) == (3, 3):
529
+ return intrinsics.unsqueeze(0)
530
+ if intrinsics.ndim == 3 and tuple(intrinsics.shape[1:]) == (3, 3):
531
+ return intrinsics
532
+ raise ValueError(
533
+ f"Expected intrinsics shape (3,3) or (B,3,3), got {tuple(intrinsics.shape)}"
534
+ )
535
+
536
+ @staticmethod
537
+ def _as_b9_camera_params(camera_params: torch.Tensor) -> torch.Tensor:
538
+ if camera_params.ndim == 1 and int(camera_params.shape[0]) == 9:
539
+ return camera_params.unsqueeze(0)
540
+ if camera_params.ndim == 2 and int(camera_params.shape[1]) == 9:
541
+ return camera_params
542
+ raise ValueError(f"Expected camera_params shape (9,) or (B,9), got {tuple(camera_params.shape)}")
543
+
544
+ @staticmethod
545
+ def _as_b16_camera_params(camera_params: torch.Tensor) -> torch.Tensor:
546
+ if camera_params.ndim == 1 and int(camera_params.shape[0]) == 16:
547
+ return camera_params.unsqueeze(0)
548
+ if camera_params.ndim == 2 and int(camera_params.shape[1]) == 16:
549
+ return camera_params
550
+ raise ValueError(f"Expected camera_params shape (16,) or (B,16), got {tuple(camera_params.shape)}")
551
+
552
+ @staticmethod
553
+ def _as_b44_pose(extrinsics: torch.Tensor) -> torch.Tensor:
554
+ if extrinsics.ndim == 2 and tuple(extrinsics.shape) == (4, 4):
555
+ return extrinsics.unsqueeze(0)
556
+ if extrinsics.ndim == 3 and tuple(extrinsics.shape[1:]) == (4, 4):
557
+ return extrinsics
558
+ raise ValueError(
559
+ f"Expected extrinsics shape (4,4) or (B,4,4), got {tuple(extrinsics.shape)}"
560
+ )
561
+
562
+ @staticmethod
563
+ def _pick_depth_for_pinhole_frustum_mask(
564
+ gt_depth: torch.Tensor | None,
565
+ pred_depth: torch.Tensor,
566
+ min_valid_px: int = 8,
567
+ ) -> torch.Tensor:
568
+ if torch.is_tensor(gt_depth):
569
+ gt_depth = UnifiedTrainer._as_b1hw_depth(gt_depth)
570
+ valid = torch.isfinite(gt_depth) & (gt_depth > 0.0)
571
+ if int(valid.sum().item()) >= int(min_valid_px):
572
+ return gt_depth
573
+ return pred_depth
574
+
575
+ @staticmethod
576
+ def _pick_depth_for_fisheye_frustum_mask(
577
+ gt_depth: torch.Tensor | None,
578
+ pred_depth: torch.Tensor,
579
+ gt_valid_mask: torch.Tensor | None = None,
580
+ min_valid_px: int = 8,
581
+ ) -> torch.Tensor:
582
+ if torch.is_tensor(gt_depth):
583
+ gt_depth = UnifiedTrainer._as_b1hw_depth(gt_depth)
584
+ if torch.is_tensor(gt_valid_mask):
585
+ gt_valid = gt_depth > 0.0
586
+ gt_valid = gt_valid & (gt_valid_mask > 0.5)
587
+ else:
588
+ gt_valid = torch.isfinite(gt_depth) & (gt_depth > 0.0)
589
+ if int(gt_valid.sum().item()) >= int(min_valid_px):
590
+ return gt_depth
591
+ return pred_depth
592
+
593
+ @staticmethod
594
+ def _as_cubemap_depth_hw1(depth: torch.Tensor) -> torch.Tensor:
595
+ if depth.ndim != 4:
596
+ raise ValueError(f"Expected 4D cubemap depth, got shape={tuple(depth.shape)}")
597
+ if depth.shape[-1] == 1:
598
+ return depth
599
+ if depth.shape[1] == 1:
600
+ return depth.permute(0, 2, 3, 1).contiguous()
601
+ raise ValueError(f"Unsupported cubemap depth shape={tuple(depth.shape)}")
602
+
603
+ def _pick_depth_for_cubemap_frustum_mask(
604
+ self,
605
+ gt_depth_cube: torch.Tensor | None,
606
+ pred_depth_cube: torch.Tensor,
607
+ face_w: int,
608
+ min_valid_px: int = 8,
609
+ ) -> torch.Tensor:
610
+ pred_hw1 = self._as_cubemap_depth_hw1(pred_depth_cube)
611
+ if torch.is_tensor(gt_depth_cube):
612
+ gt_hw1 = self._as_cubemap_depth_hw1(gt_depth_cube)
613
+ gt_dist = self._cubemap_z_depth_to_distance(gt_hw1)
614
+ gt_hw1 = self._as_cubemap_depth_hw1(gt_dist)
615
+ if gt_hw1.shape[1] != int(face_w) or gt_hw1.shape[2] != int(face_w):
616
+ gt_hw1 = F.interpolate(
617
+ gt_hw1.permute(0, 3, 1, 2),
618
+ size=(int(face_w), int(face_w)),
619
+ mode="nearest",
620
+ ).permute(0, 2, 3, 1).contiguous()
621
+ valid = torch.isfinite(gt_hw1[..., 0]) & (gt_hw1[..., 0] > 0.0)
622
+ if int(valid.sum().item()) >= int(min_valid_px):
623
+ return gt_hw1
624
+ return pred_hw1
625
+
626
+ @staticmethod
627
+ def _distance_to_z_depth_pinhole(
628
+ distance_b1hw: torch.Tensor,
629
+ intrinsics_b33: torch.Tensor,
630
+ ) -> torch.Tensor:
631
+ distance_b1hw = UnifiedTrainer._as_b1hw_depth(distance_b1hw)
632
+ intrinsics_b33 = UnifiedTrainer._as_b33_intrinsics(intrinsics_b33)
633
+ b, _, h, w = distance_b1hw.shape
634
+ dev = distance_b1hw.device
635
+ dtype = distance_b1hw.dtype
636
+ uu, vv = integer_pixel_center_grid(h, w, device=dev, dtype=dtype)
637
+ uu = uu.unsqueeze(0).expand(b, -1, -1)
638
+ vv = vv.unsqueeze(0).expand(b, -1, -1)
639
+ fx = intrinsics_b33[:, 0, 0].view(b, 1, 1).to(dtype=dtype, device=dev)
640
+ fy = intrinsics_b33[:, 1, 1].view(b, 1, 1).to(dtype=dtype, device=dev)
641
+ cx = intrinsics_b33[:, 0, 2].view(b, 1, 1).to(dtype=dtype, device=dev)
642
+ cy = intrinsics_b33[:, 1, 2].view(b, 1, 1).to(dtype=dtype, device=dev)
643
+ x = (uu - cx) / fx
644
+ y = (vv - cy) / fy
645
+ ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
646
+ return distance_b1hw * ray_z.unsqueeze(1)
647
+
648
+ @staticmethod
649
+ def _z_depth_to_distance_pinhole(
650
+ z_depth_b1hw: torch.Tensor,
651
+ intrinsics_b33: torch.Tensor,
652
+ ) -> torch.Tensor:
653
+ z_depth_b1hw = UnifiedTrainer._as_b1hw_depth(z_depth_b1hw)
654
+ intrinsics_b33 = UnifiedTrainer._as_b33_intrinsics(intrinsics_b33)
655
+ b, _, h, w = z_depth_b1hw.shape
656
+ dev = z_depth_b1hw.device
657
+ dtype = z_depth_b1hw.dtype
658
+ uu, vv = integer_pixel_center_grid(h, w, device=dev, dtype=dtype)
659
+ uu = uu.unsqueeze(0).expand(b, -1, -1)
660
+ vv = vv.unsqueeze(0).expand(b, -1, -1)
661
+ fx = intrinsics_b33[:, 0, 0].view(b, 1, 1).to(dtype=dtype, device=dev)
662
+ fy = intrinsics_b33[:, 1, 1].view(b, 1, 1).to(dtype=dtype, device=dev)
663
+ cx = intrinsics_b33[:, 0, 2].view(b, 1, 1).to(dtype=dtype, device=dev)
664
+ cy = intrinsics_b33[:, 1, 2].view(b, 1, 1).to(dtype=dtype, device=dev)
665
+ x = (uu - cx) / fx
666
+ y = (vv - cy) / fy
667
+ ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
668
+ return z_depth_b1hw / ray_z.unsqueeze(1).clamp(min=1e-8)
669
+
670
+ def _cubemap_z_depth_to_distance(
671
+ self,
672
+ depth_cube: torch.Tensor,
673
+ ) -> torch.Tensor:
674
+ if depth_cube.ndim != 4:
675
+ raise ValueError(f"Expected 4D cubemap depth, got {tuple(depth_cube.shape)}")
676
+ if depth_cube.shape[-1] == 1:
677
+ depth_61hw = depth_cube.permute(0, 3, 1, 2).contiguous()
678
+ elif depth_cube.shape[1] == 1:
679
+ depth_61hw = depth_cube
680
+ else:
681
+ raise ValueError(f"Unsupported cubemap depth shape={tuple(depth_cube.shape)}")
682
+
683
+ _, _, h, w = depth_61hw.shape
684
+ intr = get_pinhole_intrinsics_4x4(int(w)).to(
685
+ device=depth_61hw.device,
686
+ dtype=depth_61hw.dtype,
687
+ )
688
+ fx = intr[0, 0]
689
+ fy = intr[1, 1]
690
+ cx = intr[0, 2]
691
+ cy = intr[1, 2]
692
+ uu, vv = integer_pixel_center_grid(h, w, device=depth_61hw.device, dtype=depth_61hw.dtype)
693
+ x = (uu - cx) / fx
694
+ y = (vv - cy) / fy
695
+ ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
696
+ dist = depth_61hw / ray_z.view(1, 1, h, w).clamp(min=1e-8)
697
+ valid = torch.isfinite(dist) & (depth_61hw > 0.0)
698
+ dist = torch.where(valid, dist.clamp(min=1e-4), torch.zeros_like(dist))
699
+ return dist
700
+
701
+ def _collect_regularization_inputs(
702
+ self,
703
+ out: dict[str, Any],
704
+ gaussians: Any,
705
+ b: int,
706
+ projected_scale_factor: float | None,
707
+ ) -> dict[str, Any]:
708
+ delta_b = out.get("delta", None)
709
+ delta_xy_raw = None
710
+ if torch.is_tensor(delta_b):
711
+ delta_xy_raw = delta_b[b : b + 1, 0:2]
712
+ delta_rho_raw = delta_b[b : b + 1, 2:3]
713
+ delta_grid_raw = delta_b[b : b + 1]
714
+ else:
715
+ delta_rho_raw = None
716
+ delta_grid_raw = None
717
+ delta_rho_applied_all = out.get("delta_rho_applied", None)
718
+ delta_rho_applied = (
719
+ delta_rho_applied_all[b : b + 1]
720
+ if torch.is_tensor(delta_rho_applied_all)
721
+ else None
722
+ )
723
+ scale_factor_applied_all = out.get("scale_factor_applied", None)
724
+ scale_factor_applied = (
725
+ scale_factor_applied_all[b : b + 1]
726
+ if torch.is_tensor(scale_factor_applied_all)
727
+ else None
728
+ )
729
+
730
+ scales_b = gaussians.singular_values[b : b + 1]
731
+ means_b = gaussians.mean_vectors[b : b + 1]
732
+ quats_b = gaussians.quaternions[b : b + 1]
733
+ opac_b = gaussians.opacities[b : b + 1]
734
+
735
+ base_values = out.get("gaussian_base_values", None)
736
+ gauss_grid_shape = None
737
+ base_means_b = None
738
+ base_scales_b = None
739
+ angular_cell_b = None
740
+ if base_values is not None and hasattr(base_values, "rays"):
741
+ _, _, l, hb, wb = base_values.rays.shape
742
+ gauss_grid_shape = (int(l), int(hb), int(wb))
743
+ inv_dist_b = base_values.inv_distance[b : b + 1].clamp(min=1e-6)
744
+ base_rays_b = F.normalize(base_values.rays[b : b + 1], dim=1, eps=1e-6)
745
+ base_means_grid = base_rays_b / inv_dist_b
746
+ base_scales_b = base_values.scales[b : b + 1]
747
+ init_output = out.get("initializer_output", None)
748
+ global_scale = (
749
+ init_output.global_scale[b : b + 1]
750
+ if init_output is not None
751
+ and getattr(init_output, "global_scale", None) is not None
752
+ else None
753
+ )
754
+ if torch.is_tensor(global_scale):
755
+ base_means_grid = base_means_grid * global_scale.view(-1, 1, 1, 1, 1)
756
+ base_scales_b = base_scales_b * global_scale.view(-1, 1, 1, 1, 1)
757
+ base_means_b = base_means_grid.permute(0, 2, 3, 4, 1).flatten(1, 3)
758
+ angular_cell = getattr(base_values, "angular_cell", None)
759
+ angular_cell_b = angular_cell[b : b + 1] if torch.is_tensor(angular_cell) else None
760
+
761
+ return {
762
+ "delta_xy_eff": delta_xy_raw,
763
+ "delta_rho_raw": delta_rho_raw,
764
+ "delta_grid": delta_grid_raw,
765
+ "delta_rho_applied": delta_rho_applied,
766
+ "scale_factor_applied": scale_factor_applied,
767
+ "gaussian_scales": scales_b,
768
+ "gaussian_quaternions": quats_b,
769
+ "gaussian_angular_cell": angular_cell_b,
770
+ "gaussian_mean_vectors": means_b,
771
+ "gaussian_base_mean_vectors": base_means_b,
772
+ "gaussian_base_scales": base_scales_b,
773
+ "gaussian_opacities": opac_b,
774
+ "gauss_grid_shape": gauss_grid_shape,
775
+ "projected_scale_factor": projected_scale_factor,
776
+ }
777
+
778
+ def _compute_view_loss(
779
+ self,
780
+ *,
781
+ pred_rgb_linear: torch.Tensor,
782
+ pred_alpha: torch.Tensor,
783
+ pred_depth_m: torch.Tensor,
784
+ pred_depth2_m: torch.Tensor | None,
785
+ gt_rgb_u8: torch.Tensor,
786
+ gt_depth_m: torch.Tensor,
787
+ mask: torch.Tensor,
788
+ apply_color: bool,
789
+ apply_alpha: bool,
790
+ apply_depth: bool,
791
+ apply_percep: bool,
792
+ apply_tv: bool,
793
+ apply_grad: bool,
794
+ apply_grad_img: bool,
795
+ grad_img_circular_h: bool | None = None,
796
+ gaussian_scales: torch.Tensor | None = None,
797
+ gaussian_quaternions: torch.Tensor | None = None,
798
+ gaussian_angular_cell: torch.Tensor | None = None,
799
+ delta_xy: torch.Tensor | None = None,
800
+ delta_rho: torch.Tensor | None = None,
801
+ delta_grid: torch.Tensor | None = None,
802
+ gaussian_mean_vectors: torch.Tensor | None = None,
803
+ gaussian_base_mean_vectors: torch.Tensor | None = None,
804
+ gaussian_opacities: torch.Tensor | None = None,
805
+ gauss_grid_shape: tuple[int, int, int] | None = None,
806
+ projected_scale_factor: float | torch.Tensor | None = None,
807
+ projection_model: str | None = None,
808
+ projection_intrinsics: torch.Tensor | None = None,
809
+ projection_camera_params: torch.Tensor | None = None,
810
+ loss_scale: float = 1.0,
811
+ apply_splat: bool | None = None,
812
+ depth_mask: torch.Tensor | None = None,
813
+ ) -> dict[str, torch.Tensor]:
814
+ losses = self.loss_fn(
815
+ pred_rgb_linear=pred_rgb_linear,
816
+ pred_alpha=pred_alpha,
817
+ pred_depth_m=pred_depth_m,
818
+ pred_depth2_m=pred_depth2_m,
819
+ gt_rgb_u8=gt_rgb_u8,
820
+ gt_depth_m=gt_depth_m,
821
+ mask=mask,
822
+ depth_mask=depth_mask,
823
+ gaussian_scales=gaussian_scales,
824
+ gaussian_quaternions=gaussian_quaternions,
825
+ gaussian_angular_cell=gaussian_angular_cell,
826
+ delta_xy=delta_xy,
827
+ delta_rho=delta_rho,
828
+ delta_grid=delta_grid,
829
+ apply_color=bool(apply_color),
830
+ apply_alpha=bool(apply_alpha),
831
+ apply_depth=bool(apply_depth),
832
+ apply_percep=bool(apply_percep),
833
+ apply_tv=bool(apply_tv),
834
+ apply_grad=bool(apply_grad),
835
+ apply_grad_img=bool(apply_grad_img),
836
+ grad_img_circular_h=grad_img_circular_h,
837
+ apply_delta=bool(torch.is_tensor(delta_xy) or torch.is_tensor(delta_rho)),
838
+ apply_splat=bool(torch.is_tensor(gaussian_scales)) if apply_splat is None else bool(apply_splat),
839
+ gaussian_mean_vectors=gaussian_mean_vectors,
840
+ gaussian_base_mean_vectors=gaussian_base_mean_vectors,
841
+ gaussian_opacities=gaussian_opacities,
842
+ gauss_grid_shape=gauss_grid_shape,
843
+ projected_scale_factor=projected_scale_factor,
844
+ projection_model=projection_model,
845
+ projection_intrinsics=projection_intrinsics,
846
+ projection_camera_params=projection_camera_params,
847
+ )
848
+ scale = float(loss_scale)
849
+ if abs(scale - 1.0) > 1e-8:
850
+ losses = {k: (v * scale) for k, v in losses.items()}
851
+ return losses
852
+
853
+ def _build_pinhole_strategy(
854
+ self,
855
+ batch: Any,
856
+ step: int,
857
+ need_vis: bool = False,
858
+ dataset_name: str = "re10k",
859
+ ) -> _ModeStrategy:
860
+ src_u8 = self._as_bchw_rgb_u8(batch.src_rgb_u8.to(self.device, non_blocking=True))
861
+ tgt_u8 = self._as_bchw_rgb_u8(batch.tgt_rgb_u8.to(self.device, non_blocking=True))
862
+ src_u8_orig = getattr(batch, "src_rgb_u8_orig", None)
863
+ tgt_u8_orig = getattr(batch, "tgt_rgb_u8_orig", None)
864
+ src_depth_gt = getattr(batch, "src_depth_m", None)
865
+ tgt_depth_gt = getattr(batch, "tgt_depth_m", None)
866
+ src_depth_gt_orig = getattr(batch, "src_depth_m_orig", None)
867
+ tgt_depth_gt_orig = getattr(batch, "tgt_depth_m_orig", None)
868
+ has_depth_gt = torch.is_tensor(src_depth_gt) and torch.is_tensor(tgt_depth_gt)
869
+ if has_depth_gt:
870
+ src_depth_gt = self._as_b1hw_depth(
871
+ src_depth_gt.to(self.device, non_blocking=True).to(torch.float32)
872
+ )
873
+ tgt_depth_gt = self._as_b1hw_depth(
874
+ tgt_depth_gt.to(self.device, non_blocking=True).to(torch.float32)
875
+ )
876
+ has_depth_gt_orig = torch.is_tensor(src_depth_gt_orig) and torch.is_tensor(tgt_depth_gt_orig)
877
+ if has_depth_gt_orig:
878
+ src_depth_gt_orig = self._as_b1hw_depth(
879
+ src_depth_gt_orig.to(self.device, non_blocking=True).to(torch.float32)
880
+ )
881
+ tgt_depth_gt_orig = self._as_b1hw_depth(
882
+ tgt_depth_gt_orig.to(self.device, non_blocking=True).to(torch.float32)
883
+ )
884
+ src_w2c = self._as_b44_pose(batch.src_w2c.to(self.device, non_blocking=True).to(torch.float32))
885
+ tgt_w2c = self._as_b44_pose(batch.tgt_w2c.to(self.device, non_blocking=True).to(torch.float32))
886
+ src_k3 = self._as_b33_intrinsics(batch.src_intrinsics.to(self.device, non_blocking=True).to(torch.float32))
887
+ tgt_k3 = self._as_b33_intrinsics(batch.tgt_intrinsics.to(self.device, non_blocking=True).to(torch.float32))
888
+ src_k3_orig = getattr(batch, "src_intrinsics_orig", None)
889
+ tgt_k3_orig = getattr(batch, "tgt_intrinsics_orig", None)
890
+ has_orig_vis = (
891
+ torch.is_tensor(src_u8_orig)
892
+ and torch.is_tensor(tgt_u8_orig)
893
+ and torch.is_tensor(src_k3_orig)
894
+ and torch.is_tensor(tgt_k3_orig)
895
+ )
896
+ if has_orig_vis:
897
+ src_u8_orig = self._as_bchw_rgb_u8(src_u8_orig.to(self.device, non_blocking=True))
898
+ tgt_u8_orig = self._as_bchw_rgb_u8(tgt_u8_orig.to(self.device, non_blocking=True))
899
+ src_k3_orig = self._as_b33_intrinsics(
900
+ src_k3_orig.to(self.device, non_blocking=True).to(torch.float32)
901
+ )
902
+ tgt_k3_orig = self._as_b33_intrinsics(
903
+ tgt_k3_orig.to(self.device, non_blocking=True).to(torch.float32)
904
+ )
905
+ src_depth_gt_dist = None
906
+ tgt_depth_gt_dist = None
907
+ src_unik3d_gt_dist = None
908
+ if has_depth_gt:
909
+ src_unik3d_gt_dist = self._pinhole_z_to_supervision_distance(src_depth_gt, src_k3)
910
+ src_depth_gt_dist = src_unik3d_gt_dist
911
+ tgt_depth_gt_dist = self._pinhole_z_to_supervision_distance(tgt_depth_gt, tgt_k3)
912
+
913
+ src = src_u8.float().clamp(0, 255) / 255.0
914
+ tgt = tgt_u8.float().clamp(0, 255) / 255.0
915
+ distance_init_cap_m = self._distance_init_cap_for_dataset(dataset_name)
916
+
917
+ share_src_forward = bool(getattr(batch, "share_src_forward", False)) and int(src.shape[0]) > 1
918
+
919
+ def _repeat_first_dim(value: Any, batch_size: int) -> Any:
920
+ if torch.is_tensor(value):
921
+ if value.ndim > 0 and int(value.shape[0]) == 1:
922
+ return value.repeat(batch_size, *([1] * (value.ndim - 1)))
923
+ return value
924
+ if hasattr(value, "_fields"):
925
+ return type(value)(*[_repeat_first_dim(getattr(value, field), batch_size) for field in value._fields])
926
+ return value
927
+
928
+ if share_src_forward:
929
+ out_single = self.model(
930
+ image=src[0:1],
931
+ image_u8=src_u8[0:1],
932
+ camera_intrinsics=src_k3[0:1],
933
+ camera_model="pinhole",
934
+ depth_gt=(src_depth_gt_dist[0:1] if torch.is_tensor(src_depth_gt_dist) else None),
935
+ distance_init_cap_m=distance_init_cap_m,
936
+ return_aux=True,
937
+ )
938
+ out = {k: _repeat_first_dim(v, int(src.shape[0])) for k, v in out_single.items()}
939
+ else:
940
+ out = self.model(
941
+ image=src,
942
+ image_u8=src_u8,
943
+ camera_intrinsics=src_k3,
944
+ camera_model="pinhole",
945
+ depth_gt=src_depth_gt_dist,
946
+ distance_init_cap_m=distance_init_cap_m,
947
+ return_aux=True,
948
+ )
949
+ gaussians = out["gaussians"]
950
+ src_render_k3 = src_k3
951
+ tgt_render_k3 = tgt_k3
952
+ src_depth_gt_z_render = src_depth_gt if has_depth_gt else None
953
+ tgt_depth_gt_z_render = tgt_depth_gt if has_depth_gt else None
954
+ src_depth_gt_render_valid = (torch.isfinite(src_depth_gt) & (src_depth_gt > 0.0)) if has_depth_gt else None
955
+ tgt_depth_gt_render_valid = (torch.isfinite(tgt_depth_gt) & (tgt_depth_gt > 0.0)) if has_depth_gt else None
956
+ aux_ray_target_all = out.get("unik3d_gt_rays", None)
957
+ def make_world_gaussians(b: int, g_b: Any) -> Any:
958
+ return g_b
959
+
960
+ def make_sample(b: int, g_world: Any, enable_vis: bool) -> dict[str, Any]:
961
+ src_h = int(src_u8.shape[-2])
962
+ src_w = int(src_u8.shape[-1])
963
+ tgt_h = int(tgt_u8.shape[-2])
964
+ tgt_w = int(tgt_u8.shape[-1])
965
+ ident = torch.eye(4, dtype=src_w2c.dtype, device=self.device).unsqueeze(0)
966
+ rel_tgt_w2c = tgt_w2c[b : b + 1] @ torch.linalg.inv(src_w2c[b : b + 1])
967
+ src_k_render_b = src_render_k3[b : b + 1]
968
+ tgt_k_render_b = tgt_render_k3[b : b + 1]
969
+ src_out = self.renderer(
970
+ g_world,
971
+ extrinsics=ident,
972
+ intrinsics=to_k4(src_k_render_b),
973
+ image_width=src_w,
974
+ image_height=src_h,
975
+ )
976
+ tgt_out = self.renderer(
977
+ g_world,
978
+ extrinsics=rel_tgt_w2c,
979
+ intrinsics=to_k4(tgt_k_render_b),
980
+ image_width=tgt_w,
981
+ image_height=tgt_h,
982
+ )
983
+
984
+ zeros_src_depth = torch.zeros((1, 1, src_h, src_w), dtype=torch.float32, device=self.device)
985
+ zeros_tgt_depth = torch.zeros((1, 1, tgt_h, tgt_w), dtype=torch.float32, device=self.device)
986
+ ones_mask = torch.ones_like(zeros_src_depth)
987
+ fx_b = float(src_k_render_b[0, 0, 0].item())
988
+ fy_b = float(src_k_render_b[0, 1, 1].item())
989
+ proj_scale_pinhole = 0.5 * (fx_b + fy_b)
990
+ reg_inputs = self._collect_regularization_inputs(
991
+ out=out,
992
+ gaussians=gaussians,
993
+ b=b,
994
+ projected_scale_factor=proj_scale_pinhole,
995
+ )
996
+
997
+ src_depth_for_visibility = None
998
+ tgt_gt_depth_for_mask = (
999
+ tgt_depth_gt_z_render[b : b + 1]
1000
+ if has_depth_gt and torch.is_tensor(tgt_depth_gt_z_render)
1001
+ else None
1002
+ )
1003
+ if has_depth_gt:
1004
+ src_depth_for_visibility = (
1005
+ src_depth_gt_z_render[b : b + 1]
1006
+ if torch.is_tensor(src_depth_gt_z_render)
1007
+ else src_depth_gt[b : b + 1]
1008
+ )
1009
+
1010
+ tgt_depth_for_mask = self._pick_depth_for_pinhole_frustum_mask(
1011
+ gt_depth=tgt_gt_depth_for_mask,
1012
+ pred_depth=tgt_out.depth,
1013
+ )
1014
+ tgt_frustum_mask = compute_frustum_mask(
1015
+ depth=tgt_depth_for_mask,
1016
+ tgt_w2c=tgt_w2c[b : b + 1],
1017
+ src_w2c=src_w2c[b : b + 1],
1018
+ src_k3=src_k_render_b,
1019
+ tgt_k3=tgt_k_render_b,
1020
+ img_h=tgt_h,
1021
+ img_w=tgt_w,
1022
+ source_img_h=src_h,
1023
+ source_img_w=src_w,
1024
+ source_depth=src_depth_for_visibility,
1025
+ )
1026
+ tgt_frustum_mask_raw = tgt_frustum_mask
1027
+ tgt_frustum_mask = self._erode_supervision_mask(
1028
+ tgt_frustum_mask,
1029
+ self.target_mask_erode_px,
1030
+ circular_h=False,
1031
+ )
1032
+ src_depth_pred = self._clamp_distance_for_supervision(
1033
+ out["distance_layers"][b : b + 1, 0:1],
1034
+ clamp_max=False,
1035
+ )
1036
+ src_depth2_pred = (
1037
+ self._clamp_distance_for_supervision(out["distance_layers"][b : b + 1, 1:2], clamp_max=False)
1038
+ if out["distance_layers"] is not None and out["distance_layers"].shape[1] > 1
1039
+ else None
1040
+ )
1041
+ src_depth2_gt_for_aux = (
1042
+ src_unik3d_gt_dist[b : b + 1]
1043
+ if torch.is_tensor(src_unik3d_gt_dist)
1044
+ else None
1045
+ )
1046
+ src_depth2_mask_for_aux = src_depth_gt[b : b + 1] > 0.0 if has_depth_gt else None
1047
+ tgt_depth_pred = self._pinhole_z_to_supervision_distance(
1048
+ tgt_out.depth,
1049
+ tgt_k_render_b,
1050
+ clamp_max=False,
1051
+ )
1052
+ tgt_depth_loss_mask = self._rendered_depth_valid_for_inv_loss(tgt_depth_pred, tgt_out.alpha)
1053
+ if torch.is_tensor(tgt_depth_gt_render_valid):
1054
+ tgt_depth_loss_mask = tgt_depth_loss_mask * tgt_depth_gt_render_valid[b : b + 1].to(
1055
+ device=tgt_depth_loss_mask.device,
1056
+ dtype=tgt_depth_loss_mask.dtype,
1057
+ )
1058
+ tgt_extra_loss_terms: list[dict[str, Any]] = []
1059
+ vis_payload = None
1060
+ if enable_vis:
1061
+ vis_src_u8 = src_u8[b : b + 1]
1062
+ vis_tgt_u8 = tgt_u8[b : b + 1]
1063
+ vis_src_depth_gt = (src_depth_gt[b : b + 1] if has_depth_gt else None)
1064
+ vis_tgt_depth_gt = (tgt_depth_gt[b : b + 1] if has_depth_gt else None)
1065
+ vis_src_out = src_out
1066
+ vis_tgt_out = tgt_out
1067
+ if has_orig_vis:
1068
+ vis_src_u8 = src_u8_orig[b : b + 1]
1069
+ vis_tgt_u8 = tgt_u8_orig[b : b + 1]
1070
+ vis_src_depth_gt = (src_depth_gt_orig[b : b + 1] if has_depth_gt_orig else None)
1071
+ vis_tgt_depth_gt = (tgt_depth_gt_orig[b : b + 1] if has_depth_gt_orig else None)
1072
+ vis_src_render_k3 = src_k3_orig[b : b + 1]
1073
+ vis_tgt_render_k3 = tgt_k3_orig[b : b + 1]
1074
+ vis_src_out = self.renderer(
1075
+ g_world,
1076
+ extrinsics=ident,
1077
+ intrinsics=to_k4(vis_src_render_k3),
1078
+ image_width=int(vis_src_u8.shape[-1]),
1079
+ image_height=int(vis_src_u8.shape[-2]),
1080
+ )
1081
+ vis_tgt_out = self.renderer(
1082
+ g_world,
1083
+ extrinsics=rel_tgt_w2c,
1084
+ intrinsics=to_k4(vis_tgt_render_k3),
1085
+ image_width=int(vis_tgt_u8.shape[-1]),
1086
+ image_height=int(vis_tgt_u8.shape[-2]),
1087
+ )
1088
+ src_unik3d_depth = None
1089
+ tgt_unik3d_depth = None
1090
+ raw_dist = out.get("unik3d_distance", None)
1091
+ if torch.is_tensor(raw_dist):
1092
+ try:
1093
+ conditioning_rays = out.get("unik3d_ray_conditioning_rays", None)
1094
+ if not torch.is_tensor(conditioning_rays):
1095
+ conditioning_rays = out.get("unik3d_rays", None)
1096
+ ray_z = (
1097
+ conditioning_rays[b : b + 1, 2:3].detach()
1098
+ if torch.is_tensor(conditioning_rays)
1099
+ else None
1100
+ )
1101
+ if torch.is_tensor(ray_z):
1102
+ if tuple(ray_z.shape[-2:]) != tuple(raw_dist.shape[-2:]):
1103
+ ray_z = F.interpolate(ray_z, size=raw_dist.shape[-2:], mode="bilinear", align_corners=False)
1104
+ src_unik3d_depth = raw_dist[b : b + 1, 0:1].detach() * ray_z
1105
+ else:
1106
+ src_unik3d_depth = self._distance_to_z_depth_pinhole(
1107
+ raw_dist[b : b + 1, 0:1].detach(),
1108
+ src_k_render_b,
1109
+ )
1110
+ except Exception:
1111
+ src_unik3d_depth = raw_dist[b : b + 1, 0:1].detach()
1112
+ if self.enable_tgt_unik3d_vis:
1113
+ try:
1114
+ with torch.no_grad():
1115
+ from unisharp.utils.unik3d_adapter import forward_unik3d_pinhole
1116
+
1117
+ unik_tgt = forward_unik3d_pinhole(
1118
+ self._base_model().feature_extractor.unik3d,
1119
+ rgb_u8=tgt_u8[b : b + 1],
1120
+ intrinsics=tgt_k3[b : b + 1],
1121
+ normalize=True,
1122
+ )
1123
+ dist_tgt = unik_tgt.get("distance", None) if isinstance(unik_tgt, dict) else None
1124
+ if torch.is_tensor(dist_tgt):
1125
+ try:
1126
+ tgt_unik3d_depth = self._distance_to_z_depth_pinhole(
1127
+ dist_tgt[:, 0:1].detach(),
1128
+ tgt_k_render_b,
1129
+ )
1130
+ except Exception:
1131
+ tgt_unik3d_depth = dist_tgt[:, 0:1].detach()
1132
+ except Exception:
1133
+ tgt_unik3d_depth = None
1134
+
1135
+ vis_payload = {
1136
+ "src_gt": (vis_src_u8.float() / 255.0).detach(),
1137
+ "src_pred": vis_src_out.color.clamp(0, 1).detach(),
1138
+ "src_alpha": vis_src_out.alpha.detach(),
1139
+ "src_gt_depth": (vis_src_depth_gt.detach() if torch.is_tensor(vis_src_depth_gt) else None),
1140
+ "src_pred_depth": vis_src_out.depth.detach(),
1141
+ "src_unik3d_depth": src_unik3d_depth,
1142
+ "tgt_gt": (vis_tgt_u8.float() / 255.0).detach(),
1143
+ "tgt_pred": vis_tgt_out.color.clamp(0, 1).detach(),
1144
+ "tgt_alpha": vis_tgt_out.alpha.detach(),
1145
+ "tgt_gt_depth": (vis_tgt_depth_gt.detach() if torch.is_tensor(vis_tgt_depth_gt) else None),
1146
+ "tgt_pred_depth": vis_tgt_out.depth.detach(),
1147
+ "tgt_unik3d_depth": tgt_unik3d_depth,
1148
+ "dataset_name": str(dataset_name),
1149
+ "scene": str(self._item_at(getattr(batch, "scene", None), b, "unknown")),
1150
+ "src_idx": int(self._item_at(getattr(batch, "src_idx", None), b, -1)),
1151
+ "tgt_idx": int(self._item_at(getattr(batch, "tgt_idx", None), b, -1)),
1152
+ "src_pose_w2c": src_w2c[b : b + 1].detach(),
1153
+ "tgt_pose_w2c": tgt_w2c[b : b + 1].detach(),
1154
+ "tgt_metric_mask_raw": tgt_frustum_mask_raw.detach(),
1155
+ "tgt_metric_mask": tgt_frustum_mask.detach(),
1156
+ }
1157
+
1158
+ return {
1159
+ "src_pred_rgb_linear": src_out.color,
1160
+ "src_pred_alpha": src_out.alpha,
1161
+ "src_pred_depth_m": src_depth_pred,
1162
+ "src_pred_depth2_m": src_depth2_pred,
1163
+ "src_gt_rgb_u8": src_u8[b : b + 1],
1164
+ "src_gt_depth_m": (src_depth_gt_dist[b : b + 1] if has_depth_gt and src_depth_gt_dist is not None else zeros_src_depth),
1165
+ "src_mask": ones_mask,
1166
+ "src_apply_depth": False,
1167
+ "src_apply_grad": bool(has_depth_gt),
1168
+ "src_apply_grad_img": bool(has_depth_gt),
1169
+ "src_grad_img_circular_h": False,
1170
+ "tgt_pred_rgb_linear": tgt_out.color,
1171
+ "tgt_pred_alpha": tgt_out.alpha,
1172
+ "tgt_pred_depth_m": tgt_depth_pred,
1173
+ "tgt_gt_rgb_u8": tgt_u8[b : b + 1],
1174
+ "tgt_gt_depth_m": (tgt_depth_gt_dist[b : b + 1] if has_depth_gt and tgt_depth_gt_dist is not None else zeros_tgt_depth),
1175
+ "tgt_mask": tgt_frustum_mask,
1176
+ "tgt_depth_mask": tgt_depth_loss_mask,
1177
+ "tgt_apply_depth": bool(has_depth_gt),
1178
+ "tgt_apply_grad_img": bool(has_depth_gt),
1179
+ "tgt_grad_img_circular_h": False,
1180
+ "tgt_apply_percep": bool(float(self.loss_fn.w.lambda_percep) > 0.0),
1181
+ "tgt_extra_loss_terms": tgt_extra_loss_terms,
1182
+ "aux_losses": self._aux_ray_losses(
1183
+ pred_rays=(
1184
+ out.get("unik3d_rays", None)[b : b + 1]
1185
+ if torch.is_tensor(out.get("unik3d_rays", None))
1186
+ else None
1187
+ ),
1188
+ gt_rays=(
1189
+ aux_ray_target_all[b : b + 1]
1190
+ if torch.is_tensor(aux_ray_target_all)
1191
+ else None
1192
+ ),
1193
+ mask=ones_mask,
1194
+ pred_distance=(
1195
+ out["unik3d_distance"][b : b + 1, 0:1]
1196
+ if torch.is_tensor(out.get("unik3d_distance", None))
1197
+ else None
1198
+ ),
1199
+ pred_distance2=src_depth2_pred,
1200
+ gt_distance=(
1201
+ src_unik3d_gt_dist[b : b + 1]
1202
+ if torch.is_tensor(src_unik3d_gt_dist)
1203
+ else None
1204
+ ),
1205
+ gt_distance2=src_depth2_gt_for_aux,
1206
+ depth_mask=(src_depth_gt[b : b + 1] > 0.0 if has_depth_gt else None),
1207
+ depth_mask2=src_depth2_mask_for_aux,
1208
+ ),
1209
+ "gaussian_scales": reg_inputs["gaussian_scales"],
1210
+ "gaussian_quaternions": reg_inputs["gaussian_quaternions"],
1211
+ "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
1212
+ "delta_xy": reg_inputs["delta_xy_eff"],
1213
+ "delta_rho": reg_inputs["delta_rho_raw"],
1214
+ "delta_grid": reg_inputs["delta_grid"],
1215
+ "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
1216
+ "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
1217
+ "gaussian_opacities": reg_inputs["gaussian_opacities"],
1218
+ "gauss_grid_shape": reg_inputs["gauss_grid_shape"],
1219
+ "projected_scale_factor": reg_inputs["projected_scale_factor"],
1220
+ "projection_model": "pinhole",
1221
+ "projection_intrinsics": src_k_render_b,
1222
+ "vis_payload": vis_payload,
1223
+ }
1224
+
1225
+ return _ModeStrategy(
1226
+ batch_size=int(src.shape[0]),
1227
+ gaussians=gaussians,
1228
+ make_world_gaussians=make_world_gaussians,
1229
+ make_sample=make_sample,
1230
+ collect_all_vis=bool(getattr(batch, "collect_all_vis", False)),
1231
+ )
1232
+
1233
+ def _build_fisheye624_strategy(
1234
+ self,
1235
+ batch: Any,
1236
+ step: int,
1237
+ need_vis: bool = False,
1238
+ dataset_name: str = "scannetpp_fisheye",
1239
+ ) -> _ModeStrategy:
1240
+ del step
1241
+ src_u8 = self._as_bchw_rgb_u8(batch.src_rgb_u8.to(self.device, non_blocking=True))
1242
+ tgt_u8 = self._as_bchw_rgb_u8(batch.tgt_rgb_u8.to(self.device, non_blocking=True))
1243
+ src_depth_gt = self._clamp_distance_for_supervision(
1244
+ self._as_b1hw_depth(batch.src_depth_m.to(self.device, non_blocking=True).to(torch.float32))
1245
+ )
1246
+ tgt_depth_gt = self._clamp_distance_for_supervision(
1247
+ self._as_b1hw_depth(batch.tgt_depth_m.to(self.device, non_blocking=True).to(torch.float32))
1248
+ )
1249
+ src_valid_mask = self._as_b1hw_depth(batch.src_valid_mask.to(self.device, non_blocking=True).to(torch.float32))
1250
+ tgt_valid_mask = self._as_b1hw_depth(batch.tgt_valid_mask.to(self.device, non_blocking=True).to(torch.float32))
1251
+ src_w2c = self._as_b44_pose(batch.src_w2c.to(self.device, non_blocking=True).to(torch.float32))
1252
+ tgt_w2c = self._as_b44_pose(batch.tgt_w2c.to(self.device, non_blocking=True).to(torch.float32))
1253
+ src_cam_params = self._as_b16_camera_params(
1254
+ batch.src_camera_params.to(self.device, non_blocking=True).to(torch.float32)
1255
+ )
1256
+ tgt_cam_params = self._as_b16_camera_params(
1257
+ batch.tgt_camera_params.to(self.device, non_blocking=True).to(torch.float32)
1258
+ )
1259
+ distance_init_cap_m = self._distance_init_cap_for_dataset(dataset_name)
1260
+
1261
+ out = self.model(
1262
+ image=src_u8.float().clamp(0, 255) / 255.0,
1263
+ image_u8=src_u8,
1264
+ camera_intrinsics=None,
1265
+ camera_params=src_cam_params,
1266
+ camera_model="fisheye624",
1267
+ depth_gt=src_depth_gt,
1268
+ distance_init_cap_m=distance_init_cap_m,
1269
+ validity_mask=src_valid_mask,
1270
+ return_aux=True,
1271
+ )
1272
+
1273
+ gaussians = out["gaussians"]
1274
+ src_render_cam_params = src_cam_params
1275
+ tgt_render_cam_params = tgt_cam_params
1276
+ src_render_valid_mask = src_valid_mask
1277
+ tgt_render_valid_mask = tgt_valid_mask
1278
+ aux_ray_target_all = out.get("unik3d_gt_rays", None)
1279
+
1280
+ def make_world_gaussians(b: int, g_b: Any) -> Any:
1281
+ return transform_gaussians_to_world(g_b, src_w2c[b])
1282
+
1283
+ def make_sample(b: int, g_world: Any, enable_vis: bool) -> dict[str, Any]:
1284
+ src_h = int(src_u8.shape[-2])
1285
+ src_w = int(src_u8.shape[-1])
1286
+ tgt_h = int(tgt_u8.shape[-2])
1287
+ tgt_w = int(tgt_u8.shape[-1])
1288
+ src_render = render_gaussians_fisheye624(
1289
+ g_world,
1290
+ extrinsics_w2c=src_w2c[b : b + 1],
1291
+ camera_params=src_render_cam_params[b : b + 1],
1292
+ image_h=src_h,
1293
+ image_w=src_w,
1294
+ valid_mask=src_render_valid_mask[b : b + 1],
1295
+ )
1296
+ tgt_render = render_gaussians_fisheye624(
1297
+ g_world,
1298
+ extrinsics_w2c=tgt_w2c[b : b + 1],
1299
+ camera_params=tgt_render_cam_params[b : b + 1],
1300
+ image_h=tgt_h,
1301
+ image_w=tgt_w,
1302
+ valid_mask=tgt_render_valid_mask[b : b + 1],
1303
+ )
1304
+ reg_inputs = self._collect_regularization_inputs(
1305
+ out=out,
1306
+ gaussians=gaussians,
1307
+ b=b,
1308
+ projected_scale_factor=None,
1309
+ )
1310
+ tgt_depth_for_mask = self._pick_depth_for_fisheye_frustum_mask(
1311
+ gt_depth=tgt_depth_gt[b : b + 1],
1312
+ pred_depth=tgt_render["depth_distance"],
1313
+ gt_valid_mask=tgt_valid_mask[b : b + 1],
1314
+ )
1315
+ tgt_frustum_mask = compute_fisheye624_frustum_mask(
1316
+ depth_distance_m=tgt_depth_for_mask,
1317
+ tgt_w2c=tgt_w2c[b : b + 1],
1318
+ src_w2c=src_w2c[b : b + 1],
1319
+ tgt_camera_params=tgt_render_cam_params[b : b + 1],
1320
+ src_camera_params=src_render_cam_params[b : b + 1],
1321
+ src_valid_mask=src_render_valid_mask[b : b + 1] * src_render["valid_mask"],
1322
+ source_depth_distance_m=src_depth_gt[b : b + 1],
1323
+ )
1324
+ src_mask = src_render_valid_mask[b : b + 1] * src_render["valid_mask"]
1325
+ src_depth_mask = src_mask
1326
+ tgt_mask = tgt_render_valid_mask[b : b + 1] * tgt_render["valid_mask"] * tgt_frustum_mask
1327
+ tgt_mask_raw = tgt_mask
1328
+ tgt_mask = self._erode_supervision_mask(
1329
+ tgt_mask,
1330
+ self.target_mask_erode_px,
1331
+ circular_h=False,
1332
+ )
1333
+ src_depth_pred = self._clamp_distance_for_supervision(
1334
+ out["distance_layers"][b : b + 1, 0:1],
1335
+ clamp_max=False,
1336
+ )
1337
+ src_depth2_pred = (
1338
+ self._clamp_distance_for_supervision(out["distance_layers"][b : b + 1, 1:2], clamp_max=False)
1339
+ if out["distance_layers"] is not None and out["distance_layers"].shape[1] > 1
1340
+ else None
1341
+ )
1342
+ tgt_depth_pred = self._clamp_distance_for_supervision(tgt_render["depth_distance"], clamp_max=False)
1343
+ tgt_depth_loss_mask = self._rendered_depth_valid_for_inv_loss(tgt_depth_pred, tgt_render["alpha"])
1344
+
1345
+ src_loss_terms = [
1346
+ {
1347
+ "pred_rgb_linear": src_render["color"],
1348
+ "pred_alpha": src_render["alpha"],
1349
+ "pred_depth_m": src_render["depth_distance"],
1350
+ "pred_depth2_m": None,
1351
+ "gt_rgb_u8": src_u8[b : b + 1],
1352
+ "gt_depth_m": src_depth_gt[b : b + 1],
1353
+ "mask": src_mask,
1354
+ "apply_color": True,
1355
+ "apply_alpha": True,
1356
+ "apply_depth": False,
1357
+ "apply_percep": False,
1358
+ "apply_tv": False,
1359
+ "apply_grad": False,
1360
+ "apply_grad_img": False,
1361
+ "grad_img_circular_h": False,
1362
+ "gaussian_scales": None,
1363
+ "gaussian_quaternions": None,
1364
+ "gaussian_angular_cell": None,
1365
+ "delta_xy": None,
1366
+ "gaussian_mean_vectors": None,
1367
+ "gaussian_opacities": None,
1368
+ "gauss_grid_shape": None,
1369
+ "projected_scale_factor": None,
1370
+ "apply_splat": False,
1371
+ "loss_scale": 1.0,
1372
+ }
1373
+ ]
1374
+ src_extra_loss_terms = [
1375
+ {
1376
+ "pred_rgb_linear": torch.zeros((1, 3, src_h, src_w), dtype=torch.float32, device=self.device),
1377
+ "pred_alpha": torch.zeros((1, 1, src_h, src_w), dtype=torch.float32, device=self.device),
1378
+ "pred_depth_m": src_depth_pred,
1379
+ "pred_depth2_m": src_depth2_pred,
1380
+ "gt_rgb_u8": torch.zeros((1, 3, src_h, src_w), dtype=torch.uint8, device=self.device),
1381
+ "gt_depth_m": src_depth_gt[b : b + 1],
1382
+ "mask": src_depth_mask,
1383
+ "apply_color": False,
1384
+ "apply_alpha": False,
1385
+ "apply_depth": False,
1386
+ "apply_percep": False,
1387
+ "apply_tv": True,
1388
+ "apply_grad": True,
1389
+ "apply_grad_img": True,
1390
+ "grad_img_circular_h": False,
1391
+ "gaussian_scales": reg_inputs["gaussian_scales"],
1392
+ "gaussian_quaternions": reg_inputs["gaussian_quaternions"],
1393
+ "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
1394
+ "delta_xy": reg_inputs["delta_xy_eff"],
1395
+ "delta_rho": reg_inputs["delta_rho_raw"],
1396
+ "delta_grid": reg_inputs["delta_grid"],
1397
+ "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
1398
+ "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
1399
+ "gaussian_opacities": reg_inputs["gaussian_opacities"],
1400
+ "gauss_grid_shape": reg_inputs["gauss_grid_shape"],
1401
+ "projected_scale_factor": None,
1402
+ "projection_model": "fisheye624",
1403
+ "projection_camera_params": src_render_cam_params[b : b + 1],
1404
+ "apply_splat": True,
1405
+ "loss_scale": 1.0,
1406
+ }
1407
+ ]
1408
+ tgt_extra_loss_terms = []
1409
+
1410
+ vis_payload = None
1411
+ if enable_vis:
1412
+ src_unik3d_depth = out["unik3d_distance"][b : b + 1, 0:1].detach() if torch.is_tensor(out.get("unik3d_distance", None)) else None
1413
+ tgt_unik3d_depth = None
1414
+ if (
1415
+ tgt_unik3d_depth is None
1416
+ and self.enable_tgt_unik3d_vis
1417
+ ):
1418
+ try:
1419
+ with torch.no_grad():
1420
+ from unisharp.utils.unik3d_adapter import forward_unik3d_fisheye624
1421
+
1422
+ unik_tgt = forward_unik3d_fisheye624(
1423
+ self._base_model().feature_extractor.unik3d,
1424
+ rgb_u8=tgt_u8[b : b + 1],
1425
+ camera_params=tgt_render_cam_params[b : b + 1],
1426
+ normalize=True,
1427
+ validity_mask=tgt_valid_mask[b : b + 1],
1428
+ )
1429
+ dist_tgt = unik_tgt.get("distance", None) if isinstance(unik_tgt, dict) else None
1430
+ if torch.is_tensor(dist_tgt):
1431
+ tgt_unik3d_depth = dist_tgt[:, 0:1].detach()
1432
+ except Exception:
1433
+ tgt_unik3d_depth = None
1434
+ vis_payload = {
1435
+ "src_gt": (src_u8[b : b + 1].float() / 255.0).detach(),
1436
+ "src_pred": src_render["color"].clamp(0, 1).detach(),
1437
+ "src_alpha": src_render["alpha"].detach(),
1438
+ "src_gt_depth": src_depth_gt[b : b + 1].detach(),
1439
+ "src_pred_depth": src_render["depth_distance"].detach(),
1440
+ "src_unik3d_depth": src_unik3d_depth,
1441
+ "src_metric_mask": src_mask.detach(),
1442
+ "tgt_gt": (tgt_u8[b : b + 1].float() / 255.0).detach(),
1443
+ "tgt_pred": tgt_render["color"].clamp(0, 1).detach(),
1444
+ "tgt_alpha": tgt_render["alpha"].detach(),
1445
+ "tgt_gt_depth": tgt_depth_gt[b : b + 1].detach(),
1446
+ "tgt_pred_depth": tgt_depth_pred.detach(),
1447
+ "tgt_unik3d_depth": tgt_unik3d_depth,
1448
+ "dataset_name": str(dataset_name),
1449
+ "scene": str(self._first_item(getattr(batch, "scene", None), "unknown")),
1450
+ "src_idx": int(self._first_item(getattr(batch, "src_idx", None), -1)),
1451
+ "tgt_idx": int(self._first_item(getattr(batch, "tgt_idx", None), -1)),
1452
+ "src_pose_w2c": src_w2c[b : b + 1].detach(),
1453
+ "tgt_pose_w2c": tgt_w2c[b : b + 1].detach(),
1454
+ "tgt_metric_mask_raw": tgt_mask_raw.detach(),
1455
+ "tgt_metric_mask": tgt_mask.detach(),
1456
+ }
1457
+
1458
+
1459
+ return {
1460
+ "src_loss_terms": src_loss_terms,
1461
+ "src_extra_loss_terms": src_extra_loss_terms,
1462
+ "tgt_pred_rgb_linear": tgt_render["color"],
1463
+ "tgt_pred_alpha": tgt_render["alpha"],
1464
+ "tgt_pred_depth_m": tgt_depth_pred,
1465
+ "tgt_gt_rgb_u8": tgt_u8[b : b + 1],
1466
+ "tgt_gt_depth_m": tgt_depth_gt[b : b + 1],
1467
+ "tgt_mask": tgt_mask,
1468
+ "tgt_depth_mask": tgt_depth_loss_mask,
1469
+ "tgt_apply_depth": True,
1470
+ "tgt_apply_grad_img": True,
1471
+ "tgt_apply_splat": False,
1472
+ "tgt_grad_img_circular_h": False,
1473
+ "tgt_apply_percep": bool(float(self.loss_fn.w.lambda_percep) > 0.0),
1474
+ "tgt_extra_loss_terms": tgt_extra_loss_terms,
1475
+ "aux_losses": self._aux_ray_losses(
1476
+ pred_rays=(
1477
+ out.get("unik3d_rays", None)[b : b + 1]
1478
+ if torch.is_tensor(out.get("unik3d_rays", None))
1479
+ else None
1480
+ ),
1481
+ gt_rays=(
1482
+ aux_ray_target_all[b : b + 1]
1483
+ if torch.is_tensor(aux_ray_target_all)
1484
+ else None
1485
+ ),
1486
+ mask=src_render_valid_mask[b : b + 1],
1487
+ pred_distance=(
1488
+ out["unik3d_distance"][b : b + 1, 0:1]
1489
+ if torch.is_tensor(out.get("unik3d_distance", None))
1490
+ else None
1491
+ ),
1492
+ pred_distance2=None,
1493
+ gt_distance=src_depth_gt[b : b + 1],
1494
+ depth_mask=src_valid_mask[b : b + 1],
1495
+ ),
1496
+ "gaussian_scales": reg_inputs["gaussian_scales"],
1497
+ "gaussian_quaternions": reg_inputs["gaussian_quaternions"],
1498
+ "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
1499
+ "delta_xy": reg_inputs["delta_xy_eff"],
1500
+ "delta_rho": reg_inputs["delta_rho_raw"],
1501
+ "delta_grid": reg_inputs["delta_grid"],
1502
+ "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
1503
+ "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
1504
+ "gaussian_opacities": reg_inputs["gaussian_opacities"],
1505
+ "gauss_grid_shape": reg_inputs["gauss_grid_shape"],
1506
+ "projected_scale_factor": reg_inputs["projected_scale_factor"],
1507
+ "projection_model": "fisheye624",
1508
+ "projection_camera_params": src_render_cam_params[b : b + 1],
1509
+ "vis_payload": vis_payload,
1510
+ }
1511
+
1512
+ return _ModeStrategy(
1513
+ batch_size=int(src_u8.shape[0]),
1514
+ gaussians=gaussians,
1515
+ make_world_gaussians=make_world_gaussians,
1516
+ make_sample=make_sample,
1517
+ collect_all_vis=bool(getattr(batch, "collect_all_vis", False)),
1518
+ )
1519
+
1520
+ def _build_fisheye_strategy(
1521
+ self,
1522
+ batch: Any,
1523
+ step: int,
1524
+ need_vis: bool = False,
1525
+ dataset_name: str = "fisheye",
1526
+ ) -> _ModeStrategy:
1527
+ camera_model = str(getattr(batch, "camera_model", "fisheye624")).lower()
1528
+ if camera_model != "fisheye624":
1529
+ raise ValueError(
1530
+ f"Unsupported fisheye camera_model={camera_model!r}; expected 'fisheye624'."
1531
+ )
1532
+ return self._build_fisheye624_strategy(
1533
+ batch,
1534
+ step,
1535
+ need_vis=need_vis,
1536
+ dataset_name=dataset_name,
1537
+ )
1538
+
1539
+ def _build_spherical_strategy(
1540
+ self,
1541
+ batch: Any,
1542
+ step: int,
1543
+ need_vis: bool = False,
1544
+ dataset_name: str = "hm3d",
1545
+ ) -> _ModeStrategy:
1546
+ src_erp_u8 = batch.src_erp_rgb_u8.to(self.device, non_blocking=True)
1547
+ tgt_erp_u8 = batch.tgt_erp_rgb_u8.to(self.device, non_blocking=True)
1548
+ src_erp_depth = self._clamp_distance_for_supervision(
1549
+ batch.src_erp_depth_m.to(self.device, non_blocking=True)
1550
+ )
1551
+ tgt_erp_depth = self._clamp_distance_for_supervision(
1552
+ batch.tgt_erp_depth_m.to(self.device, non_blocking=True)
1553
+ )
1554
+ src_cdep = self._sanitize_positive_depth(
1555
+ batch.src_cube_depth_m.to(self.device, non_blocking=True)
1556
+ )
1557
+ tgt_cdep = self._sanitize_positive_depth(
1558
+ batch.tgt_cube_depth_m.to(self.device, non_blocking=True)
1559
+ )
1560
+ disable_depth_gt = bool(getattr(batch, "disable_depth_gt", False))
1561
+
1562
+ src_R = batch.src_R.to(self.device, non_blocking=True)
1563
+ src_t = batch.src_t.to(self.device, non_blocking=True)
1564
+ tgt_R = batch.tgt_R.to(self.device, non_blocking=True)
1565
+ tgt_t = batch.tgt_t.to(self.device, non_blocking=True)
1566
+
1567
+ cur_bs = int(src_erp_u8.shape[0])
1568
+ erp_h = int(src_erp_u8.shape[-2])
1569
+ erp_w = int(src_erp_u8.shape[-1])
1570
+ cube_face_w = int(batch.src_cube_depth_m.shape[2]) if torch.is_tensor(batch.src_cube_depth_m) else max(1, erp_h // 2)
1571
+
1572
+ use_flip_yz = str(dataset_name).lower() not in {"sim", "smx_sim_fisheye"}
1573
+ pose_convs_per_sample = ["c2w"] * cur_bs
1574
+ flip_yz_per_sample = [bool(use_flip_yz)] * cur_bs
1575
+
1576
+ extr_src_base = torch.stack(
1577
+ [build_extrinsics_w2c(src_R[i], src_t[i], pose_convs_per_sample[i]) for i in range(cur_bs)],
1578
+ dim=0
1579
+ )
1580
+ extr_tgt_base = torch.stack(
1581
+ [build_extrinsics_w2c(tgt_R[i], tgt_t[i], pose_convs_per_sample[i]) for i in range(cur_bs)],
1582
+ dim=0
1583
+ )
1584
+
1585
+ with torch.autocast("cuda", enabled=False):
1586
+ c2w_src = torch.linalg.inv(extr_src_base.to(torch.float32))
1587
+ c2w_tgt = torch.linalg.inv(extr_tgt_base.to(torch.float32))
1588
+
1589
+ flip_mask = torch.tensor(flip_yz_per_sample, device=c2w_src.device, dtype=torch.bool)
1590
+ negate_relative_z = False
1591
+ if bool(flip_mask.any().item()):
1592
+ flip_mode = os.environ.get("PANO_POSE_FLIP_CONVENTION", "flip_yz_negate_rel_z").strip().lower()
1593
+ negate_relative_z = flip_mode in {
1594
+ "flip_yz_negate_rel_z",
1595
+ "flip_yz_invert_z_translation",
1596
+ "flip_yz_neg_z",
1597
+ }
1598
+ if flip_mode in {"flip_y_only", "y", "y_only"}:
1599
+ diag = [1.0, -1.0, 1.0, 1.0]
1600
+ elif flip_mode in {"none", "identity", "no_flip"}:
1601
+ diag = [1.0, 1.0, 1.0, 1.0]
1602
+ else:
1603
+ diag = [1.0, -1.0, -1.0, 1.0]
1604
+ D = torch.diag(torch.tensor(diag, device=c2w_src.device, dtype=torch.float32))
1605
+ c2w_src = c2w_src.clone()
1606
+ c2w_tgt = c2w_tgt.clone()
1607
+ c2w_src[flip_mask] = c2w_src[flip_mask] @ D
1608
+ c2w_tgt[flip_mask] = c2w_tgt[flip_mask] @ D
1609
+
1610
+ ref_inv = torch.linalg.inv(c2w_src.to(torch.float32))
1611
+ c2w_src = ref_inv @ c2w_src
1612
+ c2w_tgt = ref_inv @ c2w_tgt
1613
+ if negate_relative_z:
1614
+ c2w_tgt = c2w_tgt.clone()
1615
+ c2w_tgt[flip_mask, 2, 3] *= -1.0
1616
+
1617
+ extr_src = torch.linalg.inv(c2w_src).to(dtype=extr_src_base.dtype)
1618
+ extr_tgt = torch.linalg.inv(c2w_tgt).to(dtype=extr_tgt_base.dtype)
1619
+
1620
+ src_erp = (src_erp_u8.float() / 255.0).clamp(0, 1)
1621
+ distance_init_cap_m = self._distance_init_cap_for_dataset(dataset_name)
1622
+
1623
+ out = self.model(
1624
+ image=src_erp,
1625
+ image_u8=src_erp_u8,
1626
+ camera_intrinsics=None,
1627
+ camera_model="spherical",
1628
+ depth_gt=None if disable_depth_gt else src_erp_depth,
1629
+ distance_init_cap_m=distance_init_cap_m,
1630
+ return_aux=True,
1631
+ )
1632
+ gaussians = out["gaussians"]
1633
+ aux_ray_target_all = out.get("unik3d_gt_rays", None)
1634
+
1635
+ def make_world_gaussians(b: int, g_b: Any) -> Any:
1636
+ return transform_gaussians_to_world(g_b, extr_src[b])
1637
+
1638
+ def make_sample(b: int, g_world: Any, enable_vis: bool) -> dict[str, Any]:
1639
+ src_rgb, src_depth, src_alpha = self._render_cubemap(g_world, extr_src[b], face_w=cube_face_w)
1640
+ tgt_rgb, tgt_depth, tgt_alpha = self._render_cubemap(g_world, extr_tgt[b], face_w=cube_face_w)
1641
+
1642
+ src_erp_pred = self._cube_to_erp(src_rgb, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w)
1643
+ tgt_erp_pred = self._cube_to_erp(tgt_rgb, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w)
1644
+ src_erp_alpha = self._cube_to_erp(src_alpha, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w)
1645
+ tgt_erp_alpha = self._cube_to_erp(tgt_alpha, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w)
1646
+ src_depth_dist = self._clamp_distance_for_supervision(
1647
+ self._cubemap_z_depth_to_distance(src_depth),
1648
+ clamp_max=False,
1649
+ )
1650
+ tgt_depth_dist = self._clamp_distance_for_supervision(
1651
+ self._cubemap_z_depth_to_distance(tgt_depth),
1652
+ clamp_max=False,
1653
+ )
1654
+ src_erp_depth_render = self._cube_to_erp(
1655
+ src_depth_dist, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w
1656
+ ).clamp(min=1e-4)
1657
+
1658
+ src_erp_depth_pred = self._clamp_distance_for_supervision(
1659
+ out["distance_layers"][b : b + 1, 0:1],
1660
+ clamp_max=False,
1661
+ )
1662
+ src_erp_depth2_pred = (
1663
+ self._clamp_distance_for_supervision(out["distance_layers"][b : b + 1, 1:2], clamp_max=False)
1664
+ if out["distance_layers"] is not None and out["distance_layers"].shape[1] > 1
1665
+ else None
1666
+ )
1667
+ tgt_erp_depth_pred = self._cube_to_erp(
1668
+ tgt_depth_dist, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w
1669
+ ).clamp(min=1e-4)
1670
+ tgt_depth_loss_mask = self._rendered_depth_valid_for_inv_loss(tgt_erp_depth_pred, tgt_erp_alpha)
1671
+ depth_novel = self._pick_depth_for_cubemap_frustum_mask(
1672
+ gt_depth_cube=None if disable_depth_gt else (tgt_cdep[b : b + 1][0] if torch.is_tensor(tgt_cdep) else None),
1673
+ pred_depth_cube=tgt_depth_dist,
1674
+ face_w=cube_face_w,
1675
+ )
1676
+ source_depth_for_visibility = self._pick_depth_for_cubemap_frustum_mask(
1677
+ gt_depth_cube=None if disable_depth_gt else (src_cdep[b : b + 1][0] if torch.is_tensor(src_cdep) else None),
1678
+ pred_depth_cube=src_depth_dist,
1679
+ face_w=cube_face_w,
1680
+ )
1681
+ mask_bool = view_frustum_mask_cubemap_union(
1682
+ depth_novel=depth_novel,
1683
+ extr_novel_w2c=extr_tgt[b],
1684
+ extr_source_w2c=extr_src[b],
1685
+ face_w=int(cube_face_w),
1686
+ source_depth=source_depth_for_visibility,
1687
+ )
1688
+ mask_erp = self._cube_to_erp(
1689
+ mask_bool[:, None].to(torch.float32), equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w
1690
+ )
1691
+
1692
+ gt_src_erp_u8 = src_erp_u8[b : b + 1]
1693
+ gt_tgt_erp_u8 = tgt_erp_u8[b : b + 1]
1694
+ gt_src_erp_depth = src_erp_depth[b : b + 1]
1695
+ gt_tgt_erp_depth = tgt_erp_depth[b : b + 1]
1696
+ gt_src_cube_u8 = batch.src_cube_rgb_u8[b].to(self.device, non_blocking=True).permute(0, 3, 1, 2).contiguous()
1697
+ gt_tgt_cube_u8 = batch.tgt_cube_rgb_u8[b].to(self.device, non_blocking=True).permute(0, 3, 1, 2).contiguous()
1698
+
1699
+ src_valid = torch.ones_like(gt_src_erp_depth) if disable_depth_gt else (gt_src_erp_depth > 0.0).to(dtype=torch.float32)
1700
+ tgt_valid = torch.ones_like(gt_tgt_erp_depth) if disable_depth_gt else (gt_tgt_erp_depth > 0.0).to(dtype=torch.float32)
1701
+ src_mask = torch.ones_like(src_valid)
1702
+ tgt_mask = (mask_erp.to(dtype=torch.float32) * tgt_valid).clamp(0.0, 1.0)
1703
+ tgt_mask_raw = tgt_mask
1704
+ tgt_mask = self._erode_supervision_mask(
1705
+ tgt_mask,
1706
+ self.target_mask_erode_px,
1707
+ circular_h=True,
1708
+ )
1709
+ src_cube_mask = torch.ones_like(src_alpha)
1710
+ if str(dataset_name).lower() == "hm3d" and (not disable_depth_gt) and torch.is_tensor(src_cdep):
1711
+ src_cube_valid = (src_cdep[b : b + 1][0, ..., 0] > 0.0).to(dtype=src_alpha.dtype).unsqueeze(1)
1712
+ if tuple(src_cube_valid.shape[-2:]) != tuple(src_alpha.shape[-2:]):
1713
+ src_cube_valid = F.interpolate(
1714
+ src_cube_valid,
1715
+ size=src_alpha.shape[-2:],
1716
+ mode="nearest",
1717
+ )
1718
+ src_cube_mask = src_cube_valid.to(device=src_alpha.device, dtype=src_alpha.dtype).clamp(0.0, 1.0)
1719
+ tgt_cube_valid = (depth_novel[..., 0] > 0.0).to(dtype=torch.float32).unsqueeze(1)
1720
+ tgt_cube_mask = (mask_bool[:, None].to(dtype=torch.float32) * tgt_cube_valid).clamp(0.0, 1.0)
1721
+ tgt_cube_mask = self._erode_supervision_mask(
1722
+ tgt_cube_mask,
1723
+ self.target_mask_erode_px,
1724
+ circular_h=False,
1725
+ )
1726
+
1727
+ src_cube_depth_zeros = torch.zeros_like(src_alpha)
1728
+ tgt_cube_depth_zeros = torch.zeros_like(tgt_alpha)
1729
+ src_erp_rgb_zeros = torch.zeros_like(src_erp_pred)
1730
+ tgt_erp_rgb_zeros = torch.zeros_like(tgt_erp_pred)
1731
+ src_erp_u8_zeros = torch.zeros_like(gt_src_erp_u8)
1732
+ tgt_erp_u8_zeros = torch.zeros_like(gt_tgt_erp_u8)
1733
+
1734
+ erp_proj_scale = 0.5 * (
1735
+ float(erp_w) / (2.0 * 3.141592653589793)
1736
+ + float(erp_h) / 3.141592653589793
1737
+ )
1738
+ reg_inputs = self._collect_regularization_inputs(
1739
+ out=out,
1740
+ gaussians=gaussians,
1741
+ b=b,
1742
+ projected_scale_factor=erp_proj_scale,
1743
+ )
1744
+
1745
+ vis_payload = None
1746
+ if enable_vis:
1747
+ src_unik3d_depth = None
1748
+ tgt_unik3d_depth = None
1749
+ raw_dist = out.get("unik3d_distance", None)
1750
+ if torch.is_tensor(raw_dist):
1751
+ src_unik3d_depth = raw_dist[b : b + 1, 0:1].detach()
1752
+
1753
+ vis_payload = {
1754
+ "src_gt": (gt_src_erp_u8.float() / 255.0).detach(),
1755
+ "src_pred": src_erp_pred.clamp(0, 1).detach(),
1756
+ "src_alpha": src_erp_alpha.detach(),
1757
+ "src_gt_depth": None if disable_depth_gt else gt_src_erp_depth.detach(),
1758
+ "src_pred_depth": src_erp_depth_render.detach(),
1759
+ "src_unik3d_depth": src_unik3d_depth,
1760
+ "tgt_gt": (gt_tgt_erp_u8.float() / 255.0).detach(),
1761
+ "tgt_pred": tgt_erp_pred.clamp(0, 1).detach(),
1762
+ "tgt_alpha": tgt_erp_alpha.detach(),
1763
+ "tgt_gt_depth": None if disable_depth_gt else gt_tgt_erp_depth.detach(),
1764
+ "tgt_pred_depth": tgt_erp_depth_pred.detach(),
1765
+ "tgt_unik3d_depth": tgt_unik3d_depth,
1766
+ "dataset_name": str(dataset_name),
1767
+ "scene": str(self._item_at(getattr(batch, "scene", None), b, "unknown")),
1768
+ "src_idx": int(self._item_at(getattr(batch, "src_idx", None), b, -1)),
1769
+ "tgt_idx": int(self._item_at(getattr(batch, "tgt_idx", None), b, -1)),
1770
+ "src_pose_w2c": extr_src[b : b + 1].detach(),
1771
+ "tgt_pose_w2c": extr_tgt[b : b + 1].detach(),
1772
+ "src_cube_gt_u8": (
1773
+ batch.src_cube_rgb_u8[b].detach()
1774
+ if hasattr(batch, "src_cube_rgb_u8") and torch.is_tensor(batch.src_cube_rgb_u8)
1775
+ else None
1776
+ ),
1777
+ "tgt_cube_gt_u8": (
1778
+ batch.tgt_cube_rgb_u8[b].detach()
1779
+ if hasattr(batch, "tgt_cube_rgb_u8") and torch.is_tensor(batch.tgt_cube_rgb_u8)
1780
+ else None
1781
+ ),
1782
+ "src_cube_pred_linear": src_rgb.detach(),
1783
+ "tgt_cube_pred_linear": tgt_rgb.detach(),
1784
+ "src_cube_alpha": src_alpha.detach(),
1785
+ "tgt_cube_alpha": tgt_alpha.detach(),
1786
+ "tgt_metric_mask_raw": tgt_mask_raw.detach(),
1787
+ "tgt_metric_mask": tgt_mask.detach(),
1788
+ }
1789
+
1790
+ tgt_loss_terms = [
1791
+ {
1792
+ "pred_rgb_linear": tgt_rgb,
1793
+ "pred_alpha": tgt_alpha,
1794
+ "pred_depth_m": tgt_cube_depth_zeros,
1795
+ "pred_depth2_m": None,
1796
+ "gt_rgb_u8": gt_tgt_cube_u8,
1797
+ "gt_depth_m": tgt_cube_depth_zeros,
1798
+ "mask": tgt_cube_mask,
1799
+ "apply_color": True,
1800
+ "apply_alpha": True,
1801
+ "apply_depth": False,
1802
+ "apply_percep": bool(float(self.loss_fn.w.lambda_percep) > 0.0),
1803
+ "apply_tv": False,
1804
+ "apply_grad": False,
1805
+ "apply_grad_img": False,
1806
+ "grad_img_circular_h": False,
1807
+ "gaussian_scales": None,
1808
+ "gaussian_quaternions": None,
1809
+ "gaussian_angular_cell": None,
1810
+ "delta_xy": None,
1811
+ "gaussian_mean_vectors": None,
1812
+ "gaussian_opacities": None,
1813
+ "gauss_grid_shape": None,
1814
+ "projected_scale_factor": None,
1815
+ },
1816
+ {
1817
+ "pred_rgb_linear": tgt_erp_rgb_zeros,
1818
+ "pred_alpha": torch.zeros_like(tgt_erp_depth_pred),
1819
+ "pred_depth_m": tgt_erp_depth_pred,
1820
+ "pred_depth2_m": None,
1821
+ "gt_rgb_u8": tgt_erp_u8_zeros,
1822
+ "gt_depth_m": gt_tgt_erp_depth,
1823
+ "mask": tgt_mask,
1824
+ "depth_mask": tgt_depth_loss_mask,
1825
+ "apply_color": False,
1826
+ "apply_alpha": False,
1827
+ "apply_depth": not disable_depth_gt,
1828
+ "apply_percep": False,
1829
+ "apply_tv": False,
1830
+ "apply_grad": False,
1831
+ "apply_grad_img": not disable_depth_gt,
1832
+ "grad_img_circular_h": True,
1833
+ "gaussian_scales": None,
1834
+ "gaussian_quaternions": None,
1835
+ "gaussian_angular_cell": None,
1836
+ "delta_xy": None,
1837
+ "gaussian_mean_vectors": None,
1838
+ "gaussian_opacities": None,
1839
+ "gauss_grid_shape": None,
1840
+ "projected_scale_factor": reg_inputs["projected_scale_factor"],
1841
+ },
1842
+ ]
1843
+ return {
1844
+ "src_loss_terms": [
1845
+ {
1846
+ "pred_rgb_linear": src_rgb,
1847
+ "pred_alpha": src_alpha,
1848
+ "pred_depth_m": src_cube_depth_zeros,
1849
+ "pred_depth2_m": None,
1850
+ "gt_rgb_u8": gt_src_cube_u8,
1851
+ "gt_depth_m": src_cube_depth_zeros,
1852
+ "mask": src_cube_mask,
1853
+ "apply_color": True,
1854
+ "apply_alpha": True,
1855
+ "apply_depth": False,
1856
+ "apply_percep": False,
1857
+ "apply_tv": False,
1858
+ "apply_grad": False,
1859
+ "apply_grad_img": False,
1860
+ "grad_img_circular_h": False,
1861
+ "gaussian_scales": None,
1862
+ "gaussian_quaternions": None,
1863
+ "gaussian_angular_cell": None,
1864
+ "delta_xy": None,
1865
+ "gaussian_mean_vectors": None,
1866
+ "gaussian_opacities": None,
1867
+ "gauss_grid_shape": None,
1868
+ "projected_scale_factor": None,
1869
+ },
1870
+ {
1871
+ "pred_rgb_linear": src_erp_rgb_zeros,
1872
+ "pred_alpha": torch.zeros_like(src_erp_depth_pred),
1873
+ "pred_depth_m": src_erp_depth_pred,
1874
+ "pred_depth2_m": src_erp_depth2_pred,
1875
+ "gt_rgb_u8": src_erp_u8_zeros,
1876
+ "gt_depth_m": gt_src_erp_depth,
1877
+ "mask": src_mask,
1878
+ "apply_color": False,
1879
+ "apply_alpha": False,
1880
+ "apply_depth": False,
1881
+ "apply_percep": False,
1882
+ "apply_tv": True,
1883
+ "apply_grad": False,
1884
+ "apply_grad_img": not disable_depth_gt,
1885
+ "grad_img_circular_h": True,
1886
+ "gaussian_scales": reg_inputs["gaussian_scales"],
1887
+ "gaussian_quaternions": reg_inputs["gaussian_quaternions"],
1888
+ "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
1889
+ "delta_xy": reg_inputs["delta_xy_eff"],
1890
+ "delta_rho": reg_inputs["delta_rho_raw"],
1891
+ "delta_grid": reg_inputs["delta_grid"],
1892
+ "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
1893
+ "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
1894
+ "gaussian_opacities": reg_inputs["gaussian_opacities"],
1895
+ "gauss_grid_shape": reg_inputs["gauss_grid_shape"],
1896
+ "projected_scale_factor": reg_inputs["projected_scale_factor"],
1897
+ "projection_model": "erp",
1898
+ },
1899
+ ],
1900
+ "tgt_loss_terms": tgt_loss_terms,
1901
+ "gaussian_scales": reg_inputs["gaussian_scales"],
1902
+ "gaussian_quaternions": reg_inputs["gaussian_quaternions"],
1903
+ "gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
1904
+ "delta_xy": reg_inputs["delta_xy_eff"],
1905
+ "delta_rho": reg_inputs["delta_rho_raw"],
1906
+ "delta_grid": reg_inputs["delta_grid"],
1907
+ "gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
1908
+ "gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
1909
+ "gaussian_opacities": reg_inputs["gaussian_opacities"],
1910
+ "gauss_grid_shape": reg_inputs["gauss_grid_shape"],
1911
+ "projected_scale_factor": reg_inputs["projected_scale_factor"],
1912
+ "projection_model": "erp",
1913
+ "aux_losses": self._aux_ray_losses(
1914
+ pred_rays=(
1915
+ out.get("unik3d_rays", None)[b : b + 1]
1916
+ if torch.is_tensor(out.get("unik3d_rays", None))
1917
+ else None
1918
+ ),
1919
+ gt_rays=(
1920
+ aux_ray_target_all[b : b + 1]
1921
+ if torch.is_tensor(aux_ray_target_all)
1922
+ else None
1923
+ ),
1924
+ mask=src_valid,
1925
+ pred_distance=(
1926
+ out["unik3d_distance"][b : b + 1, 0:1]
1927
+ if torch.is_tensor(out.get("unik3d_distance", None))
1928
+ else None
1929
+ ),
1930
+ pred_distance2=src_erp_depth2_pred,
1931
+ gt_distance=None if disable_depth_gt else gt_src_erp_depth,
1932
+ depth_mask=src_valid,
1933
+ ),
1934
+ "vis_payload": vis_payload,
1935
+ }
1936
+
1937
+ return _ModeStrategy(
1938
+ batch_size=int(cur_bs),
1939
+ gaussians=gaussians,
1940
+ make_world_gaussians=make_world_gaussians,
1941
+ make_sample=make_sample,
1942
+ collect_all_vis=bool(getattr(batch, "collect_all_vis", False)),
1943
+ )
1944
+
1945
+ def _render_cubemap(
1946
+ self,
1947
+ gaussians: Any,
1948
+ extr_w2c: torch.Tensor,
1949
+ face_w: int,
1950
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1951
+ device = gaussians.mean_vectors.device
1952
+ intr = get_pinhole_intrinsics_4x4(int(face_w)).to(device=device)[None].expand(6, -1, -1)
1953
+ extr_faces = cubemap_face_cameras(extr_w2c, device=device)
1954
+ out = self.renderer(
1955
+ gaussians,
1956
+ extrinsics=extr_faces,
1957
+ intrinsics=intr,
1958
+ image_width=int(face_w),
1959
+ image_height=int(face_w),
1960
+ )
1961
+ return out.color.contiguous(), out.depth.contiguous(), out.alpha.contiguous()
1962
+
1963
+ def _cube_to_erp(self, cube: torch.Tensor, equ_h: int, equ_w: int, face_w: int) -> torch.Tensor:
1964
+ cube = cube.permute(1, 0, 2, 3).unsqueeze(0)
1965
+ c2e = Cube2Equirec(face_w=int(face_w), equ_h=int(equ_h), equ_w=int(equ_w)).to(device=cube.device)
1966
+ return c2e(cube)
unisharp/datasets/__pycache__/dl3dv.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
unisharp/datasets/__pycache__/dl3dv.cpython-313.pyc ADDED
Binary file (19.5 kB). View file
 
unisharp/datasets/__pycache__/pair_sampling.cpython-310.pyc ADDED
Binary file (3.78 kB). View file
 
unisharp/datasets/__pycache__/pair_sampling.cpython-313.pyc ADDED
Binary file (6.66 kB). View file
 
unisharp/datasets/__pycache__/panogs.cpython-310.pyc ADDED
Binary file (17.4 kB). View file
 
unisharp/datasets/__pycache__/panogs.cpython-313.pyc ADDED
Binary file (32.5 kB). View file
 
unisharp/datasets/__pycache__/re10k.cpython-310.pyc ADDED
Binary file (19.8 kB). View file
 
unisharp/datasets/__pycache__/re10k.cpython-313.pyc ADDED
Binary file (37.4 kB). View file
 
unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-310.pyc ADDED
Binary file (17.5 kB). View file
 
unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-313.pyc ADDED
Binary file (32.4 kB). View file
 
unisharp/datasets/__pycache__/sim_panorama.cpython-310.pyc ADDED
Binary file (18.9 kB). View file
 
unisharp/datasets/__pycache__/sim_panorama.cpython-313.pyc ADDED
Binary file (34 kB). View file
 
unisharp/datasets/__pycache__/wildrgbd.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
unisharp/datasets/__pycache__/wildrgbd.cpython-313.pyc ADDED
Binary file (21.4 kB). View file
 
unisharp/datasets/dl3dv.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from collections import defaultdict, deque
5
+ import json
6
+ from pathlib import Path
7
+ import random
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from PIL import Image
13
+ from torch.utils.data import IterableDataset
14
+
15
+ from unisharp.datasets.pair_sampling import (
16
+ project_overlap_ratio,
17
+ resize_k3_align_corners_false,
18
+ resize_rgb_u8_chw_high_quality,
19
+ select_targets_for_source,
20
+ )
21
+ from unisharp.datasets.re10k import Re10KPairSample, re10k_collate
22
+ from unisharp import DEFAULT_MAX_DEPTH_M
23
+
24
+
25
+ class DL3DVDataset(IterableDataset):
26
+ def __init__(
27
+ self,
28
+ root: Path,
29
+ depth_root: Path,
30
+ scene_specs_file: Path | None = None,
31
+ min_frame_gap: int = 1,
32
+ max_frame_gap: int = 32,
33
+ pair_max_translation_m: float = 0.5,
34
+ pair_min_overlap: float = 0.6,
35
+ pair_overlap_sample_h: int = 32,
36
+ pair_overlap_sample_w: int = 56,
37
+ output_h: int | None = None,
38
+ output_w: int | None = None,
39
+ shuffle_scene: bool = True,
40
+ shuffle_frame: bool = False,
41
+ ddp_rank: int = 0,
42
+ ddp_world_size: int = 1,
43
+ batch_size_hint: int = 1,
44
+ depth_max_m: float = DEFAULT_MAX_DEPTH_M,
45
+ seed: int = 0,
46
+ verify_manifest_paths: bool = False,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.root = Path(root)
50
+ self.depth_root = Path(depth_root)
51
+ self.min_frame_gap = int(min_frame_gap)
52
+ self.max_frame_gap = int(max_frame_gap)
53
+ self.pair_max_translation_m = float(pair_max_translation_m)
54
+ self.pair_min_overlap = float(pair_min_overlap)
55
+ self.pair_overlap_sample_h = int(pair_overlap_sample_h)
56
+ self.pair_overlap_sample_w = int(pair_overlap_sample_w)
57
+ self.output_h = int(output_h) if output_h is not None else None
58
+ self.output_w = int(output_w) if output_w is not None else None
59
+ self.shuffle_scene = bool(shuffle_scene)
60
+ self.shuffle_frame = bool(shuffle_frame)
61
+ self.ddp_rank = int(ddp_rank)
62
+ self.ddp_world_size = int(ddp_world_size)
63
+ self.batch_size_hint = int(max(1, batch_size_hint))
64
+ self.depth_max_m = float(depth_max_m)
65
+ self.seed = int(seed)
66
+ self.epoch = 0
67
+ self.verify_manifest_paths = bool(verify_manifest_paths)
68
+ self.scene_specs_file = Path(scene_specs_file) if scene_specs_file is not None else None
69
+ self.scene_specs = self._load_scene_specs()
70
+ if not self.scene_specs:
71
+ raise RuntimeError(f"No valid DL3DV scenes found under {self.root}")
72
+
73
+ def set_epoch(self, epoch: int) -> None:
74
+ self.epoch = int(epoch)
75
+
76
+ def _load_scene_specs(self) -> list[tuple[str, Path, Path]]:
77
+ if self.scene_specs_file is None:
78
+ return self._scan_scenes()
79
+ if not self.scene_specs_file.exists():
80
+ raise FileNotFoundError(self.scene_specs_file)
81
+ out: list[tuple[str, Path, Path]] = []
82
+ for raw in self.scene_specs_file.read_text(encoding="utf-8").splitlines():
83
+ line = raw.strip()
84
+ if not line:
85
+ continue
86
+ parts = line.split("|")
87
+ if len(parts) != 3:
88
+ continue
89
+ scene_name, scene_dir_raw, depth_dir_raw = parts
90
+ scene_dir = Path(scene_dir_raw)
91
+ depth_dir = Path(depth_dir_raw)
92
+ if (not self.verify_manifest_paths) or (scene_dir.exists() and depth_dir.exists()):
93
+ out.append((scene_name, scene_dir, depth_dir))
94
+ return out
95
+
96
+ def _scan_scenes(self) -> list[tuple[str, Path, Path]]:
97
+ out: list[tuple[str, Path, Path]] = []
98
+ for bucket_dir in sorted([p for p in self.root.iterdir() if p.is_dir()]):
99
+ for scene_stub in sorted([p for p in bucket_dir.iterdir() if p.is_dir()]):
100
+ inner_dirs = [p for p in scene_stub.iterdir() if p.is_dir()]
101
+ scene_dir = inner_dirs[0] if inner_dirs else scene_stub
102
+ transforms_path = scene_dir / "transforms.json"
103
+ image_dir = scene_dir / "images_4"
104
+ depth_dir = self.depth_root / bucket_dir.name / scene_stub.name / "exports" / "mini_npz" / "per_image"
105
+ if transforms_path.exists() and image_dir.exists() and depth_dir.exists():
106
+ scene_name = f"{bucket_dir.name}/{scene_stub.name}"
107
+ out.append((scene_name, scene_dir, depth_dir))
108
+ return out
109
+
110
+ @staticmethod
111
+ def _load_rgb_u8(path: Path) -> torch.Tensor:
112
+ arr = np.asarray(Image.open(path).convert("RGB"), dtype=np.uint8).copy()
113
+ return torch.from_numpy(arr).permute(2, 0, 1).contiguous()
114
+
115
+ def _load_depth_m(self, path: Path) -> torch.Tensor:
116
+ payload = np.load(path)
117
+ depth = payload["depth"].astype(np.float32)
118
+ depth[~np.isfinite(depth)] = 0.0
119
+ depth = np.clip(depth, a_min=0.0, a_max=self.depth_max_m)
120
+ return torch.from_numpy(depth).unsqueeze(0)
121
+
122
+ @staticmethod
123
+ def _resize_depth_to_image(depth: torch.Tensor, image_hw: tuple[int, int]) -> torch.Tensor:
124
+ target_h, target_w = int(image_hw[0]), int(image_hw[1])
125
+ if depth.shape[-2:] == (target_h, target_w):
126
+ return depth
127
+ return F.interpolate(
128
+ depth.unsqueeze(0),
129
+ size=(target_h, target_w),
130
+ mode="nearest",
131
+ ).squeeze(0)
132
+
133
+ @staticmethod
134
+ def _frame_id_from_name(name: str) -> int:
135
+ stem = Path(name).stem
136
+ return int(stem.split("_")[-1])
137
+
138
+ def _load_scene(
139
+ self,
140
+ scene_name: str,
141
+ scene_dir: Path,
142
+ depth_dir: Path,
143
+ ) -> tuple[list[int], dict[int, Path], dict[int, Path], dict[int, torch.Tensor], dict[int, torch.Tensor], torch.Tensor]:
144
+ meta = json.loads((scene_dir / "transforms.json").read_text())
145
+ orig_w = int(meta["w"])
146
+ orig_h = int(meta["h"])
147
+ k = torch.eye(3, dtype=torch.float32)
148
+ k[0, 0] = float(meta["fl_x"])
149
+ k[1, 1] = float(meta["fl_y"])
150
+ k[0, 2] = float(meta["cx"])
151
+ k[1, 2] = float(meta["cy"])
152
+
153
+ image_dir = scene_dir / "images_4"
154
+ image_paths = {self._frame_id_from_name(p.name): p for p in image_dir.glob("*.png")}
155
+ depth_paths = {self._frame_id_from_name(p.name): p for p in depth_dir.glob("*.npz")}
156
+ w2c_map: dict[int, torch.Tensor] = {}
157
+ intr_map: dict[int, torch.Tensor] = {}
158
+ valid_ids: list[int] = []
159
+
160
+ example_img = None
161
+ for frame in meta.get("frames", []):
162
+ rel_path = str(frame.get("file_path", ""))
163
+ frame_name = Path(rel_path).name
164
+ frame_id = self._frame_id_from_name(frame_name)
165
+ if frame_id not in image_paths or frame_id not in depth_paths:
166
+ continue
167
+ c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32)
168
+ c2w[:3, 1:3] *= -1.0
169
+ if example_img is None:
170
+ example_img = self._load_rgb_u8(image_paths[frame_id])
171
+ cur_h, cur_w = int(example_img.shape[1]), int(example_img.shape[2])
172
+ k_cur = k.clone()
173
+ if cur_h != orig_h or cur_w != orig_w:
174
+ sx = float(cur_w) / float(orig_w)
175
+ sy = float(cur_h) / float(orig_h)
176
+ k_cur = resize_k3_align_corners_false(k_cur, sx=sx, sy=sy)
177
+ w2c_map[frame_id] = torch.linalg.inv(c2w)
178
+ intr_map[frame_id] = k_cur
179
+ valid_ids.append(frame_id)
180
+ valid_ids = sorted(valid_ids)
181
+ return valid_ids, image_paths, depth_paths, w2c_map, intr_map, k
182
+
183
+ def __iter__(self):
184
+ scenes = list(self.scene_specs)
185
+ order_rng = random.Random(self.seed + self.epoch)
186
+ if self.shuffle_scene:
187
+ order_rng.shuffle(scenes)
188
+ pending_by_hw: dict[tuple[int, int], deque[Re10KPairSample]] = defaultdict(deque)
189
+ worker_info = torch.utils.data.get_worker_info()
190
+ num_workers = worker_info.num_workers if worker_info is not None else 1
191
+ worker_id = worker_info.id if worker_info is not None else 0
192
+ total_shards = max(1, self.ddp_world_size * num_workers)
193
+ shard_id = self.ddp_rank * num_workers + worker_id
194
+ src_unit_index = 0
195
+
196
+ for scene_order_idx, (scene_name, scene_dir, depth_dir) in enumerate(scenes):
197
+ try:
198
+ valid_ids, image_paths, depth_paths, w2c_map, intr_map, _ = self._load_scene(scene_name, scene_dir, depth_dir)
199
+ except Exception:
200
+ continue
201
+ if len(valid_ids) < 2:
202
+ continue
203
+ src_order = list(valid_ids)
204
+ scene_rng = random.Random(self.seed + self.epoch * 1000003 + scene_order_idx)
205
+ if self.shuffle_frame:
206
+ scene_rng.shuffle(src_order)
207
+ centers = torch.stack([torch.linalg.inv(w2c_map[i])[:3, 3] for i in valid_ids], dim=0)
208
+ frame_to_pos = {fid: pos for pos, fid in enumerate(valid_ids)}
209
+
210
+ def overlap_avg(src_pos: int, tgt_pos: int) -> float:
211
+ src_fid = int(valid_ids[src_pos])
212
+ tgt_fid = int(valid_ids[tgt_pos])
213
+ src_img_path = image_paths[src_fid]
214
+ with Image.open(src_img_path) as img:
215
+ w = int(img.size[0])
216
+ h = int(img.size[1])
217
+ return float(
218
+ 0.5
219
+ * (
220
+ project_overlap_ratio(
221
+ src_w2c=w2c_map[src_fid],
222
+ tgt_w2c=w2c_map[tgt_fid],
223
+ src_k=intr_map[src_fid],
224
+ tgt_k=intr_map[tgt_fid],
225
+ h=h,
226
+ w=w,
227
+ sample_h=self.pair_overlap_sample_h,
228
+ sample_w=self.pair_overlap_sample_w,
229
+ )
230
+ + project_overlap_ratio(
231
+ src_w2c=w2c_map[tgt_fid],
232
+ tgt_w2c=w2c_map[src_fid],
233
+ src_k=intr_map[tgt_fid],
234
+ tgt_k=intr_map[src_fid],
235
+ h=h,
236
+ w=w,
237
+ sample_h=self.pair_overlap_sample_h,
238
+ sample_w=self.pair_overlap_sample_w,
239
+ )
240
+ )
241
+ )
242
+
243
+ for src_idx in src_order:
244
+ if src_unit_index % total_shards != shard_id:
245
+ src_unit_index += 1
246
+ continue
247
+ src_unit_index += 1
248
+ src_pos = int(frame_to_pos[int(src_idx)])
249
+ tgt_pos_list = select_targets_for_source(
250
+ src_idx=src_pos,
251
+ candidate_indices=list(range(len(valid_ids))),
252
+ centers=centers,
253
+ min_index_gap=int(self.min_frame_gap),
254
+ max_index_gap=int(self.max_frame_gap),
255
+ pair_max_translation_m=float(self.pair_max_translation_m),
256
+ pair_min_overlap=float(self.pair_min_overlap),
257
+ overlap_score_fn=overlap_avg,
258
+ )
259
+ if not tgt_pos_list:
260
+ continue
261
+ tgt_idx = int(valid_ids[scene_rng.choice(tgt_pos_list)])
262
+ try:
263
+ src_img = self._load_rgb_u8(image_paths[int(src_idx)])
264
+ tgt_img = self._load_rgb_u8(image_paths[int(tgt_idx)])
265
+ src_depth = self._load_depth_m(depth_paths[int(src_idx)])
266
+ tgt_depth = self._load_depth_m(depth_paths[int(tgt_idx)])
267
+ except Exception:
268
+ continue
269
+ src_depth = self._resize_depth_to_image(src_depth, (int(src_img.shape[1]), int(src_img.shape[2])))
270
+ tgt_depth = self._resize_depth_to_image(tgt_depth, (int(tgt_img.shape[1]), int(tgt_img.shape[2])))
271
+ src_intr = intr_map[int(src_idx)].clone()
272
+ tgt_intr = intr_map[int(tgt_idx)].clone()
273
+ if self.output_h is not None and self.output_w is not None:
274
+ oh, ow = int(src_img.shape[1]), int(src_img.shape[2])
275
+ if oh != self.output_h or ow != self.output_w:
276
+ sx = float(self.output_w) / float(ow)
277
+ sy = float(self.output_h) / float(oh)
278
+ src_img = resize_rgb_u8_chw_high_quality(src_img, size=(self.output_h, self.output_w))
279
+ tgt_img = resize_rgb_u8_chw_high_quality(tgt_img, size=(self.output_h, self.output_w))
280
+ src_depth = F.interpolate(src_depth[None], size=(self.output_h, self.output_w), mode="nearest")[0]
281
+ tgt_depth = F.interpolate(tgt_depth[None], size=(self.output_h, self.output_w), mode="nearest")[0]
282
+ src_intr = resize_k3_align_corners_false(src_intr, sx=sx, sy=sy)
283
+ tgt_intr = resize_k3_align_corners_false(tgt_intr, sx=sx, sy=sy)
284
+ sample = Re10KPairSample(
285
+ src_rgb_u8=src_img,
286
+ tgt_rgb_u8=tgt_img,
287
+ src_w2c=w2c_map[int(src_idx)],
288
+ tgt_w2c=w2c_map[int(tgt_idx)],
289
+ src_intrinsics=src_intr,
290
+ tgt_intrinsics=tgt_intr,
291
+ src_idx=int(src_idx),
292
+ tgt_idx=int(tgt_idx),
293
+ scene=scene_name,
294
+ src_depth_m=src_depth,
295
+ tgt_depth_m=tgt_depth,
296
+ )
297
+ hw_key = (int(sample.src_rgb_u8.shape[1]), int(sample.src_rgb_u8.shape[2]))
298
+ bucket = pending_by_hw[hw_key]
299
+ bucket.append(sample)
300
+ if self.batch_size_hint <= 1:
301
+ yield bucket.popleft()
302
+ continue
303
+ while len(bucket) >= self.batch_size_hint:
304
+ packed = [bucket.popleft() for _ in range(self.batch_size_hint)]
305
+ yield re10k_collate(packed)
unisharp/datasets/pair_sampling.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from typing import Callable
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from unisharp.utils.pixel_convention import scale_intrinsics_align_corners_false
10
+
11
+
12
+ def resize_k3_align_corners_false(k: torch.Tensor, *, sx: float, sy: float) -> torch.Tensor:
13
+ return scale_intrinsics_align_corners_false(k, sx=float(sx), sy=float(sy))
14
+
15
+
16
+ def resize_rgb_u8_chw_high_quality(image: torch.Tensor, *, size: tuple[int, int]) -> torch.Tensor:
17
+ if not torch.is_tensor(image) or image.ndim != 3:
18
+ raise ValueError(f"Expected CHW tensor, got {tuple(image.shape) if torch.is_tensor(image) else type(image)}")
19
+ dst_h, dst_w = int(size[0]), int(size[1])
20
+ if tuple(image.shape[-2:]) == (dst_h, dst_w):
21
+ return image.contiguous()
22
+ resized = F.interpolate(
23
+ image.unsqueeze(0).to(torch.float32),
24
+ size=(dst_h, dst_w),
25
+ mode="bicubic",
26
+ align_corners=False,
27
+ antialias=True,
28
+ )
29
+ return resized[0].round().clamp(0, 255).to(torch.uint8).contiguous()
30
+
31
+
32
+ def project_overlap_ratio(
33
+ src_w2c: torch.Tensor,
34
+ tgt_w2c: torch.Tensor,
35
+ src_k: torch.Tensor,
36
+ tgt_k: torch.Tensor,
37
+ h: int,
38
+ w: int,
39
+ src_hw: tuple[int, int] | None = None,
40
+ tgt_hw: tuple[int, int] | None = None,
41
+ sample_h: int = 32,
42
+ sample_w: int = 56,
43
+ proxy_depth: float = 1.0,
44
+ ) -> float:
45
+ device = src_w2c.device
46
+ src_h, src_w = tuple(int(v) for v in (src_hw or (h, w)))
47
+ tgt_h, tgt_w = tuple(int(v) for v in (tgt_hw or (h, w)))
48
+ ys = torch.linspace(0, src_h - 1, steps=sample_h, device=device)
49
+ xs = torch.linspace(0, src_w - 1, steps=sample_w, device=device)
50
+ vv, uu = torch.meshgrid(ys, xs, indexing="ij")
51
+ u = uu.reshape(-1)
52
+ v = vv.reshape(-1)
53
+
54
+ fx, fy = src_k[0, 0], src_k[1, 1]
55
+ cx, cy = src_k[0, 2], src_k[1, 2]
56
+ x = (u - cx) / fx
57
+ y = (v - cy) / fy
58
+ z = torch.ones_like(x)
59
+ rays = torch.stack([x, y, z], dim=-1)
60
+ rays = rays / torch.norm(rays, dim=-1, keepdim=True).clamp(min=1e-6)
61
+ pts_src = rays * float(proxy_depth)
62
+
63
+ src_c2w = torch.linalg.inv(src_w2c)
64
+ pts_src_h = torch.cat([pts_src, torch.ones_like(pts_src[:, :1])], dim=-1)
65
+ pts_w = (src_c2w @ pts_src_h.T).T
66
+ pts_tgt = (tgt_w2c @ pts_w.T).T
67
+ xt, yt, zt = pts_tgt[:, 0], pts_tgt[:, 1], pts_tgt[:, 2].clamp(min=1e-6)
68
+ ut = tgt_k[0, 0] * (xt / zt) + tgt_k[0, 2]
69
+ vt = tgt_k[1, 1] * (yt / zt) + tgt_k[1, 2]
70
+ inside = (zt > 0.0) & (ut >= 0.0) & (ut <= float(tgt_w - 1)) & (vt >= 0.0) & (vt <= float(tgt_h - 1))
71
+ return float(inside.float().mean().item())
72
+
73
+
74
+ def select_targets_for_source(
75
+ *,
76
+ src_idx: int,
77
+ candidate_indices: list[int],
78
+ centers: torch.Tensor,
79
+ min_index_gap: int,
80
+ max_index_gap: int,
81
+ pair_max_translation_m: float,
82
+ pair_min_overlap: float,
83
+ overlap_score_fn: Callable[[int, int], float],
84
+ ) -> list[int]:
85
+ src_c = centers[int(src_idx)]
86
+ tgt_cands: list[int] = []
87
+ for j in candidate_indices:
88
+ j = int(j)
89
+ if j == int(src_idx):
90
+ continue
91
+ gap = abs(int(j) - int(src_idx))
92
+ if gap < int(min_index_gap) or gap > int(max_index_gap):
93
+ continue
94
+ trans = float(torch.norm(centers[j] - src_c, p=2).item())
95
+ if trans > float(pair_max_translation_m):
96
+ continue
97
+ if float(overlap_score_fn(int(src_idx), j)) >= float(pair_min_overlap):
98
+ tgt_cands.append(j)
99
+ return tgt_cands
unisharp/datasets/panogs.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Literal
6
+ from typing import cast
7
+ import tarfile
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+
14
+ from unisharp import DEFAULT_MAX_DEPTH_M
15
+
16
+
17
+ MAX_DEPTH_M = DEFAULT_MAX_DEPTH_M
18
+
19
+ _PAIR_RECIPE_FIXED: tuple[str, bool] = ("c2w", True)
20
+
21
+ _PAIR_CONVENTIONS: tuple[str, ...] = ("c2w",)
22
+
23
+
24
+ def _torch_load_any(path: Path) -> object:
25
+ try:
26
+ return torch.load(path, map_location="cpu", weights_only=False)
27
+ except TypeError:
28
+ return torch.load(path, map_location="cpu")
29
+ except (KeyError, tarfile.ReadError, EOFError, OSError, RuntimeError) as e:
30
+ raise RuntimeError(f"torch.load failed (possibly incomplete/corrupted): {path}") from e
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class PanOGSSample:
35
+
36
+ src_erp_rgb_u8: torch.Tensor
37
+ tgt_erp_rgb_u8: torch.Tensor
38
+ src_erp_depth_m: torch.Tensor
39
+ tgt_erp_depth_m: torch.Tensor
40
+
41
+ src_cube_rgb_u8: torch.Tensor
42
+ tgt_cube_rgb_u8: torch.Tensor
43
+ src_cube_depth_m: torch.Tensor
44
+ tgt_cube_depth_m: torch.Tensor
45
+
46
+ src_R: torch.Tensor
47
+ src_t: torch.Tensor
48
+ tgt_R: torch.Tensor
49
+ tgt_t: torch.Tensor
50
+
51
+ src_idx: int
52
+ tgt_idx: int
53
+ scene: str
54
+
55
+
56
+ def _load_erp_rgb_u8(path: Path) -> torch.Tensor:
57
+ img = np.array(Image.open(path))
58
+ if img.ndim != 3 or img.shape[2] != 3:
59
+ raise ValueError(f"Expected RGB image at {path}, got shape={img.shape}")
60
+ return torch.from_numpy(img.astype(np.uint8)).permute(2, 0, 1).contiguous()
61
+
62
+
63
+ def _load_depth_png(path: Path) -> torch.Tensor:
64
+ dep = np.array(Image.open(path))
65
+ return torch.from_numpy(dep)
66
+
67
+
68
+ def _depth_to_meters(depth: torch.Tensor, max_depth_m: float = DEFAULT_MAX_DEPTH_M) -> torch.Tensor:
69
+ depth_f = depth.to(torch.float32)
70
+ maxv = float(depth_f.max().item()) if depth_f.numel() else 0.0
71
+ if maxv > 200.0:
72
+ depth_f = depth_f / 1000.0
73
+ depth_f[~torch.isfinite(depth_f)] = 0.0
74
+ return depth_f.clamp(min=0.0, max=float(max_depth_m))
75
+
76
+
77
+ class PanOGSDataset(Dataset[PanOGSSample]):
78
+
79
+ def __init__(
80
+ self,
81
+ root: Path,
82
+ index_manifest_path: Path | None = None,
83
+ src_tgt_max_index_gap: int = 25,
84
+ use_cubemap_supervision: bool = True,
85
+ pair_sampling: bool = True,
86
+ pair_max_translation_m: float = 0.5,
87
+ pair_min_depth_overlap: float = 0.6,
88
+ pair_overlap_face_w: int = 64,
89
+ pair_overlap_margin: float = 1.05,
90
+ pair_max_tries: int = 48,
91
+ depth_max_m: float = DEFAULT_MAX_DEPTH_M,
92
+ ) -> None:
93
+ self.root = root
94
+ self.src_tgt_max_index_gap = int(src_tgt_max_index_gap)
95
+ self.use_cubemap_supervision = use_cubemap_supervision
96
+ self.pair_sampling = bool(pair_sampling)
97
+ self.pair_max_translation_m = float(pair_max_translation_m)
98
+ self.pair_min_depth_overlap = float(pair_min_depth_overlap)
99
+ self.pair_overlap_face_w = int(pair_overlap_face_w)
100
+ self.pair_overlap_margin = float(pair_overlap_margin)
101
+ self.pair_max_tries = int(pair_max_tries)
102
+ self.depth_max_m = float(depth_max_m)
103
+ self.index_manifest_path = Path(index_manifest_path) if index_manifest_path is not None else None
104
+
105
+ self._pair_valid_tgts: dict[tuple[str, int], list[int]] = {}
106
+ self._pair_overlap_cache: dict[tuple[str, int, int], float] = {}
107
+
108
+ if not root.exists():
109
+ raise FileNotFoundError(root)
110
+
111
+ self.scenes = sorted([p for p in root.iterdir() if p.is_dir()])
112
+ if not self.scenes:
113
+ raise RuntimeError(f"No scene folders found in {root}")
114
+
115
+ self._pose_cache: dict[str, tuple[np.ndarray, np.ndarray]] = {}
116
+ self._meta_paths: dict[str, Path] = {}
117
+ self._num_frames: dict[str, int] = {}
118
+ self._available_frames: dict[str, list[int]] = {}
119
+
120
+ if self.index_manifest_path is not None:
121
+ if not self.index_manifest_path.exists():
122
+ raise FileNotFoundError(self.index_manifest_path)
123
+ valid_scenes: list[Path] = []
124
+ for raw in self.index_manifest_path.read_text(encoding="utf-8").splitlines():
125
+ line = raw.strip()
126
+ if not line:
127
+ continue
128
+ parts = line.split("|")
129
+ scene_name = parts[0].strip()
130
+ if not scene_name:
131
+ continue
132
+ scene_dir = root / scene_name
133
+ meta_path = scene_dir / "meta.pt"
134
+ if not meta_path.exists():
135
+ continue
136
+ if len(parts) >= 2:
137
+ try:
138
+ n_pose = int(parts[1])
139
+ except ValueError:
140
+ n_pose = 0
141
+ else:
142
+ n_pose = 0
143
+ if n_pose <= 0:
144
+ continue
145
+ self._meta_paths[scene_name] = meta_path
146
+ self._num_frames[scene_name] = n_pose
147
+ self._available_frames[scene_name] = list(range(n_pose))
148
+ valid_scenes.append(scene_dir)
149
+ self.scenes = valid_scenes
150
+
151
+
152
+ if not self._available_frames:
153
+
154
+ valid_scenes = []
155
+ for scene_i, scene_dir in enumerate(self.scenes):
156
+ meta_path = scene_dir / "meta.pt"
157
+ if not meta_path.exists():
158
+ continue
159
+
160
+ ex = _torch_load_any(meta_path)
161
+ cams = ex.get("cameras", None)
162
+ if not isinstance(cams, torch.Tensor):
163
+ raise ValueError(f"meta.pt missing 'cameras' tensor in {scene_dir}")
164
+ if cams.ndim != 3 or tuple(cams.shape[1:]) != (4, 4):
165
+ raise ValueError(f"Bad meta.pt cameras shape {tuple(cams.shape)} in {scene_dir}")
166
+ n_pose = int(cams.shape[0])
167
+
168
+ frames = list(range(n_pose))
169
+
170
+ name = scene_dir.name
171
+ self._meta_paths[name] = meta_path
172
+ self._num_frames[name] = n_pose
173
+ self._available_frames[name] = frames
174
+ valid_scenes.append(scene_dir)
175
+
176
+ self.scenes = valid_scenes
177
+
178
+ def _get_pose(self, scene: str) -> tuple[np.ndarray, np.ndarray]:
179
+ cached = self._pose_cache.get(scene)
180
+ if cached is not None:
181
+ return cached
182
+
183
+ meta_path = self._meta_paths.get(scene)
184
+ if meta_path is None:
185
+ raise FileNotFoundError(f"meta.pt not indexed for scene={scene} under {self.root}")
186
+ ex = _torch_load_any(meta_path)
187
+ cams = ex.get("cameras", None)
188
+ if not isinstance(cams, torch.Tensor):
189
+ raise ValueError(f"meta.pt missing 'cameras' tensor for scene={scene}")
190
+ cams = cams.to(torch.float32)
191
+ if cams.ndim != 3 or tuple(cams.shape[1:]) != (4, 4):
192
+ raise ValueError(f"Bad meta.pt cameras shape {tuple(cams.shape)} for scene={scene}")
193
+ R = cams[:, :3, :3].cpu().numpy()
194
+ t = cams[:, :3, 3].cpu().numpy()
195
+ out = (R, t)
196
+ self._pose_cache[scene] = out
197
+ return out
198
+
199
+ def __len__(self) -> int:
200
+ return len(self._index)
201
+
202
+ def _sample_target(self, scene: str, src_idx: int) -> int:
203
+ frames = self._available_frames[scene]
204
+ if len(frames) <= 1:
205
+ return src_idx
206
+ effective_gap = self.src_tgt_max_index_gap
207
+ candidates = [i for i in frames if i != src_idx and abs(i - src_idx) <= effective_gap]
208
+ if not candidates:
209
+ return src_idx
210
+ j = int(torch.randint(low=0, high=len(candidates), size=(1,)).item())
211
+ return int(candidates[j])
212
+
213
+ def _candidate_targets_by_translation(self, scene: str, src_idx: int) -> list[int]:
214
+ frames = self._available_frames[scene]
215
+ if len(frames) <= 1:
216
+ return []
217
+ R_np, t_np = self._get_pose(scene)
218
+ if not (0 <= src_idx < len(t_np) and 0 <= src_idx < len(R_np)):
219
+ return []
220
+ th = float(self.pair_max_translation_m)
221
+
222
+ def _cam_center_from(R: np.ndarray, t: np.ndarray, conv: str) -> np.ndarray:
223
+ if conv in ("c2w", "w2c_t_camcenter"):
224
+ return t
225
+ if conv == "w2c":
226
+ return -(R.transpose(0, 2, 1) @ t[..., None])[..., 0]
227
+ if conv == "c2w_t_w2c":
228
+ return -(R @ t[..., None])[..., 0]
229
+ raise ValueError(conv)
230
+
231
+ def _min_dist(idxs: np.ndarray) -> np.ndarray:
232
+ R_sub = R_np[idxs].astype(np.float32)
233
+ t_sub = t_np[idxs].astype(np.float32)
234
+ R_src = R_np[int(src_idx) : int(src_idx) + 1].astype(np.float32)
235
+ t_src = t_np[int(src_idx) : int(src_idx) + 1].astype(np.float32)
236
+ d_min = None
237
+ for conv in _PAIR_CONVENTIONS:
238
+ C_src = _cam_center_from(R_src, t_src, conv)[0]
239
+ C_sub = _cam_center_from(R_sub, t_sub, conv)
240
+ d = np.linalg.norm(C_sub - C_src[None, :], axis=1)
241
+ d_min = d if (d_min is None) else np.minimum(d_min, d)
242
+ assert d_min is not None
243
+ return d_min
244
+
245
+ effective_gap = self.src_tgt_max_index_gap
246
+ cand0 = np.array([i for i in frames if i != src_idx and abs(i - src_idx) <= effective_gap], dtype=np.int64)
247
+ if cand0.size > 0:
248
+ d0 = _min_dist(cand0)
249
+ ok0 = cand0[d0 < th]
250
+ if ok0.size > 0:
251
+ return [int(x) for x in ok0.tolist()]
252
+ return []
253
+
254
+ def _resize_cube_depth(self, depth: torch.Tensor, face_w: int) -> torch.Tensor:
255
+ if depth.ndim != 4 or depth.shape[0] != 6 or depth.shape[-1] != 1:
256
+ raise ValueError(f"Expected cube depth shape (6,H,W,1), got {tuple(depth.shape)}")
257
+ H = int(depth.shape[1])
258
+ W = int(depth.shape[2])
259
+ if H == face_w and W == face_w:
260
+ return depth.to(dtype=torch.float32)
261
+ import torch.nn.functional as F
262
+
263
+ x = depth.permute(0, 3, 1, 2).to(dtype=torch.float32)
264
+ x = F.interpolate(x, size=(face_w, face_w), mode="bilinear", align_corners=False)
265
+ return x.permute(0, 2, 3, 1).contiguous()
266
+
267
+ @staticmethod
268
+ def _cubemap_z_depth_to_distance(depth: torch.Tensor) -> torch.Tensor:
269
+ if depth.ndim != 4 or depth.shape[0] != 6 or depth.shape[-1] != 1:
270
+ raise ValueError(f"Expected cube depth shape (6,H,W,1), got {tuple(depth.shape)}")
271
+ from unisharp.utils.pano import get_pinhole_intrinsics_4x4
272
+
273
+ h = int(depth.shape[1])
274
+ w = int(depth.shape[2])
275
+ if h != w:
276
+ raise ValueError(f"Expected square cubemap faces, got {(h, w)}")
277
+ depth_61hw = depth.permute(0, 3, 1, 2).to(dtype=torch.float32).contiguous()
278
+ intr = get_pinhole_intrinsics_4x4(w).to(device=depth_61hw.device, dtype=depth_61hw.dtype)
279
+ ys = torch.arange(h, device=depth_61hw.device, dtype=depth_61hw.dtype)
280
+ xs = torch.arange(w, device=depth_61hw.device, dtype=depth_61hw.dtype)
281
+ vv, uu = torch.meshgrid(ys, xs, indexing="ij")
282
+ x = (uu - intr[0, 2]) / intr[0, 0].clamp(min=1e-8)
283
+ y = (vv - intr[1, 2]) / intr[1, 1].clamp(min=1e-8)
284
+ ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
285
+ dist = depth_61hw / ray_z.view(1, 1, h, w).clamp(min=1e-8)
286
+ valid = torch.isfinite(dist) & (dist > 0.0)
287
+ dist = torch.where(valid, dist, torch.zeros_like(dist))
288
+ return dist.permute(0, 2, 3, 1).contiguous()
289
+
290
+ def _pair_depth_overlap_score(
291
+ self,
292
+ *,
293
+ src_R: torch.Tensor,
294
+ src_t: torch.Tensor,
295
+ tgt_R: torch.Tensor,
296
+ tgt_t: torch.Tensor,
297
+ src_cube_depth_m: torch.Tensor,
298
+ tgt_cube_depth_m: torch.Tensor,
299
+ ) -> float:
300
+ from unisharp.utils.camera_projection import build_extrinsics_w2c, view_frustum_mask_cubemap_union # noqa: WPS433
301
+
302
+ device = torch.device("cpu")
303
+ src_R = src_R.to(device=device, dtype=torch.float32)
304
+ src_t = src_t.to(device=device, dtype=torch.float32)
305
+ tgt_R = tgt_R.to(device=device, dtype=torch.float32)
306
+ tgt_t = tgt_t.to(device=device, dtype=torch.float32)
307
+
308
+ face_w = int(self.pair_overlap_face_w)
309
+ margin = float(self.pair_overlap_margin)
310
+ src_d = self._cubemap_z_depth_to_distance(self._resize_cube_depth(src_cube_depth_m.to(device=device), face_w=face_w))
311
+ tgt_d = self._cubemap_z_depth_to_distance(self._resize_cube_depth(tgt_cube_depth_m.to(device=device), face_w=face_w))
312
+
313
+ def _score_one(recipe: tuple[str, bool]) -> float:
314
+ pose_conv, flip_yz = recipe
315
+ extr_src = build_extrinsics_w2c(src_R, src_t, pose_conv)
316
+ extr_tgt = build_extrinsics_w2c(tgt_R, tgt_t, pose_conv)
317
+
318
+ with torch.autocast(device_type="cpu", enabled=False):
319
+ c2w_src = torch.linalg.inv(extr_src)
320
+ c2w_tgt = torch.linalg.inv(extr_tgt)
321
+ if bool(flip_yz):
322
+ D = torch.diag(torch.tensor([1.0, -1.0, -1.0, 1.0], dtype=torch.float32, device=device))
323
+ c2w_src = c2w_src @ D
324
+ c2w_tgt = c2w_tgt @ D
325
+ ref_inv = torch.linalg.inv(c2w_src)
326
+ c2w_src = ref_inv @ c2w_src
327
+ c2w_tgt = ref_inv @ c2w_tgt
328
+ extr_src_n = torch.linalg.inv(c2w_src)
329
+ extr_tgt_n = torch.linalg.inv(c2w_tgt)
330
+
331
+ m_tgt_in_src = view_frustum_mask_cubemap_union(
332
+ depth_novel=tgt_d,
333
+ extr_novel_w2c=extr_tgt_n,
334
+ extr_source_w2c=extr_src_n,
335
+ face_w=face_w,
336
+ margin=margin,
337
+ )
338
+ m_src_in_tgt = view_frustum_mask_cubemap_union(
339
+ depth_novel=src_d,
340
+ extr_novel_w2c=extr_src_n,
341
+ extr_source_w2c=extr_tgt_n,
342
+ face_w=face_w,
343
+ margin=margin,
344
+ )
345
+ tgt_valid = torch.isfinite(tgt_d[..., 0]) & (tgt_d[..., 0] > 0.0)
346
+ src_valid = torch.isfinite(src_d[..., 0]) & (src_d[..., 0] > 0.0)
347
+ denom_t = float(tgt_valid.sum().item())
348
+ denom_s = float(src_valid.sum().item())
349
+ if denom_t < 10 or denom_s < 10:
350
+ return 0.0
351
+ a = float((m_tgt_in_src & tgt_valid).sum().item()) / denom_t
352
+ b = float((m_src_in_tgt & src_valid).sum().item()) / denom_s
353
+ return 0.5 * (a + b)
354
+
355
+ return _score_one(_PAIR_RECIPE_FIXED)
356
+
357
+ def __getitem__(self, idx: int) -> PanOGSSample:
358
+ src_erp: torch.Tensor | None = None
359
+ tgt_erp: torch.Tensor | None = None
360
+ src_dep: torch.Tensor | None = None
361
+ tgt_dep: torch.Tensor | None = None
362
+ src_cube: torch.Tensor | None = None
363
+ tgt_cube: torch.Tensor | None = None
364
+ src_cdep: torch.Tensor | None = None
365
+ tgt_cdep: torch.Tensor | None = None
366
+ last_err: Exception | None = None
367
+
368
+ max_outer = 16
369
+ for outer in range(max_outer):
370
+ scene, src_idx = self._index[int(idx) % len(self._index)]
371
+ scene_dir = self.root / scene
372
+ tgt_idx = self._sample_target(scene, src_idx)
373
+
374
+ max_retries = 8
375
+ ok = False
376
+ for _ in range(max_retries):
377
+ try:
378
+ if src_erp is None:
379
+ src_erp = _load_erp_rgb_u8(scene_dir / "pano" / f"{src_idx:05d}.png")
380
+ src_dep = _depth_to_meters(
381
+ _load_depth_png(scene_dir / "pano_depth" / f"{src_idx:05d}.png"),
382
+ max_depth_m=self.depth_max_m,
383
+ )
384
+ if self.use_cubemap_supervision:
385
+ src_cube_any = _torch_load_any(scene_dir / "cubemaps" / f"{src_idx:05d}.torch")
386
+ src_cdep_any = _torch_load_any(scene_dir / "cubemaps_depth" / f"{src_idx:05d}.torch")
387
+ if not all(isinstance(x, torch.Tensor) for x in [src_cube_any, src_cdep_any]):
388
+ raise RuntimeError("Bad .torch payload for src (expected Tensor).")
389
+ src_cube = cast(torch.Tensor, src_cube_any)
390
+ src_cdep = cast(torch.Tensor, src_cdep_any).to(torch.float32).clamp(min=0.0, max=self.depth_max_m)
391
+ else:
392
+ src_cube = torch.zeros((6, 256, 256, 3), dtype=torch.uint8)
393
+ src_cdep = torch.zeros((6, 256, 256, 1), dtype=torch.float32)
394
+
395
+ candidates: list[int] = []
396
+ if self.pair_sampling and self.use_cubemap_supervision:
397
+ key = (scene, int(src_idx))
398
+ cached = self._pair_valid_tgts.get(key)
399
+ if cached:
400
+ candidates = list(cached)
401
+ else:
402
+ candidates = self._candidate_targets_by_translation(scene, int(src_idx))
403
+ if not candidates:
404
+ candidates = [int(tgt_idx)]
405
+
406
+ tried: set[int] = set()
407
+ found = False
408
+ max_try = (
409
+ 1
410
+ if (not self.pair_sampling or not self.use_cubemap_supervision)
411
+ else max(1, self.pair_max_tries)
412
+ )
413
+ for _try in range(max_try):
414
+ pool = [
415
+ c
416
+ for c in candidates
417
+ if int(c) not in tried and int(c) != int(src_idx)
418
+ ]
419
+ if not pool:
420
+ break
421
+ j = int(torch.randint(0, len(pool), (1,)).item())
422
+ tgt_idx = int(pool[j])
423
+ tried.add(int(tgt_idx))
424
+
425
+ if self.use_cubemap_supervision:
426
+ tgt_cdep_any = _torch_load_any(scene_dir / "cubemaps_depth" / f"{tgt_idx:05d}.torch")
427
+ if not isinstance(tgt_cdep_any, torch.Tensor):
428
+ raise RuntimeError("Bad .torch payload for tgt depth (expected Tensor).")
429
+ tgt_cdep = cast(torch.Tensor, tgt_cdep_any).to(torch.float32).clamp(min=0.0, max=self.depth_max_m)
430
+ else:
431
+ tgt_cdep = torch.zeros((6, 256, 256, 1), dtype=torch.float32)
432
+
433
+ if self.pair_sampling and self.use_cubemap_supervision:
434
+ k = (scene, int(src_idx), int(tgt_idx))
435
+ score = self._pair_overlap_cache.get(k)
436
+ if score is None:
437
+ R_np, t_np = self._get_pose(scene)
438
+ src_R = torch.from_numpy(R_np[int(src_idx)])
439
+ src_t = torch.from_numpy(t_np[int(src_idx)])
440
+ tgt_R = torch.from_numpy(R_np[int(tgt_idx)])
441
+ tgt_t = torch.from_numpy(t_np[int(tgt_idx)])
442
+ score = self._pair_depth_overlap_score(
443
+ src_R=src_R,
444
+ src_t=src_t,
445
+ tgt_R=tgt_R,
446
+ tgt_t=tgt_t,
447
+ src_cube_depth_m=cast(torch.Tensor, src_cdep),
448
+ tgt_cube_depth_m=cast(torch.Tensor, tgt_cdep),
449
+ )
450
+ self._pair_overlap_cache[k] = float(score)
451
+ if float(score) < float(self.pair_min_depth_overlap):
452
+ continue
453
+ kk = (scene, int(src_idx))
454
+ self._pair_valid_tgts.setdefault(kk, []).append(int(tgt_idx))
455
+
456
+ tgt_erp = _load_erp_rgb_u8(scene_dir / "pano" / f"{tgt_idx:05d}.png")
457
+ tgt_dep = _depth_to_meters(
458
+ _load_depth_png(scene_dir / "pano_depth" / f"{tgt_idx:05d}.png"),
459
+ max_depth_m=self.depth_max_m,
460
+ )
461
+ if self.use_cubemap_supervision:
462
+ tgt_cube_any = _torch_load_any(scene_dir / "cubemaps" / f"{tgt_idx:05d}.torch")
463
+ if not isinstance(tgt_cube_any, torch.Tensor):
464
+ raise RuntimeError("Bad .torch payload for tgt RGB cubemap (expected Tensor).")
465
+ tgt_cube = cast(torch.Tensor, tgt_cube_any)
466
+ else:
467
+ tgt_cube = torch.zeros((6, 256, 256, 3), dtype=torch.uint8)
468
+
469
+ found = True
470
+ break
471
+
472
+ if not found:
473
+ raise RuntimeError(
474
+ f"No valid tgt found for scene={scene} src={src_idx} within constraints "
475
+ f"(trans<{self.pair_max_translation_m}m, overlap>{self.pair_min_depth_overlap})."
476
+ )
477
+
478
+ ok = True
479
+ break
480
+ except (FileNotFoundError, RuntimeError, EOFError, KeyError, tarfile.ReadError, OSError) as e:
481
+ last_err = e
482
+ frames = self._available_frames.get(scene, [])
483
+ if not frames:
484
+ break
485
+ src_idx = int(frames[int(torch.randint(0, len(frames), (1,)).item())])
486
+ tgt_idx = self._sample_target(scene, src_idx)
487
+ src_erp = None
488
+ src_dep = None
489
+ src_cube = None
490
+ src_cdep = None
491
+
492
+ if ok:
493
+ break
494
+
495
+ idx = int(idx) + 9973 + outer * 13
496
+ else:
497
+ raise RuntimeError(f"PanOGS __getitem__ failed after retries. last_err={last_err}")
498
+
499
+ assert src_erp is not None and tgt_erp is not None
500
+ assert src_dep is not None and tgt_dep is not None
501
+ assert src_cube is not None and tgt_cube is not None
502
+ assert src_cdep is not None and tgt_cdep is not None
503
+
504
+ src_dep = src_dep.to(torch.float32).unsqueeze(0)
505
+ tgt_dep = tgt_dep.to(torch.float32).unsqueeze(0)
506
+
507
+ R_np, t_np = self._get_pose(scene)
508
+ src_R = torch.from_numpy(R_np[src_idx])
509
+ src_t = torch.from_numpy(t_np[src_idx])
510
+ tgt_R = torch.from_numpy(R_np[tgt_idx])
511
+ tgt_t = torch.from_numpy(t_np[tgt_idx])
512
+
513
+ return PanOGSSample(
514
+ src_erp_rgb_u8=src_erp,
515
+ tgt_erp_rgb_u8=tgt_erp,
516
+ src_erp_depth_m=src_dep,
517
+ tgt_erp_depth_m=tgt_dep,
518
+ src_cube_rgb_u8=src_cube,
519
+ tgt_cube_rgb_u8=tgt_cube,
520
+ src_cube_depth_m=src_cdep,
521
+ tgt_cube_depth_m=tgt_cdep,
522
+ src_R=src_R,
523
+ src_t=src_t,
524
+ tgt_R=tgt_R,
525
+ tgt_t=tgt_t,
526
+ src_idx=src_idx,
527
+ tgt_idx=tgt_idx,
528
+ scene=scene,
529
+ )
530
+
531
+
532
+ def panogs_collate(batch: list[PanOGSSample]) -> PanOGSSample:
533
+ def stack(xs):
534
+ if isinstance(xs[0], torch.Tensor):
535
+ return torch.stack(xs, dim=0)
536
+ return xs
537
+
538
+ return PanOGSSample(
539
+ src_erp_rgb_u8=stack([b.src_erp_rgb_u8 for b in batch]),
540
+ tgt_erp_rgb_u8=stack([b.tgt_erp_rgb_u8 for b in batch]),
541
+ src_erp_depth_m=stack([b.src_erp_depth_m for b in batch]),
542
+ tgt_erp_depth_m=stack([b.tgt_erp_depth_m for b in batch]),
543
+ src_cube_rgb_u8=stack([b.src_cube_rgb_u8 for b in batch]),
544
+ tgt_cube_rgb_u8=stack([b.tgt_cube_rgb_u8 for b in batch]),
545
+ src_cube_depth_m=stack([b.src_cube_depth_m for b in batch]),
546
+ tgt_cube_depth_m=stack([b.tgt_cube_depth_m for b in batch]),
547
+ src_R=stack([b.src_R for b in batch]),
548
+ src_t=stack([b.src_t for b in batch]),
549
+ tgt_R=stack([b.tgt_R for b in batch]),
550
+ tgt_t=stack([b.tgt_t for b in batch]),
551
+ src_idx=[b.src_idx for b in batch], # type: ignore[arg-type]
552
+ tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type]
553
+ scene=[b.scene for b in batch], # type: ignore[arg-type]
554
+ )
555
+
unisharp/datasets/re10k.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict, deque
4
+ from dataclasses import dataclass
5
+ from io import BytesIO
6
+ import logging
7
+ import os
8
+ from pathlib import Path
9
+ import random
10
+ import time
11
+
12
+ import torch
13
+ import torchvision.transforms as tf
14
+ from PIL import Image
15
+ from torch.utils.data import IterableDataset
16
+
17
+ from unisharp.datasets.pair_sampling import (
18
+ project_overlap_ratio,
19
+ resize_k3_align_corners_false,
20
+ resize_rgb_u8_chw_high_quality,
21
+ select_targets_for_source,
22
+ )
23
+ from unisharp import DEFAULT_MAX_DEPTH_M
24
+ from unisharp.utils.pixel_convention import normalized_intrinsics_to_integer_pixel_k
25
+ from unisharp.utils.unik3d_adapter import infer_unik3d_pinhole, load_unik3d_model
26
+
27
+
28
+ LOGGER = logging.getLogger(__name__)
29
+
30
+
31
+ def _torch_load_any(path: Path) -> object:
32
+ try:
33
+ return torch.load(path, map_location="cpu", weights_only=False)
34
+ except TypeError:
35
+ return torch.load(path, map_location="cpu")
36
+
37
+
38
+ def _pack_re10k_batch(batch: list["Re10KPairSample"]) -> "Re10KPairSample":
39
+ def stack(xs):
40
+ if isinstance(xs[0], torch.Tensor):
41
+ ref_shape = tuple(xs[0].shape)
42
+ for idx, x in enumerate(xs[1:], start=1):
43
+ if tuple(x.shape) != ref_shape:
44
+ raise RuntimeError(
45
+ "RE10K collate got mixed tensor shapes: "
46
+ f"ref={ref_shape} mismatch_idx={idx} got={tuple(x.shape)}"
47
+ )
48
+ return torch.stack(xs, dim=0)
49
+ return xs
50
+
51
+ def stack_optional_depth(xs):
52
+ if all(torch.is_tensor(x) for x in xs):
53
+ ref_shape = tuple(xs[0].shape)
54
+ for idx, x in enumerate(xs[1:], start=1):
55
+ if tuple(x.shape) != ref_shape:
56
+ raise RuntimeError(
57
+ "RE10K collate got mixed depth shapes: "
58
+ f"ref={ref_shape} mismatch_idx={idx} got={tuple(x.shape)}"
59
+ )
60
+ return torch.stack(xs, dim=0)
61
+ return None
62
+
63
+ return Re10KPairSample(
64
+ src_rgb_u8=stack([b.src_rgb_u8 for b in batch]),
65
+ tgt_rgb_u8=stack([b.tgt_rgb_u8 for b in batch]),
66
+ src_w2c=stack([b.src_w2c for b in batch]),
67
+ tgt_w2c=stack([b.tgt_w2c for b in batch]),
68
+ src_intrinsics=stack([b.src_intrinsics for b in batch]),
69
+ tgt_intrinsics=stack([b.tgt_intrinsics for b in batch]),
70
+ src_idx=[b.src_idx for b in batch], # type: ignore[arg-type]
71
+ tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type]
72
+ scene=[b.scene for b in batch], # type: ignore[arg-type]
73
+ src_depth_m=stack_optional_depth([b.src_depth_m for b in batch]), # type: ignore[arg-type]
74
+ tgt_depth_m=stack_optional_depth([b.tgt_depth_m for b in batch]), # type: ignore[arg-type]
75
+ )
76
+
77
+
78
+ def re10k_passthrough(batch: "Re10KPairSample") -> "Re10KPairSample":
79
+ return batch
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class Re10KPairSample:
84
+ src_rgb_u8: torch.Tensor
85
+ tgt_rgb_u8: torch.Tensor
86
+ src_w2c: torch.Tensor
87
+ tgt_w2c: torch.Tensor
88
+ src_intrinsics: torch.Tensor
89
+ tgt_intrinsics: torch.Tensor
90
+ src_idx: int
91
+ tgt_idx: int
92
+ scene: str
93
+ src_depth_m: torch.Tensor | None = None
94
+ tgt_depth_m: torch.Tensor | None = None
95
+
96
+
97
+ class Re10KDataset(IterableDataset):
98
+
99
+ def __init__(
100
+ self,
101
+ root: Path,
102
+ chunks_file: Path | None = None,
103
+ split: str = "train",
104
+ min_frame_gap: int = 1,
105
+ max_frame_gap: int = 32,
106
+ pair_max_translation_m: float = 0.5,
107
+ pair_min_overlap: float = 0.6,
108
+ pair_overlap_sample_h: int = 32,
109
+ pair_overlap_sample_w: int = 56,
110
+ pair_max_tries: int = 32,
111
+ output_h: int | None = None,
112
+ output_w: int | None = None,
113
+ shuffle_chunk: bool = True,
114
+ shuffle_example: bool = True,
115
+ ddp_rank: int = 0,
116
+ ddp_world_size: int = 1,
117
+ pseudo_depth_root: Path | None = None,
118
+ pseudo_depth_autogen: bool = True,
119
+ pseudo_depth_backbone: str = "vitl",
120
+ pseudo_depth_device: str = "cpu",
121
+ pseudo_lock_timeout_sec: float = 120.0,
122
+ pseudo_lock_stale_sec: float = 1800.0,
123
+ pseudo_wait_poll_sec: float = 0.25,
124
+ batch_size_hint: int = 1,
125
+ depth_max_m: float = DEFAULT_MAX_DEPTH_M,
126
+ pseudo_far_depth_invalid_m: float = 30.0,
127
+ seed: int = 0,
128
+ ) -> None:
129
+ super().__init__()
130
+ self.root = root
131
+ self.split = split
132
+ self.min_frame_gap = int(min_frame_gap)
133
+ self.max_frame_gap = int(max_frame_gap)
134
+ self.pair_max_translation_m = float(pair_max_translation_m)
135
+ self.pair_min_overlap = float(pair_min_overlap)
136
+ self.pair_overlap_sample_h = int(pair_overlap_sample_h)
137
+ self.pair_overlap_sample_w = int(pair_overlap_sample_w)
138
+ self.pair_max_tries = int(pair_max_tries)
139
+ self.output_h = int(output_h) if output_h is not None else None
140
+ self.output_w = int(output_w) if output_w is not None else None
141
+ self.shuffle_chunk = bool(shuffle_chunk)
142
+ self.shuffle_example = bool(shuffle_example)
143
+ self.ddp_rank = int(ddp_rank)
144
+ self.ddp_world_size = int(ddp_world_size)
145
+ self.to_tensor = tf.ToTensor()
146
+ self.pseudo_depth_root = Path(pseudo_depth_root) if pseudo_depth_root is not None else None
147
+ self.pseudo_depth_autogen = bool(pseudo_depth_autogen)
148
+ self.pseudo_depth_backbone = str(pseudo_depth_backbone)
149
+ self.pseudo_depth_device = str(pseudo_depth_device)
150
+ self.pseudo_lock_timeout_sec = float(max(1.0, pseudo_lock_timeout_sec))
151
+ self.pseudo_lock_stale_sec = float(max(30.0, pseudo_lock_stale_sec))
152
+ self.pseudo_wait_poll_sec = float(max(0.05, pseudo_wait_poll_sec))
153
+ self.batch_size_hint = int(max(1, batch_size_hint))
154
+ self.depth_max_m = float(depth_max_m)
155
+ self.pseudo_far_depth_invalid_m = float(pseudo_far_depth_invalid_m)
156
+ self._pseudo_model: torch.nn.Module | None = None
157
+ self.seed = int(seed)
158
+ self.epoch = 0
159
+
160
+ self.chunks_file = Path(chunks_file) if chunks_file is not None else None
161
+ split_dir = self.root / self.split
162
+ if self.chunks_file is not None:
163
+ if not self.chunks_file.exists():
164
+ raise FileNotFoundError(self.chunks_file)
165
+ chunks: list[Path] = []
166
+ for raw in self.chunks_file.read_text(encoding="utf-8").splitlines():
167
+ line = raw.strip()
168
+ if not line:
169
+ continue
170
+ p = Path(line)
171
+ if not p.is_absolute():
172
+ p = split_dir / p
173
+ if p.suffix == ".torch":
174
+ chunks.append(p)
175
+ self.chunks = sorted(chunks)
176
+ else:
177
+ if not split_dir.exists():
178
+ raise FileNotFoundError(split_dir)
179
+ self.chunks = sorted([p for p in split_dir.iterdir() if p.suffix == ".torch"])
180
+ if not self.chunks:
181
+ source = self.chunks_file if self.chunks_file is not None else split_dir
182
+ raise RuntimeError(f"No .torch chunks found for {source}")
183
+
184
+ def set_epoch(self, epoch: int) -> None:
185
+ self.epoch = int(epoch)
186
+
187
+ if self.pseudo_depth_root is not None:
188
+ (self.pseudo_depth_root / self.split).mkdir(parents=True, exist_ok=True)
189
+
190
+ @staticmethod
191
+ def _decode_image_u8(image_bytes_tensor: torch.Tensor) -> torch.Tensor:
192
+ if image_bytes_tensor.dtype != torch.uint8:
193
+ raise ValueError(f"Expected uint8 bytes tensor, got {image_bytes_tensor.dtype}")
194
+ image = Image.open(BytesIO(image_bytes_tensor.numpy().tobytes())).convert("RGB")
195
+ chw_float = tf.ToTensor()(image)
196
+ return (chw_float * 255.0).round().to(torch.uint8)
197
+
198
+ @staticmethod
199
+ def _convert_pose_row_to_w2c(poses: torch.Tensor) -> torch.Tensor:
200
+ t = poses.shape[0]
201
+ w2c = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(t, 1, 1)
202
+ w2c[:, :3] = poses[:, 6:].reshape(t, 3, 4).to(torch.float32)
203
+ return w2c
204
+
205
+ @staticmethod
206
+ def _convert_intrinsics_to_pixel(poses: torch.Tensor, h: int, w: int) -> torch.Tensor:
207
+ t = poses.shape[0]
208
+ fx, fy, cx, cy = poses[:, 0], poses[:, 1], poses[:, 2], poses[:, 3]
209
+ del t
210
+ return normalized_intrinsics_to_integer_pixel_k(
211
+ fx,
212
+ fy,
213
+ cx,
214
+ cy,
215
+ height=int(h),
216
+ width=int(w),
217
+ )
218
+
219
+ @staticmethod
220
+ def _sanitize_scene(scene: str) -> str:
221
+ s = str(scene).strip()
222
+ s = s.replace("\\", "__").replace("/", "__")
223
+ return s if len(s) > 0 else "unknown_scene"
224
+
225
+ def _pseudo_depth_path(self, scene: str, frame_idx: int) -> Path | None:
226
+ if self.pseudo_depth_root is None:
227
+ return None
228
+ scene_key = self._sanitize_scene(scene)
229
+ return self.pseudo_depth_root / self.split / scene_key / f"{int(frame_idx):05d}.pt"
230
+
231
+ @staticmethod
232
+ def _load_pseudo_depth(path: Path) -> tuple[torch.Tensor | None, str]:
233
+ if not path.exists():
234
+ return None, "unknown"
235
+ try:
236
+ payload = _torch_load_any(path)
237
+ depth_kind = "distance"
238
+ if isinstance(payload, dict):
239
+ depth = payload.get("depth_m", None)
240
+ depth_kind = str(payload.get("depth_kind", "distance")).strip().lower()
241
+ if depth_kind not in ("distance", "zdepth"):
242
+ depth_kind = "distance"
243
+ else:
244
+ depth = payload
245
+ if not torch.is_tensor(depth):
246
+ return None, "unknown"
247
+ if depth.ndim == 3 and depth.shape[0] == 1:
248
+ depth = depth[0]
249
+ if depth.ndim != 2:
250
+ return None, "unknown"
251
+ depth = depth.to(torch.float32)
252
+ valid = torch.isfinite(depth) & (depth > 0.0)
253
+ if int(valid.sum().item()) <= 0:
254
+ return None, "unknown"
255
+ return depth.unsqueeze(0), depth_kind
256
+ except Exception:
257
+ return None, "unknown"
258
+
259
+ @staticmethod
260
+ def _distance_to_z_depth(depth_1hw: torch.Tensor, intrinsics_k3: torch.Tensor) -> torch.Tensor:
261
+ if depth_1hw.ndim != 3 or depth_1hw.shape[0] != 1:
262
+ raise ValueError(f"Expected depth shape (1,H,W), got {tuple(depth_1hw.shape)}")
263
+ d = depth_1hw.to(torch.float32)
264
+ h = int(d.shape[-2])
265
+ w = int(d.shape[-1])
266
+ k = intrinsics_k3.to(dtype=torch.float32, device=d.device)
267
+ fx = k[0, 0]
268
+ fy = k[1, 1]
269
+ cx = k[0, 2]
270
+ cy = k[1, 2]
271
+ ys = torch.arange(h, device=d.device, dtype=torch.float32)
272
+ xs = torch.arange(w, device=d.device, dtype=torch.float32)
273
+ vv, uu = torch.meshgrid(ys, xs, indexing="ij")
274
+ x = (uu - cx) / fx
275
+ y = (vv - cy) / fy
276
+ ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
277
+ z = d[0] * ray_z
278
+ return z.unsqueeze(0)
279
+
280
+ @staticmethod
281
+ def _sanitize_pseudo_depth(
282
+ depth_1hw: torch.Tensor,
283
+ *,
284
+ max_depth_m: float = DEFAULT_MAX_DEPTH_M,
285
+ far_depth_invalid_m: float = 30.0,
286
+ ) -> torch.Tensor:
287
+ d = depth_1hw.to(torch.float32)
288
+ valid = torch.isfinite(d) & (d > 0.0)
289
+ if int(valid.sum().item()) <= 0:
290
+ return d
291
+ out = d.clone()
292
+ if float(far_depth_invalid_m) > 0.0:
293
+ valid = valid & (out <= float(far_depth_invalid_m))
294
+ out = torch.where(valid, out, torch.zeros_like(out))
295
+ out[valid] = out[valid].clamp(max=float(max_depth_m))
296
+ return out
297
+
298
+ def _get_or_create_pseudo_model(self) -> torch.nn.Module:
299
+ if self._pseudo_model is None:
300
+ dev = torch.device(self.pseudo_depth_device)
301
+ self._pseudo_model = load_unik3d_model(
302
+ backbone=self.pseudo_depth_backbone,
303
+ pretrained=True,
304
+ device=dev,
305
+ )
306
+ self._pseudo_model.eval()
307
+ LOGGER.info(
308
+ "Re10K pseudo-depth model loaded (split=%s, device=%s, backbone=%s)",
309
+ self.split,
310
+ str(dev),
311
+ self.pseudo_depth_backbone,
312
+ )
313
+ return self._pseudo_model
314
+
315
+ def _save_pseudo_depth_atomic(
316
+ self,
317
+ path: Path,
318
+ depth_2d: torch.Tensor,
319
+ scene: str,
320
+ frame_idx: int,
321
+ ) -> None:
322
+ path.parent.mkdir(parents=True, exist_ok=True)
323
+ tmp = path.parent / f".tmp_{os.getpid()}_{int(time.time() * 1e6)}_{random.randint(0, 10_000_000)}.pt"
324
+ payload = {
325
+ "depth_m": depth_2d.to(torch.float16),
326
+ "depth_kind": "distance",
327
+ "scene": str(scene),
328
+ "frame_idx": int(frame_idx),
329
+ }
330
+ torch.save(payload, tmp)
331
+ os.replace(tmp, path)
332
+
333
+ def _acquire_lock_or_wait_for_file(self, target: Path) -> tuple[bool, bool]:
334
+ lock_dir = Path(str(target) + ".lock")
335
+ start = time.time()
336
+ while True:
337
+ if target.exists():
338
+ return False, True
339
+ try:
340
+ lock_dir.mkdir(parents=False, exist_ok=False)
341
+ meta = lock_dir / "owner.txt"
342
+ meta.write_text(f"pid={os.getpid()} time={time.time():.3f}\n", encoding="utf-8")
343
+ return True, False
344
+ except FileExistsError:
345
+ try:
346
+ mtime = lock_dir.stat().st_mtime
347
+ if (time.time() - float(mtime)) > self.pseudo_lock_stale_sec:
348
+ for p in lock_dir.iterdir():
349
+ try:
350
+ p.unlink()
351
+ except Exception:
352
+ pass
353
+ lock_dir.rmdir()
354
+ continue
355
+ except Exception:
356
+ pass
357
+ if (time.time() - start) >= self.pseudo_lock_timeout_sec:
358
+ return False, False
359
+ time.sleep(self.pseudo_wait_poll_sec)
360
+ except Exception:
361
+ return False, False
362
+
363
+ def _release_lock(self, target: Path) -> None:
364
+ lock_dir = Path(str(target) + ".lock")
365
+ if not lock_dir.exists():
366
+ return
367
+ try:
368
+ for p in lock_dir.iterdir():
369
+ try:
370
+ p.unlink()
371
+ except Exception:
372
+ pass
373
+ lock_dir.rmdir()
374
+ except Exception:
375
+ pass
376
+
377
+ def _get_pseudo_depth_for_frame(
378
+ self,
379
+ *,
380
+ scene: str,
381
+ frame_idx: int,
382
+ rgb_u8: torch.Tensor,
383
+ intrinsics_k3: torch.Tensor,
384
+ ) -> torch.Tensor | None:
385
+ path = self._pseudo_depth_path(scene, frame_idx)
386
+ if path is None:
387
+ return None
388
+ depth, depth_kind = self._load_pseudo_depth(path)
389
+ if depth is not None:
390
+ if depth_kind != "zdepth":
391
+ try:
392
+ depth = self._distance_to_z_depth(
393
+ self._sanitize_pseudo_depth(
394
+ depth,
395
+ max_depth_m=self.depth_max_m,
396
+ far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
397
+ ),
398
+ intrinsics_k3=intrinsics_k3,
399
+ )
400
+ except Exception:
401
+ return None
402
+ else:
403
+ depth = self._sanitize_pseudo_depth(
404
+ depth,
405
+ max_depth_m=self.depth_max_m,
406
+ far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
407
+ )
408
+ return depth
409
+ if not self.pseudo_depth_autogen:
410
+ return None
411
+
412
+ acquired, ready = self._acquire_lock_or_wait_for_file(path)
413
+ if ready:
414
+ depth, depth_kind = self._load_pseudo_depth(path)
415
+ if depth is None:
416
+ return None
417
+ if depth_kind != "zdepth":
418
+ try:
419
+ depth = self._distance_to_z_depth(
420
+ self._sanitize_pseudo_depth(
421
+ depth,
422
+ max_depth_m=self.depth_max_m,
423
+ far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
424
+ ),
425
+ intrinsics_k3=intrinsics_k3,
426
+ )
427
+ except Exception:
428
+ return None
429
+ else:
430
+ depth = self._sanitize_pseudo_depth(
431
+ depth,
432
+ max_depth_m=self.depth_max_m,
433
+ far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
434
+ )
435
+ return depth
436
+ if not acquired:
437
+ depth, depth_kind = self._load_pseudo_depth(path)
438
+ if depth is None:
439
+ return None
440
+ if depth_kind != "zdepth":
441
+ try:
442
+ depth = self._distance_to_z_depth(
443
+ self._sanitize_pseudo_depth(
444
+ depth,
445
+ max_depth_m=self.depth_max_m,
446
+ far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
447
+ ),
448
+ intrinsics_k3=intrinsics_k3,
449
+ )
450
+ except Exception:
451
+ return None
452
+ else:
453
+ depth = self._sanitize_pseudo_depth(
454
+ depth,
455
+ max_depth_m=self.depth_max_m,
456
+ far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
457
+ )
458
+ return depth
459
+
460
+ try:
461
+ depth, depth_kind = self._load_pseudo_depth(path)
462
+ if depth is not None:
463
+ if depth_kind != "zdepth":
464
+ try:
465
+ depth = self._distance_to_z_depth(
466
+ self._sanitize_pseudo_depth(
467
+ depth,
468
+ max_depth_m=self.depth_max_m,
469
+ far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
470
+ ),
471
+ intrinsics_k3=intrinsics_k3,
472
+ )
473
+ except Exception:
474
+ return None
475
+ else:
476
+ depth = self._sanitize_pseudo_depth(
477
+ depth,
478
+ max_depth_m=self.depth_max_m,
479
+ far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
480
+ )
481
+ return depth
482
+ model = self._get_or_create_pseudo_model()
483
+ out = infer_unik3d_pinhole(
484
+ model,
485
+ rgb_u8=rgb_u8.unsqueeze(0),
486
+ intrinsics=intrinsics_k3.unsqueeze(0),
487
+ )
488
+ dist = out.get("distance", None) if isinstance(out, dict) else None
489
+ if not torch.is_tensor(dist) or dist.ndim != 4 or dist.shape[1] != 1:
490
+ return None
491
+ dist_1hw = self._sanitize_pseudo_depth(
492
+ dist[0:1, 0:1].detach().to(torch.float32).cpu()[0],
493
+ max_depth_m=self.depth_max_m,
494
+ far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
495
+ )
496
+ valid = torch.isfinite(dist_1hw) & (dist_1hw > 0.0)
497
+ if int(valid.sum().item()) <= 0:
498
+ return None
499
+ self._save_pseudo_depth_atomic(
500
+ path,
501
+ depth_2d=dist_1hw[0],
502
+ scene=scene,
503
+ frame_idx=frame_idx,
504
+ )
505
+ return self._distance_to_z_depth(dist_1hw, intrinsics_k3=intrinsics_k3.cpu())
506
+ except Exception as e:
507
+ LOGGER.warning(
508
+ "Pseudo-depth generation failed scene=%s frame=%d: %s",
509
+ str(scene),
510
+ int(frame_idx),
511
+ str(e),
512
+ )
513
+ return None
514
+ finally:
515
+ self._release_lock(path)
516
+
517
+ def _candidate_target_indices(
518
+ self,
519
+ src_idx: int,
520
+ num_frames: int,
521
+ w2c_all: torch.Tensor,
522
+ intr_all: torch.Tensor,
523
+ h: int,
524
+ w: int,
525
+ ) -> list[int]:
526
+ if num_frames < 2:
527
+ return []
528
+ centers = torch.linalg.inv(w2c_all)[:, :3, 3].to(torch.float32)
529
+ sample_h = int(self.pair_overlap_sample_h)
530
+ sample_w = int(self.pair_overlap_sample_w)
531
+ return select_targets_for_source(
532
+ src_idx=int(src_idx),
533
+ candidate_indices=list(range(num_frames)),
534
+ centers=centers,
535
+ min_index_gap=int(self.min_frame_gap),
536
+ max_index_gap=int(self.max_frame_gap),
537
+ pair_max_translation_m=float(self.pair_max_translation_m),
538
+ pair_min_overlap=float(self.pair_min_overlap),
539
+ overlap_score_fn=lambda si, tj: float(
540
+ 0.5
541
+ * (
542
+ project_overlap_ratio(
543
+ src_w2c=w2c_all[si],
544
+ tgt_w2c=w2c_all[tj],
545
+ src_k=intr_all[si],
546
+ tgt_k=intr_all[tj],
547
+ h=h,
548
+ w=w,
549
+ sample_h=sample_h,
550
+ sample_w=sample_w,
551
+ )
552
+ + project_overlap_ratio(
553
+ src_w2c=w2c_all[tj],
554
+ tgt_w2c=w2c_all[si],
555
+ src_k=intr_all[tj],
556
+ tgt_k=intr_all[si],
557
+ h=h,
558
+ w=w,
559
+ sample_h=sample_h,
560
+ sample_w=sample_w,
561
+ )
562
+ )
563
+ ),
564
+ )
565
+
566
+ def __iter__(self):
567
+ chunks = list(self.chunks)
568
+ order_rng = random.Random(self.seed + self.epoch)
569
+ if self.shuffle_chunk and self.split == "train":
570
+ order_rng.shuffle(chunks)
571
+ pending_by_hw: dict[tuple[int, int], deque[Re10KPairSample]] = defaultdict(deque)
572
+
573
+ worker_info = torch.utils.data.get_worker_info()
574
+ num_workers = worker_info.num_workers if worker_info is not None else 1
575
+ worker_id = worker_info.id if worker_info is not None else 0
576
+ total_shards = max(1, self.ddp_world_size * num_workers)
577
+ shard_id = self.ddp_rank * num_workers + worker_id
578
+ chunks = [chunk for i, chunk in enumerate(chunks) if i % total_shards == shard_id]
579
+
580
+ for chunk_order_idx, chunk_path in enumerate(chunks):
581
+ chunk = _torch_load_any(chunk_path)
582
+ if not isinstance(chunk, list):
583
+ continue
584
+ examples = list(chunk)
585
+ chunk_rng = random.Random(self.seed + self.epoch * 1000003 + chunk_order_idx)
586
+ if self.shuffle_example and self.split == "train":
587
+ chunk_rng.shuffle(examples)
588
+
589
+ for example in examples:
590
+ if not isinstance(example, dict):
591
+ continue
592
+ if "cameras" not in example or "images" not in example:
593
+ continue
594
+ poses = example["cameras"]
595
+ images = example["images"]
596
+ scene = str(example.get("key", "unknown"))
597
+ if not torch.is_tensor(poses) or not isinstance(images, list):
598
+ continue
599
+ if poses.ndim != 2 or poses.shape[1] != 18:
600
+ continue
601
+ if len(images) != int(poses.shape[0]):
602
+ continue
603
+
604
+ try:
605
+ src_probe = self._decode_image_u8(images[0])
606
+ except Exception:
607
+ continue
608
+ h, w = int(src_probe.shape[1]), int(src_probe.shape[2])
609
+ w2c_all = self._convert_pose_row_to_w2c(poses)
610
+ intr_all = self._convert_intrinsics_to_pixel(poses, h=h, w=w)
611
+ src_indices = list(range(len(images)))
612
+ if self.shuffle_example and self.split == "train":
613
+ chunk_rng.shuffle(src_indices)
614
+ for src_idx in src_indices:
615
+ tgt_candidates = self._candidate_target_indices(
616
+ int(src_idx),
617
+ len(images),
618
+ w2c_all=w2c_all,
619
+ intr_all=intr_all,
620
+ h=h,
621
+ w=w,
622
+ )
623
+ if not tgt_candidates:
624
+ continue
625
+ tgt_idx = chunk_rng.choice(tgt_candidates)
626
+
627
+ try:
628
+ src_img = self._decode_image_u8(images[src_idx])
629
+ tgt_img = self._decode_image_u8(images[tgt_idx])
630
+ except Exception:
631
+ continue
632
+ if src_img.shape != tgt_img.shape:
633
+ continue
634
+ src_intr = intr_all[src_idx].clone()
635
+ tgt_intr = intr_all[tgt_idx].clone()
636
+ src_depth = self._get_pseudo_depth_for_frame(
637
+ scene=scene,
638
+ frame_idx=int(src_idx),
639
+ rgb_u8=src_img,
640
+ intrinsics_k3=intr_all[src_idx].to(torch.float32),
641
+ )
642
+ tgt_depth = self._get_pseudo_depth_for_frame(
643
+ scene=scene,
644
+ frame_idx=int(tgt_idx),
645
+ rgb_u8=tgt_img,
646
+ intrinsics_k3=intr_all[tgt_idx].to(torch.float32),
647
+ )
648
+ if self.pseudo_depth_root is not None and (
649
+ (not torch.is_tensor(src_depth)) or (not torch.is_tensor(tgt_depth))
650
+ ):
651
+ continue
652
+ if self.output_h is not None and self.output_w is not None:
653
+ oh, ow = int(src_img.shape[1]), int(src_img.shape[2])
654
+ if oh > 0 and ow > 0 and (oh != self.output_h or ow != self.output_w):
655
+ sx = float(self.output_w) / float(ow)
656
+ sy = float(self.output_h) / float(oh)
657
+ src_img = resize_rgb_u8_chw_high_quality(src_img, size=(self.output_h, self.output_w))
658
+ tgt_img = resize_rgb_u8_chw_high_quality(tgt_img, size=(self.output_h, self.output_w))
659
+ src_intr = resize_k3_align_corners_false(src_intr, sx=sx, sy=sy)
660
+ tgt_intr = resize_k3_align_corners_false(tgt_intr, sx=sx, sy=sy)
661
+ if torch.is_tensor(src_depth):
662
+ src_depth = (
663
+ torch.nn.functional.interpolate(
664
+ src_depth[None],
665
+ size=(self.output_h, self.output_w),
666
+ mode="bilinear",
667
+ align_corners=False,
668
+ )
669
+ .squeeze(0)
670
+ .to(torch.float32)
671
+ )
672
+ if torch.is_tensor(tgt_depth):
673
+ tgt_depth = (
674
+ torch.nn.functional.interpolate(
675
+ tgt_depth[None],
676
+ size=(self.output_h, self.output_w),
677
+ mode="bilinear",
678
+ align_corners=False,
679
+ )
680
+ .squeeze(0)
681
+ .to(torch.float32)
682
+ )
683
+
684
+ sample = Re10KPairSample(
685
+ src_rgb_u8=src_img,
686
+ tgt_rgb_u8=tgt_img,
687
+ src_w2c=w2c_all[src_idx],
688
+ tgt_w2c=w2c_all[tgt_idx],
689
+ src_intrinsics=src_intr,
690
+ tgt_intrinsics=tgt_intr,
691
+ src_idx=int(src_idx),
692
+ tgt_idx=int(tgt_idx),
693
+ scene=scene,
694
+ src_depth_m=src_depth,
695
+ tgt_depth_m=tgt_depth,
696
+ )
697
+ hw_key = (int(sample.src_rgb_u8.shape[1]), int(sample.src_rgb_u8.shape[2]))
698
+ bucket = pending_by_hw[hw_key]
699
+ bucket.append(sample)
700
+ if self.batch_size_hint <= 1:
701
+ yield bucket.popleft()
702
+ continue
703
+ while len(bucket) >= self.batch_size_hint:
704
+ packed = [bucket.popleft() for _ in range(self.batch_size_hint)]
705
+ yield _pack_re10k_batch(packed)
706
+
707
+ dropped = sum(len(bucket) for bucket in pending_by_hw.values())
708
+ if dropped > 0 and self.split == "train" and self.batch_size_hint > 1:
709
+ LOGGER.debug(
710
+ "Dropped %d RE10K leftover samples that could not form a same-resolution batch of size %d.",
711
+ int(dropped),
712
+ int(self.batch_size_hint),
713
+ )
714
+
715
+
716
+ def re10k_collate(batch: list[Re10KPairSample]) -> Re10KPairSample:
717
+ return _pack_re10k_batch(batch)
718
+
unisharp/datasets/scannetpp_fisheye.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ import json
6
+ import logging
7
+ from pathlib import Path
8
+ import random
9
+
10
+ import numpy as np
11
+ import torch
12
+ from PIL import Image
13
+ from torch.utils.data import IterableDataset
14
+
15
+ from unisharp import DEFAULT_MAX_DEPTH_M
16
+
17
+
18
+ LOGGER = logging.getLogger(__name__)
19
+ IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
20
+ DEPTH_DIR_NAMES = ("depth", "depths", "distance", "distances", "depth_maps")
21
+ MASK_DIR_NAMES = ("masks", "mask")
22
+ DEPTH_MAX_M = DEFAULT_MAX_DEPTH_M
23
+
24
+
25
+ def _qvec_to_rotmat(qvec: np.ndarray) -> np.ndarray:
26
+ q = np.asarray(qvec, dtype=np.float64)
27
+ return np.array(
28
+ [
29
+ [1 - 2 * q[2] ** 2 - 2 * q[3] ** 2, 2 * q[1] * q[2] - 2 * q[0] * q[3], 2 * q[3] * q[1] + 2 * q[0] * q[2]],
30
+ [2 * q[1] * q[2] + 2 * q[0] * q[3], 1 - 2 * q[1] ** 2 - 2 * q[3] ** 2, 2 * q[2] * q[3] - 2 * q[0] * q[1]],
31
+ [2 * q[3] * q[1] - 2 * q[0] * q[2], 2 * q[2] * q[3] + 2 * q[0] * q[1], 1 - 2 * q[1] ** 2 - 2 * q[2] ** 2],
32
+ ],
33
+ dtype=np.float64,
34
+ )
35
+
36
+
37
+ def _read_colmap_w2c(images_txt: Path) -> dict[str, torch.Tensor]:
38
+ poses: dict[str, torch.Tensor] = {}
39
+ if not images_txt.exists():
40
+ return poses
41
+ with images_txt.open("r", encoding="utf-8") as f:
42
+ for raw in f:
43
+ line = raw.strip()
44
+ if not line or line.startswith("#"):
45
+ continue
46
+ parts = line.split()
47
+ if len(parts) < 10:
48
+ continue
49
+ try:
50
+ qvec = np.asarray([float(x) for x in parts[1:5]], dtype=np.float64)
51
+ tvec = np.asarray([float(x) for x in parts[5:8]], dtype=np.float64)
52
+ image_name = parts[9]
53
+ except Exception:
54
+ continue
55
+ w2c = np.eye(4, dtype=np.float32)
56
+ w2c[:3, :3] = _qvec_to_rotmat(qvec).astype(np.float32)
57
+ w2c[:3, 3] = tvec.astype(np.float32)
58
+ poses[Path(image_name).name] = torch.from_numpy(w2c)
59
+ return poses
60
+
61
+
62
+ def _opencv_fisheye_to_fisheye624_params(meta: dict[str, object]) -> torch.Tensor:
63
+ if str(meta.get("camera_model", "")) != "OPENCV_FISHEYE":
64
+ raise RuntimeError(f"Unsupported ScanNet++ camera_model={meta.get('camera_model')!r}; expected OPENCV_FISHEYE.")
65
+ return torch.tensor(
66
+ [
67
+ float(meta["fl_x"]),
68
+ float(meta["fl_y"]),
69
+ float(meta["cx"]),
70
+ float(meta["cy"]),
71
+ float(meta.get("k1", 0.0)),
72
+ float(meta.get("k2", 0.0)),
73
+ float(meta.get("k3", 0.0)),
74
+ float(meta.get("k4", 0.0)),
75
+ 0.0,
76
+ 0.0,
77
+ 0.0,
78
+ 0.0,
79
+ 0.0,
80
+ 0.0,
81
+ 0.0,
82
+ 0.0,
83
+ ],
84
+ dtype=torch.float32,
85
+ )
86
+
87
+
88
+ def _camera_hw_from_meta(meta: dict[str, object]) -> tuple[int, int] | None:
89
+ h = meta.get("h", meta.get("height", None))
90
+ w = meta.get("w", meta.get("width", None))
91
+ if h is None or w is None:
92
+ return None
93
+ try:
94
+ h_i, w_i = int(h), int(w)
95
+ except Exception:
96
+ return None
97
+ return (h_i, w_i) if h_i > 0 and w_i > 0 else None
98
+
99
+
100
+ def _scale_fisheye624_params(
101
+ params: torch.Tensor,
102
+ *,
103
+ src_hw: tuple[int, int],
104
+ dst_hw: tuple[int, int],
105
+ ) -> torch.Tensor:
106
+ if tuple(int(x) for x in src_hw) == tuple(int(x) for x in dst_hw):
107
+ return params.clone()
108
+ src_h, src_w = int(src_hw[0]), int(src_hw[1])
109
+ dst_h, dst_w = int(dst_hw[0]), int(dst_hw[1])
110
+ sx = float(dst_w) / float(max(src_w, 1))
111
+ sy = float(dst_h) / float(max(src_h, 1))
112
+ out = params.clone()
113
+ out[..., 0] *= sx
114
+ out[..., 1] *= sy
115
+ out[..., 2] = (out[..., 2] + 0.5) * sx - 0.5
116
+ out[..., 3] = (out[..., 3] + 0.5) * sy - 0.5
117
+ return out
118
+
119
+
120
+ def _stack_batch(batch: list["ScannetppFisheyePairSample"]) -> "ScannetppFisheyePairSample":
121
+ return ScannetppFisheyePairSample(
122
+ src_rgb_u8=torch.stack([b.src_rgb_u8 for b in batch], dim=0),
123
+ tgt_rgb_u8=torch.stack([b.tgt_rgb_u8 for b in batch], dim=0),
124
+ src_depth_m=torch.stack([b.src_depth_m for b in batch], dim=0),
125
+ tgt_depth_m=torch.stack([b.tgt_depth_m for b in batch], dim=0),
126
+ src_valid_mask=torch.stack([b.src_valid_mask for b in batch], dim=0),
127
+ tgt_valid_mask=torch.stack([b.tgt_valid_mask for b in batch], dim=0),
128
+ src_w2c=torch.stack([b.src_w2c for b in batch], dim=0),
129
+ tgt_w2c=torch.stack([b.tgt_w2c for b in batch], dim=0),
130
+ src_camera_params=torch.stack([b.src_camera_params for b in batch], dim=0),
131
+ tgt_camera_params=torch.stack([b.tgt_camera_params for b in batch], dim=0),
132
+ src_idx=[b.src_idx for b in batch], # type: ignore[arg-type]
133
+ tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type]
134
+ scene=[b.scene for b in batch], # type: ignore[arg-type]
135
+ camera_model="fisheye624",
136
+ )
137
+
138
+
139
+ @dataclass(frozen=True)
140
+ class ScannetppFisheyePairSample:
141
+ src_rgb_u8: torch.Tensor
142
+ tgt_rgb_u8: torch.Tensor
143
+ src_depth_m: torch.Tensor
144
+ tgt_depth_m: torch.Tensor
145
+ src_valid_mask: torch.Tensor
146
+ tgt_valid_mask: torch.Tensor
147
+ src_w2c: torch.Tensor
148
+ tgt_w2c: torch.Tensor
149
+ src_camera_params: torch.Tensor
150
+ tgt_camera_params: torch.Tensor
151
+ src_idx: int
152
+ tgt_idx: int
153
+ scene: str
154
+ camera_model: str = "fisheye624"
155
+
156
+
157
+ def scannetpp_fisheye_passthrough(batch: ScannetppFisheyePairSample) -> ScannetppFisheyePairSample:
158
+ return batch
159
+
160
+
161
+ class ScannetppFisheyeDataset(IterableDataset):
162
+
163
+ def __init__(
164
+ self,
165
+ root: Path,
166
+ scene_list_file: Path | None = None,
167
+ min_frame_gap: int = 1,
168
+ max_frame_gap: int = 10,
169
+ pair_max_translation_m: float = 0.5,
170
+ shuffle_scene: bool = True,
171
+ shuffle_frame: bool = True,
172
+ skip_bad: bool = True,
173
+ ddp_rank: int = 0,
174
+ ddp_world_size: int = 1,
175
+ batch_size_hint: int = 1,
176
+ depth_max_m: float = DEFAULT_MAX_DEPTH_M,
177
+ far_depth_invalid_m: float = 30.0,
178
+ seed: int = 0,
179
+ ) -> None:
180
+ super().__init__()
181
+ self.root = Path(root)
182
+ self.min_frame_gap = int(min_frame_gap)
183
+ self.max_frame_gap = int(max_frame_gap)
184
+ self.pair_max_translation_m = float(pair_max_translation_m)
185
+ self.shuffle_scene = bool(shuffle_scene)
186
+ self.shuffle_frame = bool(shuffle_frame)
187
+ self.skip_bad = bool(skip_bad)
188
+ self.ddp_rank = int(ddp_rank)
189
+ self.ddp_world_size = int(ddp_world_size)
190
+ self.batch_size_hint = int(max(1, batch_size_hint))
191
+ self.depth_max_m = float(depth_max_m)
192
+ self.far_depth_invalid_m = float(far_depth_invalid_m)
193
+ self.seed = int(seed)
194
+ self.epoch = 0
195
+ self.scene_specs = self._load_scene_specs(scene_list_file)
196
+ if not self.scene_specs:
197
+ raise RuntimeError(f"No ScanNet++ fisheye scenes found under {self.root}")
198
+
199
+ def set_epoch(self, epoch: int) -> None:
200
+ self.epoch = int(epoch)
201
+
202
+ def _load_scene_specs(self, scene_list_file: Path | None) -> list[tuple[str, Path]]:
203
+ specs: list[tuple[str, Path]] = []
204
+ if scene_list_file is not None and Path(scene_list_file).exists():
205
+ for raw in Path(scene_list_file).read_text(encoding="utf-8").splitlines():
206
+ line = raw.strip()
207
+ if not line:
208
+ continue
209
+ parts = line.split("|")
210
+ if len(parts) == 1:
211
+ scene_dir = Path(parts[0])
212
+ scene_id = scene_dir.name
213
+ else:
214
+ scene_id = parts[0]
215
+ scene_dir = Path(parts[1])
216
+ if not scene_dir.is_absolute():
217
+ scene_dir = self.root / scene_dir
218
+ specs.append((scene_id, scene_dir))
219
+ return specs
220
+ for transforms in sorted(self.root.glob("*/nerfstudio/transforms.json")):
221
+ specs.append((transforms.parent.parent.name, transforms.parent.parent))
222
+ for transforms in sorted(self.root.glob("*/*/nerfstudio/transforms.json")):
223
+ specs.append((f"{transforms.parent.parent.parent.name}/{transforms.parent.parent.name}", transforms.parent.parent))
224
+ return specs
225
+
226
+ @staticmethod
227
+ def _load_rgb(path: Path) -> torch.Tensor:
228
+ with Image.open(path) as image:
229
+ arr = np.asarray(image.convert("RGB"), dtype=np.uint8).copy()
230
+ return torch.from_numpy(arr).permute(2, 0, 1).contiguous()
231
+
232
+ @staticmethod
233
+ def _load_mask(path: Path, image_hw: tuple[int, int]) -> torch.Tensor | None:
234
+ if not path.exists():
235
+ return None
236
+ with Image.open(path) as image:
237
+ arr = np.asarray(image.convert("L"), dtype=np.uint8).copy()
238
+ mask = torch.from_numpy(arr).unsqueeze(0).to(torch.float32) / 255.0
239
+ if tuple(mask.shape[-2:]) != tuple(image_hw):
240
+ mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=image_hw, mode="nearest").squeeze(0)
241
+ return (mask > 0.5).to(torch.float32)
242
+
243
+ def _load_depth_map(self, path: Path) -> tuple[torch.Tensor, str]:
244
+ depth_kind = "distance"
245
+ if path.suffix.lower() == ".npz":
246
+ payload = np.load(path, allow_pickle=False)
247
+ for key in ("distance_m", "depth_m", "distance", "depth"):
248
+ if key in payload:
249
+ arr = payload[key]
250
+ if key in {"distance_m", "distance"}:
251
+ depth_kind = "distance"
252
+ elif "depth_kind" in payload:
253
+ depth_kind = str(np.asarray(payload["depth_kind"]).item()).strip().lower()
254
+ break
255
+ else:
256
+ raise RuntimeError(f"Unsupported ScanNet++ depth payload keys at {path}")
257
+ else:
258
+ arr = np.load(path)
259
+ depth = torch.from_numpy(np.asarray(arr, dtype=np.float32).copy())
260
+ if depth.ndim == 3 and depth.shape[0] == 1:
261
+ depth = depth[0]
262
+ if depth.ndim != 2:
263
+ raise RuntimeError(f"Expected 2D fisheye depth at {path}, got shape={tuple(depth.shape)}")
264
+ depth = depth.unsqueeze(0)
265
+ valid = torch.isfinite(depth) & (depth > 0.0)
266
+ if self.far_depth_invalid_m > 0.0:
267
+ valid = valid & (depth <= self.far_depth_invalid_m)
268
+ depth = torch.where(valid, depth, torch.zeros_like(depth))
269
+ if depth_kind in {"radial", "radius", "dist"}:
270
+ depth_kind = "distance"
271
+ if depth_kind not in {"distance", "z"}:
272
+ raise RuntimeError(f"Unsupported fisheye depth_kind={depth_kind!r} at {path}")
273
+ return depth.clamp(min=0.0, max=self.depth_max_m), depth_kind
274
+
275
+ @staticmethod
276
+ def _fisheye_z_depth_to_distance(z_depth: torch.Tensor, camera_params: torch.Tensor) -> torch.Tensor:
277
+ from unisharp.utils.fisheye_geer import build_fisheye624_raymap
278
+
279
+ h, w = int(z_depth.shape[-2]), int(z_depth.shape[-1])
280
+ rays = build_fisheye624_raymap(
281
+ camera_params.unsqueeze(0),
282
+ image_h=h,
283
+ image_w=w,
284
+ device=z_depth.device,
285
+ dtype=torch.float32,
286
+ )
287
+ ray_z = rays[:, 2:3].squeeze(0).to(device=z_depth.device, dtype=z_depth.dtype)
288
+ valid = torch.isfinite(z_depth) & (z_depth > 0.0) & torch.isfinite(ray_z) & (ray_z > 1e-4)
289
+ distance = z_depth / ray_z.clamp(min=1e-4)
290
+ return torch.where(valid, distance, torch.zeros_like(z_depth))
291
+
292
+ def _resolve_image_path(self, scene_dir: Path, image_name: str) -> Path | None:
293
+ rel = Path(image_name)
294
+ candidates = [
295
+ scene_dir / rel,
296
+ scene_dir / "images" / rel.name,
297
+ scene_dir / "resized_images" / rel.name,
298
+ scene_dir / "dslr" / rel,
299
+ scene_dir / "dslr" / "images" / rel.name,
300
+ scene_dir / "dslr" / "resized_images" / rel.name,
301
+ ]
302
+ for path in candidates:
303
+ if path.exists() and path.suffix in IMAGE_SUFFIXES:
304
+ return path
305
+ return None
306
+
307
+ def _resolve_depth_path(self, scene_dir: Path, image_name: str) -> Path | None:
308
+ stem = Path(image_name).stem
309
+ names = [stem, Path(image_name).name]
310
+ bases = [scene_dir, scene_dir / "dslr"]
311
+ for base in bases:
312
+ for depth_dir_name in DEPTH_DIR_NAMES:
313
+ depth_dir = base / depth_dir_name
314
+ for name in names:
315
+ for suffix in (".npz", ".npy"):
316
+ path = depth_dir / f"{name}{suffix}"
317
+ if path.exists():
318
+ return path
319
+ return None
320
+
321
+ def _resolve_mask_path(self, scene_dir: Path, image_name: str, mask_name: str | None) -> Path | None:
322
+ names = []
323
+ if mask_name:
324
+ names.append(Path(mask_name).name)
325
+ names.append(f"{Path(image_name).stem}.png")
326
+ bases = [scene_dir, scene_dir / "dslr"]
327
+ for base in bases:
328
+ for name in names:
329
+ direct = base / name
330
+ if direct.exists():
331
+ return direct
332
+ for mask_dir_name in MASK_DIR_NAMES:
333
+ path = base / mask_dir_name / name
334
+ if path.exists():
335
+ return path
336
+ return None
337
+
338
+ def _load_scene_frames(self, scene_id: str, scene_dir: Path) -> tuple[torch.Tensor, list[dict[str, object]]]:
339
+ transforms_path = scene_dir / "nerfstudio" / "transforms.json"
340
+ if not transforms_path.exists():
341
+ transforms_path = scene_dir / "dslr" / "nerfstudio" / "transforms.json"
342
+ meta = json.loads(transforms_path.read_text(encoding="utf-8"))
343
+ camera_params = _opencv_fisheye_to_fisheye624_params(meta)
344
+ camera_hw = _camera_hw_from_meta(meta)
345
+ w2c_by_name = _read_colmap_w2c(scene_dir / "colmap" / "images.txt")
346
+ if not w2c_by_name:
347
+ w2c_by_name = _read_colmap_w2c(scene_dir / "dslr" / "colmap" / "images.txt")
348
+
349
+ raw_frames = list(meta.get("frames", [])) + list(meta.get("test_frames", []))
350
+ frames: list[dict[str, object]] = []
351
+ for frame in raw_frames:
352
+ image_name = Path(str(frame.get("file_path", ""))).name
353
+ if not image_name:
354
+ continue
355
+ if self.skip_bad and bool(frame.get("is_bad", False)):
356
+ continue
357
+ image_path = self._resolve_image_path(scene_dir, image_name)
358
+ depth_path = self._resolve_depth_path(scene_dir, image_name)
359
+ if image_path is None or depth_path is None:
360
+ continue
361
+ w2c = w2c_by_name.get(image_name)
362
+ if w2c is None and frame.get("transform_matrix") is not None:
363
+ c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32)
364
+ w2c = torch.linalg.inv(c2w)
365
+ if w2c is None:
366
+ continue
367
+ center = torch.linalg.inv(w2c)[:3, 3]
368
+ frames.append(
369
+ {
370
+ "image_name": image_name,
371
+ "image_path": image_path,
372
+ "depth_path": depth_path,
373
+ "mask_path": self._resolve_mask_path(scene_dir, image_name, frame.get("mask_path")),
374
+ "w2c": w2c.to(torch.float32),
375
+ "center": center.to(torch.float32),
376
+ "idx": len(frames),
377
+ "scene": scene_id,
378
+ "camera_hw": _camera_hw_from_meta(frame) or camera_hw,
379
+ }
380
+ )
381
+ return camera_params, sorted(frames, key=lambda x: str(x["image_name"]))
382
+
383
+ def _load_frame_tensor(self, frame: dict[str, object], camera_params: torch.Tensor) -> dict[str, torch.Tensor]:
384
+ rgb = self._load_rgb(frame["image_path"]) # type: ignore[arg-type]
385
+ rgb_hw = (int(rgb.shape[-2]), int(rgb.shape[-1]))
386
+ camera_hw = frame.get("camera_hw", None)
387
+ params = camera_params.clone()
388
+ if isinstance(camera_hw, tuple):
389
+ params = _scale_fisheye624_params(params, src_hw=camera_hw, dst_hw=rgb_hw)
390
+ depth, depth_kind = self._load_depth_map(frame["depth_path"]) # type: ignore[arg-type]
391
+ if tuple(depth.shape[-2:]) != tuple(rgb.shape[-2:]):
392
+ depth = torch.nn.functional.interpolate(
393
+ depth.unsqueeze(0),
394
+ size=(int(rgb.shape[-2]), int(rgb.shape[-1])),
395
+ mode="nearest",
396
+ ).squeeze(0)
397
+ if depth_kind == "z":
398
+ depth = self._fisheye_z_depth_to_distance(depth, params)
399
+ valid = (torch.isfinite(depth) & (depth > 0.0)).to(torch.float32)
400
+ mask_path = frame.get("mask_path", None)
401
+ if isinstance(mask_path, Path):
402
+ mask = self._load_mask(mask_path, (int(rgb.shape[-2]), int(rgb.shape[-1])))
403
+ if mask is not None:
404
+ valid = valid * mask
405
+ else:
406
+ valid = valid * (rgb.to(torch.float32).sum(dim=0, keepdim=True) > 1.0).to(torch.float32)
407
+ return {
408
+ "rgb_u8": rgb,
409
+ "depth_m": depth.clamp(min=0.0, max=self.depth_max_m),
410
+ "valid_mask": valid,
411
+ "camera_params": params,
412
+ }
413
+
414
+ def _iter_scene_pairs(self, scene_id: str, scene_dir: Path, rng: random.Random):
415
+ try:
416
+ camera_params, frames = self._load_scene_frames(scene_id, scene_dir)
417
+ except Exception as exc:
418
+ LOGGER.debug("Skip ScanNet++ scene %s: %s", str(scene_id), str(exc))
419
+ return
420
+ if len(frames) < 2:
421
+ return
422
+ loaded: dict[int, dict[str, torch.Tensor]] = {}
423
+
424
+ def get_loaded(pos: int) -> dict[str, torch.Tensor]:
425
+ if pos not in loaded:
426
+ loaded[pos] = self._load_frame_tensor(frames[pos], camera_params)
427
+ return loaded[pos]
428
+
429
+ order = list(range(len(frames)))
430
+ if self.shuffle_frame:
431
+ rng.shuffle(order)
432
+ for src_pos in order:
433
+ src_item = frames[src_pos]
434
+ src_center = src_item["center"]
435
+ assert torch.is_tensor(src_center)
436
+ candidates: list[int] = []
437
+ for tgt_pos in range(max(0, src_pos - self.max_frame_gap), min(len(frames), src_pos + self.max_frame_gap + 1)):
438
+ if tgt_pos == src_pos:
439
+ continue
440
+ gap = abs(tgt_pos - src_pos)
441
+ if gap < self.min_frame_gap:
442
+ continue
443
+ tgt_center = frames[tgt_pos]["center"]
444
+ assert torch.is_tensor(tgt_center)
445
+ if float(torch.norm(tgt_center - src_center, p=2).item()) > self.pair_max_translation_m:
446
+ continue
447
+ candidates.append(tgt_pos)
448
+ if not candidates:
449
+ continue
450
+ tgt_pos = rng.choice(candidates)
451
+ try:
452
+ src_loaded = get_loaded(src_pos)
453
+ tgt_loaded = get_loaded(tgt_pos)
454
+ except Exception:
455
+ continue
456
+ yield ScannetppFisheyePairSample(
457
+ src_rgb_u8=src_loaded["rgb_u8"],
458
+ tgt_rgb_u8=tgt_loaded["rgb_u8"],
459
+ src_depth_m=src_loaded["depth_m"],
460
+ tgt_depth_m=tgt_loaded["depth_m"],
461
+ src_valid_mask=src_loaded["valid_mask"],
462
+ tgt_valid_mask=tgt_loaded["valid_mask"],
463
+ src_w2c=src_item["w2c"], # type: ignore[arg-type]
464
+ tgt_w2c=frames[tgt_pos]["w2c"], # type: ignore[arg-type]
465
+ src_camera_params=src_loaded["camera_params"],
466
+ tgt_camera_params=tgt_loaded["camera_params"],
467
+ src_idx=int(src_item["idx"]),
468
+ tgt_idx=int(frames[tgt_pos]["idx"]),
469
+ scene=str(scene_id),
470
+ )
471
+
472
+ def __iter__(self):
473
+ worker = torch.utils.data.get_worker_info()
474
+ worker_id = 0 if worker is None else int(worker.id)
475
+ num_workers = 1 if worker is None else int(worker.num_workers)
476
+ rng = random.Random(self.seed + 1009 * self.epoch + 97 * self.ddp_rank + 17 * worker_id)
477
+ specs = list(self.scene_specs)
478
+ if self.shuffle_scene:
479
+ rng.shuffle(specs)
480
+ specs = specs[self.ddp_rank :: max(self.ddp_world_size, 1)]
481
+ specs = specs[worker_id :: num_workers]
482
+ pending: dict[tuple[int, int], list[ScannetppFisheyePairSample]] = {}
483
+ for scene_id, scene_dir in specs:
484
+ for sample in self._iter_scene_pairs(scene_id, scene_dir, rng):
485
+ hw = (int(sample.src_rgb_u8.shape[-2]), int(sample.src_rgb_u8.shape[-1]))
486
+ bucket = pending.setdefault(hw, [])
487
+ bucket.append(sample)
488
+ while len(bucket) >= self.batch_size_hint:
489
+ packed = bucket[: self.batch_size_hint]
490
+ del bucket[: self.batch_size_hint]
491
+ yield _stack_batch(packed)
unisharp/datasets/sim_panorama.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ import csv
5
+ from dataclasses import dataclass
6
+ import os
7
+ from pathlib import Path
8
+ import random
9
+ import re
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ from torch.utils.data import IterableDataset
16
+
17
+ from unisharp.datasets.panogs import PanOGSSample
18
+ from unisharp import DEFAULT_MAX_DEPTH_M
19
+
20
+ try:
21
+ import h5py
22
+ except ImportError:
23
+ h5py = None
24
+
25
+
26
+ _NUM_RE = re.compile(r"(\d+)(?!.*\d)")
27
+ _SIM_CACHE_VERSION = 6
28
+
29
+
30
+ def _default_dataset_manifest_dir() -> Path:
31
+ repo_root = Path(__file__).resolve().parents[2]
32
+ parent_path = repo_root.parent / "dataset_manifests"
33
+ if parent_path.exists():
34
+ return parent_path
35
+ return repo_root / "dataset_manifests"
36
+
37
+
38
+ def _frame_index_from_name(name: str) -> int | None:
39
+ match = _NUM_RE.search(Path(name).stem)
40
+ if match is None:
41
+ return None
42
+ return int(match.group(1))
43
+
44
+
45
+ def _sim_csv_xyz_to_training_position(x: float, y: float, z: float) -> torch.Tensor:
46
+ return torch.tensor([float(y), -float(z), float(x)], dtype=torch.float32)
47
+
48
+
49
+ class _EquirecToCube:
50
+ def __init__(self, equ_h: int, equ_w: int, face_w: int) -> None:
51
+ self.equ_h = int(equ_h)
52
+ self.equ_w = int(equ_w)
53
+ self.face_w = int(face_w)
54
+ self.grid = self._build_grid()
55
+ rng = torch.linspace(-0.5, 0.5, steps=self.face_w, dtype=torch.float32)
56
+ xx, yy = torch.meshgrid(rng, -rng, indexing="xy")
57
+ self.ray_z = (1.0 / torch.sqrt((2.0 * xx) ** 2 + (2.0 * yy) ** 2 + 1.0)).contiguous()
58
+
59
+ def _build_grid(self) -> torch.Tensor:
60
+ face_w = self.face_w
61
+ rng = torch.linspace(-0.5, 0.5, steps=face_w, dtype=torch.float32)
62
+ grid = torch.stack(torch.meshgrid(rng, -rng, indexing="xy"), dim=-1)
63
+ xyz = torch.zeros((6, face_w, face_w, 3), dtype=torch.float32)
64
+
65
+ xyz[0, :, :, 0] = grid[:, :, 0]
66
+ xyz[0, :, :, 1] = grid[:, :, 1]
67
+ xyz[0, :, :, 2] = 0.5
68
+
69
+ xyz[1, :, :, 2] = torch.flip(grid[:, :, 0], dims=[1])
70
+ xyz[1, :, :, 1] = torch.flip(grid[:, :, 1], dims=[1])
71
+ xyz[1, :, :, 0] = 0.5
72
+
73
+ xyz[2, :, :, 0] = torch.flip(grid[:, :, 0], dims=[1])
74
+ xyz[2, :, :, 1] = torch.flip(grid[:, :, 1], dims=[1])
75
+ xyz[2, :, :, 2] = -0.5
76
+
77
+ xyz[3, :, :, 2] = grid[:, :, 0]
78
+ xyz[3, :, :, 1] = grid[:, :, 1]
79
+ xyz[3, :, :, 0] = -0.5
80
+
81
+ xyz[4, :, :, 0] = torch.flip(grid[:, :, 0], dims=[0])
82
+ xyz[4, :, :, 2] = torch.flip(grid[:, :, 1], dims=[0])
83
+ xyz[4, :, :, 1] = 0.5
84
+
85
+ xyz[5, :, :, 0] = grid[:, :, 0]
86
+ xyz[5, :, :, 2] = grid[:, :, 1]
87
+ xyz[5, :, :, 1] = -0.5
88
+
89
+ xyz = xyz[[4, 2, 3, 0, 1, 5]]
90
+ x = xyz[..., 0]
91
+ y = xyz[..., 1]
92
+ z = xyz[..., 2]
93
+ lon = torch.atan2(x, z)
94
+ c = torch.sqrt(x * x + z * z).clamp(min=1e-8)
95
+ lat = torch.atan2(y, c)
96
+ grid_x = lon / np.pi
97
+ grid_y = (-2.0 * lat / np.pi).clamp(min=-1.0, max=1.0)
98
+ return torch.stack([grid_x, grid_y], dim=-1).contiguous()
99
+
100
+ def run_depth(self, depth_1hw: torch.Tensor) -> torch.Tensor:
101
+ depth = depth_1hw.unsqueeze(0).to(torch.float32)
102
+ if tuple(depth.shape[-2:]) != (self.equ_h, self.equ_w):
103
+ depth = F.interpolate(depth, size=(self.equ_h, self.equ_w), mode="nearest")
104
+ depth_faces = F.grid_sample(
105
+ depth.expand(6, -1, -1, -1),
106
+ self.grid,
107
+ mode="nearest",
108
+ padding_mode="border",
109
+ align_corners=True,
110
+ )
111
+ depth_faces = depth_faces[:, 0] * self.ray_z.to(depth_faces.device, depth_faces.dtype)
112
+ return depth_faces.unsqueeze(-1).to(torch.float32).cpu()
113
+
114
+ def run(self, rgb_chw: torch.Tensor, depth_1hw: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
115
+ rgb = rgb_chw.unsqueeze(0).to(torch.float32) / 255.0
116
+ if tuple(rgb.shape[-2:]) != (self.equ_h, self.equ_w):
117
+ rgb = F.interpolate(rgb, size=(self.equ_h, self.equ_w), mode="bilinear", align_corners=True)
118
+ rgb_faces = F.grid_sample(
119
+ rgb.expand(6, -1, -1, -1),
120
+ self.grid,
121
+ mode="bilinear",
122
+ padding_mode="border",
123
+ align_corners=True,
124
+ )
125
+ cube_rgb = (rgb_faces.permute(0, 2, 3, 1).clamp(0.0, 1.0) * 255.0).round().to(torch.uint8)
126
+ cube_depth = self.run_depth(depth_1hw)
127
+ return cube_rgb.cpu(), cube_depth
128
+
129
+ def run_rgb(self, rgb_chw: torch.Tensor) -> torch.Tensor:
130
+ rgb = rgb_chw.unsqueeze(0).to(torch.float32) / 255.0
131
+ if tuple(rgb.shape[-2:]) != (self.equ_h, self.equ_w):
132
+ rgb = F.interpolate(rgb, size=(self.equ_h, self.equ_w), mode="bilinear", align_corners=True)
133
+ rgb_faces = F.grid_sample(
134
+ rgb.expand(6, -1, -1, -1),
135
+ self.grid,
136
+ mode="bilinear",
137
+ padding_mode="border",
138
+ align_corners=True,
139
+ )
140
+ return (rgb_faces.permute(0, 2, 3, 1).clamp(0.0, 1.0) * 255.0).round().to(torch.uint8).cpu()
141
+
142
+
143
+ @dataclass(frozen=True)
144
+ class _SimFrame:
145
+ frame_idx: int
146
+ rgb_path: Path
147
+ depth_path: Path
148
+ position_xyz: torch.Tensor
149
+
150
+
151
+ class SimPanoramaDataset(IterableDataset):
152
+ def __init__(
153
+ self,
154
+ root: Path,
155
+ pose_root: Path,
156
+ scene_names: list[str] | None = None,
157
+ scene_list_file: Path | None = None,
158
+ position_scale: float = 0.01,
159
+ max_index_gap: int = 10,
160
+ pair_max_translation_m: float = 0.5,
161
+ pair_min_depth_overlap: float = 0.6,
162
+ pair_overlap_margin: float = 1.05,
163
+ pairs_per_chunk: int = 15,
164
+ chunk_size: int = 30,
165
+ shuffle_scene: bool = True,
166
+ ddp_rank: int = 0,
167
+ ddp_world_size: int = 1,
168
+ depth_max_m: float = DEFAULT_MAX_DEPTH_M,
169
+ far_depth_invalid_m: float = 30.0,
170
+ far_depth_invalid_max_frac: float = 1.0,
171
+ max_long_edge: int = 0,
172
+ seed: int = 0,
173
+ ) -> None:
174
+ super().__init__()
175
+ self.root = Path(root)
176
+ self.pose_root = Path(pose_root)
177
+ self.scene_list_file = Path(scene_list_file) if scene_list_file is not None else None
178
+ requested_scene_names = [str(name).strip() for name in (scene_names or []) if str(name).strip()]
179
+ if self.scene_list_file is not None:
180
+ if not self.scene_list_file.exists():
181
+ raise FileNotFoundError(self.scene_list_file)
182
+ manifest_scene_names = [
183
+ line.strip()
184
+ for line in self.scene_list_file.read_text(encoding="utf-8").splitlines()
185
+ if line.strip()
186
+ ]
187
+ if requested_scene_names:
188
+ requested = set(requested_scene_names)
189
+ self.scene_names = [name for name in manifest_scene_names if name in requested]
190
+ else:
191
+ self.scene_names = manifest_scene_names
192
+ else:
193
+ self.scene_names = requested_scene_names
194
+ if not self.scene_names:
195
+ raise ValueError("SimPanoramaDataset requires scene_names or scene_list_file.")
196
+ self.position_scale = float(position_scale)
197
+ self.max_index_gap = int(max_index_gap)
198
+ self.pair_max_translation_m = float(pair_max_translation_m)
199
+ self.pair_min_depth_overlap = float(pair_min_depth_overlap)
200
+ self.pair_overlap_margin = float(pair_overlap_margin)
201
+ self.pairs_per_chunk = int(pairs_per_chunk)
202
+ self.chunk_size = int(chunk_size)
203
+ self.shuffle_scene = bool(shuffle_scene)
204
+ self.ddp_rank = int(ddp_rank)
205
+ self.ddp_world_size = int(ddp_world_size)
206
+ self.seed = int(seed)
207
+ self.depth_max_m = float(depth_max_m)
208
+ self.far_depth_invalid_m = float(far_depth_invalid_m)
209
+ self.far_depth_invalid_max_frac = float(far_depth_invalid_max_frac)
210
+ self.max_long_edge = max(int(max_long_edge), 0)
211
+ self.epoch = 0
212
+ self.cache_dir = _default_dataset_manifest_dir() / "sim_cache"
213
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
214
+ self._scene_frames_cache: dict[str, list[_SimFrame]] = {}
215
+
216
+ def set_epoch(self, epoch: int) -> None:
217
+ self.epoch = int(epoch)
218
+
219
+ @staticmethod
220
+ def _is_depth_path(path: Path) -> bool:
221
+ tokens = [part.lower() for part in path.parts]
222
+ name = path.name.lower()
223
+ return ("depth" in name) or any("depth" in token for token in tokens)
224
+
225
+ @staticmethod
226
+ def _is_image_path(path: Path) -> bool:
227
+ return path.suffix.lower() in (".png", ".jpg", ".jpeg", ".webp")
228
+
229
+ @staticmethod
230
+ def _load_rgb(path: Path) -> torch.Tensor:
231
+ with Image.open(path) as img:
232
+ img = img.convert("RGB")
233
+ arr = np.asarray(img, dtype=np.uint8).copy()
234
+ return torch.from_numpy(arr).permute(2, 0, 1).contiguous()
235
+
236
+ @staticmethod
237
+ def _image_hw(path: Path) -> tuple[int, int]:
238
+ with Image.open(path) as img:
239
+ width, height = img.size
240
+ return int(height), int(width)
241
+
242
+ def _load_depth(self, path: Path) -> torch.Tensor:
243
+ suffix = path.suffix.lower()
244
+ if suffix == ".npy":
245
+ dep = np.load(path)
246
+ elif suffix == ".npz":
247
+ payload = np.load(path)
248
+ key = "depth" if "depth" in payload.files else payload.files[0]
249
+ dep = payload[key]
250
+ elif suffix in (".h5", ".hdf5"):
251
+ if h5py is None:
252
+ raise ImportError("h5py is required to read sim .h5 depth files but is not installed.")
253
+ with h5py.File(path, "r") as f:
254
+ keys = list(f.keys())
255
+ if not keys:
256
+ raise RuntimeError(f"Empty sim depth file: {path}")
257
+ dep = f[keys[0]][()]
258
+ else:
259
+ with Image.open(path) as img:
260
+ dep = np.asarray(img)
261
+ dep = dep.astype(np.float32)
262
+ if dep.ndim == 3:
263
+ dep = dep[..., 0]
264
+ dep[~np.isfinite(dep)] = 0.0
265
+ if self.far_depth_invalid_m > 0.0:
266
+ far = dep > self.far_depth_invalid_m
267
+ if 0.0 < float(far.mean()) <= self.far_depth_invalid_max_frac:
268
+ dep[far] = 0.0
269
+ dep = np.clip(dep, a_min=0.0, a_max=self.depth_max_m)
270
+ return torch.from_numpy(dep).unsqueeze(0)
271
+
272
+ def _resize_erp_if_needed(self, rgb: torch.Tensor, depth: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
273
+ if self.max_long_edge <= 0:
274
+ return rgb, depth
275
+ h = int(rgb.shape[-2])
276
+ w = int(rgb.shape[-1])
277
+ long_edge = max(h, w)
278
+ if long_edge <= self.max_long_edge:
279
+ return rgb, depth
280
+ scale = float(self.max_long_edge) / float(long_edge)
281
+ new_h = max(2, int(round(float(h) * scale)))
282
+ new_w = max(2, int(round(float(w) * scale)))
283
+ rgb_f = rgb.unsqueeze(0).to(dtype=torch.float32)
284
+ rgb_resized = F.interpolate(rgb_f, size=(new_h, new_w), mode="bilinear", align_corners=False)
285
+ rgb_out = rgb_resized[0].round().clamp(0.0, 255.0).to(dtype=torch.uint8).contiguous()
286
+ depth_f = depth.unsqueeze(0).to(dtype=torch.float32)
287
+ depth_out = F.interpolate(depth_f, size=(new_h, new_w), mode="nearest")[0].contiguous()
288
+ return rgb_out, depth_out
289
+
290
+ def _pose_csv_for_scene(self, scene_name: str) -> Path:
291
+ direct = self.pose_root / f"{scene_name}.csv"
292
+ if direct.exists():
293
+ return direct
294
+ matches = sorted(self.pose_root.glob(f"*{scene_name}*.csv"))
295
+ if matches:
296
+ return matches[0]
297
+ raise FileNotFoundError(f"No pose csv found for sim scene={scene_name} under {self.pose_root}")
298
+
299
+ def _parse_pose_csv(self, csv_path: Path) -> list[tuple[int, torch.Tensor]]:
300
+ with csv_path.open("r", encoding="utf-8") as f:
301
+ rows = list(csv.DictReader(f))
302
+ if not rows:
303
+ raise RuntimeError(f"Empty sim pose csv: {csv_path}")
304
+ poses: list[tuple[int, torch.Tensor]] = []
305
+ for row_idx, row in enumerate(rows):
306
+ lower = {str(k).strip().lower(): v for k, v in row.items()}
307
+ frame_val = None
308
+ for key in ("frame", "frame_idx", "idx", "index", "id", "image", "filename", "name"):
309
+ if key in lower and str(lower[key]).strip():
310
+ frame_val = _frame_index_from_name(str(lower[key]))
311
+ if frame_val is None:
312
+ try:
313
+ frame_val = int(float(str(lower[key]).strip()))
314
+ except Exception:
315
+ frame_val = None
316
+ break
317
+ x = next((lower[k] for k in lower if k in ("x", "tx", "pos_x", "world_x")), None)
318
+ y = next((lower[k] for k in lower if k in ("y", "ty", "pos_y", "world_y")), None)
319
+ z = next((lower[k] for k in lower if k in ("z", "tz", "pos_z", "world_z")), None)
320
+ if x is None or y is None or z is None:
321
+ numeric_vals = []
322
+ for val in row.values():
323
+ try:
324
+ numeric_vals.append(float(str(val).strip()))
325
+ except Exception:
326
+ continue
327
+ if len(numeric_vals) < 3:
328
+ raise ValueError(f"Failed to parse xyz from sim csv row: {row}")
329
+ x, y, z = numeric_vals[:3]
330
+ pos = _sim_csv_xyz_to_training_position(float(x), float(y), float(z)) * self.position_scale
331
+ poses.append((int(frame_val if frame_val is not None else row_idx), pos))
332
+ return poses
333
+
334
+ def _scan_scene_frames(self, scene_name: str) -> list[_SimFrame]:
335
+ scene_dir = self.root / scene_name
336
+ if not scene_dir.exists():
337
+ raise FileNotFoundError(scene_dir)
338
+ all_files = [p for p in scene_dir.rglob("*") if p.is_file()]
339
+ image_map: dict[int, Path] = {}
340
+ depth_map: dict[int, Path] = {}
341
+ for path in all_files:
342
+ idx = _frame_index_from_name(path.name)
343
+ if idx is None:
344
+ continue
345
+ if self._is_depth_path(path) and path.suffix.lower() in (".png", ".npy", ".npz", ".exr", ".h5", ".hdf5"):
346
+ depth_map.setdefault(idx, path)
347
+ elif self._is_image_path(path):
348
+ image_map.setdefault(idx, path)
349
+ pose_entries = self._parse_pose_csv(self._pose_csv_for_scene(scene_name))
350
+ frames: list[_SimFrame] = []
351
+ for frame_idx, pos in pose_entries:
352
+ rgb_path = image_map.get(int(frame_idx))
353
+ depth_path = depth_map.get(int(frame_idx))
354
+ if rgb_path is None or depth_path is None:
355
+ continue
356
+ frames.append(_SimFrame(frame_idx=int(frame_idx), rgb_path=rgb_path, depth_path=depth_path, position_xyz=pos))
357
+ return frames
358
+
359
+ @staticmethod
360
+ def _atomic_torch_save(path: Path, payload: object) -> None:
361
+ path.parent.mkdir(parents=True, exist_ok=True)
362
+ tmp_path = path.with_suffix(path.suffix + f".tmp.{os.getpid()}")
363
+ torch.save(payload, tmp_path)
364
+ os.replace(tmp_path, path)
365
+
366
+ def _scene_index_cache_path(self, scene_name: str) -> Path:
367
+ scene_key = scene_name.replace("/", "__")
368
+ return self.cache_dir / f"{scene_key}_ps{self.position_scale:g}_frames_v{_SIM_CACHE_VERSION}.pt"
369
+
370
+ def _load_or_build_scene_frames(self, scene_name: str) -> list[_SimFrame]:
371
+ cached = self._scene_frames_cache.get(scene_name)
372
+ if cached is not None:
373
+ return cached
374
+ cache_path = self._scene_index_cache_path(scene_name)
375
+ frames: list[_SimFrame]
376
+ if cache_path.exists():
377
+ try:
378
+ payload = torch.load(cache_path, map_location="cpu")
379
+ frames = [
380
+ _SimFrame(
381
+ frame_idx=int(item["frame_idx"]),
382
+ rgb_path=Path(str(item["rgb_path"])),
383
+ depth_path=Path(str(item["depth_path"])),
384
+ position_xyz=torch.tensor(item["position_xyz"], dtype=torch.float32),
385
+ )
386
+ for item in payload["frames"]
387
+ ]
388
+ except Exception:
389
+ frames = self._scan_scene_frames(scene_name)
390
+ payload = {
391
+ "scene": scene_name,
392
+ "frames": [
393
+ {
394
+ "frame_idx": int(frame.frame_idx),
395
+ "rgb_path": str(frame.rgb_path),
396
+ "depth_path": str(frame.depth_path),
397
+ "position_xyz": frame.position_xyz.tolist(),
398
+ }
399
+ for frame in frames
400
+ ],
401
+ }
402
+ self._atomic_torch_save(cache_path, payload)
403
+ else:
404
+ frames = self._scan_scene_frames(scene_name)
405
+ payload = {
406
+ "scene": scene_name,
407
+ "frames": [
408
+ {
409
+ "frame_idx": int(frame.frame_idx),
410
+ "rgb_path": str(frame.rgb_path),
411
+ "depth_path": str(frame.depth_path),
412
+ "position_xyz": frame.position_xyz.tolist(),
413
+ }
414
+ for frame in frames
415
+ ],
416
+ }
417
+ self._atomic_torch_save(cache_path, payload)
418
+ self._scene_frames_cache[scene_name] = frames
419
+ return frames
420
+
421
+ def _random_chunk_pairs(self, chunk: list[_SimFrame], rng: random.Random) -> list[tuple[int, int]]:
422
+ if len(chunk) < self.chunk_size:
423
+ return []
424
+ indices = list(range(len(chunk)))
425
+ rng.shuffle(indices)
426
+ max_pairs = min(self.pairs_per_chunk, len(indices) // 2)
427
+ return [(indices[2 * i], indices[2 * i + 1]) for i in range(max_pairs)]
428
+
429
+ def __iter__(self):
430
+ scene_names = list(self.scene_names)
431
+ order_rng = random.Random(self.seed + self.epoch)
432
+ if self.shuffle_scene:
433
+ order_rng.shuffle(scene_names)
434
+ worker_info = torch.utils.data.get_worker_info()
435
+ num_workers = worker_info.num_workers if worker_info is not None else 1
436
+ worker_id = worker_info.id if worker_info is not None else 0
437
+ total_shards = max(1, self.ddp_world_size * num_workers)
438
+ shard_id = self.ddp_rank * num_workers + worker_id
439
+ pair_unit_index = 0
440
+
441
+ for scene_order_idx, scene_name in enumerate(scene_names):
442
+ try:
443
+ frames = self._load_or_build_scene_frames(scene_name)
444
+ except Exception:
445
+ continue
446
+ if len(frames) < self.chunk_size:
447
+ continue
448
+ for start in range(0, len(frames), self.chunk_size):
449
+ chunk = frames[start : start + self.chunk_size]
450
+ if len(chunk) < self.chunk_size:
451
+ break
452
+ try:
453
+ equ_h, equ_w = self._image_hw(chunk[0].rgb_path)
454
+ if self.max_long_edge > 0 and max(equ_h, equ_w) > self.max_long_edge:
455
+ scale = float(self.max_long_edge) / float(max(equ_h, equ_w))
456
+ equ_h = max(2, int(round(float(equ_h) * scale)))
457
+ equ_w = max(2, int(round(float(equ_w) * scale)))
458
+ face_w = max(1, equ_h // 2)
459
+ converter = _EquirecToCube(equ_h=equ_h, equ_w=equ_w, face_w=face_w)
460
+ except Exception:
461
+ continue
462
+ def load_frame(local_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
463
+ frame = chunk[local_idx]
464
+ rgb = self._load_rgb(frame.rgb_path)
465
+ depth = self._load_depth(frame.depth_path)
466
+ rgb, depth = self._resize_erp_if_needed(rgb, depth)
467
+ cube_rgb, cube_depth = converter.run(rgb, depth)
468
+ return rgb, depth, cube_rgb, cube_depth
469
+
470
+ chunk_rng = random.Random(
471
+ self.seed + self.epoch * 1000003 + scene_order_idx * 1009 + start
472
+ )
473
+ pairs = self._random_chunk_pairs(chunk, chunk_rng)
474
+ for src_local, tgt_local in pairs:
475
+ if pair_unit_index % total_shards != shard_id:
476
+ pair_unit_index += 1
477
+ continue
478
+ pair_unit_index += 1
479
+ src_rgb, src_depth, src_cube_rgb, src_cube_depth = load_frame(src_local)
480
+ tgt_rgb, tgt_depth, tgt_cube_rgb, tgt_cube_depth = load_frame(tgt_local)
481
+ yield PanOGSSample(
482
+ src_erp_rgb_u8=src_rgb,
483
+ tgt_erp_rgb_u8=tgt_rgb,
484
+ src_erp_depth_m=src_depth,
485
+ tgt_erp_depth_m=tgt_depth,
486
+ src_cube_rgb_u8=src_cube_rgb,
487
+ tgt_cube_rgb_u8=tgt_cube_rgb,
488
+ src_cube_depth_m=src_cube_depth,
489
+ tgt_cube_depth_m=tgt_cube_depth,
490
+ src_R=torch.eye(3, dtype=torch.float32),
491
+ src_t=chunk[src_local].position_xyz.clone(),
492
+ tgt_R=torch.eye(3, dtype=torch.float32),
493
+ tgt_t=chunk[tgt_local].position_xyz.clone(),
494
+ src_idx=int(chunk[src_local].frame_idx),
495
+ tgt_idx=int(chunk[tgt_local].frame_idx),
496
+ scene=str(scene_name),
497
+ )
unisharp/datasets/wildrgbd.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import json
5
+ from pathlib import Path
6
+ import random
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ from torch.utils.data import IterableDataset
13
+
14
+ from unisharp.datasets.pair_sampling import project_overlap_ratio, resize_k3_align_corners_false, resize_rgb_u8_chw_high_quality
15
+ from unisharp import DEFAULT_MAX_DEPTH_M
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class WildRGBDPairSample:
20
+ src_rgb_u8: torch.Tensor
21
+ tgt_rgb_u8: torch.Tensor
22
+ src_depth_m: torch.Tensor
23
+ tgt_depth_m: torch.Tensor
24
+ src_w2c: torch.Tensor
25
+ tgt_w2c: torch.Tensor
26
+ src_intrinsics: torch.Tensor
27
+ tgt_intrinsics: torch.Tensor
28
+ src_idx: int
29
+ tgt_idx: int
30
+ scene: str
31
+
32
+
33
+ class WildRGBDDataset(IterableDataset):
34
+
35
+ def __init__(
36
+ self,
37
+ root: Path | None = None,
38
+ scene_list_file: Path | None = None,
39
+ split: str = "train",
40
+ min_frame_gap: int = 1,
41
+ max_frame_gap: int = 32,
42
+ pair_max_translation_m: float = 0.5,
43
+ pair_min_overlap: float = 0.6,
44
+ pair_overlap_sample_h: int = 32,
45
+ pair_overlap_sample_w: int = 56,
46
+ pair_max_tries: int = 32,
47
+ output_h: int | None = None,
48
+ output_w: int | None = None,
49
+ shuffle_scene: bool = True,
50
+ shuffle_frame: bool = True,
51
+ ddp_rank: int = 0,
52
+ ddp_world_size: int = 1,
53
+ roots: list[Path] | None = None,
54
+ depth_max_m: float = DEFAULT_MAX_DEPTH_M,
55
+ seed: int = 0,
56
+ verify_manifest_paths: bool = False,
57
+ ) -> None:
58
+ super().__init__()
59
+ self.root = root
60
+ self.split = split
61
+ self.min_frame_gap = int(min_frame_gap)
62
+ self.max_frame_gap = int(max_frame_gap)
63
+ self.pair_max_translation_m = float(pair_max_translation_m)
64
+ self.pair_min_overlap = float(pair_min_overlap)
65
+ self.pair_overlap_sample_h = int(pair_overlap_sample_h)
66
+ self.pair_overlap_sample_w = int(pair_overlap_sample_w)
67
+ self.pair_max_tries = int(pair_max_tries)
68
+ self.output_h = int(output_h) if output_h is not None else None
69
+ self.output_w = int(output_w) if output_w is not None else None
70
+ self.shuffle_scene = bool(shuffle_scene)
71
+ self.shuffle_frame = bool(shuffle_frame)
72
+ self.ddp_rank = int(ddp_rank)
73
+ self.ddp_world_size = int(ddp_world_size)
74
+ self.depth_max_m = float(depth_max_m)
75
+ self.seed = int(seed)
76
+ self.epoch = 0
77
+ self.verify_manifest_paths = bool(verify_manifest_paths)
78
+ self.roots = [Path(p) for p in roots] if roots is not None else ([Path(root)] if root is not None else [])
79
+ if not self.roots:
80
+ raise ValueError("WildRGBDDataset requires at least one root path.")
81
+ self.scene_dirs: list[Path] = []
82
+ self.scene_list_file = Path(scene_list_file) if scene_list_file is not None else None
83
+ if self.scene_list_file is not None:
84
+ if not self.scene_list_file.exists():
85
+ raise FileNotFoundError(self.scene_list_file)
86
+ for raw in self.scene_list_file.read_text(encoding="utf-8").splitlines():
87
+ line = raw.strip()
88
+ if not line:
89
+ continue
90
+ scene_dir = Path(line)
91
+ if (not self.verify_manifest_paths) or scene_dir.is_dir():
92
+ self.scene_dirs.append(scene_dir)
93
+ else:
94
+ for ds_root in self.roots:
95
+ split_dir = ds_root / self.split
96
+ if not split_dir.exists():
97
+ raise FileNotFoundError(split_dir)
98
+ self.scene_dirs.extend(sorted([p for p in split_dir.iterdir() if p.is_dir()]))
99
+ if not self.scene_dirs:
100
+ raise RuntimeError("No scene folders found in the configured WildRGBD roots.")
101
+
102
+ def set_epoch(self, epoch: int) -> None:
103
+ self.epoch = int(epoch)
104
+
105
+ @staticmethod
106
+ def _load_scene_pose_and_k(scene_dir: Path) -> tuple[np.ndarray, dict[int, np.ndarray], torch.Tensor]:
107
+ metadata_path = scene_dir / "metadata"
108
+ with metadata_path.open("r", encoding="utf-8") as f:
109
+ meta = json.load(f)
110
+ k_raw = np.asarray(meta["K"], dtype=np.float32).reshape(3, 3).T
111
+ k = torch.from_numpy(k_raw.copy()).to(torch.float32)
112
+
113
+ pose_path = scene_dir / "cam_poses.txt"
114
+ pose_rows = np.genfromtxt(str(pose_path), dtype=np.float32)
115
+ if pose_rows.ndim == 1:
116
+ pose_rows = pose_rows[None, :]
117
+ if pose_rows.shape[1] < 17:
118
+ raise ValueError(f"Bad cam_poses.txt shape={pose_rows.shape} at {pose_path}")
119
+ frame_ids = pose_rows[:, 0].astype(np.int64)
120
+ c2w = pose_rows[:, 1:17].reshape(-1, 4, 4).astype(np.float32)
121
+ w2c = np.linalg.inv(c2w).astype(np.float32)
122
+ w2c_map = {int(fid): w2c[i] for i, fid in enumerate(frame_ids.tolist())}
123
+ return frame_ids, w2c_map, k
124
+
125
+ @staticmethod
126
+ def _collect_frame_ids(folder: Path) -> set[int]:
127
+ ids: set[int] = set()
128
+ if not folder.exists():
129
+ return ids
130
+ for p in folder.iterdir():
131
+ if not p.is_file():
132
+ continue
133
+ if p.suffix.lower() not in (".png", ".jpg", ".jpeg"):
134
+ continue
135
+ try:
136
+ ids.add(int(p.stem))
137
+ except ValueError:
138
+ continue
139
+ return ids
140
+
141
+ @staticmethod
142
+ def _resolve_img_path(folder: Path, idx: int) -> Path:
143
+ for ext in (".png", ".jpg", ".jpeg"):
144
+ p = folder / f"{idx:05d}{ext}"
145
+ if p.exists():
146
+ return p
147
+ raise FileNotFoundError(folder / f"{idx:05d}.png")
148
+
149
+ @staticmethod
150
+ def _load_rgb_u8(path: Path) -> torch.Tensor:
151
+ img = Image.open(path).convert("RGB")
152
+ arr = np.asarray(img, dtype=np.uint8).copy()
153
+ return torch.from_numpy(arr).permute(2, 0, 1).contiguous()
154
+
155
+ def _load_depth_m(self, depth_path: Path) -> torch.Tensor:
156
+ dep = np.asarray(Image.open(depth_path))
157
+ if dep.ndim != 2:
158
+ raise ValueError(f"Expected single-channel depth at {depth_path}, got {dep.shape}")
159
+ depth = dep.astype(np.float32)
160
+ if float(np.nanmax(depth)) > 200.0:
161
+ depth = depth / 1000.0
162
+ depth[~np.isfinite(depth)] = 0.0
163
+ depth = np.clip(depth, a_min=0.0, a_max=self.depth_max_m)
164
+ return torch.from_numpy(depth).unsqueeze(0).to(torch.float32)
165
+
166
+ @staticmethod
167
+ def _scene_name(scene_dir: Path) -> str:
168
+ parent = scene_dir.parent.parent.name if scene_dir.parent.name == "scenes" else scene_dir.parent.name
169
+ return f"{parent}/{scene_dir.name}"
170
+
171
+ def _sample_target_for_src(
172
+ self,
173
+ src_idx: int,
174
+ valid_ids: list[int],
175
+ w2c_map: dict[int, np.ndarray],
176
+ intr: torch.Tensor,
177
+ h: int,
178
+ w: int,
179
+ rng: random.Random,
180
+ ) -> int | None:
181
+ src_w2c = torch.from_numpy(w2c_map[int(src_idx)]).to(torch.float32)
182
+ src_center = torch.linalg.inv(src_w2c)[:3, 3]
183
+ candidates: list[int] = []
184
+ for j in valid_ids:
185
+ if int(j) == int(src_idx):
186
+ continue
187
+ gap = abs(int(j) - int(src_idx))
188
+ if gap < self.min_frame_gap or gap > self.max_frame_gap:
189
+ continue
190
+ jw2c = torch.from_numpy(w2c_map[int(j)]).to(torch.float32)
191
+ jcenter = torch.linalg.inv(jw2c)[:3, 3]
192
+ trans = torch.norm(jcenter - src_center, p=2).item()
193
+ if trans > self.pair_max_translation_m:
194
+ continue
195
+ candidates.append(int(j))
196
+ if not candidates:
197
+ return None
198
+
199
+ rng.shuffle(candidates)
200
+ tries = min(self.pair_max_tries, len(candidates))
201
+ src_k = intr.to(torch.float32)
202
+ src_w2c_t = src_w2c.to(torch.float32)
203
+ for j in candidates[:tries]:
204
+ tgt_w2c_t = torch.from_numpy(w2c_map[int(j)]).to(torch.float32)
205
+ ov_st = project_overlap_ratio(
206
+ src_w2c=src_w2c_t,
207
+ tgt_w2c=tgt_w2c_t,
208
+ src_k=src_k,
209
+ tgt_k=src_k,
210
+ h=h,
211
+ w=w,
212
+ sample_h=self.pair_overlap_sample_h,
213
+ sample_w=self.pair_overlap_sample_w,
214
+ )
215
+ ov_ts = project_overlap_ratio(
216
+ src_w2c=tgt_w2c_t,
217
+ tgt_w2c=src_w2c_t,
218
+ src_k=src_k,
219
+ tgt_k=src_k,
220
+ h=h,
221
+ w=w,
222
+ sample_h=self.pair_overlap_sample_h,
223
+ sample_w=self.pair_overlap_sample_w,
224
+ )
225
+ if 0.5 * (ov_st + ov_ts) >= self.pair_min_overlap:
226
+ return int(j)
227
+ return None
228
+
229
+ def __iter__(self):
230
+ scenes = list(self.scene_dirs)
231
+ order_rng = random.Random(self.seed + self.epoch)
232
+ if self.shuffle_scene:
233
+ order_rng.shuffle(scenes)
234
+
235
+ worker_info = torch.utils.data.get_worker_info()
236
+ num_workers = worker_info.num_workers if worker_info is not None else 1
237
+ worker_id = worker_info.id if worker_info is not None else 0
238
+ total_shards = max(1, self.ddp_world_size * num_workers)
239
+ shard_id = self.ddp_rank * num_workers + worker_id
240
+ src_unit_index = 0
241
+
242
+ for scene_order_idx, scene_dir in enumerate(scenes):
243
+ try:
244
+ pose_ids_np, w2c_map, intr = self._load_scene_pose_and_k(scene_dir)
245
+ except Exception:
246
+ continue
247
+ pose_ids = {int(x) for x in pose_ids_np.tolist()}
248
+ rgb_ids = self._collect_frame_ids(scene_dir / "rgb")
249
+ dep_ids = self._collect_frame_ids(scene_dir / "depth")
250
+ valid_ids = sorted(list(pose_ids & rgb_ids & dep_ids))
251
+ if len(valid_ids) < 2:
252
+ continue
253
+
254
+ src_order = list(valid_ids)
255
+ scene_rng = random.Random(self.seed + self.epoch * 1000003 + scene_order_idx)
256
+ if self.shuffle_frame:
257
+ scene_rng.shuffle(src_order)
258
+
259
+ for src_idx in src_order:
260
+ if src_unit_index % total_shards != shard_id:
261
+ src_unit_index += 1
262
+ continue
263
+ src_unit_index += 1
264
+ try:
265
+ rgb_src_path = self._resolve_img_path(scene_dir / "rgb", int(src_idx))
266
+ dep_src_path = self._resolve_img_path(scene_dir / "depth", int(src_idx))
267
+ src_img = self._load_rgb_u8(rgb_src_path)
268
+ src_depth = self._load_depth_m(dep_src_path)
269
+ except Exception:
270
+ continue
271
+
272
+ h, w = int(src_img.shape[1]), int(src_img.shape[2])
273
+ tgt_idx = self._sample_target_for_src(
274
+ src_idx=int(src_idx),
275
+ valid_ids=valid_ids,
276
+ w2c_map=w2c_map,
277
+ intr=intr,
278
+ h=h,
279
+ w=w,
280
+ rng=scene_rng,
281
+ )
282
+ if tgt_idx is None:
283
+ continue
284
+
285
+ try:
286
+ rgb_tgt_path = self._resolve_img_path(scene_dir / "rgb", int(tgt_idx))
287
+ dep_tgt_path = self._resolve_img_path(scene_dir / "depth", int(tgt_idx))
288
+ tgt_img = self._load_rgb_u8(rgb_tgt_path)
289
+ tgt_depth = self._load_depth_m(dep_tgt_path)
290
+ except Exception:
291
+ continue
292
+
293
+ if src_img.shape != tgt_img.shape:
294
+ continue
295
+
296
+ src_intr = intr.clone()
297
+ tgt_intr = intr.clone()
298
+ if self.output_h is not None and self.output_w is not None:
299
+ oh, ow = int(src_img.shape[1]), int(src_img.shape[2])
300
+ if oh > 0 and ow > 0 and (oh != self.output_h or ow != self.output_w):
301
+ sx = float(self.output_w) / float(ow)
302
+ sy = float(self.output_h) / float(oh)
303
+ src_img = resize_rgb_u8_chw_high_quality(src_img, size=(self.output_h, self.output_w))
304
+ tgt_img = resize_rgb_u8_chw_high_quality(tgt_img, size=(self.output_h, self.output_w))
305
+ src_depth = F.interpolate(
306
+ src_depth.unsqueeze(0),
307
+ size=(self.output_h, self.output_w),
308
+ mode="nearest",
309
+ ).squeeze(0)
310
+ tgt_depth = F.interpolate(
311
+ tgt_depth.unsqueeze(0),
312
+ size=(self.output_h, self.output_w),
313
+ mode="nearest",
314
+ ).squeeze(0)
315
+ src_intr = resize_k3_align_corners_false(src_intr, sx=sx, sy=sy)
316
+ tgt_intr = resize_k3_align_corners_false(tgt_intr, sx=sx, sy=sy)
317
+
318
+ yield WildRGBDPairSample(
319
+ src_rgb_u8=src_img,
320
+ tgt_rgb_u8=tgt_img,
321
+ src_depth_m=src_depth,
322
+ tgt_depth_m=tgt_depth,
323
+ src_w2c=torch.from_numpy(w2c_map[int(src_idx)]).to(torch.float32),
324
+ tgt_w2c=torch.from_numpy(w2c_map[int(tgt_idx)]).to(torch.float32),
325
+ src_intrinsics=src_intr,
326
+ tgt_intrinsics=tgt_intr,
327
+ src_idx=int(src_idx),
328
+ tgt_idx=int(tgt_idx),
329
+ scene=self._scene_name(scene_dir),
330
+ )
331
+
332
+
333
+ def wildrgbd_collate(batch: list[WildRGBDPairSample]) -> WildRGBDPairSample:
334
+ def stack(xs):
335
+ if isinstance(xs[0], torch.Tensor):
336
+ return torch.stack(xs, dim=0)
337
+ return xs
338
+
339
+ return WildRGBDPairSample(
340
+ src_rgb_u8=stack([b.src_rgb_u8 for b in batch]),
341
+ tgt_rgb_u8=stack([b.tgt_rgb_u8 for b in batch]),
342
+ src_depth_m=stack([b.src_depth_m for b in batch]),
343
+ tgt_depth_m=stack([b.tgt_depth_m for b in batch]),
344
+ src_w2c=stack([b.src_w2c for b in batch]),
345
+ tgt_w2c=stack([b.tgt_w2c for b in batch]),
346
+ src_intrinsics=stack([b.src_intrinsics for b in batch]),
347
+ tgt_intrinsics=stack([b.tgt_intrinsics for b in batch]),
348
+ src_idx=[b.src_idx for b in batch], # type: ignore[arg-type]
349
+ tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type]
350
+ scene=[b.scene for b in batch], # type: ignore[arg-type]
351
+ )
352
+
unisharp/losses/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .unisharp_loss import UnisharpLoss, UnisharpLossWeights
2
+
3
+ __all__ = ["UnisharpLoss", "UnisharpLossWeights"]
4
+
unisharp/losses/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (307 Bytes). View file
 
unisharp/losses/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (323 Bytes). View file
 
unisharp/losses/__pycache__/unisharp_loss.cpython-310.pyc ADDED
Binary file (32.9 kB). View file
 
unisharp/losses/__pycache__/unisharp_loss.cpython-313.pyc ADDED
Binary file (71.3 kB). View file
 
unisharp/losses/unisharp_loss.py ADDED
@@ -0,0 +1,1120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+ from dataclasses import dataclass
6
+ import math
7
+ import torch.nn.functional as F
8
+
9
+ from unisharp import DEFAULT_MAX_DEPTH_M
10
+ from unisharp.utils import linalg
11
+
12
+
13
+ def _masked_mean(x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
14
+ if m.dtype != x.dtype:
15
+ m = m.to(dtype=x.dtype)
16
+ while m.ndim < x.ndim:
17
+ m = m.unsqueeze(1)
18
+ m_expanded = m.expand_as(x)
19
+ return (x * m_expanded).sum() / m_expanded.sum().clamp(min=1.0)
20
+
21
+
22
+ def _finite_masked_mean_flat(x: torch.Tensor, valid: torch.Tensor) -> torch.Tensor:
23
+ mask = valid.to(device=x.device, dtype=torch.bool) & torch.isfinite(x)
24
+ x_safe = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
25
+ safe = torch.where(mask, x_safe, torch.zeros_like(x_safe))
26
+ return safe.sum() / mask.to(dtype=x.dtype).sum().clamp(min=1.0)
27
+
28
+
29
+ def _finite_abs_mean(x: torch.Tensor) -> torch.Tensor:
30
+ mask = torch.isfinite(x)
31
+ x_safe = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
32
+ safe_abs = torch.where(mask, x_safe.abs(), torch.zeros_like(x_safe))
33
+ return safe_abs.sum() / mask.to(dtype=x.dtype).sum().clamp(min=1.0)
34
+
35
+
36
+ _ERP_PROJECTION_MODELS = {"erp", "spherical", "equirect", "equirectangular"}
37
+ _FISHEYE_PROJECTION_MODELS = {"fisheye624", "opencv_fisheye"}
38
+
39
+
40
+ def _tv_l1(img: torch.Tensor) -> torch.Tensor:
41
+ zero = torch.zeros((), device=img.device, dtype=img.dtype)
42
+ dx = (img[..., :, 1:] - img[..., :, :-1]).abs().mean() if int(img.shape[-1]) > 1 else zero
43
+ dy = (img[..., 1:, :] - img[..., :-1, :]).abs().mean() if int(img.shape[-2]) > 1 else zero
44
+ return dx + dy
45
+
46
+
47
+ def _tv_l1_circular_h(img: torch.Tensor) -> torch.Tensor:
48
+ zero = torch.zeros((), device=img.device, dtype=img.dtype)
49
+ dx = (torch.roll(img, shifts=-1, dims=-1) - img).abs().mean() if int(img.shape[-1]) > 1 else zero
50
+ dy = (img[..., 1:, :] - img[..., :-1, :]).abs().mean() if int(img.shape[-2]) > 1 else zero
51
+ return dx + dy
52
+
53
+
54
+ def _checkerboard_l1_5d(x: torch.Tensor, *, circular_h: bool) -> torch.Tensor:
55
+ if x.ndim != 5:
56
+ raise ValueError(f"Expected [B,C,L,H,W], got {tuple(x.shape)}")
57
+ if int(x.shape[-2]) < 2 or int(x.shape[-1]) < 2:
58
+ return torch.zeros((), device=x.device, dtype=x.dtype)
59
+ x = x.to(dtype=torch.float32)
60
+ if bool(circular_h):
61
+ top = x[..., :-1, :]
62
+ bottom = x[..., 1:, :]
63
+ response = top - torch.roll(top, shifts=-1, dims=-1) - bottom + torch.roll(bottom, shifts=-1, dims=-1)
64
+ else:
65
+ response = x[..., :-1, :-1] - x[..., :-1, 1:] - x[..., 1:, :-1] + x[..., 1:, 1:]
66
+ return _finite_abs_mean(response)
67
+
68
+
69
+ def _delta_grid_checkerboard_loss(delta_grid: torch.Tensor, *, circular_h: bool) -> torch.Tensor:
70
+ if delta_grid.ndim != 5 or int(delta_grid.shape[1]) < 14:
71
+ raise ValueError(f"Expected delta grid [B,14,L,H,W], got {tuple(delta_grid.shape)}")
72
+ delta = delta_grid.to(dtype=torch.float32)
73
+ parts = [
74
+ delta[:, 3:6],
75
+ 0.1 * delta[:, 10:13],
76
+ delta[:, 13:14],
77
+ ]
78
+ return torch.stack([_checkerboard_l1_5d(part, circular_h=circular_h) for part in parts]).mean()
79
+
80
+
81
+ def _avg_pool2d_circular_h(x: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
82
+ if kernel_size <= 1 and stride <= 1:
83
+ return x
84
+ x = F.pad(x, (kernel_size - 1, 0, 0, 0), mode="circular")
85
+ return F.avg_pool2d(x, kernel_size=kernel_size, stride=stride)
86
+
87
+
88
+ def _resize_max_side(img: torch.Tensor, max_side: int, *, mode: str = "bilinear") -> torch.Tensor:
89
+ if max_side <= 0:
90
+ return img
91
+ h, w = int(img.shape[-2]), int(img.shape[-1])
92
+ ms = max(h, w)
93
+ if ms <= max_side:
94
+ return img
95
+ scale = float(max_side) / float(ms)
96
+ nh = max(1, int(math.floor(h * scale)))
97
+ nw = max(1, int(math.floor(w * scale)))
98
+ if mode in ("bilinear", "bicubic"):
99
+ return F.interpolate(img, size=(nh, nw), mode=mode, align_corners=False)
100
+ return F.interpolate(img, size=(nh, nw), mode=mode)
101
+
102
+
103
+ def _gram_matrix(fmap: torch.Tensor) -> torch.Tensor:
104
+ b, c, h, w = fmap.shape
105
+ x = fmap.reshape(b, c, h * w)
106
+ return x @ x.transpose(1, 2)
107
+
108
+
109
+ class _ResNet50Perceptual(nn.Module):
110
+
111
+ def __init__(self) -> None:
112
+ super().__init__()
113
+ try:
114
+ from torchvision.models import resnet50, ResNet50_Weights
115
+
116
+ net = resnet50(weights=ResNet50_Weights.DEFAULT)
117
+ except Exception:
118
+ from torchvision.models import resnet50
119
+
120
+ net = resnet50(pretrained=True)
121
+
122
+ net.eval()
123
+ net.requires_grad_(False)
124
+
125
+ self.conv1 = net.conv1
126
+ self.bn1 = net.bn1
127
+ self.relu = net.relu
128
+ self.maxpool = net.maxpool
129
+ self.layer1 = net.layer1
130
+ self.layer2 = net.layer2
131
+ self.layer3 = net.layer3
132
+ self.layer4 = net.layer4
133
+
134
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
135
+ std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
136
+ self.register_buffer("_mean", mean, persistent=False)
137
+ self.register_buffer("_std", std, persistent=False)
138
+
139
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
140
+ x = x.clamp(0.0, 1.0)
141
+ x = (x - self._mean) / self._std
142
+ x = self.conv1(x)
143
+ x = self.bn1(x)
144
+ x = self.relu(x)
145
+ x = self.maxpool(x)
146
+ f1 = self.layer1(x)
147
+ f2 = self.layer2(f1)
148
+ f3 = self.layer3(f2)
149
+ f4 = self.layer4(f3)
150
+ return [f1, f2, f3, f4]
151
+
152
+
153
+ def _to_linear_rgb(img_srgb: torch.Tensor) -> torch.Tensor:
154
+ from unisharp.utils.color_space import sRGB2linearRGB
155
+
156
+ return sRGB2linearRGB(img_srgb.clamp(0.0, 1.0))
157
+
158
+
159
+ @dataclass
160
+ class UnisharpLossWeights:
161
+
162
+ lambda_color: float = 1.0
163
+ lambda_alpha: float = 1.5
164
+ lambda_percep: float = 3.0
165
+ lambda_depth: float = 0.5
166
+ lambda_tv: float = 1.0
167
+ lambda_grad: float = 1.0
168
+ lambda_delta: float = 0.0
169
+ lambda_delta_rho: float = 0.0
170
+ lambda_splat: float = 0.0
171
+ lambda_edge_splat: float = 0.0
172
+ lambda_grid: float = 0.0
173
+ lambda_grad_img: float = 0.2
174
+ lambda_edge_rgb: float = 0.0
175
+
176
+
177
+ class UnisharpLoss(nn.Module):
178
+
179
+ SUPERVISION_MAX_DEPTH_M: float = DEFAULT_MAX_DEPTH_M
180
+
181
+ def __init__(
182
+ self,
183
+ weights: UnisharpLossWeights | None = None,
184
+ *,
185
+ grad_sigma: float = 1e-2,
186
+ grad_eps: float = 1e-2,
187
+ delta_clip: float = 10.0,
188
+ raw_delta_clip: float = 400.0,
189
+ raw_delta_rho_clip: float = 5.0,
190
+ alpha_tail_min: float = 0.99,
191
+ alpha_tail_weight: float = 0.0,
192
+ splat_sigma_min: float = 1e-1,
193
+ splat_sigma_max: float = 1e2,
194
+ edge_splat_sigma_max: float = 2.0,
195
+ depth_edge_log_threshold: float = 0.05,
196
+ depth_edge_dilate_px: int = 2,
197
+ percep_max_side: int = 384,
198
+ grad_img_scales: int = 4,
199
+ grad_img_circular_h: bool = True,
200
+ ) -> None:
201
+ super().__init__()
202
+ self.w = weights or UnisharpLossWeights()
203
+ self.grad_sigma = float(grad_sigma)
204
+ self.grad_eps = float(grad_eps)
205
+ self.delta_clip = float(delta_clip)
206
+ self.raw_delta_clip = float(raw_delta_clip)
207
+ self.raw_delta_rho_clip = float(raw_delta_rho_clip)
208
+ self.alpha_tail_min = float(alpha_tail_min)
209
+ self.alpha_tail_weight = float(alpha_tail_weight)
210
+ self.splat_sigma_min = float(splat_sigma_min)
211
+ self.splat_sigma_max = float(splat_sigma_max)
212
+ self.edge_splat_sigma_max = float(edge_splat_sigma_max)
213
+ self.depth_edge_log_threshold = float(depth_edge_log_threshold)
214
+ self.depth_edge_dilate_px = int(depth_edge_dilate_px)
215
+ self.percep_max_side = int(percep_max_side)
216
+ self.grad_img_scales = int(grad_img_scales)
217
+ self.grad_img_circular_h = bool(grad_img_circular_h)
218
+
219
+ sobel_kx = torch.tensor(
220
+ [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]
221
+ ).view(1, 1, 3, 3)
222
+ sobel_ky = torch.tensor(
223
+ [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]
224
+ ).view(1, 1, 3, 3)
225
+ self.register_buffer("_sobel_kx", sobel_kx, persistent=False)
226
+ self.register_buffer("_sobel_ky", sobel_ky, persistent=False)
227
+
228
+ self._percep_net: nn.Module | None = None
229
+ if self.w.lambda_percep > 0:
230
+ self._percep_net = _ResNet50Perceptual()
231
+
232
+ @staticmethod
233
+ def _flatten_gaussian_xyz(x: torch.Tensor | None, gauss_grid_shape: tuple[int, int, int] | None = None) -> torch.Tensor | None:
234
+ if not torch.is_tensor(x):
235
+ return None
236
+ if x.ndim == 5:
237
+ return x.permute(0, 2, 3, 4, 1).flatten(1, 3)
238
+ if x.ndim == 3 and int(x.shape[-1]) == 3:
239
+ return x
240
+ if x.ndim == 2 and gauss_grid_shape is not None:
241
+ return x.unsqueeze(-1)
242
+ return None
243
+
244
+ @staticmethod
245
+ def _flatten_gaussian_quat(
246
+ x: torch.Tensor | None,
247
+ gauss_grid_shape: tuple[int, int, int] | None = None,
248
+ ) -> torch.Tensor | None:
249
+ if not torch.is_tensor(x):
250
+ return None
251
+ if x.ndim == 5 and int(x.shape[1]) == 4:
252
+ return x.permute(0, 2, 3, 4, 1).flatten(1, 3)
253
+ if x.ndim == 3 and int(x.shape[-1]) == 4:
254
+ return x
255
+ if x.ndim == 2 and gauss_grid_shape is not None:
256
+ return x.unsqueeze(-1)
257
+ return None
258
+
259
+ @staticmethod
260
+ def _flatten_gaussian_scalar(
261
+ x: torch.Tensor | None,
262
+ gauss_grid_shape: tuple[int, int, int] | None = None,
263
+ ) -> torch.Tensor | None:
264
+ if not torch.is_tensor(x):
265
+ return None
266
+ if x.ndim == 5:
267
+ return x[:, 0].flatten(1)
268
+ if x.ndim == 4:
269
+ return x.flatten(1)
270
+ if x.ndim == 3 and int(x.shape[-1]) == 1:
271
+ return x[..., 0]
272
+ if x.ndim == 2:
273
+ return x
274
+ return None
275
+
276
+ @staticmethod
277
+ def _central_disparity_gradient(inv_depth: torch.Tensor, *, circular_h: bool) -> torch.Tensor:
278
+ if circular_h:
279
+ gx = 0.5 * (torch.roll(inv_depth, shifts=-1, dims=-1) - torch.roll(inv_depth, shifts=1, dims=-1)).abs()
280
+ else:
281
+ padded_x = F.pad(inv_depth, (1, 1, 0, 0), mode="replicate")
282
+ gx = 0.5 * (padded_x[..., 2:] - padded_x[..., :-2]).abs()
283
+ padded_y = F.pad(inv_depth, (0, 0, 1, 1), mode="replicate")
284
+ gy = 0.5 * (padded_y[..., 2:, :] - padded_y[..., :-2, :]).abs()
285
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
286
+
287
+ @staticmethod
288
+ def _sample_map_at_uv(feat: torch.Tensor, u: torch.Tensor, v: torch.Tensor, valid: torch.Tensor) -> torch.Tensor:
289
+ b, _, h, w = feat.shape
290
+ valid_bool = valid.to(dtype=torch.bool) & torch.isfinite(u) & torch.isfinite(v)
291
+ u_safe = torch.where(valid_bool, u, torch.zeros_like(u)).clamp(0.0, float(max(w - 1, 0)))
292
+ v_safe = torch.where(valid_bool, v, torch.zeros_like(v)).clamp(0.0, float(max(h - 1, 0)))
293
+ grid_x = (u_safe / max(float(w - 1), 1.0)) * 2.0 - 1.0
294
+ grid_y = (v_safe / max(float(h - 1), 1.0)) * 2.0 - 1.0
295
+ grid = torch.stack([grid_x, grid_y], dim=-1).view(b, -1, 1, 2)
296
+ sampled = F.grid_sample(feat, grid, mode="bilinear", padding_mode="zeros", align_corners=True)
297
+ return sampled[:, 0, :, 0] * valid_bool.to(dtype=feat.dtype)
298
+
299
+ @staticmethod
300
+ def _expand_camera_params(camera_params: torch.Tensor, *, batch_size: int, device: torch.device) -> torch.Tensor:
301
+ params = camera_params.to(device=device, dtype=torch.float32)
302
+ if params.ndim == 1:
303
+ params = params.unsqueeze(0)
304
+ if int(params.shape[0]) == 1 and int(batch_size) > 1:
305
+ params = params.expand(int(batch_size), -1)
306
+ return params
307
+
308
+ @staticmethod
309
+ def _project_fisheye624_points_px_stable(
310
+ pts: torch.Tensor,
311
+ camera_params: torch.Tensor,
312
+ *,
313
+ image_h: int,
314
+ image_w: int,
315
+ finite: torch.Tensor,
316
+ require_in_bounds: bool = True,
317
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
318
+ b, n, _ = pts.shape
319
+ params = UnisharpLoss._expand_camera_params(camera_params, batch_size=b, device=pts.device)
320
+ x, y, z = pts.unbind(dim=-1)
321
+ radius = torch.linalg.vector_norm(pts, dim=-1).clamp(min=1e-6)
322
+ front = z > (radius * 1e-4).clamp(min=1e-4)
323
+ projectable = finite & front
324
+
325
+ safe_pts = torch.zeros_like(pts)
326
+ safe_pts[..., 2] = 1.0
327
+ pts_proj = torch.where(projectable.unsqueeze(-1), pts, safe_pts)
328
+ x, y, z = pts_proj.unbind(dim=-1)
329
+ z_safe = z.clamp(min=1e-4)
330
+
331
+ ab = torch.stack([x / z_safe, y / z_safe], dim=-1)
332
+ r = torch.sqrt((ab * ab).sum(dim=-1, keepdim=True) + 1e-12)
333
+ theta = torch.atan(r)
334
+ unit_ab = ab / r
335
+
336
+ coeffs = params[:, 4:10].reshape(b, 1, 6)
337
+ theta_powers = torch.cat([theta.pow(3 + i * 2) for i in range(6)], dim=-1)
338
+ theta_distorted = theta + (theta_powers * coeffs).sum(dim=-1, keepdim=True)
339
+ uv_dist = theta_distorted * unit_ab
340
+
341
+ p0 = params[..., -6].reshape(b, 1)
342
+ p1 = params[..., -5].reshape(b, 1)
343
+ xr = uv_dist[..., 0]
344
+ yr = uv_dist[..., 1]
345
+ xr_sq = xr.square()
346
+ yr_sq = yr.square()
347
+ rd_sq = xr_sq + yr_sq
348
+ uv_x = uv_dist[..., 0] + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
349
+ uv_y = uv_dist[..., 1] + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
350
+
351
+ s0 = params[..., -4].reshape(b, 1)
352
+ s1 = params[..., -3].reshape(b, 1)
353
+ s2 = params[..., -2].reshape(b, 1)
354
+ s3 = params[..., -1].reshape(b, 1)
355
+ rd_4 = rd_sq.square()
356
+ uv_x = uv_x + s0 * rd_sq + s1 * rd_4
357
+ uv_y = uv_y + s2 * rd_sq + s3 * rd_4
358
+
359
+ if int(params.shape[-1]) == 15:
360
+ fx = fy = params[..., 0:1]
361
+ cx = params[..., 1:2]
362
+ cy = params[..., 2:3]
363
+ else:
364
+ fx = params[..., 0:1]
365
+ fy = params[..., 1:2]
366
+ cx = params[..., 2:3]
367
+ cy = params[..., 3:4]
368
+ u = uv_x * fx + cx
369
+ v = uv_y * fy + cy
370
+ valid = projectable & torch.isfinite(u) & torch.isfinite(v)
371
+ if require_in_bounds:
372
+ valid = valid & (u >= 0.0) & (u <= float(image_w - 1)) & (v >= 0.0) & (v <= float(image_h - 1))
373
+ return u, v, valid, radius
374
+
375
+ @staticmethod
376
+ def _project_points_px(
377
+ points: torch.Tensor,
378
+ *,
379
+ projection_model: str | None,
380
+ image_h: int,
381
+ image_w: int,
382
+ intrinsics: torch.Tensor | None = None,
383
+ camera_params: torch.Tensor | None = None,
384
+ require_in_bounds: bool = True,
385
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
386
+ pts_raw = points.to(dtype=torch.float32)
387
+ finite = torch.isfinite(pts_raw).all(dim=-1)
388
+ pts = torch.nan_to_num(pts_raw, nan=0.0, posinf=0.0, neginf=0.0)
389
+ b, n, _ = pts.shape
390
+ x, y, z = pts.unbind(dim=-1)
391
+ model = (projection_model or "pinhole").lower()
392
+
393
+ if model in _ERP_PROJECTION_MODELS:
394
+ radius_sq_raw = (pts * pts).sum(dim=-1)
395
+ direction_valid = finite & (radius_sq_raw > 1e-12)
396
+ safe_pts = torch.zeros_like(pts)
397
+ safe_pts[..., 2] = 1.0
398
+ pts_erp = torch.where(direction_valid.unsqueeze(-1), pts, safe_pts)
399
+ x, y, z = pts_erp.unbind(dim=-1)
400
+ radius_sq = (pts_erp * pts_erp).sum(dim=-1)
401
+ radius = torch.sqrt(radius_sq + 1e-12)
402
+ horizontal_sq = x.square() + z.square()
403
+ horizontal = torch.sqrt(horizontal_sq + 1e-12)
404
+ pole_angle_eps = max(1e-4, 0.5 * math.pi / float(max(image_h, image_w, 1)))
405
+ lon_valid = horizontal > radius * pole_angle_eps
406
+ lon_x = torch.where(lon_valid, x, torch.zeros_like(x))
407
+ lon_z = torch.where(lon_valid, z, torch.ones_like(z))
408
+ lon = torch.atan2(lon_x, lon_z)
409
+ pitch_down = torch.atan2(y, horizontal)
410
+ u = (lon / (2.0 * math.pi) + 0.5) * float(max(image_w, 1)) - 0.5
411
+ v = (0.5 + pitch_down / math.pi) * float(max(image_h, 1)) - 0.5
412
+ valid = direction_valid & lon_valid
413
+ valid = (
414
+ valid
415
+ & torch.isfinite(u)
416
+ & torch.isfinite(v)
417
+ & (u >= 0.0)
418
+ & (u <= float(image_w - 1))
419
+ & (v >= 0.0)
420
+ & (v <= float(image_h - 1))
421
+ )
422
+ return u, v, valid, radius.clamp(min=1e-6)
423
+
424
+ if model in _FISHEYE_PROJECTION_MODELS and torch.is_tensor(camera_params):
425
+ return UnisharpLoss._project_fisheye624_points_px_stable(
426
+ pts,
427
+ camera_params,
428
+ image_h=image_h,
429
+ image_w=image_w,
430
+ finite=finite,
431
+ require_in_bounds=require_in_bounds,
432
+ )
433
+
434
+ valid = finite & (z > 1e-4)
435
+ if not torch.is_tensor(intrinsics):
436
+ fx = torch.full((b, 1), float(max(image_w, image_h)), device=pts.device, dtype=torch.float32)
437
+ fy = fx.clone()
438
+ cx = torch.full((b, 1), 0.5 * float(max(image_w - 1, 1)), device=pts.device, dtype=torch.float32)
439
+ cy = torch.full((b, 1), 0.5 * float(max(image_h - 1, 1)), device=pts.device, dtype=torch.float32)
440
+ else:
441
+ k = intrinsics.to(device=pts.device, dtype=torch.float32)
442
+ if k.ndim == 2:
443
+ k = k.unsqueeze(0)
444
+ if int(k.shape[0]) == 1 and b > 1:
445
+ k = k.expand(b, -1, -1)
446
+ fx = k[:, 0, 0:1]
447
+ fy = k[:, 1, 1:2]
448
+ cx = k[:, 0, 2:3]
449
+ cy = k[:, 1, 2:3]
450
+ z_safe = z.clamp(min=1e-4)
451
+ u = fx * (x / z_safe) + cx
452
+ v = fy * (y / z_safe) + cy
453
+ valid = valid & torch.isfinite(u) & torch.isfinite(v)
454
+ if require_in_bounds:
455
+ valid = valid & (u >= 0.0) & (u <= float(image_w - 1)) & (v >= 0.0) & (v <= float(image_h - 1))
456
+ return u, v, valid, z_safe
457
+
458
+ def _projected_sigma_px(
459
+ self,
460
+ *,
461
+ gaussian_scales: torch.Tensor,
462
+ gaussian_quaternions: torch.Tensor | None,
463
+ gaussian_mean_vectors: torch.Tensor,
464
+ valid: torch.Tensor,
465
+ projection_model: str | None,
466
+ image_h: int,
467
+ image_w: int,
468
+ intrinsics: torch.Tensor | None = None,
469
+ camera_params: torch.Tensor | None = None,
470
+ projected_scale_factor: float | torch.Tensor | None = None,
471
+ ) -> torch.Tensor:
472
+ scales = self._flatten_gaussian_xyz(gaussian_scales)
473
+ quats = self._flatten_gaussian_quat(gaussian_quaternions)
474
+ means = self._flatten_gaussian_xyz(gaussian_mean_vectors)
475
+ if scales is None or means is None:
476
+ return torch.zeros_like(valid, dtype=torch.float32)
477
+ valid = valid.to(dtype=torch.bool) & torch.isfinite(scales).all(dim=-1) & torch.isfinite(means).all(dim=-1)
478
+ scales = torch.nan_to_num(scales.to(dtype=torch.float32), nan=0.0, posinf=0.0, neginf=0.0).abs()
479
+ means = torch.nan_to_num(means.to(dtype=torch.float32), nan=0.0, posinf=0.0, neginf=0.0)
480
+ model = (projection_model or "pinhole").lower()
481
+ if model in _ERP_PROJECTION_MODELS:
482
+ radius = torch.norm(means, dim=-1).clamp(min=1e-4)
483
+ sigma_u = scales[..., 0] / radius * (float(max(image_w, 1)) / (2.0 * math.pi))
484
+ sigma_v = scales[..., 1] / radius * (float(max(image_h, 1)) / math.pi)
485
+ sigma_px = torch.maximum(sigma_u.square(), sigma_v.square())
486
+ valid = valid & torch.isfinite(sigma_px)
487
+ sigma_px = torch.nan_to_num(sigma_px, nan=0.0, posinf=0.0, neginf=0.0)
488
+ return torch.where(valid, sigma_px, torch.zeros_like(sigma_px))
489
+
490
+ if quats is not None and tuple(quats.shape[:2]) == tuple(means.shape[:2]):
491
+ quats = torch.nan_to_num(quats.to(dtype=torch.float32), nan=0.0, posinf=0.0, neginf=0.0)
492
+ quat_norm = quats.norm(dim=-1, keepdim=True)
493
+ valid = valid & torch.isfinite(quats).all(dim=-1) & (quat_norm.squeeze(-1) > 1e-8)
494
+ quats = quats / quat_norm.clamp(min=1e-8)
495
+ rotations = linalg.rotation_matrices_from_quaternions(quats)
496
+ tangent_scales = scales[..., :2]
497
+ tangent_rotations = rotations[..., :, :2]
498
+ axis_offsets = (tangent_rotations * tangent_scales[..., None, :]).transpose(-1, -2)
499
+ axis_points = means[:, :, None, :] + axis_offsets
500
+ u0, v0, valid0, _ = self._project_points_px(
501
+ means,
502
+ projection_model=projection_model,
503
+ image_h=image_h,
504
+ image_w=image_w,
505
+ intrinsics=intrinsics,
506
+ camera_params=camera_params,
507
+ require_in_bounds=False,
508
+ )
509
+ b, n, axis_count, _ = axis_points.shape
510
+ u1, v1, valid1, _ = self._project_points_px(
511
+ axis_points.reshape(b, n * axis_count, 3),
512
+ projection_model=projection_model,
513
+ image_h=image_h,
514
+ image_w=image_w,
515
+ intrinsics=intrinsics,
516
+ camera_params=camera_params,
517
+ require_in_bounds=False,
518
+ )
519
+ u1 = u1.reshape(b, n, axis_count)
520
+ v1 = v1.reshape(b, n, axis_count)
521
+ valid1 = valid1.reshape(b, n, axis_count)
522
+ du = u1 - u0[..., None]
523
+ dv = v1 - v0[..., None]
524
+ if (projection_model or "pinhole").lower() in _ERP_PROJECTION_MODELS:
525
+ width = float(max(image_w, 1))
526
+ du = torch.remainder(du + 0.5 * width, width) - 0.5 * width
527
+ cov_xx = (du * du).sum(dim=-1)
528
+ cov_xy = (du * dv).sum(dim=-1)
529
+ cov_yy = (dv * dv).sum(dim=-1)
530
+ trace = cov_xx + cov_yy
531
+ disc = (cov_xx - cov_yy).square() + 4.0 * cov_xy.square()
532
+ sigma_px = 0.5 * (trace + (disc.clamp(min=0.0) + 1e-12).sqrt())
533
+ valid = valid & valid0 & valid1.all(dim=-1) & torch.isfinite(sigma_px)
534
+ sigma_px = torch.nan_to_num(sigma_px, nan=0.0, posinf=0.0, neginf=0.0)
535
+ return torch.where(valid, sigma_px, torch.zeros_like(sigma_px))
536
+
537
+ sigma_screen_3d = scales[..., :2].to(dtype=torch.float32).abs().amax(dim=-1).clamp(min=1e-8)
538
+ if model in {"fisheye624", "opencv_fisheye"} and torch.is_tensor(camera_params):
539
+ params = camera_params.to(device=means.device, dtype=torch.float32)
540
+ if params.ndim == 1:
541
+ params = params.unsqueeze(0)
542
+ if int(params.shape[0]) == 1 and int(means.shape[0]) > 1:
543
+ params = params.expand(int(means.shape[0]), -1)
544
+ if int(params.shape[-1]) == 15:
545
+ focal = params[:, 0:1].clamp(min=1.0)
546
+ else:
547
+ focal = 0.5 * (params[:, 0:1] + params[:, 1:2]).clamp(min=1.0)
548
+ radius = torch.norm(means, dim=-1).clamp(min=1e-4)
549
+ sigma_px = (sigma_screen_3d / radius * focal).square()
550
+ elif torch.is_tensor(intrinsics):
551
+ k = intrinsics.to(device=means.device, dtype=torch.float32)
552
+ if k.ndim == 2:
553
+ k = k.unsqueeze(0)
554
+ if int(k.shape[0]) == 1 and int(means.shape[0]) > 1:
555
+ k = k.expand(int(means.shape[0]), -1, -1)
556
+ focal = 0.5 * (k[:, 0, 0:1] + k[:, 1, 1:2]).clamp(min=1.0)
557
+ depth = means[..., 2].clamp(min=1e-4)
558
+ sigma_px = (sigma_screen_3d / depth * focal).square()
559
+ else:
560
+ depth = torch.norm(means, dim=-1).clamp(min=1e-4)
561
+ sigma_px = sigma_screen_3d / depth
562
+ if torch.is_tensor(projected_scale_factor):
563
+ sigma_px = sigma_px * projected_scale_factor.to(device=sigma_px.device, dtype=sigma_px.dtype)
564
+ elif projected_scale_factor is not None:
565
+ sigma_px = sigma_px * float(projected_scale_factor)
566
+ sigma_px = sigma_px.square()
567
+ valid = valid & torch.isfinite(sigma_px)
568
+ sigma_px = torch.nan_to_num(sigma_px, nan=0.0, posinf=0.0, neginf=0.0)
569
+ return torch.where(valid, sigma_px, torch.zeros_like(sigma_px))
570
+
571
+ def _depth_edge_band(
572
+ self,
573
+ depth_m: torch.Tensor,
574
+ valid_weight: torch.Tensor,
575
+ *,
576
+ circular_h: bool,
577
+ ) -> torch.Tensor:
578
+ depth = depth_m.to(dtype=torch.float32)
579
+ if depth.ndim == 3:
580
+ depth = depth.unsqueeze(1)
581
+ valid = torch.isfinite(depth) & (depth > 0.0) & (valid_weight[:, :1].to(dtype=torch.float32) > 0.5)
582
+ log_depth = torch.where(valid, depth.clamp(min=1e-4).log(), torch.zeros_like(depth))
583
+
584
+ if bool(circular_h):
585
+ right = torch.roll(log_depth, shifts=-1, dims=-1)
586
+ valid_right = valid & torch.roll(valid, shifts=-1, dims=-1)
587
+ edge_x = (right - log_depth).abs() > float(self.depth_edge_log_threshold)
588
+ edge_x = edge_x & valid_right
589
+ else:
590
+ edge_x = torch.zeros_like(valid)
591
+ edge_x[..., :, :-1] = (
592
+ (log_depth[..., :, 1:] - log_depth[..., :, :-1]).abs() > float(self.depth_edge_log_threshold)
593
+ ) & valid[..., :, 1:] & valid[..., :, :-1]
594
+
595
+ edge_y = torch.zeros_like(valid)
596
+ edge_y[..., :-1, :] = (
597
+ (log_depth[..., 1:, :] - log_depth[..., :-1, :]).abs() > float(self.depth_edge_log_threshold)
598
+ ) & valid[..., 1:, :] & valid[..., :-1, :]
599
+ edge = (edge_x | edge_y).to(dtype=torch.float32)
600
+
601
+ radius = max(int(self.depth_edge_dilate_px), 0)
602
+ if radius <= 0:
603
+ return edge
604
+ kernel = 2 * radius + 1
605
+ if bool(circular_h):
606
+ edge = F.pad(edge, (radius, radius, 0, 0), mode="circular")
607
+ edge = F.pad(edge, (0, 0, radius, radius), mode="constant", value=0.0)
608
+ return F.max_pool2d(edge, kernel_size=kernel, stride=1)
609
+ return F.max_pool2d(edge, kernel_size=kernel, stride=1, padding=radius)
610
+
611
+ def _ray_cell_sigma(
612
+ self,
613
+ *,
614
+ gaussian_scales: torch.Tensor,
615
+ gaussian_mean_vectors: torch.Tensor,
616
+ gaussian_angular_cell: torch.Tensor,
617
+ gauss_grid_shape: tuple[int, int, int] | None,
618
+ ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
619
+ scales = self._flatten_gaussian_xyz(gaussian_scales, gauss_grid_shape)
620
+ means = self._flatten_gaussian_xyz(gaussian_mean_vectors, gauss_grid_shape)
621
+ if scales is None or means is None:
622
+ return None, None
623
+ if not torch.is_tensor(gaussian_angular_cell):
624
+ return None, None
625
+ cell = gaussian_angular_cell.to(device=scales.device, dtype=torch.float32)
626
+ if cell.ndim != 5 or int(cell.shape[1]) != 2:
627
+ return None, None
628
+ if gauss_grid_shape is None:
629
+ return None, None
630
+ l, h, w = (int(gauss_grid_shape[0]), int(gauss_grid_shape[1]), int(gauss_grid_shape[2]))
631
+ if tuple(cell.shape[-2:]) != (h, w):
632
+ return None, None
633
+ if int(cell.shape[2]) == 1 and l > 1:
634
+ cell = cell.expand(-1, -1, l, -1, -1)
635
+ elif int(cell.shape[2]) != l:
636
+ return None, None
637
+ cell_flat = cell.permute(0, 2, 3, 4, 1).flatten(1, 3)
638
+ if int(cell_flat.shape[1]) != int(scales.shape[1]):
639
+ return None, None
640
+ radius = torch.linalg.norm(means.to(dtype=torch.float32), dim=-1, keepdim=True).clamp(min=1e-4)
641
+ tangent = scales[..., :2].to(dtype=torch.float32).abs()
642
+ sigma_cells = (tangent / radius / cell_flat.clamp(min=1e-6)).square()
643
+ valid = torch.isfinite(sigma_cells).all(dim=-1) & torch.isfinite(radius.squeeze(-1))
644
+ sigma_cells = torch.nan_to_num(sigma_cells, nan=0.0, posinf=0.0, neginf=0.0)
645
+ return sigma_cells, valid
646
+
647
+ def _dynamic_splat_sigma_limits(
648
+ self,
649
+ *,
650
+ sigma_proj: torch.Tensor,
651
+ projection_model: str | None,
652
+ image_h: int,
653
+ image_w: int,
654
+ intrinsics: torch.Tensor | None = None,
655
+ camera_params: torch.Tensor | None = None,
656
+ projected_scale_factor: float | torch.Tensor | None = None,
657
+ ) -> tuple[torch.Tensor, torch.Tensor]:
658
+ del projection_model, image_h, image_w, intrinsics, camera_params, projected_scale_factor
659
+ return (
660
+ torch.as_tensor(self.splat_sigma_min, device=sigma_proj.device, dtype=sigma_proj.dtype),
661
+ torch.as_tensor(self.splat_sigma_max, device=sigma_proj.device, dtype=sigma_proj.dtype),
662
+ )
663
+
664
+ def _sanitize_supervision_depth(self, depth_m: torch.Tensor, *, clamp_max: bool = True) -> torch.Tensor:
665
+ depth = depth_m.to(torch.float32)
666
+ valid = torch.isfinite(depth) & (depth > 0.0)
667
+ depth = torch.where(valid, depth, torch.zeros_like(depth))
668
+ if bool(valid.any().item()):
669
+ depth = depth.clone()
670
+ if bool(clamp_max):
671
+ depth[valid] = depth[valid].clamp(min=1e-4, max=float(self.SUPERVISION_MAX_DEPTH_M))
672
+ else:
673
+ depth[valid] = depth[valid].clamp(min=1e-4)
674
+ return depth
675
+
676
+ def _sobel_gradient_loss_erp(
677
+ self,
678
+ pred_depth_m: torch.Tensor,
679
+ gt_depth_m: torch.Tensor,
680
+ depth_weight: torch.Tensor,
681
+ circular_h: bool | None = None,
682
+ ) -> torch.Tensor:
683
+ dtype = pred_depth_m.dtype
684
+ device = pred_depth_m.device
685
+
686
+ kx = self._sobel_kx.to(dtype=dtype, device=device) # type: ignore[attr-defined]
687
+ ky = self._sobel_ky.to(dtype=dtype, device=device) # type: ignore[attr-defined]
688
+
689
+ log_pred = torch.log(pred_depth_m.clamp(min=1e-4))
690
+ log_gt = torch.log(gt_depth_m.clamp(min=1e-4))
691
+ log_diff = log_pred - log_gt
692
+
693
+ mask = depth_weight.to(dtype=dtype).clamp(min=0.0, max=1.0)
694
+ valid_mask = (mask > 0.5).to(dtype=dtype)
695
+ log_diff = torch.where(valid_mask > 0.5, log_diff, torch.zeros_like(log_diff))
696
+
697
+ total = torch.zeros((), device=device, dtype=dtype)
698
+ n_computed = 0
699
+
700
+ use_circular_h = self.grad_img_circular_h if circular_h is None else bool(circular_h)
701
+ ones_kernel = torch.ones((1, 1, 3, 3), device=device, dtype=dtype)
702
+ for _s in range(self.grad_img_scales):
703
+ if min(log_diff.shape[-2:]) < 4:
704
+ break
705
+
706
+ if use_circular_h:
707
+ padded = F.pad(log_diff, (1, 1, 0, 0), mode="circular")
708
+ padded = F.pad(padded, (0, 0, 1, 1), mode="reflect")
709
+ padded_mask = F.pad(valid_mask, (1, 1, 0, 0), mode="circular")
710
+ padded_mask = F.pad(padded_mask, (0, 0, 1, 1), mode="replicate")
711
+ else:
712
+ padded = F.pad(log_diff, (1, 1, 1, 1), mode="reflect")
713
+ padded_mask = F.pad(valid_mask, (1, 1, 1, 1), mode="replicate")
714
+
715
+ gx = F.conv2d(padded, kx)
716
+ gy = F.conv2d(padded, ky)
717
+ grad_mag = torch.sqrt(gx * gx + gy * gy + 1e-8)
718
+
719
+ stencil_valid = (F.conv2d(padded_mask, ones_kernel) >= 8.999).to(dtype=dtype)
720
+ n_valid = stencil_valid.sum().clamp(min=1.0)
721
+ total = total + (grad_mag * stencil_valid).sum() / n_valid
722
+ n_computed += 1
723
+
724
+ if _s < self.grad_img_scales - 1:
725
+ if use_circular_h:
726
+ pooled_mask = _avg_pool2d_circular_h(valid_mask, kernel_size=2, stride=2)
727
+ pooled_diff = _avg_pool2d_circular_h(log_diff * valid_mask, kernel_size=2, stride=2)
728
+ else:
729
+ pooled_mask = F.avg_pool2d(valid_mask, kernel_size=2, stride=2)
730
+ pooled_diff = F.avg_pool2d(log_diff * valid_mask, kernel_size=2, stride=2)
731
+ log_diff = pooled_diff / pooled_mask.clamp(min=1e-6)
732
+ valid_mask = (pooled_mask > 0.999).to(dtype=dtype)
733
+ log_diff = torch.where(valid_mask > 0.5, log_diff, torch.zeros_like(log_diff))
734
+
735
+ if n_computed == 0:
736
+ return torch.zeros((), device=device, dtype=dtype)
737
+ return total / float(n_computed)
738
+
739
+ def _sobel_xy_rgb(self, img: torch.Tensor, *, circular_h: bool) -> tuple[torch.Tensor, torch.Tensor]:
740
+ channels = int(img.shape[1])
741
+ kx = self._sobel_kx.to(dtype=img.dtype, device=img.device).expand(channels, 1, 3, 3) # type: ignore[attr-defined]
742
+ ky = self._sobel_ky.to(dtype=img.dtype, device=img.device).expand(channels, 1, 3, 3) # type: ignore[attr-defined]
743
+ if bool(circular_h):
744
+ padded = F.pad(img, (1, 1, 0, 0), mode="circular")
745
+ padded = F.pad(padded, (0, 0, 1, 1), mode="reflect")
746
+ else:
747
+ padded = F.pad(img, (1, 1, 1, 1), mode="reflect")
748
+ return (
749
+ F.conv2d(padded, kx, groups=channels),
750
+ F.conv2d(padded, ky, groups=channels),
751
+ )
752
+
753
+ def _edge_rgb_gradient_loss(
754
+ self,
755
+ pred_rgb_linear: torch.Tensor,
756
+ gt_rgb_linear: torch.Tensor,
757
+ valid_weight: torch.Tensor,
758
+ depth_edge_band: torch.Tensor | None,
759
+ *,
760
+ circular_h: bool,
761
+ ) -> torch.Tensor:
762
+ dtype = pred_rgb_linear.dtype
763
+ device = pred_rgb_linear.device
764
+ pred = pred_rgb_linear.to(dtype=torch.float32)
765
+ gt = gt_rgb_linear.to(device=device, dtype=torch.float32)
766
+ weight = valid_weight.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)[:, :1]
767
+
768
+ pred_gx, pred_gy = self._sobel_xy_rgb(pred, circular_h=circular_h)
769
+ gt_gx, gt_gy = self._sobel_xy_rgb(gt, circular_h=circular_h)
770
+ gt_mag = torch.sqrt(gt_gx.square() + gt_gy.square() + 1e-8).mean(dim=1, keepdim=True)
771
+
772
+ flat = gt_mag.detach().flatten(2)
773
+ mean = flat.mean(dim=-1, keepdim=True)[..., None]
774
+ std = flat.std(dim=-1, keepdim=True, unbiased=False)[..., None]
775
+ rgb_edge = (gt_mag.detach() > (mean + 0.5 * std).clamp(min=0.02)).to(dtype=torch.float32)
776
+
777
+ if torch.is_tensor(depth_edge_band):
778
+ edge_boost = depth_edge_band.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)
779
+ if tuple(edge_boost.shape[-2:]) != tuple(gt_mag.shape[-2:]):
780
+ edge_boost = F.interpolate(edge_boost, size=gt_mag.shape[-2:], mode="nearest")
781
+ edge_weight = rgb_edge * (1.0 + edge_boost[:, :1])
782
+ else:
783
+ edge_weight = rgb_edge
784
+
785
+ ones_kernel = torch.ones((1, 1, 3, 3), device=device, dtype=torch.float32)
786
+ if bool(circular_h):
787
+ padded_weight = F.pad(weight, (1, 1, 0, 0), mode="circular")
788
+ padded_weight = F.pad(padded_weight, (0, 0, 1, 1), mode="replicate")
789
+ else:
790
+ padded_weight = F.pad(weight, (1, 1, 1, 1), mode="replicate")
791
+ stencil_valid = (F.conv2d(padded_weight, ones_kernel) >= 8.999).to(dtype=torch.float32)
792
+
793
+ diff = (pred_gx - gt_gx).abs() + (pred_gy - gt_gy).abs()
794
+ diff = diff.mean(dim=1, keepdim=True)
795
+ final_weight = edge_weight * stencil_valid
796
+ return (diff * final_weight).sum().to(dtype=dtype) / final_weight.sum().clamp(min=1.0).to(dtype=dtype)
797
+
798
+ def forward(
799
+ self,
800
+ pred_rgb_linear: torch.Tensor,
801
+ pred_alpha: torch.Tensor,
802
+ pred_depth_m: torch.Tensor,
803
+ gt_rgb_u8: torch.Tensor,
804
+ gt_depth_m: torch.Tensor,
805
+ pred_depth2_m: torch.Tensor | None = None,
806
+ mask: torch.Tensor | None = None,
807
+ depth_mask: torch.Tensor | None = None,
808
+ delta_xy: torch.Tensor | None = None,
809
+ delta_rho: torch.Tensor | None = None,
810
+ delta_grid: torch.Tensor | None = None,
811
+ gaussian_scales: torch.Tensor | None = None,
812
+ gaussian_quaternions: torch.Tensor | None = None,
813
+ gaussian_mean_vectors: torch.Tensor | None = None,
814
+ gaussian_base_mean_vectors: torch.Tensor | None = None,
815
+ gaussian_angular_cell: torch.Tensor | None = None,
816
+ gaussian_opacities: torch.Tensor | None = None,
817
+ gauss_grid_shape: tuple[int, int, int] | None = None,
818
+ projected_scale_factor: float | torch.Tensor | None = None,
819
+ projection_model: str | None = None,
820
+ projection_intrinsics: torch.Tensor | None = None,
821
+ projection_camera_params: torch.Tensor | None = None,
822
+ apply_color: bool = True,
823
+ apply_alpha: bool = True,
824
+ apply_depth: bool = True,
825
+ apply_percep: bool = False,
826
+ apply_tv: bool = True,
827
+ apply_grad: bool = True,
828
+ apply_delta: bool = True,
829
+ apply_splat: bool = True,
830
+ apply_grad_img: bool = True,
831
+ grad_img_circular_h: bool | None = None,
832
+ ) -> dict[str, torch.Tensor]:
833
+ losses: dict[str, torch.Tensor] = {}
834
+ circular_h = bool(grad_img_circular_h) if grad_img_circular_h is not None else False
835
+
836
+ gt_rgb = gt_rgb_u8.to(pred_rgb_linear.device).float() / 255.0
837
+ gt_rgb_linear = _to_linear_rgb(gt_rgb)
838
+ pred_depth_m = self._sanitize_supervision_depth(pred_depth_m.to(pred_rgb_linear.device), clamp_max=False)
839
+ if pred_depth2_m is not None:
840
+ pred_depth2_m = self._sanitize_supervision_depth(pred_depth2_m.to(pred_rgb_linear.device), clamp_max=False)
841
+ gt_depth_raw = self._sanitize_supervision_depth(gt_depth_m.to(pred_rgb_linear.device))
842
+ depth_valid = torch.isfinite(gt_depth_raw) & (gt_depth_raw > 0.0)
843
+ gt_depth = gt_depth_raw.clamp(min=1e-4)
844
+
845
+ if mask is None:
846
+ m = torch.ones_like(pred_alpha)
847
+ else:
848
+ m = mask.to(pred_rgb_linear.device).to(pred_rgb_linear.dtype)
849
+ depth_weight = depth_valid.to(dtype=pred_depth_m.dtype) * m[:, :1].to(dtype=pred_depth_m.dtype)
850
+ if depth_mask is not None:
851
+ depth_weight = depth_weight * depth_mask.to(pred_rgb_linear.device).to(dtype=pred_depth_m.dtype)[:, :1]
852
+ pred_rgb_rendered = pred_rgb_linear.clamp(0.0, 1.0)
853
+
854
+ if apply_color and self.w.lambda_color > 0:
855
+ color_l1 = (pred_rgb_rendered - gt_rgb_linear).abs()
856
+ losses["color"] = _masked_mean(color_l1, m)
857
+ else:
858
+ losses["color"] = torch.zeros((), device=pred_rgb_linear.device)
859
+
860
+ if apply_alpha and self.w.lambda_alpha > 0:
861
+ a = pred_alpha.clamp(1e-6, 1.0 - 1e-6)
862
+ with torch.autocast(device_type=a.device.type, enabled=False):
863
+ alpha_bce = F.binary_cross_entropy(
864
+ a.to(dtype=torch.float32),
865
+ torch.ones_like(a, dtype=torch.float32),
866
+ reduction="none",
867
+ )
868
+ alpha_loss = _masked_mean(alpha_bce, m)
869
+ alpha_tail_min = torch.as_tensor(
870
+ self.alpha_tail_min,
871
+ device=a.device,
872
+ dtype=torch.float32,
873
+ ).clamp(min=0.0, max=1.0)
874
+ alpha_tail_weight = torch.as_tensor(
875
+ max(0.0, self.alpha_tail_weight),
876
+ device=a.device,
877
+ dtype=torch.float32,
878
+ )
879
+ if self.alpha_tail_min > 0.0 and self.alpha_tail_weight > 0.0:
880
+ tail = F.relu(alpha_tail_min - a.to(dtype=torch.float32))
881
+ tail = tail / alpha_tail_min.clamp(min=1e-6)
882
+ tail_mask = (m[:, :1].to(dtype=torch.bool)) & (tail > 0.0)
883
+ alpha_loss = alpha_loss + alpha_tail_weight * _finite_masked_mean_flat(tail, tail_mask)
884
+ losses["alpha"] = alpha_loss.to(dtype=pred_rgb_linear.dtype)
885
+ else:
886
+ losses["alpha"] = torch.zeros((), device=pred_rgb_linear.device)
887
+
888
+ if apply_depth and self.w.lambda_depth > 0:
889
+ w_depth = depth_weight
890
+ inv_pred1 = 1.0 / pred_depth_m.clamp(min=1e-4)
891
+ inv_gt = torch.zeros_like(inv_pred1)
892
+ inv_gt[depth_valid] = 1.0 / gt_depth[depth_valid]
893
+ depth_abs = (inv_pred1 - inv_gt).abs()
894
+ losses["depth"] = _masked_mean(depth_abs, w_depth)
895
+ else:
896
+ losses["depth"] = torch.zeros((), device=pred_rgb_linear.device)
897
+
898
+ if apply_tv and self.w.lambda_tv > 0 and (pred_depth2_m is not None):
899
+ inv2 = 1.0 / pred_depth2_m.clamp(min=1e-4)
900
+ losses["tv"] = _tv_l1_circular_h(inv2) if circular_h else _tv_l1(inv2)
901
+ else:
902
+ losses["tv"] = torch.zeros((), device=pred_rgb_linear.device)
903
+
904
+ image_h, image_w = int(pred_depth_m.shape[-2]), int(pred_depth_m.shape[-1])
905
+ projection_points = self._flatten_gaussian_xyz(gaussian_mean_vectors, gauss_grid_shape)
906
+ projected_u = projected_v = None
907
+ projected_valid = None
908
+ if projection_points is not None:
909
+ projected_u, projected_v, projected_valid, _projected_depth = self._project_points_px(
910
+ projection_points,
911
+ projection_model=projection_model,
912
+ image_h=image_h,
913
+ image_w=image_w,
914
+ intrinsics=projection_intrinsics,
915
+ camera_params=projection_camera_params,
916
+ )
917
+
918
+ if apply_grad and self.w.lambda_grad > 0:
919
+ inv1 = 1.0 / pred_depth_m.clamp(min=1e-4)
920
+ op_flat = self._flatten_gaussian_scalar(gaussian_opacities, gauss_grid_shape)
921
+ if projected_u is not None and projected_v is not None and projected_valid is not None and op_flat is not None:
922
+ grad_map = self._central_disparity_gradient(inv1, circular_h=circular_h)
923
+ grad_at_gauss = self._sample_map_at_uv(grad_map, projected_u, projected_v, projected_valid)
924
+ penalty = 1.0 - torch.exp(
925
+ -(1.0 / max(self.grad_sigma, 1e-8)) * F.relu(grad_at_gauss - self.grad_eps)
926
+ )
927
+ weight = projected_valid & torch.isfinite(grad_at_gauss) & torch.isfinite(op_flat)
928
+ mask_at_gauss = self._sample_map_at_uv(m[:, :1], projected_u, projected_v, projected_valid)
929
+ weight = weight & (mask_at_gauss > 0.5)
930
+ grad_value = op_flat.to(dtype=penalty.dtype).clamp(0, 1) * penalty
931
+ losses["grad"] = _finite_masked_mean_flat(grad_value, weight)
932
+ else:
933
+ raise RuntimeError(
934
+ "L_grad requires gaussian_mean_vectors, gaussian_opacities, "
935
+ "gauss_grid_shape, and projection metadata. The old "
936
+ "pred_alpha image-space fallback is disabled for ray-local training."
937
+ )
938
+ else:
939
+ losses["grad"] = torch.zeros((), device=pred_rgb_linear.device)
940
+
941
+ if apply_grad_img and self.w.lambda_grad_img > 0:
942
+ losses["grad_img"] = self._sobel_gradient_loss_erp(
943
+ pred_depth_m=pred_depth_m,
944
+ gt_depth_m=gt_depth,
945
+ depth_weight=depth_weight,
946
+ circular_h=grad_img_circular_h,
947
+ )
948
+ else:
949
+ losses["grad_img"] = torch.zeros((), device=pred_rgb_linear.device)
950
+
951
+ if apply_color and self.w.lambda_edge_rgb > 0:
952
+ depth_edge_for_rgb = self._depth_edge_band(gt_depth, depth_weight, circular_h=circular_h)
953
+ losses["edge_rgb"] = self._edge_rgb_gradient_loss(
954
+ pred_rgb_linear=pred_rgb_rendered,
955
+ gt_rgb_linear=gt_rgb_linear,
956
+ valid_weight=m,
957
+ depth_edge_band=depth_edge_for_rgb,
958
+ circular_h=circular_h,
959
+ )
960
+ else:
961
+ losses["edge_rgb"] = torch.zeros((), device=pred_rgb_linear.device)
962
+
963
+ if apply_delta and self.w.lambda_delta > 0:
964
+ if delta_xy is not None:
965
+ dx = F.relu(delta_xy[:, 0:1].abs() - self.raw_delta_clip)
966
+ dy = F.relu(delta_xy[:, 1:2].abs() - self.raw_delta_clip)
967
+ losses["delta"] = (dx + dy).mean()
968
+ else:
969
+ del gaussian_base_mean_vectors
970
+ raise RuntimeError(
971
+ "L_delta requires raw delta_xy in ray-local training. The old "
972
+ "screen-space pixel displacement fallback is disabled."
973
+ )
974
+ else:
975
+ losses["delta"] = torch.zeros((), device=pred_rgb_linear.device)
976
+
977
+ if apply_delta and self.w.lambda_delta_rho > 0 and delta_rho is not None:
978
+ dz = delta_rho.to(device=pred_rgb_linear.device, dtype=pred_rgb_linear.dtype)
979
+ finite = torch.isfinite(dz)
980
+ dz_safe = torch.nan_to_num(dz, nan=0.0, posinf=0.0, neginf=0.0)
981
+ penalty = F.relu(dz_safe.abs() - self.raw_delta_rho_clip)
982
+ penalty = torch.where(finite, penalty, torch.zeros_like(penalty))
983
+ losses["delta_rho"] = penalty.sum() / finite.to(dtype=penalty.dtype).sum().clamp(min=1.0)
984
+ else:
985
+ losses["delta_rho"] = torch.zeros((), device=pred_rgb_linear.device)
986
+
987
+ if self.w.lambda_grid > 0 and torch.is_tensor(delta_grid):
988
+ losses["grid"] = _delta_grid_checkerboard_loss(
989
+ delta_grid.to(device=pred_rgb_linear.device),
990
+ circular_h=circular_h,
991
+ ).to(dtype=pred_rgb_linear.dtype)
992
+ else:
993
+ losses["grid"] = torch.zeros((), device=pred_rgb_linear.device)
994
+
995
+ if apply_splat and self.w.lambda_splat > 0:
996
+ if gaussian_scales is None:
997
+ raise RuntimeError("L_splat requires gaussian_scales for projected screen-space variance.")
998
+ if gaussian_mean_vectors is None or projected_valid is None:
999
+ raise RuntimeError(
1000
+ "L_splat requires gaussian_mean_vectors and projection metadata "
1001
+ "to compute projected screen-space variance."
1002
+ )
1003
+ sigma_proj = self._projected_sigma_px(
1004
+ gaussian_scales=gaussian_scales,
1005
+ gaussian_quaternions=gaussian_quaternions,
1006
+ gaussian_mean_vectors=gaussian_mean_vectors,
1007
+ valid=projected_valid,
1008
+ projection_model=projection_model,
1009
+ image_h=image_h,
1010
+ image_w=image_w,
1011
+ intrinsics=projection_intrinsics,
1012
+ camera_params=projection_camera_params,
1013
+ projected_scale_factor=projected_scale_factor,
1014
+ )
1015
+ valid_splat = projected_valid & torch.isfinite(sigma_proj)
1016
+ splat_sigma_min = torch.as_tensor(
1017
+ self.splat_sigma_min,
1018
+ device=sigma_proj.device,
1019
+ dtype=sigma_proj.dtype,
1020
+ )
1021
+ splat_sigma_max = torch.as_tensor(
1022
+ self.splat_sigma_max,
1023
+ device=sigma_proj.device,
1024
+ dtype=sigma_proj.dtype,
1025
+ )
1026
+ lower_penalty = F.relu(splat_sigma_min - sigma_proj)
1027
+ upper_penalty = F.relu(sigma_proj - splat_sigma_max)
1028
+ splat_penalty = lower_penalty + upper_penalty
1029
+ losses["splat"] = _finite_masked_mean_flat(splat_penalty, valid_splat)
1030
+ else:
1031
+ sigma_proj = None
1032
+ valid_splat = None
1033
+ losses["splat"] = torch.zeros((), device=pred_rgb_linear.device)
1034
+
1035
+ if apply_splat and self.w.lambda_edge_splat > 0:
1036
+ if gaussian_scales is None:
1037
+ raise RuntimeError("L_edge_splat requires gaussian_scales for projected screen-space variance.")
1038
+ if gaussian_mean_vectors is None or projected_valid is None:
1039
+ raise RuntimeError(
1040
+ "L_edge_splat requires gaussian_mean_vectors and projection metadata "
1041
+ "to sample source depth-edge bands."
1042
+ )
1043
+ if sigma_proj is None or valid_splat is None:
1044
+ sigma_proj = self._projected_sigma_px(
1045
+ gaussian_scales=gaussian_scales,
1046
+ gaussian_quaternions=gaussian_quaternions,
1047
+ gaussian_mean_vectors=gaussian_mean_vectors,
1048
+ valid=projected_valid,
1049
+ projection_model=projection_model,
1050
+ image_h=image_h,
1051
+ image_w=image_w,
1052
+ intrinsics=projection_intrinsics,
1053
+ camera_params=projection_camera_params,
1054
+ projected_scale_factor=projected_scale_factor,
1055
+ )
1056
+ valid_splat = projected_valid & torch.isfinite(sigma_proj)
1057
+ edge_band = self._depth_edge_band(gt_depth, depth_weight, circular_h=circular_h)
1058
+ edge_at_gauss = self._sample_map_at_uv(edge_band, projected_u, projected_v, projected_valid)
1059
+ edge_valid = valid_splat & torch.isfinite(edge_at_gauss) & (edge_at_gauss > 0.5)
1060
+ edge_sigma_max = torch.as_tensor(
1061
+ self.edge_splat_sigma_max,
1062
+ device=sigma_proj.device,
1063
+ dtype=sigma_proj.dtype,
1064
+ )
1065
+ losses["edge_splat"] = _finite_masked_mean_flat(F.relu(sigma_proj - edge_sigma_max), edge_valid)
1066
+ else:
1067
+ losses["edge_splat"] = torch.zeros((), device=pred_rgb_linear.device)
1068
+
1069
+ zero = torch.zeros((), device=pred_rgb_linear.device)
1070
+ losses["percep_feat"] = zero
1071
+ losses["percep_gram"] = zero
1072
+ if apply_percep and self.w.lambda_percep > 0 and (self._percep_net is not None):
1073
+ from unisharp.utils.color_space import linearRGB2sRGB
1074
+
1075
+ pred_srgb = linearRGB2sRGB(pred_rgb_rendered.to(torch.float32)).clamp(0, 1)
1076
+ gt_srgb = gt_rgb.clamp(0, 1)
1077
+
1078
+ pred_srgb = _resize_max_side(pred_srgb, self.percep_max_side, mode="bilinear")
1079
+ gt_srgb = _resize_max_side(gt_srgb, self.percep_max_side, mode="bilinear")
1080
+
1081
+ feats_p = self._percep_net(pred_srgb)
1082
+ feats_g = self._percep_net(gt_srgb)
1083
+ loss_feat_total = torch.zeros((), device=pred_rgb_linear.device)
1084
+ loss_gram_total = torch.zeros((), device=pred_rgb_linear.device)
1085
+ for fp, fg in zip(feats_p, feats_g):
1086
+ d, h, w = fp.shape[1], fp.shape[2], fp.shape[3]
1087
+ lam_gram = 10.0 / float(max(1, d * d))
1088
+ lam_feat = 1.0 / float(max(1, d * h * w))
1089
+ diff = (fp - fg).pow(2)
1090
+ loss_feat = (diff.sum(dim=[1, 2, 3]) * lam_feat).mean()
1091
+ gram_norm = float(max(1, h * w))
1092
+ gp = _gram_matrix(fp) / gram_norm
1093
+ gg = _gram_matrix(fg) / gram_norm
1094
+ loss_gram = ((gp - gg).pow(2).sum(dim=[1, 2]) * lam_gram).mean()
1095
+ loss_feat_total = loss_feat_total + loss_feat
1096
+ loss_gram_total = loss_gram_total + loss_gram
1097
+ layer_count = float(max(1, len(feats_p)))
1098
+ losses["percep_feat"] = loss_feat_total / layer_count
1099
+ losses["percep_gram"] = loss_gram_total / layer_count
1100
+ losses["percep"] = losses["percep_feat"] + losses["percep_gram"]
1101
+ else:
1102
+ losses["percep"] = torch.zeros((), device=pred_rgb_linear.device)
1103
+
1104
+ losses["total"] = (
1105
+ self.w.lambda_color * losses["color"]
1106
+ + self.w.lambda_alpha * losses["alpha"]
1107
+ + self.w.lambda_percep * losses["percep"]
1108
+ + self.w.lambda_depth * losses["depth"]
1109
+ + self.w.lambda_tv * losses["tv"]
1110
+ + self.w.lambda_grad * losses["grad"]
1111
+ + self.w.lambda_grad_img * losses["grad_img"]
1112
+ + self.w.lambda_edge_rgb * losses["edge_rgb"]
1113
+ + self.w.lambda_delta * losses["delta"]
1114
+ + self.w.lambda_delta_rho * losses["delta_rho"]
1115
+ + self.w.lambda_splat * losses["splat"]
1116
+ + self.w.lambda_edge_splat * losses["edge_splat"]
1117
+ + self.w.lambda_grid * losses["grid"]
1118
+ )
1119
+ return losses
1120
+
unisharp/models/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from .feature_gaussian_decoder import (
5
+ FeatureGaussianDecoder,
6
+ FeatureGaussianDecoderParams,
7
+ ImageFeatures,
8
+ create_feature_gaussian_decoder,
9
+ )
10
+ from .unisharp_params import PanoPredictorParams
11
+ from .unisharp_feature import UnisharpFeatureConfig, UnisharpFeatureModel
12
+ from .unik3d_feature_extractor import UniK3DFeatureExtractor
13
+
14
+ __all__ = [
15
+ "PanoPredictorParams",
16
+ "UniK3DFeatureExtractor",
17
+ "FeatureGaussianDecoder",
18
+ "FeatureGaussianDecoderParams",
19
+ "ImageFeatures",
20
+ "create_feature_gaussian_decoder",
21
+ "UnisharpFeatureConfig",
22
+ "UnisharpFeatureModel",
23
+ ]
unisharp/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (709 Bytes). View file
 
unisharp/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (734 Bytes). View file
 
unisharp/models/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (6.21 kB). View file
 
unisharp/models/__pycache__/blocks.cpython-313.pyc ADDED
Binary file (9.24 kB). View file
 
unisharp/models/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (3.11 kB). View file
 
unisharp/models/__pycache__/decoder.cpython-313.pyc ADDED
Binary file (5.09 kB). View file