diff --git a/FastVGGT/.gitignore b/FastVGGT/.gitignore
deleted file mode 100644
index fc1e52d1adf9afe09ea2ead132f8fa4babf666cc..0000000000000000000000000000000000000000
--- a/FastVGGT/.gitignore
+++ /dev/null
@@ -1,160 +0,0 @@
-.hydra/
-output/
-ckpt/
-.vscode/
-dependency/
-# Byte-compiled / optimized / DLL files
-__pycache__/
-**/__pycache__/
-*.py[cod]
-*$py.class
-test_logs/
-quick_start_logs/
-logs/
-*.pth
-/data/
-*.png
-eval_results/
-.vscode/
-.curosr/
-
-# C extensions
-*.so
-LightGlue/
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-pip-wheel-metadata/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-cover/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-.python-version
-
-# pipenv
-# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-# However, in case of collaboration, if having platform-specific dependencies or dependencies
-# having no cross-platform support, pipenv may install dependencies that don't work, or not
-# install all needed dependencies.
-#Pipfile.lock
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-# pytype static type analyzer
-.pytype/
-
-# Profiling data
-.prof
-
-# Folder specific to your needs
-**/tmp/
-**/outputs/skyseg.onnx
-skyseg.onnx
-
-# pixi environments
-.pixi
-*.egg-info
diff --git a/FastVGGT/.vscode/launch.json b/FastVGGT/.vscode/launch.json
deleted file mode 100644
index 218c5dcf94c00a2c11d02a0a1deeab8f37807ebf..0000000000000000000000000000000000000000
--- a/FastVGGT/.vscode/launch.json
+++ /dev/null
@@ -1,85 +0,0 @@
-{
- // Use IntelliSense to learn about possible attributes.
- // Hover to view descriptions of existing attributes.
- // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
- "version": "0.2.0",
- "configurations": [
-
- {
- "name": "launch",
- "type": "debugpy",
- "request": "launch",
- "program": "/home/sy/code/vggt_0625/training/launch.py",
- "console": "integratedTerminal",
- "args": "${command:pickArgs}",
- "env": {
- "CUDA_VISIBLE_DEVICES": "3",
- },
- "cwd": "/home/sy/code/vggt_0625/training",
- "justMyCode": true,
- "python": "/home/sy/anaconda3/envs/vggt/bin/python"
- }
- ,{
- "name": "train_scannet",
- "type": "debugpy",
- "request": "launch",
- "program": "/home/sy/code/vggt_0625/training/launch_scannet.py",
- "console": "integratedTerminal",
- "args": [
- // "--config_name", "scannet",
- // "--exp_name", "scannet_exp001",
- // "--resume_checkpoint_path", "/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt"
- ],
- "env": {
- "CUDA_VISIBLE_DEVICES": "7",
- "WORLD_SIZE": "1",
- "RANK": "0",
- "MASTER_ADDR": "localhost",
- "MASTER_PORT": "12345"
- },
- "cwd": "/home/sy/code/vggt_0625/training",
- "justMyCode": true,
- "python": "/home/sy/anaconda3/envs/vggt/bin/python"
- }
- ,{
- "name": "eval_scannet",
- "type": "debugpy",
- "request": "launch",
- "program": "/home/sy/code/FastVGGT/eval/eval_scannet.py",
- "console": "integratedTerminal",
- "args": [
- "--data_dir","/data/sy/scannetv2/process_scannet",
- "--gt_ply_dir","/data/sy/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
- "--output_path", "/home/sy/code/FastVGGT/eval_results",
- "--merging", "0",
- "--ckpt_path","/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt",
- "--vis_attn_map"
- ],
- "env": {
- "CUDA_VISIBLE_DEVICES": "2"
- },
- "justMyCode": true,
- "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
- },
- {
- "name": "eval_cd",
- "type": "debugpy",
- "request": "launch",
- "program": "/home/sy/code/FastVGGT/eval/eval_custom.py",
- "console": "integratedTerminal",
- "args": [
- "--merging", "0",
- // "--kf","10",
- // "--output_dir","/home/sy/code/vggt_0625/eval_results_cd",
- "--data_path","/data/sy/segment-102751/",
- "--vis_attn_map"
- ],
- "env": {
- "CUDA_VISIBLE_DEVICES": "3"
- },
- "justMyCode": true,
- // "python": "/home/sy/anaconda3/envs/fastvggt/bin/python"
- }
-
- ]
-}
\ No newline at end of file
diff --git a/FastVGGT/README.md b/FastVGGT/README.md
deleted file mode 100644
index 274ea91f0e398fdf0e94b921fb5bb7c48fe799fd..0000000000000000000000000000000000000000
--- a/FastVGGT/README.md
+++ /dev/null
@@ -1,163 +0,0 @@
-
-
⚡️ FastVGGT: Training-Free Acceleration of Visual Geometry Transformer
-
-
-
-
-
-
-

-

-
-
-**[Media Analytics & Computing Laboratory](https://mac.xmu.edu.cn/)**; **[AUTOLAB](https://zhipengzhang.cn/)**
-
-
-[You Shen](https://mystorm16.github.io/), [Zhipeng Zhang](https://zhipengzhang.cn/), [Yansong Qu](https://quyans.github.io/), [Liujuan Cao](https://mac.xmu.edu.cn/ljcao/)
-
-
-
-## 📰 News
-- [Sep 8, 2025] Added custom dataset evaluation.
-- [Sep 3, 2025] Paper release.
-- [Sep 2, 2025] Code release.
-
-## 🔭 Overview
-
-FastVGGT observes **strong similarity** in attention maps and leverages it to design a training-free acceleration method for long-sequence 3D reconstruction, **achieving up to 4× faster inference without sacrificing accuracy.**
-
-
-
-
-## ⚙️ Environment Setup
-First, create a virtual environment using Conda, clone this repository to your local machine, and install the required dependencies.
-
-
-```bash
-conda create -n fastvggt python=3.10
-conda activate fastvggt
-git clone git@github.com:mystorm16/FastVGGT.git
-cd FastVGGT
-pip install -r requirements.txt
-```
-
-Next, prepare the ScanNet dataset: http://www.scan-net.org/ScanNet/
-
-Then, download the VGGT checkpoint (we use the checkpoint link provided in https://github.com/facebookresearch/vggt/tree/evaluation/evaluation):
-```bash
-wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt
-```
-
-Finally, configure the dataset path and VGGT checkpoint path. For example:
-```bash
- parser.add_argument(
- "--data_dir", type=Path, default="/data/scannetv2/process_scannet"
- )
- parser.add_argument(
- "--gt_ply_dir",
- type=Path,
- default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
- )
- parser.add_argument(
- "--ckpt_path",
- type=str,
- default="./ckpt/model_tracker_fixed_e20.pt",
- )
-```
-
-
-## 💎 Observation
-
-Note: A large number of input_frames may significantly slow down saving the visualization results. Please try using a smaller number first.
-```bash
-python eval/eval_scannet.py --input_frame 30 --vis_attn_map --merging 0
-```
-
-We observe that many token-level attention maps are highly similar in each block, motivating our optimization of the Global Attention module.
-
-
-
-
-
-## 🏀 Evaluation
-### Custom Dataset
-Please organize the data according to the following directory:
-```
-/
-├── images/
-│ ├── 000000.jpg
-│ ├── 000001.jpg
-│ └── ...
-├── pose/ # Optional: Camera poses
-│ ├── 000000.txt
-│ ├── 000001.txt
-│ └── ...
-└── gt_ply/ # Optional: GT point cloud
- └── scene_xxx.ply
-```
-- Required: `images/`
-- Additionally required when `--enable_evaluation` is enabled: `pose/` and `gt_ply/`
-
-Inference only:
-
-```bash
-python eval/eval_custom.py \
- --data_path /path/to/your_dataset \
- --output_path ./eval_results_custom \
- --plot
-```
-
-Inference + Evaluation (requires `pose/` and `gt_ply/`):
-
-```bash
-python eval/eval_custom.py \
- --data_path /path/to/your_dataset \
- --enable_evaluation \
- --output_path ./eval_results_custom \
- --plot
-```
-
-### ScanNet
-Evaluate FastVGGT on the ScanNet dataset with 1,000 input images. The **--merging** parameter specifies the block index at which the merging strategy is applied:
-
-```bash
-python eval/eval_scannet.py --input_frame 1000 --merging 0
-```
-
-Evaluate Baseline VGGT on the ScanNet dataset with 1,000 input images:
-```bash
-python eval/eval_scannet.py --input_frame 1000
-```
-
-
-### 7 Scenes & NRGBD
-Evaluate across two datasets, sampling keyframes every 10 frames:
-```bash
-python eval/eval_7andN.py --kf 10
-```
-
-## 🍺 Acknowledgements
-
-- Thanks to these great repositories: [VGGT](https://github.com/facebookresearch/vggt), [Dust3r](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [CUT3R](https://github.com/CUT3R/CUT3R), [MV-DUSt3R+](https://github.com/facebookresearch/mvdust3r), [StreamVGGT](https://github.com/wzzheng/StreamVGGT), [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long), [ToMeSD](https://github.com/dbolya/tomesd) and many other inspiring works in the community.
-
-- Special thanks to [Jianyuan Wang](https://jytime.github.io/) for his valuable discussions and suggestions on this work.
-
-
-
-
-## ⚖️ License
-See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
-
-## Citation
-
-If you find this project helpful, please consider citing the following paper:
-```
-@article{shen2025fastvggt,
- title={FastVGGT: Training-Free Acceleration of Visual Geometry Transformer},
- author={Shen, You and Zhang, Zhipeng and Qu, Yansong and Cao, Liujuan},
- journal={arXiv preprint arXiv:2509.02560},
- year={2025}
-}
-```
\ No newline at end of file
diff --git a/FastVGGT/assets/attn_map.png b/FastVGGT/assets/attn_map.png
deleted file mode 100644
index 98d6f2be8cdb49e5bccee7f902af292327fd0023..0000000000000000000000000000000000000000
--- a/FastVGGT/assets/attn_map.png
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:8477957f593c203bcf41df91ac3ed0d22329e22250fdab9f8f8674340964242c
-size 2369408
diff --git a/FastVGGT/assets/autolab_logo.png b/FastVGGT/assets/autolab_logo.png
deleted file mode 100644
index c6ca9a41f5c05a1b2ae66368461672f46e20bba1..0000000000000000000000000000000000000000
--- a/FastVGGT/assets/autolab_logo.png
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:4fcead3160cbf561c4385cc8b938a17a94652e3d849da6497f053d32d1245596
-size 5126003
diff --git a/FastVGGT/assets/maclab_logo.png b/FastVGGT/assets/maclab_logo.png
deleted file mode 100644
index b96660decf945d3d786f41b9b42cb92798f9191f..0000000000000000000000000000000000000000
Binary files a/FastVGGT/assets/maclab_logo.png and /dev/null differ
diff --git a/FastVGGT/assets/main.png b/FastVGGT/assets/main.png
deleted file mode 100644
index 3bcab2d519689dfba3283e606cd55cb84175ede5..0000000000000000000000000000000000000000
--- a/FastVGGT/assets/main.png
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:eecacb414647f01dc8a52b4aba5ff2556733f46d1b9129613e3f59aceff69685
-size 883717
diff --git a/FastVGGT/assets/vs.png b/FastVGGT/assets/vs.png
deleted file mode 100644
index 3729d9ba86000b81d0e9c878b2db6f16c7f9c48a..0000000000000000000000000000000000000000
--- a/FastVGGT/assets/vs.png
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:dac1ea5397b9f985c6fb86fc5f321313cc5a372d0917b7c1c1c7dd2cea6bec5f
-size 230157
diff --git a/FastVGGT/eval/__pycache__/base.cpython-310.pyc b/FastVGGT/eval/__pycache__/base.cpython-310.pyc
deleted file mode 100644
index 49170288c162684ebafa68e5242efd320f0f35b8..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/__pycache__/base.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc b/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc
deleted file mode 100644
index 6e6e0478775fe98dc22180bc43bad8073c43c907..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/eval/__pycache__/data.cpython-310.pyc b/FastVGGT/eval/__pycache__/data.cpython-310.pyc
deleted file mode 100644
index 4d8ec6f8e1828fa87b02a4511490ad858de2093a..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/__pycache__/data.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/eval/__pycache__/data.cpython-37.pyc b/FastVGGT/eval/__pycache__/data.cpython-37.pyc
deleted file mode 100644
index d76f2c95748956a04f009e99f39cefd9cc6ec4e0..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/__pycache__/data.cpython-37.pyc and /dev/null differ
diff --git a/FastVGGT/eval/__pycache__/utils.cpython-310.pyc b/FastVGGT/eval/__pycache__/utils.cpython-310.pyc
deleted file mode 100644
index 638db49f10789e53dcf0ff33c9aaa28b814e3e72..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/__pycache__/utils.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/eval/__pycache__/utils.cpython-37.pyc b/FastVGGT/eval/__pycache__/utils.cpython-37.pyc
deleted file mode 100644
index fd73966947cb74c319d66d7bf221d73b5dde0ce0..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/__pycache__/utils.cpython-37.pyc and /dev/null differ
diff --git a/FastVGGT/eval/base.py b/FastVGGT/eval/base.py
deleted file mode 100644
index 4a716449a71f552ea408df0eb37e854cf4e92da6..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/base.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# Copyright (C) 2024-present Naver Corporation. All rights reserved.
-# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
-#
-# --------------------------------------------------------
-# base class for implementing datasets
-# --------------------------------------------------------
-import PIL
-import numpy as np
-import torch
-
-from dataset_utils.transforms import ImgNorm
-import dataset_utils.cropping as cropping
-from utils import depthmap_to_absolute_camera_coordinates
-
-
-class BaseStereoViewDataset:
- """Define all basic options.
-
- Usage:
- class MyDataset (BaseStereoViewDataset):
- def _get_views(self, idx, rng):
- # overload here
- views = []
- views.append(dict(img=, ...))
- return views
- """
-
- def __init__(
- self,
- *, # only keyword arguments
- split=None,
- resolution=None, # square_size or (width, height) or list of [(width,height), ...]
- transform=ImgNorm,
- aug_crop=False,
- seed=None,
- ):
- self.num_views = 2
- self.split = split
- self._set_resolutions(resolution)
-
- self.transform = transform
- if isinstance(transform, str):
- transform = eval(transform)
-
- self.aug_crop = aug_crop
- self.seed = seed
-
- def __len__(self):
- return len(self.scenes)
-
- def get_stats(self):
- return f"{len(self)} pairs"
-
- def __repr__(self):
- resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
- return (
- f"""{type(self).__name__}({self.get_stats()},
- {self.split=},
- {self.seed=},
- resolutions={resolutions_str},
- {self.transform=})""".replace(
- "self.", ""
- )
- .replace("\n", "")
- .replace(" ", "")
- )
-
- def _get_views(self, idx, resolution, rng):
- raise NotImplementedError()
-
- def __getitem__(self, idx):
- if isinstance(idx, tuple):
- # the idx is specifying the aspect-ratio
- idx, ar_idx = idx
- else:
- assert len(self._resolutions) == 1
- ar_idx = 0
-
- # set-up the rng
- if self.seed: # reseed for each __getitem__
- self._rng = np.random.default_rng(seed=self.seed + idx)
- elif not hasattr(self, "_rng"):
- seed = torch.initial_seed() # this is different for each dataloader process
- self._rng = np.random.default_rng(seed=seed)
-
- # over-loaded code
- resolution = self._resolutions[
- ar_idx
- ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
- views = self._get_views(idx, resolution, self._rng)
-
- # check data-types
- for v, view in enumerate(views):
- assert (
- "pts3d" not in view
- ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
- view["idx"] = v
-
- # encode the image
- width, height = view["img"].size
- view["true_shape"] = np.int32((height, width))
- view["img"] = self.transform(view["img"])
-
- assert "camera_intrinsics" in view
- if "camera_pose" not in view:
- view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32)
- else:
- assert np.isfinite(
- view["camera_pose"]
- ).all(), f"NaN in camera pose for view {view_name(view)}"
- assert "pts3d" not in view
- assert "valid_mask" not in view
- assert np.isfinite(
- view["depthmap"]
- ).all(), f"NaN in depthmap for view {view_name(view)}"
- pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
-
- view["pts3d"] = pts3d
- view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
-
- # check all datatypes
- for key, val in view.items():
- res, err_msg = is_good_type(key, val)
- assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
- K = view["camera_intrinsics"]
- view["img_mask"] = True
- view["ray_mask"] = False
- view["ray_map"] = torch.full(
- (6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan
- )
- view["update"] = True
- view["reset"] = False
-
- # last thing done!
- for view in views:
- # transpose to make sure all views are the same size
- transpose_to_landscape(view)
- # this allows to check whether the RNG is is the same state each time
- view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
- return views
-
- def _set_resolutions(self, resolutions):
- """Set the resolution(s) of the dataset.
- Params:
- - resolutions: int or tuple or list of tuples
- """
- assert resolutions is not None, "undefined resolution"
-
- if not isinstance(resolutions, list):
- resolutions = [resolutions]
-
- self._resolutions = []
- for resolution in resolutions:
- if isinstance(resolution, int):
- width = height = resolution
- else:
- width, height = resolution
- assert isinstance(
- width, int
- ), f"Bad type for {width=} {type(width)=}, should be int"
- assert isinstance(
- height, int
- ), f"Bad type for {height=} {type(height)=}, should be int"
- assert width >= height
- self._resolutions.append((width, height))
-
- def _crop_resize_if_necessary(
- self, image, depthmap, intrinsics, resolution, rng=None, info=None
- ):
- """This function:
- - first downsizes the image with LANCZOS inteprolation,
- which is better than bilinear interpolation in
- """
- if not isinstance(image, PIL.Image.Image):
- image = PIL.Image.fromarray(image)
-
- # downscale with lanczos interpolation so that image.size == resolution
- # cropping centered on the principal point
- W, H = image.size
- cx, cy = intrinsics[:2, 2].round().astype(int)
-
- # calculate min distance to margin
- min_margin_x = min(cx, W - cx)
- min_margin_y = min(cy, H - cy)
- assert min_margin_x > W / 5, f"Bad principal point in view={info}"
- assert min_margin_y > H / 5, f"Bad principal point in view={info}"
-
- ## Center crop
- # Crop on the principal point, make it always centered
- # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
- l, t = cx - min_margin_x, cy - min_margin_y
- r, b = cx + min_margin_x, cy + min_margin_y
- crop_bbox = (l, t, r, b)
-
- image, depthmap, intrinsics = cropping.crop_image_depthmap(
- image, depthmap, intrinsics, crop_bbox
- )
-
- # # transpose the resolution if necessary
- W, H = image.size # new size
- assert resolution[0] >= resolution[1]
- if H > 1.1 * W:
- # image is portrait mode
- resolution = resolution[::-1]
- elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
- # image is square, so we chose (portrait, landscape) randomly
- if rng.integers(2):
- resolution = resolution[::-1]
-
- # high-quality Lanczos down-scaling
- target_resolution = np.array(resolution)
- # # if self.aug_crop > 1:
- # # target_resolution += rng.integers(0, self.aug_crop)
- # if resolution != (224, 224):
- # halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8
- # ## Recale with max factor, so one of width or height might be larger than target_resolution
- # image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh))
- # else:
- image, depthmap, intrinsics = cropping.rescale_image_depthmap(
- image, depthmap, intrinsics, target_resolution
- )
- # actual cropping (if necessary) with bilinear interpolation
- # if resolution == (224, 224):
- intrinsics2 = cropping.camera_matrix_of_crop(
- intrinsics, image.size, resolution, offset_factor=0.5
- )
- crop_bbox = cropping.bbox_from_intrinsics_in_out(
- intrinsics, intrinsics2, resolution
- )
- image, depthmap, intrinsics = cropping.crop_image_depthmap(
- image, depthmap, intrinsics, crop_bbox
- )
- return image, depthmap, intrinsics
-
-
-def is_good_type(key, v):
- """returns (is_good, err_msg)"""
- if isinstance(v, (str, int, tuple)):
- return True, None
- if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
- return False, f"bad {v.dtype=}"
- return True, None
-
-
-def view_name(view, batch_index=None):
- def sel(x):
- return x[batch_index] if batch_index not in (None, slice(None)) else x
-
- db = sel(view["dataset"])
- label = sel(view["label"])
- instance = sel(view["instance"])
- return f"{db}/{label}/{instance}"
-
-
-def transpose_to_landscape(view):
- height, width = view["true_shape"]
-
- if width < height:
- # rectify portrait to landscape
- assert view["img"].shape == (3, height, width)
- view["img"] = view["img"].swapaxes(1, 2)
-
- assert view["valid_mask"].shape == (height, width)
- view["valid_mask"] = view["valid_mask"].swapaxes(0, 1)
-
- assert view["depthmap"].shape == (height, width)
- view["depthmap"] = view["depthmap"].swapaxes(0, 1)
-
- assert view["pts3d"].shape == (height, width, 3)
- view["pts3d"] = view["pts3d"].swapaxes(0, 1)
-
- # transpose x and y pixels
- view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]]
diff --git a/FastVGGT/eval/criterion.py b/FastVGGT/eval/criterion.py
deleted file mode 100644
index a63c991b5fd5b61325b56055d7d380a27a336971..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/criterion.py
+++ /dev/null
@@ -1,534 +0,0 @@
-import torch
-import torch.nn as nn
-from copy import copy, deepcopy
-
-from eval.dataset_utils.corr import geotrf, inv
-
-
-def invalid_to_nans(arr, valid_mask, ndim=999):
- if valid_mask is not None:
- arr = arr.clone()
- arr[~valid_mask] = float("nan")
- if arr.ndim > ndim:
- arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
- return arr
-
-
-def invalid_to_zeros(arr, valid_mask, ndim=999):
- if valid_mask is not None:
- arr = arr.clone()
- arr[~valid_mask] = 0
- nnz = valid_mask.view(len(valid_mask), -1).sum(1)
- else:
- nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
- if arr.ndim > ndim:
- arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
- return arr, nnz
-
-
-class BaseCriterion(nn.Module):
- def __init__(self, reduction="mean"):
- super().__init__()
- self.reduction = reduction
-
-
-class Criterion(nn.Module):
- def __init__(self, criterion=None):
- super().__init__()
- assert isinstance(
- criterion, BaseCriterion
- ), f"{criterion} is not a proper criterion!"
- self.criterion = copy(criterion)
-
- def get_name(self):
- return f"{type(self).__name__}({self.criterion})"
-
- def with_reduction(self, mode="none"):
- res = loss = deepcopy(self)
- while loss is not None:
- assert isinstance(loss, Criterion)
- loss.criterion.reduction = mode # make it return the loss for each sample
- loss = loss._loss2 # we assume loss is a Multiloss
- return res
-
-
-class MultiLoss(nn.Module):
- """Easily combinable losses (also keep track of individual loss values):
- loss = MyLoss1() + 0.1*MyLoss2()
- Usage:
- Inherit from this class and override get_name() and compute_loss()
- """
-
- def __init__(self):
- super().__init__()
- self._alpha = 1
- self._loss2 = None
-
- def compute_loss(self, *args, **kwargs):
- raise NotImplementedError()
-
- def get_name(self):
- raise NotImplementedError()
-
- def __mul__(self, alpha):
- assert isinstance(alpha, (int, float))
- res = copy(self)
- res._alpha = alpha
- return res
-
- __rmul__ = __mul__ # same
-
- def __add__(self, loss2):
- assert isinstance(loss2, MultiLoss)
- res = cur = copy(self)
-
- while cur._loss2 is not None:
- cur = cur._loss2
- cur._loss2 = loss2
- return res
-
- def __repr__(self):
- name = self.get_name()
- if self._alpha != 1:
- name = f"{self._alpha:g}*{name}"
- if self._loss2:
- name = f"{name} + {self._loss2}"
- return name
-
- def forward(self, *args, **kwargs):
- loss = self.compute_loss(*args, **kwargs)
- if isinstance(loss, tuple):
- loss, details = loss
- elif loss.ndim == 0:
- details = {self.get_name(): float(loss)}
- else:
- details = {}
- loss = loss * self._alpha
-
- if self._loss2:
- loss2, details2 = self._loss2(*args, **kwargs)
- loss = loss + loss2
- details |= details2
-
- return loss, details
-
-
-class LLoss(BaseCriterion):
- """L-norm loss"""
-
- def forward(self, a, b):
- assert (
- a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3
- ), f"Bad shape = {a.shape}"
- dist = self.distance(a, b)
-
- if self.reduction == "none":
- return dist
- if self.reduction == "sum":
- return dist.sum()
- if self.reduction == "mean":
- return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
- raise ValueError(f"bad {self.reduction=} mode")
-
- def distance(self, a, b):
- raise NotImplementedError()
-
-
-class L21Loss(LLoss):
- """Euclidean distance between 3d points"""
-
- def distance(self, a, b):
- return torch.norm(a - b, dim=-1) # normalized L2 distance
-
-
-L21 = L21Loss()
-
-
-def get_pred_pts3d(gt, pred, use_pose=False):
- assert use_pose is True
- return pred["pts3d_in_other_view"] # return!
-
-
-def Sum(losses, masks, conf=None):
- loss, mask = losses[0], masks[0]
- if loss.ndim > 0:
- # we are actually returning the loss for every pixels
- if conf is not None:
- return losses, masks, conf
- return losses, masks
- else:
- # we are returning the global loss
- for loss2 in losses[1:]:
- loss = loss + loss2
- return loss
-
-
-def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True):
- assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3
- assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3)
- norm_mode, dis_mode = norm_mode.split("_")
-
- nan_pts = []
- nnzs = []
-
- if norm_mode == "avg":
- # gather all points together (joint normalization)
-
- for i, pt in enumerate(pts):
- nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3)
- nan_pts.append(nan_pt)
- nnzs.append(nnz)
-
- if fix_first:
- break
- all_pts = torch.cat(nan_pts, dim=1)
-
- # compute distance to origin
- all_dis = all_pts.norm(dim=-1)
- if dis_mode == "dis":
- pass # do nothing
- elif dis_mode == "log1p":
- all_dis = torch.log1p(all_dis)
- else:
- raise ValueError(f"bad {dis_mode=}")
-
- norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8)
- else:
- raise ValueError(f"Not implemented {norm_mode=}")
-
- norm_factor = norm_factor.clip(min=1e-8)
- while norm_factor.ndim < pts[0].ndim:
- norm_factor.unsqueeze_(-1)
-
- return norm_factor
-
-
-def normalize_pointcloud_t(
- pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False
-):
- if gt:
- norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
- res = []
-
- for i, pt in enumerate(pts):
- res.append(pt / norm_factor)
-
- else:
- # pts_l, pts_r = pts
- # use pts_l and pts_r[-1] as pts to normalize
- norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first)
-
- res = []
-
- for i in range(len(pts)):
- res.append(pts[i] / norm_factor)
- # res_r.append(pts_r[i] / norm_factor)
-
- # res = [res_l, res_r]
-
- return res, norm_factor
-
-
-@torch.no_grad()
-def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5):
- # set invalid points to NaN
- _zs = []
- for i in range(len(zs)):
- valid_mask = valid_masks[i] if valid_masks is not None else None
- _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1)
- _zs.append(_z)
-
- _zs = torch.cat(_zs, dim=-1)
-
- # compute median depth overall (ignoring nans)
- if quantile == 0.5:
- shift_z = torch.nanmedian(_zs, dim=-1).values
- else:
- shift_z = torch.nanquantile(_zs, quantile, dim=-1)
- return shift_z # (B,)
-
-
-@torch.no_grad()
-def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True):
- # set invalid points to NaN
-
- _pts = []
- for i in range(len(pts)):
- valid_mask = valid_masks[i] if valid_masks is not None else None
- _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3)
- _pts.append(_pt)
-
- _pts = torch.cat(_pts, dim=1)
-
- # compute median center
- _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
- if z_only:
- _center[..., :2] = 0 # do not center X and Y
-
- # compute median norm
- _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
- scale = torch.nanmedian(_norm, dim=1).values
- return _center[:, None, :, :], scale[:, None, None, None]
-
-
-class Regr3D_t(Criterion, MultiLoss):
- def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True):
- super().__init__(criterion)
- self.norm_mode = norm_mode
- self.gt_scale = gt_scale
- self.fix_first = fix_first
-
- def get_all_pts3d_t(self, gts, preds, dist_clip=None):
- # everything is normalized w.r.t. camera of view1
- in_camera1 = inv(gts[0]["camera_pose"])
-
- gt_pts = []
- valids = []
- pr_pts = []
-
- for i, gt in enumerate(gts):
- # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3
- gt_pts.append(geotrf(in_camera1, gt["pts3d"]))
- valid = gt["valid_mask"].clone()
-
- if dist_clip is not None:
- # points that are too far-away == invalid
- dis = gt["pts3d"].norm(dim=-1)
- valid = valid & (dis <= dist_clip)
-
- valids.append(valid)
- pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True))
- # if i != len(gts)-1:
- # pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0)))
-
- # if i != 0:
- # pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0)))
-
- # pr_pts = (pr_pts_l, pr_pts_r)
-
- if self.norm_mode:
- pr_pts, pr_factor = normalize_pointcloud_t(
- pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False
- )
- else:
- pr_factor = None
-
- if self.norm_mode and not self.gt_scale:
- gt_pts, gt_factor = normalize_pointcloud_t(
- gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True
- )
- else:
- gt_factor = None
-
- return gt_pts, pr_pts, gt_factor, pr_factor, valids, {}
-
- def compute_frame_loss(self, gts, preds, **kw):
- gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
- self.get_all_pts3d_t(gts, preds, **kw)
- )
-
- pred_pts_l, pred_pts_r = pred_pts
-
- loss_all = []
- mask_all = []
- conf_all = []
-
- loss_left = 0
- loss_right = 0
- pred_conf_l = 0
- pred_conf_r = 0
-
- for i in range(len(gt_pts)):
-
- # Left (Reference)
- if i != len(gt_pts) - 1:
- frame_loss = self.criterion(
- pred_pts_l[i][masks[i]], gt_pts[i][masks[i]]
- )
-
- loss_all.append(frame_loss)
- mask_all.append(masks[i])
- conf_all.append(preds[i][0]["conf"])
-
- # To compare target/reference loss
- if i != 0:
- loss_left += frame_loss.cpu().detach().numpy().mean()
- pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean()
-
- # Right (Target)
- if i != 0:
- frame_loss = self.criterion(
- pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]]
- )
-
- loss_all.append(frame_loss)
- mask_all.append(masks[i])
- conf_all.append(preds[i - 1][1]["conf"])
-
- # To compare target/reference loss
- if i != len(gt_pts) - 1:
- loss_right += frame_loss.cpu().detach().numpy().mean()
- pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean()
-
- if pr_factor is not None and gt_factor is not None:
- filter_factor = pr_factor[pr_factor > gt_factor]
- else:
- filter_factor = []
-
- if len(filter_factor) > 0:
- factor_loss = (filter_factor - gt_factor).abs().mean()
- else:
- factor_loss = 0.0
-
- self_name = type(self).__name__
- details = {
- self_name + "_pts3d_1": float(loss_all[0].mean()),
- self_name + "_pts3d_2": float(loss_all[1].mean()),
- self_name + "loss_left": float(loss_left),
- self_name + "loss_right": float(loss_right),
- self_name + "conf_left": float(pred_conf_l),
- self_name + "conf_right": float(pred_conf_r),
- }
-
- return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss
-
-
-class ConfLoss_t(MultiLoss):
- """Weighted regression by learned confidence.
- Assuming the input pixel_loss is a pixel-level regression loss.
-
- Principle:
- high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
- low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10)
-
- alpha: hyperparameter
- """
-
- def __init__(self, pixel_loss, alpha=1):
- super().__init__()
- assert alpha > 0
- self.alpha = alpha
- self.pixel_loss = pixel_loss.with_reduction("none")
-
- def get_name(self):
- return f"ConfLoss({self.pixel_loss})"
-
- def get_conf_log(self, x):
- return x, torch.log(x)
-
- def compute_frame_loss(self, gts, preds, **kw):
- # compute per-pixel loss
- (losses, masks, confs), details, loss_factor = (
- self.pixel_loss.compute_frame_loss(gts, preds, **kw)
- )
-
- # weight by confidence
- conf_losses = []
- conf_sum = 0
- for i in range(len(losses)):
- conf, log_conf = self.get_conf_log(confs[i][masks[i]])
- conf_sum += conf.mean()
- conf_loss = losses[i] * conf - self.alpha * log_conf
- conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
- conf_losses.append(conf_loss)
-
- conf_losses = torch.stack(conf_losses) * 2.0
- conf_loss_mean = conf_losses.mean()
-
- return (
- conf_loss_mean,
- dict(
- conf_loss_1=float(conf_losses[0]),
- conf_loss2=float(conf_losses[1]),
- conf_mean=conf_sum / len(losses),
- **details,
- ),
- loss_factor,
- )
-
-
-class Regr3D_t_ShiftInv(Regr3D_t):
- """Same than Regr3D but invariant to depth shift."""
-
- def get_all_pts3d_t(self, gts, preds):
- # compute unnormalized points
- gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
- super().get_all_pts3d_t(gts, preds)
- )
-
- # pred_pts_l, pred_pts_r = pred_pts
- gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts]
-
- pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts]
- # pred_zs.append(pred_pts_r[-1][..., 2])
-
- # compute median depth
- gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None]
- pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None]
-
- # subtract the median depth
- for i in range(len(gt_pts)):
- gt_pts[i][..., 2] -= gt_shift_z
-
- for i in range(len(pred_pts)):
- # for j in range(len(pred_pts[i])):
- pred_pts[i][..., 2] -= pred_shift_z
-
- monitoring = dict(
- monitoring,
- gt_shift_z=gt_shift_z.mean().detach(),
- pred_shift_z=pred_shift_z.mean().detach(),
- )
- return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
-
-
-class Regr3D_t_ScaleInv(Regr3D_t):
- """Same than Regr3D but invariant to depth shift.
- if gt_scale == True: enforce the prediction to take the same scale than GT
- """
-
- def get_all_pts3d_t(self, gts, preds):
- # compute depth-normalized points
- gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
- super().get_all_pts3d_t(gts, preds)
- )
-
- # measure scene scale
-
- # pred_pts_l, pred_pts_r = pred_pts
-
- pred_pts_all = [
- x.clone() for x in pred_pts
- ] # [pred_pt for pred_pt in pred_pts_l]
- # pred_pts_all.append(pred_pts_r[-1])
-
- _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks)
- _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks)
-
- # prevent predictions to be in a ridiculous range
- pred_scale = pred_scale.clip(min=1e-3, max=1e3)
-
- # subtract the median depth
- if self.gt_scale:
- for i in range(len(pred_pts)):
- # for j in range(len(pred_pts[i])):
- pred_pts[i] *= gt_scale / pred_scale
-
- else:
- for i in range(len(pred_pts)):
- # for j in range(len(pred_pts[i])):
- pred_pts[i] *= pred_scale / gt_scale
-
- for i in range(len(gt_pts)):
- gt_pts[i] *= gt_scale / pred_scale
-
- monitoring = dict(
- monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach()
- )
-
- return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring
-
-
-class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv):
- # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
- pass
diff --git a/FastVGGT/eval/data.py b/FastVGGT/eval/data.py
deleted file mode 100644
index 301788e3ae6b67091fb031b068ecbecff1eae48d..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/data.py
+++ /dev/null
@@ -1,338 +0,0 @@
-import os
-import cv2
-import numpy as np
-import os.path as osp
-from collections import deque
-from base import BaseStereoViewDataset
-import dataset_utils.cropping as cropping
-from vggt.utils.eval_utils import imread_cv2, shuffle_deque
-
-
-class SevenScenes(BaseStereoViewDataset):
- def __init__(
- self,
- num_seq=1,
- num_frames=5,
- min_thresh=10,
- max_thresh=100,
- test_id=None,
- full_video=False,
- tuple_list=None,
- seq_id=None,
- rebuttal=False,
- shuffle_seed=-1,
- kf_every=1,
- *args,
- ROOT,
- **kwargs,
- ):
- self.ROOT = ROOT
- super().__init__(*args, **kwargs)
- self.num_seq = num_seq
- self.num_frames = num_frames
- self.max_thresh = max_thresh
- self.min_thresh = min_thresh
- self.test_id = test_id
- self.full_video = full_video
- self.kf_every = kf_every
- self.seq_id = seq_id
- self.rebuttal = rebuttal
- self.shuffle_seed = shuffle_seed
-
- # load all scenes
- self.load_all_tuples(tuple_list)
- self.load_all_scenes(ROOT)
-
- def __len__(self):
- if self.tuple_list is not None:
- return len(self.tuple_list)
- return len(self.scene_list) * self.num_seq
-
- def load_all_tuples(self, tuple_list):
- if tuple_list is not None:
- self.tuple_list = tuple_list
- # with open(tuple_path) as f:
- # self.tuple_list = f.read().splitlines()
-
- else:
- self.tuple_list = None
-
- def load_all_scenes(self, base_dir):
-
- if self.tuple_list is not None:
- # Use pre-defined simplerecon scene_ids
- self.scene_list = [
- "stairs/seq-06",
- "stairs/seq-02",
- "pumpkin/seq-06",
- "chess/seq-01",
- "heads/seq-02",
- "fire/seq-02",
- "office/seq-03",
- "pumpkin/seq-03",
- "redkitchen/seq-07",
- "chess/seq-02",
- "office/seq-01",
- "redkitchen/seq-01",
- "fire/seq-01",
- ]
- print(f"Found {len(self.scene_list)} sequences in split {self.split}")
- return
-
- scenes = os.listdir(base_dir)
-
- file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split]
-
- self.scene_list = []
- for scene in scenes:
- if self.test_id is not None and scene != self.test_id:
- continue
- # read file split
- with open(osp.join(base_dir, scene, file_split)) as f:
- seq_ids = f.read().splitlines()
-
- for seq_id in seq_ids:
- # seq is string, take the int part and make it 01, 02, 03
- # seq_id = 'seq-{:2d}'.format(int(seq_id))
- num_part = "".join(filter(str.isdigit, seq_id))
- seq_id = f"seq-{num_part.zfill(2)}"
- if self.seq_id is not None and seq_id != self.seq_id:
- continue
- self.scene_list.append(f"{scene}/{seq_id}")
-
- print(f"Found {len(self.scene_list)} sequences in split {self.split}")
-
- def _get_views(self, idx, resolution, rng):
-
- if self.tuple_list is not None:
- line = self.tuple_list[idx].split(" ")
- scene_id = line[0]
- img_idxs = line[1:]
-
- else:
- scene_id = self.scene_list[idx // self.num_seq]
- seq_id = idx % self.num_seq
-
- data_path = osp.join(self.ROOT, scene_id)
- num_files = len([name for name in os.listdir(data_path) if "color" in name])
- img_idxs = [f"{i:06d}" for i in range(num_files)]
- img_idxs = img_idxs[:: self.kf_every]
-
- # Intrinsics used in SimpleRecon
- fx, fy, cx, cy = 525, 525, 320, 240
- intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
-
- views = []
- imgs_idxs = deque(img_idxs)
- if self.shuffle_seed >= 0:
- imgs_idxs = shuffle_deque(imgs_idxs)
-
- while len(imgs_idxs) > 0:
- im_idx = imgs_idxs.popleft()
- impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png")
- depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png")
- posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt")
-
- rgb_image = imread_cv2(impath)
-
- depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
- rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
-
- depthmap[depthmap == 65535] = 0
- depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
-
- depthmap[depthmap > 10] = 0
- depthmap[depthmap < 1e-3] = 0
-
- camera_pose = np.loadtxt(posepath).astype(np.float32)
-
- if resolution != (224, 224) or self.rebuttal:
- rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
- rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
- )
- else:
- rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
- rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
- )
- W, H = rgb_image.size
- cx = W // 2
- cy = H // 2
- l, t = cx - 112, cy - 112
- r, b = cx + 112, cy + 112
- crop_bbox = (l, t, r, b)
- rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
- rgb_image, depthmap, intrinsics, crop_bbox
- )
-
- views.append(
- dict(
- img=rgb_image,
- depthmap=depthmap,
- camera_pose=camera_pose,
- camera_intrinsics=intrinsics,
- dataset="7scenes",
- label=osp.join(scene_id, im_idx),
- instance=impath,
- )
- )
- return views
-
-
-class NRGBD(BaseStereoViewDataset):
- def __init__(
- self,
- num_seq=1,
- num_frames=5,
- min_thresh=10,
- max_thresh=100,
- test_id=None,
- full_video=False,
- tuple_list=None,
- seq_id=None,
- rebuttal=False,
- shuffle_seed=-1,
- kf_every=1,
- *args,
- ROOT,
- **kwargs,
- ):
-
- self.ROOT = ROOT
- super().__init__(*args, **kwargs)
- self.num_seq = num_seq
- self.num_frames = num_frames
- self.max_thresh = max_thresh
- self.min_thresh = min_thresh
- self.test_id = test_id
- self.full_video = full_video
- self.kf_every = kf_every
- self.seq_id = seq_id
- self.rebuttal = rebuttal
- self.shuffle_seed = shuffle_seed
-
- # load all scenes
- self.load_all_tuples(tuple_list)
- self.load_all_scenes(ROOT)
-
- def __len__(self):
- if self.tuple_list is not None:
- return len(self.tuple_list)
- return len(self.scene_list) * self.num_seq
-
- def load_all_tuples(self, tuple_list):
- if tuple_list is not None:
- self.tuple_list = tuple_list
- # with open(tuple_path) as f:
- # self.tuple_list = f.read().splitlines()
-
- else:
- self.tuple_list = None
-
- def load_all_scenes(self, base_dir):
-
- scenes = [
- d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))
- ]
-
- if self.test_id is not None:
- self.scene_list = [self.test_id]
-
- else:
- self.scene_list = scenes
-
- print(f"Found {len(self.scene_list)} sequences in split {self.split}")
-
- def load_poses(self, path):
- file = open(path, "r")
- lines = file.readlines()
- file.close()
- poses = []
- valid = []
- lines_per_matrix = 4
- for i in range(0, len(lines), lines_per_matrix):
- if "nan" in lines[i]:
- valid.append(False)
- poses.append(np.eye(4, 4, dtype=np.float32).tolist())
- else:
- valid.append(True)
- pose_floats = [
- [float(x) for x in line.split()]
- for line in lines[i : i + lines_per_matrix]
- ]
- poses.append(pose_floats)
-
- return np.array(poses, dtype=np.float32), valid
-
- def _get_views(self, idx, resolution, rng):
-
- if self.tuple_list is not None:
- line = self.tuple_list[idx].split(" ")
- scene_id = line[0]
- img_idxs = line[1:]
-
- else:
- scene_id = self.scene_list[idx // self.num_seq]
-
- num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images")))
- img_idxs = [f"{i}" for i in range(num_files)]
- img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)]
-
- fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240
- intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
-
- posepath = osp.join(self.ROOT, scene_id, f"poses.txt")
- camera_poses, valids = self.load_poses(posepath)
-
- imgs_idxs = deque(img_idxs)
- if self.shuffle_seed >= 0:
- imgs_idxs = shuffle_deque(imgs_idxs)
- views = []
-
- while len(imgs_idxs) > 0:
- im_idx = imgs_idxs.popleft()
-
- impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png")
- depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png")
-
- rgb_image = imread_cv2(impath)
- depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
- depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0
- depthmap[depthmap > 10] = 0
- depthmap[depthmap < 1e-3] = 0
-
- rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0]))
-
- camera_pose = camera_poses[int(im_idx)]
- # gl to cv
- camera_pose[:, 1:3] *= -1.0
- if resolution != (224, 224) or self.rebuttal:
- rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
- rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath
- )
- else:
- rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
- rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath
- )
- W, H = rgb_image.size
- cx = W // 2
- cy = H // 2
- l, t = cx - 112, cy - 112
- r, b = cx + 112, cy + 112
- crop_bbox = (l, t, r, b)
- rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap(
- rgb_image, depthmap, intrinsics, crop_bbox
- )
-
- views.append(
- dict(
- img=rgb_image,
- depthmap=depthmap,
- camera_pose=camera_pose,
- camera_intrinsics=intrinsics,
- dataset="nrgbd",
- label=osp.join(scene_id, im_idx),
- instance=impath,
- )
- )
-
- return views
diff --git a/FastVGGT/eval/dataset_utils/__init__.py b/FastVGGT/eval/dataset_utils/__init__.py
deleted file mode 100644
index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/dataset_utils/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index d00b3e8b6f5fd9ef7d1428ff771243e1557ddf9f..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc
deleted file mode 100644
index 9ebd5d22d44ddce633e9a05a7467e9ae0c3f9032..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc and /dev/null differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc
deleted file mode 100644
index fbd10218eddf65486b375109c23cf8cecd7c3ba5..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc
deleted file mode 100644
index 860c8970fc49d65b2a4101b77cc313a6aeac618a..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc
deleted file mode 100644
index 0fe7870fb079d9dc08ca647866ff8c77fd0f2ec0..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc and /dev/null differ
diff --git a/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc
deleted file mode 100644
index 1b4530be10547e3e7bf79e6cef6470880cb710a4..0000000000000000000000000000000000000000
Binary files a/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/eval/dataset_utils/corr.py b/FastVGGT/eval/dataset_utils/corr.py
deleted file mode 100644
index fbf5de18b7388e38e45c3f313487957811abc483..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/dataset_utils/corr.py
+++ /dev/null
@@ -1,234 +0,0 @@
-# Copyright (C) 2024-present Naver Corporation. All rights reserved.
-# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
-#
-# --------------------------------------------------------
-
-import numpy as np
-import torch
-
-
-def todevice(batch, device, callback=None, non_blocking=False):
- """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
-
- batch: list, tuple, dict of tensors or other things
- device: pytorch device or 'numpy'
- callback: function that would be called on every sub-elements.
- """
- if callback:
- batch = callback(batch)
-
- if isinstance(batch, dict):
- return {k: todevice(v, device) for k, v in batch.items()}
-
- if isinstance(batch, (tuple, list)):
- return type(batch)(todevice(x, device) for x in batch)
-
- x = batch
- if device == "numpy":
- if isinstance(x, torch.Tensor):
- x = x.detach().cpu().numpy()
- elif x is not None:
- if isinstance(x, np.ndarray):
- x = torch.from_numpy(x)
- if torch.is_tensor(x):
- x = x.to(device, non_blocking=non_blocking)
- return x
-
-
-to_device = todevice # alias
-
-
-def to_numpy(x):
- return todevice(x, "numpy")
-
-
-def geotrf(Trf, pts, ncol=None, norm=False):
- """Apply a geometric transformation to a list of 3-D points.
-
- H: 3x3 or 4x4 projection matrix (typically a Homography)
- p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
-
- ncol: int. number of columns of the result (2 or 3)
- norm: float. if != 0, the resut is projected on the z=norm plane.
-
- Returns an array of projected 2d points.
- """
- assert Trf.ndim >= 2
- if isinstance(Trf, np.ndarray):
- pts = np.asarray(pts)
- elif isinstance(Trf, torch.Tensor):
- pts = torch.as_tensor(pts, dtype=Trf.dtype)
-
- output_reshape = pts.shape[:-1]
- ncol = ncol or pts.shape[-1]
-
- if (
- isinstance(Trf, torch.Tensor)
- and isinstance(pts, torch.Tensor)
- and Trf.ndim == 3
- and pts.ndim == 4
- ):
- d = pts.shape[3]
- if Trf.shape[-1] == d:
- pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
- elif Trf.shape[-1] == d + 1:
- pts = (
- torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
- + Trf[:, None, None, :d, d]
- )
- else:
- raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
- else:
- if Trf.ndim >= 3:
- n = Trf.ndim - 2
- assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
- Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
-
- if pts.ndim > Trf.ndim:
-
- pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
- elif pts.ndim == 2:
-
- pts = pts[:, None, :]
-
- if pts.shape[-1] + 1 == Trf.shape[-1]:
- Trf = Trf.swapaxes(-1, -2) # transpose Trf
- pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
- elif pts.shape[-1] == Trf.shape[-1]:
- Trf = Trf.swapaxes(-1, -2) # transpose Trf
- pts = pts @ Trf
- else:
- pts = Trf @ pts.T
- if pts.ndim >= 2:
- pts = pts.swapaxes(-1, -2)
-
- if norm:
- pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
- if norm != 1:
- pts *= norm
-
- res = pts[..., :ncol].reshape(*output_reshape, ncol)
- return res
-
-
-def inv(mat):
- """Invert a torch or numpy matrix"""
- if isinstance(mat, torch.Tensor):
- return torch.linalg.inv(mat)
- if isinstance(mat, np.ndarray):
- return np.linalg.inv(mat)
- raise ValueError(f"bad matrix type = {type(mat)}")
-
-
-def reproject_view(pts3d, view2):
- shape = view2["pts3d"].shape[:2]
- return reproject(
- pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
- )
-
-
-def reproject(pts3d, K, world2cam, shape):
- H, W, THREE = pts3d.shape
- assert THREE == 3
-
- with np.errstate(divide="ignore", invalid="ignore"):
- pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
-
- return (H, W), ravel_xy(pos, shape)
-
-
-def ravel_xy(pos, shape):
- H, W = shape
- with np.errstate(invalid="ignore"):
- qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
- quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
- min=0, max=H - 1, out=qy
- )
- return quantized_pos
-
-
-def unravel_xy(pos, shape):
-
- return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
-
-
-def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
- is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
- pos1 = is_reciprocal1.nonzero()[0]
- pos2 = corres_1_to_2[pos1]
- if ret_recip:
- return is_reciprocal1, pos1, pos2
- return pos1, pos2
-
-
-def extract_correspondences_from_pts3d(
- view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
-):
- view1, view2 = to_numpy((view1, view2))
-
- shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
- shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
-
- is_reciprocal1, pos1, pos2 = reciprocal_1d(
- corres1_to_2, corres2_to_1, ret_recip=True
- )
- is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
-
- if target_n_corres is None:
- if ret_xy:
- pos1 = unravel_xy(pos1, shape1)
- pos2 = unravel_xy(pos2, shape2)
- return pos1, pos2
-
- available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
- target_n_positives = int(target_n_corres * (1 - nneg))
- n_positives = min(len(pos1), target_n_positives)
- n_negatives = min(target_n_corres - n_positives, available_negatives)
-
- if n_negatives + n_positives != target_n_corres:
-
- n_positives = target_n_corres - n_negatives
- assert n_positives <= len(pos1)
-
- assert n_positives <= len(pos1)
- assert n_positives <= len(pos2)
- assert n_negatives <= (~is_reciprocal1).sum()
- assert n_negatives <= (~is_reciprocal2).sum()
- assert n_positives + n_negatives == target_n_corres
-
- valid = np.ones(n_positives, dtype=bool)
- if n_positives < len(pos1):
-
- perm = rng.permutation(len(pos1))[:n_positives]
- pos1 = pos1[perm]
- pos2 = pos2[perm]
-
- if n_negatives > 0:
-
- def norm(p):
- return p / p.sum()
-
- pos1 = np.r_[
- pos1,
- rng.choice(
- shape1[0] * shape1[1],
- size=n_negatives,
- replace=False,
- p=norm(~is_reciprocal1),
- ),
- ]
- pos2 = np.r_[
- pos2,
- rng.choice(
- shape2[0] * shape2[1],
- size=n_negatives,
- replace=False,
- p=norm(~is_reciprocal2),
- ),
- ]
- valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
-
- if ret_xy:
- pos1 = unravel_xy(pos1, shape1)
- pos2 = unravel_xy(pos2, shape2)
- return pos1, pos2, valid
diff --git a/FastVGGT/eval/dataset_utils/cropping.py b/FastVGGT/eval/dataset_utils/cropping.py
deleted file mode 100644
index 30a9eac18d241b71538957cf7ba4767ebc323b43..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/dataset_utils/cropping.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# Copyright (C) 2024-present Naver Corporation. All rights reserved.
-# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
-#
-# --------------------------------------------------------
-
-import PIL.Image
-import os
-
-from utils import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics
-
-os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
-import cv2 # noqa
-import numpy as np # noqa
-
-try:
- lanczos = PIL.Image.Resampling.LANCZOS
- bicubic = PIL.Image.Resampling.BICUBIC
-except AttributeError:
- lanczos = PIL.Image.LANCZOS
- bicubic = PIL.Image.BICUBIC
-
-
-class ImageList:
- """Convenience class to aply the same operation to a whole set of images."""
-
- def __init__(self, images):
- if not isinstance(images, (tuple, list, set)):
- images = [images]
- self.images = []
- for image in images:
- if not isinstance(image, PIL.Image.Image):
- image = PIL.Image.fromarray(image)
- self.images.append(image)
-
- def __len__(self):
- return len(self.images)
-
- def to_pil(self):
- return tuple(self.images) if len(self.images) > 1 else self.images[0]
-
- @property
- def size(self):
- sizes = [im.size for im in self.images]
- assert all(sizes[0] == s for s in sizes)
- return sizes[0]
-
- def resize(self, *args, **kwargs):
- return ImageList(self._dispatch("resize", *args, **kwargs))
-
- def crop(self, *args, **kwargs):
- return ImageList(self._dispatch("crop", *args, **kwargs))
-
- def _dispatch(self, func, *args, **kwargs):
- return [getattr(im, func)(*args, **kwargs) for im in self.images]
-
-
-def rescale_image_depthmap(
- image, depthmap, camera_intrinsics, output_resolution, force=True
-):
- """Jointly rescale a (image, depthmap)
- so that (out_width, out_height) >= output_res
- """
- image = ImageList(image)
- input_resolution = np.array(image.size) # (W,H)
- output_resolution = np.array(output_resolution)
- if depthmap is not None:
-
- assert tuple(depthmap.shape[:2]) == image.size[::-1]
-
- assert output_resolution.shape == (2,)
- scale_final = max(output_resolution / image.size) + 1e-8
- if scale_final >= 1 and not force: # image is already smaller than what is asked
- return (image.to_pil(), depthmap, camera_intrinsics)
- output_resolution = np.floor(input_resolution * scale_final).astype(int)
-
- image = image.resize(
- output_resolution, resample=lanczos if scale_final < 1 else bicubic
- )
- if depthmap is not None:
- depthmap = cv2.resize(
- depthmap,
- output_resolution,
- fx=scale_final,
- fy=scale_final,
- interpolation=cv2.INTER_NEAREST,
- )
-
- camera_intrinsics = camera_matrix_of_crop(
- camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
- )
-
- return image.to_pil(), depthmap, camera_intrinsics
-
-
-def camera_matrix_of_crop(
- input_camera_matrix,
- input_resolution,
- output_resolution,
- scaling=1,
- offset_factor=0.5,
- offset=None,
-):
-
- margins = np.asarray(input_resolution) * scaling - output_resolution
- assert np.all(margins >= 0.0)
- if offset is None:
- offset = offset_factor * margins
-
- output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
- output_camera_matrix_colmap[:2, :] *= scaling
- output_camera_matrix_colmap[:2, 2] -= offset
- output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
-
- return output_camera_matrix
-
-
-def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
- """
- Return a crop of the input view.
- """
- image = ImageList(image)
- l, t, r, b = crop_bbox
-
- image = image.crop((l, t, r, b))
- depthmap = depthmap[t:b, l:r]
-
- camera_intrinsics = camera_intrinsics.copy()
- camera_intrinsics[0, 2] -= l
- camera_intrinsics[1, 2] -= t
-
- return image.to_pil(), depthmap, camera_intrinsics
-
-
-def bbox_from_intrinsics_in_out(
- input_camera_matrix, output_camera_matrix, output_resolution
-):
- out_width, out_height = output_resolution
- l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
- crop_bbox = (l, t, l + out_width, t + out_height)
- return crop_bbox
diff --git a/FastVGGT/eval/dataset_utils/transforms.py b/FastVGGT/eval/dataset_utils/transforms.py
deleted file mode 100644
index cec2d144e1b97cc99c191d15cdcaf20796cae94b..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/dataset_utils/transforms.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright (C) 2024-present Naver Corporation. All rights reserved.
-# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
-#
-# --------------------------------------------------------
-
-import torchvision.transforms as tvf
-
-ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
-
-
-ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
-
-
-def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
- if isinstance(value, (int, float)):
- if value < 0:
- raise ValueError(f"If is a single number, it must be non negative.")
- value = [center - float(value), center + float(value)]
- if clip_first_on_zero:
- value[0] = max(value[0], 0.0)
- elif isinstance(value, (tuple, list)) and len(value) == 2:
- value = [float(value[0]), float(value[1])]
- else:
- raise TypeError(f"should be a single number or a list/tuple with length 2.")
-
- if not bound[0] <= value[0] <= value[1] <= bound[1]:
- raise ValueError(f"values should be between {bound}, but got {value}.")
-
- if value[0] == value[1] == center:
- return None
- else:
- return tuple(value)
-
-
-import torch
-import torchvision.transforms.functional as F
-
-
-def SeqColorJitter():
- """
- Return a color jitter transform with same random parameters
- """
- brightness = _check_input(0.5)
- contrast = _check_input(0.5)
- saturation = _check_input(0.5)
- hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)
-
- fn_idx = torch.randperm(4)
- brightness_factor = (
- None
- if brightness is None
- else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
- )
- contrast_factor = (
- None
- if contrast is None
- else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
- )
- saturation_factor = (
- None
- if saturation is None
- else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
- )
- hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
-
- def _color_jitter(img):
- for fn_id in fn_idx:
- if fn_id == 0 and brightness_factor is not None:
- img = F.adjust_brightness(img, brightness_factor)
- elif fn_id == 1 and contrast_factor is not None:
- img = F.adjust_contrast(img, contrast_factor)
- elif fn_id == 2 and saturation_factor is not None:
- img = F.adjust_saturation(img, saturation_factor)
- elif fn_id == 3 and hue_factor is not None:
- img = F.adjust_hue(img, hue_factor)
- return ImgNorm(img)
-
- return _color_jitter
diff --git a/FastVGGT/eval/eval_7andN.py b/FastVGGT/eval/eval_7andN.py
deleted file mode 100644
index 8d6d5fbf6c159ea18bfd6fbcee0a5d23e63c0cf2..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/eval_7andN.py
+++ /dev/null
@@ -1,497 +0,0 @@
-import os
-import sys
-
-# Ensure project root is on sys.path for absolute imports like `vggt.*`
-ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
-if ROOT_DIR not in sys.path:
- sys.path.insert(0, ROOT_DIR)
-
-import time
-import torch
-import argparse
-import numpy as np
-import open3d as o3d
-import os.path as osp
-from torch.utils.data import DataLoader
-from torch.utils.data._utils.collate import default_collate
-from tqdm import tqdm
-from collections import defaultdict
-import torchvision.transforms as transforms
-
-
-def get_args_parser():
- parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False)
- parser.add_argument(
- "--ckpt_path",
- type=str,
- default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
- help="ckpt name",
- )
- parser.add_argument("--device", type=str, default="cuda:0", help="device")
- parser.add_argument("--model_name", type=str, default="VGGT")
- parser.add_argument(
- "--conf_thresh", type=float, default=0.0, help="confidence threshold"
- )
- parser.add_argument(
- "--output_dir",
- type=str,
- default="/home/sy/code/FastVGGT/eval_results",
- help="value for outdir",
- )
- parser.add_argument("--size", type=int, default=518)
- parser.add_argument("--revisit", type=int, default=1, help="revisit times")
- parser.add_argument("--freeze", action="store_true")
- parser.add_argument("--use_proj", action="store_true")
- parser.add_argument(
- "--merging", type=int, default=0, help="VGGT aggregator merging steps"
- )
- parser.add_argument("--kf", type=int, default=2, help="key frame")
- return parser
-
-
-def main(args):
- from data import SevenScenes, NRGBD
- from utils import accuracy, completion
-
- if args.size == 512:
- resolution = (512, 384)
- elif args.size == 224:
- resolution = 224
- elif args.size == 518:
- resolution = (518, 392)
- else:
- raise NotImplementedError
- datasets_all = {
- "7scenes": SevenScenes(
- split="test",
- ROOT="/data/sy/7scenes",
- resolution=resolution,
- num_seq=1,
- full_video=True,
- kf_every=args.kf,
- ), # 20),
- "NRGBD": NRGBD(
- split="test",
- ROOT="/data/sy/neural_rgbd_data",
- resolution=resolution,
- num_seq=1,
- full_video=True,
- kf_every=args.kf,
- ),
- }
-
- device = args.device
- model_name = args.model_name
-
- from vggt.models.vggt import VGGT
- from vggt.utils.pose_enc import pose_encoding_to_extri_intri
- from vggt.utils.geometry import unproject_depth_map_to_point_map
- from criterion import Regr3D_t_ScaleShiftInv, L21
-
- # Force use of bf16 data type
- dtype = torch.bfloat16
- # Load VGGT model
- model = VGGT(merging=args.merging, enable_point=True)
- ckpt = torch.load(args.ckpt_path, map_location="cpu")
-
- # ✅ Fix: load pre-trained weights
- model.load_state_dict(
- ckpt, strict=False
- ) # Use strict=False due to enable_point=True difference
-
- model = model.cuda().eval()
- model = model.to(torch.bfloat16)
-
- del ckpt
- os.makedirs(osp.join(args.output_dir, f"{args.kf}"), exist_ok=True)
-
- criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True)
-
- with torch.no_grad():
- for name_data, dataset in datasets_all.items():
- save_path = osp.join(osp.join(args.output_dir, f"{args.kf}"), name_data)
- os.makedirs(save_path, exist_ok=True)
- log_file = osp.join(save_path, "logs.txt")
-
- acc_all = 0
- acc_all_med = 0
- comp_all = 0
- comp_all_med = 0
- nc1_all = 0
- nc1_all_med = 0
- nc2_all = 0
- nc2_all_med = 0
- scene_infer_times = defaultdict(list)
-
- for data_idx in tqdm(range(len(dataset))):
- batch = default_collate([dataset[data_idx]])
- ignore_keys = set(
- [
- "depthmap",
- "dataset",
- "label",
- "instance",
- "idx",
- "true_shape",
- "rng",
- ]
- )
- for view in batch:
- for name in view.keys(): # pseudo_focal
- if name in ignore_keys:
- continue
- if isinstance(view[name], tuple) or isinstance(
- view[name], list
- ):
- view[name] = [
- x.to(device, non_blocking=True) for x in view[name]
- ]
- else:
- view[name] = view[name].to(device, non_blocking=True)
-
- pts_all = []
- pts_gt_all = []
- images_all = []
- masks_all = []
- conf_all = []
- in_camera1 = None
-
- dtype = (
- torch.bfloat16
- if torch.cuda.get_device_capability()[0] >= 8
- else torch.float16
- )
- with torch.cuda.amp.autocast(dtype=dtype):
- if isinstance(batch, dict) and "img" in batch:
- batch["img"] = (batch["img"] + 1.0) / 2.0
- elif isinstance(batch, list) and all(
- isinstance(v, dict) and "img" in v for v in batch
- ):
- for view in batch:
- view["img"] = (view["img"] + 1.0) / 2.0
- # Gather all `img` tensors into a single tensor of shape [N, C, H, W]
- imgs_tensor = torch.cat([v["img"] for v in batch], dim=0)
-
- with torch.cuda.amp.autocast(dtype=dtype):
- with torch.no_grad():
- torch.cuda.synchronize()
- start = time.time()
- preds = model(imgs_tensor)
- torch.cuda.synchronize()
- end = time.time()
- inference_time_ms = (end - start) * 1000
- print(f"Inference time: {inference_time_ms:.2f}ms")
-
- # Wrap model outputs per-view to align with batch later
- predictions = preds
- views = batch # list[dict]
- if "pose_enc" in predictions:
- B, S = predictions["pose_enc"].shape[:2]
- elif "world_points" in predictions:
- B, S = predictions["world_points"].shape[:2]
- else:
- raise KeyError(
- "predictions is missing a key to infer sequence length"
- )
-
- ress = []
- for s in range(S):
- res = {
- "pts3d_in_other_view": predictions["world_points"][:, s],
- "conf": predictions["world_points_conf"][:, s],
- "depth": predictions["depth"][:, s],
- "depth_conf": predictions["depth_conf"][:, s],
- "camera_pose": predictions["pose_enc"][:, s, :],
- }
- if (
- isinstance(views, list)
- and s < len(views)
- and "valid_mask" in views[s]
- ):
- res["valid_mask"] = views[s]["valid_mask"]
- if "track" in predictions:
- res.update(
- {
- "track": predictions["track"][:, s],
- "vis": (
- predictions.get("vis", None)[:, s]
- if "vis" in predictions
- else None
- ),
- "track_conf": (
- predictions.get("conf", None)[:, s]
- if "conf" in predictions
- else None
- ),
- }
- )
- ress.append(res)
-
- preds = ress
-
- valid_length = len(preds) // args.revisit
- if args.revisit > 1:
- preds = preds[-valid_length:]
- batch = batch[-valid_length:]
-
- # Evaluation
- print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}")
- gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = (
- criterion.get_all_pts3d_t(batch, preds)
- )
-
- in_camera1 = None
- pts_all = []
- pts_gt_all = []
- images_all = []
- masks_all = []
- conf_all = []
-
- for j, view in enumerate(batch):
- if in_camera1 is None:
- in_camera1 = view["camera_pose"][0].cpu()
-
- image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0]
- mask = view["valid_mask"].cpu().numpy()[0]
-
- pts = pred_pts[j].cpu().numpy()[0]
- conf = preds[j]["conf"].cpu().data.numpy()[0]
-
- # mask = mask & (conf > 1.8)
-
- pts_gt = gt_pts[j].detach().cpu().numpy()[0]
-
- H, W = image.shape[:2]
- cx = W // 2
- cy = H // 2
- l, t = cx - 112, cy - 112
- r, b = cx + 112, cy + 112
- image = image[t:b, l:r]
- mask = mask[t:b, l:r]
- pts = pts[t:b, l:r]
- pts_gt = pts_gt[t:b, l:r]
-
- images_all.append(image[None, ...])
- pts_all.append(pts[None, ...])
- pts_gt_all.append(pts_gt[None, ...])
- masks_all.append(mask[None, ...])
- conf_all.append(conf[None, ...])
-
- images_all = np.concatenate(images_all, axis=0)
- pts_all = np.concatenate(pts_all, axis=0)
- pts_gt_all = np.concatenate(pts_gt_all, axis=0)
- masks_all = np.concatenate(masks_all, axis=0)
-
- scene_id = view["label"][0].rsplit("/", 1)[0]
- # Record average inference time per scene
- try:
- scene_infer_times[scene_id].append(float(inference_time_ms))
- except Exception:
- pass
-
- save_params = {}
-
- save_params["images_all"] = images_all
- save_params["pts_all"] = pts_all
- save_params["pts_gt_all"] = pts_gt_all
- save_params["masks_all"] = masks_all
-
- pts_all_masked = pts_all[masks_all > 0]
- pts_gt_all_masked = pts_gt_all[masks_all > 0]
- images_all_masked = images_all[masks_all > 0]
-
- mask = np.isfinite(pts_all_masked)
- pts_all_masked = pts_all_masked[mask]
-
- mask_gt = np.isfinite(pts_gt_all_masked)
- pts_gt_all_masked = pts_gt_all_masked[mask_gt]
- images_all_masked = images_all_masked[mask]
-
- # Reshape to point cloud (N, 3) before sampling
- pts_all_masked = pts_all_masked.reshape(-1, 3)
- pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
- images_all_masked = images_all_masked.reshape(-1, 3)
-
- # If number of points exceeds threshold, sample by points
- if pts_all_masked.shape[0] > 999999:
- sample_indices = np.random.choice(
- pts_all_masked.shape[0], 999999, replace=False
- )
- pts_all_masked = pts_all_masked[sample_indices]
- images_all_masked = images_all_masked[sample_indices]
-
- # Apply the same sampling to GT point cloud
- if pts_gt_all_masked.shape[0] > 999999:
- sample_indices_gt = np.random.choice(
- pts_gt_all_masked.shape[0], 999999, replace=False
- )
- pts_gt_all_masked = pts_gt_all_masked[sample_indices_gt]
-
- if args.use_proj:
-
- def umeyama_alignment(
- src: np.ndarray, dst: np.ndarray, with_scale: bool = True
- ):
- assert src.shape == dst.shape
- N, dim = src.shape
-
- mu_src = src.mean(axis=0)
- mu_dst = dst.mean(axis=0)
- src_c = src - mu_src
- dst_c = dst - mu_dst
-
- Sigma = dst_c.T @ src_c / N # (3,3)
-
- U, D, Vt = np.linalg.svd(Sigma)
-
- S = np.eye(dim)
- if np.linalg.det(U) * np.linalg.det(Vt) < 0:
- S[-1, -1] = -1
-
- R = U @ S @ Vt
-
- if with_scale:
- var_src = (src_c**2).sum() / N
- s = (D * S.diagonal()).sum() / var_src
- else:
- s = 1.0
-
- t = mu_dst - s * R @ mu_src
-
- return s, R, t
-
- pts_all_masked = pts_all_masked.reshape(-1, 3)
- pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3)
- s, R, t = umeyama_alignment(
- pts_all_masked, pts_gt_all_masked, with_scale=True
- )
- pts_all_aligned = (s * (R @ pts_all_masked.T)).T + t # (N,3)
- pts_all_masked = pts_all_aligned
-
- pcd = o3d.geometry.PointCloud()
- pcd.points = o3d.utility.Vector3dVector(pts_all_masked)
- pcd.colors = o3d.utility.Vector3dVector(images_all_masked)
-
- pcd_gt = o3d.geometry.PointCloud()
- pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked)
- pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked)
-
- trans_init = np.eye(4)
-
- threshold = 0.1
- reg_p2p = o3d.pipelines.registration.registration_icp(
- pcd,
- pcd_gt,
- threshold,
- trans_init,
- o3d.pipelines.registration.TransformationEstimationPointToPoint(),
- )
-
- transformation = reg_p2p.transformation
-
- pcd = pcd.transform(transformation)
- pcd.estimate_normals()
- pcd_gt.estimate_normals()
-
- gt_normal = np.asarray(pcd_gt.normals)
- pred_normal = np.asarray(pcd.normals)
-
- acc, acc_med, nc1, nc1_med = accuracy(
- pcd_gt.points, pcd.points, gt_normal, pred_normal
- )
- comp, comp_med, nc2, nc2_med = completion(
- pcd_gt.points, pcd.points, gt_normal, pred_normal
- )
- print(
- f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}"
- )
- print(
- f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}",
- file=open(log_file, "a"),
- )
-
- acc_all += acc
- comp_all += comp
- nc1_all += nc1
- nc2_all += nc2
-
- acc_all_med += acc_med
- comp_all_med += comp_med
- nc1_all_med += nc1_med
- nc2_all_med += nc2_med
-
- # release cuda memory
- torch.cuda.empty_cache()
-
- # Get depth from pcd and run TSDFusion
- to_write = ""
- # Read the log file
- if os.path.exists(osp.join(save_path, "logs.txt")):
- with open(osp.join(save_path, "logs.txt"), "r") as f_sub:
- to_write += f_sub.read()
-
- with open(osp.join(save_path, f"logs_all.txt"), "w") as f:
- log_data = to_write
- metrics = defaultdict(list)
- for line in log_data.strip().split("\n"):
- match = regex.match(line)
- if match:
- data = match.groupdict()
- # Exclude 'scene_id' from metrics as it's an identifier
- for key, value in data.items():
- if key != "scene_id":
- metrics[key].append(float(value))
- metrics["nc"].append(
- (float(data["nc1"]) + float(data["nc2"])) / 2
- )
- metrics["nc_med"].append(
- (float(data["nc1_med"]) + float(data["nc2_med"])) / 2
- )
- mean_metrics = {
- metric: sum(values) / len(values)
- for metric, values in metrics.items()
- }
-
- c_name = "mean"
- print_str = f"{c_name.ljust(20)}: "
- for m_name in mean_metrics:
- print_num = np.mean(mean_metrics[m_name])
- print_str = print_str + f"{m_name}: {print_num:.3f} | "
- print_str = print_str + "\n"
- # Summarize per-scene average inference time
- time_lines = []
- for sid, times in scene_infer_times.items():
- if len(times) > 0:
- time_lines.append(
- f"Idx: {sid}, Time_avg_ms: {np.mean(times):.2f}"
- )
- time_block = "\n".join(time_lines) + (
- "\n" if len(time_lines) > 0 else ""
- )
-
- f.write(to_write + time_block + print_str)
-
-
-from collections import defaultdict
-import re
-
-pattern = r"""
- Idx:\s*(?P[^,]+),\s*
- Acc:\s*(?P[^,]+),\s*
- Comp:\s*(?P[^,]+),\s*
- NC1:\s*(?P[^,]+),\s*
- NC2:\s*(?P[^,]+)\s*-\s*
- Acc_med:\s*(?P[^,]+),\s*
- Compc_med:\s*(?P[^,]+),\s*
- NC1c_med:\s*(?P[^,]+),\s*
- NC2c_med:\s*(?P[^,]+)
-"""
-
-regex = re.compile(pattern, re.VERBOSE)
-
-
-if __name__ == "__main__":
- parser = get_args_parser()
- args = parser.parse_args()
-
- main(args)
diff --git a/FastVGGT/eval/eval_custom.py b/FastVGGT/eval/eval_custom.py
deleted file mode 100644
index a9e4f7ccdd6c22c2c3bef6421c52ace01d362a83..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/eval_custom.py
+++ /dev/null
@@ -1,467 +0,0 @@
-import argparse
-from pathlib import Path
-import numpy as np
-import torch
-import os
-import sys
-import matplotlib.pyplot as plt
-from scipy.spatial.transform import Rotation
-
-# Ensure project root is in sys.path for absolute imports like `vggt.*`
-ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
-if ROOT_DIR not in sys.path:
- sys.path.insert(0, ROOT_DIR)
-
-from vggt.models.vggt import VGGT
-from vggt.utils.eval_utils import (
- load_poses,
- get_vgg_input_imgs,
- get_sorted_image_paths,
- build_frame_selection,
- load_images_rgb,
- infer_vggt_and_reconstruct,
- evaluate_scene_and_save,
-)
-
-# Import pose visualization libraries (optional EVO support)
-try:
- from evo.core.trajectory import PoseTrajectory3D
- import evo.tools.plot as plot
-
- EVO_AVAILABLE = True
-except ImportError:
- # EVO is optional; we have a matplotlib-based fallback
- EVO_AVAILABLE = False
-
-
-def visualize_predicted_poses(
- all_cam_to_world_mat, frame_ids, output_scene_dir, scene_name="custom_dataset"
-):
- """
- Visualize the predicted camera pose trajectory (no GT comparison required).
-
- Args:
- all_cam_to_world_mat: List of camera-to-world transform matrices
- frame_ids: List of frame IDs
- output_scene_dir: Output directory
- scene_name: Scene name
- """
- # Provide basic pose visualization even without EVO
- if not EVO_AVAILABLE:
- print("⚠️ EVO not installed; using basic matplotlib visualization")
-
- try:
- # Convert to numpy array
- poses_est = np.array(all_cam_to_world_mat)
-
- if len(poses_est) < 2:
- print("⚠️ Not enough poses to generate trajectory plot")
- return
-
- print(f"🎨 Generating pose trajectory visualization...")
-
- # Extract translation part
- positions = poses_est[:, :3, 3] # shape: (N, 3)
-
- # Create figure - show XZ-plane projection only
- fig, ax = plt.subplots(1, 1, figsize=(10, 8))
-
- # XZ-plane projection
- ax.plot(
- positions[:, 0],
- positions[:, 2],
- "b-",
- linewidth=2,
- label="Predicted Trajectory",
- )
- ax.scatter(
- positions[0, 0], positions[0, 2], color="green", s=100, label="Start"
- )
- ax.scatter(positions[-1, 0], positions[-1, 2], color="red", s=100, label="End")
- ax.set_xlabel("X (m)")
- ax.set_ylabel("Z (m)")
- ax.set_title(f"{scene_name} - XZ-plane projection")
- ax.legend()
- ax.grid(True, alpha=0.3)
-
- # Save image
- pose_plot_path = output_scene_dir / "predicted_trajectory.png"
- plt.savefig(pose_plot_path, dpi=300, bbox_inches="tight")
- plt.close()
-
- print(f"📊 Trajectory visualization saved: {pose_plot_path}")
-
- except Exception as e:
- print(f"⚠️ Failed to generate pose visualization: {e}")
- import traceback
-
- traceback.print_exc()
-
-
-def main():
- """
- Evaluation script for a Custom Dataset.
- Supports optional evaluation and custom dataset structure.
- """
- parser = argparse.ArgumentParser(
- description="Run FastVGGT evaluation on a Custom Dataset"
- )
-
- # Required: dataset path
- parser.add_argument(
- "--data_path",
- type=Path,
- required=True,
- help="Dataset path containing subfolders: color, depth, gt_ply, pose",
- )
-
- # Optional: enable evaluation
- parser.add_argument(
- "--enable_evaluation",
- action="store_true",
- help="Enable evaluation (requires pose and ply data)",
- )
-
- # Output path
- parser.add_argument(
- "--output_path",
- type=Path,
- default="./eval_results_custom",
- help="Output path for evaluation results",
- )
-
- # Model parameters
- parser.add_argument(
- "--ckpt_path",
- type=str,
- default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt",
- help="Model checkpoint file path",
- )
-
- parser.add_argument("--merging", type=int, default=0, help="Merging parameter")
-
- # Processing parameters
- parser.add_argument(
- "--input_frame",
- type=int,
- default=200,
- help="Maximum number of frames to process per scene",
- )
-
- parser.add_argument(
- "--depth_conf_thresh",
- type=float,
- default=3.0,
- help="Depth confidence threshold to filter low-confidence depth values",
- )
-
- # Evaluation parameters (only used when evaluation is enabled)
- parser.add_argument(
- "--chamfer_max_dist",
- type=float,
- default=0.5,
- help="Maximum distance threshold used in Chamfer Distance computation",
- )
-
- parser.add_argument("--plot", action="store_true", help="Whether to generate plots")
-
- parser.add_argument(
- "--vis_attn_map",
- action="store_true",
- help="Visualize attention maps during inference",
- )
-
- args = parser.parse_args()
- torch.manual_seed(33)
-
- # Check data path exists
- if not args.data_path.exists():
- print(f"❌ Error: Data path does not exist: {args.data_path}")
- return
-
- # Check required subdirectories
- color_dir = args.data_path / "images"
- pose_dir = args.data_path / "pose"
-
- if not color_dir.exists():
- print(f"❌ Error: color directory does not exist: {color_dir}")
- return
-
- print(f"📁 Dataset path: {args.data_path}")
- # print(f"🔧 Enable evaluation: {'Yes' if args.enable_evaluation else 'No'}")
-
- # If evaluation is enabled, check pose and gt_ply directories
- if args.enable_evaluation:
- if not pose_dir.exists():
- print(f"❌ Error: Evaluation requires pose directory: {pose_dir}")
- return
-
- gt_ply_dir = args.data_path / "gt_ply"
- if not gt_ply_dir.exists():
- print(f"❌ Error: Evaluation requires gt_ply directory: {gt_ply_dir}")
- return
- print(f"📊 Evaluation will use Ground Truth")
- else:
- print(f"🏃 Inference only, no evaluation")
-
- # Create output directory
- args.output_path.mkdir(parents=True, exist_ok=True)
- output_scene_dir = args.output_path / "custom_dataset"
-
- # Check if already processed
- if (output_scene_dir / "metrics.json").exists() and args.enable_evaluation:
- print(
- f"⚠️ Results already exist, skipping: {output_scene_dir / 'metrics.json'}"
- )
- return
-
- # Force use of bf16 dtype
- dtype = torch.bfloat16
-
- # Load VGGT model
- print(f"🔄 Loading model: {args.ckpt_path}")
- model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
- ckpt = torch.load(args.ckpt_path, map_location="cpu")
- incompat = model.load_state_dict(ckpt, strict=False)
- # if incompat.missing_keys or incompat.unexpected_keys:
- # print(f"⚠️ Partially incompatible keys when loading model: {incompat}")
- model = model.cuda().eval()
- model = model.to(torch.bfloat16)
- print(f"✅ Model loaded")
-
- # Load scene data
- image_paths = get_sorted_image_paths(color_dir)
- if len(image_paths) == 0:
- print(f"❌ Error: No images found in {color_dir}")
- return
-
- print(f"🖼️ Found {len(image_paths)} images")
-
- # Process pose data (if evaluation is enabled)
- poses_gt = None
- first_gt_pose = None
- available_pose_frame_ids = None
- c2ws = None
-
- if args.enable_evaluation:
- poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_dir)
- if (
- poses_gt is None
- or first_gt_pose is None
- or available_pose_frame_ids is None
- ):
- print(f"❌ Error: Failed to load pose data")
- return
- print(f"📐 Loaded {len(poses_gt)} poses")
-
- # Frame selection
- if args.enable_evaluation and available_pose_frame_ids is not None:
- # Use pose data for frame selection
- selected_frame_ids, selected_image_paths, selected_pose_indices = (
- build_frame_selection(
- image_paths, available_pose_frame_ids, args.input_frame
- )
- )
- c2ws = poses_gt[selected_pose_indices]
- image_paths = selected_image_paths
- else:
- # Simply take the first N frames
- num_frames = min(len(image_paths), args.input_frame)
- selected_frame_ids = list(range(num_frames))
- image_paths = image_paths[:num_frames]
-
- print(f"📋 Selected {len(image_paths)} frames for processing")
-
- try:
- # Load images
- print(f"🔄 Loading images...")
- images = load_images_rgb(image_paths)
-
- if not images or len(images) < 3:
- print(f"❌ Error: Not enough valid images (need at least 3)")
- return
-
- frame_ids = selected_frame_ids
- images_array = np.stack(images)
- vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
- print(f"📐 Image patch dimensions: {patch_width}x{patch_height}")
-
- # Update attention layer patch dimensions in the model
- model.update_patch_dimensions(patch_width, patch_height)
-
- # Inference + Reconstruction
- print(f"🚀 Start inference and reconstruction...")
- (
- extrinsic_np,
- intrinsic_np,
- all_world_points,
- all_point_colors,
- all_cam_to_world_mat,
- inference_time_ms,
- ) = infer_vggt_and_reconstruct(
- model, vgg_input, dtype, args.depth_conf_thresh, image_paths
- )
- print(f"⏱️ Inference time: {inference_time_ms:.2f}ms")
-
- # Check results
- if not all_cam_to_world_mat or not all_world_points:
- print(f"❌ Error: Failed to obtain valid camera poses or point clouds")
- return
-
- # print(f"✅ Inference done, obtained {len(all_world_points)} point sets")
-
- # Evaluation and saving
- if args.enable_evaluation:
- print(f"📊 Start evaluation...")
- gt_ply_dir = args.data_path / "gt_ply"
- metrics = evaluate_scene_and_save(
- "custom_dataset",
- c2ws,
- first_gt_pose,
- frame_ids,
- all_cam_to_world_mat,
- all_world_points,
- output_scene_dir,
- gt_ply_dir,
- args.chamfer_max_dist,
- inference_time_ms,
- args.plot,
- )
- if metrics is not None:
- print("📈 Evaluation results:")
- for key, value in metrics.items():
- if key in [
- "chamfer_distance",
- "ate",
- "are",
- "rpe_rot",
- "rpe_trans",
- "inference_time_ms",
- ]:
- print(f" {key}: {float(value):.4f}")
-
- # Also visualize predicted poses in evaluation branch
- if args.plot:
- visualize_predicted_poses(
- all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
- )
- else:
- # Save reconstruction only, no evaluation
- print(f"💾 Saving reconstruction...")
- output_scene_dir.mkdir(parents=True, exist_ok=True)
-
- # Save camera poses
- poses_output_path = output_scene_dir / "estimated_poses.txt"
- with open(poses_output_path, "w") as f:
- for i, pose in enumerate(all_cam_to_world_mat):
- f.write(f"# Frame {frame_ids[i]}\n")
- for row in pose:
- f.write(" ".join(map(str, row)) + "\n")
- f.write("\n")
-
- # Save point cloud
- if all_world_points:
- points_output_path = output_scene_dir / "reconstructed_points.ply"
-
- # Merge all frames' point clouds and colors
- try:
- merged_point_cloud = np.vstack(all_world_points)
- merged_colors = (
- np.vstack(all_point_colors).astype(np.uint8)
- if all_point_colors is not None and len(all_point_colors) > 0
- else None
- )
- print(
- f"📊 Merged point clouds: {len(all_world_points)} frames, total {len(merged_point_cloud)} points"
- )
-
- # If too many points, randomly sample 100000 points
- max_points = 100000
- if len(merged_point_cloud) > max_points:
- print(
- f"🔽 Too many points, randomly sampling {max_points} points..."
- )
- # Randomly choose indices
- indices = np.random.choice(
- len(merged_point_cloud), size=max_points, replace=False
- )
- merged_point_cloud = merged_point_cloud[indices]
- if merged_colors is not None:
- merged_colors = merged_colors[indices]
- print(
- f"✅ Sampling done, kept {len(merged_point_cloud)} points"
- )
-
- # Save as PLY (with color)
- with open(points_output_path, "w") as f:
- f.write("ply\n")
- f.write("format ascii 1.0\n")
- f.write(f"element vertex {len(merged_point_cloud)}\n")
- f.write("property float x\n")
- f.write("property float y\n")
- f.write("property float z\n")
- if merged_colors is not None:
- f.write("property uchar red\n")
- f.write("property uchar green\n")
- f.write("property uchar blue\n")
- f.write("end_header\n")
- if merged_colors is None:
- for point in merged_point_cloud:
- if not (np.isnan(point).any() or np.isinf(point).any()):
- f.write(
- f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n"
- )
- else:
- for point, color in zip(merged_point_cloud, merged_colors):
- # Check point validity
- if not (np.isnan(point).any() or np.isinf(point).any()):
- r = int(np.clip(color[0], 0, 255))
- g = int(np.clip(color[1], 0, 255))
- b = int(np.clip(color[2], 0, 255))
- f.write(
- f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r} {g} {b}\n"
- )
-
- print(f"💾 Point cloud saved to: {points_output_path}")
-
- except Exception as e:
- print(f"⚠️ Error saving point cloud: {e}")
- # If merge fails, try to log per-frame info
- print(f"🔍 Point cloud debug info:")
- for i, frame_points in enumerate(all_world_points):
- print(
- f" Frame {i}: {frame_points.shape if hasattr(frame_points, 'shape') else type(frame_points)}"
- )
- if (
- hasattr(frame_points, "shape")
- and len(frame_points.shape) >= 2
- ):
- print(
- f" Shape: {frame_points.shape}, Dtype: {frame_points.dtype}"
- )
- if frame_points.shape[0] > 0:
- print(
- f" Range: x[{np.min(frame_points[:, 0]):.3f}, {np.max(frame_points[:, 0]):.3f}] "
- f"y[{np.min(frame_points[:, 1]):.3f}, {np.max(frame_points[:, 1]):.3f}] "
- f"z[{np.min(frame_points[:, 2]):.3f}, {np.max(frame_points[:, 2]):.3f}]"
- )
-
- print(f"📁 Results saved to: {output_scene_dir}")
-
- # Visualize predicted pose trajectory
- if args.plot:
- visualize_predicted_poses(
- all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset"
- )
-
- print(f"🎉 Done!")
-
- except Exception as e:
- print(f"❌ Error occurred during processing: {e}")
- import traceback
-
- traceback.print_exc()
-
-
-if __name__ == "__main__":
- main()
diff --git a/FastVGGT/eval/eval_scannet.py b/FastVGGT/eval/eval_scannet.py
deleted file mode 100644
index 332132c45ab4fc3f7e768dbf3845d9cdb3ccc4eb..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/eval_scannet.py
+++ /dev/null
@@ -1,208 +0,0 @@
-import argparse
-from pathlib import Path
-import numpy as np
-import torch
-import os
-import sys
-
-# Ensure project root is in sys.path for absolute imports like `vggt.*`
-ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
-if ROOT_DIR not in sys.path:
- sys.path.insert(0, ROOT_DIR)
-
-from vggt.models.vggt import VGGT
-from vggt.utils.eval_utils import (
- load_poses,
- get_vgg_input_imgs,
- get_sorted_image_paths,
- get_all_scenes,
- build_frame_selection,
- load_images_rgb,
- infer_vggt_and_reconstruct,
- evaluate_scene_and_save,
- compute_average_metrics_and_save,
-)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--data_dir", type=Path, default="/data/scannetv2/process_scannet"
- )
- parser.add_argument(
- "--gt_ply_dir",
- type=Path,
- default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans",
- )
- parser.add_argument("--output_path", type=Path, default="./eval_results")
- parser.add_argument("--merging", type=int, default=None)
- parser.add_argument("--plot", type=bool, default=True)
- parser.add_argument(
- "--depth_conf_thresh",
- type=float,
- default=3.0,
- help="Depth confidence threshold for filtering low confidence depth values",
- )
- parser.add_argument(
- "--chamfer_max_dist",
- type=float,
- default=0.5,
- help="Maximum distance threshold in Chamfer Distance computation, distances exceeding this value will be clipped",
- )
- parser.add_argument(
- "--input_frame",
- type=int,
- default=200,
- help="Maximum number of frames selected for processing per scene",
- )
- parser.add_argument(
- "--num_scenes",
- type=int,
- default=50,
- help="Maximum number of scenes to evaluate",
- )
- parser.add_argument(
- "--ckpt_path",
- type=str,
- default="./ckpt/model_tracker_fixed_e20.pt",
- help="Path to the model checkpoint file",
- )
- parser.add_argument(
- "--vis_attn_map",
- action="store_true",
- help="Whether to visualize attention maps during inference",
- )
- args = parser.parse_args()
- torch.manual_seed(33)
-
- # Scene sampling
- scannet_scenes = get_all_scenes(args.data_dir, args.num_scenes)
- print(f"Evaluate {len(scannet_scenes)} scenes")
-
- all_scenes_metrics = {"scenes": {}, "average": {}}
- # Force use of bf16 data type
- dtype = torch.bfloat16
- # Load VGGT model
- model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map)
- ckpt = torch.load(args.ckpt_path, map_location="cpu")
- incompat = model.load_state_dict(ckpt, strict=False)
- model = model.cuda().eval()
- model = model.to(torch.bfloat16)
-
- # Process each scene
- for scene in scannet_scenes:
- scene_dir = args.data_dir / f"{scene}"
- output_scene_dir = args.output_path / f"input_frame_{args.input_frame}" / scene
- if (output_scene_dir / "metrics.json").exists():
- continue
-
- # Load scene data
- images_dir = scene_dir / "color"
- pose_path = scene_dir / "pose"
- image_paths = get_sorted_image_paths(images_dir)
- poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_path)
- if (
- poses_gt is None
- or first_gt_pose is None
- or available_pose_frame_ids is None
- ):
- print(f"Skipping scene {scene}: no pose data")
- continue
-
- # Frame filtering
- selected_frame_ids, selected_image_paths, selected_pose_indices = (
- build_frame_selection(
- image_paths, available_pose_frame_ids, args.input_frame
- )
- )
-
- # Get corresponding poses
- c2ws = poses_gt[selected_pose_indices]
- image_paths = selected_image_paths
-
- if len(image_paths) == 0:
- print(f"No images found in {images_dir}")
- continue
-
- print("🚩Processing", scene, f"Found {len(image_paths)} images")
- all_cam_to_world_mat = []
- all_world_points = []
-
- try:
- # Load images
- images = load_images_rgb(image_paths)
-
- if not images or len(images) < 3:
- print(f"Skipping {scene}: insufficient valid images")
- continue
-
- frame_ids = selected_frame_ids
- images_array = np.stack(images)
- vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array)
- print(f"Patch dimensions: {patch_width}x{patch_height}")
-
- # Update model attention layers with dynamic patch dimensions
- model.update_patch_dimensions(patch_width, patch_height)
-
- # Inference + Reconstruction
- (
- extrinsic_np,
- intrinsic_np,
- all_world_points,
- all_point_colors,
- all_cam_to_world_mat,
- inference_time_ms,
- ) = infer_vggt_and_reconstruct(
- model, vgg_input, dtype, args.depth_conf_thresh, image_paths
- )
- print(f"Inference time: {inference_time_ms:.2f}ms")
-
- # Process results
- if not all_cam_to_world_mat or not all_world_points:
- print(
- f"Skipping {scene}: failed to obtain valid camera poses or point clouds"
- )
- continue
-
- # Evaluate and save
- metrics = evaluate_scene_and_save(
- scene,
- c2ws,
- first_gt_pose,
- frame_ids,
- all_cam_to_world_mat,
- all_world_points,
- output_scene_dir,
- args.gt_ply_dir,
- args.chamfer_max_dist,
- inference_time_ms,
- args.plot,
- )
- if metrics is not None:
- all_scenes_metrics["scenes"][scene] = {
- key: float(value)
- for key, value in metrics.items()
- if key
- in [
- "chamfer_distance",
- "ate",
- "are",
- "rpe_rot",
- "rpe_trans",
- "inference_time_ms",
- ]
- }
- print("Complete metrics", all_scenes_metrics["scenes"][scene])
-
- except Exception as e:
- print(f"Error processing scene {scene}: {e}")
- import traceback
-
- traceback.print_exc()
-
- # Summarize average metrics and save
- compute_average_metrics_and_save(
- all_scenes_metrics,
- args.output_path,
- args.input_frame,
- )
diff --git a/FastVGGT/eval/utils.py b/FastVGGT/eval/utils.py
deleted file mode 100644
index e8d5606560e1fc82a7b7b81df8b3b9f3b9ec8662..0000000000000000000000000000000000000000
--- a/FastVGGT/eval/utils.py
+++ /dev/null
@@ -1,142 +0,0 @@
-import numpy as np
-from scipy.spatial import cKDTree as KDTree
-
-
-def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
- """
- Args:
- - depthmap (HxW array):
- - camera_intrinsics: a 3x3 matrix
- Returns:
- pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
- """
- camera_intrinsics = np.float32(camera_intrinsics)
- H, W = depthmap.shape
-
- assert camera_intrinsics[0, 1] == 0.0
- assert camera_intrinsics[1, 0] == 0.0
- if pseudo_focal is None:
- fu = camera_intrinsics[0, 0]
- fv = camera_intrinsics[1, 1]
- else:
- assert pseudo_focal.shape == (H, W)
- fu = fv = pseudo_focal
- cu = camera_intrinsics[0, 2]
- cv = camera_intrinsics[1, 2]
-
- u, v = np.meshgrid(np.arange(W), np.arange(H))
- z_cam = depthmap
- x_cam = (u - cu) * z_cam / fu
- y_cam = (v - cv) * z_cam / fv
- X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
-
- valid_mask = depthmap > 0.0
- return X_cam, valid_mask
-
-
-def depthmap_to_absolute_camera_coordinates(
- depthmap, camera_intrinsics, camera_pose, **kw
-):
- """
- Args:
- - depthmap (HxW array):
- - camera_intrinsics: a 3x3 matrix
- - camera_pose: a 4x3 or 4x4 cam2world matrix
- Returns:
- pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
- """
- X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
-
- X_world = X_cam # default
- if camera_pose is not None:
-
- R_cam2world = camera_pose[:3, :3]
- t_cam2world = camera_pose[:3, 3]
-
- X_world = (
- np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
- )
-
- return X_world, valid_mask
-
-
-def completion_ratio(gt_points, rec_points, dist_th=0.05):
- gen_points_kd_tree = KDTree(rec_points)
- distances, _ = gen_points_kd_tree.query(gt_points)
- comp_ratio = np.mean((distances < dist_th).astype(np.float32))
- return comp_ratio
-
-
-def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None):
- gt_points_kd_tree = KDTree(gt_points)
- distances, idx = gt_points_kd_tree.query(rec_points, workers=-1)
- acc = np.mean(distances)
-
- acc_median = np.median(distances)
-
- if gt_normals is not None and rec_normals is not None:
- normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1)
- normal_dot = np.abs(normal_dot)
-
- return acc, acc_median, np.mean(normal_dot), np.median(normal_dot)
-
- return acc, acc_median
-
-
-def completion(gt_points, rec_points, gt_normals=None, rec_normals=None):
- gt_points_kd_tree = KDTree(rec_points)
- distances, idx = gt_points_kd_tree.query(gt_points, workers=-1)
- comp = np.mean(distances)
- comp_median = np.median(distances)
-
- if gt_normals is not None and rec_normals is not None:
- normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1)
- normal_dot = np.abs(normal_dot)
-
- return comp, comp_median, np.mean(normal_dot), np.median(normal_dot)
-
- return comp, comp_median
-
-
-def compute_iou(pred_vox, target_vox):
- # Get voxel indices
- v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()]
- v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()]
-
- # Convert to sets for set operations
- v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices)
- v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices)
-
- # Compute intersection and union
- intersection = v_pred_filled & v_target_filled
- union = v_pred_filled | v_target_filled
-
- # Compute IoU
- iou = len(intersection) / len(union)
- return iou
-
-
-def colmap_to_opencv_intrinsics(K):
- """
- Modify camera intrinsics to follow a different convention.
- Coordinates of the center of the top-left pixels are by default:
- - (0.5, 0.5) in Colmap
- - (0,0) in OpenCV
- """
- K = K.copy()
- K[0, 2] -= 0.5
- K[1, 2] -= 0.5
- return K
-
-
-def opencv_to_colmap_intrinsics(K):
- """
- Modify camera intrinsics to follow a different convention.
- Coordinates of the center of the top-left pixels are by default:
- - (0.5, 0.5) in Colmap
- - (0,0) in OpenCV
- """
- K = K.copy()
- K[0, 2] += 0.5
- K[1, 2] += 0.5
- return K
diff --git a/FastVGGT/merging/__init__.py b/FastVGGT/merging/__init__.py
deleted file mode 100644
index 156ac80ee0ef4270bc02b0954f8473ff55380cce..0000000000000000000000000000000000000000
--- a/FastVGGT/merging/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from . import merge
-
-__all__ = ["merge"]
diff --git a/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc b/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 66a4e39454e637a8e8e0fa63f6708b1eb670aae9..0000000000000000000000000000000000000000
Binary files a/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/merging/__pycache__/merge.cpython-310.pyc b/FastVGGT/merging/__pycache__/merge.cpython-310.pyc
deleted file mode 100644
index 9fa7eff7fec3ec548a7532b57636ea9d7db57e16..0000000000000000000000000000000000000000
Binary files a/FastVGGT/merging/__pycache__/merge.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/merging/merge.py b/FastVGGT/merging/merge.py
deleted file mode 100644
index e2094b657cba6c9772a85f2ebe0240367fcec09c..0000000000000000000000000000000000000000
--- a/FastVGGT/merging/merge.py
+++ /dev/null
@@ -1,370 +0,0 @@
-import torch
-from typing import Tuple, Callable, Optional, Union
-
-
-@torch.jit.script
-def fast_similarity_chunks(
- a: torch.Tensor, b_transposed: torch.Tensor, chunk_size: int
-) -> Tuple[torch.Tensor, torch.Tensor]:
-
- B, num_src, C = a.shape
- original_dtype = a.dtype
-
- # Convert to bf16 for computation to improve performance and reduce memory usage
- a_bf16 = a.to(torch.bfloat16)
- b_transposed_bf16 = b_transposed.to(torch.bfloat16)
- node_max = torch.empty(B, num_src, device=a.device, dtype=original_dtype)
- node_idx = torch.empty(B, num_src, device=a.device, dtype=torch.long)
-
- # Process in chunks
- for i in range(0, num_src, chunk_size):
- end_i = min(i + chunk_size, num_src)
- a_chunk = a_bf16[:, i:end_i, :] # [B, chunk_size, C]
- scores_chunk = torch.bmm(a_chunk, b_transposed_bf16)
- chunk_max_bf16, chunk_idx = torch.max(scores_chunk, dim=2)
- chunk_max = chunk_max_bf16.to(original_dtype)
- node_max[:, i:end_i] = chunk_max
- node_idx[:, i:end_i] = chunk_idx
- return node_max, node_idx
-
-
-def do_nothing(
- x: torch.Tensor,
- extra_tensors=None,
- extra_tensors_2=None,
-) -> Union[
- torch.Tensor,
- Tuple[torch.Tensor, torch.Tensor],
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
-]:
- if extra_tensors is not None and extra_tensors_2 is not None:
- return x, extra_tensors, extra_tensors_2
- elif extra_tensors is not None:
- return x, extra_tensors
- else:
- return x
-
-
-def token_merge_bipartite2d(
- metric: torch.Tensor,
- w: int,
- h: int,
- sx: int,
- sy: int,
- r: int,
- no_rand: bool = False,
- generator: Optional[torch.Generator] = None,
- enable_protection: bool = False,
-) -> Tuple[Callable, Callable]:
- """
- Divide tokens into source (src) and destination (dst) groups, and merge r tokens from src to dst.
- dst tokens are selected by randomly choosing one token from each (sx, sy) region.
- Optionally protect the top 10% of tokens from merging based on importance scores.
-
- Args:
- - metric [B, N, C]: Tensor for similarity computation, B=batch size, N=token count, C=feature dimension
- - w: Image width in tokens
- - h: Image height in tokens
- - sx: dst stride in x dimension, must divide w evenly
- - sy: dst stride in y dimension, must divide h evenly
- - r: Number of tokens to remove through merging
- - no_rand: If True, disable randomness (use only top-left token)
- - generator: Random number generator if no_rand is False and not None
- - enable_protection: If True, enable importance protection feature
-
- Returns:
- - (merge, unmerge): Two functions for merging tokens and restoring pre-merge state
- """
- B, N, _ = metric.shape # Batch size B, total tokens N
- if r <= 0:
- return do_nothing, do_nothing
-
- gather = torch.gather
-
- tokens_per_img = w * h + 5
- num_imgs = N // tokens_per_img
- assert tokens_per_img * num_imgs == N, "Token count doesn't match (w*h+5)*num_imgs"
-
- with torch.no_grad():
- # Determine whether to compute importance scores based on enable_protection
- if enable_protection:
- num_protected = int(N * 0.1)
- step = max(1, N // num_protected)
- protected_indices = torch.arange(0, N, step, device=metric.device)[
- :num_protected
- ]
- else:
- protected_indices = None
- num_protected = 0
-
- # Global idx_buffer_seq of length N; -1 indicates dst, 0 indicates src (maintain original logic)
- idx_buffer_seq = torch.zeros(N, device=metric.device, dtype=torch.int64)
- hsy, wsx = h // sy, w // sx # Number of blocks within each image
-
- # Mark first image entirely as dst
- if num_imgs > 0:
- idx_buffer_seq[:tokens_per_img] = -1
-
- # Process other images - fully vectorized batch operations
- if num_imgs > 1:
- cls_indices = (
- torch.arange(1, num_imgs, device=metric.device) * tokens_per_img
- )
- cls_indices = cls_indices[:, None] + torch.arange(5, device=metric.device)
- idx_buffer_seq[cls_indices.flatten()] = -1
- effective_h = min(hsy * sy, h)
- effective_w = min(wsx * sx, w)
- effective_grid_size = effective_h * effective_w
-
- if no_rand:
- base_pattern = torch.zeros(
- effective_grid_size, device=metric.device, dtype=torch.int64
- )
- grid_starts = (
- torch.arange(1, num_imgs, device=metric.device) * tokens_per_img + 5
- )
- grid_indices = grid_starts[:, None] + torch.arange(
- effective_grid_size, device=metric.device
- )
- idx_buffer_seq[grid_indices.flatten()] = base_pattern.repeat(
- num_imgs - 1
- )
- else:
- total_other_imgs = num_imgs - 1
- all_rand_idx = torch.randint(
- sy * sx,
- size=(total_other_imgs, hsy, wsx),
- device=metric.device,
- generator=generator,
- )
-
- scatter_src = -torch.ones(
- total_other_imgs, hsy, wsx, device=metric.device, dtype=torch.int64
- )
-
- idx_buffer_batch = torch.zeros(
- total_other_imgs,
- hsy,
- wsx,
- sy * sx,
- device=metric.device,
- dtype=torch.int64,
- )
- idx_buffer_batch.scatter_(
- dim=3,
- index=all_rand_idx.unsqueeze(-1),
- src=scatter_src.unsqueeze(-1),
- )
-
- idx_buffer_batch = (
- idx_buffer_batch.view(total_other_imgs, hsy, wsx, sy, sx)
- .transpose(2, 3)
- .reshape(total_other_imgs, hsy * sy, wsx * sx)
- )
-
- # Batch fill to target positions - still needs a small loop here, but operations are greatly reduced
- for i in range(total_other_imgs):
- img_idx = i + 1
- grid_start = img_idx * tokens_per_img + 5
- flat_view = idx_buffer_batch[
- i, :effective_h, :effective_w
- ].flatten()
- idx_buffer_seq[grid_start : grid_start + effective_grid_size] = (
- flat_view
- )
-
- rand_idx = idx_buffer_seq.reshape(1, -1, 1).argsort(dim=1)
- num_dst_orig = int((idx_buffer_seq == -1).sum())
-
- # Original src and dst indices
- a_idx_orig = rand_idx[:, num_dst_orig:, :]
- b_idx_orig = rand_idx[:, :num_dst_orig, :]
- a_idx = a_idx_orig
- b_idx = b_idx_orig
-
- if enable_protection:
- protected_idx = protected_indices.unsqueeze(0).unsqueeze(-1)
- num_protected_actual = protected_idx.shape[1]
- else:
- protected_idx = None
- num_protected_actual = 0
-
- num_src = a_idx.shape[1]
- num_dst = b_idx.shape[1]
-
- # Define an internal function to separate src, dst, and protected tokens
- def split(x):
- C = x.shape[-1]
-
- if enable_protection:
- src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
- dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
- protected = gather(
- x, dim=1, index=protected_idx.expand(B, num_protected_actual, C)
- )
- return src, dst, protected
- else:
- src = gather(x, dim=1, index=a_idx.expand(B, num_src, C))
- dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
- return src, dst
-
- # Compute cosine similarity (normalize first then dot product)
- metric = metric / metric.norm(dim=-1, keepdim=True)
- if enable_protection:
- a, b, protected = split(metric)
- else:
- a, b = split(metric)
-
- r = min(a.shape[1], r)
- num_src_actual = a.shape[1]
- chunk_size = min(5000, num_src_actual)
-
- node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype)
- node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long)
-
- b_transposed = b.transpose(-1, -2)
- node_max, node_idx = fast_similarity_chunks(a, b_transposed, chunk_size)
- edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
-
- # If protection is enabled, filter out protected tokens to ensure they are not merged
- if enable_protection:
- src_indices = a_idx[0, :, 0]
- protected_mask_src = torch.isin(src_indices, protected_indices)
- edge_flat = edge_idx[0, :, 0]
- valid_mask = ~protected_mask_src[edge_flat]
- valid_edges = edge_flat[valid_mask]
-
- valid_count = valid_edges.shape[0]
- r_actual = min(r, valid_count)
-
- unm_idx = valid_edges[r_actual:].unsqueeze(0).unsqueeze(-1)
- src_idx = valid_edges[:r_actual].unsqueeze(0).unsqueeze(-1)
- else:
- unm_idx = edge_idx[..., r:, :]
- src_idx = edge_idx[..., :r, :]
- r_actual = r
-
- # Get dst token indices corresponding to each src token to be merged
- dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
- r = r_actual
-
- # Define merge function to merge selected src tokens to corresponding dst tokens
- def merge(
- x: torch.Tensor,
- mode: str = "mean",
- extra_tensors=None,
- extra_tensors_2=None,
- ) -> Union[
- torch.Tensor,
- Tuple[torch.Tensor, torch.Tensor],
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
- ]:
- if enable_protection:
- src, dst, protected = split(x)
- else:
- src, dst = split(x)
-
- n, t1, c = src.shape
-
- # Extract unmerged src tokens - using actual unm_idx size
- unm_len = unm_idx.shape[1]
- unm = gather(src, dim=-2, index=unm_idx.expand(n, unm_len, c))
- src_len = src_idx.shape[1]
- src = gather(src, dim=-2, index=src_idx.expand(n, src_len, c))
- dst = dst.scatter_reduce(-2, dst_idx.expand(n, src_len, c), src, reduce=mode)
-
- # ---------------- Extra tensor processing ----------------
- merged_extra_1 = None
- merged_extra_2 = None
- if extra_tensors is not None:
- E_dim = extra_tensors.shape[-1]
- if enable_protection:
- src_e, dst_e, protected_e = split(extra_tensors)
- else:
- src_e, dst_e = split(extra_tensors)
-
- # Consistent with main tensor, only select r src tokens to be merged
- src_e_r = gather(src_e, dim=-2, index=src_idx.expand(n, src_len, E_dim))
- unm_e = gather(src_e, dim=-2, index=unm_idx.expand(n, unm_len, E_dim))
-
- dst_e = dst_e.scatter_reduce(
- -2, dst_idx.expand(n, src_len, E_dim), src_e_r, reduce=mode
- )
- if enable_protection:
- merged_extra_1 = torch.cat([unm_e, dst_e, protected_e], dim=1)
- else:
- merged_extra_1 = torch.cat([unm_e, dst_e], dim=1)
-
- if extra_tensors_2 is not None:
- E_dim_2 = extra_tensors_2.shape[-1]
- if enable_protection:
- src_e2, dst_e2, protected_e2 = split(extra_tensors_2)
- else:
- src_e2, dst_e2 = split(extra_tensors_2)
-
- src_e2_r = gather(src_e2, dim=-2, index=src_idx.expand(n, src_len, E_dim_2))
- unm_e2 = gather(src_e2, dim=-2, index=unm_idx.expand(n, unm_len, E_dim_2))
-
- dst_e2 = dst_e2.scatter_reduce(
- -2, dst_idx.expand(n, src_len, E_dim_2), src_e2_r, reduce=mode
- )
- if enable_protection:
- merged_extra_2 = torch.cat([unm_e2, dst_e2, protected_e2], dim=1)
- else:
- merged_extra_2 = torch.cat([unm_e2, dst_e2], dim=1)
-
- if enable_protection:
- main_result = torch.cat([unm, dst, protected], dim=1)
- else:
- main_result = torch.cat([unm, dst], dim=1)
-
- if merged_extra_1 is not None and merged_extra_2 is not None:
- return main_result, merged_extra_1, merged_extra_2
- elif merged_extra_1 is not None:
- return main_result, merged_extra_1
- else:
- return main_result
-
- # Define unmerge function to restore pre-merge state (for decoder)
- def unmerge(x: torch.Tensor) -> torch.Tensor:
- unm_len = unm_idx.shape[1]
- dst_len = num_dst
- src_len = src_idx.shape[1]
- unm = x[..., :unm_len, :]
- dst = x[..., unm_len : unm_len + dst_len, :]
-
- if enable_protection:
- protected = x[
- ..., unm_len + dst_len : unm_len + dst_len + num_protected_actual, :
- ]
-
- _, _, c = unm.shape
- src = gather(dst, dim=-2, index=dst_idx.expand(B, src_len, c))
- out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
- out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
- out.scatter_(
- dim=-2,
- index=gather(
- a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx
- ).expand(B, unm_len, c),
- src=unm,
- )
-
- out.scatter_(
- dim=-2,
- index=gather(
- a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx
- ).expand(B, src_len, c),
- src=src,
- )
-
- if enable_protection:
- out.scatter_(
- dim=-2,
- index=protected_idx.expand(B, num_protected_actual, c),
- src=protected,
- )
-
- return out
-
- return merge, unmerge
diff --git a/FastVGGT/requirements.txt b/FastVGGT/requirements.txt
deleted file mode 100644
index e343637f2d9a1550f55f14f8890c229bd467cdb2..0000000000000000000000000000000000000000
--- a/FastVGGT/requirements.txt
+++ /dev/null
@@ -1,15 +0,0 @@
-torch==2.3.1
-torchvision==0.18.1
-numpy==1.26.1
-Pillow
-huggingface_hub
-einops
-safetensors
-evo
-open3d
-matplotlib
-scipy
-opencv-python
-scikit-image
-tqdm
-
diff --git a/FastVGGT/vggt/__init__.py b/FastVGGT/vggt/__init__.py
deleted file mode 100644
index 3a2958f93a114495631a3ce270b99ee8ba8443f1..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
\ No newline at end of file
diff --git a/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index c5f3bfe676b899c95cba2c034de5292d685e03ce..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/dependency/__init__.py b/FastVGGT/vggt/dependency/__init__.py
deleted file mode 100644
index 3a2958f93a114495631a3ce270b99ee8ba8443f1..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/dependency/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
\ No newline at end of file
diff --git a/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index cd09b8397b7a4717d49afb51e64d490fa11fc205..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc b/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc
deleted file mode 100644
index 389177102f318fe4bcfd46f6ab4a04e1a8f39196..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/dependency/distortion.py b/FastVGGT/vggt/dependency/distortion.py
deleted file mode 100644
index 375b747086478050d601676ffaea3b25e80690b4..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/dependency/distortion.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import numpy as np
-import torch
-
-
-def apply_distortion(points, distortion_params):
- """
- Apply distortion to normalized camera coordinates.
-
- Args:
- points: Array of normalized camera coordinates
- distortion_params: Distortion parameters
-
- Returns:
- Distorted coordinates
- """
- # Simple passthrough for now - implement actual distortion if needed
- return points
-
-
-def iterative_undistortion(points, distortion_params, max_iter=10):
- """
- Remove distortion from normalized camera coordinates using iterative method.
-
- Args:
- points: Array of distorted normalized camera coordinates
- distortion_params: Distortion parameters
- max_iter: Maximum number of iterations
-
- Returns:
- Undistorted coordinates
- """
- # Simple passthrough for now - implement actual undistortion if needed
- return points
-
-
-def single_undistortion(points, distortion_params):
- """
- Remove distortion from normalized camera coordinates using single step.
-
- Args:
- points: Array of distorted normalized camera coordinates
- distortion_params: Distortion parameters
-
- Returns:
- Undistorted coordinates
- """
- # Simple passthrough for now - implement actual undistortion if needed
- return points
\ No newline at end of file
diff --git a/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc
deleted file mode 100644
index 292dc30328d06ba3951a9eee0e99bc569c60a1d3..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc
deleted file mode 100644
index 838db11a0862e9907682572d8ebe3b7c1665e147..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc
deleted file mode 100644
index 6d6cd93db3a5987c35085d487d21f5513ea3d1fd..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc
deleted file mode 100644
index a3244c2b0364f9b533aab03446edb787dbed59d5..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc
deleted file mode 100644
index fc2e8df264f6accc5412e80b8580c8c48195aaf5..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/camera_head.py b/FastVGGT/vggt/heads/camera_head.py
deleted file mode 100644
index 215adf39de23abd4975479d332250fcc3e2b54b9..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/camera_head.py
+++ /dev/null
@@ -1,149 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import math
-import numpy as np
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from vggt.layers import Mlp
-from vggt.layers.block import Block
-from vggt.heads.head_act import activate_pose
-
-
-class CameraHead(nn.Module):
- """
- CameraHead predicts camera parameters from token representations using iterative refinement.
-
- It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
- """
-
- def __init__(
- self,
- dim_in: int = 2048,
- trunk_depth: int = 4,
- pose_encoding_type: str = "absT_quaR_FoV",
- num_heads: int = 16,
- mlp_ratio: int = 4,
- init_values: float = 0.01,
- trans_act: str = "linear",
- quat_act: str = "linear",
- fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
- ):
- super().__init__()
-
- if pose_encoding_type == "absT_quaR_FoV":
- self.target_dim = 9
- else:
- raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
-
- self.trans_act = trans_act
- self.quat_act = quat_act
- self.fl_act = fl_act
- self.trunk_depth = trunk_depth
-
- # Build the trunk using a sequence of transformer blocks.
- self.trunk = nn.Sequential(
- *[
- Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
- for _ in range(trunk_depth)
- ]
- )
-
- # Normalizations for camera token and trunk output.
- self.token_norm = nn.LayerNorm(dim_in)
- self.trunk_norm = nn.LayerNorm(dim_in)
-
- # Learnable empty camera pose token.
- self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
- self.embed_pose = nn.Linear(self.target_dim, dim_in)
-
- # Module for producing modulation parameters: shift, scale, and a gate.
- self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
-
- # Adaptive layer normalization without affine parameters.
- self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
- self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
-
- def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
- """
- Forward pass to predict camera parameters.
-
- Args:
- aggregated_tokens_list (list): List of token tensors from the network;
- the last tensor is used for prediction.
- num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
-
- Returns:
- list: A list of predicted camera encodings (post-activation) from each iteration.
- """
- # Use tokens from the last block for camera prediction.
- tokens = aggregated_tokens_list[-1]
-
- # Extract the camera tokens
- pose_tokens = tokens[:, :, 0]
- pose_tokens = self.token_norm(pose_tokens)
-
- pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
- return pred_pose_enc_list
-
- def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
- """
- Iteratively refine camera pose predictions.
-
- Args:
- pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
- num_iterations (int): Number of refinement iterations.
-
- Returns:
- list: List of activated camera encodings from each iteration.
- """
- B, S, C = pose_tokens.shape # S is expected to be 1.
- pred_pose_enc = None
- pred_pose_enc_list = []
-
- for _ in range(num_iterations):
- # Use a learned empty pose for the first iteration.
- if pred_pose_enc is None:
- module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
- else:
- # Detach the previous prediction to avoid backprop through time.
- pred_pose_enc = pred_pose_enc.detach()
- module_input = self.embed_pose(pred_pose_enc)
-
- # Generate modulation parameters and split them into shift, scale, and gate components.
- shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
-
- # Adaptive layer normalization and modulation.
- pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
- pose_tokens_modulated = pose_tokens_modulated + pose_tokens
-
- pose_tokens_modulated = self.trunk(pose_tokens_modulated)
- # Compute the delta update for the pose encoding.
- pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
-
- if pred_pose_enc is None:
- pred_pose_enc = pred_pose_enc_delta
- else:
- pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
-
- # Apply final activation functions for translation, quaternion, and field-of-view.
- activated_pose = activate_pose(
- pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
- )
- pred_pose_enc_list.append(activated_pose)
-
- return pred_pose_enc_list
-
-
-def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
- """
- Modulate the input tensor using scaling and shifting parameters.
- """
- # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
- return x * (1 + scale) + shift
diff --git a/FastVGGT/vggt/heads/dpt_head.py b/FastVGGT/vggt/heads/dpt_head.py
deleted file mode 100644
index e20d65ef5bfeb23cf83ca748aedff738840b4ffd..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/dpt_head.py
+++ /dev/null
@@ -1,598 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
-
-
-import os
-from typing import List, Dict, Tuple, Union, Optional
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from .head_act import activate_head
-from .utils import create_uv_grid, position_grid_to_embed
-
-
-class DPTHead(nn.Module):
- """
- DPT Head for dense prediction tasks.
-
- This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
- (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
- backbone and produces dense predictions by fusing multi-scale features.
-
- Args:
- dim_in (int): Input dimension (channels).
- patch_size (int, optional): Patch size. Default is 14.
- output_dim (int, optional): Number of output channels. Default is 4.
- activation (str, optional): Activation type. Default is "inv_log".
- conf_activation (str, optional): Confidence activation type. Default is "expp1".
- features (int, optional): Feature channels for intermediate representations. Default is 256.
- out_channels (List[int], optional): Output channels for each intermediate layer.
- intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
- pos_embed (bool, optional): Whether to use positional embedding. Default is True.
- feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
- down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
- """
-
- def __init__(
- self,
- dim_in: int,
- patch_size: int = 14,
- output_dim: int = 4,
- activation: str = "inv_log",
- conf_activation: str = "expp1",
- features: int = 256,
- out_channels: List[int] = [256, 512, 1024, 1024],
- intermediate_layer_idx: List[int] = [0, 1, 2, 3],
- pos_embed: bool = True,
- feature_only: bool = False,
- down_ratio: int = 1,
- ) -> None:
- super(DPTHead, self).__init__()
- self.patch_size = patch_size
- self.activation = activation
- self.conf_activation = conf_activation
- self.pos_embed = pos_embed
- self.feature_only = feature_only
- self.down_ratio = down_ratio
- self.intermediate_layer_idx = intermediate_layer_idx
-
- self.norm = nn.LayerNorm(dim_in)
-
- # Projection layers for each output channel from tokens.
- self.projects = nn.ModuleList(
- [
- nn.Conv2d(
- in_channels=dim_in,
- out_channels=oc,
- kernel_size=1,
- stride=1,
- padding=0,
- )
- for oc in out_channels
- ]
- )
-
- # Resize layers for upsampling feature maps.
- self.resize_layers = nn.ModuleList(
- [
- nn.ConvTranspose2d(
- in_channels=out_channels[0],
- out_channels=out_channels[0],
- kernel_size=4,
- stride=4,
- padding=0,
- ),
- nn.ConvTranspose2d(
- in_channels=out_channels[1],
- out_channels=out_channels[1],
- kernel_size=2,
- stride=2,
- padding=0,
- ),
- nn.Identity(),
- nn.Conv2d(
- in_channels=out_channels[3],
- out_channels=out_channels[3],
- kernel_size=3,
- stride=2,
- padding=1,
- ),
- ]
- )
-
- self.scratch = _make_scratch(out_channels, features, expand=False)
-
- # Attach additional modules to scratch.
- self.scratch.stem_transpose = nn.Identity() # Use Identity instead of None
- self.scratch.refinenet1 = _make_fusion_block(features)
- self.scratch.refinenet2 = _make_fusion_block(features)
- self.scratch.refinenet3 = _make_fusion_block(features)
- self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
-
- head_features_1 = features
- head_features_2 = 32
-
- if feature_only:
- self.scratch.output_conv1 = nn.Conv2d(
- head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
- )
- else:
- self.scratch.output_conv1 = nn.Conv2d(
- head_features_1,
- head_features_1 // 2,
- kernel_size=3,
- stride=1,
- padding=1,
- )
- conv2_in_channels = head_features_1 // 2
-
- self.scratch.output_conv2 = nn.Sequential(
- nn.Conv2d(
- conv2_in_channels,
- head_features_2,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- nn.ReLU(inplace=True),
- nn.Conv2d(
- head_features_2, output_dim, kernel_size=1, stride=1, padding=0
- ),
- )
-
- def forward(
- self,
- aggregated_tokens_list: List[torch.Tensor],
- images: torch.Tensor,
- patch_start_idx: int,
- frames_chunk_size: int = 8,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """
- Forward pass through the DPT head, supports processing by chunking frames.
- Args:
- aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
- images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
- patch_start_idx (int): Starting index for patch tokens in the token sequence.
- Used to separate patch tokens from other tokens (e.g., camera or register tokens).
- frames_chunk_size (int, optional): Number of frames to process in each chunk.
- If None or larger than S, all frames are processed at once. Default: 8.
-
- Returns:
- Tensor or Tuple[Tensor, Tensor]:
- - If feature_only=True: Feature maps with shape [B, S, C, H, W]
- - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
- """
- B, S, _, H, W = images.shape
-
- # If frames_chunk_size is not specified or greater than S, process all frames at once
- if frames_chunk_size is None or frames_chunk_size >= S:
- return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
-
- # Otherwise, process frames in chunks to manage memory usage
- assert frames_chunk_size > 0
-
- # Process frames in batches
- all_preds = []
- all_conf = []
-
- for frames_start_idx in range(0, S, frames_chunk_size):
- frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
-
- # Process batch of frames
- if self.feature_only:
- chunk_output = self._forward_impl(
- aggregated_tokens_list,
- images,
- patch_start_idx,
- frames_start_idx,
- frames_end_idx,
- )
- all_preds.append(chunk_output)
- else:
- chunk_preds, chunk_conf = self._forward_impl(
- aggregated_tokens_list,
- images,
- patch_start_idx,
- frames_start_idx,
- frames_end_idx,
- )
- all_preds.append(chunk_preds)
- all_conf.append(chunk_conf)
-
- # Concatenate results along the sequence dimension
- if self.feature_only:
- return torch.cat(all_preds, dim=1)
- else:
- return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
-
- def _forward_impl(
- self,
- aggregated_tokens_list: List[torch.Tensor],
- images: torch.Tensor,
- patch_start_idx: int,
- frames_start_idx: Optional[int] = None,
- frames_end_idx: Optional[int] = None,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
- """
- Implementation of the forward pass through the DPT head.
-
- This method processes a specific chunk of frames from the sequence.
-
- Args:
- aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
- images (Tensor): Input images with shape [B, S, 3, H, W].
- patch_start_idx (int): Starting index for patch tokens.
- frames_start_idx (int, optional): Starting index for frames to process.
- frames_end_idx (int, optional): Ending index for frames to process.
-
- Returns:
- Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
- """
- if frames_start_idx is not None and frames_end_idx is not None:
- images = images[:, frames_start_idx:frames_end_idx].contiguous()
-
- B, S, _, H, W = images.shape
-
- patch_h, patch_w = H // self.patch_size, W // self.patch_size
-
- out = []
- dpt_idx = 0
-
- for layer_idx in self.intermediate_layer_idx:
- x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
-
- # Select frames if processing a chunk
- if frames_start_idx is not None and frames_end_idx is not None:
- x = x[:, frames_start_idx:frames_end_idx]
-
- x = x.reshape(B * S, -1, x.shape[-1])
-
- x = self.norm(x)
-
- x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
-
- x = self.projects[dpt_idx](x)
- if self.pos_embed:
- x = self._apply_pos_embed(x, W, H)
- x = self.resize_layers[dpt_idx](x)
-
- out.append(x)
- dpt_idx += 1
-
- # Fuse features from multiple layers.
- out = self.scratch_forward(out)
- # Interpolate fused output to match target image resolution.
- out = custom_interpolate(
- out,
- (
- int(patch_h * self.patch_size / self.down_ratio),
- int(patch_w * self.patch_size / self.down_ratio),
- ),
- mode="bilinear",
- align_corners=True,
- )
-
- if self.pos_embed:
- out = self._apply_pos_embed(out, W, H)
-
- if self.feature_only:
- return out.view(B, S, *out.shape[1:])
-
- out = self.scratch.output_conv2(out)
- preds, conf = activate_head(
- out, activation=self.activation, conf_activation=self.conf_activation
- )
-
- preds = preds.view(B, S, *preds.shape[1:])
- conf = conf.view(B, S, *conf.shape[1:])
- return preds, conf
-
- def _apply_pos_embed(
- self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
- ) -> torch.Tensor:
- """
- Apply positional embedding to tensor x.
- """
- patch_w = x.shape[-1]
- patch_h = x.shape[-2]
- pos_embed = create_uv_grid(
- patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
- )
- pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
- pos_embed = pos_embed * ratio
- pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
- return x + pos_embed
-
- def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
- """
- Forward pass through the fusion blocks.
-
- Args:
- features (List[Tensor]): List of feature maps from different layers.
-
- Returns:
- Tensor: Fused feature map.
- """
- layer_1, layer_2, layer_3, layer_4 = features
-
- layer_1_rn = self.scratch.layer1_rn(layer_1)
- layer_2_rn = self.scratch.layer2_rn(layer_2)
- layer_3_rn = self.scratch.layer3_rn(layer_3)
- layer_4_rn = self.scratch.layer4_rn(layer_4)
-
- out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
- del layer_4_rn, layer_4
-
- out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
- del layer_3_rn, layer_3
-
- out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
- del layer_2_rn, layer_2
-
- out = self.scratch.refinenet1(out, layer_1_rn)
- del layer_1_rn, layer_1
-
- out = self.scratch.output_conv1(out)
- return out
-
-
-################################################################################
-# Modules
-################################################################################
-
-
-def _make_fusion_block(
- features: int,
- size: Optional[int] = None,
- has_residual: bool = True,
- groups: int = 1,
-) -> nn.Module:
- return FeatureFusionBlock(
- features,
- nn.ReLU(inplace=True),
- deconv=False,
- bn=False,
- expand=False,
- align_corners=True,
- size=size,
- has_residual=has_residual,
- groups=groups,
- )
-
-
-def _make_scratch(
- in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
-) -> nn.Module:
- scratch = nn.Module()
- out_shape1 = out_shape
- out_shape2 = out_shape
- out_shape3 = out_shape
- if len(in_shape) >= 4:
- out_shape4 = out_shape
-
- if expand:
- out_shape1 = out_shape
- out_shape2 = out_shape * 2
- out_shape3 = out_shape * 4
- if len(in_shape) >= 4:
- out_shape4 = out_shape * 8
-
- scratch.layer1_rn = nn.Conv2d(
- in_shape[0],
- out_shape1,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=groups,
- )
- scratch.layer2_rn = nn.Conv2d(
- in_shape[1],
- out_shape2,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=groups,
- )
- scratch.layer3_rn = nn.Conv2d(
- in_shape[2],
- out_shape3,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=groups,
- )
- if len(in_shape) >= 4:
- scratch.layer4_rn = nn.Conv2d(
- in_shape[3],
- out_shape4,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=False,
- groups=groups,
- )
- return scratch
-
-
-class ResidualConvUnit(nn.Module):
- """Residual convolution module."""
-
- def __init__(self, features, activation, bn, groups=1):
- """Init.
-
- Args:
- features (int): number of features
- """
- super().__init__()
-
- self.bn = bn
- self.groups = groups
- self.conv1 = nn.Conv2d(
- features,
- features,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=True,
- groups=self.groups,
- )
- self.conv2 = nn.Conv2d(
- features,
- features,
- kernel_size=3,
- stride=1,
- padding=1,
- bias=True,
- groups=self.groups,
- )
-
- self.norm1 = None
- self.norm2 = None
-
- self.activation = activation
-
- def forward(self, x):
- """Forward pass.
-
- Args:
- x (tensor): input
-
- Returns:
- tensor: output
- """
-
- out = self.activation(x)
- out = self.conv1(out)
- if self.norm1 is not None:
- out = self.norm1(out)
-
- out = self.activation(out)
- out = self.conv2(out)
- if self.norm2 is not None:
- out = self.norm2(out)
-
- return out + x
-
-
-class FeatureFusionBlock(nn.Module):
- """Feature fusion block."""
-
- def __init__(
- self,
- features,
- activation,
- deconv=False,
- bn=False,
- expand=False,
- align_corners=True,
- size=None,
- has_residual=True,
- groups=1,
- ):
- """Init.
-
- Args:
- features (int): number of features
- """
- super(FeatureFusionBlock, self).__init__()
-
- self.deconv = deconv
- self.align_corners = align_corners
- self.groups = groups
- self.expand = expand
- out_features = features
- if self.expand == True:
- out_features = features // 2
-
- self.out_conv = nn.Conv2d(
- features,
- out_features,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=True,
- groups=self.groups,
- )
-
- if has_residual:
- self.resConfUnit1 = ResidualConvUnit(
- features, activation, bn, groups=self.groups
- )
-
- self.has_residual = has_residual
- self.resConfUnit2 = ResidualConvUnit(
- features, activation, bn, groups=self.groups
- )
-
- self.size = size
-
- def forward(self, *xs, size=None):
- """Forward pass.
-
- Returns:
- tensor: output
- """
- output = xs[0]
-
- if self.has_residual:
- res = self.resConfUnit1(xs[1])
- output = output + res
-
- output = self.resConfUnit2(output)
-
- if (size is None) and (self.size is None):
- modifier = {"scale_factor": 2}
- elif size is None:
- modifier = {"size": self.size}
- else:
- modifier = {"size": size}
-
- output = custom_interpolate(
- output, **modifier, mode="bilinear", align_corners=self.align_corners
- )
- output = self.out_conv(output)
-
- return output
-
-
-def custom_interpolate(
- x: torch.Tensor,
- size: Optional[Tuple[int, int]] = None,
- scale_factor: Optional[float] = None,
- mode: str = "bilinear",
- align_corners: bool = True,
-) -> torch.Tensor:
- """
- Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
- """
- if size is None:
- size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
-
- INT_MAX = 1610612736
-
- input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
-
- if input_elements > INT_MAX:
- chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
- interpolated_chunks = [
- nn.functional.interpolate(
- chunk, size=size, mode=mode, align_corners=align_corners
- )
- for chunk in chunks
- ]
- x = torch.cat(interpolated_chunks, dim=0)
- return x.contiguous()
- else:
- return nn.functional.interpolate(
- x, size=size, mode=mode, align_corners=align_corners
- )
diff --git a/FastVGGT/vggt/heads/head_act.py b/FastVGGT/vggt/heads/head_act.py
deleted file mode 100644
index 2dedfcf1180a653dddc99623e60df625e5897489..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/head_act.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-import torch
-import torch.nn.functional as F
-
-
-def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
- """
- Activate pose parameters with specified activation functions.
-
- Args:
- pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
- trans_act: Activation type for translation component
- quat_act: Activation type for quaternion component
- fl_act: Activation type for focal length component
-
- Returns:
- Activated pose parameters tensor
- """
- T = pred_pose_enc[..., :3]
- quat = pred_pose_enc[..., 3:7]
- fl = pred_pose_enc[..., 7:] # or fov
-
- T = base_pose_act(T, trans_act)
- quat = base_pose_act(quat, quat_act)
- fl = base_pose_act(fl, fl_act) # or fov
-
- pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
-
- return pred_pose_enc
-
-
-def base_pose_act(pose_enc, act_type="linear"):
- """
- Apply basic activation function to pose parameters.
-
- Args:
- pose_enc: Tensor containing encoded pose parameters
- act_type: Activation type ("linear", "inv_log", "exp", "relu")
-
- Returns:
- Activated pose parameters
- """
- if act_type == "linear":
- return pose_enc
- elif act_type == "inv_log":
- return inverse_log_transform(pose_enc)
- elif act_type == "exp":
- return torch.exp(pose_enc)
- elif act_type == "relu":
- return F.relu(pose_enc)
- else:
- raise ValueError(f"Unknown act_type: {act_type}")
-
-
-def activate_head(out, activation="norm_exp", conf_activation="expp1"):
- """
- Process network output to extract 3D points and confidence values.
-
- Args:
- out: Network output tensor (B, C, H, W)
- activation: Activation type for 3D points
- conf_activation: Activation type for confidence values
-
- Returns:
- Tuple of (3D points tensor, confidence tensor)
- """
- # Move channels from last dim to the 4th dimension => (B, H, W, C)
- fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
-
- # Split into xyz (first C-1 channels) and confidence (last channel)
- xyz = fmap[:, :, :, :-1]
- conf = fmap[:, :, :, -1]
-
- if activation == "norm_exp":
- d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
- xyz_normed = xyz / d
- pts3d = xyz_normed * torch.expm1(d)
- elif activation == "norm":
- pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
- elif activation == "exp":
- pts3d = torch.exp(xyz)
- elif activation == "relu":
- pts3d = F.relu(xyz)
- elif activation == "inv_log":
- pts3d = inverse_log_transform(xyz)
- elif activation == "xy_inv_log":
- xy, z = xyz.split([2, 1], dim=-1)
- z = inverse_log_transform(z)
- pts3d = torch.cat([xy * z, z], dim=-1)
- elif activation == "sigmoid":
- pts3d = torch.sigmoid(xyz)
- elif activation == "linear":
- pts3d = xyz
- else:
- raise ValueError(f"Unknown activation: {activation}")
-
- if conf_activation == "expp1":
- conf_out = 1 + conf.exp()
- elif conf_activation == "expp0":
- conf_out = conf.exp()
- elif conf_activation == "sigmoid":
- conf_out = torch.sigmoid(conf)
- else:
- raise ValueError(f"Unknown conf_activation: {conf_activation}")
-
- return pts3d, conf_out
-
-
-def inverse_log_transform(y):
- """
- Apply inverse log transform: sign(y) * (exp(|y|) - 1)
-
- Args:
- y: Input tensor
-
- Returns:
- Transformed tensor
- """
- return torch.sign(y) * (torch.expm1(torch.abs(y)))
diff --git a/FastVGGT/vggt/heads/track_head.py b/FastVGGT/vggt/heads/track_head.py
deleted file mode 100644
index a4f1d9bd83cca1f74f97a644a02b984904f84706..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/track_head.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch.nn as nn
-from .dpt_head import DPTHead
-from .track_modules.base_track_predictor import BaseTrackerPredictor
-
-
-class TrackHead(nn.Module):
- """
- Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
- The tracking is performed iteratively, refining predictions over multiple iterations.
- """
-
- def __init__(
- self,
- dim_in,
- patch_size=14,
- features=128,
- iters=4,
- predict_conf=True,
- stride=2,
- corr_levels=7,
- corr_radius=4,
- hidden_size=384,
- ):
- """
- Initialize the TrackHead module.
-
- Args:
- dim_in (int): Input dimension of tokens from the backbone.
- patch_size (int): Size of image patches used in the vision transformer.
- features (int): Number of feature channels in the feature extractor output.
- iters (int): Number of refinement iterations for tracking predictions.
- predict_conf (bool): Whether to predict confidence scores for tracked points.
- stride (int): Stride value for the tracker predictor.
- corr_levels (int): Number of correlation pyramid levels
- corr_radius (int): Radius for correlation computation, controlling the search area.
- hidden_size (int): Size of hidden layers in the tracker network.
- """
- super().__init__()
-
- self.patch_size = patch_size
-
- # Feature extractor based on DPT architecture
- # Processes tokens into feature maps for tracking
- self.feature_extractor = DPTHead(
- dim_in=dim_in,
- patch_size=patch_size,
- features=features,
- feature_only=True, # Only output features, no activation
- down_ratio=2, # Reduces spatial dimensions by factor of 2
- pos_embed=False,
- )
-
- # Tracker module that predicts point trajectories
- # Takes feature maps and predicts coordinates and visibility
- self.tracker = BaseTrackerPredictor(
- latent_dim=features, # Match the output_dim of feature extractor
- predict_conf=predict_conf,
- stride=stride,
- corr_levels=corr_levels,
- corr_radius=corr_radius,
- hidden_size=hidden_size,
- )
-
- self.iters = iters
-
- def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
- """
- Forward pass of the TrackHead.
-
- Args:
- aggregated_tokens_list (list): List of aggregated tokens from the backbone.
- images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
- B = batch size, S = sequence length.
- patch_start_idx (int): Starting index for patch tokens.
- query_points (torch.Tensor, optional): Initial query points to track.
- If None, points are initialized by the tracker.
- iters (int, optional): Number of refinement iterations. If None, uses self.iters.
-
- Returns:
- tuple:
- - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
- - vis_scores (torch.Tensor): Visibility scores for tracked points.
- - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
- """
- B, S, _, H, W = images.shape
-
- # Extract features from tokens
- # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
- feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
-
- # Use default iterations if not specified
- if iters is None:
- iters = self.iters
-
- # Perform tracking using the extracted features
- coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters)
-
- return coord_preds, vis_scores, conf_scores
diff --git a/FastVGGT/vggt/heads/track_modules/__init__.py b/FastVGGT/vggt/heads/track_modules/__init__.py
deleted file mode 100644
index 0952fcc3f57e34b3747962e9ebd6fc57aeea63fa..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/track_modules/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 5ca81dbf0ca49b601ca2662314a3391c8ba00890..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc
deleted file mode 100644
index 350f8ca52c0aca644125fea47f83d49337478bcb..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc
deleted file mode 100644
index 81e9466ff91b16d265807533eeada2e8e674d7e0..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc
deleted file mode 100644
index d3fe19d5e60445f0fa06866bffaa2b8b2e60a6f8..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc
deleted file mode 100644
index cac348bf405bc00bc6578760ba8ad57185f33b81..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/heads/track_modules/base_track_predictor.py b/FastVGGT/vggt/heads/track_modules/base_track_predictor.py
deleted file mode 100644
index 3ce8ec4b66fff236e015d1bcaf85c8237a52be7a..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/track_modules/base_track_predictor.py
+++ /dev/null
@@ -1,209 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn as nn
-from einops import rearrange, repeat
-
-
-from .blocks import EfficientUpdateFormer, CorrBlock
-from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
-from .modules import Mlp
-
-
-class BaseTrackerPredictor(nn.Module):
- def __init__(
- self,
- stride=1,
- corr_levels=5,
- corr_radius=4,
- latent_dim=128,
- hidden_size=384,
- use_spaceatt=True,
- depth=6,
- max_scale=518,
- predict_conf=True,
- ):
- super(BaseTrackerPredictor, self).__init__()
- """
- The base template to create a track predictor
-
- Modified from https://github.com/facebookresearch/co-tracker/
- and https://github.com/facebookresearch/vggsfm
- """
-
- self.stride = stride
- self.latent_dim = latent_dim
- self.corr_levels = corr_levels
- self.corr_radius = corr_radius
- self.hidden_size = hidden_size
- self.max_scale = max_scale
- self.predict_conf = predict_conf
-
- self.flows_emb_dim = latent_dim // 2
-
- self.corr_mlp = Mlp(
- in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
- hidden_features=self.hidden_size,
- out_features=self.latent_dim,
- )
-
- self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
-
- self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
-
- space_depth = depth if use_spaceatt else 0
- time_depth = depth
-
- self.updateformer = EfficientUpdateFormer(
- space_depth=space_depth,
- time_depth=time_depth,
- input_dim=self.transformer_dim,
- hidden_size=self.hidden_size,
- output_dim=self.latent_dim + 2,
- mlp_ratio=4.0,
- add_space_attn=use_spaceatt,
- )
-
- self.fmap_norm = nn.LayerNorm(self.latent_dim)
- self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
-
- # A linear layer to update track feats at each iteration
- self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
-
- self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
-
- if predict_conf:
- self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
-
- def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
- """
- query_points: B x N x 2, the number of batches, tracks, and xy
- fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
- note HH and WW is the size of feature maps instead of original images
- """
- B, N, D = query_points.shape
- B, S, C, HH, WW = fmaps.shape
-
- assert D == 2, "Input points must be 2D coordinates"
-
- # apply a layernorm to fmaps here
- fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
- fmaps = fmaps.permute(0, 1, 4, 2, 3)
-
- # Scale the input query_points because we may downsample the images
- # by down_ratio or self.stride
- # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
- # its query_points should be query_points/4
- if down_ratio > 1:
- query_points = query_points / float(down_ratio)
-
- query_points = query_points / float(self.stride)
-
- # Init with coords as the query points
- # It means the search will start from the position of query points at the reference frames
- coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
-
- # Sample/extract the features of the query points in the query frame
- query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
-
- # init track feats by query feats
- track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
- # back up the init coords
- coords_backup = coords.clone()
-
- fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
-
- coord_preds = []
-
- # Iterative Refinement
- for _ in range(iters):
- # Detach the gradients from the last iteration
- # (in my experience, not very important for performance)
- coords = coords.detach()
-
- fcorrs = fcorr_fn.corr_sample(track_feats, coords)
-
- corr_dim = fcorrs.shape[3]
- fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
- fcorrs_ = self.corr_mlp(fcorrs_)
-
- # Movement of current coords relative to query points
- flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
-
- flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
-
- # (In my trials, it is also okay to just add the flows_emb instead of concat)
- flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
-
- track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
-
- # Concatenate them as the input for the transformers
- transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
-
- # 2D positional embed
- # TODO: this can be much simplified
- pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
- sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
-
- sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
-
- x = transformer_input + sampled_pos_emb
-
- # Add the query ref token to the track feats
- query_ref_token = torch.cat(
- [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
- )
- x = x + query_ref_token.to(x.device).to(x.dtype)
-
- # B, N, S, C
- x = rearrange(x, "(b n) s d -> b n s d", b=B)
-
- # Compute the delta coordinates and delta track features
- delta, _ = self.updateformer(x)
-
- # BN, S, C
- delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
- delta_coords_ = delta[:, :, :2]
- delta_feats_ = delta[:, :, 2:]
-
- track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
- delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
-
- # Update the track features
- track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
-
- track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
-
- # B x S x N x 2
- coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
-
- # Force coord0 as query
- # because we assume the query points should not be changed
- coords[:, 0] = coords_backup[:, 0]
-
- # The predicted tracks are in the original image scale
- if down_ratio > 1:
- coord_preds.append(coords * self.stride * down_ratio)
- else:
- coord_preds.append(coords * self.stride)
-
- # B, S, N
- vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
- if apply_sigmoid:
- vis_e = torch.sigmoid(vis_e)
-
- if self.predict_conf:
- conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
- if apply_sigmoid:
- conf_e = torch.sigmoid(conf_e)
- else:
- conf_e = None
-
- if return_feat:
- return coord_preds, vis_e, track_feats, query_track_feat, conf_e
- else:
- return coord_preds, vis_e, conf_e
diff --git a/FastVGGT/vggt/heads/track_modules/blocks.py b/FastVGGT/vggt/heads/track_modules/blocks.py
deleted file mode 100644
index 15c161c89ef99742b0f2c6f397c9121fe9301e08..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/track_modules/blocks.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-# Modified from https://github.com/facebookresearch/co-tracker/
-
-import math
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .utils import bilinear_sampler
-from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
-
-
-class EfficientUpdateFormer(nn.Module):
- """
- Transformer model that updates track estimates.
- """
-
- def __init__(
- self,
- space_depth=6,
- time_depth=6,
- input_dim=320,
- hidden_size=384,
- num_heads=8,
- output_dim=130,
- mlp_ratio=4.0,
- add_space_attn=True,
- num_virtual_tracks=64,
- ):
- super().__init__()
-
- self.out_channels = 2
- self.num_heads = num_heads
- self.hidden_size = hidden_size
- self.add_space_attn = add_space_attn
-
- # Add input LayerNorm before linear projection
- self.input_norm = nn.LayerNorm(input_dim)
- self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
-
- # Add output LayerNorm before final projection
- self.output_norm = nn.LayerNorm(hidden_size)
- self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
- self.num_virtual_tracks = num_virtual_tracks
-
- if self.add_space_attn:
- self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
- else:
- self.virual_tracks = None
-
- self.time_blocks = nn.ModuleList(
- [
- AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
- for _ in range(time_depth)
- ]
- )
-
- if add_space_attn:
- self.space_virtual_blocks = nn.ModuleList(
- [
- AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
- for _ in range(space_depth)
- ]
- )
- self.space_point2virtual_blocks = nn.ModuleList(
- [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
- )
- self.space_virtual2point_blocks = nn.ModuleList(
- [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
- )
- assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
- self.initialize_weights()
-
- def initialize_weights(self):
- def _basic_init(module):
- if isinstance(module, nn.Linear):
- torch.nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- nn.init.constant_(module.bias, 0)
- torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
-
- self.apply(_basic_init)
-
- def forward(self, input_tensor, mask=None):
- # Apply input LayerNorm
- input_tensor = self.input_norm(input_tensor)
- tokens = self.input_transform(input_tensor)
-
- init_tokens = tokens
-
- B, _, T, _ = tokens.shape
-
- if self.add_space_attn:
- virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
- tokens = torch.cat([tokens, virtual_tokens], dim=1)
-
- _, N, _, _ = tokens.shape
-
- j = 0
- for i in range(len(self.time_blocks)):
- time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
-
- time_tokens = self.time_blocks[i](time_tokens)
-
- tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
- if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
- space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
- point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
- virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
-
- virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
- virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
- point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
-
- space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
- tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
- j += 1
-
- if self.add_space_attn:
- tokens = tokens[:, : N - self.num_virtual_tracks]
-
- tokens = tokens + init_tokens
-
- # Apply output LayerNorm before final projection
- tokens = self.output_norm(tokens)
- flow = self.flow_head(tokens)
-
- return flow, None
-
-
-class CorrBlock:
- def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
- """
- Build a pyramid of feature maps from the input.
-
- fmaps: Tensor (B, S, C, H, W)
- num_levels: number of pyramid levels (each downsampled by factor 2)
- radius: search radius for sampling correlation
- multiple_track_feats: if True, split the target features per pyramid level
- padding_mode: passed to grid_sample / bilinear_sampler
- """
- B, S, C, H, W = fmaps.shape
- self.S, self.C, self.H, self.W = S, C, H, W
- self.num_levels = num_levels
- self.radius = radius
- self.padding_mode = padding_mode
- self.multiple_track_feats = multiple_track_feats
-
- # Build pyramid: each level is half the spatial resolution of the previous
- self.fmaps_pyramid = [fmaps] # level 0 is full resolution
- current_fmaps = fmaps
- for i in range(num_levels - 1):
- B, S, C, H, W = current_fmaps.shape
- # Merge batch & sequence dimensions
- current_fmaps = current_fmaps.reshape(B * S, C, H, W)
- # Avg pool down by factor 2
- current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
- _, _, H_new, W_new = current_fmaps.shape
- current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
- self.fmaps_pyramid.append(current_fmaps)
-
- # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
- # This grid is added to the (scaled) coordinate centroids.
- r = self.radius
- dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
- dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
- # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
- self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
-
- def corr_sample(self, targets, coords):
- """
- Instead of storing the entire correlation pyramid, we compute each level's correlation
- volume, sample it immediately, then discard it. This saves GPU memory.
-
- Args:
- targets: Tensor (B, S, N, C) — features for the current targets.
- coords: Tensor (B, S, N, 2) — coordinates at full resolution.
-
- Returns:
- Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
- """
- B, S, N, C = targets.shape
-
- # If you have multiple track features, split them per level.
- if self.multiple_track_feats:
- targets_split = torch.split(targets, C // self.num_levels, dim=-1)
-
- out_pyramid = []
- for i, fmaps in enumerate(self.fmaps_pyramid):
- # Get current spatial resolution H, W for this pyramid level.
- B, S, C, H, W = fmaps.shape
- # Reshape feature maps for correlation computation:
- # fmap2s: (B, S, C, H*W)
- fmap2s = fmaps.view(B, S, C, H * W)
- # Choose appropriate target features.
- fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
-
- # Compute correlation directly
- corrs = compute_corr_level(fmap1, fmap2s, C)
- corrs = corrs.view(B, S, N, H, W)
-
- # Prepare sampling grid:
- # Scale down the coordinates for the current level.
- centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
- # Make sure our precomputed delta grid is on the same device/dtype.
- delta_lvl = self.delta.to(coords.device).to(coords.dtype)
- # Now the grid for grid_sample is:
- # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
- coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
-
- # Sample from the correlation volume using bilinear interpolation.
- # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
- corrs_sampled = bilinear_sampler(
- corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
- )
- # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
- corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
- out_pyramid.append(corrs_sampled)
-
- # Concatenate all levels along the last dimension.
- out = torch.cat(out_pyramid, dim=-1).contiguous()
- return out
-
-
-def compute_corr_level(fmap1, fmap2s, C):
- # fmap1: (B, S, N, C)
- # fmap2s: (B, S, C, H*W)
- corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
- corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
- return corrs / math.sqrt(C)
diff --git a/FastVGGT/vggt/heads/track_modules/modules.py b/FastVGGT/vggt/heads/track_modules/modules.py
deleted file mode 100644
index 12de4f1ad76364d4665e53ac80e1037fadf98d08..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/track_modules/modules.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from functools import partial
-from typing import Callable
-import collections
-from torch import Tensor
-from itertools import repeat
-
-
-# From PyTorch internals
-def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
- return tuple(x)
- return tuple(repeat(x, n))
-
- return parse
-
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- return val if exists(val) else d
-
-
-to_2tuple = _ntuple(2)
-
-
-class ResidualBlock(nn.Module):
- """
- ResidualBlock: construct a block of two conv layers with residual connections
- """
-
- def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
- super(ResidualBlock, self).__init__()
-
- self.conv1 = nn.Conv2d(
- in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros"
- )
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros")
- self.relu = nn.ReLU(inplace=True)
-
- num_groups = planes // 8
-
- if norm_fn == "group":
- self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
- self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
- if not stride == 1:
- self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
-
- elif norm_fn == "batch":
- self.norm1 = nn.BatchNorm2d(planes)
- self.norm2 = nn.BatchNorm2d(planes)
- if not stride == 1:
- self.norm3 = nn.BatchNorm2d(planes)
-
- elif norm_fn == "instance":
- self.norm1 = nn.InstanceNorm2d(planes)
- self.norm2 = nn.InstanceNorm2d(planes)
- if not stride == 1:
- self.norm3 = nn.InstanceNorm2d(planes)
-
- elif norm_fn == "none":
- self.norm1 = nn.Sequential()
- self.norm2 = nn.Sequential()
- if not stride == 1:
- self.norm3 = nn.Sequential()
- else:
- raise NotImplementedError
-
- if stride == 1:
- self.downsample = None
- else:
- self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
-
- def forward(self, x):
- y = x
- y = self.relu(self.norm1(self.conv1(y)))
- y = self.relu(self.norm2(self.conv2(y)))
-
- if self.downsample is not None:
- x = self.downsample(x)
-
- return self.relu(x + y)
-
-
-class Mlp(nn.Module):
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
-
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- norm_layer=None,
- bias=True,
- drop=0.0,
- use_conv=False,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
-
- self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
- self.drop2 = nn.Dropout(drop_probs[1])
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop1(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
-
-
-class AttnBlock(nn.Module):
- def __init__(
- self,
- hidden_size,
- num_heads,
- attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
- mlp_ratio=4.0,
- **block_kwargs,
- ):
- """
- Self attention block
- """
- super().__init__()
-
- self.norm1 = nn.LayerNorm(hidden_size)
- self.norm2 = nn.LayerNorm(hidden_size)
-
- self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
-
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
-
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
-
- def forward(self, x, mask=None):
- # Prepare the mask for PyTorch's attention (it expects a different format)
- # attn_mask = mask if mask is not None else None
- # Normalize before attention
- x = self.norm1(x)
-
- # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
- # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
-
- attn_output, _ = self.attn(x, x, x)
-
- # Add & Norm
- x = x + attn_output
- x = x + self.mlp(self.norm2(x))
- return x
-
-
-class CrossAttnBlock(nn.Module):
- def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
- """
- Cross attention block
- """
- super().__init__()
-
- self.norm1 = nn.LayerNorm(hidden_size)
- self.norm_context = nn.LayerNorm(hidden_size)
- self.norm2 = nn.LayerNorm(hidden_size)
-
- self.cross_attn = nn.MultiheadAttention(
- embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
- )
-
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
-
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
-
- def forward(self, x, context, mask=None):
- # Normalize inputs
- x = self.norm1(x)
- context = self.norm_context(context)
-
- # Apply cross attention
- # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
- attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
-
- # Add & Norm
- x = x + attn_output
- x = x + self.mlp(self.norm2(x))
- return x
diff --git a/FastVGGT/vggt/heads/track_modules/utils.py b/FastVGGT/vggt/heads/track_modules/utils.py
deleted file mode 100644
index 3f1fffeaedd33c7f1c2ef54220e24a2a0e5a57b2..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/track_modules/utils.py
+++ /dev/null
@@ -1,223 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-# Modified from https://github.com/facebookresearch/vggsfm
-# and https://github.com/facebookresearch/co-tracker/tree/main
-
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from typing import Optional, Tuple, Union
-
-
-def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
- """
- This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
- It is a wrapper of get_2d_sincos_pos_embed_from_grid.
- Args:
- - embed_dim: The embedding dimension.
- - grid_size: The grid size.
- Returns:
- - pos_embed: The generated 2D positional embedding.
- """
- if isinstance(grid_size, tuple):
- grid_size_h, grid_size_w = grid_size
- else:
- grid_size_h = grid_size_w = grid_size
- grid_h = torch.arange(grid_size_h, dtype=torch.float)
- grid_w = torch.arange(grid_size_w, dtype=torch.float)
- grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
- grid = torch.stack(grid, dim=0)
- grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- if return_grid:
- return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid)
- return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
-
-
-def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
- """
- This function generates a 2D positional embedding from a given grid using sine and cosine functions.
-
- Args:
- - embed_dim: The embedding dimension.
- - grid: The grid to generate the embedding from.
-
- Returns:
- - emb: The generated 2D positional embedding.
- """
- assert embed_dim % 2 == 0
-
- # use half of dimensions to encode grid_h
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
-
- emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
- return emb
-
-
-def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
- """
- This function generates a 1D positional embedding from a given grid using sine and cosine functions.
-
- Args:
- - embed_dim: The embedding dimension.
- - pos: The position to generate the embedding from.
-
- Returns:
- - emb: The generated 1D positional embedding.
- """
- assert embed_dim % 2 == 0
- omega = torch.arange(embed_dim // 2, dtype=torch.double)
- omega /= embed_dim / 2.0
- omega = 1.0 / 10000**omega # (D/2,)
-
- pos = pos.reshape(-1) # (M,)
- out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
-
- emb_sin = torch.sin(out) # (M, D/2)
- emb_cos = torch.cos(out) # (M, D/2)
-
- emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
- return emb[None].float()
-
-
-def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
- """
- This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
-
- Args:
- - xy: The coordinates to generate the embedding from.
- - C: The size of the embedding.
- - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
-
- Returns:
- - pe: The generated 2D positional embedding.
- """
- B, N, D = xy.shape
- assert D == 2
-
- x = xy[:, :, 0:1]
- y = xy[:, :, 1:2]
- div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
-
- pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
- pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
-
- pe_x[:, :, 0::2] = torch.sin(x * div_term)
- pe_x[:, :, 1::2] = torch.cos(x * div_term)
-
- pe_y[:, :, 0::2] = torch.sin(y * div_term)
- pe_y[:, :, 1::2] = torch.cos(y * div_term)
-
- pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
- if cat_coords:
- pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
- return pe
-
-
-def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
- r"""Sample a tensor using bilinear interpolation
-
- `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
- coordinates :attr:`coords` using bilinear interpolation. It is the same
- as `torch.nn.functional.grid_sample()` but with a different coordinate
- convention.
-
- The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
- :math:`B` is the batch size, :math:`C` is the number of channels,
- :math:`H` is the height of the image, and :math:`W` is the width of the
- image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
- interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
-
- Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
- in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
- that in this case the order of the components is slightly different
- from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
-
- If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
- in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
- left-most image pixel :math:`W-1` to the center of the right-most
- pixel.
-
- If `align_corners` is `False`, the coordinate :math:`x` is assumed to
- be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
- the left-most pixel :math:`W` to the right edge of the right-most
- pixel.
-
- Similar conventions apply to the :math:`y` for the range
- :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
- :math:`[0,T-1]` and :math:`[0,T]`.
-
- Args:
- input (Tensor): batch of input images.
- coords (Tensor): batch of coordinates.
- align_corners (bool, optional): Coordinate convention. Defaults to `True`.
- padding_mode (str, optional): Padding mode. Defaults to `"border"`.
-
- Returns:
- Tensor: sampled points.
- """
- coords = coords.detach().clone()
- ############################################################
- # IMPORTANT:
- coords = coords.to(input.device).to(input.dtype)
- ############################################################
-
- sizes = input.shape[2:]
-
- assert len(sizes) in [2, 3]
-
- if len(sizes) == 3:
- # t x y -> x y t to match dimensions T H W in grid_sample
- coords = coords[..., [1, 2, 0]]
-
- if align_corners:
- scale = torch.tensor(
- [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
- )
- else:
- scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
-
- coords.mul_(scale) # coords = coords * scale
- coords.sub_(1) # coords = coords - 1
-
- return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
-
-
-def sample_features4d(input, coords):
- r"""Sample spatial features
-
- `sample_features4d(input, coords)` samples the spatial features
- :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
-
- The field is sampled at coordinates :attr:`coords` using bilinear
- interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
- 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
- same convention as :func:`bilinear_sampler` with `align_corners=True`.
-
- The output tensor has one feature per point, and has shape :math:`(B,
- R, C)`.
-
- Args:
- input (Tensor): spatial features.
- coords (Tensor): points.
-
- Returns:
- Tensor: sampled features.
- """
-
- B, _, _, _ = input.shape
-
- # B R 2 -> B R 1 2
- coords = coords.unsqueeze(2)
-
- # B C R 1
- feats = bilinear_sampler(input, coords)
-
- return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
diff --git a/FastVGGT/vggt/heads/utils.py b/FastVGGT/vggt/heads/utils.py
deleted file mode 100644
index 533fc8ae67a75cd0a94d5ca96dc5a0513446c64f..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/heads/utils.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn as nn
-
-
-def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
- """
- Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
-
- Args:
- pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
- embed_dim: Output channel dimension for embeddings
-
- Returns:
- Tensor of shape (H, W, embed_dim) with positional embeddings
- """
- H, W, grid_dim = pos_grid.shape
- assert grid_dim == 2
- pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
-
- # Process x and y coordinates separately
- emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
- emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
-
- # Combine and reshape
- emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
-
- return emb.view(H, W, embed_dim) # [H, W, D]
-
-
-def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
- """
- This function generates a 1D positional embedding from a given grid using sine and cosine functions.
-
- Args:
- - embed_dim: The embedding dimension.
- - pos: The position to generate the embedding from.
-
- Returns:
- - emb: The generated 1D positional embedding.
- """
- assert embed_dim % 2 == 0
- device = pos.device
- omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
- omega /= embed_dim / 2.0
- omega = 1.0 / omega_0**omega # (D/2,)
-
- pos = pos.reshape(-1) # (M,)
- out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
-
- emb_sin = torch.sin(out) # (M, D/2)
- emb_cos = torch.cos(out) # (M, D/2)
-
- emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
- return emb.float()
-
-
-# Inspired by https://github.com/microsoft/moge
-
-
-def create_uv_grid(
- width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
-) -> torch.Tensor:
- """
- Create a normalized UV grid of shape (width, height, 2).
-
- The grid spans horizontally and vertically according to an aspect ratio,
- ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
- corner is at (x_span, y_span), normalized by the diagonal of the plane.
-
- Args:
- width (int): Number of points horizontally.
- height (int): Number of points vertically.
- aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
- dtype (torch.dtype, optional): Data type of the resulting tensor.
- device (torch.device, optional): Device on which the tensor is created.
-
- Returns:
- torch.Tensor: A (width, height, 2) tensor of UV coordinates.
- """
- # Derive aspect ratio if not explicitly provided
- if aspect_ratio is None:
- aspect_ratio = float(width) / float(height)
-
- # Compute normalized spans for X and Y
- diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
- span_x = aspect_ratio / diag_factor
- span_y = 1.0 / diag_factor
-
- # Establish the linspace boundaries
- left_x = -span_x * (width - 1) / width
- right_x = span_x * (width - 1) / width
- top_y = -span_y * (height - 1) / height
- bottom_y = span_y * (height - 1) / height
-
- # Generate 1D coordinates
- x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
- y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
-
- # Create 2D meshgrid (width x height) and stack into UV
- uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
- uv_grid = torch.stack((uu, vv), dim=-1)
-
- return uv_grid
diff --git a/FastVGGT/vggt/layers/__init__.py b/FastVGGT/vggt/layers/__init__.py
deleted file mode 100644
index 8120f4bc83066cb3f825ce32daa3b437f88486f1..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from .mlp import Mlp
-from .patch_embed import PatchEmbed
-from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
-from .block import NestedTensorBlock
-from .attention import MemEffAttention
diff --git a/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 2840683edfd99fcda3d1d5c9e18d699dba029a86..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc
deleted file mode 100644
index 42e56e451e90459b8b74d99bd0ac9049e702e1bd..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc
deleted file mode 100644
index 069b3ad92a285342f200297a7358dbf39c0ef372..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc
deleted file mode 100644
index e0bc9cb8a27db08fc0921a2d8a54a60b2b126acd..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc
deleted file mode 100644
index a416380412c8942981ab0e3f5369cebe05f9496c..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc
deleted file mode 100644
index 053bf054f2a74eefeea0536eeddd23daacd28cbe..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc
deleted file mode 100644
index 20fa7427ef4c285abd1d9b48743f4c3360953591..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc
deleted file mode 100644
index af9854cf6446fb30404f4529d63b20c20f8e63e8..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc
deleted file mode 100644
index c9b1fc5deb392addff7c8620b095e0ba4d90a8de..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc
deleted file mode 100644
index d36ad71ca8dacfb3991b5b2560bc89cdfdd15c52..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/layers/attention.py b/FastVGGT/vggt/layers/attention.py
deleted file mode 100644
index aef68a4e2c628ad257a6b77b46447c8421c5b11a..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/attention.py
+++ /dev/null
@@ -1,257 +0,0 @@
-import os
-from pathlib import Path
-import time
-from PIL import Image
-import torch
-from torch import Tensor
-from torch import nn
-import torch.nn.functional as F
-from tqdm.std import tqdm
-from merging.merge import (
- token_merge_bipartite2d,
-)
-import matplotlib.pyplot as plt
-import numpy as np
-from PIL import Image
-
-XFORMERS_AVAILABLE = False
-
-# Global variables for attention visualization
-vis_attn_map = False
-current_images = []
-attention_map = None
-
-
-class Attention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = True,
- proj_bias: bool = True,
- attn_drop: float = 0.0,
- proj_drop: float = 0.0,
- norm_layer: nn.Module = nn.LayerNorm,
- qk_norm: bool = False,
- kv_group_size: int = 1,
- fused_attn: bool = True,
- rope=None,
- global_merging=None,
- patch_width: int = 37,
- patch_height: int = 28,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, "dim should be divisible by num_heads"
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim**-0.5
- self.patch_width = patch_width
- self.patch_height = patch_height
- self.fused_attn = fused_attn
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim, bias=proj_bias)
- self.proj_drop = nn.Dropout(proj_drop)
- self.rope = rope
- self.kv_group_size = kv_group_size
-
- def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor:
- merge_num = list(range(24))
-
- B, N, C = x.shape
- qkv = (
- self.qkv(x)
- .reshape(B, N, 3, self.num_heads, self.head_dim)
- .permute(2, 0, 3, 1, 4)
- )
- q, k, v = qkv.unbind(0)
- q, k = self.q_norm(q), self.k_norm(k)
-
- if self.rope is not None:
- q = self.rope(q, pos)
- k = self.rope(k, pos)
-
- if vis_attn_map and global_merging is not None:
- # s1 Chunk computation of q@k
- chunk_size = 4096
- attn_maps = []
- total_chunks = (
- q.size(-2) + chunk_size - 1
- ) // chunk_size # Calculate total number of chunks
- start_time_total = time.time()
- with torch.no_grad():
- for start in tqdm(
- range(0, q.size(-2), chunk_size),
- total=total_chunks,
- desc="Processing chunks",
- ):
- end = min(start + chunk_size, q.size(-2))
- q_chunk = q[:, :, start:end, :] # (1, 16, chunk_size, 64)
- start_time_chunk = time.time()
- attn_chunk = q_chunk @ k.transpose(
- -2, -1
- ) # (1, 16, chunk_size, 34353)
- attn_maps.append(attn_chunk.cpu())
- end_time_chunk = time.time()
- print(
- f"Chunk {start}:{end} processed in {end_time_chunk - start_time_chunk:.4f} seconds"
- )
- end_time_total = time.time()
- print(
- f"\nTotal processing time: {end_time_total - start_time_total:.4f} seconds"
- )
-
- attn_map = torch.cat(attn_maps, dim=-2)
- attn = attn_map[0].mean(0)
- frame_token_num = self.patch_height * self.patch_width + 5
- for target_token_idx in [
- 0,
- self.patch_height * self.patch_width,
- self.patch_height * self.patch_width * 10,
- ]: # Iterate through each image's target_token
- for image_idx in range(
- len(current_images)
- ): # Corresponding to which image to visualize
- target_attn = attn[
- target_token_idx,
- image_idx * frame_token_num : (image_idx + 1) * frame_token_num,
- ]
- target_attn_map = target_attn[5:].reshape(
- self.patch_height, self.patch_width
- )
- # 1) Read original image to get true size (H, W)
- image_path = current_images[image_idx]
- p = Path(image_path)
- parts = p.parts
- scene_name = parts[-4]
-
- image = Image.open(image_path).convert("RGB")
- img_width, img_height = image.size # PIL size: (W, H)
-
- # Upsample attention map to the original image size
- target_attn_map = F.interpolate(
- target_attn_map.unsqueeze(0).unsqueeze(0), # (1,1,h,w)
- size=(img_height, img_width),
- mode="bilinear",
- )
- target_attn_map = target_attn_map.squeeze()
-
- # Convert image to numpy for blending
- img_np = np.array(image) / 255.0
-
- # 2. Normalize attention map
- target_attn_map = (target_attn_map - target_attn_map.min()) / (
- target_attn_map.max() - target_attn_map.min()
- )
-
- # 3. Color attention map
- cmap = plt.get_cmap("jet")
- attn_color = cmap(target_attn_map.cpu().float().numpy())
- attn_color = attn_color[:, :, :3]
-
- # 4. Blend attention and original image
- overlay = img_np * 0.5 + attn_color * 0.5
-
- plt.imshow(overlay, cmap="viridis")
- output_dir = f"attention_map/{scene_name}/block_{attention_map}/token_{target_token_idx}"
- os.makedirs(output_dir, exist_ok=True)
- output_path = os.path.join(output_dir, f"color_{image_idx}.png")
- plt.savefig(output_path)
- plt.close()
-
- if global_merging is not None and global_merging in merge_num:
- generator = torch.Generator(device=x.device)
- generator.manual_seed(33)
-
- merge_ratio = 0.9
- r = int(x.shape[1] * merge_ratio)
-
- m, u = token_merge_bipartite2d(
- x,
- self.patch_width,
- self.patch_height,
- 2,
- 2,
- r,
- False,
- generator,
- enable_protection=True,
- )
-
- m_a, u_a = (m, u)
-
- B_q, H_q, N_q, D_q = q.shape
-
- q_merge_in = q.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q)
- k_merge_in = k.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q)
- v_merge_in = v.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q)
-
- q_out, k_out, v_out = m_a(
- q_merge_in,
- mode="mean",
- extra_tensors=k_merge_in,
- extra_tensors_2=v_merge_in,
- )
-
- del q_merge_in, k_merge_in, v_merge_in
-
- N_m = q_out.shape[1]
- q = q_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3)
- k = k_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3)
- v = v_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3)
-
- del q_out, k_out, v_out
-
- N = N_m
-
- x = F.scaled_dot_product_attention(
- q,
- k,
- v,
- dropout_p=self.attn_drop.p if self.training else 0.0,
- )
- del q, k, v
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- if global_merging is not None and global_merging in merge_num:
- x = u_a(x)
- return x
-
-
-class MemEffAttention(Attention):
- def forward(
- self, x: Tensor, attn_bias=None, pos=None, global_merging=None
- ) -> Tensor:
- assert (
- pos is None or self.rope is not None
- ), "Position encoding is only supported with RoPE"
- if not XFORMERS_AVAILABLE:
- if attn_bias is not None:
- raise AssertionError("xFormers is required for using nested tensors")
- return super().forward(x, pos=pos, global_merging=global_merging)
-
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
-
- q, k, v = qkv.unbind(2)
-
- if self.rope is not None and pos is not None:
- q = self.rope(q, pos)
- k = self.rope(k, pos)
-
- # Use scaled dot-product attention
- attn = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim**-0.5)
- if attn_bias is not None:
- attn = attn + attn_bias
- attn = F.softmax(attn, dim=-1)
- x = torch.matmul(attn, v)
- x = x.reshape([B, N, C])
-
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
diff --git a/FastVGGT/vggt/layers/block.py b/FastVGGT/vggt/layers/block.py
deleted file mode 100644
index ffa3c18aaf79fc755630d24c24c1f4c7f75624e1..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/block.py
+++ /dev/null
@@ -1,272 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the Apache License, Version 2.0
-# found in the LICENSE file in the root directory of this source tree.
-
-# References:
-# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
-# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
-
-
-from typing import Callable, List, Any, Tuple, Dict
-import torch
-from torch import nn, Tensor
-from .attention import Attention
-from .drop_path import DropPath
-from .layer_scale import LayerScale
-from .mlp import Mlp
-
-
-XFORMERS_AVAILABLE = False
-
-
-class Block(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = True,
- proj_bias: bool = True,
- ffn_bias: bool = True,
- drop: float = 0.0,
- attn_drop: float = 0.0,
- init_values=None,
- drop_path: float = 0.0,
- act_layer: Callable[..., nn.Module] = nn.GELU,
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
- attn_class: Callable[..., nn.Module] = Attention,
- ffn_layer: Callable[..., nn.Module] = Mlp,
- qk_norm: bool = False,
- fused_attn: bool = True,
- rope=None,
- merging=0,
- ) -> None:
- super().__init__()
-
- self.norm1 = norm_layer(dim)
-
- self.attn = attn_class(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- proj_bias=proj_bias,
- attn_drop=attn_drop,
- proj_drop=drop,
- qk_norm=qk_norm,
- fused_attn=fused_attn,
- rope=rope,
- )
-
- self.ls1 = (
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- )
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = ffn_layer(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop,
- bias=ffn_bias,
- )
- self.ls2 = (
- LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- )
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
- self.sample_drop_ratio = drop_path
-
- def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor:
- norm1_output = self.norm1(x)
- attn_output = self.attn(norm1_output, pos=pos, global_merging=global_merging)
- del norm1_output
- x = x + self.ls1(attn_output)
- del attn_output
-
- norm2_output = self.norm2(x)
- mlp_output = self.mlp(norm2_output)
- del norm2_output
- x = x + self.ls2(mlp_output)
- del mlp_output
- return x
-
-
-def drop_add_residual_stochastic_depth(
- x: Tensor,
- residual_func: Callable[[Tensor], Tensor],
- sample_drop_ratio: float = 0.0,
- pos=None,
-) -> Tensor:
- # 1) extract subset using permutation
- b, n, d = x.shape
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
- x_subset = x[brange]
-
- # 2) apply residual_func to get residual
- if pos is not None:
- # if necessary, apply rope to the subset
- pos = pos[brange]
- residual = residual_func(x_subset, pos=pos)
- else:
- residual = residual_func(x_subset)
-
- x_flat = x.flatten(1)
- residual = residual.flatten(1)
-
- residual_scale_factor = b / sample_subset_size
-
- # 3) add the residual
- x_plus_residual = torch.index_add(
- x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
- )
- return x_plus_residual.view_as(x)
-
-
-def get_branges_scales(x, sample_drop_ratio=0.0):
- b, n, d = x.shape
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
- residual_scale_factor = b / sample_subset_size
- return brange, residual_scale_factor
-
-
-def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
- if scaling_vector is None:
- x_flat = x.flatten(1)
- residual = residual.flatten(1)
- x_plus_residual = torch.index_add(
- x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
- )
- else:
- x_plus_residual = scaled_index_add(
- x,
- brange,
- residual.to(dtype=x.dtype),
- scaling=scaling_vector,
- alpha=residual_scale_factor,
- )
- return x_plus_residual
-
-
-attn_bias_cache: Dict[Tuple, Any] = {}
-
-
-def get_attn_bias_and_cat(x_list, branges=None):
- """
- this will perform the index select, cat the tensors, and provide the attn_bias from cache
- """
- batch_sizes = (
- [b.shape[0] for b in branges]
- if branges is not None
- else [x.shape[0] for x in x_list]
- )
- all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
- if all_shapes not in attn_bias_cache.keys():
- seqlens = []
- for b, x in zip(batch_sizes, x_list):
- for _ in range(b):
- seqlens.append(x.shape[1])
- attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
- attn_bias._batch_sizes = batch_sizes
- attn_bias_cache[all_shapes] = attn_bias
-
- if branges is not None:
- cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
- 1, -1, x_list[0].shape[-1]
- )
- else:
- tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
- cat_tensors = torch.cat(tensors_bs1, dim=1)
-
- return attn_bias_cache[all_shapes], cat_tensors
-
-
-def drop_add_residual_stochastic_depth_list(
- x_list: List[Tensor],
- residual_func: Callable[[Tensor, Any], Tensor],
- sample_drop_ratio: float = 0.0,
- scaling_vector=None,
-) -> Tensor:
- # 1) generate random set of indices for dropping samples in the batch
- branges_scales = [
- get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
- ]
- branges = [s[0] for s in branges_scales]
- residual_scale_factors = [s[1] for s in branges_scales]
-
- # 2) get attention bias and index+concat the tensors
- attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
-
- # 3) apply residual_func to get residual, and split the result
- residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
-
- outputs = []
- for x, brange, residual, residual_scale_factor in zip(
- x_list, branges, residual_list, residual_scale_factors
- ):
- outputs.append(
- add_residual(
- x, brange, residual, residual_scale_factor, scaling_vector
- ).view_as(x)
- )
- return outputs
-
-
-class NestedTensorBlock(Block):
- def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
- """
- x_list contains a list of tensors to nest together and run
- """
- assert isinstance(self.attn, MemEffAttention)
-
- if self.training and self.sample_drop_ratio > 0.0:
-
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
- return self.attn(self.norm1(x), attn_bias=attn_bias)
-
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
- return self.mlp(self.norm2(x))
-
- x_list = drop_add_residual_stochastic_depth_list(
- x_list,
- residual_func=attn_residual_func,
- sample_drop_ratio=self.sample_drop_ratio,
- scaling_vector=(
- self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
- ),
- )
- x_list = drop_add_residual_stochastic_depth_list(
- x_list,
- residual_func=ffn_residual_func,
- sample_drop_ratio=self.sample_drop_ratio,
- scaling_vector=(
- self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
- ),
- )
- return x_list
- else:
-
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
- return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
-
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
- return self.ls2(self.mlp(self.norm2(x)))
-
- attn_bias, x = get_attn_bias_and_cat(x_list)
- x = x + attn_residual_func(x, attn_bias=attn_bias)
- x = x + ffn_residual_func(x)
- return attn_bias.split(x)
-
- def forward(self, x_or_x_list):
- if isinstance(x_or_x_list, Tensor):
- return super().forward(x_or_x_list)
- elif isinstance(x_or_x_list, list):
- if not XFORMERS_AVAILABLE:
- raise AssertionError("xFormers is required for using nested tensors")
- return self.forward_nested(x_or_x_list)
- else:
- raise AssertionError
diff --git a/FastVGGT/vggt/layers/drop_path.py b/FastVGGT/vggt/layers/drop_path.py
deleted file mode 100644
index 1d640e0b969b8dcba96260243473700b4e5b24b5..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/drop_path.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the Apache License, Version 2.0
-# found in the LICENSE file in the root directory of this source tree.
-
-# References:
-# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
-# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
-
-
-from torch import nn
-
-
-def drop_path(x, drop_prob: float = 0.0, training: bool = False):
- if drop_prob == 0.0 or not training:
- return x
- keep_prob = 1 - drop_prob
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
- if keep_prob > 0.0:
- random_tensor.div_(keep_prob)
- output = x * random_tensor
- return output
-
-
-class DropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
-
- def __init__(self, drop_prob=None):
- super(DropPath, self).__init__()
- self.drop_prob = drop_prob
-
- def forward(self, x):
- return drop_path(x, self.drop_prob, self.training)
diff --git a/FastVGGT/vggt/layers/layer_scale.py b/FastVGGT/vggt/layers/layer_scale.py
deleted file mode 100644
index 4ddfc51c3d87370d50175f5b4e649dac1c614ff9..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/layer_scale.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the Apache License, Version 2.0
-# found in the LICENSE file in the root directory of this source tree.
-
-# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
-
-from typing import Union
-
-import torch
-from torch import Tensor
-from torch import nn
-
-
-class LayerScale(nn.Module):
- def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
- super().__init__()
- self.inplace = inplace
- self.gamma = nn.Parameter(init_values * torch.ones(dim))
-
- def forward(self, x: Tensor) -> Tensor:
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/FastVGGT/vggt/layers/mlp.py b/FastVGGT/vggt/layers/mlp.py
deleted file mode 100644
index bbf9432aae9258612caeae910a7bde17999e328e..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/mlp.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the Apache License, Version 2.0
-# found in the LICENSE file in the root directory of this source tree.
-
-# References:
-# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
-# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
-
-
-from typing import Callable, Optional
-
-from torch import Tensor, nn
-
-
-class Mlp(nn.Module):
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Callable[..., nn.Module] = nn.GELU,
- drop: float = 0.0,
- bias: bool = True,
- ) -> None:
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x: Tensor) -> Tensor:
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
diff --git a/FastVGGT/vggt/layers/patch_embed.py b/FastVGGT/vggt/layers/patch_embed.py
deleted file mode 100644
index bc19605e4d6e88d06355ae3b1afddc76f595aafe..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/patch_embed.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the Apache License, Version 2.0
-# found in the LICENSE file in the root directory of this source tree.
-
-# References:
-# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
-# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
-
-from typing import Callable, Optional, Tuple, Union
-
-from torch import Tensor
-import torch.nn as nn
-
-
-def make_2tuple(x):
- if isinstance(x, tuple):
- assert len(x) == 2
- return x
-
- assert isinstance(x, int)
- return (x, x)
-
-
-class PatchEmbed(nn.Module):
- """
- 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
-
- Args:
- img_size: Image size.
- patch_size: Patch token size.
- in_chans: Number of input image channels.
- embed_dim: Number of linear projection output channels.
- norm_layer: Normalization layer.
- """
-
- def __init__(
- self,
- img_size: Union[int, Tuple[int, int]] = 224,
- patch_size: Union[int, Tuple[int, int]] = 16,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer: Optional[Callable] = None,
- flatten_embedding: bool = True,
- ) -> None:
- super().__init__()
-
- image_HW = make_2tuple(img_size)
- patch_HW = make_2tuple(patch_size)
- patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
-
- self.img_size = image_HW
- self.patch_size = patch_HW
- self.patches_resolution = patch_grid_size
- self.num_patches = patch_grid_size[0] * patch_grid_size[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.flatten_embedding = flatten_embedding
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
-
- def forward(self, x: Tensor) -> Tensor:
- _, _, H, W = x.shape
- patch_H, patch_W = self.patch_size
-
- assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
- assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
-
- x = self.proj(x) # B C H W
- H, W = x.size(2), x.size(3)
- x = x.flatten(2).transpose(1, 2) # B HW C
- x = self.norm(x)
- if not self.flatten_embedding:
- x = x.reshape(-1, H, W, self.embed_dim) # B H W C
- return x
-
- def flops(self) -> float:
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
diff --git a/FastVGGT/vggt/layers/rope.py b/FastVGGT/vggt/layers/rope.py
deleted file mode 100644
index 84625de468ed89e69dd9e1579d541de71f2ebf37..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/rope.py
+++ /dev/null
@@ -1,209 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the Apache License, Version 2.0
-# found in the LICENSE file in the root directory of this source tree.
-
-
-# Implementation of 2D Rotary Position Embeddings (RoPE).
-
-# This module provides a clean implementation of 2D Rotary Position Embeddings,
-# which extends the original RoPE concept to handle 2D spatial positions.
-
-# Inspired by:
-# https://github.com/meta-llama/codellama/blob/main/llama/model.py
-# https://github.com/naver-ai/rope-vit
-
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from typing import Dict, Tuple
-
-
-class PositionGetter:
- """Generates and caches 2D spatial positions for patches in a grid.
-
- This class efficiently manages the generation of spatial coordinates for patches
- in a 2D grid, caching results to avoid redundant computations.
-
- Attributes:
- position_cache: Dictionary storing precomputed position tensors for different
- grid dimensions.
- """
-
- def __init__(self):
- """Initializes the position generator with an empty cache."""
- self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
-
- def __call__(
- self, batch_size: int, height: int, width: int, device: torch.device
- ) -> torch.Tensor:
- """Generates spatial positions for a batch of patches.
-
- Args:
- batch_size: Number of samples in the batch.
- height: Height of the grid in patches.
- width: Width of the grid in patches.
- device: Target device for the position tensor.
-
- Returns:
- Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
- for each position in the grid, repeated for each batch item.
- """
- if (height, width) not in self.position_cache:
- y_coords = torch.arange(height, device=device)
- x_coords = torch.arange(width, device=device)
- positions = torch.cartesian_prod(y_coords, x_coords)
- self.position_cache[height, width] = positions
-
- cached_positions = self.position_cache[height, width]
- return (
- cached_positions.view(1, height * width, 2)
- .expand(batch_size, -1, -1)
- .clone()
- )
-
-
-class RotaryPositionEmbedding2D(nn.Module):
- """2D Rotary Position Embedding implementation.
-
- This module applies rotary position embeddings to input tokens based on their
- 2D spatial positions. It handles the position-dependent rotation of features
- separately for vertical and horizontal dimensions.
-
- Args:
- frequency: Base frequency for the position embeddings. Default: 100.0
- scaling_factor: Scaling factor for frequency computation. Default: 1.0
-
- Attributes:
- base_frequency: Base frequency for computing position embeddings.
- scaling_factor: Factor to scale the computed frequencies.
- frequency_cache: Cache for storing precomputed frequency components.
- """
-
- def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
- """Initializes the 2D RoPE module."""
- super().__init__()
- self.base_frequency = frequency
- self.scaling_factor = scaling_factor
- self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
-
- def _compute_frequency_components(
- self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Computes frequency components for rotary embeddings.
-
- Args:
- dim: Feature dimension (must be even).
- seq_len: Maximum sequence length.
- device: Target device for computations.
- dtype: Data type for the computed tensors.
-
- Returns:
- Tuple of (cosine, sine) tensors for frequency components.
- """
- cache_key = (dim, seq_len, device, dtype)
- if cache_key not in self.frequency_cache:
- # Compute frequency bands
- exponents = torch.arange(0, dim, 2, device=device) / dim
- inv_freq = 1.0 / (self.base_frequency**exponents)
-
- # Generate position-dependent frequencies
- positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
- angles = torch.einsum("i,j->ij", positions, inv_freq)
-
- # Compute and cache frequency components
- angles = angles.to(dtype)
- angles = torch.cat((angles, angles), dim=-1)
- cos_components = angles.cos().to(dtype)
- sin_components = angles.sin().to(dtype)
- self.frequency_cache[cache_key] = (cos_components, sin_components)
-
- return self.frequency_cache[cache_key]
-
- @staticmethod
- def _rotate_features(x: torch.Tensor) -> torch.Tensor:
- """Performs feature rotation by splitting and recombining feature dimensions.
-
- Args:
- x: Input tensor to rotate.
-
- Returns:
- Rotated feature tensor.
- """
- feature_dim = x.shape[-1]
- x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
- def _apply_1d_rope(
- self,
- tokens: torch.Tensor,
- positions: torch.Tensor,
- cos_comp: torch.Tensor,
- sin_comp: torch.Tensor,
- ) -> torch.Tensor:
- """Applies 1D rotary position embeddings along one dimension.
-
- Args:
- tokens: Input token features.
- positions: Position indices.
- cos_comp: Cosine components for rotation.
- sin_comp: Sine components for rotation.
-
- Returns:
- Tokens with applied rotary position embeddings.
- """
- if positions.dtype != torch.long:
- positions = positions.long()
-
- # Embed positions with frequency components
- cos = F.embedding(positions, cos_comp)[:, None, :, :]
- sin = F.embedding(positions, sin_comp)[:, None, :, :]
-
- # Apply rotation
- return (tokens * cos) + (self._rotate_features(tokens) * sin)
-
- def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
- """Applies 2D rotary position embeddings to input tokens.
-
- Args:
- tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
- The feature dimension (dim) must be divisible by 4.
- positions: Position tensor of shape (batch_size, n_tokens, 2) containing
- the y and x coordinates for each token.
-
- Returns:
- Tensor of same shape as input with applied 2D rotary position embeddings.
-
- Raises:
- AssertionError: If input dimensions are invalid or positions are malformed.
- """
- # Validate inputs
- assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
- assert (
- positions.ndim == 3 and positions.shape[-1] == 2
- ), "Positions must have shape (batch_size, n_tokens, 2)"
-
- # Compute feature dimension for each spatial direction
- feature_dim = tokens.size(-1) // 2
-
- # Get frequency components
- max_position = int(positions.max()) + 1
- cos_comp, sin_comp = self._compute_frequency_components(
- feature_dim, max_position, tokens.device, tokens.dtype
- )
-
- # Split features for vertical and horizontal processing
- vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
-
- # Apply RoPE separately for each dimension
- vertical_features = self._apply_1d_rope(
- vertical_features, positions[..., 0], cos_comp, sin_comp
- )
- horizontal_features = self._apply_1d_rope(
- horizontal_features, positions[..., 1], cos_comp, sin_comp
- )
-
- # Combine processed features
- return torch.cat((vertical_features, horizontal_features), dim=-1)
diff --git a/FastVGGT/vggt/layers/swiglu_ffn.py b/FastVGGT/vggt/layers/swiglu_ffn.py
deleted file mode 100644
index 1dd991e1deb87141ccd282098d4b9d38fed6ef25..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/swiglu_ffn.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the Apache License, Version 2.0
-# found in the LICENSE file in the root directory of this source tree.
-
-import os
-from typing import Callable, Optional
-import warnings
-
-from torch import Tensor, nn
-import torch.nn.functional as F
-
-
-class SwiGLUFFN(nn.Module):
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Callable[..., nn.Module] = None,
- drop: float = 0.0,
- bias: bool = True,
- ) -> None:
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
-
- def forward(self, x: Tensor) -> Tensor:
- x12 = self.w12(x)
- x1, x2 = x12.chunk(2, dim=-1)
- hidden = F.silu(x1) * x2
- return self.w3(hidden)
-
-
-XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
-# try:
-# if XFORMERS_ENABLED:
-# from xformers.ops import SwiGLU
-
-# XFORMERS_AVAILABLE = True
-# warnings.warn("xFormers is available (SwiGLU)")
-# else:
-# warnings.warn("xFormers is disabled (SwiGLU)")
-# raise ImportError
-# except ImportError:
-SwiGLU = SwiGLUFFN
-XFORMERS_AVAILABLE = False
-
-# warnings.warn("xFormers is not available (SwiGLU)")
-
-
-class SwiGLUFFNFused(SwiGLU):
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Callable[..., nn.Module] = None,
- drop: float = 0.0,
- bias: bool = True,
- ) -> None:
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
- super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
diff --git a/FastVGGT/vggt/layers/vision_transformer.py b/FastVGGT/vggt/layers/vision_transformer.py
deleted file mode 100644
index 2e1aee388a9c38168657f78fa9436685582068c6..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/layers/vision_transformer.py
+++ /dev/null
@@ -1,446 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the Apache License, Version 2.0
-# found in the LICENSE file in the root directory of this source tree.
-
-# References:
-# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
-# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
-
-from functools import partial
-import math
-import logging
-from typing import Sequence, Tuple, Union, Callable
-
-import torch
-import torch.nn as nn
-from torch.utils.checkpoint import checkpoint
-from torch.nn.init import trunc_normal_
-from . import (
- Mlp,
- PatchEmbed,
- SwiGLUFFNFused,
- MemEffAttention,
- NestedTensorBlock as Block,
-)
-
-logger = logging.getLogger("dinov2")
-
-
-def named_apply(
- fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
-) -> nn.Module:
- if not depth_first and include_root:
- fn(module=module, name=name)
- for child_name, child_module in module.named_children():
- child_name = ".".join((name, child_name)) if name else child_name
- named_apply(
- fn=fn,
- module=child_module,
- name=child_name,
- depth_first=depth_first,
- include_root=True,
- )
- if depth_first and include_root:
- fn(module=module, name=name)
- return module
-
-
-class BlockChunk(nn.ModuleList):
- def forward(self, x):
- for b in self:
- x = b(x)
- return x
-
-
-class DinoVisionTransformer(nn.Module):
- def __init__(
- self,
- img_size=224,
- patch_size=16,
- in_chans=3,
- embed_dim=768,
- depth=12,
- num_heads=12,
- mlp_ratio=4.0,
- qkv_bias=True,
- ffn_bias=True,
- proj_bias=True,
- drop_path_rate=0.0,
- drop_path_uniform=False,
- init_values=None, # for layerscale: None or 0 => no layerscale
- embed_layer=PatchEmbed,
- act_layer=nn.GELU,
- block_fn=Block,
- ffn_layer="mlp",
- block_chunks=1,
- num_register_tokens=0,
- interpolate_antialias=False,
- interpolate_offset=0.1,
- qk_norm=False,
- ):
- """
- Args:
- img_size (int, tuple): input image size
- patch_size (int, tuple): patch size
- in_chans (int): number of input channels
- embed_dim (int): embedding dimension
- depth (int): depth of transformer
- num_heads (int): number of attention heads
- mlp_ratio (int): ratio of mlp hidden dim to embedding dim
- qkv_bias (bool): enable bias for qkv if True
- proj_bias (bool): enable bias for proj in attn if True
- ffn_bias (bool): enable bias for ffn if True
- drop_path_rate (float): stochastic depth rate
- drop_path_uniform (bool): apply uniform drop rate across blocks
- weight_init (str): weight init scheme
- init_values (float): layer-scale init values
- embed_layer (nn.Module): patch embedding layer
- act_layer (nn.Module): MLP activation layer
- block_fn (nn.Module): transformer block class
- ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
- block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
- num_register_tokens: (int) number of extra cls tokens (so-called "registers")
- interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
- interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
- """
- super().__init__()
- norm_layer = partial(nn.LayerNorm, eps=1e-6)
-
- self.num_features = self.embed_dim = (
- embed_dim # num_features for consistency with other models
- )
- self.num_tokens = 1
- self.n_blocks = depth
- self.num_heads = num_heads
- self.patch_size = patch_size
- self.num_register_tokens = num_register_tokens
- self.interpolate_antialias = interpolate_antialias
- self.interpolate_offset = interpolate_offset
- self.use_reentrant = False # hardcoded to False
-
- self.patch_embed = embed_layer(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=in_chans,
- embed_dim=embed_dim,
- )
- num_patches = self.patch_embed.num_patches
-
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
- self.pos_embed = nn.Parameter(
- torch.zeros(1, num_patches + self.num_tokens, embed_dim)
- )
- assert num_register_tokens >= 0
- self.register_tokens = (
- nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
- if num_register_tokens
- else None
- )
-
- if drop_path_uniform is True:
- dpr = [drop_path_rate] * depth
- else:
- dpr = [
- x.item() for x in torch.linspace(0, drop_path_rate, depth)
- ] # stochastic depth decay rule
-
- if ffn_layer == "mlp":
- logger.info("using MLP layer as FFN")
- ffn_layer = Mlp
- elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
- logger.info("using SwiGLU layer as FFN")
- ffn_layer = SwiGLUFFNFused
- elif ffn_layer == "identity":
- logger.info("using Identity layer as FFN")
-
- def f(*args, **kwargs):
- return nn.Identity()
-
- ffn_layer = f
- else:
- raise NotImplementedError
-
- blocks_list = [
- block_fn(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- proj_bias=proj_bias,
- ffn_bias=ffn_bias,
- drop_path=dpr[i],
- norm_layer=norm_layer,
- act_layer=act_layer,
- ffn_layer=ffn_layer,
- init_values=init_values,
- qk_norm=qk_norm,
- )
- for i in range(depth)
- ]
- if block_chunks > 0:
- self.chunked_blocks = True
- chunked_blocks = []
- chunksize = depth // block_chunks
- for i in range(0, depth, chunksize):
- # this is to keep the block index consistent if we chunk the block list
- chunked_blocks.append(
- [nn.Identity()] * i + blocks_list[i : i + chunksize]
- )
- self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
- else:
- self.chunked_blocks = False
- self.blocks = nn.ModuleList(blocks_list)
-
- self.norm = norm_layer(embed_dim)
- self.head = nn.Identity()
-
- self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
-
- self.init_weights()
-
- def init_weights(self):
- trunc_normal_(self.pos_embed, std=0.02)
- nn.init.normal_(self.cls_token, std=1e-6)
- if self.register_tokens is not None:
- nn.init.normal_(self.register_tokens, std=1e-6)
- named_apply(init_weights_vit_timm, self)
-
- def interpolate_pos_encoding(self, x, w, h):
- previous_dtype = x.dtype
- npatch = x.shape[1] - 1
- N = self.pos_embed.shape[1] - 1
- if npatch == N and w == h:
- return self.pos_embed
- pos_embed = self.pos_embed
- class_pos_embed = pos_embed[:, 0]
- patch_pos_embed = pos_embed[:, 1:]
- dim = x.shape[-1]
- w0 = w // self.patch_size
- h0 = h // self.patch_size
- M = int(math.sqrt(N)) # Recover the number of patches in each dimension
- assert N == M * M
- kwargs = {}
- if self.interpolate_offset:
- # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
- # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
- sx = float(w0 + self.interpolate_offset) / M
- sy = float(h0 + self.interpolate_offset) / M
- kwargs["scale_factor"] = (sx, sy)
- else:
- # Simply specify an output size instead of a scale factor
- kwargs["size"] = (w0, h0)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
- mode="bicubic",
- antialias=self.interpolate_antialias,
- **kwargs,
- )
- assert (w0, h0) == patch_pos_embed.shape[-2:]
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
- previous_dtype
- )
-
- def prepare_tokens_with_masks(self, x, masks=None):
- B, nc, w, h = x.shape
- x = self.patch_embed(x)
- if masks is not None:
- x = torch.where(
- masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
- )
-
- x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
- x = x + self.interpolate_pos_encoding(x, w, h)
-
- if self.register_tokens is not None:
- x = torch.cat(
- (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]),
- dim=1,
- )
-
- return x
-
- def forward_features_list(self, x_list, masks_list):
- x = [
- self.prepare_tokens_with_masks(x, masks)
- for x, masks in zip(x_list, masks_list)
- ]
-
- for blk in self.blocks:
- if self.training:
- x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
- else:
- x = blk(x)
-
- all_x = x
- output = []
- for x, masks in zip(all_x, masks_list):
- x_norm = self.norm(x)
- output.append(
- {
- "x_norm_clstoken": x_norm[:, 0],
- "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
- "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
- "x_prenorm": x,
- "masks": masks,
- }
- )
- return output
-
- def forward_features(self, x, masks=None):
- if isinstance(x, list):
- return self.forward_features_list(x, masks)
-
- x = self.prepare_tokens_with_masks(x, masks)
-
- for blk in self.blocks:
- if self.training:
- x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
- else:
- x = blk(x)
-
- x_norm = self.norm(x)
- return {
- "x_norm_clstoken": x_norm[:, 0],
- "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
- "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
- "x_prenorm": x,
- "masks": masks,
- }
-
- def _get_intermediate_layers_not_chunked(self, x, n=1):
- x = self.prepare_tokens_with_masks(x)
- # If n is an int, take the n last blocks. If it's a list, take them
- output, total_block_len = [], len(self.blocks)
- blocks_to_take = (
- range(total_block_len - n, total_block_len) if isinstance(n, int) else n
- )
- for i, blk in enumerate(self.blocks):
- x = blk(x)
- if i in blocks_to_take:
- output.append(x)
- assert len(output) == len(
- blocks_to_take
- ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
- return output
-
- def _get_intermediate_layers_chunked(self, x, n=1):
- x = self.prepare_tokens_with_masks(x)
- output, i, total_block_len = [], 0, len(self.blocks[-1])
- # If n is an int, take the n last blocks. If it's a list, take them
- blocks_to_take = (
- range(total_block_len - n, total_block_len) if isinstance(n, int) else n
- )
- for block_chunk in self.blocks:
- for blk in block_chunk[i:]: # Passing the nn.Identity()
- x = blk(x)
- if i in blocks_to_take:
- output.append(x)
- i += 1
- assert len(output) == len(
- blocks_to_take
- ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
- return output
-
- def get_intermediate_layers(
- self,
- x: torch.Tensor,
- n: Union[int, Sequence] = 1, # Layers or n last layers to take
- reshape: bool = False,
- return_class_token: bool = False,
- norm=True,
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
- if self.chunked_blocks:
- outputs = self._get_intermediate_layers_chunked(x, n)
- else:
- outputs = self._get_intermediate_layers_not_chunked(x, n)
- if norm:
- outputs = [self.norm(out) for out in outputs]
- class_tokens = [out[:, 0] for out in outputs]
- outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
- if reshape:
- B, _, w, h = x.shape
- outputs = [
- out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
- .permute(0, 3, 1, 2)
- .contiguous()
- for out in outputs
- ]
- if return_class_token:
- return tuple(zip(outputs, class_tokens))
- return tuple(outputs)
-
- def forward(self, *args, is_training=True, **kwargs):
- ret = self.forward_features(*args, **kwargs)
- if is_training:
- return ret
- else:
- return self.head(ret["x_norm_clstoken"])
-
-
-def init_weights_vit_timm(module: nn.Module, name: str = ""):
- """ViT weight initialization, original timm impl (for reproducibility)"""
- if isinstance(module, nn.Linear):
- trunc_normal_(module.weight, std=0.02)
- if module.bias is not None:
- nn.init.zeros_(module.bias)
-
-
-def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
- model = DinoVisionTransformer(
- patch_size=patch_size,
- embed_dim=384,
- depth=12,
- num_heads=6,
- mlp_ratio=4,
- block_fn=partial(Block, attn_class=MemEffAttention),
- num_register_tokens=num_register_tokens,
- **kwargs,
- )
- return model
-
-
-def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
- model = DinoVisionTransformer(
- patch_size=patch_size,
- embed_dim=768,
- depth=12,
- num_heads=12,
- mlp_ratio=4,
- block_fn=partial(Block, attn_class=MemEffAttention),
- num_register_tokens=num_register_tokens,
- **kwargs,
- )
- return model
-
-
-def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
- model = DinoVisionTransformer(
- patch_size=patch_size,
- embed_dim=1024,
- depth=24,
- num_heads=16,
- mlp_ratio=4,
- block_fn=partial(Block, attn_class=MemEffAttention),
- num_register_tokens=num_register_tokens,
- **kwargs,
- )
- return model
-
-
-def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
- """
- Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
- """
- model = DinoVisionTransformer(
- patch_size=patch_size,
- embed_dim=1536,
- depth=40,
- num_heads=24,
- mlp_ratio=4,
- block_fn=partial(Block, attn_class=MemEffAttention),
- num_register_tokens=num_register_tokens,
- **kwargs,
- )
- return model
diff --git a/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc b/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc
deleted file mode 100644
index ada2738eee9604acf0497bcafdfc52aa6f21119f..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc b/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc
deleted file mode 100644
index 67f693db1cd59d1a8d87e8f5c1c77c57c9de47d2..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/models/aggregator.py b/FastVGGT/vggt/models/aggregator.py
deleted file mode 100644
index 9e45bb204729f25a963fb083c158142ed01515c2..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/models/aggregator.py
+++ /dev/null
@@ -1,492 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import logging
-from numpy import block
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.utils.checkpoint import checkpoint
-from typing import Optional, Tuple, Union, List, Dict, Any
-
-from vggt.layers import PatchEmbed
-from vggt.layers.block import Block
-from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
-from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
-import time
-
-logger = logging.getLogger(__name__)
-
-_RESNET_MEAN = [0.485, 0.456, 0.406]
-_RESNET_STD = [0.229, 0.224, 0.225]
-
-
-class Aggregator(nn.Module):
- """
- The Aggregator applies alternating-attention over input frames,
- as described in VGGT: Visual Geometry Grounded Transformer.
-
- Remember to set model.train() to enable gradient checkpointing to reduce memory usage.
-
- Args:
- img_size (int): Image size in pixels.
- patch_size (int): Size of each patch for PatchEmbed.
- embed_dim (int): Dimension of the token embeddings.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
- num_register_tokens (int): Number of register tokens.
- block_fn (nn.Module): The block type used for attention (Block by default).
- qkv_bias (bool): Whether to include bias in QKV projections.
- proj_bias (bool): Whether to include bias in the output projection.
- ffn_bias (bool): Whether to include bias in MLP layers.
- patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
- aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
- aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
- qk_norm (bool): Whether to apply QK normalization.
- rope_freq (int): Base frequency for rotary embedding. -1 to disable.
- init_values (float): Init scale for layer scale.
- """
-
- def __init__(
- self,
- img_size=518,
- patch_size=14,
- embed_dim=1024,
- depth=24,
- num_heads=16,
- mlp_ratio=4.0,
- num_register_tokens=4,
- block_fn=Block,
- qkv_bias=True,
- proj_bias=True,
- ffn_bias=True,
- patch_embed="dinov2_vitl14_reg",
- aa_order=["frame", "global"],
- aa_block_size=1,
- qk_norm=True,
- rope_freq=100,
- init_values=0.01,
- global_merging=True,
- merging=0,
- vis_attn_map=False,
- ):
- super().__init__()
-
- self.__build_patch_embed__(
- patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim
- )
-
- # Initialize rotary position embedding if frequency > 0
- self.rope = (
- RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
- )
- self.position_getter = PositionGetter() if self.rope is not None else None
- self.global_merging = global_merging
- self.merging = merging
- self.vis_attn_map = vis_attn_map
- self.frame_blocks = nn.ModuleList(
- [
- block_fn(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- proj_bias=proj_bias,
- ffn_bias=ffn_bias,
- init_values=init_values,
- qk_norm=qk_norm,
- rope=self.rope,
- )
- for _ in range(depth)
- ]
- )
-
- self.global_blocks = nn.ModuleList(
- [
- block_fn(
- dim=embed_dim,
- num_heads=num_heads,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- proj_bias=proj_bias,
- ffn_bias=ffn_bias,
- init_values=init_values,
- qk_norm=qk_norm,
- rope=self.rope,
- )
- for _ in range(depth)
- ]
- )
-
- self.depth = depth
- self.aa_order = aa_order
- self.patch_size = patch_size
- self.aa_block_size = aa_block_size
-
- # Validate that depth is divisible by aa_block_size
- if self.depth % self.aa_block_size != 0:
- raise ValueError(
- f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})"
- )
-
- self.aa_block_num = self.depth // self.aa_block_size
-
- # Note: We have two camera tokens, one for the first frame and one for the rest
- # The same applies for register tokens
- self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
- self.register_token = nn.Parameter(
- torch.randn(1, 2, num_register_tokens, embed_dim)
- )
-
- # The patch tokens start after the camera and register tokens
- self.patch_start_idx = 1 + num_register_tokens
-
- # Initialize parameters with small values
- nn.init.normal_(self.camera_token, std=1e-6)
- nn.init.normal_(self.register_token, std=1e-6)
-
- # Register normalization constants as buffers (use bf16-compatible tensor)
- for name, value in (
- ("_resnet_mean", _RESNET_MEAN),
- ("_resnet_std", _RESNET_STD),
- ):
- self.register_buffer(
- name,
- torch.tensor(value, dtype=torch.bfloat16).view(1, 1, 3, 1, 1),
- persistent=False,
- )
-
- self.use_reentrant = False # hardcoded to False
-
- def __build_patch_embed__(
- self,
- patch_embed,
- img_size,
- patch_size,
- num_register_tokens,
- interpolate_antialias=True,
- interpolate_offset=0.0,
- block_chunks=0,
- init_values=1.0,
- embed_dim=1024,
- ):
- """
- Build the patch embed layer. If 'conv', we use a
- simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
- """
-
- if "conv" in patch_embed:
- self.patch_embed = PatchEmbed(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=3,
- embed_dim=embed_dim,
- )
- else:
- vit_models = {
- "dinov2_vitl14_reg": vit_large,
- "dinov2_vitb14_reg": vit_base,
- "dinov2_vits14_reg": vit_small,
- "dinov2_vitg2_reg": vit_giant2,
- }
-
- self.patch_embed = vit_models[patch_embed](
- img_size=img_size,
- patch_size=patch_size,
- num_register_tokens=num_register_tokens,
- interpolate_antialias=interpolate_antialias,
- interpolate_offset=interpolate_offset,
- block_chunks=block_chunks,
- init_values=init_values,
- )
-
- # Disable gradient updates for mask token
- if hasattr(self.patch_embed, "mask_token"):
- self.patch_embed.mask_token.requires_grad_(False)
-
- def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]:
- """
- Args:
- images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
- B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
-
- Returns:
- (list[torch.Tensor], int):
- The list of outputs from the attention blocks,
- and the patch_start_idx indicating where patch tokens begin.
- """
- B, S, C_in, H, W = images.shape
-
- if C_in != 3:
- raise ValueError(f"Expected 3 input channels, got {C_in}")
-
- # Normalize images and reshape for patch embed - ensure bf16 computation
- images = images.to(torch.bfloat16)
- images = (images - self._resnet_mean) / self._resnet_std
-
- images = images.view(B * S, C_in, H, W)
- patch_tokens = self.patch_embed(images)
- del images
-
- if isinstance(patch_tokens, dict):
- patch_tokens = patch_tokens["x_norm_patchtokens"]
-
- patch_tokens = patch_tokens.to(torch.bfloat16)
-
- _, P, C = patch_tokens.shape
-
- # Expand camera and register tokens to match batch size and sequence length
- camera_token = slice_expand_and_flatten(
- self.camera_token.to(torch.bfloat16), B, S
- )
- register_token = slice_expand_and_flatten(
- self.register_token.to(torch.bfloat16), B, S
- )
-
- tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
- tokens = tokens.to(torch.bfloat16)
- del camera_token, register_token, patch_tokens
- # Explicitly clean up image data since patch embedding is complete
- if "images_normalized" in locals():
- del images_normalized
-
- pos = None
- if self.rope is not None:
- pos = self.position_getter(
- B * S, H // self.patch_size, W // self.patch_size, device="cuda"
- )
-
- if self.patch_start_idx > 0:
- # do not use position embedding for special tokens (camera and register tokens)
- # so set pos to 0 for the special tokens
- pos_original = pos
- pos = pos + 1
- pos_special = torch.zeros(
- B * S, self.patch_start_idx, 2, device="cuda", dtype=torch.long
- )
- pos = torch.cat([pos_special, pos], dim=1)
- # Clean up temporary variables
- del pos_special, pos_original
-
- # update P because we added special tokens
- _, P, C = tokens.shape
-
- frame_idx = 0
- global_idx = 0
- output_list = []
- block4DPT_idx = [4, 11, 17, 23]
- global_merging = None
-
- # Set global variables for attention visualization
- if self.vis_attn_map:
- import vggt.layers.attention as attn_module
-
- # Set the global variables that attention.py needs
- attn_module.vis_attn_map = True
- attn_module.current_images = self._load_image_paths() # Load from temp file
- else:
- import vggt.layers.attention as attn_module
-
- attn_module.vis_attn_map = False
-
- for block_num in range(self.aa_block_num):
- torch.cuda.synchronize()
- torch.cuda.empty_cache()
-
- need_intermediates = True if block_num in block4DPT_idx else False
- if block_num % 1 == 0:
- # Clean up RoPE cache to prevent accumulation
- if hasattr(self, "rope") and self.rope is not None:
- if hasattr(self.rope, "frequency_cache"):
- self.rope.frequency_cache.clear()
- # Clean up position cache
- if (
- hasattr(self, "position_getter")
- and self.position_getter is not None
- ):
- if hasattr(self.position_getter, "position_cache"):
- # Keep only current size cache, clean up others
- current_cache = self.position_getter.position_cache.copy()
- if (
- len(current_cache) > 1
- ): # If there are multiple cache entries
- self.position_getter.position_cache.clear()
- # Keep only the most recently used one
- if current_cache:
- key = list(current_cache.keys())[-1]
- self.position_getter.position_cache[key] = (
- current_cache[key]
- )
- # Avoid saving block_num to instance variable to reduce references
- for attn_type in self.aa_order:
- if attn_type == "frame":
- tokens, frame_idx, frame_intermediates = (
- self._process_frame_attention(
- tokens,
- B,
- S,
- P,
- C,
- frame_idx,
- pos=pos,
- need_intermediates=need_intermediates,
- )
- )
- elif attn_type == "global":
- if self.merging is None:
- global_merging = None
- elif self.global_merging and block_num >= self.merging:
- global_merging = block_num
- # Set attention_map for visualization
- if self.vis_attn_map:
- import vggt.layers.attention as attn_module
-
- attn_module.attention_map = block_num
- tokens, global_idx, global_intermediates = (
- self._process_global_attention(
- tokens,
- B,
- S,
- P,
- C,
- global_idx,
- pos=pos,
- global_merging=global_merging,
- need_intermediates=need_intermediates,
- )
- )
- else:
- raise ValueError(f"Unknown attention type: {attn_type}")
-
- if block_num not in block4DPT_idx:
- if "frame_intermediates" in locals():
- del frame_intermediates
- if "global_intermediates" in locals():
- del global_intermediates
- else:
- concat_inter = torch.cat(
- [frame_intermediates[0].detach(), global_intermediates[0].detach()],
- dim=-1,
- )
- if concat_inter.dtype != torch.bfloat16:
- concat_inter = concat_inter.to(torch.bfloat16)
- output_list.append(concat_inter)
- del concat_inter, frame_intermediates, global_intermediates
-
- # Do final cleanup before returning
- del tokens, pos
- if "pos_special" in locals():
- del pos_special
- if "pos_original" in locals():
- del pos_original
- torch.cuda.empty_cache() # Final cleanup
-
- return output_list, self.patch_start_idx
-
- def _process_frame_attention(
- self, tokens, B, S, P, C, frame_idx, pos=None, need_intermediates=False
- ):
- """
- Process frame attention blocks. We keep tokens in shape (B*S, P, C).
- """
- # If needed, reshape tokens or positions:
- if tokens.shape != (B * S, P, C):
- tokens = tokens.view(B, S, P, C).view(B * S, P, C)
-
- if pos is not None and pos.shape != (B * S, P, 2):
- pos = pos.view(B, S, P, 2).view(B * S, P, 2)
-
- intermediates = [] if need_intermediates else None
-
- # by default, self.aa_block_size=1, which processes one block at a time
- for _ in range(self.aa_block_size):
- tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
- frame_idx += 1
- if need_intermediates:
- intermediates.append(tokens.view(B, S, P, C))
-
- return tokens, frame_idx, intermediates
-
- def _process_global_attention(
- self,
- tokens,
- B,
- S,
- P,
- C,
- global_idx,
- pos=None,
- global_merging=None,
- need_intermediates=False,
- ):
- """
- Process global attention blocks. We keep tokens in shape (B, S*P, C).
- """
- if tokens.shape != (B, S * P, C):
- tokens = tokens.view(B, S, P, C).view(B, S * P, C)
-
- if pos is not None and pos.shape != (B, S * P, 2):
- pos = pos.view(B, S, P, 2).view(B, S * P, 2)
-
- intermediates = [] if need_intermediates else None
-
- # by default, self.aa_block_size=1, which processes one block at a time
- for _ in range(self.aa_block_size):
- tokens = self.global_blocks[global_idx](
- tokens,
- pos=pos,
- global_merging=global_merging,
- )
- global_idx += 1
- if need_intermediates:
- intermediates.append(tokens.view(B, S, P, C))
-
- return tokens, global_idx, intermediates
-
- def _load_image_paths(self):
- """Load image paths from temporary file for visualization"""
- try:
- import os
- import tempfile
- import pickle
-
- temp_dir = tempfile.gettempdir()
- image_paths_file = os.path.join(temp_dir, "vggt_image_paths.pkl")
-
- if os.path.exists(image_paths_file):
- with open(image_paths_file, "rb") as f:
- return pickle.load(f)
- else:
- return []
- except Exception as e:
- print(f"Warning: Could not load image paths for visualization: {e}")
- return []
-
-
-def slice_expand_and_flatten(token_tensor, B, S):
- """
- Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
- 1) Uses the first position (index=0) for the first frame only
- 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
- 3) Expands both to match batch size B
- 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
- followed by (S-1) second-position tokens
- 5) Flattens to (B*S, X, C) for processing
-
- Returns:
- torch.Tensor: Processed tokens with shape (B*S, X, C)
- """
-
- query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
- # Slice out the "other" tokens => shape (1, S-1, ...)
- others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
- # Concatenate => shape (B, S, ...)
- combined = torch.cat([query, others], dim=1)
-
- # Finally flatten => shape (B*S, ...)
- combined = combined.view(B * S, *combined.shape[2:])
- return combined
diff --git a/FastVGGT/vggt/models/vggt.py b/FastVGGT/vggt/models/vggt.py
deleted file mode 100644
index 2efeb5271de9916d05fd231a0a5ec1694f999bd0..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/models/vggt.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn as nn
-from huggingface_hub import PyTorchModelHubMixin # used for model hub
-
-from vggt.models.aggregator import Aggregator
-from vggt.heads.camera_head import CameraHead
-from vggt.heads.dpt_head import DPTHead
-from vggt.heads.track_head import TrackHead
-
-
-class VGGT(nn.Module, PyTorchModelHubMixin):
- def __init__(
- self,
- img_size=518,
- patch_size=14,
- embed_dim=1024,
- enable_camera=True,
- enable_point=False,
- enable_depth=True,
- enable_track=False,
- merging=0,
- vis_attn_map=False,
- ):
- super().__init__()
-
- self.vis_attn_map = vis_attn_map
-
- self.aggregator = Aggregator(
- img_size=img_size,
- patch_size=patch_size,
- embed_dim=embed_dim,
- merging=merging,
- vis_attn_map=vis_attn_map,
- )
-
- self.camera_head = CameraHead(dim_in=2 * embed_dim) if enable_camera else None
- self.point_head = (
- DPTHead(
- dim_in=2 * embed_dim,
- output_dim=4,
- activation="inv_log",
- conf_activation="expp1",
- )
- if enable_point
- else None
- )
- self.depth_head = (
- DPTHead(
- dim_in=2 * embed_dim,
- output_dim=2,
- activation="exp",
- conf_activation="expp1",
- )
- if enable_depth
- else None
- )
- self.track_head = (
- TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
- if enable_track
- else None
- )
-
- def update_patch_dimensions(self, patch_width: int, patch_height: int):
- """
- Update patch dimensions for all attention layers in the model
-
- Args:
- patch_width: Patch width (typically 37)
- patch_height: Patch height (typically 28)
- """
-
- def update_attention_in_module(module):
- for name, child in module.named_children():
- # Recursively update submodules
- update_attention_in_module(child)
- # If it is an attention layer, update its patch dimensions
- if hasattr(child, "patch_width") and hasattr(child, "patch_height"):
- child.patch_width = patch_width
- child.patch_height = patch_height
-
- # Update all attention layers in the aggregator
- update_attention_in_module(self.aggregator)
-
- # print(
- # f"🔧 Updated model attention layer patch dimensions: {patch_width}x{patch_height}"
- # )
-
- def forward(
- self,
- images: torch.Tensor,
- query_points: torch.Tensor = None,
- image_paths: list = None,
- ):
- """
- Forward pass of the VGGT model.
-
- Args:
- images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
- B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
- query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
- Shape: [N, 2] or [B, N, 2], where N is the number of query points.
- Default: None
- image_paths (list, optional): List of image file paths for attention visualization.
- Only used when vis_attn_map=True. Default: None
-
- Returns:
- dict: A dictionary containing the following predictions:
- - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration)
- - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1]
- - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W]
- - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3]
- - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W]
- - images (torch.Tensor): Original input images, preserved for visualization
-
- If query_points is provided, also includes:
- - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates
- - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N]
- - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N]
- """
- # If without batch dimension, add it
- if len(images.shape) == 4:
- images = images.unsqueeze(0)
-
- if query_points is not None and len(query_points.shape) == 2:
- query_points = query_points.unsqueeze(0)
-
- # Save image paths globally for attention visualization
- if self.vis_attn_map and image_paths is not None:
- import os
- import tempfile
- import pickle
-
- # Create a temporary file to store image paths
- temp_dir = tempfile.gettempdir()
- image_paths_file = os.path.join(temp_dir, "vggt_image_paths.pkl")
- with open(image_paths_file, "wb") as f:
- pickle.dump(image_paths, f)
-
- aggregated_tokens_list, patch_start_idx = self.aggregator(images)
-
- predictions = {}
-
- if self.camera_head is not None:
- pose_enc_list = self.camera_head(aggregated_tokens_list)
- predictions["pose_enc"] = pose_enc_list[
- -1
- ] # pose encoding of the last iteration
- predictions["pose_enc_list"] = pose_enc_list
-
- if self.depth_head is not None:
- depth, depth_conf = self.depth_head(
- aggregated_tokens_list,
- images=images,
- patch_start_idx=patch_start_idx,
- )
- predictions["depth"] = depth
- predictions["depth_conf"] = depth_conf
-
- if self.point_head is not None:
- pts3d, pts3d_conf = self.point_head(
- aggregated_tokens_list,
- images=images,
- patch_start_idx=patch_start_idx,
- )
- predictions["world_points"] = pts3d
- predictions["world_points_conf"] = pts3d_conf
-
- if self.track_head is not None and query_points is not None:
- track_list, vis, conf = self.track_head(
- aggregated_tokens_list,
- images=images,
- patch_start_idx=patch_start_idx,
- query_points=query_points,
- )
- predictions["track"] = track_list[-1] # track of the last iteration
- predictions["vis"] = vis
- predictions["conf"] = conf
-
- if not self.training:
- predictions["images"] = (
- images # store the images for visualization during inference
- )
-
- return predictions
diff --git a/FastVGGT/vggt/utils/__init__.py b/FastVGGT/vggt/utils/__init__.py
deleted file mode 100644
index 3a2958f93a114495631a3ce270b99ee8ba8443f1..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/utils/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
\ No newline at end of file
diff --git a/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index 95269f2111ee13cbadfcbab9cd10605555f25aee..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc
deleted file mode 100644
index a3e08e180b2f41ce54c1912296117bf616196aae..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc
deleted file mode 100644
index b7a92463e6f8f34289986fcaf2ab515c9598cf51..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc
deleted file mode 100644
index 37978e9be49721f9e88beeb814a38e3dcb8754ad..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc
deleted file mode 100644
index 93fcc5eddf5936461e8cbf8ca031e150105da24e..0000000000000000000000000000000000000000
Binary files a/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc and /dev/null differ
diff --git a/FastVGGT/vggt/utils/eval_utils.py b/FastVGGT/vggt/utils/eval_utils.py
deleted file mode 100644
index a81e2cab2b36c97e35b5ebb110b6e66fc814efa7..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/utils/eval_utils.py
+++ /dev/null
@@ -1,782 +0,0 @@
-from collections import deque
-from PIL import Image
-import cv2
-import io
-import random
-import time
-from pathlib import Path
-from typing import List, Tuple, Dict, Optional
-from copy import deepcopy
-import matplotlib.pyplot as plt
-import evo.tools.plot as plot
-import evo.main_ape as main_ape
-import evo.main_rpe as main_rpe
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
-from evo.core.metrics import PoseRelation, Unit
-from evo.core.trajectory import PoseTrajectory3D
-from scipy.linalg import svd
-import open3d as o3d # for point cloud processing and Chamfer Distance computation
-
-from scipy.spatial.transform import Rotation
-from torchvision import transforms as TF
-
-
-from vggt.utils.geometry import unproject_depth_map_to_point_map
-from vggt.utils.pose_enc import pose_encoding_to_extri_intri
-
-
-def shuffle_deque(dq, seed=None):
- # Set the random seed for reproducibility
- if seed is not None:
- random.seed(seed)
-
- # Convert deque to list, shuffle, and convert back
- shuffled_list = list(dq)
- random.shuffle(shuffled_list)
- return deque(shuffled_list)
-
-
-def imread_cv2(path, options=cv2.IMREAD_COLOR):
- """Open an image or a depthmap with opencv-python."""
- if path.endswith((".exr", "EXR")):
- options = cv2.IMREAD_ANYDEPTH
- img = cv2.imread(path, options)
- if img is None:
- raise IOError(f"Could not load image={path} with {options=}")
- if img.ndim == 3:
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- return img
-
-
-# Import umeyama_alignment for internal use in eval_trajectory
-def umeyama_alignment(src, dst, estimate_scale=True):
- # Ensure inputs have correct shape
- assert (
- src.shape == dst.shape
- ), f"Input shapes don't match: src {src.shape}, dst {dst.shape}"
- assert src.shape[0] == 3, f"Expected point cloud dimension (3,N), got {src.shape}"
-
- # Compute centroids
- src_mean = src.mean(axis=1, keepdims=True)
- dst_mean = dst.mean(axis=1, keepdims=True)
-
- # Center the point clouds
- src_centered = src - src_mean
- dst_centered = dst - dst_mean
-
- # Compute covariance matrix
- cov = dst_centered @ src_centered.T
-
- try:
- # Singular Value Decomposition
- U, D, Vt = svd(cov)
- V = Vt.T
-
- # Handle reflection case
- det_UV = np.linalg.det(U @ V.T)
- S = np.eye(3)
- if det_UV < 0:
- S[2, 2] = -1
-
- # Compute rotation matrix
- R = U @ S @ V.T
-
- if estimate_scale:
- # Compute scale factor - fix dimension issue
- src_var = np.sum(src_centered * src_centered)
- if src_var < 1e-10:
- print(
- "Warning: Source point cloud variance close to zero, setting scale factor to 1.0"
- )
- scale = 1.0
- else:
- # Fix potential dimension issue with np.diag(S)
- # Use diagonal elements directly
- scale = np.sum(D * np.diag(S)) / src_var
- else:
- scale = 1.0
-
- # Compute translation vector
- t = dst_mean.ravel() - scale * (R @ src_mean).ravel()
-
- return scale, R, t
-
- except Exception as e:
- print(f"Error in umeyama_alignment computation: {e}")
- print(
- "Returning default transformation: scale=1.0, rotation=identity matrix, translation=centroid difference"
- )
- # Return default transformation
- scale = 1.0
- R = np.eye(3)
- t = (dst_mean - src_mean).ravel()
- return scale, R, t
-
-
-def compute_chamfer_distance(points_pred, points_gt, max_dist=1.0):
- # Ensure point cloud size is not too large, which would cause slow computation
- MAX_POINTS = 100000
- if points_pred.shape[0] > MAX_POINTS:
- np.random.seed(33) # Fix random seed
- indices = np.random.choice(points_pred.shape[0], MAX_POINTS, replace=False)
- points_pred = points_pred[indices]
-
- if points_gt.shape[0] > MAX_POINTS:
- np.random.seed(33) # Fix random seed
- indices = np.random.choice(points_gt.shape[0], MAX_POINTS, replace=False)
- points_gt = points_gt[indices]
-
- # Convert numpy point clouds to open3d point cloud objects
- pcd_pred = o3d.geometry.PointCloud()
- pcd_gt = o3d.geometry.PointCloud()
- pcd_pred.points = o3d.utility.Vector3dVector(points_pred)
- pcd_gt.points = o3d.utility.Vector3dVector(points_gt)
-
- # Downsample point clouds to accelerate computation
- voxel_size = 0.05 # 5cm voxel size
- pcd_pred = pcd_pred.voxel_down_sample(voxel_size)
- pcd_gt = pcd_gt.voxel_down_sample(voxel_size)
-
- # Compute distances from predicted point cloud to GT point cloud
- distances1 = np.asarray(pcd_pred.compute_point_cloud_distance(pcd_gt))
- # Compute distances from GT point cloud to predicted point cloud
- distances2 = np.asarray(pcd_gt.compute_point_cloud_distance(pcd_pred))
-
- # Apply distance clipping
- distances1 = np.clip(distances1, 0, max_dist)
- distances2 = np.clip(distances2, 0, max_dist)
-
- # Chamfer Distance is the sum of mean distances in both directions
- chamfer_dist = np.mean(distances1) + np.mean(distances2)
-
- return chamfer_dist
-
-
-def load_gt_pointcloud(scene_id, gt_ply_dir):
- scene_dir = gt_ply_dir / scene_id
- ply_path = scene_dir / (scene_id + "_vh_clean_2.ply")
- pcd = o3d.io.read_point_cloud(str(ply_path))
-
- # Convert to numpy arrays
- points = np.asarray(pcd.points)
- colors = None
- try:
- if pcd.has_colors():
- colors = np.asarray(pcd.colors)
- except Exception:
- colors = None
-
- return points, colors
-
-
-def eval_trajectory(poses_est, poses_gt, frame_ids, align=False):
- # Build reference trajectory object
- traj_ref = PoseTrajectory3D(
- positions_xyz=poses_gt[:, :3, 3], # Extract translation part
- orientations_quat_wxyz=Rotation.from_matrix(poses_gt[:, :3, :3]).as_quat(
- scalar_first=True
- ), # Convert rotation matrix to quaternion
- timestamps=frame_ids,
- )
-
- # Build estimated trajectory object
- traj_est = PoseTrajectory3D(
- positions_xyz=poses_est[:, :3, 3],
- orientations_quat_wxyz=Rotation.from_matrix(poses_est[:, :3, :3]).as_quat(
- scalar_first=True
- ),
- timestamps=frame_ids,
- )
-
- # Calculate Absolute Trajectory Error (ATE)
- ate_result = main_ape.ape(
- deepcopy(traj_ref),
- deepcopy(traj_est),
- est_name="traj",
- pose_relation=PoseRelation.translation_part,
- align=align,
- correct_scale=align,
- align_origin=True,
- )
- ate = ate_result.stats["rmse"]
-
- # Get alignment transformation matrix
- transform = np.eye(4)
- if align:
- try:
- # Use umeyama algorithm to compute optimal rigid transformation (including rotation, translation and scaling)
- aligned_xyz = ate_result.trajectories["traj"].positions_xyz
- original_xyz = traj_est.positions_xyz
-
- # At least 3 points needed to compute reliable transformation
- if len(aligned_xyz) >= 3 and len(original_xyz) >= 3:
- # Ensure point count matches
- min_points = min(len(aligned_xyz), len(original_xyz))
- aligned_xyz = aligned_xyz[:min_points]
- original_xyz = original_xyz[:min_points]
-
- # Compute transformation matrix (scaling, rotation and translation)
- try:
- s, R, t = umeyama_alignment(
- original_xyz.T, # Source point cloud (3xN)
- aligned_xyz.T, # Target point cloud (3xN)
- True, # Whether to estimate scaling
- )
-
- # Build complete transformation matrix
- transform = np.eye(4)
- transform[:3, :3] = s * R # Scaling and rotation
- transform[:3, 3] = t # Translation
-
- except Exception as e:
- print(f"umeyama_alignment failed: {e}")
- else:
- print(
- "Insufficient points, cannot reliably compute transformation matrix"
- )
- except Exception as e:
- print(f"Error computing transformation matrix: {e}")
-
- # If the above method fails, fallback to simple translation transformation
- if np.array_equal(transform, np.eye(4)) and hasattr(ate_result, "trajectories"):
- try:
- # Get original and aligned first position
- orig_pos = traj_est.positions_xyz[0]
- aligned_pos = ate_result.trajectories["traj"].positions_xyz[0]
-
- # Calculate translation part
- translation = aligned_pos - orig_pos
-
- # Update translation part of transformation matrix
- transform[:3, 3] = translation
- print(f"Fallback to simple translation transformation: {transform}")
- except Exception as e:
- print(f"Error building translation transformation: {e}")
- print("Will use identity matrix")
-
- # Calculate Absolute Rotation Error (ARE)
- are_result = main_ape.ape(
- deepcopy(traj_ref),
- deepcopy(traj_est),
- est_name="traj",
- pose_relation=PoseRelation.rotation_angle_deg,
- align=align,
- correct_scale=align,
- align_origin=True,
- )
- are = are_result.stats["rmse"]
-
- # Calculate Relative Pose Error (RPE) - rotation part
- rpe_rots_result = main_rpe.rpe(
- deepcopy(traj_ref),
- deepcopy(traj_est),
- est_name="traj",
- pose_relation=PoseRelation.rotation_angle_deg,
- align=align,
- correct_scale=align,
- delta=1,
- delta_unit=Unit.frames,
- rel_delta_tol=0.01,
- all_pairs=True,
- align_origin=True,
- )
- rpe_rot = rpe_rots_result.stats["rmse"]
-
- # Calculate Relative Pose Error (RPE) - translation part
- rpe_transs_result = main_rpe.rpe(
- deepcopy(traj_ref),
- deepcopy(traj_est),
- est_name="traj",
- pose_relation=PoseRelation.translation_part,
- align=align,
- correct_scale=align,
- delta=1,
- delta_unit=Unit.frames,
- rel_delta_tol=0.01,
- all_pairs=True,
- align_origin=True,
- )
- rpe_trans = rpe_transs_result.stats["rmse"]
-
- # Plot trajectory graph
- plot_mode = plot.PlotMode.xz # Use correct PlotMode reference
- fig = plt.figure()
- ax = plot.prepare_axis(fig, plot_mode)
- ax.set_title(f"ATE: {round(ate, 3)}, ARE: {round(are, 3)}")
-
- # Use reference trajectory (GT) for plotting
- plot.traj(ax, plot_mode, traj_ref, "--", "gray", "gt")
-
- # Use aligned trajectory for visualization
- if align:
- traj_est_aligned = ate_result.trajectories["traj"]
- plot.traj_colormap(
- ax,
- traj_est_aligned,
- ate_result.np_arrays["error_array"],
- plot_mode,
- min_map=ate_result.stats["min"],
- max_map=ate_result.stats["max"],
- )
- else:
- plot.traj_colormap(
- ax,
- traj_est,
- ate_result.np_arrays["error_array"],
- plot_mode,
- min_map=ate_result.stats["min"],
- max_map=ate_result.stats["max"],
- )
-
- ax.legend()
-
- # Save image to memory buffer
- buffer = io.BytesIO()
- plt.savefig(buffer, format="png", dpi=90)
- buffer.seek(0)
-
- pillow_image = Image.open(buffer)
- pillow_image.load()
- buffer.close()
- plt.close(fig)
-
- return (
- {"ate": ate, "are": are, "rpe_rot": rpe_rot, "rpe_trans": rpe_trans},
- pillow_image,
- transform,
- )
-
-
-def load_poses(path):
- # Read all txt files from pose directory
- pose_files = sorted(
- path.glob("*.txt"), key=lambda x: int(x.stem)
- ) # Sort by numerical order
-
- # Check if pose files exist
- if len(pose_files) == 0:
- print(f"Warning: No pose files (.txt) found in directory {path}")
- return None, None, None
-
- c2ws = []
- available_frame_ids = []
-
- for pose_file in pose_files:
- try:
- with open(pose_file, "r") as f:
- # Each file contains 16 numbers representing a 4x4 transformation matrix
- nums = [float(x) for x in f.read().strip().split()]
- pose = np.array(nums).reshape(4, 4)
- # Check if pose is valid (no infinite or NaN values)
- if not (np.isinf(pose).any() or np.isnan(pose).any()):
- c2ws.append(pose)
- available_frame_ids.append(int(pose_file.stem))
- else:
- continue
- except Exception as e:
- print(f"Error reading pose file {pose_file}: {e}")
- continue
-
- if len(c2ws) == 0:
- print(f"Warning: No valid pose files found in directory {path}")
- return None, None, None
-
- c2ws = np.stack(c2ws)
- available_frame_ids = np.array(available_frame_ids)
-
- # Transform all poses to first frame coordinate system
- first_gt_pose = c2ws[0].copy() # Save original pose of first frame
- c2ws = np.linalg.inv(c2ws[0]) @ c2ws
- return c2ws, first_gt_pose, available_frame_ids
-
-
-def get_vgg_input_imgs(images: np.ndarray):
- to_tensor = TF.ToTensor()
- vgg_input_images = []
- final_width = None
- final_height = None
-
- for image in images:
- img = Image.fromarray(image, mode="RGB")
- width, height = img.size
- # Resize image, maintain aspect ratio, ensure height is multiple of 14
- new_width = 518
- new_height = round(height * (new_width / width) / 14) * 14
- img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
- img = to_tensor(img) # Convert to tensor (0, 1)
-
- # If height exceeds 518, perform center cropping
- if new_height > 518:
- start_y = (new_height - 518) // 2
- img = img[:, start_y : start_y + 518, :]
- final_height = 518
- else:
- final_height = new_height
-
- final_width = new_width
- vgg_input_images.append(img)
-
- vgg_input_images = torch.stack(vgg_input_images)
-
- # Calculate the patch dimensions (divided by 14 for patch size)
- patch_width = final_width // 14 # 518 // 14 = 37
- patch_height = final_height // 14 # computed dynamically, typically 28
-
- return vgg_input_images, patch_width, patch_height
-
-
-def get_sorted_image_paths(images_dir):
- image_paths = []
- for ext in ["*.jpg", "*.png", "*.jpeg"]:
- image_paths.extend(sorted(images_dir.glob(ext)))
- # image_paths.sort(key=lambda x: int(x.stem))
- return image_paths
-
-
-def to_homogeneous(extrinsics):
- n = extrinsics.shape[0]
- homogeneous_extrinsics = np.eye(4)[None, :, :].repeat(
- n, axis=0
- ) # Create identity matrix
- homogeneous_extrinsics[:, :3, :4] = extrinsics # Copy [R|t] part
- return homogeneous_extrinsics
-
-
-def umeyama_alignment(src, dst, estimate_scale=True):
- # Ensure inputs have correct shape
- assert (
- src.shape == dst.shape
- ), f"Input shapes don't match: src {src.shape}, dst {dst.shape}"
- assert src.shape[0] == 3, f"Expected point cloud dimension (3,N), got {src.shape}"
-
- # Compute centroids
- src_mean = src.mean(axis=1, keepdims=True)
- dst_mean = dst.mean(axis=1, keepdims=True)
-
- # Center the point clouds
- src_centered = src - src_mean
- dst_centered = dst - dst_mean
-
- # Compute covariance matrix
- cov = dst_centered @ src_centered.T
-
- try:
- # Singular Value Decomposition
- U, D, Vt = svd(cov)
- V = Vt.T
-
- # Handle reflection case
- det_UV = np.linalg.det(U @ V.T)
- S = np.eye(3)
- if det_UV < 0:
- S[2, 2] = -1
-
- # Compute rotation matrix
- R = U @ S @ V.T
-
- if estimate_scale:
- # Compute scale factor - fix dimension issue
- src_var = np.sum(src_centered * src_centered)
- if src_var < 1e-10:
- print(
- "Warning: Source point cloud variance close to zero, setting scale factor to 1.0"
- )
- scale = 1.0
- else:
- # Fix potential dimension issue with np.diag(S)
- # Use diagonal elements directly
- scale = np.sum(D * np.diag(S)) / src_var
- else:
- scale = 1.0
-
- # Compute translation vector
- t = dst_mean.ravel() - scale * (R @ src_mean).ravel()
-
- return scale, R, t
-
- except Exception as e:
- print(f"Error in umeyama_alignment computation: {e}")
- print(
- "Returning default transformation: scale=1.0, rotation=identity matrix, translation=centroid difference"
- )
- # Return default transformation
- scale = 1.0
- R = np.eye(3)
- t = (dst_mean - src_mean).ravel()
- return scale, R, t
-
-
-def align_point_clouds_scale(source_pc, target_pc):
- # Compute bounding box sizes of point clouds
- source_min = np.min(source_pc, axis=0)
- source_max = np.max(source_pc, axis=0)
- target_min = np.min(target_pc, axis=0)
- target_max = np.max(target_pc, axis=0)
-
- source_size = source_max - source_min
- target_size = target_max - target_min
-
- # Compute point cloud centers
- source_center = (source_max + source_min) / 2
- target_center = (target_max + target_min) / 2
-
- # Compute overall scale factor (using diagonal length)
- source_diag = np.sqrt(np.sum(source_size**2))
- target_diag = np.sqrt(np.sum(target_size**2))
-
- if source_diag < 1e-8:
- print("Warning: Source point cloud size close to zero")
- scale_factor = 1.0
- else:
- scale_factor = target_diag / source_diag
-
- # Apply scaling (with source point cloud center as reference)
- centered_source = source_pc - source_center
- scaled_centered = centered_source * scale_factor
- scaled_aligned_source = scaled_centered + target_center
-
- return scaled_aligned_source, scale_factor
-
-
-def get_all_scenes(data_dir: Path, num_scenes: int) -> List[str]:
- all_scenes = sorted([d.name for d in data_dir.iterdir() if d.is_dir()])
- if len(all_scenes) > num_scenes:
- sample_interval = max(1, len(all_scenes) // num_scenes)
- return all_scenes[::sample_interval][:num_scenes]
- return all_scenes
-
-
-def build_frame_selection(
- image_paths: List[Path],
- available_pose_frame_ids: np.ndarray,
- input_frame: int,
-) -> Tuple[List[int], List[Path], List[int]]:
- all_image_frame_ids = [int(path.stem) for path in image_paths]
- valid_frame_ids = sorted(
- list(set(all_image_frame_ids) & set(available_pose_frame_ids))
- )
- if len(valid_frame_ids) > input_frame:
- first_frame = valid_frame_ids[0]
- remaining_frames = valid_frame_ids[1:]
- step = max(1, len(remaining_frames) // (input_frame - 1))
- selected_remaining = remaining_frames[::step][: input_frame - 1]
- selected_frame_ids = [first_frame] + selected_remaining
- else:
- selected_frame_ids = valid_frame_ids
-
- frame_id_to_path = {int(path.stem): path for path in image_paths}
- selected_image_paths = [
- frame_id_to_path[fid] for fid in selected_frame_ids if fid in frame_id_to_path
- ]
-
- pose_frame_to_idx = {fid: idx for idx, fid in enumerate(available_pose_frame_ids)}
- selected_pose_indices = [
- pose_frame_to_idx[fid] for fid in selected_frame_ids if fid in pose_frame_to_idx
- ]
-
- return selected_frame_ids, selected_image_paths, selected_pose_indices
-
-
-def load_images_rgb(image_paths: List[Path]) -> List[np.ndarray]:
- images: List[np.ndarray] = []
- for image_path in image_paths:
- img = cv2.imread(str(image_path))
- if img is None:
- continue
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- images.append(img)
- return images
-
-
-@torch.no_grad()
-def infer_vggt_and_reconstruct(
- model: torch.nn.Module,
- vgg_input: torch.Tensor,
- dtype: torch.dtype,
- depth_conf_thresh: float,
- image_paths: list = None,
-) -> Tuple[
- np.ndarray,
- np.ndarray,
- List[np.ndarray],
- List[np.ndarray],
- List[np.ndarray],
- float,
-]:
- torch.cuda.synchronize()
- start = time.time()
- with torch.cuda.amp.autocast(dtype=dtype):
- vgg_input_cuda = vgg_input.cuda().to(torch.bfloat16)
- predictions = model(vgg_input_cuda, image_paths=image_paths)
- torch.cuda.synchronize()
- end = time.time()
- inference_time_ms = (end - start) * 1000.0
-
- extrinsic, intrinsic = pose_encoding_to_extri_intri(
- predictions["pose_enc"], (vgg_input.shape[2], vgg_input.shape[3])
- )
-
- depth_tensor = predictions["depth"]
- depth_conf = predictions["depth_conf"]
- depth_conf_np = depth_conf[0].detach().float().cpu().numpy()
- depth_mask = depth_conf_np >= depth_conf_thresh
- depth_filtered = depth_tensor[0].detach().float().cpu().numpy()
- depth_filtered[~depth_mask] = np.nan
- depth_np = depth_filtered
-
- extrinsic_np = extrinsic[0].detach().float().cpu().numpy()
- intrinsic_np = intrinsic[0].detach().float().cpu().numpy()
-
- world_points = unproject_depth_map_to_point_map(
- depth_np, extrinsic_np, intrinsic_np
- )
- all_points: List[np.ndarray] = []
- all_colors: List[np.ndarray] = []
-
- # Prepare RGB images aligned with vgg_input for coloring point clouds (0-255, uint8)
- vgg_np = vgg_input.detach().float().cpu().numpy() # [S, 3, H, W] in [0,1]
-
- for frame_idx in range(world_points.shape[0]):
- points = world_points[frame_idx].reshape(-1, 3)
- valid_mask = ~np.isnan(points).any(axis=1) & ~np.isinf(points).any(axis=1)
- valid_points = points[valid_mask]
- if len(valid_points) > 0:
- all_points.append(valid_points)
-
- # Generate corresponding colors
- img_chw = vgg_np[frame_idx] # [3, H, W]
- img_hwc = (
- (np.transpose(img_chw, (1, 2, 0)) * 255.0).clip(0, 255).astype(np.uint8)
- ) # [H, W, 3] uint8
- rgb_flat = img_hwc.reshape(-1, 3)
- valid_colors = rgb_flat[valid_mask]
- all_colors.append(valid_colors)
-
- camera_poses = to_homogeneous(extrinsic_np)
- all_cam_to_world_mat = list(camera_poses)
-
- return (
- extrinsic_np,
- intrinsic_np,
- all_points,
- all_colors,
- all_cam_to_world_mat,
- inference_time_ms,
- )
-
-
-def evaluate_scene_and_save(
- scene: str,
- c2ws: np.ndarray,
- first_gt_pose: np.ndarray,
- frame_ids: List[int],
- all_cam_to_world_mat: List[np.ndarray],
- all_world_points: List[np.ndarray],
- output_scene_dir: Path,
- gt_ply_dir: Path,
- chamfer_max_dist: float,
- inference_time_ms: float,
- plot_flag: bool,
-) -> Optional[Dict[str, float]]:
- if not all_cam_to_world_mat or not all_world_points:
- print(f"Skipping {scene}: failed to obtain valid camera poses or point clouds")
- return None
-
- output_scene_dir.mkdir(parents=True, exist_ok=True)
-
- poses_gt = c2ws
- w2cs = np.linalg.inv(poses_gt)
- traj_est_poses = np.array(all_cam_to_world_mat)
- n = min(len(traj_est_poses), len(w2cs))
- timestamps = frame_ids[:n]
- stats_aligned, traj_plot, _ = eval_trajectory(
- traj_est_poses[:n], w2cs[:n], timestamps, align=True
- )
-
- try:
- merged_point_cloud = np.vstack(all_world_points)
- gt_point_cloud, _ = load_gt_pointcloud(scene, gt_ply_dir)
- if gt_point_cloud is not None:
- homogeneous_points = np.hstack(
- [merged_point_cloud, np.ones((merged_point_cloud.shape[0], 1))]
- )
- world_points_raw = np.dot(homogeneous_points, first_gt_pose.T)[:, :3]
-
- world_points_scaled, scale_factor = align_point_clouds_scale(
- world_points_raw, gt_point_cloud
- )
-
- cd_value = compute_chamfer_distance(
- world_points_scaled, gt_point_cloud, max_dist=chamfer_max_dist
- )
- stats_aligned["chamfer_distance"] = float(cd_value)
- stats_aligned["scale_factor"] = float(scale_factor)
- except Exception as e:
- print(f"Error computing Chamfer Distance for {scene}: {e}")
-
- all_metrics: Dict[str, float] = deepcopy(stats_aligned)
- for metric_name, metric_value in list(stats_aligned.items()):
- all_metrics[f"aligned_{metric_name}"] = metric_value
- all_metrics["inference_time_ms"] = float(inference_time_ms)
-
- with open(output_scene_dir / "metrics.json", "w") as f:
- import json
-
- json.dump(all_metrics, f, indent=4)
- if plot_flag:
- try:
- traj_plot.save(output_scene_dir / "plot.png")
- except Exception:
- pass
-
- return all_metrics
-
-
-def compute_average_metrics_and_save(
- all_scenes_metrics: Dict[str, Dict[str, Dict[str, float]]],
- output_path: Path,
- input_frame: int,
-) -> Dict[str, float]:
- metric_names = [
- "chamfer_distance",
- "ate",
- "are",
- "rpe_rot",
- "rpe_trans",
- "inference_time_ms",
- ]
- average_metrics_list: Dict[str, List[float]] = {
- metric: [] for metric in metric_names
- }
- for _, metrics in all_scenes_metrics["scenes"].items():
- for metric_name, metric_value in metrics.items():
- if metric_name in average_metrics_list:
- average_metrics_list[metric_name].append(float(metric_value))
-
- average_metrics: Dict[str, float] = {}
- for metric_name, values in average_metrics_list.items():
- average_metrics[metric_name] = float(np.mean(values)) if values else 0.0
-
- all_scenes_metrics["average"] = average_metrics
- output_path.mkdir(parents=True, exist_ok=True)
-
- input_frame_dir = output_path / f"input_frame_{input_frame}"
- input_frame_dir.mkdir(parents=True, exist_ok=True)
-
- with open(input_frame_dir / "all_scenes_metrics.json", "w") as f:
- import json
-
- json.dump(all_scenes_metrics, f, indent=4)
-
- with open(input_frame_dir / "average_metrics.json", "w") as f:
- import json
-
- json.dump(average_metrics, f, indent=4)
-
- print("\nAverage metrics:")
- for metric_name, value in average_metrics.items():
- print(f"{metric_name}: {value:.6f}")
-
- return average_metrics
diff --git a/FastVGGT/vggt/utils/geometry.py b/FastVGGT/vggt/utils/geometry.py
deleted file mode 100644
index f555516dbc8a7dd8c7b15e6fbc928a5bfe8f740b..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/utils/geometry.py
+++ /dev/null
@@ -1,324 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import os
-import torch
-import numpy as np
-
-
-from vggt.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion
-
-
-def unproject_depth_map_to_point_map(
- depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
-) -> np.ndarray:
- """
- Unproject a batch of depth maps to 3D world coordinates.
-
- Args:
- depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
- extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
- intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
-
- Returns:
- np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
- """
- if isinstance(depth_map, torch.Tensor):
- depth_map = depth_map.cpu().numpy()
- if isinstance(extrinsics_cam, torch.Tensor):
- extrinsics_cam = extrinsics_cam.cpu().numpy()
- if isinstance(intrinsics_cam, torch.Tensor):
- intrinsics_cam = intrinsics_cam.cpu().numpy()
-
- world_points_list = []
- for frame_idx in range(depth_map.shape[0]):
- cur_world_points, _, _ = depth_to_world_coords_points(
- depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
- )
- world_points_list.append(cur_world_points)
- world_points_array = np.stack(world_points_list, axis=0)
-
- return world_points_array
-
-
-def depth_to_world_coords_points(
- depth_map: np.ndarray,
- extrinsic: np.ndarray,
- intrinsic: np.ndarray,
- eps=1e-8,
-) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
- """
- Convert a depth map to world coordinates.
-
- Args:
- depth_map (np.ndarray): Depth map of shape (H, W).
- intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
- extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
-
- Returns:
- tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
- """
- if depth_map is None:
- return None, None, None
-
- # Valid depth mask
- point_mask = depth_map > eps
-
- # Convert depth map to camera coordinates
- cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
-
- # Multiply with the inverse of extrinsic matrix to transform to world coordinates
- # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
- cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
-
- R_cam_to_world = cam_to_world_extrinsic[:3, :3]
- t_cam_to_world = cam_to_world_extrinsic[:3, 3]
-
- # Apply the rotation and translation to the camera coordinates
- world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
- # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
-
- return world_coords_points, cam_coords_points, point_mask
-
-
-def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
- """
- Convert a depth map to camera coordinates.
-
- Args:
- depth_map (np.ndarray): Depth map of shape (H, W).
- intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
-
- Returns:
- tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
- """
- H, W = depth_map.shape
- assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
- assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
-
- # Intrinsic parameters
- fu, fv = intrinsic[0, 0], intrinsic[1, 1]
- cu, cv = intrinsic[0, 2], intrinsic[1, 2]
-
- # Generate grid of pixel coordinates
- u, v = np.meshgrid(np.arange(W), np.arange(H))
-
- # Unproject to camera coordinates
- x_cam = (u - cu) * depth_map / fu
- y_cam = (v - cv) * depth_map / fv
- z_cam = depth_map
-
- # Stack to form camera coordinates
- cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
-
- return cam_coords
-
-
-def closed_form_inverse_se3(se3, R=None, T=None):
- """
- Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
-
- If `R` and `T` are provided, they must correspond to the rotation and translation
- components of `se3`. Otherwise, they will be extracted from `se3`.
-
- Args:
- se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
- R (optional): Nx3x3 array or tensor of rotation matrices.
- T (optional): Nx3x1 array or tensor of translation vectors.
-
- Returns:
- Inverted SE3 matrices with the same type and device as `se3`.
-
- Shapes:
- se3: (N, 4, 4)
- R: (N, 3, 3)
- T: (N, 3, 1)
- """
- # Check if se3 is a numpy array or a torch tensor
- is_numpy = isinstance(se3, np.ndarray)
-
- # Validate shapes
- if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
- raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
-
- # Extract R and T if not provided
- if R is None:
- R = se3[:, :3, :3] # (N,3,3)
- if T is None:
- T = se3[:, :3, 3:] # (N,3,1)
-
- # Transpose R
- if is_numpy:
- # Compute the transpose of the rotation for NumPy
- R_transposed = np.transpose(R, (0, 2, 1))
- # -R^T t for NumPy
- top_right = -np.matmul(R_transposed, T)
- inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
- else:
- R_transposed = R.transpose(1, 2) # (N,3,3)
- top_right = -torch.bmm(R_transposed, T) # (N,3,1)
- inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
- inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
-
- inverted_matrix[:, :3, :3] = R_transposed
- inverted_matrix[:, :3, 3:] = top_right
-
- return inverted_matrix
-
-
-# TODO: this code can be further cleaned up
-
-
-def project_world_points_to_camera_points_batch(world_points, cam_extrinsics):
- """
- Transforms 3D points to 2D using extrinsic and intrinsic parameters.
- Args:
- world_points (torch.Tensor): 3D points of shape BxSxHxWx3.
- cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4.
- Returns:
- """
- # TODO: merge this into project_world_points_to_cam
-
- # device = world_points.device
- # with torch.autocast(device_type=device.type, enabled=False):
- ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1)
- world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4)
-
- # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4)
- extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3)
-
- # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1)
- world_points_h_exp = world_points_h.unsqueeze(-1)
-
- # Now perform the matrix multiplication
- # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1)
- camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1)
-
- return camera_points
-
-
-
-def project_world_points_to_cam(
- world_points,
- cam_extrinsics,
- cam_intrinsics=None,
- distortion_params=None,
- default=0,
- only_points_cam=False,
-):
- """
- Transforms 3D points to 2D using extrinsic and intrinsic parameters.
- Args:
- world_points (torch.Tensor): 3D points of shape Px3.
- cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
- cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
- distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion.
- Returns:
- torch.Tensor: Transformed 2D points of shape BxNx2.
- """
- device = world_points.device
- # with torch.autocast(device_type=device.type, dtype=torch.double):
- with torch.autocast(device_type=device.type, enabled=False):
- N = world_points.shape[0] # Number of points
- B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras
- world_points_homogeneous = torch.cat(
- [world_points, torch.ones_like(world_points[..., 0:1])], dim=1
- ) # Nx4
- # Reshape for batch processing
- world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand(
- B, -1, -1
- ) # BxNx4
-
- # Step 1: Apply extrinsic parameters
- # Transform 3D points to camera coordinate system for all cameras
- cam_points = torch.bmm(
- cam_extrinsics, world_points_homogeneous.transpose(-1, -2)
- )
-
- if only_points_cam:
- return None, cam_points
-
- # Step 2: Apply intrinsic parameters and (optional) distortion
- image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default)
-
- return image_points, cam_points
-
-
-
-def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0):
- """
- Applies intrinsic parameters and optional distortion to the given 3D points.
-
- Args:
- cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
- cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
- distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
- default (float, optional): Default value to replace NaNs in the output.
-
- Returns:
- pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
- """
-
- # Normalized device coordinates (NDC)
- cam_points = cam_points / cam_points[:, 2:3, :]
- ndc_xy = cam_points[:, :2, :]
-
- # Apply distortion if distortion_params are provided
- if distortion_params is not None:
- x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1])
- distorted_xy = torch.stack([x_distorted, y_distorted], dim=1)
- else:
- distorted_xy = ndc_xy
-
- # Prepare cam_points for batch matrix multiplication
- cam_coords_homo = torch.cat(
- (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1
- ) # Bx3xN
- # Apply intrinsic parameters using batch matrix multiplication
- pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN
-
- # Extract x and y coordinates
- pixel_coords = pixel_coords[:, :2, :] # Bx2xN
-
- # Replace NaNs with default value
- pixel_coords = torch.nan_to_num(pixel_coords, nan=default)
-
- return pixel_coords.transpose(1, 2) # BxNx2
-
-
-
-
-def cam_from_img(pred_tracks, intrinsics, extra_params=None):
- """
- Normalize predicted tracks based on camera intrinsics.
- Args:
- intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3].
- pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2].
- extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
- Returns:
- torch.Tensor: Normalized tracks tensor.
- """
-
- # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here
- # otherwise we can use something like
- # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2))
-
- principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2)
- focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2)
- tracks_normalized = (pred_tracks - principal_point) / focal_length
-
- if extra_params is not None:
- # Apply iterative undistortion
- try:
- tracks_normalized = iterative_undistortion(
- extra_params, tracks_normalized
- )
- except:
- tracks_normalized = single_undistortion(
- extra_params, tracks_normalized
- )
-
- return tracks_normalized
\ No newline at end of file
diff --git a/FastVGGT/vggt/utils/helper.py b/FastVGGT/vggt/utils/helper.py
deleted file mode 100644
index 7b019189c85ff86645a4cf3756632aa8d4500649..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/utils/helper.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import numpy as np
-
-
-def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray:
- """
- If mask has more than max_trues True values,
- randomly keep only max_trues of them and set the rest to False.
- """
- # 1D positions of all True entries
- true_indices = np.flatnonzero(mask) # shape = (N_true,)
-
- # if already within budget, return as-is
- if true_indices.size <= max_trues:
- return mask
-
- # randomly pick which True positions to keep
- sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,)
-
- # build new flat mask: True only at sampled positions
- limited_flat_mask = np.zeros(mask.size, dtype=bool)
- limited_flat_mask[sampled_indices] = True
-
- # restore original shape
- return limited_flat_mask.reshape(mask.shape)
-
-
-def create_pixel_coordinate_grid(num_frames, height, width):
- """
- Creates a grid of pixel coordinates and frame indices for all frames.
- Returns:
- tuple: A tuple containing:
- - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3)
- with x, y coordinates and frame indices
- - y_coords (numpy.ndarray): Array of y coordinates for all frames
- - x_coords (numpy.ndarray): Array of x coordinates for all frames
- - f_coords (numpy.ndarray): Array of frame indices for all frames
- """
- # Create coordinate grids for a single frame
- y_grid, x_grid = np.indices((height, width), dtype=np.float32)
- x_grid = x_grid[np.newaxis, :, :]
- y_grid = y_grid[np.newaxis, :, :]
-
- # Broadcast to all frames
- x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
- y_coords = np.broadcast_to(y_grid, (num_frames, height, width))
-
- # Create frame indices and broadcast
- f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis]
- f_coords = np.broadcast_to(f_idx, (num_frames, height, width))
-
- # Stack coordinates and frame indices
- points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)
-
- return points_xyf
diff --git a/FastVGGT/vggt/utils/load_fn.py b/FastVGGT/vggt/utils/load_fn.py
deleted file mode 100644
index 3ae01ba2601f7b5039220bbae94725646dd66a41..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/utils/load_fn.py
+++ /dev/null
@@ -1,242 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-from PIL import Image
-from torchvision import transforms as TF
-import numpy as np
-
-
-def load_and_preprocess_images_square(image_path_list, target_size=1024):
- """
- Load and preprocess images by center padding to square and resizing to target size.
- Also returns the position information of original pixels after transformation.
-
- Args:
- image_path_list (list): List of paths to image files
- target_size (int, optional): Target size for both width and height. Defaults to 518.
-
- Returns:
- tuple: (
- torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
- torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
- )
-
- Raises:
- ValueError: If the input list is empty
- """
- # Check for empty list
- if len(image_path_list) == 0:
- raise ValueError("At least 1 image is required")
-
- images = []
- original_coords = [] # Renamed from position_info to be more descriptive
- to_tensor = TF.ToTensor()
-
- for image_path in image_path_list:
- # Open image
- img = Image.open(image_path)
-
- # If there's an alpha channel, blend onto white background
- if img.mode == "RGBA":
- background = Image.new("RGBA", img.size, (255, 255, 255, 255))
- img = Image.alpha_composite(background, img)
-
- # Convert to RGB
- img = img.convert("RGB")
-
- # Get original dimensions
- width, height = img.size
-
- # Make the image square by padding the shorter dimension
- max_dim = max(width, height)
-
- # Calculate padding
- left = (max_dim - width) // 2
- top = (max_dim - height) // 2
-
- # Calculate scale factor for resizing
- scale = target_size / max_dim
-
- # Calculate final coordinates of original image in target space
- x1 = left * scale
- y1 = top * scale
- x2 = (left + width) * scale
- y2 = (top + height) * scale
-
- # Store original image coordinates and scale
- original_coords.append(np.array([x1, y1, x2, y2, width, height]))
-
- # Create a new black square image and paste original
- square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
- square_img.paste(img, (left, top))
-
- # Resize to target size
- square_img = square_img.resize(
- (target_size, target_size), Image.Resampling.BICUBIC
- )
-
- # Convert to tensor
- img_tensor = to_tensor(square_img)
- images.append(img_tensor)
-
- # Stack all images
- images = torch.stack(images)
- original_coords = torch.from_numpy(np.array(original_coords)).float()
-
- # Add additional dimension if single image to ensure correct shape
- if len(image_path_list) == 1:
- if images.dim() == 3:
- images = images.unsqueeze(0)
- original_coords = original_coords.unsqueeze(0)
-
- return images, original_coords
-
-
-def load_and_preprocess_images(image_path_list, mode="crop"):
- """
- A quick start function to load and preprocess images for model input.
- This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
-
- Args:
- image_path_list (list): List of paths to image files
- mode (str, optional): Preprocessing mode, either "crop" or "pad".
- - "crop" (default): Sets width to 518px and center crops height if needed.
- - "pad": Preserves all pixels by making the largest dimension 518px
- and padding the smaller dimension to reach a square shape.
-
- Returns:
- torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
-
- Raises:
- ValueError: If the input list is empty or if mode is invalid
-
- Notes:
- - Images with different dimensions will be padded with white (value=1.0)
- - A warning is printed when images have different shapes
- - When mode="crop": The function ensures width=518px while maintaining aspect ratio
- and height is center-cropped if larger than 518px
- - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
- and the smaller dimension is padded to reach a square shape (518x518)
- - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
- """
- # Check for empty list
- if len(image_path_list) == 0:
- raise ValueError("At least 1 image is required")
-
- # Validate mode
- if mode not in ["crop", "pad"]:
- raise ValueError("Mode must be either 'crop' or 'pad'")
-
- images = []
- shapes = set()
- to_tensor = TF.ToTensor()
- target_size = 518
-
- # First process all images and collect their shapes
- for image_path in image_path_list:
- # Open image
- img = Image.open(image_path)
-
- # If there's an alpha channel, blend onto white background:
- if img.mode == "RGBA":
- # Create white background
- background = Image.new("RGBA", img.size, (255, 255, 255, 255))
- # Alpha composite onto the white background
- img = Image.alpha_composite(background, img)
-
- # Now convert to "RGB" (this step assigns white for transparent areas)
- img = img.convert("RGB")
-
- width, height = img.size
-
- if mode == "pad":
- # Make the largest dimension 518px while maintaining aspect ratio
- if width >= height:
- new_width = target_size
- new_height = (
- round(height * (new_width / width) / 14) * 14
- ) # Make divisible by 14
- else:
- new_height = target_size
- new_width = (
- round(width * (new_height / height) / 14) * 14
- ) # Make divisible by 14
- else: # mode == "crop"
- # Original behavior: set width to 518px
- new_width = target_size
- # Calculate height maintaining aspect ratio, divisible by 14
- new_height = round(height * (new_width / width) / 14) * 14
-
- # Resize with new dimensions (width, height)
- img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
- img = to_tensor(img) # Convert to tensor (0, 1)
-
- # Center crop height if it's larger than 518 (only in crop mode)
- if mode == "crop" and new_height > target_size:
- start_y = (new_height - target_size) // 2
- img = img[:, start_y : start_y + target_size, :]
-
- # For pad mode, pad to make a square of target_size x target_size
- if mode == "pad":
- h_padding = target_size - img.shape[1]
- w_padding = target_size - img.shape[2]
-
- if h_padding > 0 or w_padding > 0:
- pad_top = h_padding // 2
- pad_bottom = h_padding - pad_top
- pad_left = w_padding // 2
- pad_right = w_padding - pad_left
-
- # Pad with white (value=1.0)
- img = torch.nn.functional.pad(
- img,
- (pad_left, pad_right, pad_top, pad_bottom),
- mode="constant",
- value=1.0,
- )
-
- shapes.add((img.shape[1], img.shape[2]))
- images.append(img)
-
- # Check if we have different shapes
- # In theory our model can also work well with different shapes
- if len(shapes) > 1:
- print(f"Warning: Found images with different shapes: {shapes}")
- # Find maximum dimensions
- max_height = max(shape[0] for shape in shapes)
- max_width = max(shape[1] for shape in shapes)
-
- # Pad images if necessary
- padded_images = []
- for img in images:
- h_padding = max_height - img.shape[1]
- w_padding = max_width - img.shape[2]
-
- if h_padding > 0 or w_padding > 0:
- pad_top = h_padding // 2
- pad_bottom = h_padding - pad_top
- pad_left = w_padding // 2
- pad_right = w_padding - pad_left
-
- img = torch.nn.functional.pad(
- img,
- (pad_left, pad_right, pad_top, pad_bottom),
- mode="constant",
- value=1.0,
- )
- padded_images.append(img)
- images = padded_images
-
- images = torch.stack(images) # concatenate images
-
- # Ensure correct shape when single image
- if len(image_path_list) == 1:
- # Verify shape is (1, C, H, W)
- if images.dim() == 3:
- images = images.unsqueeze(0)
-
- return images
diff --git a/FastVGGT/vggt/utils/pose_enc.py b/FastVGGT/vggt/utils/pose_enc.py
deleted file mode 100644
index 9d3b964330af0e62f4d36d332317ae00cb99b3a9..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/utils/pose_enc.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-from .rotation import quat_to_mat, mat_to_quat
-
-
-def extri_intri_to_pose_encoding(
- extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512)
-):
- """Convert camera extrinsics and intrinsics to a compact pose encoding.
-
- This function transforms camera parameters into a unified pose encoding format,
- which can be used for various downstream tasks like pose prediction or representation.
-
- Args:
- extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
- where B is batch size and S is sequence length.
- In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
- The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
- intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
- Defined in pixels, with format:
- [[fx, 0, cx],
- [0, fy, cy],
- [0, 0, 1]]
- where fx, fy are focal lengths and (cx, cy) is the principal point
- image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
- Required for computing field of view values. For example: (256, 512).
- pose_encoding_type (str): Type of pose encoding to use. Currently only
- supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
-
- Returns:
- torch.Tensor: Encoded camera pose parameters with shape BxSx9.
- For "absT_quaR_FoV" type, the 9 dimensions are:
- - [:3] = absolute translation vector T (3D)
- - [3:7] = rotation as quaternion quat (4D)
- - [7:] = field of view (2D)
- """
-
- # extrinsics: BxSx3x4
- # intrinsics: BxSx3x3
-
- if pose_encoding_type == "absT_quaR_FoV":
- R = extrinsics[:, :, :3, :3] # BxSx3x3
- T = extrinsics[:, :, :3, 3] # BxSx3
-
- quat = mat_to_quat(R)
- # Note the order of h and w here
- H, W = image_size_hw
- fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
- fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
- pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
- else:
- raise NotImplementedError
-
- return pose_encoding
-
-
-def pose_encoding_to_extri_intri(
- pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512)
-):
- """Convert a pose encoding back to camera extrinsics and intrinsics.
-
- This function performs the inverse operation of extri_intri_to_pose_encoding,
- reconstructing the full camera parameters from the compact encoding.
-
- Args:
- pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
- where B is batch size and S is sequence length.
- For "absT_quaR_FoV" type, the 9 dimensions are:
- - [:3] = absolute translation vector T (3D)
- - [3:7] = rotation as quaternion quat (4D)
- - [7:] = field of view (2D)
- image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
- Required for reconstructing intrinsics from field of view values.
- For example: (256, 512).
- pose_encoding_type (str): Type of pose encoding used. Currently only
- supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
- build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
- If False, only extrinsics are returned and intrinsics will be None.
-
- Returns:
- tuple: (extrinsics, intrinsics)
- - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
- In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
- transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
- a 3x1 translation vector.
- - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
- or None if build_intrinsics is False. Defined in pixels, with format:
- [[fx, 0, cx],
- [0, fy, cy],
- [0, 0, 1]]
- where fx, fy are focal lengths and (cx, cy) is the principal point,
- assumed to be at the center of the image (W/2, H/2).
- """
-
- intrinsics = None
-
- if pose_encoding_type == "absT_quaR_FoV":
- T = pose_encoding[..., :3]
- quat = pose_encoding[..., 3:7]
- fov_h = pose_encoding[..., 7]
- fov_w = pose_encoding[..., 8]
-
- R = quat_to_mat(quat)
- extrinsics = torch.cat([R, T[..., None]], dim=-1)
-
- if build_intrinsics:
- H, W = image_size_hw
- fy = (H / 2.0) / torch.tan(fov_h / 2.0)
- fx = (W / 2.0) / torch.tan(fov_w / 2.0)
- intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
- intrinsics[..., 0, 0] = fx
- intrinsics[..., 1, 1] = fy
- intrinsics[..., 0, 2] = W / 2
- intrinsics[..., 1, 2] = H / 2
- intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
- else:
- raise NotImplementedError
-
- return extrinsics, intrinsics
diff --git a/FastVGGT/vggt/utils/rotation.py b/FastVGGT/vggt/utils/rotation.py
deleted file mode 100644
index f972afd8414c82fa1e9ed231725fd3f9f6ebde77..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/utils/rotation.py
+++ /dev/null
@@ -1,132 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
-
-import torch
-import numpy as np
-import torch.nn.functional as F
-
-
-def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
- """
- Quaternion Order: XYZW or say ijkr, scalar-last
-
- Convert rotations given as quaternions to rotation matrices.
- Args:
- quaternions: quaternions with real part last,
- as tensor of shape (..., 4).
-
- Returns:
- Rotation matrices as tensor of shape (..., 3, 3).
- """
- i, j, k, r = torch.unbind(quaternions, -1)
- # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
- two_s = 2.0 / (quaternions * quaternions).sum(-1)
-
- o = torch.stack(
- (
- 1 - two_s * (j * j + k * k),
- two_s * (i * j - k * r),
- two_s * (i * k + j * r),
- two_s * (i * j + k * r),
- 1 - two_s * (i * i + k * k),
- two_s * (j * k - i * r),
- two_s * (i * k - j * r),
- two_s * (j * k + i * r),
- 1 - two_s * (i * i + j * j),
- ),
- -1,
- )
- return o.reshape(quaternions.shape[:-1] + (3, 3))
-
-
-def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
- """
- Convert rotations given as rotation matrices to quaternions.
-
- Args:
- matrix: Rotation matrices as tensor of shape (..., 3, 3).
-
- Returns:
- quaternions with real part last, as tensor of shape (..., 4).
- Quaternion Order: XYZW or say ijkr, scalar-last
- """
- if matrix.size(-1) != 3 or matrix.size(-2) != 3:
- raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
-
- batch_dim = matrix.shape[:-2]
- m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
-
- q_abs = _sqrt_positive_part(
- torch.stack(
- [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
- )
- )
-
- # we produce the desired quaternion multiplied by each of r, i, j, k
- quat_by_rijk = torch.stack(
- [
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
- # `int`.
- torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
- # `int`.
- torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
- # `int`.
- torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
- # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
- # `int`.
- torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
- ],
- dim=-2,
- )
-
- # We floor here at 0.1 but the exact level is not important; if q_abs is small,
- # the candidate won't be picked.
- flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
- quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
-
- # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
- # forall i; we pick the best-conditioned one (with the largest denominator)
- out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
-
- # Convert from rijk to ijkr
- out = out[..., [1, 2, 3, 0]]
-
- out = standardize_quaternion(out)
-
- return out
-
-
-def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
- """
- Returns torch.sqrt(torch.max(0, x))
- but with a zero subgradient where x is 0.
- """
- ret = torch.zeros_like(x)
- positive_mask = x > 0
- if torch.is_grad_enabled():
- ret[positive_mask] = torch.sqrt(x[positive_mask])
- else:
- ret = torch.where(positive_mask, torch.sqrt(x), ret)
- return ret
-
-
-def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
- """
- Convert a unit quaternion to a standard form: one in which the real
- part is non negative.
-
- Args:
- quaternions: Quaternions with real part last,
- as tensor of shape (..., 4).
-
- Returns:
- Standardized quaternions as tensor of shape (..., 4).
- """
- return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
diff --git a/FastVGGT/vggt/utils/visual_track.py b/FastVGGT/vggt/utils/visual_track.py
deleted file mode 100644
index 796c114ccba00b5f7850e04b9444a6cd5c44b154..0000000000000000000000000000000000000000
--- a/FastVGGT/vggt/utils/visual_track.py
+++ /dev/null
@@ -1,239 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import cv2
-import torch
-import numpy as np
-import os
-
-
-def color_from_xy(x, y, W, H, cmap_name="hsv"):
- """
- Map (x, y) -> color in (R, G, B).
- 1) Normalize x,y to [0,1].
- 2) Combine them into a single scalar c in [0,1].
- 3) Use matplotlib's colormap to convert c -> (R,G,B).
-
- You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
- """
- import matplotlib.cm
- import matplotlib.colors
-
- x_norm = x / max(W - 1, 1)
- y_norm = y / max(H - 1, 1)
- # Simple combination:
- c = (x_norm + y_norm) / 2.0
-
- cmap = matplotlib.cm.get_cmap(cmap_name)
- # cmap(c) -> (r,g,b,a) in [0,1]
- rgba = cmap(c)
- r, g, b = rgba[0], rgba[1], rgba[2]
- return (r, g, b) # in [0,1], RGB order
-
-
-def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
- """
- Given all tracks in one sample (b), compute a (N,3) array of RGB color values
- in [0,255]. The color is determined by the (x,y) position in the first
- visible frame for each track.
-
- Args:
- tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
- vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
- image_width, image_height: used for normalizing (x, y).
- cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
-
- Returns:
- track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
- """
- S, N, _ = tracks_b.shape
- track_colors = np.zeros((N, 3), dtype=np.uint8)
-
- if vis_mask_b is None:
- # treat all as visible
- vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
-
- for i in range(N):
- # Find first visible frame for track i
- visible_frames = torch.where(vis_mask_b[:, i])[0]
- if len(visible_frames) == 0:
- # track is never visible; just assign black or something
- track_colors[i] = (0, 0, 0)
- continue
-
- first_s = int(visible_frames[0].item())
- # use that frame's (x,y)
- x, y = tracks_b[first_s, i].tolist()
-
- # map (x,y) -> (R,G,B) in [0,1]
- r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
- # scale to [0,255]
- r, g, b = int(r * 255), int(g * 255), int(b * 255)
- track_colors[i] = (r, g, b)
-
- return track_colors
-
-
-def visualize_tracks_on_images(
- images,
- tracks,
- track_vis_mask=None,
- out_dir="track_visuals_concat_by_xy",
- image_format="CHW", # "CHW" or "HWC"
- normalize_mode="[0,1]",
- cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
- frames_per_row=4, # New parameter for grid layout
- save_grid=True, # Flag to control whether to save the grid image
-):
- """
- Visualizes frames in a grid layout with specified frames per row.
- Each track's color is determined by its (x,y) position
- in the first visible frame (or frame 0 if always visible).
- Finally convert the BGR result to RGB before saving.
- Also saves each individual frame as a separate PNG file.
-
- Args:
- images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
- tracks: torch.Tensor (S, N, 2), last dim = (x, y).
- track_vis_mask: torch.Tensor (S, N) or None.
- out_dir: folder to save visualizations.
- image_format: "CHW" or "HWC".
- normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
- cmap_name: a matplotlib colormap name for color_from_xy.
- frames_per_row: number of frames to display in each row of the grid.
- save_grid: whether to save all frames in one grid image.
-
- Returns:
- None (saves images in out_dir).
- """
-
- if len(tracks.shape) == 4:
- tracks = tracks.squeeze(0)
- images = images.squeeze(0)
- if track_vis_mask is not None:
- track_vis_mask = track_vis_mask.squeeze(0)
-
- import matplotlib
-
- matplotlib.use("Agg") # for non-interactive (optional)
-
- os.makedirs(out_dir, exist_ok=True)
-
- S = images.shape[0]
- _, N, _ = tracks.shape # (S, N, 2)
-
- # Move to CPU
- images = images.cpu().clone()
- tracks = tracks.cpu().clone()
- if track_vis_mask is not None:
- track_vis_mask = track_vis_mask.cpu().clone()
-
- # Infer H, W from images shape
- if image_format == "CHW":
- # e.g. images[s].shape = (3, H, W)
- H, W = images.shape[2], images.shape[3]
- else:
- # e.g. images[s].shape = (H, W, 3)
- H, W = images.shape[1], images.shape[2]
-
- # Pre-compute the color for each track i based on first visible position
- track_colors_rgb = get_track_colors_by_position(
- tracks, # shape (S, N, 2)
- vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
- image_width=W,
- image_height=H,
- cmap_name=cmap_name,
- )
-
- # We'll accumulate each frame's drawn image in a list
- frame_images = []
-
- for s in range(S):
- # shape => either (3, H, W) or (H, W, 3)
- img = images[s]
-
- # Convert to (H, W, 3)
- if image_format == "CHW":
- img = img.permute(1, 2, 0) # (H, W, 3)
- # else "HWC", do nothing
-
- img = img.numpy().astype(np.float32)
-
- # Scale to [0,255] if needed
- if normalize_mode == "[0,1]":
- img = np.clip(img, 0, 1) * 255.0
- elif normalize_mode == "[-1,1]":
- img = (img + 1.0) * 0.5 * 255.0
- img = np.clip(img, 0, 255.0)
- # else no normalization
-
- # Convert to uint8
- img = img.astype(np.uint8)
-
- # For drawing in OpenCV, convert to BGR
- img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
-
- # Draw each visible track
- cur_tracks = tracks[s] # shape (N, 2)
- if track_vis_mask is not None:
- valid_indices = torch.where(track_vis_mask[s])[0]
- else:
- valid_indices = range(N)
-
- cur_tracks_np = cur_tracks.numpy()
- for i in valid_indices:
- x, y = cur_tracks_np[i]
- pt = (int(round(x)), int(round(y)))
-
- # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
- R, G, B = track_colors_rgb[i]
- color_bgr = (int(B), int(G), int(R))
- cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
-
- # Convert back to RGB for consistent final saving:
- img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
-
- # Save individual frame
- frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
- # Convert to BGR for OpenCV imwrite
- frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
- cv2.imwrite(frame_path, frame_bgr)
-
- frame_images.append(img_rgb)
-
- # Only create and save the grid image if save_grid is True
- if save_grid:
- # Calculate grid dimensions
- num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
-
- # Create a grid of images
- grid_img = None
- for row in range(num_rows):
- start_idx = row * frames_per_row
- end_idx = min(start_idx + frames_per_row, S)
-
- # Concatenate this row horizontally
- row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
-
- # If this row has fewer than frames_per_row images, pad with black
- if end_idx - start_idx < frames_per_row:
- padding_width = (frames_per_row - (end_idx - start_idx)) * W
- padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
- row_img = np.concatenate([row_img, padding], axis=1)
-
- # Add this row to the grid
- if grid_img is None:
- grid_img = row_img
- else:
- grid_img = np.concatenate([grid_img, row_img], axis=0)
-
- out_path = os.path.join(out_dir, "tracks_grid.png")
- # Convert back to BGR for OpenCV imwrite
- grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
- cv2.imwrite(out_path, grid_img_bgr)
- print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
-
- print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")