diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..6a2e9051cbd9c691c50f08e0bdc2a2849cc21148 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+FastVGGT/assets/attn_map.png filter=lfs diff=lfs merge=lfs -text
+FastVGGT/assets/autolab_logo.png filter=lfs diff=lfs merge=lfs -text
+FastVGGT/assets/main.png filter=lfs diff=lfs merge=lfs -text
+FastVGGT/assets/vs.png filter=lfs diff=lfs merge=lfs -text
diff --git a/FastVGGT/.gitignore b/FastVGGT/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..fc1e52d1adf9afe09ea2ead132f8fa4babf666cc
--- /dev/null
+++ b/FastVGGT/.gitignore
@@ -0,0 +1,160 @@
+.hydra/
+output/
+ckpt/
+.vscode/
+dependency/
+# Byte-compiled / optimized / DLL files
+__pycache__/
+**/__pycache__/
+*.py[cod]
+*$py.class
+test_logs/
+quick_start_logs/
+logs/
+*.pth
+/data/
+*.png
+eval_results/
+.vscode/
+.curosr/
+
+# C extensions
+*.so
+LightGlue/
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Profiling data
+.prof
+
+# Folder specific to your needs
+**/tmp/
+**/outputs/skyseg.onnx
+skyseg.onnx
+
+# pixi environments
+.pixi
+*.egg-info
diff --git a/FastVGGT/.vscode/launch.json b/FastVGGT/.vscode/launch.json
new file mode 100644
index 0000000000000000000000000000000000000000..218c5dcf94c00a2c11d02a0a1deeab8f37807ebf
--- /dev/null
+++ b/FastVGGT/.vscode/launch.json
@@ -0,0 +1,85 @@
+{
+ // Use IntelliSense to learn about possible attributes.
+ // Hover to view descriptions of existing attributes.
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+ "version": "0.2.0",
+ "configurations": [
+
+ {
+ "name": "launch",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "/home/sy/code/vggt_0625/training/launch.py",
+ "console": "integratedTerminal",
+ "args": "${command:pickArgs}",
+ "env": {
+ "CUDA_VISIBLE_DEVICES": "3",
+ },
+ "cwd": "/home/sy/code/vggt_0625/training",
+ "justMyCode": true,
+ "python": "/home/sy/anaconda3/envs/vggt/bin/python"
+ }
+ ,{
+ "name": "train_scannet",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "/home/sy/code/vggt_0625/training/launch_scannet.py",
+ "console": "integratedTerminal",
+ "args": [
+ // "--config_name", "scannet",
+ // "--exp_name", "scannet_exp001",
+ // "--resume_checkpoint_path", "/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt"
+ ],
+ "env": {
+ "CUDA_VISIBLE_DEVICES": "7",
+ "WORLD_SIZE": "1",
+ "RANK": "0",
+ "MASTER_ADDR": "localhost",
+ "MASTER_PORT": "12345"
+ },
+ "cwd": "/home/sy/code/vggt_0625/training",
+ "justMyCode": true,
+ "python": "/home/sy/anaconda3/envs/vggt/bin/python"
+ }
+ ,{
+ "name": "eval_scannet",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "/home/sy/code/FastVGGT/eval/eval_scannet.py",
+ "console": "integratedTerminal",
+ "args": [
+ "--data_dir","/data/sy/scannetv2/process_scannet",
+ "--gt_ply_dir","/data/sy/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
+ "--output_path", "/home/sy/code/FastVGGT/eval_results",
+ "--merging", "0",
+ "--ckpt_path","/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt",
+ "--vis_attn_map"
+ ],
+ "env": {
+ "CUDA_VISIBLE_DEVICES": "2"
+ },
+ "justMyCode": true,
+ "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
+ },
+ {
+ "name": "eval_cd",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "/home/sy/code/FastVGGT/eval/eval_custom.py",
+ "console": "integratedTerminal",
+ "args": [
+ "--merging", "0",
+ // "--kf","10",
+ // "--output_dir","/home/sy/code/vggt_0625/eval_results_cd",
+ "--data_path","/data/sy/segment-102751/",
+ "--vis_attn_map"
+ ],
+ "env": {
+ "CUDA_VISIBLE_DEVICES": "3"
+ },
+ "justMyCode": true,
+ // "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
+ }
+
+ ]
+}
\ No newline at end of file
diff --git a/FastVGGT/README.md b/FastVGGT/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..274ea91f0e398fdf0e94b921fb5bb7c48fe799fd
--- /dev/null
+++ b/FastVGGT/README.md
@@ -0,0 +1,163 @@
+
+
⚡️ FastVGGT: Training-Free Acceleration of Visual Geometry Transformer
+
+
+
+
+
+
+

+

+
+
+**[Media Analytics & Computing Laboratory](https://mac.xmu.edu.cn/)**; **[AUTOLAB](https://zhipengzhang.cn/)**
+
+
+[You Shen](https://mystorm16.github.io/), [Zhipeng Zhang](https://zhipengzhang.cn/), [Yansong Qu](https://quyans.github.io/), [Liujuan Cao](https://mac.xmu.edu.cn/ljcao/)
+
+
+
+## 📰 News
+- [Sep 8, 2025] Added custom dataset evaluation.
+- [Sep 3, 2025] Paper release.
+- [Sep 2, 2025] Code release.
+
+## 🔭 Overview
+
+FastVGGT observes **strong similarity** in attention maps and leverages it to design a training-free acceleration method for long-sequence 3D reconstruction, **achieving up to 4× faster inference without sacrificing accuracy.**
+
+
+
+
+## ⚙️ Environment Setup
+First, create a virtual environment using Conda, clone this repository to your local machine, and install the required dependencies.
+
+
+```bash
+conda create -n fastvggt python=3.10
+conda activate fastvggt
+git clone git@github.com:mystorm16/FastVGGT.git
+cd FastVGGT
+pip install -r requirements.txt
+```
+
+Next, prepare the ScanNet dataset: http://www.scan-net.org/ScanNet/
+
+Then, download the VGGT checkpoint (we use the checkpoint link provided in https://github.com/facebookresearch/vggt/tree/evaluation/evaluation):
+```bash
+wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt
+```
+
+Finally, configure the dataset path and VGGT checkpoint path. For example:
+```bash
+ parser.add_argument(
+ "--data_dir", type=Path, default="/data/scannetv2/process_scannet"
+ )
+ parser.add_argument(
+ "--gt_ply_dir",
+ type=Path,
+ default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
+ )
+ parser.add_argument(
+ "--ckpt_path",
+ type=str,
+ default="./ckpt/model_tracker_fixed_e20.pt",
+ )
+```
+
+
+## 💎 Observation
+
+Note: A large number of input_frames may significantly slow down saving the visualization results. Please try using a smaller number first.
+```bash
+python eval/eval_scannet.py --input_frame 30 --vis_attn_map --merging 0
+```
+
+We observe that many token-level attention maps are highly similar in each block, motivating our optimization of the Global Attention module.
+
+
+
+
+
+## 🏀 Evaluation
+### Custom Dataset
+Please organize the data according to the following directory:
+```
+/
+├── images/
+│ ├── 000000.jpg
+│ ├── 000001.jpg
+│ └── ...
+├── pose/ # Optional: Camera poses
+│ ├── 000000.txt
+│ ├── 000001.txt
+│ └── ...
+└── gt_ply/ # Optional: GT point cloud
+ └── scene_xxx.ply
+```
+- Required: `images/`
+- Additionally required when `--enable_evaluation` is enabled: `pose/` and `gt_ply/`
+
+Inference only:
+
+```bash
+python eval/eval_custom.py \
+ --data_path /path/to/your_dataset \
+ --output_path ./eval_results_custom \
+ --plot
+```
+
+Inference + Evaluation (requires `pose/` and `gt_ply/`):
+
+```bash
+python eval/eval_custom.py \
+ --data_path /path/to/your_dataset \
+ --enable_evaluation \
+ --output_path ./eval_results_custom \
+ --plot
+```
+
+### ScanNet
+Evaluate FastVGGT on the ScanNet dataset with 1,000 input images. The **--merging** parameter specifies the block index at which the merging strategy is applied:
+
+```bash
+python eval/eval_scannet.py --input_frame 1000 --merging 0
+```
+
+Evaluate Baseline VGGT on the ScanNet dataset with 1,000 input images:
+```bash
+python eval/eval_scannet.py --input_frame 1000
+```
+
+
+### 7 Scenes & NRGBD
+Evaluate across two datasets, sampling keyframes every 10 frames:
+```bash
+python eval/eval_7andN.py --kf 10
+```
+
+## 🍺 Acknowledgements
+
+- Thanks to these great repositories: [VGGT](https://github.com/facebookresearch/vggt), [Dust3r](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [CUT3R](https://github.com/CUT3R/CUT3R), [MV-DUSt3R+](https://github.com/facebookresearch/mvdust3r), [StreamVGGT](https://github.com/wzzheng/StreamVGGT), [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long), [ToMeSD](https://github.com/dbolya/tomesd) and many other inspiring works in the community.
+
+- Special thanks to [Jianyuan Wang](https://jytime.github.io/) for his valuable discussions and suggestions on this work.
+
+
+
+
+## ⚖️ License
+See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
+
+## Citation
+
+If you find this project helpful, please consider citing the following paper:
+```
+@article{shen2025fastvggt,
+ title={FastVGGT: Training-Free Acceleration of Visual Geometry Transformer},
+ author={Shen, You and Zhang, Zhipeng and Qu, Yansong and Cao, Liujuan},
+ journal={arXiv preprint arXiv:2509.02560},
+ year={2025}
+}
+```
\ No newline at end of file
diff --git a/FastVGGT/assets/attn_map.png b/FastVGGT/assets/attn_map.png
new file mode 100644
index 0000000000000000000000000000000000000000..98d6f2be8cdb49e5bccee7f902af292327fd0023
--- /dev/null
+++ b/FastVGGT/assets/attn_map.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8477957f593c203bcf41df91ac3ed0d22329e22250fdab9f8f8674340964242c
+size 2369408
diff --git a/FastVGGT/assets/autolab_logo.png b/FastVGGT/assets/autolab_logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..c6ca9a41f5c05a1b2ae66368461672f46e20bba1
--- /dev/null
+++ b/FastVGGT/assets/autolab_logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fcead3160cbf561c4385cc8b938a17a94652e3d849da6497f053d32d1245596
+size 5126003
diff --git a/FastVGGT/assets/maclab_logo.png b/FastVGGT/assets/maclab_logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..b96660decf945d3d786f41b9b42cb92798f9191f
Binary files /dev/null and b/FastVGGT/assets/maclab_logo.png differ
diff --git a/FastVGGT/assets/main.png b/FastVGGT/assets/main.png
new file mode 100644
index 0000000000000000000000000000000000000000..3bcab2d519689dfba3283e606cd55cb84175ede5
--- /dev/null
+++ b/FastVGGT/assets/main.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eecacb414647f01dc8a52b4aba5ff2556733f46d1b9129613e3f59aceff69685
+size 883717
diff --git a/FastVGGT/assets/vs.png b/FastVGGT/assets/vs.png
new file mode 100644
index 0000000000000000000000000000000000000000..3729d9ba86000b81d0e9c878b2db6f16c7f9c48a
--- /dev/null
+++ b/FastVGGT/assets/vs.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dac1ea5397b9f985c6fb86fc5f321313cc5a372d0917b7c1c1c7dd2cea6bec5f
+size 230157
diff --git a/FastVGGT/eval/__pycache__/base.cpython-310.pyc b/FastVGGT/eval/__pycache__/base.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49170288c162684ebafa68e5242efd320f0f35b8
Binary files /dev/null and b/FastVGGT/eval/__pycache__/base.cpython-310.pyc differ
diff --git a/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc b/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e6e0478775fe98dc22180bc43bad8073c43c907
Binary files /dev/null and b/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc differ
diff --git a/FastVGGT/eval/__pycache__/data.cpython-310.pyc b/FastVGGT/eval/__pycache__/data.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d8ec6f8e1828fa87b02a4511490ad858de2093a
Binary files /dev/null and b/FastVGGT/eval/__pycache__/data.cpython-310.pyc differ
diff --git a/FastVGGT/eval/__pycache__/data.cpython-37.pyc b/FastVGGT/eval/__pycache__/data.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d76f2c95748956a04f009e99f39cefd9cc6ec4e0
Binary files /dev/null and b/FastVGGT/eval/__pycache__/data.cpython-37.pyc differ
diff --git a/FastVGGT/eval/__pycache__/utils.cpython-310.pyc b/FastVGGT/eval/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..638db49f10789e53dcf0ff33c9aaa28b814e3e72
Binary files /dev/null and b/FastVGGT/eval/__pycache__/utils.cpython-310.pyc differ
diff --git a/FastVGGT/eval/__pycache__/utils.cpython-37.pyc b/FastVGGT/eval/__pycache__/utils.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fd73966947cb74c319d66d7bf221d73b5dde0ce0
Binary files /dev/null and b/FastVGGT/eval/__pycache__/utils.cpython-37.pyc differ
diff --git a/FastVGGT/eval/base.py b/FastVGGT/eval/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a716449a71f552ea408df0eb37e854cf4e92da6
--- /dev/null
+++ b/FastVGGT/eval/base.py
@@ -0,0 +1,273 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# base class for implementing datasets
+# --------------------------------------------------------
+import PIL
+import numpy as np
+import torch
+
+from dataset_utils.transforms import ImgNorm
+import dataset_utils.cropping as cropping
+from utils import depthmap_to_absolute_camera_coordinates
+
+
+class BaseStereoViewDataset:
+ """Define all basic options.
+
+ Usage:
+ class MyDataset (BaseStereoViewDataset):
+ def _get_views(self, idx, rng):
+ # overload here
+ views = []
+ views.append(dict(img=, ...))
+ return views
+ """
+
+ def __init__(
+ self,
+ *, # only keyword arguments
+ split=None,
+ resolution=None, # square_size or (width, height) or list of [(width,height), ...]
+ transform=ImgNorm,
+ aug_crop=False,
+ seed=None,
+ ):
+ self.num_views = 2
+ self.split = split
+ self._set_resolutions(resolution)
+
+ self.transform = transform
+ if isinstance(transform, str):
+ transform = eval(transform)
+
+ self.aug_crop = aug_crop
+ self.seed = seed
+
+ def __len__(self):
+ return len(self.scenes)
+
+ def get_stats(self):
+ return f"{len(self)} pairs"
+
+ def __repr__(self):
+ resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
+ return (
+ f"""{type(self).__name__}({self.get_stats()},
+ {self.split=},
+ {self.seed=},
+ resolutions={resolutions_str},
+ {self.transform=})""".replace(
+ "self.", ""
+ )
+ .replace("\n", "")
+ .replace(" ", "")
+ )
+
+ def _get_views(self, idx, resolution, rng):
+ raise NotImplementedError()
+
+ def __getitem__(self, idx):
+ if isinstance(idx, tuple):
+ # the idx is specifying the aspect-ratio
+ idx, ar_idx = idx
+ else:
+ assert len(self._resolutions) == 1
+ ar_idx = 0
+
+ # set-up the rng
+ if self.seed: # reseed for each __getitem__
+ self._rng = np.random.default_rng(seed=self.seed + idx)
+ elif not hasattr(self, "_rng"):
+ seed = torch.initial_seed() # this is different for each dataloader process
+ self._rng = np.random.default_rng(seed=seed)
+
+ # over-loaded code
+ resolution = self._resolutions[
+ ar_idx
+ ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
+ views = self._get_views(idx, resolution, self._rng)
+
+ # check data-types
+ for v, view in enumerate(views):
+ assert (
+ "pts3d" not in view
+ ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
+ view["idx"] = v
+
+ # encode the image
+ width, height = view["img"].size
+ view["true_shape"] = np.int32((height, width))
+ view["img"] = self.transform(view["img"])
+
+ assert "camera_intrinsics" in view
+ if "camera_pose" not in view:
+ view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
+ else:
+ assert np.isfinite(
+ view["camera_pose"]
+ ).all(), f"NaN in camera pose for view {view_name(view)}"
+ assert "pts3d" not in view
+ assert "valid_mask" not in view
+ assert np.isfinite(
+ view["depthmap"]
+ ).all(), f"NaN in depthmap for view {view_name(view)}"
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
+
+ view["pts3d"] = pts3d
+ view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
+
+ # check all datatypes
+ for key, val in view.items():
+ res, err_msg = is_good_type(key, val)
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
+ K = view["camera_intrinsics"]
+ view["img_mask"] = True
+ view["ray_mask"] = False
+ view["ray_map"] = torch.full(
+ (6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
+ )
+ view["update"] = True
+ view["reset"] = False
+
+ # last thing done!
+ for view in views:
+ # transpose to make sure all views are the same size
+ transpose_to_landscape(view)
+ # this allows to check whether the RNG is is the same state each time
+ view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
+ return views
+
+ def _set_resolutions(self, resolutions):
+ """Set the resolution(s) of the dataset.
+ Params:
+ - resolutions: int or tuple or list of tuples
+ """
+ assert resolutions is not None, "undefined resolution"
+
+ if not isinstance(resolutions, list):
+ resolutions = [resolutions]
+
+ self._resolutions = []
+ for resolution in resolutions:
+ if isinstance(resolution, int):
+ width = height = resolution
+ else:
+ width, height = resolution
+ assert isinstance(
+ width, int
+ ), f"Bad type for {width=} {type(width)=}, should be int"
+ assert isinstance(
+ height, int
+ ), f"Bad type for {height=} {type(height)=}, should be int"
+ assert width >= height
+ self._resolutions.append((width, height))
+
+ def _crop_resize_if_necessary(
+ self, image, depthmap, intrinsics, resolution, rng=None, info=None
+ ):
+ """This function:
+ - first downsizes the image with LANCZOS inteprolation,
+ which is better than bilinear interpolation in
+ """
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+
+ # downscale with lanczos interpolation so that image.size == resolution
+ # cropping centered on the principal point
+ W, H = image.size
+ cx, cy = intrinsics[:2, 2].round().astype(int)
+
+ # calculate min distance to margin
+ min_margin_x = min(cx, W - cx)
+ min_margin_y = min(cy, H - cy)
+ assert min_margin_x > W / 5, f"Bad principal point in view={info}"
+ assert min_margin_y > H / 5, f"Bad principal point in view={info}"
+
+ ## Center crop
+ # Crop on the principal point, make it always centered
+ # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
+ l, t = cx - min_margin_x, cy - min_margin_y
+ r, b = cx + min_margin_x, cy + min_margin_y
+ crop_bbox = (l, t, r, b)
+
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(
+ image, depthmap, intrinsics, crop_bbox
+ )
+
+ # # transpose the resolution if necessary
+ W, H = image.size # new size
+ assert resolution[0] >= resolution[1]
+ if H > 1.1 * W:
+ # image is portrait mode
+ resolution = resolution[::-1]
+ elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
+ # image is square, so we chose (portrait, landscape) randomly
+ if rng.integers(2):
+ resolution = resolution[::-1]
+
+ # high-quality Lanczos down-scaling
+ target_resolution = np.array(resolution)
+ # # if self.aug_crop > 1:
+ # # target_resolution += rng.integers(0, self.aug_crop)
+ # if resolution != (224, 224):
+ # halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
+ # ## Recale with max factor, so one of width or height might be larger than target_resolution
+ # image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
+ # else:
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(
+ image, depthmap, intrinsics, target_resolution
+ )
+ # actual cropping (if necessary) with bilinear interpolation
+ # if resolution == (224, 224):
+ intrinsics2 = cropping.camera_matrix_of_crop(
+ intrinsics, image.size, resolution, offset_factor=0.5
+ )
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(
+ intrinsics, intrinsics2, resolution
+ )
+ image, depthmap, intrinsics = cropping.crop_image_depthmap(
+ image, depthmap, intrinsics, crop_bbox
+ )
+ return image, depthmap, intrinsics
+
+
+def is_good_type(key, v):
+ """returns (is_good, err_msg)"""
+ if isinstance(v, (str, int, tuple)):
+ return True, None
+ if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
+ return False, f"bad {v.dtype=}"
+ return True, None
+
+
+def view_name(view, batch_index=None):
+ def sel(x):
+ return x[batch_index] if batch_index not in (None, slice(None)) else x
+
+ db = sel(view["dataset"])
+ label = sel(view["label"])
+ instance = sel(view["instance"])
+ return f"{db}/{label}/{instance}"
+
+
+def transpose_to_landscape(view):
+ height, width = view["true_shape"]
+
+ if width < height:
+ # rectify portrait to landscape
+ assert view["img"].shape == (3, height, width)
+ view["img"] = view["img"].swapaxes(1, 2)
+
+ assert view["valid_mask"].shape == (height, width)
+ view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
+
+ assert view["depthmap"].shape == (height, width)
+ view["depthmap"] = view["depthmap"].swapaxes(0, 1)
+
+ assert view["pts3d"].shape == (height, width, 3)
+ view["pts3d"] = view["pts3d"].swapaxes(0, 1)
+
+ # transpose x and y pixels
+ view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
diff --git a/FastVGGT/eval/criterion.py b/FastVGGT/eval/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..a63c991b5fd5b61325b56055d7d380a27a336971
--- /dev/null
+++ b/FastVGGT/eval/criterion.py
@@ -0,0 +1,534 @@
+import torch
+import torch.nn as nn
+from copy import copy, deepcopy
+
+from eval.dataset_utils.corr import geotrf, inv
+
+
+def invalid_to_nans(arr, valid_mask, ndim=999):
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = float("nan")
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr
+
+
+def invalid_to_zeros(arr, valid_mask, ndim=999):
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = 0
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
+ else:
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr, nnz
+
+
+class BaseCriterion(nn.Module):
+ def __init__(self, reduction="mean"):
+ super().__init__()
+ self.reduction = reduction
+
+
+class Criterion(nn.Module):
+ def __init__(self, criterion=None):
+ super().__init__()
+ assert isinstance(
+ criterion, BaseCriterion
+ ), f"{criterion} is not a proper criterion!"
+ self.criterion = copy(criterion)
+
+ def get_name(self):
+ return f"{type(self).__name__}({self.criterion})"
+
+ def with_reduction(self, mode="none"):
+ res = loss = deepcopy(self)
+ while loss is not None:
+ assert isinstance(loss, Criterion)
+ loss.criterion.reduction = mode # make it return the loss for each sample
+ loss = loss._loss2 # we assume loss is a Multiloss
+ return res
+
+
+class MultiLoss(nn.Module):
+ """Easily combinable losses (also keep track of individual loss values):
+ loss = MyLoss1() + 0.1*MyLoss2()
+ Usage:
+ Inherit from this class and override get_name() and compute_loss()
+ """
+
+ def __init__(self):
+ super().__init__()
+ self._alpha = 1
+ self._loss2 = None
+
+ def compute_loss(self, *args, **kwargs):
+ raise NotImplementedError()
+
+ def get_name(self):
+ raise NotImplementedError()
+
+ def __mul__(self, alpha):
+ assert isinstance(alpha, (int, float))
+ res = copy(self)
+ res._alpha = alpha
+ return res
+
+ __rmul__ = __mul__ # same
+
+ def __add__(self, loss2):
+ assert isinstance(loss2, MultiLoss)
+ res = cur = copy(self)
+
+ while cur._loss2 is not None:
+ cur = cur._loss2
+ cur._loss2 = loss2
+ return res
+
+ def __repr__(self):
+ name = self.get_name()
+ if self._alpha != 1:
+ name = f"{self._alpha:g}*{name}"
+ if self._loss2:
+ name = f"{name} + {self._loss2}"
+ return name
+
+ def forward(self, *args, **kwargs):
+ loss = self.compute_loss(*args, **kwargs)
+ if isinstance(loss, tuple):
+ loss, details = loss
+ elif loss.ndim == 0:
+ details = {self.get_name(): float(loss)}
+ else:
+ details = {}
+ loss = loss * self._alpha
+
+ if self._loss2:
+ loss2, details2 = self._loss2(*args, **kwargs)
+ loss = loss + loss2
+ details |= details2
+
+ return loss, details
+
+
+class LLoss(BaseCriterion):
+ """L-norm loss"""
+
+ def forward(self, a, b):
+ assert (
+ a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
+ ), f"Bad shape = {a.shape}"
+ dist = self.distance(a, b)
+
+ if self.reduction == "none":
+ return dist
+ if self.reduction == "sum":
+ return dist.sum()
+ if self.reduction == "mean":
+ return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
+ raise ValueError(f"bad {self.reduction=} mode")
+
+ def distance(self, a, b):
+ raise NotImplementedError()
+
+
+class L21Loss(LLoss):
+ """Euclidean distance between 3d points"""
+
+ def distance(self, a, b):
+ return torch.norm(a - b, dim=-1) # normalized L2 distance
+
+
+L21 = L21Loss()
+
+
+def get_pred_pts3d(gt, pred, use_pose=False):
+ assert use_pose is True
+ return pred["pts3d_in_other_view"] # return!
+
+
+def Sum(losses, masks, conf=None):
+ loss, mask = losses[0], masks[0]
+ if loss.ndim > 0:
+ # we are actually returning the loss for every pixels
+ if conf is not None:
+ return losses, masks, conf
+ return losses, masks
+ else:
+ # we are returning the global loss
+ for loss2 in losses[1:]:
+ loss = loss + loss2
+ return loss
+
+
+def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
+ assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
+ assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
+ norm_mode, dis_mode = norm_mode.split("_")
+
+ nan_pts = []
+ nnzs = []
+
+ if norm_mode == "avg":
+ # gather all points together (joint normalization)
+
+ for i, pt in enumerate(pts):
+ nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
+ nan_pts.append(nan_pt)
+ nnzs.append(nnz)
+
+ if fix_first:
+ break
+ all_pts = torch.cat(nan_pts, dim=1)
+
+ # compute distance to origin
+ all_dis = all_pts.norm(dim=-1)
+ if dis_mode == "dis":
+ pass # do nothing
+ elif dis_mode == "log1p":
+ all_dis = torch.log1p(all_dis)
+ else:
+ raise ValueError(f"bad {dis_mode=}")
+
+ norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
+ else:
+ raise ValueError(f"Not implemented {norm_mode=}")
+
+ norm_factor = norm_factor.clip(min=1e-8)
+ while norm_factor.ndim < pts[0].ndim:
+ norm_factor.unsqueeze_(-1)
+
+ return norm_factor
+
+
+def normalize_pointcloud_t(
+ pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
+):
+ if gt:
+ norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
+ res = []
+
+ for i, pt in enumerate(pts):
+ res.append(pt / norm_factor)
+
+ else:
+ # pts_l, pts_r = pts
+ # use pts_l and pts_r[-1] as pts to normalize
+ norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
+
+ res = []
+
+ for i in range(len(pts)):
+ res.append(pts[i] / norm_factor)
+ # res_r.append(pts_r[i] / norm_factor)
+
+ # res = [res_l, res_r]
+
+ return res, norm_factor
+
+
+@torch.no_grad()
+def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
+ # set invalid points to NaN
+ _zs = []
+ for i in range(len(zs)):
+ valid_mask = valid_masks[i] if valid_masks is not None else None
+ _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
+ _zs.append(_z)
+
+ _zs = torch.cat(_zs, dim=-1)
+
+ # compute median depth overall (ignoring nans)
+ if quantile == 0.5:
+ shift_z = torch.nanmedian(_zs, dim=-1).values
+ else:
+ shift_z = torch.nanquantile(_zs, quantile, dim=-1)
+ return shift_z # (B,)
+
+
+@torch.no_grad()
+def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
+ # set invalid points to NaN
+
+ _pts = []
+ for i in range(len(pts)):
+ valid_mask = valid_masks[i] if valid_masks is not None else None
+ _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
+ _pts.append(_pt)
+
+ _pts = torch.cat(_pts, dim=1)
+
+ # compute median center
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
+ if z_only:
+ _center[..., :2] = 0 # do not center X and Y
+
+ # compute median norm
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
+ scale = torch.nanmedian(_norm, dim=1).values
+ return _center[:, None, :, :], scale[:, None, None, None]
+
+
+class Regr3D_t(Criterion, MultiLoss):
+ def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
+ super().__init__(criterion)
+ self.norm_mode = norm_mode
+ self.gt_scale = gt_scale
+ self.fix_first = fix_first
+
+ def get_all_pts3d_t(self, gts, preds, dist_clip=None):
+ # everything is normalized w.r.t. camera of view1
+ in_camera1 = inv(gts[0]["camera_pose"])
+
+ gt_pts = []
+ valids = []
+ pr_pts = []
+
+ for i, gt in enumerate(gts):
+ # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
+ gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
+ valid = gt["valid_mask"].clone()
+
+ if dist_clip is not None:
+ # points that are too far-away == invalid
+ dis = gt["pts3d"].norm(dim=-1)
+ valid = valid & (dis <= dist_clip)
+
+ valids.append(valid)
+ pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
+ # if i != len(gts)-1:
+ # pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))
+
+ # if i != 0:
+ # pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))
+
+ # pr_pts = (pr_pts_l, pr_pts_r)
+
+ if self.norm_mode:
+ pr_pts, pr_factor = normalize_pointcloud_t(
+ pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
+ )
+ else:
+ pr_factor = None
+
+ if self.norm_mode and not self.gt_scale:
+ gt_pts, gt_factor = normalize_pointcloud_t(
+ gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
+ )
+ else:
+ gt_factor = None
+
+ return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}
+
+ def compute_frame_loss(self, gts, preds, **kw):
+ gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
+ self.get_all_pts3d_t(gts, preds, **kw)
+ )
+
+ pred_pts_l, pred_pts_r = pred_pts
+
+ loss_all = []
+ mask_all = []
+ conf_all = []
+
+ loss_left = 0
+ loss_right = 0
+ pred_conf_l = 0
+ pred_conf_r = 0
+
+ for i in range(len(gt_pts)):
+
+ # Left (Reference)
+ if i != len(gt_pts) - 1:
+ frame_loss = self.criterion(
+ pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
+ )
+
+ loss_all.append(frame_loss)
+ mask_all.append(masks[i])
+ conf_all.append(preds[i][0]["conf"])
+
+ # To compare target/reference loss
+ if i != 0:
+ loss_left += frame_loss.cpu().detach().numpy().mean()
+ pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()
+
+ # Right (Target)
+ if i != 0:
+ frame_loss = self.criterion(
+ pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
+ )
+
+ loss_all.append(frame_loss)
+ mask_all.append(masks[i])
+ conf_all.append(preds[i - 1][1]["conf"])
+
+ # To compare target/reference loss
+ if i != len(gt_pts) - 1:
+ loss_right += frame_loss.cpu().detach().numpy().mean()
+ pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()
+
+ if pr_factor is not None and gt_factor is not None:
+ filter_factor = pr_factor[pr_factor > gt_factor]
+ else:
+ filter_factor = []
+
+ if len(filter_factor) > 0:
+ factor_loss = (filter_factor - gt_factor).abs().mean()
+ else:
+ factor_loss = 0.0
+
+ self_name = type(self).__name__
+ details = {
+ self_name + "_pts3d_1": float(loss_all[0].mean()),
+ self_name + "_pts3d_2": float(loss_all[1].mean()),
+ self_name + "loss_left": float(loss_left),
+ self_name + "loss_right": float(loss_right),
+ self_name + "conf_left": float(pred_conf_l),
+ self_name + "conf_right": float(pred_conf_r),
+ }
+
+ return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss
+
+
+class ConfLoss_t(MultiLoss):
+ """Weighted regression by learned confidence.
+ Assuming the input pixel_loss is a pixel-level regression loss.
+
+ Principle:
+ high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
+ low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
+
+ alpha: hyperparameter
+ """
+
+ def __init__(self, pixel_loss, alpha=1):
+ super().__init__()
+ assert alpha > 0
+ self.alpha = alpha
+ self.pixel_loss = pixel_loss.with_reduction("none")
+
+ def get_name(self):
+ return f"ConfLoss({self.pixel_loss})"
+
+ def get_conf_log(self, x):
+ return x, torch.log(x)
+
+ def compute_frame_loss(self, gts, preds, **kw):
+ # compute per-pixel loss
+ (losses, masks, confs), details, loss_factor = (
+ self.pixel_loss.compute_frame_loss(gts, preds, **kw)
+ )
+
+ # weight by confidence
+ conf_losses = []
+ conf_sum = 0
+ for i in range(len(losses)):
+ conf, log_conf = self.get_conf_log(confs[i][masks[i]])
+ conf_sum += conf.mean()
+ conf_loss = losses[i] * conf - self.alpha * log_conf
+ conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
+ conf_losses.append(conf_loss)
+
+ conf_losses = torch.stack(conf_losses) * 2.0
+ conf_loss_mean = conf_losses.mean()
+
+ return (
+ conf_loss_mean,
+ dict(
+ conf_loss_1=float(conf_losses[0]),
+ conf_loss2=float(conf_losses[1]),
+ conf_mean=conf_sum / len(losses),
+ **details,
+ ),
+ loss_factor,
+ )
+
+
+class Regr3D_t_ShiftInv(Regr3D_t):
+ """Same than Regr3D but invariant to depth shift."""
+
+ def get_all_pts3d_t(self, gts, preds):
+ # compute unnormalized points
+ gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
+ super().get_all_pts3d_t(gts, preds)
+ )
+
+ # pred_pts_l, pred_pts_r = pred_pts
+ gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]
+
+ pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
+ # pred_zs.append(pred_pts_r[-1][..., 2])
+
+ # compute median depth
+ gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
+ pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]
+
+ # subtract the median depth
+ for i in range(len(gt_pts)):
+ gt_pts[i][..., 2] -= gt_shift_z
+
+ for i in range(len(pred_pts)):
+ # for j in range(len(pred_pts[i])):
+ pred_pts[i][..., 2] -= pred_shift_z
+
+ monitoring = dict(
+ monitoring,
+ gt_shift_z=gt_shift_z.mean().detach(),
+ pred_shift_z=pred_shift_z.mean().detach(),
+ )
+ return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
+
+
+class Regr3D_t_ScaleInv(Regr3D_t):
+ """Same than Regr3D but invariant to depth shift.
+ if gt_scale == True: enforce the prediction to take the same scale than GT
+ """
+
+ def get_all_pts3d_t(self, gts, preds):
+ # compute depth-normalized points
+ gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
+ super().get_all_pts3d_t(gts, preds)
+ )
+
+ # measure scene scale
+
+ # pred_pts_l, pred_pts_r = pred_pts
+
+ pred_pts_all = [
+ x.clone() for x in pred_pts
+ ] # [pred_pt for pred_pt in pred_pts_l]
+ # pred_pts_all.append(pred_pts_r[-1])
+
+ _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
+ _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)
+
+ # prevent predictions to be in a ridiculous range
+ pred_scale = pred_scale.clip(min=1e-3, max=1e3)
+
+ # subtract the median depth
+ if self.gt_scale:
+ for i in range(len(pred_pts)):
+ # for j in range(len(pred_pts[i])):
+ pred_pts[i] *= gt_scale / pred_scale
+
+ else:
+ for i in range(len(pred_pts)):
+ # for j in range(len(pred_pts[i])):
+ pred_pts[i] *= pred_scale / gt_scale
+
+ for i in range(len(gt_pts)):
+ gt_pts[i] *= gt_scale / pred_scale
+
+ monitoring = dict(
+ monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
+ )
+
+ return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
+
+
+class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
+ # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
+ pass
diff --git a/FastVGGT/eval/data.py b/FastVGGT/eval/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..301788e3ae6b67091fb031b068ecbecff1eae48d
--- /dev/null
+++ b/FastVGGT/eval/data.py
@@ -0,0 +1,338 @@
+import os
+import cv2
+import numpy as np
+import os.path as osp
+from collections import deque
+from base import BaseStereoViewDataset
+import dataset_utils.cropping as cropping
+from vggt.utils.eval_utils import imread_cv2, shuffle_deque
+
+
+class SevenScenes(BaseStereoViewDataset):
+ def __init__(
+ self,
+ num_seq=1,
+ num_frames=5,
+ min_thresh=10,
+ max_thresh=100,
+ test_id=None,
+ full_video=False,
+ tuple_list=None,
+ seq_id=None,
+ rebuttal=False,
+ shuffle_seed=-1,
+ kf_every=1,
+ *args,
+ ROOT,
+ **kwargs,
+ ):
+ self.ROOT = ROOT
+ super().__init__(*args, **kwargs)
+ self.num_seq = num_seq
+ self.num_frames = num_frames
+ self.max_thresh = max_thresh
+ self.min_thresh = min_thresh
+ self.test_id = test_id
+ self.full_video = full_video
+ self.kf_every = kf_every
+ self.seq_id = seq_id
+ self.rebuttal = rebuttal
+ self.shuffle_seed = shuffle_seed
+
+ # load all scenes
+ self.load_all_tuples(tuple_list)
+ self.load_all_scenes(ROOT)
+
+ def __len__(self):
+ if self.tuple_list is not None:
+ return len(self.tuple_list)
+ return len(self.scene_list) * self.num_seq
+
+ def load_all_tuples(self, tuple_list):
+ if tuple_list is not None:
+ self.tuple_list = tuple_list
+ # with open(tuple_path) as f:
+ # self.tuple_list = f.read().splitlines()
+
+ else:
+ self.tuple_list = None
+
+ def load_all_scenes(self, base_dir):
+
+ if self.tuple_list is not None:
+ # Use pre-defined simplerecon scene_ids
+ self.scene_list = [
+ "stairs/seq-06",
+ "stairs/seq-02",
+ "pumpkin/seq-06",
+ "chess/seq-01",
+ "heads/seq-02",
+ "fire/seq-02",
+ "office/seq-03",
+ "pumpkin/seq-03",
+ "redkitchen/seq-07",
+ "chess/seq-02",
+ "office/seq-01",
+ "redkitchen/seq-01",
+ "fire/seq-01",
+ ]
+ print(f"Found {len(self.scene_list)} sequences in split {self.split}")
+ return
+
+ scenes = os.listdir(base_dir)
+
+ file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split]
+
+ self.scene_list = []
+ for scene in scenes:
+ if self.test_id is not None and scene != self.test_id:
+ continue
+ # read file split
+ with open(osp.join(base_dir, scene, file_split)) as f:
+ seq_ids = f.read().splitlines()
+
+ for seq_id in seq_ids:
+ # seq is string, take the int part and make it 01, 02, 03
+ # seq_id = 'seq-{:2d}'.format(int(seq_id))
+ num_part = "".join(filter(str.isdigit, seq_id))
+ seq_id = f"seq-{num_part.zfill(2)}"
+ if self.seq_id is not None and seq_id != self.seq_id:
+ continue
+ self.scene_list.append(f"{scene}/{seq_id}")
+
+ print(f"Found {len(self.scene_list)} sequences in split {self.split}")
+
+ def _get_views(self, idx, resolution, rng):
+
+ if self.tuple_list is not None:
+ line = self.tuple_list[idx].split(" ")
+ scene_id = line[0]
+ img_idxs = line[1:]
+
+ else:
+ scene_id = self.scene_list[idx // self.num_seq]
+ seq_id = idx % self.num_seq
+
+ data_path = osp.join(self.ROOT, scene_id)
+ num_files = len([name for name in os.listdir(data_path) if "color" in name])
+ img_idxs = [f"{i:06d}" for i in range(num_files)]
+ img_idxs = img_idxs[:: self.kf_every]
+
+ # Intrinsics used in SimpleRecon
+ fx, fy, cx, cy = 525, 525, 320, 240
+ intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
+
+ views = []
+ imgs_idxs = deque(img_idxs)
+ if self.shuffle_seed >= 0:
+ imgs_idxs = shuffle_deque(imgs_idxs)
+
+ while len(imgs_idxs) > 0:
+ im_idx = imgs_idxs.popleft()
+ impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png")
+ depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png")
+ posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt")
+
+ rgb_image = imread_cv2(impath)
+
+ depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
+ rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
+
+ depthmap[depthmap == 65535] = 0
+ depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
+
+ depthmap[depthmap > 10] = 0
+ depthmap[depthmap < 1e-3] = 0
+
+ camera_pose = np.loadtxt(posepath).astype(np.float32)
+
+ if resolution != (224, 224) or self.rebuttal:
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
+ )
+ else:
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
+ )
+ W, H = rgb_image.size
+ cx = W // 2
+ cy = H // 2
+ l, t = cx - 112, cy - 112
+ r, b = cx + 112, cy + 112
+ crop_bbox = (l, t, r, b)
+ rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
+ rgb_image, depthmap, intrinsics, crop_bbox
+ )
+
+ views.append(
+ dict(
+ img=rgb_image,
+ depthmap=depthmap,
+ camera_pose=camera_pose,
+ camera_intrinsics=intrinsics,
+ dataset="7scenes",
+ label=osp.join(scene_id, im_idx),
+ instance=impath,
+ )
+ )
+ return views
+
+
+class NRGBD(BaseStereoViewDataset):
+ def __init__(
+ self,
+ num_seq=1,
+ num_frames=5,
+ min_thresh=10,
+ max_thresh=100,
+ test_id=None,
+ full_video=False,
+ tuple_list=None,
+ seq_id=None,
+ rebuttal=False,
+ shuffle_seed=-1,
+ kf_every=1,
+ *args,
+ ROOT,
+ **kwargs,
+ ):
+
+ self.ROOT = ROOT
+ super().__init__(*args, **kwargs)
+ self.num_seq = num_seq
+ self.num_frames = num_frames
+ self.max_thresh = max_thresh
+ self.min_thresh = min_thresh
+ self.test_id = test_id
+ self.full_video = full_video
+ self.kf_every = kf_every
+ self.seq_id = seq_id
+ self.rebuttal = rebuttal
+ self.shuffle_seed = shuffle_seed
+
+ # load all scenes
+ self.load_all_tuples(tuple_list)
+ self.load_all_scenes(ROOT)
+
+ def __len__(self):
+ if self.tuple_list is not None:
+ return len(self.tuple_list)
+ return len(self.scene_list) * self.num_seq
+
+ def load_all_tuples(self, tuple_list):
+ if tuple_list is not None:
+ self.tuple_list = tuple_list
+ # with open(tuple_path) as f:
+ # self.tuple_list = f.read().splitlines()
+
+ else:
+ self.tuple_list = None
+
+ def load_all_scenes(self, base_dir):
+
+ scenes = [
+ d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
+ ]
+
+ if self.test_id is not None:
+ self.scene_list = [self.test_id]
+
+ else:
+ self.scene_list = scenes
+
+ print(f"Found {len(self.scene_list)} sequences in split {self.split}")
+
+ def load_poses(self, path):
+ file = open(path, "r")
+ lines = file.readlines()
+ file.close()
+ poses = []
+ valid = []
+ lines_per_matrix = 4
+ for i in range(0, len(lines), lines_per_matrix):
+ if "nan" in lines[i]:
+ valid.append(False)
+ poses.append(np.eye(4, 4, dtype=np.float32).tolist())
+ else:
+ valid.append(True)
+ pose_floats = [
+ [float(x) for x in line.split()]
+ for line in lines[i : i + lines_per_matrix]
+ ]
+ poses.append(pose_floats)
+
+ return np.array(poses, dtype=np.float32), valid
+
+ def _get_views(self, idx, resolution, rng):
+
+ if self.tuple_list is not None:
+ line = self.tuple_list[idx].split(" ")
+ scene_id = line[0]
+ img_idxs = line[1:]
+
+ else:
+ scene_id = self.scene_list[idx // self.num_seq]
+
+ num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images")))
+ img_idxs = [f"{i}" for i in range(num_files)]
+ img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)]
+
+ fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240
+ intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
+
+ posepath = osp.join(self.ROOT, scene_id, f"poses.txt")
+ camera_poses, valids = self.load_poses(posepath)
+
+ imgs_idxs = deque(img_idxs)
+ if self.shuffle_seed >= 0:
+ imgs_idxs = shuffle_deque(imgs_idxs)
+ views = []
+
+ while len(imgs_idxs) > 0:
+ im_idx = imgs_idxs.popleft()
+
+ impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png")
+ depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png")
+
+ rgb_image = imread_cv2(impath)
+ depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
+ depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
+ depthmap[depthmap > 10] = 0
+ depthmap[depthmap < 1e-3] = 0
+
+ rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
+
+ camera_pose = camera_poses[int(im_idx)]
+ # gl to cv
+ camera_pose[:, 1:3] *= -1.0
+ if resolution != (224, 224) or self.rebuttal:
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
+ )
+ else:
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
+ )
+ W, H = rgb_image.size
+ cx = W // 2
+ cy = H // 2
+ l, t = cx - 112, cy - 112
+ r, b = cx + 112, cy + 112
+ crop_bbox = (l, t, r, b)
+ rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
+ rgb_image, depthmap, intrinsics, crop_bbox
+ )
+
+ views.append(
+ dict(
+ img=rgb_image,
+ depthmap=depthmap,
+ camera_pose=camera_pose,
+ camera_intrinsics=intrinsics,
+ dataset="nrgbd",
+ label=osp.join(scene_id, im_idx),
+ instance=impath,
+ )
+ )
+
+ return views
diff --git a/FastVGGT/eval/dataset_utils/__init__.py b/FastVGGT/eval/dataset_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/FastVGGT/eval/dataset_utils/__init__.py
@@ -0,0 +1 @@
+
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d00b3e8b6f5fd9ef7d1428ff771243e1557ddf9f
Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ebd5d22d44ddce633e9a05a7467e9ae0c3f9032
Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fbd10218eddf65486b375109c23cf8cecd7c3ba5
Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..860c8970fc49d65b2a4101b77cc313a6aeac618a
Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fe7870fb079d9dc08ca647866ff8c77fd0f2ec0
Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b4530be10547e3e7bf79e6cef6470880cb710a4
Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc differ
diff --git a/FastVGGT/eval/dataset_utils/corr.py b/FastVGGT/eval/dataset_utils/corr.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbf5de18b7388e38e45c3f313487957811abc483
--- /dev/null
+++ b/FastVGGT/eval/dataset_utils/corr.py
@@ -0,0 +1,234 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+
+import numpy as np
+import torch
+
+
+def todevice(batch, device, callback=None, non_blocking=False):
+ """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
+
+ batch: list, tuple, dict of tensors or other things
+ device: pytorch device or 'numpy'
+ callback: function that would be called on every sub-elements.
+ """
+ if callback:
+ batch = callback(batch)
+
+ if isinstance(batch, dict):
+ return {k: todevice(v, device) for k, v in batch.items()}
+
+ if isinstance(batch, (tuple, list)):
+ return type(batch)(todevice(x, device) for x in batch)
+
+ x = batch
+ if device == "numpy":
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ elif x is not None:
+ if isinstance(x, np.ndarray):
+ x = torch.from_numpy(x)
+ if torch.is_tensor(x):
+ x = x.to(device, non_blocking=non_blocking)
+ return x
+
+
+to_device = todevice # alias
+
+
+def to_numpy(x):
+ return todevice(x, "numpy")
+
+
+def geotrf(Trf, pts, ncol=None, norm=False):
+ """Apply a geometric transformation to a list of 3-D points.
+
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
+
+ ncol: int. number of columns of the result (2 or 3)
+ norm: float. if != 0, the resut is projected on the z=norm plane.
+
+ Returns an array of projected 2d points.
+ """
+ assert Trf.ndim >= 2
+ if isinstance(Trf, np.ndarray):
+ pts = np.asarray(pts)
+ elif isinstance(Trf, torch.Tensor):
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
+
+ output_reshape = pts.shape[:-1]
+ ncol = ncol or pts.shape[-1]
+
+ if (
+ isinstance(Trf, torch.Tensor)
+ and isinstance(pts, torch.Tensor)
+ and Trf.ndim == 3
+ and pts.ndim == 4
+ ):
+ d = pts.shape[3]
+ if Trf.shape[-1] == d:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
+ elif Trf.shape[-1] == d + 1:
+ pts = (
+ torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
+ + Trf[:, None, None, :d, d]
+ )
+ else:
+ raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
+ else:
+ if Trf.ndim >= 3:
+ n = Trf.ndim - 2
+ assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
+
+ if pts.ndim > Trf.ndim:
+
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
+ elif pts.ndim == 2:
+
+ pts = pts[:, None, :]
+
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
+ elif pts.shape[-1] == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf
+ else:
+ pts = Trf @ pts.T
+ if pts.ndim >= 2:
+ pts = pts.swapaxes(-1, -2)
+
+ if norm:
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
+ if norm != 1:
+ pts *= norm
+
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
+ return res
+
+
+def inv(mat):
+ """Invert a torch or numpy matrix"""
+ if isinstance(mat, torch.Tensor):
+ return torch.linalg.inv(mat)
+ if isinstance(mat, np.ndarray):
+ return np.linalg.inv(mat)
+ raise ValueError(f"bad matrix type = {type(mat)}")
+
+
+def reproject_view(pts3d, view2):
+ shape = view2["pts3d"].shape[:2]
+ return reproject(
+ pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
+ )
+
+
+def reproject(pts3d, K, world2cam, shape):
+ H, W, THREE = pts3d.shape
+ assert THREE == 3
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
+
+ return (H, W), ravel_xy(pos, shape)
+
+
+def ravel_xy(pos, shape):
+ H, W = shape
+ with np.errstate(invalid="ignore"):
+ qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
+ quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
+ min=0, max=H - 1, out=qy
+ )
+ return quantized_pos
+
+
+def unravel_xy(pos, shape):
+
+ return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
+
+
+def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
+ is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
+ pos1 = is_reciprocal1.nonzero()[0]
+ pos2 = corres_1_to_2[pos1]
+ if ret_recip:
+ return is_reciprocal1, pos1, pos2
+ return pos1, pos2
+
+
+def extract_correspondences_from_pts3d(
+ view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
+):
+ view1, view2 = to_numpy((view1, view2))
+
+ shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
+ shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
+
+ is_reciprocal1, pos1, pos2 = reciprocal_1d(
+ corres1_to_2, corres2_to_1, ret_recip=True
+ )
+ is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
+
+ if target_n_corres is None:
+ if ret_xy:
+ pos1 = unravel_xy(pos1, shape1)
+ pos2 = unravel_xy(pos2, shape2)
+ return pos1, pos2
+
+ available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
+ target_n_positives = int(target_n_corres * (1 - nneg))
+ n_positives = min(len(pos1), target_n_positives)
+ n_negatives = min(target_n_corres - n_positives, available_negatives)
+
+ if n_negatives + n_positives != target_n_corres:
+
+ n_positives = target_n_corres - n_negatives
+ assert n_positives <= len(pos1)
+
+ assert n_positives <= len(pos1)
+ assert n_positives <= len(pos2)
+ assert n_negatives <= (~is_reciprocal1).sum()
+ assert n_negatives <= (~is_reciprocal2).sum()
+ assert n_positives + n_negatives == target_n_corres
+
+ valid = np.ones(n_positives, dtype=bool)
+ if n_positives < len(pos1):
+
+ perm = rng.permutation(len(pos1))[:n_positives]
+ pos1 = pos1[perm]
+ pos2 = pos2[perm]
+
+ if n_negatives > 0:
+
+ def norm(p):
+ return p / p.sum()
+
+ pos1 = np.r_[
+ pos1,
+ rng.choice(
+ shape1[0] * shape1[1],
+ size=n_negatives,
+ replace=False,
+ p=norm(~is_reciprocal1),
+ ),
+ ]
+ pos2 = np.r_[
+ pos2,
+ rng.choice(
+ shape2[0] * shape2[1],
+ size=n_negatives,
+ replace=False,
+ p=norm(~is_reciprocal2),
+ ),
+ ]
+ valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
+
+ if ret_xy:
+ pos1 = unravel_xy(pos1, shape1)
+ pos2 = unravel_xy(pos2, shape2)
+ return pos1, pos2, valid
diff --git a/FastVGGT/eval/dataset_utils/cropping.py b/FastVGGT/eval/dataset_utils/cropping.py
new file mode 100644
index 0000000000000000000000000000000000000000..30a9eac18d241b71538957cf7ba4767ebc323b43
--- /dev/null
+++ b/FastVGGT/eval/dataset_utils/cropping.py
@@ -0,0 +1,140 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+
+import PIL.Image
+import os
+
+from utils import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
+
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import cv2 # noqa
+import numpy as np # noqa
+
+try:
+ lanczos = PIL.Image.Resampling.LANCZOS
+ bicubic = PIL.Image.Resampling.BICUBIC
+except AttributeError:
+ lanczos = PIL.Image.LANCZOS
+ bicubic = PIL.Image.BICUBIC
+
+
+class ImageList:
+ """Convenience class to aply the same operation to a whole set of images."""
+
+ def __init__(self, images):
+ if not isinstance(images, (tuple, list, set)):
+ images = [images]
+ self.images = []
+ for image in images:
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+ self.images.append(image)
+
+ def __len__(self):
+ return len(self.images)
+
+ def to_pil(self):
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
+
+ @property
+ def size(self):
+ sizes = [im.size for im in self.images]
+ assert all(sizes[0] == s for s in sizes)
+ return sizes[0]
+
+ def resize(self, *args, **kwargs):
+ return ImageList(self._dispatch("resize", *args, **kwargs))
+
+ def crop(self, *args, **kwargs):
+ return ImageList(self._dispatch("crop", *args, **kwargs))
+
+ def _dispatch(self, func, *args, **kwargs):
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
+
+
+def rescale_image_depthmap(
+ image, depthmap, camera_intrinsics, output_resolution, force=True
+):
+ """Jointly rescale a (image, depthmap)
+ so that (out_width, out_height) >= output_res
+ """
+ image = ImageList(image)
+ input_resolution = np.array(image.size) # (W,H)
+ output_resolution = np.array(output_resolution)
+ if depthmap is not None:
+
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
+
+ assert output_resolution.shape == (2,)
+ scale_final = max(output_resolution / image.size) + 1e-8
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
+ return (image.to_pil(), depthmap, camera_intrinsics)
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
+
+ image = image.resize(
+ output_resolution, resample=lanczos if scale_final < 1 else bicubic
+ )
+ if depthmap is not None:
+ depthmap = cv2.resize(
+ depthmap,
+ output_resolution,
+ fx=scale_final,
+ fy=scale_final,
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ camera_intrinsics = camera_matrix_of_crop(
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
+ )
+
+ return image.to_pil(), depthmap, camera_intrinsics
+
+
+def camera_matrix_of_crop(
+ input_camera_matrix,
+ input_resolution,
+ output_resolution,
+ scaling=1,
+ offset_factor=0.5,
+ offset=None,
+):
+
+ margins = np.asarray(input_resolution) * scaling - output_resolution
+ assert np.all(margins >= 0.0)
+ if offset is None:
+ offset = offset_factor * margins
+
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
+ output_camera_matrix_colmap[:2, :] *= scaling
+ output_camera_matrix_colmap[:2, 2] -= offset
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
+
+ return output_camera_matrix
+
+
+def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
+ """
+ Return a crop of the input view.
+ """
+ image = ImageList(image)
+ l, t, r, b = crop_bbox
+
+ image = image.crop((l, t, r, b))
+ depthmap = depthmap[t:b, l:r]
+
+ camera_intrinsics = camera_intrinsics.copy()
+ camera_intrinsics[0, 2] -= l
+ camera_intrinsics[1, 2] -= t
+
+ return image.to_pil(), depthmap, camera_intrinsics
+
+
+def bbox_from_intrinsics_in_out(
+ input_camera_matrix, output_camera_matrix, output_resolution
+):
+ out_width, out_height = output_resolution
+ l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
+ crop_bbox = (l, t, l + out_width, t + out_height)
+ return crop_bbox
diff --git a/FastVGGT/eval/dataset_utils/transforms.py b/FastVGGT/eval/dataset_utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..cec2d144e1b97cc99c191d15cdcaf20796cae94b
--- /dev/null
+++ b/FastVGGT/eval/dataset_utils/transforms.py
@@ -0,0 +1,78 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+
+import torchvision.transforms as tvf
+
+ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
+
+ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
+
+
+def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
+ if isinstance(value, (int, float)):
+ if value < 0:
+ raise ValueError(f"If is a single number, it must be non negative.")
+ value = [center - float(value), center + float(value)]
+ if clip_first_on_zero:
+ value[0] = max(value[0], 0.0)
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
+ value = [float(value[0]), float(value[1])]
+ else:
+ raise TypeError(f"should be a single number or a list/tuple with length 2.")
+
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
+ raise ValueError(f"values should be between {bound}, but got {value}.")
+
+ if value[0] == value[1] == center:
+ return None
+ else:
+ return tuple(value)
+
+
+import torch
+import torchvision.transforms.functional as F
+
+
+def SeqColorJitter():
+ """
+ Return a color jitter transform with same random parameters
+ """
+ brightness = _check_input(0.5)
+ contrast = _check_input(0.5)
+ saturation = _check_input(0.5)
+ hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
+
+ fn_idx = torch.randperm(4)
+ brightness_factor = (
+ None
+ if brightness is None
+ else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
+ )
+ contrast_factor = (
+ None
+ if contrast is None
+ else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
+ )
+ saturation_factor = (
+ None
+ if saturation is None
+ else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
+ )
+ hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
+
+ def _color_jitter(img):
+ for fn_id in fn_idx:
+ if fn_id == 0 and brightness_factor is not None:
+ img = F.adjust_brightness(img, brightness_factor)
+ elif fn_id == 1 and contrast_factor is not None:
+ img = F.adjust_contrast(img, contrast_factor)
+ elif fn_id == 2 and saturation_factor is not None:
+ img = F.adjust_saturation(img, saturation_factor)
+ elif fn_id == 3 and hue_factor is not None:
+ img = F.adjust_hue(img, hue_factor)
+ return ImgNorm(img)
+
+ return _color_jitter
diff --git a/FastVGGT/eval/eval_7andN.py b/FastVGGT/eval/eval_7andN.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d6d5fbf6c159ea18bfd6fbcee0a5d23e63c0cf2
--- /dev/null
+++ b/FastVGGT/eval/eval_7andN.py
@@ -0,0 +1,497 @@
+import os
+import sys
+
+# Ensure project root is on sys.path for absolute imports like `vggt.*`
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
+if ROOT_DIR not in sys.path:
+ sys.path.insert(0, ROOT_DIR)
+
+import time
+import torch
+import argparse
+import numpy as np
+import open3d as o3d
+import os.path as osp
+from torch.utils.data import DataLoader
+from torch.utils.data._utils.collate import default_collate
+from tqdm import tqdm
+from collections import defaultdict
+import torchvision.transforms as transforms
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False)
+ parser.add_argument(
+ "--ckpt_path",
+ type=str,
+ default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
+ help="ckpt name",
+ )
+ parser.add_argument("--device", type=str, default="cuda:0", help="device")
+ parser.add_argument("--model_name", type=str, default="VGGT")
+ parser.add_argument(
+ "--conf_thresh", type=float, default=0.0, help="confidence threshold"
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="/home/sy/code/FastVGGT/eval_results",
+ help="value for outdir",
+ )
+ parser.add_argument("--size", type=int, default=518)
+ parser.add_argument("--revisit", type=int, default=1, help="revisit times")
+ parser.add_argument("--freeze", action="store_true")
+ parser.add_argument("--use_proj", action="store_true")
+ parser.add_argument(
+ "--merging", type=int, default=0, help="VGGT aggregator merging steps"
+ )
+ parser.add_argument("--kf", type=int, default=2, help="key frame")
+ return parser
+
+
+def main(args):
+ from data import SevenScenes, NRGBD
+ from utils import accuracy, completion
+
+ if args.size == 512:
+ resolution = (512, 384)
+ elif args.size == 224:
+ resolution = 224
+ elif args.size == 518:
+ resolution = (518, 392)
+ else:
+ raise NotImplementedError
+ datasets_all = {
+ "7scenes": SevenScenes(
+ split="test",
+ ROOT="/data/sy/7scenes",
+ resolution=resolution,
+ num_seq=1,
+ full_video=True,
+ kf_every=args.kf,
+ ), # 20),
+ "NRGBD": NRGBD(
+ split="test",
+ ROOT="/data/sy/neural_rgbd_data",
+ resolution=resolution,
+ num_seq=1,
+ full_video=True,
+ kf_every=args.kf,
+ ),
+ }
+
+ device = args.device
+ model_name = args.model_name
+
+ from vggt.models.vggt import VGGT
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
+ from criterion import Regr3D_t_ScaleShiftInv, L21
+
+ # Force use of bf16 data type
+ dtype = torch.bfloat16
+ # Load VGGT model
+ model = VGGT(merging=args.merging, enable_point=True)
+ ckpt = torch.load(args.ckpt_path, map_location="cpu")
+
+ # ✅ Fix: load pre-trained weights
+ model.load_state_dict(
+ ckpt, strict=False
+ ) # Use strict=False due to enable_point=True difference
+
+ model = model.cuda().eval()
+ model = model.to(torch.bfloat16)
+
+ del ckpt
+ os.makedirs(osp.join(args.output_dir, f"{args.kf}"), exist_ok=True)
+
+ criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True)
+
+ with torch.no_grad():
+ for name_data, dataset in datasets_all.items():
+ save_path = osp.join(osp.join(args.output_dir, f"{args.kf}"), name_data)
+ os.makedirs(save_path, exist_ok=True)
+ log_file = osp.join(save_path, "logs.txt")
+
+ acc_all = 0
+ acc_all_med = 0
+ comp_all = 0
+ comp_all_med = 0
+ nc1_all = 0
+ nc1_all_med = 0
+ nc2_all = 0
+ nc2_all_med = 0
+ scene_infer_times = defaultdict(list)
+
+ for data_idx in tqdm(range(len(dataset))):
+ batch = default_collate([dataset[data_idx]])
+ ignore_keys = set(
+ [
+ "depthmap",
+ "dataset",
+ "label",
+ "instance",
+ "idx",
+ "true_shape",
+ "rng",
+ ]
+ )
+ for view in batch:
+ for name in view.keys(): # pseudo_focal
+ if name in ignore_keys:
+ continue
+ if isinstance(view[name], tuple) or isinstance(
+ view[name], list
+ ):
+ view[name] = [
+ x.to(device, non_blocking=True) for x in view[name]
+ ]
+ else:
+ view[name] = view[name].to(device, non_blocking=True)
+
+ pts_all = []
+ pts_gt_all = []
+ images_all = []
+ masks_all = []
+ conf_all = []
+ in_camera1 = None
+
+ dtype = (
+ torch.bfloat16
+ if torch.cuda.get_device_capability()[0] >= 8
+ else torch.float16
+ )
+ with torch.cuda.amp.autocast(dtype=dtype):
+ if isinstance(batch, dict) and "img" in batch:
+ batch["img"] = (batch["img"] + 1.0) / 2.0
+ elif isinstance(batch, list) and all(
+ isinstance(v, dict) and "img" in v for v in batch
+ ):
+ for view in batch:
+ view["img"] = (view["img"] + 1.0) / 2.0
+ # Gather all `img` tensors into a single tensor of shape [N, C, H, W]
+ imgs_tensor = torch.cat([v["img"] for v in batch], dim=0)
+
+ with torch.cuda.amp.autocast(dtype=dtype):
+ with torch.no_grad():
+ torch.cuda.synchronize()
+ start = time.time()
+ preds = model(imgs_tensor)
+ torch.cuda.synchronize()
+ end = time.time()
+ inference_time_ms = (end - start) * 1000
+ print(f"Inference time: {inference_time_ms:.2f}ms")
+
+ # Wrap model outputs per-view to align with batch later
+ predictions = preds
+ views = batch # list[dict]
+ if "pose_enc" in predictions:
+ B, S = predictions["pose_enc"].shape[:2]
+ elif "world_points" in predictions:
+ B, S = predictions["world_points"].shape[:2]
+ else:
+ raise KeyError(
+ "predictions is missing a key to infer sequence length"
+ )
+
+ ress = []
+ for s in range(S):
+ res = {
+ "pts3d_in_other_view": predictions["world_points"][:, s],
+ "conf": predictions["world_points_conf"][:, s],
+ "depth": predictions["depth"][:, s],
+ "depth_conf": predictions["depth_conf"][:, s],
+ "camera_pose": predictions["pose_enc"][:, s, :],
+ }
+ if (
+ isinstance(views, list)
+ and s < len(views)
+ and "valid_mask" in views[s]
+ ):
+ res["valid_mask"] = views[s]["valid_mask"]
+ if "track" in predictions:
+ res.update(
+ {
+ "track": predictions["track"][:, s],
+ "vis": (
+ predictions.get("vis", None)[:, s]
+ if "vis" in predictions
+ else None
+ ),
+ "track_conf": (
+ predictions.get("conf", None)[:, s]
+ if "conf" in predictions
+ else None
+ ),
+ }
+ )
+ ress.append(res)
+
+ preds = ress
+
+ valid_length = len(preds) // args.revisit
+ if args.revisit > 1:
+ preds = preds[-valid_length:]
+ batch = batch[-valid_length:]
+
+ # Evaluation
+ print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}")
+ gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
+ criterion.get_all_pts3d_t(batch, preds)
+ )
+
+ in_camera1 = None
+ pts_all = []
+ pts_gt_all = []
+ images_all = []
+ masks_all = []
+ conf_all = []
+
+ for j, view in enumerate(batch):
+ if in_camera1 is None:
+ in_camera1 = view["camera_pose"][0].cpu()
+
+ image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0]
+ mask = view["valid_mask"].cpu().numpy()[0]
+
+ pts = pred_pts[j].cpu().numpy()[0]
+ conf = preds[j]["conf"].cpu().data.numpy()[0]
+
+ # mask = mask & (conf > 1.8)
+
+ pts_gt = gt_pts[j].detach().cpu().numpy()[0]
+
+ H, W = image.shape[:2]
+ cx = W // 2
+ cy = H // 2
+ l, t = cx - 112, cy - 112
+ r, b = cx + 112, cy + 112
+ image = image[t:b, l:r]
+ mask = mask[t:b, l:r]
+ pts = pts[t:b, l:r]
+ pts_gt = pts_gt[t:b, l:r]
+
+ images_all.append(image[None, ...])
+ pts_all.append(pts[None, ...])
+ pts_gt_all.append(pts_gt[None, ...])
+ masks_all.append(mask[None, ...])
+ conf_all.append(conf[None, ...])
+
+ images_all = np.concatenate(images_all, axis=0)
+ pts_all = np.concatenate(pts_all, axis=0)
+ pts_gt_all = np.concatenate(pts_gt_all, axis=0)
+ masks_all = np.concatenate(masks_all, axis=0)
+
+ scene_id = view["label"][0].rsplit("/", 1)[0]
+ # Record average inference time per scene
+ try:
+ scene_infer_times[scene_id].append(float(inference_time_ms))
+ except Exception:
+ pass
+
+ save_params = {}
+
+ save_params["images_all"] = images_all
+ save_params["pts_all"] = pts_all
+ save_params["pts_gt_all"] = pts_gt_all
+ save_params["masks_all"] = masks_all
+
+ pts_all_masked = pts_all[masks_all > 0]
+ pts_gt_all_masked = pts_gt_all[masks_all > 0]
+ images_all_masked = images_all[masks_all > 0]
+
+ mask = np.isfinite(pts_all_masked)
+ pts_all_masked = pts_all_masked[mask]
+
+ mask_gt = np.isfinite(pts_gt_all_masked)
+ pts_gt_all_masked = pts_gt_all_masked[mask_gt]
+ images_all_masked = images_all_masked[mask]
+
+ # Reshape to point cloud (N, 3) before sampling
+ pts_all_masked = pts_all_masked.reshape(-1, 3)
+ pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
+ images_all_masked = images_all_masked.reshape(-1, 3)
+
+ # If number of points exceeds threshold, sample by points
+ if pts_all_masked.shape[0] > 999999:
+ sample_indices = np.random.choice(
+ pts_all_masked.shape[0], 999999, replace=False
+ )
+ pts_all_masked = pts_all_masked[sample_indices]
+ images_all_masked = images_all_masked[sample_indices]
+
+ # Apply the same sampling to GT point cloud
+ if pts_gt_all_masked.shape[0] > 999999:
+ sample_indices_gt = np.random.choice(
+ pts_gt_all_masked.shape[0], 999999, replace=False
+ )
+ pts_gt_all_masked = pts_gt_all_masked[sample_indices_gt]
+
+ if args.use_proj:
+
+ def umeyama_alignment(
+ src: np.ndarray, dst: np.ndarray, with_scale: bool = True
+ ):
+ assert src.shape == dst.shape
+ N, dim = src.shape
+
+ mu_src = src.mean(axis=0)
+ mu_dst = dst.mean(axis=0)
+ src_c = src - mu_src
+ dst_c = dst - mu_dst
+
+ Sigma = dst_c.T @ src_c / N # (3,3)
+
+ U, D, Vt = np.linalg.svd(Sigma)
+
+ S = np.eye(dim)
+ if np.linalg.det(U) * np.linalg.det(Vt) < 0:
+ S[-1, -1] = -1
+
+ R = U @ S @ Vt
+
+ if with_scale:
+ var_src = (src_c**2).sum() / N
+ s = (D * S.diagonal()).sum() / var_src
+ else:
+ s = 1.0
+
+ t = mu_dst - s * R @ mu_src
+
+ return s, R, t
+
+ pts_all_masked = pts_all_masked.reshape(-1, 3)
+ pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
+ s, R, t = umeyama_alignment(
+ pts_all_masked, pts_gt_all_masked, with_scale=True
+ )
+ pts_all_aligned = (s * (R @ pts_all_masked.T)).T + t # (N,3)
+ pts_all_masked = pts_all_aligned
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(pts_all_masked)
+ pcd.colors = o3d.utility.Vector3dVector(images_all_masked)
+
+ pcd_gt = o3d.geometry.PointCloud()
+ pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked)
+ pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked)
+
+ trans_init = np.eye(4)
+
+ threshold = 0.1
+ reg_p2p = o3d.pipelines.registration.registration_icp(
+ pcd,
+ pcd_gt,
+ threshold,
+ trans_init,
+ o3d.pipelines.registration.TransformationEstimationPointToPoint(),
+ )
+
+ transformation = reg_p2p.transformation
+
+ pcd = pcd.transform(transformation)
+ pcd.estimate_normals()
+ pcd_gt.estimate_normals()
+
+ gt_normal = np.asarray(pcd_gt.normals)
+ pred_normal = np.asarray(pcd.normals)
+
+ acc, acc_med, nc1, nc1_med = accuracy(
+ pcd_gt.points, pcd.points, gt_normal, pred_normal
+ )
+ comp, comp_med, nc2, nc2_med = completion(
+ pcd_gt.points, pcd.points, gt_normal, pred_normal
+ )
+ print(
+ f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}"
+ )
+ print(
+ f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}",
+ file=open(log_file, "a"),
+ )
+
+ acc_all += acc
+ comp_all += comp
+ nc1_all += nc1
+ nc2_all += nc2
+
+ acc_all_med += acc_med
+ comp_all_med += comp_med
+ nc1_all_med += nc1_med
+ nc2_all_med += nc2_med
+
+ # release cuda memory
+ torch.cuda.empty_cache()
+
+ # Get depth from pcd and run TSDFusion
+ to_write = ""
+ # Read the log file
+ if os.path.exists(osp.join(save_path, "logs.txt")):
+ with open(osp.join(save_path, "logs.txt"), "r") as f_sub:
+ to_write += f_sub.read()
+
+ with open(osp.join(save_path, f"logs_all.txt"), "w") as f:
+ log_data = to_write
+ metrics = defaultdict(list)
+ for line in log_data.strip().split("\n"):
+ match = regex.match(line)
+ if match:
+ data = match.groupdict()
+ # Exclude 'scene_id' from metrics as it's an identifier
+ for key, value in data.items():
+ if key != "scene_id":
+ metrics[key].append(float(value))
+ metrics["nc"].append(
+ (float(data["nc1"]) + float(data["nc2"])) / 2
+ )
+ metrics["nc_med"].append(
+ (float(data["nc1_med"]) + float(data["nc2_med"])) / 2
+ )
+ mean_metrics = {
+ metric: sum(values) / len(values)
+ for metric, values in metrics.items()
+ }
+
+ c_name = "mean"
+ print_str = f"{c_name.ljust(20)}: "
+ for m_name in mean_metrics:
+ print_num = np.mean(mean_metrics[m_name])
+ print_str = print_str + f"{m_name}: {print_num:.3f} | "
+ print_str = print_str + "\n"
+ # Summarize per-scene average inference time
+ time_lines = []
+ for sid, times in scene_infer_times.items():
+ if len(times) > 0:
+ time_lines.append(
+ f"Idx: {sid}, Time_avg_ms: {np.mean(times):.2f}"
+ )
+ time_block = "\n".join(time_lines) + (
+ "\n" if len(time_lines) > 0 else ""
+ )
+
+ f.write(to_write + time_block + print_str)
+
+
+from collections import defaultdict
+import re
+
+pattern = r"""
+ Idx:\s*(?P[^,]+),\s*
+ Acc:\s*(?P[^,]+),\s*
+ Comp:\s*(?P[^,]+),\s*
+ NC1:\s*(?P[^,]+),\s*
+ NC2:\s*(?P[^,]+)\s*-\s*
+ Acc_med:\s*(?P[^,]+),\s*
+ Compc_med:\s*(?P[^,]+),\s*
+ NC1c_med:\s*(?P[^,]+),\s*
+ NC2c_med:\s*(?P[^,]+)
+"""
+
+regex = re.compile(pattern, re.VERBOSE)
+
+
+if __name__ == "__main__":
+ parser = get_args_parser()
+ args = parser.parse_args()
+
+ main(args)
diff --git a/FastVGGT/eval/eval_custom.py b/FastVGGT/eval/eval_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9e4f7ccdd6c22c2c3bef6421c52ace01d362a83
--- /dev/null
+++ b/FastVGGT/eval/eval_custom.py
@@ -0,0 +1,467 @@
+import argparse
+from pathlib import Path
+import numpy as np
+import torch
+import os
+import sys
+import matplotlib.pyplot as plt
+from scipy.spatial.transform import Rotation
+
+# Ensure project root is in sys.path for absolute imports like `vggt.*`
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
+if ROOT_DIR not in sys.path:
+ sys.path.insert(0, ROOT_DIR)
+
+from vggt.models.vggt import VGGT
+from vggt.utils.eval_utils import (
+ load_poses,
+ get_vgg_input_imgs,
+ get_sorted_image_paths,
+ build_frame_selection,
+ load_images_rgb,
+ infer_vggt_and_reconstruct,
+ evaluate_scene_and_save,
+)
+
+# Import pose visualization libraries (optional EVO support)
+try:
+ from evo.core.trajectory import PoseTrajectory3D
+ import evo.tools.plot as plot
+
+ EVO_AVAILABLE = True
+except ImportError:
+ # EVO is optional; we have a matplotlib-based fallback
+ EVO_AVAILABLE = False
+
+
+def visualize_predicted_poses(
+ all_cam_to_world_mat, frame_ids, output_scene_dir, scene_name="custom_dataset"
+):
+ """
+ Visualize the predicted camera pose trajectory (no GT comparison required).
+
+ Args:
+ all_cam_to_world_mat: List of camera-to-world transform matrices
+ frame_ids: List of frame IDs
+ output_scene_dir: Output directory
+ scene_name: Scene name
+ """
+ # Provide basic pose visualization even without EVO
+ if not EVO_AVAILABLE:
+ print("⚠️ EVO not installed; using basic matplotlib visualization")
+
+ try:
+ # Convert to numpy array
+ poses_est = np.array(all_cam_to_world_mat)
+
+ if len(poses_est) < 2:
+ print("⚠️ Not enough poses to generate trajectory plot")
+ return
+
+ print(f"🎨 Generating pose trajectory visualization...")
+
+ # Extract translation part
+ positions = poses_est[:, :3, 3] # shape: (N, 3)
+
+ # Create figure - show XZ-plane projection only
+ fig, ax = plt.subplots(1, 1, figsize=(10, 8))
+
+ # XZ-plane projection
+ ax.plot(
+ positions[:, 0],
+ positions[:, 2],
+ "b-",
+ linewidth=2,
+ label="Predicted Trajectory",
+ )
+ ax.scatter(
+ positions[0, 0], positions[0, 2], color="green", s=100, label="Start"
+ )
+ ax.scatter(positions[-1, 0], positions[-1, 2], color="red", s=100, label="End")
+ ax.set_xlabel("X (m)")
+ ax.set_ylabel("Z (m)")
+ ax.set_title(f"{scene_name} - XZ-plane projection")
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+
+ # Save image
+ pose_plot_path = output_scene_dir / "predicted_trajectory.png"
+ plt.savefig(pose_plot_path, dpi=300, bbox_inches="tight")
+ plt.close()
+
+ print(f"📊 Trajectory visualization saved: {pose_plot_path}")
+
+ except Exception as e:
+ print(f"⚠️ Failed to generate pose visualization: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+
+def main():
+ """
+ Evaluation script for a Custom Dataset.
+ Supports optional evaluation and custom dataset structure.
+ """
+ parser = argparse.ArgumentParser(
+ description="Run FastVGGT evaluation on a Custom Dataset"
+ )
+
+ # Required: dataset path
+ parser.add_argument(
+ "--data_path",
+ type=Path,
+ required=True,
+ help="Dataset path containing subfolders: color, depth, gt_ply, pose",
+ )
+
+ # Optional: enable evaluation
+ parser.add_argument(
+ "--enable_evaluation",
+ action="store_true",
+ help="Enable evaluation (requires pose and ply data)",
+ )
+
+ # Output path
+ parser.add_argument(
+ "--output_path",
+ type=Path,
+ default="./eval_results_custom",
+ help="Output path for evaluation results",
+ )
+
+ # Model parameters
+ parser.add_argument(
+ "--ckpt_path",
+ type=str,
+ default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
+ help="Model checkpoint file path",
+ )
+
+ parser.add_argument("--merging", type=int, default=0, help="Merging parameter")
+
+ # Processing parameters
+ parser.add_argument(
+ "--input_frame",
+ type=int,
+ default=200,
+ help="Maximum number of frames to process per scene",
+ )
+
+ parser.add_argument(
+ "--depth_conf_thresh",
+ type=float,
+ default=3.0,
+ help="Depth confidence threshold to filter low-confidence depth values",
+ )
+
+ # Evaluation parameters (only used when evaluation is enabled)
+ parser.add_argument(
+ "--chamfer_max_dist",
+ type=float,
+ default=0.5,
+ help="Maximum distance threshold used in Chamfer Distance computation",
+ )
+
+ parser.add_argument("--plot", action="store_true", help="Whether to generate plots")
+
+ parser.add_argument(
+ "--vis_attn_map",
+ action="store_true",
+ help="Visualize attention maps during inference",
+ )
+
+ args = parser.parse_args()
+ torch.manual_seed(33)
+
+ # Check data path exists
+ if not args.data_path.exists():
+ print(f"❌ Error: Data path does not exist: {args.data_path}")
+ return
+
+ # Check required subdirectories
+ color_dir = args.data_path / "images"
+ pose_dir = args.data_path / "pose"
+
+ if not color_dir.exists():
+ print(f"❌ Error: color directory does not exist: {color_dir}")
+ return
+
+ print(f"📁 Dataset path: {args.data_path}")
+ # print(f"🔧 Enable evaluation: {'Yes' if args.enable_evaluation else 'No'}")
+
+ # If evaluation is enabled, check pose and gt_ply directories
+ if args.enable_evaluation:
+ if not pose_dir.exists():
+ print(f"❌ Error: Evaluation requires pose directory: {pose_dir}")
+ return
+
+ gt_ply_dir = args.data_path / "gt_ply"
+ if not gt_ply_dir.exists():
+ print(f"❌ Error: Evaluation requires gt_ply directory: {gt_ply_dir}")
+ return
+ print(f"📊 Evaluation will use Ground Truth")
+ else:
+ print(f"🏃 Inference only, no evaluation")
+
+ # Create output directory
+ args.output_path.mkdir(parents=True, exist_ok=True)
+ output_scene_dir = args.output_path / "custom_dataset"
+
+ # Check if already processed
+ if (output_scene_dir / "metrics.json").exists() and args.enable_evaluation:
+ print(
+ f"⚠️ Results already exist, skipping: {output_scene_dir / 'metrics.json'}"
+ )
+ return
+
+ # Force use of bf16 dtype
+ dtype = torch.bfloat16
+
+ # Load VGGT model
+ print(f"🔄 Loading model: {args.ckpt_path}")
+ model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
+ ckpt = torch.load(args.ckpt_path, map_location="cpu")
+ incompat = model.load_state_dict(ckpt, strict=False)
+ # if incompat.missing_keys or incompat.unexpected_keys:
+ # print(f"⚠️ Partially incompatible keys when loading model: {incompat}")
+ model = model.cuda().eval()
+ model = model.to(torch.bfloat16)
+ print(f"✅ Model loaded")
+
+ # Load scene data
+ image_paths = get_sorted_image_paths(color_dir)
+ if len(image_paths) == 0:
+ print(f"❌ Error: No images found in {color_dir}")
+ return
+
+ print(f"🖼️ Found {len(image_paths)} images")
+
+ # Process pose data (if evaluation is enabled)
+ poses_gt = None
+ first_gt_pose = None
+ available_pose_frame_ids = None
+ c2ws = None
+
+ if args.enable_evaluation:
+ poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_dir)
+ if (
+ poses_gt is None
+ or first_gt_pose is None
+ or available_pose_frame_ids is None
+ ):
+ print(f"❌ Error: Failed to load pose data")
+ return
+ print(f"📐 Loaded {len(poses_gt)} poses")
+
+ # Frame selection
+ if args.enable_evaluation and available_pose_frame_ids is not None:
+ # Use pose data for frame selection
+ selected_frame_ids, selected_image_paths, selected_pose_indices = (
+ build_frame_selection(
+ image_paths, available_pose_frame_ids, args.input_frame
+ )
+ )
+ c2ws = poses_gt[selected_pose_indices]
+ image_paths = selected_image_paths
+ else:
+ # Simply take the first N frames
+ num_frames = min(len(image_paths), args.input_frame)
+ selected_frame_ids = list(range(num_frames))
+ image_paths = image_paths[:num_frames]
+
+ print(f"📋 Selected {len(image_paths)} frames for processing")
+
+ try:
+ # Load images
+ print(f"🔄 Loading images...")
+ images = load_images_rgb(image_paths)
+
+ if not images or len(images) < 3:
+ print(f"❌ Error: Not enough valid images (need at least 3)")
+ return
+
+ frame_ids = selected_frame_ids
+ images_array = np.stack(images)
+ vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
+ print(f"📐 Image patch dimensions: {patch_width}x{patch_height}")
+
+ # Update attention layer patch dimensions in the model
+ model.update_patch_dimensions(patch_width, patch_height)
+
+ # Inference + Reconstruction
+ print(f"🚀 Start inference and reconstruction...")
+ (
+ extrinsic_np,
+ intrinsic_np,
+ all_world_points,
+ all_point_colors,
+ all_cam_to_world_mat,
+ inference_time_ms,
+ ) = infer_vggt_and_reconstruct(
+ model, vgg_input, dtype, args.depth_conf_thresh, image_paths
+ )
+ print(f"⏱️ Inference time: {inference_time_ms:.2f}ms")
+
+ # Check results
+ if not all_cam_to_world_mat or not all_world_points:
+ print(f"❌ Error: Failed to obtain valid camera poses or point clouds")
+ return
+
+ # print(f"✅ Inference done, obtained {len(all_world_points)} point sets")
+
+ # Evaluation and saving
+ if args.enable_evaluation:
+ print(f"📊 Start evaluation...")
+ gt_ply_dir = args.data_path / "gt_ply"
+ metrics = evaluate_scene_and_save(
+ "custom_dataset",
+ c2ws,
+ first_gt_pose,
+ frame_ids,
+ all_cam_to_world_mat,
+ all_world_points,
+ output_scene_dir,
+ gt_ply_dir,
+ args.chamfer_max_dist,
+ inference_time_ms,
+ args.plot,
+ )
+ if metrics is not None:
+ print("📈 Evaluation results:")
+ for key, value in metrics.items():
+ if key in [
+ "chamfer_distance",
+ "ate",
+ "are",
+ "rpe_rot",
+ "rpe_trans",
+ "inference_time_ms",
+ ]:
+ print(f" {key}: {float(value):.4f}")
+
+ # Also visualize predicted poses in evaluation branch
+ if args.plot:
+ visualize_predicted_poses(
+ all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
+ )
+ else:
+ # Save reconstruction only, no evaluation
+ print(f"💾 Saving reconstruction...")
+ output_scene_dir.mkdir(parents=True, exist_ok=True)
+
+ # Save camera poses
+ poses_output_path = output_scene_dir / "estimated_poses.txt"
+ with open(poses_output_path, "w") as f:
+ for i, pose in enumerate(all_cam_to_world_mat):
+ f.write(f"# Frame {frame_ids[i]}\n")
+ for row in pose:
+ f.write(" ".join(map(str, row)) + "\n")
+ f.write("\n")
+
+ # Save point cloud
+ if all_world_points:
+ points_output_path = output_scene_dir / "reconstructed_points.ply"
+
+ # Merge all frames' point clouds and colors
+ try:
+ merged_point_cloud = np.vstack(all_world_points)
+ merged_colors = (
+ np.vstack(all_point_colors).astype(np.uint8)
+ if all_point_colors is not None and len(all_point_colors) > 0
+ else None
+ )
+ print(
+ f"📊 Merged point clouds: {len(all_world_points)} frames, total {len(merged_point_cloud)} points"
+ )
+
+ # If too many points, randomly sample 100000 points
+ max_points = 100000
+ if len(merged_point_cloud) > max_points:
+ print(
+ f"🔽 Too many points, randomly sampling {max_points} points..."
+ )
+ # Randomly choose indices
+ indices = np.random.choice(
+ len(merged_point_cloud), size=max_points, replace=False
+ )
+ merged_point_cloud = merged_point_cloud[indices]
+ if merged_colors is not None:
+ merged_colors = merged_colors[indices]
+ print(
+ f"✅ Sampling done, kept {len(merged_point_cloud)} points"
+ )
+
+ # Save as PLY (with color)
+ with open(points_output_path, "w") as f:
+ f.write("ply\n")
+ f.write("format ascii 1.0\n")
+ f.write(f"element vertex {len(merged_point_cloud)}\n")
+ f.write("property float x\n")
+ f.write("property float y\n")
+ f.write("property float z\n")
+ if merged_colors is not None:
+ f.write("property uchar red\n")
+ f.write("property uchar green\n")
+ f.write("property uchar blue\n")
+ f.write("end_header\n")
+ if merged_colors is None:
+ for point in merged_point_cloud:
+ if not (np.isnan(point).any() or np.isinf(point).any()):
+ f.write(
+ f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n"
+ )
+ else:
+ for point, color in zip(merged_point_cloud, merged_colors):
+ # Check point validity
+ if not (np.isnan(point).any() or np.isinf(point).any()):
+ r = int(np.clip(color[0], 0, 255))
+ g = int(np.clip(color[1], 0, 255))
+ b = int(np.clip(color[2], 0, 255))
+ f.write(
+ f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r} {g} {b}\n"
+ )
+
+ print(f"💾 Point cloud saved to: {points_output_path}")
+
+ except Exception as e:
+ print(f"⚠️ Error saving point cloud: {e}")
+ # If merge fails, try to log per-frame info
+ print(f"🔍 Point cloud debug info:")
+ for i, frame_points in enumerate(all_world_points):
+ print(
+ f" Frame {i}: {frame_points.shape if hasattr(frame_points, 'shape') else type(frame_points)}"
+ )
+ if (
+ hasattr(frame_points, "shape")
+ and len(frame_points.shape) >= 2
+ ):
+ print(
+ f" Shape: {frame_points.shape}, Dtype: {frame_points.dtype}"
+ )
+ if frame_points.shape[0] > 0:
+ print(
+ f" Range: x[{np.min(frame_points[:, 0]):.3f}, {np.max(frame_points[:, 0]):.3f}] "
+ f"y[{np.min(frame_points[:, 1]):.3f}, {np.max(frame_points[:, 1]):.3f}] "
+ f"z[{np.min(frame_points[:, 2]):.3f}, {np.max(frame_points[:, 2]):.3f}]"
+ )
+
+ print(f"📁 Results saved to: {output_scene_dir}")
+
+ # Visualize predicted pose trajectory
+ if args.plot:
+ visualize_predicted_poses(
+ all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
+ )
+
+ print(f"🎉 Done!")
+
+ except Exception as e:
+ print(f"❌ Error occurred during processing: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/FastVGGT/eval/eval_scannet.py b/FastVGGT/eval/eval_scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..332132c45ab4fc3f7e768dbf3845d9cdb3ccc4eb
--- /dev/null
+++ b/FastVGGT/eval/eval_scannet.py
@@ -0,0 +1,208 @@
+import argparse
+from pathlib import Path
+import numpy as np
+import torch
+import os
+import sys
+
+# Ensure project root is in sys.path for absolute imports like `vggt.*`
+ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
+if ROOT_DIR not in sys.path:
+ sys.path.insert(0, ROOT_DIR)
+
+from vggt.models.vggt import VGGT
+from vggt.utils.eval_utils import (
+ load_poses,
+ get_vgg_input_imgs,
+ get_sorted_image_paths,
+ get_all_scenes,
+ build_frame_selection,
+ load_images_rgb,
+ infer_vggt_and_reconstruct,
+ evaluate_scene_and_save,
+ compute_average_metrics_and_save,
+)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--data_dir", type=Path, default="/data/scannetv2/process_scannet"
+ )
+ parser.add_argument(
+ "--gt_ply_dir",
+ type=Path,
+ default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
+ )
+ parser.add_argument("--output_path", type=Path, default="./eval_results")
+ parser.add_argument("--merging", type=int, default=None)
+ parser.add_argument("--plot", type=bool, default=True)
+ parser.add_argument(
+ "--depth_conf_thresh",
+ type=float,
+ default=3.0,
+ help="Depth confidence threshold for filtering low confidence depth values",
+ )
+ parser.add_argument(
+ "--chamfer_max_dist",
+ type=float,
+ default=0.5,
+ help="Maximum distance threshold in Chamfer Distance computation, distances exceeding this value will be clipped",
+ )
+ parser.add_argument(
+ "--input_frame",
+ type=int,
+ default=200,
+ help="Maximum number of frames selected for processing per scene",
+ )
+ parser.add_argument(
+ "--num_scenes",
+ type=int,
+ default=50,
+ help="Maximum number of scenes to evaluate",
+ )
+ parser.add_argument(
+ "--ckpt_path",
+ type=str,
+ default="./ckpt/model_tracker_fixed_e20.pt",
+ help="Path to the model checkpoint file",
+ )
+ parser.add_argument(
+ "--vis_attn_map",
+ action="store_true",
+ help="Whether to visualize attention maps during inference",
+ )
+ args = parser.parse_args()
+ torch.manual_seed(33)
+
+ # Scene sampling
+ scannet_scenes = get_all_scenes(args.data_dir, args.num_scenes)
+ print(f"Evaluate {len(scannet_scenes)} scenes")
+
+ all_scenes_metrics = {"scenes": {}, "average": {}}
+ # Force use of bf16 data type
+ dtype = torch.bfloat16
+ # Load VGGT model
+ model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
+ ckpt = torch.load(args.ckpt_path, map_location="cpu")
+ incompat = model.load_state_dict(ckpt, strict=False)
+ model = model.cuda().eval()
+ model = model.to(torch.bfloat16)
+
+ # Process each scene
+ for scene in scannet_scenes:
+ scene_dir = args.data_dir / f"{scene}"
+ output_scene_dir = args.output_path / f"input_frame_{args.input_frame}" / scene
+ if (output_scene_dir / "metrics.json").exists():
+ continue
+
+ # Load scene data
+ images_dir = scene_dir / "color"
+ pose_path = scene_dir / "pose"
+ image_paths = get_sorted_image_paths(images_dir)
+ poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_path)
+ if (
+ poses_gt is None
+ or first_gt_pose is None
+ or available_pose_frame_ids is None
+ ):
+ print(f"Skipping scene {scene}: no pose data")
+ continue
+
+ # Frame filtering
+ selected_frame_ids, selected_image_paths, selected_pose_indices = (
+ build_frame_selection(
+ image_paths, available_pose_frame_ids, args.input_frame
+ )
+ )
+
+ # Get corresponding poses
+ c2ws = poses_gt[selected_pose_indices]
+ image_paths = selected_image_paths
+
+ if len(image_paths) == 0:
+ print(f"No images found in {images_dir}")
+ continue
+
+ print("🚩Processing", scene, f"Found {len(image_paths)} images")
+ all_cam_to_world_mat = []
+ all_world_points = []
+
+ try:
+ # Load images
+ images = load_images_rgb(image_paths)
+
+ if not images or len(images) < 3:
+ print(f"Skipping {scene}: insufficient valid images")
+ continue
+
+ frame_ids = selected_frame_ids
+ images_array = np.stack(images)
+ vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
+ print(f"Patch dimensions: {patch_width}x{patch_height}")
+
+ # Update model attention layers with dynamic patch dimensions
+ model.update_patch_dimensions(patch_width, patch_height)
+
+ # Inference + Reconstruction
+ (
+ extrinsic_np,
+ intrinsic_np,
+ all_world_points,
+ all_point_colors,
+ all_cam_to_world_mat,
+ inference_time_ms,
+ ) = infer_vggt_and_reconstruct(
+ model, vgg_input, dtype, args.depth_conf_thresh, image_paths
+ )
+ print(f"Inference time: {inference_time_ms:.2f}ms")
+
+ # Process results
+ if not all_cam_to_world_mat or not all_world_points:
+ print(
+ f"Skipping {scene}: failed to obtain valid camera poses or point clouds"
+ )
+ continue
+
+ # Evaluate and save
+ metrics = evaluate_scene_and_save(
+ scene,
+ c2ws,
+ first_gt_pose,
+ frame_ids,
+ all_cam_to_world_mat,
+ all_world_points,
+ output_scene_dir,
+ args.gt_ply_dir,
+ args.chamfer_max_dist,
+ inference_time_ms,
+ args.plot,
+ )
+ if metrics is not None:
+ all_scenes_metrics["scenes"][scene] = {
+ key: float(value)
+ for key, value in metrics.items()
+ if key
+ in [
+ "chamfer_distance",
+ "ate",
+ "are",
+ "rpe_rot",
+ "rpe_trans",
+ "inference_time_ms",
+ ]
+ }
+ print("Complete metrics", all_scenes_metrics["scenes"][scene])
+
+ except Exception as e:
+ print(f"Error processing scene {scene}: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+ # Summarize average metrics and save
+ compute_average_metrics_and_save(
+ all_scenes_metrics,
+ args.output_path,
+ args.input_frame,
+ )
diff --git a/FastVGGT/eval/utils.py b/FastVGGT/eval/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8d5606560e1fc82a7b7b81df8b3b9f3b9ec8662
--- /dev/null
+++ b/FastVGGT/eval/utils.py
@@ -0,0 +1,142 @@
+import numpy as np
+from scipy.spatial import cKDTree as KDTree
+
+
+def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ camera_intrinsics = np.float32(camera_intrinsics)
+ H, W = depthmap.shape
+
+ assert camera_intrinsics[0, 1] == 0.0
+ assert camera_intrinsics[1, 0] == 0.0
+ if pseudo_focal is None:
+ fu = camera_intrinsics[0, 0]
+ fv = camera_intrinsics[1, 1]
+ else:
+ assert pseudo_focal.shape == (H, W)
+ fu = fv = pseudo_focal
+ cu = camera_intrinsics[0, 2]
+ cv = camera_intrinsics[1, 2]
+
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+ z_cam = depthmap
+ x_cam = (u - cu) * z_cam / fu
+ y_cam = (v - cv) * z_cam / fv
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ valid_mask = depthmap > 0.0
+ return X_cam, valid_mask
+
+
+def depthmap_to_absolute_camera_coordinates(
+ depthmap, camera_intrinsics, camera_pose, **kw
+):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
+
+ X_world = X_cam # default
+ if camera_pose is not None:
+
+ R_cam2world = camera_pose[:3, :3]
+ t_cam2world = camera_pose[:3, 3]
+
+ X_world = (
+ np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
+ )
+
+ return X_world, valid_mask
+
+
+def completion_ratio(gt_points, rec_points, dist_th=0.05):
+ gen_points_kd_tree = KDTree(rec_points)
+ distances, _ = gen_points_kd_tree.query(gt_points)
+ comp_ratio = np.mean((distances < dist_th).astype(np.float32))
+ return comp_ratio
+
+
+def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None):
+ gt_points_kd_tree = KDTree(gt_points)
+ distances, idx = gt_points_kd_tree.query(rec_points, workers=-1)
+ acc = np.mean(distances)
+
+ acc_median = np.median(distances)
+
+ if gt_normals is not None and rec_normals is not None:
+ normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1)
+ normal_dot = np.abs(normal_dot)
+
+ return acc, acc_median, np.mean(normal_dot), np.median(normal_dot)
+
+ return acc, acc_median
+
+
+def completion(gt_points, rec_points, gt_normals=None, rec_normals=None):
+ gt_points_kd_tree = KDTree(rec_points)
+ distances, idx = gt_points_kd_tree.query(gt_points, workers=-1)
+ comp = np.mean(distances)
+ comp_median = np.median(distances)
+
+ if gt_normals is not None and rec_normals is not None:
+ normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1)
+ normal_dot = np.abs(normal_dot)
+
+ return comp, comp_median, np.mean(normal_dot), np.median(normal_dot)
+
+ return comp, comp_median
+
+
+def compute_iou(pred_vox, target_vox):
+ # Get voxel indices
+ v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()]
+ v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()]
+
+ # Convert to sets for set operations
+ v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices)
+ v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices)
+
+ # Compute intersection and union
+ intersection = v_pred_filled & v_target_filled
+ union = v_pred_filled | v_target_filled
+
+ # Compute IoU
+ iou = len(intersection) / len(union)
+ return iou
+
+
+def colmap_to_opencv_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] -= 0.5
+ K[1, 2] -= 0.5
+ return K
+
+
+def opencv_to_colmap_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] += 0.5
+ K[1, 2] += 0.5
+ return K
diff --git a/FastVGGT/merging/__init__.py b/FastVGGT/merging/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..156ac80ee0ef4270bc02b0954f8473ff55380cce
--- /dev/null
+++ b/FastVGGT/merging/__init__.py
@@ -0,0 +1,3 @@
+from . import merge
+
+__all__ = ["merge"]
diff --git a/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc b/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66a4e39454e637a8e8e0fa63f6708b1eb670aae9
Binary files /dev/null and b/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc differ
diff --git a/FastVGGT/merging/__pycache__/merge.cpython-310.pyc b/FastVGGT/merging/__pycache__/merge.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fa7eff7fec3ec548a7532b57636ea9d7db57e16
Binary files /dev/null and b/FastVGGT/merging/__pycache__/merge.cpython-310.pyc differ
diff --git a/FastVGGT/merging/merge.py b/FastVGGT/merging/merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2094b657cba6c9772a85f2ebe0240367fcec09c
--- /dev/null
+++ b/FastVGGT/merging/merge.py
@@ -0,0 +1,370 @@
+import torch
+from typing import Tuple, Callable, Optional, Union
+
+
+@torch.jit.script
+def fast_similarity_chunks(
+ a: torch.Tensor, b_transposed: torch.Tensor, chunk_size: int
+) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ B, num_src, C = a.shape
+ original_dtype = a.dtype
+
+ # Convert to bf16 for computation to improve performance and reduce memory usage
+ a_bf16 = a.to(torch.bfloat16)
+ b_transposed_bf16 = b_transposed.to(torch.bfloat16)
+ node_max = torch.empty(B, num_src, device=a.device, dtype=original_dtype)
+ node_idx = torch.empty(B, num_src, device=a.device, dtype=torch.long)
+
+ # Process in chunks
+ for i in range(0, num_src, chunk_size):
+ end_i = min(i + chunk_size, num_src)
+ a_chunk = a_bf16[:, i:end_i, :] # [B, chunk_size, C]
+ scores_chunk = torch.bmm(a_chunk, b_transposed_bf16)
+ chunk_max_bf16, chunk_idx = torch.max(scores_chunk, dim=2)
+ chunk_max = chunk_max_bf16.to(original_dtype)
+ node_max[:, i:end_i] = chunk_max
+ node_idx[:, i:end_i] = chunk_idx
+ return node_max, node_idx
+
+
+def do_nothing(
+ x: torch.Tensor,
+ extra_tensors=None,
+ extra_tensors_2=None,
+) -> Union[
+ torch.Tensor,
+ Tuple[torch.Tensor, torch.Tensor],
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+]:
+ if extra_tensors is not None and extra_tensors_2 is not None:
+ return x, extra_tensors, extra_tensors_2
+ elif extra_tensors is not None:
+ return x, extra_tensors
+ else:
+ return x
+
+
+def token_merge_bipartite2d(
+ metric: torch.Tensor,
+ w: int,
+ h: int,
+ sx: int,
+ sy: int,
+ r: int,
+ no_rand: bool = False,
+ generator: Optional[torch.Generator] = None,
+ enable_protection: bool = False,
+) -> Tuple[Callable, Callable]:
+ """
+ Divide tokens into source (src) and destination (dst) groups, and merge r tokens from src to dst.
+ dst tokens are selected by randomly choosing one token from each (sx, sy) region.
+ Optionally protect the top 10% of tokens from merging based on importance scores.
+
+ Args:
+ - metric [B, N, C]: Tensor for similarity computation, B=batch size, N=token count, C=feature dimension
+ - w: Image width in tokens
+ - h: Image height in tokens
+ - sx: dst stride in x dimension, must divide w evenly
+ - sy: dst stride in y dimension, must divide h evenly
+ - r: Number of tokens to remove through merging
+ - no_rand: If True, disable randomness (use only top-left token)
+ - generator: Random number generator if no_rand is False and not None
+ - enable_protection: If True, enable importance protection feature
+
+ Returns:
+ - (merge, unmerge): Two functions for merging tokens and restoring pre-merge state
+ """
+ B, N, _ = metric.shape # Batch size B, total tokens N
+ if r <= 0:
+ return do_nothing, do_nothing
+
+ gather = torch.gather
+
+ tokens_per_img = w * h + 5
+ num_imgs = N // tokens_per_img
+ assert tokens_per_img * num_imgs == N, "Token count doesn't match (w*h+5)*num_imgs"
+
+ with torch.no_grad():
+ # Determine whether to compute importance scores based on enable_protection
+ if enable_protection:
+ num_protected = int(N * 0.1)
+ step = max(1, N // num_protected)
+ protected_indices = torch.arange(0, N, step, device=metric.device)[
+ :num_protected
+ ]
+ else:
+ protected_indices = None
+ num_protected = 0
+
+ # Global idx_buffer_seq of length N; -1 indicates dst, 0 indicates src (maintain original logic)
+ idx_buffer_seq = torch.zeros(N, device=metric.device, dtype=torch.int64)
+ hsy, wsx = h // sy, w // sx # Number of blocks within each image
+
+ # Mark first image entirely as dst
+ if num_imgs > 0:
+ idx_buffer_seq[:tokens_per_img] = -1
+
+ # Process other images - fully vectorized batch operations
+ if num_imgs > 1:
+ cls_indices = (
+ torch.arange(1, num_imgs, device=metric.device) * tokens_per_img
+ )
+ cls_indices = cls_indices[:, None] + torch.arange(5, device=metric.device)
+ idx_buffer_seq[cls_indices.flatten()] = -1
+ effective_h = min(hsy * sy, h)
+ effective_w = min(wsx * sx, w)
+ effective_grid_size = effective_h * effective_w
+
+ if no_rand:
+ base_pattern = torch.zeros(
+ effective_grid_size, device=metric.device, dtype=torch.int64
+ )
+ grid_starts = (
+ torch.arange(1, num_imgs, device=metric.device) * tokens_per_img + 5
+ )
+ grid_indices = grid_starts[:, None] + torch.arange(
+ effective_grid_size, device=metric.device
+ )
+ idx_buffer_seq[grid_indices.flatten()] = base_pattern.repeat(
+ num_imgs - 1
+ )
+ else:
+ total_other_imgs = num_imgs - 1
+ all_rand_idx = torch.randint(
+ sy * sx,
+ size=(total_other_imgs, hsy, wsx),
+ device=metric.device,
+ generator=generator,
+ )
+
+ scatter_src = -torch.ones(
+ total_other_imgs, hsy, wsx, device=metric.device, dtype=torch.int64
+ )
+
+ idx_buffer_batch = torch.zeros(
+ total_other_imgs,
+ hsy,
+ wsx,
+ sy * sx,
+ device=metric.device,
+ dtype=torch.int64,
+ )
+ idx_buffer_batch.scatter_(
+ dim=3,
+ index=all_rand_idx.unsqueeze(-1),
+ src=scatter_src.unsqueeze(-1),
+ )
+
+ idx_buffer_batch = (
+ idx_buffer_batch.view(total_other_imgs, hsy, wsx, sy, sx)
+ .transpose(2, 3)
+ .reshape(total_other_imgs, hsy * sy, wsx * sx)
+ )
+
+ # Batch fill to target positions - still needs a small loop here, but operations are greatly reduced
+ for i in range(total_other_imgs):
+ img_idx = i + 1
+ grid_start = img_idx * tokens_per_img + 5
+ flat_view = idx_buffer_batch[
+ i, :effective_h, :effective_w
+ ].flatten()
+ idx_buffer_seq[grid_start : grid_start + effective_grid_size] = (
+ flat_view
+ )
+
+ rand_idx = idx_buffer_seq.reshape(1, -1, 1).argsort(dim=1)
+ num_dst_orig = int((idx_buffer_seq == -1).sum())
+
+ # Original src and dst indices
+ a_idx_orig = rand_idx[:, num_dst_orig:, :]
+ b_idx_orig = rand_idx[:, :num_dst_orig, :]
+ a_idx = a_idx_orig
+ b_idx = b_idx_orig
+
+ if enable_protection:
+ protected_idx = protected_indices.unsqueeze(0).unsqueeze(-1)
+ num_protected_actual = protected_idx.shape[1]
+ else:
+ protected_idx = None
+ num_protected_actual = 0
+
+ num_src = a_idx.shape[1]
+ num_dst = b_idx.shape[1]
+
+ # Define an internal function to separate src, dst, and protected tokens
+ def split(x):
+ C = x.shape[-1]
+
+ if enable_protection:
+ src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
+ dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
+ protected = gather(
+ x, dim=1, index=protected_idx.expand(B, num_protected_actual, C)
+ )
+ return src, dst, protected
+ else:
+ src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
+ dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
+ return src, dst
+
+ # Compute cosine similarity (normalize first then dot product)
+ metric = metric / metric.norm(dim=-1, keepdim=True)
+ if enable_protection:
+ a, b, protected = split(metric)
+ else:
+ a, b = split(metric)
+
+ r = min(a.shape[1], r)
+ num_src_actual = a.shape[1]
+ chunk_size = min(5000, num_src_actual)
+
+ node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype)
+ node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long)
+
+ b_transposed = b.transpose(-1, -2)
+ node_max, node_idx = fast_similarity_chunks(a, b_transposed, chunk_size)
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
+
+ # If protection is enabled, filter out protected tokens to ensure they are not merged
+ if enable_protection:
+ src_indices = a_idx[0, :, 0]
+ protected_mask_src = torch.isin(src_indices, protected_indices)
+ edge_flat = edge_idx[0, :, 0]
+ valid_mask = ~protected_mask_src[edge_flat]
+ valid_edges = edge_flat[valid_mask]
+
+ valid_count = valid_edges.shape[0]
+ r_actual = min(r, valid_count)
+
+ unm_idx = valid_edges[r_actual:].unsqueeze(0).unsqueeze(-1)
+ src_idx = valid_edges[:r_actual].unsqueeze(0).unsqueeze(-1)
+ else:
+ unm_idx = edge_idx[..., r:, :]
+ src_idx = edge_idx[..., :r, :]
+ r_actual = r
+
+ # Get dst token indices corresponding to each src token to be merged
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
+ r = r_actual
+
+ # Define merge function to merge selected src tokens to corresponding dst tokens
+ def merge(
+ x: torch.Tensor,
+ mode: str = "mean",
+ extra_tensors=None,
+ extra_tensors_2=None,
+ ) -> Union[
+ torch.Tensor,
+ Tuple[torch.Tensor, torch.Tensor],
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+ ]:
+ if enable_protection:
+ src, dst, protected = split(x)
+ else:
+ src, dst = split(x)
+
+ n, t1, c = src.shape
+
+ # Extract unmerged src tokens - using actual unm_idx size
+ unm_len = unm_idx.shape[1]
+ unm = gather(src, dim=-2, index=unm_idx.expand(n, unm_len, c))
+ src_len = src_idx.shape[1]
+ src = gather(src, dim=-2, index=src_idx.expand(n, src_len, c))
+ dst = dst.scatter_reduce(-2, dst_idx.expand(n, src_len, c), src, reduce=mode)
+
+ # ---------------- Extra tensor processing ----------------
+ merged_extra_1 = None
+ merged_extra_2 = None
+ if extra_tensors is not None:
+ E_dim = extra_tensors.shape[-1]
+ if enable_protection:
+ src_e, dst_e, protected_e = split(extra_tensors)
+ else:
+ src_e, dst_e = split(extra_tensors)
+
+ # Consistent with main tensor, only select r src tokens to be merged
+ src_e_r = gather(src_e, dim=-2, index=src_idx.expand(n, src_len, E_dim))
+ unm_e = gather(src_e, dim=-2, index=unm_idx.expand(n, unm_len, E_dim))
+
+ dst_e = dst_e.scatter_reduce(
+ -2, dst_idx.expand(n, src_len, E_dim), src_e_r, reduce=mode
+ )
+ if enable_protection:
+ merged_extra_1 = torch.cat([unm_e, dst_e, protected_e], dim=1)
+ else:
+ merged_extra_1 = torch.cat([unm_e, dst_e], dim=1)
+
+ if extra_tensors_2 is not None:
+ E_dim_2 = extra_tensors_2.shape[-1]
+ if enable_protection:
+ src_e2, dst_e2, protected_e2 = split(extra_tensors_2)
+ else:
+ src_e2, dst_e2 = split(extra_tensors_2)
+
+ src_e2_r = gather(src_e2, dim=-2, index=src_idx.expand(n, src_len, E_dim_2))
+ unm_e2 = gather(src_e2, dim=-2, index=unm_idx.expand(n, unm_len, E_dim_2))
+
+ dst_e2 = dst_e2.scatter_reduce(
+ -2, dst_idx.expand(n, src_len, E_dim_2), src_e2_r, reduce=mode
+ )
+ if enable_protection:
+ merged_extra_2 = torch.cat([unm_e2, dst_e2, protected_e2], dim=1)
+ else:
+ merged_extra_2 = torch.cat([unm_e2, dst_e2], dim=1)
+
+ if enable_protection:
+ main_result = torch.cat([unm, dst, protected], dim=1)
+ else:
+ main_result = torch.cat([unm, dst], dim=1)
+
+ if merged_extra_1 is not None and merged_extra_2 is not None:
+ return main_result, merged_extra_1, merged_extra_2
+ elif merged_extra_1 is not None:
+ return main_result, merged_extra_1
+ else:
+ return main_result
+
+ # Define unmerge function to restore pre-merge state (for decoder)
+ def unmerge(x: torch.Tensor) -> torch.Tensor:
+ unm_len = unm_idx.shape[1]
+ dst_len = num_dst
+ src_len = src_idx.shape[1]
+ unm = x[..., :unm_len, :]
+ dst = x[..., unm_len : unm_len + dst_len, :]
+
+ if enable_protection:
+ protected = x[
+ ..., unm_len + dst_len : unm_len + dst_len + num_protected_actual, :
+ ]
+
+ _, _, c = unm.shape
+ src = gather(dst, dim=-2, index=dst_idx.expand(B, src_len, c))
+ out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
+ out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
+ out.scatter_(
+ dim=-2,
+ index=gather(
+ a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx
+ ).expand(B, unm_len, c),
+ src=unm,
+ )
+
+ out.scatter_(
+ dim=-2,
+ index=gather(
+ a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx
+ ).expand(B, src_len, c),
+ src=src,
+ )
+
+ if enable_protection:
+ out.scatter_(
+ dim=-2,
+ index=protected_idx.expand(B, num_protected_actual, c),
+ src=protected,
+ )
+
+ return out
+
+ return merge, unmerge
diff --git a/FastVGGT/requirements.txt b/FastVGGT/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e343637f2d9a1550f55f14f8890c229bd467cdb2
--- /dev/null
+++ b/FastVGGT/requirements.txt
@@ -0,0 +1,15 @@
+torch==2.3.1
+torchvision==0.18.1
+numpy==1.26.1
+Pillow
+huggingface_hub
+einops
+safetensors
+evo
+open3d
+matplotlib
+scipy
+opencv-python
+scikit-image
+tqdm
+
diff --git a/FastVGGT/vggt/__init__.py b/FastVGGT/vggt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a2958f93a114495631a3ce270b99ee8ba8443f1
--- /dev/null
+++ b/FastVGGT/vggt/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
\ No newline at end of file
diff --git a/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5f3bfe676b899c95cba2c034de5292d685e03ce
Binary files /dev/null and b/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/dependency/__init__.py b/FastVGGT/vggt/dependency/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a2958f93a114495631a3ce270b99ee8ba8443f1
--- /dev/null
+++ b/FastVGGT/vggt/dependency/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
\ No newline at end of file
diff --git a/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cd09b8397b7a4717d49afb51e64d490fa11fc205
Binary files /dev/null and b/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc b/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..389177102f318fe4bcfd46f6ab4a04e1a8f39196
Binary files /dev/null and b/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/dependency/distortion.py b/FastVGGT/vggt/dependency/distortion.py
new file mode 100644
index 0000000000000000000000000000000000000000..375b747086478050d601676ffaea3b25e80690b4
--- /dev/null
+++ b/FastVGGT/vggt/dependency/distortion.py
@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+
+def apply_distortion(points, distortion_params):
+ """
+ Apply distortion to normalized camera coordinates.
+
+ Args:
+ points: Array of normalized camera coordinates
+ distortion_params: Distortion parameters
+
+ Returns:
+ Distorted coordinates
+ """
+ # Simple passthrough for now - implement actual distortion if needed
+ return points
+
+
+def iterative_undistortion(points, distortion_params, max_iter=10):
+ """
+ Remove distortion from normalized camera coordinates using iterative method.
+
+ Args:
+ points: Array of distorted normalized camera coordinates
+ distortion_params: Distortion parameters
+ max_iter: Maximum number of iterations
+
+ Returns:
+ Undistorted coordinates
+ """
+ # Simple passthrough for now - implement actual undistortion if needed
+ return points
+
+
+def single_undistortion(points, distortion_params):
+ """
+ Remove distortion from normalized camera coordinates using single step.
+
+ Args:
+ points: Array of distorted normalized camera coordinates
+ distortion_params: Distortion parameters
+
+ Returns:
+ Undistorted coordinates
+ """
+ # Simple passthrough for now - implement actual undistortion if needed
+ return points
\ No newline at end of file
diff --git a/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..292dc30328d06ba3951a9eee0e99bc569c60a1d3
Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..838db11a0862e9907682572d8ebe3b7c1665e147
Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d6cd93db3a5987c35085d487d21f5513ea3d1fd
Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3244c2b0364f9b533aab03446edb787dbed59d5
Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc2e8df264f6accc5412e80b8580c8c48195aaf5
Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/camera_head.py b/FastVGGT/vggt/heads/camera_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..215adf39de23abd4975479d332250fcc3e2b54b9
--- /dev/null
+++ b/FastVGGT/vggt/heads/camera_head.py
@@ -0,0 +1,149 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from vggt.layers import Mlp
+from vggt.layers.block import Block
+from vggt.heads.head_act import activate_pose
+
+
+class CameraHead(nn.Module):
+ """
+ CameraHead predicts camera parameters from token representations using iterative refinement.
+
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
+ """
+
+ def __init__(
+ self,
+ dim_in: int = 2048,
+ trunk_depth: int = 4,
+ pose_encoding_type: str = "absT_quaR_FoV",
+ num_heads: int = 16,
+ mlp_ratio: int = 4,
+ init_values: float = 0.01,
+ trans_act: str = "linear",
+ quat_act: str = "linear",
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
+ ):
+ super().__init__()
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ self.target_dim = 9
+ else:
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
+
+ self.trans_act = trans_act
+ self.quat_act = quat_act
+ self.fl_act = fl_act
+ self.trunk_depth = trunk_depth
+
+ # Build the trunk using a sequence of transformer blocks.
+ self.trunk = nn.Sequential(
+ *[
+ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
+ for _ in range(trunk_depth)
+ ]
+ )
+
+ # Normalizations for camera token and trunk output.
+ self.token_norm = nn.LayerNorm(dim_in)
+ self.trunk_norm = nn.LayerNorm(dim_in)
+
+ # Learnable empty camera pose token.
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
+
+ # Module for producing modulation parameters: shift, scale, and a gate.
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
+
+ # Adaptive layer normalization without affine parameters.
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
+ self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
+
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
+ """
+ Forward pass to predict camera parameters.
+
+ Args:
+ aggregated_tokens_list (list): List of token tensors from the network;
+ the last tensor is used for prediction.
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
+
+ Returns:
+ list: A list of predicted camera encodings (post-activation) from each iteration.
+ """
+ # Use tokens from the last block for camera prediction.
+ tokens = aggregated_tokens_list[-1]
+
+ # Extract the camera tokens
+ pose_tokens = tokens[:, :, 0]
+ pose_tokens = self.token_norm(pose_tokens)
+
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
+ return pred_pose_enc_list
+
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
+ """
+ Iteratively refine camera pose predictions.
+
+ Args:
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
+ num_iterations (int): Number of refinement iterations.
+
+ Returns:
+ list: List of activated camera encodings from each iteration.
+ """
+ B, S, C = pose_tokens.shape # S is expected to be 1.
+ pred_pose_enc = None
+ pred_pose_enc_list = []
+
+ for _ in range(num_iterations):
+ # Use a learned empty pose for the first iteration.
+ if pred_pose_enc is None:
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
+ else:
+ # Detach the previous prediction to avoid backprop through time.
+ pred_pose_enc = pred_pose_enc.detach()
+ module_input = self.embed_pose(pred_pose_enc)
+
+ # Generate modulation parameters and split them into shift, scale, and gate components.
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
+
+ # Adaptive layer normalization and modulation.
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
+
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
+ # Compute the delta update for the pose encoding.
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
+
+ if pred_pose_enc is None:
+ pred_pose_enc = pred_pose_enc_delta
+ else:
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
+
+ # Apply final activation functions for translation, quaternion, and field-of-view.
+ activated_pose = activate_pose(
+ pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
+ )
+ pred_pose_enc_list.append(activated_pose)
+
+ return pred_pose_enc_list
+
+
+def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ """
+ Modulate the input tensor using scaling and shifting parameters.
+ """
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
+ return x * (1 + scale) + shift
diff --git a/FastVGGT/vggt/heads/dpt_head.py b/FastVGGT/vggt/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e20d65ef5bfeb23cf83ca748aedff738840b4ffd
--- /dev/null
+++ b/FastVGGT/vggt/heads/dpt_head.py
@@ -0,0 +1,598 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
+
+
+import os
+from typing import List, Dict, Tuple, Union, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .head_act import activate_head
+from .utils import create_uv_grid, position_grid_to_embed
+
+
+class DPTHead(nn.Module):
+ """
+ DPT Head for dense prediction tasks.
+
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
+ backbone and produces dense predictions by fusing multi-scale features.
+
+ Args:
+ dim_in (int): Input dimension (channels).
+ patch_size (int, optional): Patch size. Default is 14.
+ output_dim (int, optional): Number of output channels. Default is 4.
+ activation (str, optional): Activation type. Default is "inv_log".
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
+ out_channels (List[int], optional): Output channels for each intermediate layer.
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
+ """
+
+ def __init__(
+ self,
+ dim_in: int,
+ patch_size: int = 14,
+ output_dim: int = 4,
+ activation: str = "inv_log",
+ conf_activation: str = "expp1",
+ features: int = 256,
+ out_channels: List[int] = [256, 512, 1024, 1024],
+ intermediate_layer_idx: List[int] = [0, 1, 2, 3],
+ pos_embed: bool = True,
+ feature_only: bool = False,
+ down_ratio: int = 1,
+ ) -> None:
+ super(DPTHead, self).__init__()
+ self.patch_size = patch_size
+ self.activation = activation
+ self.conf_activation = conf_activation
+ self.pos_embed = pos_embed
+ self.feature_only = feature_only
+ self.down_ratio = down_ratio
+ self.intermediate_layer_idx = intermediate_layer_idx
+
+ self.norm = nn.LayerNorm(dim_in)
+
+ # Projection layers for each output channel from tokens.
+ self.projects = nn.ModuleList(
+ [
+ nn.Conv2d(
+ in_channels=dim_in,
+ out_channels=oc,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ for oc in out_channels
+ ]
+ )
+
+ # Resize layers for upsampling feature maps.
+ self.resize_layers = nn.ModuleList(
+ [
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ ),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ ]
+ )
+
+ self.scratch = _make_scratch(out_channels, features, expand=False)
+
+ # Attach additional modules to scratch.
+ self.scratch.stem_transpose = nn.Identity() # Use Identity instead of None
+ self.scratch.refinenet1 = _make_fusion_block(features)
+ self.scratch.refinenet2 = _make_fusion_block(features)
+ self.scratch.refinenet3 = _make_fusion_block(features)
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
+
+ head_features_1 = features
+ head_features_2 = 32
+
+ if feature_only:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.scratch.output_conv1 = nn.Conv2d(
+ head_features_1,
+ head_features_1 // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ conv2_in_channels = head_features_1 // 2
+
+ self.scratch.output_conv2 = nn.Sequential(
+ nn.Conv2d(
+ conv2_in_channels,
+ head_features_2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(
+ head_features_2, output_dim, kernel_size=1, stride=1, padding=0
+ ),
+ )
+
+ def forward(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_chunk_size: int = 8,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Forward pass through the DPT head, supports processing by chunking frames.
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
+ If None or larger than S, all frames are processed at once. Default: 8.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]:
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
+ """
+ B, S, _, H, W = images.shape
+
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
+ if frames_chunk_size is None or frames_chunk_size >= S:
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
+
+ # Otherwise, process frames in chunks to manage memory usage
+ assert frames_chunk_size > 0
+
+ # Process frames in batches
+ all_preds = []
+ all_conf = []
+
+ for frames_start_idx in range(0, S, frames_chunk_size):
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
+
+ # Process batch of frames
+ if self.feature_only:
+ chunk_output = self._forward_impl(
+ aggregated_tokens_list,
+ images,
+ patch_start_idx,
+ frames_start_idx,
+ frames_end_idx,
+ )
+ all_preds.append(chunk_output)
+ else:
+ chunk_preds, chunk_conf = self._forward_impl(
+ aggregated_tokens_list,
+ images,
+ patch_start_idx,
+ frames_start_idx,
+ frames_end_idx,
+ )
+ all_preds.append(chunk_preds)
+ all_conf.append(chunk_conf)
+
+ # Concatenate results along the sequence dimension
+ if self.feature_only:
+ return torch.cat(all_preds, dim=1)
+ else:
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
+
+ def _forward_impl(
+ self,
+ aggregated_tokens_list: List[torch.Tensor],
+ images: torch.Tensor,
+ patch_start_idx: int,
+ frames_start_idx: Optional[int] = None,
+ frames_end_idx: Optional[int] = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Implementation of the forward pass through the DPT head.
+
+ This method processes a specific chunk of frames from the sequence.
+
+ Args:
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
+ images (Tensor): Input images with shape [B, S, 3, H, W].
+ patch_start_idx (int): Starting index for patch tokens.
+ frames_start_idx (int, optional): Starting index for frames to process.
+ frames_end_idx (int, optional): Ending index for frames to process.
+
+ Returns:
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
+ """
+ if frames_start_idx is not None and frames_end_idx is not None:
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
+
+ B, S, _, H, W = images.shape
+
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
+
+ out = []
+ dpt_idx = 0
+
+ for layer_idx in self.intermediate_layer_idx:
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
+
+ # Select frames if processing a chunk
+ if frames_start_idx is not None and frames_end_idx is not None:
+ x = x[:, frames_start_idx:frames_end_idx]
+
+ x = x.reshape(B * S, -1, x.shape[-1])
+
+ x = self.norm(x)
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[dpt_idx](x)
+ if self.pos_embed:
+ x = self._apply_pos_embed(x, W, H)
+ x = self.resize_layers[dpt_idx](x)
+
+ out.append(x)
+ dpt_idx += 1
+
+ # Fuse features from multiple layers.
+ out = self.scratch_forward(out)
+ # Interpolate fused output to match target image resolution.
+ out = custom_interpolate(
+ out,
+ (
+ int(patch_h * self.patch_size / self.down_ratio),
+ int(patch_w * self.patch_size / self.down_ratio),
+ ),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ if self.pos_embed:
+ out = self._apply_pos_embed(out, W, H)
+
+ if self.feature_only:
+ return out.view(B, S, *out.shape[1:])
+
+ out = self.scratch.output_conv2(out)
+ preds, conf = activate_head(
+ out, activation=self.activation, conf_activation=self.conf_activation
+ )
+
+ preds = preds.view(B, S, *preds.shape[1:])
+ conf = conf.view(B, S, *conf.shape[1:])
+ return preds, conf
+
+ def _apply_pos_embed(
+ self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
+ ) -> torch.Tensor:
+ """
+ Apply positional embedding to tensor x.
+ """
+ patch_w = x.shape[-1]
+ patch_h = x.shape[-2]
+ pos_embed = create_uv_grid(
+ patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
+ )
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
+ pos_embed = pos_embed * ratio
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
+ return x + pos_embed
+
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Forward pass through the fusion blocks.
+
+ Args:
+ features (List[Tensor]): List of feature maps from different layers.
+
+ Returns:
+ Tensor: Fused feature map.
+ """
+ layer_1, layer_2, layer_3, layer_4 = features
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ del layer_4_rn, layer_4
+
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
+ del layer_3_rn, layer_3
+
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
+ del layer_2_rn, layer_2
+
+ out = self.scratch.refinenet1(out, layer_1_rn)
+ del layer_1_rn, layer_1
+
+ out = self.scratch.output_conv1(out)
+ return out
+
+
+################################################################################
+# Modules
+################################################################################
+
+
+def _make_fusion_block(
+ features: int,
+ size: Optional[int] = None,
+ has_residual: bool = True,
+ groups: int = 1,
+) -> nn.Module:
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(inplace=True),
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=size,
+ has_residual=has_residual,
+ groups=groups,
+ )
+
+
+def _make_scratch(
+ in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
+) -> nn.Module:
+ scratch = nn.Module()
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0],
+ out_shape1,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1],
+ out_shape2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2],
+ out_shape3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3],
+ out_shape4,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ groups=groups,
+ )
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features, activation, bn, groups=1):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+ self.groups = groups
+ self.conv1 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ groups=self.groups,
+ )
+ self.conv2 = nn.Conv2d(
+ features,
+ features,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ groups=self.groups,
+ )
+
+ self.norm1 = None
+ self.norm2 = None
+
+ self.activation = activation
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.norm1 is not None:
+ out = self.norm1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.norm2 is not None:
+ out = self.norm2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None,
+ has_residual=True,
+ groups=1,
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+ self.groups = groups
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(
+ features,
+ out_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ groups=self.groups,
+ )
+
+ if has_residual:
+ self.resConfUnit1 = ResidualConvUnit(
+ features, activation, bn, groups=self.groups
+ )
+
+ self.has_residual = has_residual
+ self.resConfUnit2 = ResidualConvUnit(
+ features, activation, bn, groups=self.groups
+ )
+
+ self.size = size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if self.has_residual:
+ res = self.resConfUnit1(xs[1])
+ output = output + res
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = custom_interpolate(
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
+ )
+ output = self.out_conv(output)
+
+ return output
+
+
+def custom_interpolate(
+ x: torch.Tensor,
+ size: Optional[Tuple[int, int]] = None,
+ scale_factor: Optional[float] = None,
+ mode: str = "bilinear",
+ align_corners: bool = True,
+) -> torch.Tensor:
+ """
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
+ """
+ if size is None:
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
+
+ INT_MAX = 1610612736
+
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
+
+ if input_elements > INT_MAX:
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
+ interpolated_chunks = [
+ nn.functional.interpolate(
+ chunk, size=size, mode=mode, align_corners=align_corners
+ )
+ for chunk in chunks
+ ]
+ x = torch.cat(interpolated_chunks, dim=0)
+ return x.contiguous()
+ else:
+ return nn.functional.interpolate(
+ x, size=size, mode=mode, align_corners=align_corners
+ )
diff --git a/FastVGGT/vggt/heads/head_act.py b/FastVGGT/vggt/heads/head_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489
--- /dev/null
+++ b/FastVGGT/vggt/heads/head_act.py
@@ -0,0 +1,125 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn.functional as F
+
+
+def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
+ """
+ Activate pose parameters with specified activation functions.
+
+ Args:
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
+ trans_act: Activation type for translation component
+ quat_act: Activation type for quaternion component
+ fl_act: Activation type for focal length component
+
+ Returns:
+ Activated pose parameters tensor
+ """
+ T = pred_pose_enc[..., :3]
+ quat = pred_pose_enc[..., 3:7]
+ fl = pred_pose_enc[..., 7:] # or fov
+
+ T = base_pose_act(T, trans_act)
+ quat = base_pose_act(quat, quat_act)
+ fl = base_pose_act(fl, fl_act) # or fov
+
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
+
+ return pred_pose_enc
+
+
+def base_pose_act(pose_enc, act_type="linear"):
+ """
+ Apply basic activation function to pose parameters.
+
+ Args:
+ pose_enc: Tensor containing encoded pose parameters
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
+
+ Returns:
+ Activated pose parameters
+ """
+ if act_type == "linear":
+ return pose_enc
+ elif act_type == "inv_log":
+ return inverse_log_transform(pose_enc)
+ elif act_type == "exp":
+ return torch.exp(pose_enc)
+ elif act_type == "relu":
+ return F.relu(pose_enc)
+ else:
+ raise ValueError(f"Unknown act_type: {act_type}")
+
+
+def activate_head(out, activation="norm_exp", conf_activation="expp1"):
+ """
+ Process network output to extract 3D points and confidence values.
+
+ Args:
+ out: Network output tensor (B, C, H, W)
+ activation: Activation type for 3D points
+ conf_activation: Activation type for confidence values
+
+ Returns:
+ Tuple of (3D points tensor, confidence tensor)
+ """
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
+
+ # Split into xyz (first C-1 channels) and confidence (last channel)
+ xyz = fmap[:, :, :, :-1]
+ conf = fmap[:, :, :, -1]
+
+ if activation == "norm_exp":
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
+ xyz_normed = xyz / d
+ pts3d = xyz_normed * torch.expm1(d)
+ elif activation == "norm":
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
+ elif activation == "exp":
+ pts3d = torch.exp(xyz)
+ elif activation == "relu":
+ pts3d = F.relu(xyz)
+ elif activation == "inv_log":
+ pts3d = inverse_log_transform(xyz)
+ elif activation == "xy_inv_log":
+ xy, z = xyz.split([2, 1], dim=-1)
+ z = inverse_log_transform(z)
+ pts3d = torch.cat([xy * z, z], dim=-1)
+ elif activation == "sigmoid":
+ pts3d = torch.sigmoid(xyz)
+ elif activation == "linear":
+ pts3d = xyz
+ else:
+ raise ValueError(f"Unknown activation: {activation}")
+
+ if conf_activation == "expp1":
+ conf_out = 1 + conf.exp()
+ elif conf_activation == "expp0":
+ conf_out = conf.exp()
+ elif conf_activation == "sigmoid":
+ conf_out = torch.sigmoid(conf)
+ else:
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
+
+ return pts3d, conf_out
+
+
+def inverse_log_transform(y):
+ """
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
+
+ Args:
+ y: Input tensor
+
+ Returns:
+ Transformed tensor
+ """
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
diff --git a/FastVGGT/vggt/heads/track_head.py b/FastVGGT/vggt/heads/track_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4f1d9bd83cca1f74f97a644a02b984904f84706
--- /dev/null
+++ b/FastVGGT/vggt/heads/track_head.py
@@ -0,0 +1,104 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+from .dpt_head import DPTHead
+from .track_modules.base_track_predictor import BaseTrackerPredictor
+
+
+class TrackHead(nn.Module):
+ """
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
+ The tracking is performed iteratively, refining predictions over multiple iterations.
+ """
+
+ def __init__(
+ self,
+ dim_in,
+ patch_size=14,
+ features=128,
+ iters=4,
+ predict_conf=True,
+ stride=2,
+ corr_levels=7,
+ corr_radius=4,
+ hidden_size=384,
+ ):
+ """
+ Initialize the TrackHead module.
+
+ Args:
+ dim_in (int): Input dimension of tokens from the backbone.
+ patch_size (int): Size of image patches used in the vision transformer.
+ features (int): Number of feature channels in the feature extractor output.
+ iters (int): Number of refinement iterations for tracking predictions.
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
+ stride (int): Stride value for the tracker predictor.
+ corr_levels (int): Number of correlation pyramid levels
+ corr_radius (int): Radius for correlation computation, controlling the search area.
+ hidden_size (int): Size of hidden layers in the tracker network.
+ """
+ super().__init__()
+
+ self.patch_size = patch_size
+
+ # Feature extractor based on DPT architecture
+ # Processes tokens into feature maps for tracking
+ self.feature_extractor = DPTHead(
+ dim_in=dim_in,
+ patch_size=patch_size,
+ features=features,
+ feature_only=True, # Only output features, no activation
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
+ pos_embed=False,
+ )
+
+ # Tracker module that predicts point trajectories
+ # Takes feature maps and predicts coordinates and visibility
+ self.tracker = BaseTrackerPredictor(
+ latent_dim=features, # Match the output_dim of feature extractor
+ predict_conf=predict_conf,
+ stride=stride,
+ corr_levels=corr_levels,
+ corr_radius=corr_radius,
+ hidden_size=hidden_size,
+ )
+
+ self.iters = iters
+
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
+ """
+ Forward pass of the TrackHead.
+
+ Args:
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
+ B = batch size, S = sequence length.
+ patch_start_idx (int): Starting index for patch tokens.
+ query_points (torch.Tensor, optional): Initial query points to track.
+ If None, points are initialized by the tracker.
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
+
+ Returns:
+ tuple:
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
+ """
+ B, S, _, H, W = images.shape
+
+ # Extract features from tokens
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
+
+ # Use default iterations if not specified
+ if iters is None:
+ iters = self.iters
+
+ # Perform tracking using the extracted features
+ coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters)
+
+ return coord_preds, vis_scores, conf_scores
diff --git a/FastVGGT/vggt/heads/track_modules/__init__.py b/FastVGGT/vggt/heads/track_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa
--- /dev/null
+++ b/FastVGGT/vggt/heads/track_modules/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5ca81dbf0ca49b601ca2662314a3391c8ba00890
Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..350f8ca52c0aca644125fea47f83d49337478bcb
Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81e9466ff91b16d265807533eeada2e8e674d7e0
Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3fe19d5e60445f0fa06866bffaa2b8b2e60a6f8
Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cac348bf405bc00bc6578760ba8ad57185f33b81
Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/heads/track_modules/base_track_predictor.py b/FastVGGT/vggt/heads/track_modules/base_track_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ce8ec4b66fff236e015d1bcaf85c8237a52be7a
--- /dev/null
+++ b/FastVGGT/vggt/heads/track_modules/base_track_predictor.py
@@ -0,0 +1,209 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+from .blocks import EfficientUpdateFormer, CorrBlock
+from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
+from .modules import Mlp
+
+
+class BaseTrackerPredictor(nn.Module):
+ def __init__(
+ self,
+ stride=1,
+ corr_levels=5,
+ corr_radius=4,
+ latent_dim=128,
+ hidden_size=384,
+ use_spaceatt=True,
+ depth=6,
+ max_scale=518,
+ predict_conf=True,
+ ):
+ super(BaseTrackerPredictor, self).__init__()
+ """
+ The base template to create a track predictor
+
+ Modified from https://github.com/facebookresearch/co-tracker/
+ and https://github.com/facebookresearch/vggsfm
+ """
+
+ self.stride = stride
+ self.latent_dim = latent_dim
+ self.corr_levels = corr_levels
+ self.corr_radius = corr_radius
+ self.hidden_size = hidden_size
+ self.max_scale = max_scale
+ self.predict_conf = predict_conf
+
+ self.flows_emb_dim = latent_dim // 2
+
+ self.corr_mlp = Mlp(
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
+ hidden_features=self.hidden_size,
+ out_features=self.latent_dim,
+ )
+
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
+
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
+
+ space_depth = depth if use_spaceatt else 0
+ time_depth = depth
+
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=space_depth,
+ time_depth=time_depth,
+ input_dim=self.transformer_dim,
+ hidden_size=self.hidden_size,
+ output_dim=self.latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=use_spaceatt,
+ )
+
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
+
+ # A linear layer to update track feats at each iteration
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
+
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ if predict_conf:
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
+
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
+ """
+ query_points: B x N x 2, the number of batches, tracks, and xy
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
+ note HH and WW is the size of feature maps instead of original images
+ """
+ B, N, D = query_points.shape
+ B, S, C, HH, WW = fmaps.shape
+
+ assert D == 2, "Input points must be 2D coordinates"
+
+ # apply a layernorm to fmaps here
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
+
+ # Scale the input query_points because we may downsample the images
+ # by down_ratio or self.stride
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
+ # its query_points should be query_points/4
+ if down_ratio > 1:
+ query_points = query_points / float(down_ratio)
+
+ query_points = query_points / float(self.stride)
+
+ # Init with coords as the query points
+ # It means the search will start from the position of query points at the reference frames
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
+
+ # Sample/extract the features of the query points in the query frame
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
+
+ # init track feats by query feats
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
+ # back up the init coords
+ coords_backup = coords.clone()
+
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
+
+ coord_preds = []
+
+ # Iterative Refinement
+ for _ in range(iters):
+ # Detach the gradients from the last iteration
+ # (in my experience, not very important for performance)
+ coords = coords.detach()
+
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
+
+ corr_dim = fcorrs.shape[3]
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
+ fcorrs_ = self.corr_mlp(fcorrs_)
+
+ # Movement of current coords relative to query points
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
+
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
+
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
+
+ # Concatenate them as the input for the transformers
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
+
+ # 2D positional embed
+ # TODO: this can be much simplified
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
+
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
+
+ x = transformer_input + sampled_pos_emb
+
+ # Add the query ref token to the track feats
+ query_ref_token = torch.cat(
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
+ )
+ x = x + query_ref_token.to(x.device).to(x.dtype)
+
+ # B, N, S, C
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
+
+ # Compute the delta coordinates and delta track features
+ delta, _ = self.updateformer(x)
+
+ # BN, S, C
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
+ delta_coords_ = delta[:, :, :2]
+ delta_feats_ = delta[:, :, 2:]
+
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
+
+ # Update the track features
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
+
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
+
+ # B x S x N x 2
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
+
+ # Force coord0 as query
+ # because we assume the query points should not be changed
+ coords[:, 0] = coords_backup[:, 0]
+
+ # The predicted tracks are in the original image scale
+ if down_ratio > 1:
+ coord_preds.append(coords * self.stride * down_ratio)
+ else:
+ coord_preds.append(coords * self.stride)
+
+ # B, S, N
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ vis_e = torch.sigmoid(vis_e)
+
+ if self.predict_conf:
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
+ if apply_sigmoid:
+ conf_e = torch.sigmoid(conf_e)
+ else:
+ conf_e = None
+
+ if return_feat:
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
+ else:
+ return coord_preds, vis_e, conf_e
diff --git a/FastVGGT/vggt/heads/track_modules/blocks.py b/FastVGGT/vggt/heads/track_modules/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..15c161c89ef99742b0f2c6f397c9121fe9301e08
--- /dev/null
+++ b/FastVGGT/vggt/heads/track_modules/blocks.py
@@ -0,0 +1,236 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+# Modified from https://github.com/facebookresearch/co-tracker/
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import bilinear_sampler
+from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
+
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ ):
+ super().__init__()
+
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+
+ # Add input LayerNorm before linear projection
+ self.input_norm = nn.LayerNorm(input_dim)
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+
+ # Add output LayerNorm before final projection
+ self.output_norm = nn.LayerNorm(hidden_size)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+ self.num_virtual_tracks = num_virtual_tracks
+
+ if self.add_space_attn:
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
+ else:
+ self.virual_tracks = None
+
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, mask=None):
+ # Apply input LayerNorm
+ input_tensor = self.input_norm(input_tensor)
+ tokens = self.input_transform(input_tensor)
+
+ init_tokens = tokens
+
+ B, _, T, _ = tokens.shape
+
+ if self.add_space_attn:
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+
+ _, N, _, _ = tokens.shape
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+
+ time_tokens = self.time_blocks[i](time_tokens)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
+
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
+ j += 1
+
+ if self.add_space_attn:
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+
+ tokens = tokens + init_tokens
+
+ # Apply output LayerNorm before final projection
+ tokens = self.output_norm(tokens)
+ flow = self.flow_head(tokens)
+
+ return flow, None
+
+
+class CorrBlock:
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
+ """
+ Build a pyramid of feature maps from the input.
+
+ fmaps: Tensor (B, S, C, H, W)
+ num_levels: number of pyramid levels (each downsampled by factor 2)
+ radius: search radius for sampling correlation
+ multiple_track_feats: if True, split the target features per pyramid level
+ padding_mode: passed to grid_sample / bilinear_sampler
+ """
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.num_levels = num_levels
+ self.radius = radius
+ self.padding_mode = padding_mode
+ self.multiple_track_feats = multiple_track_feats
+
+ # Build pyramid: each level is half the spatial resolution of the previous
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
+ current_fmaps = fmaps
+ for i in range(num_levels - 1):
+ B, S, C, H, W = current_fmaps.shape
+ # Merge batch & sequence dimensions
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
+ # Avg pool down by factor 2
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
+ _, _, H_new, W_new = current_fmaps.shape
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
+ self.fmaps_pyramid.append(current_fmaps)
+
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
+ # This grid is added to the (scaled) coordinate centroids.
+ r = self.radius
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
+
+ def corr_sample(self, targets, coords):
+ """
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
+ volume, sample it immediately, then discard it. This saves GPU memory.
+
+ Args:
+ targets: Tensor (B, S, N, C) — features for the current targets.
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
+
+ Returns:
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
+ """
+ B, S, N, C = targets.shape
+
+ # If you have multiple track features, split them per level.
+ if self.multiple_track_feats:
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
+
+ out_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ # Get current spatial resolution H, W for this pyramid level.
+ B, S, C, H, W = fmaps.shape
+ # Reshape feature maps for correlation computation:
+ # fmap2s: (B, S, C, H*W)
+ fmap2s = fmaps.view(B, S, C, H * W)
+ # Choose appropriate target features.
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
+
+ # Compute correlation directly
+ corrs = compute_corr_level(fmap1, fmap2s, C)
+ corrs = corrs.view(B, S, N, H, W)
+
+ # Prepare sampling grid:
+ # Scale down the coordinates for the current level.
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
+ # Make sure our precomputed delta grid is on the same device/dtype.
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
+ # Now the grid for grid_sample is:
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
+
+ # Sample from the correlation volume using bilinear interpolation.
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
+ corrs_sampled = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
+ )
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
+ out_pyramid.append(corrs_sampled)
+
+ # Concatenate all levels along the last dimension.
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
+ return out
+
+
+def compute_corr_level(fmap1, fmap2s, C):
+ # fmap1: (B, S, N, C)
+ # fmap2s: (B, S, C, H*W)
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
+ return corrs / math.sqrt(C)
diff --git a/FastVGGT/vggt/heads/track_modules/modules.py b/FastVGGT/vggt/heads/track_modules/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..12de4f1ad76364d4665e53ac80e1037fadf98d08
--- /dev/null
+++ b/FastVGGT/vggt/heads/track_modules/modules.py
@@ -0,0 +1,204 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable
+import collections
+from torch import Tensor
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class ResidualBlock(nn.Module):
+ """
+ ResidualBlock: construct a block of two conv layers with residual connections
+ """
+
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros"
+ )
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros")
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+ else:
+ raise NotImplementedError
+
+ if stride == 1:
+ self.downsample = None
+ else:
+ self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
+ mlp_ratio=4.0,
+ **block_kwargs,
+ ):
+ """
+ Self attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, mask=None):
+ # Prepare the mask for PyTorch's attention (it expects a different format)
+ # attn_mask = mask if mask is not None else None
+ # Normalize before attention
+ x = self.norm1(x)
+
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
+
+ attn_output, _ = self.attn(x, x, x)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class CrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
+ """
+ Cross attention block
+ """
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(hidden_size)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size)
+
+ self.cross_attn = nn.MultiheadAttention(
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
+ )
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
+
+ def forward(self, x, context, mask=None):
+ # Normalize inputs
+ x = self.norm1(x)
+ context = self.norm_context(context)
+
+ # Apply cross attention
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
+
+ # Add & Norm
+ x = x + attn_output
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/FastVGGT/vggt/heads/track_modules/utils.py b/FastVGGT/vggt/heads/track_modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f1fffeaedd33c7f1c2ef54220e24a2a0e5a57b2
--- /dev/null
+++ b/FastVGGT/vggt/heads/track_modules/utils.py
@@ -0,0 +1,223 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from https://github.com/facebookresearch/vggsfm
+# and https://github.com/facebookresearch/co-tracker/tree/main
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Union
+
+
+def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
+ """
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid_size: The grid size.
+ Returns:
+ - pos_embed: The generated 2D positional embedding.
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if return_grid:
+ return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid)
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid: The grid to generate the embedding from.
+
+ Returns:
+ - emb: The generated 2D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+
+def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
+
+ Args:
+ - xy: The coordinates to generate the embedding from.
+ - C: The size of the embedding.
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
+
+ Returns:
+ - pe: The generated 2D positional embedding.
+ """
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
+ return pe
+
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+ coords = coords.detach().clone()
+ ############################################################
+ # IMPORTANT:
+ coords = coords.to(input.device).to(input.dtype)
+ ############################################################
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ scale = torch.tensor(
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
+ )
+ else:
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
+
+ coords.mul_(scale) # coords = coords * scale
+ coords.sub_(1) # coords = coords - 1
+
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
+
+
+def sample_features4d(input, coords):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
diff --git a/FastVGGT/vggt/heads/utils.py b/FastVGGT/vggt/heads/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..533fc8ae67a75cd0a94d5ca96dc5a0513446c64f
--- /dev/null
+++ b/FastVGGT/vggt/heads/utils.py
@@ -0,0 +1,109 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+
+def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
+ """
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
+
+ Args:
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
+ embed_dim: Output channel dimension for embeddings
+
+ Returns:
+ Tensor of shape (H, W, embed_dim) with positional embeddings
+ """
+ H, W, grid_dim = pos_grid.shape
+ assert grid_dim == 2
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
+
+ # Process x and y coordinates separately
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
+
+ # Combine and reshape
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
+
+ return emb.view(H, W, embed_dim) # [H, W, D]
+
+
+def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ device = pos.device
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / omega_0**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb.float()
+
+
+# Inspired by https://github.com/microsoft/moge
+
+
+def create_uv_grid(
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
+) -> torch.Tensor:
+ """
+ Create a normalized UV grid of shape (width, height, 2).
+
+ The grid spans horizontally and vertically according to an aspect ratio,
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
+
+ Args:
+ width (int): Number of points horizontally.
+ height (int): Number of points vertically.
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
+ device (torch.device, optional): Device on which the tensor is created.
+
+ Returns:
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
+ """
+ # Derive aspect ratio if not explicitly provided
+ if aspect_ratio is None:
+ aspect_ratio = float(width) / float(height)
+
+ # Compute normalized spans for X and Y
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
+ span_x = aspect_ratio / diag_factor
+ span_y = 1.0 / diag_factor
+
+ # Establish the linspace boundaries
+ left_x = -span_x * (width - 1) / width
+ right_x = span_x * (width - 1) / width
+ top_y = -span_y * (height - 1) / height
+ bottom_y = span_y * (height - 1) / height
+
+ # Generate 1D coordinates
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
+
+ # Create 2D meshgrid (width x height) and stack into UV
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
+ uv_grid = torch.stack((uu, vv), dim=-1)
+
+ return uv_grid
diff --git a/FastVGGT/vggt/layers/__init__.py b/FastVGGT/vggt/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1
--- /dev/null
+++ b/FastVGGT/vggt/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2840683edfd99fcda3d1d5c9e18d699dba029a86
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..42e56e451e90459b8b74d99bd0ac9049e702e1bd
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..069b3ad92a285342f200297a7358dbf39c0ef372
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e0bc9cb8a27db08fc0921a2d8a54a60b2b126acd
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a416380412c8942981ab0e3f5369cebe05f9496c
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..053bf054f2a74eefeea0536eeddd23daacd28cbe
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20fa7427ef4c285abd1d9b48743f4c3360953591
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..af9854cf6446fb30404f4529d63b20c20f8e63e8
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9b1fc5deb392addff7c8620b095e0ba4d90a8de
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d36ad71ca8dacfb3991b5b2560bc89cdfdd15c52
Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/layers/attention.py b/FastVGGT/vggt/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..aef68a4e2c628ad257a6b77b46447c8421c5b11a
--- /dev/null
+++ b/FastVGGT/vggt/layers/attention.py
@@ -0,0 +1,257 @@
+import os
+from pathlib import Path
+import time
+from PIL import Image
+import torch
+from torch import Tensor
+from torch import nn
+import torch.nn.functional as F
+from tqdm.std import tqdm
+from merging.merge import (
+ token_merge_bipartite2d,
+)
+import matplotlib.pyplot as plt
+import numpy as np
+from PIL import Image
+
+XFORMERS_AVAILABLE = False
+
+# Global variables for attention visualization
+vis_attn_map = False
+current_images = []
+attention_map = None
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ qk_norm: bool = False,
+ kv_group_size: int = 1,
+ fused_attn: bool = True,
+ rope=None,
+ global_merging=None,
+ patch_width: int = 37,
+ patch_height: int = 28,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ self.patch_width = patch_width
+ self.patch_height = patch_height
+ self.fused_attn = fused_attn
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.rope = rope
+ self.kv_group_size = kv_group_size
+
+ def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor:
+ merge_num = list(range(24))
+
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv.unbind(0)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.rope is not None:
+ q = self.rope(q, pos)
+ k = self.rope(k, pos)
+
+ if vis_attn_map and global_merging is not None:
+ # s1 Chunk computation of q@k
+ chunk_size = 4096
+ attn_maps = []
+ total_chunks = (
+ q.size(-2) + chunk_size - 1
+ ) // chunk_size # Calculate total number of chunks
+ start_time_total = time.time()
+ with torch.no_grad():
+ for start in tqdm(
+ range(0, q.size(-2), chunk_size),
+ total=total_chunks,
+ desc="Processing chunks",
+ ):
+ end = min(start + chunk_size, q.size(-2))
+ q_chunk = q[:, :, start:end, :] # (1, 16, chunk_size, 64)
+ start_time_chunk = time.time()
+ attn_chunk = q_chunk @ k.transpose(
+ -2, -1
+ ) # (1, 16, chunk_size, 34353)
+ attn_maps.append(attn_chunk.cpu())
+ end_time_chunk = time.time()
+ print(
+ f"Chunk {start}:{end} processed in {end_time_chunk - start_time_chunk:.4f} seconds"
+ )
+ end_time_total = time.time()
+ print(
+ f"\nTotal processing time: {end_time_total - start_time_total:.4f} seconds"
+ )
+
+ attn_map = torch.cat(attn_maps, dim=-2)
+ attn = attn_map[0].mean(0)
+ frame_token_num = self.patch_height * self.patch_width + 5
+ for target_token_idx in [
+ 0,
+ self.patch_height * self.patch_width,
+ self.patch_height * self.patch_width * 10,
+ ]: # Iterate through each image's target_token
+ for image_idx in range(
+ len(current_images)
+ ): # Corresponding to which image to visualize
+ target_attn = attn[
+ target_token_idx,
+ image_idx * frame_token_num : (image_idx + 1) * frame_token_num,
+ ]
+ target_attn_map = target_attn[5:].reshape(
+ self.patch_height, self.patch_width
+ )
+ # 1) Read original image to get true size (H, W)
+ image_path = current_images[image_idx]
+ p = Path(image_path)
+ parts = p.parts
+ scene_name = parts[-4]
+
+ image = Image.open(image_path).convert("RGB")
+ img_width, img_height = image.size # PIL size: (W, H)
+
+ # Upsample attention map to the original image size
+ target_attn_map = F.interpolate(
+ target_attn_map.unsqueeze(0).unsqueeze(0), # (1,1,h,w)
+ size=(img_height, img_width),
+ mode="bilinear",
+ )
+ target_attn_map = target_attn_map.squeeze()
+
+ # Convert image to numpy for blending
+ img_np = np.array(image) / 255.0
+
+ # 2. Normalize attention map
+ target_attn_map = (target_attn_map - target_attn_map.min()) / (
+ target_attn_map.max() - target_attn_map.min()
+ )
+
+ # 3. Color attention map
+ cmap = plt.get_cmap("jet")
+ attn_color = cmap(target_attn_map.cpu().float().numpy())
+ attn_color = attn_color[:, :, :3]
+
+ # 4. Blend attention and original image
+ overlay = img_np * 0.5 + attn_color * 0.5
+
+ plt.imshow(overlay, cmap="viridis")
+ output_dir = f"attention_map/{scene_name}/block_{attention_map}/token_{target_token_idx}"
+ os.makedirs(output_dir, exist_ok=True)
+ output_path = os.path.join(output_dir, f"color_{image_idx}.png")
+ plt.savefig(output_path)
+ plt.close()
+
+ if global_merging is not None and global_merging in merge_num:
+ generator = torch.Generator(device=x.device)
+ generator.manual_seed(33)
+
+ merge_ratio = 0.9
+ r = int(x.shape[1] * merge_ratio)
+
+ m, u = token_merge_bipartite2d(
+ x,
+ self.patch_width,
+ self.patch_height,
+ 2,
+ 2,
+ r,
+ False,
+ generator,
+ enable_protection=True,
+ )
+
+ m_a, u_a = (m, u)
+
+ B_q, H_q, N_q, D_q = q.shape
+
+ q_merge_in = q.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q)
+ k_merge_in = k.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q)
+ v_merge_in = v.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q)
+
+ q_out, k_out, v_out = m_a(
+ q_merge_in,
+ mode="mean",
+ extra_tensors=k_merge_in,
+ extra_tensors_2=v_merge_in,
+ )
+
+ del q_merge_in, k_merge_in, v_merge_in
+
+ N_m = q_out.shape[1]
+ q = q_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3)
+ k = k_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3)
+ v = v_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3)
+
+ del q_out, k_out, v_out
+
+ N = N_m
+
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+ del q, k, v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ if global_merging is not None and global_merging in merge_num:
+ x = u_a(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(
+ self, x: Tensor, attn_bias=None, pos=None, global_merging=None
+ ) -> Tensor:
+ assert (
+ pos is None or self.rope is not None
+ ), "Position encoding is only supported with RoPE"
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x, pos=pos, global_merging=global_merging)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = qkv.unbind(2)
+
+ if self.rope is not None and pos is not None:
+ q = self.rope(q, pos)
+ k = self.rope(k, pos)
+
+ # Use scaled dot-product attention
+ attn = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim**-0.5)
+ if attn_bias is not None:
+ attn = attn + attn_bias
+ attn = F.softmax(attn, dim=-1)
+ x = torch.matmul(attn, v)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/FastVGGT/vggt/layers/block.py b/FastVGGT/vggt/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffa3c18aaf79fc755630d24c24c1f4c7f75624e1
--- /dev/null
+++ b/FastVGGT/vggt/layers/block.py
@@ -0,0 +1,272 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+
+from typing import Callable, List, Any, Tuple, Dict
+import torch
+from torch import nn, Tensor
+from .attention import Attention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ qk_norm: bool = False,
+ fused_attn: bool = True,
+ rope=None,
+ merging=0,
+ ) -> None:
+ super().__init__()
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ qk_norm=qk_norm,
+ fused_attn=fused_attn,
+ rope=rope,
+ )
+
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor:
+ norm1_output = self.norm1(x)
+ attn_output = self.attn(norm1_output, pos=pos, global_merging=global_merging)
+ del norm1_output
+ x = x + self.ls1(attn_output)
+ del attn_output
+
+ norm2_output = self.norm2(x)
+ mlp_output = self.mlp(norm2_output)
+ del norm2_output
+ x = x + self.ls2(mlp_output)
+ del mlp_output
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+ pos=None,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ if pos is not None:
+ # if necessary, apply rope to the subset
+ pos = pos[brange]
+ residual = residual_func(x_subset, pos=pos)
+ else:
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ else:
+ x_plus_residual = scaled_index_add(
+ x,
+ brange,
+ residual.to(dtype=x.dtype),
+ scaling=scaling_vector,
+ alpha=residual_scale_factor,
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = (
+ [b.shape[0] for b in branges]
+ if branges is not None
+ else [x.shape[0] for x in x_list]
+ )
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
+ 1, -1, x_list[0].shape[-1]
+ )
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
+ ]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(
+ x_list, branges, residual_list, residual_scale_factors
+ ):
+ outputs.append(
+ add_residual(
+ x, brange, residual, residual_scale_factor, scaling_vector
+ ).view_as(x)
+ )
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(
+ self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
+ ),
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(
+ self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
+ ),
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/FastVGGT/vggt/layers/drop_path.py b/FastVGGT/vggt/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/FastVGGT/vggt/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/FastVGGT/vggt/layers/layer_scale.py b/FastVGGT/vggt/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ddfc51c3d87370d50175f5b4e649dac1c614ff9
--- /dev/null
+++ b/FastVGGT/vggt/layers/layer_scale.py
@@ -0,0 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/FastVGGT/vggt/layers/mlp.py b/FastVGGT/vggt/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/FastVGGT/vggt/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/FastVGGT/vggt/layers/patch_embed.py b/FastVGGT/vggt/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc19605e4d6e88d06355ae3b1afddc76f595aafe
--- /dev/null
+++ b/FastVGGT/vggt/layers/patch_embed.py
@@ -0,0 +1,85 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/FastVGGT/vggt/layers/rope.py b/FastVGGT/vggt/layers/rope.py
new file mode 100644
index 0000000000000000000000000000000000000000..84625de468ed89e69dd9e1579d541de71f2ebf37
--- /dev/null
+++ b/FastVGGT/vggt/layers/rope.py
@@ -0,0 +1,209 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+# Implementation of 2D Rotary Position Embeddings (RoPE).
+
+# This module provides a clean implementation of 2D Rotary Position Embeddings,
+# which extends the original RoPE concept to handle 2D spatial positions.
+
+# Inspired by:
+# https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# https://github.com/naver-ai/rope-vit
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Dict, Tuple
+
+
+class PositionGetter:
+ """Generates and caches 2D spatial positions for patches in a grid.
+
+ This class efficiently manages the generation of spatial coordinates for patches
+ in a 2D grid, caching results to avoid redundant computations.
+
+ Attributes:
+ position_cache: Dictionary storing precomputed position tensors for different
+ grid dimensions.
+ """
+
+ def __init__(self):
+ """Initializes the position generator with an empty cache."""
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
+
+ def __call__(
+ self, batch_size: int, height: int, width: int, device: torch.device
+ ) -> torch.Tensor:
+ """Generates spatial positions for a batch of patches.
+
+ Args:
+ batch_size: Number of samples in the batch.
+ height: Height of the grid in patches.
+ width: Width of the grid in patches.
+ device: Target device for the position tensor.
+
+ Returns:
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
+ for each position in the grid, repeated for each batch item.
+ """
+ if (height, width) not in self.position_cache:
+ y_coords = torch.arange(height, device=device)
+ x_coords = torch.arange(width, device=device)
+ positions = torch.cartesian_prod(y_coords, x_coords)
+ self.position_cache[height, width] = positions
+
+ cached_positions = self.position_cache[height, width]
+ return (
+ cached_positions.view(1, height * width, 2)
+ .expand(batch_size, -1, -1)
+ .clone()
+ )
+
+
+class RotaryPositionEmbedding2D(nn.Module):
+ """2D Rotary Position Embedding implementation.
+
+ This module applies rotary position embeddings to input tokens based on their
+ 2D spatial positions. It handles the position-dependent rotation of features
+ separately for vertical and horizontal dimensions.
+
+ Args:
+ frequency: Base frequency for the position embeddings. Default: 100.0
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
+
+ Attributes:
+ base_frequency: Base frequency for computing position embeddings.
+ scaling_factor: Factor to scale the computed frequencies.
+ frequency_cache: Cache for storing precomputed frequency components.
+ """
+
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
+ """Initializes the 2D RoPE module."""
+ super().__init__()
+ self.base_frequency = frequency
+ self.scaling_factor = scaling_factor
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
+
+ def _compute_frequency_components(
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Computes frequency components for rotary embeddings.
+
+ Args:
+ dim: Feature dimension (must be even).
+ seq_len: Maximum sequence length.
+ device: Target device for computations.
+ dtype: Data type for the computed tensors.
+
+ Returns:
+ Tuple of (cosine, sine) tensors for frequency components.
+ """
+ cache_key = (dim, seq_len, device, dtype)
+ if cache_key not in self.frequency_cache:
+ # Compute frequency bands
+ exponents = torch.arange(0, dim, 2, device=device) / dim
+ inv_freq = 1.0 / (self.base_frequency**exponents)
+
+ # Generate position-dependent frequencies
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
+
+ # Compute and cache frequency components
+ angles = angles.to(dtype)
+ angles = torch.cat((angles, angles), dim=-1)
+ cos_components = angles.cos().to(dtype)
+ sin_components = angles.sin().to(dtype)
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
+
+ return self.frequency_cache[cache_key]
+
+ @staticmethod
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
+ """Performs feature rotation by splitting and recombining feature dimensions.
+
+ Args:
+ x: Input tensor to rotate.
+
+ Returns:
+ Rotated feature tensor.
+ """
+ feature_dim = x.shape[-1]
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+ def _apply_1d_rope(
+ self,
+ tokens: torch.Tensor,
+ positions: torch.Tensor,
+ cos_comp: torch.Tensor,
+ sin_comp: torch.Tensor,
+ ) -> torch.Tensor:
+ """Applies 1D rotary position embeddings along one dimension.
+
+ Args:
+ tokens: Input token features.
+ positions: Position indices.
+ cos_comp: Cosine components for rotation.
+ sin_comp: Sine components for rotation.
+
+ Returns:
+ Tokens with applied rotary position embeddings.
+ """
+ if positions.dtype != torch.long:
+ positions = positions.long()
+
+ # Embed positions with frequency components
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
+
+ # Apply rotation
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
+
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
+ """Applies 2D rotary position embeddings to input tokens.
+
+ Args:
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
+ The feature dimension (dim) must be divisible by 4.
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
+ the y and x coordinates for each token.
+
+ Returns:
+ Tensor of same shape as input with applied 2D rotary position embeddings.
+
+ Raises:
+ AssertionError: If input dimensions are invalid or positions are malformed.
+ """
+ # Validate inputs
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
+ assert (
+ positions.ndim == 3 and positions.shape[-1] == 2
+ ), "Positions must have shape (batch_size, n_tokens, 2)"
+
+ # Compute feature dimension for each spatial direction
+ feature_dim = tokens.size(-1) // 2
+
+ # Get frequency components
+ max_position = int(positions.max()) + 1
+ cos_comp, sin_comp = self._compute_frequency_components(
+ feature_dim, max_position, tokens.device, tokens.dtype
+ )
+
+ # Split features for vertical and horizontal processing
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
+
+ # Apply RoPE separately for each dimension
+ vertical_features = self._apply_1d_rope(
+ vertical_features, positions[..., 0], cos_comp, sin_comp
+ )
+ horizontal_features = self._apply_1d_rope(
+ horizontal_features, positions[..., 1], cos_comp, sin_comp
+ )
+
+ # Combine processed features
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
diff --git a/FastVGGT/vggt/layers/swiglu_ffn.py b/FastVGGT/vggt/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dd991e1deb87141ccd282098d4b9d38fed6ef25
--- /dev/null
+++ b/FastVGGT/vggt/layers/swiglu_ffn.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+# try:
+# if XFORMERS_ENABLED:
+# from xformers.ops import SwiGLU
+
+# XFORMERS_AVAILABLE = True
+# warnings.warn("xFormers is available (SwiGLU)")
+# else:
+# warnings.warn("xFormers is disabled (SwiGLU)")
+# raise ImportError
+# except ImportError:
+SwiGLU = SwiGLUFFN
+XFORMERS_AVAILABLE = False
+
+# warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
diff --git a/FastVGGT/vggt/layers/vision_transformer.py b/FastVGGT/vggt/layers/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e1aee388a9c38168657f78fa9436685582068c6
--- /dev/null
+++ b/FastVGGT/vggt/layers/vision_transformer.py
@@ -0,0 +1,446 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from torch.nn.init import trunc_normal_
+from . import (
+ Mlp,
+ PatchEmbed,
+ SwiGLUFFNFused,
+ MemEffAttention,
+ NestedTensorBlock as Block,
+)
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
+) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(
+ fn=fn,
+ module=child_module,
+ name=child_name,
+ depth_first=depth_first,
+ include_root=True,
+ )
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ qk_norm=False,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = (
+ embed_dim # num_features for consistency with other models
+ )
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+ self.use_reentrant = False # hardcoded to False
+
+ self.patch_embed = embed_layer(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
+ )
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
+ if num_register_tokens
+ else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append(
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
+ )
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
+ previous_dtype
+ )
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(
+ masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
+ )
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [
+ self.prepare_tokens_with_masks(x, masks)
+ for x, masks in zip(x_list, masks_list)
+ ]
+
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ if self.training:
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
+ else:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = (
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ )
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(
+ blocks_to_take
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = (
+ range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ )
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(
+ blocks_to_take
+ ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=True, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc b/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ada2738eee9604acf0497bcafdfc52aa6f21119f
Binary files /dev/null and b/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc b/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67f693db1cd59d1a8d87e8f5c1c77c57c9de47d2
Binary files /dev/null and b/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/models/aggregator.py b/FastVGGT/vggt/models/aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e45bb204729f25a963fb083c158142ed01515c2
--- /dev/null
+++ b/FastVGGT/vggt/models/aggregator.py
@@ -0,0 +1,492 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+from numpy import block
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
+from typing import Optional, Tuple, Union, List, Dict, Any
+
+from vggt.layers import PatchEmbed
+from vggt.layers.block import Block
+from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
+from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
+import time
+
+logger = logging.getLogger(__name__)
+
+_RESNET_MEAN = [0.485, 0.456, 0.406]
+_RESNET_STD = [0.229, 0.224, 0.225]
+
+
+class Aggregator(nn.Module):
+ """
+ The Aggregator applies alternating-attention over input frames,
+ as described in VGGT: Visual Geometry Grounded Transformer.
+
+ Remember to set model.train() to enable gradient checkpointing to reduce memory usage.
+
+ Args:
+ img_size (int): Image size in pixels.
+ patch_size (int): Size of each patch for PatchEmbed.
+ embed_dim (int): Dimension of the token embeddings.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
+ num_register_tokens (int): Number of register tokens.
+ block_fn (nn.Module): The block type used for attention (Block by default).
+ qkv_bias (bool): Whether to include bias in QKV projections.
+ proj_bias (bool): Whether to include bias in the output projection.
+ ffn_bias (bool): Whether to include bias in MLP layers.
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
+ qk_norm (bool): Whether to apply QK normalization.
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
+ init_values (float): Init scale for layer scale.
+ """
+
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4.0,
+ num_register_tokens=4,
+ block_fn=Block,
+ qkv_bias=True,
+ proj_bias=True,
+ ffn_bias=True,
+ patch_embed="dinov2_vitl14_reg",
+ aa_order=["frame", "global"],
+ aa_block_size=1,
+ qk_norm=True,
+ rope_freq=100,
+ init_values=0.01,
+ global_merging=True,
+ merging=0,
+ vis_attn_map=False,
+ ):
+ super().__init__()
+
+ self.__build_patch_embed__(
+ patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim
+ )
+
+ # Initialize rotary position embedding if frequency > 0
+ self.rope = (
+ RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
+ )
+ self.position_getter = PositionGetter() if self.rope is not None else None
+ self.global_merging = global_merging
+ self.merging = merging
+ self.vis_attn_map = vis_attn_map
+ self.frame_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.global_blocks = nn.ModuleList(
+ [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ init_values=init_values,
+ qk_norm=qk_norm,
+ rope=self.rope,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.depth = depth
+ self.aa_order = aa_order
+ self.patch_size = patch_size
+ self.aa_block_size = aa_block_size
+
+ # Validate that depth is divisible by aa_block_size
+ if self.depth % self.aa_block_size != 0:
+ raise ValueError(
+ f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})"
+ )
+
+ self.aa_block_num = self.depth // self.aa_block_size
+
+ # Note: We have two camera tokens, one for the first frame and one for the rest
+ # The same applies for register tokens
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
+ self.register_token = nn.Parameter(
+ torch.randn(1, 2, num_register_tokens, embed_dim)
+ )
+
+ # The patch tokens start after the camera and register tokens
+ self.patch_start_idx = 1 + num_register_tokens
+
+ # Initialize parameters with small values
+ nn.init.normal_(self.camera_token, std=1e-6)
+ nn.init.normal_(self.register_token, std=1e-6)
+
+ # Register normalization constants as buffers (use bf16-compatible tensor)
+ for name, value in (
+ ("_resnet_mean", _RESNET_MEAN),
+ ("_resnet_std", _RESNET_STD),
+ ):
+ self.register_buffer(
+ name,
+ torch.tensor(value, dtype=torch.bfloat16).view(1, 1, 3, 1, 1),
+ persistent=False,
+ )
+
+ self.use_reentrant = False # hardcoded to False
+
+ def __build_patch_embed__(
+ self,
+ patch_embed,
+ img_size,
+ patch_size,
+ num_register_tokens,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ block_chunks=0,
+ init_values=1.0,
+ embed_dim=1024,
+ ):
+ """
+ Build the patch embed layer. If 'conv', we use a
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
+ """
+
+ if "conv" in patch_embed:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=3,
+ embed_dim=embed_dim,
+ )
+ else:
+ vit_models = {
+ "dinov2_vitl14_reg": vit_large,
+ "dinov2_vitb14_reg": vit_base,
+ "dinov2_vits14_reg": vit_small,
+ "dinov2_vitg2_reg": vit_giant2,
+ }
+
+ self.patch_embed = vit_models[patch_embed](
+ img_size=img_size,
+ patch_size=patch_size,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ block_chunks=block_chunks,
+ init_values=init_values,
+ )
+
+ # Disable gradient updates for mask token
+ if hasattr(self.patch_embed, "mask_token"):
+ self.patch_embed.mask_token.requires_grad_(False)
+
+ def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]:
+ """
+ Args:
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+
+ Returns:
+ (list[torch.Tensor], int):
+ The list of outputs from the attention blocks,
+ and the patch_start_idx indicating where patch tokens begin.
+ """
+ B, S, C_in, H, W = images.shape
+
+ if C_in != 3:
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
+
+ # Normalize images and reshape for patch embed - ensure bf16 computation
+ images = images.to(torch.bfloat16)
+ images = (images - self._resnet_mean) / self._resnet_std
+
+ images = images.view(B * S, C_in, H, W)
+ patch_tokens = self.patch_embed(images)
+ del images
+
+ if isinstance(patch_tokens, dict):
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
+
+ patch_tokens = patch_tokens.to(torch.bfloat16)
+
+ _, P, C = patch_tokens.shape
+
+ # Expand camera and register tokens to match batch size and sequence length
+ camera_token = slice_expand_and_flatten(
+ self.camera_token.to(torch.bfloat16), B, S
+ )
+ register_token = slice_expand_and_flatten(
+ self.register_token.to(torch.bfloat16), B, S
+ )
+
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
+ tokens = tokens.to(torch.bfloat16)
+ del camera_token, register_token, patch_tokens
+ # Explicitly clean up image data since patch embedding is complete
+ if "images_normalized" in locals():
+ del images_normalized
+
+ pos = None
+ if self.rope is not None:
+ pos = self.position_getter(
+ B * S, H // self.patch_size, W // self.patch_size, device="cuda"
+ )
+
+ if self.patch_start_idx > 0:
+ # do not use position embedding for special tokens (camera and register tokens)
+ # so set pos to 0 for the special tokens
+ pos_original = pos
+ pos = pos + 1
+ pos_special = torch.zeros(
+ B * S, self.patch_start_idx, 2, device="cuda", dtype=torch.long
+ )
+ pos = torch.cat([pos_special, pos], dim=1)
+ # Clean up temporary variables
+ del pos_special, pos_original
+
+ # update P because we added special tokens
+ _, P, C = tokens.shape
+
+ frame_idx = 0
+ global_idx = 0
+ output_list = []
+ block4DPT_idx = [4, 11, 17, 23]
+ global_merging = None
+
+ # Set global variables for attention visualization
+ if self.vis_attn_map:
+ import vggt.layers.attention as attn_module
+
+ # Set the global variables that attention.py needs
+ attn_module.vis_attn_map = True
+ attn_module.current_images = self._load_image_paths() # Load from temp file
+ else:
+ import vggt.layers.attention as attn_module
+
+ attn_module.vis_attn_map = False
+
+ for block_num in range(self.aa_block_num):
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+
+ need_intermediates = True if block_num in block4DPT_idx else False
+ if block_num % 1 == 0:
+ # Clean up RoPE cache to prevent accumulation
+ if hasattr(self, "rope") and self.rope is not None:
+ if hasattr(self.rope, "frequency_cache"):
+ self.rope.frequency_cache.clear()
+ # Clean up position cache
+ if (
+ hasattr(self, "position_getter")
+ and self.position_getter is not None
+ ):
+ if hasattr(self.position_getter, "position_cache"):
+ # Keep only current size cache, clean up others
+ current_cache = self.position_getter.position_cache.copy()
+ if (
+ len(current_cache) > 1
+ ): # If there are multiple cache entries
+ self.position_getter.position_cache.clear()
+ # Keep only the most recently used one
+ if current_cache:
+ key = list(current_cache.keys())[-1]
+ self.position_getter.position_cache[key] = (
+ current_cache[key]
+ )
+ # Avoid saving block_num to instance variable to reduce references
+ for attn_type in self.aa_order:
+ if attn_type == "frame":
+ tokens, frame_idx, frame_intermediates = (
+ self._process_frame_attention(
+ tokens,
+ B,
+ S,
+ P,
+ C,
+ frame_idx,
+ pos=pos,
+ need_intermediates=need_intermediates,
+ )
+ )
+ elif attn_type == "global":
+ if self.merging is None:
+ global_merging = None
+ elif self.global_merging and block_num >= self.merging:
+ global_merging = block_num
+ # Set attention_map for visualization
+ if self.vis_attn_map:
+ import vggt.layers.attention as attn_module
+
+ attn_module.attention_map = block_num
+ tokens, global_idx, global_intermediates = (
+ self._process_global_attention(
+ tokens,
+ B,
+ S,
+ P,
+ C,
+ global_idx,
+ pos=pos,
+ global_merging=global_merging,
+ need_intermediates=need_intermediates,
+ )
+ )
+ else:
+ raise ValueError(f"Unknown attention type: {attn_type}")
+
+ if block_num not in block4DPT_idx:
+ if "frame_intermediates" in locals():
+ del frame_intermediates
+ if "global_intermediates" in locals():
+ del global_intermediates
+ else:
+ concat_inter = torch.cat(
+ [frame_intermediates[0].detach(), global_intermediates[0].detach()],
+ dim=-1,
+ )
+ if concat_inter.dtype != torch.bfloat16:
+ concat_inter = concat_inter.to(torch.bfloat16)
+ output_list.append(concat_inter)
+ del concat_inter, frame_intermediates, global_intermediates
+
+ # Do final cleanup before returning
+ del tokens, pos
+ if "pos_special" in locals():
+ del pos_special
+ if "pos_original" in locals():
+ del pos_original
+ torch.cuda.empty_cache() # Final cleanup
+
+ return output_list, self.patch_start_idx
+
+ def _process_frame_attention(
+ self, tokens, B, S, P, C, frame_idx, pos=None, need_intermediates=False
+ ):
+ """
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
+ """
+ # If needed, reshape tokens or positions:
+ if tokens.shape != (B * S, P, C):
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
+
+ if pos is not None and pos.shape != (B * S, P, 2):
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
+
+ intermediates = [] if need_intermediates else None
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
+ frame_idx += 1
+ if need_intermediates:
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, frame_idx, intermediates
+
+ def _process_global_attention(
+ self,
+ tokens,
+ B,
+ S,
+ P,
+ C,
+ global_idx,
+ pos=None,
+ global_merging=None,
+ need_intermediates=False,
+ ):
+ """
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
+ """
+ if tokens.shape != (B, S * P, C):
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
+
+ if pos is not None and pos.shape != (B, S * P, 2):
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
+
+ intermediates = [] if need_intermediates else None
+
+ # by default, self.aa_block_size=1, which processes one block at a time
+ for _ in range(self.aa_block_size):
+ tokens = self.global_blocks[global_idx](
+ tokens,
+ pos=pos,
+ global_merging=global_merging,
+ )
+ global_idx += 1
+ if need_intermediates:
+ intermediates.append(tokens.view(B, S, P, C))
+
+ return tokens, global_idx, intermediates
+
+ def _load_image_paths(self):
+ """Load image paths from temporary file for visualization"""
+ try:
+ import os
+ import tempfile
+ import pickle
+
+ temp_dir = tempfile.gettempdir()
+ image_paths_file = os.path.join(temp_dir, "vggt_image_paths.pkl")
+
+ if os.path.exists(image_paths_file):
+ with open(image_paths_file, "rb") as f:
+ return pickle.load(f)
+ else:
+ return []
+ except Exception as e:
+ print(f"Warning: Could not load image paths for visualization: {e}")
+ return []
+
+
+def slice_expand_and_flatten(token_tensor, B, S):
+ """
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
+ 1) Uses the first position (index=0) for the first frame only
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
+ 3) Expands both to match batch size B
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
+ followed by (S-1) second-position tokens
+ 5) Flattens to (B*S, X, C) for processing
+
+ Returns:
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
+ """
+
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
+ # Slice out the "other" tokens => shape (1, S-1, ...)
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
+ # Concatenate => shape (B, S, ...)
+ combined = torch.cat([query, others], dim=1)
+
+ # Finally flatten => shape (B*S, ...)
+ combined = combined.view(B * S, *combined.shape[2:])
+ return combined
diff --git a/FastVGGT/vggt/models/vggt.py b/FastVGGT/vggt/models/vggt.py
new file mode 100644
index 0000000000000000000000000000000000000000..2efeb5271de9916d05fd231a0a5ec1694f999bd0
--- /dev/null
+++ b/FastVGGT/vggt/models/vggt.py
@@ -0,0 +1,190 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin # used for model hub
+
+from vggt.models.aggregator import Aggregator
+from vggt.heads.camera_head import CameraHead
+from vggt.heads.dpt_head import DPTHead
+from vggt.heads.track_head import TrackHead
+
+
+class VGGT(nn.Module, PyTorchModelHubMixin):
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=14,
+ embed_dim=1024,
+ enable_camera=True,
+ enable_point=False,
+ enable_depth=True,
+ enable_track=False,
+ merging=0,
+ vis_attn_map=False,
+ ):
+ super().__init__()
+
+ self.vis_attn_map = vis_attn_map
+
+ self.aggregator = Aggregator(
+ img_size=img_size,
+ patch_size=patch_size,
+ embed_dim=embed_dim,
+ merging=merging,
+ vis_attn_map=vis_attn_map,
+ )
+
+ self.camera_head = CameraHead(dim_in=2 * embed_dim) if enable_camera else None
+ self.point_head = (
+ DPTHead(
+ dim_in=2 * embed_dim,
+ output_dim=4,
+ activation="inv_log",
+ conf_activation="expp1",
+ )
+ if enable_point
+ else None
+ )
+ self.depth_head = (
+ DPTHead(
+ dim_in=2 * embed_dim,
+ output_dim=2,
+ activation="exp",
+ conf_activation="expp1",
+ )
+ if enable_depth
+ else None
+ )
+ self.track_head = (
+ TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
+ if enable_track
+ else None
+ )
+
+ def update_patch_dimensions(self, patch_width: int, patch_height: int):
+ """
+ Update patch dimensions for all attention layers in the model
+
+ Args:
+ patch_width: Patch width (typically 37)
+ patch_height: Patch height (typically 28)
+ """
+
+ def update_attention_in_module(module):
+ for name, child in module.named_children():
+ # Recursively update submodules
+ update_attention_in_module(child)
+ # If it is an attention layer, update its patch dimensions
+ if hasattr(child, "patch_width") and hasattr(child, "patch_height"):
+ child.patch_width = patch_width
+ child.patch_height = patch_height
+
+ # Update all attention layers in the aggregator
+ update_attention_in_module(self.aggregator)
+
+ # print(
+ # f"🔧 Updated model attention layer patch dimensions: {patch_width}x{patch_height}"
+ # )
+
+ def forward(
+ self,
+ images: torch.Tensor,
+ query_points: torch.Tensor = None,
+ image_paths: list = None,
+ ):
+ """
+ Forward pass of the VGGT model.
+
+ Args:
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
+ Default: None
+ image_paths (list, optional): List of image file paths for attention visualization.
+ Only used when vis_attn_map=True. Default: None
+
+ Returns:
+ dict: A dictionary containing the following predictions:
+ - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
+ - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
+ - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
+ - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
+ - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
+ - images (torch.Tensor): Original input images, preserved for visualization
+
+ If query_points is provided, also includes:
+ - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
+ - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
+ - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
+ """
+ # If without batch dimension, add it
+ if len(images.shape) == 4:
+ images = images.unsqueeze(0)
+
+ if query_points is not None and len(query_points.shape) == 2:
+ query_points = query_points.unsqueeze(0)
+
+ # Save image paths globally for attention visualization
+ if self.vis_attn_map and image_paths is not None:
+ import os
+ import tempfile
+ import pickle
+
+ # Create a temporary file to store image paths
+ temp_dir = tempfile.gettempdir()
+ image_paths_file = os.path.join(temp_dir, "vggt_image_paths.pkl")
+ with open(image_paths_file, "wb") as f:
+ pickle.dump(image_paths, f)
+
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
+
+ predictions = {}
+
+ if self.camera_head is not None:
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
+ predictions["pose_enc"] = pose_enc_list[
+ -1
+ ] # pose encoding of the last iteration
+ predictions["pose_enc_list"] = pose_enc_list
+
+ if self.depth_head is not None:
+ depth, depth_conf = self.depth_head(
+ aggregated_tokens_list,
+ images=images,
+ patch_start_idx=patch_start_idx,
+ )
+ predictions["depth"] = depth
+ predictions["depth_conf"] = depth_conf
+
+ if self.point_head is not None:
+ pts3d, pts3d_conf = self.point_head(
+ aggregated_tokens_list,
+ images=images,
+ patch_start_idx=patch_start_idx,
+ )
+ predictions["world_points"] = pts3d
+ predictions["world_points_conf"] = pts3d_conf
+
+ if self.track_head is not None and query_points is not None:
+ track_list, vis, conf = self.track_head(
+ aggregated_tokens_list,
+ images=images,
+ patch_start_idx=patch_start_idx,
+ query_points=query_points,
+ )
+ predictions["track"] = track_list[-1] # track of the last iteration
+ predictions["vis"] = vis
+ predictions["conf"] = conf
+
+ if not self.training:
+ predictions["images"] = (
+ images # store the images for visualization during inference
+ )
+
+ return predictions
diff --git a/FastVGGT/vggt/utils/__init__.py b/FastVGGT/vggt/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a2958f93a114495631a3ce270b99ee8ba8443f1
--- /dev/null
+++ b/FastVGGT/vggt/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
\ No newline at end of file
diff --git a/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95269f2111ee13cbadfcbab9cd10605555f25aee
Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3e08e180b2f41ce54c1912296117bf616196aae
Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7a92463e6f8f34289986fcaf2ab515c9598cf51
Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37978e9be49721f9e88beeb814a38e3dcb8754ad
Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93fcc5eddf5936461e8cbf8ca031e150105da24e
Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc differ
diff --git a/FastVGGT/vggt/utils/eval_utils.py b/FastVGGT/vggt/utils/eval_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a81e2cab2b36c97e35b5ebb110b6e66fc814efa7
--- /dev/null
+++ b/FastVGGT/vggt/utils/eval_utils.py
@@ -0,0 +1,782 @@
+from collections import deque
+from PIL import Image
+import cv2
+import io
+import random
+import time
+from pathlib import Path
+from typing import List, Tuple, Dict, Optional
+from copy import deepcopy
+import matplotlib.pyplot as plt
+import evo.tools.plot as plot
+import evo.main_ape as main_ape
+import evo.main_rpe as main_rpe
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from evo.core.metrics import PoseRelation, Unit
+from evo.core.trajectory import PoseTrajectory3D
+from scipy.linalg import svd
+import open3d as o3d # for point cloud processing and Chamfer Distance computation
+
+from scipy.spatial.transform import Rotation
+from torchvision import transforms as TF
+
+
+from vggt.utils.geometry import unproject_depth_map_to_point_map
+from vggt.utils.pose_enc import pose_encoding_to_extri_intri
+
+
+def shuffle_deque(dq, seed=None):
+ # Set the random seed for reproducibility
+ if seed is not None:
+ random.seed(seed)
+
+ # Convert deque to list, shuffle, and convert back
+ shuffled_list = list(dq)
+ random.shuffle(shuffled_list)
+ return deque(shuffled_list)
+
+
+def imread_cv2(path, options=cv2.IMREAD_COLOR):
+ """Open an image or a depthmap with opencv-python."""
+ if path.endswith((".exr", "EXR")):
+ options = cv2.IMREAD_ANYDEPTH
+ img = cv2.imread(path, options)
+ if img is None:
+ raise IOError(f"Could not load image={path} with {options=}")
+ if img.ndim == 3:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
+
+
+# Import umeyama_alignment for internal use in eval_trajectory
+def umeyama_alignment(src, dst, estimate_scale=True):
+ # Ensure inputs have correct shape
+ assert (
+ src.shape == dst.shape
+ ), f"Input shapes don't match: src {src.shape}, dst {dst.shape}"
+ assert src.shape[0] == 3, f"Expected point cloud dimension (3,N), got {src.shape}"
+
+ # Compute centroids
+ src_mean = src.mean(axis=1, keepdims=True)
+ dst_mean = dst.mean(axis=1, keepdims=True)
+
+ # Center the point clouds
+ src_centered = src - src_mean
+ dst_centered = dst - dst_mean
+
+ # Compute covariance matrix
+ cov = dst_centered @ src_centered.T
+
+ try:
+ # Singular Value Decomposition
+ U, D, Vt = svd(cov)
+ V = Vt.T
+
+ # Handle reflection case
+ det_UV = np.linalg.det(U @ V.T)
+ S = np.eye(3)
+ if det_UV < 0:
+ S[2, 2] = -1
+
+ # Compute rotation matrix
+ R = U @ S @ V.T
+
+ if estimate_scale:
+ # Compute scale factor - fix dimension issue
+ src_var = np.sum(src_centered * src_centered)
+ if src_var < 1e-10:
+ print(
+ "Warning: Source point cloud variance close to zero, setting scale factor to 1.0"
+ )
+ scale = 1.0
+ else:
+ # Fix potential dimension issue with np.diag(S)
+ # Use diagonal elements directly
+ scale = np.sum(D * np.diag(S)) / src_var
+ else:
+ scale = 1.0
+
+ # Compute translation vector
+ t = dst_mean.ravel() - scale * (R @ src_mean).ravel()
+
+ return scale, R, t
+
+ except Exception as e:
+ print(f"Error in umeyama_alignment computation: {e}")
+ print(
+ "Returning default transformation: scale=1.0, rotation=identity matrix, translation=centroid difference"
+ )
+ # Return default transformation
+ scale = 1.0
+ R = np.eye(3)
+ t = (dst_mean - src_mean).ravel()
+ return scale, R, t
+
+
+def compute_chamfer_distance(points_pred, points_gt, max_dist=1.0):
+ # Ensure point cloud size is not too large, which would cause slow computation
+ MAX_POINTS = 100000
+ if points_pred.shape[0] > MAX_POINTS:
+ np.random.seed(33) # Fix random seed
+ indices = np.random.choice(points_pred.shape[0], MAX_POINTS, replace=False)
+ points_pred = points_pred[indices]
+
+ if points_gt.shape[0] > MAX_POINTS:
+ np.random.seed(33) # Fix random seed
+ indices = np.random.choice(points_gt.shape[0], MAX_POINTS, replace=False)
+ points_gt = points_gt[indices]
+
+ # Convert numpy point clouds to open3d point cloud objects
+ pcd_pred = o3d.geometry.PointCloud()
+ pcd_gt = o3d.geometry.PointCloud()
+ pcd_pred.points = o3d.utility.Vector3dVector(points_pred)
+ pcd_gt.points = o3d.utility.Vector3dVector(points_gt)
+
+ # Downsample point clouds to accelerate computation
+ voxel_size = 0.05 # 5cm voxel size
+ pcd_pred = pcd_pred.voxel_down_sample(voxel_size)
+ pcd_gt = pcd_gt.voxel_down_sample(voxel_size)
+
+ # Compute distances from predicted point cloud to GT point cloud
+ distances1 = np.asarray(pcd_pred.compute_point_cloud_distance(pcd_gt))
+ # Compute distances from GT point cloud to predicted point cloud
+ distances2 = np.asarray(pcd_gt.compute_point_cloud_distance(pcd_pred))
+
+ # Apply distance clipping
+ distances1 = np.clip(distances1, 0, max_dist)
+ distances2 = np.clip(distances2, 0, max_dist)
+
+ # Chamfer Distance is the sum of mean distances in both directions
+ chamfer_dist = np.mean(distances1) + np.mean(distances2)
+
+ return chamfer_dist
+
+
+def load_gt_pointcloud(scene_id, gt_ply_dir):
+ scene_dir = gt_ply_dir / scene_id
+ ply_path = scene_dir / (scene_id + "_vh_clean_2.ply")
+ pcd = o3d.io.read_point_cloud(str(ply_path))
+
+ # Convert to numpy arrays
+ points = np.asarray(pcd.points)
+ colors = None
+ try:
+ if pcd.has_colors():
+ colors = np.asarray(pcd.colors)
+ except Exception:
+ colors = None
+
+ return points, colors
+
+
+def eval_trajectory(poses_est, poses_gt, frame_ids, align=False):
+ # Build reference trajectory object
+ traj_ref = PoseTrajectory3D(
+ positions_xyz=poses_gt[:, :3, 3], # Extract translation part
+ orientations_quat_wxyz=Rotation.from_matrix(poses_gt[:, :3, :3]).as_quat(
+ scalar_first=True
+ ), # Convert rotation matrix to quaternion
+ timestamps=frame_ids,
+ )
+
+ # Build estimated trajectory object
+ traj_est = PoseTrajectory3D(
+ positions_xyz=poses_est[:, :3, 3],
+ orientations_quat_wxyz=Rotation.from_matrix(poses_est[:, :3, :3]).as_quat(
+ scalar_first=True
+ ),
+ timestamps=frame_ids,
+ )
+
+ # Calculate Absolute Trajectory Error (ATE)
+ ate_result = main_ape.ape(
+ deepcopy(traj_ref),
+ deepcopy(traj_est),
+ est_name="traj",
+ pose_relation=PoseRelation.translation_part,
+ align=align,
+ correct_scale=align,
+ align_origin=True,
+ )
+ ate = ate_result.stats["rmse"]
+
+ # Get alignment transformation matrix
+ transform = np.eye(4)
+ if align:
+ try:
+ # Use umeyama algorithm to compute optimal rigid transformation (including rotation, translation and scaling)
+ aligned_xyz = ate_result.trajectories["traj"].positions_xyz
+ original_xyz = traj_est.positions_xyz
+
+ # At least 3 points needed to compute reliable transformation
+ if len(aligned_xyz) >= 3 and len(original_xyz) >= 3:
+ # Ensure point count matches
+ min_points = min(len(aligned_xyz), len(original_xyz))
+ aligned_xyz = aligned_xyz[:min_points]
+ original_xyz = original_xyz[:min_points]
+
+ # Compute transformation matrix (scaling, rotation and translation)
+ try:
+ s, R, t = umeyama_alignment(
+ original_xyz.T, # Source point cloud (3xN)
+ aligned_xyz.T, # Target point cloud (3xN)
+ True, # Whether to estimate scaling
+ )
+
+ # Build complete transformation matrix
+ transform = np.eye(4)
+ transform[:3, :3] = s * R # Scaling and rotation
+ transform[:3, 3] = t # Translation
+
+ except Exception as e:
+ print(f"umeyama_alignment failed: {e}")
+ else:
+ print(
+ "Insufficient points, cannot reliably compute transformation matrix"
+ )
+ except Exception as e:
+ print(f"Error computing transformation matrix: {e}")
+
+ # If the above method fails, fallback to simple translation transformation
+ if np.array_equal(transform, np.eye(4)) and hasattr(ate_result, "trajectories"):
+ try:
+ # Get original and aligned first position
+ orig_pos = traj_est.positions_xyz[0]
+ aligned_pos = ate_result.trajectories["traj"].positions_xyz[0]
+
+ # Calculate translation part
+ translation = aligned_pos - orig_pos
+
+ # Update translation part of transformation matrix
+ transform[:3, 3] = translation
+ print(f"Fallback to simple translation transformation: {transform}")
+ except Exception as e:
+ print(f"Error building translation transformation: {e}")
+ print("Will use identity matrix")
+
+ # Calculate Absolute Rotation Error (ARE)
+ are_result = main_ape.ape(
+ deepcopy(traj_ref),
+ deepcopy(traj_est),
+ est_name="traj",
+ pose_relation=PoseRelation.rotation_angle_deg,
+ align=align,
+ correct_scale=align,
+ align_origin=True,
+ )
+ are = are_result.stats["rmse"]
+
+ # Calculate Relative Pose Error (RPE) - rotation part
+ rpe_rots_result = main_rpe.rpe(
+ deepcopy(traj_ref),
+ deepcopy(traj_est),
+ est_name="traj",
+ pose_relation=PoseRelation.rotation_angle_deg,
+ align=align,
+ correct_scale=align,
+ delta=1,
+ delta_unit=Unit.frames,
+ rel_delta_tol=0.01,
+ all_pairs=True,
+ align_origin=True,
+ )
+ rpe_rot = rpe_rots_result.stats["rmse"]
+
+ # Calculate Relative Pose Error (RPE) - translation part
+ rpe_transs_result = main_rpe.rpe(
+ deepcopy(traj_ref),
+ deepcopy(traj_est),
+ est_name="traj",
+ pose_relation=PoseRelation.translation_part,
+ align=align,
+ correct_scale=align,
+ delta=1,
+ delta_unit=Unit.frames,
+ rel_delta_tol=0.01,
+ all_pairs=True,
+ align_origin=True,
+ )
+ rpe_trans = rpe_transs_result.stats["rmse"]
+
+ # Plot trajectory graph
+ plot_mode = plot.PlotMode.xz # Use correct PlotMode reference
+ fig = plt.figure()
+ ax = plot.prepare_axis(fig, plot_mode)
+ ax.set_title(f"ATE: {round(ate, 3)}, ARE: {round(are, 3)}")
+
+ # Use reference trajectory (GT) for plotting
+ plot.traj(ax, plot_mode, traj_ref, "--", "gray", "gt")
+
+ # Use aligned trajectory for visualization
+ if align:
+ traj_est_aligned = ate_result.trajectories["traj"]
+ plot.traj_colormap(
+ ax,
+ traj_est_aligned,
+ ate_result.np_arrays["error_array"],
+ plot_mode,
+ min_map=ate_result.stats["min"],
+ max_map=ate_result.stats["max"],
+ )
+ else:
+ plot.traj_colormap(
+ ax,
+ traj_est,
+ ate_result.np_arrays["error_array"],
+ plot_mode,
+ min_map=ate_result.stats["min"],
+ max_map=ate_result.stats["max"],
+ )
+
+ ax.legend()
+
+ # Save image to memory buffer
+ buffer = io.BytesIO()
+ plt.savefig(buffer, format="png", dpi=90)
+ buffer.seek(0)
+
+ pillow_image = Image.open(buffer)
+ pillow_image.load()
+ buffer.close()
+ plt.close(fig)
+
+ return (
+ {"ate": ate, "are": are, "rpe_rot": rpe_rot, "rpe_trans": rpe_trans},
+ pillow_image,
+ transform,
+ )
+
+
+def load_poses(path):
+ # Read all txt files from pose directory
+ pose_files = sorted(
+ path.glob("*.txt"), key=lambda x: int(x.stem)
+ ) # Sort by numerical order
+
+ # Check if pose files exist
+ if len(pose_files) == 0:
+ print(f"Warning: No pose files (.txt) found in directory {path}")
+ return None, None, None
+
+ c2ws = []
+ available_frame_ids = []
+
+ for pose_file in pose_files:
+ try:
+ with open(pose_file, "r") as f:
+ # Each file contains 16 numbers representing a 4x4 transformation matrix
+ nums = [float(x) for x in f.read().strip().split()]
+ pose = np.array(nums).reshape(4, 4)
+ # Check if pose is valid (no infinite or NaN values)
+ if not (np.isinf(pose).any() or np.isnan(pose).any()):
+ c2ws.append(pose)
+ available_frame_ids.append(int(pose_file.stem))
+ else:
+ continue
+ except Exception as e:
+ print(f"Error reading pose file {pose_file}: {e}")
+ continue
+
+ if len(c2ws) == 0:
+ print(f"Warning: No valid pose files found in directory {path}")
+ return None, None, None
+
+ c2ws = np.stack(c2ws)
+ available_frame_ids = np.array(available_frame_ids)
+
+ # Transform all poses to first frame coordinate system
+ first_gt_pose = c2ws[0].copy() # Save original pose of first frame
+ c2ws = np.linalg.inv(c2ws[0]) @ c2ws
+ return c2ws, first_gt_pose, available_frame_ids
+
+
+def get_vgg_input_imgs(images: np.ndarray):
+ to_tensor = TF.ToTensor()
+ vgg_input_images = []
+ final_width = None
+ final_height = None
+
+ for image in images:
+ img = Image.fromarray(image, mode="RGB")
+ width, height = img.size
+ # Resize image, maintain aspect ratio, ensure height is multiple of 14
+ new_width = 518
+ new_height = round(height * (new_width / width) / 14) * 14
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
+ img = to_tensor(img) # Convert to tensor (0, 1)
+
+ # If height exceeds 518, perform center cropping
+ if new_height > 518:
+ start_y = (new_height - 518) // 2
+ img = img[:, start_y : start_y + 518, :]
+ final_height = 518
+ else:
+ final_height = new_height
+
+ final_width = new_width
+ vgg_input_images.append(img)
+
+ vgg_input_images = torch.stack(vgg_input_images)
+
+ # Calculate the patch dimensions (divided by 14 for patch size)
+ patch_width = final_width // 14 # 518 // 14 = 37
+ patch_height = final_height // 14 # computed dynamically, typically 28
+
+ return vgg_input_images, patch_width, patch_height
+
+
+def get_sorted_image_paths(images_dir):
+ image_paths = []
+ for ext in ["*.jpg", "*.png", "*.jpeg"]:
+ image_paths.extend(sorted(images_dir.glob(ext)))
+ # image_paths.sort(key=lambda x: int(x.stem))
+ return image_paths
+
+
+def to_homogeneous(extrinsics):
+ n = extrinsics.shape[0]
+ homogeneous_extrinsics = np.eye(4)[None, :, :].repeat(
+ n, axis=0
+ ) # Create identity matrix
+ homogeneous_extrinsics[:, :3, :4] = extrinsics # Copy [R|t] part
+ return homogeneous_extrinsics
+
+
+def umeyama_alignment(src, dst, estimate_scale=True):
+ # Ensure inputs have correct shape
+ assert (
+ src.shape == dst.shape
+ ), f"Input shapes don't match: src {src.shape}, dst {dst.shape}"
+ assert src.shape[0] == 3, f"Expected point cloud dimension (3,N), got {src.shape}"
+
+ # Compute centroids
+ src_mean = src.mean(axis=1, keepdims=True)
+ dst_mean = dst.mean(axis=1, keepdims=True)
+
+ # Center the point clouds
+ src_centered = src - src_mean
+ dst_centered = dst - dst_mean
+
+ # Compute covariance matrix
+ cov = dst_centered @ src_centered.T
+
+ try:
+ # Singular Value Decomposition
+ U, D, Vt = svd(cov)
+ V = Vt.T
+
+ # Handle reflection case
+ det_UV = np.linalg.det(U @ V.T)
+ S = np.eye(3)
+ if det_UV < 0:
+ S[2, 2] = -1
+
+ # Compute rotation matrix
+ R = U @ S @ V.T
+
+ if estimate_scale:
+ # Compute scale factor - fix dimension issue
+ src_var = np.sum(src_centered * src_centered)
+ if src_var < 1e-10:
+ print(
+ "Warning: Source point cloud variance close to zero, setting scale factor to 1.0"
+ )
+ scale = 1.0
+ else:
+ # Fix potential dimension issue with np.diag(S)
+ # Use diagonal elements directly
+ scale = np.sum(D * np.diag(S)) / src_var
+ else:
+ scale = 1.0
+
+ # Compute translation vector
+ t = dst_mean.ravel() - scale * (R @ src_mean).ravel()
+
+ return scale, R, t
+
+ except Exception as e:
+ print(f"Error in umeyama_alignment computation: {e}")
+ print(
+ "Returning default transformation: scale=1.0, rotation=identity matrix, translation=centroid difference"
+ )
+ # Return default transformation
+ scale = 1.0
+ R = np.eye(3)
+ t = (dst_mean - src_mean).ravel()
+ return scale, R, t
+
+
+def align_point_clouds_scale(source_pc, target_pc):
+ # Compute bounding box sizes of point clouds
+ source_min = np.min(source_pc, axis=0)
+ source_max = np.max(source_pc, axis=0)
+ target_min = np.min(target_pc, axis=0)
+ target_max = np.max(target_pc, axis=0)
+
+ source_size = source_max - source_min
+ target_size = target_max - target_min
+
+ # Compute point cloud centers
+ source_center = (source_max + source_min) / 2
+ target_center = (target_max + target_min) / 2
+
+ # Compute overall scale factor (using diagonal length)
+ source_diag = np.sqrt(np.sum(source_size**2))
+ target_diag = np.sqrt(np.sum(target_size**2))
+
+ if source_diag < 1e-8:
+ print("Warning: Source point cloud size close to zero")
+ scale_factor = 1.0
+ else:
+ scale_factor = target_diag / source_diag
+
+ # Apply scaling (with source point cloud center as reference)
+ centered_source = source_pc - source_center
+ scaled_centered = centered_source * scale_factor
+ scaled_aligned_source = scaled_centered + target_center
+
+ return scaled_aligned_source, scale_factor
+
+
+def get_all_scenes(data_dir: Path, num_scenes: int) -> List[str]:
+ all_scenes = sorted([d.name for d in data_dir.iterdir() if d.is_dir()])
+ if len(all_scenes) > num_scenes:
+ sample_interval = max(1, len(all_scenes) // num_scenes)
+ return all_scenes[::sample_interval][:num_scenes]
+ return all_scenes
+
+
+def build_frame_selection(
+ image_paths: List[Path],
+ available_pose_frame_ids: np.ndarray,
+ input_frame: int,
+) -> Tuple[List[int], List[Path], List[int]]:
+ all_image_frame_ids = [int(path.stem) for path in image_paths]
+ valid_frame_ids = sorted(
+ list(set(all_image_frame_ids) & set(available_pose_frame_ids))
+ )
+ if len(valid_frame_ids) > input_frame:
+ first_frame = valid_frame_ids[0]
+ remaining_frames = valid_frame_ids[1:]
+ step = max(1, len(remaining_frames) // (input_frame - 1))
+ selected_remaining = remaining_frames[::step][: input_frame - 1]
+ selected_frame_ids = [first_frame] + selected_remaining
+ else:
+ selected_frame_ids = valid_frame_ids
+
+ frame_id_to_path = {int(path.stem): path for path in image_paths}
+ selected_image_paths = [
+ frame_id_to_path[fid] for fid in selected_frame_ids if fid in frame_id_to_path
+ ]
+
+ pose_frame_to_idx = {fid: idx for idx, fid in enumerate(available_pose_frame_ids)}
+ selected_pose_indices = [
+ pose_frame_to_idx[fid] for fid in selected_frame_ids if fid in pose_frame_to_idx
+ ]
+
+ return selected_frame_ids, selected_image_paths, selected_pose_indices
+
+
+def load_images_rgb(image_paths: List[Path]) -> List[np.ndarray]:
+ images: List[np.ndarray] = []
+ for image_path in image_paths:
+ img = cv2.imread(str(image_path))
+ if img is None:
+ continue
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ images.append(img)
+ return images
+
+
+@torch.no_grad()
+def infer_vggt_and_reconstruct(
+ model: torch.nn.Module,
+ vgg_input: torch.Tensor,
+ dtype: torch.dtype,
+ depth_conf_thresh: float,
+ image_paths: list = None,
+) -> Tuple[
+ np.ndarray,
+ np.ndarray,
+ List[np.ndarray],
+ List[np.ndarray],
+ List[np.ndarray],
+ float,
+]:
+ torch.cuda.synchronize()
+ start = time.time()
+ with torch.cuda.amp.autocast(dtype=dtype):
+ vgg_input_cuda = vgg_input.cuda().to(torch.bfloat16)
+ predictions = model(vgg_input_cuda, image_paths=image_paths)
+ torch.cuda.synchronize()
+ end = time.time()
+ inference_time_ms = (end - start) * 1000.0
+
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(
+ predictions["pose_enc"], (vgg_input.shape[2], vgg_input.shape[3])
+ )
+
+ depth_tensor = predictions["depth"]
+ depth_conf = predictions["depth_conf"]
+ depth_conf_np = depth_conf[0].detach().float().cpu().numpy()
+ depth_mask = depth_conf_np >= depth_conf_thresh
+ depth_filtered = depth_tensor[0].detach().float().cpu().numpy()
+ depth_filtered[~depth_mask] = np.nan
+ depth_np = depth_filtered
+
+ extrinsic_np = extrinsic[0].detach().float().cpu().numpy()
+ intrinsic_np = intrinsic[0].detach().float().cpu().numpy()
+
+ world_points = unproject_depth_map_to_point_map(
+ depth_np, extrinsic_np, intrinsic_np
+ )
+ all_points: List[np.ndarray] = []
+ all_colors: List[np.ndarray] = []
+
+ # Prepare RGB images aligned with vgg_input for coloring point clouds (0-255, uint8)
+ vgg_np = vgg_input.detach().float().cpu().numpy() # [S, 3, H, W] in [0,1]
+
+ for frame_idx in range(world_points.shape[0]):
+ points = world_points[frame_idx].reshape(-1, 3)
+ valid_mask = ~np.isnan(points).any(axis=1) & ~np.isinf(points).any(axis=1)
+ valid_points = points[valid_mask]
+ if len(valid_points) > 0:
+ all_points.append(valid_points)
+
+ # Generate corresponding colors
+ img_chw = vgg_np[frame_idx] # [3, H, W]
+ img_hwc = (
+ (np.transpose(img_chw, (1, 2, 0)) * 255.0).clip(0, 255).astype(np.uint8)
+ ) # [H, W, 3] uint8
+ rgb_flat = img_hwc.reshape(-1, 3)
+ valid_colors = rgb_flat[valid_mask]
+ all_colors.append(valid_colors)
+
+ camera_poses = to_homogeneous(extrinsic_np)
+ all_cam_to_world_mat = list(camera_poses)
+
+ return (
+ extrinsic_np,
+ intrinsic_np,
+ all_points,
+ all_colors,
+ all_cam_to_world_mat,
+ inference_time_ms,
+ )
+
+
+def evaluate_scene_and_save(
+ scene: str,
+ c2ws: np.ndarray,
+ first_gt_pose: np.ndarray,
+ frame_ids: List[int],
+ all_cam_to_world_mat: List[np.ndarray],
+ all_world_points: List[np.ndarray],
+ output_scene_dir: Path,
+ gt_ply_dir: Path,
+ chamfer_max_dist: float,
+ inference_time_ms: float,
+ plot_flag: bool,
+) -> Optional[Dict[str, float]]:
+ if not all_cam_to_world_mat or not all_world_points:
+ print(f"Skipping {scene}: failed to obtain valid camera poses or point clouds")
+ return None
+
+ output_scene_dir.mkdir(parents=True, exist_ok=True)
+
+ poses_gt = c2ws
+ w2cs = np.linalg.inv(poses_gt)
+ traj_est_poses = np.array(all_cam_to_world_mat)
+ n = min(len(traj_est_poses), len(w2cs))
+ timestamps = frame_ids[:n]
+ stats_aligned, traj_plot, _ = eval_trajectory(
+ traj_est_poses[:n], w2cs[:n], timestamps, align=True
+ )
+
+ try:
+ merged_point_cloud = np.vstack(all_world_points)
+ gt_point_cloud, _ = load_gt_pointcloud(scene, gt_ply_dir)
+ if gt_point_cloud is not None:
+ homogeneous_points = np.hstack(
+ [merged_point_cloud, np.ones((merged_point_cloud.shape[0], 1))]
+ )
+ world_points_raw = np.dot(homogeneous_points, first_gt_pose.T)[:, :3]
+
+ world_points_scaled, scale_factor = align_point_clouds_scale(
+ world_points_raw, gt_point_cloud
+ )
+
+ cd_value = compute_chamfer_distance(
+ world_points_scaled, gt_point_cloud, max_dist=chamfer_max_dist
+ )
+ stats_aligned["chamfer_distance"] = float(cd_value)
+ stats_aligned["scale_factor"] = float(scale_factor)
+ except Exception as e:
+ print(f"Error computing Chamfer Distance for {scene}: {e}")
+
+ all_metrics: Dict[str, float] = deepcopy(stats_aligned)
+ for metric_name, metric_value in list(stats_aligned.items()):
+ all_metrics[f"aligned_{metric_name}"] = metric_value
+ all_metrics["inference_time_ms"] = float(inference_time_ms)
+
+ with open(output_scene_dir / "metrics.json", "w") as f:
+ import json
+
+ json.dump(all_metrics, f, indent=4)
+ if plot_flag:
+ try:
+ traj_plot.save(output_scene_dir / "plot.png")
+ except Exception:
+ pass
+
+ return all_metrics
+
+
+def compute_average_metrics_and_save(
+ all_scenes_metrics: Dict[str, Dict[str, Dict[str, float]]],
+ output_path: Path,
+ input_frame: int,
+) -> Dict[str, float]:
+ metric_names = [
+ "chamfer_distance",
+ "ate",
+ "are",
+ "rpe_rot",
+ "rpe_trans",
+ "inference_time_ms",
+ ]
+ average_metrics_list: Dict[str, List[float]] = {
+ metric: [] for metric in metric_names
+ }
+ for _, metrics in all_scenes_metrics["scenes"].items():
+ for metric_name, metric_value in metrics.items():
+ if metric_name in average_metrics_list:
+ average_metrics_list[metric_name].append(float(metric_value))
+
+ average_metrics: Dict[str, float] = {}
+ for metric_name, values in average_metrics_list.items():
+ average_metrics[metric_name] = float(np.mean(values)) if values else 0.0
+
+ all_scenes_metrics["average"] = average_metrics
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ input_frame_dir = output_path / f"input_frame_{input_frame}"
+ input_frame_dir.mkdir(parents=True, exist_ok=True)
+
+ with open(input_frame_dir / "all_scenes_metrics.json", "w") as f:
+ import json
+
+ json.dump(all_scenes_metrics, f, indent=4)
+
+ with open(input_frame_dir / "average_metrics.json", "w") as f:
+ import json
+
+ json.dump(average_metrics, f, indent=4)
+
+ print("\nAverage metrics:")
+ for metric_name, value in average_metrics.items():
+ print(f"{metric_name}: {value:.6f}")
+
+ return average_metrics
diff --git a/FastVGGT/vggt/utils/geometry.py b/FastVGGT/vggt/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..f555516dbc8a7dd8c7b15e6fbc928a5bfe8f740b
--- /dev/null
+++ b/FastVGGT/vggt/utils/geometry.py
@@ -0,0 +1,324 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import numpy as np
+
+
+from vggt.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion
+
+
+def unproject_depth_map_to_point_map(
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
+) -> np.ndarray:
+ """
+ Unproject a batch of depth maps to 3D world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
+
+ Returns:
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
+ """
+ if isinstance(depth_map, torch.Tensor):
+ depth_map = depth_map.cpu().numpy()
+ if isinstance(extrinsics_cam, torch.Tensor):
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
+ if isinstance(intrinsics_cam, torch.Tensor):
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
+
+ world_points_list = []
+ for frame_idx in range(depth_map.shape[0]):
+ cur_world_points, _, _ = depth_to_world_coords_points(
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
+ )
+ world_points_list.append(cur_world_points)
+ world_points_array = np.stack(world_points_list, axis=0)
+
+ return world_points_array
+
+
+def depth_to_world_coords_points(
+ depth_map: np.ndarray,
+ extrinsic: np.ndarray,
+ intrinsic: np.ndarray,
+ eps=1e-8,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to world coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
+ """
+ if depth_map is None:
+ return None, None, None
+
+ # Valid depth mask
+ point_mask = depth_map > eps
+
+ # Convert depth map to camera coordinates
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
+
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
+
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
+
+ # Apply the rotation and translation to the camera coordinates
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
+
+ return world_coords_points, cam_coords_points, point_mask
+
+
+def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ """
+ Convert a depth map to camera coordinates.
+
+ Args:
+ depth_map (np.ndarray): Depth map of shape (H, W).
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
+ """
+ H, W = depth_map.shape
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
+
+ # Intrinsic parameters
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
+
+ # Generate grid of pixel coordinates
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+
+ # Unproject to camera coordinates
+ x_cam = (u - cu) * depth_map / fu
+ y_cam = (v - cv) * depth_map / fv
+ z_cam = depth_map
+
+ # Stack to form camera coordinates
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ return cam_coords
+
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Check if se3 is a numpy array or a torch tensor
+ is_numpy = isinstance(se3, np.ndarray)
+
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ if is_numpy:
+ # Compute the transpose of the rotation for NumPy
+ R_transposed = np.transpose(R, (0, 2, 1))
+ # -R^T t for NumPy
+ top_right = -np.matmul(R_transposed, T)
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
+ else:
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
+
+
+# TODO: this code can be further cleaned up
+
+
+def project_world_points_to_camera_points_batch(world_points, cam_extrinsics):
+ """
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
+ Args:
+ world_points (torch.Tensor): 3D points of shape BxSxHxWx3.
+ cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4.
+ Returns:
+ """
+ # TODO: merge this into project_world_points_to_cam
+
+ # device = world_points.device
+ # with torch.autocast(device_type=device.type, enabled=False):
+ ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1)
+ world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4)
+
+ # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4)
+ extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3)
+
+ # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1)
+ world_points_h_exp = world_points_h.unsqueeze(-1)
+
+ # Now perform the matrix multiplication
+ # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1)
+ camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1)
+
+ return camera_points
+
+
+
+def project_world_points_to_cam(
+ world_points,
+ cam_extrinsics,
+ cam_intrinsics=None,
+ distortion_params=None,
+ default=0,
+ only_points_cam=False,
+):
+ """
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
+ Args:
+ world_points (torch.Tensor): 3D points of shape Px3.
+ cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
+ cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
+ distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion.
+ Returns:
+ torch.Tensor: Transformed 2D points of shape BxNx2.
+ """
+ device = world_points.device
+ # with torch.autocast(device_type=device.type, dtype=torch.double):
+ with torch.autocast(device_type=device.type, enabled=False):
+ N = world_points.shape[0] # Number of points
+ B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras
+ world_points_homogeneous = torch.cat(
+ [world_points, torch.ones_like(world_points[..., 0:1])], dim=1
+ ) # Nx4
+ # Reshape for batch processing
+ world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand(
+ B, -1, -1
+ ) # BxNx4
+
+ # Step 1: Apply extrinsic parameters
+ # Transform 3D points to camera coordinate system for all cameras
+ cam_points = torch.bmm(
+ cam_extrinsics, world_points_homogeneous.transpose(-1, -2)
+ )
+
+ if only_points_cam:
+ return None, cam_points
+
+ # Step 2: Apply intrinsic parameters and (optional) distortion
+ image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default)
+
+ return image_points, cam_points
+
+
+
+def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0):
+ """
+ Applies intrinsic parameters and optional distortion to the given 3D points.
+
+ Args:
+ cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
+ cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
+ distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
+ default (float, optional): Default value to replace NaNs in the output.
+
+ Returns:
+ pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
+ """
+
+ # Normalized device coordinates (NDC)
+ cam_points = cam_points / cam_points[:, 2:3, :]
+ ndc_xy = cam_points[:, :2, :]
+
+ # Apply distortion if distortion_params are provided
+ if distortion_params is not None:
+ x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1])
+ distorted_xy = torch.stack([x_distorted, y_distorted], dim=1)
+ else:
+ distorted_xy = ndc_xy
+
+ # Prepare cam_points for batch matrix multiplication
+ cam_coords_homo = torch.cat(
+ (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1
+ ) # Bx3xN
+ # Apply intrinsic parameters using batch matrix multiplication
+ pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN
+
+ # Extract x and y coordinates
+ pixel_coords = pixel_coords[:, :2, :] # Bx2xN
+
+ # Replace NaNs with default value
+ pixel_coords = torch.nan_to_num(pixel_coords, nan=default)
+
+ return pixel_coords.transpose(1, 2) # BxNx2
+
+
+
+
+def cam_from_img(pred_tracks, intrinsics, extra_params=None):
+ """
+ Normalize predicted tracks based on camera intrinsics.
+ Args:
+ intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3].
+ pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2].
+ extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
+ Returns:
+ torch.Tensor: Normalized tracks tensor.
+ """
+
+ # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here
+ # otherwise we can use something like
+ # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2))
+
+ principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2)
+ focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2)
+ tracks_normalized = (pred_tracks - principal_point) / focal_length
+
+ if extra_params is not None:
+ # Apply iterative undistortion
+ try:
+ tracks_normalized = iterative_undistortion(
+ extra_params, tracks_normalized
+ )
+ except:
+ tracks_normalized = single_undistortion(
+ extra_params, tracks_normalized
+ )
+
+ return tracks_normalized
\ No newline at end of file
diff --git a/FastVGGT/vggt/utils/helper.py b/FastVGGT/vggt/utils/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b019189c85ff86645a4cf3756632aa8d4500649
--- /dev/null
+++ b/FastVGGT/vggt/utils/helper.py
@@ -0,0 +1,60 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+
+
+def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray:
+ """
+ If mask has more than max_trues True values,
+ randomly keep only max_trues of them and set the rest to False.
+ """
+ # 1D positions of all True entries
+ true_indices = np.flatnonzero(mask) # shape = (N_true,)
+
+ # if already within budget, return as-is
+ if true_indices.size <= max_trues:
+ return mask
+
+ # randomly pick which True positions to keep
+ sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,)
+
+ # build new flat mask: True only at sampled positions
+ limited_flat_mask = np.zeros(mask.size, dtype=bool)
+ limited_flat_mask[sampled_indices] = True
+
+ # restore original shape
+ return limited_flat_mask.reshape(mask.shape)
+
+
+def create_pixel_coordinate_grid(num_frames, height, width):
+ """
+ Creates a grid of pixel coordinates and frame indices for all frames.
+ Returns:
+ tuple: A tuple containing:
+ - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3)
+ with x, y coordinates and frame indices
+ - y_coords (numpy.ndarray): Array of y coordinates for all frames
+ - x_coords (numpy.ndarray): Array of x coordinates for all frames
+ - f_coords (numpy.ndarray): Array of frame indices for all frames
+ """
+ # Create coordinate grids for a single frame
+ y_grid, x_grid = np.indices((height, width), dtype=np.float32)
+ x_grid = x_grid[np.newaxis, :, :]
+ y_grid = y_grid[np.newaxis, :, :]
+
+ # Broadcast to all frames
+ x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
+ y_coords = np.broadcast_to(y_grid, (num_frames, height, width))
+
+ # Create frame indices and broadcast
+ f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis]
+ f_coords = np.broadcast_to(f_idx, (num_frames, height, width))
+
+ # Stack coordinates and frame indices
+ points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)
+
+ return points_xyf
diff --git a/FastVGGT/vggt/utils/load_fn.py b/FastVGGT/vggt/utils/load_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ae01ba2601f7b5039220bbae94725646dd66a41
--- /dev/null
+++ b/FastVGGT/vggt/utils/load_fn.py
@@ -0,0 +1,242 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from PIL import Image
+from torchvision import transforms as TF
+import numpy as np
+
+
+def load_and_preprocess_images_square(image_path_list, target_size=1024):
+ """
+ Load and preprocess images by center padding to square and resizing to target size.
+ Also returns the position information of original pixels after transformation.
+
+ Args:
+ image_path_list (list): List of paths to image files
+ target_size (int, optional): Target size for both width and height. Defaults to 518.
+
+ Returns:
+ tuple: (
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
+ torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
+ )
+
+ Raises:
+ ValueError: If the input list is empty
+ """
+ # Check for empty list
+ if len(image_path_list) == 0:
+ raise ValueError("At least 1 image is required")
+
+ images = []
+ original_coords = [] # Renamed from position_info to be more descriptive
+ to_tensor = TF.ToTensor()
+
+ for image_path in image_path_list:
+ # Open image
+ img = Image.open(image_path)
+
+ # If there's an alpha channel, blend onto white background
+ if img.mode == "RGBA":
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
+ img = Image.alpha_composite(background, img)
+
+ # Convert to RGB
+ img = img.convert("RGB")
+
+ # Get original dimensions
+ width, height = img.size
+
+ # Make the image square by padding the shorter dimension
+ max_dim = max(width, height)
+
+ # Calculate padding
+ left = (max_dim - width) // 2
+ top = (max_dim - height) // 2
+
+ # Calculate scale factor for resizing
+ scale = target_size / max_dim
+
+ # Calculate final coordinates of original image in target space
+ x1 = left * scale
+ y1 = top * scale
+ x2 = (left + width) * scale
+ y2 = (top + height) * scale
+
+ # Store original image coordinates and scale
+ original_coords.append(np.array([x1, y1, x2, y2, width, height]))
+
+ # Create a new black square image and paste original
+ square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
+ square_img.paste(img, (left, top))
+
+ # Resize to target size
+ square_img = square_img.resize(
+ (target_size, target_size), Image.Resampling.BICUBIC
+ )
+
+ # Convert to tensor
+ img_tensor = to_tensor(square_img)
+ images.append(img_tensor)
+
+ # Stack all images
+ images = torch.stack(images)
+ original_coords = torch.from_numpy(np.array(original_coords)).float()
+
+ # Add additional dimension if single image to ensure correct shape
+ if len(image_path_list) == 1:
+ if images.dim() == 3:
+ images = images.unsqueeze(0)
+ original_coords = original_coords.unsqueeze(0)
+
+ return images, original_coords
+
+
+def load_and_preprocess_images(image_path_list, mode="crop"):
+ """
+ A quick start function to load and preprocess images for model input.
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
+
+ Args:
+ image_path_list (list): List of paths to image files
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
+ - "crop" (default): Sets width to 518px and center crops height if needed.
+ - "pad": Preserves all pixels by making the largest dimension 518px
+ and padding the smaller dimension to reach a square shape.
+
+ Returns:
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
+
+ Raises:
+ ValueError: If the input list is empty or if mode is invalid
+
+ Notes:
+ - Images with different dimensions will be padded with white (value=1.0)
+ - A warning is printed when images have different shapes
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
+ and height is center-cropped if larger than 518px
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
+ and the smaller dimension is padded to reach a square shape (518x518)
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
+ """
+ # Check for empty list
+ if len(image_path_list) == 0:
+ raise ValueError("At least 1 image is required")
+
+ # Validate mode
+ if mode not in ["crop", "pad"]:
+ raise ValueError("Mode must be either 'crop' or 'pad'")
+
+ images = []
+ shapes = set()
+ to_tensor = TF.ToTensor()
+ target_size = 518
+
+ # First process all images and collect their shapes
+ for image_path in image_path_list:
+ # Open image
+ img = Image.open(image_path)
+
+ # If there's an alpha channel, blend onto white background:
+ if img.mode == "RGBA":
+ # Create white background
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
+ # Alpha composite onto the white background
+ img = Image.alpha_composite(background, img)
+
+ # Now convert to "RGB" (this step assigns white for transparent areas)
+ img = img.convert("RGB")
+
+ width, height = img.size
+
+ if mode == "pad":
+ # Make the largest dimension 518px while maintaining aspect ratio
+ if width >= height:
+ new_width = target_size
+ new_height = (
+ round(height * (new_width / width) / 14) * 14
+ ) # Make divisible by 14
+ else:
+ new_height = target_size
+ new_width = (
+ round(width * (new_height / height) / 14) * 14
+ ) # Make divisible by 14
+ else: # mode == "crop"
+ # Original behavior: set width to 518px
+ new_width = target_size
+ # Calculate height maintaining aspect ratio, divisible by 14
+ new_height = round(height * (new_width / width) / 14) * 14
+
+ # Resize with new dimensions (width, height)
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
+ img = to_tensor(img) # Convert to tensor (0, 1)
+
+ # Center crop height if it's larger than 518 (only in crop mode)
+ if mode == "crop" and new_height > target_size:
+ start_y = (new_height - target_size) // 2
+ img = img[:, start_y : start_y + target_size, :]
+
+ # For pad mode, pad to make a square of target_size x target_size
+ if mode == "pad":
+ h_padding = target_size - img.shape[1]
+ w_padding = target_size - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ # Pad with white (value=1.0)
+ img = torch.nn.functional.pad(
+ img,
+ (pad_left, pad_right, pad_top, pad_bottom),
+ mode="constant",
+ value=1.0,
+ )
+
+ shapes.add((img.shape[1], img.shape[2]))
+ images.append(img)
+
+ # Check if we have different shapes
+ # In theory our model can also work well with different shapes
+ if len(shapes) > 1:
+ print(f"Warning: Found images with different shapes: {shapes}")
+ # Find maximum dimensions
+ max_height = max(shape[0] for shape in shapes)
+ max_width = max(shape[1] for shape in shapes)
+
+ # Pad images if necessary
+ padded_images = []
+ for img in images:
+ h_padding = max_height - img.shape[1]
+ w_padding = max_width - img.shape[2]
+
+ if h_padding > 0 or w_padding > 0:
+ pad_top = h_padding // 2
+ pad_bottom = h_padding - pad_top
+ pad_left = w_padding // 2
+ pad_right = w_padding - pad_left
+
+ img = torch.nn.functional.pad(
+ img,
+ (pad_left, pad_right, pad_top, pad_bottom),
+ mode="constant",
+ value=1.0,
+ )
+ padded_images.append(img)
+ images = padded_images
+
+ images = torch.stack(images) # concatenate images
+
+ # Ensure correct shape when single image
+ if len(image_path_list) == 1:
+ # Verify shape is (1, C, H, W)
+ if images.dim() == 3:
+ images = images.unsqueeze(0)
+
+ return images
diff --git a/FastVGGT/vggt/utils/pose_enc.py b/FastVGGT/vggt/utils/pose_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d3b964330af0e62f4d36d332317ae00cb99b3a9
--- /dev/null
+++ b/FastVGGT/vggt/utils/pose_enc.py
@@ -0,0 +1,124 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from .rotation import quat_to_mat, mat_to_quat
+
+
+def extri_intri_to_pose_encoding(
+ extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512)
+):
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
+
+ This function transforms camera parameters into a unified pose encoding format,
+ which can be used for various downstream tasks like pose prediction or representation.
+
+ Args:
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
+ where B is batch size and S is sequence length.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
+ Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for computing field of view values. For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+
+ Returns:
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ """
+
+ # extrinsics: BxSx3x4
+ # intrinsics: BxSx3x3
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
+ T = extrinsics[:, :, :3, 3] # BxSx3
+
+ quat = mat_to_quat(R)
+ # Note the order of h and w here
+ H, W = image_size_hw
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
+ else:
+ raise NotImplementedError
+
+ return pose_encoding
+
+
+def pose_encoding_to_extri_intri(
+ pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512)
+):
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
+
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
+ reconstructing the full camera parameters from the compact encoding.
+
+ Args:
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
+ where B is batch size and S is sequence length.
+ For "absT_quaR_FoV" type, the 9 dimensions are:
+ - [:3] = absolute translation vector T (3D)
+ - [3:7] = rotation as quaternion quat (4D)
+ - [7:] = field of view (2D)
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
+ Required for reconstructing intrinsics from field of view values.
+ For example: (256, 512).
+ pose_encoding_type (str): Type of pose encoding used. Currently only
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
+ If False, only extrinsics are returned and intrinsics will be None.
+
+ Returns:
+ tuple: (extrinsics, intrinsics)
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
+ a 3x1 translation vector.
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
+ or None if build_intrinsics is False. Defined in pixels, with format:
+ [[fx, 0, cx],
+ [0, fy, cy],
+ [0, 0, 1]]
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
+ assumed to be at the center of the image (W/2, H/2).
+ """
+
+ intrinsics = None
+
+ if pose_encoding_type == "absT_quaR_FoV":
+ T = pose_encoding[..., :3]
+ quat = pose_encoding[..., 3:7]
+ fov_h = pose_encoding[..., 7]
+ fov_w = pose_encoding[..., 8]
+
+ R = quat_to_mat(quat)
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
+
+ if build_intrinsics:
+ H, W = image_size_hw
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
+ intrinsics[..., 0, 0] = fx
+ intrinsics[..., 1, 1] = fy
+ intrinsics[..., 0, 2] = W / 2
+ intrinsics[..., 1, 2] = H / 2
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
+ else:
+ raise NotImplementedError
+
+ return extrinsics, intrinsics
diff --git a/FastVGGT/vggt/utils/rotation.py b/FastVGGT/vggt/utils/rotation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f972afd8414c82fa1e9ed231725fd3f9f6ebde77
--- /dev/null
+++ b/FastVGGT/vggt/utils/rotation.py
@@ -0,0 +1,132 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+
+def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Quaternion Order: XYZW or say ijkr, scalar-last
+
+ Convert rotations given as quaternions to rotation matrices.
+ Args:
+ quaternions: quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ i, j, k, r = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
diff --git a/FastVGGT/vggt/utils/visual_track.py b/FastVGGT/vggt/utils/visual_track.py
new file mode 100644
index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154
--- /dev/null
+++ b/FastVGGT/vggt/utils/visual_track.py
@@ -0,0 +1,239 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import cv2
+import torch
+import numpy as np
+import os
+
+
+def color_from_xy(x, y, W, H, cmap_name="hsv"):
+ """
+ Map (x, y) -> color in (R, G, B).
+ 1) Normalize x,y to [0,1].
+ 2) Combine them into a single scalar c in [0,1].
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
+
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
+ """
+ import matplotlib.cm
+ import matplotlib.colors
+
+ x_norm = x / max(W - 1, 1)
+ y_norm = y / max(H - 1, 1)
+ # Simple combination:
+ c = (x_norm + y_norm) / 2.0
+
+ cmap = matplotlib.cm.get_cmap(cmap_name)
+ # cmap(c) -> (r,g,b,a) in [0,1]
+ rgba = cmap(c)
+ r, g, b = rgba[0], rgba[1], rgba[2]
+ return (r, g, b) # in [0,1], RGB order
+
+
+def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
+ """
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
+ in [0,255]. The color is determined by the (x,y) position in the first
+ visible frame for each track.
+
+ Args:
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
+ image_width, image_height: used for normalizing (x, y).
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
+
+ Returns:
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
+ """
+ S, N, _ = tracks_b.shape
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
+
+ if vis_mask_b is None:
+ # treat all as visible
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
+
+ for i in range(N):
+ # Find first visible frame for track i
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
+ if len(visible_frames) == 0:
+ # track is never visible; just assign black or something
+ track_colors[i] = (0, 0, 0)
+ continue
+
+ first_s = int(visible_frames[0].item())
+ # use that frame's (x,y)
+ x, y = tracks_b[first_s, i].tolist()
+
+ # map (x,y) -> (R,G,B) in [0,1]
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
+ # scale to [0,255]
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
+ track_colors[i] = (r, g, b)
+
+ return track_colors
+
+
+def visualize_tracks_on_images(
+ images,
+ tracks,
+ track_vis_mask=None,
+ out_dir="track_visuals_concat_by_xy",
+ image_format="CHW", # "CHW" or "HWC"
+ normalize_mode="[0,1]",
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
+ frames_per_row=4, # New parameter for grid layout
+ save_grid=True, # Flag to control whether to save the grid image
+):
+ """
+ Visualizes frames in a grid layout with specified frames per row.
+ Each track's color is determined by its (x,y) position
+ in the first visible frame (or frame 0 if always visible).
+ Finally convert the BGR result to RGB before saving.
+ Also saves each individual frame as a separate PNG file.
+
+ Args:
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
+ track_vis_mask: torch.Tensor (S, N) or None.
+ out_dir: folder to save visualizations.
+ image_format: "CHW" or "HWC".
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
+ cmap_name: a matplotlib colormap name for color_from_xy.
+ frames_per_row: number of frames to display in each row of the grid.
+ save_grid: whether to save all frames in one grid image.
+
+ Returns:
+ None (saves images in out_dir).
+ """
+
+ if len(tracks.shape) == 4:
+ tracks = tracks.squeeze(0)
+ images = images.squeeze(0)
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.squeeze(0)
+
+ import matplotlib
+
+ matplotlib.use("Agg") # for non-interactive (optional)
+
+ os.makedirs(out_dir, exist_ok=True)
+
+ S = images.shape[0]
+ _, N, _ = tracks.shape # (S, N, 2)
+
+ # Move to CPU
+ images = images.cpu().clone()
+ tracks = tracks.cpu().clone()
+ if track_vis_mask is not None:
+ track_vis_mask = track_vis_mask.cpu().clone()
+
+ # Infer H, W from images shape
+ if image_format == "CHW":
+ # e.g. images[s].shape = (3, H, W)
+ H, W = images.shape[2], images.shape[3]
+ else:
+ # e.g. images[s].shape = (H, W, 3)
+ H, W = images.shape[1], images.shape[2]
+
+ # Pre-compute the color for each track i based on first visible position
+ track_colors_rgb = get_track_colors_by_position(
+ tracks, # shape (S, N, 2)
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
+ image_width=W,
+ image_height=H,
+ cmap_name=cmap_name,
+ )
+
+ # We'll accumulate each frame's drawn image in a list
+ frame_images = []
+
+ for s in range(S):
+ # shape => either (3, H, W) or (H, W, 3)
+ img = images[s]
+
+ # Convert to (H, W, 3)
+ if image_format == "CHW":
+ img = img.permute(1, 2, 0) # (H, W, 3)
+ # else "HWC", do nothing
+
+ img = img.numpy().astype(np.float32)
+
+ # Scale to [0,255] if needed
+ if normalize_mode == "[0,1]":
+ img = np.clip(img, 0, 1) * 255.0
+ elif normalize_mode == "[-1,1]":
+ img = (img + 1.0) * 0.5 * 255.0
+ img = np.clip(img, 0, 255.0)
+ # else no normalization
+
+ # Convert to uint8
+ img = img.astype(np.uint8)
+
+ # For drawing in OpenCV, convert to BGR
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+ # Draw each visible track
+ cur_tracks = tracks[s] # shape (N, 2)
+ if track_vis_mask is not None:
+ valid_indices = torch.where(track_vis_mask[s])[0]
+ else:
+ valid_indices = range(N)
+
+ cur_tracks_np = cur_tracks.numpy()
+ for i in valid_indices:
+ x, y = cur_tracks_np[i]
+ pt = (int(round(x)), int(round(y)))
+
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
+ R, G, B = track_colors_rgb[i]
+ color_bgr = (int(B), int(G), int(R))
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
+
+ # Convert back to RGB for consistent final saving:
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+
+ # Save individual frame
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
+ # Convert to BGR for OpenCV imwrite
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(frame_path, frame_bgr)
+
+ frame_images.append(img_rgb)
+
+ # Only create and save the grid image if save_grid is True
+ if save_grid:
+ # Calculate grid dimensions
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
+
+ # Create a grid of images
+ grid_img = None
+ for row in range(num_rows):
+ start_idx = row * frames_per_row
+ end_idx = min(start_idx + frames_per_row, S)
+
+ # Concatenate this row horizontally
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
+
+ # If this row has fewer than frames_per_row images, pad with black
+ if end_idx - start_idx < frames_per_row:
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
+ row_img = np.concatenate([row_img, padding], axis=1)
+
+ # Add this row to the grid
+ if grid_img is None:
+ grid_img = row_img
+ else:
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
+
+ out_path = os.path.join(out_dir, "tracks_grid.png")
+ # Convert back to BGR for OpenCV imwrite
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(out_path, grid_img_bgr)
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
+
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")