diff --git a/FastVGGT/.gitignore b/FastVGGT/.gitignore deleted file mode 100644 index fc1e52d1adf9afe09ea2ead132f8fa4babf666cc..0000000000000000000000000000000000000000 --- a/FastVGGT/.gitignore +++ /dev/null @@ -1,160 +0,0 @@ -.hydra/ -output/ -ckpt/ -.vscode/ -dependency/ -# Byte-compiled / optimized / DLL files -__pycache__/ -**/__pycache__/ -*.py[cod] -*$py.class -test_logs/ -quick_start_logs/ -logs/ -*.pth -/data/ -*.png -eval_results/ -.vscode/ -.curosr/ - -# C extensions -*.so -LightGlue/ -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Profiling data -.prof - -# Folder specific to your needs -**/tmp/ -**/outputs/skyseg.onnx -skyseg.onnx - -# pixi environments -.pixi -*.egg-info diff --git a/FastVGGT/.vscode/launch.json b/FastVGGT/.vscode/launch.json deleted file mode 100644 index 218c5dcf94c00a2c11d02a0a1deeab8f37807ebf..0000000000000000000000000000000000000000 --- a/FastVGGT/.vscode/launch.json +++ /dev/null @@ -1,85 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - - { - "name": "launch", - "type": "debugpy", - "request": "launch", - "program": "/home/sy/code/vggt_0625/training/launch.py", - "console": "integratedTerminal", - "args": "${command:pickArgs}", - "env": { - "CUDA_VISIBLE_DEVICES": "3", - }, - "cwd": "/home/sy/code/vggt_0625/training", - "justMyCode": true, - "python": "/home/sy/anaconda3/envs/vggt/bin/python" - } - ,{ - "name": "train_scannet", - "type": "debugpy", - "request": "launch", - "program": "/home/sy/code/vggt_0625/training/launch_scannet.py", - "console": "integratedTerminal", - "args": [ - // "--config_name", "scannet", - // "--exp_name", "scannet_exp001", - // "--resume_checkpoint_path", "/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt" - ], - "env": { - "CUDA_VISIBLE_DEVICES": "7", - "WORLD_SIZE": "1", - "RANK": "0", - "MASTER_ADDR": "localhost", - "MASTER_PORT": "12345" - }, - "cwd": "/home/sy/code/vggt_0625/training", - "justMyCode": true, - "python": "/home/sy/anaconda3/envs/vggt/bin/python" - } - ,{ - "name": "eval_scannet", - "type": "debugpy", - "request": "launch", - "program": "/home/sy/code/FastVGGT/eval/eval_scannet.py", - "console": "integratedTerminal", - "args": [ - "--data_dir","/data/sy/scannetv2/process_scannet", - "--gt_ply_dir","/data/sy/scannetv2/OpenDataLab___ScanNet_v2/raw/scans", - "--output_path", "/home/sy/code/FastVGGT/eval_results", - "--merging", "0", - "--ckpt_path","/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt", - "--vis_attn_map" - ], - "env": { - "CUDA_VISIBLE_DEVICES": "2" - }, - "justMyCode": true, - "python": "/home/sy/anaconda3/envs/fastvggt/bin/python" - }, - { - "name": "eval_cd", - "type": "debugpy", - "request": "launch", - "program": "/home/sy/code/FastVGGT/eval/eval_custom.py", - "console": "integratedTerminal", - "args": [ - "--merging", "0", - // "--kf","10", - // "--output_dir","/home/sy/code/vggt_0625/eval_results_cd", - "--data_path","/data/sy/segment-102751/", - "--vis_attn_map" - ], - "env": { - "CUDA_VISIBLE_DEVICES": "3" - }, - "justMyCode": true, - // "python": "/home/sy/anaconda3/envs/fastvggt/bin/python" - } - - ] -} \ No newline at end of file diff --git a/FastVGGT/README.md b/FastVGGT/README.md deleted file mode 100644 index 274ea91f0e398fdf0e94b921fb5bb7c48fe799fd..0000000000000000000000000000000000000000 --- a/FastVGGT/README.md +++ /dev/null @@ -1,163 +0,0 @@ -
-

⚡️ FastVGGT: Training-Free Acceleration of Visual Geometry Transformer

- -

- Paper PDF - Project Page -

- -Maclab Logo -Autolab Logo - - -**[Media Analytics & Computing Laboratory](https://mac.xmu.edu.cn/)**; **[AUTOLAB](https://zhipengzhang.cn/)** - - -[You Shen](https://mystorm16.github.io/), [Zhipeng Zhang](https://zhipengzhang.cn/), [Yansong Qu](https://quyans.github.io/), [Liujuan Cao](https://mac.xmu.edu.cn/ljcao/) -
- - -## 📰 News -- [Sep 8, 2025] Added custom dataset evaluation. -- [Sep 3, 2025] Paper release. -- [Sep 2, 2025] Code release. - -## 🔭 Overview - -FastVGGT observes **strong similarity** in attention maps and leverages it to design a training-free acceleration method for long-sequence 3D reconstruction, **achieving up to 4× faster inference without sacrificing accuracy.** - -Autolab Logo - - -## ⚙️ Environment Setup -First, create a virtual environment using Conda, clone this repository to your local machine, and install the required dependencies. - - -```bash -conda create -n fastvggt python=3.10 -conda activate fastvggt -git clone git@github.com:mystorm16/FastVGGT.git -cd FastVGGT -pip install -r requirements.txt -``` - -Next, prepare the ScanNet dataset: http://www.scan-net.org/ScanNet/ - -Then, download the VGGT checkpoint (we use the checkpoint link provided in https://github.com/facebookresearch/vggt/tree/evaluation/evaluation): -```bash -wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt -``` - -Finally, configure the dataset path and VGGT checkpoint path. For example: -```bash - parser.add_argument( - "--data_dir", type=Path, default="/data/scannetv2/process_scannet" - ) - parser.add_argument( - "--gt_ply_dir", - type=Path, - default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans", - ) - parser.add_argument( - "--ckpt_path", - type=str, - default="./ckpt/model_tracker_fixed_e20.pt", - ) -``` - - -## 💎 Observation - -Note: A large number of input_frames may significantly slow down saving the visualization results. Please try using a smaller number first. -```bash -python eval/eval_scannet.py --input_frame 30 --vis_attn_map --merging 0 -``` - -We observe that many token-level attention maps are highly similar in each block, motivating our optimization of the Global Attention module. - -Autolab Logo - - - -## 🏀 Evaluation -### Custom Dataset -Please organize the data according to the following directory: -``` -/ -├── images/ -│ ├── 000000.jpg -│ ├── 000001.jpg -│ └── ... -├── pose/ # Optional: Camera poses -│ ├── 000000.txt -│ ├── 000001.txt -│ └── ... -└── gt_ply/ # Optional: GT point cloud - └── scene_xxx.ply -``` -- Required: `images/` -- Additionally required when `--enable_evaluation` is enabled: `pose/` and `gt_ply/` - -Inference only: - -```bash -python eval/eval_custom.py \ - --data_path /path/to/your_dataset \ - --output_path ./eval_results_custom \ - --plot -``` - -Inference + Evaluation (requires `pose/` and `gt_ply/`): - -```bash -python eval/eval_custom.py \ - --data_path /path/to/your_dataset \ - --enable_evaluation \ - --output_path ./eval_results_custom \ - --plot -``` - -### ScanNet -Evaluate FastVGGT on the ScanNet dataset with 1,000 input images. The **--merging** parameter specifies the block index at which the merging strategy is applied: - -```bash -python eval/eval_scannet.py --input_frame 1000 --merging 0 -``` - -Evaluate Baseline VGGT on the ScanNet dataset with 1,000 input images: -```bash -python eval/eval_scannet.py --input_frame 1000 -``` -Autolab Logo - -### 7 Scenes & NRGBD -Evaluate across two datasets, sampling keyframes every 10 frames: -```bash -python eval/eval_7andN.py --kf 10 -``` - -## 🍺 Acknowledgements - -- Thanks to these great repositories: [VGGT](https://github.com/facebookresearch/vggt), [Dust3r](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [CUT3R](https://github.com/CUT3R/CUT3R), [MV-DUSt3R+](https://github.com/facebookresearch/mvdust3r), [StreamVGGT](https://github.com/wzzheng/StreamVGGT), [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long), [ToMeSD](https://github.com/dbolya/tomesd) and many other inspiring works in the community. - -- Special thanks to [Jianyuan Wang](https://jytime.github.io/) for his valuable discussions and suggestions on this work. - - - - -## ⚖️ License -See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available. - -## Citation - -If you find this project helpful, please consider citing the following paper: -``` -@article{shen2025fastvggt, - title={FastVGGT: Training-Free Acceleration of Visual Geometry Transformer}, - author={Shen, You and Zhang, Zhipeng and Qu, Yansong and Cao, Liujuan}, - journal={arXiv preprint arXiv:2509.02560}, - year={2025} -} -``` \ No newline at end of file diff --git a/FastVGGT/assets/attn_map.png b/FastVGGT/assets/attn_map.png deleted file mode 100644 index 98d6f2be8cdb49e5bccee7f902af292327fd0023..0000000000000000000000000000000000000000 --- a/FastVGGT/assets/attn_map.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8477957f593c203bcf41df91ac3ed0d22329e22250fdab9f8f8674340964242c -size 2369408 diff --git a/FastVGGT/assets/autolab_logo.png b/FastVGGT/assets/autolab_logo.png deleted file mode 100644 index c6ca9a41f5c05a1b2ae66368461672f46e20bba1..0000000000000000000000000000000000000000 --- a/FastVGGT/assets/autolab_logo.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4fcead3160cbf561c4385cc8b938a17a94652e3d849da6497f053d32d1245596 -size 5126003 diff --git a/FastVGGT/assets/maclab_logo.png b/FastVGGT/assets/maclab_logo.png deleted file mode 100644 index b96660decf945d3d786f41b9b42cb92798f9191f..0000000000000000000000000000000000000000 Binary files a/FastVGGT/assets/maclab_logo.png and /dev/null differ diff --git a/FastVGGT/assets/main.png b/FastVGGT/assets/main.png deleted file mode 100644 index 3bcab2d519689dfba3283e606cd55cb84175ede5..0000000000000000000000000000000000000000 --- a/FastVGGT/assets/main.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:eecacb414647f01dc8a52b4aba5ff2556733f46d1b9129613e3f59aceff69685 -size 883717 diff --git a/FastVGGT/assets/vs.png b/FastVGGT/assets/vs.png deleted file mode 100644 index 3729d9ba86000b81d0e9c878b2db6f16c7f9c48a..0000000000000000000000000000000000000000 --- a/FastVGGT/assets/vs.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dac1ea5397b9f985c6fb86fc5f321313cc5a372d0917b7c1c1c7dd2cea6bec5f -size 230157 diff --git a/FastVGGT/eval/__pycache__/base.cpython-310.pyc b/FastVGGT/eval/__pycache__/base.cpython-310.pyc deleted file mode 100644 index 49170288c162684ebafa68e5242efd320f0f35b8..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/__pycache__/base.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc b/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc deleted file mode 100644 index 6e6e0478775fe98dc22180bc43bad8073c43c907..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/eval/__pycache__/data.cpython-310.pyc b/FastVGGT/eval/__pycache__/data.cpython-310.pyc deleted file mode 100644 index 4d8ec6f8e1828fa87b02a4511490ad858de2093a..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/__pycache__/data.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/eval/__pycache__/data.cpython-37.pyc b/FastVGGT/eval/__pycache__/data.cpython-37.pyc deleted file mode 100644 index d76f2c95748956a04f009e99f39cefd9cc6ec4e0..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/__pycache__/data.cpython-37.pyc and /dev/null differ diff --git a/FastVGGT/eval/__pycache__/utils.cpython-310.pyc b/FastVGGT/eval/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index 638db49f10789e53dcf0ff33c9aaa28b814e3e72..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/eval/__pycache__/utils.cpython-37.pyc b/FastVGGT/eval/__pycache__/utils.cpython-37.pyc deleted file mode 100644 index fd73966947cb74c319d66d7bf221d73b5dde0ce0..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/__pycache__/utils.cpython-37.pyc and /dev/null differ diff --git a/FastVGGT/eval/base.py b/FastVGGT/eval/base.py deleted file mode 100644 index 4a716449a71f552ea408df0eb37e854cf4e92da6..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/base.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright (C) 2024-present Naver Corporation. All rights reserved. -# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). -# -# -------------------------------------------------------- -# base class for implementing datasets -# -------------------------------------------------------- -import PIL -import numpy as np -import torch - -from dataset_utils.transforms import ImgNorm -import dataset_utils.cropping as cropping -from utils import depthmap_to_absolute_camera_coordinates - - -class BaseStereoViewDataset: - """Define all basic options. - - Usage: - class MyDataset (BaseStereoViewDataset): - def _get_views(self, idx, rng): - # overload here - views = [] - views.append(dict(img=, ...)) - return views - """ - - def __init__( - self, - *, # only keyword arguments - split=None, - resolution=None, # square_size or (width, height) or list of [(width,height), ...] - transform=ImgNorm, - aug_crop=False, - seed=None, - ): - self.num_views = 2 - self.split = split - self._set_resolutions(resolution) - - self.transform = transform - if isinstance(transform, str): - transform = eval(transform) - - self.aug_crop = aug_crop - self.seed = seed - - def __len__(self): - return len(self.scenes) - - def get_stats(self): - return f"{len(self)} pairs" - - def __repr__(self): - resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" - return ( - f"""{type(self).__name__}({self.get_stats()}, - {self.split=}, - {self.seed=}, - resolutions={resolutions_str}, - {self.transform=})""".replace( - "self.", "" - ) - .replace("\n", "") - .replace(" ", "") - ) - - def _get_views(self, idx, resolution, rng): - raise NotImplementedError() - - def __getitem__(self, idx): - if isinstance(idx, tuple): - # the idx is specifying the aspect-ratio - idx, ar_idx = idx - else: - assert len(self._resolutions) == 1 - ar_idx = 0 - - # set-up the rng - if self.seed: # reseed for each __getitem__ - self._rng = np.random.default_rng(seed=self.seed + idx) - elif not hasattr(self, "_rng"): - seed = torch.initial_seed() # this is different for each dataloader process - self._rng = np.random.default_rng(seed=seed) - - # over-loaded code - resolution = self._resolutions[ - ar_idx - ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) - views = self._get_views(idx, resolution, self._rng) - - # check data-types - for v, view in enumerate(views): - assert ( - "pts3d" not in view - ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" - view["idx"] = v - - # encode the image - width, height = view["img"].size - view["true_shape"] = np.int32((height, width)) - view["img"] = self.transform(view["img"]) - - assert "camera_intrinsics" in view - if "camera_pose" not in view: - view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32) - else: - assert np.isfinite( - view["camera_pose"] - ).all(), f"NaN in camera pose for view {view_name(view)}" - assert "pts3d" not in view - assert "valid_mask" not in view - assert np.isfinite( - view["depthmap"] - ).all(), f"NaN in depthmap for view {view_name(view)}" - pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) - - view["pts3d"] = pts3d - view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) - - # check all datatypes - for key, val in view.items(): - res, err_msg = is_good_type(key, val) - assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" - K = view["camera_intrinsics"] - view["img_mask"] = True - view["ray_mask"] = False - view["ray_map"] = torch.full( - (6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan - ) - view["update"] = True - view["reset"] = False - - # last thing done! - for view in views: - # transpose to make sure all views are the same size - transpose_to_landscape(view) - # this allows to check whether the RNG is is the same state each time - view["rng"] = int.from_bytes(self._rng.bytes(4), "big") - return views - - def _set_resolutions(self, resolutions): - """Set the resolution(s) of the dataset. - Params: - - resolutions: int or tuple or list of tuples - """ - assert resolutions is not None, "undefined resolution" - - if not isinstance(resolutions, list): - resolutions = [resolutions] - - self._resolutions = [] - for resolution in resolutions: - if isinstance(resolution, int): - width = height = resolution - else: - width, height = resolution - assert isinstance( - width, int - ), f"Bad type for {width=} {type(width)=}, should be int" - assert isinstance( - height, int - ), f"Bad type for {height=} {type(height)=}, should be int" - assert width >= height - self._resolutions.append((width, height)) - - def _crop_resize_if_necessary( - self, image, depthmap, intrinsics, resolution, rng=None, info=None - ): - """This function: - - first downsizes the image with LANCZOS inteprolation, - which is better than bilinear interpolation in - """ - if not isinstance(image, PIL.Image.Image): - image = PIL.Image.fromarray(image) - - # downscale with lanczos interpolation so that image.size == resolution - # cropping centered on the principal point - W, H = image.size - cx, cy = intrinsics[:2, 2].round().astype(int) - - # calculate min distance to margin - min_margin_x = min(cx, W - cx) - min_margin_y = min(cy, H - cy) - assert min_margin_x > W / 5, f"Bad principal point in view={info}" - assert min_margin_y > H / 5, f"Bad principal point in view={info}" - - ## Center crop - # Crop on the principal point, make it always centered - # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) - l, t = cx - min_margin_x, cy - min_margin_y - r, b = cx + min_margin_x, cy + min_margin_y - crop_bbox = (l, t, r, b) - - image, depthmap, intrinsics = cropping.crop_image_depthmap( - image, depthmap, intrinsics, crop_bbox - ) - - # # transpose the resolution if necessary - W, H = image.size # new size - assert resolution[0] >= resolution[1] - if H > 1.1 * W: - # image is portrait mode - resolution = resolution[::-1] - elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]: - # image is square, so we chose (portrait, landscape) randomly - if rng.integers(2): - resolution = resolution[::-1] - - # high-quality Lanczos down-scaling - target_resolution = np.array(resolution) - # # if self.aug_crop > 1: - # # target_resolution += rng.integers(0, self.aug_crop) - # if resolution != (224, 224): - # halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8 - # ## Recale with max factor, so one of width or height might be larger than target_resolution - # image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh)) - # else: - image, depthmap, intrinsics = cropping.rescale_image_depthmap( - image, depthmap, intrinsics, target_resolution - ) - # actual cropping (if necessary) with bilinear interpolation - # if resolution == (224, 224): - intrinsics2 = cropping.camera_matrix_of_crop( - intrinsics, image.size, resolution, offset_factor=0.5 - ) - crop_bbox = cropping.bbox_from_intrinsics_in_out( - intrinsics, intrinsics2, resolution - ) - image, depthmap, intrinsics = cropping.crop_image_depthmap( - image, depthmap, intrinsics, crop_bbox - ) - return image, depthmap, intrinsics - - -def is_good_type(key, v): - """returns (is_good, err_msg)""" - if isinstance(v, (str, int, tuple)): - return True, None - if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): - return False, f"bad {v.dtype=}" - return True, None - - -def view_name(view, batch_index=None): - def sel(x): - return x[batch_index] if batch_index not in (None, slice(None)) else x - - db = sel(view["dataset"]) - label = sel(view["label"]) - instance = sel(view["instance"]) - return f"{db}/{label}/{instance}" - - -def transpose_to_landscape(view): - height, width = view["true_shape"] - - if width < height: - # rectify portrait to landscape - assert view["img"].shape == (3, height, width) - view["img"] = view["img"].swapaxes(1, 2) - - assert view["valid_mask"].shape == (height, width) - view["valid_mask"] = view["valid_mask"].swapaxes(0, 1) - - assert view["depthmap"].shape == (height, width) - view["depthmap"] = view["depthmap"].swapaxes(0, 1) - - assert view["pts3d"].shape == (height, width, 3) - view["pts3d"] = view["pts3d"].swapaxes(0, 1) - - # transpose x and y pixels - view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]] diff --git a/FastVGGT/eval/criterion.py b/FastVGGT/eval/criterion.py deleted file mode 100644 index a63c991b5fd5b61325b56055d7d380a27a336971..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/criterion.py +++ /dev/null @@ -1,534 +0,0 @@ -import torch -import torch.nn as nn -from copy import copy, deepcopy - -from eval.dataset_utils.corr import geotrf, inv - - -def invalid_to_nans(arr, valid_mask, ndim=999): - if valid_mask is not None: - arr = arr.clone() - arr[~valid_mask] = float("nan") - if arr.ndim > ndim: - arr = arr.flatten(-2 - (arr.ndim - ndim), -2) - return arr - - -def invalid_to_zeros(arr, valid_mask, ndim=999): - if valid_mask is not None: - arr = arr.clone() - arr[~valid_mask] = 0 - nnz = valid_mask.view(len(valid_mask), -1).sum(1) - else: - nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image - if arr.ndim > ndim: - arr = arr.flatten(-2 - (arr.ndim - ndim), -2) - return arr, nnz - - -class BaseCriterion(nn.Module): - def __init__(self, reduction="mean"): - super().__init__() - self.reduction = reduction - - -class Criterion(nn.Module): - def __init__(self, criterion=None): - super().__init__() - assert isinstance( - criterion, BaseCriterion - ), f"{criterion} is not a proper criterion!" - self.criterion = copy(criterion) - - def get_name(self): - return f"{type(self).__name__}({self.criterion})" - - def with_reduction(self, mode="none"): - res = loss = deepcopy(self) - while loss is not None: - assert isinstance(loss, Criterion) - loss.criterion.reduction = mode # make it return the loss for each sample - loss = loss._loss2 # we assume loss is a Multiloss - return res - - -class MultiLoss(nn.Module): - """Easily combinable losses (also keep track of individual loss values): - loss = MyLoss1() + 0.1*MyLoss2() - Usage: - Inherit from this class and override get_name() and compute_loss() - """ - - def __init__(self): - super().__init__() - self._alpha = 1 - self._loss2 = None - - def compute_loss(self, *args, **kwargs): - raise NotImplementedError() - - def get_name(self): - raise NotImplementedError() - - def __mul__(self, alpha): - assert isinstance(alpha, (int, float)) - res = copy(self) - res._alpha = alpha - return res - - __rmul__ = __mul__ # same - - def __add__(self, loss2): - assert isinstance(loss2, MultiLoss) - res = cur = copy(self) - - while cur._loss2 is not None: - cur = cur._loss2 - cur._loss2 = loss2 - return res - - def __repr__(self): - name = self.get_name() - if self._alpha != 1: - name = f"{self._alpha:g}*{name}" - if self._loss2: - name = f"{name} + {self._loss2}" - return name - - def forward(self, *args, **kwargs): - loss = self.compute_loss(*args, **kwargs) - if isinstance(loss, tuple): - loss, details = loss - elif loss.ndim == 0: - details = {self.get_name(): float(loss)} - else: - details = {} - loss = loss * self._alpha - - if self._loss2: - loss2, details2 = self._loss2(*args, **kwargs) - loss = loss + loss2 - details |= details2 - - return loss, details - - -class LLoss(BaseCriterion): - """L-norm loss""" - - def forward(self, a, b): - assert ( - a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3 - ), f"Bad shape = {a.shape}" - dist = self.distance(a, b) - - if self.reduction == "none": - return dist - if self.reduction == "sum": - return dist.sum() - if self.reduction == "mean": - return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) - raise ValueError(f"bad {self.reduction=} mode") - - def distance(self, a, b): - raise NotImplementedError() - - -class L21Loss(LLoss): - """Euclidean distance between 3d points""" - - def distance(self, a, b): - return torch.norm(a - b, dim=-1) # normalized L2 distance - - -L21 = L21Loss() - - -def get_pred_pts3d(gt, pred, use_pose=False): - assert use_pose is True - return pred["pts3d_in_other_view"] # return! - - -def Sum(losses, masks, conf=None): - loss, mask = losses[0], masks[0] - if loss.ndim > 0: - # we are actually returning the loss for every pixels - if conf is not None: - return losses, masks, conf - return losses, masks - else: - # we are returning the global loss - for loss2 in losses[1:]: - loss = loss + loss2 - return loss - - -def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True): - assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3 - assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3) - norm_mode, dis_mode = norm_mode.split("_") - - nan_pts = [] - nnzs = [] - - if norm_mode == "avg": - # gather all points together (joint normalization) - - for i, pt in enumerate(pts): - nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3) - nan_pts.append(nan_pt) - nnzs.append(nnz) - - if fix_first: - break - all_pts = torch.cat(nan_pts, dim=1) - - # compute distance to origin - all_dis = all_pts.norm(dim=-1) - if dis_mode == "dis": - pass # do nothing - elif dis_mode == "log1p": - all_dis = torch.log1p(all_dis) - else: - raise ValueError(f"bad {dis_mode=}") - - norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8) - else: - raise ValueError(f"Not implemented {norm_mode=}") - - norm_factor = norm_factor.clip(min=1e-8) - while norm_factor.ndim < pts[0].ndim: - norm_factor.unsqueeze_(-1) - - return norm_factor - - -def normalize_pointcloud_t( - pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False -): - if gt: - norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first) - res = [] - - for i, pt in enumerate(pts): - res.append(pt / norm_factor) - - else: - # pts_l, pts_r = pts - # use pts_l and pts_r[-1] as pts to normalize - norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first) - - res = [] - - for i in range(len(pts)): - res.append(pts[i] / norm_factor) - # res_r.append(pts_r[i] / norm_factor) - - # res = [res_l, res_r] - - return res, norm_factor - - -@torch.no_grad() -def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5): - # set invalid points to NaN - _zs = [] - for i in range(len(zs)): - valid_mask = valid_masks[i] if valid_masks is not None else None - _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1) - _zs.append(_z) - - _zs = torch.cat(_zs, dim=-1) - - # compute median depth overall (ignoring nans) - if quantile == 0.5: - shift_z = torch.nanmedian(_zs, dim=-1).values - else: - shift_z = torch.nanquantile(_zs, quantile, dim=-1) - return shift_z # (B,) - - -@torch.no_grad() -def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True): - # set invalid points to NaN - - _pts = [] - for i in range(len(pts)): - valid_mask = valid_masks[i] if valid_masks is not None else None - _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3) - _pts.append(_pt) - - _pts = torch.cat(_pts, dim=1) - - # compute median center - _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) - if z_only: - _center[..., :2] = 0 # do not center X and Y - - # compute median norm - _norm = ((_pts - _center) if center else _pts).norm(dim=-1) - scale = torch.nanmedian(_norm, dim=1).values - return _center[:, None, :, :], scale[:, None, None, None] - - -class Regr3D_t(Criterion, MultiLoss): - def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True): - super().__init__(criterion) - self.norm_mode = norm_mode - self.gt_scale = gt_scale - self.fix_first = fix_first - - def get_all_pts3d_t(self, gts, preds, dist_clip=None): - # everything is normalized w.r.t. camera of view1 - in_camera1 = inv(gts[0]["camera_pose"]) - - gt_pts = [] - valids = [] - pr_pts = [] - - for i, gt in enumerate(gts): - # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3 - gt_pts.append(geotrf(in_camera1, gt["pts3d"])) - valid = gt["valid_mask"].clone() - - if dist_clip is not None: - # points that are too far-away == invalid - dis = gt["pts3d"].norm(dim=-1) - valid = valid & (dis <= dist_clip) - - valids.append(valid) - pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True)) - # if i != len(gts)-1: - # pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0))) - - # if i != 0: - # pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0))) - - # pr_pts = (pr_pts_l, pr_pts_r) - - if self.norm_mode: - pr_pts, pr_factor = normalize_pointcloud_t( - pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False - ) - else: - pr_factor = None - - if self.norm_mode and not self.gt_scale: - gt_pts, gt_factor = normalize_pointcloud_t( - gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True - ) - else: - gt_factor = None - - return gt_pts, pr_pts, gt_factor, pr_factor, valids, {} - - def compute_frame_loss(self, gts, preds, **kw): - gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( - self.get_all_pts3d_t(gts, preds, **kw) - ) - - pred_pts_l, pred_pts_r = pred_pts - - loss_all = [] - mask_all = [] - conf_all = [] - - loss_left = 0 - loss_right = 0 - pred_conf_l = 0 - pred_conf_r = 0 - - for i in range(len(gt_pts)): - - # Left (Reference) - if i != len(gt_pts) - 1: - frame_loss = self.criterion( - pred_pts_l[i][masks[i]], gt_pts[i][masks[i]] - ) - - loss_all.append(frame_loss) - mask_all.append(masks[i]) - conf_all.append(preds[i][0]["conf"]) - - # To compare target/reference loss - if i != 0: - loss_left += frame_loss.cpu().detach().numpy().mean() - pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean() - - # Right (Target) - if i != 0: - frame_loss = self.criterion( - pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]] - ) - - loss_all.append(frame_loss) - mask_all.append(masks[i]) - conf_all.append(preds[i - 1][1]["conf"]) - - # To compare target/reference loss - if i != len(gt_pts) - 1: - loss_right += frame_loss.cpu().detach().numpy().mean() - pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean() - - if pr_factor is not None and gt_factor is not None: - filter_factor = pr_factor[pr_factor > gt_factor] - else: - filter_factor = [] - - if len(filter_factor) > 0: - factor_loss = (filter_factor - gt_factor).abs().mean() - else: - factor_loss = 0.0 - - self_name = type(self).__name__ - details = { - self_name + "_pts3d_1": float(loss_all[0].mean()), - self_name + "_pts3d_2": float(loss_all[1].mean()), - self_name + "loss_left": float(loss_left), - self_name + "loss_right": float(loss_right), - self_name + "conf_left": float(pred_conf_l), - self_name + "conf_right": float(pred_conf_r), - } - - return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss - - -class ConfLoss_t(MultiLoss): - """Weighted regression by learned confidence. - Assuming the input pixel_loss is a pixel-level regression loss. - - Principle: - high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) - low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) - - alpha: hyperparameter - """ - - def __init__(self, pixel_loss, alpha=1): - super().__init__() - assert alpha > 0 - self.alpha = alpha - self.pixel_loss = pixel_loss.with_reduction("none") - - def get_name(self): - return f"ConfLoss({self.pixel_loss})" - - def get_conf_log(self, x): - return x, torch.log(x) - - def compute_frame_loss(self, gts, preds, **kw): - # compute per-pixel loss - (losses, masks, confs), details, loss_factor = ( - self.pixel_loss.compute_frame_loss(gts, preds, **kw) - ) - - # weight by confidence - conf_losses = [] - conf_sum = 0 - for i in range(len(losses)): - conf, log_conf = self.get_conf_log(confs[i][masks[i]]) - conf_sum += conf.mean() - conf_loss = losses[i] * conf - self.alpha * log_conf - conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0 - conf_losses.append(conf_loss) - - conf_losses = torch.stack(conf_losses) * 2.0 - conf_loss_mean = conf_losses.mean() - - return ( - conf_loss_mean, - dict( - conf_loss_1=float(conf_losses[0]), - conf_loss2=float(conf_losses[1]), - conf_mean=conf_sum / len(losses), - **details, - ), - loss_factor, - ) - - -class Regr3D_t_ShiftInv(Regr3D_t): - """Same than Regr3D but invariant to depth shift.""" - - def get_all_pts3d_t(self, gts, preds): - # compute unnormalized points - gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( - super().get_all_pts3d_t(gts, preds) - ) - - # pred_pts_l, pred_pts_r = pred_pts - gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts] - - pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts] - # pred_zs.append(pred_pts_r[-1][..., 2]) - - # compute median depth - gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None] - pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None] - - # subtract the median depth - for i in range(len(gt_pts)): - gt_pts[i][..., 2] -= gt_shift_z - - for i in range(len(pred_pts)): - # for j in range(len(pred_pts[i])): - pred_pts[i][..., 2] -= pred_shift_z - - monitoring = dict( - monitoring, - gt_shift_z=gt_shift_z.mean().detach(), - pred_shift_z=pred_shift_z.mean().detach(), - ) - return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring - - -class Regr3D_t_ScaleInv(Regr3D_t): - """Same than Regr3D but invariant to depth shift. - if gt_scale == True: enforce the prediction to take the same scale than GT - """ - - def get_all_pts3d_t(self, gts, preds): - # compute depth-normalized points - gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( - super().get_all_pts3d_t(gts, preds) - ) - - # measure scene scale - - # pred_pts_l, pred_pts_r = pred_pts - - pred_pts_all = [ - x.clone() for x in pred_pts - ] # [pred_pt for pred_pt in pred_pts_l] - # pred_pts_all.append(pred_pts_r[-1]) - - _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks) - _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks) - - # prevent predictions to be in a ridiculous range - pred_scale = pred_scale.clip(min=1e-3, max=1e3) - - # subtract the median depth - if self.gt_scale: - for i in range(len(pred_pts)): - # for j in range(len(pred_pts[i])): - pred_pts[i] *= gt_scale / pred_scale - - else: - for i in range(len(pred_pts)): - # for j in range(len(pred_pts[i])): - pred_pts[i] *= pred_scale / gt_scale - - for i in range(len(gt_pts)): - gt_pts[i] *= gt_scale / pred_scale - - monitoring = dict( - monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach() - ) - - return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring - - -class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv): - # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv - pass diff --git a/FastVGGT/eval/data.py b/FastVGGT/eval/data.py deleted file mode 100644 index 301788e3ae6b67091fb031b068ecbecff1eae48d..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/data.py +++ /dev/null @@ -1,338 +0,0 @@ -import os -import cv2 -import numpy as np -import os.path as osp -from collections import deque -from base import BaseStereoViewDataset -import dataset_utils.cropping as cropping -from vggt.utils.eval_utils import imread_cv2, shuffle_deque - - -class SevenScenes(BaseStereoViewDataset): - def __init__( - self, - num_seq=1, - num_frames=5, - min_thresh=10, - max_thresh=100, - test_id=None, - full_video=False, - tuple_list=None, - seq_id=None, - rebuttal=False, - shuffle_seed=-1, - kf_every=1, - *args, - ROOT, - **kwargs, - ): - self.ROOT = ROOT - super().__init__(*args, **kwargs) - self.num_seq = num_seq - self.num_frames = num_frames - self.max_thresh = max_thresh - self.min_thresh = min_thresh - self.test_id = test_id - self.full_video = full_video - self.kf_every = kf_every - self.seq_id = seq_id - self.rebuttal = rebuttal - self.shuffle_seed = shuffle_seed - - # load all scenes - self.load_all_tuples(tuple_list) - self.load_all_scenes(ROOT) - - def __len__(self): - if self.tuple_list is not None: - return len(self.tuple_list) - return len(self.scene_list) * self.num_seq - - def load_all_tuples(self, tuple_list): - if tuple_list is not None: - self.tuple_list = tuple_list - # with open(tuple_path) as f: - # self.tuple_list = f.read().splitlines() - - else: - self.tuple_list = None - - def load_all_scenes(self, base_dir): - - if self.tuple_list is not None: - # Use pre-defined simplerecon scene_ids - self.scene_list = [ - "stairs/seq-06", - "stairs/seq-02", - "pumpkin/seq-06", - "chess/seq-01", - "heads/seq-02", - "fire/seq-02", - "office/seq-03", - "pumpkin/seq-03", - "redkitchen/seq-07", - "chess/seq-02", - "office/seq-01", - "redkitchen/seq-01", - "fire/seq-01", - ] - print(f"Found {len(self.scene_list)} sequences in split {self.split}") - return - - scenes = os.listdir(base_dir) - - file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split] - - self.scene_list = [] - for scene in scenes: - if self.test_id is not None and scene != self.test_id: - continue - # read file split - with open(osp.join(base_dir, scene, file_split)) as f: - seq_ids = f.read().splitlines() - - for seq_id in seq_ids: - # seq is string, take the int part and make it 01, 02, 03 - # seq_id = 'seq-{:2d}'.format(int(seq_id)) - num_part = "".join(filter(str.isdigit, seq_id)) - seq_id = f"seq-{num_part.zfill(2)}" - if self.seq_id is not None and seq_id != self.seq_id: - continue - self.scene_list.append(f"{scene}/{seq_id}") - - print(f"Found {len(self.scene_list)} sequences in split {self.split}") - - def _get_views(self, idx, resolution, rng): - - if self.tuple_list is not None: - line = self.tuple_list[idx].split(" ") - scene_id = line[0] - img_idxs = line[1:] - - else: - scene_id = self.scene_list[idx // self.num_seq] - seq_id = idx % self.num_seq - - data_path = osp.join(self.ROOT, scene_id) - num_files = len([name for name in os.listdir(data_path) if "color" in name]) - img_idxs = [f"{i:06d}" for i in range(num_files)] - img_idxs = img_idxs[:: self.kf_every] - - # Intrinsics used in SimpleRecon - fx, fy, cx, cy = 525, 525, 320, 240 - intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) - - views = [] - imgs_idxs = deque(img_idxs) - if self.shuffle_seed >= 0: - imgs_idxs = shuffle_deque(imgs_idxs) - - while len(imgs_idxs) > 0: - im_idx = imgs_idxs.popleft() - impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png") - depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png") - posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt") - - rgb_image = imread_cv2(impath) - - depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) - rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0])) - - depthmap[depthmap == 65535] = 0 - depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0 - - depthmap[depthmap > 10] = 0 - depthmap[depthmap < 1e-3] = 0 - - camera_pose = np.loadtxt(posepath).astype(np.float32) - - if resolution != (224, 224) or self.rebuttal: - rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( - rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath - ) - else: - rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( - rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath - ) - W, H = rgb_image.size - cx = W // 2 - cy = H // 2 - l, t = cx - 112, cy - 112 - r, b = cx + 112, cy + 112 - crop_bbox = (l, t, r, b) - rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap( - rgb_image, depthmap, intrinsics, crop_bbox - ) - - views.append( - dict( - img=rgb_image, - depthmap=depthmap, - camera_pose=camera_pose, - camera_intrinsics=intrinsics, - dataset="7scenes", - label=osp.join(scene_id, im_idx), - instance=impath, - ) - ) - return views - - -class NRGBD(BaseStereoViewDataset): - def __init__( - self, - num_seq=1, - num_frames=5, - min_thresh=10, - max_thresh=100, - test_id=None, - full_video=False, - tuple_list=None, - seq_id=None, - rebuttal=False, - shuffle_seed=-1, - kf_every=1, - *args, - ROOT, - **kwargs, - ): - - self.ROOT = ROOT - super().__init__(*args, **kwargs) - self.num_seq = num_seq - self.num_frames = num_frames - self.max_thresh = max_thresh - self.min_thresh = min_thresh - self.test_id = test_id - self.full_video = full_video - self.kf_every = kf_every - self.seq_id = seq_id - self.rebuttal = rebuttal - self.shuffle_seed = shuffle_seed - - # load all scenes - self.load_all_tuples(tuple_list) - self.load_all_scenes(ROOT) - - def __len__(self): - if self.tuple_list is not None: - return len(self.tuple_list) - return len(self.scene_list) * self.num_seq - - def load_all_tuples(self, tuple_list): - if tuple_list is not None: - self.tuple_list = tuple_list - # with open(tuple_path) as f: - # self.tuple_list = f.read().splitlines() - - else: - self.tuple_list = None - - def load_all_scenes(self, base_dir): - - scenes = [ - d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) - ] - - if self.test_id is not None: - self.scene_list = [self.test_id] - - else: - self.scene_list = scenes - - print(f"Found {len(self.scene_list)} sequences in split {self.split}") - - def load_poses(self, path): - file = open(path, "r") - lines = file.readlines() - file.close() - poses = [] - valid = [] - lines_per_matrix = 4 - for i in range(0, len(lines), lines_per_matrix): - if "nan" in lines[i]: - valid.append(False) - poses.append(np.eye(4, 4, dtype=np.float32).tolist()) - else: - valid.append(True) - pose_floats = [ - [float(x) for x in line.split()] - for line in lines[i : i + lines_per_matrix] - ] - poses.append(pose_floats) - - return np.array(poses, dtype=np.float32), valid - - def _get_views(self, idx, resolution, rng): - - if self.tuple_list is not None: - line = self.tuple_list[idx].split(" ") - scene_id = line[0] - img_idxs = line[1:] - - else: - scene_id = self.scene_list[idx // self.num_seq] - - num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images"))) - img_idxs = [f"{i}" for i in range(num_files)] - img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)] - - fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240 - intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) - - posepath = osp.join(self.ROOT, scene_id, f"poses.txt") - camera_poses, valids = self.load_poses(posepath) - - imgs_idxs = deque(img_idxs) - if self.shuffle_seed >= 0: - imgs_idxs = shuffle_deque(imgs_idxs) - views = [] - - while len(imgs_idxs) > 0: - im_idx = imgs_idxs.popleft() - - impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png") - depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png") - - rgb_image = imread_cv2(impath) - depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) - depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0 - depthmap[depthmap > 10] = 0 - depthmap[depthmap < 1e-3] = 0 - - rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0])) - - camera_pose = camera_poses[int(im_idx)] - # gl to cv - camera_pose[:, 1:3] *= -1.0 - if resolution != (224, 224) or self.rebuttal: - rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( - rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath - ) - else: - rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( - rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath - ) - W, H = rgb_image.size - cx = W // 2 - cy = H // 2 - l, t = cx - 112, cy - 112 - r, b = cx + 112, cy + 112 - crop_bbox = (l, t, r, b) - rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap( - rgb_image, depthmap, intrinsics, crop_bbox - ) - - views.append( - dict( - img=rgb_image, - depthmap=depthmap, - camera_pose=camera_pose, - camera_intrinsics=intrinsics, - dataset="nrgbd", - label=osp.join(scene_id, im_idx), - instance=impath, - ) - ) - - return views diff --git a/FastVGGT/eval/dataset_utils/__init__.py b/FastVGGT/eval/dataset_utils/__init__.py deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/dataset_utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index d00b3e8b6f5fd9ef7d1428ff771243e1557ddf9f..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 9ebd5d22d44ddce633e9a05a7467e9ae0c3f9032..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc deleted file mode 100644 index fbd10218eddf65486b375109c23cf8cecd7c3ba5..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc deleted file mode 100644 index 860c8970fc49d65b2a4101b77cc313a6aeac618a..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc deleted file mode 100644 index 0fe7870fb079d9dc08ca647866ff8c77fd0f2ec0..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc and /dev/null differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc deleted file mode 100644 index 1b4530be10547e3e7bf79e6cef6470880cb710a4..0000000000000000000000000000000000000000 Binary files a/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/eval/dataset_utils/corr.py b/FastVGGT/eval/dataset_utils/corr.py deleted file mode 100644 index fbf5de18b7388e38e45c3f313487957811abc483..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/dataset_utils/corr.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright (C) 2024-present Naver Corporation. All rights reserved. -# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). -# -# -------------------------------------------------------- - -import numpy as np -import torch - - -def todevice(batch, device, callback=None, non_blocking=False): - """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). - - batch: list, tuple, dict of tensors or other things - device: pytorch device or 'numpy' - callback: function that would be called on every sub-elements. - """ - if callback: - batch = callback(batch) - - if isinstance(batch, dict): - return {k: todevice(v, device) for k, v in batch.items()} - - if isinstance(batch, (tuple, list)): - return type(batch)(todevice(x, device) for x in batch) - - x = batch - if device == "numpy": - if isinstance(x, torch.Tensor): - x = x.detach().cpu().numpy() - elif x is not None: - if isinstance(x, np.ndarray): - x = torch.from_numpy(x) - if torch.is_tensor(x): - x = x.to(device, non_blocking=non_blocking) - return x - - -to_device = todevice # alias - - -def to_numpy(x): - return todevice(x, "numpy") - - -def geotrf(Trf, pts, ncol=None, norm=False): - """Apply a geometric transformation to a list of 3-D points. - - H: 3x3 or 4x4 projection matrix (typically a Homography) - p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) - - ncol: int. number of columns of the result (2 or 3) - norm: float. if != 0, the resut is projected on the z=norm plane. - - Returns an array of projected 2d points. - """ - assert Trf.ndim >= 2 - if isinstance(Trf, np.ndarray): - pts = np.asarray(pts) - elif isinstance(Trf, torch.Tensor): - pts = torch.as_tensor(pts, dtype=Trf.dtype) - - output_reshape = pts.shape[:-1] - ncol = ncol or pts.shape[-1] - - if ( - isinstance(Trf, torch.Tensor) - and isinstance(pts, torch.Tensor) - and Trf.ndim == 3 - and pts.ndim == 4 - ): - d = pts.shape[3] - if Trf.shape[-1] == d: - pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) - elif Trf.shape[-1] == d + 1: - pts = ( - torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) - + Trf[:, None, None, :d, d] - ) - else: - raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") - else: - if Trf.ndim >= 3: - n = Trf.ndim - 2 - assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" - Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) - - if pts.ndim > Trf.ndim: - - pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) - elif pts.ndim == 2: - - pts = pts[:, None, :] - - if pts.shape[-1] + 1 == Trf.shape[-1]: - Trf = Trf.swapaxes(-1, -2) # transpose Trf - pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] - elif pts.shape[-1] == Trf.shape[-1]: - Trf = Trf.swapaxes(-1, -2) # transpose Trf - pts = pts @ Trf - else: - pts = Trf @ pts.T - if pts.ndim >= 2: - pts = pts.swapaxes(-1, -2) - - if norm: - pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG - if norm != 1: - pts *= norm - - res = pts[..., :ncol].reshape(*output_reshape, ncol) - return res - - -def inv(mat): - """Invert a torch or numpy matrix""" - if isinstance(mat, torch.Tensor): - return torch.linalg.inv(mat) - if isinstance(mat, np.ndarray): - return np.linalg.inv(mat) - raise ValueError(f"bad matrix type = {type(mat)}") - - -def reproject_view(pts3d, view2): - shape = view2["pts3d"].shape[:2] - return reproject( - pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape - ) - - -def reproject(pts3d, K, world2cam, shape): - H, W, THREE = pts3d.shape - assert THREE == 3 - - with np.errstate(divide="ignore", invalid="ignore"): - pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2) - - return (H, W), ravel_xy(pos, shape) - - -def ravel_xy(pos, shape): - H, W = shape - with np.errstate(invalid="ignore"): - qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T - quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip( - min=0, max=H - 1, out=qy - ) - return quantized_pos - - -def unravel_xy(pos, shape): - - return np.unravel_index(pos, shape)[0].base[:, ::-1].copy() - - -def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False): - is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)) - pos1 = is_reciprocal1.nonzero()[0] - pos2 = corres_1_to_2[pos1] - if ret_recip: - return is_reciprocal1, pos1, pos2 - return pos1, pos2 - - -def extract_correspondences_from_pts3d( - view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0 -): - view1, view2 = to_numpy((view1, view2)) - - shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2) - shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1) - - is_reciprocal1, pos1, pos2 = reciprocal_1d( - corres1_to_2, corres2_to_1, ret_recip=True - ) - is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)) - - if target_n_corres is None: - if ret_xy: - pos1 = unravel_xy(pos1, shape1) - pos2 = unravel_xy(pos2, shape2) - return pos1, pos2 - - available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum()) - target_n_positives = int(target_n_corres * (1 - nneg)) - n_positives = min(len(pos1), target_n_positives) - n_negatives = min(target_n_corres - n_positives, available_negatives) - - if n_negatives + n_positives != target_n_corres: - - n_positives = target_n_corres - n_negatives - assert n_positives <= len(pos1) - - assert n_positives <= len(pos1) - assert n_positives <= len(pos2) - assert n_negatives <= (~is_reciprocal1).sum() - assert n_negatives <= (~is_reciprocal2).sum() - assert n_positives + n_negatives == target_n_corres - - valid = np.ones(n_positives, dtype=bool) - if n_positives < len(pos1): - - perm = rng.permutation(len(pos1))[:n_positives] - pos1 = pos1[perm] - pos2 = pos2[perm] - - if n_negatives > 0: - - def norm(p): - return p / p.sum() - - pos1 = np.r_[ - pos1, - rng.choice( - shape1[0] * shape1[1], - size=n_negatives, - replace=False, - p=norm(~is_reciprocal1), - ), - ] - pos2 = np.r_[ - pos2, - rng.choice( - shape2[0] * shape2[1], - size=n_negatives, - replace=False, - p=norm(~is_reciprocal2), - ), - ] - valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)] - - if ret_xy: - pos1 = unravel_xy(pos1, shape1) - pos2 = unravel_xy(pos2, shape2) - return pos1, pos2, valid diff --git a/FastVGGT/eval/dataset_utils/cropping.py b/FastVGGT/eval/dataset_utils/cropping.py deleted file mode 100644 index 30a9eac18d241b71538957cf7ba4767ebc323b43..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/dataset_utils/cropping.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (C) 2024-present Naver Corporation. All rights reserved. -# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). -# -# -------------------------------------------------------- - -import PIL.Image -import os - -from utils import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics - -os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" -import cv2 # noqa -import numpy as np # noqa - -try: - lanczos = PIL.Image.Resampling.LANCZOS - bicubic = PIL.Image.Resampling.BICUBIC -except AttributeError: - lanczos = PIL.Image.LANCZOS - bicubic = PIL.Image.BICUBIC - - -class ImageList: - """Convenience class to aply the same operation to a whole set of images.""" - - def __init__(self, images): - if not isinstance(images, (tuple, list, set)): - images = [images] - self.images = [] - for image in images: - if not isinstance(image, PIL.Image.Image): - image = PIL.Image.fromarray(image) - self.images.append(image) - - def __len__(self): - return len(self.images) - - def to_pil(self): - return tuple(self.images) if len(self.images) > 1 else self.images[0] - - @property - def size(self): - sizes = [im.size for im in self.images] - assert all(sizes[0] == s for s in sizes) - return sizes[0] - - def resize(self, *args, **kwargs): - return ImageList(self._dispatch("resize", *args, **kwargs)) - - def crop(self, *args, **kwargs): - return ImageList(self._dispatch("crop", *args, **kwargs)) - - def _dispatch(self, func, *args, **kwargs): - return [getattr(im, func)(*args, **kwargs) for im in self.images] - - -def rescale_image_depthmap( - image, depthmap, camera_intrinsics, output_resolution, force=True -): - """Jointly rescale a (image, depthmap) - so that (out_width, out_height) >= output_res - """ - image = ImageList(image) - input_resolution = np.array(image.size) # (W,H) - output_resolution = np.array(output_resolution) - if depthmap is not None: - - assert tuple(depthmap.shape[:2]) == image.size[::-1] - - assert output_resolution.shape == (2,) - scale_final = max(output_resolution / image.size) + 1e-8 - if scale_final >= 1 and not force: # image is already smaller than what is asked - return (image.to_pil(), depthmap, camera_intrinsics) - output_resolution = np.floor(input_resolution * scale_final).astype(int) - - image = image.resize( - output_resolution, resample=lanczos if scale_final < 1 else bicubic - ) - if depthmap is not None: - depthmap = cv2.resize( - depthmap, - output_resolution, - fx=scale_final, - fy=scale_final, - interpolation=cv2.INTER_NEAREST, - ) - - camera_intrinsics = camera_matrix_of_crop( - camera_intrinsics, input_resolution, output_resolution, scaling=scale_final - ) - - return image.to_pil(), depthmap, camera_intrinsics - - -def camera_matrix_of_crop( - input_camera_matrix, - input_resolution, - output_resolution, - scaling=1, - offset_factor=0.5, - offset=None, -): - - margins = np.asarray(input_resolution) * scaling - output_resolution - assert np.all(margins >= 0.0) - if offset is None: - offset = offset_factor * margins - - output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) - output_camera_matrix_colmap[:2, :] *= scaling - output_camera_matrix_colmap[:2, 2] -= offset - output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) - - return output_camera_matrix - - -def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): - """ - Return a crop of the input view. - """ - image = ImageList(image) - l, t, r, b = crop_bbox - - image = image.crop((l, t, r, b)) - depthmap = depthmap[t:b, l:r] - - camera_intrinsics = camera_intrinsics.copy() - camera_intrinsics[0, 2] -= l - camera_intrinsics[1, 2] -= t - - return image.to_pil(), depthmap, camera_intrinsics - - -def bbox_from_intrinsics_in_out( - input_camera_matrix, output_camera_matrix, output_resolution -): - out_width, out_height = output_resolution - l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) - crop_bbox = (l, t, l + out_width, t + out_height) - return crop_bbox diff --git a/FastVGGT/eval/dataset_utils/transforms.py b/FastVGGT/eval/dataset_utils/transforms.py deleted file mode 100644 index cec2d144e1b97cc99c191d15cdcaf20796cae94b..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/dataset_utils/transforms.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (C) 2024-present Naver Corporation. All rights reserved. -# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). -# -# -------------------------------------------------------- - -import torchvision.transforms as tvf - -ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - - -ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) - - -def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True): - if isinstance(value, (int, float)): - if value < 0: - raise ValueError(f"If is a single number, it must be non negative.") - value = [center - float(value), center + float(value)] - if clip_first_on_zero: - value[0] = max(value[0], 0.0) - elif isinstance(value, (tuple, list)) and len(value) == 2: - value = [float(value[0]), float(value[1])] - else: - raise TypeError(f"should be a single number or a list/tuple with length 2.") - - if not bound[0] <= value[0] <= value[1] <= bound[1]: - raise ValueError(f"values should be between {bound}, but got {value}.") - - if value[0] == value[1] == center: - return None - else: - return tuple(value) - - -import torch -import torchvision.transforms.functional as F - - -def SeqColorJitter(): - """ - Return a color jitter transform with same random parameters - """ - brightness = _check_input(0.5) - contrast = _check_input(0.5) - saturation = _check_input(0.5) - hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) - - fn_idx = torch.randperm(4) - brightness_factor = ( - None - if brightness is None - else float(torch.empty(1).uniform_(brightness[0], brightness[1])) - ) - contrast_factor = ( - None - if contrast is None - else float(torch.empty(1).uniform_(contrast[0], contrast[1])) - ) - saturation_factor = ( - None - if saturation is None - else float(torch.empty(1).uniform_(saturation[0], saturation[1])) - ) - hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) - - def _color_jitter(img): - for fn_id in fn_idx: - if fn_id == 0 and brightness_factor is not None: - img = F.adjust_brightness(img, brightness_factor) - elif fn_id == 1 and contrast_factor is not None: - img = F.adjust_contrast(img, contrast_factor) - elif fn_id == 2 and saturation_factor is not None: - img = F.adjust_saturation(img, saturation_factor) - elif fn_id == 3 and hue_factor is not None: - img = F.adjust_hue(img, hue_factor) - return ImgNorm(img) - - return _color_jitter diff --git a/FastVGGT/eval/eval_7andN.py b/FastVGGT/eval/eval_7andN.py deleted file mode 100644 index 8d6d5fbf6c159ea18bfd6fbcee0a5d23e63c0cf2..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/eval_7andN.py +++ /dev/null @@ -1,497 +0,0 @@ -import os -import sys - -# Ensure project root is on sys.path for absolute imports like `vggt.*` -ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) -if ROOT_DIR not in sys.path: - sys.path.insert(0, ROOT_DIR) - -import time -import torch -import argparse -import numpy as np -import open3d as o3d -import os.path as osp -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate -from tqdm import tqdm -from collections import defaultdict -import torchvision.transforms as transforms - - -def get_args_parser(): - parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False) - parser.add_argument( - "--ckpt_path", - type=str, - default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt", - help="ckpt name", - ) - parser.add_argument("--device", type=str, default="cuda:0", help="device") - parser.add_argument("--model_name", type=str, default="VGGT") - parser.add_argument( - "--conf_thresh", type=float, default=0.0, help="confidence threshold" - ) - parser.add_argument( - "--output_dir", - type=str, - default="/home/sy/code/FastVGGT/eval_results", - help="value for outdir", - ) - parser.add_argument("--size", type=int, default=518) - parser.add_argument("--revisit", type=int, default=1, help="revisit times") - parser.add_argument("--freeze", action="store_true") - parser.add_argument("--use_proj", action="store_true") - parser.add_argument( - "--merging", type=int, default=0, help="VGGT aggregator merging steps" - ) - parser.add_argument("--kf", type=int, default=2, help="key frame") - return parser - - -def main(args): - from data import SevenScenes, NRGBD - from utils import accuracy, completion - - if args.size == 512: - resolution = (512, 384) - elif args.size == 224: - resolution = 224 - elif args.size == 518: - resolution = (518, 392) - else: - raise NotImplementedError - datasets_all = { - "7scenes": SevenScenes( - split="test", - ROOT="/data/sy/7scenes", - resolution=resolution, - num_seq=1, - full_video=True, - kf_every=args.kf, - ), # 20), - "NRGBD": NRGBD( - split="test", - ROOT="/data/sy/neural_rgbd_data", - resolution=resolution, - num_seq=1, - full_video=True, - kf_every=args.kf, - ), - } - - device = args.device - model_name = args.model_name - - from vggt.models.vggt import VGGT - from vggt.utils.pose_enc import pose_encoding_to_extri_intri - from vggt.utils.geometry import unproject_depth_map_to_point_map - from criterion import Regr3D_t_ScaleShiftInv, L21 - - # Force use of bf16 data type - dtype = torch.bfloat16 - # Load VGGT model - model = VGGT(merging=args.merging, enable_point=True) - ckpt = torch.load(args.ckpt_path, map_location="cpu") - - # ✅ Fix: load pre-trained weights - model.load_state_dict( - ckpt, strict=False - ) # Use strict=False due to enable_point=True difference - - model = model.cuda().eval() - model = model.to(torch.bfloat16) - - del ckpt - os.makedirs(osp.join(args.output_dir, f"{args.kf}"), exist_ok=True) - - criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True) - - with torch.no_grad(): - for name_data, dataset in datasets_all.items(): - save_path = osp.join(osp.join(args.output_dir, f"{args.kf}"), name_data) - os.makedirs(save_path, exist_ok=True) - log_file = osp.join(save_path, "logs.txt") - - acc_all = 0 - acc_all_med = 0 - comp_all = 0 - comp_all_med = 0 - nc1_all = 0 - nc1_all_med = 0 - nc2_all = 0 - nc2_all_med = 0 - scene_infer_times = defaultdict(list) - - for data_idx in tqdm(range(len(dataset))): - batch = default_collate([dataset[data_idx]]) - ignore_keys = set( - [ - "depthmap", - "dataset", - "label", - "instance", - "idx", - "true_shape", - "rng", - ] - ) - for view in batch: - for name in view.keys(): # pseudo_focal - if name in ignore_keys: - continue - if isinstance(view[name], tuple) or isinstance( - view[name], list - ): - view[name] = [ - x.to(device, non_blocking=True) for x in view[name] - ] - else: - view[name] = view[name].to(device, non_blocking=True) - - pts_all = [] - pts_gt_all = [] - images_all = [] - masks_all = [] - conf_all = [] - in_camera1 = None - - dtype = ( - torch.bfloat16 - if torch.cuda.get_device_capability()[0] >= 8 - else torch.float16 - ) - with torch.cuda.amp.autocast(dtype=dtype): - if isinstance(batch, dict) and "img" in batch: - batch["img"] = (batch["img"] + 1.0) / 2.0 - elif isinstance(batch, list) and all( - isinstance(v, dict) and "img" in v for v in batch - ): - for view in batch: - view["img"] = (view["img"] + 1.0) / 2.0 - # Gather all `img` tensors into a single tensor of shape [N, C, H, W] - imgs_tensor = torch.cat([v["img"] for v in batch], dim=0) - - with torch.cuda.amp.autocast(dtype=dtype): - with torch.no_grad(): - torch.cuda.synchronize() - start = time.time() - preds = model(imgs_tensor) - torch.cuda.synchronize() - end = time.time() - inference_time_ms = (end - start) * 1000 - print(f"Inference time: {inference_time_ms:.2f}ms") - - # Wrap model outputs per-view to align with batch later - predictions = preds - views = batch # list[dict] - if "pose_enc" in predictions: - B, S = predictions["pose_enc"].shape[:2] - elif "world_points" in predictions: - B, S = predictions["world_points"].shape[:2] - else: - raise KeyError( - "predictions is missing a key to infer sequence length" - ) - - ress = [] - for s in range(S): - res = { - "pts3d_in_other_view": predictions["world_points"][:, s], - "conf": predictions["world_points_conf"][:, s], - "depth": predictions["depth"][:, s], - "depth_conf": predictions["depth_conf"][:, s], - "camera_pose": predictions["pose_enc"][:, s, :], - } - if ( - isinstance(views, list) - and s < len(views) - and "valid_mask" in views[s] - ): - res["valid_mask"] = views[s]["valid_mask"] - if "track" in predictions: - res.update( - { - "track": predictions["track"][:, s], - "vis": ( - predictions.get("vis", None)[:, s] - if "vis" in predictions - else None - ), - "track_conf": ( - predictions.get("conf", None)[:, s] - if "conf" in predictions - else None - ), - } - ) - ress.append(res) - - preds = ress - - valid_length = len(preds) // args.revisit - if args.revisit > 1: - preds = preds[-valid_length:] - batch = batch[-valid_length:] - - # Evaluation - print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}") - gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( - criterion.get_all_pts3d_t(batch, preds) - ) - - in_camera1 = None - pts_all = [] - pts_gt_all = [] - images_all = [] - masks_all = [] - conf_all = [] - - for j, view in enumerate(batch): - if in_camera1 is None: - in_camera1 = view["camera_pose"][0].cpu() - - image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0] - mask = view["valid_mask"].cpu().numpy()[0] - - pts = pred_pts[j].cpu().numpy()[0] - conf = preds[j]["conf"].cpu().data.numpy()[0] - - # mask = mask & (conf > 1.8) - - pts_gt = gt_pts[j].detach().cpu().numpy()[0] - - H, W = image.shape[:2] - cx = W // 2 - cy = H // 2 - l, t = cx - 112, cy - 112 - r, b = cx + 112, cy + 112 - image = image[t:b, l:r] - mask = mask[t:b, l:r] - pts = pts[t:b, l:r] - pts_gt = pts_gt[t:b, l:r] - - images_all.append(image[None, ...]) - pts_all.append(pts[None, ...]) - pts_gt_all.append(pts_gt[None, ...]) - masks_all.append(mask[None, ...]) - conf_all.append(conf[None, ...]) - - images_all = np.concatenate(images_all, axis=0) - pts_all = np.concatenate(pts_all, axis=0) - pts_gt_all = np.concatenate(pts_gt_all, axis=0) - masks_all = np.concatenate(masks_all, axis=0) - - scene_id = view["label"][0].rsplit("/", 1)[0] - # Record average inference time per scene - try: - scene_infer_times[scene_id].append(float(inference_time_ms)) - except Exception: - pass - - save_params = {} - - save_params["images_all"] = images_all - save_params["pts_all"] = pts_all - save_params["pts_gt_all"] = pts_gt_all - save_params["masks_all"] = masks_all - - pts_all_masked = pts_all[masks_all > 0] - pts_gt_all_masked = pts_gt_all[masks_all > 0] - images_all_masked = images_all[masks_all > 0] - - mask = np.isfinite(pts_all_masked) - pts_all_masked = pts_all_masked[mask] - - mask_gt = np.isfinite(pts_gt_all_masked) - pts_gt_all_masked = pts_gt_all_masked[mask_gt] - images_all_masked = images_all_masked[mask] - - # Reshape to point cloud (N, 3) before sampling - pts_all_masked = pts_all_masked.reshape(-1, 3) - pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3) - images_all_masked = images_all_masked.reshape(-1, 3) - - # If number of points exceeds threshold, sample by points - if pts_all_masked.shape[0] > 999999: - sample_indices = np.random.choice( - pts_all_masked.shape[0], 999999, replace=False - ) - pts_all_masked = pts_all_masked[sample_indices] - images_all_masked = images_all_masked[sample_indices] - - # Apply the same sampling to GT point cloud - if pts_gt_all_masked.shape[0] > 999999: - sample_indices_gt = np.random.choice( - pts_gt_all_masked.shape[0], 999999, replace=False - ) - pts_gt_all_masked = pts_gt_all_masked[sample_indices_gt] - - if args.use_proj: - - def umeyama_alignment( - src: np.ndarray, dst: np.ndarray, with_scale: bool = True - ): - assert src.shape == dst.shape - N, dim = src.shape - - mu_src = src.mean(axis=0) - mu_dst = dst.mean(axis=0) - src_c = src - mu_src - dst_c = dst - mu_dst - - Sigma = dst_c.T @ src_c / N # (3,3) - - U, D, Vt = np.linalg.svd(Sigma) - - S = np.eye(dim) - if np.linalg.det(U) * np.linalg.det(Vt) < 0: - S[-1, -1] = -1 - - R = U @ S @ Vt - - if with_scale: - var_src = (src_c**2).sum() / N - s = (D * S.diagonal()).sum() / var_src - else: - s = 1.0 - - t = mu_dst - s * R @ mu_src - - return s, R, t - - pts_all_masked = pts_all_masked.reshape(-1, 3) - pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3) - s, R, t = umeyama_alignment( - pts_all_masked, pts_gt_all_masked, with_scale=True - ) - pts_all_aligned = (s * (R @ pts_all_masked.T)).T + t # (N,3) - pts_all_masked = pts_all_aligned - - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(pts_all_masked) - pcd.colors = o3d.utility.Vector3dVector(images_all_masked) - - pcd_gt = o3d.geometry.PointCloud() - pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked) - pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked) - - trans_init = np.eye(4) - - threshold = 0.1 - reg_p2p = o3d.pipelines.registration.registration_icp( - pcd, - pcd_gt, - threshold, - trans_init, - o3d.pipelines.registration.TransformationEstimationPointToPoint(), - ) - - transformation = reg_p2p.transformation - - pcd = pcd.transform(transformation) - pcd.estimate_normals() - pcd_gt.estimate_normals() - - gt_normal = np.asarray(pcd_gt.normals) - pred_normal = np.asarray(pcd.normals) - - acc, acc_med, nc1, nc1_med = accuracy( - pcd_gt.points, pcd.points, gt_normal, pred_normal - ) - comp, comp_med, nc2, nc2_med = completion( - pcd_gt.points, pcd.points, gt_normal, pred_normal - ) - print( - f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}" - ) - print( - f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}", - file=open(log_file, "a"), - ) - - acc_all += acc - comp_all += comp - nc1_all += nc1 - nc2_all += nc2 - - acc_all_med += acc_med - comp_all_med += comp_med - nc1_all_med += nc1_med - nc2_all_med += nc2_med - - # release cuda memory - torch.cuda.empty_cache() - - # Get depth from pcd and run TSDFusion - to_write = "" - # Read the log file - if os.path.exists(osp.join(save_path, "logs.txt")): - with open(osp.join(save_path, "logs.txt"), "r") as f_sub: - to_write += f_sub.read() - - with open(osp.join(save_path, f"logs_all.txt"), "w") as f: - log_data = to_write - metrics = defaultdict(list) - for line in log_data.strip().split("\n"): - match = regex.match(line) - if match: - data = match.groupdict() - # Exclude 'scene_id' from metrics as it's an identifier - for key, value in data.items(): - if key != "scene_id": - metrics[key].append(float(value)) - metrics["nc"].append( - (float(data["nc1"]) + float(data["nc2"])) / 2 - ) - metrics["nc_med"].append( - (float(data["nc1_med"]) + float(data["nc2_med"])) / 2 - ) - mean_metrics = { - metric: sum(values) / len(values) - for metric, values in metrics.items() - } - - c_name = "mean" - print_str = f"{c_name.ljust(20)}: " - for m_name in mean_metrics: - print_num = np.mean(mean_metrics[m_name]) - print_str = print_str + f"{m_name}: {print_num:.3f} | " - print_str = print_str + "\n" - # Summarize per-scene average inference time - time_lines = [] - for sid, times in scene_infer_times.items(): - if len(times) > 0: - time_lines.append( - f"Idx: {sid}, Time_avg_ms: {np.mean(times):.2f}" - ) - time_block = "\n".join(time_lines) + ( - "\n" if len(time_lines) > 0 else "" - ) - - f.write(to_write + time_block + print_str) - - -from collections import defaultdict -import re - -pattern = r""" - Idx:\s*(?P[^,]+),\s* - Acc:\s*(?P[^,]+),\s* - Comp:\s*(?P[^,]+),\s* - NC1:\s*(?P[^,]+),\s* - NC2:\s*(?P[^,]+)\s*-\s* - Acc_med:\s*(?P[^,]+),\s* - Compc_med:\s*(?P[^,]+),\s* - NC1c_med:\s*(?P[^,]+),\s* - NC2c_med:\s*(?P[^,]+) -""" - -regex = re.compile(pattern, re.VERBOSE) - - -if __name__ == "__main__": - parser = get_args_parser() - args = parser.parse_args() - - main(args) diff --git a/FastVGGT/eval/eval_custom.py b/FastVGGT/eval/eval_custom.py deleted file mode 100644 index a9e4f7ccdd6c22c2c3bef6421c52ace01d362a83..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/eval_custom.py +++ /dev/null @@ -1,467 +0,0 @@ -import argparse -from pathlib import Path -import numpy as np -import torch -import os -import sys -import matplotlib.pyplot as plt -from scipy.spatial.transform import Rotation - -# Ensure project root is in sys.path for absolute imports like `vggt.*` -ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) -if ROOT_DIR not in sys.path: - sys.path.insert(0, ROOT_DIR) - -from vggt.models.vggt import VGGT -from vggt.utils.eval_utils import ( - load_poses, - get_vgg_input_imgs, - get_sorted_image_paths, - build_frame_selection, - load_images_rgb, - infer_vggt_and_reconstruct, - evaluate_scene_and_save, -) - -# Import pose visualization libraries (optional EVO support) -try: - from evo.core.trajectory import PoseTrajectory3D - import evo.tools.plot as plot - - EVO_AVAILABLE = True -except ImportError: - # EVO is optional; we have a matplotlib-based fallback - EVO_AVAILABLE = False - - -def visualize_predicted_poses( - all_cam_to_world_mat, frame_ids, output_scene_dir, scene_name="custom_dataset" -): - """ - Visualize the predicted camera pose trajectory (no GT comparison required). - - Args: - all_cam_to_world_mat: List of camera-to-world transform matrices - frame_ids: List of frame IDs - output_scene_dir: Output directory - scene_name: Scene name - """ - # Provide basic pose visualization even without EVO - if not EVO_AVAILABLE: - print("⚠️ EVO not installed; using basic matplotlib visualization") - - try: - # Convert to numpy array - poses_est = np.array(all_cam_to_world_mat) - - if len(poses_est) < 2: - print("⚠️ Not enough poses to generate trajectory plot") - return - - print(f"🎨 Generating pose trajectory visualization...") - - # Extract translation part - positions = poses_est[:, :3, 3] # shape: (N, 3) - - # Create figure - show XZ-plane projection only - fig, ax = plt.subplots(1, 1, figsize=(10, 8)) - - # XZ-plane projection - ax.plot( - positions[:, 0], - positions[:, 2], - "b-", - linewidth=2, - label="Predicted Trajectory", - ) - ax.scatter( - positions[0, 0], positions[0, 2], color="green", s=100, label="Start" - ) - ax.scatter(positions[-1, 0], positions[-1, 2], color="red", s=100, label="End") - ax.set_xlabel("X (m)") - ax.set_ylabel("Z (m)") - ax.set_title(f"{scene_name} - XZ-plane projection") - ax.legend() - ax.grid(True, alpha=0.3) - - # Save image - pose_plot_path = output_scene_dir / "predicted_trajectory.png" - plt.savefig(pose_plot_path, dpi=300, bbox_inches="tight") - plt.close() - - print(f"📊 Trajectory visualization saved: {pose_plot_path}") - - except Exception as e: - print(f"⚠️ Failed to generate pose visualization: {e}") - import traceback - - traceback.print_exc() - - -def main(): - """ - Evaluation script for a Custom Dataset. - Supports optional evaluation and custom dataset structure. - """ - parser = argparse.ArgumentParser( - description="Run FastVGGT evaluation on a Custom Dataset" - ) - - # Required: dataset path - parser.add_argument( - "--data_path", - type=Path, - required=True, - help="Dataset path containing subfolders: color, depth, gt_ply, pose", - ) - - # Optional: enable evaluation - parser.add_argument( - "--enable_evaluation", - action="store_true", - help="Enable evaluation (requires pose and ply data)", - ) - - # Output path - parser.add_argument( - "--output_path", - type=Path, - default="./eval_results_custom", - help="Output path for evaluation results", - ) - - # Model parameters - parser.add_argument( - "--ckpt_path", - type=str, - default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt", - help="Model checkpoint file path", - ) - - parser.add_argument("--merging", type=int, default=0, help="Merging parameter") - - # Processing parameters - parser.add_argument( - "--input_frame", - type=int, - default=200, - help="Maximum number of frames to process per scene", - ) - - parser.add_argument( - "--depth_conf_thresh", - type=float, - default=3.0, - help="Depth confidence threshold to filter low-confidence depth values", - ) - - # Evaluation parameters (only used when evaluation is enabled) - parser.add_argument( - "--chamfer_max_dist", - type=float, - default=0.5, - help="Maximum distance threshold used in Chamfer Distance computation", - ) - - parser.add_argument("--plot", action="store_true", help="Whether to generate plots") - - parser.add_argument( - "--vis_attn_map", - action="store_true", - help="Visualize attention maps during inference", - ) - - args = parser.parse_args() - torch.manual_seed(33) - - # Check data path exists - if not args.data_path.exists(): - print(f"❌ Error: Data path does not exist: {args.data_path}") - return - - # Check required subdirectories - color_dir = args.data_path / "images" - pose_dir = args.data_path / "pose" - - if not color_dir.exists(): - print(f"❌ Error: color directory does not exist: {color_dir}") - return - - print(f"📁 Dataset path: {args.data_path}") - # print(f"🔧 Enable evaluation: {'Yes' if args.enable_evaluation else 'No'}") - - # If evaluation is enabled, check pose and gt_ply directories - if args.enable_evaluation: - if not pose_dir.exists(): - print(f"❌ Error: Evaluation requires pose directory: {pose_dir}") - return - - gt_ply_dir = args.data_path / "gt_ply" - if not gt_ply_dir.exists(): - print(f"❌ Error: Evaluation requires gt_ply directory: {gt_ply_dir}") - return - print(f"📊 Evaluation will use Ground Truth") - else: - print(f"🏃 Inference only, no evaluation") - - # Create output directory - args.output_path.mkdir(parents=True, exist_ok=True) - output_scene_dir = args.output_path / "custom_dataset" - - # Check if already processed - if (output_scene_dir / "metrics.json").exists() and args.enable_evaluation: - print( - f"⚠️ Results already exist, skipping: {output_scene_dir / 'metrics.json'}" - ) - return - - # Force use of bf16 dtype - dtype = torch.bfloat16 - - # Load VGGT model - print(f"🔄 Loading model: {args.ckpt_path}") - model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map) - ckpt = torch.load(args.ckpt_path, map_location="cpu") - incompat = model.load_state_dict(ckpt, strict=False) - # if incompat.missing_keys or incompat.unexpected_keys: - # print(f"⚠️ Partially incompatible keys when loading model: {incompat}") - model = model.cuda().eval() - model = model.to(torch.bfloat16) - print(f"✅ Model loaded") - - # Load scene data - image_paths = get_sorted_image_paths(color_dir) - if len(image_paths) == 0: - print(f"❌ Error: No images found in {color_dir}") - return - - print(f"🖼️ Found {len(image_paths)} images") - - # Process pose data (if evaluation is enabled) - poses_gt = None - first_gt_pose = None - available_pose_frame_ids = None - c2ws = None - - if args.enable_evaluation: - poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_dir) - if ( - poses_gt is None - or first_gt_pose is None - or available_pose_frame_ids is None - ): - print(f"❌ Error: Failed to load pose data") - return - print(f"📐 Loaded {len(poses_gt)} poses") - - # Frame selection - if args.enable_evaluation and available_pose_frame_ids is not None: - # Use pose data for frame selection - selected_frame_ids, selected_image_paths, selected_pose_indices = ( - build_frame_selection( - image_paths, available_pose_frame_ids, args.input_frame - ) - ) - c2ws = poses_gt[selected_pose_indices] - image_paths = selected_image_paths - else: - # Simply take the first N frames - num_frames = min(len(image_paths), args.input_frame) - selected_frame_ids = list(range(num_frames)) - image_paths = image_paths[:num_frames] - - print(f"📋 Selected {len(image_paths)} frames for processing") - - try: - # Load images - print(f"🔄 Loading images...") - images = load_images_rgb(image_paths) - - if not images or len(images) < 3: - print(f"❌ Error: Not enough valid images (need at least 3)") - return - - frame_ids = selected_frame_ids - images_array = np.stack(images) - vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array) - print(f"📐 Image patch dimensions: {patch_width}x{patch_height}") - - # Update attention layer patch dimensions in the model - model.update_patch_dimensions(patch_width, patch_height) - - # Inference + Reconstruction - print(f"🚀 Start inference and reconstruction...") - ( - extrinsic_np, - intrinsic_np, - all_world_points, - all_point_colors, - all_cam_to_world_mat, - inference_time_ms, - ) = infer_vggt_and_reconstruct( - model, vgg_input, dtype, args.depth_conf_thresh, image_paths - ) - print(f"⏱️ Inference time: {inference_time_ms:.2f}ms") - - # Check results - if not all_cam_to_world_mat or not all_world_points: - print(f"❌ Error: Failed to obtain valid camera poses or point clouds") - return - - # print(f"✅ Inference done, obtained {len(all_world_points)} point sets") - - # Evaluation and saving - if args.enable_evaluation: - print(f"📊 Start evaluation...") - gt_ply_dir = args.data_path / "gt_ply" - metrics = evaluate_scene_and_save( - "custom_dataset", - c2ws, - first_gt_pose, - frame_ids, - all_cam_to_world_mat, - all_world_points, - output_scene_dir, - gt_ply_dir, - args.chamfer_max_dist, - inference_time_ms, - args.plot, - ) - if metrics is not None: - print("📈 Evaluation results:") - for key, value in metrics.items(): - if key in [ - "chamfer_distance", - "ate", - "are", - "rpe_rot", - "rpe_trans", - "inference_time_ms", - ]: - print(f" {key}: {float(value):.4f}") - - # Also visualize predicted poses in evaluation branch - if args.plot: - visualize_predicted_poses( - all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset" - ) - else: - # Save reconstruction only, no evaluation - print(f"💾 Saving reconstruction...") - output_scene_dir.mkdir(parents=True, exist_ok=True) - - # Save camera poses - poses_output_path = output_scene_dir / "estimated_poses.txt" - with open(poses_output_path, "w") as f: - for i, pose in enumerate(all_cam_to_world_mat): - f.write(f"# Frame {frame_ids[i]}\n") - for row in pose: - f.write(" ".join(map(str, row)) + "\n") - f.write("\n") - - # Save point cloud - if all_world_points: - points_output_path = output_scene_dir / "reconstructed_points.ply" - - # Merge all frames' point clouds and colors - try: - merged_point_cloud = np.vstack(all_world_points) - merged_colors = ( - np.vstack(all_point_colors).astype(np.uint8) - if all_point_colors is not None and len(all_point_colors) > 0 - else None - ) - print( - f"📊 Merged point clouds: {len(all_world_points)} frames, total {len(merged_point_cloud)} points" - ) - - # If too many points, randomly sample 100000 points - max_points = 100000 - if len(merged_point_cloud) > max_points: - print( - f"🔽 Too many points, randomly sampling {max_points} points..." - ) - # Randomly choose indices - indices = np.random.choice( - len(merged_point_cloud), size=max_points, replace=False - ) - merged_point_cloud = merged_point_cloud[indices] - if merged_colors is not None: - merged_colors = merged_colors[indices] - print( - f"✅ Sampling done, kept {len(merged_point_cloud)} points" - ) - - # Save as PLY (with color) - with open(points_output_path, "w") as f: - f.write("ply\n") - f.write("format ascii 1.0\n") - f.write(f"element vertex {len(merged_point_cloud)}\n") - f.write("property float x\n") - f.write("property float y\n") - f.write("property float z\n") - if merged_colors is not None: - f.write("property uchar red\n") - f.write("property uchar green\n") - f.write("property uchar blue\n") - f.write("end_header\n") - if merged_colors is None: - for point in merged_point_cloud: - if not (np.isnan(point).any() or np.isinf(point).any()): - f.write( - f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n" - ) - else: - for point, color in zip(merged_point_cloud, merged_colors): - # Check point validity - if not (np.isnan(point).any() or np.isinf(point).any()): - r = int(np.clip(color[0], 0, 255)) - g = int(np.clip(color[1], 0, 255)) - b = int(np.clip(color[2], 0, 255)) - f.write( - f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r} {g} {b}\n" - ) - - print(f"💾 Point cloud saved to: {points_output_path}") - - except Exception as e: - print(f"⚠️ Error saving point cloud: {e}") - # If merge fails, try to log per-frame info - print(f"🔍 Point cloud debug info:") - for i, frame_points in enumerate(all_world_points): - print( - f" Frame {i}: {frame_points.shape if hasattr(frame_points, 'shape') else type(frame_points)}" - ) - if ( - hasattr(frame_points, "shape") - and len(frame_points.shape) >= 2 - ): - print( - f" Shape: {frame_points.shape}, Dtype: {frame_points.dtype}" - ) - if frame_points.shape[0] > 0: - print( - f" Range: x[{np.min(frame_points[:, 0]):.3f}, {np.max(frame_points[:, 0]):.3f}] " - f"y[{np.min(frame_points[:, 1]):.3f}, {np.max(frame_points[:, 1]):.3f}] " - f"z[{np.min(frame_points[:, 2]):.3f}, {np.max(frame_points[:, 2]):.3f}]" - ) - - print(f"📁 Results saved to: {output_scene_dir}") - - # Visualize predicted pose trajectory - if args.plot: - visualize_predicted_poses( - all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset" - ) - - print(f"🎉 Done!") - - except Exception as e: - print(f"❌ Error occurred during processing: {e}") - import traceback - - traceback.print_exc() - - -if __name__ == "__main__": - main() diff --git a/FastVGGT/eval/eval_scannet.py b/FastVGGT/eval/eval_scannet.py deleted file mode 100644 index 332132c45ab4fc3f7e768dbf3845d9cdb3ccc4eb..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/eval_scannet.py +++ /dev/null @@ -1,208 +0,0 @@ -import argparse -from pathlib import Path -import numpy as np -import torch -import os -import sys - -# Ensure project root is in sys.path for absolute imports like `vggt.*` -ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) -if ROOT_DIR not in sys.path: - sys.path.insert(0, ROOT_DIR) - -from vggt.models.vggt import VGGT -from vggt.utils.eval_utils import ( - load_poses, - get_vgg_input_imgs, - get_sorted_image_paths, - get_all_scenes, - build_frame_selection, - load_images_rgb, - infer_vggt_and_reconstruct, - evaluate_scene_and_save, - compute_average_metrics_and_save, -) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_dir", type=Path, default="/data/scannetv2/process_scannet" - ) - parser.add_argument( - "--gt_ply_dir", - type=Path, - default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans", - ) - parser.add_argument("--output_path", type=Path, default="./eval_results") - parser.add_argument("--merging", type=int, default=None) - parser.add_argument("--plot", type=bool, default=True) - parser.add_argument( - "--depth_conf_thresh", - type=float, - default=3.0, - help="Depth confidence threshold for filtering low confidence depth values", - ) - parser.add_argument( - "--chamfer_max_dist", - type=float, - default=0.5, - help="Maximum distance threshold in Chamfer Distance computation, distances exceeding this value will be clipped", - ) - parser.add_argument( - "--input_frame", - type=int, - default=200, - help="Maximum number of frames selected for processing per scene", - ) - parser.add_argument( - "--num_scenes", - type=int, - default=50, - help="Maximum number of scenes to evaluate", - ) - parser.add_argument( - "--ckpt_path", - type=str, - default="./ckpt/model_tracker_fixed_e20.pt", - help="Path to the model checkpoint file", - ) - parser.add_argument( - "--vis_attn_map", - action="store_true", - help="Whether to visualize attention maps during inference", - ) - args = parser.parse_args() - torch.manual_seed(33) - - # Scene sampling - scannet_scenes = get_all_scenes(args.data_dir, args.num_scenes) - print(f"Evaluate {len(scannet_scenes)} scenes") - - all_scenes_metrics = {"scenes": {}, "average": {}} - # Force use of bf16 data type - dtype = torch.bfloat16 - # Load VGGT model - model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map) - ckpt = torch.load(args.ckpt_path, map_location="cpu") - incompat = model.load_state_dict(ckpt, strict=False) - model = model.cuda().eval() - model = model.to(torch.bfloat16) - - # Process each scene - for scene in scannet_scenes: - scene_dir = args.data_dir / f"{scene}" - output_scene_dir = args.output_path / f"input_frame_{args.input_frame}" / scene - if (output_scene_dir / "metrics.json").exists(): - continue - - # Load scene data - images_dir = scene_dir / "color" - pose_path = scene_dir / "pose" - image_paths = get_sorted_image_paths(images_dir) - poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_path) - if ( - poses_gt is None - or first_gt_pose is None - or available_pose_frame_ids is None - ): - print(f"Skipping scene {scene}: no pose data") - continue - - # Frame filtering - selected_frame_ids, selected_image_paths, selected_pose_indices = ( - build_frame_selection( - image_paths, available_pose_frame_ids, args.input_frame - ) - ) - - # Get corresponding poses - c2ws = poses_gt[selected_pose_indices] - image_paths = selected_image_paths - - if len(image_paths) == 0: - print(f"No images found in {images_dir}") - continue - - print("🚩Processing", scene, f"Found {len(image_paths)} images") - all_cam_to_world_mat = [] - all_world_points = [] - - try: - # Load images - images = load_images_rgb(image_paths) - - if not images or len(images) < 3: - print(f"Skipping {scene}: insufficient valid images") - continue - - frame_ids = selected_frame_ids - images_array = np.stack(images) - vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array) - print(f"Patch dimensions: {patch_width}x{patch_height}") - - # Update model attention layers with dynamic patch dimensions - model.update_patch_dimensions(patch_width, patch_height) - - # Inference + Reconstruction - ( - extrinsic_np, - intrinsic_np, - all_world_points, - all_point_colors, - all_cam_to_world_mat, - inference_time_ms, - ) = infer_vggt_and_reconstruct( - model, vgg_input, dtype, args.depth_conf_thresh, image_paths - ) - print(f"Inference time: {inference_time_ms:.2f}ms") - - # Process results - if not all_cam_to_world_mat or not all_world_points: - print( - f"Skipping {scene}: failed to obtain valid camera poses or point clouds" - ) - continue - - # Evaluate and save - metrics = evaluate_scene_and_save( - scene, - c2ws, - first_gt_pose, - frame_ids, - all_cam_to_world_mat, - all_world_points, - output_scene_dir, - args.gt_ply_dir, - args.chamfer_max_dist, - inference_time_ms, - args.plot, - ) - if metrics is not None: - all_scenes_metrics["scenes"][scene] = { - key: float(value) - for key, value in metrics.items() - if key - in [ - "chamfer_distance", - "ate", - "are", - "rpe_rot", - "rpe_trans", - "inference_time_ms", - ] - } - print("Complete metrics", all_scenes_metrics["scenes"][scene]) - - except Exception as e: - print(f"Error processing scene {scene}: {e}") - import traceback - - traceback.print_exc() - - # Summarize average metrics and save - compute_average_metrics_and_save( - all_scenes_metrics, - args.output_path, - args.input_frame, - ) diff --git a/FastVGGT/eval/utils.py b/FastVGGT/eval/utils.py deleted file mode 100644 index e8d5606560e1fc82a7b7b81df8b3b9f3b9ec8662..0000000000000000000000000000000000000000 --- a/FastVGGT/eval/utils.py +++ /dev/null @@ -1,142 +0,0 @@ -import numpy as np -from scipy.spatial import cKDTree as KDTree - - -def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): - """ - Args: - - depthmap (HxW array): - - camera_intrinsics: a 3x3 matrix - Returns: - pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. - """ - camera_intrinsics = np.float32(camera_intrinsics) - H, W = depthmap.shape - - assert camera_intrinsics[0, 1] == 0.0 - assert camera_intrinsics[1, 0] == 0.0 - if pseudo_focal is None: - fu = camera_intrinsics[0, 0] - fv = camera_intrinsics[1, 1] - else: - assert pseudo_focal.shape == (H, W) - fu = fv = pseudo_focal - cu = camera_intrinsics[0, 2] - cv = camera_intrinsics[1, 2] - - u, v = np.meshgrid(np.arange(W), np.arange(H)) - z_cam = depthmap - x_cam = (u - cu) * z_cam / fu - y_cam = (v - cv) * z_cam / fv - X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) - - valid_mask = depthmap > 0.0 - return X_cam, valid_mask - - -def depthmap_to_absolute_camera_coordinates( - depthmap, camera_intrinsics, camera_pose, **kw -): - """ - Args: - - depthmap (HxW array): - - camera_intrinsics: a 3x3 matrix - - camera_pose: a 4x3 or 4x4 cam2world matrix - Returns: - pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. - """ - X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) - - X_world = X_cam # default - if camera_pose is not None: - - R_cam2world = camera_pose[:3, :3] - t_cam2world = camera_pose[:3, 3] - - X_world = ( - np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] - ) - - return X_world, valid_mask - - -def completion_ratio(gt_points, rec_points, dist_th=0.05): - gen_points_kd_tree = KDTree(rec_points) - distances, _ = gen_points_kd_tree.query(gt_points) - comp_ratio = np.mean((distances < dist_th).astype(np.float32)) - return comp_ratio - - -def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None): - gt_points_kd_tree = KDTree(gt_points) - distances, idx = gt_points_kd_tree.query(rec_points, workers=-1) - acc = np.mean(distances) - - acc_median = np.median(distances) - - if gt_normals is not None and rec_normals is not None: - normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1) - normal_dot = np.abs(normal_dot) - - return acc, acc_median, np.mean(normal_dot), np.median(normal_dot) - - return acc, acc_median - - -def completion(gt_points, rec_points, gt_normals=None, rec_normals=None): - gt_points_kd_tree = KDTree(rec_points) - distances, idx = gt_points_kd_tree.query(gt_points, workers=-1) - comp = np.mean(distances) - comp_median = np.median(distances) - - if gt_normals is not None and rec_normals is not None: - normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1) - normal_dot = np.abs(normal_dot) - - return comp, comp_median, np.mean(normal_dot), np.median(normal_dot) - - return comp, comp_median - - -def compute_iou(pred_vox, target_vox): - # Get voxel indices - v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()] - v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()] - - # Convert to sets for set operations - v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices) - v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices) - - # Compute intersection and union - intersection = v_pred_filled & v_target_filled - union = v_pred_filled | v_target_filled - - # Compute IoU - iou = len(intersection) / len(union) - return iou - - -def colmap_to_opencv_intrinsics(K): - """ - Modify camera intrinsics to follow a different convention. - Coordinates of the center of the top-left pixels are by default: - - (0.5, 0.5) in Colmap - - (0,0) in OpenCV - """ - K = K.copy() - K[0, 2] -= 0.5 - K[1, 2] -= 0.5 - return K - - -def opencv_to_colmap_intrinsics(K): - """ - Modify camera intrinsics to follow a different convention. - Coordinates of the center of the top-left pixels are by default: - - (0.5, 0.5) in Colmap - - (0,0) in OpenCV - """ - K = K.copy() - K[0, 2] += 0.5 - K[1, 2] += 0.5 - return K diff --git a/FastVGGT/merging/__init__.py b/FastVGGT/merging/__init__.py deleted file mode 100644 index 156ac80ee0ef4270bc02b0954f8473ff55380cce..0000000000000000000000000000000000000000 --- a/FastVGGT/merging/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import merge - -__all__ = ["merge"] diff --git a/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc b/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 66a4e39454e637a8e8e0fa63f6708b1eb670aae9..0000000000000000000000000000000000000000 Binary files a/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/merging/__pycache__/merge.cpython-310.pyc b/FastVGGT/merging/__pycache__/merge.cpython-310.pyc deleted file mode 100644 index 9fa7eff7fec3ec548a7532b57636ea9d7db57e16..0000000000000000000000000000000000000000 Binary files a/FastVGGT/merging/__pycache__/merge.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/merging/merge.py b/FastVGGT/merging/merge.py deleted file mode 100644 index e2094b657cba6c9772a85f2ebe0240367fcec09c..0000000000000000000000000000000000000000 --- a/FastVGGT/merging/merge.py +++ /dev/null @@ -1,370 +0,0 @@ -import torch -from typing import Tuple, Callable, Optional, Union - - -@torch.jit.script -def fast_similarity_chunks( - a: torch.Tensor, b_transposed: torch.Tensor, chunk_size: int -) -> Tuple[torch.Tensor, torch.Tensor]: - - B, num_src, C = a.shape - original_dtype = a.dtype - - # Convert to bf16 for computation to improve performance and reduce memory usage - a_bf16 = a.to(torch.bfloat16) - b_transposed_bf16 = b_transposed.to(torch.bfloat16) - node_max = torch.empty(B, num_src, device=a.device, dtype=original_dtype) - node_idx = torch.empty(B, num_src, device=a.device, dtype=torch.long) - - # Process in chunks - for i in range(0, num_src, chunk_size): - end_i = min(i + chunk_size, num_src) - a_chunk = a_bf16[:, i:end_i, :] # [B, chunk_size, C] - scores_chunk = torch.bmm(a_chunk, b_transposed_bf16) - chunk_max_bf16, chunk_idx = torch.max(scores_chunk, dim=2) - chunk_max = chunk_max_bf16.to(original_dtype) - node_max[:, i:end_i] = chunk_max - node_idx[:, i:end_i] = chunk_idx - return node_max, node_idx - - -def do_nothing( - x: torch.Tensor, - extra_tensors=None, - extra_tensors_2=None, -) -> Union[ - torch.Tensor, - Tuple[torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor, torch.Tensor], -]: - if extra_tensors is not None and extra_tensors_2 is not None: - return x, extra_tensors, extra_tensors_2 - elif extra_tensors is not None: - return x, extra_tensors - else: - return x - - -def token_merge_bipartite2d( - metric: torch.Tensor, - w: int, - h: int, - sx: int, - sy: int, - r: int, - no_rand: bool = False, - generator: Optional[torch.Generator] = None, - enable_protection: bool = False, -) -> Tuple[Callable, Callable]: - """ - Divide tokens into source (src) and destination (dst) groups, and merge r tokens from src to dst. - dst tokens are selected by randomly choosing one token from each (sx, sy) region. - Optionally protect the top 10% of tokens from merging based on importance scores. - - Args: - - metric [B, N, C]: Tensor for similarity computation, B=batch size, N=token count, C=feature dimension - - w: Image width in tokens - - h: Image height in tokens - - sx: dst stride in x dimension, must divide w evenly - - sy: dst stride in y dimension, must divide h evenly - - r: Number of tokens to remove through merging - - no_rand: If True, disable randomness (use only top-left token) - - generator: Random number generator if no_rand is False and not None - - enable_protection: If True, enable importance protection feature - - Returns: - - (merge, unmerge): Two functions for merging tokens and restoring pre-merge state - """ - B, N, _ = metric.shape # Batch size B, total tokens N - if r <= 0: - return do_nothing, do_nothing - - gather = torch.gather - - tokens_per_img = w * h + 5 - num_imgs = N // tokens_per_img - assert tokens_per_img * num_imgs == N, "Token count doesn't match (w*h+5)*num_imgs" - - with torch.no_grad(): - # Determine whether to compute importance scores based on enable_protection - if enable_protection: - num_protected = int(N * 0.1) - step = max(1, N // num_protected) - protected_indices = torch.arange(0, N, step, device=metric.device)[ - :num_protected - ] - else: - protected_indices = None - num_protected = 0 - - # Global idx_buffer_seq of length N; -1 indicates dst, 0 indicates src (maintain original logic) - idx_buffer_seq = torch.zeros(N, device=metric.device, dtype=torch.int64) - hsy, wsx = h // sy, w // sx # Number of blocks within each image - - # Mark first image entirely as dst - if num_imgs > 0: - idx_buffer_seq[:tokens_per_img] = -1 - - # Process other images - fully vectorized batch operations - if num_imgs > 1: - cls_indices = ( - torch.arange(1, num_imgs, device=metric.device) * tokens_per_img - ) - cls_indices = cls_indices[:, None] + torch.arange(5, device=metric.device) - idx_buffer_seq[cls_indices.flatten()] = -1 - effective_h = min(hsy * sy, h) - effective_w = min(wsx * sx, w) - effective_grid_size = effective_h * effective_w - - if no_rand: - base_pattern = torch.zeros( - effective_grid_size, device=metric.device, dtype=torch.int64 - ) - grid_starts = ( - torch.arange(1, num_imgs, device=metric.device) * tokens_per_img + 5 - ) - grid_indices = grid_starts[:, None] + torch.arange( - effective_grid_size, device=metric.device - ) - idx_buffer_seq[grid_indices.flatten()] = base_pattern.repeat( - num_imgs - 1 - ) - else: - total_other_imgs = num_imgs - 1 - all_rand_idx = torch.randint( - sy * sx, - size=(total_other_imgs, hsy, wsx), - device=metric.device, - generator=generator, - ) - - scatter_src = -torch.ones( - total_other_imgs, hsy, wsx, device=metric.device, dtype=torch.int64 - ) - - idx_buffer_batch = torch.zeros( - total_other_imgs, - hsy, - wsx, - sy * sx, - device=metric.device, - dtype=torch.int64, - ) - idx_buffer_batch.scatter_( - dim=3, - index=all_rand_idx.unsqueeze(-1), - src=scatter_src.unsqueeze(-1), - ) - - idx_buffer_batch = ( - idx_buffer_batch.view(total_other_imgs, hsy, wsx, sy, sx) - .transpose(2, 3) - .reshape(total_other_imgs, hsy * sy, wsx * sx) - ) - - # Batch fill to target positions - still needs a small loop here, but operations are greatly reduced - for i in range(total_other_imgs): - img_idx = i + 1 - grid_start = img_idx * tokens_per_img + 5 - flat_view = idx_buffer_batch[ - i, :effective_h, :effective_w - ].flatten() - idx_buffer_seq[grid_start : grid_start + effective_grid_size] = ( - flat_view - ) - - rand_idx = idx_buffer_seq.reshape(1, -1, 1).argsort(dim=1) - num_dst_orig = int((idx_buffer_seq == -1).sum()) - - # Original src and dst indices - a_idx_orig = rand_idx[:, num_dst_orig:, :] - b_idx_orig = rand_idx[:, :num_dst_orig, :] - a_idx = a_idx_orig - b_idx = b_idx_orig - - if enable_protection: - protected_idx = protected_indices.unsqueeze(0).unsqueeze(-1) - num_protected_actual = protected_idx.shape[1] - else: - protected_idx = None - num_protected_actual = 0 - - num_src = a_idx.shape[1] - num_dst = b_idx.shape[1] - - # Define an internal function to separate src, dst, and protected tokens - def split(x): - C = x.shape[-1] - - if enable_protection: - src = gather(x, dim=1, index=a_idx.expand(B, num_src, C)) - dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) - protected = gather( - x, dim=1, index=protected_idx.expand(B, num_protected_actual, C) - ) - return src, dst, protected - else: - src = gather(x, dim=1, index=a_idx.expand(B, num_src, C)) - dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) - return src, dst - - # Compute cosine similarity (normalize first then dot product) - metric = metric / metric.norm(dim=-1, keepdim=True) - if enable_protection: - a, b, protected = split(metric) - else: - a, b = split(metric) - - r = min(a.shape[1], r) - num_src_actual = a.shape[1] - chunk_size = min(5000, num_src_actual) - - node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype) - node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long) - - b_transposed = b.transpose(-1, -2) - node_max, node_idx = fast_similarity_chunks(a, b_transposed, chunk_size) - edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] - - # If protection is enabled, filter out protected tokens to ensure they are not merged - if enable_protection: - src_indices = a_idx[0, :, 0] - protected_mask_src = torch.isin(src_indices, protected_indices) - edge_flat = edge_idx[0, :, 0] - valid_mask = ~protected_mask_src[edge_flat] - valid_edges = edge_flat[valid_mask] - - valid_count = valid_edges.shape[0] - r_actual = min(r, valid_count) - - unm_idx = valid_edges[r_actual:].unsqueeze(0).unsqueeze(-1) - src_idx = valid_edges[:r_actual].unsqueeze(0).unsqueeze(-1) - else: - unm_idx = edge_idx[..., r:, :] - src_idx = edge_idx[..., :r, :] - r_actual = r - - # Get dst token indices corresponding to each src token to be merged - dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) - r = r_actual - - # Define merge function to merge selected src tokens to corresponding dst tokens - def merge( - x: torch.Tensor, - mode: str = "mean", - extra_tensors=None, - extra_tensors_2=None, - ) -> Union[ - torch.Tensor, - Tuple[torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - ]: - if enable_protection: - src, dst, protected = split(x) - else: - src, dst = split(x) - - n, t1, c = src.shape - - # Extract unmerged src tokens - using actual unm_idx size - unm_len = unm_idx.shape[1] - unm = gather(src, dim=-2, index=unm_idx.expand(n, unm_len, c)) - src_len = src_idx.shape[1] - src = gather(src, dim=-2, index=src_idx.expand(n, src_len, c)) - dst = dst.scatter_reduce(-2, dst_idx.expand(n, src_len, c), src, reduce=mode) - - # ---------------- Extra tensor processing ---------------- - merged_extra_1 = None - merged_extra_2 = None - if extra_tensors is not None: - E_dim = extra_tensors.shape[-1] - if enable_protection: - src_e, dst_e, protected_e = split(extra_tensors) - else: - src_e, dst_e = split(extra_tensors) - - # Consistent with main tensor, only select r src tokens to be merged - src_e_r = gather(src_e, dim=-2, index=src_idx.expand(n, src_len, E_dim)) - unm_e = gather(src_e, dim=-2, index=unm_idx.expand(n, unm_len, E_dim)) - - dst_e = dst_e.scatter_reduce( - -2, dst_idx.expand(n, src_len, E_dim), src_e_r, reduce=mode - ) - if enable_protection: - merged_extra_1 = torch.cat([unm_e, dst_e, protected_e], dim=1) - else: - merged_extra_1 = torch.cat([unm_e, dst_e], dim=1) - - if extra_tensors_2 is not None: - E_dim_2 = extra_tensors_2.shape[-1] - if enable_protection: - src_e2, dst_e2, protected_e2 = split(extra_tensors_2) - else: - src_e2, dst_e2 = split(extra_tensors_2) - - src_e2_r = gather(src_e2, dim=-2, index=src_idx.expand(n, src_len, E_dim_2)) - unm_e2 = gather(src_e2, dim=-2, index=unm_idx.expand(n, unm_len, E_dim_2)) - - dst_e2 = dst_e2.scatter_reduce( - -2, dst_idx.expand(n, src_len, E_dim_2), src_e2_r, reduce=mode - ) - if enable_protection: - merged_extra_2 = torch.cat([unm_e2, dst_e2, protected_e2], dim=1) - else: - merged_extra_2 = torch.cat([unm_e2, dst_e2], dim=1) - - if enable_protection: - main_result = torch.cat([unm, dst, protected], dim=1) - else: - main_result = torch.cat([unm, dst], dim=1) - - if merged_extra_1 is not None and merged_extra_2 is not None: - return main_result, merged_extra_1, merged_extra_2 - elif merged_extra_1 is not None: - return main_result, merged_extra_1 - else: - return main_result - - # Define unmerge function to restore pre-merge state (for decoder) - def unmerge(x: torch.Tensor) -> torch.Tensor: - unm_len = unm_idx.shape[1] - dst_len = num_dst - src_len = src_idx.shape[1] - unm = x[..., :unm_len, :] - dst = x[..., unm_len : unm_len + dst_len, :] - - if enable_protection: - protected = x[ - ..., unm_len + dst_len : unm_len + dst_len + num_protected_actual, : - ] - - _, _, c = unm.shape - src = gather(dst, dim=-2, index=dst_idx.expand(B, src_len, c)) - out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) - out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) - out.scatter_( - dim=-2, - index=gather( - a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx - ).expand(B, unm_len, c), - src=unm, - ) - - out.scatter_( - dim=-2, - index=gather( - a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx - ).expand(B, src_len, c), - src=src, - ) - - if enable_protection: - out.scatter_( - dim=-2, - index=protected_idx.expand(B, num_protected_actual, c), - src=protected, - ) - - return out - - return merge, unmerge diff --git a/FastVGGT/requirements.txt b/FastVGGT/requirements.txt deleted file mode 100644 index e343637f2d9a1550f55f14f8890c229bd467cdb2..0000000000000000000000000000000000000000 --- a/FastVGGT/requirements.txt +++ /dev/null @@ -1,15 +0,0 @@ -torch==2.3.1 -torchvision==0.18.1 -numpy==1.26.1 -Pillow -huggingface_hub -einops -safetensors -evo -open3d -matplotlib -scipy -opencv-python -scikit-image -tqdm - diff --git a/FastVGGT/vggt/__init__.py b/FastVGGT/vggt/__init__.py deleted file mode 100644 index 3a2958f93a114495631a3ce270b99ee8ba8443f1..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index c5f3bfe676b899c95cba2c034de5292d685e03ce..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/dependency/__init__.py b/FastVGGT/vggt/dependency/__init__.py deleted file mode 100644 index 3a2958f93a114495631a3ce270b99ee8ba8443f1..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/dependency/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index cd09b8397b7a4717d49afb51e64d490fa11fc205..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc b/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc deleted file mode 100644 index 389177102f318fe4bcfd46f6ab4a04e1a8f39196..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/dependency/distortion.py b/FastVGGT/vggt/dependency/distortion.py deleted file mode 100644 index 375b747086478050d601676ffaea3b25e80690b4..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/dependency/distortion.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import torch - - -def apply_distortion(points, distortion_params): - """ - Apply distortion to normalized camera coordinates. - - Args: - points: Array of normalized camera coordinates - distortion_params: Distortion parameters - - Returns: - Distorted coordinates - """ - # Simple passthrough for now - implement actual distortion if needed - return points - - -def iterative_undistortion(points, distortion_params, max_iter=10): - """ - Remove distortion from normalized camera coordinates using iterative method. - - Args: - points: Array of distorted normalized camera coordinates - distortion_params: Distortion parameters - max_iter: Maximum number of iterations - - Returns: - Undistorted coordinates - """ - # Simple passthrough for now - implement actual undistortion if needed - return points - - -def single_undistortion(points, distortion_params): - """ - Remove distortion from normalized camera coordinates using single step. - - Args: - points: Array of distorted normalized camera coordinates - distortion_params: Distortion parameters - - Returns: - Undistorted coordinates - """ - # Simple passthrough for now - implement actual undistortion if needed - return points \ No newline at end of file diff --git a/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc deleted file mode 100644 index 292dc30328d06ba3951a9eee0e99bc569c60a1d3..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc deleted file mode 100644 index 838db11a0862e9907682572d8ebe3b7c1665e147..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc deleted file mode 100644 index 6d6cd93db3a5987c35085d487d21f5513ea3d1fd..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc deleted file mode 100644 index a3244c2b0364f9b533aab03446edb787dbed59d5..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index fc2e8df264f6accc5412e80b8580c8c48195aaf5..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/camera_head.py b/FastVGGT/vggt/heads/camera_head.py deleted file mode 100644 index 215adf39de23abd4975479d332250fcc3e2b54b9..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/camera_head.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import math -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from vggt.layers import Mlp -from vggt.layers.block import Block -from vggt.heads.head_act import activate_pose - - -class CameraHead(nn.Module): - """ - CameraHead predicts camera parameters from token representations using iterative refinement. - - It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. - """ - - def __init__( - self, - dim_in: int = 2048, - trunk_depth: int = 4, - pose_encoding_type: str = "absT_quaR_FoV", - num_heads: int = 16, - mlp_ratio: int = 4, - init_values: float = 0.01, - trans_act: str = "linear", - quat_act: str = "linear", - fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. - ): - super().__init__() - - if pose_encoding_type == "absT_quaR_FoV": - self.target_dim = 9 - else: - raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") - - self.trans_act = trans_act - self.quat_act = quat_act - self.fl_act = fl_act - self.trunk_depth = trunk_depth - - # Build the trunk using a sequence of transformer blocks. - self.trunk = nn.Sequential( - *[ - Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) - for _ in range(trunk_depth) - ] - ) - - # Normalizations for camera token and trunk output. - self.token_norm = nn.LayerNorm(dim_in) - self.trunk_norm = nn.LayerNorm(dim_in) - - # Learnable empty camera pose token. - self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) - self.embed_pose = nn.Linear(self.target_dim, dim_in) - - # Module for producing modulation parameters: shift, scale, and a gate. - self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) - - # Adaptive layer normalization without affine parameters. - self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) - self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) - - def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: - """ - Forward pass to predict camera parameters. - - Args: - aggregated_tokens_list (list): List of token tensors from the network; - the last tensor is used for prediction. - num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. - - Returns: - list: A list of predicted camera encodings (post-activation) from each iteration. - """ - # Use tokens from the last block for camera prediction. - tokens = aggregated_tokens_list[-1] - - # Extract the camera tokens - pose_tokens = tokens[:, :, 0] - pose_tokens = self.token_norm(pose_tokens) - - pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) - return pred_pose_enc_list - - def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: - """ - Iteratively refine camera pose predictions. - - Args: - pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. - num_iterations (int): Number of refinement iterations. - - Returns: - list: List of activated camera encodings from each iteration. - """ - B, S, C = pose_tokens.shape # S is expected to be 1. - pred_pose_enc = None - pred_pose_enc_list = [] - - for _ in range(num_iterations): - # Use a learned empty pose for the first iteration. - if pred_pose_enc is None: - module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) - else: - # Detach the previous prediction to avoid backprop through time. - pred_pose_enc = pred_pose_enc.detach() - module_input = self.embed_pose(pred_pose_enc) - - # Generate modulation parameters and split them into shift, scale, and gate components. - shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) - - # Adaptive layer normalization and modulation. - pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) - pose_tokens_modulated = pose_tokens_modulated + pose_tokens - - pose_tokens_modulated = self.trunk(pose_tokens_modulated) - # Compute the delta update for the pose encoding. - pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) - - if pred_pose_enc is None: - pred_pose_enc = pred_pose_enc_delta - else: - pred_pose_enc = pred_pose_enc + pred_pose_enc_delta - - # Apply final activation functions for translation, quaternion, and field-of-view. - activated_pose = activate_pose( - pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act - ) - pred_pose_enc_list.append(activated_pose) - - return pred_pose_enc_list - - -def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: - """ - Modulate the input tensor using scaling and shifting parameters. - """ - # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 - return x * (1 + scale) + shift diff --git a/FastVGGT/vggt/heads/dpt_head.py b/FastVGGT/vggt/heads/dpt_head.py deleted file mode 100644 index e20d65ef5bfeb23cf83ca748aedff738840b4ffd..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/dpt_head.py +++ /dev/null @@ -1,598 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 - - -import os -from typing import List, Dict, Tuple, Union, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from .head_act import activate_head -from .utils import create_uv_grid, position_grid_to_embed - - -class DPTHead(nn.Module): - """ - DPT Head for dense prediction tasks. - - This implementation follows the architecture described in "Vision Transformers for Dense Prediction" - (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer - backbone and produces dense predictions by fusing multi-scale features. - - Args: - dim_in (int): Input dimension (channels). - patch_size (int, optional): Patch size. Default is 14. - output_dim (int, optional): Number of output channels. Default is 4. - activation (str, optional): Activation type. Default is "inv_log". - conf_activation (str, optional): Confidence activation type. Default is "expp1". - features (int, optional): Feature channels for intermediate representations. Default is 256. - out_channels (List[int], optional): Output channels for each intermediate layer. - intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. - pos_embed (bool, optional): Whether to use positional embedding. Default is True. - feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. - down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. - """ - - def __init__( - self, - dim_in: int, - patch_size: int = 14, - output_dim: int = 4, - activation: str = "inv_log", - conf_activation: str = "expp1", - features: int = 256, - out_channels: List[int] = [256, 512, 1024, 1024], - intermediate_layer_idx: List[int] = [0, 1, 2, 3], - pos_embed: bool = True, - feature_only: bool = False, - down_ratio: int = 1, - ) -> None: - super(DPTHead, self).__init__() - self.patch_size = patch_size - self.activation = activation - self.conf_activation = conf_activation - self.pos_embed = pos_embed - self.feature_only = feature_only - self.down_ratio = down_ratio - self.intermediate_layer_idx = intermediate_layer_idx - - self.norm = nn.LayerNorm(dim_in) - - # Projection layers for each output channel from tokens. - self.projects = nn.ModuleList( - [ - nn.Conv2d( - in_channels=dim_in, - out_channels=oc, - kernel_size=1, - stride=1, - padding=0, - ) - for oc in out_channels - ] - ) - - # Resize layers for upsampling feature maps. - self.resize_layers = nn.ModuleList( - [ - nn.ConvTranspose2d( - in_channels=out_channels[0], - out_channels=out_channels[0], - kernel_size=4, - stride=4, - padding=0, - ), - nn.ConvTranspose2d( - in_channels=out_channels[1], - out_channels=out_channels[1], - kernel_size=2, - stride=2, - padding=0, - ), - nn.Identity(), - nn.Conv2d( - in_channels=out_channels[3], - out_channels=out_channels[3], - kernel_size=3, - stride=2, - padding=1, - ), - ] - ) - - self.scratch = _make_scratch(out_channels, features, expand=False) - - # Attach additional modules to scratch. - self.scratch.stem_transpose = nn.Identity() # Use Identity instead of None - self.scratch.refinenet1 = _make_fusion_block(features) - self.scratch.refinenet2 = _make_fusion_block(features) - self.scratch.refinenet3 = _make_fusion_block(features) - self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) - - head_features_1 = features - head_features_2 = 32 - - if feature_only: - self.scratch.output_conv1 = nn.Conv2d( - head_features_1, head_features_1, kernel_size=3, stride=1, padding=1 - ) - else: - self.scratch.output_conv1 = nn.Conv2d( - head_features_1, - head_features_1 // 2, - kernel_size=3, - stride=1, - padding=1, - ) - conv2_in_channels = head_features_1 // 2 - - self.scratch.output_conv2 = nn.Sequential( - nn.Conv2d( - conv2_in_channels, - head_features_2, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.Conv2d( - head_features_2, output_dim, kernel_size=1, stride=1, padding=0 - ), - ) - - def forward( - self, - aggregated_tokens_list: List[torch.Tensor], - images: torch.Tensor, - patch_start_idx: int, - frames_chunk_size: int = 8, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Forward pass through the DPT head, supports processing by chunking frames. - Args: - aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. - images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. - patch_start_idx (int): Starting index for patch tokens in the token sequence. - Used to separate patch tokens from other tokens (e.g., camera or register tokens). - frames_chunk_size (int, optional): Number of frames to process in each chunk. - If None or larger than S, all frames are processed at once. Default: 8. - - Returns: - Tensor or Tuple[Tensor, Tensor]: - - If feature_only=True: Feature maps with shape [B, S, C, H, W] - - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] - """ - B, S, _, H, W = images.shape - - # If frames_chunk_size is not specified or greater than S, process all frames at once - if frames_chunk_size is None or frames_chunk_size >= S: - return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) - - # Otherwise, process frames in chunks to manage memory usage - assert frames_chunk_size > 0 - - # Process frames in batches - all_preds = [] - all_conf = [] - - for frames_start_idx in range(0, S, frames_chunk_size): - frames_end_idx = min(frames_start_idx + frames_chunk_size, S) - - # Process batch of frames - if self.feature_only: - chunk_output = self._forward_impl( - aggregated_tokens_list, - images, - patch_start_idx, - frames_start_idx, - frames_end_idx, - ) - all_preds.append(chunk_output) - else: - chunk_preds, chunk_conf = self._forward_impl( - aggregated_tokens_list, - images, - patch_start_idx, - frames_start_idx, - frames_end_idx, - ) - all_preds.append(chunk_preds) - all_conf.append(chunk_conf) - - # Concatenate results along the sequence dimension - if self.feature_only: - return torch.cat(all_preds, dim=1) - else: - return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) - - def _forward_impl( - self, - aggregated_tokens_list: List[torch.Tensor], - images: torch.Tensor, - patch_start_idx: int, - frames_start_idx: Optional[int] = None, - frames_end_idx: Optional[int] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Implementation of the forward pass through the DPT head. - - This method processes a specific chunk of frames from the sequence. - - Args: - aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. - images (Tensor): Input images with shape [B, S, 3, H, W]. - patch_start_idx (int): Starting index for patch tokens. - frames_start_idx (int, optional): Starting index for frames to process. - frames_end_idx (int, optional): Ending index for frames to process. - - Returns: - Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). - """ - if frames_start_idx is not None and frames_end_idx is not None: - images = images[:, frames_start_idx:frames_end_idx].contiguous() - - B, S, _, H, W = images.shape - - patch_h, patch_w = H // self.patch_size, W // self.patch_size - - out = [] - dpt_idx = 0 - - for layer_idx in self.intermediate_layer_idx: - x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] - - # Select frames if processing a chunk - if frames_start_idx is not None and frames_end_idx is not None: - x = x[:, frames_start_idx:frames_end_idx] - - x = x.reshape(B * S, -1, x.shape[-1]) - - x = self.norm(x) - - x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) - - x = self.projects[dpt_idx](x) - if self.pos_embed: - x = self._apply_pos_embed(x, W, H) - x = self.resize_layers[dpt_idx](x) - - out.append(x) - dpt_idx += 1 - - # Fuse features from multiple layers. - out = self.scratch_forward(out) - # Interpolate fused output to match target image resolution. - out = custom_interpolate( - out, - ( - int(patch_h * self.patch_size / self.down_ratio), - int(patch_w * self.patch_size / self.down_ratio), - ), - mode="bilinear", - align_corners=True, - ) - - if self.pos_embed: - out = self._apply_pos_embed(out, W, H) - - if self.feature_only: - return out.view(B, S, *out.shape[1:]) - - out = self.scratch.output_conv2(out) - preds, conf = activate_head( - out, activation=self.activation, conf_activation=self.conf_activation - ) - - preds = preds.view(B, S, *preds.shape[1:]) - conf = conf.view(B, S, *conf.shape[1:]) - return preds, conf - - def _apply_pos_embed( - self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1 - ) -> torch.Tensor: - """ - Apply positional embedding to tensor x. - """ - patch_w = x.shape[-1] - patch_h = x.shape[-2] - pos_embed = create_uv_grid( - patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device - ) - pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) - pos_embed = pos_embed * ratio - pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) - return x + pos_embed - - def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: - """ - Forward pass through the fusion blocks. - - Args: - features (List[Tensor]): List of feature maps from different layers. - - Returns: - Tensor: Fused feature map. - """ - layer_1, layer_2, layer_3, layer_4 = features - - layer_1_rn = self.scratch.layer1_rn(layer_1) - layer_2_rn = self.scratch.layer2_rn(layer_2) - layer_3_rn = self.scratch.layer3_rn(layer_3) - layer_4_rn = self.scratch.layer4_rn(layer_4) - - out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) - del layer_4_rn, layer_4 - - out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) - del layer_3_rn, layer_3 - - out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) - del layer_2_rn, layer_2 - - out = self.scratch.refinenet1(out, layer_1_rn) - del layer_1_rn, layer_1 - - out = self.scratch.output_conv1(out) - return out - - -################################################################################ -# Modules -################################################################################ - - -def _make_fusion_block( - features: int, - size: Optional[int] = None, - has_residual: bool = True, - groups: int = 1, -) -> nn.Module: - return FeatureFusionBlock( - features, - nn.ReLU(inplace=True), - deconv=False, - bn=False, - expand=False, - align_corners=True, - size=size, - has_residual=has_residual, - groups=groups, - ) - - -def _make_scratch( - in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False -) -> nn.Module: - scratch = nn.Module() - out_shape1 = out_shape - out_shape2 = out_shape - out_shape3 = out_shape - if len(in_shape) >= 4: - out_shape4 = out_shape - - if expand: - out_shape1 = out_shape - out_shape2 = out_shape * 2 - out_shape3 = out_shape * 4 - if len(in_shape) >= 4: - out_shape4 = out_shape * 8 - - scratch.layer1_rn = nn.Conv2d( - in_shape[0], - out_shape1, - kernel_size=3, - stride=1, - padding=1, - bias=False, - groups=groups, - ) - scratch.layer2_rn = nn.Conv2d( - in_shape[1], - out_shape2, - kernel_size=3, - stride=1, - padding=1, - bias=False, - groups=groups, - ) - scratch.layer3_rn = nn.Conv2d( - in_shape[2], - out_shape3, - kernel_size=3, - stride=1, - padding=1, - bias=False, - groups=groups, - ) - if len(in_shape) >= 4: - scratch.layer4_rn = nn.Conv2d( - in_shape[3], - out_shape4, - kernel_size=3, - stride=1, - padding=1, - bias=False, - groups=groups, - ) - return scratch - - -class ResidualConvUnit(nn.Module): - """Residual convolution module.""" - - def __init__(self, features, activation, bn, groups=1): - """Init. - - Args: - features (int): number of features - """ - super().__init__() - - self.bn = bn - self.groups = groups - self.conv1 = nn.Conv2d( - features, - features, - kernel_size=3, - stride=1, - padding=1, - bias=True, - groups=self.groups, - ) - self.conv2 = nn.Conv2d( - features, - features, - kernel_size=3, - stride=1, - padding=1, - bias=True, - groups=self.groups, - ) - - self.norm1 = None - self.norm2 = None - - self.activation = activation - - def forward(self, x): - """Forward pass. - - Args: - x (tensor): input - - Returns: - tensor: output - """ - - out = self.activation(x) - out = self.conv1(out) - if self.norm1 is not None: - out = self.norm1(out) - - out = self.activation(out) - out = self.conv2(out) - if self.norm2 is not None: - out = self.norm2(out) - - return out + x - - -class FeatureFusionBlock(nn.Module): - """Feature fusion block.""" - - def __init__( - self, - features, - activation, - deconv=False, - bn=False, - expand=False, - align_corners=True, - size=None, - has_residual=True, - groups=1, - ): - """Init. - - Args: - features (int): number of features - """ - super(FeatureFusionBlock, self).__init__() - - self.deconv = deconv - self.align_corners = align_corners - self.groups = groups - self.expand = expand - out_features = features - if self.expand == True: - out_features = features // 2 - - self.out_conv = nn.Conv2d( - features, - out_features, - kernel_size=1, - stride=1, - padding=0, - bias=True, - groups=self.groups, - ) - - if has_residual: - self.resConfUnit1 = ResidualConvUnit( - features, activation, bn, groups=self.groups - ) - - self.has_residual = has_residual - self.resConfUnit2 = ResidualConvUnit( - features, activation, bn, groups=self.groups - ) - - self.size = size - - def forward(self, *xs, size=None): - """Forward pass. - - Returns: - tensor: output - """ - output = xs[0] - - if self.has_residual: - res = self.resConfUnit1(xs[1]) - output = output + res - - output = self.resConfUnit2(output) - - if (size is None) and (self.size is None): - modifier = {"scale_factor": 2} - elif size is None: - modifier = {"size": self.size} - else: - modifier = {"size": size} - - output = custom_interpolate( - output, **modifier, mode="bilinear", align_corners=self.align_corners - ) - output = self.out_conv(output) - - return output - - -def custom_interpolate( - x: torch.Tensor, - size: Optional[Tuple[int, int]] = None, - scale_factor: Optional[float] = None, - mode: str = "bilinear", - align_corners: bool = True, -) -> torch.Tensor: - """ - Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. - """ - if size is None: - size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) - - INT_MAX = 1610612736 - - input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] - - if input_elements > INT_MAX: - chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) - interpolated_chunks = [ - nn.functional.interpolate( - chunk, size=size, mode=mode, align_corners=align_corners - ) - for chunk in chunks - ] - x = torch.cat(interpolated_chunks, dim=0) - return x.contiguous() - else: - return nn.functional.interpolate( - x, size=size, mode=mode, align_corners=align_corners - ) diff --git a/FastVGGT/vggt/heads/head_act.py b/FastVGGT/vggt/heads/head_act.py deleted file mode 100644 index 2dedfcf1180a653dddc99623e60df625e5897489..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/head_act.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -import torch.nn.functional as F - - -def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): - """ - Activate pose parameters with specified activation functions. - - Args: - pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] - trans_act: Activation type for translation component - quat_act: Activation type for quaternion component - fl_act: Activation type for focal length component - - Returns: - Activated pose parameters tensor - """ - T = pred_pose_enc[..., :3] - quat = pred_pose_enc[..., 3:7] - fl = pred_pose_enc[..., 7:] # or fov - - T = base_pose_act(T, trans_act) - quat = base_pose_act(quat, quat_act) - fl = base_pose_act(fl, fl_act) # or fov - - pred_pose_enc = torch.cat([T, quat, fl], dim=-1) - - return pred_pose_enc - - -def base_pose_act(pose_enc, act_type="linear"): - """ - Apply basic activation function to pose parameters. - - Args: - pose_enc: Tensor containing encoded pose parameters - act_type: Activation type ("linear", "inv_log", "exp", "relu") - - Returns: - Activated pose parameters - """ - if act_type == "linear": - return pose_enc - elif act_type == "inv_log": - return inverse_log_transform(pose_enc) - elif act_type == "exp": - return torch.exp(pose_enc) - elif act_type == "relu": - return F.relu(pose_enc) - else: - raise ValueError(f"Unknown act_type: {act_type}") - - -def activate_head(out, activation="norm_exp", conf_activation="expp1"): - """ - Process network output to extract 3D points and confidence values. - - Args: - out: Network output tensor (B, C, H, W) - activation: Activation type for 3D points - conf_activation: Activation type for confidence values - - Returns: - Tuple of (3D points tensor, confidence tensor) - """ - # Move channels from last dim to the 4th dimension => (B, H, W, C) - fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected - - # Split into xyz (first C-1 channels) and confidence (last channel) - xyz = fmap[:, :, :, :-1] - conf = fmap[:, :, :, -1] - - if activation == "norm_exp": - d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) - xyz_normed = xyz / d - pts3d = xyz_normed * torch.expm1(d) - elif activation == "norm": - pts3d = xyz / xyz.norm(dim=-1, keepdim=True) - elif activation == "exp": - pts3d = torch.exp(xyz) - elif activation == "relu": - pts3d = F.relu(xyz) - elif activation == "inv_log": - pts3d = inverse_log_transform(xyz) - elif activation == "xy_inv_log": - xy, z = xyz.split([2, 1], dim=-1) - z = inverse_log_transform(z) - pts3d = torch.cat([xy * z, z], dim=-1) - elif activation == "sigmoid": - pts3d = torch.sigmoid(xyz) - elif activation == "linear": - pts3d = xyz - else: - raise ValueError(f"Unknown activation: {activation}") - - if conf_activation == "expp1": - conf_out = 1 + conf.exp() - elif conf_activation == "expp0": - conf_out = conf.exp() - elif conf_activation == "sigmoid": - conf_out = torch.sigmoid(conf) - else: - raise ValueError(f"Unknown conf_activation: {conf_activation}") - - return pts3d, conf_out - - -def inverse_log_transform(y): - """ - Apply inverse log transform: sign(y) * (exp(|y|) - 1) - - Args: - y: Input tensor - - Returns: - Transformed tensor - """ - return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/FastVGGT/vggt/heads/track_head.py b/FastVGGT/vggt/heads/track_head.py deleted file mode 100644 index a4f1d9bd83cca1f74f97a644a02b984904f84706..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/track_head.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch.nn as nn -from .dpt_head import DPTHead -from .track_modules.base_track_predictor import BaseTrackerPredictor - - -class TrackHead(nn.Module): - """ - Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. - The tracking is performed iteratively, refining predictions over multiple iterations. - """ - - def __init__( - self, - dim_in, - patch_size=14, - features=128, - iters=4, - predict_conf=True, - stride=2, - corr_levels=7, - corr_radius=4, - hidden_size=384, - ): - """ - Initialize the TrackHead module. - - Args: - dim_in (int): Input dimension of tokens from the backbone. - patch_size (int): Size of image patches used in the vision transformer. - features (int): Number of feature channels in the feature extractor output. - iters (int): Number of refinement iterations for tracking predictions. - predict_conf (bool): Whether to predict confidence scores for tracked points. - stride (int): Stride value for the tracker predictor. - corr_levels (int): Number of correlation pyramid levels - corr_radius (int): Radius for correlation computation, controlling the search area. - hidden_size (int): Size of hidden layers in the tracker network. - """ - super().__init__() - - self.patch_size = patch_size - - # Feature extractor based on DPT architecture - # Processes tokens into feature maps for tracking - self.feature_extractor = DPTHead( - dim_in=dim_in, - patch_size=patch_size, - features=features, - feature_only=True, # Only output features, no activation - down_ratio=2, # Reduces spatial dimensions by factor of 2 - pos_embed=False, - ) - - # Tracker module that predicts point trajectories - # Takes feature maps and predicts coordinates and visibility - self.tracker = BaseTrackerPredictor( - latent_dim=features, # Match the output_dim of feature extractor - predict_conf=predict_conf, - stride=stride, - corr_levels=corr_levels, - corr_radius=corr_radius, - hidden_size=hidden_size, - ) - - self.iters = iters - - def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): - """ - Forward pass of the TrackHead. - - Args: - aggregated_tokens_list (list): List of aggregated tokens from the backbone. - images (torch.Tensor): Input images of shape (B, S, C, H, W) where: - B = batch size, S = sequence length. - patch_start_idx (int): Starting index for patch tokens. - query_points (torch.Tensor, optional): Initial query points to track. - If None, points are initialized by the tracker. - iters (int, optional): Number of refinement iterations. If None, uses self.iters. - - Returns: - tuple: - - coord_preds (torch.Tensor): Predicted coordinates for tracked points. - - vis_scores (torch.Tensor): Visibility scores for tracked points. - - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). - """ - B, S, _, H, W = images.shape - - # Extract features from tokens - # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 - feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) - - # Use default iterations if not specified - if iters is None: - iters = self.iters - - # Perform tracking using the extracted features - coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters) - - return coord_preds, vis_scores, conf_scores diff --git a/FastVGGT/vggt/heads/track_modules/__init__.py b/FastVGGT/vggt/heads/track_modules/__init__.py deleted file mode 100644 index 0952fcc3f57e34b3747962e9ebd6fc57aeea63fa..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/track_modules/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 5ca81dbf0ca49b601ca2662314a3391c8ba00890..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc deleted file mode 100644 index 350f8ca52c0aca644125fea47f83d49337478bcb..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc deleted file mode 100644 index 81e9466ff91b16d265807533eeada2e8e674d7e0..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc deleted file mode 100644 index d3fe19d5e60445f0fa06866bffaa2b8b2e60a6f8..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc deleted file mode 100644 index cac348bf405bc00bc6578760ba8ad57185f33b81..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/heads/track_modules/base_track_predictor.py b/FastVGGT/vggt/heads/track_modules/base_track_predictor.py deleted file mode 100644 index 3ce8ec4b66fff236e015d1bcaf85c8237a52be7a..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/track_modules/base_track_predictor.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from einops import rearrange, repeat - - -from .blocks import EfficientUpdateFormer, CorrBlock -from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed -from .modules import Mlp - - -class BaseTrackerPredictor(nn.Module): - def __init__( - self, - stride=1, - corr_levels=5, - corr_radius=4, - latent_dim=128, - hidden_size=384, - use_spaceatt=True, - depth=6, - max_scale=518, - predict_conf=True, - ): - super(BaseTrackerPredictor, self).__init__() - """ - The base template to create a track predictor - - Modified from https://github.com/facebookresearch/co-tracker/ - and https://github.com/facebookresearch/vggsfm - """ - - self.stride = stride - self.latent_dim = latent_dim - self.corr_levels = corr_levels - self.corr_radius = corr_radius - self.hidden_size = hidden_size - self.max_scale = max_scale - self.predict_conf = predict_conf - - self.flows_emb_dim = latent_dim // 2 - - self.corr_mlp = Mlp( - in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, - hidden_features=self.hidden_size, - out_features=self.latent_dim, - ) - - self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 - - self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) - - space_depth = depth if use_spaceatt else 0 - time_depth = depth - - self.updateformer = EfficientUpdateFormer( - space_depth=space_depth, - time_depth=time_depth, - input_dim=self.transformer_dim, - hidden_size=self.hidden_size, - output_dim=self.latent_dim + 2, - mlp_ratio=4.0, - add_space_attn=use_spaceatt, - ) - - self.fmap_norm = nn.LayerNorm(self.latent_dim) - self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) - - # A linear layer to update track feats at each iteration - self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) - - self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) - - if predict_conf: - self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) - - def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): - """ - query_points: B x N x 2, the number of batches, tracks, and xy - fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. - note HH and WW is the size of feature maps instead of original images - """ - B, N, D = query_points.shape - B, S, C, HH, WW = fmaps.shape - - assert D == 2, "Input points must be 2D coordinates" - - # apply a layernorm to fmaps here - fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) - fmaps = fmaps.permute(0, 1, 4, 2, 3) - - # Scale the input query_points because we may downsample the images - # by down_ratio or self.stride - # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map - # its query_points should be query_points/4 - if down_ratio > 1: - query_points = query_points / float(down_ratio) - - query_points = query_points / float(self.stride) - - # Init with coords as the query points - # It means the search will start from the position of query points at the reference frames - coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) - - # Sample/extract the features of the query points in the query frame - query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) - - # init track feats by query feats - track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C - # back up the init coords - coords_backup = coords.clone() - - fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) - - coord_preds = [] - - # Iterative Refinement - for _ in range(iters): - # Detach the gradients from the last iteration - # (in my experience, not very important for performance) - coords = coords.detach() - - fcorrs = fcorr_fn.corr_sample(track_feats, coords) - - corr_dim = fcorrs.shape[3] - fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) - fcorrs_ = self.corr_mlp(fcorrs_) - - # Movement of current coords relative to query points - flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) - - flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) - - # (In my trials, it is also okay to just add the flows_emb instead of concat) - flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) - - track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) - - # Concatenate them as the input for the transformers - transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) - - # 2D positional embed - # TODO: this can be much simplified - pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) - sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) - - sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) - - x = transformer_input + sampled_pos_emb - - # Add the query ref token to the track feats - query_ref_token = torch.cat( - [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 - ) - x = x + query_ref_token.to(x.device).to(x.dtype) - - # B, N, S, C - x = rearrange(x, "(b n) s d -> b n s d", b=B) - - # Compute the delta coordinates and delta track features - delta, _ = self.updateformer(x) - - # BN, S, C - delta = rearrange(delta, " b n s d -> (b n) s d", b=B) - delta_coords_ = delta[:, :, :2] - delta_feats_ = delta[:, :, 2:] - - track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) - delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) - - # Update the track features - track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ - - track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC - - # B x S x N x 2 - coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) - - # Force coord0 as query - # because we assume the query points should not be changed - coords[:, 0] = coords_backup[:, 0] - - # The predicted tracks are in the original image scale - if down_ratio > 1: - coord_preds.append(coords * self.stride * down_ratio) - else: - coord_preds.append(coords * self.stride) - - # B, S, N - vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) - if apply_sigmoid: - vis_e = torch.sigmoid(vis_e) - - if self.predict_conf: - conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) - if apply_sigmoid: - conf_e = torch.sigmoid(conf_e) - else: - conf_e = None - - if return_feat: - return coord_preds, vis_e, track_feats, query_track_feat, conf_e - else: - return coord_preds, vis_e, conf_e diff --git a/FastVGGT/vggt/heads/track_modules/blocks.py b/FastVGGT/vggt/heads/track_modules/blocks.py deleted file mode 100644 index 15c161c89ef99742b0f2c6f397c9121fe9301e08..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/track_modules/blocks.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -# Modified from https://github.com/facebookresearch/co-tracker/ - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .utils import bilinear_sampler -from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock - - -class EfficientUpdateFormer(nn.Module): - """ - Transformer model that updates track estimates. - """ - - def __init__( - self, - space_depth=6, - time_depth=6, - input_dim=320, - hidden_size=384, - num_heads=8, - output_dim=130, - mlp_ratio=4.0, - add_space_attn=True, - num_virtual_tracks=64, - ): - super().__init__() - - self.out_channels = 2 - self.num_heads = num_heads - self.hidden_size = hidden_size - self.add_space_attn = add_space_attn - - # Add input LayerNorm before linear projection - self.input_norm = nn.LayerNorm(input_dim) - self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) - - # Add output LayerNorm before final projection - self.output_norm = nn.LayerNorm(hidden_size) - self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) - self.num_virtual_tracks = num_virtual_tracks - - if self.add_space_attn: - self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) - else: - self.virual_tracks = None - - self.time_blocks = nn.ModuleList( - [ - AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) - for _ in range(time_depth) - ] - ) - - if add_space_attn: - self.space_virtual_blocks = nn.ModuleList( - [ - AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) - for _ in range(space_depth) - ] - ) - self.space_point2virtual_blocks = nn.ModuleList( - [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] - ) - self.space_virtual2point_blocks = nn.ModuleList( - [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] - ) - assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) - self.initialize_weights() - - def initialize_weights(self): - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) - - self.apply(_basic_init) - - def forward(self, input_tensor, mask=None): - # Apply input LayerNorm - input_tensor = self.input_norm(input_tensor) - tokens = self.input_transform(input_tensor) - - init_tokens = tokens - - B, _, T, _ = tokens.shape - - if self.add_space_attn: - virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) - tokens = torch.cat([tokens, virtual_tokens], dim=1) - - _, N, _, _ = tokens.shape - - j = 0 - for i in range(len(self.time_blocks)): - time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C - - time_tokens = self.time_blocks[i](time_tokens) - - tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C - if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): - space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C - point_tokens = space_tokens[:, : N - self.num_virtual_tracks] - virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] - - virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) - virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) - point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) - - space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) - tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C - j += 1 - - if self.add_space_attn: - tokens = tokens[:, : N - self.num_virtual_tracks] - - tokens = tokens + init_tokens - - # Apply output LayerNorm before final projection - tokens = self.output_norm(tokens) - flow = self.flow_head(tokens) - - return flow, None - - -class CorrBlock: - def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): - """ - Build a pyramid of feature maps from the input. - - fmaps: Tensor (B, S, C, H, W) - num_levels: number of pyramid levels (each downsampled by factor 2) - radius: search radius for sampling correlation - multiple_track_feats: if True, split the target features per pyramid level - padding_mode: passed to grid_sample / bilinear_sampler - """ - B, S, C, H, W = fmaps.shape - self.S, self.C, self.H, self.W = S, C, H, W - self.num_levels = num_levels - self.radius = radius - self.padding_mode = padding_mode - self.multiple_track_feats = multiple_track_feats - - # Build pyramid: each level is half the spatial resolution of the previous - self.fmaps_pyramid = [fmaps] # level 0 is full resolution - current_fmaps = fmaps - for i in range(num_levels - 1): - B, S, C, H, W = current_fmaps.shape - # Merge batch & sequence dimensions - current_fmaps = current_fmaps.reshape(B * S, C, H, W) - # Avg pool down by factor 2 - current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) - _, _, H_new, W_new = current_fmaps.shape - current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) - self.fmaps_pyramid.append(current_fmaps) - - # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. - # This grid is added to the (scaled) coordinate centroids. - r = self.radius - dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) - dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) - # delta: for every (dy,dx) displacement (i.e. Δx, Δy) - self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) - - def corr_sample(self, targets, coords): - """ - Instead of storing the entire correlation pyramid, we compute each level's correlation - volume, sample it immediately, then discard it. This saves GPU memory. - - Args: - targets: Tensor (B, S, N, C) — features for the current targets. - coords: Tensor (B, S, N, 2) — coordinates at full resolution. - - Returns: - Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) - """ - B, S, N, C = targets.shape - - # If you have multiple track features, split them per level. - if self.multiple_track_feats: - targets_split = torch.split(targets, C // self.num_levels, dim=-1) - - out_pyramid = [] - for i, fmaps in enumerate(self.fmaps_pyramid): - # Get current spatial resolution H, W for this pyramid level. - B, S, C, H, W = fmaps.shape - # Reshape feature maps for correlation computation: - # fmap2s: (B, S, C, H*W) - fmap2s = fmaps.view(B, S, C, H * W) - # Choose appropriate target features. - fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) - - # Compute correlation directly - corrs = compute_corr_level(fmap1, fmap2s, C) - corrs = corrs.view(B, S, N, H, W) - - # Prepare sampling grid: - # Scale down the coordinates for the current level. - centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) - # Make sure our precomputed delta grid is on the same device/dtype. - delta_lvl = self.delta.to(coords.device).to(coords.dtype) - # Now the grid for grid_sample is: - # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) - coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) - - # Sample from the correlation volume using bilinear interpolation. - # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. - corrs_sampled = bilinear_sampler( - corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode - ) - # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. - corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) - out_pyramid.append(corrs_sampled) - - # Concatenate all levels along the last dimension. - out = torch.cat(out_pyramid, dim=-1).contiguous() - return out - - -def compute_corr_level(fmap1, fmap2s, C): - # fmap1: (B, S, N, C) - # fmap2s: (B, S, C, H*W) - corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) - corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) - return corrs / math.sqrt(C) diff --git a/FastVGGT/vggt/heads/track_modules/modules.py b/FastVGGT/vggt/heads/track_modules/modules.py deleted file mode 100644 index 12de4f1ad76364d4665e53ac80e1037fadf98d08..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/track_modules/modules.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial -from typing import Callable -import collections -from torch import Tensor -from itertools import repeat - - -# From PyTorch internals -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): - return tuple(x) - return tuple(repeat(x, n)) - - return parse - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -to_2tuple = _ntuple(2) - - -class ResidualBlock(nn.Module): - """ - ResidualBlock: construct a block of two conv layers with residual connections - """ - - def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): - super(ResidualBlock, self).__init__() - - self.conv1 = nn.Conv2d( - in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" - ) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") - self.relu = nn.ReLU(inplace=True) - - num_groups = planes // 8 - - if norm_fn == "group": - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - if not stride == 1: - self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == "batch": - self.norm1 = nn.BatchNorm2d(planes) - self.norm2 = nn.BatchNorm2d(planes) - if not stride == 1: - self.norm3 = nn.BatchNorm2d(planes) - - elif norm_fn == "instance": - self.norm1 = nn.InstanceNorm2d(planes) - self.norm2 = nn.InstanceNorm2d(planes) - if not stride == 1: - self.norm3 = nn.InstanceNorm2d(planes) - - elif norm_fn == "none": - self.norm1 = nn.Sequential() - self.norm2 = nn.Sequential() - if not stride == 1: - self.norm3 = nn.Sequential() - else: - raise NotImplementedError - - if stride == 1: - self.downsample = None - else: - self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) - - def forward(self, x): - y = x - y = self.relu(self.norm1(self.conv1(y))) - y = self.relu(self.norm2(self.conv2(y))) - - if self.downsample is not None: - x = self.downsample(x) - - return self.relu(x + y) - - -class Mlp(nn.Module): - """MLP as used in Vision Transformer, MLP-Mixer and related networks""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - norm_layer=None, - bias=True, - drop=0.0, - use_conv=False, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - bias = to_2tuple(bias) - drop_probs = to_2tuple(drop) - linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear - - self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) - self.act = act_layer() - self.drop1 = nn.Dropout(drop_probs[0]) - self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) - self.drop2 = nn.Dropout(drop_probs[1]) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop1(x) - x = self.fc2(x) - x = self.drop2(x) - return x - - -class AttnBlock(nn.Module): - def __init__( - self, - hidden_size, - num_heads, - attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, - mlp_ratio=4.0, - **block_kwargs, - ): - """ - Self attention block - """ - super().__init__() - - self.norm1 = nn.LayerNorm(hidden_size) - self.norm2 = nn.LayerNorm(hidden_size) - - self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) - - def forward(self, x, mask=None): - # Prepare the mask for PyTorch's attention (it expects a different format) - # attn_mask = mask if mask is not None else None - # Normalize before attention - x = self.norm1(x) - - # PyTorch's MultiheadAttention returns attn_output, attn_output_weights - # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) - - attn_output, _ = self.attn(x, x, x) - - # Add & Norm - x = x + attn_output - x = x + self.mlp(self.norm2(x)) - return x - - -class CrossAttnBlock(nn.Module): - def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): - """ - Cross attention block - """ - super().__init__() - - self.norm1 = nn.LayerNorm(hidden_size) - self.norm_context = nn.LayerNorm(hidden_size) - self.norm2 = nn.LayerNorm(hidden_size) - - self.cross_attn = nn.MultiheadAttention( - embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs - ) - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) - - def forward(self, x, context, mask=None): - # Normalize inputs - x = self.norm1(x) - context = self.norm_context(context) - - # Apply cross attention - # Note: nn.MultiheadAttention returns attn_output, attn_output_weights - attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) - - # Add & Norm - x = x + attn_output - x = x + self.mlp(self.norm2(x)) - return x diff --git a/FastVGGT/vggt/heads/track_modules/utils.py b/FastVGGT/vggt/heads/track_modules/utils.py deleted file mode 100644 index 3f1fffeaedd33c7f1c2ef54220e24a2a0e5a57b2..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/track_modules/utils.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# Modified from https://github.com/facebookresearch/vggsfm -# and https://github.com/facebookresearch/co-tracker/tree/main - - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from typing import Optional, Tuple, Union - - -def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: - """ - This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. - It is a wrapper of get_2d_sincos_pos_embed_from_grid. - Args: - - embed_dim: The embedding dimension. - - grid_size: The grid size. - Returns: - - pos_embed: The generated 2D positional embedding. - """ - if isinstance(grid_size, tuple): - grid_size_h, grid_size_w = grid_size - else: - grid_size_h = grid_size_w = grid_size - grid_h = torch.arange(grid_size_h, dtype=torch.float) - grid_w = torch.arange(grid_size_w, dtype=torch.float) - grid = torch.meshgrid(grid_w, grid_h, indexing="xy") - grid = torch.stack(grid, dim=0) - grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if return_grid: - return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) - return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) - - -def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: - """ - This function generates a 2D positional embedding from a given grid using sine and cosine functions. - - Args: - - embed_dim: The embedding dimension. - - grid: The grid to generate the embedding from. - - Returns: - - emb: The generated 2D positional embedding. - """ - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: - """ - This function generates a 1D positional embedding from a given grid using sine and cosine functions. - - Args: - - embed_dim: The embedding dimension. - - pos: The position to generate the embedding from. - - Returns: - - emb: The generated 1D positional embedding. - """ - assert embed_dim % 2 == 0 - omega = torch.arange(embed_dim // 2, dtype=torch.double) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = torch.sin(out) # (M, D/2) - emb_cos = torch.cos(out) # (M, D/2) - - emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) - return emb[None].float() - - -def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: - """ - This function generates a 2D positional embedding from given coordinates using sine and cosine functions. - - Args: - - xy: The coordinates to generate the embedding from. - - C: The size of the embedding. - - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. - - Returns: - - pe: The generated 2D positional embedding. - """ - B, N, D = xy.shape - assert D == 2 - - x = xy[:, :, 0:1] - y = xy[:, :, 1:2] - div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) - - pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) - pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) - - pe_x[:, :, 0::2] = torch.sin(x * div_term) - pe_x[:, :, 1::2] = torch.cos(x * div_term) - - pe_y[:, :, 0::2] = torch.sin(y * div_term) - pe_y[:, :, 1::2] = torch.cos(y * div_term) - - pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) - if cat_coords: - pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) - return pe - - -def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): - r"""Sample a tensor using bilinear interpolation - - `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at - coordinates :attr:`coords` using bilinear interpolation. It is the same - as `torch.nn.functional.grid_sample()` but with a different coordinate - convention. - - The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where - :math:`B` is the batch size, :math:`C` is the number of channels, - :math:`H` is the height of the image, and :math:`W` is the width of the - image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is - interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. - - Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, - in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note - that in this case the order of the components is slightly different - from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. - - If `align_corners` is `True`, the coordinate :math:`x` is assumed to be - in the range :math:`[0,W-1]`, with 0 corresponding to the center of the - left-most image pixel :math:`W-1` to the center of the right-most - pixel. - - If `align_corners` is `False`, the coordinate :math:`x` is assumed to - be in the range :math:`[0,W]`, with 0 corresponding to the left edge of - the left-most pixel :math:`W` to the right edge of the right-most - pixel. - - Similar conventions apply to the :math:`y` for the range - :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range - :math:`[0,T-1]` and :math:`[0,T]`. - - Args: - input (Tensor): batch of input images. - coords (Tensor): batch of coordinates. - align_corners (bool, optional): Coordinate convention. Defaults to `True`. - padding_mode (str, optional): Padding mode. Defaults to `"border"`. - - Returns: - Tensor: sampled points. - """ - coords = coords.detach().clone() - ############################################################ - # IMPORTANT: - coords = coords.to(input.device).to(input.dtype) - ############################################################ - - sizes = input.shape[2:] - - assert len(sizes) in [2, 3] - - if len(sizes) == 3: - # t x y -> x y t to match dimensions T H W in grid_sample - coords = coords[..., [1, 2, 0]] - - if align_corners: - scale = torch.tensor( - [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype - ) - else: - scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) - - coords.mul_(scale) # coords = coords * scale - coords.sub_(1) # coords = coords - 1 - - return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) - - -def sample_features4d(input, coords): - r"""Sample spatial features - - `sample_features4d(input, coords)` samples the spatial features - :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. - - The field is sampled at coordinates :attr:`coords` using bilinear - interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, - 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the - same convention as :func:`bilinear_sampler` with `align_corners=True`. - - The output tensor has one feature per point, and has shape :math:`(B, - R, C)`. - - Args: - input (Tensor): spatial features. - coords (Tensor): points. - - Returns: - Tensor: sampled features. - """ - - B, _, _, _ = input.shape - - # B R 2 -> B R 1 2 - coords = coords.unsqueeze(2) - - # B C R 1 - feats = bilinear_sampler(input, coords) - - return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/FastVGGT/vggt/heads/utils.py b/FastVGGT/vggt/heads/utils.py deleted file mode 100644 index 533fc8ae67a75cd0a94d5ca96dc5a0513446c64f..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/heads/utils.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - - -def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: - """ - Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) - - Args: - pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates - embed_dim: Output channel dimension for embeddings - - Returns: - Tensor of shape (H, W, embed_dim) with positional embeddings - """ - H, W, grid_dim = pos_grid.shape - assert grid_dim == 2 - pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) - - # Process x and y coordinates separately - emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] - emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] - - # Combine and reshape - emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] - - return emb.view(H, W, embed_dim) # [H, W, D] - - -def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: - """ - This function generates a 1D positional embedding from a given grid using sine and cosine functions. - - Args: - - embed_dim: The embedding dimension. - - pos: The position to generate the embedding from. - - Returns: - - emb: The generated 1D positional embedding. - """ - assert embed_dim % 2 == 0 - device = pos.device - omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) - omega /= embed_dim / 2.0 - omega = 1.0 / omega_0**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product - - emb_sin = torch.sin(out) # (M, D/2) - emb_cos = torch.cos(out) # (M, D/2) - - emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) - return emb.float() - - -# Inspired by https://github.com/microsoft/moge - - -def create_uv_grid( - width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None -) -> torch.Tensor: - """ - Create a normalized UV grid of shape (width, height, 2). - - The grid spans horizontally and vertically according to an aspect ratio, - ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right - corner is at (x_span, y_span), normalized by the diagonal of the plane. - - Args: - width (int): Number of points horizontally. - height (int): Number of points vertically. - aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. - dtype (torch.dtype, optional): Data type of the resulting tensor. - device (torch.device, optional): Device on which the tensor is created. - - Returns: - torch.Tensor: A (width, height, 2) tensor of UV coordinates. - """ - # Derive aspect ratio if not explicitly provided - if aspect_ratio is None: - aspect_ratio = float(width) / float(height) - - # Compute normalized spans for X and Y - diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 - span_x = aspect_ratio / diag_factor - span_y = 1.0 / diag_factor - - # Establish the linspace boundaries - left_x = -span_x * (width - 1) / width - right_x = span_x * (width - 1) / width - top_y = -span_y * (height - 1) / height - bottom_y = span_y * (height - 1) / height - - # Generate 1D coordinates - x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) - y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) - - # Create 2D meshgrid (width x height) and stack into UV - uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") - uv_grid = torch.stack((uu, vv), dim=-1) - - return uv_grid diff --git a/FastVGGT/vggt/layers/__init__.py b/FastVGGT/vggt/layers/__init__.py deleted file mode 100644 index 8120f4bc83066cb3f825ce32daa3b437f88486f1..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from .mlp import Mlp -from .patch_embed import PatchEmbed -from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused -from .block import NestedTensorBlock -from .attention import MemEffAttention diff --git a/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 2840683edfd99fcda3d1d5c9e18d699dba029a86..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc deleted file mode 100644 index 42e56e451e90459b8b74d99bd0ac9049e702e1bd..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc deleted file mode 100644 index 069b3ad92a285342f200297a7358dbf39c0ef372..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc deleted file mode 100644 index e0bc9cb8a27db08fc0921a2d8a54a60b2b126acd..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc deleted file mode 100644 index a416380412c8942981ab0e3f5369cebe05f9496c..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc deleted file mode 100644 index 053bf054f2a74eefeea0536eeddd23daacd28cbe..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc deleted file mode 100644 index 20fa7427ef4c285abd1d9b48743f4c3360953591..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc deleted file mode 100644 index af9854cf6446fb30404f4529d63b20c20f8e63e8..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc deleted file mode 100644 index c9b1fc5deb392addff7c8620b095e0ba4d90a8de..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc deleted file mode 100644 index d36ad71ca8dacfb3991b5b2560bc89cdfdd15c52..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/layers/attention.py b/FastVGGT/vggt/layers/attention.py deleted file mode 100644 index aef68a4e2c628ad257a6b77b46447c8421c5b11a..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/attention.py +++ /dev/null @@ -1,257 +0,0 @@ -import os -from pathlib import Path -import time -from PIL import Image -import torch -from torch import Tensor -from torch import nn -import torch.nn.functional as F -from tqdm.std import tqdm -from merging.merge import ( - token_merge_bipartite2d, -) -import matplotlib.pyplot as plt -import numpy as np -from PIL import Image - -XFORMERS_AVAILABLE = False - -# Global variables for attention visualization -vis_attn_map = False -current_images = [] -attention_map = None - - -class Attention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = True, - proj_bias: bool = True, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - norm_layer: nn.Module = nn.LayerNorm, - qk_norm: bool = False, - kv_group_size: int = 1, - fused_attn: bool = True, - rope=None, - global_merging=None, - patch_width: int = 37, - patch_height: int = 28, - ) -> None: - super().__init__() - assert dim % num_heads == 0, "dim should be divisible by num_heads" - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim**-0.5 - self.patch_width = patch_width - self.patch_height = patch_height - self.fused_attn = fused_attn - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim, bias=proj_bias) - self.proj_drop = nn.Dropout(proj_drop) - self.rope = rope - self.kv_group_size = kv_group_size - - def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor: - merge_num = list(range(24)) - - B, N, C = x.shape - qkv = ( - self.qkv(x) - .reshape(B, N, 3, self.num_heads, self.head_dim) - .permute(2, 0, 3, 1, 4) - ) - q, k, v = qkv.unbind(0) - q, k = self.q_norm(q), self.k_norm(k) - - if self.rope is not None: - q = self.rope(q, pos) - k = self.rope(k, pos) - - if vis_attn_map and global_merging is not None: - # s1 Chunk computation of q@k - chunk_size = 4096 - attn_maps = [] - total_chunks = ( - q.size(-2) + chunk_size - 1 - ) // chunk_size # Calculate total number of chunks - start_time_total = time.time() - with torch.no_grad(): - for start in tqdm( - range(0, q.size(-2), chunk_size), - total=total_chunks, - desc="Processing chunks", - ): - end = min(start + chunk_size, q.size(-2)) - q_chunk = q[:, :, start:end, :] # (1, 16, chunk_size, 64) - start_time_chunk = time.time() - attn_chunk = q_chunk @ k.transpose( - -2, -1 - ) # (1, 16, chunk_size, 34353) - attn_maps.append(attn_chunk.cpu()) - end_time_chunk = time.time() - print( - f"Chunk {start}:{end} processed in {end_time_chunk - start_time_chunk:.4f} seconds" - ) - end_time_total = time.time() - print( - f"\nTotal processing time: {end_time_total - start_time_total:.4f} seconds" - ) - - attn_map = torch.cat(attn_maps, dim=-2) - attn = attn_map[0].mean(0) - frame_token_num = self.patch_height * self.patch_width + 5 - for target_token_idx in [ - 0, - self.patch_height * self.patch_width, - self.patch_height * self.patch_width * 10, - ]: # Iterate through each image's target_token - for image_idx in range( - len(current_images) - ): # Corresponding to which image to visualize - target_attn = attn[ - target_token_idx, - image_idx * frame_token_num : (image_idx + 1) * frame_token_num, - ] - target_attn_map = target_attn[5:].reshape( - self.patch_height, self.patch_width - ) - # 1) Read original image to get true size (H, W) - image_path = current_images[image_idx] - p = Path(image_path) - parts = p.parts - scene_name = parts[-4] - - image = Image.open(image_path).convert("RGB") - img_width, img_height = image.size # PIL size: (W, H) - - # Upsample attention map to the original image size - target_attn_map = F.interpolate( - target_attn_map.unsqueeze(0).unsqueeze(0), # (1,1,h,w) - size=(img_height, img_width), - mode="bilinear", - ) - target_attn_map = target_attn_map.squeeze() - - # Convert image to numpy for blending - img_np = np.array(image) / 255.0 - - # 2. Normalize attention map - target_attn_map = (target_attn_map - target_attn_map.min()) / ( - target_attn_map.max() - target_attn_map.min() - ) - - # 3. Color attention map - cmap = plt.get_cmap("jet") - attn_color = cmap(target_attn_map.cpu().float().numpy()) - attn_color = attn_color[:, :, :3] - - # 4. Blend attention and original image - overlay = img_np * 0.5 + attn_color * 0.5 - - plt.imshow(overlay, cmap="viridis") - output_dir = f"attention_map/{scene_name}/block_{attention_map}/token_{target_token_idx}" - os.makedirs(output_dir, exist_ok=True) - output_path = os.path.join(output_dir, f"color_{image_idx}.png") - plt.savefig(output_path) - plt.close() - - if global_merging is not None and global_merging in merge_num: - generator = torch.Generator(device=x.device) - generator.manual_seed(33) - - merge_ratio = 0.9 - r = int(x.shape[1] * merge_ratio) - - m, u = token_merge_bipartite2d( - x, - self.patch_width, - self.patch_height, - 2, - 2, - r, - False, - generator, - enable_protection=True, - ) - - m_a, u_a = (m, u) - - B_q, H_q, N_q, D_q = q.shape - - q_merge_in = q.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q) - k_merge_in = k.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q) - v_merge_in = v.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q) - - q_out, k_out, v_out = m_a( - q_merge_in, - mode="mean", - extra_tensors=k_merge_in, - extra_tensors_2=v_merge_in, - ) - - del q_merge_in, k_merge_in, v_merge_in - - N_m = q_out.shape[1] - q = q_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3) - k = k_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3) - v = v_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3) - - del q_out, k_out, v_out - - N = N_m - - x = F.scaled_dot_product_attention( - q, - k, - v, - dropout_p=self.attn_drop.p if self.training else 0.0, - ) - del q, k, v - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - if global_merging is not None and global_merging in merge_num: - x = u_a(x) - return x - - -class MemEffAttention(Attention): - def forward( - self, x: Tensor, attn_bias=None, pos=None, global_merging=None - ) -> Tensor: - assert ( - pos is None or self.rope is not None - ), "Position encoding is only supported with RoPE" - if not XFORMERS_AVAILABLE: - if attn_bias is not None: - raise AssertionError("xFormers is required for using nested tensors") - return super().forward(x, pos=pos, global_merging=global_merging) - - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) - - q, k, v = qkv.unbind(2) - - if self.rope is not None and pos is not None: - q = self.rope(q, pos) - k = self.rope(k, pos) - - # Use scaled dot-product attention - attn = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim**-0.5) - if attn_bias is not None: - attn = attn + attn_bias - attn = F.softmax(attn, dim=-1) - x = torch.matmul(attn, v) - x = x.reshape([B, N, C]) - - x = self.proj(x) - x = self.proj_drop(x) - return x diff --git a/FastVGGT/vggt/layers/block.py b/FastVGGT/vggt/layers/block.py deleted file mode 100644 index ffa3c18aaf79fc755630d24c24c1f4c7f75624e1..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/block.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py - - -from typing import Callable, List, Any, Tuple, Dict -import torch -from torch import nn, Tensor -from .attention import Attention -from .drop_path import DropPath -from .layer_scale import LayerScale -from .mlp import Mlp - - -XFORMERS_AVAILABLE = False - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - proj_bias: bool = True, - ffn_bias: bool = True, - drop: float = 0.0, - attn_drop: float = 0.0, - init_values=None, - drop_path: float = 0.0, - act_layer: Callable[..., nn.Module] = nn.GELU, - norm_layer: Callable[..., nn.Module] = nn.LayerNorm, - attn_class: Callable[..., nn.Module] = Attention, - ffn_layer: Callable[..., nn.Module] = Mlp, - qk_norm: bool = False, - fused_attn: bool = True, - rope=None, - merging=0, - ) -> None: - super().__init__() - - self.norm1 = norm_layer(dim) - - self.attn = attn_class( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - attn_drop=attn_drop, - proj_drop=drop, - qk_norm=qk_norm, - fused_attn=fused_attn, - rope=rope, - ) - - self.ls1 = ( - LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - ) - self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = ffn_layer( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - bias=ffn_bias, - ) - self.ls2 = ( - LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - ) - self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - - self.sample_drop_ratio = drop_path - - def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor: - norm1_output = self.norm1(x) - attn_output = self.attn(norm1_output, pos=pos, global_merging=global_merging) - del norm1_output - x = x + self.ls1(attn_output) - del attn_output - - norm2_output = self.norm2(x) - mlp_output = self.mlp(norm2_output) - del norm2_output - x = x + self.ls2(mlp_output) - del mlp_output - return x - - -def drop_add_residual_stochastic_depth( - x: Tensor, - residual_func: Callable[[Tensor], Tensor], - sample_drop_ratio: float = 0.0, - pos=None, -) -> Tensor: - # 1) extract subset using permutation - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - x_subset = x[brange] - - # 2) apply residual_func to get residual - if pos is not None: - # if necessary, apply rope to the subset - pos = pos[brange] - residual = residual_func(x_subset, pos=pos) - else: - residual = residual_func(x_subset) - - x_flat = x.flatten(1) - residual = residual.flatten(1) - - residual_scale_factor = b / sample_subset_size - - # 3) add the residual - x_plus_residual = torch.index_add( - x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor - ) - return x_plus_residual.view_as(x) - - -def get_branges_scales(x, sample_drop_ratio=0.0): - b, n, d = x.shape - sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) - brange = (torch.randperm(b, device=x.device))[:sample_subset_size] - residual_scale_factor = b / sample_subset_size - return brange, residual_scale_factor - - -def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): - if scaling_vector is None: - x_flat = x.flatten(1) - residual = residual.flatten(1) - x_plus_residual = torch.index_add( - x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor - ) - else: - x_plus_residual = scaled_index_add( - x, - brange, - residual.to(dtype=x.dtype), - scaling=scaling_vector, - alpha=residual_scale_factor, - ) - return x_plus_residual - - -attn_bias_cache: Dict[Tuple, Any] = {} - - -def get_attn_bias_and_cat(x_list, branges=None): - """ - this will perform the index select, cat the tensors, and provide the attn_bias from cache - """ - batch_sizes = ( - [b.shape[0] for b in branges] - if branges is not None - else [x.shape[0] for x in x_list] - ) - all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) - if all_shapes not in attn_bias_cache.keys(): - seqlens = [] - for b, x in zip(batch_sizes, x_list): - for _ in range(b): - seqlens.append(x.shape[1]) - attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) - attn_bias._batch_sizes = batch_sizes - attn_bias_cache[all_shapes] = attn_bias - - if branges is not None: - cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( - 1, -1, x_list[0].shape[-1] - ) - else: - tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) - cat_tensors = torch.cat(tensors_bs1, dim=1) - - return attn_bias_cache[all_shapes], cat_tensors - - -def drop_add_residual_stochastic_depth_list( - x_list: List[Tensor], - residual_func: Callable[[Tensor, Any], Tensor], - sample_drop_ratio: float = 0.0, - scaling_vector=None, -) -> Tensor: - # 1) generate random set of indices for dropping samples in the batch - branges_scales = [ - get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list - ] - branges = [s[0] for s in branges_scales] - residual_scale_factors = [s[1] for s in branges_scales] - - # 2) get attention bias and index+concat the tensors - attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) - - # 3) apply residual_func to get residual, and split the result - residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore - - outputs = [] - for x, brange, residual, residual_scale_factor in zip( - x_list, branges, residual_list, residual_scale_factors - ): - outputs.append( - add_residual( - x, brange, residual, residual_scale_factor, scaling_vector - ).view_as(x) - ) - return outputs - - -class NestedTensorBlock(Block): - def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: - """ - x_list contains a list of tensors to nest together and run - """ - assert isinstance(self.attn, MemEffAttention) - - if self.training and self.sample_drop_ratio > 0.0: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.attn(self.norm1(x), attn_bias=attn_bias) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.mlp(self.norm2(x)) - - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=attn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=( - self.ls1.gamma if isinstance(self.ls1, LayerScale) else None - ), - ) - x_list = drop_add_residual_stochastic_depth_list( - x_list, - residual_func=ffn_residual_func, - sample_drop_ratio=self.sample_drop_ratio, - scaling_vector=( - self.ls2.gamma if isinstance(self.ls1, LayerScale) else None - ), - ) - return x_list - else: - - def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) - - def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: - return self.ls2(self.mlp(self.norm2(x))) - - attn_bias, x = get_attn_bias_and_cat(x_list) - x = x + attn_residual_func(x, attn_bias=attn_bias) - x = x + ffn_residual_func(x) - return attn_bias.split(x) - - def forward(self, x_or_x_list): - if isinstance(x_or_x_list, Tensor): - return super().forward(x_or_x_list) - elif isinstance(x_or_x_list, list): - if not XFORMERS_AVAILABLE: - raise AssertionError("xFormers is required for using nested tensors") - return self.forward_nested(x_or_x_list) - else: - raise AssertionError diff --git a/FastVGGT/vggt/layers/drop_path.py b/FastVGGT/vggt/layers/drop_path.py deleted file mode 100644 index 1d640e0b969b8dcba96260243473700b4e5b24b5..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/drop_path.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py - - -from torch import nn - - -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = x.new_empty(shape).bernoulli_(keep_prob) - if keep_prob > 0.0: - random_tensor.div_(keep_prob) - output = x * random_tensor - return output - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) diff --git a/FastVGGT/vggt/layers/layer_scale.py b/FastVGGT/vggt/layers/layer_scale.py deleted file mode 100644 index 4ddfc51c3d87370d50175f5b4e649dac1c614ff9..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/layer_scale.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 - -from typing import Union - -import torch -from torch import Tensor -from torch import nn - - -class LayerScale(nn.Module): - def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: Tensor) -> Tensor: - return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/FastVGGT/vggt/layers/mlp.py b/FastVGGT/vggt/layers/mlp.py deleted file mode 100644 index bbf9432aae9258612caeae910a7bde17999e328e..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/mlp.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py - - -from typing import Callable, Optional - -from torch import Tensor, nn - - -class Mlp(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = nn.GELU, - drop: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) - self.drop = nn.Dropout(drop) - - def forward(self, x: Tensor) -> Tensor: - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x diff --git a/FastVGGT/vggt/layers/patch_embed.py b/FastVGGT/vggt/layers/patch_embed.py deleted file mode 100644 index bc19605e4d6e88d06355ae3b1afddc76f595aafe..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/patch_embed.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py - -from typing import Callable, Optional, Tuple, Union - -from torch import Tensor -import torch.nn as nn - - -def make_2tuple(x): - if isinstance(x, tuple): - assert len(x) == 2 - return x - - assert isinstance(x, int) - return (x, x) - - -class PatchEmbed(nn.Module): - """ - 2D image to patch embedding: (B,C,H,W) -> (B,N,D) - - Args: - img_size: Image size. - patch_size: Patch token size. - in_chans: Number of input image channels. - embed_dim: Number of linear projection output channels. - norm_layer: Normalization layer. - """ - - def __init__( - self, - img_size: Union[int, Tuple[int, int]] = 224, - patch_size: Union[int, Tuple[int, int]] = 16, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer: Optional[Callable] = None, - flatten_embedding: bool = True, - ) -> None: - super().__init__() - - image_HW = make_2tuple(img_size) - patch_HW = make_2tuple(patch_size) - patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) - - self.img_size = image_HW - self.patch_size = patch_HW - self.patches_resolution = patch_grid_size - self.num_patches = patch_grid_size[0] * patch_grid_size[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.flatten_embedding = flatten_embedding - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - _, _, H, W = x.shape - patch_H, patch_W = self.patch_size - - assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" - assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" - - x = self.proj(x) # B C H W - H, W = x.size(2), x.size(3) - x = x.flatten(2).transpose(1, 2) # B HW C - x = self.norm(x) - if not self.flatten_embedding: - x = x.reshape(-1, H, W, self.embed_dim) # B H W C - return x - - def flops(self) -> float: - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops diff --git a/FastVGGT/vggt/layers/rope.py b/FastVGGT/vggt/layers/rope.py deleted file mode 100644 index 84625de468ed89e69dd9e1579d541de71f2ebf37..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/rope.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - - -# Implementation of 2D Rotary Position Embeddings (RoPE). - -# This module provides a clean implementation of 2D Rotary Position Embeddings, -# which extends the original RoPE concept to handle 2D spatial positions. - -# Inspired by: -# https://github.com/meta-llama/codellama/blob/main/llama/model.py -# https://github.com/naver-ai/rope-vit - - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Dict, Tuple - - -class PositionGetter: - """Generates and caches 2D spatial positions for patches in a grid. - - This class efficiently manages the generation of spatial coordinates for patches - in a 2D grid, caching results to avoid redundant computations. - - Attributes: - position_cache: Dictionary storing precomputed position tensors for different - grid dimensions. - """ - - def __init__(self): - """Initializes the position generator with an empty cache.""" - self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} - - def __call__( - self, batch_size: int, height: int, width: int, device: torch.device - ) -> torch.Tensor: - """Generates spatial positions for a batch of patches. - - Args: - batch_size: Number of samples in the batch. - height: Height of the grid in patches. - width: Width of the grid in patches. - device: Target device for the position tensor. - - Returns: - Tensor of shape (batch_size, height*width, 2) containing y,x coordinates - for each position in the grid, repeated for each batch item. - """ - if (height, width) not in self.position_cache: - y_coords = torch.arange(height, device=device) - x_coords = torch.arange(width, device=device) - positions = torch.cartesian_prod(y_coords, x_coords) - self.position_cache[height, width] = positions - - cached_positions = self.position_cache[height, width] - return ( - cached_positions.view(1, height * width, 2) - .expand(batch_size, -1, -1) - .clone() - ) - - -class RotaryPositionEmbedding2D(nn.Module): - """2D Rotary Position Embedding implementation. - - This module applies rotary position embeddings to input tokens based on their - 2D spatial positions. It handles the position-dependent rotation of features - separately for vertical and horizontal dimensions. - - Args: - frequency: Base frequency for the position embeddings. Default: 100.0 - scaling_factor: Scaling factor for frequency computation. Default: 1.0 - - Attributes: - base_frequency: Base frequency for computing position embeddings. - scaling_factor: Factor to scale the computed frequencies. - frequency_cache: Cache for storing precomputed frequency components. - """ - - def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): - """Initializes the 2D RoPE module.""" - super().__init__() - self.base_frequency = frequency - self.scaling_factor = scaling_factor - self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} - - def _compute_frequency_components( - self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes frequency components for rotary embeddings. - - Args: - dim: Feature dimension (must be even). - seq_len: Maximum sequence length. - device: Target device for computations. - dtype: Data type for the computed tensors. - - Returns: - Tuple of (cosine, sine) tensors for frequency components. - """ - cache_key = (dim, seq_len, device, dtype) - if cache_key not in self.frequency_cache: - # Compute frequency bands - exponents = torch.arange(0, dim, 2, device=device) / dim - inv_freq = 1.0 / (self.base_frequency**exponents) - - # Generate position-dependent frequencies - positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - angles = torch.einsum("i,j->ij", positions, inv_freq) - - # Compute and cache frequency components - angles = angles.to(dtype) - angles = torch.cat((angles, angles), dim=-1) - cos_components = angles.cos().to(dtype) - sin_components = angles.sin().to(dtype) - self.frequency_cache[cache_key] = (cos_components, sin_components) - - return self.frequency_cache[cache_key] - - @staticmethod - def _rotate_features(x: torch.Tensor) -> torch.Tensor: - """Performs feature rotation by splitting and recombining feature dimensions. - - Args: - x: Input tensor to rotate. - - Returns: - Rotated feature tensor. - """ - feature_dim = x.shape[-1] - x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _apply_1d_rope( - self, - tokens: torch.Tensor, - positions: torch.Tensor, - cos_comp: torch.Tensor, - sin_comp: torch.Tensor, - ) -> torch.Tensor: - """Applies 1D rotary position embeddings along one dimension. - - Args: - tokens: Input token features. - positions: Position indices. - cos_comp: Cosine components for rotation. - sin_comp: Sine components for rotation. - - Returns: - Tokens with applied rotary position embeddings. - """ - if positions.dtype != torch.long: - positions = positions.long() - - # Embed positions with frequency components - cos = F.embedding(positions, cos_comp)[:, None, :, :] - sin = F.embedding(positions, sin_comp)[:, None, :, :] - - # Apply rotation - return (tokens * cos) + (self._rotate_features(tokens) * sin) - - def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: - """Applies 2D rotary position embeddings to input tokens. - - Args: - tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). - The feature dimension (dim) must be divisible by 4. - positions: Position tensor of shape (batch_size, n_tokens, 2) containing - the y and x coordinates for each token. - - Returns: - Tensor of same shape as input with applied 2D rotary position embeddings. - - Raises: - AssertionError: If input dimensions are invalid or positions are malformed. - """ - # Validate inputs - assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" - assert ( - positions.ndim == 3 and positions.shape[-1] == 2 - ), "Positions must have shape (batch_size, n_tokens, 2)" - - # Compute feature dimension for each spatial direction - feature_dim = tokens.size(-1) // 2 - - # Get frequency components - max_position = int(positions.max()) + 1 - cos_comp, sin_comp = self._compute_frequency_components( - feature_dim, max_position, tokens.device, tokens.dtype - ) - - # Split features for vertical and horizontal processing - vertical_features, horizontal_features = tokens.chunk(2, dim=-1) - - # Apply RoPE separately for each dimension - vertical_features = self._apply_1d_rope( - vertical_features, positions[..., 0], cos_comp, sin_comp - ) - horizontal_features = self._apply_1d_rope( - horizontal_features, positions[..., 1], cos_comp, sin_comp - ) - - # Combine processed features - return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/FastVGGT/vggt/layers/swiglu_ffn.py b/FastVGGT/vggt/layers/swiglu_ffn.py deleted file mode 100644 index 1dd991e1deb87141ccd282098d4b9d38fed6ef25..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/swiglu_ffn.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -import os -from typing import Callable, Optional -import warnings - -from torch import Tensor, nn -import torch.nn.functional as F - - -class SwiGLUFFN(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) - self.w3 = nn.Linear(hidden_features, out_features, bias=bias) - - def forward(self, x: Tensor) -> Tensor: - x12 = self.w12(x) - x1, x2 = x12.chunk(2, dim=-1) - hidden = F.silu(x1) * x2 - return self.w3(hidden) - - -XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None -# try: -# if XFORMERS_ENABLED: -# from xformers.ops import SwiGLU - -# XFORMERS_AVAILABLE = True -# warnings.warn("xFormers is available (SwiGLU)") -# else: -# warnings.warn("xFormers is disabled (SwiGLU)") -# raise ImportError -# except ImportError: -SwiGLU = SwiGLUFFN -XFORMERS_AVAILABLE = False - -# warnings.warn("xFormers is not available (SwiGLU)") - - -class SwiGLUFFNFused(SwiGLU): - def __init__( - self, - in_features: int, - hidden_features: Optional[int] = None, - out_features: Optional[int] = None, - act_layer: Callable[..., nn.Module] = None, - drop: float = 0.0, - bias: bool = True, - ) -> None: - out_features = out_features or in_features - hidden_features = hidden_features or in_features - hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 - super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) diff --git a/FastVGGT/vggt/layers/vision_transformer.py b/FastVGGT/vggt/layers/vision_transformer.py deleted file mode 100644 index 2e1aee388a9c38168657f78fa9436685582068c6..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/layers/vision_transformer.py +++ /dev/null @@ -1,446 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - -# References: -# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py -# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py - -from functools import partial -import math -import logging -from typing import Sequence, Tuple, Union, Callable - -import torch -import torch.nn as nn -from torch.utils.checkpoint import checkpoint -from torch.nn.init import trunc_normal_ -from . import ( - Mlp, - PatchEmbed, - SwiGLUFFNFused, - MemEffAttention, - NestedTensorBlock as Block, -) - -logger = logging.getLogger("dinov2") - - -def named_apply( - fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False -) -> nn.Module: - if not depth_first and include_root: - fn(module=module, name=name) - for child_name, child_module in module.named_children(): - child_name = ".".join((name, child_name)) if name else child_name - named_apply( - fn=fn, - module=child_module, - name=child_name, - depth_first=depth_first, - include_root=True, - ) - if depth_first and include_root: - fn(module=module, name=name) - return module - - -class BlockChunk(nn.ModuleList): - def forward(self, x): - for b in self: - x = b(x) - return x - - -class DinoVisionTransformer(nn.Module): - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - ffn_bias=True, - proj_bias=True, - drop_path_rate=0.0, - drop_path_uniform=False, - init_values=None, # for layerscale: None or 0 => no layerscale - embed_layer=PatchEmbed, - act_layer=nn.GELU, - block_fn=Block, - ffn_layer="mlp", - block_chunks=1, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - qk_norm=False, - ): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - embed_dim (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - proj_bias (bool): enable bias for proj in attn if True - ffn_bias (bool): enable bias for ffn if True - drop_path_rate (float): stochastic depth rate - drop_path_uniform (bool): apply uniform drop rate across blocks - weight_init (str): weight init scheme - init_values (float): layer-scale init values - embed_layer (nn.Module): patch embedding layer - act_layer (nn.Module): MLP activation layer - block_fn (nn.Module): transformer block class - ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" - block_chunks: (int) split block sequence into block_chunks units for FSDP wrap - num_register_tokens: (int) number of extra cls tokens (so-called "registers") - interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings - interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings - """ - super().__init__() - norm_layer = partial(nn.LayerNorm, eps=1e-6) - - self.num_features = self.embed_dim = ( - embed_dim # num_features for consistency with other models - ) - self.num_tokens = 1 - self.n_blocks = depth - self.num_heads = num_heads - self.patch_size = patch_size - self.num_register_tokens = num_register_tokens - self.interpolate_antialias = interpolate_antialias - self.interpolate_offset = interpolate_offset - self.use_reentrant = False # hardcoded to False - - self.patch_embed = embed_layer( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - ) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter( - torch.zeros(1, num_patches + self.num_tokens, embed_dim) - ) - assert num_register_tokens >= 0 - self.register_tokens = ( - nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) - if num_register_tokens - else None - ) - - if drop_path_uniform is True: - dpr = [drop_path_rate] * depth - else: - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, depth) - ] # stochastic depth decay rule - - if ffn_layer == "mlp": - logger.info("using MLP layer as FFN") - ffn_layer = Mlp - elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": - logger.info("using SwiGLU layer as FFN") - ffn_layer = SwiGLUFFNFused - elif ffn_layer == "identity": - logger.info("using Identity layer as FFN") - - def f(*args, **kwargs): - return nn.Identity() - - ffn_layer = f - else: - raise NotImplementedError - - blocks_list = [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - drop_path=dpr[i], - norm_layer=norm_layer, - act_layer=act_layer, - ffn_layer=ffn_layer, - init_values=init_values, - qk_norm=qk_norm, - ) - for i in range(depth) - ] - if block_chunks > 0: - self.chunked_blocks = True - chunked_blocks = [] - chunksize = depth // block_chunks - for i in range(0, depth, chunksize): - # this is to keep the block index consistent if we chunk the block list - chunked_blocks.append( - [nn.Identity()] * i + blocks_list[i : i + chunksize] - ) - self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) - else: - self.chunked_blocks = False - self.blocks = nn.ModuleList(blocks_list) - - self.norm = norm_layer(embed_dim) - self.head = nn.Identity() - - self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) - - self.init_weights() - - def init_weights(self): - trunc_normal_(self.pos_embed, std=0.02) - nn.init.normal_(self.cls_token, std=1e-6) - if self.register_tokens is not None: - nn.init.normal_(self.register_tokens, std=1e-6) - named_apply(init_weights_vit_timm, self) - - def interpolate_pos_encoding(self, x, w, h): - previous_dtype = x.dtype - npatch = x.shape[1] - 1 - N = self.pos_embed.shape[1] - 1 - if npatch == N and w == h: - return self.pos_embed - pos_embed = self.pos_embed - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - dim = x.shape[-1] - w0 = w // self.patch_size - h0 = h // self.patch_size - M = int(math.sqrt(N)) # Recover the number of patches in each dimension - assert N == M * M - kwargs = {} - if self.interpolate_offset: - # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 - # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors - sx = float(w0 + self.interpolate_offset) / M - sy = float(h0 + self.interpolate_offset) / M - kwargs["scale_factor"] = (sx, sy) - else: - # Simply specify an output size instead of a scale factor - kwargs["size"] = (w0, h0) - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), - mode="bicubic", - antialias=self.interpolate_antialias, - **kwargs, - ) - assert (w0, h0) == patch_pos_embed.shape[-2:] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( - previous_dtype - ) - - def prepare_tokens_with_masks(self, x, masks=None): - B, nc, w, h = x.shape - x = self.patch_embed(x) - if masks is not None: - x = torch.where( - masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x - ) - - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.interpolate_pos_encoding(x, w, h) - - if self.register_tokens is not None: - x = torch.cat( - (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), - dim=1, - ) - - return x - - def forward_features_list(self, x_list, masks_list): - x = [ - self.prepare_tokens_with_masks(x, masks) - for x, masks in zip(x_list, masks_list) - ] - - for blk in self.blocks: - if self.training: - x = checkpoint(blk, x, use_reentrant=self.use_reentrant) - else: - x = blk(x) - - all_x = x - output = [] - for x, masks in zip(all_x, masks_list): - x_norm = self.norm(x) - output.append( - { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, - } - ) - return output - - def forward_features(self, x, masks=None): - if isinstance(x, list): - return self.forward_features_list(x, masks) - - x = self.prepare_tokens_with_masks(x, masks) - - for blk in self.blocks: - if self.training: - x = checkpoint(blk, x, use_reentrant=self.use_reentrant) - else: - x = blk(x) - - x_norm = self.norm(x) - return { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, - } - - def _get_intermediate_layers_not_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - # If n is an int, take the n last blocks. If it's a list, take them - output, total_block_len = [], len(self.blocks) - blocks_to_take = ( - range(total_block_len - n, total_block_len) if isinstance(n, int) else n - ) - for i, blk in enumerate(self.blocks): - x = blk(x) - if i in blocks_to_take: - output.append(x) - assert len(output) == len( - blocks_to_take - ), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def _get_intermediate_layers_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - output, i, total_block_len = [], 0, len(self.blocks[-1]) - # If n is an int, take the n last blocks. If it's a list, take them - blocks_to_take = ( - range(total_block_len - n, total_block_len) if isinstance(n, int) else n - ) - for block_chunk in self.blocks: - for blk in block_chunk[i:]: # Passing the nn.Identity() - x = blk(x) - if i in blocks_to_take: - output.append(x) - i += 1 - assert len(output) == len( - blocks_to_take - ), f"only {len(output)} / {len(blocks_to_take)} blocks found" - return output - - def get_intermediate_layers( - self, - x: torch.Tensor, - n: Union[int, Sequence] = 1, # Layers or n last layers to take - reshape: bool = False, - return_class_token: bool = False, - norm=True, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: - if self.chunked_blocks: - outputs = self._get_intermediate_layers_chunked(x, n) - else: - outputs = self._get_intermediate_layers_not_chunked(x, n) - if norm: - outputs = [self.norm(out) for out in outputs] - class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] - if reshape: - B, _, w, h = x.shape - outputs = [ - out.reshape(B, w // self.patch_size, h // self.patch_size, -1) - .permute(0, 3, 1, 2) - .contiguous() - for out in outputs - ] - if return_class_token: - return tuple(zip(outputs, class_tokens)) - return tuple(outputs) - - def forward(self, *args, is_training=True, **kwargs): - ret = self.forward_features(*args, **kwargs) - if is_training: - return ret - else: - return self.head(ret["x_norm_clstoken"]) - - -def init_weights_vit_timm(module: nn.Module, name: str = ""): - """ViT weight initialization, original timm impl (for reproducibility)""" - if isinstance(module, nn.Linear): - trunc_normal_(module.weight, std=0.02) - if module.bias is not None: - nn.init.zeros_(module.bias) - - -def vit_small(patch_size=16, num_register_tokens=0, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=384, - depth=12, - num_heads=6, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - return model - - -def vit_base(patch_size=16, num_register_tokens=0, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - return model - - -def vit_large(patch_size=16, num_register_tokens=0, **kwargs): - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - return model - - -def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): - """ - Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 - """ - model = DinoVisionTransformer( - patch_size=patch_size, - embed_dim=1536, - depth=40, - num_heads=24, - mlp_ratio=4, - block_fn=partial(Block, attn_class=MemEffAttention), - num_register_tokens=num_register_tokens, - **kwargs, - ) - return model diff --git a/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc b/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc deleted file mode 100644 index ada2738eee9604acf0497bcafdfc52aa6f21119f..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc b/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc deleted file mode 100644 index 67f693db1cd59d1a8d87e8f5c1c77c57c9de47d2..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/models/aggregator.py b/FastVGGT/vggt/models/aggregator.py deleted file mode 100644 index 9e45bb204729f25a963fb083c158142ed01515c2..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/models/aggregator.py +++ /dev/null @@ -1,492 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from numpy import block -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint -from typing import Optional, Tuple, Union, List, Dict, Any - -from vggt.layers import PatchEmbed -from vggt.layers.block import Block -from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter -from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 -import time - -logger = logging.getLogger(__name__) - -_RESNET_MEAN = [0.485, 0.456, 0.406] -_RESNET_STD = [0.229, 0.224, 0.225] - - -class Aggregator(nn.Module): - """ - The Aggregator applies alternating-attention over input frames, - as described in VGGT: Visual Geometry Grounded Transformer. - - Remember to set model.train() to enable gradient checkpointing to reduce memory usage. - - Args: - img_size (int): Image size in pixels. - patch_size (int): Size of each patch for PatchEmbed. - embed_dim (int): Dimension of the token embeddings. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. - num_register_tokens (int): Number of register tokens. - block_fn (nn.Module): The block type used for attention (Block by default). - qkv_bias (bool): Whether to include bias in QKV projections. - proj_bias (bool): Whether to include bias in the output projection. - ffn_bias (bool): Whether to include bias in MLP layers. - patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". - aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. - aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. - qk_norm (bool): Whether to apply QK normalization. - rope_freq (int): Base frequency for rotary embedding. -1 to disable. - init_values (float): Init scale for layer scale. - """ - - def __init__( - self, - img_size=518, - patch_size=14, - embed_dim=1024, - depth=24, - num_heads=16, - mlp_ratio=4.0, - num_register_tokens=4, - block_fn=Block, - qkv_bias=True, - proj_bias=True, - ffn_bias=True, - patch_embed="dinov2_vitl14_reg", - aa_order=["frame", "global"], - aa_block_size=1, - qk_norm=True, - rope_freq=100, - init_values=0.01, - global_merging=True, - merging=0, - vis_attn_map=False, - ): - super().__init__() - - self.__build_patch_embed__( - patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim - ) - - # Initialize rotary position embedding if frequency > 0 - self.rope = ( - RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None - ) - self.position_getter = PositionGetter() if self.rope is not None else None - self.global_merging = global_merging - self.merging = merging - self.vis_attn_map = vis_attn_map - self.frame_blocks = nn.ModuleList( - [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - init_values=init_values, - qk_norm=qk_norm, - rope=self.rope, - ) - for _ in range(depth) - ] - ) - - self.global_blocks = nn.ModuleList( - [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - init_values=init_values, - qk_norm=qk_norm, - rope=self.rope, - ) - for _ in range(depth) - ] - ) - - self.depth = depth - self.aa_order = aa_order - self.patch_size = patch_size - self.aa_block_size = aa_block_size - - # Validate that depth is divisible by aa_block_size - if self.depth % self.aa_block_size != 0: - raise ValueError( - f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})" - ) - - self.aa_block_num = self.depth // self.aa_block_size - - # Note: We have two camera tokens, one for the first frame and one for the rest - # The same applies for register tokens - self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) - self.register_token = nn.Parameter( - torch.randn(1, 2, num_register_tokens, embed_dim) - ) - - # The patch tokens start after the camera and register tokens - self.patch_start_idx = 1 + num_register_tokens - - # Initialize parameters with small values - nn.init.normal_(self.camera_token, std=1e-6) - nn.init.normal_(self.register_token, std=1e-6) - - # Register normalization constants as buffers (use bf16-compatible tensor) - for name, value in ( - ("_resnet_mean", _RESNET_MEAN), - ("_resnet_std", _RESNET_STD), - ): - self.register_buffer( - name, - torch.tensor(value, dtype=torch.bfloat16).view(1, 1, 3, 1, 1), - persistent=False, - ) - - self.use_reentrant = False # hardcoded to False - - def __build_patch_embed__( - self, - patch_embed, - img_size, - patch_size, - num_register_tokens, - interpolate_antialias=True, - interpolate_offset=0.0, - block_chunks=0, - init_values=1.0, - embed_dim=1024, - ): - """ - Build the patch embed layer. If 'conv', we use a - simple PatchEmbed conv layer. Otherwise, we use a vision transformer. - """ - - if "conv" in patch_embed: - self.patch_embed = PatchEmbed( - img_size=img_size, - patch_size=patch_size, - in_chans=3, - embed_dim=embed_dim, - ) - else: - vit_models = { - "dinov2_vitl14_reg": vit_large, - "dinov2_vitb14_reg": vit_base, - "dinov2_vits14_reg": vit_small, - "dinov2_vitg2_reg": vit_giant2, - } - - self.patch_embed = vit_models[patch_embed]( - img_size=img_size, - patch_size=patch_size, - num_register_tokens=num_register_tokens, - interpolate_antialias=interpolate_antialias, - interpolate_offset=interpolate_offset, - block_chunks=block_chunks, - init_values=init_values, - ) - - # Disable gradient updates for mask token - if hasattr(self.patch_embed, "mask_token"): - self.patch_embed.mask_token.requires_grad_(False) - - def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]: - """ - Args: - images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. - B: batch size, S: sequence length, 3: RGB channels, H: height, W: width - - Returns: - (list[torch.Tensor], int): - The list of outputs from the attention blocks, - and the patch_start_idx indicating where patch tokens begin. - """ - B, S, C_in, H, W = images.shape - - if C_in != 3: - raise ValueError(f"Expected 3 input channels, got {C_in}") - - # Normalize images and reshape for patch embed - ensure bf16 computation - images = images.to(torch.bfloat16) - images = (images - self._resnet_mean) / self._resnet_std - - images = images.view(B * S, C_in, H, W) - patch_tokens = self.patch_embed(images) - del images - - if isinstance(patch_tokens, dict): - patch_tokens = patch_tokens["x_norm_patchtokens"] - - patch_tokens = patch_tokens.to(torch.bfloat16) - - _, P, C = patch_tokens.shape - - # Expand camera and register tokens to match batch size and sequence length - camera_token = slice_expand_and_flatten( - self.camera_token.to(torch.bfloat16), B, S - ) - register_token = slice_expand_and_flatten( - self.register_token.to(torch.bfloat16), B, S - ) - - tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) - tokens = tokens.to(torch.bfloat16) - del camera_token, register_token, patch_tokens - # Explicitly clean up image data since patch embedding is complete - if "images_normalized" in locals(): - del images_normalized - - pos = None - if self.rope is not None: - pos = self.position_getter( - B * S, H // self.patch_size, W // self.patch_size, device="cuda" - ) - - if self.patch_start_idx > 0: - # do not use position embedding for special tokens (camera and register tokens) - # so set pos to 0 for the special tokens - pos_original = pos - pos = pos + 1 - pos_special = torch.zeros( - B * S, self.patch_start_idx, 2, device="cuda", dtype=torch.long - ) - pos = torch.cat([pos_special, pos], dim=1) - # Clean up temporary variables - del pos_special, pos_original - - # update P because we added special tokens - _, P, C = tokens.shape - - frame_idx = 0 - global_idx = 0 - output_list = [] - block4DPT_idx = [4, 11, 17, 23] - global_merging = None - - # Set global variables for attention visualization - if self.vis_attn_map: - import vggt.layers.attention as attn_module - - # Set the global variables that attention.py needs - attn_module.vis_attn_map = True - attn_module.current_images = self._load_image_paths() # Load from temp file - else: - import vggt.layers.attention as attn_module - - attn_module.vis_attn_map = False - - for block_num in range(self.aa_block_num): - torch.cuda.synchronize() - torch.cuda.empty_cache() - - need_intermediates = True if block_num in block4DPT_idx else False - if block_num % 1 == 0: - # Clean up RoPE cache to prevent accumulation - if hasattr(self, "rope") and self.rope is not None: - if hasattr(self.rope, "frequency_cache"): - self.rope.frequency_cache.clear() - # Clean up position cache - if ( - hasattr(self, "position_getter") - and self.position_getter is not None - ): - if hasattr(self.position_getter, "position_cache"): - # Keep only current size cache, clean up others - current_cache = self.position_getter.position_cache.copy() - if ( - len(current_cache) > 1 - ): # If there are multiple cache entries - self.position_getter.position_cache.clear() - # Keep only the most recently used one - if current_cache: - key = list(current_cache.keys())[-1] - self.position_getter.position_cache[key] = ( - current_cache[key] - ) - # Avoid saving block_num to instance variable to reduce references - for attn_type in self.aa_order: - if attn_type == "frame": - tokens, frame_idx, frame_intermediates = ( - self._process_frame_attention( - tokens, - B, - S, - P, - C, - frame_idx, - pos=pos, - need_intermediates=need_intermediates, - ) - ) - elif attn_type == "global": - if self.merging is None: - global_merging = None - elif self.global_merging and block_num >= self.merging: - global_merging = block_num - # Set attention_map for visualization - if self.vis_attn_map: - import vggt.layers.attention as attn_module - - attn_module.attention_map = block_num - tokens, global_idx, global_intermediates = ( - self._process_global_attention( - tokens, - B, - S, - P, - C, - global_idx, - pos=pos, - global_merging=global_merging, - need_intermediates=need_intermediates, - ) - ) - else: - raise ValueError(f"Unknown attention type: {attn_type}") - - if block_num not in block4DPT_idx: - if "frame_intermediates" in locals(): - del frame_intermediates - if "global_intermediates" in locals(): - del global_intermediates - else: - concat_inter = torch.cat( - [frame_intermediates[0].detach(), global_intermediates[0].detach()], - dim=-1, - ) - if concat_inter.dtype != torch.bfloat16: - concat_inter = concat_inter.to(torch.bfloat16) - output_list.append(concat_inter) - del concat_inter, frame_intermediates, global_intermediates - - # Do final cleanup before returning - del tokens, pos - if "pos_special" in locals(): - del pos_special - if "pos_original" in locals(): - del pos_original - torch.cuda.empty_cache() # Final cleanup - - return output_list, self.patch_start_idx - - def _process_frame_attention( - self, tokens, B, S, P, C, frame_idx, pos=None, need_intermediates=False - ): - """ - Process frame attention blocks. We keep tokens in shape (B*S, P, C). - """ - # If needed, reshape tokens or positions: - if tokens.shape != (B * S, P, C): - tokens = tokens.view(B, S, P, C).view(B * S, P, C) - - if pos is not None and pos.shape != (B * S, P, 2): - pos = pos.view(B, S, P, 2).view(B * S, P, 2) - - intermediates = [] if need_intermediates else None - - # by default, self.aa_block_size=1, which processes one block at a time - for _ in range(self.aa_block_size): - tokens = self.frame_blocks[frame_idx](tokens, pos=pos) - frame_idx += 1 - if need_intermediates: - intermediates.append(tokens.view(B, S, P, C)) - - return tokens, frame_idx, intermediates - - def _process_global_attention( - self, - tokens, - B, - S, - P, - C, - global_idx, - pos=None, - global_merging=None, - need_intermediates=False, - ): - """ - Process global attention blocks. We keep tokens in shape (B, S*P, C). - """ - if tokens.shape != (B, S * P, C): - tokens = tokens.view(B, S, P, C).view(B, S * P, C) - - if pos is not None and pos.shape != (B, S * P, 2): - pos = pos.view(B, S, P, 2).view(B, S * P, 2) - - intermediates = [] if need_intermediates else None - - # by default, self.aa_block_size=1, which processes one block at a time - for _ in range(self.aa_block_size): - tokens = self.global_blocks[global_idx]( - tokens, - pos=pos, - global_merging=global_merging, - ) - global_idx += 1 - if need_intermediates: - intermediates.append(tokens.view(B, S, P, C)) - - return tokens, global_idx, intermediates - - def _load_image_paths(self): - """Load image paths from temporary file for visualization""" - try: - import os - import tempfile - import pickle - - temp_dir = tempfile.gettempdir() - image_paths_file = os.path.join(temp_dir, "vggt_image_paths.pkl") - - if os.path.exists(image_paths_file): - with open(image_paths_file, "rb") as f: - return pickle.load(f) - else: - return [] - except Exception as e: - print(f"Warning: Could not load image paths for visualization: {e}") - return [] - - -def slice_expand_and_flatten(token_tensor, B, S): - """ - Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: - 1) Uses the first position (index=0) for the first frame only - 2) Uses the second position (index=1) for all remaining frames (S-1 frames) - 3) Expands both to match batch size B - 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token - followed by (S-1) second-position tokens - 5) Flattens to (B*S, X, C) for processing - - Returns: - torch.Tensor: Processed tokens with shape (B*S, X, C) - """ - - query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) - # Slice out the "other" tokens => shape (1, S-1, ...) - others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) - # Concatenate => shape (B, S, ...) - combined = torch.cat([query, others], dim=1) - - # Finally flatten => shape (B*S, ...) - combined = combined.view(B * S, *combined.shape[2:]) - return combined diff --git a/FastVGGT/vggt/models/vggt.py b/FastVGGT/vggt/models/vggt.py deleted file mode 100644 index 2efeb5271de9916d05fd231a0a5ec1694f999bd0..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/models/vggt.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from huggingface_hub import PyTorchModelHubMixin # used for model hub - -from vggt.models.aggregator import Aggregator -from vggt.heads.camera_head import CameraHead -from vggt.heads.dpt_head import DPTHead -from vggt.heads.track_head import TrackHead - - -class VGGT(nn.Module, PyTorchModelHubMixin): - def __init__( - self, - img_size=518, - patch_size=14, - embed_dim=1024, - enable_camera=True, - enable_point=False, - enable_depth=True, - enable_track=False, - merging=0, - vis_attn_map=False, - ): - super().__init__() - - self.vis_attn_map = vis_attn_map - - self.aggregator = Aggregator( - img_size=img_size, - patch_size=patch_size, - embed_dim=embed_dim, - merging=merging, - vis_attn_map=vis_attn_map, - ) - - self.camera_head = CameraHead(dim_in=2 * embed_dim) if enable_camera else None - self.point_head = ( - DPTHead( - dim_in=2 * embed_dim, - output_dim=4, - activation="inv_log", - conf_activation="expp1", - ) - if enable_point - else None - ) - self.depth_head = ( - DPTHead( - dim_in=2 * embed_dim, - output_dim=2, - activation="exp", - conf_activation="expp1", - ) - if enable_depth - else None - ) - self.track_head = ( - TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) - if enable_track - else None - ) - - def update_patch_dimensions(self, patch_width: int, patch_height: int): - """ - Update patch dimensions for all attention layers in the model - - Args: - patch_width: Patch width (typically 37) - patch_height: Patch height (typically 28) - """ - - def update_attention_in_module(module): - for name, child in module.named_children(): - # Recursively update submodules - update_attention_in_module(child) - # If it is an attention layer, update its patch dimensions - if hasattr(child, "patch_width") and hasattr(child, "patch_height"): - child.patch_width = patch_width - child.patch_height = patch_height - - # Update all attention layers in the aggregator - update_attention_in_module(self.aggregator) - - # print( - # f"🔧 Updated model attention layer patch dimensions: {patch_width}x{patch_height}" - # ) - - def forward( - self, - images: torch.Tensor, - query_points: torch.Tensor = None, - image_paths: list = None, - ): - """ - Forward pass of the VGGT model. - - Args: - images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. - B: batch size, S: sequence length, 3: RGB channels, H: height, W: width - query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. - Shape: [N, 2] or [B, N, 2], where N is the number of query points. - Default: None - image_paths (list, optional): List of image file paths for attention visualization. - Only used when vis_attn_map=True. Default: None - - Returns: - dict: A dictionary containing the following predictions: - - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) - - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] - - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] - - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] - - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] - - images (torch.Tensor): Original input images, preserved for visualization - - If query_points is provided, also includes: - - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates - - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] - - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] - """ - # If without batch dimension, add it - if len(images.shape) == 4: - images = images.unsqueeze(0) - - if query_points is not None and len(query_points.shape) == 2: - query_points = query_points.unsqueeze(0) - - # Save image paths globally for attention visualization - if self.vis_attn_map and image_paths is not None: - import os - import tempfile - import pickle - - # Create a temporary file to store image paths - temp_dir = tempfile.gettempdir() - image_paths_file = os.path.join(temp_dir, "vggt_image_paths.pkl") - with open(image_paths_file, "wb") as f: - pickle.dump(image_paths, f) - - aggregated_tokens_list, patch_start_idx = self.aggregator(images) - - predictions = {} - - if self.camera_head is not None: - pose_enc_list = self.camera_head(aggregated_tokens_list) - predictions["pose_enc"] = pose_enc_list[ - -1 - ] # pose encoding of the last iteration - predictions["pose_enc_list"] = pose_enc_list - - if self.depth_head is not None: - depth, depth_conf = self.depth_head( - aggregated_tokens_list, - images=images, - patch_start_idx=patch_start_idx, - ) - predictions["depth"] = depth - predictions["depth_conf"] = depth_conf - - if self.point_head is not None: - pts3d, pts3d_conf = self.point_head( - aggregated_tokens_list, - images=images, - patch_start_idx=patch_start_idx, - ) - predictions["world_points"] = pts3d - predictions["world_points_conf"] = pts3d_conf - - if self.track_head is not None and query_points is not None: - track_list, vis, conf = self.track_head( - aggregated_tokens_list, - images=images, - patch_start_idx=patch_start_idx, - query_points=query_points, - ) - predictions["track"] = track_list[-1] # track of the last iteration - predictions["vis"] = vis - predictions["conf"] = conf - - if not self.training: - predictions["images"] = ( - images # store the images for visualization during inference - ) - - return predictions diff --git a/FastVGGT/vggt/utils/__init__.py b/FastVGGT/vggt/utils/__init__.py deleted file mode 100644 index 3a2958f93a114495631a3ce270b99ee8ba8443f1..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 95269f2111ee13cbadfcbab9cd10605555f25aee..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc deleted file mode 100644 index a3e08e180b2f41ce54c1912296117bf616196aae..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc deleted file mode 100644 index b7a92463e6f8f34289986fcaf2ab515c9598cf51..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc deleted file mode 100644 index 37978e9be49721f9e88beeb814a38e3dcb8754ad..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc deleted file mode 100644 index 93fcc5eddf5936461e8cbf8ca031e150105da24e..0000000000000000000000000000000000000000 Binary files a/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc and /dev/null differ diff --git a/FastVGGT/vggt/utils/eval_utils.py b/FastVGGT/vggt/utils/eval_utils.py deleted file mode 100644 index a81e2cab2b36c97e35b5ebb110b6e66fc814efa7..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/utils/eval_utils.py +++ /dev/null @@ -1,782 +0,0 @@ -from collections import deque -from PIL import Image -import cv2 -import io -import random -import time -from pathlib import Path -from typing import List, Tuple, Dict, Optional -from copy import deepcopy -import matplotlib.pyplot as plt -import evo.tools.plot as plot -import evo.main_ape as main_ape -import evo.main_rpe as main_rpe -import matplotlib.pyplot as plt -import numpy as np -import torch -from evo.core.metrics import PoseRelation, Unit -from evo.core.trajectory import PoseTrajectory3D -from scipy.linalg import svd -import open3d as o3d # for point cloud processing and Chamfer Distance computation - -from scipy.spatial.transform import Rotation -from torchvision import transforms as TF - - -from vggt.utils.geometry import unproject_depth_map_to_point_map -from vggt.utils.pose_enc import pose_encoding_to_extri_intri - - -def shuffle_deque(dq, seed=None): - # Set the random seed for reproducibility - if seed is not None: - random.seed(seed) - - # Convert deque to list, shuffle, and convert back - shuffled_list = list(dq) - random.shuffle(shuffled_list) - return deque(shuffled_list) - - -def imread_cv2(path, options=cv2.IMREAD_COLOR): - """Open an image or a depthmap with opencv-python.""" - if path.endswith((".exr", "EXR")): - options = cv2.IMREAD_ANYDEPTH - img = cv2.imread(path, options) - if img is None: - raise IOError(f"Could not load image={path} with {options=}") - if img.ndim == 3: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - return img - - -# Import umeyama_alignment for internal use in eval_trajectory -def umeyama_alignment(src, dst, estimate_scale=True): - # Ensure inputs have correct shape - assert ( - src.shape == dst.shape - ), f"Input shapes don't match: src {src.shape}, dst {dst.shape}" - assert src.shape[0] == 3, f"Expected point cloud dimension (3,N), got {src.shape}" - - # Compute centroids - src_mean = src.mean(axis=1, keepdims=True) - dst_mean = dst.mean(axis=1, keepdims=True) - - # Center the point clouds - src_centered = src - src_mean - dst_centered = dst - dst_mean - - # Compute covariance matrix - cov = dst_centered @ src_centered.T - - try: - # Singular Value Decomposition - U, D, Vt = svd(cov) - V = Vt.T - - # Handle reflection case - det_UV = np.linalg.det(U @ V.T) - S = np.eye(3) - if det_UV < 0: - S[2, 2] = -1 - - # Compute rotation matrix - R = U @ S @ V.T - - if estimate_scale: - # Compute scale factor - fix dimension issue - src_var = np.sum(src_centered * src_centered) - if src_var < 1e-10: - print( - "Warning: Source point cloud variance close to zero, setting scale factor to 1.0" - ) - scale = 1.0 - else: - # Fix potential dimension issue with np.diag(S) - # Use diagonal elements directly - scale = np.sum(D * np.diag(S)) / src_var - else: - scale = 1.0 - - # Compute translation vector - t = dst_mean.ravel() - scale * (R @ src_mean).ravel() - - return scale, R, t - - except Exception as e: - print(f"Error in umeyama_alignment computation: {e}") - print( - "Returning default transformation: scale=1.0, rotation=identity matrix, translation=centroid difference" - ) - # Return default transformation - scale = 1.0 - R = np.eye(3) - t = (dst_mean - src_mean).ravel() - return scale, R, t - - -def compute_chamfer_distance(points_pred, points_gt, max_dist=1.0): - # Ensure point cloud size is not too large, which would cause slow computation - MAX_POINTS = 100000 - if points_pred.shape[0] > MAX_POINTS: - np.random.seed(33) # Fix random seed - indices = np.random.choice(points_pred.shape[0], MAX_POINTS, replace=False) - points_pred = points_pred[indices] - - if points_gt.shape[0] > MAX_POINTS: - np.random.seed(33) # Fix random seed - indices = np.random.choice(points_gt.shape[0], MAX_POINTS, replace=False) - points_gt = points_gt[indices] - - # Convert numpy point clouds to open3d point cloud objects - pcd_pred = o3d.geometry.PointCloud() - pcd_gt = o3d.geometry.PointCloud() - pcd_pred.points = o3d.utility.Vector3dVector(points_pred) - pcd_gt.points = o3d.utility.Vector3dVector(points_gt) - - # Downsample point clouds to accelerate computation - voxel_size = 0.05 # 5cm voxel size - pcd_pred = pcd_pred.voxel_down_sample(voxel_size) - pcd_gt = pcd_gt.voxel_down_sample(voxel_size) - - # Compute distances from predicted point cloud to GT point cloud - distances1 = np.asarray(pcd_pred.compute_point_cloud_distance(pcd_gt)) - # Compute distances from GT point cloud to predicted point cloud - distances2 = np.asarray(pcd_gt.compute_point_cloud_distance(pcd_pred)) - - # Apply distance clipping - distances1 = np.clip(distances1, 0, max_dist) - distances2 = np.clip(distances2, 0, max_dist) - - # Chamfer Distance is the sum of mean distances in both directions - chamfer_dist = np.mean(distances1) + np.mean(distances2) - - return chamfer_dist - - -def load_gt_pointcloud(scene_id, gt_ply_dir): - scene_dir = gt_ply_dir / scene_id - ply_path = scene_dir / (scene_id + "_vh_clean_2.ply") - pcd = o3d.io.read_point_cloud(str(ply_path)) - - # Convert to numpy arrays - points = np.asarray(pcd.points) - colors = None - try: - if pcd.has_colors(): - colors = np.asarray(pcd.colors) - except Exception: - colors = None - - return points, colors - - -def eval_trajectory(poses_est, poses_gt, frame_ids, align=False): - # Build reference trajectory object - traj_ref = PoseTrajectory3D( - positions_xyz=poses_gt[:, :3, 3], # Extract translation part - orientations_quat_wxyz=Rotation.from_matrix(poses_gt[:, :3, :3]).as_quat( - scalar_first=True - ), # Convert rotation matrix to quaternion - timestamps=frame_ids, - ) - - # Build estimated trajectory object - traj_est = PoseTrajectory3D( - positions_xyz=poses_est[:, :3, 3], - orientations_quat_wxyz=Rotation.from_matrix(poses_est[:, :3, :3]).as_quat( - scalar_first=True - ), - timestamps=frame_ids, - ) - - # Calculate Absolute Trajectory Error (ATE) - ate_result = main_ape.ape( - deepcopy(traj_ref), - deepcopy(traj_est), - est_name="traj", - pose_relation=PoseRelation.translation_part, - align=align, - correct_scale=align, - align_origin=True, - ) - ate = ate_result.stats["rmse"] - - # Get alignment transformation matrix - transform = np.eye(4) - if align: - try: - # Use umeyama algorithm to compute optimal rigid transformation (including rotation, translation and scaling) - aligned_xyz = ate_result.trajectories["traj"].positions_xyz - original_xyz = traj_est.positions_xyz - - # At least 3 points needed to compute reliable transformation - if len(aligned_xyz) >= 3 and len(original_xyz) >= 3: - # Ensure point count matches - min_points = min(len(aligned_xyz), len(original_xyz)) - aligned_xyz = aligned_xyz[:min_points] - original_xyz = original_xyz[:min_points] - - # Compute transformation matrix (scaling, rotation and translation) - try: - s, R, t = umeyama_alignment( - original_xyz.T, # Source point cloud (3xN) - aligned_xyz.T, # Target point cloud (3xN) - True, # Whether to estimate scaling - ) - - # Build complete transformation matrix - transform = np.eye(4) - transform[:3, :3] = s * R # Scaling and rotation - transform[:3, 3] = t # Translation - - except Exception as e: - print(f"umeyama_alignment failed: {e}") - else: - print( - "Insufficient points, cannot reliably compute transformation matrix" - ) - except Exception as e: - print(f"Error computing transformation matrix: {e}") - - # If the above method fails, fallback to simple translation transformation - if np.array_equal(transform, np.eye(4)) and hasattr(ate_result, "trajectories"): - try: - # Get original and aligned first position - orig_pos = traj_est.positions_xyz[0] - aligned_pos = ate_result.trajectories["traj"].positions_xyz[0] - - # Calculate translation part - translation = aligned_pos - orig_pos - - # Update translation part of transformation matrix - transform[:3, 3] = translation - print(f"Fallback to simple translation transformation: {transform}") - except Exception as e: - print(f"Error building translation transformation: {e}") - print("Will use identity matrix") - - # Calculate Absolute Rotation Error (ARE) - are_result = main_ape.ape( - deepcopy(traj_ref), - deepcopy(traj_est), - est_name="traj", - pose_relation=PoseRelation.rotation_angle_deg, - align=align, - correct_scale=align, - align_origin=True, - ) - are = are_result.stats["rmse"] - - # Calculate Relative Pose Error (RPE) - rotation part - rpe_rots_result = main_rpe.rpe( - deepcopy(traj_ref), - deepcopy(traj_est), - est_name="traj", - pose_relation=PoseRelation.rotation_angle_deg, - align=align, - correct_scale=align, - delta=1, - delta_unit=Unit.frames, - rel_delta_tol=0.01, - all_pairs=True, - align_origin=True, - ) - rpe_rot = rpe_rots_result.stats["rmse"] - - # Calculate Relative Pose Error (RPE) - translation part - rpe_transs_result = main_rpe.rpe( - deepcopy(traj_ref), - deepcopy(traj_est), - est_name="traj", - pose_relation=PoseRelation.translation_part, - align=align, - correct_scale=align, - delta=1, - delta_unit=Unit.frames, - rel_delta_tol=0.01, - all_pairs=True, - align_origin=True, - ) - rpe_trans = rpe_transs_result.stats["rmse"] - - # Plot trajectory graph - plot_mode = plot.PlotMode.xz # Use correct PlotMode reference - fig = plt.figure() - ax = plot.prepare_axis(fig, plot_mode) - ax.set_title(f"ATE: {round(ate, 3)}, ARE: {round(are, 3)}") - - # Use reference trajectory (GT) for plotting - plot.traj(ax, plot_mode, traj_ref, "--", "gray", "gt") - - # Use aligned trajectory for visualization - if align: - traj_est_aligned = ate_result.trajectories["traj"] - plot.traj_colormap( - ax, - traj_est_aligned, - ate_result.np_arrays["error_array"], - plot_mode, - min_map=ate_result.stats["min"], - max_map=ate_result.stats["max"], - ) - else: - plot.traj_colormap( - ax, - traj_est, - ate_result.np_arrays["error_array"], - plot_mode, - min_map=ate_result.stats["min"], - max_map=ate_result.stats["max"], - ) - - ax.legend() - - # Save image to memory buffer - buffer = io.BytesIO() - plt.savefig(buffer, format="png", dpi=90) - buffer.seek(0) - - pillow_image = Image.open(buffer) - pillow_image.load() - buffer.close() - plt.close(fig) - - return ( - {"ate": ate, "are": are, "rpe_rot": rpe_rot, "rpe_trans": rpe_trans}, - pillow_image, - transform, - ) - - -def load_poses(path): - # Read all txt files from pose directory - pose_files = sorted( - path.glob("*.txt"), key=lambda x: int(x.stem) - ) # Sort by numerical order - - # Check if pose files exist - if len(pose_files) == 0: - print(f"Warning: No pose files (.txt) found in directory {path}") - return None, None, None - - c2ws = [] - available_frame_ids = [] - - for pose_file in pose_files: - try: - with open(pose_file, "r") as f: - # Each file contains 16 numbers representing a 4x4 transformation matrix - nums = [float(x) for x in f.read().strip().split()] - pose = np.array(nums).reshape(4, 4) - # Check if pose is valid (no infinite or NaN values) - if not (np.isinf(pose).any() or np.isnan(pose).any()): - c2ws.append(pose) - available_frame_ids.append(int(pose_file.stem)) - else: - continue - except Exception as e: - print(f"Error reading pose file {pose_file}: {e}") - continue - - if len(c2ws) == 0: - print(f"Warning: No valid pose files found in directory {path}") - return None, None, None - - c2ws = np.stack(c2ws) - available_frame_ids = np.array(available_frame_ids) - - # Transform all poses to first frame coordinate system - first_gt_pose = c2ws[0].copy() # Save original pose of first frame - c2ws = np.linalg.inv(c2ws[0]) @ c2ws - return c2ws, first_gt_pose, available_frame_ids - - -def get_vgg_input_imgs(images: np.ndarray): - to_tensor = TF.ToTensor() - vgg_input_images = [] - final_width = None - final_height = None - - for image in images: - img = Image.fromarray(image, mode="RGB") - width, height = img.size - # Resize image, maintain aspect ratio, ensure height is multiple of 14 - new_width = 518 - new_height = round(height * (new_width / width) / 14) * 14 - img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) - img = to_tensor(img) # Convert to tensor (0, 1) - - # If height exceeds 518, perform center cropping - if new_height > 518: - start_y = (new_height - 518) // 2 - img = img[:, start_y : start_y + 518, :] - final_height = 518 - else: - final_height = new_height - - final_width = new_width - vgg_input_images.append(img) - - vgg_input_images = torch.stack(vgg_input_images) - - # Calculate the patch dimensions (divided by 14 for patch size) - patch_width = final_width // 14 # 518 // 14 = 37 - patch_height = final_height // 14 # computed dynamically, typically 28 - - return vgg_input_images, patch_width, patch_height - - -def get_sorted_image_paths(images_dir): - image_paths = [] - for ext in ["*.jpg", "*.png", "*.jpeg"]: - image_paths.extend(sorted(images_dir.glob(ext))) - # image_paths.sort(key=lambda x: int(x.stem)) - return image_paths - - -def to_homogeneous(extrinsics): - n = extrinsics.shape[0] - homogeneous_extrinsics = np.eye(4)[None, :, :].repeat( - n, axis=0 - ) # Create identity matrix - homogeneous_extrinsics[:, :3, :4] = extrinsics # Copy [R|t] part - return homogeneous_extrinsics - - -def umeyama_alignment(src, dst, estimate_scale=True): - # Ensure inputs have correct shape - assert ( - src.shape == dst.shape - ), f"Input shapes don't match: src {src.shape}, dst {dst.shape}" - assert src.shape[0] == 3, f"Expected point cloud dimension (3,N), got {src.shape}" - - # Compute centroids - src_mean = src.mean(axis=1, keepdims=True) - dst_mean = dst.mean(axis=1, keepdims=True) - - # Center the point clouds - src_centered = src - src_mean - dst_centered = dst - dst_mean - - # Compute covariance matrix - cov = dst_centered @ src_centered.T - - try: - # Singular Value Decomposition - U, D, Vt = svd(cov) - V = Vt.T - - # Handle reflection case - det_UV = np.linalg.det(U @ V.T) - S = np.eye(3) - if det_UV < 0: - S[2, 2] = -1 - - # Compute rotation matrix - R = U @ S @ V.T - - if estimate_scale: - # Compute scale factor - fix dimension issue - src_var = np.sum(src_centered * src_centered) - if src_var < 1e-10: - print( - "Warning: Source point cloud variance close to zero, setting scale factor to 1.0" - ) - scale = 1.0 - else: - # Fix potential dimension issue with np.diag(S) - # Use diagonal elements directly - scale = np.sum(D * np.diag(S)) / src_var - else: - scale = 1.0 - - # Compute translation vector - t = dst_mean.ravel() - scale * (R @ src_mean).ravel() - - return scale, R, t - - except Exception as e: - print(f"Error in umeyama_alignment computation: {e}") - print( - "Returning default transformation: scale=1.0, rotation=identity matrix, translation=centroid difference" - ) - # Return default transformation - scale = 1.0 - R = np.eye(3) - t = (dst_mean - src_mean).ravel() - return scale, R, t - - -def align_point_clouds_scale(source_pc, target_pc): - # Compute bounding box sizes of point clouds - source_min = np.min(source_pc, axis=0) - source_max = np.max(source_pc, axis=0) - target_min = np.min(target_pc, axis=0) - target_max = np.max(target_pc, axis=0) - - source_size = source_max - source_min - target_size = target_max - target_min - - # Compute point cloud centers - source_center = (source_max + source_min) / 2 - target_center = (target_max + target_min) / 2 - - # Compute overall scale factor (using diagonal length) - source_diag = np.sqrt(np.sum(source_size**2)) - target_diag = np.sqrt(np.sum(target_size**2)) - - if source_diag < 1e-8: - print("Warning: Source point cloud size close to zero") - scale_factor = 1.0 - else: - scale_factor = target_diag / source_diag - - # Apply scaling (with source point cloud center as reference) - centered_source = source_pc - source_center - scaled_centered = centered_source * scale_factor - scaled_aligned_source = scaled_centered + target_center - - return scaled_aligned_source, scale_factor - - -def get_all_scenes(data_dir: Path, num_scenes: int) -> List[str]: - all_scenes = sorted([d.name for d in data_dir.iterdir() if d.is_dir()]) - if len(all_scenes) > num_scenes: - sample_interval = max(1, len(all_scenes) // num_scenes) - return all_scenes[::sample_interval][:num_scenes] - return all_scenes - - -def build_frame_selection( - image_paths: List[Path], - available_pose_frame_ids: np.ndarray, - input_frame: int, -) -> Tuple[List[int], List[Path], List[int]]: - all_image_frame_ids = [int(path.stem) for path in image_paths] - valid_frame_ids = sorted( - list(set(all_image_frame_ids) & set(available_pose_frame_ids)) - ) - if len(valid_frame_ids) > input_frame: - first_frame = valid_frame_ids[0] - remaining_frames = valid_frame_ids[1:] - step = max(1, len(remaining_frames) // (input_frame - 1)) - selected_remaining = remaining_frames[::step][: input_frame - 1] - selected_frame_ids = [first_frame] + selected_remaining - else: - selected_frame_ids = valid_frame_ids - - frame_id_to_path = {int(path.stem): path for path in image_paths} - selected_image_paths = [ - frame_id_to_path[fid] for fid in selected_frame_ids if fid in frame_id_to_path - ] - - pose_frame_to_idx = {fid: idx for idx, fid in enumerate(available_pose_frame_ids)} - selected_pose_indices = [ - pose_frame_to_idx[fid] for fid in selected_frame_ids if fid in pose_frame_to_idx - ] - - return selected_frame_ids, selected_image_paths, selected_pose_indices - - -def load_images_rgb(image_paths: List[Path]) -> List[np.ndarray]: - images: List[np.ndarray] = [] - for image_path in image_paths: - img = cv2.imread(str(image_path)) - if img is None: - continue - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - images.append(img) - return images - - -@torch.no_grad() -def infer_vggt_and_reconstruct( - model: torch.nn.Module, - vgg_input: torch.Tensor, - dtype: torch.dtype, - depth_conf_thresh: float, - image_paths: list = None, -) -> Tuple[ - np.ndarray, - np.ndarray, - List[np.ndarray], - List[np.ndarray], - List[np.ndarray], - float, -]: - torch.cuda.synchronize() - start = time.time() - with torch.cuda.amp.autocast(dtype=dtype): - vgg_input_cuda = vgg_input.cuda().to(torch.bfloat16) - predictions = model(vgg_input_cuda, image_paths=image_paths) - torch.cuda.synchronize() - end = time.time() - inference_time_ms = (end - start) * 1000.0 - - extrinsic, intrinsic = pose_encoding_to_extri_intri( - predictions["pose_enc"], (vgg_input.shape[2], vgg_input.shape[3]) - ) - - depth_tensor = predictions["depth"] - depth_conf = predictions["depth_conf"] - depth_conf_np = depth_conf[0].detach().float().cpu().numpy() - depth_mask = depth_conf_np >= depth_conf_thresh - depth_filtered = depth_tensor[0].detach().float().cpu().numpy() - depth_filtered[~depth_mask] = np.nan - depth_np = depth_filtered - - extrinsic_np = extrinsic[0].detach().float().cpu().numpy() - intrinsic_np = intrinsic[0].detach().float().cpu().numpy() - - world_points = unproject_depth_map_to_point_map( - depth_np, extrinsic_np, intrinsic_np - ) - all_points: List[np.ndarray] = [] - all_colors: List[np.ndarray] = [] - - # Prepare RGB images aligned with vgg_input for coloring point clouds (0-255, uint8) - vgg_np = vgg_input.detach().float().cpu().numpy() # [S, 3, H, W] in [0,1] - - for frame_idx in range(world_points.shape[0]): - points = world_points[frame_idx].reshape(-1, 3) - valid_mask = ~np.isnan(points).any(axis=1) & ~np.isinf(points).any(axis=1) - valid_points = points[valid_mask] - if len(valid_points) > 0: - all_points.append(valid_points) - - # Generate corresponding colors - img_chw = vgg_np[frame_idx] # [3, H, W] - img_hwc = ( - (np.transpose(img_chw, (1, 2, 0)) * 255.0).clip(0, 255).astype(np.uint8) - ) # [H, W, 3] uint8 - rgb_flat = img_hwc.reshape(-1, 3) - valid_colors = rgb_flat[valid_mask] - all_colors.append(valid_colors) - - camera_poses = to_homogeneous(extrinsic_np) - all_cam_to_world_mat = list(camera_poses) - - return ( - extrinsic_np, - intrinsic_np, - all_points, - all_colors, - all_cam_to_world_mat, - inference_time_ms, - ) - - -def evaluate_scene_and_save( - scene: str, - c2ws: np.ndarray, - first_gt_pose: np.ndarray, - frame_ids: List[int], - all_cam_to_world_mat: List[np.ndarray], - all_world_points: List[np.ndarray], - output_scene_dir: Path, - gt_ply_dir: Path, - chamfer_max_dist: float, - inference_time_ms: float, - plot_flag: bool, -) -> Optional[Dict[str, float]]: - if not all_cam_to_world_mat or not all_world_points: - print(f"Skipping {scene}: failed to obtain valid camera poses or point clouds") - return None - - output_scene_dir.mkdir(parents=True, exist_ok=True) - - poses_gt = c2ws - w2cs = np.linalg.inv(poses_gt) - traj_est_poses = np.array(all_cam_to_world_mat) - n = min(len(traj_est_poses), len(w2cs)) - timestamps = frame_ids[:n] - stats_aligned, traj_plot, _ = eval_trajectory( - traj_est_poses[:n], w2cs[:n], timestamps, align=True - ) - - try: - merged_point_cloud = np.vstack(all_world_points) - gt_point_cloud, _ = load_gt_pointcloud(scene, gt_ply_dir) - if gt_point_cloud is not None: - homogeneous_points = np.hstack( - [merged_point_cloud, np.ones((merged_point_cloud.shape[0], 1))] - ) - world_points_raw = np.dot(homogeneous_points, first_gt_pose.T)[:, :3] - - world_points_scaled, scale_factor = align_point_clouds_scale( - world_points_raw, gt_point_cloud - ) - - cd_value = compute_chamfer_distance( - world_points_scaled, gt_point_cloud, max_dist=chamfer_max_dist - ) - stats_aligned["chamfer_distance"] = float(cd_value) - stats_aligned["scale_factor"] = float(scale_factor) - except Exception as e: - print(f"Error computing Chamfer Distance for {scene}: {e}") - - all_metrics: Dict[str, float] = deepcopy(stats_aligned) - for metric_name, metric_value in list(stats_aligned.items()): - all_metrics[f"aligned_{metric_name}"] = metric_value - all_metrics["inference_time_ms"] = float(inference_time_ms) - - with open(output_scene_dir / "metrics.json", "w") as f: - import json - - json.dump(all_metrics, f, indent=4) - if plot_flag: - try: - traj_plot.save(output_scene_dir / "plot.png") - except Exception: - pass - - return all_metrics - - -def compute_average_metrics_and_save( - all_scenes_metrics: Dict[str, Dict[str, Dict[str, float]]], - output_path: Path, - input_frame: int, -) -> Dict[str, float]: - metric_names = [ - "chamfer_distance", - "ate", - "are", - "rpe_rot", - "rpe_trans", - "inference_time_ms", - ] - average_metrics_list: Dict[str, List[float]] = { - metric: [] for metric in metric_names - } - for _, metrics in all_scenes_metrics["scenes"].items(): - for metric_name, metric_value in metrics.items(): - if metric_name in average_metrics_list: - average_metrics_list[metric_name].append(float(metric_value)) - - average_metrics: Dict[str, float] = {} - for metric_name, values in average_metrics_list.items(): - average_metrics[metric_name] = float(np.mean(values)) if values else 0.0 - - all_scenes_metrics["average"] = average_metrics - output_path.mkdir(parents=True, exist_ok=True) - - input_frame_dir = output_path / f"input_frame_{input_frame}" - input_frame_dir.mkdir(parents=True, exist_ok=True) - - with open(input_frame_dir / "all_scenes_metrics.json", "w") as f: - import json - - json.dump(all_scenes_metrics, f, indent=4) - - with open(input_frame_dir / "average_metrics.json", "w") as f: - import json - - json.dump(average_metrics, f, indent=4) - - print("\nAverage metrics:") - for metric_name, value in average_metrics.items(): - print(f"{metric_name}: {value:.6f}") - - return average_metrics diff --git a/FastVGGT/vggt/utils/geometry.py b/FastVGGT/vggt/utils/geometry.py deleted file mode 100644 index f555516dbc8a7dd8c7b15e6fbc928a5bfe8f740b..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/utils/geometry.py +++ /dev/null @@ -1,324 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import os -import torch -import numpy as np - - -from vggt.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion - - -def unproject_depth_map_to_point_map( - depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray -) -> np.ndarray: - """ - Unproject a batch of depth maps to 3D world coordinates. - - Args: - depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) - extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) - intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) - - Returns: - np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) - """ - if isinstance(depth_map, torch.Tensor): - depth_map = depth_map.cpu().numpy() - if isinstance(extrinsics_cam, torch.Tensor): - extrinsics_cam = extrinsics_cam.cpu().numpy() - if isinstance(intrinsics_cam, torch.Tensor): - intrinsics_cam = intrinsics_cam.cpu().numpy() - - world_points_list = [] - for frame_idx in range(depth_map.shape[0]): - cur_world_points, _, _ = depth_to_world_coords_points( - depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] - ) - world_points_list.append(cur_world_points) - world_points_array = np.stack(world_points_list, axis=0) - - return world_points_array - - -def depth_to_world_coords_points( - depth_map: np.ndarray, - extrinsic: np.ndarray, - intrinsic: np.ndarray, - eps=1e-8, -) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Convert a depth map to world coordinates. - - Args: - depth_map (np.ndarray): Depth map of shape (H, W). - intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). - extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. - - Returns: - tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). - """ - if depth_map is None: - return None, None, None - - # Valid depth mask - point_mask = depth_map > eps - - # Convert depth map to camera coordinates - cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) - - # Multiply with the inverse of extrinsic matrix to transform to world coordinates - # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) - cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] - - R_cam_to_world = cam_to_world_extrinsic[:3, :3] - t_cam_to_world = cam_to_world_extrinsic[:3, 3] - - # Apply the rotation and translation to the camera coordinates - world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 - # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world - - return world_coords_points, cam_coords_points, point_mask - - -def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """ - Convert a depth map to camera coordinates. - - Args: - depth_map (np.ndarray): Depth map of shape (H, W). - intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). - - Returns: - tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) - """ - H, W = depth_map.shape - assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" - assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" - - # Intrinsic parameters - fu, fv = intrinsic[0, 0], intrinsic[1, 1] - cu, cv = intrinsic[0, 2], intrinsic[1, 2] - - # Generate grid of pixel coordinates - u, v = np.meshgrid(np.arange(W), np.arange(H)) - - # Unproject to camera coordinates - x_cam = (u - cu) * depth_map / fu - y_cam = (v - cv) * depth_map / fv - z_cam = depth_map - - # Stack to form camera coordinates - cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) - - return cam_coords - - -def closed_form_inverse_se3(se3, R=None, T=None): - """ - Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. - - If `R` and `T` are provided, they must correspond to the rotation and translation - components of `se3`. Otherwise, they will be extracted from `se3`. - - Args: - se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. - R (optional): Nx3x3 array or tensor of rotation matrices. - T (optional): Nx3x1 array or tensor of translation vectors. - - Returns: - Inverted SE3 matrices with the same type and device as `se3`. - - Shapes: - se3: (N, 4, 4) - R: (N, 3, 3) - T: (N, 3, 1) - """ - # Check if se3 is a numpy array or a torch tensor - is_numpy = isinstance(se3, np.ndarray) - - # Validate shapes - if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): - raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") - - # Extract R and T if not provided - if R is None: - R = se3[:, :3, :3] # (N,3,3) - if T is None: - T = se3[:, :3, 3:] # (N,3,1) - - # Transpose R - if is_numpy: - # Compute the transpose of the rotation for NumPy - R_transposed = np.transpose(R, (0, 2, 1)) - # -R^T t for NumPy - top_right = -np.matmul(R_transposed, T) - inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) - else: - R_transposed = R.transpose(1, 2) # (N,3,3) - top_right = -torch.bmm(R_transposed, T) # (N,3,1) - inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) - inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) - - inverted_matrix[:, :3, :3] = R_transposed - inverted_matrix[:, :3, 3:] = top_right - - return inverted_matrix - - -# TODO: this code can be further cleaned up - - -def project_world_points_to_camera_points_batch(world_points, cam_extrinsics): - """ - Transforms 3D points to 2D using extrinsic and intrinsic parameters. - Args: - world_points (torch.Tensor): 3D points of shape BxSxHxWx3. - cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4. - Returns: - """ - # TODO: merge this into project_world_points_to_cam - - # device = world_points.device - # with torch.autocast(device_type=device.type, enabled=False): - ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1) - world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4) - - # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4) - extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3) - - # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1) - world_points_h_exp = world_points_h.unsqueeze(-1) - - # Now perform the matrix multiplication - # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1) - camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1) - - return camera_points - - - -def project_world_points_to_cam( - world_points, - cam_extrinsics, - cam_intrinsics=None, - distortion_params=None, - default=0, - only_points_cam=False, -): - """ - Transforms 3D points to 2D using extrinsic and intrinsic parameters. - Args: - world_points (torch.Tensor): 3D points of shape Px3. - cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. - cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. - distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion. - Returns: - torch.Tensor: Transformed 2D points of shape BxNx2. - """ - device = world_points.device - # with torch.autocast(device_type=device.type, dtype=torch.double): - with torch.autocast(device_type=device.type, enabled=False): - N = world_points.shape[0] # Number of points - B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras - world_points_homogeneous = torch.cat( - [world_points, torch.ones_like(world_points[..., 0:1])], dim=1 - ) # Nx4 - # Reshape for batch processing - world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand( - B, -1, -1 - ) # BxNx4 - - # Step 1: Apply extrinsic parameters - # Transform 3D points to camera coordinate system for all cameras - cam_points = torch.bmm( - cam_extrinsics, world_points_homogeneous.transpose(-1, -2) - ) - - if only_points_cam: - return None, cam_points - - # Step 2: Apply intrinsic parameters and (optional) distortion - image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default) - - return image_points, cam_points - - - -def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0): - """ - Applies intrinsic parameters and optional distortion to the given 3D points. - - Args: - cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. - cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. - distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. - default (float, optional): Default value to replace NaNs in the output. - - Returns: - pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. - """ - - # Normalized device coordinates (NDC) - cam_points = cam_points / cam_points[:, 2:3, :] - ndc_xy = cam_points[:, :2, :] - - # Apply distortion if distortion_params are provided - if distortion_params is not None: - x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1]) - distorted_xy = torch.stack([x_distorted, y_distorted], dim=1) - else: - distorted_xy = ndc_xy - - # Prepare cam_points for batch matrix multiplication - cam_coords_homo = torch.cat( - (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1 - ) # Bx3xN - # Apply intrinsic parameters using batch matrix multiplication - pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN - - # Extract x and y coordinates - pixel_coords = pixel_coords[:, :2, :] # Bx2xN - - # Replace NaNs with default value - pixel_coords = torch.nan_to_num(pixel_coords, nan=default) - - return pixel_coords.transpose(1, 2) # BxNx2 - - - - -def cam_from_img(pred_tracks, intrinsics, extra_params=None): - """ - Normalize predicted tracks based on camera intrinsics. - Args: - intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3]. - pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2]. - extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. - Returns: - torch.Tensor: Normalized tracks tensor. - """ - - # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here - # otherwise we can use something like - # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2)) - - principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2) - focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2) - tracks_normalized = (pred_tracks - principal_point) / focal_length - - if extra_params is not None: - # Apply iterative undistortion - try: - tracks_normalized = iterative_undistortion( - extra_params, tracks_normalized - ) - except: - tracks_normalized = single_undistortion( - extra_params, tracks_normalized - ) - - return tracks_normalized \ No newline at end of file diff --git a/FastVGGT/vggt/utils/helper.py b/FastVGGT/vggt/utils/helper.py deleted file mode 100644 index 7b019189c85ff86645a4cf3756632aa8d4500649..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/utils/helper.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np - - -def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: - """ - If mask has more than max_trues True values, - randomly keep only max_trues of them and set the rest to False. - """ - # 1D positions of all True entries - true_indices = np.flatnonzero(mask) # shape = (N_true,) - - # if already within budget, return as-is - if true_indices.size <= max_trues: - return mask - - # randomly pick which True positions to keep - sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,) - - # build new flat mask: True only at sampled positions - limited_flat_mask = np.zeros(mask.size, dtype=bool) - limited_flat_mask[sampled_indices] = True - - # restore original shape - return limited_flat_mask.reshape(mask.shape) - - -def create_pixel_coordinate_grid(num_frames, height, width): - """ - Creates a grid of pixel coordinates and frame indices for all frames. - Returns: - tuple: A tuple containing: - - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) - with x, y coordinates and frame indices - - y_coords (numpy.ndarray): Array of y coordinates for all frames - - x_coords (numpy.ndarray): Array of x coordinates for all frames - - f_coords (numpy.ndarray): Array of frame indices for all frames - """ - # Create coordinate grids for a single frame - y_grid, x_grid = np.indices((height, width), dtype=np.float32) - x_grid = x_grid[np.newaxis, :, :] - y_grid = y_grid[np.newaxis, :, :] - - # Broadcast to all frames - x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) - y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) - - # Create frame indices and broadcast - f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] - f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) - - # Stack coordinates and frame indices - points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) - - return points_xyf diff --git a/FastVGGT/vggt/utils/load_fn.py b/FastVGGT/vggt/utils/load_fn.py deleted file mode 100644 index 3ae01ba2601f7b5039220bbae94725646dd66a41..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/utils/load_fn.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from PIL import Image -from torchvision import transforms as TF -import numpy as np - - -def load_and_preprocess_images_square(image_path_list, target_size=1024): - """ - Load and preprocess images by center padding to square and resizing to target size. - Also returns the position information of original pixels after transformation. - - Args: - image_path_list (list): List of paths to image files - target_size (int, optional): Target size for both width and height. Defaults to 518. - - Returns: - tuple: ( - torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), - torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image - ) - - Raises: - ValueError: If the input list is empty - """ - # Check for empty list - if len(image_path_list) == 0: - raise ValueError("At least 1 image is required") - - images = [] - original_coords = [] # Renamed from position_info to be more descriptive - to_tensor = TF.ToTensor() - - for image_path in image_path_list: - # Open image - img = Image.open(image_path) - - # If there's an alpha channel, blend onto white background - if img.mode == "RGBA": - background = Image.new("RGBA", img.size, (255, 255, 255, 255)) - img = Image.alpha_composite(background, img) - - # Convert to RGB - img = img.convert("RGB") - - # Get original dimensions - width, height = img.size - - # Make the image square by padding the shorter dimension - max_dim = max(width, height) - - # Calculate padding - left = (max_dim - width) // 2 - top = (max_dim - height) // 2 - - # Calculate scale factor for resizing - scale = target_size / max_dim - - # Calculate final coordinates of original image in target space - x1 = left * scale - y1 = top * scale - x2 = (left + width) * scale - y2 = (top + height) * scale - - # Store original image coordinates and scale - original_coords.append(np.array([x1, y1, x2, y2, width, height])) - - # Create a new black square image and paste original - square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) - square_img.paste(img, (left, top)) - - # Resize to target size - square_img = square_img.resize( - (target_size, target_size), Image.Resampling.BICUBIC - ) - - # Convert to tensor - img_tensor = to_tensor(square_img) - images.append(img_tensor) - - # Stack all images - images = torch.stack(images) - original_coords = torch.from_numpy(np.array(original_coords)).float() - - # Add additional dimension if single image to ensure correct shape - if len(image_path_list) == 1: - if images.dim() == 3: - images = images.unsqueeze(0) - original_coords = original_coords.unsqueeze(0) - - return images, original_coords - - -def load_and_preprocess_images(image_path_list, mode="crop"): - """ - A quick start function to load and preprocess images for model input. - This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. - - Args: - image_path_list (list): List of paths to image files - mode (str, optional): Preprocessing mode, either "crop" or "pad". - - "crop" (default): Sets width to 518px and center crops height if needed. - - "pad": Preserves all pixels by making the largest dimension 518px - and padding the smaller dimension to reach a square shape. - - Returns: - torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) - - Raises: - ValueError: If the input list is empty or if mode is invalid - - Notes: - - Images with different dimensions will be padded with white (value=1.0) - - A warning is printed when images have different shapes - - When mode="crop": The function ensures width=518px while maintaining aspect ratio - and height is center-cropped if larger than 518px - - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio - and the smaller dimension is padded to reach a square shape (518x518) - - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements - """ - # Check for empty list - if len(image_path_list) == 0: - raise ValueError("At least 1 image is required") - - # Validate mode - if mode not in ["crop", "pad"]: - raise ValueError("Mode must be either 'crop' or 'pad'") - - images = [] - shapes = set() - to_tensor = TF.ToTensor() - target_size = 518 - - # First process all images and collect their shapes - for image_path in image_path_list: - # Open image - img = Image.open(image_path) - - # If there's an alpha channel, blend onto white background: - if img.mode == "RGBA": - # Create white background - background = Image.new("RGBA", img.size, (255, 255, 255, 255)) - # Alpha composite onto the white background - img = Image.alpha_composite(background, img) - - # Now convert to "RGB" (this step assigns white for transparent areas) - img = img.convert("RGB") - - width, height = img.size - - if mode == "pad": - # Make the largest dimension 518px while maintaining aspect ratio - if width >= height: - new_width = target_size - new_height = ( - round(height * (new_width / width) / 14) * 14 - ) # Make divisible by 14 - else: - new_height = target_size - new_width = ( - round(width * (new_height / height) / 14) * 14 - ) # Make divisible by 14 - else: # mode == "crop" - # Original behavior: set width to 518px - new_width = target_size - # Calculate height maintaining aspect ratio, divisible by 14 - new_height = round(height * (new_width / width) / 14) * 14 - - # Resize with new dimensions (width, height) - img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) - img = to_tensor(img) # Convert to tensor (0, 1) - - # Center crop height if it's larger than 518 (only in crop mode) - if mode == "crop" and new_height > target_size: - start_y = (new_height - target_size) // 2 - img = img[:, start_y : start_y + target_size, :] - - # For pad mode, pad to make a square of target_size x target_size - if mode == "pad": - h_padding = target_size - img.shape[1] - w_padding = target_size - img.shape[2] - - if h_padding > 0 or w_padding > 0: - pad_top = h_padding // 2 - pad_bottom = h_padding - pad_top - pad_left = w_padding // 2 - pad_right = w_padding - pad_left - - # Pad with white (value=1.0) - img = torch.nn.functional.pad( - img, - (pad_left, pad_right, pad_top, pad_bottom), - mode="constant", - value=1.0, - ) - - shapes.add((img.shape[1], img.shape[2])) - images.append(img) - - # Check if we have different shapes - # In theory our model can also work well with different shapes - if len(shapes) > 1: - print(f"Warning: Found images with different shapes: {shapes}") - # Find maximum dimensions - max_height = max(shape[0] for shape in shapes) - max_width = max(shape[1] for shape in shapes) - - # Pad images if necessary - padded_images = [] - for img in images: - h_padding = max_height - img.shape[1] - w_padding = max_width - img.shape[2] - - if h_padding > 0 or w_padding > 0: - pad_top = h_padding // 2 - pad_bottom = h_padding - pad_top - pad_left = w_padding // 2 - pad_right = w_padding - pad_left - - img = torch.nn.functional.pad( - img, - (pad_left, pad_right, pad_top, pad_bottom), - mode="constant", - value=1.0, - ) - padded_images.append(img) - images = padded_images - - images = torch.stack(images) # concatenate images - - # Ensure correct shape when single image - if len(image_path_list) == 1: - # Verify shape is (1, C, H, W) - if images.dim() == 3: - images = images.unsqueeze(0) - - return images diff --git a/FastVGGT/vggt/utils/pose_enc.py b/FastVGGT/vggt/utils/pose_enc.py deleted file mode 100644 index 9d3b964330af0e62f4d36d332317ae00cb99b3a9..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/utils/pose_enc.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from .rotation import quat_to_mat, mat_to_quat - - -def extri_intri_to_pose_encoding( - extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512) -): - """Convert camera extrinsics and intrinsics to a compact pose encoding. - - This function transforms camera parameters into a unified pose encoding format, - which can be used for various downstream tasks like pose prediction or representation. - - Args: - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, - where B is batch size and S is sequence length. - In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. - The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. - intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. - Defined in pixels, with format: - [[fx, 0, cx], - [0, fy, cy], - [0, 0, 1]] - where fx, fy are focal lengths and (cx, cy) is the principal point - image_size_hw (tuple): Tuple of (height, width) of the image in pixels. - Required for computing field of view values. For example: (256, 512). - pose_encoding_type (str): Type of pose encoding to use. Currently only - supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). - - Returns: - torch.Tensor: Encoded camera pose parameters with shape BxSx9. - For "absT_quaR_FoV" type, the 9 dimensions are: - - [:3] = absolute translation vector T (3D) - - [3:7] = rotation as quaternion quat (4D) - - [7:] = field of view (2D) - """ - - # extrinsics: BxSx3x4 - # intrinsics: BxSx3x3 - - if pose_encoding_type == "absT_quaR_FoV": - R = extrinsics[:, :, :3, :3] # BxSx3x3 - T = extrinsics[:, :, :3, 3] # BxSx3 - - quat = mat_to_quat(R) - # Note the order of h and w here - H, W = image_size_hw - fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) - fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) - pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() - else: - raise NotImplementedError - - return pose_encoding - - -def pose_encoding_to_extri_intri( - pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512) -): - """Convert a pose encoding back to camera extrinsics and intrinsics. - - This function performs the inverse operation of extri_intri_to_pose_encoding, - reconstructing the full camera parameters from the compact encoding. - - Args: - pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, - where B is batch size and S is sequence length. - For "absT_quaR_FoV" type, the 9 dimensions are: - - [:3] = absolute translation vector T (3D) - - [3:7] = rotation as quaternion quat (4D) - - [7:] = field of view (2D) - image_size_hw (tuple): Tuple of (height, width) of the image in pixels. - Required for reconstructing intrinsics from field of view values. - For example: (256, 512). - pose_encoding_type (str): Type of pose encoding used. Currently only - supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). - build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. - If False, only extrinsics are returned and intrinsics will be None. - - Returns: - tuple: (extrinsics, intrinsics) - - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. - In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world - transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is - a 3x1 translation vector. - - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, - or None if build_intrinsics is False. Defined in pixels, with format: - [[fx, 0, cx], - [0, fy, cy], - [0, 0, 1]] - where fx, fy are focal lengths and (cx, cy) is the principal point, - assumed to be at the center of the image (W/2, H/2). - """ - - intrinsics = None - - if pose_encoding_type == "absT_quaR_FoV": - T = pose_encoding[..., :3] - quat = pose_encoding[..., 3:7] - fov_h = pose_encoding[..., 7] - fov_w = pose_encoding[..., 8] - - R = quat_to_mat(quat) - extrinsics = torch.cat([R, T[..., None]], dim=-1) - - if build_intrinsics: - H, W = image_size_hw - fy = (H / 2.0) / torch.tan(fov_h / 2.0) - fx = (W / 2.0) / torch.tan(fov_w / 2.0) - intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) - intrinsics[..., 0, 0] = fx - intrinsics[..., 1, 1] = fy - intrinsics[..., 0, 2] = W / 2 - intrinsics[..., 1, 2] = H / 2 - intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 - else: - raise NotImplementedError - - return extrinsics, intrinsics diff --git a/FastVGGT/vggt/utils/rotation.py b/FastVGGT/vggt/utils/rotation.py deleted file mode 100644 index f972afd8414c82fa1e9ed231725fd3f9f6ebde77..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/utils/rotation.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d - -import torch -import numpy as np -import torch.nn.functional as F - - -def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: - """ - Quaternion Order: XYZW or say ijkr, scalar-last - - Convert rotations given as quaternions to rotation matrices. - Args: - quaternions: quaternions with real part last, - as tensor of shape (..., 4). - - Returns: - Rotation matrices as tensor of shape (..., 3, 3). - """ - i, j, k, r = torch.unbind(quaternions, -1) - # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. - two_s = 2.0 / (quaternions * quaternions).sum(-1) - - o = torch.stack( - ( - 1 - two_s * (j * j + k * k), - two_s * (i * j - k * r), - two_s * (i * k + j * r), - two_s * (i * j + k * r), - 1 - two_s * (i * i + k * k), - two_s * (j * k - i * r), - two_s * (i * k - j * r), - two_s * (j * k + i * r), - 1 - two_s * (i * i + j * j), - ), - -1, - ) - return o.reshape(quaternions.shape[:-1] + (3, 3)) - - -def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: - """ - Convert rotations given as rotation matrices to quaternions. - - Args: - matrix: Rotation matrices as tensor of shape (..., 3, 3). - - Returns: - quaternions with real part last, as tensor of shape (..., 4). - Quaternion Order: XYZW or say ijkr, scalar-last - """ - if matrix.size(-1) != 3 or matrix.size(-2) != 3: - raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") - - batch_dim = matrix.shape[:-2] - m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) - - q_abs = _sqrt_positive_part( - torch.stack( - [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 - ) - ) - - # we produce the desired quaternion multiplied by each of r, i, j, k - quat_by_rijk = torch.stack( - [ - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), - # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and - # `int`. - torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), - ], - dim=-2, - ) - - # We floor here at 0.1 but the exact level is not important; if q_abs is small, - # the candidate won't be picked. - flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) - quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) - - # if not for numerical problems, quat_candidates[i] should be same (up to a sign), - # forall i; we pick the best-conditioned one (with the largest denominator) - out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) - - # Convert from rijk to ijkr - out = out[..., [1, 2, 3, 0]] - - out = standardize_quaternion(out) - - return out - - -def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: - """ - Returns torch.sqrt(torch.max(0, x)) - but with a zero subgradient where x is 0. - """ - ret = torch.zeros_like(x) - positive_mask = x > 0 - if torch.is_grad_enabled(): - ret[positive_mask] = torch.sqrt(x[positive_mask]) - else: - ret = torch.where(positive_mask, torch.sqrt(x), ret) - return ret - - -def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: - """ - Convert a unit quaternion to a standard form: one in which the real - part is non negative. - - Args: - quaternions: Quaternions with real part last, - as tensor of shape (..., 4). - - Returns: - Standardized quaternions as tensor of shape (..., 4). - """ - return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/FastVGGT/vggt/utils/visual_track.py b/FastVGGT/vggt/utils/visual_track.py deleted file mode 100644 index 796c114ccba00b5f7850e04b9444a6cd5c44b154..0000000000000000000000000000000000000000 --- a/FastVGGT/vggt/utils/visual_track.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import cv2 -import torch -import numpy as np -import os - - -def color_from_xy(x, y, W, H, cmap_name="hsv"): - """ - Map (x, y) -> color in (R, G, B). - 1) Normalize x,y to [0,1]. - 2) Combine them into a single scalar c in [0,1]. - 3) Use matplotlib's colormap to convert c -> (R,G,B). - - You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). - """ - import matplotlib.cm - import matplotlib.colors - - x_norm = x / max(W - 1, 1) - y_norm = y / max(H - 1, 1) - # Simple combination: - c = (x_norm + y_norm) / 2.0 - - cmap = matplotlib.cm.get_cmap(cmap_name) - # cmap(c) -> (r,g,b,a) in [0,1] - rgba = cmap(c) - r, g, b = rgba[0], rgba[1], rgba[2] - return (r, g, b) # in [0,1], RGB order - - -def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): - """ - Given all tracks in one sample (b), compute a (N,3) array of RGB color values - in [0,255]. The color is determined by the (x,y) position in the first - visible frame for each track. - - Args: - tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. - vis_mask_b: (S, N) boolean mask; if None, assume all are visible. - image_width, image_height: used for normalizing (x, y). - cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). - - Returns: - track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. - """ - S, N, _ = tracks_b.shape - track_colors = np.zeros((N, 3), dtype=np.uint8) - - if vis_mask_b is None: - # treat all as visible - vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) - - for i in range(N): - # Find first visible frame for track i - visible_frames = torch.where(vis_mask_b[:, i])[0] - if len(visible_frames) == 0: - # track is never visible; just assign black or something - track_colors[i] = (0, 0, 0) - continue - - first_s = int(visible_frames[0].item()) - # use that frame's (x,y) - x, y = tracks_b[first_s, i].tolist() - - # map (x,y) -> (R,G,B) in [0,1] - r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) - # scale to [0,255] - r, g, b = int(r * 255), int(g * 255), int(b * 255) - track_colors[i] = (r, g, b) - - return track_colors - - -def visualize_tracks_on_images( - images, - tracks, - track_vis_mask=None, - out_dir="track_visuals_concat_by_xy", - image_format="CHW", # "CHW" or "HWC" - normalize_mode="[0,1]", - cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" - frames_per_row=4, # New parameter for grid layout - save_grid=True, # Flag to control whether to save the grid image -): - """ - Visualizes frames in a grid layout with specified frames per row. - Each track's color is determined by its (x,y) position - in the first visible frame (or frame 0 if always visible). - Finally convert the BGR result to RGB before saving. - Also saves each individual frame as a separate PNG file. - - Args: - images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. - tracks: torch.Tensor (S, N, 2), last dim = (x, y). - track_vis_mask: torch.Tensor (S, N) or None. - out_dir: folder to save visualizations. - image_format: "CHW" or "HWC". - normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 - cmap_name: a matplotlib colormap name for color_from_xy. - frames_per_row: number of frames to display in each row of the grid. - save_grid: whether to save all frames in one grid image. - - Returns: - None (saves images in out_dir). - """ - - if len(tracks.shape) == 4: - tracks = tracks.squeeze(0) - images = images.squeeze(0) - if track_vis_mask is not None: - track_vis_mask = track_vis_mask.squeeze(0) - - import matplotlib - - matplotlib.use("Agg") # for non-interactive (optional) - - os.makedirs(out_dir, exist_ok=True) - - S = images.shape[0] - _, N, _ = tracks.shape # (S, N, 2) - - # Move to CPU - images = images.cpu().clone() - tracks = tracks.cpu().clone() - if track_vis_mask is not None: - track_vis_mask = track_vis_mask.cpu().clone() - - # Infer H, W from images shape - if image_format == "CHW": - # e.g. images[s].shape = (3, H, W) - H, W = images.shape[2], images.shape[3] - else: - # e.g. images[s].shape = (H, W, 3) - H, W = images.shape[1], images.shape[2] - - # Pre-compute the color for each track i based on first visible position - track_colors_rgb = get_track_colors_by_position( - tracks, # shape (S, N, 2) - vis_mask_b=track_vis_mask if track_vis_mask is not None else None, - image_width=W, - image_height=H, - cmap_name=cmap_name, - ) - - # We'll accumulate each frame's drawn image in a list - frame_images = [] - - for s in range(S): - # shape => either (3, H, W) or (H, W, 3) - img = images[s] - - # Convert to (H, W, 3) - if image_format == "CHW": - img = img.permute(1, 2, 0) # (H, W, 3) - # else "HWC", do nothing - - img = img.numpy().astype(np.float32) - - # Scale to [0,255] if needed - if normalize_mode == "[0,1]": - img = np.clip(img, 0, 1) * 255.0 - elif normalize_mode == "[-1,1]": - img = (img + 1.0) * 0.5 * 255.0 - img = np.clip(img, 0, 255.0) - # else no normalization - - # Convert to uint8 - img = img.astype(np.uint8) - - # For drawing in OpenCV, convert to BGR - img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - - # Draw each visible track - cur_tracks = tracks[s] # shape (N, 2) - if track_vis_mask is not None: - valid_indices = torch.where(track_vis_mask[s])[0] - else: - valid_indices = range(N) - - cur_tracks_np = cur_tracks.numpy() - for i in valid_indices: - x, y = cur_tracks_np[i] - pt = (int(round(x)), int(round(y))) - - # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR - R, G, B = track_colors_rgb[i] - color_bgr = (int(B), int(G), int(R)) - cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) - - # Convert back to RGB for consistent final saving: - img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) - - # Save individual frame - frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") - # Convert to BGR for OpenCV imwrite - frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) - cv2.imwrite(frame_path, frame_bgr) - - frame_images.append(img_rgb) - - # Only create and save the grid image if save_grid is True - if save_grid: - # Calculate grid dimensions - num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division - - # Create a grid of images - grid_img = None - for row in range(num_rows): - start_idx = row * frames_per_row - end_idx = min(start_idx + frames_per_row, S) - - # Concatenate this row horizontally - row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) - - # If this row has fewer than frames_per_row images, pad with black - if end_idx - start_idx < frames_per_row: - padding_width = (frames_per_row - (end_idx - start_idx)) * W - padding = np.zeros((H, padding_width, 3), dtype=np.uint8) - row_img = np.concatenate([row_img, padding], axis=1) - - # Add this row to the grid - if grid_img is None: - grid_img = row_img - else: - grid_img = np.concatenate([grid_img, row_img], axis=0) - - out_path = os.path.join(out_dir, "tracks_grid.png") - # Convert back to BGR for OpenCV imwrite - grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) - cv2.imwrite(out_path, grid_img_bgr) - print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") - - print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")