Spaces:
Running on Zero
Running on Zero
Upload 119 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- unisharp/.DS_Store +0 -0
- unisharp/__init__.py +1 -0
- unisharp/cli/__init__.py +13 -0
- unisharp/cli/__main__.py +12 -0
- unisharp/cli/__pycache__/__init__.cpython-310.pyc +0 -0
- unisharp/cli/__pycache__/__init__.cpython-313.pyc +0 -0
- unisharp/cli/__pycache__/mixed_sampler.cpython-313.pyc +0 -0
- unisharp/cli/__pycache__/train_feature.cpython-310.pyc +0 -0
- unisharp/cli/__pycache__/train_feature.cpython-313.pyc +0 -0
- unisharp/cli/__pycache__/train_utils.cpython-313.pyc +0 -0
- unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc +3 -0
- unisharp/cli/mixed_sampler.py +80 -0
- unisharp/cli/train_feature.py +1410 -0
- unisharp/cli/train_utils.py +130 -0
- unisharp/cli/unified_trainer.py +1966 -0
- unisharp/datasets/__pycache__/dl3dv.cpython-310.pyc +0 -0
- unisharp/datasets/__pycache__/dl3dv.cpython-313.pyc +0 -0
- unisharp/datasets/__pycache__/pair_sampling.cpython-310.pyc +0 -0
- unisharp/datasets/__pycache__/pair_sampling.cpython-313.pyc +0 -0
- unisharp/datasets/__pycache__/panogs.cpython-310.pyc +0 -0
- unisharp/datasets/__pycache__/panogs.cpython-313.pyc +0 -0
- unisharp/datasets/__pycache__/re10k.cpython-310.pyc +0 -0
- unisharp/datasets/__pycache__/re10k.cpython-313.pyc +0 -0
- unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-310.pyc +0 -0
- unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-313.pyc +0 -0
- unisharp/datasets/__pycache__/sim_panorama.cpython-310.pyc +0 -0
- unisharp/datasets/__pycache__/sim_panorama.cpython-313.pyc +0 -0
- unisharp/datasets/__pycache__/wildrgbd.cpython-310.pyc +0 -0
- unisharp/datasets/__pycache__/wildrgbd.cpython-313.pyc +0 -0
- unisharp/datasets/dl3dv.py +305 -0
- unisharp/datasets/pair_sampling.py +99 -0
- unisharp/datasets/panogs.py +555 -0
- unisharp/datasets/re10k.py +718 -0
- unisharp/datasets/scannetpp_fisheye.py +491 -0
- unisharp/datasets/sim_panorama.py +497 -0
- unisharp/datasets/wildrgbd.py +352 -0
- unisharp/losses/__init__.py +4 -0
- unisharp/losses/__pycache__/__init__.cpython-310.pyc +0 -0
- unisharp/losses/__pycache__/__init__.cpython-313.pyc +0 -0
- unisharp/losses/__pycache__/unisharp_loss.cpython-310.pyc +0 -0
- unisharp/losses/__pycache__/unisharp_loss.cpython-313.pyc +0 -0
- unisharp/losses/unisharp_loss.py +1120 -0
- unisharp/models/__init__.py +23 -0
- unisharp/models/__pycache__/__init__.cpython-310.pyc +0 -0
- unisharp/models/__pycache__/__init__.cpython-313.pyc +0 -0
- unisharp/models/__pycache__/blocks.cpython-310.pyc +0 -0
- unisharp/models/__pycache__/blocks.cpython-313.pyc +0 -0
- unisharp/models/__pycache__/decoder.cpython-310.pyc +0 -0
- unisharp/models/__pycache__/decoder.cpython-313.pyc +0 -0
.gitattributes
CHANGED
|
@@ -3,3 +3,4 @@
|
|
| 3 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 5 |
examples/omnirooms/*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 3 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 4 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 5 |
examples/omnirooms/*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc filter=lfs diff=lfs merge=lfs -text
|
unisharp/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
unisharp/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
DEFAULT_MAX_DEPTH_M: float = 100.0
|
unisharp/cli/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import click
|
| 4 |
+
|
| 5 |
+
from .train_feature import train_feature_cli
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@click.group()
|
| 9 |
+
def main_cli():
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
main_cli.add_command(train_feature_cli, "train-feature")
|
| 13 |
+
|
unisharp/cli/__main__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from unisharp.cli import main_cli
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main() -> None:
|
| 7 |
+
main_cli()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if __name__ == "__main__":
|
| 11 |
+
main()
|
| 12 |
+
|
unisharp/cli/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (480 Bytes). View file
|
|
|
unisharp/cli/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (581 Bytes). View file
|
|
|
unisharp/cli/__pycache__/mixed_sampler.cpython-313.pyc
ADDED
|
Binary file (5.33 kB). View file
|
|
|
unisharp/cli/__pycache__/train_feature.cpython-310.pyc
ADDED
|
Binary file (42 kB). View file
|
|
|
unisharp/cli/__pycache__/train_feature.cpython-313.pyc
ADDED
|
Binary file (74 kB). View file
|
|
|
unisharp/cli/__pycache__/train_utils.cpython-313.pyc
ADDED
|
Binary file (7.03 kB). View file
|
|
|
unisharp/cli/__pycache__/unified_trainer.cpython-313.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2698667885fba54eef04bacbbee4bbf897c0dc3df57e6fe7d10ba185a76d2ed
|
| 3 |
+
size 103553
|
unisharp/cli/mixed_sampler.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import random
|
| 5 |
+
from typing import Any, Iterator
|
| 6 |
+
|
| 7 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LazyDataLoaderIterator:
|
| 11 |
+
|
| 12 |
+
def __init__(self, dataloader: Any):
|
| 13 |
+
self.dataloader = dataloader
|
| 14 |
+
self.iterator: Iterator[Any] | None = None
|
| 15 |
+
|
| 16 |
+
def __next__(self) -> Any:
|
| 17 |
+
if self.iterator is None:
|
| 18 |
+
self.iterator = iter(self.dataloader)
|
| 19 |
+
return next(self.iterator)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MixedDatasetSampler:
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
datasets: dict[str, Dataset | IterableDataset],
|
| 27 |
+
weights: dict[str, float],
|
| 28 |
+
iterators: dict[str, Iterator[Any]],
|
| 29 |
+
seed: int | None = None,
|
| 30 |
+
):
|
| 31 |
+
self.datasets = datasets
|
| 32 |
+
self.weights = weights
|
| 33 |
+
self.iterators = iterators
|
| 34 |
+
self._rng = random.Random(seed)
|
| 35 |
+
|
| 36 |
+
if len(weights) == 0:
|
| 37 |
+
raise ValueError("weights is empty")
|
| 38 |
+
for name, w in weights.items():
|
| 39 |
+
if float(w) <= 0.0:
|
| 40 |
+
raise ValueError(f"Dataset weight must be > 0, got {name}={float(w)}")
|
| 41 |
+
if name not in datasets:
|
| 42 |
+
raise ValueError(f"Unknown dataset in weights: {name}")
|
| 43 |
+
if name not in iterators:
|
| 44 |
+
raise ValueError(f"Missing iterator for dataset: {name}")
|
| 45 |
+
|
| 46 |
+
total_weight = float(sum(float(v) for v in weights.values()))
|
| 47 |
+
self.probs = {name: float(w) / total_weight for name, w in weights.items()}
|
| 48 |
+
self.dataset_names = list(datasets.keys())
|
| 49 |
+
self.prob_list = [self.probs[name] for name in self.dataset_names]
|
| 50 |
+
|
| 51 |
+
def sample(self) -> tuple[str, Any]:
|
| 52 |
+
dataset_name = self.choose_dataset_name()
|
| 53 |
+
batch = self.next_batch(dataset_name)
|
| 54 |
+
return dataset_name, batch
|
| 55 |
+
|
| 56 |
+
def choose_dataset_name(self, allowed_dataset_names: list[str] | None = None) -> str:
|
| 57 |
+
if allowed_dataset_names is None:
|
| 58 |
+
names = self.dataset_names
|
| 59 |
+
probs = self.prob_list
|
| 60 |
+
else:
|
| 61 |
+
names = [name for name in self.dataset_names if name in set(allowed_dataset_names)]
|
| 62 |
+
if len(names) == 0:
|
| 63 |
+
raise ValueError("No allowed dataset names available for sampling.")
|
| 64 |
+
probs = [self.probs[name] for name in names]
|
| 65 |
+
return self._rng.choices(names, weights=probs, k=1)[0]
|
| 66 |
+
|
| 67 |
+
def next_batch(self, dataset_name: str) -> Any:
|
| 68 |
+
if dataset_name not in self.iterators:
|
| 69 |
+
raise ValueError(f"Unknown dataset iterator: {dataset_name}")
|
| 70 |
+
try:
|
| 71 |
+
batch = next(self.iterators[dataset_name])
|
| 72 |
+
except StopIteration as exc:
|
| 73 |
+
raise StopIteration(f"Dataset {dataset_name} exhausted") from exc
|
| 74 |
+
return batch
|
| 75 |
+
|
| 76 |
+
def get_sampling_stats(self) -> dict[str, float]:
|
| 77 |
+
return {
|
| 78 |
+
"probabilities": self.probs.copy(),
|
| 79 |
+
"sampling": self.weights.copy(),
|
| 80 |
+
}
|
unisharp/cli/train_feature.py
ADDED
|
@@ -0,0 +1,1410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
from dataclasses import fields, is_dataclass, replace
|
| 11 |
+
from datetime import datetime, timedelta
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import click
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 21 |
+
from torch.utils.data import DataLoader
|
| 22 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 23 |
+
|
| 24 |
+
from unisharp.datasets.re10k import Re10KDataset, re10k_collate, re10k_passthrough
|
| 25 |
+
from unisharp.datasets.wildrgbd import WildRGBDDataset, wildrgbd_collate
|
| 26 |
+
from unisharp.datasets.dl3dv import DL3DVDataset
|
| 27 |
+
from unisharp.datasets.scannetpp_fisheye import ScannetppFisheyeDataset, scannetpp_fisheye_passthrough
|
| 28 |
+
from unisharp.datasets.sim_panorama import SimPanoramaDataset
|
| 29 |
+
from unisharp.datasets.panogs import PanOGSDataset, panogs_collate
|
| 30 |
+
from unisharp.losses import UnisharpLoss, UnisharpLossWeights
|
| 31 |
+
from unisharp.models.unisharp_feature import UnisharpFeatureModel, UnisharpFeatureConfig
|
| 32 |
+
from unisharp.utils import logging as logging_utils
|
| 33 |
+
from unisharp import DEFAULT_MAX_DEPTH_M
|
| 34 |
+
from unisharp.utils.gsplat import GSplatRenderer
|
| 35 |
+
from unisharp.utils.io import save_image
|
| 36 |
+
from unisharp.utils.rayfit_camera import scale_pinhole_intrinsics
|
| 37 |
+
from unisharp.utils.unified_vis import save_pair_visualization
|
| 38 |
+
|
| 39 |
+
from .mixed_sampler import LazyDataLoaderIterator, MixedDatasetSampler # type: ignore[import]
|
| 40 |
+
from .train_utils import warmup_cosine_lr # type: ignore[import]
|
| 41 |
+
|
| 42 |
+
LOGGER = logging.getLogger(__name__)
|
| 43 |
+
REPO_ROOT = Path(__file__).resolve().parents[2]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _default_dataset_manifest_file(name: str) -> Path:
|
| 47 |
+
parent_path = REPO_ROOT.parent / "dataset_manifests" / name
|
| 48 |
+
if parent_path.exists():
|
| 49 |
+
return parent_path
|
| 50 |
+
return REPO_ROOT / "dataset_manifests" / name
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
DEFAULT_WILDRGBD_ROOTS_FILE = _default_dataset_manifest_file("wildrgbd_roots.txt")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _multiple_aligned_hw(hw: tuple[int, int], multiple: int) -> tuple[int, int]:
|
| 57 |
+
h, w = int(hw[0]), int(hw[1])
|
| 58 |
+
m = int(multiple)
|
| 59 |
+
if m <= 1:
|
| 60 |
+
return h, w
|
| 61 |
+
out_h = max(m, (h // m) * m)
|
| 62 |
+
out_w = max(m, (w // m) * m)
|
| 63 |
+
return min(out_h, h), min(out_w, w)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _erp_multiple_aligned_hw(hw: tuple[int, int], multiple: int) -> tuple[int, int]:
|
| 67 |
+
h, w = int(hw[0]), int(hw[1])
|
| 68 |
+
m = int(multiple)
|
| 69 |
+
if m <= 1:
|
| 70 |
+
return h, w
|
| 71 |
+
max_h_from_h = h // m
|
| 72 |
+
max_h_from_w = w // (2 * m)
|
| 73 |
+
h_units = min(max_h_from_h, max_h_from_w)
|
| 74 |
+
if h_units <= 0:
|
| 75 |
+
return h, w
|
| 76 |
+
out_h = h_units * m
|
| 77 |
+
return out_h, 2 * out_h
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _resize_chw_tensor(x: torch.Tensor, dst_hw: tuple[int, int], *, kind: str) -> torch.Tensor:
|
| 81 |
+
if not torch.is_tensor(x) or x.ndim < 3:
|
| 82 |
+
return x
|
| 83 |
+
src_hw = (int(x.shape[-2]), int(x.shape[-1]))
|
| 84 |
+
if src_hw == tuple(int(v) for v in dst_hw):
|
| 85 |
+
return x
|
| 86 |
+
orig_dtype = x.dtype
|
| 87 |
+
flat = x.reshape(-1, int(x.shape[-3]), src_hw[0], src_hw[1]).to(dtype=torch.float32)
|
| 88 |
+
if kind == "image":
|
| 89 |
+
y = F.interpolate(flat, size=dst_hw, mode="bilinear", align_corners=False)
|
| 90 |
+
y = y.round().clamp(0.0, 255.0).to(dtype=orig_dtype) if orig_dtype == torch.uint8 else y.to(dtype=orig_dtype)
|
| 91 |
+
elif kind == "ray":
|
| 92 |
+
y = F.interpolate(flat, size=dst_hw, mode="bilinear", align_corners=False)
|
| 93 |
+
y = y / torch.linalg.vector_norm(y, dim=1, keepdim=True).clamp(min=1e-6)
|
| 94 |
+
y = y.to(dtype=orig_dtype)
|
| 95 |
+
else:
|
| 96 |
+
y = F.interpolate(flat, size=dst_hw, mode="nearest").to(dtype=orig_dtype)
|
| 97 |
+
return y.reshape(*x.shape[:-2], int(dst_hw[0]), int(dst_hw[1])).contiguous()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _resize_cube_tensor(x: torch.Tensor, dst_hw: tuple[int, int], *, kind: str) -> torch.Tensor:
|
| 101 |
+
if not torch.is_tensor(x) or x.ndim < 4:
|
| 102 |
+
return x
|
| 103 |
+
src_hw = (int(x.shape[-3]), int(x.shape[-2]))
|
| 104 |
+
if src_hw == tuple(int(v) for v in dst_hw):
|
| 105 |
+
return x
|
| 106 |
+
orig_dtype = x.dtype
|
| 107 |
+
channels = int(x.shape[-1])
|
| 108 |
+
flat = x.reshape(-1, src_hw[0], src_hw[1], channels).permute(0, 3, 1, 2).to(dtype=torch.float32)
|
| 109 |
+
if kind == "image":
|
| 110 |
+
y = F.interpolate(flat, size=dst_hw, mode="bilinear", align_corners=False)
|
| 111 |
+
y = y.round().clamp(0.0, 255.0).to(dtype=orig_dtype) if orig_dtype == torch.uint8 else y.to(dtype=orig_dtype)
|
| 112 |
+
else:
|
| 113 |
+
y = F.interpolate(flat, size=dst_hw, mode="nearest").to(dtype=orig_dtype)
|
| 114 |
+
y = y.permute(0, 2, 3, 1)
|
| 115 |
+
return y.reshape(*x.shape[:-3], int(dst_hw[0]), int(dst_hw[1]), channels).contiguous()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _training_batch_src_hw(batch: Any) -> tuple[int, int] | None:
|
| 119 |
+
for name in ("src_rgb_u8", "src_erp_rgb_u8"):
|
| 120 |
+
value = getattr(batch, name, None)
|
| 121 |
+
if torch.is_tensor(value) and value.ndim >= 3:
|
| 122 |
+
return int(value.shape[-2]), int(value.shape[-1])
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _scale_fisheye624_params_any(params: torch.Tensor, *, src_hw: tuple[int, int], dst_hw: tuple[int, int]) -> torch.Tensor:
|
| 127 |
+
if tuple(int(x) for x in src_hw) == tuple(int(x) for x in dst_hw):
|
| 128 |
+
return params
|
| 129 |
+
src_h, src_w = int(src_hw[0]), int(src_hw[1])
|
| 130 |
+
dst_h, dst_w = int(dst_hw[0]), int(dst_hw[1])
|
| 131 |
+
sx = float(dst_w) / float(max(src_w, 1))
|
| 132 |
+
sy = float(dst_h) / float(max(src_h, 1))
|
| 133 |
+
out = params.clone()
|
| 134 |
+
out[..., 0] *= sx
|
| 135 |
+
out[..., 1] *= sy
|
| 136 |
+
out[..., 2] = (out[..., 2] + 0.5) * sx - 0.5
|
| 137 |
+
out[..., 3] = (out[..., 3] + 0.5) * sy - 0.5
|
| 138 |
+
return out
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _resize_training_batch_to_multiple(batch: Any, multiple: int) -> Any:
|
| 142 |
+
if int(multiple) <= 1 or not is_dataclass(batch):
|
| 143 |
+
return batch
|
| 144 |
+
src_hw = _training_batch_src_hw(batch)
|
| 145 |
+
if src_hw is None:
|
| 146 |
+
return batch
|
| 147 |
+
|
| 148 |
+
def _view_hw(prefix: str) -> tuple[int, int] | None:
|
| 149 |
+
for rgb_name in (f"{prefix}_rgb_u8", f"{prefix}_erp_rgb_u8"):
|
| 150 |
+
rgb = getattr(batch, rgb_name, None)
|
| 151 |
+
if torch.is_tensor(rgb) and rgb.ndim >= 3:
|
| 152 |
+
return int(rgb.shape[-2]), int(rgb.shape[-1])
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
def _aligned_view_hw(prefix: str, hw: tuple[int, int]) -> tuple[int, int]:
|
| 156 |
+
is_view_erp = torch.is_tensor(getattr(batch, f"{prefix}_erp_rgb_u8", None))
|
| 157 |
+
return (
|
| 158 |
+
_erp_multiple_aligned_hw(hw, int(multiple))
|
| 159 |
+
if bool(is_view_erp)
|
| 160 |
+
else _multiple_aligned_hw(hw, int(multiple))
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def _field_dst_hw(name: str, value: torch.Tensor) -> tuple[int, int]:
|
| 164 |
+
prefix = "tgt" if name.startswith("tgt_") else "src"
|
| 165 |
+
view_hw = _view_hw(prefix)
|
| 166 |
+
if view_hw is not None:
|
| 167 |
+
return _aligned_view_hw(prefix, view_hw)
|
| 168 |
+
hw = (int(value.shape[-2]), int(value.shape[-1]))
|
| 169 |
+
return _erp_multiple_aligned_hw(hw, int(multiple)) if "_erp_" in name else _multiple_aligned_hw(hw, int(multiple))
|
| 170 |
+
|
| 171 |
+
updates: dict[str, Any] = {}
|
| 172 |
+
for field in fields(batch):
|
| 173 |
+
name = field.name
|
| 174 |
+
value = getattr(batch, name)
|
| 175 |
+
if not torch.is_tensor(value):
|
| 176 |
+
continue
|
| 177 |
+
if name.endswith("_rgb_u8") and value.ndim >= 3:
|
| 178 |
+
if "_cube_" in name:
|
| 179 |
+
cube_hw = _multiple_aligned_hw((int(value.shape[-3]), int(value.shape[-2])), int(multiple))
|
| 180 |
+
updates[name] = _resize_cube_tensor(value, cube_hw, kind="image")
|
| 181 |
+
else:
|
| 182 |
+
updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="image")
|
| 183 |
+
elif name.endswith("_depth_m") and value.ndim >= 3:
|
| 184 |
+
if "_cube_" in name:
|
| 185 |
+
cube_hw = _multiple_aligned_hw((int(value.shape[-3]), int(value.shape[-2])), int(multiple))
|
| 186 |
+
updates[name] = _resize_cube_tensor(value, cube_hw, kind="depth")
|
| 187 |
+
else:
|
| 188 |
+
updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="depth")
|
| 189 |
+
elif name.endswith("_valid_mask") and value.ndim >= 3:
|
| 190 |
+
updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="depth")
|
| 191 |
+
elif name.endswith("_rays") and value.ndim >= 3:
|
| 192 |
+
updates[name] = _resize_chw_tensor(value, _field_dst_hw(name, value), kind="ray")
|
| 193 |
+
|
| 194 |
+
for intr_name in ("src_intrinsics", "tgt_intrinsics"):
|
| 195 |
+
intr = getattr(batch, intr_name, None)
|
| 196 |
+
if torch.is_tensor(intr):
|
| 197 |
+
prefix = "tgt" if intr_name.startswith("tgt_") else "src"
|
| 198 |
+
view_hw = _view_hw(prefix)
|
| 199 |
+
if view_hw is not None:
|
| 200 |
+
updates[intr_name] = scale_pinhole_intrinsics(
|
| 201 |
+
intr,
|
| 202 |
+
src_hw=view_hw,
|
| 203 |
+
dst_hw=_aligned_view_hw(prefix, view_hw),
|
| 204 |
+
)
|
| 205 |
+
for params_name in ("src_camera_params", "tgt_camera_params"):
|
| 206 |
+
params = getattr(batch, params_name, None)
|
| 207 |
+
if torch.is_tensor(params):
|
| 208 |
+
prefix = "tgt" if params_name.startswith("tgt_") else "src"
|
| 209 |
+
view_hw = _view_hw(prefix)
|
| 210 |
+
if view_hw is not None:
|
| 211 |
+
updates[params_name] = _scale_fisheye624_params_any(
|
| 212 |
+
params,
|
| 213 |
+
src_hw=view_hw,
|
| 214 |
+
dst_hw=_aligned_view_hw(prefix, view_hw),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return replace(batch, **updates) if updates else batch
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _build_optimizer_param_groups(
|
| 221 |
+
raw_model: UnisharpFeatureModel,
|
| 222 |
+
) -> tuple[list[torch.nn.Parameter], list[torch.nn.Parameter], list[torch.nn.Parameter]]:
|
| 223 |
+
base_params: list[torch.nn.Parameter] = []
|
| 224 |
+
unik3d_encoder_params: list[torch.nn.Parameter] = []
|
| 225 |
+
unik3d_decoder_params: list[torch.nn.Parameter] = []
|
| 226 |
+
for name, param in raw_model.named_parameters():
|
| 227 |
+
if not param.requires_grad:
|
| 228 |
+
continue
|
| 229 |
+
if name.startswith("feature_extractor.unik3d.pixel_encoder."):
|
| 230 |
+
unik3d_encoder_params.append(param)
|
| 231 |
+
elif name.startswith("second_layer_depth_head."):
|
| 232 |
+
unik3d_decoder_params.append(param)
|
| 233 |
+
elif name.startswith("feature_extractor.unik3d."):
|
| 234 |
+
unik3d_decoder_params.append(param)
|
| 235 |
+
else:
|
| 236 |
+
base_params.append(param)
|
| 237 |
+
return base_params, unik3d_encoder_params, unik3d_decoder_params
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _count_numel(params: list[torch.nn.Parameter]) -> int:
|
| 241 |
+
return int(sum(int(p.numel()) for p in params))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def _configure_torchhub_cache() -> Path:
|
| 245 |
+
torchhub_dir = REPO_ROOT / "checkpoints" / "torchhub"
|
| 246 |
+
torchhub_dir.mkdir(parents=True, exist_ok=True)
|
| 247 |
+
os.environ["TORCH_HOME"] = str(torchhub_dir)
|
| 248 |
+
torch.hub.set_dir(str(torchhub_dir))
|
| 249 |
+
return torchhub_dir
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _ddp_is_enabled() -> bool:
|
| 253 |
+
return int(os.environ.get("WORLD_SIZE", "1")) > 1
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _ddp_setup(device: str, ddp_timeout_hours: float = 8.0) -> tuple[torch.device, int, int, bool]:
|
| 257 |
+
if not _ddp_is_enabled():
|
| 258 |
+
dev = torch.device(device)
|
| 259 |
+
return dev, 0, 1, True
|
| 260 |
+
|
| 261 |
+
if device != "cuda":
|
| 262 |
+
raise RuntimeError("DDP currently supports CUDA only.")
|
| 263 |
+
if not torch.cuda.is_available():
|
| 264 |
+
raise RuntimeError("CUDA not available.")
|
| 265 |
+
|
| 266 |
+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
| 267 |
+
rank = int(os.environ.get("RANK", "0"))
|
| 268 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
| 269 |
+
torch.cuda.set_device(local_rank)
|
| 270 |
+
timeout_hours = max(float(ddp_timeout_hours), 0.25)
|
| 271 |
+
if rank == 0:
|
| 272 |
+
print(
|
| 273 |
+
"[ddp_setup] init_process_group backend=nccl "
|
| 274 |
+
f"world_size={world_size} NCCL_NET={os.environ.get('NCCL_NET', '<unset>')} "
|
| 275 |
+
f"NCCL_IB_DISABLE={os.environ.get('NCCL_IB_DISABLE', '<unset>')}",
|
| 276 |
+
flush=True,
|
| 277 |
+
)
|
| 278 |
+
dist.init_process_group(backend="nccl", timeout=timedelta(hours=timeout_hours))
|
| 279 |
+
if rank == 0:
|
| 280 |
+
print("[ddp_setup] init_process_group done", flush=True)
|
| 281 |
+
dev = torch.device("cuda", local_rank)
|
| 282 |
+
return dev, rank, world_size, (rank == 0)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _ddp_broadcast_path(p: Path, is_main: bool) -> Path:
|
| 286 |
+
if not _ddp_is_enabled():
|
| 287 |
+
return p
|
| 288 |
+
obj_list: list[str] = [str(p) if is_main else ""]
|
| 289 |
+
dist.broadcast_object_list(obj_list, src=0)
|
| 290 |
+
return Path(obj_list[0])
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def _ddp_broadcast_str(value: str, is_main: bool) -> str:
|
| 294 |
+
if not _ddp_is_enabled():
|
| 295 |
+
return value
|
| 296 |
+
obj_list: list[str] = [str(value) if is_main else ""]
|
| 297 |
+
dist.broadcast_object_list(obj_list, src=0)
|
| 298 |
+
return str(obj_list[0])
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _ddp_any_bool(flag: bool, device: torch.device) -> bool:
|
| 302 |
+
if not _ddp_is_enabled():
|
| 303 |
+
return bool(flag)
|
| 304 |
+
x = torch.tensor(1 if flag else 0, device=device, dtype=torch.int32)
|
| 305 |
+
dist.all_reduce(x, op=dist.ReduceOp.MAX)
|
| 306 |
+
return bool(int(x.item()) != 0)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _env_flag(name: str, default: bool = False) -> bool:
|
| 310 |
+
raw = os.environ.get(name)
|
| 311 |
+
if raw is None:
|
| 312 |
+
return bool(default)
|
| 313 |
+
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _is_oom_exception(exc: BaseException) -> bool:
|
| 317 |
+
if isinstance(exc, torch.cuda.OutOfMemoryError):
|
| 318 |
+
return True
|
| 319 |
+
msg = str(exc).lower()
|
| 320 |
+
oom_markers = (
|
| 321 |
+
"out of memory",
|
| 322 |
+
"cuda error: out of memory",
|
| 323 |
+
"cublas_status_alloc_failed",
|
| 324 |
+
"cudnn_status_alloc_failed",
|
| 325 |
+
"defaultcpuallocator",
|
| 326 |
+
)
|
| 327 |
+
return any(marker in msg for marker in oom_markers)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def _ddp_barrier(device: torch.device) -> None:
|
| 331 |
+
if not _ddp_is_enabled():
|
| 332 |
+
return
|
| 333 |
+
if device.type == "cuda" and device.index is not None:
|
| 334 |
+
dist.barrier(device_ids=[device.index])
|
| 335 |
+
else:
|
| 336 |
+
dist.barrier()
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def _maybe_set_dataset_epoch(dataset: Any, epoch: int) -> None:
|
| 340 |
+
set_epoch = getattr(dataset, "set_epoch", None)
|
| 341 |
+
if callable(set_epoch):
|
| 342 |
+
set_epoch(int(epoch))
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def _ddp_mean(x: torch.Tensor) -> torch.Tensor:
|
| 346 |
+
if not _ddp_is_enabled():
|
| 347 |
+
return x
|
| 348 |
+
y = x.detach().clone()
|
| 349 |
+
dist.all_reduce(y, op=dist.ReduceOp.SUM)
|
| 350 |
+
y = y / float(dist.get_world_size())
|
| 351 |
+
return y
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def _save_train_vis(
|
| 355 |
+
out_dir: Path,
|
| 356 |
+
step: int,
|
| 357 |
+
src_gt: torch.Tensor,
|
| 358 |
+
src_pred: torch.Tensor,
|
| 359 |
+
src_alpha: torch.Tensor,
|
| 360 |
+
tgt_gt: torch.Tensor,
|
| 361 |
+
tgt_pred: torch.Tensor,
|
| 362 |
+
tgt_alpha: torch.Tensor,
|
| 363 |
+
src_gt_depth: torch.Tensor | None = None,
|
| 364 |
+
tgt_gt_depth: torch.Tensor | None = None,
|
| 365 |
+
src_pred_depth: torch.Tensor | None = None,
|
| 366 |
+
tgt_pred_depth: torch.Tensor | None = None,
|
| 367 |
+
src_unik3d_depth: torch.Tensor | None = None,
|
| 368 |
+
tgt_unik3d_depth: torch.Tensor | None = None,
|
| 369 |
+
dataset_name: str | None = None,
|
| 370 |
+
scene: str | None = None,
|
| 371 |
+
src_idx: int | None = None,
|
| 372 |
+
tgt_idx: int | None = None,
|
| 373 |
+
src_pose_w2c: torch.Tensor | None = None,
|
| 374 |
+
tgt_pose_w2c: torch.Tensor | None = None,
|
| 375 |
+
src_metric_mask: torch.Tensor | None = None,
|
| 376 |
+
tgt_metric_mask: torch.Tensor | None = None,
|
| 377 |
+
src_cube_gt_u8: torch.Tensor | None = None,
|
| 378 |
+
src_cube_pred_linear: torch.Tensor | None = None,
|
| 379 |
+
src_cube_alpha: torch.Tensor | None = None,
|
| 380 |
+
tgt_cube_gt_u8: torch.Tensor | None = None,
|
| 381 |
+
tgt_cube_pred_linear: torch.Tensor | None = None,
|
| 382 |
+
tgt_cube_alpha: torch.Tensor | None = None,
|
| 383 |
+
) -> None:
|
| 384 |
+
vis_dir = out_dir / "vis"
|
| 385 |
+
vis_dir.mkdir(parents=True, exist_ok=True)
|
| 386 |
+
LOGGER.info("Saving train visualization: %s", str(vis_dir / f"step_{int(step):07d}.png"))
|
| 387 |
+
save_pair_visualization(
|
| 388 |
+
vis_dir / f"step_{int(step):07d}.png",
|
| 389 |
+
src_gt=src_gt,
|
| 390 |
+
src_pred=src_pred,
|
| 391 |
+
src_alpha=src_alpha,
|
| 392 |
+
tgt_gt=tgt_gt,
|
| 393 |
+
tgt_pred=tgt_pred,
|
| 394 |
+
tgt_alpha=tgt_alpha,
|
| 395 |
+
src_gt_depth=src_gt_depth,
|
| 396 |
+
tgt_gt_depth=tgt_gt_depth,
|
| 397 |
+
src_pred_depth=src_pred_depth,
|
| 398 |
+
tgt_pred_depth=tgt_pred_depth,
|
| 399 |
+
src_unik3d_depth=src_unik3d_depth,
|
| 400 |
+
tgt_unik3d_depth=tgt_unik3d_depth,
|
| 401 |
+
dataset_name=dataset_name,
|
| 402 |
+
scene=scene,
|
| 403 |
+
step=int(step),
|
| 404 |
+
src_idx=src_idx,
|
| 405 |
+
tgt_idx=tgt_idx,
|
| 406 |
+
src_pose_w2c=src_pose_w2c,
|
| 407 |
+
tgt_pose_w2c=tgt_pose_w2c,
|
| 408 |
+
src_cube_gt_u8=src_cube_gt_u8,
|
| 409 |
+
src_cube_pred_linear=src_cube_pred_linear,
|
| 410 |
+
src_cube_alpha=src_cube_alpha,
|
| 411 |
+
tgt_cube_gt_u8=tgt_cube_gt_u8,
|
| 412 |
+
tgt_cube_pred_linear=tgt_cube_pred_linear,
|
| 413 |
+
tgt_cube_alpha=tgt_cube_alpha,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _read_nonempty_lines(path: Path) -> list[str]:
|
| 418 |
+
return [line.strip() for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def _resolve_manifest_file(manifest_dir: Path | None, filename: str) -> Path | None:
|
| 422 |
+
if manifest_dir is None:
|
| 423 |
+
return None
|
| 424 |
+
path = Path(manifest_dir) / filename
|
| 425 |
+
return path if path.exists() else None
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
@click.command()
|
| 429 |
+
@click.option("--data-root-re10k", type=click.Path(path_type=Path, exists=True), default=None)
|
| 430 |
+
@click.option("--data-root-hm3d", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/panogs"))
|
| 431 |
+
@click.option("--data-root-sim", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/smx_sim"))
|
| 432 |
+
@click.option("--sim-pose-root", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/smx_sim/30cm"))
|
| 433 |
+
@click.option("--data-root-wildrgbd", type=click.Path(path_type=Path, exists=True), default=None)
|
| 434 |
+
@click.option("--wild-roots-file", type=click.Path(path_type=Path, exists=True, dir_okay=False), default=DEFAULT_WILDRGBD_ROOTS_FILE)
|
| 435 |
+
@click.option("--data-root-dl3dv", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/sharp/DL3DV-ALL-960P"))
|
| 436 |
+
@click.option("--data-root-dl3dv-depth", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/sharp/DL3DV-ALL-960P_da3_outputs"))
|
| 437 |
+
@click.option("--data-root-scanetpp", type=click.Path(path_type=Path, exists=True), default=Path("/media/team_data/ML4_team/datasets/scan"))
|
| 438 |
+
@click.option("--dataset-manifest-dir", type=click.Path(path_type=Path, file_okay=False), default=None)
|
| 439 |
+
@click.option("--out-root", type=click.Path(path_type=Path, file_okay=False), required=True)
|
| 440 |
+
@click.option("--run-name", type=str, default=None)
|
| 441 |
+
@click.option("--steps", type=int, default=1000000)
|
| 442 |
+
@click.option("--batch-size", type=int, default=2)
|
| 443 |
+
@click.option("--num-workers", type=int, default=1)
|
| 444 |
+
@click.option("--warmup", type=int, default=75000)
|
| 445 |
+
@click.option("--lr0", type=float, default=1.2e-4)
|
| 446 |
+
@click.option("--lr1", type=float, default=1.6e-5)
|
| 447 |
+
@click.option("--unik3d-lr0", type=float, default=2.5e-5, help="UniK3D decoder/head peak LR.")
|
| 448 |
+
@click.option("--unik3d-lr1", type=float, default=2.5e-6, help="UniK3D decoder/head final LR.")
|
| 449 |
+
@click.option("--unik3d-encoder-lr0", type=float, default=1.5e-6, help="UniK3D pixel_encoder peak LR.")
|
| 450 |
+
@click.option("--unik3d-encoder-lr1", type=float, default=1.5e-7, help="UniK3D pixel_encoder final LR.")
|
| 451 |
+
@click.option("--grad-clip-norm", type=float, default=1.0, show_default=True)
|
| 452 |
+
@click.option("--max-step-grad-norm", type=float, default=100000.0, show_default=True, help="Skip optimizer step when pre-clip grad norm exceeds this value. 0 disables.")
|
| 453 |
+
@click.option("--max-depth-m", type=float, default=DEFAULT_MAX_DEPTH_M, show_default=True)
|
| 454 |
+
@click.option("--sim-far-depth-invalid-m", type=float, default=30.0, show_default=True)
|
| 455 |
+
@click.option("--sim-far-depth-invalid-max-frac", type=float, default=1.0, show_default=True)
|
| 456 |
+
@click.option("--sim-max-long-edge", type=int, default=512, show_default=True, help="Resize SIM ERP frames before cubemap conversion. 0 keeps native resolution.")
|
| 457 |
+
@click.option("--train-resize-multiple", type=int, default=256, show_default=True, help="Before model forward, downsize training inputs to the largest H/W divisible by this value. 0 disables.")
|
| 458 |
+
@click.option("--pinhole-train-size", type=int, default=0, show_default=True, help="Resize pinhole training datasets to NxN before model forward. 0 keeps dataset native resolution.")
|
| 459 |
+
@click.option("--scanetpp-fisheye-far-depth-invalid-m", type=float, default=30.0, show_default=True)
|
| 460 |
+
@click.option("--max-index-gap", type=int, default=10)
|
| 461 |
+
@click.option("--device", type=str, default="cuda")
|
| 462 |
+
@click.option("--render-low-pass-filter-eps", type=float, default=1e-2, show_default=True)
|
| 463 |
+
@click.option("--ddp-timeout-hours", type=float, default=8.0)
|
| 464 |
+
@click.option("--save-every", type=int, default=5000)
|
| 465 |
+
@click.option("--log-every", type=int, default=50)
|
| 466 |
+
@click.option("--vis-every", type=int, default=500)
|
| 467 |
+
@click.option("--unik3d-backbone", type=click.Choice(["vitb", "vitl"]), default="vitl")
|
| 468 |
+
@click.option("--unik3d-resolution-level", type=click.IntRange(0, 9), default=0, show_default=True)
|
| 469 |
+
@click.option("--initializer-stride", type=click.IntRange(1, 2), default=1)
|
| 470 |
+
@click.option("--initializer-scale-factor", type=float, default=1.5, show_default=True)
|
| 471 |
+
@click.option("--lambda-aux-ray", type=float, default=3.0)
|
| 472 |
+
@click.option("--lambda-aux-depth-scale", type=float, default=3.0)
|
| 473 |
+
@click.option("--lambda-aux-depth2-scale", type=float, default=1.0)
|
| 474 |
+
@click.option("--lambda-color", type=float, default=1.0)
|
| 475 |
+
@click.option("--lambda-alpha", type=float, default=1.5)
|
| 476 |
+
@click.option("--alpha-tail-min", type=float, default=0.99, show_default=True, help="Alpha value below which local tail coverage loss is applied.")
|
| 477 |
+
@click.option("--alpha-tail-weight", type=float, default=0.0, show_default=True, help="Extra normalized tail weight for local low-alpha holes.")
|
| 478 |
+
@click.option("--lambda-percep", type=float, default=1.0)
|
| 479 |
+
@click.option("--lambda-depth", type=float, default=0.5)
|
| 480 |
+
@click.option("--lambda-tv", type=float, default=1.0)
|
| 481 |
+
@click.option("--lambda-grad", type=float, default=1.0)
|
| 482 |
+
@click.option("--lambda-grad-img", type=float, default=0.2)
|
| 483 |
+
@click.option("--lambda-edge-rgb", type=float, default=0.0, show_default=True, help="Weight for GT RGB edge-band gradient matching.")
|
| 484 |
+
@click.option("--lambda-delta", type=float, default=1.0)
|
| 485 |
+
@click.option("--lambda-delta-rho", type=float, default=0.01, show_default=True)
|
| 486 |
+
@click.option("--lambda-splat", type=float, default=1.0)
|
| 487 |
+
@click.option("--lambda-edge-splat", type=float, default=0.0, show_default=True, help="Weight for stricter projected-sigma penalty on GT depth-edge bands.")
|
| 488 |
+
@click.option("--lambda-grid", type=float, default=0.05, show_default=True, help="Weight for Gaussian-grid 2x2 checkerboard residual regularization.")
|
| 489 |
+
@click.option("--delta-clip", type=float, default=10.0, show_default=True)
|
| 490 |
+
@click.option("--raw-delta-clip", type=float, default=400.0, show_default=True)
|
| 491 |
+
@click.option("--raw-delta-rho-clip", type=float, default=5.0, show_default=True)
|
| 492 |
+
@click.option("--delta-rho-limit", type=float, default=2.0, show_default=True)
|
| 493 |
+
@click.option("--splat-sigma-min", type=float, default=1e-1, show_default=True, help="Minimum projected screen-space variance for L_splat.")
|
| 494 |
+
@click.option("--splat-sigma-max", type=float, default=1e2, show_default=True, help="Maximum projected screen-space variance for L_splat.")
|
| 495 |
+
@click.option("--edge-splat-sigma-max", type=float, default=2.0, show_default=True, help="Maximum projected variance on depth-edge bands for L_edge_splat.")
|
| 496 |
+
@click.option("--depth-edge-log-threshold", type=float, default=0.05, show_default=True, help="Log-depth jump threshold used to build L_edge_splat edge bands.")
|
| 497 |
+
@click.option("--depth-edge-dilate-px", type=int, default=2, show_default=True, help="Dilation radius in pixels for L_edge_splat depth-edge bands.")
|
| 498 |
+
@click.option("--target-mask-erode-px", type=int, default=0, show_default=True, help="Erode source-visible target masks by this many pixels before target supervision.")
|
| 499 |
+
@click.option("--dataset-weight-re10k", type=float, default=1.0)
|
| 500 |
+
@click.option("--dataset-weight-hm3d", type=float, default=1.0)
|
| 501 |
+
@click.option("--dataset-weight-sim", type=float, default=1.0)
|
| 502 |
+
@click.option("--dataset-weight-wildrgbd", type=float, default=1.0)
|
| 503 |
+
@click.option("--dataset-weight-dl3dv", type=float, default=1.0)
|
| 504 |
+
@click.option("--dataset-weight-scanetpp", type=float, default=0.0)
|
| 505 |
+
@click.option(
|
| 506 |
+
"--re10k-pseudo-depth-root",
|
| 507 |
+
type=click.Path(path_type=Path, file_okay=False),
|
| 508 |
+
default=Path("/media/team_data/ML4_team/datasets/nopose/re10k_unik3d_pseudo_depth"),
|
| 509 |
+
)
|
| 510 |
+
@click.option("--re10k-pseudo-depth-autogen/--no-re10k-pseudo-depth-autogen", default=True)
|
| 511 |
+
@click.option("--re10k-pseudo-depth-backbone", type=click.Choice(["vitb", "vitl"]), default="vitl")
|
| 512 |
+
@click.option("--re10k-pseudo-depth-device", type=str, default="cpu")
|
| 513 |
+
@click.option("--re10k-pseudo-lock-timeout-sec", type=float, default=120.0)
|
| 514 |
+
@click.option("--re10k-pseudo-lock-stale-sec", type=float, default=1800.0)
|
| 515 |
+
@click.option("--re10k-pseudo-far-depth-invalid-m", type=float, default=30.0)
|
| 516 |
+
@click.option("--seed", type=int, default=None)
|
| 517 |
+
@click.option("-v", "--verbose", is_flag=True)
|
| 518 |
+
def train_feature_cli(
|
| 519 |
+
data_root_re10k: Path | None,
|
| 520 |
+
data_root_hm3d: Path | None,
|
| 521 |
+
data_root_sim: Path | None,
|
| 522 |
+
sim_pose_root: Path | None,
|
| 523 |
+
data_root_wildrgbd: Path | None,
|
| 524 |
+
wild_roots_file: Path,
|
| 525 |
+
data_root_dl3dv: Path | None,
|
| 526 |
+
data_root_dl3dv_depth: Path | None,
|
| 527 |
+
data_root_scanetpp: Path | None,
|
| 528 |
+
dataset_manifest_dir: Path | None,
|
| 529 |
+
out_root: Path,
|
| 530 |
+
run_name: str | None,
|
| 531 |
+
steps: int,
|
| 532 |
+
batch_size: int,
|
| 533 |
+
num_workers: int,
|
| 534 |
+
warmup: int,
|
| 535 |
+
lr0: float,
|
| 536 |
+
lr1: float,
|
| 537 |
+
unik3d_lr0: float,
|
| 538 |
+
unik3d_lr1: float,
|
| 539 |
+
unik3d_encoder_lr0: float,
|
| 540 |
+
unik3d_encoder_lr1: float,
|
| 541 |
+
grad_clip_norm: float,
|
| 542 |
+
max_step_grad_norm: float,
|
| 543 |
+
max_depth_m: float,
|
| 544 |
+
sim_far_depth_invalid_m: float,
|
| 545 |
+
sim_far_depth_invalid_max_frac: float,
|
| 546 |
+
sim_max_long_edge: int,
|
| 547 |
+
train_resize_multiple: int,
|
| 548 |
+
pinhole_train_size: int,
|
| 549 |
+
scanetpp_fisheye_far_depth_invalid_m: float,
|
| 550 |
+
max_index_gap: int,
|
| 551 |
+
device: str,
|
| 552 |
+
render_low_pass_filter_eps: float,
|
| 553 |
+
ddp_timeout_hours: float,
|
| 554 |
+
save_every: int,
|
| 555 |
+
log_every: int,
|
| 556 |
+
vis_every: int,
|
| 557 |
+
unik3d_backbone: str,
|
| 558 |
+
unik3d_resolution_level: int,
|
| 559 |
+
initializer_stride: int,
|
| 560 |
+
initializer_scale_factor: float,
|
| 561 |
+
lambda_aux_ray: float,
|
| 562 |
+
lambda_aux_depth_scale: float,
|
| 563 |
+
lambda_aux_depth2_scale: float,
|
| 564 |
+
lambda_color: float,
|
| 565 |
+
lambda_alpha: float,
|
| 566 |
+
alpha_tail_min: float,
|
| 567 |
+
alpha_tail_weight: float,
|
| 568 |
+
lambda_percep: float,
|
| 569 |
+
lambda_depth: float,
|
| 570 |
+
lambda_tv: float,
|
| 571 |
+
lambda_grad: float,
|
| 572 |
+
lambda_grad_img: float,
|
| 573 |
+
lambda_edge_rgb: float,
|
| 574 |
+
lambda_delta: float,
|
| 575 |
+
lambda_delta_rho: float,
|
| 576 |
+
lambda_splat: float,
|
| 577 |
+
lambda_edge_splat: float,
|
| 578 |
+
lambda_grid: float,
|
| 579 |
+
delta_clip: float,
|
| 580 |
+
raw_delta_clip: float,
|
| 581 |
+
raw_delta_rho_clip: float,
|
| 582 |
+
delta_rho_limit: float,
|
| 583 |
+
splat_sigma_min: float,
|
| 584 |
+
splat_sigma_max: float,
|
| 585 |
+
edge_splat_sigma_max: float,
|
| 586 |
+
depth_edge_log_threshold: float,
|
| 587 |
+
depth_edge_dilate_px: int,
|
| 588 |
+
target_mask_erode_px: int,
|
| 589 |
+
dataset_weight_re10k: float,
|
| 590 |
+
dataset_weight_hm3d: float,
|
| 591 |
+
dataset_weight_sim: float,
|
| 592 |
+
dataset_weight_wildrgbd: float,
|
| 593 |
+
dataset_weight_dl3dv: float,
|
| 594 |
+
dataset_weight_scanetpp: float,
|
| 595 |
+
re10k_pseudo_depth_root: Path,
|
| 596 |
+
re10k_pseudo_depth_autogen: bool,
|
| 597 |
+
re10k_pseudo_depth_backbone: str,
|
| 598 |
+
re10k_pseudo_depth_device: str,
|
| 599 |
+
re10k_pseudo_lock_timeout_sec: float,
|
| 600 |
+
re10k_pseudo_lock_stale_sec: float,
|
| 601 |
+
re10k_pseudo_far_depth_invalid_m: float,
|
| 602 |
+
seed: int | None,
|
| 603 |
+
verbose: bool,
|
| 604 |
+
) -> None:
|
| 605 |
+
detach_init_layer0_distance = True
|
| 606 |
+
|
| 607 |
+
log_level = logging.DEBUG if verbose else logging.INFO
|
| 608 |
+
logging_utils.configure(log_level)
|
| 609 |
+
if float(max_depth_m) <= 0.0:
|
| 610 |
+
raise ValueError("--max-depth-m must be positive.")
|
| 611 |
+
if float(grad_clip_norm) <= 0.0:
|
| 612 |
+
raise ValueError("--grad-clip-norm must be positive.")
|
| 613 |
+
if float(max_step_grad_norm) < 0.0:
|
| 614 |
+
raise ValueError("--max-step-grad-norm must be non-negative.")
|
| 615 |
+
if float(render_low_pass_filter_eps) < 0.0:
|
| 616 |
+
raise ValueError("--render-low-pass-filter-eps must be non-negative.")
|
| 617 |
+
if not (0.0 <= float(sim_far_depth_invalid_max_frac) <= 1.0):
|
| 618 |
+
raise ValueError("--sim-far-depth-invalid-max-frac must be in [0, 1].")
|
| 619 |
+
if int(sim_max_long_edge) < 0:
|
| 620 |
+
raise ValueError("--sim-max-long-edge must be non-negative.")
|
| 621 |
+
if int(train_resize_multiple) < 0:
|
| 622 |
+
raise ValueError("--train-resize-multiple must be non-negative.")
|
| 623 |
+
if int(pinhole_train_size) < 0:
|
| 624 |
+
raise ValueError("--pinhole-train-size must be non-negative.")
|
| 625 |
+
if float(scanetpp_fisheye_far_depth_invalid_m) < 0.0:
|
| 626 |
+
raise ValueError("--scanetpp-fisheye-far-depth-invalid-m must be non-negative.")
|
| 627 |
+
if float(delta_clip) < 0.0:
|
| 628 |
+
raise ValueError("--delta-clip must be non-negative.")
|
| 629 |
+
if float(raw_delta_clip) < 0.0:
|
| 630 |
+
raise ValueError("--raw-delta-clip must be non-negative.")
|
| 631 |
+
if float(raw_delta_rho_clip) < 0.0:
|
| 632 |
+
raise ValueError("--raw-delta-rho-clip must be non-negative.")
|
| 633 |
+
if float(lambda_grid) < 0.0:
|
| 634 |
+
raise ValueError("--lambda-grid must be non-negative.")
|
| 635 |
+
if float(lambda_edge_rgb) < 0.0:
|
| 636 |
+
raise ValueError("--lambda-edge-rgb must be non-negative.")
|
| 637 |
+
if float(lambda_edge_splat) < 0.0:
|
| 638 |
+
raise ValueError("--lambda-edge-splat must be non-negative.")
|
| 639 |
+
if float(edge_splat_sigma_max) < 0.0:
|
| 640 |
+
raise ValueError("--edge-splat-sigma-max must be non-negative.")
|
| 641 |
+
if float(depth_edge_log_threshold) < 0.0:
|
| 642 |
+
raise ValueError("--depth-edge-log-threshold must be non-negative.")
|
| 643 |
+
if int(depth_edge_dilate_px) < 0:
|
| 644 |
+
raise ValueError("--depth-edge-dilate-px must be non-negative.")
|
| 645 |
+
if int(target_mask_erode_px) < 0:
|
| 646 |
+
raise ValueError("--target-mask-erode-px must be non-negative.")
|
| 647 |
+
if not (0.0 <= float(alpha_tail_min) <= 1.0):
|
| 648 |
+
raise ValueError("--alpha-tail-min must be in [0, 1].")
|
| 649 |
+
if float(alpha_tail_weight) < 0.0:
|
| 650 |
+
raise ValueError("--alpha-tail-weight must be non-negative.")
|
| 651 |
+
if float(delta_rho_limit) < 0.0:
|
| 652 |
+
raise ValueError("--delta-rho-limit must be non-negative.")
|
| 653 |
+
if float(splat_sigma_min) < 0.0:
|
| 654 |
+
raise ValueError("--splat-sigma-min must be non-negative.")
|
| 655 |
+
if float(splat_sigma_max) <= float(splat_sigma_min):
|
| 656 |
+
raise ValueError("--splat-sigma-max must be greater than --splat-sigma-min.")
|
| 657 |
+
dev, rank, world_size, is_main = _ddp_setup(device, ddp_timeout_hours=ddp_timeout_hours)
|
| 658 |
+
|
| 659 |
+
if seed is not None:
|
| 660 |
+
s = int(seed)
|
| 661 |
+
random.seed(s + rank)
|
| 662 |
+
np.random.seed(s + rank)
|
| 663 |
+
torch.manual_seed(s + rank)
|
| 664 |
+
if torch.cuda.is_available():
|
| 665 |
+
torch.cuda.manual_seed_all(s + rank)
|
| 666 |
+
|
| 667 |
+
if is_main and (run_name is None or run_name.strip() == ""):
|
| 668 |
+
run_name = f"unified_feature_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 669 |
+
if run_name is None:
|
| 670 |
+
run_name = "unified_feature_ddp"
|
| 671 |
+
out_dir = _ddp_broadcast_path(Path(out_root) / run_name, is_main=is_main)
|
| 672 |
+
logging_utils.configure(log_level)
|
| 673 |
+
if not is_main:
|
| 674 |
+
logging.getLogger().setLevel(logging.WARNING)
|
| 675 |
+
LOGGER.setLevel(logging.WARNING)
|
| 676 |
+
_configure_torchhub_cache()
|
| 677 |
+
re10k_enabled_for_train = bool(float(dataset_weight_re10k) > 0.0)
|
| 678 |
+
hm3d_enabled_for_train = bool(float(dataset_weight_hm3d) > 0.0)
|
| 679 |
+
sim_enabled_for_train = bool(float(dataset_weight_sim) > 0.0)
|
| 680 |
+
dl3dv_enabled_for_train = bool(float(dataset_weight_dl3dv) > 0.0)
|
| 681 |
+
scanetpp_enabled_for_train = bool(float(dataset_weight_scanetpp) > 0.0)
|
| 682 |
+
wild_roots = _read_nonempty_lines(wild_roots_file) if wild_roots_file.exists() else []
|
| 683 |
+
re10k_manifest = _resolve_manifest_file(dataset_manifest_dir, "re10k_train_chunks.txt")
|
| 684 |
+
hm3d_manifest = _resolve_manifest_file(dataset_manifest_dir, "hm3d_train_scenes.txt")
|
| 685 |
+
sim_manifest = _resolve_manifest_file(dataset_manifest_dir, "sim_train_scenes.txt")
|
| 686 |
+
wildrgbd_manifest = _resolve_manifest_file(dataset_manifest_dir, "wildrgbd_train_scenes.txt")
|
| 687 |
+
dl3dv_manifest = _resolve_manifest_file(dataset_manifest_dir, "dl3dv_train_scenes.txt")
|
| 688 |
+
scanetpp_manifest = _resolve_manifest_file(dataset_manifest_dir, "scanetpp_fisheye_train_scenes.txt")
|
| 689 |
+
wildrgbd_enabled_for_train = bool(
|
| 690 |
+
((data_root_wildrgbd is not None) or bool(wild_roots)) and (float(dataset_weight_wildrgbd) > 0.0)
|
| 691 |
+
)
|
| 692 |
+
if re10k_enabled_for_train and data_root_re10k is None:
|
| 693 |
+
raise ValueError("dataset_weight_re10k>0 but --data-root-re10k is not provided.")
|
| 694 |
+
if hm3d_enabled_for_train and data_root_hm3d is None:
|
| 695 |
+
raise ValueError("dataset_weight_hm3d>0 but --data-root-hm3d is not provided.")
|
| 696 |
+
if sim_enabled_for_train and (data_root_sim is None or sim_pose_root is None):
|
| 697 |
+
raise ValueError("dataset_weight_sim>0 but --data-root-sim / --sim-pose-root is missing.")
|
| 698 |
+
if sim_enabled_for_train and sim_manifest is None:
|
| 699 |
+
raise ValueError("dataset_weight_sim>0 but sim_train_scenes.txt is missing from --dataset-manifest-dir.")
|
| 700 |
+
if float(dataset_weight_wildrgbd) > 0.0 and (data_root_wildrgbd is None) and (not wild_roots):
|
| 701 |
+
raise ValueError("dataset_weight_wildrgbd>0 but neither --data-root-wildrgbd nor --wild-roots-file is provided.")
|
| 702 |
+
if dl3dv_enabled_for_train and (data_root_dl3dv is None or data_root_dl3dv_depth is None):
|
| 703 |
+
raise ValueError("dataset_weight_dl3dv>0 but --data-root-dl3dv / --data-root-dl3dv-depth is missing.")
|
| 704 |
+
if scanetpp_enabled_for_train and data_root_scanetpp is None:
|
| 705 |
+
raise ValueError("dataset_weight_scanetpp>0 but --data-root-scanetpp is missing.")
|
| 706 |
+
|
| 707 |
+
if is_main:
|
| 708 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 709 |
+
LOGGER.info(
|
| 710 |
+
"Training start: out=%s branch=gt-override scratch_unik3d_pretrained backbone=%s steps=%d batch=%d",
|
| 711 |
+
str(out_dir),
|
| 712 |
+
str(unik3d_backbone),
|
| 713 |
+
int(steps),
|
| 714 |
+
int(batch_size),
|
| 715 |
+
)
|
| 716 |
+
LOGGER.info(
|
| 717 |
+
"Loss weights: color=%.3g alpha=%.3g depth=%.3g percep=%.3g aux_ray=%.3g aux_depth0=%.3g aux_depth1=%.3g",
|
| 718 |
+
float(lambda_color),
|
| 719 |
+
float(lambda_alpha),
|
| 720 |
+
float(lambda_depth),
|
| 721 |
+
float(lambda_percep),
|
| 722 |
+
float(lambda_aux_ray),
|
| 723 |
+
float(lambda_aux_depth_scale),
|
| 724 |
+
float(lambda_aux_depth2_scale),
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
dataset_seed = int(seed) if seed is not None else 12345
|
| 728 |
+
pinhole_output_h = int(pinhole_train_size) if int(pinhole_train_size) > 0 else None
|
| 729 |
+
pinhole_output_w = int(pinhole_train_size) if int(pinhole_train_size) > 0 else None
|
| 730 |
+
|
| 731 |
+
re10k_ds = None
|
| 732 |
+
if re10k_enabled_for_train:
|
| 733 |
+
re10k_ds = Re10KDataset(
|
| 734 |
+
root=data_root_re10k,
|
| 735 |
+
chunks_file=re10k_manifest,
|
| 736 |
+
split="train",
|
| 737 |
+
min_frame_gap=1,
|
| 738 |
+
max_frame_gap=int(max_index_gap),
|
| 739 |
+
pair_max_translation_m=0.5,
|
| 740 |
+
pair_min_overlap=0.6,
|
| 741 |
+
output_h=pinhole_output_h,
|
| 742 |
+
output_w=pinhole_output_w,
|
| 743 |
+
shuffle_chunk=True,
|
| 744 |
+
shuffle_example=True,
|
| 745 |
+
ddp_rank=rank,
|
| 746 |
+
ddp_world_size=world_size,
|
| 747 |
+
pseudo_depth_root=re10k_pseudo_depth_root,
|
| 748 |
+
pseudo_depth_autogen=bool(re10k_pseudo_depth_autogen),
|
| 749 |
+
pseudo_depth_backbone=str(re10k_pseudo_depth_backbone),
|
| 750 |
+
pseudo_depth_device=str(re10k_pseudo_depth_device),
|
| 751 |
+
pseudo_lock_timeout_sec=float(re10k_pseudo_lock_timeout_sec),
|
| 752 |
+
pseudo_lock_stale_sec=float(re10k_pseudo_lock_stale_sec),
|
| 753 |
+
batch_size_hint=int(batch_size),
|
| 754 |
+
depth_max_m=float(max_depth_m),
|
| 755 |
+
pseudo_far_depth_invalid_m=float(re10k_pseudo_far_depth_invalid_m),
|
| 756 |
+
seed=dataset_seed,
|
| 757 |
+
)
|
| 758 |
+
hm3d_train_root = None
|
| 759 |
+
if data_root_hm3d is not None:
|
| 760 |
+
hm3d_train_root = data_root_hm3d / "train" if (data_root_hm3d / "train").exists() else data_root_hm3d
|
| 761 |
+
|
| 762 |
+
hm3d_ds = None
|
| 763 |
+
if hm3d_enabled_for_train:
|
| 764 |
+
hm3d_ds = PanOGSDataset(
|
| 765 |
+
root=hm3d_train_root,
|
| 766 |
+
index_manifest_path=hm3d_manifest,
|
| 767 |
+
src_tgt_max_index_gap=int(max_index_gap),
|
| 768 |
+
use_cubemap_supervision=True,
|
| 769 |
+
pair_sampling=True,
|
| 770 |
+
pair_max_translation_m=0.5,
|
| 771 |
+
pair_min_depth_overlap=0.6,
|
| 772 |
+
pair_overlap_face_w=64,
|
| 773 |
+
pair_overlap_margin=1.05,
|
| 774 |
+
pair_max_tries=48,
|
| 775 |
+
depth_max_m=float(max_depth_m),
|
| 776 |
+
)
|
| 777 |
+
sim_ds = None
|
| 778 |
+
if sim_enabled_for_train:
|
| 779 |
+
sim_ds = SimPanoramaDataset(
|
| 780 |
+
root=data_root_sim,
|
| 781 |
+
pose_root=sim_pose_root,
|
| 782 |
+
scene_list_file=sim_manifest,
|
| 783 |
+
max_index_gap=int(max_index_gap),
|
| 784 |
+
pair_max_translation_m=0.5,
|
| 785 |
+
pair_min_depth_overlap=0.6,
|
| 786 |
+
pairs_per_chunk=15,
|
| 787 |
+
chunk_size=30,
|
| 788 |
+
shuffle_scene=True,
|
| 789 |
+
ddp_rank=rank,
|
| 790 |
+
ddp_world_size=world_size,
|
| 791 |
+
depth_max_m=float(max_depth_m),
|
| 792 |
+
far_depth_invalid_m=float(sim_far_depth_invalid_m),
|
| 793 |
+
far_depth_invalid_max_frac=float(sim_far_depth_invalid_max_frac),
|
| 794 |
+
max_long_edge=int(sim_max_long_edge),
|
| 795 |
+
seed=dataset_seed,
|
| 796 |
+
)
|
| 797 |
+
wildrgbd_ds = None
|
| 798 |
+
if wildrgbd_enabled_for_train:
|
| 799 |
+
wild_dataset_roots = [Path(p) for p in wild_roots]
|
| 800 |
+
if data_root_wildrgbd is not None:
|
| 801 |
+
wild_dataset_roots.append(data_root_wildrgbd)
|
| 802 |
+
wildrgbd_ds = WildRGBDDataset(
|
| 803 |
+
root=None,
|
| 804 |
+
scene_list_file=wildrgbd_manifest,
|
| 805 |
+
split="scenes",
|
| 806 |
+
min_frame_gap=1,
|
| 807 |
+
max_frame_gap=int(max_index_gap),
|
| 808 |
+
pair_max_translation_m=0.5,
|
| 809 |
+
pair_min_overlap=0.6,
|
| 810 |
+
output_h=pinhole_output_h,
|
| 811 |
+
output_w=pinhole_output_w,
|
| 812 |
+
shuffle_scene=True,
|
| 813 |
+
shuffle_frame=False,
|
| 814 |
+
ddp_rank=rank,
|
| 815 |
+
ddp_world_size=world_size,
|
| 816 |
+
roots=wild_dataset_roots,
|
| 817 |
+
depth_max_m=float(max_depth_m),
|
| 818 |
+
seed=dataset_seed,
|
| 819 |
+
)
|
| 820 |
+
dl3dv_ds = None
|
| 821 |
+
if dl3dv_enabled_for_train:
|
| 822 |
+
dl3dv_ds = DL3DVDataset(
|
| 823 |
+
root=data_root_dl3dv,
|
| 824 |
+
depth_root=data_root_dl3dv_depth,
|
| 825 |
+
scene_specs_file=dl3dv_manifest,
|
| 826 |
+
min_frame_gap=1,
|
| 827 |
+
max_frame_gap=int(max_index_gap),
|
| 828 |
+
pair_max_translation_m=0.5,
|
| 829 |
+
pair_min_overlap=0.6,
|
| 830 |
+
output_h=pinhole_output_h,
|
| 831 |
+
output_w=pinhole_output_w,
|
| 832 |
+
shuffle_scene=True,
|
| 833 |
+
shuffle_frame=False,
|
| 834 |
+
ddp_rank=rank,
|
| 835 |
+
ddp_world_size=world_size,
|
| 836 |
+
batch_size_hint=int(batch_size),
|
| 837 |
+
depth_max_m=float(max_depth_m),
|
| 838 |
+
seed=dataset_seed,
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
scanetpp_ds = None
|
| 842 |
+
if scanetpp_enabled_for_train:
|
| 843 |
+
scanetpp_ds = ScannetppFisheyeDataset(
|
| 844 |
+
root=data_root_scanetpp,
|
| 845 |
+
scene_list_file=scanetpp_manifest,
|
| 846 |
+
min_frame_gap=1,
|
| 847 |
+
max_frame_gap=int(max_index_gap),
|
| 848 |
+
pair_max_translation_m=0.5,
|
| 849 |
+
shuffle_scene=True,
|
| 850 |
+
shuffle_frame=False,
|
| 851 |
+
ddp_rank=rank,
|
| 852 |
+
ddp_world_size=world_size,
|
| 853 |
+
batch_size_hint=int(batch_size),
|
| 854 |
+
depth_max_m=float(max_depth_m),
|
| 855 |
+
far_depth_invalid_m=float(scanetpp_fisheye_far_depth_invalid_m),
|
| 856 |
+
seed=dataset_seed,
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
hm3d_sampler = None
|
| 860 |
+
if hm3d_ds is not None and _ddp_is_enabled():
|
| 861 |
+
hm3d_sampler = DistributedSampler(hm3d_ds, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False)
|
| 862 |
+
|
| 863 |
+
re10k_num_workers = int(num_workers)
|
| 864 |
+
if re10k_ds is not None and bool(re10k_pseudo_depth_autogen) and re10k_num_workers > 0:
|
| 865 |
+
re10k_num_workers = 0
|
| 866 |
+
if is_main:
|
| 867 |
+
LOGGER.warning(
|
| 868 |
+
"RE10K pseudo-depth auto-generate enabled: force re10k dataloader num_workers=%d (requested=%d).",
|
| 869 |
+
int(re10k_num_workers),
|
| 870 |
+
int(num_workers),
|
| 871 |
+
)
|
| 872 |
+
if re10k_ds is not None and batch_size > 1 and re10k_num_workers > 0:
|
| 873 |
+
re10k_num_workers = 0
|
| 874 |
+
if is_main:
|
| 875 |
+
LOGGER.warning(
|
| 876 |
+
"Dynamic-resolution RE10K batching requires ordered same-resolution samples: force re10k dataloader num_workers=%d (requested=%d).",
|
| 877 |
+
int(re10k_num_workers),
|
| 878 |
+
int(num_workers),
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
highres_pin_memory = os.environ.get("HIGHRES_TRAIN_PIN_MEMORY", "0").strip().lower() in {"1", "true", "yes", "on"}
|
| 882 |
+
standard_pin_memory = os.environ.get("TRAIN_PIN_MEMORY", "1").strip().lower() in {"1", "true", "yes", "on"}
|
| 883 |
+
try:
|
| 884 |
+
train_prefetch_factor = max(1, int(os.environ.get("TRAIN_PREFETCH_FACTOR", "1").strip()))
|
| 885 |
+
except Exception:
|
| 886 |
+
train_prefetch_factor = 1
|
| 887 |
+
def _loader_worker_kwargs(worker_count: int, *, pin_memory: bool) -> dict[str, Any]:
|
| 888 |
+
kwargs: dict[str, Any] = {
|
| 889 |
+
"num_workers": int(worker_count),
|
| 890 |
+
"pin_memory": bool(pin_memory),
|
| 891 |
+
}
|
| 892 |
+
if int(worker_count) > 0:
|
| 893 |
+
kwargs["prefetch_factor"] = int(train_prefetch_factor)
|
| 894 |
+
return kwargs
|
| 895 |
+
|
| 896 |
+
re10k_dl = None
|
| 897 |
+
if re10k_ds is not None:
|
| 898 |
+
re10k_dl = DataLoader(
|
| 899 |
+
re10k_ds,
|
| 900 |
+
batch_size=None,
|
| 901 |
+
**_loader_worker_kwargs(re10k_num_workers, pin_memory=standard_pin_memory),
|
| 902 |
+
collate_fn=re10k_passthrough,
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
hm3d_dl = None
|
| 906 |
+
if hm3d_ds is not None:
|
| 907 |
+
hm3d_dl = DataLoader(
|
| 908 |
+
hm3d_ds,
|
| 909 |
+
batch_size=batch_size,
|
| 910 |
+
shuffle=(hm3d_sampler is None),
|
| 911 |
+
sampler=hm3d_sampler,
|
| 912 |
+
**_loader_worker_kwargs(num_workers, pin_memory=highres_pin_memory),
|
| 913 |
+
collate_fn=panogs_collate,
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
sim_dl = None
|
| 917 |
+
if sim_ds is not None:
|
| 918 |
+
sim_dl = DataLoader(
|
| 919 |
+
sim_ds,
|
| 920 |
+
batch_size=batch_size,
|
| 921 |
+
**_loader_worker_kwargs(num_workers, pin_memory=highres_pin_memory),
|
| 922 |
+
collate_fn=panogs_collate,
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
wildrgbd_dl = None
|
| 926 |
+
if wildrgbd_ds is not None:
|
| 927 |
+
wildrgbd_dl = DataLoader(
|
| 928 |
+
wildrgbd_ds,
|
| 929 |
+
batch_size=batch_size,
|
| 930 |
+
**_loader_worker_kwargs(num_workers, pin_memory=standard_pin_memory),
|
| 931 |
+
collate_fn=wildrgbd_collate,
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
dl3dv_dl = None
|
| 935 |
+
if dl3dv_ds is not None:
|
| 936 |
+
dl3dv_dl = DataLoader(
|
| 937 |
+
dl3dv_ds,
|
| 938 |
+
batch_size=None,
|
| 939 |
+
**_loader_worker_kwargs(num_workers, pin_memory=standard_pin_memory),
|
| 940 |
+
collate_fn=re10k_passthrough,
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
scanetpp_dl = None
|
| 944 |
+
if scanetpp_ds is not None:
|
| 945 |
+
scanetpp_dl = DataLoader(
|
| 946 |
+
scanetpp_ds,
|
| 947 |
+
batch_size=None,
|
| 948 |
+
**_loader_worker_kwargs(num_workers, pin_memory=highres_pin_memory),
|
| 949 |
+
collate_fn=scannetpp_fisheye_passthrough,
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
candidate_datasets: dict[str, Any] = {}
|
| 953 |
+
candidate_dataloaders: dict[str, DataLoader] = {}
|
| 954 |
+
candidate_weights: dict[str, float] = {}
|
| 955 |
+
if re10k_ds is not None and re10k_dl is not None:
|
| 956 |
+
candidate_datasets["re10k"] = re10k_ds
|
| 957 |
+
candidate_dataloaders["re10k"] = re10k_dl
|
| 958 |
+
candidate_weights["re10k"] = float(dataset_weight_re10k)
|
| 959 |
+
if hm3d_ds is not None and hm3d_dl is not None:
|
| 960 |
+
candidate_datasets["hm3d"] = hm3d_ds
|
| 961 |
+
candidate_dataloaders["hm3d"] = hm3d_dl
|
| 962 |
+
candidate_weights["hm3d"] = float(dataset_weight_hm3d)
|
| 963 |
+
if sim_ds is not None and sim_dl is not None:
|
| 964 |
+
candidate_datasets["sim"] = sim_ds
|
| 965 |
+
candidate_dataloaders["sim"] = sim_dl
|
| 966 |
+
candidate_weights["sim"] = float(dataset_weight_sim)
|
| 967 |
+
if wildrgbd_ds is not None and wildrgbd_dl is not None:
|
| 968 |
+
candidate_datasets["wildrgbd"] = wildrgbd_ds
|
| 969 |
+
candidate_dataloaders["wildrgbd"] = wildrgbd_dl
|
| 970 |
+
candidate_weights["wildrgbd"] = float(dataset_weight_wildrgbd)
|
| 971 |
+
if dl3dv_ds is not None and dl3dv_dl is not None:
|
| 972 |
+
candidate_datasets["dl3dv"] = dl3dv_ds
|
| 973 |
+
candidate_dataloaders["dl3dv"] = dl3dv_dl
|
| 974 |
+
candidate_weights["dl3dv"] = float(dataset_weight_dl3dv)
|
| 975 |
+
if scanetpp_ds is not None and scanetpp_dl is not None:
|
| 976 |
+
candidate_datasets["scanetpp_fisheye"] = scanetpp_ds
|
| 977 |
+
candidate_dataloaders["scanetpp_fisheye"] = scanetpp_dl
|
| 978 |
+
candidate_weights["scanetpp_fisheye"] = float(dataset_weight_scanetpp)
|
| 979 |
+
|
| 980 |
+
datasets: dict[str, Any] = {}
|
| 981 |
+
dataloaders: dict[str, DataLoader] = {}
|
| 982 |
+
sampling: dict[str, float] = {}
|
| 983 |
+
for name, w in candidate_weights.items():
|
| 984 |
+
if float(w) > 0.0:
|
| 985 |
+
datasets[name] = candidate_datasets[name]
|
| 986 |
+
dataloaders[name] = candidate_dataloaders[name]
|
| 987 |
+
sampling[name] = float(w)
|
| 988 |
+
elif is_main:
|
| 989 |
+
LOGGER.warning("Skip dataset in mixed sampler: %s (weight=%.4f <= 0)", name, float(w))
|
| 990 |
+
|
| 991 |
+
if len(datasets) == 0:
|
| 992 |
+
raise ValueError("No dataset selected for mixed sampler (all dataset weights <= 0).")
|
| 993 |
+
for name, dataset in datasets.items():
|
| 994 |
+
_maybe_set_dataset_epoch(dataset, 0)
|
| 995 |
+
iterators = {name: LazyDataLoaderIterator(dl) for name, dl in dataloaders.items()}
|
| 996 |
+
sampler_seed = int(seed + rank) if seed is not None else int(12345 + rank)
|
| 997 |
+
sampler = MixedDatasetSampler(
|
| 998 |
+
datasets=datasets,
|
| 999 |
+
weights=sampling,
|
| 1000 |
+
iterators=iterators,
|
| 1001 |
+
seed=sampler_seed,
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
config = UnisharpFeatureConfig(
|
| 1005 |
+
unik3d_backbone=unik3d_backbone,
|
| 1006 |
+
unik3d_resolution_level=int(unik3d_resolution_level),
|
| 1007 |
+
initializer_stride=int(initializer_stride),
|
| 1008 |
+
initializer_scale_factor=float(initializer_scale_factor),
|
| 1009 |
+
detach_init_layer0_distance=bool(detach_init_layer0_distance),
|
| 1010 |
+
delta_rho_limit=float(delta_rho_limit),
|
| 1011 |
+
)
|
| 1012 |
+
setattr(config, "max_distance_m", float(max_depth_m))
|
| 1013 |
+
|
| 1014 |
+
model = UnisharpFeatureModel(config).to(dev).train()
|
| 1015 |
+
|
| 1016 |
+
if _ddp_is_enabled():
|
| 1017 |
+
model = DDP(
|
| 1018 |
+
model,
|
| 1019 |
+
device_ids=[dev.index],
|
| 1020 |
+
output_device=dev.index,
|
| 1021 |
+
find_unused_parameters=True,
|
| 1022 |
+
gradient_as_bucket_view=True,
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
raw_model = model.module if isinstance(model, DDP) else model
|
| 1026 |
+
base_params, unik3d_encoder_params, unik3d_decoder_params = _build_optimizer_param_groups(raw_model)
|
| 1027 |
+
unik3d_params = unik3d_encoder_params + unik3d_decoder_params
|
| 1028 |
+
trainable_params = base_params + unik3d_params
|
| 1029 |
+
if len(trainable_params) == 0:
|
| 1030 |
+
raise RuntimeError("No trainable parameters found.")
|
| 1031 |
+
if len(unik3d_params) == 0:
|
| 1032 |
+
raise RuntimeError(
|
| 1033 |
+
"No UniK3D parameters were collected for the default unfreeze training path. "
|
| 1034 |
+
"Please check parameter naming."
|
| 1035 |
+
)
|
| 1036 |
+
depth_head_params = [p for p in raw_model.second_layer_depth_head.parameters() if p.requires_grad]
|
| 1037 |
+
if len(depth_head_params) == 0:
|
| 1038 |
+
raise RuntimeError("Depth heads have no trainable parameters; depth branch would not train.")
|
| 1039 |
+
|
| 1040 |
+
opt_groups: list[dict[str, Any]] = [{"params": base_params, "lr": float(lr0), "group_name": "base"}]
|
| 1041 |
+
if len(unik3d_encoder_params) > 0:
|
| 1042 |
+
opt_groups.append(
|
| 1043 |
+
{
|
| 1044 |
+
"params": unik3d_encoder_params,
|
| 1045 |
+
"lr": float(unik3d_encoder_lr0),
|
| 1046 |
+
"group_name": "unik3d_encoder",
|
| 1047 |
+
}
|
| 1048 |
+
)
|
| 1049 |
+
if len(unik3d_decoder_params) > 0:
|
| 1050 |
+
opt_groups.append(
|
| 1051 |
+
{
|
| 1052 |
+
"params": unik3d_decoder_params,
|
| 1053 |
+
"lr": float(unik3d_lr0),
|
| 1054 |
+
"group_name": "unik3d_decoder",
|
| 1055 |
+
}
|
| 1056 |
+
)
|
| 1057 |
+
opt = torch.optim.Adam(opt_groups)
|
| 1058 |
+
if is_main:
|
| 1059 |
+
LOGGER.info(
|
| 1060 |
+
"Model ready: scratch heads, pretrained UniK3D, trainable_params=%d",
|
| 1061 |
+
_count_numel(trainable_params),
|
| 1062 |
+
)
|
| 1063 |
+
if dev.type == "cuda":
|
| 1064 |
+
scaler = torch.amp.GradScaler("cuda", enabled=True)
|
| 1065 |
+
else:
|
| 1066 |
+
scaler = torch.amp.GradScaler("cpu", enabled=False)
|
| 1067 |
+
|
| 1068 |
+
renderer = GSplatRenderer(
|
| 1069 |
+
color_space="sRGB",
|
| 1070 |
+
background_color="black",
|
| 1071 |
+
low_pass_filter_eps=float(render_low_pass_filter_eps),
|
| 1072 |
+
).to(dev)
|
| 1073 |
+
|
| 1074 |
+
loss_w = UnisharpLossWeights(
|
| 1075 |
+
lambda_color=float(lambda_color),
|
| 1076 |
+
lambda_alpha=float(lambda_alpha),
|
| 1077 |
+
lambda_percep=float(lambda_percep),
|
| 1078 |
+
lambda_depth=float(lambda_depth),
|
| 1079 |
+
lambda_tv=float(lambda_tv),
|
| 1080 |
+
lambda_grad=float(lambda_grad),
|
| 1081 |
+
lambda_grad_img=float(lambda_grad_img),
|
| 1082 |
+
lambda_edge_rgb=float(lambda_edge_rgb),
|
| 1083 |
+
lambda_delta=float(lambda_delta),
|
| 1084 |
+
lambda_delta_rho=float(lambda_delta_rho),
|
| 1085 |
+
lambda_splat=float(lambda_splat),
|
| 1086 |
+
lambda_edge_splat=float(lambda_edge_splat),
|
| 1087 |
+
lambda_grid=float(lambda_grid),
|
| 1088 |
+
)
|
| 1089 |
+
loss_fn = UnisharpLoss(
|
| 1090 |
+
weights=loss_w,
|
| 1091 |
+
delta_clip=float(delta_clip),
|
| 1092 |
+
raw_delta_clip=float(raw_delta_clip),
|
| 1093 |
+
raw_delta_rho_clip=float(raw_delta_rho_clip),
|
| 1094 |
+
alpha_tail_min=float(alpha_tail_min),
|
| 1095 |
+
alpha_tail_weight=float(alpha_tail_weight),
|
| 1096 |
+
splat_sigma_min=float(splat_sigma_min),
|
| 1097 |
+
splat_sigma_max=float(splat_sigma_max),
|
| 1098 |
+
edge_splat_sigma_max=float(edge_splat_sigma_max),
|
| 1099 |
+
depth_edge_log_threshold=float(depth_edge_log_threshold),
|
| 1100 |
+
depth_edge_dilate_px=int(depth_edge_dilate_px),
|
| 1101 |
+
).to(dev)
|
| 1102 |
+
loss_fn.SUPERVISION_MAX_DEPTH_M = float(max_depth_m)
|
| 1103 |
+
|
| 1104 |
+
if is_main:
|
| 1105 |
+
config_dict = {
|
| 1106 |
+
"max_depth_m": float(max_depth_m),
|
| 1107 |
+
"sim_far_depth_invalid_m": float(sim_far_depth_invalid_m),
|
| 1108 |
+
"sim_far_depth_invalid_max_frac": float(sim_far_depth_invalid_max_frac),
|
| 1109 |
+
"re10k_pseudo_far_depth_invalid_m": float(re10k_pseudo_far_depth_invalid_m),
|
| 1110 |
+
"scanetpp_fisheye_far_depth_invalid_m": float(scanetpp_fisheye_far_depth_invalid_m),
|
| 1111 |
+
"render_low_pass_filter_eps": float(render_low_pass_filter_eps),
|
| 1112 |
+
}
|
| 1113 |
+
(out_dir / "config.json").write_text(
|
| 1114 |
+
json.dumps(config_dict, ensure_ascii=False, indent=2, sort_keys=True) + "\n",
|
| 1115 |
+
encoding="utf-8",
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
loss_csv = out_dir / "losses.csv"
|
| 1119 |
+
loss_csv_fields = [
|
| 1120 |
+
"loss",
|
| 1121 |
+
"src_loss",
|
| 1122 |
+
"tgt_loss",
|
| 1123 |
+
"dataset",
|
| 1124 |
+
]
|
| 1125 |
+
if is_main:
|
| 1126 |
+
with loss_csv.open("w", newline="") as f:
|
| 1127 |
+
csv.DictWriter(f, fieldnames=loss_csv_fields).writeheader()
|
| 1128 |
+
|
| 1129 |
+
if is_main:
|
| 1130 |
+
LOGGER.info("Training loop started.")
|
| 1131 |
+
|
| 1132 |
+
from unisharp.cli.unified_trainer import UnifiedTrainer
|
| 1133 |
+
|
| 1134 |
+
trainer = UnifiedTrainer(
|
| 1135 |
+
model=model,
|
| 1136 |
+
renderer=renderer,
|
| 1137 |
+
loss_fn=loss_fn,
|
| 1138 |
+
device=dev,
|
| 1139 |
+
max_depth_m=float(max_depth_m),
|
| 1140 |
+
sim_far_depth_invalid_m=float(sim_far_depth_invalid_m),
|
| 1141 |
+
re10k_pseudo_far_depth_invalid_m=float(re10k_pseudo_far_depth_invalid_m),
|
| 1142 |
+
scanetpp_fisheye_far_depth_invalid_m=float(scanetpp_fisheye_far_depth_invalid_m),
|
| 1143 |
+
aux_ray_loss_weight=float(lambda_aux_ray),
|
| 1144 |
+
aux_depth_scale_loss_weight=float(lambda_aux_depth_scale),
|
| 1145 |
+
aux_depth2_scale_loss_weight=float(lambda_aux_depth2_scale),
|
| 1146 |
+
target_mask_erode_px=int(target_mask_erode_px),
|
| 1147 |
+
)
|
| 1148 |
+
skip_forward_oom = _env_flag("TRAIN_SKIP_FORWARD_OOM", default=True)
|
| 1149 |
+
|
| 1150 |
+
dataset_epochs: dict[str, int] = {name: 0 for name in dataloaders.keys()}
|
| 1151 |
+
dataset_samplers: dict[str, DistributedSampler | None] = {"hm3d": hm3d_sampler}
|
| 1152 |
+
|
| 1153 |
+
for step in range(1, steps + 1):
|
| 1154 |
+
lr = warmup_cosine_lr(step, warmup, steps, lr0, lr1)
|
| 1155 |
+
lr_unik3d_encoder = warmup_cosine_lr(step, warmup, steps, unik3d_encoder_lr0, unik3d_encoder_lr1)
|
| 1156 |
+
lr_unik3d_decoder = warmup_cosine_lr(step, warmup, steps, unik3d_lr0, unik3d_lr1)
|
| 1157 |
+
for g in opt.param_groups:
|
| 1158 |
+
if g.get("group_name") == "unik3d_encoder":
|
| 1159 |
+
g["lr"] = lr_unik3d_encoder
|
| 1160 |
+
elif g.get("group_name") == "unik3d_decoder":
|
| 1161 |
+
g["lr"] = lr_unik3d_decoder
|
| 1162 |
+
else:
|
| 1163 |
+
g["lr"] = lr
|
| 1164 |
+
|
| 1165 |
+
if _ddp_is_enabled():
|
| 1166 |
+
batch = None
|
| 1167 |
+
available_dataset_names = list(dataloaders.keys())
|
| 1168 |
+
dataset_name = ""
|
| 1169 |
+
for _dataset_attempt in range(max(1, len(dataloaders))):
|
| 1170 |
+
dataset_name = _ddp_broadcast_str(
|
| 1171 |
+
sampler.choose_dataset_name(available_dataset_names) if is_main else "",
|
| 1172 |
+
is_main=is_main,
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
local_exhausted = False
|
| 1176 |
+
try:
|
| 1177 |
+
batch = sampler.next_batch(dataset_name)
|
| 1178 |
+
except StopIteration:
|
| 1179 |
+
local_exhausted = True
|
| 1180 |
+
|
| 1181 |
+
exhausted_any = _ddp_any_bool(local_exhausted, device=dev)
|
| 1182 |
+
if exhausted_any:
|
| 1183 |
+
dataset_epochs[dataset_name] = dataset_epochs.get(dataset_name, 0) + 1
|
| 1184 |
+
ds_sampler = dataset_samplers.get(dataset_name, None)
|
| 1185 |
+
if ds_sampler is not None:
|
| 1186 |
+
ds_sampler.set_epoch(dataset_epochs[dataset_name])
|
| 1187 |
+
_maybe_set_dataset_epoch(datasets[dataset_name], dataset_epochs[dataset_name])
|
| 1188 |
+
iterators[dataset_name] = iter(dataloaders[dataset_name])
|
| 1189 |
+
sampler.iterators = iterators
|
| 1190 |
+
batch = None
|
| 1191 |
+
|
| 1192 |
+
local_exhausted = False
|
| 1193 |
+
try:
|
| 1194 |
+
batch = sampler.next_batch(dataset_name)
|
| 1195 |
+
except StopIteration:
|
| 1196 |
+
local_exhausted = True
|
| 1197 |
+
exhausted_any = _ddp_any_bool(local_exhausted, device=dev)
|
| 1198 |
+
|
| 1199 |
+
if not exhausted_any:
|
| 1200 |
+
break
|
| 1201 |
+
|
| 1202 |
+
batch = None
|
| 1203 |
+
available_dataset_names = [name for name in available_dataset_names if name != dataset_name]
|
| 1204 |
+
if len(available_dataset_names) == 0:
|
| 1205 |
+
break
|
| 1206 |
+
if batch is None:
|
| 1207 |
+
raise RuntimeError(f"Failed to fetch synchronized DDP batch for dataset={dataset_name}")
|
| 1208 |
+
else:
|
| 1209 |
+
try:
|
| 1210 |
+
dataset_name, batch = sampler.sample()
|
| 1211 |
+
except StopIteration as e:
|
| 1212 |
+
msg = str(e)
|
| 1213 |
+
exhausted_name = None
|
| 1214 |
+
if msg.startswith("Dataset ") and msg.endswith(" exhausted"):
|
| 1215 |
+
exhausted_name = msg[len("Dataset ") : -len(" exhausted")]
|
| 1216 |
+
if exhausted_name is None or exhausted_name not in dataloaders:
|
| 1217 |
+
raise
|
| 1218 |
+
dataset_epochs[exhausted_name] = dataset_epochs.get(exhausted_name, 0) + 1
|
| 1219 |
+
ds_sampler = dataset_samplers.get(exhausted_name, None)
|
| 1220 |
+
if ds_sampler is not None:
|
| 1221 |
+
ds_sampler.set_epoch(dataset_epochs[exhausted_name])
|
| 1222 |
+
_maybe_set_dataset_epoch(datasets[exhausted_name], dataset_epochs[exhausted_name])
|
| 1223 |
+
iterators[exhausted_name] = iter(dataloaders[exhausted_name])
|
| 1224 |
+
sampler.iterators = iterators
|
| 1225 |
+
dataset_name, batch = sampler.sample()
|
| 1226 |
+
|
| 1227 |
+
batch = _resize_training_batch_to_multiple(batch, int(train_resize_multiple))
|
| 1228 |
+
|
| 1229 |
+
opt.zero_grad(set_to_none=True)
|
| 1230 |
+
|
| 1231 |
+
autocast_enabled = dev.type == "cuda"
|
| 1232 |
+
if autocast_enabled and torch.cuda.is_bf16_supported():
|
| 1233 |
+
autocast_dtype = torch.bfloat16
|
| 1234 |
+
else:
|
| 1235 |
+
autocast_dtype = torch.float16 if autocast_enabled else torch.bfloat16
|
| 1236 |
+
|
| 1237 |
+
need_vis = bool(is_main and vis_every > 0 and (step % vis_every == 0))
|
| 1238 |
+
result: dict[str, Any] | None = None
|
| 1239 |
+
forward_oom_local = False
|
| 1240 |
+
forward_oom_error = ""
|
| 1241 |
+
try:
|
| 1242 |
+
with torch.autocast(device_type=dev.type, enabled=autocast_enabled, dtype=autocast_dtype):
|
| 1243 |
+
result = trainer.process_batch(
|
| 1244 |
+
batch,
|
| 1245 |
+
dataset_name,
|
| 1246 |
+
step,
|
| 1247 |
+
need_vis=need_vis,
|
| 1248 |
+
)
|
| 1249 |
+
except Exception as e:
|
| 1250 |
+
if skip_forward_oom and _is_oom_exception(e):
|
| 1251 |
+
forward_oom_local = True
|
| 1252 |
+
forward_oom_error = str(e)
|
| 1253 |
+
opt.zero_grad(set_to_none=True)
|
| 1254 |
+
if dev.type == "cuda":
|
| 1255 |
+
torch.cuda.empty_cache()
|
| 1256 |
+
else:
|
| 1257 |
+
raise
|
| 1258 |
+
|
| 1259 |
+
forward_oom_any = _ddp_any_bool(forward_oom_local, device=dev)
|
| 1260 |
+
if forward_oom_any:
|
| 1261 |
+
opt.zero_grad(set_to_none=True)
|
| 1262 |
+
if result is not None:
|
| 1263 |
+
del result
|
| 1264 |
+
result = None
|
| 1265 |
+
if dev.type == "cuda":
|
| 1266 |
+
torch.cuda.empty_cache()
|
| 1267 |
+
if is_main:
|
| 1268 |
+
LOGGER.error(
|
| 1269 |
+
"Skipping optimizer step=%d because forward OOM occurred on at least one rank | dataset=%s",
|
| 1270 |
+
int(step),
|
| 1271 |
+
str(dataset_name),
|
| 1272 |
+
)
|
| 1273 |
+
continue
|
| 1274 |
+
|
| 1275 |
+
if result is None:
|
| 1276 |
+
raise RuntimeError(f"Forward returned no result for dataset={dataset_name} step={step}")
|
| 1277 |
+
total_loss = result["total"]
|
| 1278 |
+
local_nonfinite_loss = not bool(torch.isfinite(total_loss.detach()).item())
|
| 1279 |
+
nonfinite_loss_any = _ddp_any_bool(local_nonfinite_loss, device=dev)
|
| 1280 |
+
if nonfinite_loss_any:
|
| 1281 |
+
opt.zero_grad(set_to_none=True)
|
| 1282 |
+
if is_main:
|
| 1283 |
+
LOGGER.error(
|
| 1284 |
+
"Skipping optimizer step=%d because loss is non-finite on at least one rank | dataset=%s",
|
| 1285 |
+
int(step),
|
| 1286 |
+
str(dataset_name),
|
| 1287 |
+
)
|
| 1288 |
+
continue
|
| 1289 |
+
|
| 1290 |
+
try:
|
| 1291 |
+
scaler.scale(total_loss).backward()
|
| 1292 |
+
except Exception as e:
|
| 1293 |
+
raise
|
| 1294 |
+
try:
|
| 1295 |
+
scaler.unscale_(opt)
|
| 1296 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=float(grad_clip_norm))
|
| 1297 |
+
except Exception as e:
|
| 1298 |
+
LOGGER.error("Gradient unscale/clip failed at step=%d: %s", int(step), str(e))
|
| 1299 |
+
raise
|
| 1300 |
+
grad_norm_value = float(grad_norm.detach().to(dtype=torch.float32).cpu().item()) if torch.is_tensor(grad_norm) else float(grad_norm)
|
| 1301 |
+
local_nonfinite_grad = not np.isfinite(grad_norm_value)
|
| 1302 |
+
nonfinite_grad_any = _ddp_any_bool(local_nonfinite_grad, device=dev)
|
| 1303 |
+
if nonfinite_grad_any:
|
| 1304 |
+
opt.zero_grad(set_to_none=True)
|
| 1305 |
+
scaler.update()
|
| 1306 |
+
if is_main:
|
| 1307 |
+
LOGGER.error(
|
| 1308 |
+
"Skipping optimizer step=%d because grad norm is non-finite on at least one rank | dataset=%s | local_grad_norm=%s",
|
| 1309 |
+
int(step),
|
| 1310 |
+
str(dataset_name),
|
| 1311 |
+
str(grad_norm_value),
|
| 1312 |
+
)
|
| 1313 |
+
continue
|
| 1314 |
+
local_huge_grad = bool(float(max_step_grad_norm) > 0.0 and grad_norm_value > float(max_step_grad_norm))
|
| 1315 |
+
huge_grad_any = _ddp_any_bool(local_huge_grad, device=dev)
|
| 1316 |
+
if huge_grad_any:
|
| 1317 |
+
opt.zero_grad(set_to_none=True)
|
| 1318 |
+
scaler.update()
|
| 1319 |
+
if is_main:
|
| 1320 |
+
LOGGER.error(
|
| 1321 |
+
"Skipping optimizer step=%d because grad norm exceeded max-step-grad-norm on at least one rank | dataset=%s | local_grad_norm=%.6g | threshold=%.6g",
|
| 1322 |
+
int(step),
|
| 1323 |
+
str(dataset_name),
|
| 1324 |
+
float(grad_norm_value),
|
| 1325 |
+
float(max_step_grad_norm),
|
| 1326 |
+
)
|
| 1327 |
+
continue
|
| 1328 |
+
scaler.step(opt)
|
| 1329 |
+
scaler.update()
|
| 1330 |
+
|
| 1331 |
+
if log_every > 0 and step % log_every == 0:
|
| 1332 |
+
loss_v = float(_ddp_mean(total_loss.detach()).item())
|
| 1333 |
+
src_v = float(_ddp_mean(result["src"].detach()).item())
|
| 1334 |
+
tgt_v = float(_ddp_mean(result["tgt"].detach()).item())
|
| 1335 |
+
row = {
|
| 1336 |
+
"loss": loss_v,
|
| 1337 |
+
"src_loss": src_v,
|
| 1338 |
+
"tgt_loss": tgt_v,
|
| 1339 |
+
"dataset": str(dataset_name),
|
| 1340 |
+
}
|
| 1341 |
+
if is_main:
|
| 1342 |
+
LOGGER.info(
|
| 1343 |
+
"step=%d dataset=%s loss=%.6f src_loss=%.6f tgt_loss=%.6f",
|
| 1344 |
+
step,
|
| 1345 |
+
dataset_name,
|
| 1346 |
+
loss_v,
|
| 1347 |
+
src_v,
|
| 1348 |
+
tgt_v,
|
| 1349 |
+
)
|
| 1350 |
+
row_csv = dict(row)
|
| 1351 |
+
for k in ("loss", "src_loss", "tgt_loss"):
|
| 1352 |
+
v = float(row_csv.get(k, float("nan")))
|
| 1353 |
+
row_csv[k] = "" if not np.isfinite(v) else f"{v:.4f}"
|
| 1354 |
+
with loss_csv.open("a", newline="") as f:
|
| 1355 |
+
csv.DictWriter(f, fieldnames=loss_csv_fields).writerow(row_csv)
|
| 1356 |
+
|
| 1357 |
+
if need_vis and result.get("vis_payload"):
|
| 1358 |
+
vis = result["vis_payload"]
|
| 1359 |
+
_save_train_vis(
|
| 1360 |
+
out_dir,
|
| 1361 |
+
step,
|
| 1362 |
+
vis["src_gt"],
|
| 1363 |
+
vis["src_pred"],
|
| 1364 |
+
vis["src_alpha"],
|
| 1365 |
+
vis["tgt_gt"],
|
| 1366 |
+
vis["tgt_pred"],
|
| 1367 |
+
vis["tgt_alpha"],
|
| 1368 |
+
src_gt_depth=vis.get("src_gt_depth"),
|
| 1369 |
+
tgt_gt_depth=vis.get("tgt_gt_depth"),
|
| 1370 |
+
src_pred_depth=vis.get("src_pred_depth"),
|
| 1371 |
+
tgt_pred_depth=vis.get("tgt_pred_depth"),
|
| 1372 |
+
src_unik3d_depth=vis.get("src_unik3d_depth"),
|
| 1373 |
+
tgt_unik3d_depth=vis.get("tgt_unik3d_depth"),
|
| 1374 |
+
dataset_name=vis.get("dataset_name"),
|
| 1375 |
+
scene=vis.get("scene"),
|
| 1376 |
+
src_idx=vis.get("src_idx"),
|
| 1377 |
+
tgt_idx=vis.get("tgt_idx"),
|
| 1378 |
+
src_pose_w2c=vis.get("src_pose_w2c"),
|
| 1379 |
+
tgt_pose_w2c=vis.get("tgt_pose_w2c"),
|
| 1380 |
+
src_metric_mask=vis.get("src_metric_mask"),
|
| 1381 |
+
tgt_metric_mask=vis.get("tgt_metric_mask"),
|
| 1382 |
+
src_cube_gt_u8=vis.get("src_cube_gt_u8"),
|
| 1383 |
+
src_cube_pred_linear=vis.get("src_cube_pred_linear"),
|
| 1384 |
+
src_cube_alpha=vis.get("src_cube_alpha"),
|
| 1385 |
+
tgt_cube_gt_u8=vis.get("tgt_cube_gt_u8"),
|
| 1386 |
+
tgt_cube_pred_linear=vis.get("tgt_cube_pred_linear"),
|
| 1387 |
+
tgt_cube_alpha=vis.get("tgt_cube_alpha"),
|
| 1388 |
+
)
|
| 1389 |
+
|
| 1390 |
+
if need_vis:
|
| 1391 |
+
if "vis" in locals():
|
| 1392 |
+
del vis
|
| 1393 |
+
if dev.type == "cuda":
|
| 1394 |
+
torch.cuda.empty_cache()
|
| 1395 |
+
del result
|
| 1396 |
+
del total_loss
|
| 1397 |
+
batch = None
|
| 1398 |
+
|
| 1399 |
+
if is_main and (save_every > 0) and (step % save_every == 0):
|
| 1400 |
+
path = out_dir / f"step_{step:07d}.pt"
|
| 1401 |
+
raw_model.save_checkpoint(str(path), step, opt)
|
| 1402 |
+
LOGGER.info("💾 Saved checkpoint: %s", str(path))
|
| 1403 |
+
|
| 1404 |
+
|
| 1405 |
+
if _ddp_is_enabled():
|
| 1406 |
+
_ddp_barrier(dev)
|
| 1407 |
+
dist.destroy_process_group()
|
| 1408 |
+
|
| 1409 |
+
if is_main:
|
| 1410 |
+
LOGGER.info("✅ Training completed!")
|
unisharp/cli/train_utils.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def quat_mul_wxyz(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
|
| 9 |
+
w1, x1, y1, z1 = q1.unbind(dim=-1)
|
| 10 |
+
w2, x2, y2, z2 = q2.unbind(dim=-1)
|
| 11 |
+
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
|
| 12 |
+
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
|
| 13 |
+
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
|
| 14 |
+
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
|
| 15 |
+
return torch.stack([w, x, y, z], dim=-1)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def rotmat_to_quat_wxyz(Rm: torch.Tensor) -> torch.Tensor:
|
| 19 |
+
m00, m01, m02 = Rm[0, 0], Rm[0, 1], Rm[0, 2]
|
| 20 |
+
m10, m11, m12 = Rm[1, 0], Rm[1, 1], Rm[1, 2]
|
| 21 |
+
m20, m21, m22 = Rm[2, 0], Rm[2, 1], Rm[2, 2]
|
| 22 |
+
tr = m00 + m11 + m22
|
| 23 |
+
if tr > 0.0:
|
| 24 |
+
s = torch.sqrt(tr + 1.0) * 2.0
|
| 25 |
+
w = 0.25 * s
|
| 26 |
+
x = (m21 - m12) / s
|
| 27 |
+
y = (m02 - m20) / s
|
| 28 |
+
z = (m10 - m01) / s
|
| 29 |
+
elif (m00 > m11) and (m00 > m22):
|
| 30 |
+
s = torch.sqrt(1.0 + m00 - m11 - m22) * 2.0
|
| 31 |
+
w = (m21 - m12) / s
|
| 32 |
+
x = 0.25 * s
|
| 33 |
+
y = (m01 + m10) / s
|
| 34 |
+
z = (m02 + m20) / s
|
| 35 |
+
elif m11 > m22:
|
| 36 |
+
s = torch.sqrt(1.0 + m11 - m00 - m22) * 2.0
|
| 37 |
+
w = (m02 - m20) / s
|
| 38 |
+
x = (m01 + m10) / s
|
| 39 |
+
y = 0.25 * s
|
| 40 |
+
z = (m12 + m21) / s
|
| 41 |
+
else:
|
| 42 |
+
s = torch.sqrt(1.0 + m22 - m00 - m11) * 2.0
|
| 43 |
+
w = (m10 - m01) / s
|
| 44 |
+
x = (m02 + m20) / s
|
| 45 |
+
y = (m12 + m21) / s
|
| 46 |
+
z = 0.25 * s
|
| 47 |
+
q = torch.stack([w, x, y, z])
|
| 48 |
+
return q / q.norm().clamp(min=1e-8)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def to_k4(k3: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
b = k3.shape[0]
|
| 53 |
+
out = torch.eye(4, dtype=k3.dtype, device=k3.device).unsqueeze(0).repeat(b, 1, 1)
|
| 54 |
+
out[:, :3, :3] = k3
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def warmup_cosine_lr(step: int, warmup: int, total: int, lr0: float, lr1: float) -> float:
|
| 59 |
+
if step <= warmup:
|
| 60 |
+
return lr0 * float(step) / float(max(1, warmup))
|
| 61 |
+
t = (step - warmup) / float(max(1, total - warmup))
|
| 62 |
+
cos = 0.5 * (1 + np.cos(np.pi * t))
|
| 63 |
+
return lr1 + (lr0 - lr1) * cos
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@torch.no_grad()
|
| 67 |
+
def compute_frustum_mask(
|
| 68 |
+
depth: torch.Tensor,
|
| 69 |
+
tgt_w2c: torch.Tensor,
|
| 70 |
+
src_w2c: torch.Tensor,
|
| 71 |
+
src_k3: torch.Tensor,
|
| 72 |
+
tgt_k3: torch.Tensor,
|
| 73 |
+
img_h: int,
|
| 74 |
+
img_w: int,
|
| 75 |
+
source_img_h: int | None = None,
|
| 76 |
+
source_img_w: int | None = None,
|
| 77 |
+
depth_min: float = 0.05,
|
| 78 |
+
margin: float = 0.05,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
dev = depth.device
|
| 81 |
+
f32 = torch.float32
|
| 82 |
+
src_h = int(img_h if source_img_h is None else source_img_h)
|
| 83 |
+
src_w = int(img_w if source_img_w is None else source_img_w)
|
| 84 |
+
|
| 85 |
+
d = depth[0, 0].to(f32)
|
| 86 |
+
valid = d > depth_min
|
| 87 |
+
|
| 88 |
+
vy, vx = torch.meshgrid(
|
| 89 |
+
torch.arange(img_h, device=dev, dtype=f32),
|
| 90 |
+
torch.arange(img_w, device=dev, dtype=f32),
|
| 91 |
+
indexing="ij",
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
fx_t = tgt_k3[0, 0, 0].to(f32)
|
| 95 |
+
fy_t = tgt_k3[0, 1, 1].to(f32)
|
| 96 |
+
cx_t = tgt_k3[0, 0, 2].to(f32)
|
| 97 |
+
cy_t = tgt_k3[0, 1, 2].to(f32)
|
| 98 |
+
X_t = (vx - cx_t) / fx_t * d
|
| 99 |
+
Y_t = (vy - cy_t) / fy_t * d
|
| 100 |
+
Z_t = d
|
| 101 |
+
pts_t = torch.stack([X_t, Y_t, Z_t], dim=-1).reshape(-1, 3)
|
| 102 |
+
|
| 103 |
+
c2w_t = torch.linalg.inv(tgt_w2c[0].to(f32))
|
| 104 |
+
pts_w = pts_t @ c2w_t[:3, :3].T + c2w_t[:3, 3][None, :]
|
| 105 |
+
|
| 106 |
+
w2c_s = src_w2c[0].to(f32)
|
| 107 |
+
pts_s = pts_w @ w2c_s[:3, :3].T + w2c_s[:3, 3][None, :]
|
| 108 |
+
|
| 109 |
+
Z_s = pts_s[:, 2].clamp(min=1e-4)
|
| 110 |
+
fx_s = src_k3[0, 0, 0].to(f32)
|
| 111 |
+
fy_s = src_k3[0, 1, 1].to(f32)
|
| 112 |
+
cx_s = src_k3[0, 0, 2].to(f32)
|
| 113 |
+
cy_s = src_k3[0, 1, 2].to(f32)
|
| 114 |
+
u_s = pts_s[:, 0] / Z_s * fx_s + cx_s
|
| 115 |
+
v_s = pts_s[:, 1] / Z_s * fy_s + cy_s
|
| 116 |
+
|
| 117 |
+
half_w = (src_w - 1) * 0.5
|
| 118 |
+
half_h = (src_h - 1) * 0.5
|
| 119 |
+
x_ndc = (u_s - half_w) / half_w
|
| 120 |
+
y_ndc = (v_s - half_h) / half_h
|
| 121 |
+
|
| 122 |
+
in_frust = (
|
| 123 |
+
(x_ndc.abs() <= 1.0 + margin)
|
| 124 |
+
& (y_ndc.abs() <= 1.0 + margin)
|
| 125 |
+
& (pts_s[:, 2] > 0)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
mask = in_frust.reshape(img_h, img_w).float()
|
| 129 |
+
mask = mask * valid.float()
|
| 130 |
+
return mask[None, None]
|
unisharp/cli/unified_trainer.py
ADDED
|
@@ -0,0 +1,1966 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
import os
|
| 5 |
+
from typing import Any, Callable
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from unisharp.utils.gsplat import GSplatRenderer
|
| 12 |
+
from unisharp.losses import UnisharpLoss
|
| 13 |
+
from unisharp.utils.camera_utils import (
|
| 14 |
+
transform_gaussians_to_world,
|
| 15 |
+
to_k4,
|
| 16 |
+
compute_frustum_mask,
|
| 17 |
+
)
|
| 18 |
+
from unisharp.utils.fisheye_geer import (
|
| 19 |
+
compute_fisheye624_frustum_mask,
|
| 20 |
+
render_gaussians_fisheye624,
|
| 21 |
+
)
|
| 22 |
+
from unisharp.utils.camera_projection import cubemap_face_cameras, build_extrinsics_w2c, view_frustum_mask_cubemap_union
|
| 23 |
+
from unisharp.utils.pano import Cube2Equirec, get_pinhole_intrinsics_4x4
|
| 24 |
+
from unisharp import DEFAULT_MAX_DEPTH_M
|
| 25 |
+
from unisharp.utils.pixel_convention import integer_pixel_center_grid
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class _ModeStrategy:
|
| 30 |
+
|
| 31 |
+
batch_size: int
|
| 32 |
+
gaussians: Any
|
| 33 |
+
make_world_gaussians: Callable[[int, Any], Any]
|
| 34 |
+
make_sample: Callable[[int, Any, bool], dict[str, Any]]
|
| 35 |
+
collect_all_vis: bool = False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class UnifiedTrainer:
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
model: nn.Module,
|
| 43 |
+
renderer: GSplatRenderer,
|
| 44 |
+
loss_fn: UnisharpLoss,
|
| 45 |
+
device: torch.device,
|
| 46 |
+
enable_tgt_unik3d_vis: bool = True,
|
| 47 |
+
max_depth_m: float = DEFAULT_MAX_DEPTH_M,
|
| 48 |
+
sim_far_depth_invalid_m: float = 30.0,
|
| 49 |
+
re10k_pseudo_far_depth_invalid_m: float = 30.0,
|
| 50 |
+
scanetpp_fisheye_far_depth_invalid_m: float = 30.0,
|
| 51 |
+
aux_ray_loss_weight: float = 3.0,
|
| 52 |
+
aux_depth_scale_loss_weight: float = 3.0,
|
| 53 |
+
aux_depth2_scale_loss_weight: float = 1.0,
|
| 54 |
+
target_mask_erode_px: int = 0,
|
| 55 |
+
):
|
| 56 |
+
self.model = model
|
| 57 |
+
self.renderer = renderer
|
| 58 |
+
self.loss_fn = loss_fn
|
| 59 |
+
self.device = device
|
| 60 |
+
self.enable_tgt_unik3d_vis = bool(enable_tgt_unik3d_vis)
|
| 61 |
+
self.max_depth_m = float(max_depth_m)
|
| 62 |
+
self.sim_far_depth_invalid_m = float(sim_far_depth_invalid_m)
|
| 63 |
+
self.re10k_pseudo_far_depth_invalid_m = float(re10k_pseudo_far_depth_invalid_m)
|
| 64 |
+
self.scanetpp_fisheye_far_depth_invalid_m = float(scanetpp_fisheye_far_depth_invalid_m)
|
| 65 |
+
self.aux_ray_loss_weight = float(aux_ray_loss_weight)
|
| 66 |
+
self.aux_depth_scale_loss_weight = float(aux_depth_scale_loss_weight)
|
| 67 |
+
self.aux_depth2_scale_loss_weight = float(aux_depth2_scale_loss_weight)
|
| 68 |
+
self.target_mask_erode_px = max(int(target_mask_erode_px), 0)
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def _erode_supervision_mask(mask: torch.Tensor, radius_px: int, *, circular_h: bool = False) -> torch.Tensor:
|
| 72 |
+
radius = max(int(radius_px), 0)
|
| 73 |
+
if radius <= 0:
|
| 74 |
+
return mask
|
| 75 |
+
if not torch.is_tensor(mask):
|
| 76 |
+
return mask
|
| 77 |
+
m = mask.to(dtype=torch.float32).clamp(0.0, 1.0)
|
| 78 |
+
if m.ndim == 3:
|
| 79 |
+
m = m.unsqueeze(1)
|
| 80 |
+
invalid = 1.0 - m
|
| 81 |
+
kernel = 2 * radius + 1
|
| 82 |
+
if bool(circular_h):
|
| 83 |
+
invalid = F.pad(invalid, (radius, radius, 0, 0), mode="circular")
|
| 84 |
+
invalid = F.pad(invalid, (0, 0, radius, radius), mode="constant", value=0.0)
|
| 85 |
+
dilated_invalid = F.max_pool2d(invalid, kernel_size=kernel, stride=1)
|
| 86 |
+
else:
|
| 87 |
+
dilated_invalid = F.max_pool2d(invalid, kernel_size=kernel, stride=1, padding=radius)
|
| 88 |
+
return (m * (1.0 - dilated_invalid)).to(device=mask.device, dtype=mask.dtype)
|
| 89 |
+
|
| 90 |
+
def _aux_ray_losses(
|
| 91 |
+
self,
|
| 92 |
+
*,
|
| 93 |
+
pred_rays: torch.Tensor | None,
|
| 94 |
+
gt_rays: torch.Tensor | None,
|
| 95 |
+
mask: torch.Tensor | None,
|
| 96 |
+
pred_distance: torch.Tensor | None = None,
|
| 97 |
+
pred_distance2: torch.Tensor | None = None,
|
| 98 |
+
gt_distance: torch.Tensor | None = None,
|
| 99 |
+
gt_distance2: torch.Tensor | None = None,
|
| 100 |
+
depth_mask: torch.Tensor | None = None,
|
| 101 |
+
depth_mask2: torch.Tensor | None = None,
|
| 102 |
+
) -> dict[str, torch.Tensor]:
|
| 103 |
+
out: dict[str, torch.Tensor] = {}
|
| 104 |
+
if torch.is_tensor(pred_rays) and torch.is_tensor(gt_rays) and self.aux_ray_loss_weight > 0.0:
|
| 105 |
+
out["unik3d_ray"] = self.aux_ray_loss_weight * self._unik3d_polar_ray_loss(
|
| 106 |
+
pred_rays,
|
| 107 |
+
gt_rays,
|
| 108 |
+
mask,
|
| 109 |
+
)
|
| 110 |
+
if torch.is_tensor(pred_distance) and torch.is_tensor(gt_distance):
|
| 111 |
+
out["unik3d_depth_scale"] = self.aux_depth_scale_loss_weight * self._unik3d_scale_depth_loss(
|
| 112 |
+
pred_distance,
|
| 113 |
+
gt_distance,
|
| 114 |
+
depth_mask if torch.is_tensor(depth_mask) else mask,
|
| 115 |
+
)
|
| 116 |
+
depth2_target = gt_distance2 if torch.is_tensor(gt_distance2) else gt_distance
|
| 117 |
+
if torch.is_tensor(pred_distance2) and torch.is_tensor(depth2_target):
|
| 118 |
+
depth2_mask = depth_mask2 if torch.is_tensor(depth_mask2) else depth_mask
|
| 119 |
+
out["unik3d_depth2_scale"] = self.aux_depth2_scale_loss_weight * self._unik3d_scale_depth_loss(
|
| 120 |
+
pred_distance2,
|
| 121 |
+
depth2_target,
|
| 122 |
+
depth2_mask if torch.is_tensor(depth2_mask) else mask,
|
| 123 |
+
)
|
| 124 |
+
return out
|
| 125 |
+
|
| 126 |
+
DEPTH_SUPERVISION_MAX_M: float = DEFAULT_MAX_DEPTH_M
|
| 127 |
+
|
| 128 |
+
def _distance_init_cap_for_dataset(self, dataset_name: str) -> float | None:
|
| 129 |
+
name = str(dataset_name).lower()
|
| 130 |
+
if name == "re10k" and self.re10k_pseudo_far_depth_invalid_m > 0.0:
|
| 131 |
+
return self.re10k_pseudo_far_depth_invalid_m
|
| 132 |
+
if name == "sim" and self.sim_far_depth_invalid_m > 0.0:
|
| 133 |
+
return self.sim_far_depth_invalid_m
|
| 134 |
+
if name in {"scanetpp_fisheye", "scannetpp_fisheye"} and self.scanetpp_fisheye_far_depth_invalid_m > 0.0:
|
| 135 |
+
return self.scanetpp_fisheye_far_depth_invalid_m
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def _unik3d_polar_ray_loss(
|
| 140 |
+
pred_rays: torch.Tensor | None,
|
| 141 |
+
gt_rays: torch.Tensor | None,
|
| 142 |
+
mask: torch.Tensor | None,
|
| 143 |
+
) -> torch.Tensor:
|
| 144 |
+
if not torch.is_tensor(pred_rays) or not torch.is_tensor(gt_rays):
|
| 145 |
+
device = pred_rays.device if torch.is_tensor(pred_rays) else torch.device("cpu")
|
| 146 |
+
return torch.zeros((), device=device, dtype=torch.float32)
|
| 147 |
+
pred = pred_rays.to(dtype=torch.float32)
|
| 148 |
+
gt = gt_rays.to(device=pred.device, dtype=torch.float32)
|
| 149 |
+
if pred.ndim == 3:
|
| 150 |
+
pred = pred.unsqueeze(0)
|
| 151 |
+
if gt.ndim == 3:
|
| 152 |
+
gt = gt.unsqueeze(0)
|
| 153 |
+
if tuple(pred.shape) != tuple(gt.shape):
|
| 154 |
+
gt = F.interpolate(gt, size=pred.shape[-2:], mode="bilinear", align_corners=False)
|
| 155 |
+
gt = gt / torch.norm(gt, dim=1, keepdim=True).clamp(min=1e-5)
|
| 156 |
+
pred = pred / torch.norm(pred, dim=1, keepdim=True).clamp(min=1e-5)
|
| 157 |
+
gt = gt / torch.norm(gt, dim=1, keepdim=True).clamp(min=1e-5)
|
| 158 |
+
|
| 159 |
+
px, py, pz = pred.unbind(dim=1)
|
| 160 |
+
gx, gy, gz = gt.unbind(dim=1)
|
| 161 |
+
polar_pred = torch.acos(pz.clamp(min=-0.99999, max=0.99999))
|
| 162 |
+
polar_gt = torch.acos(gz.clamp(min=-0.99999, max=0.99999))
|
| 163 |
+
az_pred = torch.atan2(py, px.abs().clamp(min=1e-5) * (2.0 * (px > 0).to(px.dtype) - 1.0))
|
| 164 |
+
az_gt = torch.atan2(gy, gx.abs().clamp(min=1e-5) * (2.0 * (gx > 0).to(gx.dtype) - 1.0))
|
| 165 |
+
polar_error = (polar_pred - polar_gt).abs()
|
| 166 |
+
az_delta = az_pred - az_gt
|
| 167 |
+
az_error = torch.atan2(torch.sin(az_delta), torch.cos(az_delta)).abs()
|
| 168 |
+
quantile_weight = torch.ones_like(polar_error)
|
| 169 |
+
quantile_weight[(polar_gt > polar_pred) & (polar_gt > torch.pi / 2)] = 1.4
|
| 170 |
+
quantile_weight[(polar_gt <= polar_pred) & (polar_gt > torch.pi / 2)] = 0.6
|
| 171 |
+
|
| 172 |
+
if torch.is_tensor(mask):
|
| 173 |
+
m = mask.to(device=pred.device, dtype=torch.float32)
|
| 174 |
+
if m.ndim == 3:
|
| 175 |
+
m = m.unsqueeze(1)
|
| 176 |
+
if tuple(m.shape[-2:]) != tuple(pred.shape[-2:]):
|
| 177 |
+
m = F.interpolate(m, size=pred.shape[-2:], mode="nearest")
|
| 178 |
+
m = m[:, 0].clamp(0.0, 1.0)
|
| 179 |
+
else:
|
| 180 |
+
m = torch.ones_like(polar_error)
|
| 181 |
+
denom = m.sum(dim=(-1, -2), keepdim=False).clamp(min=1.0)
|
| 182 |
+
mean_polar = (polar_error * quantile_weight * m).sum(dim=(-1, -2)) / denom
|
| 183 |
+
mean_azimuth = (az_error * m).sum(dim=(-1, -2)) / denom
|
| 184 |
+
mean_error = (3.0 * mean_polar + mean_azimuth) / 4.0
|
| 185 |
+
return torch.sqrt(mean_error + 1e-4).mean()
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def _unik3d_scale_depth_loss(
|
| 189 |
+
pred_distance: torch.Tensor,
|
| 190 |
+
gt_distance: torch.Tensor,
|
| 191 |
+
mask: torch.Tensor | None,
|
| 192 |
+
) -> torch.Tensor:
|
| 193 |
+
pred = UnifiedTrainer._as_b1hw_depth(pred_distance).to(dtype=torch.float32)
|
| 194 |
+
gt = UnifiedTrainer._as_b1hw_depth(gt_distance).to(device=pred.device, dtype=torch.float32)
|
| 195 |
+
if tuple(gt.shape[-2:]) != tuple(pred.shape[-2:]):
|
| 196 |
+
gt = F.interpolate(gt, size=pred.shape[-2:], mode="nearest")
|
| 197 |
+
valid = torch.isfinite(pred) & torch.isfinite(gt) & (pred > 0.0) & (gt > 0.0)
|
| 198 |
+
if torch.is_tensor(mask):
|
| 199 |
+
m = mask.to(device=pred.device)
|
| 200 |
+
if m.ndim == 3:
|
| 201 |
+
m = m.unsqueeze(1)
|
| 202 |
+
if tuple(m.shape[-2:]) != tuple(pred.shape[-2:]):
|
| 203 |
+
m = F.interpolate(m.to(dtype=torch.float32), size=pred.shape[-2:], mode="nearest")
|
| 204 |
+
valid = valid & (m[:, :1] > 0.5)
|
| 205 |
+
err = (gt.clamp(min=1e-4).log() - pred.clamp(min=1e-4).log()).abs()
|
| 206 |
+
err = torch.where(valid, err, torch.zeros_like(err))
|
| 207 |
+
denom = valid.to(dtype=err.dtype).sum(dim=(-2, -1)).clamp(min=1.0)
|
| 208 |
+
per_image = err.sum(dim=(-2, -1)) / denom
|
| 209 |
+
return torch.sqrt(per_image.clamp(min=0.0)).mean()
|
| 210 |
+
|
| 211 |
+
def _base_model(self) -> nn.Module:
|
| 212 |
+
return self.model.module if hasattr(self.model, "module") else self.model
|
| 213 |
+
|
| 214 |
+
def process_batch(
|
| 215 |
+
self,
|
| 216 |
+
batch: Any,
|
| 217 |
+
dataset_name: str,
|
| 218 |
+
step: int,
|
| 219 |
+
need_vis: bool = False,
|
| 220 |
+
) -> dict[str, Any]:
|
| 221 |
+
if hasattr(batch, "src_rgb_u8") and hasattr(batch, "src_intrinsics"):
|
| 222 |
+
strategy = self._build_pinhole_strategy(
|
| 223 |
+
batch,
|
| 224 |
+
step,
|
| 225 |
+
need_vis=need_vis,
|
| 226 |
+
dataset_name=str(dataset_name),
|
| 227 |
+
)
|
| 228 |
+
elif hasattr(batch, "src_rgb_u8") and hasattr(batch, "src_camera_params"):
|
| 229 |
+
strategy = self._build_fisheye_strategy(
|
| 230 |
+
batch,
|
| 231 |
+
step,
|
| 232 |
+
need_vis=need_vis,
|
| 233 |
+
dataset_name=str(dataset_name),
|
| 234 |
+
)
|
| 235 |
+
elif hasattr(batch, "src_erp_rgb_u8") and hasattr(batch, "src_cube_depth_m"):
|
| 236 |
+
strategy = self._build_spherical_strategy(
|
| 237 |
+
batch,
|
| 238 |
+
step,
|
| 239 |
+
need_vis=need_vis,
|
| 240 |
+
dataset_name=str(dataset_name),
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError(f"Unknown batch schema for dataset={dataset_name}")
|
| 244 |
+
return self._run_strategy_loop(
|
| 245 |
+
strategy,
|
| 246 |
+
need_vis=need_vis,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
def _run_strategy_loop(
|
| 250 |
+
self,
|
| 251 |
+
strategy: _ModeStrategy,
|
| 252 |
+
need_vis: bool = False,
|
| 253 |
+
) -> dict[str, Any]:
|
| 254 |
+
total_loss = torch.zeros((), device=self.device)
|
| 255 |
+
src_sum = torch.zeros((), device=self.device)
|
| 256 |
+
tgt_sum = torch.zeros((), device=self.device)
|
| 257 |
+
src_log_sum: dict[str, torch.Tensor] = {}
|
| 258 |
+
tgt_log_sum: dict[str, torch.Tensor] = {}
|
| 259 |
+
aux_log_sum: dict[str, torch.Tensor] = {}
|
| 260 |
+
vis_payload = None
|
| 261 |
+
vis_payloads: list[dict[str, Any]] = []
|
| 262 |
+
|
| 263 |
+
def _accumulate_loss_terms(term_specs: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
|
| 264 |
+
merged: dict[str, torch.Tensor] = {}
|
| 265 |
+
for spec in term_specs:
|
| 266 |
+
term_losses = self._compute_view_loss(**spec)
|
| 267 |
+
for k, v in term_losses.items():
|
| 268 |
+
merged[k] = merged.get(k, torch.zeros((), device=self.device)) + v
|
| 269 |
+
return merged
|
| 270 |
+
|
| 271 |
+
collect_all_vis = bool(getattr(strategy, "collect_all_vis", False))
|
| 272 |
+
for b in range(int(strategy.batch_size)):
|
| 273 |
+
g = strategy.gaussians
|
| 274 |
+
g_b = type(g)(
|
| 275 |
+
mean_vectors=g.mean_vectors[b : b + 1],
|
| 276 |
+
singular_values=g.singular_values[b : b + 1],
|
| 277 |
+
quaternions=g.quaternions[b : b + 1],
|
| 278 |
+
colors=g.colors[b : b + 1],
|
| 279 |
+
opacities=g.opacities[b : b + 1],
|
| 280 |
+
)
|
| 281 |
+
g_world = strategy.make_world_gaussians(b, g_b)
|
| 282 |
+
sample = strategy.make_sample(
|
| 283 |
+
b,
|
| 284 |
+
g_world,
|
| 285 |
+
bool(need_vis and (collect_all_vis or b == 0)),
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if isinstance(sample.get("src_loss_terms", None), list):
|
| 289 |
+
src_losses = _accumulate_loss_terms(sample["src_loss_terms"])
|
| 290 |
+
else:
|
| 291 |
+
src_losses = self._compute_view_loss(
|
| 292 |
+
pred_rgb_linear=sample["src_pred_rgb_linear"],
|
| 293 |
+
pred_alpha=sample["src_pred_alpha"],
|
| 294 |
+
pred_depth_m=sample["src_pred_depth_m"],
|
| 295 |
+
pred_depth2_m=sample.get("src_pred_depth2_m", None),
|
| 296 |
+
gt_rgb_u8=sample["src_gt_rgb_u8"],
|
| 297 |
+
gt_depth_m=sample["src_gt_depth_m"],
|
| 298 |
+
mask=sample["src_mask"],
|
| 299 |
+
apply_color=bool(sample.get("src_apply_color", True)),
|
| 300 |
+
apply_alpha=bool(sample.get("src_apply_alpha", True)),
|
| 301 |
+
apply_depth=bool(sample.get("src_apply_depth", True)),
|
| 302 |
+
apply_percep=False,
|
| 303 |
+
apply_tv=True,
|
| 304 |
+
apply_grad=bool(sample.get("src_apply_grad", True)),
|
| 305 |
+
apply_grad_img=bool(sample.get("src_apply_grad_img", True)),
|
| 306 |
+
apply_splat=bool(sample.get("src_apply_splat", True)),
|
| 307 |
+
grad_img_circular_h=sample.get("src_grad_img_circular_h", None),
|
| 308 |
+
gaussian_scales=sample.get("gaussian_scales", None),
|
| 309 |
+
gaussian_quaternions=sample.get("gaussian_quaternions", None),
|
| 310 |
+
gaussian_angular_cell=sample.get("gaussian_angular_cell", None),
|
| 311 |
+
delta_xy=sample.get("delta_xy", None),
|
| 312 |
+
delta_rho=sample.get("delta_rho", None),
|
| 313 |
+
delta_grid=sample.get("delta_grid", None),
|
| 314 |
+
gaussian_mean_vectors=sample.get("gaussian_mean_vectors", None),
|
| 315 |
+
gaussian_base_mean_vectors=sample.get("gaussian_base_mean_vectors", None),
|
| 316 |
+
gaussian_opacities=sample.get("gaussian_opacities", None),
|
| 317 |
+
gauss_grid_shape=sample.get("gauss_grid_shape", None),
|
| 318 |
+
projected_scale_factor=sample.get("projected_scale_factor", None),
|
| 319 |
+
projection_model=sample.get("projection_model", None),
|
| 320 |
+
projection_intrinsics=sample.get("projection_intrinsics", None),
|
| 321 |
+
projection_camera_params=sample.get("projection_camera_params", None),
|
| 322 |
+
depth_mask=sample.get("src_depth_mask", None),
|
| 323 |
+
)
|
| 324 |
+
if isinstance(sample.get("src_extra_loss_terms", None), list):
|
| 325 |
+
extra_src_losses = _accumulate_loss_terms(sample["src_extra_loss_terms"])
|
| 326 |
+
for k, v in extra_src_losses.items():
|
| 327 |
+
src_losses[k] = src_losses.get(k, torch.zeros((), device=self.device)) + v
|
| 328 |
+
if isinstance(sample.get("tgt_loss_terms", None), list):
|
| 329 |
+
tgt_losses = _accumulate_loss_terms(sample["tgt_loss_terms"])
|
| 330 |
+
else:
|
| 331 |
+
tgt_losses = self._compute_view_loss(
|
| 332 |
+
pred_rgb_linear=sample["tgt_pred_rgb_linear"],
|
| 333 |
+
pred_alpha=sample["tgt_pred_alpha"],
|
| 334 |
+
pred_depth_m=sample["tgt_pred_depth_m"],
|
| 335 |
+
pred_depth2_m=sample.get("tgt_pred_depth2_m", None),
|
| 336 |
+
gt_rgb_u8=sample["tgt_gt_rgb_u8"],
|
| 337 |
+
gt_depth_m=sample["tgt_gt_depth_m"],
|
| 338 |
+
mask=sample["tgt_mask"],
|
| 339 |
+
apply_color=bool(sample.get("tgt_apply_color", True)),
|
| 340 |
+
apply_alpha=bool(sample.get("tgt_apply_alpha", True)),
|
| 341 |
+
apply_depth=bool(sample.get("tgt_apply_depth", True)),
|
| 342 |
+
apply_percep=bool(sample.get("tgt_apply_percep", False)),
|
| 343 |
+
apply_tv=False,
|
| 344 |
+
apply_grad=False,
|
| 345 |
+
apply_grad_img=bool(sample.get("tgt_apply_grad_img", True)),
|
| 346 |
+
apply_splat=bool(sample.get("tgt_apply_splat", False)),
|
| 347 |
+
grad_img_circular_h=sample.get("tgt_grad_img_circular_h", None),
|
| 348 |
+
gaussian_scales=None,
|
| 349 |
+
gaussian_quaternions=None,
|
| 350 |
+
delta_xy=None,
|
| 351 |
+
delta_rho=None,
|
| 352 |
+
gaussian_mean_vectors=None,
|
| 353 |
+
gaussian_base_mean_vectors=None,
|
| 354 |
+
gaussian_opacities=None,
|
| 355 |
+
gauss_grid_shape=None,
|
| 356 |
+
projected_scale_factor=sample.get("projected_scale_factor", None),
|
| 357 |
+
projection_model=sample.get("projection_model", None),
|
| 358 |
+
projection_intrinsics=sample.get("projection_intrinsics", None),
|
| 359 |
+
projection_camera_params=sample.get("projection_camera_params", None),
|
| 360 |
+
depth_mask=sample.get("tgt_depth_mask", None),
|
| 361 |
+
)
|
| 362 |
+
if isinstance(sample.get("tgt_extra_loss_terms", None), list):
|
| 363 |
+
extra_tgt_losses = _accumulate_loss_terms(sample["tgt_extra_loss_terms"])
|
| 364 |
+
for k, v in extra_tgt_losses.items():
|
| 365 |
+
tgt_losses[k] = tgt_losses.get(k, torch.zeros((), device=self.device)) + v
|
| 366 |
+
|
| 367 |
+
aux_total = torch.zeros((), device=self.device)
|
| 368 |
+
raw_aux = sample.get("aux_losses", None)
|
| 369 |
+
if isinstance(raw_aux, dict):
|
| 370 |
+
for k, v in raw_aux.items():
|
| 371 |
+
if torch.is_tensor(v):
|
| 372 |
+
vv = v.to(device=self.device)
|
| 373 |
+
else:
|
| 374 |
+
vv = torch.tensor(float(v), device=self.device, dtype=torch.float32)
|
| 375 |
+
aux_total = aux_total + vv
|
| 376 |
+
aux_log_sum[str(k)] = aux_log_sum.get(str(k), torch.zeros((), device=self.device)) + vv.detach()
|
| 377 |
+
|
| 378 |
+
src_sum = src_sum + src_losses["total"]
|
| 379 |
+
tgt_sum = tgt_sum + tgt_losses["total"]
|
| 380 |
+
total_loss = total_loss + src_losses["total"] + tgt_losses["total"] + aux_total
|
| 381 |
+
for k, v in src_losses.items():
|
| 382 |
+
src_log_sum[k] = src_log_sum.get(k, torch.zeros((), device=self.device)) + v.detach()
|
| 383 |
+
for k, v in tgt_losses.items():
|
| 384 |
+
tgt_log_sum[k] = tgt_log_sum.get(k, torch.zeros((), device=self.device)) + v.detach()
|
| 385 |
+
|
| 386 |
+
if need_vis and isinstance(sample.get("vis_payload", None), dict):
|
| 387 |
+
vis_payloads.append(sample["vis_payload"])
|
| 388 |
+
if b == 0:
|
| 389 |
+
vis_payload = sample["vis_payload"]
|
| 390 |
+
|
| 391 |
+
bs = float(strategy.batch_size)
|
| 392 |
+
total_loss = total_loss / bs
|
| 393 |
+
src_sum = src_sum / bs
|
| 394 |
+
tgt_sum = tgt_sum / bs
|
| 395 |
+
loss_breakdown: dict[str, torch.Tensor] = {}
|
| 396 |
+
for k, v in src_log_sum.items():
|
| 397 |
+
loss_breakdown[f"src_{k}"] = v / bs
|
| 398 |
+
for k, v in tgt_log_sum.items():
|
| 399 |
+
loss_breakdown[f"tgt_{k}"] = v / bs
|
| 400 |
+
for k, v in aux_log_sum.items():
|
| 401 |
+
loss_breakdown[f"aux_{k}"] = v / bs
|
| 402 |
+
batch_stats = {
|
| 403 |
+
"batch_size": int(strategy.batch_size),
|
| 404 |
+
"gaussian_count": int(strategy.gaussians.mean_vectors.shape[1]),
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
return {
|
| 408 |
+
"total": total_loss,
|
| 409 |
+
"src": src_sum,
|
| 410 |
+
"tgt": tgt_sum,
|
| 411 |
+
"loss_breakdown": loss_breakdown,
|
| 412 |
+
"batch_stats": batch_stats,
|
| 413 |
+
"vis_payload": vis_payload,
|
| 414 |
+
"vis_payloads": vis_payloads,
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
@staticmethod
|
| 418 |
+
def _first_item(x: Any, default: Any = None) -> Any:
|
| 419 |
+
if x is None:
|
| 420 |
+
return default
|
| 421 |
+
if isinstance(x, (list, tuple)):
|
| 422 |
+
return x[0] if len(x) > 0 else default
|
| 423 |
+
if torch.is_tensor(x):
|
| 424 |
+
if x.numel() == 0:
|
| 425 |
+
return default
|
| 426 |
+
return x.flatten()[0].item()
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
@staticmethod
|
| 430 |
+
def _item_at(x: Any, index: int, default: Any = None) -> Any:
|
| 431 |
+
if x is None:
|
| 432 |
+
return default
|
| 433 |
+
if isinstance(x, (list, tuple)):
|
| 434 |
+
return x[index] if 0 <= int(index) < len(x) else default
|
| 435 |
+
if torch.is_tensor(x):
|
| 436 |
+
if x.numel() == 0:
|
| 437 |
+
return default
|
| 438 |
+
if x.ndim == 0:
|
| 439 |
+
return x.item()
|
| 440 |
+
if 0 <= int(index) < int(x.shape[0]):
|
| 441 |
+
item = x[int(index)]
|
| 442 |
+
return item.item() if item.numel() == 1 else item
|
| 443 |
+
return default
|
| 444 |
+
return x
|
| 445 |
+
|
| 446 |
+
@staticmethod
|
| 447 |
+
def _finite_quantile(x: torch.Tensor, q: float, default: float = float("nan")) -> torch.Tensor:
|
| 448 |
+
vals = x[torch.isfinite(x)]
|
| 449 |
+
if int(vals.numel()) <= 0:
|
| 450 |
+
return torch.tensor(float(default), device=x.device, dtype=torch.float32)
|
| 451 |
+
vals = vals.to(torch.float32).flatten()
|
| 452 |
+
if int(vals.numel()) > 262144:
|
| 453 |
+
step = max(1, int(vals.numel()) // 262144)
|
| 454 |
+
vals = vals[::step]
|
| 455 |
+
return torch.quantile(vals, float(q))
|
| 456 |
+
|
| 457 |
+
def _clamp_distance_for_supervision(
|
| 458 |
+
self,
|
| 459 |
+
depth_m: torch.Tensor | None,
|
| 460 |
+
*,
|
| 461 |
+
max_depth_m: float | None = None,
|
| 462 |
+
clamp_max: bool = True,
|
| 463 |
+
) -> torch.Tensor | None:
|
| 464 |
+
if not torch.is_tensor(depth_m):
|
| 465 |
+
return None
|
| 466 |
+
cap = float(self.max_depth_m if max_depth_m is None else max_depth_m)
|
| 467 |
+
out = depth_m.to(dtype=torch.float32)
|
| 468 |
+
valid = torch.isfinite(out) & (out > 0.0)
|
| 469 |
+
if bool(clamp_max):
|
| 470 |
+
sanitized = out.clamp(min=1e-4, max=cap)
|
| 471 |
+
else:
|
| 472 |
+
sanitized = out.clamp(min=1e-4)
|
| 473 |
+
return torch.where(valid, sanitized, torch.zeros_like(out))
|
| 474 |
+
|
| 475 |
+
@staticmethod
|
| 476 |
+
def _rendered_depth_valid_for_inv_loss(
|
| 477 |
+
depth_m: torch.Tensor,
|
| 478 |
+
alpha: torch.Tensor,
|
| 479 |
+
*,
|
| 480 |
+
alpha_min: float | None = None,
|
| 481 |
+
depth_min_m: float = 1e-3,
|
| 482 |
+
) -> torch.Tensor:
|
| 483 |
+
depth = depth_m.detach()
|
| 484 |
+
valid = torch.isfinite(depth) & (depth > float(depth_min_m))
|
| 485 |
+
if alpha_min is not None:
|
| 486 |
+
a = alpha.detach().to(device=depth.device)
|
| 487 |
+
valid = valid & (a[:, :1] > float(alpha_min))
|
| 488 |
+
return valid.to(dtype=depth.dtype)
|
| 489 |
+
|
| 490 |
+
def _pinhole_z_to_supervision_distance(
|
| 491 |
+
self,
|
| 492 |
+
z_depth_b1hw: torch.Tensor | None,
|
| 493 |
+
k3_b33: torch.Tensor | None,
|
| 494 |
+
*,
|
| 495 |
+
clamp_max: bool = True,
|
| 496 |
+
) -> torch.Tensor | None:
|
| 497 |
+
if not torch.is_tensor(z_depth_b1hw) or not torch.is_tensor(k3_b33):
|
| 498 |
+
return None
|
| 499 |
+
dist = self._z_depth_to_distance_pinhole(z_depth_b1hw, k3_b33)
|
| 500 |
+
return self._clamp_distance_for_supervision(dist, clamp_max=bool(clamp_max))
|
| 501 |
+
|
| 502 |
+
@staticmethod
|
| 503 |
+
def _sanitize_positive_depth(depth_m: torch.Tensor | None) -> torch.Tensor | None:
|
| 504 |
+
if not torch.is_tensor(depth_m):
|
| 505 |
+
return None
|
| 506 |
+
out = depth_m.to(dtype=torch.float32)
|
| 507 |
+
valid = torch.isfinite(out) & (out > 0.0)
|
| 508 |
+
return torch.where(valid, out, torch.zeros_like(out))
|
| 509 |
+
|
| 510 |
+
@staticmethod
|
| 511 |
+
def _as_b1hw_depth(depth: torch.Tensor) -> torch.Tensor:
|
| 512 |
+
if depth.ndim == 3:
|
| 513 |
+
return depth.unsqueeze(1)
|
| 514 |
+
if depth.ndim == 4 and depth.shape[1] == 1:
|
| 515 |
+
return depth
|
| 516 |
+
raise ValueError(f"Expected depth shape (B,H,W) or (B,1,H,W), got {tuple(depth.shape)}")
|
| 517 |
+
|
| 518 |
+
@staticmethod
|
| 519 |
+
def _as_bchw_rgb_u8(image: torch.Tensor) -> torch.Tensor:
|
| 520 |
+
if image.ndim == 3 and image.shape[0] == 3:
|
| 521 |
+
return image.unsqueeze(0)
|
| 522 |
+
if image.ndim == 4 and image.shape[1] == 3:
|
| 523 |
+
return image
|
| 524 |
+
raise ValueError(f"Expected image shape (3,H,W) or (B,3,H,W), got {tuple(image.shape)}")
|
| 525 |
+
|
| 526 |
+
@staticmethod
|
| 527 |
+
def _as_b33_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor:
|
| 528 |
+
if intrinsics.ndim == 2 and tuple(intrinsics.shape) == (3, 3):
|
| 529 |
+
return intrinsics.unsqueeze(0)
|
| 530 |
+
if intrinsics.ndim == 3 and tuple(intrinsics.shape[1:]) == (3, 3):
|
| 531 |
+
return intrinsics
|
| 532 |
+
raise ValueError(
|
| 533 |
+
f"Expected intrinsics shape (3,3) or (B,3,3), got {tuple(intrinsics.shape)}"
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
@staticmethod
|
| 537 |
+
def _as_b9_camera_params(camera_params: torch.Tensor) -> torch.Tensor:
|
| 538 |
+
if camera_params.ndim == 1 and int(camera_params.shape[0]) == 9:
|
| 539 |
+
return camera_params.unsqueeze(0)
|
| 540 |
+
if camera_params.ndim == 2 and int(camera_params.shape[1]) == 9:
|
| 541 |
+
return camera_params
|
| 542 |
+
raise ValueError(f"Expected camera_params shape (9,) or (B,9), got {tuple(camera_params.shape)}")
|
| 543 |
+
|
| 544 |
+
@staticmethod
|
| 545 |
+
def _as_b16_camera_params(camera_params: torch.Tensor) -> torch.Tensor:
|
| 546 |
+
if camera_params.ndim == 1 and int(camera_params.shape[0]) == 16:
|
| 547 |
+
return camera_params.unsqueeze(0)
|
| 548 |
+
if camera_params.ndim == 2 and int(camera_params.shape[1]) == 16:
|
| 549 |
+
return camera_params
|
| 550 |
+
raise ValueError(f"Expected camera_params shape (16,) or (B,16), got {tuple(camera_params.shape)}")
|
| 551 |
+
|
| 552 |
+
@staticmethod
|
| 553 |
+
def _as_b44_pose(extrinsics: torch.Tensor) -> torch.Tensor:
|
| 554 |
+
if extrinsics.ndim == 2 and tuple(extrinsics.shape) == (4, 4):
|
| 555 |
+
return extrinsics.unsqueeze(0)
|
| 556 |
+
if extrinsics.ndim == 3 and tuple(extrinsics.shape[1:]) == (4, 4):
|
| 557 |
+
return extrinsics
|
| 558 |
+
raise ValueError(
|
| 559 |
+
f"Expected extrinsics shape (4,4) or (B,4,4), got {tuple(extrinsics.shape)}"
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
@staticmethod
|
| 563 |
+
def _pick_depth_for_pinhole_frustum_mask(
|
| 564 |
+
gt_depth: torch.Tensor | None,
|
| 565 |
+
pred_depth: torch.Tensor,
|
| 566 |
+
min_valid_px: int = 8,
|
| 567 |
+
) -> torch.Tensor:
|
| 568 |
+
if torch.is_tensor(gt_depth):
|
| 569 |
+
gt_depth = UnifiedTrainer._as_b1hw_depth(gt_depth)
|
| 570 |
+
valid = torch.isfinite(gt_depth) & (gt_depth > 0.0)
|
| 571 |
+
if int(valid.sum().item()) >= int(min_valid_px):
|
| 572 |
+
return gt_depth
|
| 573 |
+
return pred_depth
|
| 574 |
+
|
| 575 |
+
@staticmethod
|
| 576 |
+
def _pick_depth_for_fisheye_frustum_mask(
|
| 577 |
+
gt_depth: torch.Tensor | None,
|
| 578 |
+
pred_depth: torch.Tensor,
|
| 579 |
+
gt_valid_mask: torch.Tensor | None = None,
|
| 580 |
+
min_valid_px: int = 8,
|
| 581 |
+
) -> torch.Tensor:
|
| 582 |
+
if torch.is_tensor(gt_depth):
|
| 583 |
+
gt_depth = UnifiedTrainer._as_b1hw_depth(gt_depth)
|
| 584 |
+
if torch.is_tensor(gt_valid_mask):
|
| 585 |
+
gt_valid = gt_depth > 0.0
|
| 586 |
+
gt_valid = gt_valid & (gt_valid_mask > 0.5)
|
| 587 |
+
else:
|
| 588 |
+
gt_valid = torch.isfinite(gt_depth) & (gt_depth > 0.0)
|
| 589 |
+
if int(gt_valid.sum().item()) >= int(min_valid_px):
|
| 590 |
+
return gt_depth
|
| 591 |
+
return pred_depth
|
| 592 |
+
|
| 593 |
+
@staticmethod
|
| 594 |
+
def _as_cubemap_depth_hw1(depth: torch.Tensor) -> torch.Tensor:
|
| 595 |
+
if depth.ndim != 4:
|
| 596 |
+
raise ValueError(f"Expected 4D cubemap depth, got shape={tuple(depth.shape)}")
|
| 597 |
+
if depth.shape[-1] == 1:
|
| 598 |
+
return depth
|
| 599 |
+
if depth.shape[1] == 1:
|
| 600 |
+
return depth.permute(0, 2, 3, 1).contiguous()
|
| 601 |
+
raise ValueError(f"Unsupported cubemap depth shape={tuple(depth.shape)}")
|
| 602 |
+
|
| 603 |
+
def _pick_depth_for_cubemap_frustum_mask(
|
| 604 |
+
self,
|
| 605 |
+
gt_depth_cube: torch.Tensor | None,
|
| 606 |
+
pred_depth_cube: torch.Tensor,
|
| 607 |
+
face_w: int,
|
| 608 |
+
min_valid_px: int = 8,
|
| 609 |
+
) -> torch.Tensor:
|
| 610 |
+
pred_hw1 = self._as_cubemap_depth_hw1(pred_depth_cube)
|
| 611 |
+
if torch.is_tensor(gt_depth_cube):
|
| 612 |
+
gt_hw1 = self._as_cubemap_depth_hw1(gt_depth_cube)
|
| 613 |
+
gt_dist = self._cubemap_z_depth_to_distance(gt_hw1)
|
| 614 |
+
gt_hw1 = self._as_cubemap_depth_hw1(gt_dist)
|
| 615 |
+
if gt_hw1.shape[1] != int(face_w) or gt_hw1.shape[2] != int(face_w):
|
| 616 |
+
gt_hw1 = F.interpolate(
|
| 617 |
+
gt_hw1.permute(0, 3, 1, 2),
|
| 618 |
+
size=(int(face_w), int(face_w)),
|
| 619 |
+
mode="nearest",
|
| 620 |
+
).permute(0, 2, 3, 1).contiguous()
|
| 621 |
+
valid = torch.isfinite(gt_hw1[..., 0]) & (gt_hw1[..., 0] > 0.0)
|
| 622 |
+
if int(valid.sum().item()) >= int(min_valid_px):
|
| 623 |
+
return gt_hw1
|
| 624 |
+
return pred_hw1
|
| 625 |
+
|
| 626 |
+
@staticmethod
|
| 627 |
+
def _distance_to_z_depth_pinhole(
|
| 628 |
+
distance_b1hw: torch.Tensor,
|
| 629 |
+
intrinsics_b33: torch.Tensor,
|
| 630 |
+
) -> torch.Tensor:
|
| 631 |
+
distance_b1hw = UnifiedTrainer._as_b1hw_depth(distance_b1hw)
|
| 632 |
+
intrinsics_b33 = UnifiedTrainer._as_b33_intrinsics(intrinsics_b33)
|
| 633 |
+
b, _, h, w = distance_b1hw.shape
|
| 634 |
+
dev = distance_b1hw.device
|
| 635 |
+
dtype = distance_b1hw.dtype
|
| 636 |
+
uu, vv = integer_pixel_center_grid(h, w, device=dev, dtype=dtype)
|
| 637 |
+
uu = uu.unsqueeze(0).expand(b, -1, -1)
|
| 638 |
+
vv = vv.unsqueeze(0).expand(b, -1, -1)
|
| 639 |
+
fx = intrinsics_b33[:, 0, 0].view(b, 1, 1).to(dtype=dtype, device=dev)
|
| 640 |
+
fy = intrinsics_b33[:, 1, 1].view(b, 1, 1).to(dtype=dtype, device=dev)
|
| 641 |
+
cx = intrinsics_b33[:, 0, 2].view(b, 1, 1).to(dtype=dtype, device=dev)
|
| 642 |
+
cy = intrinsics_b33[:, 1, 2].view(b, 1, 1).to(dtype=dtype, device=dev)
|
| 643 |
+
x = (uu - cx) / fx
|
| 644 |
+
y = (vv - cy) / fy
|
| 645 |
+
ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
|
| 646 |
+
return distance_b1hw * ray_z.unsqueeze(1)
|
| 647 |
+
|
| 648 |
+
@staticmethod
|
| 649 |
+
def _z_depth_to_distance_pinhole(
|
| 650 |
+
z_depth_b1hw: torch.Tensor,
|
| 651 |
+
intrinsics_b33: torch.Tensor,
|
| 652 |
+
) -> torch.Tensor:
|
| 653 |
+
z_depth_b1hw = UnifiedTrainer._as_b1hw_depth(z_depth_b1hw)
|
| 654 |
+
intrinsics_b33 = UnifiedTrainer._as_b33_intrinsics(intrinsics_b33)
|
| 655 |
+
b, _, h, w = z_depth_b1hw.shape
|
| 656 |
+
dev = z_depth_b1hw.device
|
| 657 |
+
dtype = z_depth_b1hw.dtype
|
| 658 |
+
uu, vv = integer_pixel_center_grid(h, w, device=dev, dtype=dtype)
|
| 659 |
+
uu = uu.unsqueeze(0).expand(b, -1, -1)
|
| 660 |
+
vv = vv.unsqueeze(0).expand(b, -1, -1)
|
| 661 |
+
fx = intrinsics_b33[:, 0, 0].view(b, 1, 1).to(dtype=dtype, device=dev)
|
| 662 |
+
fy = intrinsics_b33[:, 1, 1].view(b, 1, 1).to(dtype=dtype, device=dev)
|
| 663 |
+
cx = intrinsics_b33[:, 0, 2].view(b, 1, 1).to(dtype=dtype, device=dev)
|
| 664 |
+
cy = intrinsics_b33[:, 1, 2].view(b, 1, 1).to(dtype=dtype, device=dev)
|
| 665 |
+
x = (uu - cx) / fx
|
| 666 |
+
y = (vv - cy) / fy
|
| 667 |
+
ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
|
| 668 |
+
return z_depth_b1hw / ray_z.unsqueeze(1).clamp(min=1e-8)
|
| 669 |
+
|
| 670 |
+
def _cubemap_z_depth_to_distance(
|
| 671 |
+
self,
|
| 672 |
+
depth_cube: torch.Tensor,
|
| 673 |
+
) -> torch.Tensor:
|
| 674 |
+
if depth_cube.ndim != 4:
|
| 675 |
+
raise ValueError(f"Expected 4D cubemap depth, got {tuple(depth_cube.shape)}")
|
| 676 |
+
if depth_cube.shape[-1] == 1:
|
| 677 |
+
depth_61hw = depth_cube.permute(0, 3, 1, 2).contiguous()
|
| 678 |
+
elif depth_cube.shape[1] == 1:
|
| 679 |
+
depth_61hw = depth_cube
|
| 680 |
+
else:
|
| 681 |
+
raise ValueError(f"Unsupported cubemap depth shape={tuple(depth_cube.shape)}")
|
| 682 |
+
|
| 683 |
+
_, _, h, w = depth_61hw.shape
|
| 684 |
+
intr = get_pinhole_intrinsics_4x4(int(w)).to(
|
| 685 |
+
device=depth_61hw.device,
|
| 686 |
+
dtype=depth_61hw.dtype,
|
| 687 |
+
)
|
| 688 |
+
fx = intr[0, 0]
|
| 689 |
+
fy = intr[1, 1]
|
| 690 |
+
cx = intr[0, 2]
|
| 691 |
+
cy = intr[1, 2]
|
| 692 |
+
uu, vv = integer_pixel_center_grid(h, w, device=depth_61hw.device, dtype=depth_61hw.dtype)
|
| 693 |
+
x = (uu - cx) / fx
|
| 694 |
+
y = (vv - cy) / fy
|
| 695 |
+
ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
|
| 696 |
+
dist = depth_61hw / ray_z.view(1, 1, h, w).clamp(min=1e-8)
|
| 697 |
+
valid = torch.isfinite(dist) & (depth_61hw > 0.0)
|
| 698 |
+
dist = torch.where(valid, dist.clamp(min=1e-4), torch.zeros_like(dist))
|
| 699 |
+
return dist
|
| 700 |
+
|
| 701 |
+
def _collect_regularization_inputs(
|
| 702 |
+
self,
|
| 703 |
+
out: dict[str, Any],
|
| 704 |
+
gaussians: Any,
|
| 705 |
+
b: int,
|
| 706 |
+
projected_scale_factor: float | None,
|
| 707 |
+
) -> dict[str, Any]:
|
| 708 |
+
delta_b = out.get("delta", None)
|
| 709 |
+
delta_xy_raw = None
|
| 710 |
+
if torch.is_tensor(delta_b):
|
| 711 |
+
delta_xy_raw = delta_b[b : b + 1, 0:2]
|
| 712 |
+
delta_rho_raw = delta_b[b : b + 1, 2:3]
|
| 713 |
+
delta_grid_raw = delta_b[b : b + 1]
|
| 714 |
+
else:
|
| 715 |
+
delta_rho_raw = None
|
| 716 |
+
delta_grid_raw = None
|
| 717 |
+
delta_rho_applied_all = out.get("delta_rho_applied", None)
|
| 718 |
+
delta_rho_applied = (
|
| 719 |
+
delta_rho_applied_all[b : b + 1]
|
| 720 |
+
if torch.is_tensor(delta_rho_applied_all)
|
| 721 |
+
else None
|
| 722 |
+
)
|
| 723 |
+
scale_factor_applied_all = out.get("scale_factor_applied", None)
|
| 724 |
+
scale_factor_applied = (
|
| 725 |
+
scale_factor_applied_all[b : b + 1]
|
| 726 |
+
if torch.is_tensor(scale_factor_applied_all)
|
| 727 |
+
else None
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
scales_b = gaussians.singular_values[b : b + 1]
|
| 731 |
+
means_b = gaussians.mean_vectors[b : b + 1]
|
| 732 |
+
quats_b = gaussians.quaternions[b : b + 1]
|
| 733 |
+
opac_b = gaussians.opacities[b : b + 1]
|
| 734 |
+
|
| 735 |
+
base_values = out.get("gaussian_base_values", None)
|
| 736 |
+
gauss_grid_shape = None
|
| 737 |
+
base_means_b = None
|
| 738 |
+
base_scales_b = None
|
| 739 |
+
angular_cell_b = None
|
| 740 |
+
if base_values is not None and hasattr(base_values, "rays"):
|
| 741 |
+
_, _, l, hb, wb = base_values.rays.shape
|
| 742 |
+
gauss_grid_shape = (int(l), int(hb), int(wb))
|
| 743 |
+
inv_dist_b = base_values.inv_distance[b : b + 1].clamp(min=1e-6)
|
| 744 |
+
base_rays_b = F.normalize(base_values.rays[b : b + 1], dim=1, eps=1e-6)
|
| 745 |
+
base_means_grid = base_rays_b / inv_dist_b
|
| 746 |
+
base_scales_b = base_values.scales[b : b + 1]
|
| 747 |
+
init_output = out.get("initializer_output", None)
|
| 748 |
+
global_scale = (
|
| 749 |
+
init_output.global_scale[b : b + 1]
|
| 750 |
+
if init_output is not None
|
| 751 |
+
and getattr(init_output, "global_scale", None) is not None
|
| 752 |
+
else None
|
| 753 |
+
)
|
| 754 |
+
if torch.is_tensor(global_scale):
|
| 755 |
+
base_means_grid = base_means_grid * global_scale.view(-1, 1, 1, 1, 1)
|
| 756 |
+
base_scales_b = base_scales_b * global_scale.view(-1, 1, 1, 1, 1)
|
| 757 |
+
base_means_b = base_means_grid.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 758 |
+
angular_cell = getattr(base_values, "angular_cell", None)
|
| 759 |
+
angular_cell_b = angular_cell[b : b + 1] if torch.is_tensor(angular_cell) else None
|
| 760 |
+
|
| 761 |
+
return {
|
| 762 |
+
"delta_xy_eff": delta_xy_raw,
|
| 763 |
+
"delta_rho_raw": delta_rho_raw,
|
| 764 |
+
"delta_grid": delta_grid_raw,
|
| 765 |
+
"delta_rho_applied": delta_rho_applied,
|
| 766 |
+
"scale_factor_applied": scale_factor_applied,
|
| 767 |
+
"gaussian_scales": scales_b,
|
| 768 |
+
"gaussian_quaternions": quats_b,
|
| 769 |
+
"gaussian_angular_cell": angular_cell_b,
|
| 770 |
+
"gaussian_mean_vectors": means_b,
|
| 771 |
+
"gaussian_base_mean_vectors": base_means_b,
|
| 772 |
+
"gaussian_base_scales": base_scales_b,
|
| 773 |
+
"gaussian_opacities": opac_b,
|
| 774 |
+
"gauss_grid_shape": gauss_grid_shape,
|
| 775 |
+
"projected_scale_factor": projected_scale_factor,
|
| 776 |
+
}
|
| 777 |
+
|
| 778 |
+
def _compute_view_loss(
|
| 779 |
+
self,
|
| 780 |
+
*,
|
| 781 |
+
pred_rgb_linear: torch.Tensor,
|
| 782 |
+
pred_alpha: torch.Tensor,
|
| 783 |
+
pred_depth_m: torch.Tensor,
|
| 784 |
+
pred_depth2_m: torch.Tensor | None,
|
| 785 |
+
gt_rgb_u8: torch.Tensor,
|
| 786 |
+
gt_depth_m: torch.Tensor,
|
| 787 |
+
mask: torch.Tensor,
|
| 788 |
+
apply_color: bool,
|
| 789 |
+
apply_alpha: bool,
|
| 790 |
+
apply_depth: bool,
|
| 791 |
+
apply_percep: bool,
|
| 792 |
+
apply_tv: bool,
|
| 793 |
+
apply_grad: bool,
|
| 794 |
+
apply_grad_img: bool,
|
| 795 |
+
grad_img_circular_h: bool | None = None,
|
| 796 |
+
gaussian_scales: torch.Tensor | None = None,
|
| 797 |
+
gaussian_quaternions: torch.Tensor | None = None,
|
| 798 |
+
gaussian_angular_cell: torch.Tensor | None = None,
|
| 799 |
+
delta_xy: torch.Tensor | None = None,
|
| 800 |
+
delta_rho: torch.Tensor | None = None,
|
| 801 |
+
delta_grid: torch.Tensor | None = None,
|
| 802 |
+
gaussian_mean_vectors: torch.Tensor | None = None,
|
| 803 |
+
gaussian_base_mean_vectors: torch.Tensor | None = None,
|
| 804 |
+
gaussian_opacities: torch.Tensor | None = None,
|
| 805 |
+
gauss_grid_shape: tuple[int, int, int] | None = None,
|
| 806 |
+
projected_scale_factor: float | torch.Tensor | None = None,
|
| 807 |
+
projection_model: str | None = None,
|
| 808 |
+
projection_intrinsics: torch.Tensor | None = None,
|
| 809 |
+
projection_camera_params: torch.Tensor | None = None,
|
| 810 |
+
loss_scale: float = 1.0,
|
| 811 |
+
apply_splat: bool | None = None,
|
| 812 |
+
depth_mask: torch.Tensor | None = None,
|
| 813 |
+
) -> dict[str, torch.Tensor]:
|
| 814 |
+
losses = self.loss_fn(
|
| 815 |
+
pred_rgb_linear=pred_rgb_linear,
|
| 816 |
+
pred_alpha=pred_alpha,
|
| 817 |
+
pred_depth_m=pred_depth_m,
|
| 818 |
+
pred_depth2_m=pred_depth2_m,
|
| 819 |
+
gt_rgb_u8=gt_rgb_u8,
|
| 820 |
+
gt_depth_m=gt_depth_m,
|
| 821 |
+
mask=mask,
|
| 822 |
+
depth_mask=depth_mask,
|
| 823 |
+
gaussian_scales=gaussian_scales,
|
| 824 |
+
gaussian_quaternions=gaussian_quaternions,
|
| 825 |
+
gaussian_angular_cell=gaussian_angular_cell,
|
| 826 |
+
delta_xy=delta_xy,
|
| 827 |
+
delta_rho=delta_rho,
|
| 828 |
+
delta_grid=delta_grid,
|
| 829 |
+
apply_color=bool(apply_color),
|
| 830 |
+
apply_alpha=bool(apply_alpha),
|
| 831 |
+
apply_depth=bool(apply_depth),
|
| 832 |
+
apply_percep=bool(apply_percep),
|
| 833 |
+
apply_tv=bool(apply_tv),
|
| 834 |
+
apply_grad=bool(apply_grad),
|
| 835 |
+
apply_grad_img=bool(apply_grad_img),
|
| 836 |
+
grad_img_circular_h=grad_img_circular_h,
|
| 837 |
+
apply_delta=bool(torch.is_tensor(delta_xy) or torch.is_tensor(delta_rho)),
|
| 838 |
+
apply_splat=bool(torch.is_tensor(gaussian_scales)) if apply_splat is None else bool(apply_splat),
|
| 839 |
+
gaussian_mean_vectors=gaussian_mean_vectors,
|
| 840 |
+
gaussian_base_mean_vectors=gaussian_base_mean_vectors,
|
| 841 |
+
gaussian_opacities=gaussian_opacities,
|
| 842 |
+
gauss_grid_shape=gauss_grid_shape,
|
| 843 |
+
projected_scale_factor=projected_scale_factor,
|
| 844 |
+
projection_model=projection_model,
|
| 845 |
+
projection_intrinsics=projection_intrinsics,
|
| 846 |
+
projection_camera_params=projection_camera_params,
|
| 847 |
+
)
|
| 848 |
+
scale = float(loss_scale)
|
| 849 |
+
if abs(scale - 1.0) > 1e-8:
|
| 850 |
+
losses = {k: (v * scale) for k, v in losses.items()}
|
| 851 |
+
return losses
|
| 852 |
+
|
| 853 |
+
def _build_pinhole_strategy(
|
| 854 |
+
self,
|
| 855 |
+
batch: Any,
|
| 856 |
+
step: int,
|
| 857 |
+
need_vis: bool = False,
|
| 858 |
+
dataset_name: str = "re10k",
|
| 859 |
+
) -> _ModeStrategy:
|
| 860 |
+
src_u8 = self._as_bchw_rgb_u8(batch.src_rgb_u8.to(self.device, non_blocking=True))
|
| 861 |
+
tgt_u8 = self._as_bchw_rgb_u8(batch.tgt_rgb_u8.to(self.device, non_blocking=True))
|
| 862 |
+
src_u8_orig = getattr(batch, "src_rgb_u8_orig", None)
|
| 863 |
+
tgt_u8_orig = getattr(batch, "tgt_rgb_u8_orig", None)
|
| 864 |
+
src_depth_gt = getattr(batch, "src_depth_m", None)
|
| 865 |
+
tgt_depth_gt = getattr(batch, "tgt_depth_m", None)
|
| 866 |
+
src_depth_gt_orig = getattr(batch, "src_depth_m_orig", None)
|
| 867 |
+
tgt_depth_gt_orig = getattr(batch, "tgt_depth_m_orig", None)
|
| 868 |
+
has_depth_gt = torch.is_tensor(src_depth_gt) and torch.is_tensor(tgt_depth_gt)
|
| 869 |
+
if has_depth_gt:
|
| 870 |
+
src_depth_gt = self._as_b1hw_depth(
|
| 871 |
+
src_depth_gt.to(self.device, non_blocking=True).to(torch.float32)
|
| 872 |
+
)
|
| 873 |
+
tgt_depth_gt = self._as_b1hw_depth(
|
| 874 |
+
tgt_depth_gt.to(self.device, non_blocking=True).to(torch.float32)
|
| 875 |
+
)
|
| 876 |
+
has_depth_gt_orig = torch.is_tensor(src_depth_gt_orig) and torch.is_tensor(tgt_depth_gt_orig)
|
| 877 |
+
if has_depth_gt_orig:
|
| 878 |
+
src_depth_gt_orig = self._as_b1hw_depth(
|
| 879 |
+
src_depth_gt_orig.to(self.device, non_blocking=True).to(torch.float32)
|
| 880 |
+
)
|
| 881 |
+
tgt_depth_gt_orig = self._as_b1hw_depth(
|
| 882 |
+
tgt_depth_gt_orig.to(self.device, non_blocking=True).to(torch.float32)
|
| 883 |
+
)
|
| 884 |
+
src_w2c = self._as_b44_pose(batch.src_w2c.to(self.device, non_blocking=True).to(torch.float32))
|
| 885 |
+
tgt_w2c = self._as_b44_pose(batch.tgt_w2c.to(self.device, non_blocking=True).to(torch.float32))
|
| 886 |
+
src_k3 = self._as_b33_intrinsics(batch.src_intrinsics.to(self.device, non_blocking=True).to(torch.float32))
|
| 887 |
+
tgt_k3 = self._as_b33_intrinsics(batch.tgt_intrinsics.to(self.device, non_blocking=True).to(torch.float32))
|
| 888 |
+
src_k3_orig = getattr(batch, "src_intrinsics_orig", None)
|
| 889 |
+
tgt_k3_orig = getattr(batch, "tgt_intrinsics_orig", None)
|
| 890 |
+
has_orig_vis = (
|
| 891 |
+
torch.is_tensor(src_u8_orig)
|
| 892 |
+
and torch.is_tensor(tgt_u8_orig)
|
| 893 |
+
and torch.is_tensor(src_k3_orig)
|
| 894 |
+
and torch.is_tensor(tgt_k3_orig)
|
| 895 |
+
)
|
| 896 |
+
if has_orig_vis:
|
| 897 |
+
src_u8_orig = self._as_bchw_rgb_u8(src_u8_orig.to(self.device, non_blocking=True))
|
| 898 |
+
tgt_u8_orig = self._as_bchw_rgb_u8(tgt_u8_orig.to(self.device, non_blocking=True))
|
| 899 |
+
src_k3_orig = self._as_b33_intrinsics(
|
| 900 |
+
src_k3_orig.to(self.device, non_blocking=True).to(torch.float32)
|
| 901 |
+
)
|
| 902 |
+
tgt_k3_orig = self._as_b33_intrinsics(
|
| 903 |
+
tgt_k3_orig.to(self.device, non_blocking=True).to(torch.float32)
|
| 904 |
+
)
|
| 905 |
+
src_depth_gt_dist = None
|
| 906 |
+
tgt_depth_gt_dist = None
|
| 907 |
+
src_unik3d_gt_dist = None
|
| 908 |
+
if has_depth_gt:
|
| 909 |
+
src_unik3d_gt_dist = self._pinhole_z_to_supervision_distance(src_depth_gt, src_k3)
|
| 910 |
+
src_depth_gt_dist = src_unik3d_gt_dist
|
| 911 |
+
tgt_depth_gt_dist = self._pinhole_z_to_supervision_distance(tgt_depth_gt, tgt_k3)
|
| 912 |
+
|
| 913 |
+
src = src_u8.float().clamp(0, 255) / 255.0
|
| 914 |
+
tgt = tgt_u8.float().clamp(0, 255) / 255.0
|
| 915 |
+
distance_init_cap_m = self._distance_init_cap_for_dataset(dataset_name)
|
| 916 |
+
|
| 917 |
+
share_src_forward = bool(getattr(batch, "share_src_forward", False)) and int(src.shape[0]) > 1
|
| 918 |
+
|
| 919 |
+
def _repeat_first_dim(value: Any, batch_size: int) -> Any:
|
| 920 |
+
if torch.is_tensor(value):
|
| 921 |
+
if value.ndim > 0 and int(value.shape[0]) == 1:
|
| 922 |
+
return value.repeat(batch_size, *([1] * (value.ndim - 1)))
|
| 923 |
+
return value
|
| 924 |
+
if hasattr(value, "_fields"):
|
| 925 |
+
return type(value)(*[_repeat_first_dim(getattr(value, field), batch_size) for field in value._fields])
|
| 926 |
+
return value
|
| 927 |
+
|
| 928 |
+
if share_src_forward:
|
| 929 |
+
out_single = self.model(
|
| 930 |
+
image=src[0:1],
|
| 931 |
+
image_u8=src_u8[0:1],
|
| 932 |
+
camera_intrinsics=src_k3[0:1],
|
| 933 |
+
camera_model="pinhole",
|
| 934 |
+
depth_gt=(src_depth_gt_dist[0:1] if torch.is_tensor(src_depth_gt_dist) else None),
|
| 935 |
+
distance_init_cap_m=distance_init_cap_m,
|
| 936 |
+
return_aux=True,
|
| 937 |
+
)
|
| 938 |
+
out = {k: _repeat_first_dim(v, int(src.shape[0])) for k, v in out_single.items()}
|
| 939 |
+
else:
|
| 940 |
+
out = self.model(
|
| 941 |
+
image=src,
|
| 942 |
+
image_u8=src_u8,
|
| 943 |
+
camera_intrinsics=src_k3,
|
| 944 |
+
camera_model="pinhole",
|
| 945 |
+
depth_gt=src_depth_gt_dist,
|
| 946 |
+
distance_init_cap_m=distance_init_cap_m,
|
| 947 |
+
return_aux=True,
|
| 948 |
+
)
|
| 949 |
+
gaussians = out["gaussians"]
|
| 950 |
+
src_render_k3 = src_k3
|
| 951 |
+
tgt_render_k3 = tgt_k3
|
| 952 |
+
src_depth_gt_z_render = src_depth_gt if has_depth_gt else None
|
| 953 |
+
tgt_depth_gt_z_render = tgt_depth_gt if has_depth_gt else None
|
| 954 |
+
src_depth_gt_render_valid = (torch.isfinite(src_depth_gt) & (src_depth_gt > 0.0)) if has_depth_gt else None
|
| 955 |
+
tgt_depth_gt_render_valid = (torch.isfinite(tgt_depth_gt) & (tgt_depth_gt > 0.0)) if has_depth_gt else None
|
| 956 |
+
aux_ray_target_all = out.get("unik3d_gt_rays", None)
|
| 957 |
+
def make_world_gaussians(b: int, g_b: Any) -> Any:
|
| 958 |
+
return g_b
|
| 959 |
+
|
| 960 |
+
def make_sample(b: int, g_world: Any, enable_vis: bool) -> dict[str, Any]:
|
| 961 |
+
src_h = int(src_u8.shape[-2])
|
| 962 |
+
src_w = int(src_u8.shape[-1])
|
| 963 |
+
tgt_h = int(tgt_u8.shape[-2])
|
| 964 |
+
tgt_w = int(tgt_u8.shape[-1])
|
| 965 |
+
ident = torch.eye(4, dtype=src_w2c.dtype, device=self.device).unsqueeze(0)
|
| 966 |
+
rel_tgt_w2c = tgt_w2c[b : b + 1] @ torch.linalg.inv(src_w2c[b : b + 1])
|
| 967 |
+
src_k_render_b = src_render_k3[b : b + 1]
|
| 968 |
+
tgt_k_render_b = tgt_render_k3[b : b + 1]
|
| 969 |
+
src_out = self.renderer(
|
| 970 |
+
g_world,
|
| 971 |
+
extrinsics=ident,
|
| 972 |
+
intrinsics=to_k4(src_k_render_b),
|
| 973 |
+
image_width=src_w,
|
| 974 |
+
image_height=src_h,
|
| 975 |
+
)
|
| 976 |
+
tgt_out = self.renderer(
|
| 977 |
+
g_world,
|
| 978 |
+
extrinsics=rel_tgt_w2c,
|
| 979 |
+
intrinsics=to_k4(tgt_k_render_b),
|
| 980 |
+
image_width=tgt_w,
|
| 981 |
+
image_height=tgt_h,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
zeros_src_depth = torch.zeros((1, 1, src_h, src_w), dtype=torch.float32, device=self.device)
|
| 985 |
+
zeros_tgt_depth = torch.zeros((1, 1, tgt_h, tgt_w), dtype=torch.float32, device=self.device)
|
| 986 |
+
ones_mask = torch.ones_like(zeros_src_depth)
|
| 987 |
+
fx_b = float(src_k_render_b[0, 0, 0].item())
|
| 988 |
+
fy_b = float(src_k_render_b[0, 1, 1].item())
|
| 989 |
+
proj_scale_pinhole = 0.5 * (fx_b + fy_b)
|
| 990 |
+
reg_inputs = self._collect_regularization_inputs(
|
| 991 |
+
out=out,
|
| 992 |
+
gaussians=gaussians,
|
| 993 |
+
b=b,
|
| 994 |
+
projected_scale_factor=proj_scale_pinhole,
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
src_depth_for_visibility = None
|
| 998 |
+
tgt_gt_depth_for_mask = (
|
| 999 |
+
tgt_depth_gt_z_render[b : b + 1]
|
| 1000 |
+
if has_depth_gt and torch.is_tensor(tgt_depth_gt_z_render)
|
| 1001 |
+
else None
|
| 1002 |
+
)
|
| 1003 |
+
if has_depth_gt:
|
| 1004 |
+
src_depth_for_visibility = (
|
| 1005 |
+
src_depth_gt_z_render[b : b + 1]
|
| 1006 |
+
if torch.is_tensor(src_depth_gt_z_render)
|
| 1007 |
+
else src_depth_gt[b : b + 1]
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
tgt_depth_for_mask = self._pick_depth_for_pinhole_frustum_mask(
|
| 1011 |
+
gt_depth=tgt_gt_depth_for_mask,
|
| 1012 |
+
pred_depth=tgt_out.depth,
|
| 1013 |
+
)
|
| 1014 |
+
tgt_frustum_mask = compute_frustum_mask(
|
| 1015 |
+
depth=tgt_depth_for_mask,
|
| 1016 |
+
tgt_w2c=tgt_w2c[b : b + 1],
|
| 1017 |
+
src_w2c=src_w2c[b : b + 1],
|
| 1018 |
+
src_k3=src_k_render_b,
|
| 1019 |
+
tgt_k3=tgt_k_render_b,
|
| 1020 |
+
img_h=tgt_h,
|
| 1021 |
+
img_w=tgt_w,
|
| 1022 |
+
source_img_h=src_h,
|
| 1023 |
+
source_img_w=src_w,
|
| 1024 |
+
source_depth=src_depth_for_visibility,
|
| 1025 |
+
)
|
| 1026 |
+
tgt_frustum_mask_raw = tgt_frustum_mask
|
| 1027 |
+
tgt_frustum_mask = self._erode_supervision_mask(
|
| 1028 |
+
tgt_frustum_mask,
|
| 1029 |
+
self.target_mask_erode_px,
|
| 1030 |
+
circular_h=False,
|
| 1031 |
+
)
|
| 1032 |
+
src_depth_pred = self._clamp_distance_for_supervision(
|
| 1033 |
+
out["distance_layers"][b : b + 1, 0:1],
|
| 1034 |
+
clamp_max=False,
|
| 1035 |
+
)
|
| 1036 |
+
src_depth2_pred = (
|
| 1037 |
+
self._clamp_distance_for_supervision(out["distance_layers"][b : b + 1, 1:2], clamp_max=False)
|
| 1038 |
+
if out["distance_layers"] is not None and out["distance_layers"].shape[1] > 1
|
| 1039 |
+
else None
|
| 1040 |
+
)
|
| 1041 |
+
src_depth2_gt_for_aux = (
|
| 1042 |
+
src_unik3d_gt_dist[b : b + 1]
|
| 1043 |
+
if torch.is_tensor(src_unik3d_gt_dist)
|
| 1044 |
+
else None
|
| 1045 |
+
)
|
| 1046 |
+
src_depth2_mask_for_aux = src_depth_gt[b : b + 1] > 0.0 if has_depth_gt else None
|
| 1047 |
+
tgt_depth_pred = self._pinhole_z_to_supervision_distance(
|
| 1048 |
+
tgt_out.depth,
|
| 1049 |
+
tgt_k_render_b,
|
| 1050 |
+
clamp_max=False,
|
| 1051 |
+
)
|
| 1052 |
+
tgt_depth_loss_mask = self._rendered_depth_valid_for_inv_loss(tgt_depth_pred, tgt_out.alpha)
|
| 1053 |
+
if torch.is_tensor(tgt_depth_gt_render_valid):
|
| 1054 |
+
tgt_depth_loss_mask = tgt_depth_loss_mask * tgt_depth_gt_render_valid[b : b + 1].to(
|
| 1055 |
+
device=tgt_depth_loss_mask.device,
|
| 1056 |
+
dtype=tgt_depth_loss_mask.dtype,
|
| 1057 |
+
)
|
| 1058 |
+
tgt_extra_loss_terms: list[dict[str, Any]] = []
|
| 1059 |
+
vis_payload = None
|
| 1060 |
+
if enable_vis:
|
| 1061 |
+
vis_src_u8 = src_u8[b : b + 1]
|
| 1062 |
+
vis_tgt_u8 = tgt_u8[b : b + 1]
|
| 1063 |
+
vis_src_depth_gt = (src_depth_gt[b : b + 1] if has_depth_gt else None)
|
| 1064 |
+
vis_tgt_depth_gt = (tgt_depth_gt[b : b + 1] if has_depth_gt else None)
|
| 1065 |
+
vis_src_out = src_out
|
| 1066 |
+
vis_tgt_out = tgt_out
|
| 1067 |
+
if has_orig_vis:
|
| 1068 |
+
vis_src_u8 = src_u8_orig[b : b + 1]
|
| 1069 |
+
vis_tgt_u8 = tgt_u8_orig[b : b + 1]
|
| 1070 |
+
vis_src_depth_gt = (src_depth_gt_orig[b : b + 1] if has_depth_gt_orig else None)
|
| 1071 |
+
vis_tgt_depth_gt = (tgt_depth_gt_orig[b : b + 1] if has_depth_gt_orig else None)
|
| 1072 |
+
vis_src_render_k3 = src_k3_orig[b : b + 1]
|
| 1073 |
+
vis_tgt_render_k3 = tgt_k3_orig[b : b + 1]
|
| 1074 |
+
vis_src_out = self.renderer(
|
| 1075 |
+
g_world,
|
| 1076 |
+
extrinsics=ident,
|
| 1077 |
+
intrinsics=to_k4(vis_src_render_k3),
|
| 1078 |
+
image_width=int(vis_src_u8.shape[-1]),
|
| 1079 |
+
image_height=int(vis_src_u8.shape[-2]),
|
| 1080 |
+
)
|
| 1081 |
+
vis_tgt_out = self.renderer(
|
| 1082 |
+
g_world,
|
| 1083 |
+
extrinsics=rel_tgt_w2c,
|
| 1084 |
+
intrinsics=to_k4(vis_tgt_render_k3),
|
| 1085 |
+
image_width=int(vis_tgt_u8.shape[-1]),
|
| 1086 |
+
image_height=int(vis_tgt_u8.shape[-2]),
|
| 1087 |
+
)
|
| 1088 |
+
src_unik3d_depth = None
|
| 1089 |
+
tgt_unik3d_depth = None
|
| 1090 |
+
raw_dist = out.get("unik3d_distance", None)
|
| 1091 |
+
if torch.is_tensor(raw_dist):
|
| 1092 |
+
try:
|
| 1093 |
+
conditioning_rays = out.get("unik3d_ray_conditioning_rays", None)
|
| 1094 |
+
if not torch.is_tensor(conditioning_rays):
|
| 1095 |
+
conditioning_rays = out.get("unik3d_rays", None)
|
| 1096 |
+
ray_z = (
|
| 1097 |
+
conditioning_rays[b : b + 1, 2:3].detach()
|
| 1098 |
+
if torch.is_tensor(conditioning_rays)
|
| 1099 |
+
else None
|
| 1100 |
+
)
|
| 1101 |
+
if torch.is_tensor(ray_z):
|
| 1102 |
+
if tuple(ray_z.shape[-2:]) != tuple(raw_dist.shape[-2:]):
|
| 1103 |
+
ray_z = F.interpolate(ray_z, size=raw_dist.shape[-2:], mode="bilinear", align_corners=False)
|
| 1104 |
+
src_unik3d_depth = raw_dist[b : b + 1, 0:1].detach() * ray_z
|
| 1105 |
+
else:
|
| 1106 |
+
src_unik3d_depth = self._distance_to_z_depth_pinhole(
|
| 1107 |
+
raw_dist[b : b + 1, 0:1].detach(),
|
| 1108 |
+
src_k_render_b,
|
| 1109 |
+
)
|
| 1110 |
+
except Exception:
|
| 1111 |
+
src_unik3d_depth = raw_dist[b : b + 1, 0:1].detach()
|
| 1112 |
+
if self.enable_tgt_unik3d_vis:
|
| 1113 |
+
try:
|
| 1114 |
+
with torch.no_grad():
|
| 1115 |
+
from unisharp.utils.unik3d_adapter import forward_unik3d_pinhole
|
| 1116 |
+
|
| 1117 |
+
unik_tgt = forward_unik3d_pinhole(
|
| 1118 |
+
self._base_model().feature_extractor.unik3d,
|
| 1119 |
+
rgb_u8=tgt_u8[b : b + 1],
|
| 1120 |
+
intrinsics=tgt_k3[b : b + 1],
|
| 1121 |
+
normalize=True,
|
| 1122 |
+
)
|
| 1123 |
+
dist_tgt = unik_tgt.get("distance", None) if isinstance(unik_tgt, dict) else None
|
| 1124 |
+
if torch.is_tensor(dist_tgt):
|
| 1125 |
+
try:
|
| 1126 |
+
tgt_unik3d_depth = self._distance_to_z_depth_pinhole(
|
| 1127 |
+
dist_tgt[:, 0:1].detach(),
|
| 1128 |
+
tgt_k_render_b,
|
| 1129 |
+
)
|
| 1130 |
+
except Exception:
|
| 1131 |
+
tgt_unik3d_depth = dist_tgt[:, 0:1].detach()
|
| 1132 |
+
except Exception:
|
| 1133 |
+
tgt_unik3d_depth = None
|
| 1134 |
+
|
| 1135 |
+
vis_payload = {
|
| 1136 |
+
"src_gt": (vis_src_u8.float() / 255.0).detach(),
|
| 1137 |
+
"src_pred": vis_src_out.color.clamp(0, 1).detach(),
|
| 1138 |
+
"src_alpha": vis_src_out.alpha.detach(),
|
| 1139 |
+
"src_gt_depth": (vis_src_depth_gt.detach() if torch.is_tensor(vis_src_depth_gt) else None),
|
| 1140 |
+
"src_pred_depth": vis_src_out.depth.detach(),
|
| 1141 |
+
"src_unik3d_depth": src_unik3d_depth,
|
| 1142 |
+
"tgt_gt": (vis_tgt_u8.float() / 255.0).detach(),
|
| 1143 |
+
"tgt_pred": vis_tgt_out.color.clamp(0, 1).detach(),
|
| 1144 |
+
"tgt_alpha": vis_tgt_out.alpha.detach(),
|
| 1145 |
+
"tgt_gt_depth": (vis_tgt_depth_gt.detach() if torch.is_tensor(vis_tgt_depth_gt) else None),
|
| 1146 |
+
"tgt_pred_depth": vis_tgt_out.depth.detach(),
|
| 1147 |
+
"tgt_unik3d_depth": tgt_unik3d_depth,
|
| 1148 |
+
"dataset_name": str(dataset_name),
|
| 1149 |
+
"scene": str(self._item_at(getattr(batch, "scene", None), b, "unknown")),
|
| 1150 |
+
"src_idx": int(self._item_at(getattr(batch, "src_idx", None), b, -1)),
|
| 1151 |
+
"tgt_idx": int(self._item_at(getattr(batch, "tgt_idx", None), b, -1)),
|
| 1152 |
+
"src_pose_w2c": src_w2c[b : b + 1].detach(),
|
| 1153 |
+
"tgt_pose_w2c": tgt_w2c[b : b + 1].detach(),
|
| 1154 |
+
"tgt_metric_mask_raw": tgt_frustum_mask_raw.detach(),
|
| 1155 |
+
"tgt_metric_mask": tgt_frustum_mask.detach(),
|
| 1156 |
+
}
|
| 1157 |
+
|
| 1158 |
+
return {
|
| 1159 |
+
"src_pred_rgb_linear": src_out.color,
|
| 1160 |
+
"src_pred_alpha": src_out.alpha,
|
| 1161 |
+
"src_pred_depth_m": src_depth_pred,
|
| 1162 |
+
"src_pred_depth2_m": src_depth2_pred,
|
| 1163 |
+
"src_gt_rgb_u8": src_u8[b : b + 1],
|
| 1164 |
+
"src_gt_depth_m": (src_depth_gt_dist[b : b + 1] if has_depth_gt and src_depth_gt_dist is not None else zeros_src_depth),
|
| 1165 |
+
"src_mask": ones_mask,
|
| 1166 |
+
"src_apply_depth": False,
|
| 1167 |
+
"src_apply_grad": bool(has_depth_gt),
|
| 1168 |
+
"src_apply_grad_img": bool(has_depth_gt),
|
| 1169 |
+
"src_grad_img_circular_h": False,
|
| 1170 |
+
"tgt_pred_rgb_linear": tgt_out.color,
|
| 1171 |
+
"tgt_pred_alpha": tgt_out.alpha,
|
| 1172 |
+
"tgt_pred_depth_m": tgt_depth_pred,
|
| 1173 |
+
"tgt_gt_rgb_u8": tgt_u8[b : b + 1],
|
| 1174 |
+
"tgt_gt_depth_m": (tgt_depth_gt_dist[b : b + 1] if has_depth_gt and tgt_depth_gt_dist is not None else zeros_tgt_depth),
|
| 1175 |
+
"tgt_mask": tgt_frustum_mask,
|
| 1176 |
+
"tgt_depth_mask": tgt_depth_loss_mask,
|
| 1177 |
+
"tgt_apply_depth": bool(has_depth_gt),
|
| 1178 |
+
"tgt_apply_grad_img": bool(has_depth_gt),
|
| 1179 |
+
"tgt_grad_img_circular_h": False,
|
| 1180 |
+
"tgt_apply_percep": bool(float(self.loss_fn.w.lambda_percep) > 0.0),
|
| 1181 |
+
"tgt_extra_loss_terms": tgt_extra_loss_terms,
|
| 1182 |
+
"aux_losses": self._aux_ray_losses(
|
| 1183 |
+
pred_rays=(
|
| 1184 |
+
out.get("unik3d_rays", None)[b : b + 1]
|
| 1185 |
+
if torch.is_tensor(out.get("unik3d_rays", None))
|
| 1186 |
+
else None
|
| 1187 |
+
),
|
| 1188 |
+
gt_rays=(
|
| 1189 |
+
aux_ray_target_all[b : b + 1]
|
| 1190 |
+
if torch.is_tensor(aux_ray_target_all)
|
| 1191 |
+
else None
|
| 1192 |
+
),
|
| 1193 |
+
mask=ones_mask,
|
| 1194 |
+
pred_distance=(
|
| 1195 |
+
out["unik3d_distance"][b : b + 1, 0:1]
|
| 1196 |
+
if torch.is_tensor(out.get("unik3d_distance", None))
|
| 1197 |
+
else None
|
| 1198 |
+
),
|
| 1199 |
+
pred_distance2=src_depth2_pred,
|
| 1200 |
+
gt_distance=(
|
| 1201 |
+
src_unik3d_gt_dist[b : b + 1]
|
| 1202 |
+
if torch.is_tensor(src_unik3d_gt_dist)
|
| 1203 |
+
else None
|
| 1204 |
+
),
|
| 1205 |
+
gt_distance2=src_depth2_gt_for_aux,
|
| 1206 |
+
depth_mask=(src_depth_gt[b : b + 1] > 0.0 if has_depth_gt else None),
|
| 1207 |
+
depth_mask2=src_depth2_mask_for_aux,
|
| 1208 |
+
),
|
| 1209 |
+
"gaussian_scales": reg_inputs["gaussian_scales"],
|
| 1210 |
+
"gaussian_quaternions": reg_inputs["gaussian_quaternions"],
|
| 1211 |
+
"gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
|
| 1212 |
+
"delta_xy": reg_inputs["delta_xy_eff"],
|
| 1213 |
+
"delta_rho": reg_inputs["delta_rho_raw"],
|
| 1214 |
+
"delta_grid": reg_inputs["delta_grid"],
|
| 1215 |
+
"gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
|
| 1216 |
+
"gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
|
| 1217 |
+
"gaussian_opacities": reg_inputs["gaussian_opacities"],
|
| 1218 |
+
"gauss_grid_shape": reg_inputs["gauss_grid_shape"],
|
| 1219 |
+
"projected_scale_factor": reg_inputs["projected_scale_factor"],
|
| 1220 |
+
"projection_model": "pinhole",
|
| 1221 |
+
"projection_intrinsics": src_k_render_b,
|
| 1222 |
+
"vis_payload": vis_payload,
|
| 1223 |
+
}
|
| 1224 |
+
|
| 1225 |
+
return _ModeStrategy(
|
| 1226 |
+
batch_size=int(src.shape[0]),
|
| 1227 |
+
gaussians=gaussians,
|
| 1228 |
+
make_world_gaussians=make_world_gaussians,
|
| 1229 |
+
make_sample=make_sample,
|
| 1230 |
+
collect_all_vis=bool(getattr(batch, "collect_all_vis", False)),
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
def _build_fisheye624_strategy(
|
| 1234 |
+
self,
|
| 1235 |
+
batch: Any,
|
| 1236 |
+
step: int,
|
| 1237 |
+
need_vis: bool = False,
|
| 1238 |
+
dataset_name: str = "scannetpp_fisheye",
|
| 1239 |
+
) -> _ModeStrategy:
|
| 1240 |
+
del step
|
| 1241 |
+
src_u8 = self._as_bchw_rgb_u8(batch.src_rgb_u8.to(self.device, non_blocking=True))
|
| 1242 |
+
tgt_u8 = self._as_bchw_rgb_u8(batch.tgt_rgb_u8.to(self.device, non_blocking=True))
|
| 1243 |
+
src_depth_gt = self._clamp_distance_for_supervision(
|
| 1244 |
+
self._as_b1hw_depth(batch.src_depth_m.to(self.device, non_blocking=True).to(torch.float32))
|
| 1245 |
+
)
|
| 1246 |
+
tgt_depth_gt = self._clamp_distance_for_supervision(
|
| 1247 |
+
self._as_b1hw_depth(batch.tgt_depth_m.to(self.device, non_blocking=True).to(torch.float32))
|
| 1248 |
+
)
|
| 1249 |
+
src_valid_mask = self._as_b1hw_depth(batch.src_valid_mask.to(self.device, non_blocking=True).to(torch.float32))
|
| 1250 |
+
tgt_valid_mask = self._as_b1hw_depth(batch.tgt_valid_mask.to(self.device, non_blocking=True).to(torch.float32))
|
| 1251 |
+
src_w2c = self._as_b44_pose(batch.src_w2c.to(self.device, non_blocking=True).to(torch.float32))
|
| 1252 |
+
tgt_w2c = self._as_b44_pose(batch.tgt_w2c.to(self.device, non_blocking=True).to(torch.float32))
|
| 1253 |
+
src_cam_params = self._as_b16_camera_params(
|
| 1254 |
+
batch.src_camera_params.to(self.device, non_blocking=True).to(torch.float32)
|
| 1255 |
+
)
|
| 1256 |
+
tgt_cam_params = self._as_b16_camera_params(
|
| 1257 |
+
batch.tgt_camera_params.to(self.device, non_blocking=True).to(torch.float32)
|
| 1258 |
+
)
|
| 1259 |
+
distance_init_cap_m = self._distance_init_cap_for_dataset(dataset_name)
|
| 1260 |
+
|
| 1261 |
+
out = self.model(
|
| 1262 |
+
image=src_u8.float().clamp(0, 255) / 255.0,
|
| 1263 |
+
image_u8=src_u8,
|
| 1264 |
+
camera_intrinsics=None,
|
| 1265 |
+
camera_params=src_cam_params,
|
| 1266 |
+
camera_model="fisheye624",
|
| 1267 |
+
depth_gt=src_depth_gt,
|
| 1268 |
+
distance_init_cap_m=distance_init_cap_m,
|
| 1269 |
+
validity_mask=src_valid_mask,
|
| 1270 |
+
return_aux=True,
|
| 1271 |
+
)
|
| 1272 |
+
|
| 1273 |
+
gaussians = out["gaussians"]
|
| 1274 |
+
src_render_cam_params = src_cam_params
|
| 1275 |
+
tgt_render_cam_params = tgt_cam_params
|
| 1276 |
+
src_render_valid_mask = src_valid_mask
|
| 1277 |
+
tgt_render_valid_mask = tgt_valid_mask
|
| 1278 |
+
aux_ray_target_all = out.get("unik3d_gt_rays", None)
|
| 1279 |
+
|
| 1280 |
+
def make_world_gaussians(b: int, g_b: Any) -> Any:
|
| 1281 |
+
return transform_gaussians_to_world(g_b, src_w2c[b])
|
| 1282 |
+
|
| 1283 |
+
def make_sample(b: int, g_world: Any, enable_vis: bool) -> dict[str, Any]:
|
| 1284 |
+
src_h = int(src_u8.shape[-2])
|
| 1285 |
+
src_w = int(src_u8.shape[-1])
|
| 1286 |
+
tgt_h = int(tgt_u8.shape[-2])
|
| 1287 |
+
tgt_w = int(tgt_u8.shape[-1])
|
| 1288 |
+
src_render = render_gaussians_fisheye624(
|
| 1289 |
+
g_world,
|
| 1290 |
+
extrinsics_w2c=src_w2c[b : b + 1],
|
| 1291 |
+
camera_params=src_render_cam_params[b : b + 1],
|
| 1292 |
+
image_h=src_h,
|
| 1293 |
+
image_w=src_w,
|
| 1294 |
+
valid_mask=src_render_valid_mask[b : b + 1],
|
| 1295 |
+
)
|
| 1296 |
+
tgt_render = render_gaussians_fisheye624(
|
| 1297 |
+
g_world,
|
| 1298 |
+
extrinsics_w2c=tgt_w2c[b : b + 1],
|
| 1299 |
+
camera_params=tgt_render_cam_params[b : b + 1],
|
| 1300 |
+
image_h=tgt_h,
|
| 1301 |
+
image_w=tgt_w,
|
| 1302 |
+
valid_mask=tgt_render_valid_mask[b : b + 1],
|
| 1303 |
+
)
|
| 1304 |
+
reg_inputs = self._collect_regularization_inputs(
|
| 1305 |
+
out=out,
|
| 1306 |
+
gaussians=gaussians,
|
| 1307 |
+
b=b,
|
| 1308 |
+
projected_scale_factor=None,
|
| 1309 |
+
)
|
| 1310 |
+
tgt_depth_for_mask = self._pick_depth_for_fisheye_frustum_mask(
|
| 1311 |
+
gt_depth=tgt_depth_gt[b : b + 1],
|
| 1312 |
+
pred_depth=tgt_render["depth_distance"],
|
| 1313 |
+
gt_valid_mask=tgt_valid_mask[b : b + 1],
|
| 1314 |
+
)
|
| 1315 |
+
tgt_frustum_mask = compute_fisheye624_frustum_mask(
|
| 1316 |
+
depth_distance_m=tgt_depth_for_mask,
|
| 1317 |
+
tgt_w2c=tgt_w2c[b : b + 1],
|
| 1318 |
+
src_w2c=src_w2c[b : b + 1],
|
| 1319 |
+
tgt_camera_params=tgt_render_cam_params[b : b + 1],
|
| 1320 |
+
src_camera_params=src_render_cam_params[b : b + 1],
|
| 1321 |
+
src_valid_mask=src_render_valid_mask[b : b + 1] * src_render["valid_mask"],
|
| 1322 |
+
source_depth_distance_m=src_depth_gt[b : b + 1],
|
| 1323 |
+
)
|
| 1324 |
+
src_mask = src_render_valid_mask[b : b + 1] * src_render["valid_mask"]
|
| 1325 |
+
src_depth_mask = src_mask
|
| 1326 |
+
tgt_mask = tgt_render_valid_mask[b : b + 1] * tgt_render["valid_mask"] * tgt_frustum_mask
|
| 1327 |
+
tgt_mask_raw = tgt_mask
|
| 1328 |
+
tgt_mask = self._erode_supervision_mask(
|
| 1329 |
+
tgt_mask,
|
| 1330 |
+
self.target_mask_erode_px,
|
| 1331 |
+
circular_h=False,
|
| 1332 |
+
)
|
| 1333 |
+
src_depth_pred = self._clamp_distance_for_supervision(
|
| 1334 |
+
out["distance_layers"][b : b + 1, 0:1],
|
| 1335 |
+
clamp_max=False,
|
| 1336 |
+
)
|
| 1337 |
+
src_depth2_pred = (
|
| 1338 |
+
self._clamp_distance_for_supervision(out["distance_layers"][b : b + 1, 1:2], clamp_max=False)
|
| 1339 |
+
if out["distance_layers"] is not None and out["distance_layers"].shape[1] > 1
|
| 1340 |
+
else None
|
| 1341 |
+
)
|
| 1342 |
+
tgt_depth_pred = self._clamp_distance_for_supervision(tgt_render["depth_distance"], clamp_max=False)
|
| 1343 |
+
tgt_depth_loss_mask = self._rendered_depth_valid_for_inv_loss(tgt_depth_pred, tgt_render["alpha"])
|
| 1344 |
+
|
| 1345 |
+
src_loss_terms = [
|
| 1346 |
+
{
|
| 1347 |
+
"pred_rgb_linear": src_render["color"],
|
| 1348 |
+
"pred_alpha": src_render["alpha"],
|
| 1349 |
+
"pred_depth_m": src_render["depth_distance"],
|
| 1350 |
+
"pred_depth2_m": None,
|
| 1351 |
+
"gt_rgb_u8": src_u8[b : b + 1],
|
| 1352 |
+
"gt_depth_m": src_depth_gt[b : b + 1],
|
| 1353 |
+
"mask": src_mask,
|
| 1354 |
+
"apply_color": True,
|
| 1355 |
+
"apply_alpha": True,
|
| 1356 |
+
"apply_depth": False,
|
| 1357 |
+
"apply_percep": False,
|
| 1358 |
+
"apply_tv": False,
|
| 1359 |
+
"apply_grad": False,
|
| 1360 |
+
"apply_grad_img": False,
|
| 1361 |
+
"grad_img_circular_h": False,
|
| 1362 |
+
"gaussian_scales": None,
|
| 1363 |
+
"gaussian_quaternions": None,
|
| 1364 |
+
"gaussian_angular_cell": None,
|
| 1365 |
+
"delta_xy": None,
|
| 1366 |
+
"gaussian_mean_vectors": None,
|
| 1367 |
+
"gaussian_opacities": None,
|
| 1368 |
+
"gauss_grid_shape": None,
|
| 1369 |
+
"projected_scale_factor": None,
|
| 1370 |
+
"apply_splat": False,
|
| 1371 |
+
"loss_scale": 1.0,
|
| 1372 |
+
}
|
| 1373 |
+
]
|
| 1374 |
+
src_extra_loss_terms = [
|
| 1375 |
+
{
|
| 1376 |
+
"pred_rgb_linear": torch.zeros((1, 3, src_h, src_w), dtype=torch.float32, device=self.device),
|
| 1377 |
+
"pred_alpha": torch.zeros((1, 1, src_h, src_w), dtype=torch.float32, device=self.device),
|
| 1378 |
+
"pred_depth_m": src_depth_pred,
|
| 1379 |
+
"pred_depth2_m": src_depth2_pred,
|
| 1380 |
+
"gt_rgb_u8": torch.zeros((1, 3, src_h, src_w), dtype=torch.uint8, device=self.device),
|
| 1381 |
+
"gt_depth_m": src_depth_gt[b : b + 1],
|
| 1382 |
+
"mask": src_depth_mask,
|
| 1383 |
+
"apply_color": False,
|
| 1384 |
+
"apply_alpha": False,
|
| 1385 |
+
"apply_depth": False,
|
| 1386 |
+
"apply_percep": False,
|
| 1387 |
+
"apply_tv": True,
|
| 1388 |
+
"apply_grad": True,
|
| 1389 |
+
"apply_grad_img": True,
|
| 1390 |
+
"grad_img_circular_h": False,
|
| 1391 |
+
"gaussian_scales": reg_inputs["gaussian_scales"],
|
| 1392 |
+
"gaussian_quaternions": reg_inputs["gaussian_quaternions"],
|
| 1393 |
+
"gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
|
| 1394 |
+
"delta_xy": reg_inputs["delta_xy_eff"],
|
| 1395 |
+
"delta_rho": reg_inputs["delta_rho_raw"],
|
| 1396 |
+
"delta_grid": reg_inputs["delta_grid"],
|
| 1397 |
+
"gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
|
| 1398 |
+
"gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
|
| 1399 |
+
"gaussian_opacities": reg_inputs["gaussian_opacities"],
|
| 1400 |
+
"gauss_grid_shape": reg_inputs["gauss_grid_shape"],
|
| 1401 |
+
"projected_scale_factor": None,
|
| 1402 |
+
"projection_model": "fisheye624",
|
| 1403 |
+
"projection_camera_params": src_render_cam_params[b : b + 1],
|
| 1404 |
+
"apply_splat": True,
|
| 1405 |
+
"loss_scale": 1.0,
|
| 1406 |
+
}
|
| 1407 |
+
]
|
| 1408 |
+
tgt_extra_loss_terms = []
|
| 1409 |
+
|
| 1410 |
+
vis_payload = None
|
| 1411 |
+
if enable_vis:
|
| 1412 |
+
src_unik3d_depth = out["unik3d_distance"][b : b + 1, 0:1].detach() if torch.is_tensor(out.get("unik3d_distance", None)) else None
|
| 1413 |
+
tgt_unik3d_depth = None
|
| 1414 |
+
if (
|
| 1415 |
+
tgt_unik3d_depth is None
|
| 1416 |
+
and self.enable_tgt_unik3d_vis
|
| 1417 |
+
):
|
| 1418 |
+
try:
|
| 1419 |
+
with torch.no_grad():
|
| 1420 |
+
from unisharp.utils.unik3d_adapter import forward_unik3d_fisheye624
|
| 1421 |
+
|
| 1422 |
+
unik_tgt = forward_unik3d_fisheye624(
|
| 1423 |
+
self._base_model().feature_extractor.unik3d,
|
| 1424 |
+
rgb_u8=tgt_u8[b : b + 1],
|
| 1425 |
+
camera_params=tgt_render_cam_params[b : b + 1],
|
| 1426 |
+
normalize=True,
|
| 1427 |
+
validity_mask=tgt_valid_mask[b : b + 1],
|
| 1428 |
+
)
|
| 1429 |
+
dist_tgt = unik_tgt.get("distance", None) if isinstance(unik_tgt, dict) else None
|
| 1430 |
+
if torch.is_tensor(dist_tgt):
|
| 1431 |
+
tgt_unik3d_depth = dist_tgt[:, 0:1].detach()
|
| 1432 |
+
except Exception:
|
| 1433 |
+
tgt_unik3d_depth = None
|
| 1434 |
+
vis_payload = {
|
| 1435 |
+
"src_gt": (src_u8[b : b + 1].float() / 255.0).detach(),
|
| 1436 |
+
"src_pred": src_render["color"].clamp(0, 1).detach(),
|
| 1437 |
+
"src_alpha": src_render["alpha"].detach(),
|
| 1438 |
+
"src_gt_depth": src_depth_gt[b : b + 1].detach(),
|
| 1439 |
+
"src_pred_depth": src_render["depth_distance"].detach(),
|
| 1440 |
+
"src_unik3d_depth": src_unik3d_depth,
|
| 1441 |
+
"src_metric_mask": src_mask.detach(),
|
| 1442 |
+
"tgt_gt": (tgt_u8[b : b + 1].float() / 255.0).detach(),
|
| 1443 |
+
"tgt_pred": tgt_render["color"].clamp(0, 1).detach(),
|
| 1444 |
+
"tgt_alpha": tgt_render["alpha"].detach(),
|
| 1445 |
+
"tgt_gt_depth": tgt_depth_gt[b : b + 1].detach(),
|
| 1446 |
+
"tgt_pred_depth": tgt_depth_pred.detach(),
|
| 1447 |
+
"tgt_unik3d_depth": tgt_unik3d_depth,
|
| 1448 |
+
"dataset_name": str(dataset_name),
|
| 1449 |
+
"scene": str(self._first_item(getattr(batch, "scene", None), "unknown")),
|
| 1450 |
+
"src_idx": int(self._first_item(getattr(batch, "src_idx", None), -1)),
|
| 1451 |
+
"tgt_idx": int(self._first_item(getattr(batch, "tgt_idx", None), -1)),
|
| 1452 |
+
"src_pose_w2c": src_w2c[b : b + 1].detach(),
|
| 1453 |
+
"tgt_pose_w2c": tgt_w2c[b : b + 1].detach(),
|
| 1454 |
+
"tgt_metric_mask_raw": tgt_mask_raw.detach(),
|
| 1455 |
+
"tgt_metric_mask": tgt_mask.detach(),
|
| 1456 |
+
}
|
| 1457 |
+
|
| 1458 |
+
|
| 1459 |
+
return {
|
| 1460 |
+
"src_loss_terms": src_loss_terms,
|
| 1461 |
+
"src_extra_loss_terms": src_extra_loss_terms,
|
| 1462 |
+
"tgt_pred_rgb_linear": tgt_render["color"],
|
| 1463 |
+
"tgt_pred_alpha": tgt_render["alpha"],
|
| 1464 |
+
"tgt_pred_depth_m": tgt_depth_pred,
|
| 1465 |
+
"tgt_gt_rgb_u8": tgt_u8[b : b + 1],
|
| 1466 |
+
"tgt_gt_depth_m": tgt_depth_gt[b : b + 1],
|
| 1467 |
+
"tgt_mask": tgt_mask,
|
| 1468 |
+
"tgt_depth_mask": tgt_depth_loss_mask,
|
| 1469 |
+
"tgt_apply_depth": True,
|
| 1470 |
+
"tgt_apply_grad_img": True,
|
| 1471 |
+
"tgt_apply_splat": False,
|
| 1472 |
+
"tgt_grad_img_circular_h": False,
|
| 1473 |
+
"tgt_apply_percep": bool(float(self.loss_fn.w.lambda_percep) > 0.0),
|
| 1474 |
+
"tgt_extra_loss_terms": tgt_extra_loss_terms,
|
| 1475 |
+
"aux_losses": self._aux_ray_losses(
|
| 1476 |
+
pred_rays=(
|
| 1477 |
+
out.get("unik3d_rays", None)[b : b + 1]
|
| 1478 |
+
if torch.is_tensor(out.get("unik3d_rays", None))
|
| 1479 |
+
else None
|
| 1480 |
+
),
|
| 1481 |
+
gt_rays=(
|
| 1482 |
+
aux_ray_target_all[b : b + 1]
|
| 1483 |
+
if torch.is_tensor(aux_ray_target_all)
|
| 1484 |
+
else None
|
| 1485 |
+
),
|
| 1486 |
+
mask=src_render_valid_mask[b : b + 1],
|
| 1487 |
+
pred_distance=(
|
| 1488 |
+
out["unik3d_distance"][b : b + 1, 0:1]
|
| 1489 |
+
if torch.is_tensor(out.get("unik3d_distance", None))
|
| 1490 |
+
else None
|
| 1491 |
+
),
|
| 1492 |
+
pred_distance2=None,
|
| 1493 |
+
gt_distance=src_depth_gt[b : b + 1],
|
| 1494 |
+
depth_mask=src_valid_mask[b : b + 1],
|
| 1495 |
+
),
|
| 1496 |
+
"gaussian_scales": reg_inputs["gaussian_scales"],
|
| 1497 |
+
"gaussian_quaternions": reg_inputs["gaussian_quaternions"],
|
| 1498 |
+
"gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
|
| 1499 |
+
"delta_xy": reg_inputs["delta_xy_eff"],
|
| 1500 |
+
"delta_rho": reg_inputs["delta_rho_raw"],
|
| 1501 |
+
"delta_grid": reg_inputs["delta_grid"],
|
| 1502 |
+
"gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
|
| 1503 |
+
"gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
|
| 1504 |
+
"gaussian_opacities": reg_inputs["gaussian_opacities"],
|
| 1505 |
+
"gauss_grid_shape": reg_inputs["gauss_grid_shape"],
|
| 1506 |
+
"projected_scale_factor": reg_inputs["projected_scale_factor"],
|
| 1507 |
+
"projection_model": "fisheye624",
|
| 1508 |
+
"projection_camera_params": src_render_cam_params[b : b + 1],
|
| 1509 |
+
"vis_payload": vis_payload,
|
| 1510 |
+
}
|
| 1511 |
+
|
| 1512 |
+
return _ModeStrategy(
|
| 1513 |
+
batch_size=int(src_u8.shape[0]),
|
| 1514 |
+
gaussians=gaussians,
|
| 1515 |
+
make_world_gaussians=make_world_gaussians,
|
| 1516 |
+
make_sample=make_sample,
|
| 1517 |
+
collect_all_vis=bool(getattr(batch, "collect_all_vis", False)),
|
| 1518 |
+
)
|
| 1519 |
+
|
| 1520 |
+
def _build_fisheye_strategy(
|
| 1521 |
+
self,
|
| 1522 |
+
batch: Any,
|
| 1523 |
+
step: int,
|
| 1524 |
+
need_vis: bool = False,
|
| 1525 |
+
dataset_name: str = "fisheye",
|
| 1526 |
+
) -> _ModeStrategy:
|
| 1527 |
+
camera_model = str(getattr(batch, "camera_model", "fisheye624")).lower()
|
| 1528 |
+
if camera_model != "fisheye624":
|
| 1529 |
+
raise ValueError(
|
| 1530 |
+
f"Unsupported fisheye camera_model={camera_model!r}; expected 'fisheye624'."
|
| 1531 |
+
)
|
| 1532 |
+
return self._build_fisheye624_strategy(
|
| 1533 |
+
batch,
|
| 1534 |
+
step,
|
| 1535 |
+
need_vis=need_vis,
|
| 1536 |
+
dataset_name=dataset_name,
|
| 1537 |
+
)
|
| 1538 |
+
|
| 1539 |
+
def _build_spherical_strategy(
|
| 1540 |
+
self,
|
| 1541 |
+
batch: Any,
|
| 1542 |
+
step: int,
|
| 1543 |
+
need_vis: bool = False,
|
| 1544 |
+
dataset_name: str = "hm3d",
|
| 1545 |
+
) -> _ModeStrategy:
|
| 1546 |
+
src_erp_u8 = batch.src_erp_rgb_u8.to(self.device, non_blocking=True)
|
| 1547 |
+
tgt_erp_u8 = batch.tgt_erp_rgb_u8.to(self.device, non_blocking=True)
|
| 1548 |
+
src_erp_depth = self._clamp_distance_for_supervision(
|
| 1549 |
+
batch.src_erp_depth_m.to(self.device, non_blocking=True)
|
| 1550 |
+
)
|
| 1551 |
+
tgt_erp_depth = self._clamp_distance_for_supervision(
|
| 1552 |
+
batch.tgt_erp_depth_m.to(self.device, non_blocking=True)
|
| 1553 |
+
)
|
| 1554 |
+
src_cdep = self._sanitize_positive_depth(
|
| 1555 |
+
batch.src_cube_depth_m.to(self.device, non_blocking=True)
|
| 1556 |
+
)
|
| 1557 |
+
tgt_cdep = self._sanitize_positive_depth(
|
| 1558 |
+
batch.tgt_cube_depth_m.to(self.device, non_blocking=True)
|
| 1559 |
+
)
|
| 1560 |
+
disable_depth_gt = bool(getattr(batch, "disable_depth_gt", False))
|
| 1561 |
+
|
| 1562 |
+
src_R = batch.src_R.to(self.device, non_blocking=True)
|
| 1563 |
+
src_t = batch.src_t.to(self.device, non_blocking=True)
|
| 1564 |
+
tgt_R = batch.tgt_R.to(self.device, non_blocking=True)
|
| 1565 |
+
tgt_t = batch.tgt_t.to(self.device, non_blocking=True)
|
| 1566 |
+
|
| 1567 |
+
cur_bs = int(src_erp_u8.shape[0])
|
| 1568 |
+
erp_h = int(src_erp_u8.shape[-2])
|
| 1569 |
+
erp_w = int(src_erp_u8.shape[-1])
|
| 1570 |
+
cube_face_w = int(batch.src_cube_depth_m.shape[2]) if torch.is_tensor(batch.src_cube_depth_m) else max(1, erp_h // 2)
|
| 1571 |
+
|
| 1572 |
+
use_flip_yz = str(dataset_name).lower() not in {"sim", "smx_sim_fisheye"}
|
| 1573 |
+
pose_convs_per_sample = ["c2w"] * cur_bs
|
| 1574 |
+
flip_yz_per_sample = [bool(use_flip_yz)] * cur_bs
|
| 1575 |
+
|
| 1576 |
+
extr_src_base = torch.stack(
|
| 1577 |
+
[build_extrinsics_w2c(src_R[i], src_t[i], pose_convs_per_sample[i]) for i in range(cur_bs)],
|
| 1578 |
+
dim=0
|
| 1579 |
+
)
|
| 1580 |
+
extr_tgt_base = torch.stack(
|
| 1581 |
+
[build_extrinsics_w2c(tgt_R[i], tgt_t[i], pose_convs_per_sample[i]) for i in range(cur_bs)],
|
| 1582 |
+
dim=0
|
| 1583 |
+
)
|
| 1584 |
+
|
| 1585 |
+
with torch.autocast("cuda", enabled=False):
|
| 1586 |
+
c2w_src = torch.linalg.inv(extr_src_base.to(torch.float32))
|
| 1587 |
+
c2w_tgt = torch.linalg.inv(extr_tgt_base.to(torch.float32))
|
| 1588 |
+
|
| 1589 |
+
flip_mask = torch.tensor(flip_yz_per_sample, device=c2w_src.device, dtype=torch.bool)
|
| 1590 |
+
negate_relative_z = False
|
| 1591 |
+
if bool(flip_mask.any().item()):
|
| 1592 |
+
flip_mode = os.environ.get("PANO_POSE_FLIP_CONVENTION", "flip_yz_negate_rel_z").strip().lower()
|
| 1593 |
+
negate_relative_z = flip_mode in {
|
| 1594 |
+
"flip_yz_negate_rel_z",
|
| 1595 |
+
"flip_yz_invert_z_translation",
|
| 1596 |
+
"flip_yz_neg_z",
|
| 1597 |
+
}
|
| 1598 |
+
if flip_mode in {"flip_y_only", "y", "y_only"}:
|
| 1599 |
+
diag = [1.0, -1.0, 1.0, 1.0]
|
| 1600 |
+
elif flip_mode in {"none", "identity", "no_flip"}:
|
| 1601 |
+
diag = [1.0, 1.0, 1.0, 1.0]
|
| 1602 |
+
else:
|
| 1603 |
+
diag = [1.0, -1.0, -1.0, 1.0]
|
| 1604 |
+
D = torch.diag(torch.tensor(diag, device=c2w_src.device, dtype=torch.float32))
|
| 1605 |
+
c2w_src = c2w_src.clone()
|
| 1606 |
+
c2w_tgt = c2w_tgt.clone()
|
| 1607 |
+
c2w_src[flip_mask] = c2w_src[flip_mask] @ D
|
| 1608 |
+
c2w_tgt[flip_mask] = c2w_tgt[flip_mask] @ D
|
| 1609 |
+
|
| 1610 |
+
ref_inv = torch.linalg.inv(c2w_src.to(torch.float32))
|
| 1611 |
+
c2w_src = ref_inv @ c2w_src
|
| 1612 |
+
c2w_tgt = ref_inv @ c2w_tgt
|
| 1613 |
+
if negate_relative_z:
|
| 1614 |
+
c2w_tgt = c2w_tgt.clone()
|
| 1615 |
+
c2w_tgt[flip_mask, 2, 3] *= -1.0
|
| 1616 |
+
|
| 1617 |
+
extr_src = torch.linalg.inv(c2w_src).to(dtype=extr_src_base.dtype)
|
| 1618 |
+
extr_tgt = torch.linalg.inv(c2w_tgt).to(dtype=extr_tgt_base.dtype)
|
| 1619 |
+
|
| 1620 |
+
src_erp = (src_erp_u8.float() / 255.0).clamp(0, 1)
|
| 1621 |
+
distance_init_cap_m = self._distance_init_cap_for_dataset(dataset_name)
|
| 1622 |
+
|
| 1623 |
+
out = self.model(
|
| 1624 |
+
image=src_erp,
|
| 1625 |
+
image_u8=src_erp_u8,
|
| 1626 |
+
camera_intrinsics=None,
|
| 1627 |
+
camera_model="spherical",
|
| 1628 |
+
depth_gt=None if disable_depth_gt else src_erp_depth,
|
| 1629 |
+
distance_init_cap_m=distance_init_cap_m,
|
| 1630 |
+
return_aux=True,
|
| 1631 |
+
)
|
| 1632 |
+
gaussians = out["gaussians"]
|
| 1633 |
+
aux_ray_target_all = out.get("unik3d_gt_rays", None)
|
| 1634 |
+
|
| 1635 |
+
def make_world_gaussians(b: int, g_b: Any) -> Any:
|
| 1636 |
+
return transform_gaussians_to_world(g_b, extr_src[b])
|
| 1637 |
+
|
| 1638 |
+
def make_sample(b: int, g_world: Any, enable_vis: bool) -> dict[str, Any]:
|
| 1639 |
+
src_rgb, src_depth, src_alpha = self._render_cubemap(g_world, extr_src[b], face_w=cube_face_w)
|
| 1640 |
+
tgt_rgb, tgt_depth, tgt_alpha = self._render_cubemap(g_world, extr_tgt[b], face_w=cube_face_w)
|
| 1641 |
+
|
| 1642 |
+
src_erp_pred = self._cube_to_erp(src_rgb, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w)
|
| 1643 |
+
tgt_erp_pred = self._cube_to_erp(tgt_rgb, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w)
|
| 1644 |
+
src_erp_alpha = self._cube_to_erp(src_alpha, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w)
|
| 1645 |
+
tgt_erp_alpha = self._cube_to_erp(tgt_alpha, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w)
|
| 1646 |
+
src_depth_dist = self._clamp_distance_for_supervision(
|
| 1647 |
+
self._cubemap_z_depth_to_distance(src_depth),
|
| 1648 |
+
clamp_max=False,
|
| 1649 |
+
)
|
| 1650 |
+
tgt_depth_dist = self._clamp_distance_for_supervision(
|
| 1651 |
+
self._cubemap_z_depth_to_distance(tgt_depth),
|
| 1652 |
+
clamp_max=False,
|
| 1653 |
+
)
|
| 1654 |
+
src_erp_depth_render = self._cube_to_erp(
|
| 1655 |
+
src_depth_dist, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w
|
| 1656 |
+
).clamp(min=1e-4)
|
| 1657 |
+
|
| 1658 |
+
src_erp_depth_pred = self._clamp_distance_for_supervision(
|
| 1659 |
+
out["distance_layers"][b : b + 1, 0:1],
|
| 1660 |
+
clamp_max=False,
|
| 1661 |
+
)
|
| 1662 |
+
src_erp_depth2_pred = (
|
| 1663 |
+
self._clamp_distance_for_supervision(out["distance_layers"][b : b + 1, 1:2], clamp_max=False)
|
| 1664 |
+
if out["distance_layers"] is not None and out["distance_layers"].shape[1] > 1
|
| 1665 |
+
else None
|
| 1666 |
+
)
|
| 1667 |
+
tgt_erp_depth_pred = self._cube_to_erp(
|
| 1668 |
+
tgt_depth_dist, equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w
|
| 1669 |
+
).clamp(min=1e-4)
|
| 1670 |
+
tgt_depth_loss_mask = self._rendered_depth_valid_for_inv_loss(tgt_erp_depth_pred, tgt_erp_alpha)
|
| 1671 |
+
depth_novel = self._pick_depth_for_cubemap_frustum_mask(
|
| 1672 |
+
gt_depth_cube=None if disable_depth_gt else (tgt_cdep[b : b + 1][0] if torch.is_tensor(tgt_cdep) else None),
|
| 1673 |
+
pred_depth_cube=tgt_depth_dist,
|
| 1674 |
+
face_w=cube_face_w,
|
| 1675 |
+
)
|
| 1676 |
+
source_depth_for_visibility = self._pick_depth_for_cubemap_frustum_mask(
|
| 1677 |
+
gt_depth_cube=None if disable_depth_gt else (src_cdep[b : b + 1][0] if torch.is_tensor(src_cdep) else None),
|
| 1678 |
+
pred_depth_cube=src_depth_dist,
|
| 1679 |
+
face_w=cube_face_w,
|
| 1680 |
+
)
|
| 1681 |
+
mask_bool = view_frustum_mask_cubemap_union(
|
| 1682 |
+
depth_novel=depth_novel,
|
| 1683 |
+
extr_novel_w2c=extr_tgt[b],
|
| 1684 |
+
extr_source_w2c=extr_src[b],
|
| 1685 |
+
face_w=int(cube_face_w),
|
| 1686 |
+
source_depth=source_depth_for_visibility,
|
| 1687 |
+
)
|
| 1688 |
+
mask_erp = self._cube_to_erp(
|
| 1689 |
+
mask_bool[:, None].to(torch.float32), equ_h=erp_h, equ_w=erp_w, face_w=cube_face_w
|
| 1690 |
+
)
|
| 1691 |
+
|
| 1692 |
+
gt_src_erp_u8 = src_erp_u8[b : b + 1]
|
| 1693 |
+
gt_tgt_erp_u8 = tgt_erp_u8[b : b + 1]
|
| 1694 |
+
gt_src_erp_depth = src_erp_depth[b : b + 1]
|
| 1695 |
+
gt_tgt_erp_depth = tgt_erp_depth[b : b + 1]
|
| 1696 |
+
gt_src_cube_u8 = batch.src_cube_rgb_u8[b].to(self.device, non_blocking=True).permute(0, 3, 1, 2).contiguous()
|
| 1697 |
+
gt_tgt_cube_u8 = batch.tgt_cube_rgb_u8[b].to(self.device, non_blocking=True).permute(0, 3, 1, 2).contiguous()
|
| 1698 |
+
|
| 1699 |
+
src_valid = torch.ones_like(gt_src_erp_depth) if disable_depth_gt else (gt_src_erp_depth > 0.0).to(dtype=torch.float32)
|
| 1700 |
+
tgt_valid = torch.ones_like(gt_tgt_erp_depth) if disable_depth_gt else (gt_tgt_erp_depth > 0.0).to(dtype=torch.float32)
|
| 1701 |
+
src_mask = torch.ones_like(src_valid)
|
| 1702 |
+
tgt_mask = (mask_erp.to(dtype=torch.float32) * tgt_valid).clamp(0.0, 1.0)
|
| 1703 |
+
tgt_mask_raw = tgt_mask
|
| 1704 |
+
tgt_mask = self._erode_supervision_mask(
|
| 1705 |
+
tgt_mask,
|
| 1706 |
+
self.target_mask_erode_px,
|
| 1707 |
+
circular_h=True,
|
| 1708 |
+
)
|
| 1709 |
+
src_cube_mask = torch.ones_like(src_alpha)
|
| 1710 |
+
if str(dataset_name).lower() == "hm3d" and (not disable_depth_gt) and torch.is_tensor(src_cdep):
|
| 1711 |
+
src_cube_valid = (src_cdep[b : b + 1][0, ..., 0] > 0.0).to(dtype=src_alpha.dtype).unsqueeze(1)
|
| 1712 |
+
if tuple(src_cube_valid.shape[-2:]) != tuple(src_alpha.shape[-2:]):
|
| 1713 |
+
src_cube_valid = F.interpolate(
|
| 1714 |
+
src_cube_valid,
|
| 1715 |
+
size=src_alpha.shape[-2:],
|
| 1716 |
+
mode="nearest",
|
| 1717 |
+
)
|
| 1718 |
+
src_cube_mask = src_cube_valid.to(device=src_alpha.device, dtype=src_alpha.dtype).clamp(0.0, 1.0)
|
| 1719 |
+
tgt_cube_valid = (depth_novel[..., 0] > 0.0).to(dtype=torch.float32).unsqueeze(1)
|
| 1720 |
+
tgt_cube_mask = (mask_bool[:, None].to(dtype=torch.float32) * tgt_cube_valid).clamp(0.0, 1.0)
|
| 1721 |
+
tgt_cube_mask = self._erode_supervision_mask(
|
| 1722 |
+
tgt_cube_mask,
|
| 1723 |
+
self.target_mask_erode_px,
|
| 1724 |
+
circular_h=False,
|
| 1725 |
+
)
|
| 1726 |
+
|
| 1727 |
+
src_cube_depth_zeros = torch.zeros_like(src_alpha)
|
| 1728 |
+
tgt_cube_depth_zeros = torch.zeros_like(tgt_alpha)
|
| 1729 |
+
src_erp_rgb_zeros = torch.zeros_like(src_erp_pred)
|
| 1730 |
+
tgt_erp_rgb_zeros = torch.zeros_like(tgt_erp_pred)
|
| 1731 |
+
src_erp_u8_zeros = torch.zeros_like(gt_src_erp_u8)
|
| 1732 |
+
tgt_erp_u8_zeros = torch.zeros_like(gt_tgt_erp_u8)
|
| 1733 |
+
|
| 1734 |
+
erp_proj_scale = 0.5 * (
|
| 1735 |
+
float(erp_w) / (2.0 * 3.141592653589793)
|
| 1736 |
+
+ float(erp_h) / 3.141592653589793
|
| 1737 |
+
)
|
| 1738 |
+
reg_inputs = self._collect_regularization_inputs(
|
| 1739 |
+
out=out,
|
| 1740 |
+
gaussians=gaussians,
|
| 1741 |
+
b=b,
|
| 1742 |
+
projected_scale_factor=erp_proj_scale,
|
| 1743 |
+
)
|
| 1744 |
+
|
| 1745 |
+
vis_payload = None
|
| 1746 |
+
if enable_vis:
|
| 1747 |
+
src_unik3d_depth = None
|
| 1748 |
+
tgt_unik3d_depth = None
|
| 1749 |
+
raw_dist = out.get("unik3d_distance", None)
|
| 1750 |
+
if torch.is_tensor(raw_dist):
|
| 1751 |
+
src_unik3d_depth = raw_dist[b : b + 1, 0:1].detach()
|
| 1752 |
+
|
| 1753 |
+
vis_payload = {
|
| 1754 |
+
"src_gt": (gt_src_erp_u8.float() / 255.0).detach(),
|
| 1755 |
+
"src_pred": src_erp_pred.clamp(0, 1).detach(),
|
| 1756 |
+
"src_alpha": src_erp_alpha.detach(),
|
| 1757 |
+
"src_gt_depth": None if disable_depth_gt else gt_src_erp_depth.detach(),
|
| 1758 |
+
"src_pred_depth": src_erp_depth_render.detach(),
|
| 1759 |
+
"src_unik3d_depth": src_unik3d_depth,
|
| 1760 |
+
"tgt_gt": (gt_tgt_erp_u8.float() / 255.0).detach(),
|
| 1761 |
+
"tgt_pred": tgt_erp_pred.clamp(0, 1).detach(),
|
| 1762 |
+
"tgt_alpha": tgt_erp_alpha.detach(),
|
| 1763 |
+
"tgt_gt_depth": None if disable_depth_gt else gt_tgt_erp_depth.detach(),
|
| 1764 |
+
"tgt_pred_depth": tgt_erp_depth_pred.detach(),
|
| 1765 |
+
"tgt_unik3d_depth": tgt_unik3d_depth,
|
| 1766 |
+
"dataset_name": str(dataset_name),
|
| 1767 |
+
"scene": str(self._item_at(getattr(batch, "scene", None), b, "unknown")),
|
| 1768 |
+
"src_idx": int(self._item_at(getattr(batch, "src_idx", None), b, -1)),
|
| 1769 |
+
"tgt_idx": int(self._item_at(getattr(batch, "tgt_idx", None), b, -1)),
|
| 1770 |
+
"src_pose_w2c": extr_src[b : b + 1].detach(),
|
| 1771 |
+
"tgt_pose_w2c": extr_tgt[b : b + 1].detach(),
|
| 1772 |
+
"src_cube_gt_u8": (
|
| 1773 |
+
batch.src_cube_rgb_u8[b].detach()
|
| 1774 |
+
if hasattr(batch, "src_cube_rgb_u8") and torch.is_tensor(batch.src_cube_rgb_u8)
|
| 1775 |
+
else None
|
| 1776 |
+
),
|
| 1777 |
+
"tgt_cube_gt_u8": (
|
| 1778 |
+
batch.tgt_cube_rgb_u8[b].detach()
|
| 1779 |
+
if hasattr(batch, "tgt_cube_rgb_u8") and torch.is_tensor(batch.tgt_cube_rgb_u8)
|
| 1780 |
+
else None
|
| 1781 |
+
),
|
| 1782 |
+
"src_cube_pred_linear": src_rgb.detach(),
|
| 1783 |
+
"tgt_cube_pred_linear": tgt_rgb.detach(),
|
| 1784 |
+
"src_cube_alpha": src_alpha.detach(),
|
| 1785 |
+
"tgt_cube_alpha": tgt_alpha.detach(),
|
| 1786 |
+
"tgt_metric_mask_raw": tgt_mask_raw.detach(),
|
| 1787 |
+
"tgt_metric_mask": tgt_mask.detach(),
|
| 1788 |
+
}
|
| 1789 |
+
|
| 1790 |
+
tgt_loss_terms = [
|
| 1791 |
+
{
|
| 1792 |
+
"pred_rgb_linear": tgt_rgb,
|
| 1793 |
+
"pred_alpha": tgt_alpha,
|
| 1794 |
+
"pred_depth_m": tgt_cube_depth_zeros,
|
| 1795 |
+
"pred_depth2_m": None,
|
| 1796 |
+
"gt_rgb_u8": gt_tgt_cube_u8,
|
| 1797 |
+
"gt_depth_m": tgt_cube_depth_zeros,
|
| 1798 |
+
"mask": tgt_cube_mask,
|
| 1799 |
+
"apply_color": True,
|
| 1800 |
+
"apply_alpha": True,
|
| 1801 |
+
"apply_depth": False,
|
| 1802 |
+
"apply_percep": bool(float(self.loss_fn.w.lambda_percep) > 0.0),
|
| 1803 |
+
"apply_tv": False,
|
| 1804 |
+
"apply_grad": False,
|
| 1805 |
+
"apply_grad_img": False,
|
| 1806 |
+
"grad_img_circular_h": False,
|
| 1807 |
+
"gaussian_scales": None,
|
| 1808 |
+
"gaussian_quaternions": None,
|
| 1809 |
+
"gaussian_angular_cell": None,
|
| 1810 |
+
"delta_xy": None,
|
| 1811 |
+
"gaussian_mean_vectors": None,
|
| 1812 |
+
"gaussian_opacities": None,
|
| 1813 |
+
"gauss_grid_shape": None,
|
| 1814 |
+
"projected_scale_factor": None,
|
| 1815 |
+
},
|
| 1816 |
+
{
|
| 1817 |
+
"pred_rgb_linear": tgt_erp_rgb_zeros,
|
| 1818 |
+
"pred_alpha": torch.zeros_like(tgt_erp_depth_pred),
|
| 1819 |
+
"pred_depth_m": tgt_erp_depth_pred,
|
| 1820 |
+
"pred_depth2_m": None,
|
| 1821 |
+
"gt_rgb_u8": tgt_erp_u8_zeros,
|
| 1822 |
+
"gt_depth_m": gt_tgt_erp_depth,
|
| 1823 |
+
"mask": tgt_mask,
|
| 1824 |
+
"depth_mask": tgt_depth_loss_mask,
|
| 1825 |
+
"apply_color": False,
|
| 1826 |
+
"apply_alpha": False,
|
| 1827 |
+
"apply_depth": not disable_depth_gt,
|
| 1828 |
+
"apply_percep": False,
|
| 1829 |
+
"apply_tv": False,
|
| 1830 |
+
"apply_grad": False,
|
| 1831 |
+
"apply_grad_img": not disable_depth_gt,
|
| 1832 |
+
"grad_img_circular_h": True,
|
| 1833 |
+
"gaussian_scales": None,
|
| 1834 |
+
"gaussian_quaternions": None,
|
| 1835 |
+
"gaussian_angular_cell": None,
|
| 1836 |
+
"delta_xy": None,
|
| 1837 |
+
"gaussian_mean_vectors": None,
|
| 1838 |
+
"gaussian_opacities": None,
|
| 1839 |
+
"gauss_grid_shape": None,
|
| 1840 |
+
"projected_scale_factor": reg_inputs["projected_scale_factor"],
|
| 1841 |
+
},
|
| 1842 |
+
]
|
| 1843 |
+
return {
|
| 1844 |
+
"src_loss_terms": [
|
| 1845 |
+
{
|
| 1846 |
+
"pred_rgb_linear": src_rgb,
|
| 1847 |
+
"pred_alpha": src_alpha,
|
| 1848 |
+
"pred_depth_m": src_cube_depth_zeros,
|
| 1849 |
+
"pred_depth2_m": None,
|
| 1850 |
+
"gt_rgb_u8": gt_src_cube_u8,
|
| 1851 |
+
"gt_depth_m": src_cube_depth_zeros,
|
| 1852 |
+
"mask": src_cube_mask,
|
| 1853 |
+
"apply_color": True,
|
| 1854 |
+
"apply_alpha": True,
|
| 1855 |
+
"apply_depth": False,
|
| 1856 |
+
"apply_percep": False,
|
| 1857 |
+
"apply_tv": False,
|
| 1858 |
+
"apply_grad": False,
|
| 1859 |
+
"apply_grad_img": False,
|
| 1860 |
+
"grad_img_circular_h": False,
|
| 1861 |
+
"gaussian_scales": None,
|
| 1862 |
+
"gaussian_quaternions": None,
|
| 1863 |
+
"gaussian_angular_cell": None,
|
| 1864 |
+
"delta_xy": None,
|
| 1865 |
+
"gaussian_mean_vectors": None,
|
| 1866 |
+
"gaussian_opacities": None,
|
| 1867 |
+
"gauss_grid_shape": None,
|
| 1868 |
+
"projected_scale_factor": None,
|
| 1869 |
+
},
|
| 1870 |
+
{
|
| 1871 |
+
"pred_rgb_linear": src_erp_rgb_zeros,
|
| 1872 |
+
"pred_alpha": torch.zeros_like(src_erp_depth_pred),
|
| 1873 |
+
"pred_depth_m": src_erp_depth_pred,
|
| 1874 |
+
"pred_depth2_m": src_erp_depth2_pred,
|
| 1875 |
+
"gt_rgb_u8": src_erp_u8_zeros,
|
| 1876 |
+
"gt_depth_m": gt_src_erp_depth,
|
| 1877 |
+
"mask": src_mask,
|
| 1878 |
+
"apply_color": False,
|
| 1879 |
+
"apply_alpha": False,
|
| 1880 |
+
"apply_depth": False,
|
| 1881 |
+
"apply_percep": False,
|
| 1882 |
+
"apply_tv": True,
|
| 1883 |
+
"apply_grad": False,
|
| 1884 |
+
"apply_grad_img": not disable_depth_gt,
|
| 1885 |
+
"grad_img_circular_h": True,
|
| 1886 |
+
"gaussian_scales": reg_inputs["gaussian_scales"],
|
| 1887 |
+
"gaussian_quaternions": reg_inputs["gaussian_quaternions"],
|
| 1888 |
+
"gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
|
| 1889 |
+
"delta_xy": reg_inputs["delta_xy_eff"],
|
| 1890 |
+
"delta_rho": reg_inputs["delta_rho_raw"],
|
| 1891 |
+
"delta_grid": reg_inputs["delta_grid"],
|
| 1892 |
+
"gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
|
| 1893 |
+
"gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
|
| 1894 |
+
"gaussian_opacities": reg_inputs["gaussian_opacities"],
|
| 1895 |
+
"gauss_grid_shape": reg_inputs["gauss_grid_shape"],
|
| 1896 |
+
"projected_scale_factor": reg_inputs["projected_scale_factor"],
|
| 1897 |
+
"projection_model": "erp",
|
| 1898 |
+
},
|
| 1899 |
+
],
|
| 1900 |
+
"tgt_loss_terms": tgt_loss_terms,
|
| 1901 |
+
"gaussian_scales": reg_inputs["gaussian_scales"],
|
| 1902 |
+
"gaussian_quaternions": reg_inputs["gaussian_quaternions"],
|
| 1903 |
+
"gaussian_angular_cell": reg_inputs["gaussian_angular_cell"],
|
| 1904 |
+
"delta_xy": reg_inputs["delta_xy_eff"],
|
| 1905 |
+
"delta_rho": reg_inputs["delta_rho_raw"],
|
| 1906 |
+
"delta_grid": reg_inputs["delta_grid"],
|
| 1907 |
+
"gaussian_mean_vectors": reg_inputs["gaussian_mean_vectors"],
|
| 1908 |
+
"gaussian_base_mean_vectors": reg_inputs["gaussian_base_mean_vectors"],
|
| 1909 |
+
"gaussian_opacities": reg_inputs["gaussian_opacities"],
|
| 1910 |
+
"gauss_grid_shape": reg_inputs["gauss_grid_shape"],
|
| 1911 |
+
"projected_scale_factor": reg_inputs["projected_scale_factor"],
|
| 1912 |
+
"projection_model": "erp",
|
| 1913 |
+
"aux_losses": self._aux_ray_losses(
|
| 1914 |
+
pred_rays=(
|
| 1915 |
+
out.get("unik3d_rays", None)[b : b + 1]
|
| 1916 |
+
if torch.is_tensor(out.get("unik3d_rays", None))
|
| 1917 |
+
else None
|
| 1918 |
+
),
|
| 1919 |
+
gt_rays=(
|
| 1920 |
+
aux_ray_target_all[b : b + 1]
|
| 1921 |
+
if torch.is_tensor(aux_ray_target_all)
|
| 1922 |
+
else None
|
| 1923 |
+
),
|
| 1924 |
+
mask=src_valid,
|
| 1925 |
+
pred_distance=(
|
| 1926 |
+
out["unik3d_distance"][b : b + 1, 0:1]
|
| 1927 |
+
if torch.is_tensor(out.get("unik3d_distance", None))
|
| 1928 |
+
else None
|
| 1929 |
+
),
|
| 1930 |
+
pred_distance2=src_erp_depth2_pred,
|
| 1931 |
+
gt_distance=None if disable_depth_gt else gt_src_erp_depth,
|
| 1932 |
+
depth_mask=src_valid,
|
| 1933 |
+
),
|
| 1934 |
+
"vis_payload": vis_payload,
|
| 1935 |
+
}
|
| 1936 |
+
|
| 1937 |
+
return _ModeStrategy(
|
| 1938 |
+
batch_size=int(cur_bs),
|
| 1939 |
+
gaussians=gaussians,
|
| 1940 |
+
make_world_gaussians=make_world_gaussians,
|
| 1941 |
+
make_sample=make_sample,
|
| 1942 |
+
collect_all_vis=bool(getattr(batch, "collect_all_vis", False)),
|
| 1943 |
+
)
|
| 1944 |
+
|
| 1945 |
+
def _render_cubemap(
|
| 1946 |
+
self,
|
| 1947 |
+
gaussians: Any,
|
| 1948 |
+
extr_w2c: torch.Tensor,
|
| 1949 |
+
face_w: int,
|
| 1950 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1951 |
+
device = gaussians.mean_vectors.device
|
| 1952 |
+
intr = get_pinhole_intrinsics_4x4(int(face_w)).to(device=device)[None].expand(6, -1, -1)
|
| 1953 |
+
extr_faces = cubemap_face_cameras(extr_w2c, device=device)
|
| 1954 |
+
out = self.renderer(
|
| 1955 |
+
gaussians,
|
| 1956 |
+
extrinsics=extr_faces,
|
| 1957 |
+
intrinsics=intr,
|
| 1958 |
+
image_width=int(face_w),
|
| 1959 |
+
image_height=int(face_w),
|
| 1960 |
+
)
|
| 1961 |
+
return out.color.contiguous(), out.depth.contiguous(), out.alpha.contiguous()
|
| 1962 |
+
|
| 1963 |
+
def _cube_to_erp(self, cube: torch.Tensor, equ_h: int, equ_w: int, face_w: int) -> torch.Tensor:
|
| 1964 |
+
cube = cube.permute(1, 0, 2, 3).unsqueeze(0)
|
| 1965 |
+
c2e = Cube2Equirec(face_w=int(face_w), equ_h=int(equ_h), equ_w=int(equ_w)).to(device=cube.device)
|
| 1966 |
+
return c2e(cube)
|
unisharp/datasets/__pycache__/dl3dv.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
unisharp/datasets/__pycache__/dl3dv.cpython-313.pyc
ADDED
|
Binary file (19.5 kB). View file
|
|
|
unisharp/datasets/__pycache__/pair_sampling.cpython-310.pyc
ADDED
|
Binary file (3.78 kB). View file
|
|
|
unisharp/datasets/__pycache__/pair_sampling.cpython-313.pyc
ADDED
|
Binary file (6.66 kB). View file
|
|
|
unisharp/datasets/__pycache__/panogs.cpython-310.pyc
ADDED
|
Binary file (17.4 kB). View file
|
|
|
unisharp/datasets/__pycache__/panogs.cpython-313.pyc
ADDED
|
Binary file (32.5 kB). View file
|
|
|
unisharp/datasets/__pycache__/re10k.cpython-310.pyc
ADDED
|
Binary file (19.8 kB). View file
|
|
|
unisharp/datasets/__pycache__/re10k.cpython-313.pyc
ADDED
|
Binary file (37.4 kB). View file
|
|
|
unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-310.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
unisharp/datasets/__pycache__/scannetpp_fisheye.cpython-313.pyc
ADDED
|
Binary file (32.4 kB). View file
|
|
|
unisharp/datasets/__pycache__/sim_panorama.cpython-310.pyc
ADDED
|
Binary file (18.9 kB). View file
|
|
|
unisharp/datasets/__pycache__/sim_panorama.cpython-313.pyc
ADDED
|
Binary file (34 kB). View file
|
|
|
unisharp/datasets/__pycache__/wildrgbd.cpython-310.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
unisharp/datasets/__pycache__/wildrgbd.cpython-313.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
unisharp/datasets/dl3dv.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from collections import defaultdict, deque
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torch.utils.data import IterableDataset
|
| 14 |
+
|
| 15 |
+
from unisharp.datasets.pair_sampling import (
|
| 16 |
+
project_overlap_ratio,
|
| 17 |
+
resize_k3_align_corners_false,
|
| 18 |
+
resize_rgb_u8_chw_high_quality,
|
| 19 |
+
select_targets_for_source,
|
| 20 |
+
)
|
| 21 |
+
from unisharp.datasets.re10k import Re10KPairSample, re10k_collate
|
| 22 |
+
from unisharp import DEFAULT_MAX_DEPTH_M
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DL3DVDataset(IterableDataset):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
root: Path,
|
| 29 |
+
depth_root: Path,
|
| 30 |
+
scene_specs_file: Path | None = None,
|
| 31 |
+
min_frame_gap: int = 1,
|
| 32 |
+
max_frame_gap: int = 32,
|
| 33 |
+
pair_max_translation_m: float = 0.5,
|
| 34 |
+
pair_min_overlap: float = 0.6,
|
| 35 |
+
pair_overlap_sample_h: int = 32,
|
| 36 |
+
pair_overlap_sample_w: int = 56,
|
| 37 |
+
output_h: int | None = None,
|
| 38 |
+
output_w: int | None = None,
|
| 39 |
+
shuffle_scene: bool = True,
|
| 40 |
+
shuffle_frame: bool = False,
|
| 41 |
+
ddp_rank: int = 0,
|
| 42 |
+
ddp_world_size: int = 1,
|
| 43 |
+
batch_size_hint: int = 1,
|
| 44 |
+
depth_max_m: float = DEFAULT_MAX_DEPTH_M,
|
| 45 |
+
seed: int = 0,
|
| 46 |
+
verify_manifest_paths: bool = False,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.root = Path(root)
|
| 50 |
+
self.depth_root = Path(depth_root)
|
| 51 |
+
self.min_frame_gap = int(min_frame_gap)
|
| 52 |
+
self.max_frame_gap = int(max_frame_gap)
|
| 53 |
+
self.pair_max_translation_m = float(pair_max_translation_m)
|
| 54 |
+
self.pair_min_overlap = float(pair_min_overlap)
|
| 55 |
+
self.pair_overlap_sample_h = int(pair_overlap_sample_h)
|
| 56 |
+
self.pair_overlap_sample_w = int(pair_overlap_sample_w)
|
| 57 |
+
self.output_h = int(output_h) if output_h is not None else None
|
| 58 |
+
self.output_w = int(output_w) if output_w is not None else None
|
| 59 |
+
self.shuffle_scene = bool(shuffle_scene)
|
| 60 |
+
self.shuffle_frame = bool(shuffle_frame)
|
| 61 |
+
self.ddp_rank = int(ddp_rank)
|
| 62 |
+
self.ddp_world_size = int(ddp_world_size)
|
| 63 |
+
self.batch_size_hint = int(max(1, batch_size_hint))
|
| 64 |
+
self.depth_max_m = float(depth_max_m)
|
| 65 |
+
self.seed = int(seed)
|
| 66 |
+
self.epoch = 0
|
| 67 |
+
self.verify_manifest_paths = bool(verify_manifest_paths)
|
| 68 |
+
self.scene_specs_file = Path(scene_specs_file) if scene_specs_file is not None else None
|
| 69 |
+
self.scene_specs = self._load_scene_specs()
|
| 70 |
+
if not self.scene_specs:
|
| 71 |
+
raise RuntimeError(f"No valid DL3DV scenes found under {self.root}")
|
| 72 |
+
|
| 73 |
+
def set_epoch(self, epoch: int) -> None:
|
| 74 |
+
self.epoch = int(epoch)
|
| 75 |
+
|
| 76 |
+
def _load_scene_specs(self) -> list[tuple[str, Path, Path]]:
|
| 77 |
+
if self.scene_specs_file is None:
|
| 78 |
+
return self._scan_scenes()
|
| 79 |
+
if not self.scene_specs_file.exists():
|
| 80 |
+
raise FileNotFoundError(self.scene_specs_file)
|
| 81 |
+
out: list[tuple[str, Path, Path]] = []
|
| 82 |
+
for raw in self.scene_specs_file.read_text(encoding="utf-8").splitlines():
|
| 83 |
+
line = raw.strip()
|
| 84 |
+
if not line:
|
| 85 |
+
continue
|
| 86 |
+
parts = line.split("|")
|
| 87 |
+
if len(parts) != 3:
|
| 88 |
+
continue
|
| 89 |
+
scene_name, scene_dir_raw, depth_dir_raw = parts
|
| 90 |
+
scene_dir = Path(scene_dir_raw)
|
| 91 |
+
depth_dir = Path(depth_dir_raw)
|
| 92 |
+
if (not self.verify_manifest_paths) or (scene_dir.exists() and depth_dir.exists()):
|
| 93 |
+
out.append((scene_name, scene_dir, depth_dir))
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
def _scan_scenes(self) -> list[tuple[str, Path, Path]]:
|
| 97 |
+
out: list[tuple[str, Path, Path]] = []
|
| 98 |
+
for bucket_dir in sorted([p for p in self.root.iterdir() if p.is_dir()]):
|
| 99 |
+
for scene_stub in sorted([p for p in bucket_dir.iterdir() if p.is_dir()]):
|
| 100 |
+
inner_dirs = [p for p in scene_stub.iterdir() if p.is_dir()]
|
| 101 |
+
scene_dir = inner_dirs[0] if inner_dirs else scene_stub
|
| 102 |
+
transforms_path = scene_dir / "transforms.json"
|
| 103 |
+
image_dir = scene_dir / "images_4"
|
| 104 |
+
depth_dir = self.depth_root / bucket_dir.name / scene_stub.name / "exports" / "mini_npz" / "per_image"
|
| 105 |
+
if transforms_path.exists() and image_dir.exists() and depth_dir.exists():
|
| 106 |
+
scene_name = f"{bucket_dir.name}/{scene_stub.name}"
|
| 107 |
+
out.append((scene_name, scene_dir, depth_dir))
|
| 108 |
+
return out
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def _load_rgb_u8(path: Path) -> torch.Tensor:
|
| 112 |
+
arr = np.asarray(Image.open(path).convert("RGB"), dtype=np.uint8).copy()
|
| 113 |
+
return torch.from_numpy(arr).permute(2, 0, 1).contiguous()
|
| 114 |
+
|
| 115 |
+
def _load_depth_m(self, path: Path) -> torch.Tensor:
|
| 116 |
+
payload = np.load(path)
|
| 117 |
+
depth = payload["depth"].astype(np.float32)
|
| 118 |
+
depth[~np.isfinite(depth)] = 0.0
|
| 119 |
+
depth = np.clip(depth, a_min=0.0, a_max=self.depth_max_m)
|
| 120 |
+
return torch.from_numpy(depth).unsqueeze(0)
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def _resize_depth_to_image(depth: torch.Tensor, image_hw: tuple[int, int]) -> torch.Tensor:
|
| 124 |
+
target_h, target_w = int(image_hw[0]), int(image_hw[1])
|
| 125 |
+
if depth.shape[-2:] == (target_h, target_w):
|
| 126 |
+
return depth
|
| 127 |
+
return F.interpolate(
|
| 128 |
+
depth.unsqueeze(0),
|
| 129 |
+
size=(target_h, target_w),
|
| 130 |
+
mode="nearest",
|
| 131 |
+
).squeeze(0)
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def _frame_id_from_name(name: str) -> int:
|
| 135 |
+
stem = Path(name).stem
|
| 136 |
+
return int(stem.split("_")[-1])
|
| 137 |
+
|
| 138 |
+
def _load_scene(
|
| 139 |
+
self,
|
| 140 |
+
scene_name: str,
|
| 141 |
+
scene_dir: Path,
|
| 142 |
+
depth_dir: Path,
|
| 143 |
+
) -> tuple[list[int], dict[int, Path], dict[int, Path], dict[int, torch.Tensor], dict[int, torch.Tensor], torch.Tensor]:
|
| 144 |
+
meta = json.loads((scene_dir / "transforms.json").read_text())
|
| 145 |
+
orig_w = int(meta["w"])
|
| 146 |
+
orig_h = int(meta["h"])
|
| 147 |
+
k = torch.eye(3, dtype=torch.float32)
|
| 148 |
+
k[0, 0] = float(meta["fl_x"])
|
| 149 |
+
k[1, 1] = float(meta["fl_y"])
|
| 150 |
+
k[0, 2] = float(meta["cx"])
|
| 151 |
+
k[1, 2] = float(meta["cy"])
|
| 152 |
+
|
| 153 |
+
image_dir = scene_dir / "images_4"
|
| 154 |
+
image_paths = {self._frame_id_from_name(p.name): p for p in image_dir.glob("*.png")}
|
| 155 |
+
depth_paths = {self._frame_id_from_name(p.name): p for p in depth_dir.glob("*.npz")}
|
| 156 |
+
w2c_map: dict[int, torch.Tensor] = {}
|
| 157 |
+
intr_map: dict[int, torch.Tensor] = {}
|
| 158 |
+
valid_ids: list[int] = []
|
| 159 |
+
|
| 160 |
+
example_img = None
|
| 161 |
+
for frame in meta.get("frames", []):
|
| 162 |
+
rel_path = str(frame.get("file_path", ""))
|
| 163 |
+
frame_name = Path(rel_path).name
|
| 164 |
+
frame_id = self._frame_id_from_name(frame_name)
|
| 165 |
+
if frame_id not in image_paths or frame_id not in depth_paths:
|
| 166 |
+
continue
|
| 167 |
+
c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32)
|
| 168 |
+
c2w[:3, 1:3] *= -1.0
|
| 169 |
+
if example_img is None:
|
| 170 |
+
example_img = self._load_rgb_u8(image_paths[frame_id])
|
| 171 |
+
cur_h, cur_w = int(example_img.shape[1]), int(example_img.shape[2])
|
| 172 |
+
k_cur = k.clone()
|
| 173 |
+
if cur_h != orig_h or cur_w != orig_w:
|
| 174 |
+
sx = float(cur_w) / float(orig_w)
|
| 175 |
+
sy = float(cur_h) / float(orig_h)
|
| 176 |
+
k_cur = resize_k3_align_corners_false(k_cur, sx=sx, sy=sy)
|
| 177 |
+
w2c_map[frame_id] = torch.linalg.inv(c2w)
|
| 178 |
+
intr_map[frame_id] = k_cur
|
| 179 |
+
valid_ids.append(frame_id)
|
| 180 |
+
valid_ids = sorted(valid_ids)
|
| 181 |
+
return valid_ids, image_paths, depth_paths, w2c_map, intr_map, k
|
| 182 |
+
|
| 183 |
+
def __iter__(self):
|
| 184 |
+
scenes = list(self.scene_specs)
|
| 185 |
+
order_rng = random.Random(self.seed + self.epoch)
|
| 186 |
+
if self.shuffle_scene:
|
| 187 |
+
order_rng.shuffle(scenes)
|
| 188 |
+
pending_by_hw: dict[tuple[int, int], deque[Re10KPairSample]] = defaultdict(deque)
|
| 189 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 190 |
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
| 191 |
+
worker_id = worker_info.id if worker_info is not None else 0
|
| 192 |
+
total_shards = max(1, self.ddp_world_size * num_workers)
|
| 193 |
+
shard_id = self.ddp_rank * num_workers + worker_id
|
| 194 |
+
src_unit_index = 0
|
| 195 |
+
|
| 196 |
+
for scene_order_idx, (scene_name, scene_dir, depth_dir) in enumerate(scenes):
|
| 197 |
+
try:
|
| 198 |
+
valid_ids, image_paths, depth_paths, w2c_map, intr_map, _ = self._load_scene(scene_name, scene_dir, depth_dir)
|
| 199 |
+
except Exception:
|
| 200 |
+
continue
|
| 201 |
+
if len(valid_ids) < 2:
|
| 202 |
+
continue
|
| 203 |
+
src_order = list(valid_ids)
|
| 204 |
+
scene_rng = random.Random(self.seed + self.epoch * 1000003 + scene_order_idx)
|
| 205 |
+
if self.shuffle_frame:
|
| 206 |
+
scene_rng.shuffle(src_order)
|
| 207 |
+
centers = torch.stack([torch.linalg.inv(w2c_map[i])[:3, 3] for i in valid_ids], dim=0)
|
| 208 |
+
frame_to_pos = {fid: pos for pos, fid in enumerate(valid_ids)}
|
| 209 |
+
|
| 210 |
+
def overlap_avg(src_pos: int, tgt_pos: int) -> float:
|
| 211 |
+
src_fid = int(valid_ids[src_pos])
|
| 212 |
+
tgt_fid = int(valid_ids[tgt_pos])
|
| 213 |
+
src_img_path = image_paths[src_fid]
|
| 214 |
+
with Image.open(src_img_path) as img:
|
| 215 |
+
w = int(img.size[0])
|
| 216 |
+
h = int(img.size[1])
|
| 217 |
+
return float(
|
| 218 |
+
0.5
|
| 219 |
+
* (
|
| 220 |
+
project_overlap_ratio(
|
| 221 |
+
src_w2c=w2c_map[src_fid],
|
| 222 |
+
tgt_w2c=w2c_map[tgt_fid],
|
| 223 |
+
src_k=intr_map[src_fid],
|
| 224 |
+
tgt_k=intr_map[tgt_fid],
|
| 225 |
+
h=h,
|
| 226 |
+
w=w,
|
| 227 |
+
sample_h=self.pair_overlap_sample_h,
|
| 228 |
+
sample_w=self.pair_overlap_sample_w,
|
| 229 |
+
)
|
| 230 |
+
+ project_overlap_ratio(
|
| 231 |
+
src_w2c=w2c_map[tgt_fid],
|
| 232 |
+
tgt_w2c=w2c_map[src_fid],
|
| 233 |
+
src_k=intr_map[tgt_fid],
|
| 234 |
+
tgt_k=intr_map[src_fid],
|
| 235 |
+
h=h,
|
| 236 |
+
w=w,
|
| 237 |
+
sample_h=self.pair_overlap_sample_h,
|
| 238 |
+
sample_w=self.pair_overlap_sample_w,
|
| 239 |
+
)
|
| 240 |
+
)
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
for src_idx in src_order:
|
| 244 |
+
if src_unit_index % total_shards != shard_id:
|
| 245 |
+
src_unit_index += 1
|
| 246 |
+
continue
|
| 247 |
+
src_unit_index += 1
|
| 248 |
+
src_pos = int(frame_to_pos[int(src_idx)])
|
| 249 |
+
tgt_pos_list = select_targets_for_source(
|
| 250 |
+
src_idx=src_pos,
|
| 251 |
+
candidate_indices=list(range(len(valid_ids))),
|
| 252 |
+
centers=centers,
|
| 253 |
+
min_index_gap=int(self.min_frame_gap),
|
| 254 |
+
max_index_gap=int(self.max_frame_gap),
|
| 255 |
+
pair_max_translation_m=float(self.pair_max_translation_m),
|
| 256 |
+
pair_min_overlap=float(self.pair_min_overlap),
|
| 257 |
+
overlap_score_fn=overlap_avg,
|
| 258 |
+
)
|
| 259 |
+
if not tgt_pos_list:
|
| 260 |
+
continue
|
| 261 |
+
tgt_idx = int(valid_ids[scene_rng.choice(tgt_pos_list)])
|
| 262 |
+
try:
|
| 263 |
+
src_img = self._load_rgb_u8(image_paths[int(src_idx)])
|
| 264 |
+
tgt_img = self._load_rgb_u8(image_paths[int(tgt_idx)])
|
| 265 |
+
src_depth = self._load_depth_m(depth_paths[int(src_idx)])
|
| 266 |
+
tgt_depth = self._load_depth_m(depth_paths[int(tgt_idx)])
|
| 267 |
+
except Exception:
|
| 268 |
+
continue
|
| 269 |
+
src_depth = self._resize_depth_to_image(src_depth, (int(src_img.shape[1]), int(src_img.shape[2])))
|
| 270 |
+
tgt_depth = self._resize_depth_to_image(tgt_depth, (int(tgt_img.shape[1]), int(tgt_img.shape[2])))
|
| 271 |
+
src_intr = intr_map[int(src_idx)].clone()
|
| 272 |
+
tgt_intr = intr_map[int(tgt_idx)].clone()
|
| 273 |
+
if self.output_h is not None and self.output_w is not None:
|
| 274 |
+
oh, ow = int(src_img.shape[1]), int(src_img.shape[2])
|
| 275 |
+
if oh != self.output_h or ow != self.output_w:
|
| 276 |
+
sx = float(self.output_w) / float(ow)
|
| 277 |
+
sy = float(self.output_h) / float(oh)
|
| 278 |
+
src_img = resize_rgb_u8_chw_high_quality(src_img, size=(self.output_h, self.output_w))
|
| 279 |
+
tgt_img = resize_rgb_u8_chw_high_quality(tgt_img, size=(self.output_h, self.output_w))
|
| 280 |
+
src_depth = F.interpolate(src_depth[None], size=(self.output_h, self.output_w), mode="nearest")[0]
|
| 281 |
+
tgt_depth = F.interpolate(tgt_depth[None], size=(self.output_h, self.output_w), mode="nearest")[0]
|
| 282 |
+
src_intr = resize_k3_align_corners_false(src_intr, sx=sx, sy=sy)
|
| 283 |
+
tgt_intr = resize_k3_align_corners_false(tgt_intr, sx=sx, sy=sy)
|
| 284 |
+
sample = Re10KPairSample(
|
| 285 |
+
src_rgb_u8=src_img,
|
| 286 |
+
tgt_rgb_u8=tgt_img,
|
| 287 |
+
src_w2c=w2c_map[int(src_idx)],
|
| 288 |
+
tgt_w2c=w2c_map[int(tgt_idx)],
|
| 289 |
+
src_intrinsics=src_intr,
|
| 290 |
+
tgt_intrinsics=tgt_intr,
|
| 291 |
+
src_idx=int(src_idx),
|
| 292 |
+
tgt_idx=int(tgt_idx),
|
| 293 |
+
scene=scene_name,
|
| 294 |
+
src_depth_m=src_depth,
|
| 295 |
+
tgt_depth_m=tgt_depth,
|
| 296 |
+
)
|
| 297 |
+
hw_key = (int(sample.src_rgb_u8.shape[1]), int(sample.src_rgb_u8.shape[2]))
|
| 298 |
+
bucket = pending_by_hw[hw_key]
|
| 299 |
+
bucket.append(sample)
|
| 300 |
+
if self.batch_size_hint <= 1:
|
| 301 |
+
yield bucket.popleft()
|
| 302 |
+
continue
|
| 303 |
+
while len(bucket) >= self.batch_size_hint:
|
| 304 |
+
packed = [bucket.popleft() for _ in range(self.batch_size_hint)]
|
| 305 |
+
yield re10k_collate(packed)
|
unisharp/datasets/pair_sampling.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Callable
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from unisharp.utils.pixel_convention import scale_intrinsics_align_corners_false
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def resize_k3_align_corners_false(k: torch.Tensor, *, sx: float, sy: float) -> torch.Tensor:
|
| 13 |
+
return scale_intrinsics_align_corners_false(k, sx=float(sx), sy=float(sy))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def resize_rgb_u8_chw_high_quality(image: torch.Tensor, *, size: tuple[int, int]) -> torch.Tensor:
|
| 17 |
+
if not torch.is_tensor(image) or image.ndim != 3:
|
| 18 |
+
raise ValueError(f"Expected CHW tensor, got {tuple(image.shape) if torch.is_tensor(image) else type(image)}")
|
| 19 |
+
dst_h, dst_w = int(size[0]), int(size[1])
|
| 20 |
+
if tuple(image.shape[-2:]) == (dst_h, dst_w):
|
| 21 |
+
return image.contiguous()
|
| 22 |
+
resized = F.interpolate(
|
| 23 |
+
image.unsqueeze(0).to(torch.float32),
|
| 24 |
+
size=(dst_h, dst_w),
|
| 25 |
+
mode="bicubic",
|
| 26 |
+
align_corners=False,
|
| 27 |
+
antialias=True,
|
| 28 |
+
)
|
| 29 |
+
return resized[0].round().clamp(0, 255).to(torch.uint8).contiguous()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def project_overlap_ratio(
|
| 33 |
+
src_w2c: torch.Tensor,
|
| 34 |
+
tgt_w2c: torch.Tensor,
|
| 35 |
+
src_k: torch.Tensor,
|
| 36 |
+
tgt_k: torch.Tensor,
|
| 37 |
+
h: int,
|
| 38 |
+
w: int,
|
| 39 |
+
src_hw: tuple[int, int] | None = None,
|
| 40 |
+
tgt_hw: tuple[int, int] | None = None,
|
| 41 |
+
sample_h: int = 32,
|
| 42 |
+
sample_w: int = 56,
|
| 43 |
+
proxy_depth: float = 1.0,
|
| 44 |
+
) -> float:
|
| 45 |
+
device = src_w2c.device
|
| 46 |
+
src_h, src_w = tuple(int(v) for v in (src_hw or (h, w)))
|
| 47 |
+
tgt_h, tgt_w = tuple(int(v) for v in (tgt_hw or (h, w)))
|
| 48 |
+
ys = torch.linspace(0, src_h - 1, steps=sample_h, device=device)
|
| 49 |
+
xs = torch.linspace(0, src_w - 1, steps=sample_w, device=device)
|
| 50 |
+
vv, uu = torch.meshgrid(ys, xs, indexing="ij")
|
| 51 |
+
u = uu.reshape(-1)
|
| 52 |
+
v = vv.reshape(-1)
|
| 53 |
+
|
| 54 |
+
fx, fy = src_k[0, 0], src_k[1, 1]
|
| 55 |
+
cx, cy = src_k[0, 2], src_k[1, 2]
|
| 56 |
+
x = (u - cx) / fx
|
| 57 |
+
y = (v - cy) / fy
|
| 58 |
+
z = torch.ones_like(x)
|
| 59 |
+
rays = torch.stack([x, y, z], dim=-1)
|
| 60 |
+
rays = rays / torch.norm(rays, dim=-1, keepdim=True).clamp(min=1e-6)
|
| 61 |
+
pts_src = rays * float(proxy_depth)
|
| 62 |
+
|
| 63 |
+
src_c2w = torch.linalg.inv(src_w2c)
|
| 64 |
+
pts_src_h = torch.cat([pts_src, torch.ones_like(pts_src[:, :1])], dim=-1)
|
| 65 |
+
pts_w = (src_c2w @ pts_src_h.T).T
|
| 66 |
+
pts_tgt = (tgt_w2c @ pts_w.T).T
|
| 67 |
+
xt, yt, zt = pts_tgt[:, 0], pts_tgt[:, 1], pts_tgt[:, 2].clamp(min=1e-6)
|
| 68 |
+
ut = tgt_k[0, 0] * (xt / zt) + tgt_k[0, 2]
|
| 69 |
+
vt = tgt_k[1, 1] * (yt / zt) + tgt_k[1, 2]
|
| 70 |
+
inside = (zt > 0.0) & (ut >= 0.0) & (ut <= float(tgt_w - 1)) & (vt >= 0.0) & (vt <= float(tgt_h - 1))
|
| 71 |
+
return float(inside.float().mean().item())
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def select_targets_for_source(
|
| 75 |
+
*,
|
| 76 |
+
src_idx: int,
|
| 77 |
+
candidate_indices: list[int],
|
| 78 |
+
centers: torch.Tensor,
|
| 79 |
+
min_index_gap: int,
|
| 80 |
+
max_index_gap: int,
|
| 81 |
+
pair_max_translation_m: float,
|
| 82 |
+
pair_min_overlap: float,
|
| 83 |
+
overlap_score_fn: Callable[[int, int], float],
|
| 84 |
+
) -> list[int]:
|
| 85 |
+
src_c = centers[int(src_idx)]
|
| 86 |
+
tgt_cands: list[int] = []
|
| 87 |
+
for j in candidate_indices:
|
| 88 |
+
j = int(j)
|
| 89 |
+
if j == int(src_idx):
|
| 90 |
+
continue
|
| 91 |
+
gap = abs(int(j) - int(src_idx))
|
| 92 |
+
if gap < int(min_index_gap) or gap > int(max_index_gap):
|
| 93 |
+
continue
|
| 94 |
+
trans = float(torch.norm(centers[j] - src_c, p=2).item())
|
| 95 |
+
if trans > float(pair_max_translation_m):
|
| 96 |
+
continue
|
| 97 |
+
if float(overlap_score_fn(int(src_idx), j)) >= float(pair_min_overlap):
|
| 98 |
+
tgt_cands.append(j)
|
| 99 |
+
return tgt_cands
|
unisharp/datasets/panogs.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Literal
|
| 6 |
+
from typing import cast
|
| 7 |
+
import tarfile
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torch.utils.data import Dataset
|
| 13 |
+
|
| 14 |
+
from unisharp import DEFAULT_MAX_DEPTH_M
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
MAX_DEPTH_M = DEFAULT_MAX_DEPTH_M
|
| 18 |
+
|
| 19 |
+
_PAIR_RECIPE_FIXED: tuple[str, bool] = ("c2w", True)
|
| 20 |
+
|
| 21 |
+
_PAIR_CONVENTIONS: tuple[str, ...] = ("c2w",)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _torch_load_any(path: Path) -> object:
|
| 25 |
+
try:
|
| 26 |
+
return torch.load(path, map_location="cpu", weights_only=False)
|
| 27 |
+
except TypeError:
|
| 28 |
+
return torch.load(path, map_location="cpu")
|
| 29 |
+
except (KeyError, tarfile.ReadError, EOFError, OSError, RuntimeError) as e:
|
| 30 |
+
raise RuntimeError(f"torch.load failed (possibly incomplete/corrupted): {path}") from e
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass(frozen=True)
|
| 34 |
+
class PanOGSSample:
|
| 35 |
+
|
| 36 |
+
src_erp_rgb_u8: torch.Tensor
|
| 37 |
+
tgt_erp_rgb_u8: torch.Tensor
|
| 38 |
+
src_erp_depth_m: torch.Tensor
|
| 39 |
+
tgt_erp_depth_m: torch.Tensor
|
| 40 |
+
|
| 41 |
+
src_cube_rgb_u8: torch.Tensor
|
| 42 |
+
tgt_cube_rgb_u8: torch.Tensor
|
| 43 |
+
src_cube_depth_m: torch.Tensor
|
| 44 |
+
tgt_cube_depth_m: torch.Tensor
|
| 45 |
+
|
| 46 |
+
src_R: torch.Tensor
|
| 47 |
+
src_t: torch.Tensor
|
| 48 |
+
tgt_R: torch.Tensor
|
| 49 |
+
tgt_t: torch.Tensor
|
| 50 |
+
|
| 51 |
+
src_idx: int
|
| 52 |
+
tgt_idx: int
|
| 53 |
+
scene: str
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _load_erp_rgb_u8(path: Path) -> torch.Tensor:
|
| 57 |
+
img = np.array(Image.open(path))
|
| 58 |
+
if img.ndim != 3 or img.shape[2] != 3:
|
| 59 |
+
raise ValueError(f"Expected RGB image at {path}, got shape={img.shape}")
|
| 60 |
+
return torch.from_numpy(img.astype(np.uint8)).permute(2, 0, 1).contiguous()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _load_depth_png(path: Path) -> torch.Tensor:
|
| 64 |
+
dep = np.array(Image.open(path))
|
| 65 |
+
return torch.from_numpy(dep)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _depth_to_meters(depth: torch.Tensor, max_depth_m: float = DEFAULT_MAX_DEPTH_M) -> torch.Tensor:
|
| 69 |
+
depth_f = depth.to(torch.float32)
|
| 70 |
+
maxv = float(depth_f.max().item()) if depth_f.numel() else 0.0
|
| 71 |
+
if maxv > 200.0:
|
| 72 |
+
depth_f = depth_f / 1000.0
|
| 73 |
+
depth_f[~torch.isfinite(depth_f)] = 0.0
|
| 74 |
+
return depth_f.clamp(min=0.0, max=float(max_depth_m))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class PanOGSDataset(Dataset[PanOGSSample]):
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
root: Path,
|
| 82 |
+
index_manifest_path: Path | None = None,
|
| 83 |
+
src_tgt_max_index_gap: int = 25,
|
| 84 |
+
use_cubemap_supervision: bool = True,
|
| 85 |
+
pair_sampling: bool = True,
|
| 86 |
+
pair_max_translation_m: float = 0.5,
|
| 87 |
+
pair_min_depth_overlap: float = 0.6,
|
| 88 |
+
pair_overlap_face_w: int = 64,
|
| 89 |
+
pair_overlap_margin: float = 1.05,
|
| 90 |
+
pair_max_tries: int = 48,
|
| 91 |
+
depth_max_m: float = DEFAULT_MAX_DEPTH_M,
|
| 92 |
+
) -> None:
|
| 93 |
+
self.root = root
|
| 94 |
+
self.src_tgt_max_index_gap = int(src_tgt_max_index_gap)
|
| 95 |
+
self.use_cubemap_supervision = use_cubemap_supervision
|
| 96 |
+
self.pair_sampling = bool(pair_sampling)
|
| 97 |
+
self.pair_max_translation_m = float(pair_max_translation_m)
|
| 98 |
+
self.pair_min_depth_overlap = float(pair_min_depth_overlap)
|
| 99 |
+
self.pair_overlap_face_w = int(pair_overlap_face_w)
|
| 100 |
+
self.pair_overlap_margin = float(pair_overlap_margin)
|
| 101 |
+
self.pair_max_tries = int(pair_max_tries)
|
| 102 |
+
self.depth_max_m = float(depth_max_m)
|
| 103 |
+
self.index_manifest_path = Path(index_manifest_path) if index_manifest_path is not None else None
|
| 104 |
+
|
| 105 |
+
self._pair_valid_tgts: dict[tuple[str, int], list[int]] = {}
|
| 106 |
+
self._pair_overlap_cache: dict[tuple[str, int, int], float] = {}
|
| 107 |
+
|
| 108 |
+
if not root.exists():
|
| 109 |
+
raise FileNotFoundError(root)
|
| 110 |
+
|
| 111 |
+
self.scenes = sorted([p for p in root.iterdir() if p.is_dir()])
|
| 112 |
+
if not self.scenes:
|
| 113 |
+
raise RuntimeError(f"No scene folders found in {root}")
|
| 114 |
+
|
| 115 |
+
self._pose_cache: dict[str, tuple[np.ndarray, np.ndarray]] = {}
|
| 116 |
+
self._meta_paths: dict[str, Path] = {}
|
| 117 |
+
self._num_frames: dict[str, int] = {}
|
| 118 |
+
self._available_frames: dict[str, list[int]] = {}
|
| 119 |
+
|
| 120 |
+
if self.index_manifest_path is not None:
|
| 121 |
+
if not self.index_manifest_path.exists():
|
| 122 |
+
raise FileNotFoundError(self.index_manifest_path)
|
| 123 |
+
valid_scenes: list[Path] = []
|
| 124 |
+
for raw in self.index_manifest_path.read_text(encoding="utf-8").splitlines():
|
| 125 |
+
line = raw.strip()
|
| 126 |
+
if not line:
|
| 127 |
+
continue
|
| 128 |
+
parts = line.split("|")
|
| 129 |
+
scene_name = parts[0].strip()
|
| 130 |
+
if not scene_name:
|
| 131 |
+
continue
|
| 132 |
+
scene_dir = root / scene_name
|
| 133 |
+
meta_path = scene_dir / "meta.pt"
|
| 134 |
+
if not meta_path.exists():
|
| 135 |
+
continue
|
| 136 |
+
if len(parts) >= 2:
|
| 137 |
+
try:
|
| 138 |
+
n_pose = int(parts[1])
|
| 139 |
+
except ValueError:
|
| 140 |
+
n_pose = 0
|
| 141 |
+
else:
|
| 142 |
+
n_pose = 0
|
| 143 |
+
if n_pose <= 0:
|
| 144 |
+
continue
|
| 145 |
+
self._meta_paths[scene_name] = meta_path
|
| 146 |
+
self._num_frames[scene_name] = n_pose
|
| 147 |
+
self._available_frames[scene_name] = list(range(n_pose))
|
| 148 |
+
valid_scenes.append(scene_dir)
|
| 149 |
+
self.scenes = valid_scenes
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if not self._available_frames:
|
| 153 |
+
|
| 154 |
+
valid_scenes = []
|
| 155 |
+
for scene_i, scene_dir in enumerate(self.scenes):
|
| 156 |
+
meta_path = scene_dir / "meta.pt"
|
| 157 |
+
if not meta_path.exists():
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
ex = _torch_load_any(meta_path)
|
| 161 |
+
cams = ex.get("cameras", None)
|
| 162 |
+
if not isinstance(cams, torch.Tensor):
|
| 163 |
+
raise ValueError(f"meta.pt missing 'cameras' tensor in {scene_dir}")
|
| 164 |
+
if cams.ndim != 3 or tuple(cams.shape[1:]) != (4, 4):
|
| 165 |
+
raise ValueError(f"Bad meta.pt cameras shape {tuple(cams.shape)} in {scene_dir}")
|
| 166 |
+
n_pose = int(cams.shape[0])
|
| 167 |
+
|
| 168 |
+
frames = list(range(n_pose))
|
| 169 |
+
|
| 170 |
+
name = scene_dir.name
|
| 171 |
+
self._meta_paths[name] = meta_path
|
| 172 |
+
self._num_frames[name] = n_pose
|
| 173 |
+
self._available_frames[name] = frames
|
| 174 |
+
valid_scenes.append(scene_dir)
|
| 175 |
+
|
| 176 |
+
self.scenes = valid_scenes
|
| 177 |
+
|
| 178 |
+
def _get_pose(self, scene: str) -> tuple[np.ndarray, np.ndarray]:
|
| 179 |
+
cached = self._pose_cache.get(scene)
|
| 180 |
+
if cached is not None:
|
| 181 |
+
return cached
|
| 182 |
+
|
| 183 |
+
meta_path = self._meta_paths.get(scene)
|
| 184 |
+
if meta_path is None:
|
| 185 |
+
raise FileNotFoundError(f"meta.pt not indexed for scene={scene} under {self.root}")
|
| 186 |
+
ex = _torch_load_any(meta_path)
|
| 187 |
+
cams = ex.get("cameras", None)
|
| 188 |
+
if not isinstance(cams, torch.Tensor):
|
| 189 |
+
raise ValueError(f"meta.pt missing 'cameras' tensor for scene={scene}")
|
| 190 |
+
cams = cams.to(torch.float32)
|
| 191 |
+
if cams.ndim != 3 or tuple(cams.shape[1:]) != (4, 4):
|
| 192 |
+
raise ValueError(f"Bad meta.pt cameras shape {tuple(cams.shape)} for scene={scene}")
|
| 193 |
+
R = cams[:, :3, :3].cpu().numpy()
|
| 194 |
+
t = cams[:, :3, 3].cpu().numpy()
|
| 195 |
+
out = (R, t)
|
| 196 |
+
self._pose_cache[scene] = out
|
| 197 |
+
return out
|
| 198 |
+
|
| 199 |
+
def __len__(self) -> int:
|
| 200 |
+
return len(self._index)
|
| 201 |
+
|
| 202 |
+
def _sample_target(self, scene: str, src_idx: int) -> int:
|
| 203 |
+
frames = self._available_frames[scene]
|
| 204 |
+
if len(frames) <= 1:
|
| 205 |
+
return src_idx
|
| 206 |
+
effective_gap = self.src_tgt_max_index_gap
|
| 207 |
+
candidates = [i for i in frames if i != src_idx and abs(i - src_idx) <= effective_gap]
|
| 208 |
+
if not candidates:
|
| 209 |
+
return src_idx
|
| 210 |
+
j = int(torch.randint(low=0, high=len(candidates), size=(1,)).item())
|
| 211 |
+
return int(candidates[j])
|
| 212 |
+
|
| 213 |
+
def _candidate_targets_by_translation(self, scene: str, src_idx: int) -> list[int]:
|
| 214 |
+
frames = self._available_frames[scene]
|
| 215 |
+
if len(frames) <= 1:
|
| 216 |
+
return []
|
| 217 |
+
R_np, t_np = self._get_pose(scene)
|
| 218 |
+
if not (0 <= src_idx < len(t_np) and 0 <= src_idx < len(R_np)):
|
| 219 |
+
return []
|
| 220 |
+
th = float(self.pair_max_translation_m)
|
| 221 |
+
|
| 222 |
+
def _cam_center_from(R: np.ndarray, t: np.ndarray, conv: str) -> np.ndarray:
|
| 223 |
+
if conv in ("c2w", "w2c_t_camcenter"):
|
| 224 |
+
return t
|
| 225 |
+
if conv == "w2c":
|
| 226 |
+
return -(R.transpose(0, 2, 1) @ t[..., None])[..., 0]
|
| 227 |
+
if conv == "c2w_t_w2c":
|
| 228 |
+
return -(R @ t[..., None])[..., 0]
|
| 229 |
+
raise ValueError(conv)
|
| 230 |
+
|
| 231 |
+
def _min_dist(idxs: np.ndarray) -> np.ndarray:
|
| 232 |
+
R_sub = R_np[idxs].astype(np.float32)
|
| 233 |
+
t_sub = t_np[idxs].astype(np.float32)
|
| 234 |
+
R_src = R_np[int(src_idx) : int(src_idx) + 1].astype(np.float32)
|
| 235 |
+
t_src = t_np[int(src_idx) : int(src_idx) + 1].astype(np.float32)
|
| 236 |
+
d_min = None
|
| 237 |
+
for conv in _PAIR_CONVENTIONS:
|
| 238 |
+
C_src = _cam_center_from(R_src, t_src, conv)[0]
|
| 239 |
+
C_sub = _cam_center_from(R_sub, t_sub, conv)
|
| 240 |
+
d = np.linalg.norm(C_sub - C_src[None, :], axis=1)
|
| 241 |
+
d_min = d if (d_min is None) else np.minimum(d_min, d)
|
| 242 |
+
assert d_min is not None
|
| 243 |
+
return d_min
|
| 244 |
+
|
| 245 |
+
effective_gap = self.src_tgt_max_index_gap
|
| 246 |
+
cand0 = np.array([i for i in frames if i != src_idx and abs(i - src_idx) <= effective_gap], dtype=np.int64)
|
| 247 |
+
if cand0.size > 0:
|
| 248 |
+
d0 = _min_dist(cand0)
|
| 249 |
+
ok0 = cand0[d0 < th]
|
| 250 |
+
if ok0.size > 0:
|
| 251 |
+
return [int(x) for x in ok0.tolist()]
|
| 252 |
+
return []
|
| 253 |
+
|
| 254 |
+
def _resize_cube_depth(self, depth: torch.Tensor, face_w: int) -> torch.Tensor:
|
| 255 |
+
if depth.ndim != 4 or depth.shape[0] != 6 or depth.shape[-1] != 1:
|
| 256 |
+
raise ValueError(f"Expected cube depth shape (6,H,W,1), got {tuple(depth.shape)}")
|
| 257 |
+
H = int(depth.shape[1])
|
| 258 |
+
W = int(depth.shape[2])
|
| 259 |
+
if H == face_w and W == face_w:
|
| 260 |
+
return depth.to(dtype=torch.float32)
|
| 261 |
+
import torch.nn.functional as F
|
| 262 |
+
|
| 263 |
+
x = depth.permute(0, 3, 1, 2).to(dtype=torch.float32)
|
| 264 |
+
x = F.interpolate(x, size=(face_w, face_w), mode="bilinear", align_corners=False)
|
| 265 |
+
return x.permute(0, 2, 3, 1).contiguous()
|
| 266 |
+
|
| 267 |
+
@staticmethod
|
| 268 |
+
def _cubemap_z_depth_to_distance(depth: torch.Tensor) -> torch.Tensor:
|
| 269 |
+
if depth.ndim != 4 or depth.shape[0] != 6 or depth.shape[-1] != 1:
|
| 270 |
+
raise ValueError(f"Expected cube depth shape (6,H,W,1), got {tuple(depth.shape)}")
|
| 271 |
+
from unisharp.utils.pano import get_pinhole_intrinsics_4x4
|
| 272 |
+
|
| 273 |
+
h = int(depth.shape[1])
|
| 274 |
+
w = int(depth.shape[2])
|
| 275 |
+
if h != w:
|
| 276 |
+
raise ValueError(f"Expected square cubemap faces, got {(h, w)}")
|
| 277 |
+
depth_61hw = depth.permute(0, 3, 1, 2).to(dtype=torch.float32).contiguous()
|
| 278 |
+
intr = get_pinhole_intrinsics_4x4(w).to(device=depth_61hw.device, dtype=depth_61hw.dtype)
|
| 279 |
+
ys = torch.arange(h, device=depth_61hw.device, dtype=depth_61hw.dtype)
|
| 280 |
+
xs = torch.arange(w, device=depth_61hw.device, dtype=depth_61hw.dtype)
|
| 281 |
+
vv, uu = torch.meshgrid(ys, xs, indexing="ij")
|
| 282 |
+
x = (uu - intr[0, 2]) / intr[0, 0].clamp(min=1e-8)
|
| 283 |
+
y = (vv - intr[1, 2]) / intr[1, 1].clamp(min=1e-8)
|
| 284 |
+
ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
|
| 285 |
+
dist = depth_61hw / ray_z.view(1, 1, h, w).clamp(min=1e-8)
|
| 286 |
+
valid = torch.isfinite(dist) & (dist > 0.0)
|
| 287 |
+
dist = torch.where(valid, dist, torch.zeros_like(dist))
|
| 288 |
+
return dist.permute(0, 2, 3, 1).contiguous()
|
| 289 |
+
|
| 290 |
+
def _pair_depth_overlap_score(
|
| 291 |
+
self,
|
| 292 |
+
*,
|
| 293 |
+
src_R: torch.Tensor,
|
| 294 |
+
src_t: torch.Tensor,
|
| 295 |
+
tgt_R: torch.Tensor,
|
| 296 |
+
tgt_t: torch.Tensor,
|
| 297 |
+
src_cube_depth_m: torch.Tensor,
|
| 298 |
+
tgt_cube_depth_m: torch.Tensor,
|
| 299 |
+
) -> float:
|
| 300 |
+
from unisharp.utils.camera_projection import build_extrinsics_w2c, view_frustum_mask_cubemap_union # noqa: WPS433
|
| 301 |
+
|
| 302 |
+
device = torch.device("cpu")
|
| 303 |
+
src_R = src_R.to(device=device, dtype=torch.float32)
|
| 304 |
+
src_t = src_t.to(device=device, dtype=torch.float32)
|
| 305 |
+
tgt_R = tgt_R.to(device=device, dtype=torch.float32)
|
| 306 |
+
tgt_t = tgt_t.to(device=device, dtype=torch.float32)
|
| 307 |
+
|
| 308 |
+
face_w = int(self.pair_overlap_face_w)
|
| 309 |
+
margin = float(self.pair_overlap_margin)
|
| 310 |
+
src_d = self._cubemap_z_depth_to_distance(self._resize_cube_depth(src_cube_depth_m.to(device=device), face_w=face_w))
|
| 311 |
+
tgt_d = self._cubemap_z_depth_to_distance(self._resize_cube_depth(tgt_cube_depth_m.to(device=device), face_w=face_w))
|
| 312 |
+
|
| 313 |
+
def _score_one(recipe: tuple[str, bool]) -> float:
|
| 314 |
+
pose_conv, flip_yz = recipe
|
| 315 |
+
extr_src = build_extrinsics_w2c(src_R, src_t, pose_conv)
|
| 316 |
+
extr_tgt = build_extrinsics_w2c(tgt_R, tgt_t, pose_conv)
|
| 317 |
+
|
| 318 |
+
with torch.autocast(device_type="cpu", enabled=False):
|
| 319 |
+
c2w_src = torch.linalg.inv(extr_src)
|
| 320 |
+
c2w_tgt = torch.linalg.inv(extr_tgt)
|
| 321 |
+
if bool(flip_yz):
|
| 322 |
+
D = torch.diag(torch.tensor([1.0, -1.0, -1.0, 1.0], dtype=torch.float32, device=device))
|
| 323 |
+
c2w_src = c2w_src @ D
|
| 324 |
+
c2w_tgt = c2w_tgt @ D
|
| 325 |
+
ref_inv = torch.linalg.inv(c2w_src)
|
| 326 |
+
c2w_src = ref_inv @ c2w_src
|
| 327 |
+
c2w_tgt = ref_inv @ c2w_tgt
|
| 328 |
+
extr_src_n = torch.linalg.inv(c2w_src)
|
| 329 |
+
extr_tgt_n = torch.linalg.inv(c2w_tgt)
|
| 330 |
+
|
| 331 |
+
m_tgt_in_src = view_frustum_mask_cubemap_union(
|
| 332 |
+
depth_novel=tgt_d,
|
| 333 |
+
extr_novel_w2c=extr_tgt_n,
|
| 334 |
+
extr_source_w2c=extr_src_n,
|
| 335 |
+
face_w=face_w,
|
| 336 |
+
margin=margin,
|
| 337 |
+
)
|
| 338 |
+
m_src_in_tgt = view_frustum_mask_cubemap_union(
|
| 339 |
+
depth_novel=src_d,
|
| 340 |
+
extr_novel_w2c=extr_src_n,
|
| 341 |
+
extr_source_w2c=extr_tgt_n,
|
| 342 |
+
face_w=face_w,
|
| 343 |
+
margin=margin,
|
| 344 |
+
)
|
| 345 |
+
tgt_valid = torch.isfinite(tgt_d[..., 0]) & (tgt_d[..., 0] > 0.0)
|
| 346 |
+
src_valid = torch.isfinite(src_d[..., 0]) & (src_d[..., 0] > 0.0)
|
| 347 |
+
denom_t = float(tgt_valid.sum().item())
|
| 348 |
+
denom_s = float(src_valid.sum().item())
|
| 349 |
+
if denom_t < 10 or denom_s < 10:
|
| 350 |
+
return 0.0
|
| 351 |
+
a = float((m_tgt_in_src & tgt_valid).sum().item()) / denom_t
|
| 352 |
+
b = float((m_src_in_tgt & src_valid).sum().item()) / denom_s
|
| 353 |
+
return 0.5 * (a + b)
|
| 354 |
+
|
| 355 |
+
return _score_one(_PAIR_RECIPE_FIXED)
|
| 356 |
+
|
| 357 |
+
def __getitem__(self, idx: int) -> PanOGSSample:
|
| 358 |
+
src_erp: torch.Tensor | None = None
|
| 359 |
+
tgt_erp: torch.Tensor | None = None
|
| 360 |
+
src_dep: torch.Tensor | None = None
|
| 361 |
+
tgt_dep: torch.Tensor | None = None
|
| 362 |
+
src_cube: torch.Tensor | None = None
|
| 363 |
+
tgt_cube: torch.Tensor | None = None
|
| 364 |
+
src_cdep: torch.Tensor | None = None
|
| 365 |
+
tgt_cdep: torch.Tensor | None = None
|
| 366 |
+
last_err: Exception | None = None
|
| 367 |
+
|
| 368 |
+
max_outer = 16
|
| 369 |
+
for outer in range(max_outer):
|
| 370 |
+
scene, src_idx = self._index[int(idx) % len(self._index)]
|
| 371 |
+
scene_dir = self.root / scene
|
| 372 |
+
tgt_idx = self._sample_target(scene, src_idx)
|
| 373 |
+
|
| 374 |
+
max_retries = 8
|
| 375 |
+
ok = False
|
| 376 |
+
for _ in range(max_retries):
|
| 377 |
+
try:
|
| 378 |
+
if src_erp is None:
|
| 379 |
+
src_erp = _load_erp_rgb_u8(scene_dir / "pano" / f"{src_idx:05d}.png")
|
| 380 |
+
src_dep = _depth_to_meters(
|
| 381 |
+
_load_depth_png(scene_dir / "pano_depth" / f"{src_idx:05d}.png"),
|
| 382 |
+
max_depth_m=self.depth_max_m,
|
| 383 |
+
)
|
| 384 |
+
if self.use_cubemap_supervision:
|
| 385 |
+
src_cube_any = _torch_load_any(scene_dir / "cubemaps" / f"{src_idx:05d}.torch")
|
| 386 |
+
src_cdep_any = _torch_load_any(scene_dir / "cubemaps_depth" / f"{src_idx:05d}.torch")
|
| 387 |
+
if not all(isinstance(x, torch.Tensor) for x in [src_cube_any, src_cdep_any]):
|
| 388 |
+
raise RuntimeError("Bad .torch payload for src (expected Tensor).")
|
| 389 |
+
src_cube = cast(torch.Tensor, src_cube_any)
|
| 390 |
+
src_cdep = cast(torch.Tensor, src_cdep_any).to(torch.float32).clamp(min=0.0, max=self.depth_max_m)
|
| 391 |
+
else:
|
| 392 |
+
src_cube = torch.zeros((6, 256, 256, 3), dtype=torch.uint8)
|
| 393 |
+
src_cdep = torch.zeros((6, 256, 256, 1), dtype=torch.float32)
|
| 394 |
+
|
| 395 |
+
candidates: list[int] = []
|
| 396 |
+
if self.pair_sampling and self.use_cubemap_supervision:
|
| 397 |
+
key = (scene, int(src_idx))
|
| 398 |
+
cached = self._pair_valid_tgts.get(key)
|
| 399 |
+
if cached:
|
| 400 |
+
candidates = list(cached)
|
| 401 |
+
else:
|
| 402 |
+
candidates = self._candidate_targets_by_translation(scene, int(src_idx))
|
| 403 |
+
if not candidates:
|
| 404 |
+
candidates = [int(tgt_idx)]
|
| 405 |
+
|
| 406 |
+
tried: set[int] = set()
|
| 407 |
+
found = False
|
| 408 |
+
max_try = (
|
| 409 |
+
1
|
| 410 |
+
if (not self.pair_sampling or not self.use_cubemap_supervision)
|
| 411 |
+
else max(1, self.pair_max_tries)
|
| 412 |
+
)
|
| 413 |
+
for _try in range(max_try):
|
| 414 |
+
pool = [
|
| 415 |
+
c
|
| 416 |
+
for c in candidates
|
| 417 |
+
if int(c) not in tried and int(c) != int(src_idx)
|
| 418 |
+
]
|
| 419 |
+
if not pool:
|
| 420 |
+
break
|
| 421 |
+
j = int(torch.randint(0, len(pool), (1,)).item())
|
| 422 |
+
tgt_idx = int(pool[j])
|
| 423 |
+
tried.add(int(tgt_idx))
|
| 424 |
+
|
| 425 |
+
if self.use_cubemap_supervision:
|
| 426 |
+
tgt_cdep_any = _torch_load_any(scene_dir / "cubemaps_depth" / f"{tgt_idx:05d}.torch")
|
| 427 |
+
if not isinstance(tgt_cdep_any, torch.Tensor):
|
| 428 |
+
raise RuntimeError("Bad .torch payload for tgt depth (expected Tensor).")
|
| 429 |
+
tgt_cdep = cast(torch.Tensor, tgt_cdep_any).to(torch.float32).clamp(min=0.0, max=self.depth_max_m)
|
| 430 |
+
else:
|
| 431 |
+
tgt_cdep = torch.zeros((6, 256, 256, 1), dtype=torch.float32)
|
| 432 |
+
|
| 433 |
+
if self.pair_sampling and self.use_cubemap_supervision:
|
| 434 |
+
k = (scene, int(src_idx), int(tgt_idx))
|
| 435 |
+
score = self._pair_overlap_cache.get(k)
|
| 436 |
+
if score is None:
|
| 437 |
+
R_np, t_np = self._get_pose(scene)
|
| 438 |
+
src_R = torch.from_numpy(R_np[int(src_idx)])
|
| 439 |
+
src_t = torch.from_numpy(t_np[int(src_idx)])
|
| 440 |
+
tgt_R = torch.from_numpy(R_np[int(tgt_idx)])
|
| 441 |
+
tgt_t = torch.from_numpy(t_np[int(tgt_idx)])
|
| 442 |
+
score = self._pair_depth_overlap_score(
|
| 443 |
+
src_R=src_R,
|
| 444 |
+
src_t=src_t,
|
| 445 |
+
tgt_R=tgt_R,
|
| 446 |
+
tgt_t=tgt_t,
|
| 447 |
+
src_cube_depth_m=cast(torch.Tensor, src_cdep),
|
| 448 |
+
tgt_cube_depth_m=cast(torch.Tensor, tgt_cdep),
|
| 449 |
+
)
|
| 450 |
+
self._pair_overlap_cache[k] = float(score)
|
| 451 |
+
if float(score) < float(self.pair_min_depth_overlap):
|
| 452 |
+
continue
|
| 453 |
+
kk = (scene, int(src_idx))
|
| 454 |
+
self._pair_valid_tgts.setdefault(kk, []).append(int(tgt_idx))
|
| 455 |
+
|
| 456 |
+
tgt_erp = _load_erp_rgb_u8(scene_dir / "pano" / f"{tgt_idx:05d}.png")
|
| 457 |
+
tgt_dep = _depth_to_meters(
|
| 458 |
+
_load_depth_png(scene_dir / "pano_depth" / f"{tgt_idx:05d}.png"),
|
| 459 |
+
max_depth_m=self.depth_max_m,
|
| 460 |
+
)
|
| 461 |
+
if self.use_cubemap_supervision:
|
| 462 |
+
tgt_cube_any = _torch_load_any(scene_dir / "cubemaps" / f"{tgt_idx:05d}.torch")
|
| 463 |
+
if not isinstance(tgt_cube_any, torch.Tensor):
|
| 464 |
+
raise RuntimeError("Bad .torch payload for tgt RGB cubemap (expected Tensor).")
|
| 465 |
+
tgt_cube = cast(torch.Tensor, tgt_cube_any)
|
| 466 |
+
else:
|
| 467 |
+
tgt_cube = torch.zeros((6, 256, 256, 3), dtype=torch.uint8)
|
| 468 |
+
|
| 469 |
+
found = True
|
| 470 |
+
break
|
| 471 |
+
|
| 472 |
+
if not found:
|
| 473 |
+
raise RuntimeError(
|
| 474 |
+
f"No valid tgt found for scene={scene} src={src_idx} within constraints "
|
| 475 |
+
f"(trans<{self.pair_max_translation_m}m, overlap>{self.pair_min_depth_overlap})."
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
ok = True
|
| 479 |
+
break
|
| 480 |
+
except (FileNotFoundError, RuntimeError, EOFError, KeyError, tarfile.ReadError, OSError) as e:
|
| 481 |
+
last_err = e
|
| 482 |
+
frames = self._available_frames.get(scene, [])
|
| 483 |
+
if not frames:
|
| 484 |
+
break
|
| 485 |
+
src_idx = int(frames[int(torch.randint(0, len(frames), (1,)).item())])
|
| 486 |
+
tgt_idx = self._sample_target(scene, src_idx)
|
| 487 |
+
src_erp = None
|
| 488 |
+
src_dep = None
|
| 489 |
+
src_cube = None
|
| 490 |
+
src_cdep = None
|
| 491 |
+
|
| 492 |
+
if ok:
|
| 493 |
+
break
|
| 494 |
+
|
| 495 |
+
idx = int(idx) + 9973 + outer * 13
|
| 496 |
+
else:
|
| 497 |
+
raise RuntimeError(f"PanOGS __getitem__ failed after retries. last_err={last_err}")
|
| 498 |
+
|
| 499 |
+
assert src_erp is not None and tgt_erp is not None
|
| 500 |
+
assert src_dep is not None and tgt_dep is not None
|
| 501 |
+
assert src_cube is not None and tgt_cube is not None
|
| 502 |
+
assert src_cdep is not None and tgt_cdep is not None
|
| 503 |
+
|
| 504 |
+
src_dep = src_dep.to(torch.float32).unsqueeze(0)
|
| 505 |
+
tgt_dep = tgt_dep.to(torch.float32).unsqueeze(0)
|
| 506 |
+
|
| 507 |
+
R_np, t_np = self._get_pose(scene)
|
| 508 |
+
src_R = torch.from_numpy(R_np[src_idx])
|
| 509 |
+
src_t = torch.from_numpy(t_np[src_idx])
|
| 510 |
+
tgt_R = torch.from_numpy(R_np[tgt_idx])
|
| 511 |
+
tgt_t = torch.from_numpy(t_np[tgt_idx])
|
| 512 |
+
|
| 513 |
+
return PanOGSSample(
|
| 514 |
+
src_erp_rgb_u8=src_erp,
|
| 515 |
+
tgt_erp_rgb_u8=tgt_erp,
|
| 516 |
+
src_erp_depth_m=src_dep,
|
| 517 |
+
tgt_erp_depth_m=tgt_dep,
|
| 518 |
+
src_cube_rgb_u8=src_cube,
|
| 519 |
+
tgt_cube_rgb_u8=tgt_cube,
|
| 520 |
+
src_cube_depth_m=src_cdep,
|
| 521 |
+
tgt_cube_depth_m=tgt_cdep,
|
| 522 |
+
src_R=src_R,
|
| 523 |
+
src_t=src_t,
|
| 524 |
+
tgt_R=tgt_R,
|
| 525 |
+
tgt_t=tgt_t,
|
| 526 |
+
src_idx=src_idx,
|
| 527 |
+
tgt_idx=tgt_idx,
|
| 528 |
+
scene=scene,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def panogs_collate(batch: list[PanOGSSample]) -> PanOGSSample:
|
| 533 |
+
def stack(xs):
|
| 534 |
+
if isinstance(xs[0], torch.Tensor):
|
| 535 |
+
return torch.stack(xs, dim=0)
|
| 536 |
+
return xs
|
| 537 |
+
|
| 538 |
+
return PanOGSSample(
|
| 539 |
+
src_erp_rgb_u8=stack([b.src_erp_rgb_u8 for b in batch]),
|
| 540 |
+
tgt_erp_rgb_u8=stack([b.tgt_erp_rgb_u8 for b in batch]),
|
| 541 |
+
src_erp_depth_m=stack([b.src_erp_depth_m for b in batch]),
|
| 542 |
+
tgt_erp_depth_m=stack([b.tgt_erp_depth_m for b in batch]),
|
| 543 |
+
src_cube_rgb_u8=stack([b.src_cube_rgb_u8 for b in batch]),
|
| 544 |
+
tgt_cube_rgb_u8=stack([b.tgt_cube_rgb_u8 for b in batch]),
|
| 545 |
+
src_cube_depth_m=stack([b.src_cube_depth_m for b in batch]),
|
| 546 |
+
tgt_cube_depth_m=stack([b.tgt_cube_depth_m for b in batch]),
|
| 547 |
+
src_R=stack([b.src_R for b in batch]),
|
| 548 |
+
src_t=stack([b.src_t for b in batch]),
|
| 549 |
+
tgt_R=stack([b.tgt_R for b in batch]),
|
| 550 |
+
tgt_t=stack([b.tgt_t for b in batch]),
|
| 551 |
+
src_idx=[b.src_idx for b in batch], # type: ignore[arg-type]
|
| 552 |
+
tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type]
|
| 553 |
+
scene=[b.scene for b in batch], # type: ignore[arg-type]
|
| 554 |
+
)
|
| 555 |
+
|
unisharp/datasets/re10k.py
ADDED
|
@@ -0,0 +1,718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict, deque
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import random
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torchvision.transforms as tf
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from torch.utils.data import IterableDataset
|
| 16 |
+
|
| 17 |
+
from unisharp.datasets.pair_sampling import (
|
| 18 |
+
project_overlap_ratio,
|
| 19 |
+
resize_k3_align_corners_false,
|
| 20 |
+
resize_rgb_u8_chw_high_quality,
|
| 21 |
+
select_targets_for_source,
|
| 22 |
+
)
|
| 23 |
+
from unisharp import DEFAULT_MAX_DEPTH_M
|
| 24 |
+
from unisharp.utils.pixel_convention import normalized_intrinsics_to_integer_pixel_k
|
| 25 |
+
from unisharp.utils.unik3d_adapter import infer_unik3d_pinhole, load_unik3d_model
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
LOGGER = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _torch_load_any(path: Path) -> object:
|
| 32 |
+
try:
|
| 33 |
+
return torch.load(path, map_location="cpu", weights_only=False)
|
| 34 |
+
except TypeError:
|
| 35 |
+
return torch.load(path, map_location="cpu")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _pack_re10k_batch(batch: list["Re10KPairSample"]) -> "Re10KPairSample":
|
| 39 |
+
def stack(xs):
|
| 40 |
+
if isinstance(xs[0], torch.Tensor):
|
| 41 |
+
ref_shape = tuple(xs[0].shape)
|
| 42 |
+
for idx, x in enumerate(xs[1:], start=1):
|
| 43 |
+
if tuple(x.shape) != ref_shape:
|
| 44 |
+
raise RuntimeError(
|
| 45 |
+
"RE10K collate got mixed tensor shapes: "
|
| 46 |
+
f"ref={ref_shape} mismatch_idx={idx} got={tuple(x.shape)}"
|
| 47 |
+
)
|
| 48 |
+
return torch.stack(xs, dim=0)
|
| 49 |
+
return xs
|
| 50 |
+
|
| 51 |
+
def stack_optional_depth(xs):
|
| 52 |
+
if all(torch.is_tensor(x) for x in xs):
|
| 53 |
+
ref_shape = tuple(xs[0].shape)
|
| 54 |
+
for idx, x in enumerate(xs[1:], start=1):
|
| 55 |
+
if tuple(x.shape) != ref_shape:
|
| 56 |
+
raise RuntimeError(
|
| 57 |
+
"RE10K collate got mixed depth shapes: "
|
| 58 |
+
f"ref={ref_shape} mismatch_idx={idx} got={tuple(x.shape)}"
|
| 59 |
+
)
|
| 60 |
+
return torch.stack(xs, dim=0)
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
return Re10KPairSample(
|
| 64 |
+
src_rgb_u8=stack([b.src_rgb_u8 for b in batch]),
|
| 65 |
+
tgt_rgb_u8=stack([b.tgt_rgb_u8 for b in batch]),
|
| 66 |
+
src_w2c=stack([b.src_w2c for b in batch]),
|
| 67 |
+
tgt_w2c=stack([b.tgt_w2c for b in batch]),
|
| 68 |
+
src_intrinsics=stack([b.src_intrinsics for b in batch]),
|
| 69 |
+
tgt_intrinsics=stack([b.tgt_intrinsics for b in batch]),
|
| 70 |
+
src_idx=[b.src_idx for b in batch], # type: ignore[arg-type]
|
| 71 |
+
tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type]
|
| 72 |
+
scene=[b.scene for b in batch], # type: ignore[arg-type]
|
| 73 |
+
src_depth_m=stack_optional_depth([b.src_depth_m for b in batch]), # type: ignore[arg-type]
|
| 74 |
+
tgt_depth_m=stack_optional_depth([b.tgt_depth_m for b in batch]), # type: ignore[arg-type]
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def re10k_passthrough(batch: "Re10KPairSample") -> "Re10KPairSample":
|
| 79 |
+
return batch
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@dataclass(frozen=True)
|
| 83 |
+
class Re10KPairSample:
|
| 84 |
+
src_rgb_u8: torch.Tensor
|
| 85 |
+
tgt_rgb_u8: torch.Tensor
|
| 86 |
+
src_w2c: torch.Tensor
|
| 87 |
+
tgt_w2c: torch.Tensor
|
| 88 |
+
src_intrinsics: torch.Tensor
|
| 89 |
+
tgt_intrinsics: torch.Tensor
|
| 90 |
+
src_idx: int
|
| 91 |
+
tgt_idx: int
|
| 92 |
+
scene: str
|
| 93 |
+
src_depth_m: torch.Tensor | None = None
|
| 94 |
+
tgt_depth_m: torch.Tensor | None = None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Re10KDataset(IterableDataset):
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
root: Path,
|
| 102 |
+
chunks_file: Path | None = None,
|
| 103 |
+
split: str = "train",
|
| 104 |
+
min_frame_gap: int = 1,
|
| 105 |
+
max_frame_gap: int = 32,
|
| 106 |
+
pair_max_translation_m: float = 0.5,
|
| 107 |
+
pair_min_overlap: float = 0.6,
|
| 108 |
+
pair_overlap_sample_h: int = 32,
|
| 109 |
+
pair_overlap_sample_w: int = 56,
|
| 110 |
+
pair_max_tries: int = 32,
|
| 111 |
+
output_h: int | None = None,
|
| 112 |
+
output_w: int | None = None,
|
| 113 |
+
shuffle_chunk: bool = True,
|
| 114 |
+
shuffle_example: bool = True,
|
| 115 |
+
ddp_rank: int = 0,
|
| 116 |
+
ddp_world_size: int = 1,
|
| 117 |
+
pseudo_depth_root: Path | None = None,
|
| 118 |
+
pseudo_depth_autogen: bool = True,
|
| 119 |
+
pseudo_depth_backbone: str = "vitl",
|
| 120 |
+
pseudo_depth_device: str = "cpu",
|
| 121 |
+
pseudo_lock_timeout_sec: float = 120.0,
|
| 122 |
+
pseudo_lock_stale_sec: float = 1800.0,
|
| 123 |
+
pseudo_wait_poll_sec: float = 0.25,
|
| 124 |
+
batch_size_hint: int = 1,
|
| 125 |
+
depth_max_m: float = DEFAULT_MAX_DEPTH_M,
|
| 126 |
+
pseudo_far_depth_invalid_m: float = 30.0,
|
| 127 |
+
seed: int = 0,
|
| 128 |
+
) -> None:
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.root = root
|
| 131 |
+
self.split = split
|
| 132 |
+
self.min_frame_gap = int(min_frame_gap)
|
| 133 |
+
self.max_frame_gap = int(max_frame_gap)
|
| 134 |
+
self.pair_max_translation_m = float(pair_max_translation_m)
|
| 135 |
+
self.pair_min_overlap = float(pair_min_overlap)
|
| 136 |
+
self.pair_overlap_sample_h = int(pair_overlap_sample_h)
|
| 137 |
+
self.pair_overlap_sample_w = int(pair_overlap_sample_w)
|
| 138 |
+
self.pair_max_tries = int(pair_max_tries)
|
| 139 |
+
self.output_h = int(output_h) if output_h is not None else None
|
| 140 |
+
self.output_w = int(output_w) if output_w is not None else None
|
| 141 |
+
self.shuffle_chunk = bool(shuffle_chunk)
|
| 142 |
+
self.shuffle_example = bool(shuffle_example)
|
| 143 |
+
self.ddp_rank = int(ddp_rank)
|
| 144 |
+
self.ddp_world_size = int(ddp_world_size)
|
| 145 |
+
self.to_tensor = tf.ToTensor()
|
| 146 |
+
self.pseudo_depth_root = Path(pseudo_depth_root) if pseudo_depth_root is not None else None
|
| 147 |
+
self.pseudo_depth_autogen = bool(pseudo_depth_autogen)
|
| 148 |
+
self.pseudo_depth_backbone = str(pseudo_depth_backbone)
|
| 149 |
+
self.pseudo_depth_device = str(pseudo_depth_device)
|
| 150 |
+
self.pseudo_lock_timeout_sec = float(max(1.0, pseudo_lock_timeout_sec))
|
| 151 |
+
self.pseudo_lock_stale_sec = float(max(30.0, pseudo_lock_stale_sec))
|
| 152 |
+
self.pseudo_wait_poll_sec = float(max(0.05, pseudo_wait_poll_sec))
|
| 153 |
+
self.batch_size_hint = int(max(1, batch_size_hint))
|
| 154 |
+
self.depth_max_m = float(depth_max_m)
|
| 155 |
+
self.pseudo_far_depth_invalid_m = float(pseudo_far_depth_invalid_m)
|
| 156 |
+
self._pseudo_model: torch.nn.Module | None = None
|
| 157 |
+
self.seed = int(seed)
|
| 158 |
+
self.epoch = 0
|
| 159 |
+
|
| 160 |
+
self.chunks_file = Path(chunks_file) if chunks_file is not None else None
|
| 161 |
+
split_dir = self.root / self.split
|
| 162 |
+
if self.chunks_file is not None:
|
| 163 |
+
if not self.chunks_file.exists():
|
| 164 |
+
raise FileNotFoundError(self.chunks_file)
|
| 165 |
+
chunks: list[Path] = []
|
| 166 |
+
for raw in self.chunks_file.read_text(encoding="utf-8").splitlines():
|
| 167 |
+
line = raw.strip()
|
| 168 |
+
if not line:
|
| 169 |
+
continue
|
| 170 |
+
p = Path(line)
|
| 171 |
+
if not p.is_absolute():
|
| 172 |
+
p = split_dir / p
|
| 173 |
+
if p.suffix == ".torch":
|
| 174 |
+
chunks.append(p)
|
| 175 |
+
self.chunks = sorted(chunks)
|
| 176 |
+
else:
|
| 177 |
+
if not split_dir.exists():
|
| 178 |
+
raise FileNotFoundError(split_dir)
|
| 179 |
+
self.chunks = sorted([p for p in split_dir.iterdir() if p.suffix == ".torch"])
|
| 180 |
+
if not self.chunks:
|
| 181 |
+
source = self.chunks_file if self.chunks_file is not None else split_dir
|
| 182 |
+
raise RuntimeError(f"No .torch chunks found for {source}")
|
| 183 |
+
|
| 184 |
+
def set_epoch(self, epoch: int) -> None:
|
| 185 |
+
self.epoch = int(epoch)
|
| 186 |
+
|
| 187 |
+
if self.pseudo_depth_root is not None:
|
| 188 |
+
(self.pseudo_depth_root / self.split).mkdir(parents=True, exist_ok=True)
|
| 189 |
+
|
| 190 |
+
@staticmethod
|
| 191 |
+
def _decode_image_u8(image_bytes_tensor: torch.Tensor) -> torch.Tensor:
|
| 192 |
+
if image_bytes_tensor.dtype != torch.uint8:
|
| 193 |
+
raise ValueError(f"Expected uint8 bytes tensor, got {image_bytes_tensor.dtype}")
|
| 194 |
+
image = Image.open(BytesIO(image_bytes_tensor.numpy().tobytes())).convert("RGB")
|
| 195 |
+
chw_float = tf.ToTensor()(image)
|
| 196 |
+
return (chw_float * 255.0).round().to(torch.uint8)
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def _convert_pose_row_to_w2c(poses: torch.Tensor) -> torch.Tensor:
|
| 200 |
+
t = poses.shape[0]
|
| 201 |
+
w2c = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(t, 1, 1)
|
| 202 |
+
w2c[:, :3] = poses[:, 6:].reshape(t, 3, 4).to(torch.float32)
|
| 203 |
+
return w2c
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
def _convert_intrinsics_to_pixel(poses: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
| 207 |
+
t = poses.shape[0]
|
| 208 |
+
fx, fy, cx, cy = poses[:, 0], poses[:, 1], poses[:, 2], poses[:, 3]
|
| 209 |
+
del t
|
| 210 |
+
return normalized_intrinsics_to_integer_pixel_k(
|
| 211 |
+
fx,
|
| 212 |
+
fy,
|
| 213 |
+
cx,
|
| 214 |
+
cy,
|
| 215 |
+
height=int(h),
|
| 216 |
+
width=int(w),
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
def _sanitize_scene(scene: str) -> str:
|
| 221 |
+
s = str(scene).strip()
|
| 222 |
+
s = s.replace("\\", "__").replace("/", "__")
|
| 223 |
+
return s if len(s) > 0 else "unknown_scene"
|
| 224 |
+
|
| 225 |
+
def _pseudo_depth_path(self, scene: str, frame_idx: int) -> Path | None:
|
| 226 |
+
if self.pseudo_depth_root is None:
|
| 227 |
+
return None
|
| 228 |
+
scene_key = self._sanitize_scene(scene)
|
| 229 |
+
return self.pseudo_depth_root / self.split / scene_key / f"{int(frame_idx):05d}.pt"
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def _load_pseudo_depth(path: Path) -> tuple[torch.Tensor | None, str]:
|
| 233 |
+
if not path.exists():
|
| 234 |
+
return None, "unknown"
|
| 235 |
+
try:
|
| 236 |
+
payload = _torch_load_any(path)
|
| 237 |
+
depth_kind = "distance"
|
| 238 |
+
if isinstance(payload, dict):
|
| 239 |
+
depth = payload.get("depth_m", None)
|
| 240 |
+
depth_kind = str(payload.get("depth_kind", "distance")).strip().lower()
|
| 241 |
+
if depth_kind not in ("distance", "zdepth"):
|
| 242 |
+
depth_kind = "distance"
|
| 243 |
+
else:
|
| 244 |
+
depth = payload
|
| 245 |
+
if not torch.is_tensor(depth):
|
| 246 |
+
return None, "unknown"
|
| 247 |
+
if depth.ndim == 3 and depth.shape[0] == 1:
|
| 248 |
+
depth = depth[0]
|
| 249 |
+
if depth.ndim != 2:
|
| 250 |
+
return None, "unknown"
|
| 251 |
+
depth = depth.to(torch.float32)
|
| 252 |
+
valid = torch.isfinite(depth) & (depth > 0.0)
|
| 253 |
+
if int(valid.sum().item()) <= 0:
|
| 254 |
+
return None, "unknown"
|
| 255 |
+
return depth.unsqueeze(0), depth_kind
|
| 256 |
+
except Exception:
|
| 257 |
+
return None, "unknown"
|
| 258 |
+
|
| 259 |
+
@staticmethod
|
| 260 |
+
def _distance_to_z_depth(depth_1hw: torch.Tensor, intrinsics_k3: torch.Tensor) -> torch.Tensor:
|
| 261 |
+
if depth_1hw.ndim != 3 or depth_1hw.shape[0] != 1:
|
| 262 |
+
raise ValueError(f"Expected depth shape (1,H,W), got {tuple(depth_1hw.shape)}")
|
| 263 |
+
d = depth_1hw.to(torch.float32)
|
| 264 |
+
h = int(d.shape[-2])
|
| 265 |
+
w = int(d.shape[-1])
|
| 266 |
+
k = intrinsics_k3.to(dtype=torch.float32, device=d.device)
|
| 267 |
+
fx = k[0, 0]
|
| 268 |
+
fy = k[1, 1]
|
| 269 |
+
cx = k[0, 2]
|
| 270 |
+
cy = k[1, 2]
|
| 271 |
+
ys = torch.arange(h, device=d.device, dtype=torch.float32)
|
| 272 |
+
xs = torch.arange(w, device=d.device, dtype=torch.float32)
|
| 273 |
+
vv, uu = torch.meshgrid(ys, xs, indexing="ij")
|
| 274 |
+
x = (uu - cx) / fx
|
| 275 |
+
y = (vv - cy) / fy
|
| 276 |
+
ray_z = 1.0 / torch.sqrt(x * x + y * y + 1.0).clamp(min=1e-8)
|
| 277 |
+
z = d[0] * ray_z
|
| 278 |
+
return z.unsqueeze(0)
|
| 279 |
+
|
| 280 |
+
@staticmethod
|
| 281 |
+
def _sanitize_pseudo_depth(
|
| 282 |
+
depth_1hw: torch.Tensor,
|
| 283 |
+
*,
|
| 284 |
+
max_depth_m: float = DEFAULT_MAX_DEPTH_M,
|
| 285 |
+
far_depth_invalid_m: float = 30.0,
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
d = depth_1hw.to(torch.float32)
|
| 288 |
+
valid = torch.isfinite(d) & (d > 0.0)
|
| 289 |
+
if int(valid.sum().item()) <= 0:
|
| 290 |
+
return d
|
| 291 |
+
out = d.clone()
|
| 292 |
+
if float(far_depth_invalid_m) > 0.0:
|
| 293 |
+
valid = valid & (out <= float(far_depth_invalid_m))
|
| 294 |
+
out = torch.where(valid, out, torch.zeros_like(out))
|
| 295 |
+
out[valid] = out[valid].clamp(max=float(max_depth_m))
|
| 296 |
+
return out
|
| 297 |
+
|
| 298 |
+
def _get_or_create_pseudo_model(self) -> torch.nn.Module:
|
| 299 |
+
if self._pseudo_model is None:
|
| 300 |
+
dev = torch.device(self.pseudo_depth_device)
|
| 301 |
+
self._pseudo_model = load_unik3d_model(
|
| 302 |
+
backbone=self.pseudo_depth_backbone,
|
| 303 |
+
pretrained=True,
|
| 304 |
+
device=dev,
|
| 305 |
+
)
|
| 306 |
+
self._pseudo_model.eval()
|
| 307 |
+
LOGGER.info(
|
| 308 |
+
"Re10K pseudo-depth model loaded (split=%s, device=%s, backbone=%s)",
|
| 309 |
+
self.split,
|
| 310 |
+
str(dev),
|
| 311 |
+
self.pseudo_depth_backbone,
|
| 312 |
+
)
|
| 313 |
+
return self._pseudo_model
|
| 314 |
+
|
| 315 |
+
def _save_pseudo_depth_atomic(
|
| 316 |
+
self,
|
| 317 |
+
path: Path,
|
| 318 |
+
depth_2d: torch.Tensor,
|
| 319 |
+
scene: str,
|
| 320 |
+
frame_idx: int,
|
| 321 |
+
) -> None:
|
| 322 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 323 |
+
tmp = path.parent / f".tmp_{os.getpid()}_{int(time.time() * 1e6)}_{random.randint(0, 10_000_000)}.pt"
|
| 324 |
+
payload = {
|
| 325 |
+
"depth_m": depth_2d.to(torch.float16),
|
| 326 |
+
"depth_kind": "distance",
|
| 327 |
+
"scene": str(scene),
|
| 328 |
+
"frame_idx": int(frame_idx),
|
| 329 |
+
}
|
| 330 |
+
torch.save(payload, tmp)
|
| 331 |
+
os.replace(tmp, path)
|
| 332 |
+
|
| 333 |
+
def _acquire_lock_or_wait_for_file(self, target: Path) -> tuple[bool, bool]:
|
| 334 |
+
lock_dir = Path(str(target) + ".lock")
|
| 335 |
+
start = time.time()
|
| 336 |
+
while True:
|
| 337 |
+
if target.exists():
|
| 338 |
+
return False, True
|
| 339 |
+
try:
|
| 340 |
+
lock_dir.mkdir(parents=False, exist_ok=False)
|
| 341 |
+
meta = lock_dir / "owner.txt"
|
| 342 |
+
meta.write_text(f"pid={os.getpid()} time={time.time():.3f}\n", encoding="utf-8")
|
| 343 |
+
return True, False
|
| 344 |
+
except FileExistsError:
|
| 345 |
+
try:
|
| 346 |
+
mtime = lock_dir.stat().st_mtime
|
| 347 |
+
if (time.time() - float(mtime)) > self.pseudo_lock_stale_sec:
|
| 348 |
+
for p in lock_dir.iterdir():
|
| 349 |
+
try:
|
| 350 |
+
p.unlink()
|
| 351 |
+
except Exception:
|
| 352 |
+
pass
|
| 353 |
+
lock_dir.rmdir()
|
| 354 |
+
continue
|
| 355 |
+
except Exception:
|
| 356 |
+
pass
|
| 357 |
+
if (time.time() - start) >= self.pseudo_lock_timeout_sec:
|
| 358 |
+
return False, False
|
| 359 |
+
time.sleep(self.pseudo_wait_poll_sec)
|
| 360 |
+
except Exception:
|
| 361 |
+
return False, False
|
| 362 |
+
|
| 363 |
+
def _release_lock(self, target: Path) -> None:
|
| 364 |
+
lock_dir = Path(str(target) + ".lock")
|
| 365 |
+
if not lock_dir.exists():
|
| 366 |
+
return
|
| 367 |
+
try:
|
| 368 |
+
for p in lock_dir.iterdir():
|
| 369 |
+
try:
|
| 370 |
+
p.unlink()
|
| 371 |
+
except Exception:
|
| 372 |
+
pass
|
| 373 |
+
lock_dir.rmdir()
|
| 374 |
+
except Exception:
|
| 375 |
+
pass
|
| 376 |
+
|
| 377 |
+
def _get_pseudo_depth_for_frame(
|
| 378 |
+
self,
|
| 379 |
+
*,
|
| 380 |
+
scene: str,
|
| 381 |
+
frame_idx: int,
|
| 382 |
+
rgb_u8: torch.Tensor,
|
| 383 |
+
intrinsics_k3: torch.Tensor,
|
| 384 |
+
) -> torch.Tensor | None:
|
| 385 |
+
path = self._pseudo_depth_path(scene, frame_idx)
|
| 386 |
+
if path is None:
|
| 387 |
+
return None
|
| 388 |
+
depth, depth_kind = self._load_pseudo_depth(path)
|
| 389 |
+
if depth is not None:
|
| 390 |
+
if depth_kind != "zdepth":
|
| 391 |
+
try:
|
| 392 |
+
depth = self._distance_to_z_depth(
|
| 393 |
+
self._sanitize_pseudo_depth(
|
| 394 |
+
depth,
|
| 395 |
+
max_depth_m=self.depth_max_m,
|
| 396 |
+
far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
|
| 397 |
+
),
|
| 398 |
+
intrinsics_k3=intrinsics_k3,
|
| 399 |
+
)
|
| 400 |
+
except Exception:
|
| 401 |
+
return None
|
| 402 |
+
else:
|
| 403 |
+
depth = self._sanitize_pseudo_depth(
|
| 404 |
+
depth,
|
| 405 |
+
max_depth_m=self.depth_max_m,
|
| 406 |
+
far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
|
| 407 |
+
)
|
| 408 |
+
return depth
|
| 409 |
+
if not self.pseudo_depth_autogen:
|
| 410 |
+
return None
|
| 411 |
+
|
| 412 |
+
acquired, ready = self._acquire_lock_or_wait_for_file(path)
|
| 413 |
+
if ready:
|
| 414 |
+
depth, depth_kind = self._load_pseudo_depth(path)
|
| 415 |
+
if depth is None:
|
| 416 |
+
return None
|
| 417 |
+
if depth_kind != "zdepth":
|
| 418 |
+
try:
|
| 419 |
+
depth = self._distance_to_z_depth(
|
| 420 |
+
self._sanitize_pseudo_depth(
|
| 421 |
+
depth,
|
| 422 |
+
max_depth_m=self.depth_max_m,
|
| 423 |
+
far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
|
| 424 |
+
),
|
| 425 |
+
intrinsics_k3=intrinsics_k3,
|
| 426 |
+
)
|
| 427 |
+
except Exception:
|
| 428 |
+
return None
|
| 429 |
+
else:
|
| 430 |
+
depth = self._sanitize_pseudo_depth(
|
| 431 |
+
depth,
|
| 432 |
+
max_depth_m=self.depth_max_m,
|
| 433 |
+
far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
|
| 434 |
+
)
|
| 435 |
+
return depth
|
| 436 |
+
if not acquired:
|
| 437 |
+
depth, depth_kind = self._load_pseudo_depth(path)
|
| 438 |
+
if depth is None:
|
| 439 |
+
return None
|
| 440 |
+
if depth_kind != "zdepth":
|
| 441 |
+
try:
|
| 442 |
+
depth = self._distance_to_z_depth(
|
| 443 |
+
self._sanitize_pseudo_depth(
|
| 444 |
+
depth,
|
| 445 |
+
max_depth_m=self.depth_max_m,
|
| 446 |
+
far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
|
| 447 |
+
),
|
| 448 |
+
intrinsics_k3=intrinsics_k3,
|
| 449 |
+
)
|
| 450 |
+
except Exception:
|
| 451 |
+
return None
|
| 452 |
+
else:
|
| 453 |
+
depth = self._sanitize_pseudo_depth(
|
| 454 |
+
depth,
|
| 455 |
+
max_depth_m=self.depth_max_m,
|
| 456 |
+
far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
|
| 457 |
+
)
|
| 458 |
+
return depth
|
| 459 |
+
|
| 460 |
+
try:
|
| 461 |
+
depth, depth_kind = self._load_pseudo_depth(path)
|
| 462 |
+
if depth is not None:
|
| 463 |
+
if depth_kind != "zdepth":
|
| 464 |
+
try:
|
| 465 |
+
depth = self._distance_to_z_depth(
|
| 466 |
+
self._sanitize_pseudo_depth(
|
| 467 |
+
depth,
|
| 468 |
+
max_depth_m=self.depth_max_m,
|
| 469 |
+
far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
|
| 470 |
+
),
|
| 471 |
+
intrinsics_k3=intrinsics_k3,
|
| 472 |
+
)
|
| 473 |
+
except Exception:
|
| 474 |
+
return None
|
| 475 |
+
else:
|
| 476 |
+
depth = self._sanitize_pseudo_depth(
|
| 477 |
+
depth,
|
| 478 |
+
max_depth_m=self.depth_max_m,
|
| 479 |
+
far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
|
| 480 |
+
)
|
| 481 |
+
return depth
|
| 482 |
+
model = self._get_or_create_pseudo_model()
|
| 483 |
+
out = infer_unik3d_pinhole(
|
| 484 |
+
model,
|
| 485 |
+
rgb_u8=rgb_u8.unsqueeze(0),
|
| 486 |
+
intrinsics=intrinsics_k3.unsqueeze(0),
|
| 487 |
+
)
|
| 488 |
+
dist = out.get("distance", None) if isinstance(out, dict) else None
|
| 489 |
+
if not torch.is_tensor(dist) or dist.ndim != 4 or dist.shape[1] != 1:
|
| 490 |
+
return None
|
| 491 |
+
dist_1hw = self._sanitize_pseudo_depth(
|
| 492 |
+
dist[0:1, 0:1].detach().to(torch.float32).cpu()[0],
|
| 493 |
+
max_depth_m=self.depth_max_m,
|
| 494 |
+
far_depth_invalid_m=self.pseudo_far_depth_invalid_m,
|
| 495 |
+
)
|
| 496 |
+
valid = torch.isfinite(dist_1hw) & (dist_1hw > 0.0)
|
| 497 |
+
if int(valid.sum().item()) <= 0:
|
| 498 |
+
return None
|
| 499 |
+
self._save_pseudo_depth_atomic(
|
| 500 |
+
path,
|
| 501 |
+
depth_2d=dist_1hw[0],
|
| 502 |
+
scene=scene,
|
| 503 |
+
frame_idx=frame_idx,
|
| 504 |
+
)
|
| 505 |
+
return self._distance_to_z_depth(dist_1hw, intrinsics_k3=intrinsics_k3.cpu())
|
| 506 |
+
except Exception as e:
|
| 507 |
+
LOGGER.warning(
|
| 508 |
+
"Pseudo-depth generation failed scene=%s frame=%d: %s",
|
| 509 |
+
str(scene),
|
| 510 |
+
int(frame_idx),
|
| 511 |
+
str(e),
|
| 512 |
+
)
|
| 513 |
+
return None
|
| 514 |
+
finally:
|
| 515 |
+
self._release_lock(path)
|
| 516 |
+
|
| 517 |
+
def _candidate_target_indices(
|
| 518 |
+
self,
|
| 519 |
+
src_idx: int,
|
| 520 |
+
num_frames: int,
|
| 521 |
+
w2c_all: torch.Tensor,
|
| 522 |
+
intr_all: torch.Tensor,
|
| 523 |
+
h: int,
|
| 524 |
+
w: int,
|
| 525 |
+
) -> list[int]:
|
| 526 |
+
if num_frames < 2:
|
| 527 |
+
return []
|
| 528 |
+
centers = torch.linalg.inv(w2c_all)[:, :3, 3].to(torch.float32)
|
| 529 |
+
sample_h = int(self.pair_overlap_sample_h)
|
| 530 |
+
sample_w = int(self.pair_overlap_sample_w)
|
| 531 |
+
return select_targets_for_source(
|
| 532 |
+
src_idx=int(src_idx),
|
| 533 |
+
candidate_indices=list(range(num_frames)),
|
| 534 |
+
centers=centers,
|
| 535 |
+
min_index_gap=int(self.min_frame_gap),
|
| 536 |
+
max_index_gap=int(self.max_frame_gap),
|
| 537 |
+
pair_max_translation_m=float(self.pair_max_translation_m),
|
| 538 |
+
pair_min_overlap=float(self.pair_min_overlap),
|
| 539 |
+
overlap_score_fn=lambda si, tj: float(
|
| 540 |
+
0.5
|
| 541 |
+
* (
|
| 542 |
+
project_overlap_ratio(
|
| 543 |
+
src_w2c=w2c_all[si],
|
| 544 |
+
tgt_w2c=w2c_all[tj],
|
| 545 |
+
src_k=intr_all[si],
|
| 546 |
+
tgt_k=intr_all[tj],
|
| 547 |
+
h=h,
|
| 548 |
+
w=w,
|
| 549 |
+
sample_h=sample_h,
|
| 550 |
+
sample_w=sample_w,
|
| 551 |
+
)
|
| 552 |
+
+ project_overlap_ratio(
|
| 553 |
+
src_w2c=w2c_all[tj],
|
| 554 |
+
tgt_w2c=w2c_all[si],
|
| 555 |
+
src_k=intr_all[tj],
|
| 556 |
+
tgt_k=intr_all[si],
|
| 557 |
+
h=h,
|
| 558 |
+
w=w,
|
| 559 |
+
sample_h=sample_h,
|
| 560 |
+
sample_w=sample_w,
|
| 561 |
+
)
|
| 562 |
+
)
|
| 563 |
+
),
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
def __iter__(self):
|
| 567 |
+
chunks = list(self.chunks)
|
| 568 |
+
order_rng = random.Random(self.seed + self.epoch)
|
| 569 |
+
if self.shuffle_chunk and self.split == "train":
|
| 570 |
+
order_rng.shuffle(chunks)
|
| 571 |
+
pending_by_hw: dict[tuple[int, int], deque[Re10KPairSample]] = defaultdict(deque)
|
| 572 |
+
|
| 573 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 574 |
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
| 575 |
+
worker_id = worker_info.id if worker_info is not None else 0
|
| 576 |
+
total_shards = max(1, self.ddp_world_size * num_workers)
|
| 577 |
+
shard_id = self.ddp_rank * num_workers + worker_id
|
| 578 |
+
chunks = [chunk for i, chunk in enumerate(chunks) if i % total_shards == shard_id]
|
| 579 |
+
|
| 580 |
+
for chunk_order_idx, chunk_path in enumerate(chunks):
|
| 581 |
+
chunk = _torch_load_any(chunk_path)
|
| 582 |
+
if not isinstance(chunk, list):
|
| 583 |
+
continue
|
| 584 |
+
examples = list(chunk)
|
| 585 |
+
chunk_rng = random.Random(self.seed + self.epoch * 1000003 + chunk_order_idx)
|
| 586 |
+
if self.shuffle_example and self.split == "train":
|
| 587 |
+
chunk_rng.shuffle(examples)
|
| 588 |
+
|
| 589 |
+
for example in examples:
|
| 590 |
+
if not isinstance(example, dict):
|
| 591 |
+
continue
|
| 592 |
+
if "cameras" not in example or "images" not in example:
|
| 593 |
+
continue
|
| 594 |
+
poses = example["cameras"]
|
| 595 |
+
images = example["images"]
|
| 596 |
+
scene = str(example.get("key", "unknown"))
|
| 597 |
+
if not torch.is_tensor(poses) or not isinstance(images, list):
|
| 598 |
+
continue
|
| 599 |
+
if poses.ndim != 2 or poses.shape[1] != 18:
|
| 600 |
+
continue
|
| 601 |
+
if len(images) != int(poses.shape[0]):
|
| 602 |
+
continue
|
| 603 |
+
|
| 604 |
+
try:
|
| 605 |
+
src_probe = self._decode_image_u8(images[0])
|
| 606 |
+
except Exception:
|
| 607 |
+
continue
|
| 608 |
+
h, w = int(src_probe.shape[1]), int(src_probe.shape[2])
|
| 609 |
+
w2c_all = self._convert_pose_row_to_w2c(poses)
|
| 610 |
+
intr_all = self._convert_intrinsics_to_pixel(poses, h=h, w=w)
|
| 611 |
+
src_indices = list(range(len(images)))
|
| 612 |
+
if self.shuffle_example and self.split == "train":
|
| 613 |
+
chunk_rng.shuffle(src_indices)
|
| 614 |
+
for src_idx in src_indices:
|
| 615 |
+
tgt_candidates = self._candidate_target_indices(
|
| 616 |
+
int(src_idx),
|
| 617 |
+
len(images),
|
| 618 |
+
w2c_all=w2c_all,
|
| 619 |
+
intr_all=intr_all,
|
| 620 |
+
h=h,
|
| 621 |
+
w=w,
|
| 622 |
+
)
|
| 623 |
+
if not tgt_candidates:
|
| 624 |
+
continue
|
| 625 |
+
tgt_idx = chunk_rng.choice(tgt_candidates)
|
| 626 |
+
|
| 627 |
+
try:
|
| 628 |
+
src_img = self._decode_image_u8(images[src_idx])
|
| 629 |
+
tgt_img = self._decode_image_u8(images[tgt_idx])
|
| 630 |
+
except Exception:
|
| 631 |
+
continue
|
| 632 |
+
if src_img.shape != tgt_img.shape:
|
| 633 |
+
continue
|
| 634 |
+
src_intr = intr_all[src_idx].clone()
|
| 635 |
+
tgt_intr = intr_all[tgt_idx].clone()
|
| 636 |
+
src_depth = self._get_pseudo_depth_for_frame(
|
| 637 |
+
scene=scene,
|
| 638 |
+
frame_idx=int(src_idx),
|
| 639 |
+
rgb_u8=src_img,
|
| 640 |
+
intrinsics_k3=intr_all[src_idx].to(torch.float32),
|
| 641 |
+
)
|
| 642 |
+
tgt_depth = self._get_pseudo_depth_for_frame(
|
| 643 |
+
scene=scene,
|
| 644 |
+
frame_idx=int(tgt_idx),
|
| 645 |
+
rgb_u8=tgt_img,
|
| 646 |
+
intrinsics_k3=intr_all[tgt_idx].to(torch.float32),
|
| 647 |
+
)
|
| 648 |
+
if self.pseudo_depth_root is not None and (
|
| 649 |
+
(not torch.is_tensor(src_depth)) or (not torch.is_tensor(tgt_depth))
|
| 650 |
+
):
|
| 651 |
+
continue
|
| 652 |
+
if self.output_h is not None and self.output_w is not None:
|
| 653 |
+
oh, ow = int(src_img.shape[1]), int(src_img.shape[2])
|
| 654 |
+
if oh > 0 and ow > 0 and (oh != self.output_h or ow != self.output_w):
|
| 655 |
+
sx = float(self.output_w) / float(ow)
|
| 656 |
+
sy = float(self.output_h) / float(oh)
|
| 657 |
+
src_img = resize_rgb_u8_chw_high_quality(src_img, size=(self.output_h, self.output_w))
|
| 658 |
+
tgt_img = resize_rgb_u8_chw_high_quality(tgt_img, size=(self.output_h, self.output_w))
|
| 659 |
+
src_intr = resize_k3_align_corners_false(src_intr, sx=sx, sy=sy)
|
| 660 |
+
tgt_intr = resize_k3_align_corners_false(tgt_intr, sx=sx, sy=sy)
|
| 661 |
+
if torch.is_tensor(src_depth):
|
| 662 |
+
src_depth = (
|
| 663 |
+
torch.nn.functional.interpolate(
|
| 664 |
+
src_depth[None],
|
| 665 |
+
size=(self.output_h, self.output_w),
|
| 666 |
+
mode="bilinear",
|
| 667 |
+
align_corners=False,
|
| 668 |
+
)
|
| 669 |
+
.squeeze(0)
|
| 670 |
+
.to(torch.float32)
|
| 671 |
+
)
|
| 672 |
+
if torch.is_tensor(tgt_depth):
|
| 673 |
+
tgt_depth = (
|
| 674 |
+
torch.nn.functional.interpolate(
|
| 675 |
+
tgt_depth[None],
|
| 676 |
+
size=(self.output_h, self.output_w),
|
| 677 |
+
mode="bilinear",
|
| 678 |
+
align_corners=False,
|
| 679 |
+
)
|
| 680 |
+
.squeeze(0)
|
| 681 |
+
.to(torch.float32)
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
sample = Re10KPairSample(
|
| 685 |
+
src_rgb_u8=src_img,
|
| 686 |
+
tgt_rgb_u8=tgt_img,
|
| 687 |
+
src_w2c=w2c_all[src_idx],
|
| 688 |
+
tgt_w2c=w2c_all[tgt_idx],
|
| 689 |
+
src_intrinsics=src_intr,
|
| 690 |
+
tgt_intrinsics=tgt_intr,
|
| 691 |
+
src_idx=int(src_idx),
|
| 692 |
+
tgt_idx=int(tgt_idx),
|
| 693 |
+
scene=scene,
|
| 694 |
+
src_depth_m=src_depth,
|
| 695 |
+
tgt_depth_m=tgt_depth,
|
| 696 |
+
)
|
| 697 |
+
hw_key = (int(sample.src_rgb_u8.shape[1]), int(sample.src_rgb_u8.shape[2]))
|
| 698 |
+
bucket = pending_by_hw[hw_key]
|
| 699 |
+
bucket.append(sample)
|
| 700 |
+
if self.batch_size_hint <= 1:
|
| 701 |
+
yield bucket.popleft()
|
| 702 |
+
continue
|
| 703 |
+
while len(bucket) >= self.batch_size_hint:
|
| 704 |
+
packed = [bucket.popleft() for _ in range(self.batch_size_hint)]
|
| 705 |
+
yield _pack_re10k_batch(packed)
|
| 706 |
+
|
| 707 |
+
dropped = sum(len(bucket) for bucket in pending_by_hw.values())
|
| 708 |
+
if dropped > 0 and self.split == "train" and self.batch_size_hint > 1:
|
| 709 |
+
LOGGER.debug(
|
| 710 |
+
"Dropped %d RE10K leftover samples that could not form a same-resolution batch of size %d.",
|
| 711 |
+
int(dropped),
|
| 712 |
+
int(self.batch_size_hint),
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def re10k_collate(batch: list[Re10KPairSample]) -> Re10KPairSample:
|
| 717 |
+
return _pack_re10k_batch(batch)
|
| 718 |
+
|
unisharp/datasets/scannetpp_fisheye.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import random
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from torch.utils.data import IterableDataset
|
| 14 |
+
|
| 15 |
+
from unisharp import DEFAULT_MAX_DEPTH_M
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
LOGGER = logging.getLogger(__name__)
|
| 19 |
+
IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"}
|
| 20 |
+
DEPTH_DIR_NAMES = ("depth", "depths", "distance", "distances", "depth_maps")
|
| 21 |
+
MASK_DIR_NAMES = ("masks", "mask")
|
| 22 |
+
DEPTH_MAX_M = DEFAULT_MAX_DEPTH_M
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _qvec_to_rotmat(qvec: np.ndarray) -> np.ndarray:
|
| 26 |
+
q = np.asarray(qvec, dtype=np.float64)
|
| 27 |
+
return np.array(
|
| 28 |
+
[
|
| 29 |
+
[1 - 2 * q[2] ** 2 - 2 * q[3] ** 2, 2 * q[1] * q[2] - 2 * q[0] * q[3], 2 * q[3] * q[1] + 2 * q[0] * q[2]],
|
| 30 |
+
[2 * q[1] * q[2] + 2 * q[0] * q[3], 1 - 2 * q[1] ** 2 - 2 * q[3] ** 2, 2 * q[2] * q[3] - 2 * q[0] * q[1]],
|
| 31 |
+
[2 * q[3] * q[1] - 2 * q[0] * q[2], 2 * q[2] * q[3] + 2 * q[0] * q[1], 1 - 2 * q[1] ** 2 - 2 * q[2] ** 2],
|
| 32 |
+
],
|
| 33 |
+
dtype=np.float64,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _read_colmap_w2c(images_txt: Path) -> dict[str, torch.Tensor]:
|
| 38 |
+
poses: dict[str, torch.Tensor] = {}
|
| 39 |
+
if not images_txt.exists():
|
| 40 |
+
return poses
|
| 41 |
+
with images_txt.open("r", encoding="utf-8") as f:
|
| 42 |
+
for raw in f:
|
| 43 |
+
line = raw.strip()
|
| 44 |
+
if not line or line.startswith("#"):
|
| 45 |
+
continue
|
| 46 |
+
parts = line.split()
|
| 47 |
+
if len(parts) < 10:
|
| 48 |
+
continue
|
| 49 |
+
try:
|
| 50 |
+
qvec = np.asarray([float(x) for x in parts[1:5]], dtype=np.float64)
|
| 51 |
+
tvec = np.asarray([float(x) for x in parts[5:8]], dtype=np.float64)
|
| 52 |
+
image_name = parts[9]
|
| 53 |
+
except Exception:
|
| 54 |
+
continue
|
| 55 |
+
w2c = np.eye(4, dtype=np.float32)
|
| 56 |
+
w2c[:3, :3] = _qvec_to_rotmat(qvec).astype(np.float32)
|
| 57 |
+
w2c[:3, 3] = tvec.astype(np.float32)
|
| 58 |
+
poses[Path(image_name).name] = torch.from_numpy(w2c)
|
| 59 |
+
return poses
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _opencv_fisheye_to_fisheye624_params(meta: dict[str, object]) -> torch.Tensor:
|
| 63 |
+
if str(meta.get("camera_model", "")) != "OPENCV_FISHEYE":
|
| 64 |
+
raise RuntimeError(f"Unsupported ScanNet++ camera_model={meta.get('camera_model')!r}; expected OPENCV_FISHEYE.")
|
| 65 |
+
return torch.tensor(
|
| 66 |
+
[
|
| 67 |
+
float(meta["fl_x"]),
|
| 68 |
+
float(meta["fl_y"]),
|
| 69 |
+
float(meta["cx"]),
|
| 70 |
+
float(meta["cy"]),
|
| 71 |
+
float(meta.get("k1", 0.0)),
|
| 72 |
+
float(meta.get("k2", 0.0)),
|
| 73 |
+
float(meta.get("k3", 0.0)),
|
| 74 |
+
float(meta.get("k4", 0.0)),
|
| 75 |
+
0.0,
|
| 76 |
+
0.0,
|
| 77 |
+
0.0,
|
| 78 |
+
0.0,
|
| 79 |
+
0.0,
|
| 80 |
+
0.0,
|
| 81 |
+
0.0,
|
| 82 |
+
0.0,
|
| 83 |
+
],
|
| 84 |
+
dtype=torch.float32,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _camera_hw_from_meta(meta: dict[str, object]) -> tuple[int, int] | None:
|
| 89 |
+
h = meta.get("h", meta.get("height", None))
|
| 90 |
+
w = meta.get("w", meta.get("width", None))
|
| 91 |
+
if h is None or w is None:
|
| 92 |
+
return None
|
| 93 |
+
try:
|
| 94 |
+
h_i, w_i = int(h), int(w)
|
| 95 |
+
except Exception:
|
| 96 |
+
return None
|
| 97 |
+
return (h_i, w_i) if h_i > 0 and w_i > 0 else None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _scale_fisheye624_params(
|
| 101 |
+
params: torch.Tensor,
|
| 102 |
+
*,
|
| 103 |
+
src_hw: tuple[int, int],
|
| 104 |
+
dst_hw: tuple[int, int],
|
| 105 |
+
) -> torch.Tensor:
|
| 106 |
+
if tuple(int(x) for x in src_hw) == tuple(int(x) for x in dst_hw):
|
| 107 |
+
return params.clone()
|
| 108 |
+
src_h, src_w = int(src_hw[0]), int(src_hw[1])
|
| 109 |
+
dst_h, dst_w = int(dst_hw[0]), int(dst_hw[1])
|
| 110 |
+
sx = float(dst_w) / float(max(src_w, 1))
|
| 111 |
+
sy = float(dst_h) / float(max(src_h, 1))
|
| 112 |
+
out = params.clone()
|
| 113 |
+
out[..., 0] *= sx
|
| 114 |
+
out[..., 1] *= sy
|
| 115 |
+
out[..., 2] = (out[..., 2] + 0.5) * sx - 0.5
|
| 116 |
+
out[..., 3] = (out[..., 3] + 0.5) * sy - 0.5
|
| 117 |
+
return out
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _stack_batch(batch: list["ScannetppFisheyePairSample"]) -> "ScannetppFisheyePairSample":
|
| 121 |
+
return ScannetppFisheyePairSample(
|
| 122 |
+
src_rgb_u8=torch.stack([b.src_rgb_u8 for b in batch], dim=0),
|
| 123 |
+
tgt_rgb_u8=torch.stack([b.tgt_rgb_u8 for b in batch], dim=0),
|
| 124 |
+
src_depth_m=torch.stack([b.src_depth_m for b in batch], dim=0),
|
| 125 |
+
tgt_depth_m=torch.stack([b.tgt_depth_m for b in batch], dim=0),
|
| 126 |
+
src_valid_mask=torch.stack([b.src_valid_mask for b in batch], dim=0),
|
| 127 |
+
tgt_valid_mask=torch.stack([b.tgt_valid_mask for b in batch], dim=0),
|
| 128 |
+
src_w2c=torch.stack([b.src_w2c for b in batch], dim=0),
|
| 129 |
+
tgt_w2c=torch.stack([b.tgt_w2c for b in batch], dim=0),
|
| 130 |
+
src_camera_params=torch.stack([b.src_camera_params for b in batch], dim=0),
|
| 131 |
+
tgt_camera_params=torch.stack([b.tgt_camera_params for b in batch], dim=0),
|
| 132 |
+
src_idx=[b.src_idx for b in batch], # type: ignore[arg-type]
|
| 133 |
+
tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type]
|
| 134 |
+
scene=[b.scene for b in batch], # type: ignore[arg-type]
|
| 135 |
+
camera_model="fisheye624",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass(frozen=True)
|
| 140 |
+
class ScannetppFisheyePairSample:
|
| 141 |
+
src_rgb_u8: torch.Tensor
|
| 142 |
+
tgt_rgb_u8: torch.Tensor
|
| 143 |
+
src_depth_m: torch.Tensor
|
| 144 |
+
tgt_depth_m: torch.Tensor
|
| 145 |
+
src_valid_mask: torch.Tensor
|
| 146 |
+
tgt_valid_mask: torch.Tensor
|
| 147 |
+
src_w2c: torch.Tensor
|
| 148 |
+
tgt_w2c: torch.Tensor
|
| 149 |
+
src_camera_params: torch.Tensor
|
| 150 |
+
tgt_camera_params: torch.Tensor
|
| 151 |
+
src_idx: int
|
| 152 |
+
tgt_idx: int
|
| 153 |
+
scene: str
|
| 154 |
+
camera_model: str = "fisheye624"
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def scannetpp_fisheye_passthrough(batch: ScannetppFisheyePairSample) -> ScannetppFisheyePairSample:
|
| 158 |
+
return batch
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class ScannetppFisheyeDataset(IterableDataset):
|
| 162 |
+
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
root: Path,
|
| 166 |
+
scene_list_file: Path | None = None,
|
| 167 |
+
min_frame_gap: int = 1,
|
| 168 |
+
max_frame_gap: int = 10,
|
| 169 |
+
pair_max_translation_m: float = 0.5,
|
| 170 |
+
shuffle_scene: bool = True,
|
| 171 |
+
shuffle_frame: bool = True,
|
| 172 |
+
skip_bad: bool = True,
|
| 173 |
+
ddp_rank: int = 0,
|
| 174 |
+
ddp_world_size: int = 1,
|
| 175 |
+
batch_size_hint: int = 1,
|
| 176 |
+
depth_max_m: float = DEFAULT_MAX_DEPTH_M,
|
| 177 |
+
far_depth_invalid_m: float = 30.0,
|
| 178 |
+
seed: int = 0,
|
| 179 |
+
) -> None:
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.root = Path(root)
|
| 182 |
+
self.min_frame_gap = int(min_frame_gap)
|
| 183 |
+
self.max_frame_gap = int(max_frame_gap)
|
| 184 |
+
self.pair_max_translation_m = float(pair_max_translation_m)
|
| 185 |
+
self.shuffle_scene = bool(shuffle_scene)
|
| 186 |
+
self.shuffle_frame = bool(shuffle_frame)
|
| 187 |
+
self.skip_bad = bool(skip_bad)
|
| 188 |
+
self.ddp_rank = int(ddp_rank)
|
| 189 |
+
self.ddp_world_size = int(ddp_world_size)
|
| 190 |
+
self.batch_size_hint = int(max(1, batch_size_hint))
|
| 191 |
+
self.depth_max_m = float(depth_max_m)
|
| 192 |
+
self.far_depth_invalid_m = float(far_depth_invalid_m)
|
| 193 |
+
self.seed = int(seed)
|
| 194 |
+
self.epoch = 0
|
| 195 |
+
self.scene_specs = self._load_scene_specs(scene_list_file)
|
| 196 |
+
if not self.scene_specs:
|
| 197 |
+
raise RuntimeError(f"No ScanNet++ fisheye scenes found under {self.root}")
|
| 198 |
+
|
| 199 |
+
def set_epoch(self, epoch: int) -> None:
|
| 200 |
+
self.epoch = int(epoch)
|
| 201 |
+
|
| 202 |
+
def _load_scene_specs(self, scene_list_file: Path | None) -> list[tuple[str, Path]]:
|
| 203 |
+
specs: list[tuple[str, Path]] = []
|
| 204 |
+
if scene_list_file is not None and Path(scene_list_file).exists():
|
| 205 |
+
for raw in Path(scene_list_file).read_text(encoding="utf-8").splitlines():
|
| 206 |
+
line = raw.strip()
|
| 207 |
+
if not line:
|
| 208 |
+
continue
|
| 209 |
+
parts = line.split("|")
|
| 210 |
+
if len(parts) == 1:
|
| 211 |
+
scene_dir = Path(parts[0])
|
| 212 |
+
scene_id = scene_dir.name
|
| 213 |
+
else:
|
| 214 |
+
scene_id = parts[0]
|
| 215 |
+
scene_dir = Path(parts[1])
|
| 216 |
+
if not scene_dir.is_absolute():
|
| 217 |
+
scene_dir = self.root / scene_dir
|
| 218 |
+
specs.append((scene_id, scene_dir))
|
| 219 |
+
return specs
|
| 220 |
+
for transforms in sorted(self.root.glob("*/nerfstudio/transforms.json")):
|
| 221 |
+
specs.append((transforms.parent.parent.name, transforms.parent.parent))
|
| 222 |
+
for transforms in sorted(self.root.glob("*/*/nerfstudio/transforms.json")):
|
| 223 |
+
specs.append((f"{transforms.parent.parent.parent.name}/{transforms.parent.parent.name}", transforms.parent.parent))
|
| 224 |
+
return specs
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def _load_rgb(path: Path) -> torch.Tensor:
|
| 228 |
+
with Image.open(path) as image:
|
| 229 |
+
arr = np.asarray(image.convert("RGB"), dtype=np.uint8).copy()
|
| 230 |
+
return torch.from_numpy(arr).permute(2, 0, 1).contiguous()
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def _load_mask(path: Path, image_hw: tuple[int, int]) -> torch.Tensor | None:
|
| 234 |
+
if not path.exists():
|
| 235 |
+
return None
|
| 236 |
+
with Image.open(path) as image:
|
| 237 |
+
arr = np.asarray(image.convert("L"), dtype=np.uint8).copy()
|
| 238 |
+
mask = torch.from_numpy(arr).unsqueeze(0).to(torch.float32) / 255.0
|
| 239 |
+
if tuple(mask.shape[-2:]) != tuple(image_hw):
|
| 240 |
+
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=image_hw, mode="nearest").squeeze(0)
|
| 241 |
+
return (mask > 0.5).to(torch.float32)
|
| 242 |
+
|
| 243 |
+
def _load_depth_map(self, path: Path) -> tuple[torch.Tensor, str]:
|
| 244 |
+
depth_kind = "distance"
|
| 245 |
+
if path.suffix.lower() == ".npz":
|
| 246 |
+
payload = np.load(path, allow_pickle=False)
|
| 247 |
+
for key in ("distance_m", "depth_m", "distance", "depth"):
|
| 248 |
+
if key in payload:
|
| 249 |
+
arr = payload[key]
|
| 250 |
+
if key in {"distance_m", "distance"}:
|
| 251 |
+
depth_kind = "distance"
|
| 252 |
+
elif "depth_kind" in payload:
|
| 253 |
+
depth_kind = str(np.asarray(payload["depth_kind"]).item()).strip().lower()
|
| 254 |
+
break
|
| 255 |
+
else:
|
| 256 |
+
raise RuntimeError(f"Unsupported ScanNet++ depth payload keys at {path}")
|
| 257 |
+
else:
|
| 258 |
+
arr = np.load(path)
|
| 259 |
+
depth = torch.from_numpy(np.asarray(arr, dtype=np.float32).copy())
|
| 260 |
+
if depth.ndim == 3 and depth.shape[0] == 1:
|
| 261 |
+
depth = depth[0]
|
| 262 |
+
if depth.ndim != 2:
|
| 263 |
+
raise RuntimeError(f"Expected 2D fisheye depth at {path}, got shape={tuple(depth.shape)}")
|
| 264 |
+
depth = depth.unsqueeze(0)
|
| 265 |
+
valid = torch.isfinite(depth) & (depth > 0.0)
|
| 266 |
+
if self.far_depth_invalid_m > 0.0:
|
| 267 |
+
valid = valid & (depth <= self.far_depth_invalid_m)
|
| 268 |
+
depth = torch.where(valid, depth, torch.zeros_like(depth))
|
| 269 |
+
if depth_kind in {"radial", "radius", "dist"}:
|
| 270 |
+
depth_kind = "distance"
|
| 271 |
+
if depth_kind not in {"distance", "z"}:
|
| 272 |
+
raise RuntimeError(f"Unsupported fisheye depth_kind={depth_kind!r} at {path}")
|
| 273 |
+
return depth.clamp(min=0.0, max=self.depth_max_m), depth_kind
|
| 274 |
+
|
| 275 |
+
@staticmethod
|
| 276 |
+
def _fisheye_z_depth_to_distance(z_depth: torch.Tensor, camera_params: torch.Tensor) -> torch.Tensor:
|
| 277 |
+
from unisharp.utils.fisheye_geer import build_fisheye624_raymap
|
| 278 |
+
|
| 279 |
+
h, w = int(z_depth.shape[-2]), int(z_depth.shape[-1])
|
| 280 |
+
rays = build_fisheye624_raymap(
|
| 281 |
+
camera_params.unsqueeze(0),
|
| 282 |
+
image_h=h,
|
| 283 |
+
image_w=w,
|
| 284 |
+
device=z_depth.device,
|
| 285 |
+
dtype=torch.float32,
|
| 286 |
+
)
|
| 287 |
+
ray_z = rays[:, 2:3].squeeze(0).to(device=z_depth.device, dtype=z_depth.dtype)
|
| 288 |
+
valid = torch.isfinite(z_depth) & (z_depth > 0.0) & torch.isfinite(ray_z) & (ray_z > 1e-4)
|
| 289 |
+
distance = z_depth / ray_z.clamp(min=1e-4)
|
| 290 |
+
return torch.where(valid, distance, torch.zeros_like(z_depth))
|
| 291 |
+
|
| 292 |
+
def _resolve_image_path(self, scene_dir: Path, image_name: str) -> Path | None:
|
| 293 |
+
rel = Path(image_name)
|
| 294 |
+
candidates = [
|
| 295 |
+
scene_dir / rel,
|
| 296 |
+
scene_dir / "images" / rel.name,
|
| 297 |
+
scene_dir / "resized_images" / rel.name,
|
| 298 |
+
scene_dir / "dslr" / rel,
|
| 299 |
+
scene_dir / "dslr" / "images" / rel.name,
|
| 300 |
+
scene_dir / "dslr" / "resized_images" / rel.name,
|
| 301 |
+
]
|
| 302 |
+
for path in candidates:
|
| 303 |
+
if path.exists() and path.suffix in IMAGE_SUFFIXES:
|
| 304 |
+
return path
|
| 305 |
+
return None
|
| 306 |
+
|
| 307 |
+
def _resolve_depth_path(self, scene_dir: Path, image_name: str) -> Path | None:
|
| 308 |
+
stem = Path(image_name).stem
|
| 309 |
+
names = [stem, Path(image_name).name]
|
| 310 |
+
bases = [scene_dir, scene_dir / "dslr"]
|
| 311 |
+
for base in bases:
|
| 312 |
+
for depth_dir_name in DEPTH_DIR_NAMES:
|
| 313 |
+
depth_dir = base / depth_dir_name
|
| 314 |
+
for name in names:
|
| 315 |
+
for suffix in (".npz", ".npy"):
|
| 316 |
+
path = depth_dir / f"{name}{suffix}"
|
| 317 |
+
if path.exists():
|
| 318 |
+
return path
|
| 319 |
+
return None
|
| 320 |
+
|
| 321 |
+
def _resolve_mask_path(self, scene_dir: Path, image_name: str, mask_name: str | None) -> Path | None:
|
| 322 |
+
names = []
|
| 323 |
+
if mask_name:
|
| 324 |
+
names.append(Path(mask_name).name)
|
| 325 |
+
names.append(f"{Path(image_name).stem}.png")
|
| 326 |
+
bases = [scene_dir, scene_dir / "dslr"]
|
| 327 |
+
for base in bases:
|
| 328 |
+
for name in names:
|
| 329 |
+
direct = base / name
|
| 330 |
+
if direct.exists():
|
| 331 |
+
return direct
|
| 332 |
+
for mask_dir_name in MASK_DIR_NAMES:
|
| 333 |
+
path = base / mask_dir_name / name
|
| 334 |
+
if path.exists():
|
| 335 |
+
return path
|
| 336 |
+
return None
|
| 337 |
+
|
| 338 |
+
def _load_scene_frames(self, scene_id: str, scene_dir: Path) -> tuple[torch.Tensor, list[dict[str, object]]]:
|
| 339 |
+
transforms_path = scene_dir / "nerfstudio" / "transforms.json"
|
| 340 |
+
if not transforms_path.exists():
|
| 341 |
+
transforms_path = scene_dir / "dslr" / "nerfstudio" / "transforms.json"
|
| 342 |
+
meta = json.loads(transforms_path.read_text(encoding="utf-8"))
|
| 343 |
+
camera_params = _opencv_fisheye_to_fisheye624_params(meta)
|
| 344 |
+
camera_hw = _camera_hw_from_meta(meta)
|
| 345 |
+
w2c_by_name = _read_colmap_w2c(scene_dir / "colmap" / "images.txt")
|
| 346 |
+
if not w2c_by_name:
|
| 347 |
+
w2c_by_name = _read_colmap_w2c(scene_dir / "dslr" / "colmap" / "images.txt")
|
| 348 |
+
|
| 349 |
+
raw_frames = list(meta.get("frames", [])) + list(meta.get("test_frames", []))
|
| 350 |
+
frames: list[dict[str, object]] = []
|
| 351 |
+
for frame in raw_frames:
|
| 352 |
+
image_name = Path(str(frame.get("file_path", ""))).name
|
| 353 |
+
if not image_name:
|
| 354 |
+
continue
|
| 355 |
+
if self.skip_bad and bool(frame.get("is_bad", False)):
|
| 356 |
+
continue
|
| 357 |
+
image_path = self._resolve_image_path(scene_dir, image_name)
|
| 358 |
+
depth_path = self._resolve_depth_path(scene_dir, image_name)
|
| 359 |
+
if image_path is None or depth_path is None:
|
| 360 |
+
continue
|
| 361 |
+
w2c = w2c_by_name.get(image_name)
|
| 362 |
+
if w2c is None and frame.get("transform_matrix") is not None:
|
| 363 |
+
c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32)
|
| 364 |
+
w2c = torch.linalg.inv(c2w)
|
| 365 |
+
if w2c is None:
|
| 366 |
+
continue
|
| 367 |
+
center = torch.linalg.inv(w2c)[:3, 3]
|
| 368 |
+
frames.append(
|
| 369 |
+
{
|
| 370 |
+
"image_name": image_name,
|
| 371 |
+
"image_path": image_path,
|
| 372 |
+
"depth_path": depth_path,
|
| 373 |
+
"mask_path": self._resolve_mask_path(scene_dir, image_name, frame.get("mask_path")),
|
| 374 |
+
"w2c": w2c.to(torch.float32),
|
| 375 |
+
"center": center.to(torch.float32),
|
| 376 |
+
"idx": len(frames),
|
| 377 |
+
"scene": scene_id,
|
| 378 |
+
"camera_hw": _camera_hw_from_meta(frame) or camera_hw,
|
| 379 |
+
}
|
| 380 |
+
)
|
| 381 |
+
return camera_params, sorted(frames, key=lambda x: str(x["image_name"]))
|
| 382 |
+
|
| 383 |
+
def _load_frame_tensor(self, frame: dict[str, object], camera_params: torch.Tensor) -> dict[str, torch.Tensor]:
|
| 384 |
+
rgb = self._load_rgb(frame["image_path"]) # type: ignore[arg-type]
|
| 385 |
+
rgb_hw = (int(rgb.shape[-2]), int(rgb.shape[-1]))
|
| 386 |
+
camera_hw = frame.get("camera_hw", None)
|
| 387 |
+
params = camera_params.clone()
|
| 388 |
+
if isinstance(camera_hw, tuple):
|
| 389 |
+
params = _scale_fisheye624_params(params, src_hw=camera_hw, dst_hw=rgb_hw)
|
| 390 |
+
depth, depth_kind = self._load_depth_map(frame["depth_path"]) # type: ignore[arg-type]
|
| 391 |
+
if tuple(depth.shape[-2:]) != tuple(rgb.shape[-2:]):
|
| 392 |
+
depth = torch.nn.functional.interpolate(
|
| 393 |
+
depth.unsqueeze(0),
|
| 394 |
+
size=(int(rgb.shape[-2]), int(rgb.shape[-1])),
|
| 395 |
+
mode="nearest",
|
| 396 |
+
).squeeze(0)
|
| 397 |
+
if depth_kind == "z":
|
| 398 |
+
depth = self._fisheye_z_depth_to_distance(depth, params)
|
| 399 |
+
valid = (torch.isfinite(depth) & (depth > 0.0)).to(torch.float32)
|
| 400 |
+
mask_path = frame.get("mask_path", None)
|
| 401 |
+
if isinstance(mask_path, Path):
|
| 402 |
+
mask = self._load_mask(mask_path, (int(rgb.shape[-2]), int(rgb.shape[-1])))
|
| 403 |
+
if mask is not None:
|
| 404 |
+
valid = valid * mask
|
| 405 |
+
else:
|
| 406 |
+
valid = valid * (rgb.to(torch.float32).sum(dim=0, keepdim=True) > 1.0).to(torch.float32)
|
| 407 |
+
return {
|
| 408 |
+
"rgb_u8": rgb,
|
| 409 |
+
"depth_m": depth.clamp(min=0.0, max=self.depth_max_m),
|
| 410 |
+
"valid_mask": valid,
|
| 411 |
+
"camera_params": params,
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
def _iter_scene_pairs(self, scene_id: str, scene_dir: Path, rng: random.Random):
|
| 415 |
+
try:
|
| 416 |
+
camera_params, frames = self._load_scene_frames(scene_id, scene_dir)
|
| 417 |
+
except Exception as exc:
|
| 418 |
+
LOGGER.debug("Skip ScanNet++ scene %s: %s", str(scene_id), str(exc))
|
| 419 |
+
return
|
| 420 |
+
if len(frames) < 2:
|
| 421 |
+
return
|
| 422 |
+
loaded: dict[int, dict[str, torch.Tensor]] = {}
|
| 423 |
+
|
| 424 |
+
def get_loaded(pos: int) -> dict[str, torch.Tensor]:
|
| 425 |
+
if pos not in loaded:
|
| 426 |
+
loaded[pos] = self._load_frame_tensor(frames[pos], camera_params)
|
| 427 |
+
return loaded[pos]
|
| 428 |
+
|
| 429 |
+
order = list(range(len(frames)))
|
| 430 |
+
if self.shuffle_frame:
|
| 431 |
+
rng.shuffle(order)
|
| 432 |
+
for src_pos in order:
|
| 433 |
+
src_item = frames[src_pos]
|
| 434 |
+
src_center = src_item["center"]
|
| 435 |
+
assert torch.is_tensor(src_center)
|
| 436 |
+
candidates: list[int] = []
|
| 437 |
+
for tgt_pos in range(max(0, src_pos - self.max_frame_gap), min(len(frames), src_pos + self.max_frame_gap + 1)):
|
| 438 |
+
if tgt_pos == src_pos:
|
| 439 |
+
continue
|
| 440 |
+
gap = abs(tgt_pos - src_pos)
|
| 441 |
+
if gap < self.min_frame_gap:
|
| 442 |
+
continue
|
| 443 |
+
tgt_center = frames[tgt_pos]["center"]
|
| 444 |
+
assert torch.is_tensor(tgt_center)
|
| 445 |
+
if float(torch.norm(tgt_center - src_center, p=2).item()) > self.pair_max_translation_m:
|
| 446 |
+
continue
|
| 447 |
+
candidates.append(tgt_pos)
|
| 448 |
+
if not candidates:
|
| 449 |
+
continue
|
| 450 |
+
tgt_pos = rng.choice(candidates)
|
| 451 |
+
try:
|
| 452 |
+
src_loaded = get_loaded(src_pos)
|
| 453 |
+
tgt_loaded = get_loaded(tgt_pos)
|
| 454 |
+
except Exception:
|
| 455 |
+
continue
|
| 456 |
+
yield ScannetppFisheyePairSample(
|
| 457 |
+
src_rgb_u8=src_loaded["rgb_u8"],
|
| 458 |
+
tgt_rgb_u8=tgt_loaded["rgb_u8"],
|
| 459 |
+
src_depth_m=src_loaded["depth_m"],
|
| 460 |
+
tgt_depth_m=tgt_loaded["depth_m"],
|
| 461 |
+
src_valid_mask=src_loaded["valid_mask"],
|
| 462 |
+
tgt_valid_mask=tgt_loaded["valid_mask"],
|
| 463 |
+
src_w2c=src_item["w2c"], # type: ignore[arg-type]
|
| 464 |
+
tgt_w2c=frames[tgt_pos]["w2c"], # type: ignore[arg-type]
|
| 465 |
+
src_camera_params=src_loaded["camera_params"],
|
| 466 |
+
tgt_camera_params=tgt_loaded["camera_params"],
|
| 467 |
+
src_idx=int(src_item["idx"]),
|
| 468 |
+
tgt_idx=int(frames[tgt_pos]["idx"]),
|
| 469 |
+
scene=str(scene_id),
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
def __iter__(self):
|
| 473 |
+
worker = torch.utils.data.get_worker_info()
|
| 474 |
+
worker_id = 0 if worker is None else int(worker.id)
|
| 475 |
+
num_workers = 1 if worker is None else int(worker.num_workers)
|
| 476 |
+
rng = random.Random(self.seed + 1009 * self.epoch + 97 * self.ddp_rank + 17 * worker_id)
|
| 477 |
+
specs = list(self.scene_specs)
|
| 478 |
+
if self.shuffle_scene:
|
| 479 |
+
rng.shuffle(specs)
|
| 480 |
+
specs = specs[self.ddp_rank :: max(self.ddp_world_size, 1)]
|
| 481 |
+
specs = specs[worker_id :: num_workers]
|
| 482 |
+
pending: dict[tuple[int, int], list[ScannetppFisheyePairSample]] = {}
|
| 483 |
+
for scene_id, scene_dir in specs:
|
| 484 |
+
for sample in self._iter_scene_pairs(scene_id, scene_dir, rng):
|
| 485 |
+
hw = (int(sample.src_rgb_u8.shape[-2]), int(sample.src_rgb_u8.shape[-1]))
|
| 486 |
+
bucket = pending.setdefault(hw, [])
|
| 487 |
+
bucket.append(sample)
|
| 488 |
+
while len(bucket) >= self.batch_size_hint:
|
| 489 |
+
packed = bucket[: self.batch_size_hint]
|
| 490 |
+
del bucket[: self.batch_size_hint]
|
| 491 |
+
yield _stack_batch(packed)
|
unisharp/datasets/sim_panorama.py
ADDED
|
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import csv
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import random
|
| 9 |
+
import re
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from torch.utils.data import IterableDataset
|
| 16 |
+
|
| 17 |
+
from unisharp.datasets.panogs import PanOGSSample
|
| 18 |
+
from unisharp import DEFAULT_MAX_DEPTH_M
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import h5py
|
| 22 |
+
except ImportError:
|
| 23 |
+
h5py = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
_NUM_RE = re.compile(r"(\d+)(?!.*\d)")
|
| 27 |
+
_SIM_CACHE_VERSION = 6
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _default_dataset_manifest_dir() -> Path:
|
| 31 |
+
repo_root = Path(__file__).resolve().parents[2]
|
| 32 |
+
parent_path = repo_root.parent / "dataset_manifests"
|
| 33 |
+
if parent_path.exists():
|
| 34 |
+
return parent_path
|
| 35 |
+
return repo_root / "dataset_manifests"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _frame_index_from_name(name: str) -> int | None:
|
| 39 |
+
match = _NUM_RE.search(Path(name).stem)
|
| 40 |
+
if match is None:
|
| 41 |
+
return None
|
| 42 |
+
return int(match.group(1))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _sim_csv_xyz_to_training_position(x: float, y: float, z: float) -> torch.Tensor:
|
| 46 |
+
return torch.tensor([float(y), -float(z), float(x)], dtype=torch.float32)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class _EquirecToCube:
|
| 50 |
+
def __init__(self, equ_h: int, equ_w: int, face_w: int) -> None:
|
| 51 |
+
self.equ_h = int(equ_h)
|
| 52 |
+
self.equ_w = int(equ_w)
|
| 53 |
+
self.face_w = int(face_w)
|
| 54 |
+
self.grid = self._build_grid()
|
| 55 |
+
rng = torch.linspace(-0.5, 0.5, steps=self.face_w, dtype=torch.float32)
|
| 56 |
+
xx, yy = torch.meshgrid(rng, -rng, indexing="xy")
|
| 57 |
+
self.ray_z = (1.0 / torch.sqrt((2.0 * xx) ** 2 + (2.0 * yy) ** 2 + 1.0)).contiguous()
|
| 58 |
+
|
| 59 |
+
def _build_grid(self) -> torch.Tensor:
|
| 60 |
+
face_w = self.face_w
|
| 61 |
+
rng = torch.linspace(-0.5, 0.5, steps=face_w, dtype=torch.float32)
|
| 62 |
+
grid = torch.stack(torch.meshgrid(rng, -rng, indexing="xy"), dim=-1)
|
| 63 |
+
xyz = torch.zeros((6, face_w, face_w, 3), dtype=torch.float32)
|
| 64 |
+
|
| 65 |
+
xyz[0, :, :, 0] = grid[:, :, 0]
|
| 66 |
+
xyz[0, :, :, 1] = grid[:, :, 1]
|
| 67 |
+
xyz[0, :, :, 2] = 0.5
|
| 68 |
+
|
| 69 |
+
xyz[1, :, :, 2] = torch.flip(grid[:, :, 0], dims=[1])
|
| 70 |
+
xyz[1, :, :, 1] = torch.flip(grid[:, :, 1], dims=[1])
|
| 71 |
+
xyz[1, :, :, 0] = 0.5
|
| 72 |
+
|
| 73 |
+
xyz[2, :, :, 0] = torch.flip(grid[:, :, 0], dims=[1])
|
| 74 |
+
xyz[2, :, :, 1] = torch.flip(grid[:, :, 1], dims=[1])
|
| 75 |
+
xyz[2, :, :, 2] = -0.5
|
| 76 |
+
|
| 77 |
+
xyz[3, :, :, 2] = grid[:, :, 0]
|
| 78 |
+
xyz[3, :, :, 1] = grid[:, :, 1]
|
| 79 |
+
xyz[3, :, :, 0] = -0.5
|
| 80 |
+
|
| 81 |
+
xyz[4, :, :, 0] = torch.flip(grid[:, :, 0], dims=[0])
|
| 82 |
+
xyz[4, :, :, 2] = torch.flip(grid[:, :, 1], dims=[0])
|
| 83 |
+
xyz[4, :, :, 1] = 0.5
|
| 84 |
+
|
| 85 |
+
xyz[5, :, :, 0] = grid[:, :, 0]
|
| 86 |
+
xyz[5, :, :, 2] = grid[:, :, 1]
|
| 87 |
+
xyz[5, :, :, 1] = -0.5
|
| 88 |
+
|
| 89 |
+
xyz = xyz[[4, 2, 3, 0, 1, 5]]
|
| 90 |
+
x = xyz[..., 0]
|
| 91 |
+
y = xyz[..., 1]
|
| 92 |
+
z = xyz[..., 2]
|
| 93 |
+
lon = torch.atan2(x, z)
|
| 94 |
+
c = torch.sqrt(x * x + z * z).clamp(min=1e-8)
|
| 95 |
+
lat = torch.atan2(y, c)
|
| 96 |
+
grid_x = lon / np.pi
|
| 97 |
+
grid_y = (-2.0 * lat / np.pi).clamp(min=-1.0, max=1.0)
|
| 98 |
+
return torch.stack([grid_x, grid_y], dim=-1).contiguous()
|
| 99 |
+
|
| 100 |
+
def run_depth(self, depth_1hw: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
depth = depth_1hw.unsqueeze(0).to(torch.float32)
|
| 102 |
+
if tuple(depth.shape[-2:]) != (self.equ_h, self.equ_w):
|
| 103 |
+
depth = F.interpolate(depth, size=(self.equ_h, self.equ_w), mode="nearest")
|
| 104 |
+
depth_faces = F.grid_sample(
|
| 105 |
+
depth.expand(6, -1, -1, -1),
|
| 106 |
+
self.grid,
|
| 107 |
+
mode="nearest",
|
| 108 |
+
padding_mode="border",
|
| 109 |
+
align_corners=True,
|
| 110 |
+
)
|
| 111 |
+
depth_faces = depth_faces[:, 0] * self.ray_z.to(depth_faces.device, depth_faces.dtype)
|
| 112 |
+
return depth_faces.unsqueeze(-1).to(torch.float32).cpu()
|
| 113 |
+
|
| 114 |
+
def run(self, rgb_chw: torch.Tensor, depth_1hw: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 115 |
+
rgb = rgb_chw.unsqueeze(0).to(torch.float32) / 255.0
|
| 116 |
+
if tuple(rgb.shape[-2:]) != (self.equ_h, self.equ_w):
|
| 117 |
+
rgb = F.interpolate(rgb, size=(self.equ_h, self.equ_w), mode="bilinear", align_corners=True)
|
| 118 |
+
rgb_faces = F.grid_sample(
|
| 119 |
+
rgb.expand(6, -1, -1, -1),
|
| 120 |
+
self.grid,
|
| 121 |
+
mode="bilinear",
|
| 122 |
+
padding_mode="border",
|
| 123 |
+
align_corners=True,
|
| 124 |
+
)
|
| 125 |
+
cube_rgb = (rgb_faces.permute(0, 2, 3, 1).clamp(0.0, 1.0) * 255.0).round().to(torch.uint8)
|
| 126 |
+
cube_depth = self.run_depth(depth_1hw)
|
| 127 |
+
return cube_rgb.cpu(), cube_depth
|
| 128 |
+
|
| 129 |
+
def run_rgb(self, rgb_chw: torch.Tensor) -> torch.Tensor:
|
| 130 |
+
rgb = rgb_chw.unsqueeze(0).to(torch.float32) / 255.0
|
| 131 |
+
if tuple(rgb.shape[-2:]) != (self.equ_h, self.equ_w):
|
| 132 |
+
rgb = F.interpolate(rgb, size=(self.equ_h, self.equ_w), mode="bilinear", align_corners=True)
|
| 133 |
+
rgb_faces = F.grid_sample(
|
| 134 |
+
rgb.expand(6, -1, -1, -1),
|
| 135 |
+
self.grid,
|
| 136 |
+
mode="bilinear",
|
| 137 |
+
padding_mode="border",
|
| 138 |
+
align_corners=True,
|
| 139 |
+
)
|
| 140 |
+
return (rgb_faces.permute(0, 2, 3, 1).clamp(0.0, 1.0) * 255.0).round().to(torch.uint8).cpu()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@dataclass(frozen=True)
|
| 144 |
+
class _SimFrame:
|
| 145 |
+
frame_idx: int
|
| 146 |
+
rgb_path: Path
|
| 147 |
+
depth_path: Path
|
| 148 |
+
position_xyz: torch.Tensor
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class SimPanoramaDataset(IterableDataset):
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
root: Path,
|
| 155 |
+
pose_root: Path,
|
| 156 |
+
scene_names: list[str] | None = None,
|
| 157 |
+
scene_list_file: Path | None = None,
|
| 158 |
+
position_scale: float = 0.01,
|
| 159 |
+
max_index_gap: int = 10,
|
| 160 |
+
pair_max_translation_m: float = 0.5,
|
| 161 |
+
pair_min_depth_overlap: float = 0.6,
|
| 162 |
+
pair_overlap_margin: float = 1.05,
|
| 163 |
+
pairs_per_chunk: int = 15,
|
| 164 |
+
chunk_size: int = 30,
|
| 165 |
+
shuffle_scene: bool = True,
|
| 166 |
+
ddp_rank: int = 0,
|
| 167 |
+
ddp_world_size: int = 1,
|
| 168 |
+
depth_max_m: float = DEFAULT_MAX_DEPTH_M,
|
| 169 |
+
far_depth_invalid_m: float = 30.0,
|
| 170 |
+
far_depth_invalid_max_frac: float = 1.0,
|
| 171 |
+
max_long_edge: int = 0,
|
| 172 |
+
seed: int = 0,
|
| 173 |
+
) -> None:
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.root = Path(root)
|
| 176 |
+
self.pose_root = Path(pose_root)
|
| 177 |
+
self.scene_list_file = Path(scene_list_file) if scene_list_file is not None else None
|
| 178 |
+
requested_scene_names = [str(name).strip() for name in (scene_names or []) if str(name).strip()]
|
| 179 |
+
if self.scene_list_file is not None:
|
| 180 |
+
if not self.scene_list_file.exists():
|
| 181 |
+
raise FileNotFoundError(self.scene_list_file)
|
| 182 |
+
manifest_scene_names = [
|
| 183 |
+
line.strip()
|
| 184 |
+
for line in self.scene_list_file.read_text(encoding="utf-8").splitlines()
|
| 185 |
+
if line.strip()
|
| 186 |
+
]
|
| 187 |
+
if requested_scene_names:
|
| 188 |
+
requested = set(requested_scene_names)
|
| 189 |
+
self.scene_names = [name for name in manifest_scene_names if name in requested]
|
| 190 |
+
else:
|
| 191 |
+
self.scene_names = manifest_scene_names
|
| 192 |
+
else:
|
| 193 |
+
self.scene_names = requested_scene_names
|
| 194 |
+
if not self.scene_names:
|
| 195 |
+
raise ValueError("SimPanoramaDataset requires scene_names or scene_list_file.")
|
| 196 |
+
self.position_scale = float(position_scale)
|
| 197 |
+
self.max_index_gap = int(max_index_gap)
|
| 198 |
+
self.pair_max_translation_m = float(pair_max_translation_m)
|
| 199 |
+
self.pair_min_depth_overlap = float(pair_min_depth_overlap)
|
| 200 |
+
self.pair_overlap_margin = float(pair_overlap_margin)
|
| 201 |
+
self.pairs_per_chunk = int(pairs_per_chunk)
|
| 202 |
+
self.chunk_size = int(chunk_size)
|
| 203 |
+
self.shuffle_scene = bool(shuffle_scene)
|
| 204 |
+
self.ddp_rank = int(ddp_rank)
|
| 205 |
+
self.ddp_world_size = int(ddp_world_size)
|
| 206 |
+
self.seed = int(seed)
|
| 207 |
+
self.depth_max_m = float(depth_max_m)
|
| 208 |
+
self.far_depth_invalid_m = float(far_depth_invalid_m)
|
| 209 |
+
self.far_depth_invalid_max_frac = float(far_depth_invalid_max_frac)
|
| 210 |
+
self.max_long_edge = max(int(max_long_edge), 0)
|
| 211 |
+
self.epoch = 0
|
| 212 |
+
self.cache_dir = _default_dataset_manifest_dir() / "sim_cache"
|
| 213 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 214 |
+
self._scene_frames_cache: dict[str, list[_SimFrame]] = {}
|
| 215 |
+
|
| 216 |
+
def set_epoch(self, epoch: int) -> None:
|
| 217 |
+
self.epoch = int(epoch)
|
| 218 |
+
|
| 219 |
+
@staticmethod
|
| 220 |
+
def _is_depth_path(path: Path) -> bool:
|
| 221 |
+
tokens = [part.lower() for part in path.parts]
|
| 222 |
+
name = path.name.lower()
|
| 223 |
+
return ("depth" in name) or any("depth" in token for token in tokens)
|
| 224 |
+
|
| 225 |
+
@staticmethod
|
| 226 |
+
def _is_image_path(path: Path) -> bool:
|
| 227 |
+
return path.suffix.lower() in (".png", ".jpg", ".jpeg", ".webp")
|
| 228 |
+
|
| 229 |
+
@staticmethod
|
| 230 |
+
def _load_rgb(path: Path) -> torch.Tensor:
|
| 231 |
+
with Image.open(path) as img:
|
| 232 |
+
img = img.convert("RGB")
|
| 233 |
+
arr = np.asarray(img, dtype=np.uint8).copy()
|
| 234 |
+
return torch.from_numpy(arr).permute(2, 0, 1).contiguous()
|
| 235 |
+
|
| 236 |
+
@staticmethod
|
| 237 |
+
def _image_hw(path: Path) -> tuple[int, int]:
|
| 238 |
+
with Image.open(path) as img:
|
| 239 |
+
width, height = img.size
|
| 240 |
+
return int(height), int(width)
|
| 241 |
+
|
| 242 |
+
def _load_depth(self, path: Path) -> torch.Tensor:
|
| 243 |
+
suffix = path.suffix.lower()
|
| 244 |
+
if suffix == ".npy":
|
| 245 |
+
dep = np.load(path)
|
| 246 |
+
elif suffix == ".npz":
|
| 247 |
+
payload = np.load(path)
|
| 248 |
+
key = "depth" if "depth" in payload.files else payload.files[0]
|
| 249 |
+
dep = payload[key]
|
| 250 |
+
elif suffix in (".h5", ".hdf5"):
|
| 251 |
+
if h5py is None:
|
| 252 |
+
raise ImportError("h5py is required to read sim .h5 depth files but is not installed.")
|
| 253 |
+
with h5py.File(path, "r") as f:
|
| 254 |
+
keys = list(f.keys())
|
| 255 |
+
if not keys:
|
| 256 |
+
raise RuntimeError(f"Empty sim depth file: {path}")
|
| 257 |
+
dep = f[keys[0]][()]
|
| 258 |
+
else:
|
| 259 |
+
with Image.open(path) as img:
|
| 260 |
+
dep = np.asarray(img)
|
| 261 |
+
dep = dep.astype(np.float32)
|
| 262 |
+
if dep.ndim == 3:
|
| 263 |
+
dep = dep[..., 0]
|
| 264 |
+
dep[~np.isfinite(dep)] = 0.0
|
| 265 |
+
if self.far_depth_invalid_m > 0.0:
|
| 266 |
+
far = dep > self.far_depth_invalid_m
|
| 267 |
+
if 0.0 < float(far.mean()) <= self.far_depth_invalid_max_frac:
|
| 268 |
+
dep[far] = 0.0
|
| 269 |
+
dep = np.clip(dep, a_min=0.0, a_max=self.depth_max_m)
|
| 270 |
+
return torch.from_numpy(dep).unsqueeze(0)
|
| 271 |
+
|
| 272 |
+
def _resize_erp_if_needed(self, rgb: torch.Tensor, depth: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 273 |
+
if self.max_long_edge <= 0:
|
| 274 |
+
return rgb, depth
|
| 275 |
+
h = int(rgb.shape[-2])
|
| 276 |
+
w = int(rgb.shape[-1])
|
| 277 |
+
long_edge = max(h, w)
|
| 278 |
+
if long_edge <= self.max_long_edge:
|
| 279 |
+
return rgb, depth
|
| 280 |
+
scale = float(self.max_long_edge) / float(long_edge)
|
| 281 |
+
new_h = max(2, int(round(float(h) * scale)))
|
| 282 |
+
new_w = max(2, int(round(float(w) * scale)))
|
| 283 |
+
rgb_f = rgb.unsqueeze(0).to(dtype=torch.float32)
|
| 284 |
+
rgb_resized = F.interpolate(rgb_f, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
| 285 |
+
rgb_out = rgb_resized[0].round().clamp(0.0, 255.0).to(dtype=torch.uint8).contiguous()
|
| 286 |
+
depth_f = depth.unsqueeze(0).to(dtype=torch.float32)
|
| 287 |
+
depth_out = F.interpolate(depth_f, size=(new_h, new_w), mode="nearest")[0].contiguous()
|
| 288 |
+
return rgb_out, depth_out
|
| 289 |
+
|
| 290 |
+
def _pose_csv_for_scene(self, scene_name: str) -> Path:
|
| 291 |
+
direct = self.pose_root / f"{scene_name}.csv"
|
| 292 |
+
if direct.exists():
|
| 293 |
+
return direct
|
| 294 |
+
matches = sorted(self.pose_root.glob(f"*{scene_name}*.csv"))
|
| 295 |
+
if matches:
|
| 296 |
+
return matches[0]
|
| 297 |
+
raise FileNotFoundError(f"No pose csv found for sim scene={scene_name} under {self.pose_root}")
|
| 298 |
+
|
| 299 |
+
def _parse_pose_csv(self, csv_path: Path) -> list[tuple[int, torch.Tensor]]:
|
| 300 |
+
with csv_path.open("r", encoding="utf-8") as f:
|
| 301 |
+
rows = list(csv.DictReader(f))
|
| 302 |
+
if not rows:
|
| 303 |
+
raise RuntimeError(f"Empty sim pose csv: {csv_path}")
|
| 304 |
+
poses: list[tuple[int, torch.Tensor]] = []
|
| 305 |
+
for row_idx, row in enumerate(rows):
|
| 306 |
+
lower = {str(k).strip().lower(): v for k, v in row.items()}
|
| 307 |
+
frame_val = None
|
| 308 |
+
for key in ("frame", "frame_idx", "idx", "index", "id", "image", "filename", "name"):
|
| 309 |
+
if key in lower and str(lower[key]).strip():
|
| 310 |
+
frame_val = _frame_index_from_name(str(lower[key]))
|
| 311 |
+
if frame_val is None:
|
| 312 |
+
try:
|
| 313 |
+
frame_val = int(float(str(lower[key]).strip()))
|
| 314 |
+
except Exception:
|
| 315 |
+
frame_val = None
|
| 316 |
+
break
|
| 317 |
+
x = next((lower[k] for k in lower if k in ("x", "tx", "pos_x", "world_x")), None)
|
| 318 |
+
y = next((lower[k] for k in lower if k in ("y", "ty", "pos_y", "world_y")), None)
|
| 319 |
+
z = next((lower[k] for k in lower if k in ("z", "tz", "pos_z", "world_z")), None)
|
| 320 |
+
if x is None or y is None or z is None:
|
| 321 |
+
numeric_vals = []
|
| 322 |
+
for val in row.values():
|
| 323 |
+
try:
|
| 324 |
+
numeric_vals.append(float(str(val).strip()))
|
| 325 |
+
except Exception:
|
| 326 |
+
continue
|
| 327 |
+
if len(numeric_vals) < 3:
|
| 328 |
+
raise ValueError(f"Failed to parse xyz from sim csv row: {row}")
|
| 329 |
+
x, y, z = numeric_vals[:3]
|
| 330 |
+
pos = _sim_csv_xyz_to_training_position(float(x), float(y), float(z)) * self.position_scale
|
| 331 |
+
poses.append((int(frame_val if frame_val is not None else row_idx), pos))
|
| 332 |
+
return poses
|
| 333 |
+
|
| 334 |
+
def _scan_scene_frames(self, scene_name: str) -> list[_SimFrame]:
|
| 335 |
+
scene_dir = self.root / scene_name
|
| 336 |
+
if not scene_dir.exists():
|
| 337 |
+
raise FileNotFoundError(scene_dir)
|
| 338 |
+
all_files = [p for p in scene_dir.rglob("*") if p.is_file()]
|
| 339 |
+
image_map: dict[int, Path] = {}
|
| 340 |
+
depth_map: dict[int, Path] = {}
|
| 341 |
+
for path in all_files:
|
| 342 |
+
idx = _frame_index_from_name(path.name)
|
| 343 |
+
if idx is None:
|
| 344 |
+
continue
|
| 345 |
+
if self._is_depth_path(path) and path.suffix.lower() in (".png", ".npy", ".npz", ".exr", ".h5", ".hdf5"):
|
| 346 |
+
depth_map.setdefault(idx, path)
|
| 347 |
+
elif self._is_image_path(path):
|
| 348 |
+
image_map.setdefault(idx, path)
|
| 349 |
+
pose_entries = self._parse_pose_csv(self._pose_csv_for_scene(scene_name))
|
| 350 |
+
frames: list[_SimFrame] = []
|
| 351 |
+
for frame_idx, pos in pose_entries:
|
| 352 |
+
rgb_path = image_map.get(int(frame_idx))
|
| 353 |
+
depth_path = depth_map.get(int(frame_idx))
|
| 354 |
+
if rgb_path is None or depth_path is None:
|
| 355 |
+
continue
|
| 356 |
+
frames.append(_SimFrame(frame_idx=int(frame_idx), rgb_path=rgb_path, depth_path=depth_path, position_xyz=pos))
|
| 357 |
+
return frames
|
| 358 |
+
|
| 359 |
+
@staticmethod
|
| 360 |
+
def _atomic_torch_save(path: Path, payload: object) -> None:
|
| 361 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 362 |
+
tmp_path = path.with_suffix(path.suffix + f".tmp.{os.getpid()}")
|
| 363 |
+
torch.save(payload, tmp_path)
|
| 364 |
+
os.replace(tmp_path, path)
|
| 365 |
+
|
| 366 |
+
def _scene_index_cache_path(self, scene_name: str) -> Path:
|
| 367 |
+
scene_key = scene_name.replace("/", "__")
|
| 368 |
+
return self.cache_dir / f"{scene_key}_ps{self.position_scale:g}_frames_v{_SIM_CACHE_VERSION}.pt"
|
| 369 |
+
|
| 370 |
+
def _load_or_build_scene_frames(self, scene_name: str) -> list[_SimFrame]:
|
| 371 |
+
cached = self._scene_frames_cache.get(scene_name)
|
| 372 |
+
if cached is not None:
|
| 373 |
+
return cached
|
| 374 |
+
cache_path = self._scene_index_cache_path(scene_name)
|
| 375 |
+
frames: list[_SimFrame]
|
| 376 |
+
if cache_path.exists():
|
| 377 |
+
try:
|
| 378 |
+
payload = torch.load(cache_path, map_location="cpu")
|
| 379 |
+
frames = [
|
| 380 |
+
_SimFrame(
|
| 381 |
+
frame_idx=int(item["frame_idx"]),
|
| 382 |
+
rgb_path=Path(str(item["rgb_path"])),
|
| 383 |
+
depth_path=Path(str(item["depth_path"])),
|
| 384 |
+
position_xyz=torch.tensor(item["position_xyz"], dtype=torch.float32),
|
| 385 |
+
)
|
| 386 |
+
for item in payload["frames"]
|
| 387 |
+
]
|
| 388 |
+
except Exception:
|
| 389 |
+
frames = self._scan_scene_frames(scene_name)
|
| 390 |
+
payload = {
|
| 391 |
+
"scene": scene_name,
|
| 392 |
+
"frames": [
|
| 393 |
+
{
|
| 394 |
+
"frame_idx": int(frame.frame_idx),
|
| 395 |
+
"rgb_path": str(frame.rgb_path),
|
| 396 |
+
"depth_path": str(frame.depth_path),
|
| 397 |
+
"position_xyz": frame.position_xyz.tolist(),
|
| 398 |
+
}
|
| 399 |
+
for frame in frames
|
| 400 |
+
],
|
| 401 |
+
}
|
| 402 |
+
self._atomic_torch_save(cache_path, payload)
|
| 403 |
+
else:
|
| 404 |
+
frames = self._scan_scene_frames(scene_name)
|
| 405 |
+
payload = {
|
| 406 |
+
"scene": scene_name,
|
| 407 |
+
"frames": [
|
| 408 |
+
{
|
| 409 |
+
"frame_idx": int(frame.frame_idx),
|
| 410 |
+
"rgb_path": str(frame.rgb_path),
|
| 411 |
+
"depth_path": str(frame.depth_path),
|
| 412 |
+
"position_xyz": frame.position_xyz.tolist(),
|
| 413 |
+
}
|
| 414 |
+
for frame in frames
|
| 415 |
+
],
|
| 416 |
+
}
|
| 417 |
+
self._atomic_torch_save(cache_path, payload)
|
| 418 |
+
self._scene_frames_cache[scene_name] = frames
|
| 419 |
+
return frames
|
| 420 |
+
|
| 421 |
+
def _random_chunk_pairs(self, chunk: list[_SimFrame], rng: random.Random) -> list[tuple[int, int]]:
|
| 422 |
+
if len(chunk) < self.chunk_size:
|
| 423 |
+
return []
|
| 424 |
+
indices = list(range(len(chunk)))
|
| 425 |
+
rng.shuffle(indices)
|
| 426 |
+
max_pairs = min(self.pairs_per_chunk, len(indices) // 2)
|
| 427 |
+
return [(indices[2 * i], indices[2 * i + 1]) for i in range(max_pairs)]
|
| 428 |
+
|
| 429 |
+
def __iter__(self):
|
| 430 |
+
scene_names = list(self.scene_names)
|
| 431 |
+
order_rng = random.Random(self.seed + self.epoch)
|
| 432 |
+
if self.shuffle_scene:
|
| 433 |
+
order_rng.shuffle(scene_names)
|
| 434 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 435 |
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
| 436 |
+
worker_id = worker_info.id if worker_info is not None else 0
|
| 437 |
+
total_shards = max(1, self.ddp_world_size * num_workers)
|
| 438 |
+
shard_id = self.ddp_rank * num_workers + worker_id
|
| 439 |
+
pair_unit_index = 0
|
| 440 |
+
|
| 441 |
+
for scene_order_idx, scene_name in enumerate(scene_names):
|
| 442 |
+
try:
|
| 443 |
+
frames = self._load_or_build_scene_frames(scene_name)
|
| 444 |
+
except Exception:
|
| 445 |
+
continue
|
| 446 |
+
if len(frames) < self.chunk_size:
|
| 447 |
+
continue
|
| 448 |
+
for start in range(0, len(frames), self.chunk_size):
|
| 449 |
+
chunk = frames[start : start + self.chunk_size]
|
| 450 |
+
if len(chunk) < self.chunk_size:
|
| 451 |
+
break
|
| 452 |
+
try:
|
| 453 |
+
equ_h, equ_w = self._image_hw(chunk[0].rgb_path)
|
| 454 |
+
if self.max_long_edge > 0 and max(equ_h, equ_w) > self.max_long_edge:
|
| 455 |
+
scale = float(self.max_long_edge) / float(max(equ_h, equ_w))
|
| 456 |
+
equ_h = max(2, int(round(float(equ_h) * scale)))
|
| 457 |
+
equ_w = max(2, int(round(float(equ_w) * scale)))
|
| 458 |
+
face_w = max(1, equ_h // 2)
|
| 459 |
+
converter = _EquirecToCube(equ_h=equ_h, equ_w=equ_w, face_w=face_w)
|
| 460 |
+
except Exception:
|
| 461 |
+
continue
|
| 462 |
+
def load_frame(local_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 463 |
+
frame = chunk[local_idx]
|
| 464 |
+
rgb = self._load_rgb(frame.rgb_path)
|
| 465 |
+
depth = self._load_depth(frame.depth_path)
|
| 466 |
+
rgb, depth = self._resize_erp_if_needed(rgb, depth)
|
| 467 |
+
cube_rgb, cube_depth = converter.run(rgb, depth)
|
| 468 |
+
return rgb, depth, cube_rgb, cube_depth
|
| 469 |
+
|
| 470 |
+
chunk_rng = random.Random(
|
| 471 |
+
self.seed + self.epoch * 1000003 + scene_order_idx * 1009 + start
|
| 472 |
+
)
|
| 473 |
+
pairs = self._random_chunk_pairs(chunk, chunk_rng)
|
| 474 |
+
for src_local, tgt_local in pairs:
|
| 475 |
+
if pair_unit_index % total_shards != shard_id:
|
| 476 |
+
pair_unit_index += 1
|
| 477 |
+
continue
|
| 478 |
+
pair_unit_index += 1
|
| 479 |
+
src_rgb, src_depth, src_cube_rgb, src_cube_depth = load_frame(src_local)
|
| 480 |
+
tgt_rgb, tgt_depth, tgt_cube_rgb, tgt_cube_depth = load_frame(tgt_local)
|
| 481 |
+
yield PanOGSSample(
|
| 482 |
+
src_erp_rgb_u8=src_rgb,
|
| 483 |
+
tgt_erp_rgb_u8=tgt_rgb,
|
| 484 |
+
src_erp_depth_m=src_depth,
|
| 485 |
+
tgt_erp_depth_m=tgt_depth,
|
| 486 |
+
src_cube_rgb_u8=src_cube_rgb,
|
| 487 |
+
tgt_cube_rgb_u8=tgt_cube_rgb,
|
| 488 |
+
src_cube_depth_m=src_cube_depth,
|
| 489 |
+
tgt_cube_depth_m=tgt_cube_depth,
|
| 490 |
+
src_R=torch.eye(3, dtype=torch.float32),
|
| 491 |
+
src_t=chunk[src_local].position_xyz.clone(),
|
| 492 |
+
tgt_R=torch.eye(3, dtype=torch.float32),
|
| 493 |
+
tgt_t=chunk[tgt_local].position_xyz.clone(),
|
| 494 |
+
src_idx=int(chunk[src_local].frame_idx),
|
| 495 |
+
tgt_idx=int(chunk[tgt_local].frame_idx),
|
| 496 |
+
scene=str(scene_name),
|
| 497 |
+
)
|
unisharp/datasets/wildrgbd.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torch.utils.data import IterableDataset
|
| 13 |
+
|
| 14 |
+
from unisharp.datasets.pair_sampling import project_overlap_ratio, resize_k3_align_corners_false, resize_rgb_u8_chw_high_quality
|
| 15 |
+
from unisharp import DEFAULT_MAX_DEPTH_M
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass(frozen=True)
|
| 19 |
+
class WildRGBDPairSample:
|
| 20 |
+
src_rgb_u8: torch.Tensor
|
| 21 |
+
tgt_rgb_u8: torch.Tensor
|
| 22 |
+
src_depth_m: torch.Tensor
|
| 23 |
+
tgt_depth_m: torch.Tensor
|
| 24 |
+
src_w2c: torch.Tensor
|
| 25 |
+
tgt_w2c: torch.Tensor
|
| 26 |
+
src_intrinsics: torch.Tensor
|
| 27 |
+
tgt_intrinsics: torch.Tensor
|
| 28 |
+
src_idx: int
|
| 29 |
+
tgt_idx: int
|
| 30 |
+
scene: str
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WildRGBDDataset(IterableDataset):
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
root: Path | None = None,
|
| 38 |
+
scene_list_file: Path | None = None,
|
| 39 |
+
split: str = "train",
|
| 40 |
+
min_frame_gap: int = 1,
|
| 41 |
+
max_frame_gap: int = 32,
|
| 42 |
+
pair_max_translation_m: float = 0.5,
|
| 43 |
+
pair_min_overlap: float = 0.6,
|
| 44 |
+
pair_overlap_sample_h: int = 32,
|
| 45 |
+
pair_overlap_sample_w: int = 56,
|
| 46 |
+
pair_max_tries: int = 32,
|
| 47 |
+
output_h: int | None = None,
|
| 48 |
+
output_w: int | None = None,
|
| 49 |
+
shuffle_scene: bool = True,
|
| 50 |
+
shuffle_frame: bool = True,
|
| 51 |
+
ddp_rank: int = 0,
|
| 52 |
+
ddp_world_size: int = 1,
|
| 53 |
+
roots: list[Path] | None = None,
|
| 54 |
+
depth_max_m: float = DEFAULT_MAX_DEPTH_M,
|
| 55 |
+
seed: int = 0,
|
| 56 |
+
verify_manifest_paths: bool = False,
|
| 57 |
+
) -> None:
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.root = root
|
| 60 |
+
self.split = split
|
| 61 |
+
self.min_frame_gap = int(min_frame_gap)
|
| 62 |
+
self.max_frame_gap = int(max_frame_gap)
|
| 63 |
+
self.pair_max_translation_m = float(pair_max_translation_m)
|
| 64 |
+
self.pair_min_overlap = float(pair_min_overlap)
|
| 65 |
+
self.pair_overlap_sample_h = int(pair_overlap_sample_h)
|
| 66 |
+
self.pair_overlap_sample_w = int(pair_overlap_sample_w)
|
| 67 |
+
self.pair_max_tries = int(pair_max_tries)
|
| 68 |
+
self.output_h = int(output_h) if output_h is not None else None
|
| 69 |
+
self.output_w = int(output_w) if output_w is not None else None
|
| 70 |
+
self.shuffle_scene = bool(shuffle_scene)
|
| 71 |
+
self.shuffle_frame = bool(shuffle_frame)
|
| 72 |
+
self.ddp_rank = int(ddp_rank)
|
| 73 |
+
self.ddp_world_size = int(ddp_world_size)
|
| 74 |
+
self.depth_max_m = float(depth_max_m)
|
| 75 |
+
self.seed = int(seed)
|
| 76 |
+
self.epoch = 0
|
| 77 |
+
self.verify_manifest_paths = bool(verify_manifest_paths)
|
| 78 |
+
self.roots = [Path(p) for p in roots] if roots is not None else ([Path(root)] if root is not None else [])
|
| 79 |
+
if not self.roots:
|
| 80 |
+
raise ValueError("WildRGBDDataset requires at least one root path.")
|
| 81 |
+
self.scene_dirs: list[Path] = []
|
| 82 |
+
self.scene_list_file = Path(scene_list_file) if scene_list_file is not None else None
|
| 83 |
+
if self.scene_list_file is not None:
|
| 84 |
+
if not self.scene_list_file.exists():
|
| 85 |
+
raise FileNotFoundError(self.scene_list_file)
|
| 86 |
+
for raw in self.scene_list_file.read_text(encoding="utf-8").splitlines():
|
| 87 |
+
line = raw.strip()
|
| 88 |
+
if not line:
|
| 89 |
+
continue
|
| 90 |
+
scene_dir = Path(line)
|
| 91 |
+
if (not self.verify_manifest_paths) or scene_dir.is_dir():
|
| 92 |
+
self.scene_dirs.append(scene_dir)
|
| 93 |
+
else:
|
| 94 |
+
for ds_root in self.roots:
|
| 95 |
+
split_dir = ds_root / self.split
|
| 96 |
+
if not split_dir.exists():
|
| 97 |
+
raise FileNotFoundError(split_dir)
|
| 98 |
+
self.scene_dirs.extend(sorted([p for p in split_dir.iterdir() if p.is_dir()]))
|
| 99 |
+
if not self.scene_dirs:
|
| 100 |
+
raise RuntimeError("No scene folders found in the configured WildRGBD roots.")
|
| 101 |
+
|
| 102 |
+
def set_epoch(self, epoch: int) -> None:
|
| 103 |
+
self.epoch = int(epoch)
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def _load_scene_pose_and_k(scene_dir: Path) -> tuple[np.ndarray, dict[int, np.ndarray], torch.Tensor]:
|
| 107 |
+
metadata_path = scene_dir / "metadata"
|
| 108 |
+
with metadata_path.open("r", encoding="utf-8") as f:
|
| 109 |
+
meta = json.load(f)
|
| 110 |
+
k_raw = np.asarray(meta["K"], dtype=np.float32).reshape(3, 3).T
|
| 111 |
+
k = torch.from_numpy(k_raw.copy()).to(torch.float32)
|
| 112 |
+
|
| 113 |
+
pose_path = scene_dir / "cam_poses.txt"
|
| 114 |
+
pose_rows = np.genfromtxt(str(pose_path), dtype=np.float32)
|
| 115 |
+
if pose_rows.ndim == 1:
|
| 116 |
+
pose_rows = pose_rows[None, :]
|
| 117 |
+
if pose_rows.shape[1] < 17:
|
| 118 |
+
raise ValueError(f"Bad cam_poses.txt shape={pose_rows.shape} at {pose_path}")
|
| 119 |
+
frame_ids = pose_rows[:, 0].astype(np.int64)
|
| 120 |
+
c2w = pose_rows[:, 1:17].reshape(-1, 4, 4).astype(np.float32)
|
| 121 |
+
w2c = np.linalg.inv(c2w).astype(np.float32)
|
| 122 |
+
w2c_map = {int(fid): w2c[i] for i, fid in enumerate(frame_ids.tolist())}
|
| 123 |
+
return frame_ids, w2c_map, k
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def _collect_frame_ids(folder: Path) -> set[int]:
|
| 127 |
+
ids: set[int] = set()
|
| 128 |
+
if not folder.exists():
|
| 129 |
+
return ids
|
| 130 |
+
for p in folder.iterdir():
|
| 131 |
+
if not p.is_file():
|
| 132 |
+
continue
|
| 133 |
+
if p.suffix.lower() not in (".png", ".jpg", ".jpeg"):
|
| 134 |
+
continue
|
| 135 |
+
try:
|
| 136 |
+
ids.add(int(p.stem))
|
| 137 |
+
except ValueError:
|
| 138 |
+
continue
|
| 139 |
+
return ids
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
def _resolve_img_path(folder: Path, idx: int) -> Path:
|
| 143 |
+
for ext in (".png", ".jpg", ".jpeg"):
|
| 144 |
+
p = folder / f"{idx:05d}{ext}"
|
| 145 |
+
if p.exists():
|
| 146 |
+
return p
|
| 147 |
+
raise FileNotFoundError(folder / f"{idx:05d}.png")
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def _load_rgb_u8(path: Path) -> torch.Tensor:
|
| 151 |
+
img = Image.open(path).convert("RGB")
|
| 152 |
+
arr = np.asarray(img, dtype=np.uint8).copy()
|
| 153 |
+
return torch.from_numpy(arr).permute(2, 0, 1).contiguous()
|
| 154 |
+
|
| 155 |
+
def _load_depth_m(self, depth_path: Path) -> torch.Tensor:
|
| 156 |
+
dep = np.asarray(Image.open(depth_path))
|
| 157 |
+
if dep.ndim != 2:
|
| 158 |
+
raise ValueError(f"Expected single-channel depth at {depth_path}, got {dep.shape}")
|
| 159 |
+
depth = dep.astype(np.float32)
|
| 160 |
+
if float(np.nanmax(depth)) > 200.0:
|
| 161 |
+
depth = depth / 1000.0
|
| 162 |
+
depth[~np.isfinite(depth)] = 0.0
|
| 163 |
+
depth = np.clip(depth, a_min=0.0, a_max=self.depth_max_m)
|
| 164 |
+
return torch.from_numpy(depth).unsqueeze(0).to(torch.float32)
|
| 165 |
+
|
| 166 |
+
@staticmethod
|
| 167 |
+
def _scene_name(scene_dir: Path) -> str:
|
| 168 |
+
parent = scene_dir.parent.parent.name if scene_dir.parent.name == "scenes" else scene_dir.parent.name
|
| 169 |
+
return f"{parent}/{scene_dir.name}"
|
| 170 |
+
|
| 171 |
+
def _sample_target_for_src(
|
| 172 |
+
self,
|
| 173 |
+
src_idx: int,
|
| 174 |
+
valid_ids: list[int],
|
| 175 |
+
w2c_map: dict[int, np.ndarray],
|
| 176 |
+
intr: torch.Tensor,
|
| 177 |
+
h: int,
|
| 178 |
+
w: int,
|
| 179 |
+
rng: random.Random,
|
| 180 |
+
) -> int | None:
|
| 181 |
+
src_w2c = torch.from_numpy(w2c_map[int(src_idx)]).to(torch.float32)
|
| 182 |
+
src_center = torch.linalg.inv(src_w2c)[:3, 3]
|
| 183 |
+
candidates: list[int] = []
|
| 184 |
+
for j in valid_ids:
|
| 185 |
+
if int(j) == int(src_idx):
|
| 186 |
+
continue
|
| 187 |
+
gap = abs(int(j) - int(src_idx))
|
| 188 |
+
if gap < self.min_frame_gap or gap > self.max_frame_gap:
|
| 189 |
+
continue
|
| 190 |
+
jw2c = torch.from_numpy(w2c_map[int(j)]).to(torch.float32)
|
| 191 |
+
jcenter = torch.linalg.inv(jw2c)[:3, 3]
|
| 192 |
+
trans = torch.norm(jcenter - src_center, p=2).item()
|
| 193 |
+
if trans > self.pair_max_translation_m:
|
| 194 |
+
continue
|
| 195 |
+
candidates.append(int(j))
|
| 196 |
+
if not candidates:
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
rng.shuffle(candidates)
|
| 200 |
+
tries = min(self.pair_max_tries, len(candidates))
|
| 201 |
+
src_k = intr.to(torch.float32)
|
| 202 |
+
src_w2c_t = src_w2c.to(torch.float32)
|
| 203 |
+
for j in candidates[:tries]:
|
| 204 |
+
tgt_w2c_t = torch.from_numpy(w2c_map[int(j)]).to(torch.float32)
|
| 205 |
+
ov_st = project_overlap_ratio(
|
| 206 |
+
src_w2c=src_w2c_t,
|
| 207 |
+
tgt_w2c=tgt_w2c_t,
|
| 208 |
+
src_k=src_k,
|
| 209 |
+
tgt_k=src_k,
|
| 210 |
+
h=h,
|
| 211 |
+
w=w,
|
| 212 |
+
sample_h=self.pair_overlap_sample_h,
|
| 213 |
+
sample_w=self.pair_overlap_sample_w,
|
| 214 |
+
)
|
| 215 |
+
ov_ts = project_overlap_ratio(
|
| 216 |
+
src_w2c=tgt_w2c_t,
|
| 217 |
+
tgt_w2c=src_w2c_t,
|
| 218 |
+
src_k=src_k,
|
| 219 |
+
tgt_k=src_k,
|
| 220 |
+
h=h,
|
| 221 |
+
w=w,
|
| 222 |
+
sample_h=self.pair_overlap_sample_h,
|
| 223 |
+
sample_w=self.pair_overlap_sample_w,
|
| 224 |
+
)
|
| 225 |
+
if 0.5 * (ov_st + ov_ts) >= self.pair_min_overlap:
|
| 226 |
+
return int(j)
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
def __iter__(self):
|
| 230 |
+
scenes = list(self.scene_dirs)
|
| 231 |
+
order_rng = random.Random(self.seed + self.epoch)
|
| 232 |
+
if self.shuffle_scene:
|
| 233 |
+
order_rng.shuffle(scenes)
|
| 234 |
+
|
| 235 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 236 |
+
num_workers = worker_info.num_workers if worker_info is not None else 1
|
| 237 |
+
worker_id = worker_info.id if worker_info is not None else 0
|
| 238 |
+
total_shards = max(1, self.ddp_world_size * num_workers)
|
| 239 |
+
shard_id = self.ddp_rank * num_workers + worker_id
|
| 240 |
+
src_unit_index = 0
|
| 241 |
+
|
| 242 |
+
for scene_order_idx, scene_dir in enumerate(scenes):
|
| 243 |
+
try:
|
| 244 |
+
pose_ids_np, w2c_map, intr = self._load_scene_pose_and_k(scene_dir)
|
| 245 |
+
except Exception:
|
| 246 |
+
continue
|
| 247 |
+
pose_ids = {int(x) for x in pose_ids_np.tolist()}
|
| 248 |
+
rgb_ids = self._collect_frame_ids(scene_dir / "rgb")
|
| 249 |
+
dep_ids = self._collect_frame_ids(scene_dir / "depth")
|
| 250 |
+
valid_ids = sorted(list(pose_ids & rgb_ids & dep_ids))
|
| 251 |
+
if len(valid_ids) < 2:
|
| 252 |
+
continue
|
| 253 |
+
|
| 254 |
+
src_order = list(valid_ids)
|
| 255 |
+
scene_rng = random.Random(self.seed + self.epoch * 1000003 + scene_order_idx)
|
| 256 |
+
if self.shuffle_frame:
|
| 257 |
+
scene_rng.shuffle(src_order)
|
| 258 |
+
|
| 259 |
+
for src_idx in src_order:
|
| 260 |
+
if src_unit_index % total_shards != shard_id:
|
| 261 |
+
src_unit_index += 1
|
| 262 |
+
continue
|
| 263 |
+
src_unit_index += 1
|
| 264 |
+
try:
|
| 265 |
+
rgb_src_path = self._resolve_img_path(scene_dir / "rgb", int(src_idx))
|
| 266 |
+
dep_src_path = self._resolve_img_path(scene_dir / "depth", int(src_idx))
|
| 267 |
+
src_img = self._load_rgb_u8(rgb_src_path)
|
| 268 |
+
src_depth = self._load_depth_m(dep_src_path)
|
| 269 |
+
except Exception:
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
h, w = int(src_img.shape[1]), int(src_img.shape[2])
|
| 273 |
+
tgt_idx = self._sample_target_for_src(
|
| 274 |
+
src_idx=int(src_idx),
|
| 275 |
+
valid_ids=valid_ids,
|
| 276 |
+
w2c_map=w2c_map,
|
| 277 |
+
intr=intr,
|
| 278 |
+
h=h,
|
| 279 |
+
w=w,
|
| 280 |
+
rng=scene_rng,
|
| 281 |
+
)
|
| 282 |
+
if tgt_idx is None:
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
rgb_tgt_path = self._resolve_img_path(scene_dir / "rgb", int(tgt_idx))
|
| 287 |
+
dep_tgt_path = self._resolve_img_path(scene_dir / "depth", int(tgt_idx))
|
| 288 |
+
tgt_img = self._load_rgb_u8(rgb_tgt_path)
|
| 289 |
+
tgt_depth = self._load_depth_m(dep_tgt_path)
|
| 290 |
+
except Exception:
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
if src_img.shape != tgt_img.shape:
|
| 294 |
+
continue
|
| 295 |
+
|
| 296 |
+
src_intr = intr.clone()
|
| 297 |
+
tgt_intr = intr.clone()
|
| 298 |
+
if self.output_h is not None and self.output_w is not None:
|
| 299 |
+
oh, ow = int(src_img.shape[1]), int(src_img.shape[2])
|
| 300 |
+
if oh > 0 and ow > 0 and (oh != self.output_h or ow != self.output_w):
|
| 301 |
+
sx = float(self.output_w) / float(ow)
|
| 302 |
+
sy = float(self.output_h) / float(oh)
|
| 303 |
+
src_img = resize_rgb_u8_chw_high_quality(src_img, size=(self.output_h, self.output_w))
|
| 304 |
+
tgt_img = resize_rgb_u8_chw_high_quality(tgt_img, size=(self.output_h, self.output_w))
|
| 305 |
+
src_depth = F.interpolate(
|
| 306 |
+
src_depth.unsqueeze(0),
|
| 307 |
+
size=(self.output_h, self.output_w),
|
| 308 |
+
mode="nearest",
|
| 309 |
+
).squeeze(0)
|
| 310 |
+
tgt_depth = F.interpolate(
|
| 311 |
+
tgt_depth.unsqueeze(0),
|
| 312 |
+
size=(self.output_h, self.output_w),
|
| 313 |
+
mode="nearest",
|
| 314 |
+
).squeeze(0)
|
| 315 |
+
src_intr = resize_k3_align_corners_false(src_intr, sx=sx, sy=sy)
|
| 316 |
+
tgt_intr = resize_k3_align_corners_false(tgt_intr, sx=sx, sy=sy)
|
| 317 |
+
|
| 318 |
+
yield WildRGBDPairSample(
|
| 319 |
+
src_rgb_u8=src_img,
|
| 320 |
+
tgt_rgb_u8=tgt_img,
|
| 321 |
+
src_depth_m=src_depth,
|
| 322 |
+
tgt_depth_m=tgt_depth,
|
| 323 |
+
src_w2c=torch.from_numpy(w2c_map[int(src_idx)]).to(torch.float32),
|
| 324 |
+
tgt_w2c=torch.from_numpy(w2c_map[int(tgt_idx)]).to(torch.float32),
|
| 325 |
+
src_intrinsics=src_intr,
|
| 326 |
+
tgt_intrinsics=tgt_intr,
|
| 327 |
+
src_idx=int(src_idx),
|
| 328 |
+
tgt_idx=int(tgt_idx),
|
| 329 |
+
scene=self._scene_name(scene_dir),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def wildrgbd_collate(batch: list[WildRGBDPairSample]) -> WildRGBDPairSample:
|
| 334 |
+
def stack(xs):
|
| 335 |
+
if isinstance(xs[0], torch.Tensor):
|
| 336 |
+
return torch.stack(xs, dim=0)
|
| 337 |
+
return xs
|
| 338 |
+
|
| 339 |
+
return WildRGBDPairSample(
|
| 340 |
+
src_rgb_u8=stack([b.src_rgb_u8 for b in batch]),
|
| 341 |
+
tgt_rgb_u8=stack([b.tgt_rgb_u8 for b in batch]),
|
| 342 |
+
src_depth_m=stack([b.src_depth_m for b in batch]),
|
| 343 |
+
tgt_depth_m=stack([b.tgt_depth_m for b in batch]),
|
| 344 |
+
src_w2c=stack([b.src_w2c for b in batch]),
|
| 345 |
+
tgt_w2c=stack([b.tgt_w2c for b in batch]),
|
| 346 |
+
src_intrinsics=stack([b.src_intrinsics for b in batch]),
|
| 347 |
+
tgt_intrinsics=stack([b.tgt_intrinsics for b in batch]),
|
| 348 |
+
src_idx=[b.src_idx for b in batch], # type: ignore[arg-type]
|
| 349 |
+
tgt_idx=[b.tgt_idx for b in batch], # type: ignore[arg-type]
|
| 350 |
+
scene=[b.scene for b in batch], # type: ignore[arg-type]
|
| 351 |
+
)
|
| 352 |
+
|
unisharp/losses/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .unisharp_loss import UnisharpLoss, UnisharpLossWeights
|
| 2 |
+
|
| 3 |
+
__all__ = ["UnisharpLoss", "UnisharpLossWeights"]
|
| 4 |
+
|
unisharp/losses/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (307 Bytes). View file
|
|
|
unisharp/losses/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (323 Bytes). View file
|
|
|
unisharp/losses/__pycache__/unisharp_loss.cpython-310.pyc
ADDED
|
Binary file (32.9 kB). View file
|
|
|
unisharp/losses/__pycache__/unisharp_loss.cpython-313.pyc
ADDED
|
Binary file (71.3 kB). View file
|
|
|
unisharp/losses/unisharp_loss.py
ADDED
|
@@ -0,0 +1,1120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
import math
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from unisharp import DEFAULT_MAX_DEPTH_M
|
| 10 |
+
from unisharp.utils import linalg
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _masked_mean(x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
if m.dtype != x.dtype:
|
| 15 |
+
m = m.to(dtype=x.dtype)
|
| 16 |
+
while m.ndim < x.ndim:
|
| 17 |
+
m = m.unsqueeze(1)
|
| 18 |
+
m_expanded = m.expand_as(x)
|
| 19 |
+
return (x * m_expanded).sum() / m_expanded.sum().clamp(min=1.0)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _finite_masked_mean_flat(x: torch.Tensor, valid: torch.Tensor) -> torch.Tensor:
|
| 23 |
+
mask = valid.to(device=x.device, dtype=torch.bool) & torch.isfinite(x)
|
| 24 |
+
x_safe = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
|
| 25 |
+
safe = torch.where(mask, x_safe, torch.zeros_like(x_safe))
|
| 26 |
+
return safe.sum() / mask.to(dtype=x.dtype).sum().clamp(min=1.0)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _finite_abs_mean(x: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
mask = torch.isfinite(x)
|
| 31 |
+
x_safe = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
|
| 32 |
+
safe_abs = torch.where(mask, x_safe.abs(), torch.zeros_like(x_safe))
|
| 33 |
+
return safe_abs.sum() / mask.to(dtype=x.dtype).sum().clamp(min=1.0)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
_ERP_PROJECTION_MODELS = {"erp", "spherical", "equirect", "equirectangular"}
|
| 37 |
+
_FISHEYE_PROJECTION_MODELS = {"fisheye624", "opencv_fisheye"}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _tv_l1(img: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
zero = torch.zeros((), device=img.device, dtype=img.dtype)
|
| 42 |
+
dx = (img[..., :, 1:] - img[..., :, :-1]).abs().mean() if int(img.shape[-1]) > 1 else zero
|
| 43 |
+
dy = (img[..., 1:, :] - img[..., :-1, :]).abs().mean() if int(img.shape[-2]) > 1 else zero
|
| 44 |
+
return dx + dy
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _tv_l1_circular_h(img: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
zero = torch.zeros((), device=img.device, dtype=img.dtype)
|
| 49 |
+
dx = (torch.roll(img, shifts=-1, dims=-1) - img).abs().mean() if int(img.shape[-1]) > 1 else zero
|
| 50 |
+
dy = (img[..., 1:, :] - img[..., :-1, :]).abs().mean() if int(img.shape[-2]) > 1 else zero
|
| 51 |
+
return dx + dy
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _checkerboard_l1_5d(x: torch.Tensor, *, circular_h: bool) -> torch.Tensor:
|
| 55 |
+
if x.ndim != 5:
|
| 56 |
+
raise ValueError(f"Expected [B,C,L,H,W], got {tuple(x.shape)}")
|
| 57 |
+
if int(x.shape[-2]) < 2 or int(x.shape[-1]) < 2:
|
| 58 |
+
return torch.zeros((), device=x.device, dtype=x.dtype)
|
| 59 |
+
x = x.to(dtype=torch.float32)
|
| 60 |
+
if bool(circular_h):
|
| 61 |
+
top = x[..., :-1, :]
|
| 62 |
+
bottom = x[..., 1:, :]
|
| 63 |
+
response = top - torch.roll(top, shifts=-1, dims=-1) - bottom + torch.roll(bottom, shifts=-1, dims=-1)
|
| 64 |
+
else:
|
| 65 |
+
response = x[..., :-1, :-1] - x[..., :-1, 1:] - x[..., 1:, :-1] + x[..., 1:, 1:]
|
| 66 |
+
return _finite_abs_mean(response)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _delta_grid_checkerboard_loss(delta_grid: torch.Tensor, *, circular_h: bool) -> torch.Tensor:
|
| 70 |
+
if delta_grid.ndim != 5 or int(delta_grid.shape[1]) < 14:
|
| 71 |
+
raise ValueError(f"Expected delta grid [B,14,L,H,W], got {tuple(delta_grid.shape)}")
|
| 72 |
+
delta = delta_grid.to(dtype=torch.float32)
|
| 73 |
+
parts = [
|
| 74 |
+
delta[:, 3:6],
|
| 75 |
+
0.1 * delta[:, 10:13],
|
| 76 |
+
delta[:, 13:14],
|
| 77 |
+
]
|
| 78 |
+
return torch.stack([_checkerboard_l1_5d(part, circular_h=circular_h) for part in parts]).mean()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _avg_pool2d_circular_h(x: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
|
| 82 |
+
if kernel_size <= 1 and stride <= 1:
|
| 83 |
+
return x
|
| 84 |
+
x = F.pad(x, (kernel_size - 1, 0, 0, 0), mode="circular")
|
| 85 |
+
return F.avg_pool2d(x, kernel_size=kernel_size, stride=stride)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _resize_max_side(img: torch.Tensor, max_side: int, *, mode: str = "bilinear") -> torch.Tensor:
|
| 89 |
+
if max_side <= 0:
|
| 90 |
+
return img
|
| 91 |
+
h, w = int(img.shape[-2]), int(img.shape[-1])
|
| 92 |
+
ms = max(h, w)
|
| 93 |
+
if ms <= max_side:
|
| 94 |
+
return img
|
| 95 |
+
scale = float(max_side) / float(ms)
|
| 96 |
+
nh = max(1, int(math.floor(h * scale)))
|
| 97 |
+
nw = max(1, int(math.floor(w * scale)))
|
| 98 |
+
if mode in ("bilinear", "bicubic"):
|
| 99 |
+
return F.interpolate(img, size=(nh, nw), mode=mode, align_corners=False)
|
| 100 |
+
return F.interpolate(img, size=(nh, nw), mode=mode)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _gram_matrix(fmap: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
b, c, h, w = fmap.shape
|
| 105 |
+
x = fmap.reshape(b, c, h * w)
|
| 106 |
+
return x @ x.transpose(1, 2)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class _ResNet50Perceptual(nn.Module):
|
| 110 |
+
|
| 111 |
+
def __init__(self) -> None:
|
| 112 |
+
super().__init__()
|
| 113 |
+
try:
|
| 114 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
| 115 |
+
|
| 116 |
+
net = resnet50(weights=ResNet50_Weights.DEFAULT)
|
| 117 |
+
except Exception:
|
| 118 |
+
from torchvision.models import resnet50
|
| 119 |
+
|
| 120 |
+
net = resnet50(pretrained=True)
|
| 121 |
+
|
| 122 |
+
net.eval()
|
| 123 |
+
net.requires_grad_(False)
|
| 124 |
+
|
| 125 |
+
self.conv1 = net.conv1
|
| 126 |
+
self.bn1 = net.bn1
|
| 127 |
+
self.relu = net.relu
|
| 128 |
+
self.maxpool = net.maxpool
|
| 129 |
+
self.layer1 = net.layer1
|
| 130 |
+
self.layer2 = net.layer2
|
| 131 |
+
self.layer3 = net.layer3
|
| 132 |
+
self.layer4 = net.layer4
|
| 133 |
+
|
| 134 |
+
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
| 135 |
+
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
| 136 |
+
self.register_buffer("_mean", mean, persistent=False)
|
| 137 |
+
self.register_buffer("_std", std, persistent=False)
|
| 138 |
+
|
| 139 |
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
| 140 |
+
x = x.clamp(0.0, 1.0)
|
| 141 |
+
x = (x - self._mean) / self._std
|
| 142 |
+
x = self.conv1(x)
|
| 143 |
+
x = self.bn1(x)
|
| 144 |
+
x = self.relu(x)
|
| 145 |
+
x = self.maxpool(x)
|
| 146 |
+
f1 = self.layer1(x)
|
| 147 |
+
f2 = self.layer2(f1)
|
| 148 |
+
f3 = self.layer3(f2)
|
| 149 |
+
f4 = self.layer4(f3)
|
| 150 |
+
return [f1, f2, f3, f4]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _to_linear_rgb(img_srgb: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
from unisharp.utils.color_space import sRGB2linearRGB
|
| 155 |
+
|
| 156 |
+
return sRGB2linearRGB(img_srgb.clamp(0.0, 1.0))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@dataclass
|
| 160 |
+
class UnisharpLossWeights:
|
| 161 |
+
|
| 162 |
+
lambda_color: float = 1.0
|
| 163 |
+
lambda_alpha: float = 1.5
|
| 164 |
+
lambda_percep: float = 3.0
|
| 165 |
+
lambda_depth: float = 0.5
|
| 166 |
+
lambda_tv: float = 1.0
|
| 167 |
+
lambda_grad: float = 1.0
|
| 168 |
+
lambda_delta: float = 0.0
|
| 169 |
+
lambda_delta_rho: float = 0.0
|
| 170 |
+
lambda_splat: float = 0.0
|
| 171 |
+
lambda_edge_splat: float = 0.0
|
| 172 |
+
lambda_grid: float = 0.0
|
| 173 |
+
lambda_grad_img: float = 0.2
|
| 174 |
+
lambda_edge_rgb: float = 0.0
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class UnisharpLoss(nn.Module):
|
| 178 |
+
|
| 179 |
+
SUPERVISION_MAX_DEPTH_M: float = DEFAULT_MAX_DEPTH_M
|
| 180 |
+
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
weights: UnisharpLossWeights | None = None,
|
| 184 |
+
*,
|
| 185 |
+
grad_sigma: float = 1e-2,
|
| 186 |
+
grad_eps: float = 1e-2,
|
| 187 |
+
delta_clip: float = 10.0,
|
| 188 |
+
raw_delta_clip: float = 400.0,
|
| 189 |
+
raw_delta_rho_clip: float = 5.0,
|
| 190 |
+
alpha_tail_min: float = 0.99,
|
| 191 |
+
alpha_tail_weight: float = 0.0,
|
| 192 |
+
splat_sigma_min: float = 1e-1,
|
| 193 |
+
splat_sigma_max: float = 1e2,
|
| 194 |
+
edge_splat_sigma_max: float = 2.0,
|
| 195 |
+
depth_edge_log_threshold: float = 0.05,
|
| 196 |
+
depth_edge_dilate_px: int = 2,
|
| 197 |
+
percep_max_side: int = 384,
|
| 198 |
+
grad_img_scales: int = 4,
|
| 199 |
+
grad_img_circular_h: bool = True,
|
| 200 |
+
) -> None:
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.w = weights or UnisharpLossWeights()
|
| 203 |
+
self.grad_sigma = float(grad_sigma)
|
| 204 |
+
self.grad_eps = float(grad_eps)
|
| 205 |
+
self.delta_clip = float(delta_clip)
|
| 206 |
+
self.raw_delta_clip = float(raw_delta_clip)
|
| 207 |
+
self.raw_delta_rho_clip = float(raw_delta_rho_clip)
|
| 208 |
+
self.alpha_tail_min = float(alpha_tail_min)
|
| 209 |
+
self.alpha_tail_weight = float(alpha_tail_weight)
|
| 210 |
+
self.splat_sigma_min = float(splat_sigma_min)
|
| 211 |
+
self.splat_sigma_max = float(splat_sigma_max)
|
| 212 |
+
self.edge_splat_sigma_max = float(edge_splat_sigma_max)
|
| 213 |
+
self.depth_edge_log_threshold = float(depth_edge_log_threshold)
|
| 214 |
+
self.depth_edge_dilate_px = int(depth_edge_dilate_px)
|
| 215 |
+
self.percep_max_side = int(percep_max_side)
|
| 216 |
+
self.grad_img_scales = int(grad_img_scales)
|
| 217 |
+
self.grad_img_circular_h = bool(grad_img_circular_h)
|
| 218 |
+
|
| 219 |
+
sobel_kx = torch.tensor(
|
| 220 |
+
[[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]
|
| 221 |
+
).view(1, 1, 3, 3)
|
| 222 |
+
sobel_ky = torch.tensor(
|
| 223 |
+
[[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]
|
| 224 |
+
).view(1, 1, 3, 3)
|
| 225 |
+
self.register_buffer("_sobel_kx", sobel_kx, persistent=False)
|
| 226 |
+
self.register_buffer("_sobel_ky", sobel_ky, persistent=False)
|
| 227 |
+
|
| 228 |
+
self._percep_net: nn.Module | None = None
|
| 229 |
+
if self.w.lambda_percep > 0:
|
| 230 |
+
self._percep_net = _ResNet50Perceptual()
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def _flatten_gaussian_xyz(x: torch.Tensor | None, gauss_grid_shape: tuple[int, int, int] | None = None) -> torch.Tensor | None:
|
| 234 |
+
if not torch.is_tensor(x):
|
| 235 |
+
return None
|
| 236 |
+
if x.ndim == 5:
|
| 237 |
+
return x.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 238 |
+
if x.ndim == 3 and int(x.shape[-1]) == 3:
|
| 239 |
+
return x
|
| 240 |
+
if x.ndim == 2 and gauss_grid_shape is not None:
|
| 241 |
+
return x.unsqueeze(-1)
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
@staticmethod
|
| 245 |
+
def _flatten_gaussian_quat(
|
| 246 |
+
x: torch.Tensor | None,
|
| 247 |
+
gauss_grid_shape: tuple[int, int, int] | None = None,
|
| 248 |
+
) -> torch.Tensor | None:
|
| 249 |
+
if not torch.is_tensor(x):
|
| 250 |
+
return None
|
| 251 |
+
if x.ndim == 5 and int(x.shape[1]) == 4:
|
| 252 |
+
return x.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 253 |
+
if x.ndim == 3 and int(x.shape[-1]) == 4:
|
| 254 |
+
return x
|
| 255 |
+
if x.ndim == 2 and gauss_grid_shape is not None:
|
| 256 |
+
return x.unsqueeze(-1)
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
@staticmethod
|
| 260 |
+
def _flatten_gaussian_scalar(
|
| 261 |
+
x: torch.Tensor | None,
|
| 262 |
+
gauss_grid_shape: tuple[int, int, int] | None = None,
|
| 263 |
+
) -> torch.Tensor | None:
|
| 264 |
+
if not torch.is_tensor(x):
|
| 265 |
+
return None
|
| 266 |
+
if x.ndim == 5:
|
| 267 |
+
return x[:, 0].flatten(1)
|
| 268 |
+
if x.ndim == 4:
|
| 269 |
+
return x.flatten(1)
|
| 270 |
+
if x.ndim == 3 and int(x.shape[-1]) == 1:
|
| 271 |
+
return x[..., 0]
|
| 272 |
+
if x.ndim == 2:
|
| 273 |
+
return x
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def _central_disparity_gradient(inv_depth: torch.Tensor, *, circular_h: bool) -> torch.Tensor:
|
| 278 |
+
if circular_h:
|
| 279 |
+
gx = 0.5 * (torch.roll(inv_depth, shifts=-1, dims=-1) - torch.roll(inv_depth, shifts=1, dims=-1)).abs()
|
| 280 |
+
else:
|
| 281 |
+
padded_x = F.pad(inv_depth, (1, 1, 0, 0), mode="replicate")
|
| 282 |
+
gx = 0.5 * (padded_x[..., 2:] - padded_x[..., :-2]).abs()
|
| 283 |
+
padded_y = F.pad(inv_depth, (0, 0, 1, 1), mode="replicate")
|
| 284 |
+
gy = 0.5 * (padded_y[..., 2:, :] - padded_y[..., :-2, :]).abs()
|
| 285 |
+
return torch.sqrt(gx * gx + gy * gy + 1e-12)
|
| 286 |
+
|
| 287 |
+
@staticmethod
|
| 288 |
+
def _sample_map_at_uv(feat: torch.Tensor, u: torch.Tensor, v: torch.Tensor, valid: torch.Tensor) -> torch.Tensor:
|
| 289 |
+
b, _, h, w = feat.shape
|
| 290 |
+
valid_bool = valid.to(dtype=torch.bool) & torch.isfinite(u) & torch.isfinite(v)
|
| 291 |
+
u_safe = torch.where(valid_bool, u, torch.zeros_like(u)).clamp(0.0, float(max(w - 1, 0)))
|
| 292 |
+
v_safe = torch.where(valid_bool, v, torch.zeros_like(v)).clamp(0.0, float(max(h - 1, 0)))
|
| 293 |
+
grid_x = (u_safe / max(float(w - 1), 1.0)) * 2.0 - 1.0
|
| 294 |
+
grid_y = (v_safe / max(float(h - 1), 1.0)) * 2.0 - 1.0
|
| 295 |
+
grid = torch.stack([grid_x, grid_y], dim=-1).view(b, -1, 1, 2)
|
| 296 |
+
sampled = F.grid_sample(feat, grid, mode="bilinear", padding_mode="zeros", align_corners=True)
|
| 297 |
+
return sampled[:, 0, :, 0] * valid_bool.to(dtype=feat.dtype)
|
| 298 |
+
|
| 299 |
+
@staticmethod
|
| 300 |
+
def _expand_camera_params(camera_params: torch.Tensor, *, batch_size: int, device: torch.device) -> torch.Tensor:
|
| 301 |
+
params = camera_params.to(device=device, dtype=torch.float32)
|
| 302 |
+
if params.ndim == 1:
|
| 303 |
+
params = params.unsqueeze(0)
|
| 304 |
+
if int(params.shape[0]) == 1 and int(batch_size) > 1:
|
| 305 |
+
params = params.expand(int(batch_size), -1)
|
| 306 |
+
return params
|
| 307 |
+
|
| 308 |
+
@staticmethod
|
| 309 |
+
def _project_fisheye624_points_px_stable(
|
| 310 |
+
pts: torch.Tensor,
|
| 311 |
+
camera_params: torch.Tensor,
|
| 312 |
+
*,
|
| 313 |
+
image_h: int,
|
| 314 |
+
image_w: int,
|
| 315 |
+
finite: torch.Tensor,
|
| 316 |
+
require_in_bounds: bool = True,
|
| 317 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 318 |
+
b, n, _ = pts.shape
|
| 319 |
+
params = UnisharpLoss._expand_camera_params(camera_params, batch_size=b, device=pts.device)
|
| 320 |
+
x, y, z = pts.unbind(dim=-1)
|
| 321 |
+
radius = torch.linalg.vector_norm(pts, dim=-1).clamp(min=1e-6)
|
| 322 |
+
front = z > (radius * 1e-4).clamp(min=1e-4)
|
| 323 |
+
projectable = finite & front
|
| 324 |
+
|
| 325 |
+
safe_pts = torch.zeros_like(pts)
|
| 326 |
+
safe_pts[..., 2] = 1.0
|
| 327 |
+
pts_proj = torch.where(projectable.unsqueeze(-1), pts, safe_pts)
|
| 328 |
+
x, y, z = pts_proj.unbind(dim=-1)
|
| 329 |
+
z_safe = z.clamp(min=1e-4)
|
| 330 |
+
|
| 331 |
+
ab = torch.stack([x / z_safe, y / z_safe], dim=-1)
|
| 332 |
+
r = torch.sqrt((ab * ab).sum(dim=-1, keepdim=True) + 1e-12)
|
| 333 |
+
theta = torch.atan(r)
|
| 334 |
+
unit_ab = ab / r
|
| 335 |
+
|
| 336 |
+
coeffs = params[:, 4:10].reshape(b, 1, 6)
|
| 337 |
+
theta_powers = torch.cat([theta.pow(3 + i * 2) for i in range(6)], dim=-1)
|
| 338 |
+
theta_distorted = theta + (theta_powers * coeffs).sum(dim=-1, keepdim=True)
|
| 339 |
+
uv_dist = theta_distorted * unit_ab
|
| 340 |
+
|
| 341 |
+
p0 = params[..., -6].reshape(b, 1)
|
| 342 |
+
p1 = params[..., -5].reshape(b, 1)
|
| 343 |
+
xr = uv_dist[..., 0]
|
| 344 |
+
yr = uv_dist[..., 1]
|
| 345 |
+
xr_sq = xr.square()
|
| 346 |
+
yr_sq = yr.square()
|
| 347 |
+
rd_sq = xr_sq + yr_sq
|
| 348 |
+
uv_x = uv_dist[..., 0] + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
|
| 349 |
+
uv_y = uv_dist[..., 1] + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
|
| 350 |
+
|
| 351 |
+
s0 = params[..., -4].reshape(b, 1)
|
| 352 |
+
s1 = params[..., -3].reshape(b, 1)
|
| 353 |
+
s2 = params[..., -2].reshape(b, 1)
|
| 354 |
+
s3 = params[..., -1].reshape(b, 1)
|
| 355 |
+
rd_4 = rd_sq.square()
|
| 356 |
+
uv_x = uv_x + s0 * rd_sq + s1 * rd_4
|
| 357 |
+
uv_y = uv_y + s2 * rd_sq + s3 * rd_4
|
| 358 |
+
|
| 359 |
+
if int(params.shape[-1]) == 15:
|
| 360 |
+
fx = fy = params[..., 0:1]
|
| 361 |
+
cx = params[..., 1:2]
|
| 362 |
+
cy = params[..., 2:3]
|
| 363 |
+
else:
|
| 364 |
+
fx = params[..., 0:1]
|
| 365 |
+
fy = params[..., 1:2]
|
| 366 |
+
cx = params[..., 2:3]
|
| 367 |
+
cy = params[..., 3:4]
|
| 368 |
+
u = uv_x * fx + cx
|
| 369 |
+
v = uv_y * fy + cy
|
| 370 |
+
valid = projectable & torch.isfinite(u) & torch.isfinite(v)
|
| 371 |
+
if require_in_bounds:
|
| 372 |
+
valid = valid & (u >= 0.0) & (u <= float(image_w - 1)) & (v >= 0.0) & (v <= float(image_h - 1))
|
| 373 |
+
return u, v, valid, radius
|
| 374 |
+
|
| 375 |
+
@staticmethod
|
| 376 |
+
def _project_points_px(
|
| 377 |
+
points: torch.Tensor,
|
| 378 |
+
*,
|
| 379 |
+
projection_model: str | None,
|
| 380 |
+
image_h: int,
|
| 381 |
+
image_w: int,
|
| 382 |
+
intrinsics: torch.Tensor | None = None,
|
| 383 |
+
camera_params: torch.Tensor | None = None,
|
| 384 |
+
require_in_bounds: bool = True,
|
| 385 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 386 |
+
pts_raw = points.to(dtype=torch.float32)
|
| 387 |
+
finite = torch.isfinite(pts_raw).all(dim=-1)
|
| 388 |
+
pts = torch.nan_to_num(pts_raw, nan=0.0, posinf=0.0, neginf=0.0)
|
| 389 |
+
b, n, _ = pts.shape
|
| 390 |
+
x, y, z = pts.unbind(dim=-1)
|
| 391 |
+
model = (projection_model or "pinhole").lower()
|
| 392 |
+
|
| 393 |
+
if model in _ERP_PROJECTION_MODELS:
|
| 394 |
+
radius_sq_raw = (pts * pts).sum(dim=-1)
|
| 395 |
+
direction_valid = finite & (radius_sq_raw > 1e-12)
|
| 396 |
+
safe_pts = torch.zeros_like(pts)
|
| 397 |
+
safe_pts[..., 2] = 1.0
|
| 398 |
+
pts_erp = torch.where(direction_valid.unsqueeze(-1), pts, safe_pts)
|
| 399 |
+
x, y, z = pts_erp.unbind(dim=-1)
|
| 400 |
+
radius_sq = (pts_erp * pts_erp).sum(dim=-1)
|
| 401 |
+
radius = torch.sqrt(radius_sq + 1e-12)
|
| 402 |
+
horizontal_sq = x.square() + z.square()
|
| 403 |
+
horizontal = torch.sqrt(horizontal_sq + 1e-12)
|
| 404 |
+
pole_angle_eps = max(1e-4, 0.5 * math.pi / float(max(image_h, image_w, 1)))
|
| 405 |
+
lon_valid = horizontal > radius * pole_angle_eps
|
| 406 |
+
lon_x = torch.where(lon_valid, x, torch.zeros_like(x))
|
| 407 |
+
lon_z = torch.where(lon_valid, z, torch.ones_like(z))
|
| 408 |
+
lon = torch.atan2(lon_x, lon_z)
|
| 409 |
+
pitch_down = torch.atan2(y, horizontal)
|
| 410 |
+
u = (lon / (2.0 * math.pi) + 0.5) * float(max(image_w, 1)) - 0.5
|
| 411 |
+
v = (0.5 + pitch_down / math.pi) * float(max(image_h, 1)) - 0.5
|
| 412 |
+
valid = direction_valid & lon_valid
|
| 413 |
+
valid = (
|
| 414 |
+
valid
|
| 415 |
+
& torch.isfinite(u)
|
| 416 |
+
& torch.isfinite(v)
|
| 417 |
+
& (u >= 0.0)
|
| 418 |
+
& (u <= float(image_w - 1))
|
| 419 |
+
& (v >= 0.0)
|
| 420 |
+
& (v <= float(image_h - 1))
|
| 421 |
+
)
|
| 422 |
+
return u, v, valid, radius.clamp(min=1e-6)
|
| 423 |
+
|
| 424 |
+
if model in _FISHEYE_PROJECTION_MODELS and torch.is_tensor(camera_params):
|
| 425 |
+
return UnisharpLoss._project_fisheye624_points_px_stable(
|
| 426 |
+
pts,
|
| 427 |
+
camera_params,
|
| 428 |
+
image_h=image_h,
|
| 429 |
+
image_w=image_w,
|
| 430 |
+
finite=finite,
|
| 431 |
+
require_in_bounds=require_in_bounds,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
valid = finite & (z > 1e-4)
|
| 435 |
+
if not torch.is_tensor(intrinsics):
|
| 436 |
+
fx = torch.full((b, 1), float(max(image_w, image_h)), device=pts.device, dtype=torch.float32)
|
| 437 |
+
fy = fx.clone()
|
| 438 |
+
cx = torch.full((b, 1), 0.5 * float(max(image_w - 1, 1)), device=pts.device, dtype=torch.float32)
|
| 439 |
+
cy = torch.full((b, 1), 0.5 * float(max(image_h - 1, 1)), device=pts.device, dtype=torch.float32)
|
| 440 |
+
else:
|
| 441 |
+
k = intrinsics.to(device=pts.device, dtype=torch.float32)
|
| 442 |
+
if k.ndim == 2:
|
| 443 |
+
k = k.unsqueeze(0)
|
| 444 |
+
if int(k.shape[0]) == 1 and b > 1:
|
| 445 |
+
k = k.expand(b, -1, -1)
|
| 446 |
+
fx = k[:, 0, 0:1]
|
| 447 |
+
fy = k[:, 1, 1:2]
|
| 448 |
+
cx = k[:, 0, 2:3]
|
| 449 |
+
cy = k[:, 1, 2:3]
|
| 450 |
+
z_safe = z.clamp(min=1e-4)
|
| 451 |
+
u = fx * (x / z_safe) + cx
|
| 452 |
+
v = fy * (y / z_safe) + cy
|
| 453 |
+
valid = valid & torch.isfinite(u) & torch.isfinite(v)
|
| 454 |
+
if require_in_bounds:
|
| 455 |
+
valid = valid & (u >= 0.0) & (u <= float(image_w - 1)) & (v >= 0.0) & (v <= float(image_h - 1))
|
| 456 |
+
return u, v, valid, z_safe
|
| 457 |
+
|
| 458 |
+
def _projected_sigma_px(
|
| 459 |
+
self,
|
| 460 |
+
*,
|
| 461 |
+
gaussian_scales: torch.Tensor,
|
| 462 |
+
gaussian_quaternions: torch.Tensor | None,
|
| 463 |
+
gaussian_mean_vectors: torch.Tensor,
|
| 464 |
+
valid: torch.Tensor,
|
| 465 |
+
projection_model: str | None,
|
| 466 |
+
image_h: int,
|
| 467 |
+
image_w: int,
|
| 468 |
+
intrinsics: torch.Tensor | None = None,
|
| 469 |
+
camera_params: torch.Tensor | None = None,
|
| 470 |
+
projected_scale_factor: float | torch.Tensor | None = None,
|
| 471 |
+
) -> torch.Tensor:
|
| 472 |
+
scales = self._flatten_gaussian_xyz(gaussian_scales)
|
| 473 |
+
quats = self._flatten_gaussian_quat(gaussian_quaternions)
|
| 474 |
+
means = self._flatten_gaussian_xyz(gaussian_mean_vectors)
|
| 475 |
+
if scales is None or means is None:
|
| 476 |
+
return torch.zeros_like(valid, dtype=torch.float32)
|
| 477 |
+
valid = valid.to(dtype=torch.bool) & torch.isfinite(scales).all(dim=-1) & torch.isfinite(means).all(dim=-1)
|
| 478 |
+
scales = torch.nan_to_num(scales.to(dtype=torch.float32), nan=0.0, posinf=0.0, neginf=0.0).abs()
|
| 479 |
+
means = torch.nan_to_num(means.to(dtype=torch.float32), nan=0.0, posinf=0.0, neginf=0.0)
|
| 480 |
+
model = (projection_model or "pinhole").lower()
|
| 481 |
+
if model in _ERP_PROJECTION_MODELS:
|
| 482 |
+
radius = torch.norm(means, dim=-1).clamp(min=1e-4)
|
| 483 |
+
sigma_u = scales[..., 0] / radius * (float(max(image_w, 1)) / (2.0 * math.pi))
|
| 484 |
+
sigma_v = scales[..., 1] / radius * (float(max(image_h, 1)) / math.pi)
|
| 485 |
+
sigma_px = torch.maximum(sigma_u.square(), sigma_v.square())
|
| 486 |
+
valid = valid & torch.isfinite(sigma_px)
|
| 487 |
+
sigma_px = torch.nan_to_num(sigma_px, nan=0.0, posinf=0.0, neginf=0.0)
|
| 488 |
+
return torch.where(valid, sigma_px, torch.zeros_like(sigma_px))
|
| 489 |
+
|
| 490 |
+
if quats is not None and tuple(quats.shape[:2]) == tuple(means.shape[:2]):
|
| 491 |
+
quats = torch.nan_to_num(quats.to(dtype=torch.float32), nan=0.0, posinf=0.0, neginf=0.0)
|
| 492 |
+
quat_norm = quats.norm(dim=-1, keepdim=True)
|
| 493 |
+
valid = valid & torch.isfinite(quats).all(dim=-1) & (quat_norm.squeeze(-1) > 1e-8)
|
| 494 |
+
quats = quats / quat_norm.clamp(min=1e-8)
|
| 495 |
+
rotations = linalg.rotation_matrices_from_quaternions(quats)
|
| 496 |
+
tangent_scales = scales[..., :2]
|
| 497 |
+
tangent_rotations = rotations[..., :, :2]
|
| 498 |
+
axis_offsets = (tangent_rotations * tangent_scales[..., None, :]).transpose(-1, -2)
|
| 499 |
+
axis_points = means[:, :, None, :] + axis_offsets
|
| 500 |
+
u0, v0, valid0, _ = self._project_points_px(
|
| 501 |
+
means,
|
| 502 |
+
projection_model=projection_model,
|
| 503 |
+
image_h=image_h,
|
| 504 |
+
image_w=image_w,
|
| 505 |
+
intrinsics=intrinsics,
|
| 506 |
+
camera_params=camera_params,
|
| 507 |
+
require_in_bounds=False,
|
| 508 |
+
)
|
| 509 |
+
b, n, axis_count, _ = axis_points.shape
|
| 510 |
+
u1, v1, valid1, _ = self._project_points_px(
|
| 511 |
+
axis_points.reshape(b, n * axis_count, 3),
|
| 512 |
+
projection_model=projection_model,
|
| 513 |
+
image_h=image_h,
|
| 514 |
+
image_w=image_w,
|
| 515 |
+
intrinsics=intrinsics,
|
| 516 |
+
camera_params=camera_params,
|
| 517 |
+
require_in_bounds=False,
|
| 518 |
+
)
|
| 519 |
+
u1 = u1.reshape(b, n, axis_count)
|
| 520 |
+
v1 = v1.reshape(b, n, axis_count)
|
| 521 |
+
valid1 = valid1.reshape(b, n, axis_count)
|
| 522 |
+
du = u1 - u0[..., None]
|
| 523 |
+
dv = v1 - v0[..., None]
|
| 524 |
+
if (projection_model or "pinhole").lower() in _ERP_PROJECTION_MODELS:
|
| 525 |
+
width = float(max(image_w, 1))
|
| 526 |
+
du = torch.remainder(du + 0.5 * width, width) - 0.5 * width
|
| 527 |
+
cov_xx = (du * du).sum(dim=-1)
|
| 528 |
+
cov_xy = (du * dv).sum(dim=-1)
|
| 529 |
+
cov_yy = (dv * dv).sum(dim=-1)
|
| 530 |
+
trace = cov_xx + cov_yy
|
| 531 |
+
disc = (cov_xx - cov_yy).square() + 4.0 * cov_xy.square()
|
| 532 |
+
sigma_px = 0.5 * (trace + (disc.clamp(min=0.0) + 1e-12).sqrt())
|
| 533 |
+
valid = valid & valid0 & valid1.all(dim=-1) & torch.isfinite(sigma_px)
|
| 534 |
+
sigma_px = torch.nan_to_num(sigma_px, nan=0.0, posinf=0.0, neginf=0.0)
|
| 535 |
+
return torch.where(valid, sigma_px, torch.zeros_like(sigma_px))
|
| 536 |
+
|
| 537 |
+
sigma_screen_3d = scales[..., :2].to(dtype=torch.float32).abs().amax(dim=-1).clamp(min=1e-8)
|
| 538 |
+
if model in {"fisheye624", "opencv_fisheye"} and torch.is_tensor(camera_params):
|
| 539 |
+
params = camera_params.to(device=means.device, dtype=torch.float32)
|
| 540 |
+
if params.ndim == 1:
|
| 541 |
+
params = params.unsqueeze(0)
|
| 542 |
+
if int(params.shape[0]) == 1 and int(means.shape[0]) > 1:
|
| 543 |
+
params = params.expand(int(means.shape[0]), -1)
|
| 544 |
+
if int(params.shape[-1]) == 15:
|
| 545 |
+
focal = params[:, 0:1].clamp(min=1.0)
|
| 546 |
+
else:
|
| 547 |
+
focal = 0.5 * (params[:, 0:1] + params[:, 1:2]).clamp(min=1.0)
|
| 548 |
+
radius = torch.norm(means, dim=-1).clamp(min=1e-4)
|
| 549 |
+
sigma_px = (sigma_screen_3d / radius * focal).square()
|
| 550 |
+
elif torch.is_tensor(intrinsics):
|
| 551 |
+
k = intrinsics.to(device=means.device, dtype=torch.float32)
|
| 552 |
+
if k.ndim == 2:
|
| 553 |
+
k = k.unsqueeze(0)
|
| 554 |
+
if int(k.shape[0]) == 1 and int(means.shape[0]) > 1:
|
| 555 |
+
k = k.expand(int(means.shape[0]), -1, -1)
|
| 556 |
+
focal = 0.5 * (k[:, 0, 0:1] + k[:, 1, 1:2]).clamp(min=1.0)
|
| 557 |
+
depth = means[..., 2].clamp(min=1e-4)
|
| 558 |
+
sigma_px = (sigma_screen_3d / depth * focal).square()
|
| 559 |
+
else:
|
| 560 |
+
depth = torch.norm(means, dim=-1).clamp(min=1e-4)
|
| 561 |
+
sigma_px = sigma_screen_3d / depth
|
| 562 |
+
if torch.is_tensor(projected_scale_factor):
|
| 563 |
+
sigma_px = sigma_px * projected_scale_factor.to(device=sigma_px.device, dtype=sigma_px.dtype)
|
| 564 |
+
elif projected_scale_factor is not None:
|
| 565 |
+
sigma_px = sigma_px * float(projected_scale_factor)
|
| 566 |
+
sigma_px = sigma_px.square()
|
| 567 |
+
valid = valid & torch.isfinite(sigma_px)
|
| 568 |
+
sigma_px = torch.nan_to_num(sigma_px, nan=0.0, posinf=0.0, neginf=0.0)
|
| 569 |
+
return torch.where(valid, sigma_px, torch.zeros_like(sigma_px))
|
| 570 |
+
|
| 571 |
+
def _depth_edge_band(
|
| 572 |
+
self,
|
| 573 |
+
depth_m: torch.Tensor,
|
| 574 |
+
valid_weight: torch.Tensor,
|
| 575 |
+
*,
|
| 576 |
+
circular_h: bool,
|
| 577 |
+
) -> torch.Tensor:
|
| 578 |
+
depth = depth_m.to(dtype=torch.float32)
|
| 579 |
+
if depth.ndim == 3:
|
| 580 |
+
depth = depth.unsqueeze(1)
|
| 581 |
+
valid = torch.isfinite(depth) & (depth > 0.0) & (valid_weight[:, :1].to(dtype=torch.float32) > 0.5)
|
| 582 |
+
log_depth = torch.where(valid, depth.clamp(min=1e-4).log(), torch.zeros_like(depth))
|
| 583 |
+
|
| 584 |
+
if bool(circular_h):
|
| 585 |
+
right = torch.roll(log_depth, shifts=-1, dims=-1)
|
| 586 |
+
valid_right = valid & torch.roll(valid, shifts=-1, dims=-1)
|
| 587 |
+
edge_x = (right - log_depth).abs() > float(self.depth_edge_log_threshold)
|
| 588 |
+
edge_x = edge_x & valid_right
|
| 589 |
+
else:
|
| 590 |
+
edge_x = torch.zeros_like(valid)
|
| 591 |
+
edge_x[..., :, :-1] = (
|
| 592 |
+
(log_depth[..., :, 1:] - log_depth[..., :, :-1]).abs() > float(self.depth_edge_log_threshold)
|
| 593 |
+
) & valid[..., :, 1:] & valid[..., :, :-1]
|
| 594 |
+
|
| 595 |
+
edge_y = torch.zeros_like(valid)
|
| 596 |
+
edge_y[..., :-1, :] = (
|
| 597 |
+
(log_depth[..., 1:, :] - log_depth[..., :-1, :]).abs() > float(self.depth_edge_log_threshold)
|
| 598 |
+
) & valid[..., 1:, :] & valid[..., :-1, :]
|
| 599 |
+
edge = (edge_x | edge_y).to(dtype=torch.float32)
|
| 600 |
+
|
| 601 |
+
radius = max(int(self.depth_edge_dilate_px), 0)
|
| 602 |
+
if radius <= 0:
|
| 603 |
+
return edge
|
| 604 |
+
kernel = 2 * radius + 1
|
| 605 |
+
if bool(circular_h):
|
| 606 |
+
edge = F.pad(edge, (radius, radius, 0, 0), mode="circular")
|
| 607 |
+
edge = F.pad(edge, (0, 0, radius, radius), mode="constant", value=0.0)
|
| 608 |
+
return F.max_pool2d(edge, kernel_size=kernel, stride=1)
|
| 609 |
+
return F.max_pool2d(edge, kernel_size=kernel, stride=1, padding=radius)
|
| 610 |
+
|
| 611 |
+
def _ray_cell_sigma(
|
| 612 |
+
self,
|
| 613 |
+
*,
|
| 614 |
+
gaussian_scales: torch.Tensor,
|
| 615 |
+
gaussian_mean_vectors: torch.Tensor,
|
| 616 |
+
gaussian_angular_cell: torch.Tensor,
|
| 617 |
+
gauss_grid_shape: tuple[int, int, int] | None,
|
| 618 |
+
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
| 619 |
+
scales = self._flatten_gaussian_xyz(gaussian_scales, gauss_grid_shape)
|
| 620 |
+
means = self._flatten_gaussian_xyz(gaussian_mean_vectors, gauss_grid_shape)
|
| 621 |
+
if scales is None or means is None:
|
| 622 |
+
return None, None
|
| 623 |
+
if not torch.is_tensor(gaussian_angular_cell):
|
| 624 |
+
return None, None
|
| 625 |
+
cell = gaussian_angular_cell.to(device=scales.device, dtype=torch.float32)
|
| 626 |
+
if cell.ndim != 5 or int(cell.shape[1]) != 2:
|
| 627 |
+
return None, None
|
| 628 |
+
if gauss_grid_shape is None:
|
| 629 |
+
return None, None
|
| 630 |
+
l, h, w = (int(gauss_grid_shape[0]), int(gauss_grid_shape[1]), int(gauss_grid_shape[2]))
|
| 631 |
+
if tuple(cell.shape[-2:]) != (h, w):
|
| 632 |
+
return None, None
|
| 633 |
+
if int(cell.shape[2]) == 1 and l > 1:
|
| 634 |
+
cell = cell.expand(-1, -1, l, -1, -1)
|
| 635 |
+
elif int(cell.shape[2]) != l:
|
| 636 |
+
return None, None
|
| 637 |
+
cell_flat = cell.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 638 |
+
if int(cell_flat.shape[1]) != int(scales.shape[1]):
|
| 639 |
+
return None, None
|
| 640 |
+
radius = torch.linalg.norm(means.to(dtype=torch.float32), dim=-1, keepdim=True).clamp(min=1e-4)
|
| 641 |
+
tangent = scales[..., :2].to(dtype=torch.float32).abs()
|
| 642 |
+
sigma_cells = (tangent / radius / cell_flat.clamp(min=1e-6)).square()
|
| 643 |
+
valid = torch.isfinite(sigma_cells).all(dim=-1) & torch.isfinite(radius.squeeze(-1))
|
| 644 |
+
sigma_cells = torch.nan_to_num(sigma_cells, nan=0.0, posinf=0.0, neginf=0.0)
|
| 645 |
+
return sigma_cells, valid
|
| 646 |
+
|
| 647 |
+
def _dynamic_splat_sigma_limits(
|
| 648 |
+
self,
|
| 649 |
+
*,
|
| 650 |
+
sigma_proj: torch.Tensor,
|
| 651 |
+
projection_model: str | None,
|
| 652 |
+
image_h: int,
|
| 653 |
+
image_w: int,
|
| 654 |
+
intrinsics: torch.Tensor | None = None,
|
| 655 |
+
camera_params: torch.Tensor | None = None,
|
| 656 |
+
projected_scale_factor: float | torch.Tensor | None = None,
|
| 657 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 658 |
+
del projection_model, image_h, image_w, intrinsics, camera_params, projected_scale_factor
|
| 659 |
+
return (
|
| 660 |
+
torch.as_tensor(self.splat_sigma_min, device=sigma_proj.device, dtype=sigma_proj.dtype),
|
| 661 |
+
torch.as_tensor(self.splat_sigma_max, device=sigma_proj.device, dtype=sigma_proj.dtype),
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
def _sanitize_supervision_depth(self, depth_m: torch.Tensor, *, clamp_max: bool = True) -> torch.Tensor:
|
| 665 |
+
depth = depth_m.to(torch.float32)
|
| 666 |
+
valid = torch.isfinite(depth) & (depth > 0.0)
|
| 667 |
+
depth = torch.where(valid, depth, torch.zeros_like(depth))
|
| 668 |
+
if bool(valid.any().item()):
|
| 669 |
+
depth = depth.clone()
|
| 670 |
+
if bool(clamp_max):
|
| 671 |
+
depth[valid] = depth[valid].clamp(min=1e-4, max=float(self.SUPERVISION_MAX_DEPTH_M))
|
| 672 |
+
else:
|
| 673 |
+
depth[valid] = depth[valid].clamp(min=1e-4)
|
| 674 |
+
return depth
|
| 675 |
+
|
| 676 |
+
def _sobel_gradient_loss_erp(
|
| 677 |
+
self,
|
| 678 |
+
pred_depth_m: torch.Tensor,
|
| 679 |
+
gt_depth_m: torch.Tensor,
|
| 680 |
+
depth_weight: torch.Tensor,
|
| 681 |
+
circular_h: bool | None = None,
|
| 682 |
+
) -> torch.Tensor:
|
| 683 |
+
dtype = pred_depth_m.dtype
|
| 684 |
+
device = pred_depth_m.device
|
| 685 |
+
|
| 686 |
+
kx = self._sobel_kx.to(dtype=dtype, device=device) # type: ignore[attr-defined]
|
| 687 |
+
ky = self._sobel_ky.to(dtype=dtype, device=device) # type: ignore[attr-defined]
|
| 688 |
+
|
| 689 |
+
log_pred = torch.log(pred_depth_m.clamp(min=1e-4))
|
| 690 |
+
log_gt = torch.log(gt_depth_m.clamp(min=1e-4))
|
| 691 |
+
log_diff = log_pred - log_gt
|
| 692 |
+
|
| 693 |
+
mask = depth_weight.to(dtype=dtype).clamp(min=0.0, max=1.0)
|
| 694 |
+
valid_mask = (mask > 0.5).to(dtype=dtype)
|
| 695 |
+
log_diff = torch.where(valid_mask > 0.5, log_diff, torch.zeros_like(log_diff))
|
| 696 |
+
|
| 697 |
+
total = torch.zeros((), device=device, dtype=dtype)
|
| 698 |
+
n_computed = 0
|
| 699 |
+
|
| 700 |
+
use_circular_h = self.grad_img_circular_h if circular_h is None else bool(circular_h)
|
| 701 |
+
ones_kernel = torch.ones((1, 1, 3, 3), device=device, dtype=dtype)
|
| 702 |
+
for _s in range(self.grad_img_scales):
|
| 703 |
+
if min(log_diff.shape[-2:]) < 4:
|
| 704 |
+
break
|
| 705 |
+
|
| 706 |
+
if use_circular_h:
|
| 707 |
+
padded = F.pad(log_diff, (1, 1, 0, 0), mode="circular")
|
| 708 |
+
padded = F.pad(padded, (0, 0, 1, 1), mode="reflect")
|
| 709 |
+
padded_mask = F.pad(valid_mask, (1, 1, 0, 0), mode="circular")
|
| 710 |
+
padded_mask = F.pad(padded_mask, (0, 0, 1, 1), mode="replicate")
|
| 711 |
+
else:
|
| 712 |
+
padded = F.pad(log_diff, (1, 1, 1, 1), mode="reflect")
|
| 713 |
+
padded_mask = F.pad(valid_mask, (1, 1, 1, 1), mode="replicate")
|
| 714 |
+
|
| 715 |
+
gx = F.conv2d(padded, kx)
|
| 716 |
+
gy = F.conv2d(padded, ky)
|
| 717 |
+
grad_mag = torch.sqrt(gx * gx + gy * gy + 1e-8)
|
| 718 |
+
|
| 719 |
+
stencil_valid = (F.conv2d(padded_mask, ones_kernel) >= 8.999).to(dtype=dtype)
|
| 720 |
+
n_valid = stencil_valid.sum().clamp(min=1.0)
|
| 721 |
+
total = total + (grad_mag * stencil_valid).sum() / n_valid
|
| 722 |
+
n_computed += 1
|
| 723 |
+
|
| 724 |
+
if _s < self.grad_img_scales - 1:
|
| 725 |
+
if use_circular_h:
|
| 726 |
+
pooled_mask = _avg_pool2d_circular_h(valid_mask, kernel_size=2, stride=2)
|
| 727 |
+
pooled_diff = _avg_pool2d_circular_h(log_diff * valid_mask, kernel_size=2, stride=2)
|
| 728 |
+
else:
|
| 729 |
+
pooled_mask = F.avg_pool2d(valid_mask, kernel_size=2, stride=2)
|
| 730 |
+
pooled_diff = F.avg_pool2d(log_diff * valid_mask, kernel_size=2, stride=2)
|
| 731 |
+
log_diff = pooled_diff / pooled_mask.clamp(min=1e-6)
|
| 732 |
+
valid_mask = (pooled_mask > 0.999).to(dtype=dtype)
|
| 733 |
+
log_diff = torch.where(valid_mask > 0.5, log_diff, torch.zeros_like(log_diff))
|
| 734 |
+
|
| 735 |
+
if n_computed == 0:
|
| 736 |
+
return torch.zeros((), device=device, dtype=dtype)
|
| 737 |
+
return total / float(n_computed)
|
| 738 |
+
|
| 739 |
+
def _sobel_xy_rgb(self, img: torch.Tensor, *, circular_h: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
| 740 |
+
channels = int(img.shape[1])
|
| 741 |
+
kx = self._sobel_kx.to(dtype=img.dtype, device=img.device).expand(channels, 1, 3, 3) # type: ignore[attr-defined]
|
| 742 |
+
ky = self._sobel_ky.to(dtype=img.dtype, device=img.device).expand(channels, 1, 3, 3) # type: ignore[attr-defined]
|
| 743 |
+
if bool(circular_h):
|
| 744 |
+
padded = F.pad(img, (1, 1, 0, 0), mode="circular")
|
| 745 |
+
padded = F.pad(padded, (0, 0, 1, 1), mode="reflect")
|
| 746 |
+
else:
|
| 747 |
+
padded = F.pad(img, (1, 1, 1, 1), mode="reflect")
|
| 748 |
+
return (
|
| 749 |
+
F.conv2d(padded, kx, groups=channels),
|
| 750 |
+
F.conv2d(padded, ky, groups=channels),
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
def _edge_rgb_gradient_loss(
|
| 754 |
+
self,
|
| 755 |
+
pred_rgb_linear: torch.Tensor,
|
| 756 |
+
gt_rgb_linear: torch.Tensor,
|
| 757 |
+
valid_weight: torch.Tensor,
|
| 758 |
+
depth_edge_band: torch.Tensor | None,
|
| 759 |
+
*,
|
| 760 |
+
circular_h: bool,
|
| 761 |
+
) -> torch.Tensor:
|
| 762 |
+
dtype = pred_rgb_linear.dtype
|
| 763 |
+
device = pred_rgb_linear.device
|
| 764 |
+
pred = pred_rgb_linear.to(dtype=torch.float32)
|
| 765 |
+
gt = gt_rgb_linear.to(device=device, dtype=torch.float32)
|
| 766 |
+
weight = valid_weight.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)[:, :1]
|
| 767 |
+
|
| 768 |
+
pred_gx, pred_gy = self._sobel_xy_rgb(pred, circular_h=circular_h)
|
| 769 |
+
gt_gx, gt_gy = self._sobel_xy_rgb(gt, circular_h=circular_h)
|
| 770 |
+
gt_mag = torch.sqrt(gt_gx.square() + gt_gy.square() + 1e-8).mean(dim=1, keepdim=True)
|
| 771 |
+
|
| 772 |
+
flat = gt_mag.detach().flatten(2)
|
| 773 |
+
mean = flat.mean(dim=-1, keepdim=True)[..., None]
|
| 774 |
+
std = flat.std(dim=-1, keepdim=True, unbiased=False)[..., None]
|
| 775 |
+
rgb_edge = (gt_mag.detach() > (mean + 0.5 * std).clamp(min=0.02)).to(dtype=torch.float32)
|
| 776 |
+
|
| 777 |
+
if torch.is_tensor(depth_edge_band):
|
| 778 |
+
edge_boost = depth_edge_band.to(device=device, dtype=torch.float32).clamp(0.0, 1.0)
|
| 779 |
+
if tuple(edge_boost.shape[-2:]) != tuple(gt_mag.shape[-2:]):
|
| 780 |
+
edge_boost = F.interpolate(edge_boost, size=gt_mag.shape[-2:], mode="nearest")
|
| 781 |
+
edge_weight = rgb_edge * (1.0 + edge_boost[:, :1])
|
| 782 |
+
else:
|
| 783 |
+
edge_weight = rgb_edge
|
| 784 |
+
|
| 785 |
+
ones_kernel = torch.ones((1, 1, 3, 3), device=device, dtype=torch.float32)
|
| 786 |
+
if bool(circular_h):
|
| 787 |
+
padded_weight = F.pad(weight, (1, 1, 0, 0), mode="circular")
|
| 788 |
+
padded_weight = F.pad(padded_weight, (0, 0, 1, 1), mode="replicate")
|
| 789 |
+
else:
|
| 790 |
+
padded_weight = F.pad(weight, (1, 1, 1, 1), mode="replicate")
|
| 791 |
+
stencil_valid = (F.conv2d(padded_weight, ones_kernel) >= 8.999).to(dtype=torch.float32)
|
| 792 |
+
|
| 793 |
+
diff = (pred_gx - gt_gx).abs() + (pred_gy - gt_gy).abs()
|
| 794 |
+
diff = diff.mean(dim=1, keepdim=True)
|
| 795 |
+
final_weight = edge_weight * stencil_valid
|
| 796 |
+
return (diff * final_weight).sum().to(dtype=dtype) / final_weight.sum().clamp(min=1.0).to(dtype=dtype)
|
| 797 |
+
|
| 798 |
+
def forward(
|
| 799 |
+
self,
|
| 800 |
+
pred_rgb_linear: torch.Tensor,
|
| 801 |
+
pred_alpha: torch.Tensor,
|
| 802 |
+
pred_depth_m: torch.Tensor,
|
| 803 |
+
gt_rgb_u8: torch.Tensor,
|
| 804 |
+
gt_depth_m: torch.Tensor,
|
| 805 |
+
pred_depth2_m: torch.Tensor | None = None,
|
| 806 |
+
mask: torch.Tensor | None = None,
|
| 807 |
+
depth_mask: torch.Tensor | None = None,
|
| 808 |
+
delta_xy: torch.Tensor | None = None,
|
| 809 |
+
delta_rho: torch.Tensor | None = None,
|
| 810 |
+
delta_grid: torch.Tensor | None = None,
|
| 811 |
+
gaussian_scales: torch.Tensor | None = None,
|
| 812 |
+
gaussian_quaternions: torch.Tensor | None = None,
|
| 813 |
+
gaussian_mean_vectors: torch.Tensor | None = None,
|
| 814 |
+
gaussian_base_mean_vectors: torch.Tensor | None = None,
|
| 815 |
+
gaussian_angular_cell: torch.Tensor | None = None,
|
| 816 |
+
gaussian_opacities: torch.Tensor | None = None,
|
| 817 |
+
gauss_grid_shape: tuple[int, int, int] | None = None,
|
| 818 |
+
projected_scale_factor: float | torch.Tensor | None = None,
|
| 819 |
+
projection_model: str | None = None,
|
| 820 |
+
projection_intrinsics: torch.Tensor | None = None,
|
| 821 |
+
projection_camera_params: torch.Tensor | None = None,
|
| 822 |
+
apply_color: bool = True,
|
| 823 |
+
apply_alpha: bool = True,
|
| 824 |
+
apply_depth: bool = True,
|
| 825 |
+
apply_percep: bool = False,
|
| 826 |
+
apply_tv: bool = True,
|
| 827 |
+
apply_grad: bool = True,
|
| 828 |
+
apply_delta: bool = True,
|
| 829 |
+
apply_splat: bool = True,
|
| 830 |
+
apply_grad_img: bool = True,
|
| 831 |
+
grad_img_circular_h: bool | None = None,
|
| 832 |
+
) -> dict[str, torch.Tensor]:
|
| 833 |
+
losses: dict[str, torch.Tensor] = {}
|
| 834 |
+
circular_h = bool(grad_img_circular_h) if grad_img_circular_h is not None else False
|
| 835 |
+
|
| 836 |
+
gt_rgb = gt_rgb_u8.to(pred_rgb_linear.device).float() / 255.0
|
| 837 |
+
gt_rgb_linear = _to_linear_rgb(gt_rgb)
|
| 838 |
+
pred_depth_m = self._sanitize_supervision_depth(pred_depth_m.to(pred_rgb_linear.device), clamp_max=False)
|
| 839 |
+
if pred_depth2_m is not None:
|
| 840 |
+
pred_depth2_m = self._sanitize_supervision_depth(pred_depth2_m.to(pred_rgb_linear.device), clamp_max=False)
|
| 841 |
+
gt_depth_raw = self._sanitize_supervision_depth(gt_depth_m.to(pred_rgb_linear.device))
|
| 842 |
+
depth_valid = torch.isfinite(gt_depth_raw) & (gt_depth_raw > 0.0)
|
| 843 |
+
gt_depth = gt_depth_raw.clamp(min=1e-4)
|
| 844 |
+
|
| 845 |
+
if mask is None:
|
| 846 |
+
m = torch.ones_like(pred_alpha)
|
| 847 |
+
else:
|
| 848 |
+
m = mask.to(pred_rgb_linear.device).to(pred_rgb_linear.dtype)
|
| 849 |
+
depth_weight = depth_valid.to(dtype=pred_depth_m.dtype) * m[:, :1].to(dtype=pred_depth_m.dtype)
|
| 850 |
+
if depth_mask is not None:
|
| 851 |
+
depth_weight = depth_weight * depth_mask.to(pred_rgb_linear.device).to(dtype=pred_depth_m.dtype)[:, :1]
|
| 852 |
+
pred_rgb_rendered = pred_rgb_linear.clamp(0.0, 1.0)
|
| 853 |
+
|
| 854 |
+
if apply_color and self.w.lambda_color > 0:
|
| 855 |
+
color_l1 = (pred_rgb_rendered - gt_rgb_linear).abs()
|
| 856 |
+
losses["color"] = _masked_mean(color_l1, m)
|
| 857 |
+
else:
|
| 858 |
+
losses["color"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 859 |
+
|
| 860 |
+
if apply_alpha and self.w.lambda_alpha > 0:
|
| 861 |
+
a = pred_alpha.clamp(1e-6, 1.0 - 1e-6)
|
| 862 |
+
with torch.autocast(device_type=a.device.type, enabled=False):
|
| 863 |
+
alpha_bce = F.binary_cross_entropy(
|
| 864 |
+
a.to(dtype=torch.float32),
|
| 865 |
+
torch.ones_like(a, dtype=torch.float32),
|
| 866 |
+
reduction="none",
|
| 867 |
+
)
|
| 868 |
+
alpha_loss = _masked_mean(alpha_bce, m)
|
| 869 |
+
alpha_tail_min = torch.as_tensor(
|
| 870 |
+
self.alpha_tail_min,
|
| 871 |
+
device=a.device,
|
| 872 |
+
dtype=torch.float32,
|
| 873 |
+
).clamp(min=0.0, max=1.0)
|
| 874 |
+
alpha_tail_weight = torch.as_tensor(
|
| 875 |
+
max(0.0, self.alpha_tail_weight),
|
| 876 |
+
device=a.device,
|
| 877 |
+
dtype=torch.float32,
|
| 878 |
+
)
|
| 879 |
+
if self.alpha_tail_min > 0.0 and self.alpha_tail_weight > 0.0:
|
| 880 |
+
tail = F.relu(alpha_tail_min - a.to(dtype=torch.float32))
|
| 881 |
+
tail = tail / alpha_tail_min.clamp(min=1e-6)
|
| 882 |
+
tail_mask = (m[:, :1].to(dtype=torch.bool)) & (tail > 0.0)
|
| 883 |
+
alpha_loss = alpha_loss + alpha_tail_weight * _finite_masked_mean_flat(tail, tail_mask)
|
| 884 |
+
losses["alpha"] = alpha_loss.to(dtype=pred_rgb_linear.dtype)
|
| 885 |
+
else:
|
| 886 |
+
losses["alpha"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 887 |
+
|
| 888 |
+
if apply_depth and self.w.lambda_depth > 0:
|
| 889 |
+
w_depth = depth_weight
|
| 890 |
+
inv_pred1 = 1.0 / pred_depth_m.clamp(min=1e-4)
|
| 891 |
+
inv_gt = torch.zeros_like(inv_pred1)
|
| 892 |
+
inv_gt[depth_valid] = 1.0 / gt_depth[depth_valid]
|
| 893 |
+
depth_abs = (inv_pred1 - inv_gt).abs()
|
| 894 |
+
losses["depth"] = _masked_mean(depth_abs, w_depth)
|
| 895 |
+
else:
|
| 896 |
+
losses["depth"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 897 |
+
|
| 898 |
+
if apply_tv and self.w.lambda_tv > 0 and (pred_depth2_m is not None):
|
| 899 |
+
inv2 = 1.0 / pred_depth2_m.clamp(min=1e-4)
|
| 900 |
+
losses["tv"] = _tv_l1_circular_h(inv2) if circular_h else _tv_l1(inv2)
|
| 901 |
+
else:
|
| 902 |
+
losses["tv"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 903 |
+
|
| 904 |
+
image_h, image_w = int(pred_depth_m.shape[-2]), int(pred_depth_m.shape[-1])
|
| 905 |
+
projection_points = self._flatten_gaussian_xyz(gaussian_mean_vectors, gauss_grid_shape)
|
| 906 |
+
projected_u = projected_v = None
|
| 907 |
+
projected_valid = None
|
| 908 |
+
if projection_points is not None:
|
| 909 |
+
projected_u, projected_v, projected_valid, _projected_depth = self._project_points_px(
|
| 910 |
+
projection_points,
|
| 911 |
+
projection_model=projection_model,
|
| 912 |
+
image_h=image_h,
|
| 913 |
+
image_w=image_w,
|
| 914 |
+
intrinsics=projection_intrinsics,
|
| 915 |
+
camera_params=projection_camera_params,
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
if apply_grad and self.w.lambda_grad > 0:
|
| 919 |
+
inv1 = 1.0 / pred_depth_m.clamp(min=1e-4)
|
| 920 |
+
op_flat = self._flatten_gaussian_scalar(gaussian_opacities, gauss_grid_shape)
|
| 921 |
+
if projected_u is not None and projected_v is not None and projected_valid is not None and op_flat is not None:
|
| 922 |
+
grad_map = self._central_disparity_gradient(inv1, circular_h=circular_h)
|
| 923 |
+
grad_at_gauss = self._sample_map_at_uv(grad_map, projected_u, projected_v, projected_valid)
|
| 924 |
+
penalty = 1.0 - torch.exp(
|
| 925 |
+
-(1.0 / max(self.grad_sigma, 1e-8)) * F.relu(grad_at_gauss - self.grad_eps)
|
| 926 |
+
)
|
| 927 |
+
weight = projected_valid & torch.isfinite(grad_at_gauss) & torch.isfinite(op_flat)
|
| 928 |
+
mask_at_gauss = self._sample_map_at_uv(m[:, :1], projected_u, projected_v, projected_valid)
|
| 929 |
+
weight = weight & (mask_at_gauss > 0.5)
|
| 930 |
+
grad_value = op_flat.to(dtype=penalty.dtype).clamp(0, 1) * penalty
|
| 931 |
+
losses["grad"] = _finite_masked_mean_flat(grad_value, weight)
|
| 932 |
+
else:
|
| 933 |
+
raise RuntimeError(
|
| 934 |
+
"L_grad requires gaussian_mean_vectors, gaussian_opacities, "
|
| 935 |
+
"gauss_grid_shape, and projection metadata. The old "
|
| 936 |
+
"pred_alpha image-space fallback is disabled for ray-local training."
|
| 937 |
+
)
|
| 938 |
+
else:
|
| 939 |
+
losses["grad"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 940 |
+
|
| 941 |
+
if apply_grad_img and self.w.lambda_grad_img > 0:
|
| 942 |
+
losses["grad_img"] = self._sobel_gradient_loss_erp(
|
| 943 |
+
pred_depth_m=pred_depth_m,
|
| 944 |
+
gt_depth_m=gt_depth,
|
| 945 |
+
depth_weight=depth_weight,
|
| 946 |
+
circular_h=grad_img_circular_h,
|
| 947 |
+
)
|
| 948 |
+
else:
|
| 949 |
+
losses["grad_img"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 950 |
+
|
| 951 |
+
if apply_color and self.w.lambda_edge_rgb > 0:
|
| 952 |
+
depth_edge_for_rgb = self._depth_edge_band(gt_depth, depth_weight, circular_h=circular_h)
|
| 953 |
+
losses["edge_rgb"] = self._edge_rgb_gradient_loss(
|
| 954 |
+
pred_rgb_linear=pred_rgb_rendered,
|
| 955 |
+
gt_rgb_linear=gt_rgb_linear,
|
| 956 |
+
valid_weight=m,
|
| 957 |
+
depth_edge_band=depth_edge_for_rgb,
|
| 958 |
+
circular_h=circular_h,
|
| 959 |
+
)
|
| 960 |
+
else:
|
| 961 |
+
losses["edge_rgb"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 962 |
+
|
| 963 |
+
if apply_delta and self.w.lambda_delta > 0:
|
| 964 |
+
if delta_xy is not None:
|
| 965 |
+
dx = F.relu(delta_xy[:, 0:1].abs() - self.raw_delta_clip)
|
| 966 |
+
dy = F.relu(delta_xy[:, 1:2].abs() - self.raw_delta_clip)
|
| 967 |
+
losses["delta"] = (dx + dy).mean()
|
| 968 |
+
else:
|
| 969 |
+
del gaussian_base_mean_vectors
|
| 970 |
+
raise RuntimeError(
|
| 971 |
+
"L_delta requires raw delta_xy in ray-local training. The old "
|
| 972 |
+
"screen-space pixel displacement fallback is disabled."
|
| 973 |
+
)
|
| 974 |
+
else:
|
| 975 |
+
losses["delta"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 976 |
+
|
| 977 |
+
if apply_delta and self.w.lambda_delta_rho > 0 and delta_rho is not None:
|
| 978 |
+
dz = delta_rho.to(device=pred_rgb_linear.device, dtype=pred_rgb_linear.dtype)
|
| 979 |
+
finite = torch.isfinite(dz)
|
| 980 |
+
dz_safe = torch.nan_to_num(dz, nan=0.0, posinf=0.0, neginf=0.0)
|
| 981 |
+
penalty = F.relu(dz_safe.abs() - self.raw_delta_rho_clip)
|
| 982 |
+
penalty = torch.where(finite, penalty, torch.zeros_like(penalty))
|
| 983 |
+
losses["delta_rho"] = penalty.sum() / finite.to(dtype=penalty.dtype).sum().clamp(min=1.0)
|
| 984 |
+
else:
|
| 985 |
+
losses["delta_rho"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 986 |
+
|
| 987 |
+
if self.w.lambda_grid > 0 and torch.is_tensor(delta_grid):
|
| 988 |
+
losses["grid"] = _delta_grid_checkerboard_loss(
|
| 989 |
+
delta_grid.to(device=pred_rgb_linear.device),
|
| 990 |
+
circular_h=circular_h,
|
| 991 |
+
).to(dtype=pred_rgb_linear.dtype)
|
| 992 |
+
else:
|
| 993 |
+
losses["grid"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 994 |
+
|
| 995 |
+
if apply_splat and self.w.lambda_splat > 0:
|
| 996 |
+
if gaussian_scales is None:
|
| 997 |
+
raise RuntimeError("L_splat requires gaussian_scales for projected screen-space variance.")
|
| 998 |
+
if gaussian_mean_vectors is None or projected_valid is None:
|
| 999 |
+
raise RuntimeError(
|
| 1000 |
+
"L_splat requires gaussian_mean_vectors and projection metadata "
|
| 1001 |
+
"to compute projected screen-space variance."
|
| 1002 |
+
)
|
| 1003 |
+
sigma_proj = self._projected_sigma_px(
|
| 1004 |
+
gaussian_scales=gaussian_scales,
|
| 1005 |
+
gaussian_quaternions=gaussian_quaternions,
|
| 1006 |
+
gaussian_mean_vectors=gaussian_mean_vectors,
|
| 1007 |
+
valid=projected_valid,
|
| 1008 |
+
projection_model=projection_model,
|
| 1009 |
+
image_h=image_h,
|
| 1010 |
+
image_w=image_w,
|
| 1011 |
+
intrinsics=projection_intrinsics,
|
| 1012 |
+
camera_params=projection_camera_params,
|
| 1013 |
+
projected_scale_factor=projected_scale_factor,
|
| 1014 |
+
)
|
| 1015 |
+
valid_splat = projected_valid & torch.isfinite(sigma_proj)
|
| 1016 |
+
splat_sigma_min = torch.as_tensor(
|
| 1017 |
+
self.splat_sigma_min,
|
| 1018 |
+
device=sigma_proj.device,
|
| 1019 |
+
dtype=sigma_proj.dtype,
|
| 1020 |
+
)
|
| 1021 |
+
splat_sigma_max = torch.as_tensor(
|
| 1022 |
+
self.splat_sigma_max,
|
| 1023 |
+
device=sigma_proj.device,
|
| 1024 |
+
dtype=sigma_proj.dtype,
|
| 1025 |
+
)
|
| 1026 |
+
lower_penalty = F.relu(splat_sigma_min - sigma_proj)
|
| 1027 |
+
upper_penalty = F.relu(sigma_proj - splat_sigma_max)
|
| 1028 |
+
splat_penalty = lower_penalty + upper_penalty
|
| 1029 |
+
losses["splat"] = _finite_masked_mean_flat(splat_penalty, valid_splat)
|
| 1030 |
+
else:
|
| 1031 |
+
sigma_proj = None
|
| 1032 |
+
valid_splat = None
|
| 1033 |
+
losses["splat"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 1034 |
+
|
| 1035 |
+
if apply_splat and self.w.lambda_edge_splat > 0:
|
| 1036 |
+
if gaussian_scales is None:
|
| 1037 |
+
raise RuntimeError("L_edge_splat requires gaussian_scales for projected screen-space variance.")
|
| 1038 |
+
if gaussian_mean_vectors is None or projected_valid is None:
|
| 1039 |
+
raise RuntimeError(
|
| 1040 |
+
"L_edge_splat requires gaussian_mean_vectors and projection metadata "
|
| 1041 |
+
"to sample source depth-edge bands."
|
| 1042 |
+
)
|
| 1043 |
+
if sigma_proj is None or valid_splat is None:
|
| 1044 |
+
sigma_proj = self._projected_sigma_px(
|
| 1045 |
+
gaussian_scales=gaussian_scales,
|
| 1046 |
+
gaussian_quaternions=gaussian_quaternions,
|
| 1047 |
+
gaussian_mean_vectors=gaussian_mean_vectors,
|
| 1048 |
+
valid=projected_valid,
|
| 1049 |
+
projection_model=projection_model,
|
| 1050 |
+
image_h=image_h,
|
| 1051 |
+
image_w=image_w,
|
| 1052 |
+
intrinsics=projection_intrinsics,
|
| 1053 |
+
camera_params=projection_camera_params,
|
| 1054 |
+
projected_scale_factor=projected_scale_factor,
|
| 1055 |
+
)
|
| 1056 |
+
valid_splat = projected_valid & torch.isfinite(sigma_proj)
|
| 1057 |
+
edge_band = self._depth_edge_band(gt_depth, depth_weight, circular_h=circular_h)
|
| 1058 |
+
edge_at_gauss = self._sample_map_at_uv(edge_band, projected_u, projected_v, projected_valid)
|
| 1059 |
+
edge_valid = valid_splat & torch.isfinite(edge_at_gauss) & (edge_at_gauss > 0.5)
|
| 1060 |
+
edge_sigma_max = torch.as_tensor(
|
| 1061 |
+
self.edge_splat_sigma_max,
|
| 1062 |
+
device=sigma_proj.device,
|
| 1063 |
+
dtype=sigma_proj.dtype,
|
| 1064 |
+
)
|
| 1065 |
+
losses["edge_splat"] = _finite_masked_mean_flat(F.relu(sigma_proj - edge_sigma_max), edge_valid)
|
| 1066 |
+
else:
|
| 1067 |
+
losses["edge_splat"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 1068 |
+
|
| 1069 |
+
zero = torch.zeros((), device=pred_rgb_linear.device)
|
| 1070 |
+
losses["percep_feat"] = zero
|
| 1071 |
+
losses["percep_gram"] = zero
|
| 1072 |
+
if apply_percep and self.w.lambda_percep > 0 and (self._percep_net is not None):
|
| 1073 |
+
from unisharp.utils.color_space import linearRGB2sRGB
|
| 1074 |
+
|
| 1075 |
+
pred_srgb = linearRGB2sRGB(pred_rgb_rendered.to(torch.float32)).clamp(0, 1)
|
| 1076 |
+
gt_srgb = gt_rgb.clamp(0, 1)
|
| 1077 |
+
|
| 1078 |
+
pred_srgb = _resize_max_side(pred_srgb, self.percep_max_side, mode="bilinear")
|
| 1079 |
+
gt_srgb = _resize_max_side(gt_srgb, self.percep_max_side, mode="bilinear")
|
| 1080 |
+
|
| 1081 |
+
feats_p = self._percep_net(pred_srgb)
|
| 1082 |
+
feats_g = self._percep_net(gt_srgb)
|
| 1083 |
+
loss_feat_total = torch.zeros((), device=pred_rgb_linear.device)
|
| 1084 |
+
loss_gram_total = torch.zeros((), device=pred_rgb_linear.device)
|
| 1085 |
+
for fp, fg in zip(feats_p, feats_g):
|
| 1086 |
+
d, h, w = fp.shape[1], fp.shape[2], fp.shape[3]
|
| 1087 |
+
lam_gram = 10.0 / float(max(1, d * d))
|
| 1088 |
+
lam_feat = 1.0 / float(max(1, d * h * w))
|
| 1089 |
+
diff = (fp - fg).pow(2)
|
| 1090 |
+
loss_feat = (diff.sum(dim=[1, 2, 3]) * lam_feat).mean()
|
| 1091 |
+
gram_norm = float(max(1, h * w))
|
| 1092 |
+
gp = _gram_matrix(fp) / gram_norm
|
| 1093 |
+
gg = _gram_matrix(fg) / gram_norm
|
| 1094 |
+
loss_gram = ((gp - gg).pow(2).sum(dim=[1, 2]) * lam_gram).mean()
|
| 1095 |
+
loss_feat_total = loss_feat_total + loss_feat
|
| 1096 |
+
loss_gram_total = loss_gram_total + loss_gram
|
| 1097 |
+
layer_count = float(max(1, len(feats_p)))
|
| 1098 |
+
losses["percep_feat"] = loss_feat_total / layer_count
|
| 1099 |
+
losses["percep_gram"] = loss_gram_total / layer_count
|
| 1100 |
+
losses["percep"] = losses["percep_feat"] + losses["percep_gram"]
|
| 1101 |
+
else:
|
| 1102 |
+
losses["percep"] = torch.zeros((), device=pred_rgb_linear.device)
|
| 1103 |
+
|
| 1104 |
+
losses["total"] = (
|
| 1105 |
+
self.w.lambda_color * losses["color"]
|
| 1106 |
+
+ self.w.lambda_alpha * losses["alpha"]
|
| 1107 |
+
+ self.w.lambda_percep * losses["percep"]
|
| 1108 |
+
+ self.w.lambda_depth * losses["depth"]
|
| 1109 |
+
+ self.w.lambda_tv * losses["tv"]
|
| 1110 |
+
+ self.w.lambda_grad * losses["grad"]
|
| 1111 |
+
+ self.w.lambda_grad_img * losses["grad_img"]
|
| 1112 |
+
+ self.w.lambda_edge_rgb * losses["edge_rgb"]
|
| 1113 |
+
+ self.w.lambda_delta * losses["delta"]
|
| 1114 |
+
+ self.w.lambda_delta_rho * losses["delta_rho"]
|
| 1115 |
+
+ self.w.lambda_splat * losses["splat"]
|
| 1116 |
+
+ self.w.lambda_edge_splat * losses["edge_splat"]
|
| 1117 |
+
+ self.w.lambda_grid * losses["grid"]
|
| 1118 |
+
)
|
| 1119 |
+
return losses
|
| 1120 |
+
|
unisharp/models/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from .feature_gaussian_decoder import (
|
| 5 |
+
FeatureGaussianDecoder,
|
| 6 |
+
FeatureGaussianDecoderParams,
|
| 7 |
+
ImageFeatures,
|
| 8 |
+
create_feature_gaussian_decoder,
|
| 9 |
+
)
|
| 10 |
+
from .unisharp_params import PanoPredictorParams
|
| 11 |
+
from .unisharp_feature import UnisharpFeatureConfig, UnisharpFeatureModel
|
| 12 |
+
from .unik3d_feature_extractor import UniK3DFeatureExtractor
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"PanoPredictorParams",
|
| 16 |
+
"UniK3DFeatureExtractor",
|
| 17 |
+
"FeatureGaussianDecoder",
|
| 18 |
+
"FeatureGaussianDecoderParams",
|
| 19 |
+
"ImageFeatures",
|
| 20 |
+
"create_feature_gaussian_decoder",
|
| 21 |
+
"UnisharpFeatureConfig",
|
| 22 |
+
"UnisharpFeatureModel",
|
| 23 |
+
]
|
unisharp/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (709 Bytes). View file
|
|
|
unisharp/models/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (734 Bytes). View file
|
|
|
unisharp/models/__pycache__/blocks.cpython-310.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
unisharp/models/__pycache__/blocks.cpython-313.pyc
ADDED
|
Binary file (9.24 kB). View file
|
|
|
unisharp/models/__pycache__/decoder.cpython-310.pyc
ADDED
|
Binary file (3.11 kB). View file
|
|
|
unisharp/models/__pycache__/decoder.cpython-313.pyc
ADDED
|
Binary file (5.09 kB). View file
|
|
|