diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..6a2e9051cbd9c691c50f08e0bdc2a2849cc21148 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +FastVGGT/assets/attn_map.png filter=lfs diff=lfs merge=lfs -text +FastVGGT/assets/autolab_logo.png filter=lfs diff=lfs merge=lfs -text +FastVGGT/assets/main.png filter=lfs diff=lfs merge=lfs -text +FastVGGT/assets/vs.png filter=lfs diff=lfs merge=lfs -text diff --git a/FastVGGT/.gitignore b/FastVGGT/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fc1e52d1adf9afe09ea2ead132f8fa4babf666cc --- /dev/null +++ b/FastVGGT/.gitignore @@ -0,0 +1,160 @@ +.hydra/ +output/ +ckpt/ +.vscode/ +dependency/ +# Byte-compiled / optimized / DLL files +__pycache__/ +**/__pycache__/ +*.py[cod] +*$py.class +test_logs/ +quick_start_logs/ +logs/ +*.pth +/data/ +*.png +eval_results/ +.vscode/ +.curosr/ + +# C extensions +*.so +LightGlue/ +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Profiling data +.prof + +# Folder specific to your needs +**/tmp/ +**/outputs/skyseg.onnx +skyseg.onnx + +# pixi environments +.pixi +*.egg-info diff --git a/FastVGGT/.vscode/launch.json b/FastVGGT/.vscode/launch.json new file mode 100644 index 0000000000000000000000000000000000000000..218c5dcf94c00a2c11d02a0a1deeab8f37807ebf --- /dev/null +++ b/FastVGGT/.vscode/launch.json @@ -0,0 +1,85 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + + { + "name": "launch", + "type": "debugpy", + "request": "launch", + "program": "/home/sy/code/vggt_0625/training/launch.py", + "console": "integratedTerminal", + "args": "${command:pickArgs}", + "env": { + "CUDA_VISIBLE_DEVICES": "3", + }, + "cwd": "/home/sy/code/vggt_0625/training", + "justMyCode": true, + "python": "/home/sy/anaconda3/envs/vggt/bin/python" + } + ,{ + "name": "train_scannet", + "type": "debugpy", + "request": "launch", + "program": "/home/sy/code/vggt_0625/training/launch_scannet.py", + "console": "integratedTerminal", + "args": [ + // "--config_name", "scannet", + // "--exp_name", "scannet_exp001", + // "--resume_checkpoint_path", "/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "7", + "WORLD_SIZE": "1", + "RANK": "0", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345" + }, + "cwd": "/home/sy/code/vggt_0625/training", + "justMyCode": true, + "python": "/home/sy/anaconda3/envs/vggt/bin/python" + } + ,{ + "name": "eval_scannet", + "type": "debugpy", + "request": "launch", + "program": "/home/sy/code/FastVGGT/eval/eval_scannet.py", + "console": "integratedTerminal", + "args": [ + "--data_dir","/data/sy/scannetv2/process_scannet", + "--gt_ply_dir","/data/sy/scannetv2/OpenDataLab___ScanNet_v2/raw/scans", + "--output_path", "/home/sy/code/FastVGGT/eval_results", + "--merging", "0", + "--ckpt_path","/home/sy/code/vggt_0625/ckpt/model_tracker_fixed_e20.pt", + "--vis_attn_map" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "2" + }, + "justMyCode": true, + "python": "/home/sy/anaconda3/envs/fastvggt/bin/python" + }, + { + "name": "eval_cd", + "type": "debugpy", + "request": "launch", + "program": "/home/sy/code/FastVGGT/eval/eval_custom.py", + "console": "integratedTerminal", + "args": [ + "--merging", "0", + // "--kf","10", + // "--output_dir","/home/sy/code/vggt_0625/eval_results_cd", + "--data_path","/data/sy/segment-102751/", + "--vis_attn_map" + ], + "env": { + "CUDA_VISIBLE_DEVICES": "3" + }, + "justMyCode": true, + // "python": "/home/sy/anaconda3/envs/fastvggt/bin/python" + } + + ] +} \ No newline at end of file diff --git a/FastVGGT/README.md b/FastVGGT/README.md new file mode 100644 index 0000000000000000000000000000000000000000..274ea91f0e398fdf0e94b921fb5bb7c48fe799fd --- /dev/null +++ b/FastVGGT/README.md @@ -0,0 +1,163 @@ +
+

⚡️ FastVGGT: Training-Free Acceleration of Visual Geometry Transformer

+ +

+ Paper PDF + Project Page +

+ +Maclab Logo +Autolab Logo + + +**[Media Analytics & Computing Laboratory](https://mac.xmu.edu.cn/)**; **[AUTOLAB](https://zhipengzhang.cn/)** + + +[You Shen](https://mystorm16.github.io/), [Zhipeng Zhang](https://zhipengzhang.cn/), [Yansong Qu](https://quyans.github.io/), [Liujuan Cao](https://mac.xmu.edu.cn/ljcao/) +
+ + +## 📰 News +- [Sep 8, 2025] Added custom dataset evaluation. +- [Sep 3, 2025] Paper release. +- [Sep 2, 2025] Code release. + +## 🔭 Overview + +FastVGGT observes **strong similarity** in attention maps and leverages it to design a training-free acceleration method for long-sequence 3D reconstruction, **achieving up to 4× faster inference without sacrificing accuracy.** + +Autolab Logo + + +## ⚙️ Environment Setup +First, create a virtual environment using Conda, clone this repository to your local machine, and install the required dependencies. + + +```bash +conda create -n fastvggt python=3.10 +conda activate fastvggt +git clone git@github.com:mystorm16/FastVGGT.git +cd FastVGGT +pip install -r requirements.txt +``` + +Next, prepare the ScanNet dataset: http://www.scan-net.org/ScanNet/ + +Then, download the VGGT checkpoint (we use the checkpoint link provided in https://github.com/facebookresearch/vggt/tree/evaluation/evaluation): +```bash +wget https://huggingface.co/facebook/VGGT_tracker_fixed/resolve/main/model_tracker_fixed_e20.pt +``` + +Finally, configure the dataset path and VGGT checkpoint path. For example: +```bash + parser.add_argument( + "--data_dir", type=Path, default="/data/scannetv2/process_scannet" + ) + parser.add_argument( + "--gt_ply_dir", + type=Path, + default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans", + ) + parser.add_argument( + "--ckpt_path", + type=str, + default="./ckpt/model_tracker_fixed_e20.pt", + ) +``` + + +## 💎 Observation + +Note: A large number of input_frames may significantly slow down saving the visualization results. Please try using a smaller number first. +```bash +python eval/eval_scannet.py --input_frame 30 --vis_attn_map --merging 0 +``` + +We observe that many token-level attention maps are highly similar in each block, motivating our optimization of the Global Attention module. + +Autolab Logo + + + +## 🏀 Evaluation +### Custom Dataset +Please organize the data according to the following directory: +``` +/ +├── images/ +│ ├── 000000.jpg +│ ├── 000001.jpg +│ └── ... +├── pose/ # Optional: Camera poses +│ ├── 000000.txt +│ ├── 000001.txt +│ └── ... +└── gt_ply/ # Optional: GT point cloud + └── scene_xxx.ply +``` +- Required: `images/` +- Additionally required when `--enable_evaluation` is enabled: `pose/` and `gt_ply/` + +Inference only: + +```bash +python eval/eval_custom.py \ + --data_path /path/to/your_dataset \ + --output_path ./eval_results_custom \ + --plot +``` + +Inference + Evaluation (requires `pose/` and `gt_ply/`): + +```bash +python eval/eval_custom.py \ + --data_path /path/to/your_dataset \ + --enable_evaluation \ + --output_path ./eval_results_custom \ + --plot +``` + +### ScanNet +Evaluate FastVGGT on the ScanNet dataset with 1,000 input images. The **--merging** parameter specifies the block index at which the merging strategy is applied: + +```bash +python eval/eval_scannet.py --input_frame 1000 --merging 0 +``` + +Evaluate Baseline VGGT on the ScanNet dataset with 1,000 input images: +```bash +python eval/eval_scannet.py --input_frame 1000 +``` +Autolab Logo + +### 7 Scenes & NRGBD +Evaluate across two datasets, sampling keyframes every 10 frames: +```bash +python eval/eval_7andN.py --kf 10 +``` + +## 🍺 Acknowledgements + +- Thanks to these great repositories: [VGGT](https://github.com/facebookresearch/vggt), [Dust3r](https://github.com/naver/dust3r), [Fast3R](https://github.com/facebookresearch/fast3r), [CUT3R](https://github.com/CUT3R/CUT3R), [MV-DUSt3R+](https://github.com/facebookresearch/mvdust3r), [StreamVGGT](https://github.com/wzzheng/StreamVGGT), [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long), [ToMeSD](https://github.com/dbolya/tomesd) and many other inspiring works in the community. + +- Special thanks to [Jianyuan Wang](https://jytime.github.io/) for his valuable discussions and suggestions on this work. + + + + +## ⚖️ License +See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available. + +## Citation + +If you find this project helpful, please consider citing the following paper: +``` +@article{shen2025fastvggt, + title={FastVGGT: Training-Free Acceleration of Visual Geometry Transformer}, + author={Shen, You and Zhang, Zhipeng and Qu, Yansong and Cao, Liujuan}, + journal={arXiv preprint arXiv:2509.02560}, + year={2025} +} +``` \ No newline at end of file diff --git a/FastVGGT/assets/attn_map.png b/FastVGGT/assets/attn_map.png new file mode 100644 index 0000000000000000000000000000000000000000..98d6f2be8cdb49e5bccee7f902af292327fd0023 --- /dev/null +++ b/FastVGGT/assets/attn_map.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8477957f593c203bcf41df91ac3ed0d22329e22250fdab9f8f8674340964242c +size 2369408 diff --git a/FastVGGT/assets/autolab_logo.png b/FastVGGT/assets/autolab_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..c6ca9a41f5c05a1b2ae66368461672f46e20bba1 --- /dev/null +++ b/FastVGGT/assets/autolab_logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4fcead3160cbf561c4385cc8b938a17a94652e3d849da6497f053d32d1245596 +size 5126003 diff --git a/FastVGGT/assets/maclab_logo.png b/FastVGGT/assets/maclab_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..b96660decf945d3d786f41b9b42cb92798f9191f Binary files /dev/null and b/FastVGGT/assets/maclab_logo.png differ diff --git a/FastVGGT/assets/main.png b/FastVGGT/assets/main.png new file mode 100644 index 0000000000000000000000000000000000000000..3bcab2d519689dfba3283e606cd55cb84175ede5 --- /dev/null +++ b/FastVGGT/assets/main.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eecacb414647f01dc8a52b4aba5ff2556733f46d1b9129613e3f59aceff69685 +size 883717 diff --git a/FastVGGT/assets/vs.png b/FastVGGT/assets/vs.png new file mode 100644 index 0000000000000000000000000000000000000000..3729d9ba86000b81d0e9c878b2db6f16c7f9c48a --- /dev/null +++ b/FastVGGT/assets/vs.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dac1ea5397b9f985c6fb86fc5f321313cc5a372d0917b7c1c1c7dd2cea6bec5f +size 230157 diff --git a/FastVGGT/eval/__pycache__/base.cpython-310.pyc b/FastVGGT/eval/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49170288c162684ebafa68e5242efd320f0f35b8 Binary files /dev/null and b/FastVGGT/eval/__pycache__/base.cpython-310.pyc differ diff --git a/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc b/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e6e0478775fe98dc22180bc43bad8073c43c907 Binary files /dev/null and b/FastVGGT/eval/__pycache__/criterion.cpython-310.pyc differ diff --git a/FastVGGT/eval/__pycache__/data.cpython-310.pyc b/FastVGGT/eval/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d8ec6f8e1828fa87b02a4511490ad858de2093a Binary files /dev/null and b/FastVGGT/eval/__pycache__/data.cpython-310.pyc differ diff --git a/FastVGGT/eval/__pycache__/data.cpython-37.pyc b/FastVGGT/eval/__pycache__/data.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d76f2c95748956a04f009e99f39cefd9cc6ec4e0 Binary files /dev/null and b/FastVGGT/eval/__pycache__/data.cpython-37.pyc differ diff --git a/FastVGGT/eval/__pycache__/utils.cpython-310.pyc b/FastVGGT/eval/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..638db49f10789e53dcf0ff33c9aaa28b814e3e72 Binary files /dev/null and b/FastVGGT/eval/__pycache__/utils.cpython-310.pyc differ diff --git a/FastVGGT/eval/__pycache__/utils.cpython-37.pyc b/FastVGGT/eval/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd73966947cb74c319d66d7bf221d73b5dde0ce0 Binary files /dev/null and b/FastVGGT/eval/__pycache__/utils.cpython-37.pyc differ diff --git a/FastVGGT/eval/base.py b/FastVGGT/eval/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4a716449a71f552ea408df0eb37e854cf4e92da6 --- /dev/null +++ b/FastVGGT/eval/base.py @@ -0,0 +1,273 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- +# base class for implementing datasets +# -------------------------------------------------------- +import PIL +import numpy as np +import torch + +from dataset_utils.transforms import ImgNorm +import dataset_utils.cropping as cropping +from utils import depthmap_to_absolute_camera_coordinates + + +class BaseStereoViewDataset: + """Define all basic options. + + Usage: + class MyDataset (BaseStereoViewDataset): + def _get_views(self, idx, rng): + # overload here + views = [] + views.append(dict(img=, ...)) + return views + """ + + def __init__( + self, + *, # only keyword arguments + split=None, + resolution=None, # square_size or (width, height) or list of [(width,height), ...] + transform=ImgNorm, + aug_crop=False, + seed=None, + ): + self.num_views = 2 + self.split = split + self._set_resolutions(resolution) + + self.transform = transform + if isinstance(transform, str): + transform = eval(transform) + + self.aug_crop = aug_crop + self.seed = seed + + def __len__(self): + return len(self.scenes) + + def get_stats(self): + return f"{len(self)} pairs" + + def __repr__(self): + resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" + return ( + f"""{type(self).__name__}({self.get_stats()}, + {self.split=}, + {self.seed=}, + resolutions={resolutions_str}, + {self.transform=})""".replace( + "self.", "" + ) + .replace("\n", "") + .replace(" ", "") + ) + + def _get_views(self, idx, resolution, rng): + raise NotImplementedError() + + def __getitem__(self, idx): + if isinstance(idx, tuple): + # the idx is specifying the aspect-ratio + idx, ar_idx = idx + else: + assert len(self._resolutions) == 1 + ar_idx = 0 + + # set-up the rng + if self.seed: # reseed for each __getitem__ + self._rng = np.random.default_rng(seed=self.seed + idx) + elif not hasattr(self, "_rng"): + seed = torch.initial_seed() # this is different for each dataloader process + self._rng = np.random.default_rng(seed=seed) + + # over-loaded code + resolution = self._resolutions[ + ar_idx + ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) + views = self._get_views(idx, resolution, self._rng) + + # check data-types + for v, view in enumerate(views): + assert ( + "pts3d" not in view + ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" + view["idx"] = v + + # encode the image + width, height = view["img"].size + view["true_shape"] = np.int32((height, width)) + view["img"] = self.transform(view["img"]) + + assert "camera_intrinsics" in view + if "camera_pose" not in view: + view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32) + else: + assert np.isfinite( + view["camera_pose"] + ).all(), f"NaN in camera pose for view {view_name(view)}" + assert "pts3d" not in view + assert "valid_mask" not in view + assert np.isfinite( + view["depthmap"] + ).all(), f"NaN in depthmap for view {view_name(view)}" + pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) + + view["pts3d"] = pts3d + view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) + + # check all datatypes + for key, val in view.items(): + res, err_msg = is_good_type(key, val) + assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" + K = view["camera_intrinsics"] + view["img_mask"] = True + view["ray_mask"] = False + view["ray_map"] = torch.full( + (6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan + ) + view["update"] = True + view["reset"] = False + + # last thing done! + for view in views: + # transpose to make sure all views are the same size + transpose_to_landscape(view) + # this allows to check whether the RNG is is the same state each time + view["rng"] = int.from_bytes(self._rng.bytes(4), "big") + return views + + def _set_resolutions(self, resolutions): + """Set the resolution(s) of the dataset. + Params: + - resolutions: int or tuple or list of tuples + """ + assert resolutions is not None, "undefined resolution" + + if not isinstance(resolutions, list): + resolutions = [resolutions] + + self._resolutions = [] + for resolution in resolutions: + if isinstance(resolution, int): + width = height = resolution + else: + width, height = resolution + assert isinstance( + width, int + ), f"Bad type for {width=} {type(width)=}, should be int" + assert isinstance( + height, int + ), f"Bad type for {height=} {type(height)=}, should be int" + assert width >= height + self._resolutions.append((width, height)) + + def _crop_resize_if_necessary( + self, image, depthmap, intrinsics, resolution, rng=None, info=None + ): + """This function: + - first downsizes the image with LANCZOS inteprolation, + which is better than bilinear interpolation in + """ + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + + # downscale with lanczos interpolation so that image.size == resolution + # cropping centered on the principal point + W, H = image.size + cx, cy = intrinsics[:2, 2].round().astype(int) + + # calculate min distance to margin + min_margin_x = min(cx, W - cx) + min_margin_y = min(cy, H - cy) + assert min_margin_x > W / 5, f"Bad principal point in view={info}" + assert min_margin_y > H / 5, f"Bad principal point in view={info}" + + ## Center crop + # Crop on the principal point, make it always centered + # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) + l, t = cx - min_margin_x, cy - min_margin_y + r, b = cx + min_margin_x, cy + min_margin_y + crop_bbox = (l, t, r, b) + + image, depthmap, intrinsics = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + + # # transpose the resolution if necessary + W, H = image.size # new size + assert resolution[0] >= resolution[1] + if H > 1.1 * W: + # image is portrait mode + resolution = resolution[::-1] + elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]: + # image is square, so we chose (portrait, landscape) randomly + if rng.integers(2): + resolution = resolution[::-1] + + # high-quality Lanczos down-scaling + target_resolution = np.array(resolution) + # # if self.aug_crop > 1: + # # target_resolution += rng.integers(0, self.aug_crop) + # if resolution != (224, 224): + # halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8 + # ## Recale with max factor, so one of width or height might be larger than target_resolution + # image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh)) + # else: + image, depthmap, intrinsics = cropping.rescale_image_depthmap( + image, depthmap, intrinsics, target_resolution + ) + # actual cropping (if necessary) with bilinear interpolation + # if resolution == (224, 224): + intrinsics2 = cropping.camera_matrix_of_crop( + intrinsics, image.size, resolution, offset_factor=0.5 + ) + crop_bbox = cropping.bbox_from_intrinsics_in_out( + intrinsics, intrinsics2, resolution + ) + image, depthmap, intrinsics = cropping.crop_image_depthmap( + image, depthmap, intrinsics, crop_bbox + ) + return image, depthmap, intrinsics + + +def is_good_type(key, v): + """returns (is_good, err_msg)""" + if isinstance(v, (str, int, tuple)): + return True, None + if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): + return False, f"bad {v.dtype=}" + return True, None + + +def view_name(view, batch_index=None): + def sel(x): + return x[batch_index] if batch_index not in (None, slice(None)) else x + + db = sel(view["dataset"]) + label = sel(view["label"]) + instance = sel(view["instance"]) + return f"{db}/{label}/{instance}" + + +def transpose_to_landscape(view): + height, width = view["true_shape"] + + if width < height: + # rectify portrait to landscape + assert view["img"].shape == (3, height, width) + view["img"] = view["img"].swapaxes(1, 2) + + assert view["valid_mask"].shape == (height, width) + view["valid_mask"] = view["valid_mask"].swapaxes(0, 1) + + assert view["depthmap"].shape == (height, width) + view["depthmap"] = view["depthmap"].swapaxes(0, 1) + + assert view["pts3d"].shape == (height, width, 3) + view["pts3d"] = view["pts3d"].swapaxes(0, 1) + + # transpose x and y pixels + view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]] diff --git a/FastVGGT/eval/criterion.py b/FastVGGT/eval/criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..a63c991b5fd5b61325b56055d7d380a27a336971 --- /dev/null +++ b/FastVGGT/eval/criterion.py @@ -0,0 +1,534 @@ +import torch +import torch.nn as nn +from copy import copy, deepcopy + +from eval.dataset_utils.corr import geotrf, inv + + +def invalid_to_nans(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = float("nan") + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr + + +def invalid_to_zeros(arr, valid_mask, ndim=999): + if valid_mask is not None: + arr = arr.clone() + arr[~valid_mask] = 0 + nnz = valid_mask.view(len(valid_mask), -1).sum(1) + else: + nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image + if arr.ndim > ndim: + arr = arr.flatten(-2 - (arr.ndim - ndim), -2) + return arr, nnz + + +class BaseCriterion(nn.Module): + def __init__(self, reduction="mean"): + super().__init__() + self.reduction = reduction + + +class Criterion(nn.Module): + def __init__(self, criterion=None): + super().__init__() + assert isinstance( + criterion, BaseCriterion + ), f"{criterion} is not a proper criterion!" + self.criterion = copy(criterion) + + def get_name(self): + return f"{type(self).__name__}({self.criterion})" + + def with_reduction(self, mode="none"): + res = loss = deepcopy(self) + while loss is not None: + assert isinstance(loss, Criterion) + loss.criterion.reduction = mode # make it return the loss for each sample + loss = loss._loss2 # we assume loss is a Multiloss + return res + + +class MultiLoss(nn.Module): + """Easily combinable losses (also keep track of individual loss values): + loss = MyLoss1() + 0.1*MyLoss2() + Usage: + Inherit from this class and override get_name() and compute_loss() + """ + + def __init__(self): + super().__init__() + self._alpha = 1 + self._loss2 = None + + def compute_loss(self, *args, **kwargs): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + def __mul__(self, alpha): + assert isinstance(alpha, (int, float)) + res = copy(self) + res._alpha = alpha + return res + + __rmul__ = __mul__ # same + + def __add__(self, loss2): + assert isinstance(loss2, MultiLoss) + res = cur = copy(self) + + while cur._loss2 is not None: + cur = cur._loss2 + cur._loss2 = loss2 + return res + + def __repr__(self): + name = self.get_name() + if self._alpha != 1: + name = f"{self._alpha:g}*{name}" + if self._loss2: + name = f"{name} + {self._loss2}" + return name + + def forward(self, *args, **kwargs): + loss = self.compute_loss(*args, **kwargs) + if isinstance(loss, tuple): + loss, details = loss + elif loss.ndim == 0: + details = {self.get_name(): float(loss)} + else: + details = {} + loss = loss * self._alpha + + if self._loss2: + loss2, details2 = self._loss2(*args, **kwargs) + loss = loss + loss2 + details |= details2 + + return loss, details + + +class LLoss(BaseCriterion): + """L-norm loss""" + + def forward(self, a, b): + assert ( + a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3 + ), f"Bad shape = {a.shape}" + dist = self.distance(a, b) + + if self.reduction == "none": + return dist + if self.reduction == "sum": + return dist.sum() + if self.reduction == "mean": + return dist.mean() if dist.numel() > 0 else dist.new_zeros(()) + raise ValueError(f"bad {self.reduction=} mode") + + def distance(self, a, b): + raise NotImplementedError() + + +class L21Loss(LLoss): + """Euclidean distance between 3d points""" + + def distance(self, a, b): + return torch.norm(a - b, dim=-1) # normalized L2 distance + + +L21 = L21Loss() + + +def get_pred_pts3d(gt, pred, use_pose=False): + assert use_pose is True + return pred["pts3d_in_other_view"] # return! + + +def Sum(losses, masks, conf=None): + loss, mask = losses[0], masks[0] + if loss.ndim > 0: + # we are actually returning the loss for every pixels + if conf is not None: + return losses, masks, conf + return losses, masks + else: + # we are returning the global loss + for loss2 in losses[1:]: + loss = loss + loss2 + return loss + + +def get_norm_factor(pts, norm_mode="avg_dis", valids=None, fix_first=True): + assert pts[0].ndim >= 3 and pts[0].shape[-1] == 3 + assert pts[1] is None or (pts[1].ndim >= 3 and pts[1].shape[-1] == 3) + norm_mode, dis_mode = norm_mode.split("_") + + nan_pts = [] + nnzs = [] + + if norm_mode == "avg": + # gather all points together (joint normalization) + + for i, pt in enumerate(pts): + nan_pt, nnz = invalid_to_zeros(pt, valids[i], ndim=3) + nan_pts.append(nan_pt) + nnzs.append(nnz) + + if fix_first: + break + all_pts = torch.cat(nan_pts, dim=1) + + # compute distance to origin + all_dis = all_pts.norm(dim=-1) + if dis_mode == "dis": + pass # do nothing + elif dis_mode == "log1p": + all_dis = torch.log1p(all_dis) + else: + raise ValueError(f"bad {dis_mode=}") + + norm_factor = all_dis.sum(dim=1) / (torch.cat(nnzs).sum() + 1e-8) + else: + raise ValueError(f"Not implemented {norm_mode=}") + + norm_factor = norm_factor.clip(min=1e-8) + while norm_factor.ndim < pts[0].ndim: + norm_factor.unsqueeze_(-1) + + return norm_factor + + +def normalize_pointcloud_t( + pts, norm_mode="avg_dis", valids=None, fix_first=True, gt=False +): + if gt: + norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first) + res = [] + + for i, pt in enumerate(pts): + res.append(pt / norm_factor) + + else: + # pts_l, pts_r = pts + # use pts_l and pts_r[-1] as pts to normalize + norm_factor = get_norm_factor(pts, norm_mode, valids, fix_first) + + res = [] + + for i in range(len(pts)): + res.append(pts[i] / norm_factor) + # res_r.append(pts_r[i] / norm_factor) + + # res = [res_l, res_r] + + return res, norm_factor + + +@torch.no_grad() +def get_joint_pointcloud_depth(zs, valid_masks=None, quantile=0.5): + # set invalid points to NaN + _zs = [] + for i in range(len(zs)): + valid_mask = valid_masks[i] if valid_masks is not None else None + _z = invalid_to_nans(zs[i], valid_mask).reshape(len(zs[i]), -1) + _zs.append(_z) + + _zs = torch.cat(_zs, dim=-1) + + # compute median depth overall (ignoring nans) + if quantile == 0.5: + shift_z = torch.nanmedian(_zs, dim=-1).values + else: + shift_z = torch.nanquantile(_zs, quantile, dim=-1) + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale(pts, valid_masks=None, z_only=False, center=True): + # set invalid points to NaN + + _pts = [] + for i in range(len(pts)): + valid_mask = valid_masks[i] if valid_masks is not None else None + _pt = invalid_to_nans(pts[i], valid_mask).reshape(len(pts[i]), -1, 3) + _pts.append(_pt) + + _pts = torch.cat(_pts, dim=1) + + # compute median center + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + # compute median norm + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + return _center[:, None, :, :], scale[:, None, None, None] + + +class Regr3D_t(Criterion, MultiLoss): + def __init__(self, criterion, norm_mode="avg_dis", gt_scale=False, fix_first=True): + super().__init__(criterion) + self.norm_mode = norm_mode + self.gt_scale = gt_scale + self.fix_first = fix_first + + def get_all_pts3d_t(self, gts, preds, dist_clip=None): + # everything is normalized w.r.t. camera of view1 + in_camera1 = inv(gts[0]["camera_pose"]) + + gt_pts = [] + valids = [] + pr_pts = [] + + for i, gt in enumerate(gts): + # in_camera1: Bs, 4, 4 gt['pts3d']: Bs, H, W, 3 + gt_pts.append(geotrf(in_camera1, gt["pts3d"])) + valid = gt["valid_mask"].clone() + + if dist_clip is not None: + # points that are too far-away == invalid + dis = gt["pts3d"].norm(dim=-1) + valid = valid & (dis <= dist_clip) + + valids.append(valid) + pr_pts.append(get_pred_pts3d(gt, preds[i], use_pose=True)) + # if i != len(gts)-1: + # pr_pts_l.append(get_pred_pts3d(gt, preds[i][0], use_pose=(i!=0))) + + # if i != 0: + # pr_pts_r.append(get_pred_pts3d(gt, preds[i-1][1], use_pose=(i!=0))) + + # pr_pts = (pr_pts_l, pr_pts_r) + + if self.norm_mode: + pr_pts, pr_factor = normalize_pointcloud_t( + pr_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=False + ) + else: + pr_factor = None + + if self.norm_mode and not self.gt_scale: + gt_pts, gt_factor = normalize_pointcloud_t( + gt_pts, self.norm_mode, valids, fix_first=self.fix_first, gt=True + ) + else: + gt_factor = None + + return gt_pts, pr_pts, gt_factor, pr_factor, valids, {} + + def compute_frame_loss(self, gts, preds, **kw): + gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( + self.get_all_pts3d_t(gts, preds, **kw) + ) + + pred_pts_l, pred_pts_r = pred_pts + + loss_all = [] + mask_all = [] + conf_all = [] + + loss_left = 0 + loss_right = 0 + pred_conf_l = 0 + pred_conf_r = 0 + + for i in range(len(gt_pts)): + + # Left (Reference) + if i != len(gt_pts) - 1: + frame_loss = self.criterion( + pred_pts_l[i][masks[i]], gt_pts[i][masks[i]] + ) + + loss_all.append(frame_loss) + mask_all.append(masks[i]) + conf_all.append(preds[i][0]["conf"]) + + # To compare target/reference loss + if i != 0: + loss_left += frame_loss.cpu().detach().numpy().mean() + pred_conf_l += preds[i][0]["conf"].cpu().detach().numpy().mean() + + # Right (Target) + if i != 0: + frame_loss = self.criterion( + pred_pts_r[i - 1][masks[i]], gt_pts[i][masks[i]] + ) + + loss_all.append(frame_loss) + mask_all.append(masks[i]) + conf_all.append(preds[i - 1][1]["conf"]) + + # To compare target/reference loss + if i != len(gt_pts) - 1: + loss_right += frame_loss.cpu().detach().numpy().mean() + pred_conf_r += preds[i - 1][1]["conf"].cpu().detach().numpy().mean() + + if pr_factor is not None and gt_factor is not None: + filter_factor = pr_factor[pr_factor > gt_factor] + else: + filter_factor = [] + + if len(filter_factor) > 0: + factor_loss = (filter_factor - gt_factor).abs().mean() + else: + factor_loss = 0.0 + + self_name = type(self).__name__ + details = { + self_name + "_pts3d_1": float(loss_all[0].mean()), + self_name + "_pts3d_2": float(loss_all[1].mean()), + self_name + "loss_left": float(loss_left), + self_name + "loss_right": float(loss_right), + self_name + "conf_left": float(pred_conf_l), + self_name + "conf_right": float(pred_conf_r), + } + + return Sum(loss_all, mask_all, conf_all), (details | monitoring), factor_loss + + +class ConfLoss_t(MultiLoss): + """Weighted regression by learned confidence. + Assuming the input pixel_loss is a pixel-level regression loss. + + Principle: + high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10) + low confidence means low conf = 10 ==> conf_loss = x * 10 - alpha*log(10) + + alpha: hyperparameter + """ + + def __init__(self, pixel_loss, alpha=1): + super().__init__() + assert alpha > 0 + self.alpha = alpha + self.pixel_loss = pixel_loss.with_reduction("none") + + def get_name(self): + return f"ConfLoss({self.pixel_loss})" + + def get_conf_log(self, x): + return x, torch.log(x) + + def compute_frame_loss(self, gts, preds, **kw): + # compute per-pixel loss + (losses, masks, confs), details, loss_factor = ( + self.pixel_loss.compute_frame_loss(gts, preds, **kw) + ) + + # weight by confidence + conf_losses = [] + conf_sum = 0 + for i in range(len(losses)): + conf, log_conf = self.get_conf_log(confs[i][masks[i]]) + conf_sum += conf.mean() + conf_loss = losses[i] * conf - self.alpha * log_conf + conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0 + conf_losses.append(conf_loss) + + conf_losses = torch.stack(conf_losses) * 2.0 + conf_loss_mean = conf_losses.mean() + + return ( + conf_loss_mean, + dict( + conf_loss_1=float(conf_losses[0]), + conf_loss2=float(conf_losses[1]), + conf_mean=conf_sum / len(losses), + **details, + ), + loss_factor, + ) + + +class Regr3D_t_ShiftInv(Regr3D_t): + """Same than Regr3D but invariant to depth shift.""" + + def get_all_pts3d_t(self, gts, preds): + # compute unnormalized points + gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( + super().get_all_pts3d_t(gts, preds) + ) + + # pred_pts_l, pred_pts_r = pred_pts + gt_zs = [gt_pt[..., 2] for gt_pt in gt_pts] + + pred_zs = [pred_pt[..., 2] for pred_pt in pred_pts] + # pred_zs.append(pred_pts_r[-1][..., 2]) + + # compute median depth + gt_shift_z = get_joint_pointcloud_depth(gt_zs, masks)[:, None, None] + pred_shift_z = get_joint_pointcloud_depth(pred_zs, masks)[:, None, None] + + # subtract the median depth + for i in range(len(gt_pts)): + gt_pts[i][..., 2] -= gt_shift_z + + for i in range(len(pred_pts)): + # for j in range(len(pred_pts[i])): + pred_pts[i][..., 2] -= pred_shift_z + + monitoring = dict( + monitoring, + gt_shift_z=gt_shift_z.mean().detach(), + pred_shift_z=pred_shift_z.mean().detach(), + ) + return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring + + +class Regr3D_t_ScaleInv(Regr3D_t): + """Same than Regr3D but invariant to depth shift. + if gt_scale == True: enforce the prediction to take the same scale than GT + """ + + def get_all_pts3d_t(self, gts, preds): + # compute depth-normalized points + gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( + super().get_all_pts3d_t(gts, preds) + ) + + # measure scene scale + + # pred_pts_l, pred_pts_r = pred_pts + + pred_pts_all = [ + x.clone() for x in pred_pts + ] # [pred_pt for pred_pt in pred_pts_l] + # pred_pts_all.append(pred_pts_r[-1]) + + _, gt_scale = get_joint_pointcloud_center_scale(gt_pts, masks) + _, pred_scale = get_joint_pointcloud_center_scale(pred_pts_all, masks) + + # prevent predictions to be in a ridiculous range + pred_scale = pred_scale.clip(min=1e-3, max=1e3) + + # subtract the median depth + if self.gt_scale: + for i in range(len(pred_pts)): + # for j in range(len(pred_pts[i])): + pred_pts[i] *= gt_scale / pred_scale + + else: + for i in range(len(pred_pts)): + # for j in range(len(pred_pts[i])): + pred_pts[i] *= pred_scale / gt_scale + + for i in range(len(gt_pts)): + gt_pts[i] *= gt_scale / pred_scale + + monitoring = dict( + monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach() + ) + + return gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring + + +class Regr3D_t_ScaleShiftInv(Regr3D_t_ScaleInv, Regr3D_t_ShiftInv): + # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv + pass diff --git a/FastVGGT/eval/data.py b/FastVGGT/eval/data.py new file mode 100644 index 0000000000000000000000000000000000000000..301788e3ae6b67091fb031b068ecbecff1eae48d --- /dev/null +++ b/FastVGGT/eval/data.py @@ -0,0 +1,338 @@ +import os +import cv2 +import numpy as np +import os.path as osp +from collections import deque +from base import BaseStereoViewDataset +import dataset_utils.cropping as cropping +from vggt.utils.eval_utils import imread_cv2, shuffle_deque + + +class SevenScenes(BaseStereoViewDataset): + def __init__( + self, + num_seq=1, + num_frames=5, + min_thresh=10, + max_thresh=100, + test_id=None, + full_video=False, + tuple_list=None, + seq_id=None, + rebuttal=False, + shuffle_seed=-1, + kf_every=1, + *args, + ROOT, + **kwargs, + ): + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self.num_seq = num_seq + self.num_frames = num_frames + self.max_thresh = max_thresh + self.min_thresh = min_thresh + self.test_id = test_id + self.full_video = full_video + self.kf_every = kf_every + self.seq_id = seq_id + self.rebuttal = rebuttal + self.shuffle_seed = shuffle_seed + + # load all scenes + self.load_all_tuples(tuple_list) + self.load_all_scenes(ROOT) + + def __len__(self): + if self.tuple_list is not None: + return len(self.tuple_list) + return len(self.scene_list) * self.num_seq + + def load_all_tuples(self, tuple_list): + if tuple_list is not None: + self.tuple_list = tuple_list + # with open(tuple_path) as f: + # self.tuple_list = f.read().splitlines() + + else: + self.tuple_list = None + + def load_all_scenes(self, base_dir): + + if self.tuple_list is not None: + # Use pre-defined simplerecon scene_ids + self.scene_list = [ + "stairs/seq-06", + "stairs/seq-02", + "pumpkin/seq-06", + "chess/seq-01", + "heads/seq-02", + "fire/seq-02", + "office/seq-03", + "pumpkin/seq-03", + "redkitchen/seq-07", + "chess/seq-02", + "office/seq-01", + "redkitchen/seq-01", + "fire/seq-01", + ] + print(f"Found {len(self.scene_list)} sequences in split {self.split}") + return + + scenes = os.listdir(base_dir) + + file_split = {"train": "TrainSplit.txt", "test": "TestSplit.txt"}[self.split] + + self.scene_list = [] + for scene in scenes: + if self.test_id is not None and scene != self.test_id: + continue + # read file split + with open(osp.join(base_dir, scene, file_split)) as f: + seq_ids = f.read().splitlines() + + for seq_id in seq_ids: + # seq is string, take the int part and make it 01, 02, 03 + # seq_id = 'seq-{:2d}'.format(int(seq_id)) + num_part = "".join(filter(str.isdigit, seq_id)) + seq_id = f"seq-{num_part.zfill(2)}" + if self.seq_id is not None and seq_id != self.seq_id: + continue + self.scene_list.append(f"{scene}/{seq_id}") + + print(f"Found {len(self.scene_list)} sequences in split {self.split}") + + def _get_views(self, idx, resolution, rng): + + if self.tuple_list is not None: + line = self.tuple_list[idx].split(" ") + scene_id = line[0] + img_idxs = line[1:] + + else: + scene_id = self.scene_list[idx // self.num_seq] + seq_id = idx % self.num_seq + + data_path = osp.join(self.ROOT, scene_id) + num_files = len([name for name in os.listdir(data_path) if "color" in name]) + img_idxs = [f"{i:06d}" for i in range(num_files)] + img_idxs = img_idxs[:: self.kf_every] + + # Intrinsics used in SimpleRecon + fx, fy, cx, cy = 525, 525, 320, 240 + intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + views = [] + imgs_idxs = deque(img_idxs) + if self.shuffle_seed >= 0: + imgs_idxs = shuffle_deque(imgs_idxs) + + while len(imgs_idxs) > 0: + im_idx = imgs_idxs.popleft() + impath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.color.png") + depthpath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.depth.proj.png") + posepath = osp.join(self.ROOT, scene_id, f"frame-{im_idx}.pose.txt") + + rgb_image = imread_cv2(impath) + + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0])) + + depthmap[depthmap == 65535] = 0 + depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0 + + depthmap[depthmap > 10] = 0 + depthmap[depthmap < 1e-3] = 0 + + camera_pose = np.loadtxt(posepath).astype(np.float32) + + if resolution != (224, 224) or self.rebuttal: + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath + ) + else: + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath + ) + W, H = rgb_image.size + cx = W // 2 + cy = H // 2 + l, t = cx - 112, cy - 112 + r, b = cx + 112, cy + 112 + crop_bbox = (l, t, r, b) + rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap( + rgb_image, depthmap, intrinsics, crop_bbox + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset="7scenes", + label=osp.join(scene_id, im_idx), + instance=impath, + ) + ) + return views + + +class NRGBD(BaseStereoViewDataset): + def __init__( + self, + num_seq=1, + num_frames=5, + min_thresh=10, + max_thresh=100, + test_id=None, + full_video=False, + tuple_list=None, + seq_id=None, + rebuttal=False, + shuffle_seed=-1, + kf_every=1, + *args, + ROOT, + **kwargs, + ): + + self.ROOT = ROOT + super().__init__(*args, **kwargs) + self.num_seq = num_seq + self.num_frames = num_frames + self.max_thresh = max_thresh + self.min_thresh = min_thresh + self.test_id = test_id + self.full_video = full_video + self.kf_every = kf_every + self.seq_id = seq_id + self.rebuttal = rebuttal + self.shuffle_seed = shuffle_seed + + # load all scenes + self.load_all_tuples(tuple_list) + self.load_all_scenes(ROOT) + + def __len__(self): + if self.tuple_list is not None: + return len(self.tuple_list) + return len(self.scene_list) * self.num_seq + + def load_all_tuples(self, tuple_list): + if tuple_list is not None: + self.tuple_list = tuple_list + # with open(tuple_path) as f: + # self.tuple_list = f.read().splitlines() + + else: + self.tuple_list = None + + def load_all_scenes(self, base_dir): + + scenes = [ + d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) + ] + + if self.test_id is not None: + self.scene_list = [self.test_id] + + else: + self.scene_list = scenes + + print(f"Found {len(self.scene_list)} sequences in split {self.split}") + + def load_poses(self, path): + file = open(path, "r") + lines = file.readlines() + file.close() + poses = [] + valid = [] + lines_per_matrix = 4 + for i in range(0, len(lines), lines_per_matrix): + if "nan" in lines[i]: + valid.append(False) + poses.append(np.eye(4, 4, dtype=np.float32).tolist()) + else: + valid.append(True) + pose_floats = [ + [float(x) for x in line.split()] + for line in lines[i : i + lines_per_matrix] + ] + poses.append(pose_floats) + + return np.array(poses, dtype=np.float32), valid + + def _get_views(self, idx, resolution, rng): + + if self.tuple_list is not None: + line = self.tuple_list[idx].split(" ") + scene_id = line[0] + img_idxs = line[1:] + + else: + scene_id = self.scene_list[idx // self.num_seq] + + num_files = len(os.listdir(os.path.join(self.ROOT, scene_id, "images"))) + img_idxs = [f"{i}" for i in range(num_files)] + img_idxs = img_idxs[:: min(self.kf_every, len(img_idxs) // 2)] + + fx, fy, cx, cy = 554.2562584220408, 554.2562584220408, 320, 240 + intrinsics_ = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + posepath = osp.join(self.ROOT, scene_id, f"poses.txt") + camera_poses, valids = self.load_poses(posepath) + + imgs_idxs = deque(img_idxs) + if self.shuffle_seed >= 0: + imgs_idxs = shuffle_deque(imgs_idxs) + views = [] + + while len(imgs_idxs) > 0: + im_idx = imgs_idxs.popleft() + + impath = osp.join(self.ROOT, scene_id, "images", f"img{im_idx}.png") + depthpath = osp.join(self.ROOT, scene_id, "depth", f"depth{im_idx}.png") + + rgb_image = imread_cv2(impath) + depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) + depthmap = np.nan_to_num(depthmap.astype(np.float32), 0.0) / 1000.0 + depthmap[depthmap > 10] = 0 + depthmap[depthmap < 1e-3] = 0 + + rgb_image = cv2.resize(rgb_image, (depthmap.shape[1], depthmap.shape[0])) + + camera_pose = camera_poses[int(im_idx)] + # gl to cv + camera_pose[:, 1:3] *= -1.0 + if resolution != (224, 224) or self.rebuttal: + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics_, resolution, rng=rng, info=impath + ) + else: + rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( + rgb_image, depthmap, intrinsics_, (512, 384), rng=rng, info=impath + ) + W, H = rgb_image.size + cx = W // 2 + cy = H // 2 + l, t = cx - 112, cy - 112 + r, b = cx + 112, cy + 112 + crop_bbox = (l, t, r, b) + rgb_image, depthmap, intrinsics = cropping.crop_image_depthmap( + rgb_image, depthmap, intrinsics, crop_bbox + ) + + views.append( + dict( + img=rgb_image, + depthmap=depthmap, + camera_pose=camera_pose, + camera_intrinsics=intrinsics, + dataset="nrgbd", + label=osp.join(scene_id, im_idx), + instance=impath, + ) + ) + + return views diff --git a/FastVGGT/eval/dataset_utils/__init__.py b/FastVGGT/eval/dataset_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/FastVGGT/eval/dataset_utils/__init__.py @@ -0,0 +1 @@ + diff --git a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d00b3e8b6f5fd9ef7d1428ff771243e1557ddf9f Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ebd5d22d44ddce633e9a05a7467e9ae0c3f9032 Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbd10218eddf65486b375109c23cf8cecd7c3ba5 Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/corr.cpython-310.pyc differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..860c8970fc49d65b2a4101b77cc313a6aeac618a Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-310.pyc differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fe7870fb079d9dc08ca647866ff8c77fd0f2ec0 Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/cropping.cpython-37.pyc differ diff --git a/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc b/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b4530be10547e3e7bf79e6cef6470880cb710a4 Binary files /dev/null and b/FastVGGT/eval/dataset_utils/__pycache__/transforms.cpython-310.pyc differ diff --git a/FastVGGT/eval/dataset_utils/corr.py b/FastVGGT/eval/dataset_utils/corr.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf5de18b7388e38e45c3f313487957811abc483 --- /dev/null +++ b/FastVGGT/eval/dataset_utils/corr.py @@ -0,0 +1,234 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- + +import numpy as np +import torch + + +def todevice(batch, device, callback=None, non_blocking=False): + """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). + + batch: list, tuple, dict of tensors or other things + device: pytorch device or 'numpy' + callback: function that would be called on every sub-elements. + """ + if callback: + batch = callback(batch) + + if isinstance(batch, dict): + return {k: todevice(v, device) for k, v in batch.items()} + + if isinstance(batch, (tuple, list)): + return type(batch)(todevice(x, device) for x in batch) + + x = batch + if device == "numpy": + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif x is not None: + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if torch.is_tensor(x): + x = x.to(device, non_blocking=non_blocking) + return x + + +to_device = todevice # alias + + +def to_numpy(x): + return todevice(x, "numpy") + + +def geotrf(Trf, pts, ncol=None, norm=False): + """Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + if ( + isinstance(Trf, torch.Tensor) + and isinstance(pts, torch.Tensor) + and Trf.ndim == 3 + and pts.ndim == 4 + ): + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = ( + torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + + Trf[:, None, None, :d, d] + ) + else: + raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + return res + + +def inv(mat): + """Invert a torch or numpy matrix""" + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f"bad matrix type = {type(mat)}") + + +def reproject_view(pts3d, view2): + shape = view2["pts3d"].shape[:2] + return reproject( + pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape + ) + + +def reproject(pts3d, K, world2cam, shape): + H, W, THREE = pts3d.shape + assert THREE == 3 + + with np.errstate(divide="ignore", invalid="ignore"): + pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2) + + return (H, W), ravel_xy(pos, shape) + + +def ravel_xy(pos, shape): + H, W = shape + with np.errstate(invalid="ignore"): + qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T + quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip( + min=0, max=H - 1, out=qy + ) + return quantized_pos + + +def unravel_xy(pos, shape): + + return np.unravel_index(pos, shape)[0].base[:, ::-1].copy() + + +def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False): + is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)) + pos1 = is_reciprocal1.nonzero()[0] + pos2 = corres_1_to_2[pos1] + if ret_recip: + return is_reciprocal1, pos1, pos2 + return pos1, pos2 + + +def extract_correspondences_from_pts3d( + view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0 +): + view1, view2 = to_numpy((view1, view2)) + + shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2) + shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1) + + is_reciprocal1, pos1, pos2 = reciprocal_1d( + corres1_to_2, corres2_to_1, ret_recip=True + ) + is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)) + + if target_n_corres is None: + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2 + + available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum()) + target_n_positives = int(target_n_corres * (1 - nneg)) + n_positives = min(len(pos1), target_n_positives) + n_negatives = min(target_n_corres - n_positives, available_negatives) + + if n_negatives + n_positives != target_n_corres: + + n_positives = target_n_corres - n_negatives + assert n_positives <= len(pos1) + + assert n_positives <= len(pos1) + assert n_positives <= len(pos2) + assert n_negatives <= (~is_reciprocal1).sum() + assert n_negatives <= (~is_reciprocal2).sum() + assert n_positives + n_negatives == target_n_corres + + valid = np.ones(n_positives, dtype=bool) + if n_positives < len(pos1): + + perm = rng.permutation(len(pos1))[:n_positives] + pos1 = pos1[perm] + pos2 = pos2[perm] + + if n_negatives > 0: + + def norm(p): + return p / p.sum() + + pos1 = np.r_[ + pos1, + rng.choice( + shape1[0] * shape1[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal1), + ), + ] + pos2 = np.r_[ + pos2, + rng.choice( + shape2[0] * shape2[1], + size=n_negatives, + replace=False, + p=norm(~is_reciprocal2), + ), + ] + valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)] + + if ret_xy: + pos1 = unravel_xy(pos1, shape1) + pos2 = unravel_xy(pos2, shape2) + return pos1, pos2, valid diff --git a/FastVGGT/eval/dataset_utils/cropping.py b/FastVGGT/eval/dataset_utils/cropping.py new file mode 100644 index 0000000000000000000000000000000000000000..30a9eac18d241b71538957cf7ba4767ebc323b43 --- /dev/null +++ b/FastVGGT/eval/dataset_utils/cropping.py @@ -0,0 +1,140 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- + +import PIL.Image +import os + +from utils import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import cv2 # noqa +import numpy as np # noqa + +try: + lanczos = PIL.Image.Resampling.LANCZOS + bicubic = PIL.Image.Resampling.BICUBIC +except AttributeError: + lanczos = PIL.Image.LANCZOS + bicubic = PIL.Image.BICUBIC + + +class ImageList: + """Convenience class to aply the same operation to a whole set of images.""" + + def __init__(self, images): + if not isinstance(images, (tuple, list, set)): + images = [images] + self.images = [] + for image in images: + if not isinstance(image, PIL.Image.Image): + image = PIL.Image.fromarray(image) + self.images.append(image) + + def __len__(self): + return len(self.images) + + def to_pil(self): + return tuple(self.images) if len(self.images) > 1 else self.images[0] + + @property + def size(self): + sizes = [im.size for im in self.images] + assert all(sizes[0] == s for s in sizes) + return sizes[0] + + def resize(self, *args, **kwargs): + return ImageList(self._dispatch("resize", *args, **kwargs)) + + def crop(self, *args, **kwargs): + return ImageList(self._dispatch("crop", *args, **kwargs)) + + def _dispatch(self, func, *args, **kwargs): + return [getattr(im, func)(*args, **kwargs) for im in self.images] + + +def rescale_image_depthmap( + image, depthmap, camera_intrinsics, output_resolution, force=True +): + """Jointly rescale a (image, depthmap) + so that (out_width, out_height) >= output_res + """ + image = ImageList(image) + input_resolution = np.array(image.size) # (W,H) + output_resolution = np.array(output_resolution) + if depthmap is not None: + + assert tuple(depthmap.shape[:2]) == image.size[::-1] + + assert output_resolution.shape == (2,) + scale_final = max(output_resolution / image.size) + 1e-8 + if scale_final >= 1 and not force: # image is already smaller than what is asked + return (image.to_pil(), depthmap, camera_intrinsics) + output_resolution = np.floor(input_resolution * scale_final).astype(int) + + image = image.resize( + output_resolution, resample=lanczos if scale_final < 1 else bicubic + ) + if depthmap is not None: + depthmap = cv2.resize( + depthmap, + output_resolution, + fx=scale_final, + fy=scale_final, + interpolation=cv2.INTER_NEAREST, + ) + + camera_intrinsics = camera_matrix_of_crop( + camera_intrinsics, input_resolution, output_resolution, scaling=scale_final + ) + + return image.to_pil(), depthmap, camera_intrinsics + + +def camera_matrix_of_crop( + input_camera_matrix, + input_resolution, + output_resolution, + scaling=1, + offset_factor=0.5, + offset=None, +): + + margins = np.asarray(input_resolution) * scaling - output_resolution + assert np.all(margins >= 0.0) + if offset is None: + offset = offset_factor * margins + + output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) + output_camera_matrix_colmap[:2, :] *= scaling + output_camera_matrix_colmap[:2, 2] -= offset + output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) + + return output_camera_matrix + + +def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): + """ + Return a crop of the input view. + """ + image = ImageList(image) + l, t, r, b = crop_bbox + + image = image.crop((l, t, r, b)) + depthmap = depthmap[t:b, l:r] + + camera_intrinsics = camera_intrinsics.copy() + camera_intrinsics[0, 2] -= l + camera_intrinsics[1, 2] -= t + + return image.to_pil(), depthmap, camera_intrinsics + + +def bbox_from_intrinsics_in_out( + input_camera_matrix, output_camera_matrix, output_resolution +): + out_width, out_height = output_resolution + l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) + crop_bbox = (l, t, l + out_width, t + out_height) + return crop_bbox diff --git a/FastVGGT/eval/dataset_utils/transforms.py b/FastVGGT/eval/dataset_utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..cec2d144e1b97cc99c191d15cdcaf20796cae94b --- /dev/null +++ b/FastVGGT/eval/dataset_utils/transforms.py @@ -0,0 +1,78 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# +# -------------------------------------------------------- + +import torchvision.transforms as tvf + +ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + +ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) + + +def _check_input(value, center=1, bound=(0, float("inf")), clip_first_on_zero=True): + if isinstance(value, (int, float)): + if value < 0: + raise ValueError(f"If is a single number, it must be non negative.") + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + value = [float(value[0]), float(value[1])] + else: + raise TypeError(f"should be a single number or a list/tuple with length 2.") + + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"values should be between {bound}, but got {value}.") + + if value[0] == value[1] == center: + return None + else: + return tuple(value) + + +import torch +import torchvision.transforms.functional as F + + +def SeqColorJitter(): + """ + Return a color jitter transform with same random parameters + """ + brightness = _check_input(0.5) + contrast = _check_input(0.5) + saturation = _check_input(0.5) + hue = _check_input(0.1, center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + + fn_idx = torch.randperm(4) + brightness_factor = ( + None + if brightness is None + else float(torch.empty(1).uniform_(brightness[0], brightness[1])) + ) + contrast_factor = ( + None + if contrast is None + else float(torch.empty(1).uniform_(contrast[0], contrast[1])) + ) + saturation_factor = ( + None + if saturation is None + else float(torch.empty(1).uniform_(saturation[0], saturation[1])) + ) + hue_factor = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1])) + + def _color_jitter(img): + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + img = F.adjust_brightness(img, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + img = F.adjust_contrast(img, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + img = F.adjust_saturation(img, saturation_factor) + elif fn_id == 3 and hue_factor is not None: + img = F.adjust_hue(img, hue_factor) + return ImgNorm(img) + + return _color_jitter diff --git a/FastVGGT/eval/eval_7andN.py b/FastVGGT/eval/eval_7andN.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6d5fbf6c159ea18bfd6fbcee0a5d23e63c0cf2 --- /dev/null +++ b/FastVGGT/eval/eval_7andN.py @@ -0,0 +1,497 @@ +import os +import sys + +# Ensure project root is on sys.path for absolute imports like `vggt.*` +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +if ROOT_DIR not in sys.path: + sys.path.insert(0, ROOT_DIR) + +import time +import torch +import argparse +import numpy as np +import open3d as o3d +import os.path as osp +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate +from tqdm import tqdm +from collections import defaultdict +import torchvision.transforms as transforms + + +def get_args_parser(): + parser = argparse.ArgumentParser("3D Reconstruction evaluation", add_help=False) + parser.add_argument( + "--ckpt_path", + type=str, + default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt", + help="ckpt name", + ) + parser.add_argument("--device", type=str, default="cuda:0", help="device") + parser.add_argument("--model_name", type=str, default="VGGT") + parser.add_argument( + "--conf_thresh", type=float, default=0.0, help="confidence threshold" + ) + parser.add_argument( + "--output_dir", + type=str, + default="/home/sy/code/FastVGGT/eval_results", + help="value for outdir", + ) + parser.add_argument("--size", type=int, default=518) + parser.add_argument("--revisit", type=int, default=1, help="revisit times") + parser.add_argument("--freeze", action="store_true") + parser.add_argument("--use_proj", action="store_true") + parser.add_argument( + "--merging", type=int, default=0, help="VGGT aggregator merging steps" + ) + parser.add_argument("--kf", type=int, default=2, help="key frame") + return parser + + +def main(args): + from data import SevenScenes, NRGBD + from utils import accuracy, completion + + if args.size == 512: + resolution = (512, 384) + elif args.size == 224: + resolution = 224 + elif args.size == 518: + resolution = (518, 392) + else: + raise NotImplementedError + datasets_all = { + "7scenes": SevenScenes( + split="test", + ROOT="/data/sy/7scenes", + resolution=resolution, + num_seq=1, + full_video=True, + kf_every=args.kf, + ), # 20), + "NRGBD": NRGBD( + split="test", + ROOT="/data/sy/neural_rgbd_data", + resolution=resolution, + num_seq=1, + full_video=True, + kf_every=args.kf, + ), + } + + device = args.device + model_name = args.model_name + + from vggt.models.vggt import VGGT + from vggt.utils.pose_enc import pose_encoding_to_extri_intri + from vggt.utils.geometry import unproject_depth_map_to_point_map + from criterion import Regr3D_t_ScaleShiftInv, L21 + + # Force use of bf16 data type + dtype = torch.bfloat16 + # Load VGGT model + model = VGGT(merging=args.merging, enable_point=True) + ckpt = torch.load(args.ckpt_path, map_location="cpu") + + # ✅ Fix: load pre-trained weights + model.load_state_dict( + ckpt, strict=False + ) # Use strict=False due to enable_point=True difference + + model = model.cuda().eval() + model = model.to(torch.bfloat16) + + del ckpt + os.makedirs(osp.join(args.output_dir, f"{args.kf}"), exist_ok=True) + + criterion = Regr3D_t_ScaleShiftInv(L21, norm_mode=False, gt_scale=True) + + with torch.no_grad(): + for name_data, dataset in datasets_all.items(): + save_path = osp.join(osp.join(args.output_dir, f"{args.kf}"), name_data) + os.makedirs(save_path, exist_ok=True) + log_file = osp.join(save_path, "logs.txt") + + acc_all = 0 + acc_all_med = 0 + comp_all = 0 + comp_all_med = 0 + nc1_all = 0 + nc1_all_med = 0 + nc2_all = 0 + nc2_all_med = 0 + scene_infer_times = defaultdict(list) + + for data_idx in tqdm(range(len(dataset))): + batch = default_collate([dataset[data_idx]]) + ignore_keys = set( + [ + "depthmap", + "dataset", + "label", + "instance", + "idx", + "true_shape", + "rng", + ] + ) + for view in batch: + for name in view.keys(): # pseudo_focal + if name in ignore_keys: + continue + if isinstance(view[name], tuple) or isinstance( + view[name], list + ): + view[name] = [ + x.to(device, non_blocking=True) for x in view[name] + ] + else: + view[name] = view[name].to(device, non_blocking=True) + + pts_all = [] + pts_gt_all = [] + images_all = [] + masks_all = [] + conf_all = [] + in_camera1 = None + + dtype = ( + torch.bfloat16 + if torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 + ) + with torch.cuda.amp.autocast(dtype=dtype): + if isinstance(batch, dict) and "img" in batch: + batch["img"] = (batch["img"] + 1.0) / 2.0 + elif isinstance(batch, list) and all( + isinstance(v, dict) and "img" in v for v in batch + ): + for view in batch: + view["img"] = (view["img"] + 1.0) / 2.0 + # Gather all `img` tensors into a single tensor of shape [N, C, H, W] + imgs_tensor = torch.cat([v["img"] for v in batch], dim=0) + + with torch.cuda.amp.autocast(dtype=dtype): + with torch.no_grad(): + torch.cuda.synchronize() + start = time.time() + preds = model(imgs_tensor) + torch.cuda.synchronize() + end = time.time() + inference_time_ms = (end - start) * 1000 + print(f"Inference time: {inference_time_ms:.2f}ms") + + # Wrap model outputs per-view to align with batch later + predictions = preds + views = batch # list[dict] + if "pose_enc" in predictions: + B, S = predictions["pose_enc"].shape[:2] + elif "world_points" in predictions: + B, S = predictions["world_points"].shape[:2] + else: + raise KeyError( + "predictions is missing a key to infer sequence length" + ) + + ress = [] + for s in range(S): + res = { + "pts3d_in_other_view": predictions["world_points"][:, s], + "conf": predictions["world_points_conf"][:, s], + "depth": predictions["depth"][:, s], + "depth_conf": predictions["depth_conf"][:, s], + "camera_pose": predictions["pose_enc"][:, s, :], + } + if ( + isinstance(views, list) + and s < len(views) + and "valid_mask" in views[s] + ): + res["valid_mask"] = views[s]["valid_mask"] + if "track" in predictions: + res.update( + { + "track": predictions["track"][:, s], + "vis": ( + predictions.get("vis", None)[:, s] + if "vis" in predictions + else None + ), + "track_conf": ( + predictions.get("conf", None)[:, s] + if "conf" in predictions + else None + ), + } + ) + ress.append(res) + + preds = ress + + valid_length = len(preds) // args.revisit + if args.revisit > 1: + preds = preds[-valid_length:] + batch = batch[-valid_length:] + + # Evaluation + print(f"Evaluation for {name_data} {data_idx+1}/{len(dataset)}") + gt_pts, pred_pts, gt_factor, pr_factor, masks, monitoring = ( + criterion.get_all_pts3d_t(batch, preds) + ) + + in_camera1 = None + pts_all = [] + pts_gt_all = [] + images_all = [] + masks_all = [] + conf_all = [] + + for j, view in enumerate(batch): + if in_camera1 is None: + in_camera1 = view["camera_pose"][0].cpu() + + image = view["img"].permute(0, 2, 3, 1).cpu().numpy()[0] + mask = view["valid_mask"].cpu().numpy()[0] + + pts = pred_pts[j].cpu().numpy()[0] + conf = preds[j]["conf"].cpu().data.numpy()[0] + + # mask = mask & (conf > 1.8) + + pts_gt = gt_pts[j].detach().cpu().numpy()[0] + + H, W = image.shape[:2] + cx = W // 2 + cy = H // 2 + l, t = cx - 112, cy - 112 + r, b = cx + 112, cy + 112 + image = image[t:b, l:r] + mask = mask[t:b, l:r] + pts = pts[t:b, l:r] + pts_gt = pts_gt[t:b, l:r] + + images_all.append(image[None, ...]) + pts_all.append(pts[None, ...]) + pts_gt_all.append(pts_gt[None, ...]) + masks_all.append(mask[None, ...]) + conf_all.append(conf[None, ...]) + + images_all = np.concatenate(images_all, axis=0) + pts_all = np.concatenate(pts_all, axis=0) + pts_gt_all = np.concatenate(pts_gt_all, axis=0) + masks_all = np.concatenate(masks_all, axis=0) + + scene_id = view["label"][0].rsplit("/", 1)[0] + # Record average inference time per scene + try: + scene_infer_times[scene_id].append(float(inference_time_ms)) + except Exception: + pass + + save_params = {} + + save_params["images_all"] = images_all + save_params["pts_all"] = pts_all + save_params["pts_gt_all"] = pts_gt_all + save_params["masks_all"] = masks_all + + pts_all_masked = pts_all[masks_all > 0] + pts_gt_all_masked = pts_gt_all[masks_all > 0] + images_all_masked = images_all[masks_all > 0] + + mask = np.isfinite(pts_all_masked) + pts_all_masked = pts_all_masked[mask] + + mask_gt = np.isfinite(pts_gt_all_masked) + pts_gt_all_masked = pts_gt_all_masked[mask_gt] + images_all_masked = images_all_masked[mask] + + # Reshape to point cloud (N, 3) before sampling + pts_all_masked = pts_all_masked.reshape(-1, 3) + pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3) + images_all_masked = images_all_masked.reshape(-1, 3) + + # If number of points exceeds threshold, sample by points + if pts_all_masked.shape[0] > 999999: + sample_indices = np.random.choice( + pts_all_masked.shape[0], 999999, replace=False + ) + pts_all_masked = pts_all_masked[sample_indices] + images_all_masked = images_all_masked[sample_indices] + + # Apply the same sampling to GT point cloud + if pts_gt_all_masked.shape[0] > 999999: + sample_indices_gt = np.random.choice( + pts_gt_all_masked.shape[0], 999999, replace=False + ) + pts_gt_all_masked = pts_gt_all_masked[sample_indices_gt] + + if args.use_proj: + + def umeyama_alignment( + src: np.ndarray, dst: np.ndarray, with_scale: bool = True + ): + assert src.shape == dst.shape + N, dim = src.shape + + mu_src = src.mean(axis=0) + mu_dst = dst.mean(axis=0) + src_c = src - mu_src + dst_c = dst - mu_dst + + Sigma = dst_c.T @ src_c / N # (3,3) + + U, D, Vt = np.linalg.svd(Sigma) + + S = np.eye(dim) + if np.linalg.det(U) * np.linalg.det(Vt) < 0: + S[-1, -1] = -1 + + R = U @ S @ Vt + + if with_scale: + var_src = (src_c**2).sum() / N + s = (D * S.diagonal()).sum() / var_src + else: + s = 1.0 + + t = mu_dst - s * R @ mu_src + + return s, R, t + + pts_all_masked = pts_all_masked.reshape(-1, 3) + pts_gt_all_masked = pts_gt_all_masked.reshape(-1, 3) + s, R, t = umeyama_alignment( + pts_all_masked, pts_gt_all_masked, with_scale=True + ) + pts_all_aligned = (s * (R @ pts_all_masked.T)).T + t # (N,3) + pts_all_masked = pts_all_aligned + + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(pts_all_masked) + pcd.colors = o3d.utility.Vector3dVector(images_all_masked) + + pcd_gt = o3d.geometry.PointCloud() + pcd_gt.points = o3d.utility.Vector3dVector(pts_gt_all_masked) + pcd_gt.colors = o3d.utility.Vector3dVector(images_all_masked) + + trans_init = np.eye(4) + + threshold = 0.1 + reg_p2p = o3d.pipelines.registration.registration_icp( + pcd, + pcd_gt, + threshold, + trans_init, + o3d.pipelines.registration.TransformationEstimationPointToPoint(), + ) + + transformation = reg_p2p.transformation + + pcd = pcd.transform(transformation) + pcd.estimate_normals() + pcd_gt.estimate_normals() + + gt_normal = np.asarray(pcd_gt.normals) + pred_normal = np.asarray(pcd.normals) + + acc, acc_med, nc1, nc1_med = accuracy( + pcd_gt.points, pcd.points, gt_normal, pred_normal + ) + comp, comp_med, nc2, nc2_med = completion( + pcd_gt.points, pcd.points, gt_normal, pred_normal + ) + print( + f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}" + ) + print( + f"Idx: {scene_id}, Acc: {acc}, Comp: {comp}, NC1: {nc1}, NC2: {nc2} - Acc_med: {acc_med}, Compc_med: {comp_med}, NC1c_med: {nc1_med}, NC2c_med: {nc2_med}", + file=open(log_file, "a"), + ) + + acc_all += acc + comp_all += comp + nc1_all += nc1 + nc2_all += nc2 + + acc_all_med += acc_med + comp_all_med += comp_med + nc1_all_med += nc1_med + nc2_all_med += nc2_med + + # release cuda memory + torch.cuda.empty_cache() + + # Get depth from pcd and run TSDFusion + to_write = "" + # Read the log file + if os.path.exists(osp.join(save_path, "logs.txt")): + with open(osp.join(save_path, "logs.txt"), "r") as f_sub: + to_write += f_sub.read() + + with open(osp.join(save_path, f"logs_all.txt"), "w") as f: + log_data = to_write + metrics = defaultdict(list) + for line in log_data.strip().split("\n"): + match = regex.match(line) + if match: + data = match.groupdict() + # Exclude 'scene_id' from metrics as it's an identifier + for key, value in data.items(): + if key != "scene_id": + metrics[key].append(float(value)) + metrics["nc"].append( + (float(data["nc1"]) + float(data["nc2"])) / 2 + ) + metrics["nc_med"].append( + (float(data["nc1_med"]) + float(data["nc2_med"])) / 2 + ) + mean_metrics = { + metric: sum(values) / len(values) + for metric, values in metrics.items() + } + + c_name = "mean" + print_str = f"{c_name.ljust(20)}: " + for m_name in mean_metrics: + print_num = np.mean(mean_metrics[m_name]) + print_str = print_str + f"{m_name}: {print_num:.3f} | " + print_str = print_str + "\n" + # Summarize per-scene average inference time + time_lines = [] + for sid, times in scene_infer_times.items(): + if len(times) > 0: + time_lines.append( + f"Idx: {sid}, Time_avg_ms: {np.mean(times):.2f}" + ) + time_block = "\n".join(time_lines) + ( + "\n" if len(time_lines) > 0 else "" + ) + + f.write(to_write + time_block + print_str) + + +from collections import defaultdict +import re + +pattern = r""" + Idx:\s*(?P[^,]+),\s* + Acc:\s*(?P[^,]+),\s* + Comp:\s*(?P[^,]+),\s* + NC1:\s*(?P[^,]+),\s* + NC2:\s*(?P[^,]+)\s*-\s* + Acc_med:\s*(?P[^,]+),\s* + Compc_med:\s*(?P[^,]+),\s* + NC1c_med:\s*(?P[^,]+),\s* + NC2c_med:\s*(?P[^,]+) +""" + +regex = re.compile(pattern, re.VERBOSE) + + +if __name__ == "__main__": + parser = get_args_parser() + args = parser.parse_args() + + main(args) diff --git a/FastVGGT/eval/eval_custom.py b/FastVGGT/eval/eval_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e4f7ccdd6c22c2c3bef6421c52ace01d362a83 --- /dev/null +++ b/FastVGGT/eval/eval_custom.py @@ -0,0 +1,467 @@ +import argparse +from pathlib import Path +import numpy as np +import torch +import os +import sys +import matplotlib.pyplot as plt +from scipy.spatial.transform import Rotation + +# Ensure project root is in sys.path for absolute imports like `vggt.*` +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +if ROOT_DIR not in sys.path: + sys.path.insert(0, ROOT_DIR) + +from vggt.models.vggt import VGGT +from vggt.utils.eval_utils import ( + load_poses, + get_vgg_input_imgs, + get_sorted_image_paths, + build_frame_selection, + load_images_rgb, + infer_vggt_and_reconstruct, + evaluate_scene_and_save, +) + +# Import pose visualization libraries (optional EVO support) +try: + from evo.core.trajectory import PoseTrajectory3D + import evo.tools.plot as plot + + EVO_AVAILABLE = True +except ImportError: + # EVO is optional; we have a matplotlib-based fallback + EVO_AVAILABLE = False + + +def visualize_predicted_poses( + all_cam_to_world_mat, frame_ids, output_scene_dir, scene_name="custom_dataset" +): + """ + Visualize the predicted camera pose trajectory (no GT comparison required). + + Args: + all_cam_to_world_mat: List of camera-to-world transform matrices + frame_ids: List of frame IDs + output_scene_dir: Output directory + scene_name: Scene name + """ + # Provide basic pose visualization even without EVO + if not EVO_AVAILABLE: + print("⚠️ EVO not installed; using basic matplotlib visualization") + + try: + # Convert to numpy array + poses_est = np.array(all_cam_to_world_mat) + + if len(poses_est) < 2: + print("⚠️ Not enough poses to generate trajectory plot") + return + + print(f"🎨 Generating pose trajectory visualization...") + + # Extract translation part + positions = poses_est[:, :3, 3] # shape: (N, 3) + + # Create figure - show XZ-plane projection only + fig, ax = plt.subplots(1, 1, figsize=(10, 8)) + + # XZ-plane projection + ax.plot( + positions[:, 0], + positions[:, 2], + "b-", + linewidth=2, + label="Predicted Trajectory", + ) + ax.scatter( + positions[0, 0], positions[0, 2], color="green", s=100, label="Start" + ) + ax.scatter(positions[-1, 0], positions[-1, 2], color="red", s=100, label="End") + ax.set_xlabel("X (m)") + ax.set_ylabel("Z (m)") + ax.set_title(f"{scene_name} - XZ-plane projection") + ax.legend() + ax.grid(True, alpha=0.3) + + # Save image + pose_plot_path = output_scene_dir / "predicted_trajectory.png" + plt.savefig(pose_plot_path, dpi=300, bbox_inches="tight") + plt.close() + + print(f"📊 Trajectory visualization saved: {pose_plot_path}") + + except Exception as e: + print(f"⚠️ Failed to generate pose visualization: {e}") + import traceback + + traceback.print_exc() + + +def main(): + """ + Evaluation script for a Custom Dataset. + Supports optional evaluation and custom dataset structure. + """ + parser = argparse.ArgumentParser( + description="Run FastVGGT evaluation on a Custom Dataset" + ) + + # Required: dataset path + parser.add_argument( + "--data_path", + type=Path, + required=True, + help="Dataset path containing subfolders: color, depth, gt_ply, pose", + ) + + # Optional: enable evaluation + parser.add_argument( + "--enable_evaluation", + action="store_true", + help="Enable evaluation (requires pose and ply data)", + ) + + # Output path + parser.add_argument( + "--output_path", + type=Path, + default="./eval_results_custom", + help="Output path for evaluation results", + ) + + # Model parameters + parser.add_argument( + "--ckpt_path", + type=str, + default="/home/sy/code/FastVGGT/ckpt/model_tracker_fixed_e20.pt", + help="Model checkpoint file path", + ) + + parser.add_argument("--merging", type=int, default=0, help="Merging parameter") + + # Processing parameters + parser.add_argument( + "--input_frame", + type=int, + default=200, + help="Maximum number of frames to process per scene", + ) + + parser.add_argument( + "--depth_conf_thresh", + type=float, + default=3.0, + help="Depth confidence threshold to filter low-confidence depth values", + ) + + # Evaluation parameters (only used when evaluation is enabled) + parser.add_argument( + "--chamfer_max_dist", + type=float, + default=0.5, + help="Maximum distance threshold used in Chamfer Distance computation", + ) + + parser.add_argument("--plot", action="store_true", help="Whether to generate plots") + + parser.add_argument( + "--vis_attn_map", + action="store_true", + help="Visualize attention maps during inference", + ) + + args = parser.parse_args() + torch.manual_seed(33) + + # Check data path exists + if not args.data_path.exists(): + print(f"❌ Error: Data path does not exist: {args.data_path}") + return + + # Check required subdirectories + color_dir = args.data_path / "images" + pose_dir = args.data_path / "pose" + + if not color_dir.exists(): + print(f"❌ Error: color directory does not exist: {color_dir}") + return + + print(f"📁 Dataset path: {args.data_path}") + # print(f"🔧 Enable evaluation: {'Yes' if args.enable_evaluation else 'No'}") + + # If evaluation is enabled, check pose and gt_ply directories + if args.enable_evaluation: + if not pose_dir.exists(): + print(f"❌ Error: Evaluation requires pose directory: {pose_dir}") + return + + gt_ply_dir = args.data_path / "gt_ply" + if not gt_ply_dir.exists(): + print(f"❌ Error: Evaluation requires gt_ply directory: {gt_ply_dir}") + return + print(f"📊 Evaluation will use Ground Truth") + else: + print(f"🏃 Inference only, no evaluation") + + # Create output directory + args.output_path.mkdir(parents=True, exist_ok=True) + output_scene_dir = args.output_path / "custom_dataset" + + # Check if already processed + if (output_scene_dir / "metrics.json").exists() and args.enable_evaluation: + print( + f"⚠️ Results already exist, skipping: {output_scene_dir / 'metrics.json'}" + ) + return + + # Force use of bf16 dtype + dtype = torch.bfloat16 + + # Load VGGT model + print(f"🔄 Loading model: {args.ckpt_path}") + model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map) + ckpt = torch.load(args.ckpt_path, map_location="cpu") + incompat = model.load_state_dict(ckpt, strict=False) + # if incompat.missing_keys or incompat.unexpected_keys: + # print(f"⚠️ Partially incompatible keys when loading model: {incompat}") + model = model.cuda().eval() + model = model.to(torch.bfloat16) + print(f"✅ Model loaded") + + # Load scene data + image_paths = get_sorted_image_paths(color_dir) + if len(image_paths) == 0: + print(f"❌ Error: No images found in {color_dir}") + return + + print(f"🖼️ Found {len(image_paths)} images") + + # Process pose data (if evaluation is enabled) + poses_gt = None + first_gt_pose = None + available_pose_frame_ids = None + c2ws = None + + if args.enable_evaluation: + poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_dir) + if ( + poses_gt is None + or first_gt_pose is None + or available_pose_frame_ids is None + ): + print(f"❌ Error: Failed to load pose data") + return + print(f"📐 Loaded {len(poses_gt)} poses") + + # Frame selection + if args.enable_evaluation and available_pose_frame_ids is not None: + # Use pose data for frame selection + selected_frame_ids, selected_image_paths, selected_pose_indices = ( + build_frame_selection( + image_paths, available_pose_frame_ids, args.input_frame + ) + ) + c2ws = poses_gt[selected_pose_indices] + image_paths = selected_image_paths + else: + # Simply take the first N frames + num_frames = min(len(image_paths), args.input_frame) + selected_frame_ids = list(range(num_frames)) + image_paths = image_paths[:num_frames] + + print(f"📋 Selected {len(image_paths)} frames for processing") + + try: + # Load images + print(f"🔄 Loading images...") + images = load_images_rgb(image_paths) + + if not images or len(images) < 3: + print(f"❌ Error: Not enough valid images (need at least 3)") + return + + frame_ids = selected_frame_ids + images_array = np.stack(images) + vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array) + print(f"📐 Image patch dimensions: {patch_width}x{patch_height}") + + # Update attention layer patch dimensions in the model + model.update_patch_dimensions(patch_width, patch_height) + + # Inference + Reconstruction + print(f"🚀 Start inference and reconstruction...") + ( + extrinsic_np, + intrinsic_np, + all_world_points, + all_point_colors, + all_cam_to_world_mat, + inference_time_ms, + ) = infer_vggt_and_reconstruct( + model, vgg_input, dtype, args.depth_conf_thresh, image_paths + ) + print(f"⏱️ Inference time: {inference_time_ms:.2f}ms") + + # Check results + if not all_cam_to_world_mat or not all_world_points: + print(f"❌ Error: Failed to obtain valid camera poses or point clouds") + return + + # print(f"✅ Inference done, obtained {len(all_world_points)} point sets") + + # Evaluation and saving + if args.enable_evaluation: + print(f"📊 Start evaluation...") + gt_ply_dir = args.data_path / "gt_ply" + metrics = evaluate_scene_and_save( + "custom_dataset", + c2ws, + first_gt_pose, + frame_ids, + all_cam_to_world_mat, + all_world_points, + output_scene_dir, + gt_ply_dir, + args.chamfer_max_dist, + inference_time_ms, + args.plot, + ) + if metrics is not None: + print("📈 Evaluation results:") + for key, value in metrics.items(): + if key in [ + "chamfer_distance", + "ate", + "are", + "rpe_rot", + "rpe_trans", + "inference_time_ms", + ]: + print(f" {key}: {float(value):.4f}") + + # Also visualize predicted poses in evaluation branch + if args.plot: + visualize_predicted_poses( + all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset" + ) + else: + # Save reconstruction only, no evaluation + print(f"💾 Saving reconstruction...") + output_scene_dir.mkdir(parents=True, exist_ok=True) + + # Save camera poses + poses_output_path = output_scene_dir / "estimated_poses.txt" + with open(poses_output_path, "w") as f: + for i, pose in enumerate(all_cam_to_world_mat): + f.write(f"# Frame {frame_ids[i]}\n") + for row in pose: + f.write(" ".join(map(str, row)) + "\n") + f.write("\n") + + # Save point cloud + if all_world_points: + points_output_path = output_scene_dir / "reconstructed_points.ply" + + # Merge all frames' point clouds and colors + try: + merged_point_cloud = np.vstack(all_world_points) + merged_colors = ( + np.vstack(all_point_colors).astype(np.uint8) + if all_point_colors is not None and len(all_point_colors) > 0 + else None + ) + print( + f"📊 Merged point clouds: {len(all_world_points)} frames, total {len(merged_point_cloud)} points" + ) + + # If too many points, randomly sample 100000 points + max_points = 100000 + if len(merged_point_cloud) > max_points: + print( + f"🔽 Too many points, randomly sampling {max_points} points..." + ) + # Randomly choose indices + indices = np.random.choice( + len(merged_point_cloud), size=max_points, replace=False + ) + merged_point_cloud = merged_point_cloud[indices] + if merged_colors is not None: + merged_colors = merged_colors[indices] + print( + f"✅ Sampling done, kept {len(merged_point_cloud)} points" + ) + + # Save as PLY (with color) + with open(points_output_path, "w") as f: + f.write("ply\n") + f.write("format ascii 1.0\n") + f.write(f"element vertex {len(merged_point_cloud)}\n") + f.write("property float x\n") + f.write("property float y\n") + f.write("property float z\n") + if merged_colors is not None: + f.write("property uchar red\n") + f.write("property uchar green\n") + f.write("property uchar blue\n") + f.write("end_header\n") + if merged_colors is None: + for point in merged_point_cloud: + if not (np.isnan(point).any() or np.isinf(point).any()): + f.write( + f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f}\n" + ) + else: + for point, color in zip(merged_point_cloud, merged_colors): + # Check point validity + if not (np.isnan(point).any() or np.isinf(point).any()): + r = int(np.clip(color[0], 0, 255)) + g = int(np.clip(color[1], 0, 255)) + b = int(np.clip(color[2], 0, 255)) + f.write( + f"{point[0]:.6f} {point[1]:.6f} {point[2]:.6f} {r} {g} {b}\n" + ) + + print(f"💾 Point cloud saved to: {points_output_path}") + + except Exception as e: + print(f"⚠️ Error saving point cloud: {e}") + # If merge fails, try to log per-frame info + print(f"🔍 Point cloud debug info:") + for i, frame_points in enumerate(all_world_points): + print( + f" Frame {i}: {frame_points.shape if hasattr(frame_points, 'shape') else type(frame_points)}" + ) + if ( + hasattr(frame_points, "shape") + and len(frame_points.shape) >= 2 + ): + print( + f" Shape: {frame_points.shape}, Dtype: {frame_points.dtype}" + ) + if frame_points.shape[0] > 0: + print( + f" Range: x[{np.min(frame_points[:, 0]):.3f}, {np.max(frame_points[:, 0]):.3f}] " + f"y[{np.min(frame_points[:, 1]):.3f}, {np.max(frame_points[:, 1]):.3f}] " + f"z[{np.min(frame_points[:, 2]):.3f}, {np.max(frame_points[:, 2]):.3f}]" + ) + + print(f"📁 Results saved to: {output_scene_dir}") + + # Visualize predicted pose trajectory + if args.plot: + visualize_predicted_poses( + all_cam_to_world_mat, frame_ids, output_scene_dir, "custom_dataset" + ) + + print(f"🎉 Done!") + + except Exception as e: + print(f"❌ Error occurred during processing: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/FastVGGT/eval/eval_scannet.py b/FastVGGT/eval/eval_scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..332132c45ab4fc3f7e768dbf3845d9cdb3ccc4eb --- /dev/null +++ b/FastVGGT/eval/eval_scannet.py @@ -0,0 +1,208 @@ +import argparse +from pathlib import Path +import numpy as np +import torch +import os +import sys + +# Ensure project root is in sys.path for absolute imports like `vggt.*` +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) +if ROOT_DIR not in sys.path: + sys.path.insert(0, ROOT_DIR) + +from vggt.models.vggt import VGGT +from vggt.utils.eval_utils import ( + load_poses, + get_vgg_input_imgs, + get_sorted_image_paths, + get_all_scenes, + build_frame_selection, + load_images_rgb, + infer_vggt_and_reconstruct, + evaluate_scene_and_save, + compute_average_metrics_and_save, +) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", type=Path, default="/data/scannetv2/process_scannet" + ) + parser.add_argument( + "--gt_ply_dir", + type=Path, + default="/data/scannetv2/OpenDataLab___ScanNet_v2/raw/scans", + ) + parser.add_argument("--output_path", type=Path, default="./eval_results") + parser.add_argument("--merging", type=int, default=None) + parser.add_argument("--plot", type=bool, default=True) + parser.add_argument( + "--depth_conf_thresh", + type=float, + default=3.0, + help="Depth confidence threshold for filtering low confidence depth values", + ) + parser.add_argument( + "--chamfer_max_dist", + type=float, + default=0.5, + help="Maximum distance threshold in Chamfer Distance computation, distances exceeding this value will be clipped", + ) + parser.add_argument( + "--input_frame", + type=int, + default=200, + help="Maximum number of frames selected for processing per scene", + ) + parser.add_argument( + "--num_scenes", + type=int, + default=50, + help="Maximum number of scenes to evaluate", + ) + parser.add_argument( + "--ckpt_path", + type=str, + default="./ckpt/model_tracker_fixed_e20.pt", + help="Path to the model checkpoint file", + ) + parser.add_argument( + "--vis_attn_map", + action="store_true", + help="Whether to visualize attention maps during inference", + ) + args = parser.parse_args() + torch.manual_seed(33) + + # Scene sampling + scannet_scenes = get_all_scenes(args.data_dir, args.num_scenes) + print(f"Evaluate {len(scannet_scenes)} scenes") + + all_scenes_metrics = {"scenes": {}, "average": {}} + # Force use of bf16 data type + dtype = torch.bfloat16 + # Load VGGT model + model = VGGT(merging=args.merging, vis_attn_map=args.vis_attn_map) + ckpt = torch.load(args.ckpt_path, map_location="cpu") + incompat = model.load_state_dict(ckpt, strict=False) + model = model.cuda().eval() + model = model.to(torch.bfloat16) + + # Process each scene + for scene in scannet_scenes: + scene_dir = args.data_dir / f"{scene}" + output_scene_dir = args.output_path / f"input_frame_{args.input_frame}" / scene + if (output_scene_dir / "metrics.json").exists(): + continue + + # Load scene data + images_dir = scene_dir / "color" + pose_path = scene_dir / "pose" + image_paths = get_sorted_image_paths(images_dir) + poses_gt, first_gt_pose, available_pose_frame_ids = load_poses(pose_path) + if ( + poses_gt is None + or first_gt_pose is None + or available_pose_frame_ids is None + ): + print(f"Skipping scene {scene}: no pose data") + continue + + # Frame filtering + selected_frame_ids, selected_image_paths, selected_pose_indices = ( + build_frame_selection( + image_paths, available_pose_frame_ids, args.input_frame + ) + ) + + # Get corresponding poses + c2ws = poses_gt[selected_pose_indices] + image_paths = selected_image_paths + + if len(image_paths) == 0: + print(f"No images found in {images_dir}") + continue + + print("🚩Processing", scene, f"Found {len(image_paths)} images") + all_cam_to_world_mat = [] + all_world_points = [] + + try: + # Load images + images = load_images_rgb(image_paths) + + if not images or len(images) < 3: + print(f"Skipping {scene}: insufficient valid images") + continue + + frame_ids = selected_frame_ids + images_array = np.stack(images) + vgg_input, patch_width, patch_height = get_vgg_input_imgs(images_array) + print(f"Patch dimensions: {patch_width}x{patch_height}") + + # Update model attention layers with dynamic patch dimensions + model.update_patch_dimensions(patch_width, patch_height) + + # Inference + Reconstruction + ( + extrinsic_np, + intrinsic_np, + all_world_points, + all_point_colors, + all_cam_to_world_mat, + inference_time_ms, + ) = infer_vggt_and_reconstruct( + model, vgg_input, dtype, args.depth_conf_thresh, image_paths + ) + print(f"Inference time: {inference_time_ms:.2f}ms") + + # Process results + if not all_cam_to_world_mat or not all_world_points: + print( + f"Skipping {scene}: failed to obtain valid camera poses or point clouds" + ) + continue + + # Evaluate and save + metrics = evaluate_scene_and_save( + scene, + c2ws, + first_gt_pose, + frame_ids, + all_cam_to_world_mat, + all_world_points, + output_scene_dir, + args.gt_ply_dir, + args.chamfer_max_dist, + inference_time_ms, + args.plot, + ) + if metrics is not None: + all_scenes_metrics["scenes"][scene] = { + key: float(value) + for key, value in metrics.items() + if key + in [ + "chamfer_distance", + "ate", + "are", + "rpe_rot", + "rpe_trans", + "inference_time_ms", + ] + } + print("Complete metrics", all_scenes_metrics["scenes"][scene]) + + except Exception as e: + print(f"Error processing scene {scene}: {e}") + import traceback + + traceback.print_exc() + + # Summarize average metrics and save + compute_average_metrics_and_save( + all_scenes_metrics, + args.output_path, + args.input_frame, + ) diff --git a/FastVGGT/eval/utils.py b/FastVGGT/eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d5606560e1fc82a7b7b81df8b3b9f3b9ec8662 --- /dev/null +++ b/FastVGGT/eval/utils.py @@ -0,0 +1,142 @@ +import numpy as np +from scipy.spatial import cKDTree as KDTree + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = np.meshgrid(np.arange(W), np.arange(H)) + z_cam = depthmap + x_cam = (u - cu) * z_cam / fu + y_cam = (v - cv) * z_cam / fv + X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + valid_mask = depthmap > 0.0 + return X_cam, valid_mask + + +def depthmap_to_absolute_camera_coordinates( + depthmap, camera_intrinsics, camera_pose, **kw +): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + X_world = X_cam # default + if camera_pose is not None: + + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + X_world = ( + np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + ) + + return X_world, valid_mask + + +def completion_ratio(gt_points, rec_points, dist_th=0.05): + gen_points_kd_tree = KDTree(rec_points) + distances, _ = gen_points_kd_tree.query(gt_points) + comp_ratio = np.mean((distances < dist_th).astype(np.float32)) + return comp_ratio + + +def accuracy(gt_points, rec_points, gt_normals=None, rec_normals=None): + gt_points_kd_tree = KDTree(gt_points) + distances, idx = gt_points_kd_tree.query(rec_points, workers=-1) + acc = np.mean(distances) + + acc_median = np.median(distances) + + if gt_normals is not None and rec_normals is not None: + normal_dot = np.sum(gt_normals[idx] * rec_normals, axis=-1) + normal_dot = np.abs(normal_dot) + + return acc, acc_median, np.mean(normal_dot), np.median(normal_dot) + + return acc, acc_median + + +def completion(gt_points, rec_points, gt_normals=None, rec_normals=None): + gt_points_kd_tree = KDTree(rec_points) + distances, idx = gt_points_kd_tree.query(gt_points, workers=-1) + comp = np.mean(distances) + comp_median = np.median(distances) + + if gt_normals is not None and rec_normals is not None: + normal_dot = np.sum(gt_normals * rec_normals[idx], axis=-1) + normal_dot = np.abs(normal_dot) + + return comp, comp_median, np.mean(normal_dot), np.median(normal_dot) + + return comp, comp_median + + +def compute_iou(pred_vox, target_vox): + # Get voxel indices + v_pred_indices = [voxel.grid_index for voxel in pred_vox.get_voxels()] + v_target_indices = [voxel.grid_index for voxel in target_vox.get_voxels()] + + # Convert to sets for set operations + v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred_indices) + v_target_filled = set(tuple(np.round(x, 4)) for x in v_target_indices) + + # Compute intersection and union + intersection = v_pred_filled & v_target_filled + union = v_pred_filled | v_target_filled + + # Compute IoU + iou = len(intersection) / len(union) + return iou + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + return K diff --git a/FastVGGT/merging/__init__.py b/FastVGGT/merging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..156ac80ee0ef4270bc02b0954f8473ff55380cce --- /dev/null +++ b/FastVGGT/merging/__init__.py @@ -0,0 +1,3 @@ +from . import merge + +__all__ = ["merge"] diff --git a/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc b/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66a4e39454e637a8e8e0fa63f6708b1eb670aae9 Binary files /dev/null and b/FastVGGT/merging/__pycache__/__init__.cpython-310.pyc differ diff --git a/FastVGGT/merging/__pycache__/merge.cpython-310.pyc b/FastVGGT/merging/__pycache__/merge.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fa7eff7fec3ec548a7532b57636ea9d7db57e16 Binary files /dev/null and b/FastVGGT/merging/__pycache__/merge.cpython-310.pyc differ diff --git a/FastVGGT/merging/merge.py b/FastVGGT/merging/merge.py new file mode 100644 index 0000000000000000000000000000000000000000..e2094b657cba6c9772a85f2ebe0240367fcec09c --- /dev/null +++ b/FastVGGT/merging/merge.py @@ -0,0 +1,370 @@ +import torch +from typing import Tuple, Callable, Optional, Union + + +@torch.jit.script +def fast_similarity_chunks( + a: torch.Tensor, b_transposed: torch.Tensor, chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + + B, num_src, C = a.shape + original_dtype = a.dtype + + # Convert to bf16 for computation to improve performance and reduce memory usage + a_bf16 = a.to(torch.bfloat16) + b_transposed_bf16 = b_transposed.to(torch.bfloat16) + node_max = torch.empty(B, num_src, device=a.device, dtype=original_dtype) + node_idx = torch.empty(B, num_src, device=a.device, dtype=torch.long) + + # Process in chunks + for i in range(0, num_src, chunk_size): + end_i = min(i + chunk_size, num_src) + a_chunk = a_bf16[:, i:end_i, :] # [B, chunk_size, C] + scores_chunk = torch.bmm(a_chunk, b_transposed_bf16) + chunk_max_bf16, chunk_idx = torch.max(scores_chunk, dim=2) + chunk_max = chunk_max_bf16.to(original_dtype) + node_max[:, i:end_i] = chunk_max + node_idx[:, i:end_i] = chunk_idx + return node_max, node_idx + + +def do_nothing( + x: torch.Tensor, + extra_tensors=None, + extra_tensors_2=None, +) -> Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], +]: + if extra_tensors is not None and extra_tensors_2 is not None: + return x, extra_tensors, extra_tensors_2 + elif extra_tensors is not None: + return x, extra_tensors + else: + return x + + +def token_merge_bipartite2d( + metric: torch.Tensor, + w: int, + h: int, + sx: int, + sy: int, + r: int, + no_rand: bool = False, + generator: Optional[torch.Generator] = None, + enable_protection: bool = False, +) -> Tuple[Callable, Callable]: + """ + Divide tokens into source (src) and destination (dst) groups, and merge r tokens from src to dst. + dst tokens are selected by randomly choosing one token from each (sx, sy) region. + Optionally protect the top 10% of tokens from merging based on importance scores. + + Args: + - metric [B, N, C]: Tensor for similarity computation, B=batch size, N=token count, C=feature dimension + - w: Image width in tokens + - h: Image height in tokens + - sx: dst stride in x dimension, must divide w evenly + - sy: dst stride in y dimension, must divide h evenly + - r: Number of tokens to remove through merging + - no_rand: If True, disable randomness (use only top-left token) + - generator: Random number generator if no_rand is False and not None + - enable_protection: If True, enable importance protection feature + + Returns: + - (merge, unmerge): Two functions for merging tokens and restoring pre-merge state + """ + B, N, _ = metric.shape # Batch size B, total tokens N + if r <= 0: + return do_nothing, do_nothing + + gather = torch.gather + + tokens_per_img = w * h + 5 + num_imgs = N // tokens_per_img + assert tokens_per_img * num_imgs == N, "Token count doesn't match (w*h+5)*num_imgs" + + with torch.no_grad(): + # Determine whether to compute importance scores based on enable_protection + if enable_protection: + num_protected = int(N * 0.1) + step = max(1, N // num_protected) + protected_indices = torch.arange(0, N, step, device=metric.device)[ + :num_protected + ] + else: + protected_indices = None + num_protected = 0 + + # Global idx_buffer_seq of length N; -1 indicates dst, 0 indicates src (maintain original logic) + idx_buffer_seq = torch.zeros(N, device=metric.device, dtype=torch.int64) + hsy, wsx = h // sy, w // sx # Number of blocks within each image + + # Mark first image entirely as dst + if num_imgs > 0: + idx_buffer_seq[:tokens_per_img] = -1 + + # Process other images - fully vectorized batch operations + if num_imgs > 1: + cls_indices = ( + torch.arange(1, num_imgs, device=metric.device) * tokens_per_img + ) + cls_indices = cls_indices[:, None] + torch.arange(5, device=metric.device) + idx_buffer_seq[cls_indices.flatten()] = -1 + effective_h = min(hsy * sy, h) + effective_w = min(wsx * sx, w) + effective_grid_size = effective_h * effective_w + + if no_rand: + base_pattern = torch.zeros( + effective_grid_size, device=metric.device, dtype=torch.int64 + ) + grid_starts = ( + torch.arange(1, num_imgs, device=metric.device) * tokens_per_img + 5 + ) + grid_indices = grid_starts[:, None] + torch.arange( + effective_grid_size, device=metric.device + ) + idx_buffer_seq[grid_indices.flatten()] = base_pattern.repeat( + num_imgs - 1 + ) + else: + total_other_imgs = num_imgs - 1 + all_rand_idx = torch.randint( + sy * sx, + size=(total_other_imgs, hsy, wsx), + device=metric.device, + generator=generator, + ) + + scatter_src = -torch.ones( + total_other_imgs, hsy, wsx, device=metric.device, dtype=torch.int64 + ) + + idx_buffer_batch = torch.zeros( + total_other_imgs, + hsy, + wsx, + sy * sx, + device=metric.device, + dtype=torch.int64, + ) + idx_buffer_batch.scatter_( + dim=3, + index=all_rand_idx.unsqueeze(-1), + src=scatter_src.unsqueeze(-1), + ) + + idx_buffer_batch = ( + idx_buffer_batch.view(total_other_imgs, hsy, wsx, sy, sx) + .transpose(2, 3) + .reshape(total_other_imgs, hsy * sy, wsx * sx) + ) + + # Batch fill to target positions - still needs a small loop here, but operations are greatly reduced + for i in range(total_other_imgs): + img_idx = i + 1 + grid_start = img_idx * tokens_per_img + 5 + flat_view = idx_buffer_batch[ + i, :effective_h, :effective_w + ].flatten() + idx_buffer_seq[grid_start : grid_start + effective_grid_size] = ( + flat_view + ) + + rand_idx = idx_buffer_seq.reshape(1, -1, 1).argsort(dim=1) + num_dst_orig = int((idx_buffer_seq == -1).sum()) + + # Original src and dst indices + a_idx_orig = rand_idx[:, num_dst_orig:, :] + b_idx_orig = rand_idx[:, :num_dst_orig, :] + a_idx = a_idx_orig + b_idx = b_idx_orig + + if enable_protection: + protected_idx = protected_indices.unsqueeze(0).unsqueeze(-1) + num_protected_actual = protected_idx.shape[1] + else: + protected_idx = None + num_protected_actual = 0 + + num_src = a_idx.shape[1] + num_dst = b_idx.shape[1] + + # Define an internal function to separate src, dst, and protected tokens + def split(x): + C = x.shape[-1] + + if enable_protection: + src = gather(x, dim=1, index=a_idx.expand(B, num_src, C)) + dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) + protected = gather( + x, dim=1, index=protected_idx.expand(B, num_protected_actual, C) + ) + return src, dst, protected + else: + src = gather(x, dim=1, index=a_idx.expand(B, num_src, C)) + dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) + return src, dst + + # Compute cosine similarity (normalize first then dot product) + metric = metric / metric.norm(dim=-1, keepdim=True) + if enable_protection: + a, b, protected = split(metric) + else: + a, b = split(metric) + + r = min(a.shape[1], r) + num_src_actual = a.shape[1] + chunk_size = min(5000, num_src_actual) + + node_max = torch.empty(B, num_src_actual, device=a.device, dtype=a.dtype) + node_idx = torch.empty(B, num_src_actual, device=a.device, dtype=torch.long) + + b_transposed = b.transpose(-1, -2) + node_max, node_idx = fast_similarity_chunks(a, b_transposed, chunk_size) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + # If protection is enabled, filter out protected tokens to ensure they are not merged + if enable_protection: + src_indices = a_idx[0, :, 0] + protected_mask_src = torch.isin(src_indices, protected_indices) + edge_flat = edge_idx[0, :, 0] + valid_mask = ~protected_mask_src[edge_flat] + valid_edges = edge_flat[valid_mask] + + valid_count = valid_edges.shape[0] + r_actual = min(r, valid_count) + + unm_idx = valid_edges[r_actual:].unsqueeze(0).unsqueeze(-1) + src_idx = valid_edges[:r_actual].unsqueeze(0).unsqueeze(-1) + else: + unm_idx = edge_idx[..., r:, :] + src_idx = edge_idx[..., :r, :] + r_actual = r + + # Get dst token indices corresponding to each src token to be merged + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + r = r_actual + + # Define merge function to merge selected src tokens to corresponding dst tokens + def merge( + x: torch.Tensor, + mode: str = "mean", + extra_tensors=None, + extra_tensors_2=None, + ) -> Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ]: + if enable_protection: + src, dst, protected = split(x) + else: + src, dst = split(x) + + n, t1, c = src.shape + + # Extract unmerged src tokens - using actual unm_idx size + unm_len = unm_idx.shape[1] + unm = gather(src, dim=-2, index=unm_idx.expand(n, unm_len, c)) + src_len = src_idx.shape[1] + src = gather(src, dim=-2, index=src_idx.expand(n, src_len, c)) + dst = dst.scatter_reduce(-2, dst_idx.expand(n, src_len, c), src, reduce=mode) + + # ---------------- Extra tensor processing ---------------- + merged_extra_1 = None + merged_extra_2 = None + if extra_tensors is not None: + E_dim = extra_tensors.shape[-1] + if enable_protection: + src_e, dst_e, protected_e = split(extra_tensors) + else: + src_e, dst_e = split(extra_tensors) + + # Consistent with main tensor, only select r src tokens to be merged + src_e_r = gather(src_e, dim=-2, index=src_idx.expand(n, src_len, E_dim)) + unm_e = gather(src_e, dim=-2, index=unm_idx.expand(n, unm_len, E_dim)) + + dst_e = dst_e.scatter_reduce( + -2, dst_idx.expand(n, src_len, E_dim), src_e_r, reduce=mode + ) + if enable_protection: + merged_extra_1 = torch.cat([unm_e, dst_e, protected_e], dim=1) + else: + merged_extra_1 = torch.cat([unm_e, dst_e], dim=1) + + if extra_tensors_2 is not None: + E_dim_2 = extra_tensors_2.shape[-1] + if enable_protection: + src_e2, dst_e2, protected_e2 = split(extra_tensors_2) + else: + src_e2, dst_e2 = split(extra_tensors_2) + + src_e2_r = gather(src_e2, dim=-2, index=src_idx.expand(n, src_len, E_dim_2)) + unm_e2 = gather(src_e2, dim=-2, index=unm_idx.expand(n, unm_len, E_dim_2)) + + dst_e2 = dst_e2.scatter_reduce( + -2, dst_idx.expand(n, src_len, E_dim_2), src_e2_r, reduce=mode + ) + if enable_protection: + merged_extra_2 = torch.cat([unm_e2, dst_e2, protected_e2], dim=1) + else: + merged_extra_2 = torch.cat([unm_e2, dst_e2], dim=1) + + if enable_protection: + main_result = torch.cat([unm, dst, protected], dim=1) + else: + main_result = torch.cat([unm, dst], dim=1) + + if merged_extra_1 is not None and merged_extra_2 is not None: + return main_result, merged_extra_1, merged_extra_2 + elif merged_extra_1 is not None: + return main_result, merged_extra_1 + else: + return main_result + + # Define unmerge function to restore pre-merge state (for decoder) + def unmerge(x: torch.Tensor) -> torch.Tensor: + unm_len = unm_idx.shape[1] + dst_len = num_dst + src_len = src_idx.shape[1] + unm = x[..., :unm_len, :] + dst = x[..., unm_len : unm_len + dst_len, :] + + if enable_protection: + protected = x[ + ..., unm_len + dst_len : unm_len + dst_len + num_protected_actual, : + ] + + _, _, c = unm.shape + src = gather(dst, dim=-2, index=dst_idx.expand(B, src_len, c)) + out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) + out.scatter_( + dim=-2, + index=gather( + a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx + ).expand(B, unm_len, c), + src=unm, + ) + + out.scatter_( + dim=-2, + index=gather( + a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx + ).expand(B, src_len, c), + src=src, + ) + + if enable_protection: + out.scatter_( + dim=-2, + index=protected_idx.expand(B, num_protected_actual, c), + src=protected, + ) + + return out + + return merge, unmerge diff --git a/FastVGGT/requirements.txt b/FastVGGT/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e343637f2d9a1550f55f14f8890c229bd467cdb2 --- /dev/null +++ b/FastVGGT/requirements.txt @@ -0,0 +1,15 @@ +torch==2.3.1 +torchvision==0.18.1 +numpy==1.26.1 +Pillow +huggingface_hub +einops +safetensors +evo +open3d +matplotlib +scipy +opencv-python +scikit-image +tqdm + diff --git a/FastVGGT/vggt/__init__.py b/FastVGGT/vggt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2958f93a114495631a3ce270b99ee8ba8443f1 --- /dev/null +++ b/FastVGGT/vggt/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5f3bfe676b899c95cba2c034de5292d685e03ce Binary files /dev/null and b/FastVGGT/vggt/__pycache__/__init__.cpython-310.pyc differ diff --git a/FastVGGT/vggt/dependency/__init__.py b/FastVGGT/vggt/dependency/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2958f93a114495631a3ce270b99ee8ba8443f1 --- /dev/null +++ b/FastVGGT/vggt/dependency/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd09b8397b7a4717d49afb51e64d490fa11fc205 Binary files /dev/null and b/FastVGGT/vggt/dependency/__pycache__/__init__.cpython-310.pyc differ diff --git a/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc b/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..389177102f318fe4bcfd46f6ab4a04e1a8f39196 Binary files /dev/null and b/FastVGGT/vggt/dependency/__pycache__/distortion.cpython-310.pyc differ diff --git a/FastVGGT/vggt/dependency/distortion.py b/FastVGGT/vggt/dependency/distortion.py new file mode 100644 index 0000000000000000000000000000000000000000..375b747086478050d601676ffaea3b25e80690b4 --- /dev/null +++ b/FastVGGT/vggt/dependency/distortion.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + + +def apply_distortion(points, distortion_params): + """ + Apply distortion to normalized camera coordinates. + + Args: + points: Array of normalized camera coordinates + distortion_params: Distortion parameters + + Returns: + Distorted coordinates + """ + # Simple passthrough for now - implement actual distortion if needed + return points + + +def iterative_undistortion(points, distortion_params, max_iter=10): + """ + Remove distortion from normalized camera coordinates using iterative method. + + Args: + points: Array of distorted normalized camera coordinates + distortion_params: Distortion parameters + max_iter: Maximum number of iterations + + Returns: + Undistorted coordinates + """ + # Simple passthrough for now - implement actual undistortion if needed + return points + + +def single_undistortion(points, distortion_params): + """ + Remove distortion from normalized camera coordinates using single step. + + Args: + points: Array of distorted normalized camera coordinates + distortion_params: Distortion parameters + + Returns: + Undistorted coordinates + """ + # Simple passthrough for now - implement actual undistortion if needed + return points \ No newline at end of file diff --git a/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..292dc30328d06ba3951a9eee0e99bc569c60a1d3 Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/camera_head.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..838db11a0862e9907682572d8ebe3b7c1665e147 Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/dpt_head.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d6cd93db3a5987c35085d487d21f5513ea3d1fd Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/head_act.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3244c2b0364f9b533aab03446edb787dbed59d5 Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/track_head.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc b/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc2e8df264f6accc5412e80b8580c8c48195aaf5 Binary files /dev/null and b/FastVGGT/vggt/heads/__pycache__/utils.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/camera_head.py b/FastVGGT/vggt/heads/camera_head.py new file mode 100644 index 0000000000000000000000000000000000000000..215adf39de23abd4975479d332250fcc3e2b54b9 --- /dev/null +++ b/FastVGGT/vggt/heads/camera_head.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vggt.layers import Mlp +from vggt.layers.block import Block +from vggt.heads.head_act import activate_pose + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape # S is expected to be 1. + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + pose_tokens_modulated = self.trunk(pose_tokens_modulated) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act + ) + pred_pose_enc_list.append(activated_pose) + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift diff --git a/FastVGGT/vggt/heads/dpt_head.py b/FastVGGT/vggt/heads/dpt_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e20d65ef5bfeb23cf83ca748aedff738840b4ffd --- /dev/null +++ b/FastVGGT/vggt/heads/dpt_head.py @@ -0,0 +1,598 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +import os +from typing import List, Dict, Tuple, Union, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [0, 1, 2, 3], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=dim_in, + out_channels=oc, + kernel_size=1, + stride=1, + padding=0, + ) + for oc in out_channels + ] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0, + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1, + ), + ] + ) + + self.scratch = _make_scratch(out_channels, features, expand=False) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = nn.Identity() # Use Identity instead of None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1, kernel_size=3, stride=1, padding=1 + ) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, + head_features_1 // 2, + kernel_size=3, + stride=1, + padding=1, + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d( + conv2_in_channels, + head_features_2, + kernel_size=3, + stride=1, + padding=1, + ), + nn.ReLU(inplace=True), + nn.Conv2d( + head_features_2, output_dim, kernel_size=1, stride=1, padding=0 + ), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, S, _, H, W = images.shape + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, + images, + patch_start_idx, + frames_start_idx, + frames_end_idx, + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, + images, + patch_start_idx, + frames_start_idx, + frames_end_idx, + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: Optional[int] = None, + frames_end_idx: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + if frames_start_idx is not None and frames_end_idx is not None: + images = images[:, frames_start_idx:frames_end_idx].contiguous() + + B, S, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + # Select frames if processing a chunk + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + x = x.reshape(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + ( + int(patch_h * self.patch_size / self.down_ratio), + int(patch_w * self.patch_size / self.down_ratio), + ), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head( + out, activation=self.activation, conf_activation=self.conf_activation + ) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed( + self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1 + ) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid( + patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device + ) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block( + features: int, + size: Optional[int] = None, + has_residual: bool = True, + groups: int = 1, +) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch( + in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False +) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=self.groups, + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit( + features, activation, bn, groups=self.groups + ) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit( + features, activation, bn, groups=self.groups + ) + + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = output + res + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Optional[Tuple[int, int]] = None, + scale_factor: Optional[float] = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate( + chunk, size=size, mode=mode, align_corners=align_corners + ) + for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate( + x, size=size, mode=mode, align_corners=align_corners + ) diff --git a/FastVGGT/vggt/heads/head_act.py b/FastVGGT/vggt/heads/head_act.py new file mode 100644 index 0000000000000000000000000000000000000000..2dedfcf1180a653dddc99623e60df625e5897489 --- /dev/null +++ b/FastVGGT/vggt/heads/head_act.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/FastVGGT/vggt/heads/track_head.py b/FastVGGT/vggt/heads/track_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f1d9bd83cca1f74f97a644a02b984904f84706 --- /dev/null +++ b/FastVGGT/vggt/heads/track_head.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn +from .dpt_head import DPTHead +from .track_modules.base_track_predictor import BaseTrackerPredictor + + +class TrackHead(nn.Module): + """ + Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking. + The tracking is performed iteratively, refining predictions over multiple iterations. + """ + + def __init__( + self, + dim_in, + patch_size=14, + features=128, + iters=4, + predict_conf=True, + stride=2, + corr_levels=7, + corr_radius=4, + hidden_size=384, + ): + """ + Initialize the TrackHead module. + + Args: + dim_in (int): Input dimension of tokens from the backbone. + patch_size (int): Size of image patches used in the vision transformer. + features (int): Number of feature channels in the feature extractor output. + iters (int): Number of refinement iterations for tracking predictions. + predict_conf (bool): Whether to predict confidence scores for tracked points. + stride (int): Stride value for the tracker predictor. + corr_levels (int): Number of correlation pyramid levels + corr_radius (int): Radius for correlation computation, controlling the search area. + hidden_size (int): Size of hidden layers in the tracker network. + """ + super().__init__() + + self.patch_size = patch_size + + # Feature extractor based on DPT architecture + # Processes tokens into feature maps for tracking + self.feature_extractor = DPTHead( + dim_in=dim_in, + patch_size=patch_size, + features=features, + feature_only=True, # Only output features, no activation + down_ratio=2, # Reduces spatial dimensions by factor of 2 + pos_embed=False, + ) + + # Tracker module that predicts point trajectories + # Takes feature maps and predicts coordinates and visibility + self.tracker = BaseTrackerPredictor( + latent_dim=features, # Match the output_dim of feature extractor + predict_conf=predict_conf, + stride=stride, + corr_levels=corr_levels, + corr_radius=corr_radius, + hidden_size=hidden_size, + ) + + self.iters = iters + + def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None): + """ + Forward pass of the TrackHead. + + Args: + aggregated_tokens_list (list): List of aggregated tokens from the backbone. + images (torch.Tensor): Input images of shape (B, S, C, H, W) where: + B = batch size, S = sequence length. + patch_start_idx (int): Starting index for patch tokens. + query_points (torch.Tensor, optional): Initial query points to track. + If None, points are initialized by the tracker. + iters (int, optional): Number of refinement iterations. If None, uses self.iters. + + Returns: + tuple: + - coord_preds (torch.Tensor): Predicted coordinates for tracked points. + - vis_scores (torch.Tensor): Visibility scores for tracked points. + - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True). + """ + B, S, _, H, W = images.shape + + # Extract features from tokens + # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2 + feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx) + + # Use default iterations if not specified + if iters is None: + iters = self.iters + + # Perform tracking using the extracted features + coord_preds, vis_scores, conf_scores = self.tracker(query_points=query_points, fmaps=feature_maps, iters=iters) + + return coord_preds, vis_scores, conf_scores diff --git a/FastVGGT/vggt/heads/track_modules/__init__.py b/FastVGGT/vggt/heads/track_modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/FastVGGT/vggt/heads/track_modules/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ca81dbf0ca49b601ca2662314a3391c8ba00890 Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..350f8ca52c0aca644125fea47f83d49337478bcb Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81e9466ff91b16d265807533eeada2e8e674d7e0 Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3fe19d5e60445f0fa06866bffaa2b8b2e60a6f8 Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc b/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cac348bf405bc00bc6578760ba8ad57185f33b81 Binary files /dev/null and b/FastVGGT/vggt/heads/track_modules/__pycache__/utils.cpython-310.pyc differ diff --git a/FastVGGT/vggt/heads/track_modules/base_track_predictor.py b/FastVGGT/vggt/heads/track_modules/base_track_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce8ec4b66fff236e015d1bcaf85c8237a52be7a --- /dev/null +++ b/FastVGGT/vggt/heads/track_modules/base_track_predictor.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +from .blocks import EfficientUpdateFormer, CorrBlock +from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed +from .modules import Mlp + + +class BaseTrackerPredictor(nn.Module): + def __init__( + self, + stride=1, + corr_levels=5, + corr_radius=4, + latent_dim=128, + hidden_size=384, + use_spaceatt=True, + depth=6, + max_scale=518, + predict_conf=True, + ): + super(BaseTrackerPredictor, self).__init__() + """ + The base template to create a track predictor + + Modified from https://github.com/facebookresearch/co-tracker/ + and https://github.com/facebookresearch/vggsfm + """ + + self.stride = stride + self.latent_dim = latent_dim + self.corr_levels = corr_levels + self.corr_radius = corr_radius + self.hidden_size = hidden_size + self.max_scale = max_scale + self.predict_conf = predict_conf + + self.flows_emb_dim = latent_dim // 2 + + self.corr_mlp = Mlp( + in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2, + hidden_features=self.hidden_size, + out_features=self.latent_dim, + ) + + self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4 + + self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim)) + + space_depth = depth if use_spaceatt else 0 + time_depth = depth + + self.updateformer = EfficientUpdateFormer( + space_depth=space_depth, + time_depth=time_depth, + input_dim=self.transformer_dim, + hidden_size=self.hidden_size, + output_dim=self.latent_dim + 2, + mlp_ratio=4.0, + add_space_attn=use_spaceatt, + ) + + self.fmap_norm = nn.LayerNorm(self.latent_dim) + self.ffeat_norm = nn.GroupNorm(1, self.latent_dim) + + # A linear layer to update track feats at each iteration + self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()) + + self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + if predict_conf: + self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1)) + + def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True): + """ + query_points: B x N x 2, the number of batches, tracks, and xy + fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension. + note HH and WW is the size of feature maps instead of original images + """ + B, N, D = query_points.shape + B, S, C, HH, WW = fmaps.shape + + assert D == 2, "Input points must be 2D coordinates" + + # apply a layernorm to fmaps here + fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2)) + fmaps = fmaps.permute(0, 1, 4, 2, 3) + + # Scale the input query_points because we may downsample the images + # by down_ratio or self.stride + # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map + # its query_points should be query_points/4 + if down_ratio > 1: + query_points = query_points / float(down_ratio) + + query_points = query_points / float(self.stride) + + # Init with coords as the query points + # It means the search will start from the position of query points at the reference frames + coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1) + + # Sample/extract the features of the query points in the query frame + query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0]) + + # init track feats by query feats + track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C + # back up the init coords + coords_backup = coords.clone() + + fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius) + + coord_preds = [] + + # Iterative Refinement + for _ in range(iters): + # Detach the gradients from the last iteration + # (in my experience, not very important for performance) + coords = coords.detach() + + fcorrs = fcorr_fn.corr_sample(track_feats, coords) + + corr_dim = fcorrs.shape[3] + fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim) + fcorrs_ = self.corr_mlp(fcorrs_) + + # Movement of current coords relative to query points + flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2) + + flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False) + + # (In my trials, it is also okay to just add the flows_emb instead of concat) + flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1) + + track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim) + + # Concatenate them as the input for the transformers + transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2) + + # 2D positional embed + # TODO: this can be much simplified + pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device) + sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0]) + + sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1) + + x = transformer_input + sampled_pos_emb + + # Add the query ref token to the track feats + query_ref_token = torch.cat( + [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1 + ) + x = x + query_ref_token.to(x.device).to(x.dtype) + + # B, N, S, C + x = rearrange(x, "(b n) s d -> b n s d", b=B) + + # Compute the delta coordinates and delta track features + delta, _ = self.updateformer(x) + + # BN, S, C + delta = rearrange(delta, " b n s d -> (b n) s d", b=B) + delta_coords_ = delta[:, :, :2] + delta_feats_ = delta[:, :, 2:] + + track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim) + delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim) + + # Update the track features + track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_ + + track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC + + # B x S x N x 2 + coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3) + + # Force coord0 as query + # because we assume the query points should not be changed + coords[:, 0] = coords_backup[:, 0] + + # The predicted tracks are in the original image scale + if down_ratio > 1: + coord_preds.append(coords * self.stride * down_ratio) + else: + coord_preds.append(coords * self.stride) + + # B, S, N + vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + vis_e = torch.sigmoid(vis_e) + + if self.predict_conf: + conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N) + if apply_sigmoid: + conf_e = torch.sigmoid(conf_e) + else: + conf_e = None + + if return_feat: + return coord_preds, vis_e, track_feats, query_track_feat, conf_e + else: + return coord_preds, vis_e, conf_e diff --git a/FastVGGT/vggt/heads/track_modules/blocks.py b/FastVGGT/vggt/heads/track_modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..15c161c89ef99742b0f2c6f397c9121fe9301e08 --- /dev/null +++ b/FastVGGT/vggt/heads/track_modules/blocks.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Modified from https://github.com/facebookresearch/co-tracker/ + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import bilinear_sampler +from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock + + +class EfficientUpdateFormer(nn.Module): + """ + Transformer model that updates track estimates. + """ + + def __init__( + self, + space_depth=6, + time_depth=6, + input_dim=320, + hidden_size=384, + num_heads=8, + output_dim=130, + mlp_ratio=4.0, + add_space_attn=True, + num_virtual_tracks=64, + ): + super().__init__() + + self.out_channels = 2 + self.num_heads = num_heads + self.hidden_size = hidden_size + self.add_space_attn = add_space_attn + + # Add input LayerNorm before linear projection + self.input_norm = nn.LayerNorm(input_dim) + self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True) + + # Add output LayerNorm before final projection + self.output_norm = nn.LayerNorm(hidden_size) + self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True) + self.num_virtual_tracks = num_virtual_tracks + + if self.add_space_attn: + self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size)) + else: + self.virual_tracks = None + + self.time_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(time_depth) + ] + ) + + if add_space_attn: + self.space_virtual_blocks = nn.ModuleList( + [ + AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention) + for _ in range(space_depth) + ] + ) + self.space_point2virtual_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + self.space_virtual2point_blocks = nn.ModuleList( + [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)] + ) + assert len(self.time_blocks) >= len(self.space_virtual2point_blocks) + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001) + + self.apply(_basic_init) + + def forward(self, input_tensor, mask=None): + # Apply input LayerNorm + input_tensor = self.input_norm(input_tensor) + tokens = self.input_transform(input_tensor) + + init_tokens = tokens + + B, _, T, _ = tokens.shape + + if self.add_space_attn: + virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1) + tokens = torch.cat([tokens, virtual_tokens], dim=1) + + _, N, _, _ = tokens.shape + + j = 0 + for i in range(len(self.time_blocks)): + time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C + + time_tokens = self.time_blocks[i](time_tokens) + + tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C + if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0): + space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C + point_tokens = space_tokens[:, : N - self.num_virtual_tracks] + virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :] + + virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask) + virtual_tokens = self.space_virtual_blocks[j](virtual_tokens) + point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask) + + space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1) + tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C + j += 1 + + if self.add_space_attn: + tokens = tokens[:, : N - self.num_virtual_tracks] + + tokens = tokens + init_tokens + + # Apply output LayerNorm before final projection + tokens = self.output_norm(tokens) + flow = self.flow_head(tokens) + + return flow, None + + +class CorrBlock: + def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"): + """ + Build a pyramid of feature maps from the input. + + fmaps: Tensor (B, S, C, H, W) + num_levels: number of pyramid levels (each downsampled by factor 2) + radius: search radius for sampling correlation + multiple_track_feats: if True, split the target features per pyramid level + padding_mode: passed to grid_sample / bilinear_sampler + """ + B, S, C, H, W = fmaps.shape + self.S, self.C, self.H, self.W = S, C, H, W + self.num_levels = num_levels + self.radius = radius + self.padding_mode = padding_mode + self.multiple_track_feats = multiple_track_feats + + # Build pyramid: each level is half the spatial resolution of the previous + self.fmaps_pyramid = [fmaps] # level 0 is full resolution + current_fmaps = fmaps + for i in range(num_levels - 1): + B, S, C, H, W = current_fmaps.shape + # Merge batch & sequence dimensions + current_fmaps = current_fmaps.reshape(B * S, C, H, W) + # Avg pool down by factor 2 + current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2) + _, _, H_new, W_new = current_fmaps.shape + current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new) + self.fmaps_pyramid.append(current_fmaps) + + # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling. + # This grid is added to the (scaled) coordinate centroids. + r = self.radius + dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype) + # delta: for every (dy,dx) displacement (i.e. Δx, Δy) + self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2) + + def corr_sample(self, targets, coords): + """ + Instead of storing the entire correlation pyramid, we compute each level's correlation + volume, sample it immediately, then discard it. This saves GPU memory. + + Args: + targets: Tensor (B, S, N, C) — features for the current targets. + coords: Tensor (B, S, N, 2) — coordinates at full resolution. + + Returns: + Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations) + """ + B, S, N, C = targets.shape + + # If you have multiple track features, split them per level. + if self.multiple_track_feats: + targets_split = torch.split(targets, C // self.num_levels, dim=-1) + + out_pyramid = [] + for i, fmaps in enumerate(self.fmaps_pyramid): + # Get current spatial resolution H, W for this pyramid level. + B, S, C, H, W = fmaps.shape + # Reshape feature maps for correlation computation: + # fmap2s: (B, S, C, H*W) + fmap2s = fmaps.view(B, S, C, H * W) + # Choose appropriate target features. + fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C) + + # Compute correlation directly + corrs = compute_corr_level(fmap1, fmap2s, C) + corrs = corrs.view(B, S, N, H, W) + + # Prepare sampling grid: + # Scale down the coordinates for the current level. + centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i) + # Make sure our precomputed delta grid is on the same device/dtype. + delta_lvl = self.delta.to(coords.device).to(coords.dtype) + # Now the grid for grid_sample is: + # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid) + coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2) + + # Sample from the correlation volume using bilinear interpolation. + # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target. + corrs_sampled = bilinear_sampler( + corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode + ) + # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims. + corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2) + out_pyramid.append(corrs_sampled) + + # Concatenate all levels along the last dimension. + out = torch.cat(out_pyramid, dim=-1).contiguous() + return out + + +def compute_corr_level(fmap1, fmap2s, C): + # fmap1: (B, S, N, C) + # fmap2s: (B, S, C, H*W) + corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W) + corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W) + return corrs / math.sqrt(C) diff --git a/FastVGGT/vggt/heads/track_modules/modules.py b/FastVGGT/vggt/heads/track_modules/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..12de4f1ad76364d4665e53ac80e1037fadf98d08 --- /dev/null +++ b/FastVGGT/vggt/heads/track_modules/modules.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial +from typing import Callable +import collections +from torch import Tensor +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +to_2tuple = _ntuple(2) + + +class ResidualBlock(nn.Module): + """ + ResidualBlock: construct a block of two conv layers with residual connections + """ + + def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=kernel_size, padding=1, stride=stride, padding_mode="zeros" + ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros") + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + else: + raise NotImplementedError + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class AttnBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, + mlp_ratio=4.0, + **block_kwargs, + ): + """ + Self attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, mask=None): + # Prepare the mask for PyTorch's attention (it expects a different format) + # attn_mask = mask if mask is not None else None + # Normalize before attention + x = self.norm1(x) + + # PyTorch's MultiheadAttention returns attn_output, attn_output_weights + # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) + + attn_output, _ = self.attn(x, x, x) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttnBlock(nn.Module): + def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): + """ + Cross attention block + """ + super().__init__() + + self.norm1 = nn.LayerNorm(hidden_size) + self.norm_context = nn.LayerNorm(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size) + + self.cross_attn = nn.MultiheadAttention( + embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs + ) + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) + + def forward(self, x, context, mask=None): + # Normalize inputs + x = self.norm1(x) + context = self.norm_context(context) + + # Apply cross attention + # Note: nn.MultiheadAttention returns attn_output, attn_output_weights + attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) + + # Add & Norm + x = x + attn_output + x = x + self.mlp(self.norm2(x)) + return x diff --git a/FastVGGT/vggt/heads/track_modules/utils.py b/FastVGGT/vggt/heads/track_modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f1fffeaedd33c7f1c2ef54220e24a2a0e5a57b2 --- /dev/null +++ b/FastVGGT/vggt/heads/track_modules/utils.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from https://github.com/facebookresearch/vggsfm +# and https://github.com/facebookresearch/co-tracker/tree/main + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Union + + +def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: + """ + This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. + It is a wrapper of get_2d_sincos_pos_embed_from_grid. + Args: + - embed_dim: The embedding dimension. + - grid_size: The grid size. + Returns: + - pos_embed: The generated 2D positional embedding. + """ + if isinstance(grid_size, tuple): + grid_size_h, grid_size_w = grid_size + else: + grid_size_h = grid_size_w = grid_size + grid_h = torch.arange(grid_size_h, dtype=torch.float) + grid_w = torch.arange(grid_size_w, dtype=torch.float) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") + grid = torch.stack(grid, dim=0) + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if return_grid: + return (pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), grid) + return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: + """ + This function generates a 2D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - grid: The grid to generate the embedding from. + + Returns: + - emb: The generated 2D positional embedding. + """ + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.double) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb[None].float() + + +def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: + """ + This function generates a 2D positional embedding from given coordinates using sine and cosine functions. + + Args: + - xy: The coordinates to generate the embedding from. + - C: The size of the embedding. + - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. + + Returns: + - pe: The generated 2D positional embedding. + """ + B, N, D = xy.shape + assert D == 2 + + x = xy[:, :, 0:1] + y = xy[:, :, 1:2] + div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) + + pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) + + pe_x[:, :, 0::2] = torch.sin(x * div_term) + pe_x[:, :, 1::2] = torch.cos(x * div_term) + + pe_y[:, :, 0::2] = torch.sin(y * div_term) + pe_y[:, :, 1::2] = torch.cos(y * div_term) + + pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) + if cat_coords: + pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) + return pe + + +def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): + r"""Sample a tensor using bilinear interpolation + + `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at + coordinates :attr:`coords` using bilinear interpolation. It is the same + as `torch.nn.functional.grid_sample()` but with a different coordinate + convention. + + The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where + :math:`B` is the batch size, :math:`C` is the number of channels, + :math:`H` is the height of the image, and :math:`W` is the width of the + image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is + interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. + + Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, + in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note + that in this case the order of the components is slightly different + from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. + + If `align_corners` is `True`, the coordinate :math:`x` is assumed to be + in the range :math:`[0,W-1]`, with 0 corresponding to the center of the + left-most image pixel :math:`W-1` to the center of the right-most + pixel. + + If `align_corners` is `False`, the coordinate :math:`x` is assumed to + be in the range :math:`[0,W]`, with 0 corresponding to the left edge of + the left-most pixel :math:`W` to the right edge of the right-most + pixel. + + Similar conventions apply to the :math:`y` for the range + :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range + :math:`[0,T-1]` and :math:`[0,T]`. + + Args: + input (Tensor): batch of input images. + coords (Tensor): batch of coordinates. + align_corners (bool, optional): Coordinate convention. Defaults to `True`. + padding_mode (str, optional): Padding mode. Defaults to `"border"`. + + Returns: + Tensor: sampled points. + """ + coords = coords.detach().clone() + ############################################################ + # IMPORTANT: + coords = coords.to(input.device).to(input.dtype) + ############################################################ + + sizes = input.shape[2:] + + assert len(sizes) in [2, 3] + + if len(sizes) == 3: + # t x y -> x y t to match dimensions T H W in grid_sample + coords = coords[..., [1, 2, 0]] + + if align_corners: + scale = torch.tensor( + [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype + ) + else: + scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) + + coords.mul_(scale) # coords = coords * scale + coords.sub_(1) # coords = coords - 1 + + return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) + + +def sample_features4d(input, coords): + r"""Sample spatial features + + `sample_features4d(input, coords)` samples the spatial features + :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. + + The field is sampled at coordinates :attr:`coords` using bilinear + interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, + 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the + same convention as :func:`bilinear_sampler` with `align_corners=True`. + + The output tensor has one feature per point, and has shape :math:`(B, + R, C)`. + + Args: + input (Tensor): spatial features. + coords (Tensor): points. + + Returns: + Tensor: sampled features. + """ + + B, _, _, _ = input.shape + + # B R 2 -> B R 1 2 + coords = coords.unsqueeze(2) + + # B C R 1 + feats = bilinear_sampler(input, coords) + + return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C diff --git a/FastVGGT/vggt/heads/utils.py b/FastVGGT/vggt/heads/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..533fc8ae67a75cd0a94d5ca96dc5a0513446c64f --- /dev/null +++ b/FastVGGT/vggt/heads/utils.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + device = pos.device + omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/FastVGGT/vggt/layers/__init__.py b/FastVGGT/vggt/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/FastVGGT/vggt/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2840683edfd99fcda3d1d5c9e18d699dba029a86 Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42e56e451e90459b8b74d99bd0ac9049e702e1bd Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/attention.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..069b3ad92a285342f200297a7358dbf39c0ef372 Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/block.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0bc9cb8a27db08fc0921a2d8a54a60b2b126acd Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/drop_path.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a416380412c8942981ab0e3f5369cebe05f9496c Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/layer_scale.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..053bf054f2a74eefeea0536eeddd23daacd28cbe Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/mlp.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20fa7427ef4c285abd1d9b48743f4c3360953591 Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af9854cf6446fb30404f4529d63b20c20f8e63e8 Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/rope.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9b1fc5deb392addff7c8620b095e0ba4d90a8de Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc b/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d36ad71ca8dacfb3991b5b2560bc89cdfdd15c52 Binary files /dev/null and b/FastVGGT/vggt/layers/__pycache__/vision_transformer.cpython-310.pyc differ diff --git a/FastVGGT/vggt/layers/attention.py b/FastVGGT/vggt/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..aef68a4e2c628ad257a6b77b46447c8421c5b11a --- /dev/null +++ b/FastVGGT/vggt/layers/attention.py @@ -0,0 +1,257 @@ +import os +from pathlib import Path +import time +from PIL import Image +import torch +from torch import Tensor +from torch import nn +import torch.nn.functional as F +from tqdm.std import tqdm +from merging.merge import ( + token_merge_bipartite2d, +) +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image + +XFORMERS_AVAILABLE = False + +# Global variables for attention visualization +vis_attn_map = False +current_images = [] +attention_map = None + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + kv_group_size: int = 1, + fused_attn: bool = True, + rope=None, + global_merging=None, + patch_width: int = 37, + patch_height: int = 28, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.patch_width = patch_width + self.patch_height = patch_height + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + self.kv_group_size = kv_group_size + + def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor: + merge_num = list(range(24)) + + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if vis_attn_map and global_merging is not None: + # s1 Chunk computation of q@k + chunk_size = 4096 + attn_maps = [] + total_chunks = ( + q.size(-2) + chunk_size - 1 + ) // chunk_size # Calculate total number of chunks + start_time_total = time.time() + with torch.no_grad(): + for start in tqdm( + range(0, q.size(-2), chunk_size), + total=total_chunks, + desc="Processing chunks", + ): + end = min(start + chunk_size, q.size(-2)) + q_chunk = q[:, :, start:end, :] # (1, 16, chunk_size, 64) + start_time_chunk = time.time() + attn_chunk = q_chunk @ k.transpose( + -2, -1 + ) # (1, 16, chunk_size, 34353) + attn_maps.append(attn_chunk.cpu()) + end_time_chunk = time.time() + print( + f"Chunk {start}:{end} processed in {end_time_chunk - start_time_chunk:.4f} seconds" + ) + end_time_total = time.time() + print( + f"\nTotal processing time: {end_time_total - start_time_total:.4f} seconds" + ) + + attn_map = torch.cat(attn_maps, dim=-2) + attn = attn_map[0].mean(0) + frame_token_num = self.patch_height * self.patch_width + 5 + for target_token_idx in [ + 0, + self.patch_height * self.patch_width, + self.patch_height * self.patch_width * 10, + ]: # Iterate through each image's target_token + for image_idx in range( + len(current_images) + ): # Corresponding to which image to visualize + target_attn = attn[ + target_token_idx, + image_idx * frame_token_num : (image_idx + 1) * frame_token_num, + ] + target_attn_map = target_attn[5:].reshape( + self.patch_height, self.patch_width + ) + # 1) Read original image to get true size (H, W) + image_path = current_images[image_idx] + p = Path(image_path) + parts = p.parts + scene_name = parts[-4] + + image = Image.open(image_path).convert("RGB") + img_width, img_height = image.size # PIL size: (W, H) + + # Upsample attention map to the original image size + target_attn_map = F.interpolate( + target_attn_map.unsqueeze(0).unsqueeze(0), # (1,1,h,w) + size=(img_height, img_width), + mode="bilinear", + ) + target_attn_map = target_attn_map.squeeze() + + # Convert image to numpy for blending + img_np = np.array(image) / 255.0 + + # 2. Normalize attention map + target_attn_map = (target_attn_map - target_attn_map.min()) / ( + target_attn_map.max() - target_attn_map.min() + ) + + # 3. Color attention map + cmap = plt.get_cmap("jet") + attn_color = cmap(target_attn_map.cpu().float().numpy()) + attn_color = attn_color[:, :, :3] + + # 4. Blend attention and original image + overlay = img_np * 0.5 + attn_color * 0.5 + + plt.imshow(overlay, cmap="viridis") + output_dir = f"attention_map/{scene_name}/block_{attention_map}/token_{target_token_idx}" + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, f"color_{image_idx}.png") + plt.savefig(output_path) + plt.close() + + if global_merging is not None and global_merging in merge_num: + generator = torch.Generator(device=x.device) + generator.manual_seed(33) + + merge_ratio = 0.9 + r = int(x.shape[1] * merge_ratio) + + m, u = token_merge_bipartite2d( + x, + self.patch_width, + self.patch_height, + 2, + 2, + r, + False, + generator, + enable_protection=True, + ) + + m_a, u_a = (m, u) + + B_q, H_q, N_q, D_q = q.shape + + q_merge_in = q.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q) + k_merge_in = k.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q) + v_merge_in = v.permute(0, 2, 1, 3).reshape(B_q, N_q, H_q * D_q) + + q_out, k_out, v_out = m_a( + q_merge_in, + mode="mean", + extra_tensors=k_merge_in, + extra_tensors_2=v_merge_in, + ) + + del q_merge_in, k_merge_in, v_merge_in + + N_m = q_out.shape[1] + q = q_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3) + k = k_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3) + v = v_out.reshape(B_q, N_m, H_q, D_q).permute(0, 2, 1, 3) + + del q_out, k_out, v_out + + N = N_m + + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + del q, k, v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + if global_merging is not None and global_merging in merge_num: + x = u_a(x) + return x + + +class MemEffAttention(Attention): + def forward( + self, x: Tensor, attn_bias=None, pos=None, global_merging=None + ) -> Tensor: + assert ( + pos is None or self.rope is not None + ), "Position encoding is only supported with RoPE" + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x, pos=pos, global_merging=global_merging) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = qkv.unbind(2) + + if self.rope is not None and pos is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + # Use scaled dot-product attention + attn = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim**-0.5) + if attn_bias is not None: + attn = attn + attn_bias + attn = F.softmax(attn, dim=-1) + x = torch.matmul(attn, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/FastVGGT/vggt/layers/block.py b/FastVGGT/vggt/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa3c18aaf79fc755630d24c24c1f4c7f75624e1 --- /dev/null +++ b/FastVGGT/vggt/layers/block.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + + +from typing import Callable, List, Any, Tuple, Dict +import torch +from torch import nn, Tensor +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, + rope=None, + merging=0, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None, global_merging=None) -> Tensor: + norm1_output = self.norm1(x) + attn_output = self.attn(norm1_output, pos=pos, global_merging=global_merging) + del norm1_output + x = x + self.ls1(attn_output) + del attn_output + + norm2_output = self.norm2(x) + mlp_output = self.mlp(norm2_output) + del norm2_output + x = x + self.ls2(mlp_output) + del mlp_output + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, + pos=None, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls1.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls2.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/FastVGGT/vggt/layers/drop_path.py b/FastVGGT/vggt/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/FastVGGT/vggt/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/FastVGGT/vggt/layers/layer_scale.py b/FastVGGT/vggt/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..4ddfc51c3d87370d50175f5b4e649dac1c614ff9 --- /dev/null +++ b/FastVGGT/vggt/layers/layer_scale.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/FastVGGT/vggt/layers/mlp.py b/FastVGGT/vggt/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/FastVGGT/vggt/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/FastVGGT/vggt/layers/patch_embed.py b/FastVGGT/vggt/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..bc19605e4d6e88d06355ae3b1afddc76f595aafe --- /dev/null +++ b/FastVGGT/vggt/layers/patch_embed.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/FastVGGT/vggt/layers/rope.py b/FastVGGT/vggt/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..84625de468ed89e69dd9e1579d541de71f2ebf37 --- /dev/null +++ b/FastVGGT/vggt/layers/rope.py @@ -0,0 +1,209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__( + self, batch_size: int, height: int, width: int, device: torch.device + ) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return ( + cached_positions.view(1, height * width, 2) + .expand(batch_size, -1, -1) + .clone() + ) + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device) / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, + tokens: torch.Tensor, + positions: torch.Tensor, + cos_comp: torch.Tensor, + sin_comp: torch.Tensor, + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + if positions.dtype != torch.long: + positions = positions.long() + + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert ( + positions.ndim == 3 and positions.shape[-1] == 2 + ), "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components( + feature_dim, max_position, tokens.device, tokens.dtype + ) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope( + vertical_features, positions[..., 0], cos_comp, sin_comp + ) + horizontal_features = self._apply_1d_rope( + horizontal_features, positions[..., 1], cos_comp, sin_comp + ) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/FastVGGT/vggt/layers/swiglu_ffn.py b/FastVGGT/vggt/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd991e1deb87141ccd282098d4b9d38fed6ef25 --- /dev/null +++ b/FastVGGT/vggt/layers/swiglu_ffn.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +# try: +# if XFORMERS_ENABLED: +# from xformers.ops import SwiGLU + +# XFORMERS_AVAILABLE = True +# warnings.warn("xFormers is available (SwiGLU)") +# else: +# warnings.warn("xFormers is disabled (SwiGLU)") +# raise ImportError +# except ImportError: +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + +# warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) diff --git a/FastVGGT/vggt/layers/vision_transformer.py b/FastVGGT/vggt/layers/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1aee388a9c38168657f78fa9436685582068c6 --- /dev/null +++ b/FastVGGT/vggt/layers/vision_transformer.py @@ -0,0 +1,446 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ +from . import ( + Mlp, + PatchEmbed, + SwiGLUFFNFused, + MemEffAttention, + NestedTensorBlock as Block, +) + +logger = logging.getLogger("dinov2") + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.use_reentrant = False # hardcoded to False + + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append( + [nn.Identity()] * i + blocks_list[i : i + chunksize] + ) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where( + masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x + ) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [ + self.prepare_tokens_with_masks(x, masks) + for x, masks in zip(x_list, masks_list) + ] + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len( + blocks_to_take + ), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = ( + range(total_block_len - n, total_block_len) if isinstance(n, int) else n + ) + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len( + blocks_to_take + ), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc b/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ada2738eee9604acf0497bcafdfc52aa6f21119f Binary files /dev/null and b/FastVGGT/vggt/models/__pycache__/aggregator.cpython-310.pyc differ diff --git a/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc b/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67f693db1cd59d1a8d87e8f5c1c77c57c9de47d2 Binary files /dev/null and b/FastVGGT/vggt/models/__pycache__/vggt.cpython-310.pyc differ diff --git a/FastVGGT/vggt/models/aggregator.py b/FastVGGT/vggt/models/aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..9e45bb204729f25a963fb083c158142ed01515c2 --- /dev/null +++ b/FastVGGT/vggt/models/aggregator.py @@ -0,0 +1,492 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from numpy import block +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from typing import Optional, Tuple, Union, List, Dict, Any + +from vggt.layers import PatchEmbed +from vggt.layers.block import Block +from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter +from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 +import time + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class Aggregator(nn.Module): + """ + The Aggregator applies alternating-attention over input frames, + as described in VGGT: Visual Geometry Grounded Transformer. + + Remember to set model.train() to enable gradient checkpointing to reduce memory usage. + + Args: + img_size (int): Image size in pixels. + patch_size (int): Size of each patch for PatchEmbed. + embed_dim (int): Dimension of the token embeddings. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. + num_register_tokens (int): Number of register tokens. + block_fn (nn.Module): The block type used for attention (Block by default). + qkv_bias (bool): Whether to include bias in QKV projections. + proj_bias (bool): Whether to include bias in the output projection. + ffn_bias (bool): Whether to include bias in MLP layers. + patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg". + aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"]. + aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1. + qk_norm (bool): Whether to apply QK normalization. + rope_freq (int): Base frequency for rotary embedding. -1 to disable. + init_values (float): Init scale for layer scale. + """ + + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + patch_embed="dinov2_vitl14_reg", + aa_order=["frame", "global"], + aa_block_size=1, + qk_norm=True, + rope_freq=100, + init_values=0.01, + global_merging=True, + merging=0, + vis_attn_map=False, + ): + super().__init__() + + self.__build_patch_embed__( + patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim + ) + + # Initialize rotary position embedding if frequency > 0 + self.rope = ( + RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + ) + self.position_getter = PositionGetter() if self.rope is not None else None + self.global_merging = global_merging + self.merging = merging + self.vis_attn_map = vis_attn_map + self.frame_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.global_blocks = nn.ModuleList( + [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + rope=self.rope, + ) + for _ in range(depth) + ] + ) + + self.depth = depth + self.aa_order = aa_order + self.patch_size = patch_size + self.aa_block_size = aa_block_size + + # Validate that depth is divisible by aa_block_size + if self.depth % self.aa_block_size != 0: + raise ValueError( + f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})" + ) + + self.aa_block_num = self.depth // self.aa_block_size + + # Note: We have two camera tokens, one for the first frame and one for the rest + # The same applies for register tokens + self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim)) + self.register_token = nn.Parameter( + torch.randn(1, 2, num_register_tokens, embed_dim) + ) + + # The patch tokens start after the camera and register tokens + self.patch_start_idx = 1 + num_register_tokens + + # Initialize parameters with small values + nn.init.normal_(self.camera_token, std=1e-6) + nn.init.normal_(self.register_token, std=1e-6) + + # Register normalization constants as buffers (use bf16-compatible tensor) + for name, value in ( + ("_resnet_mean", _RESNET_MEAN), + ("_resnet_std", _RESNET_STD), + ): + self.register_buffer( + name, + torch.tensor(value, dtype=torch.bfloat16).view(1, 1, 3, 1, 1), + persistent=False, + ) + + self.use_reentrant = False # hardcoded to False + + def __build_patch_embed__( + self, + patch_embed, + img_size, + patch_size, + num_register_tokens, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + embed_dim=1024, + ): + """ + Build the patch embed layer. If 'conv', we use a + simple PatchEmbed conv layer. Otherwise, we use a vision transformer. + """ + + if "conv" in patch_embed: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=3, + embed_dim=embed_dim, + ) + else: + vit_models = { + "dinov2_vitl14_reg": vit_large, + "dinov2_vitb14_reg": vit_base, + "dinov2_vits14_reg": vit_small, + "dinov2_vitg2_reg": vit_giant2, + } + + self.patch_embed = vit_models[patch_embed]( + img_size=img_size, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + block_chunks=block_chunks, + init_values=init_values, + ) + + # Disable gradient updates for mask token + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + def forward(self, images: torch.Tensor) -> Tuple[List[torch.Tensor], int]: + """ + Args: + images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + + Returns: + (list[torch.Tensor], int): + The list of outputs from the attention blocks, + and the patch_start_idx indicating where patch tokens begin. + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + # Normalize images and reshape for patch embed - ensure bf16 computation + images = images.to(torch.bfloat16) + images = (images - self._resnet_mean) / self._resnet_std + + images = images.view(B * S, C_in, H, W) + patch_tokens = self.patch_embed(images) + del images + + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + patch_tokens = patch_tokens.to(torch.bfloat16) + + _, P, C = patch_tokens.shape + + # Expand camera and register tokens to match batch size and sequence length + camera_token = slice_expand_and_flatten( + self.camera_token.to(torch.bfloat16), B, S + ) + register_token = slice_expand_and_flatten( + self.register_token.to(torch.bfloat16), B, S + ) + + tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1) + tokens = tokens.to(torch.bfloat16) + del camera_token, register_token, patch_tokens + # Explicitly clean up image data since patch embedding is complete + if "images_normalized" in locals(): + del images_normalized + + pos = None + if self.rope is not None: + pos = self.position_getter( + B * S, H // self.patch_size, W // self.patch_size, device="cuda" + ) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos_original = pos + pos = pos + 1 + pos_special = torch.zeros( + B * S, self.patch_start_idx, 2, device="cuda", dtype=torch.long + ) + pos = torch.cat([pos_special, pos], dim=1) + # Clean up temporary variables + del pos_special, pos_original + + # update P because we added special tokens + _, P, C = tokens.shape + + frame_idx = 0 + global_idx = 0 + output_list = [] + block4DPT_idx = [4, 11, 17, 23] + global_merging = None + + # Set global variables for attention visualization + if self.vis_attn_map: + import vggt.layers.attention as attn_module + + # Set the global variables that attention.py needs + attn_module.vis_attn_map = True + attn_module.current_images = self._load_image_paths() # Load from temp file + else: + import vggt.layers.attention as attn_module + + attn_module.vis_attn_map = False + + for block_num in range(self.aa_block_num): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + need_intermediates = True if block_num in block4DPT_idx else False + if block_num % 1 == 0: + # Clean up RoPE cache to prevent accumulation + if hasattr(self, "rope") and self.rope is not None: + if hasattr(self.rope, "frequency_cache"): + self.rope.frequency_cache.clear() + # Clean up position cache + if ( + hasattr(self, "position_getter") + and self.position_getter is not None + ): + if hasattr(self.position_getter, "position_cache"): + # Keep only current size cache, clean up others + current_cache = self.position_getter.position_cache.copy() + if ( + len(current_cache) > 1 + ): # If there are multiple cache entries + self.position_getter.position_cache.clear() + # Keep only the most recently used one + if current_cache: + key = list(current_cache.keys())[-1] + self.position_getter.position_cache[key] = ( + current_cache[key] + ) + # Avoid saving block_num to instance variable to reduce references + for attn_type in self.aa_order: + if attn_type == "frame": + tokens, frame_idx, frame_intermediates = ( + self._process_frame_attention( + tokens, + B, + S, + P, + C, + frame_idx, + pos=pos, + need_intermediates=need_intermediates, + ) + ) + elif attn_type == "global": + if self.merging is None: + global_merging = None + elif self.global_merging and block_num >= self.merging: + global_merging = block_num + # Set attention_map for visualization + if self.vis_attn_map: + import vggt.layers.attention as attn_module + + attn_module.attention_map = block_num + tokens, global_idx, global_intermediates = ( + self._process_global_attention( + tokens, + B, + S, + P, + C, + global_idx, + pos=pos, + global_merging=global_merging, + need_intermediates=need_intermediates, + ) + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + if block_num not in block4DPT_idx: + if "frame_intermediates" in locals(): + del frame_intermediates + if "global_intermediates" in locals(): + del global_intermediates + else: + concat_inter = torch.cat( + [frame_intermediates[0].detach(), global_intermediates[0].detach()], + dim=-1, + ) + if concat_inter.dtype != torch.bfloat16: + concat_inter = concat_inter.to(torch.bfloat16) + output_list.append(concat_inter) + del concat_inter, frame_intermediates, global_intermediates + + # Do final cleanup before returning + del tokens, pos + if "pos_special" in locals(): + del pos_special + if "pos_original" in locals(): + del pos_original + torch.cuda.empty_cache() # Final cleanup + + return output_list, self.patch_start_idx + + def _process_frame_attention( + self, tokens, B, S, P, C, frame_idx, pos=None, need_intermediates=False + ): + """ + Process frame attention blocks. We keep tokens in shape (B*S, P, C). + """ + # If needed, reshape tokens or positions: + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B, S, P, C).view(B * S, P, C) + + if pos is not None and pos.shape != (B * S, P, 2): + pos = pos.view(B, S, P, 2).view(B * S, P, 2) + + intermediates = [] if need_intermediates else None + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + tokens = self.frame_blocks[frame_idx](tokens, pos=pos) + frame_idx += 1 + if need_intermediates: + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + def _process_global_attention( + self, + tokens, + B, + S, + P, + C, + global_idx, + pos=None, + global_merging=None, + need_intermediates=False, + ): + """ + Process global attention blocks. We keep tokens in shape (B, S*P, C). + """ + if tokens.shape != (B, S * P, C): + tokens = tokens.view(B, S, P, C).view(B, S * P, C) + + if pos is not None and pos.shape != (B, S * P, 2): + pos = pos.view(B, S, P, 2).view(B, S * P, 2) + + intermediates = [] if need_intermediates else None + + # by default, self.aa_block_size=1, which processes one block at a time + for _ in range(self.aa_block_size): + tokens = self.global_blocks[global_idx]( + tokens, + pos=pos, + global_merging=global_merging, + ) + global_idx += 1 + if need_intermediates: + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, global_idx, intermediates + + def _load_image_paths(self): + """Load image paths from temporary file for visualization""" + try: + import os + import tempfile + import pickle + + temp_dir = tempfile.gettempdir() + image_paths_file = os.path.join(temp_dir, "vggt_image_paths.pkl") + + if os.path.exists(image_paths_file): + with open(image_paths_file, "rb") as f: + return pickle.load(f) + else: + return [] + except Exception as e: + print(f"Warning: Could not load image paths for visualization: {e}") + return [] + + +def slice_expand_and_flatten(token_tensor, B, S): + """ + Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing: + 1) Uses the first position (index=0) for the first frame only + 2) Uses the second position (index=1) for all remaining frames (S-1 frames) + 3) Expands both to match batch size B + 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token + followed by (S-1) second-position tokens + 5) Flattens to (B*S, X, C) for processing + + Returns: + torch.Tensor: Processed tokens with shape (B*S, X, C) + """ + + query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:]) + # Slice out the "other" tokens => shape (1, S-1, ...) + others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:]) + # Concatenate => shape (B, S, ...) + combined = torch.cat([query, others], dim=1) + + # Finally flatten => shape (B*S, ...) + combined = combined.view(B * S, *combined.shape[2:]) + return combined diff --git a/FastVGGT/vggt/models/vggt.py b/FastVGGT/vggt/models/vggt.py new file mode 100644 index 0000000000000000000000000000000000000000..2efeb5271de9916d05fd231a0a5ec1694f999bd0 --- /dev/null +++ b/FastVGGT/vggt/models/vggt.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin # used for model hub + +from vggt.models.aggregator import Aggregator +from vggt.heads.camera_head import CameraHead +from vggt.heads.dpt_head import DPTHead +from vggt.heads.track_head import TrackHead + + +class VGGT(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + img_size=518, + patch_size=14, + embed_dim=1024, + enable_camera=True, + enable_point=False, + enable_depth=True, + enable_track=False, + merging=0, + vis_attn_map=False, + ): + super().__init__() + + self.vis_attn_map = vis_attn_map + + self.aggregator = Aggregator( + img_size=img_size, + patch_size=patch_size, + embed_dim=embed_dim, + merging=merging, + vis_attn_map=vis_attn_map, + ) + + self.camera_head = CameraHead(dim_in=2 * embed_dim) if enable_camera else None + self.point_head = ( + DPTHead( + dim_in=2 * embed_dim, + output_dim=4, + activation="inv_log", + conf_activation="expp1", + ) + if enable_point + else None + ) + self.depth_head = ( + DPTHead( + dim_in=2 * embed_dim, + output_dim=2, + activation="exp", + conf_activation="expp1", + ) + if enable_depth + else None + ) + self.track_head = ( + TrackHead(dim_in=2 * embed_dim, patch_size=patch_size) + if enable_track + else None + ) + + def update_patch_dimensions(self, patch_width: int, patch_height: int): + """ + Update patch dimensions for all attention layers in the model + + Args: + patch_width: Patch width (typically 37) + patch_height: Patch height (typically 28) + """ + + def update_attention_in_module(module): + for name, child in module.named_children(): + # Recursively update submodules + update_attention_in_module(child) + # If it is an attention layer, update its patch dimensions + if hasattr(child, "patch_width") and hasattr(child, "patch_height"): + child.patch_width = patch_width + child.patch_height = patch_height + + # Update all attention layers in the aggregator + update_attention_in_module(self.aggregator) + + # print( + # f"🔧 Updated model attention layer patch dimensions: {patch_width}x{patch_height}" + # ) + + def forward( + self, + images: torch.Tensor, + query_points: torch.Tensor = None, + image_paths: list = None, + ): + """ + Forward pass of the VGGT model. + + Args: + images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]. + B: batch size, S: sequence length, 3: RGB channels, H: height, W: width + query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates. + Shape: [N, 2] or [B, N, 2], where N is the number of query points. + Default: None + image_paths (list, optional): List of image file paths for attention visualization. + Only used when vis_attn_map=True. Default: None + + Returns: + dict: A dictionary containing the following predictions: + - pose_enc (torch.Tensor): Camera pose encoding with shape [B, S, 9] (from the last iteration) + - depth (torch.Tensor): Predicted depth maps with shape [B, S, H, W, 1] + - depth_conf (torch.Tensor): Confidence scores for depth predictions with shape [B, S, H, W] + - world_points (torch.Tensor): 3D world coordinates for each pixel with shape [B, S, H, W, 3] + - world_points_conf (torch.Tensor): Confidence scores for world points with shape [B, S, H, W] + - images (torch.Tensor): Original input images, preserved for visualization + + If query_points is provided, also includes: + - track (torch.Tensor): Point tracks with shape [B, S, N, 2] (from the last iteration), in pixel coordinates + - vis (torch.Tensor): Visibility scores for tracked points with shape [B, S, N] + - conf (torch.Tensor): Confidence scores for tracked points with shape [B, S, N] + """ + # If without batch dimension, add it + if len(images.shape) == 4: + images = images.unsqueeze(0) + + if query_points is not None and len(query_points.shape) == 2: + query_points = query_points.unsqueeze(0) + + # Save image paths globally for attention visualization + if self.vis_attn_map and image_paths is not None: + import os + import tempfile + import pickle + + # Create a temporary file to store image paths + temp_dir = tempfile.gettempdir() + image_paths_file = os.path.join(temp_dir, "vggt_image_paths.pkl") + with open(image_paths_file, "wb") as f: + pickle.dump(image_paths, f) + + aggregated_tokens_list, patch_start_idx = self.aggregator(images) + + predictions = {} + + if self.camera_head is not None: + pose_enc_list = self.camera_head(aggregated_tokens_list) + predictions["pose_enc"] = pose_enc_list[ + -1 + ] # pose encoding of the last iteration + predictions["pose_enc_list"] = pose_enc_list + + if self.depth_head is not None: + depth, depth_conf = self.depth_head( + aggregated_tokens_list, + images=images, + patch_start_idx=patch_start_idx, + ) + predictions["depth"] = depth + predictions["depth_conf"] = depth_conf + + if self.point_head is not None: + pts3d, pts3d_conf = self.point_head( + aggregated_tokens_list, + images=images, + patch_start_idx=patch_start_idx, + ) + predictions["world_points"] = pts3d + predictions["world_points_conf"] = pts3d_conf + + if self.track_head is not None and query_points is not None: + track_list, vis, conf = self.track_head( + aggregated_tokens_list, + images=images, + patch_start_idx=patch_start_idx, + query_points=query_points, + ) + predictions["track"] = track_list[-1] # track of the last iteration + predictions["vis"] = vis + predictions["conf"] = conf + + if not self.training: + predictions["images"] = ( + images # store the images for visualization during inference + ) + + return predictions diff --git a/FastVGGT/vggt/utils/__init__.py b/FastVGGT/vggt/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2958f93a114495631a3ce270b99ee8ba8443f1 --- /dev/null +++ b/FastVGGT/vggt/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95269f2111ee13cbadfcbab9cd10605555f25aee Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3e08e180b2f41ce54c1912296117bf616196aae Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/eval_utils.cpython-310.pyc differ diff --git a/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7a92463e6f8f34289986fcaf2ab515c9598cf51 Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/geometry.cpython-310.pyc differ diff --git a/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37978e9be49721f9e88beeb814a38e3dcb8754ad Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/pose_enc.cpython-310.pyc differ diff --git a/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc b/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93fcc5eddf5936461e8cbf8ca031e150105da24e Binary files /dev/null and b/FastVGGT/vggt/utils/__pycache__/rotation.cpython-310.pyc differ diff --git a/FastVGGT/vggt/utils/eval_utils.py b/FastVGGT/vggt/utils/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a81e2cab2b36c97e35b5ebb110b6e66fc814efa7 --- /dev/null +++ b/FastVGGT/vggt/utils/eval_utils.py @@ -0,0 +1,782 @@ +from collections import deque +from PIL import Image +import cv2 +import io +import random +import time +from pathlib import Path +from typing import List, Tuple, Dict, Optional +from copy import deepcopy +import matplotlib.pyplot as plt +import evo.tools.plot as plot +import evo.main_ape as main_ape +import evo.main_rpe as main_rpe +import matplotlib.pyplot as plt +import numpy as np +import torch +from evo.core.metrics import PoseRelation, Unit +from evo.core.trajectory import PoseTrajectory3D +from scipy.linalg import svd +import open3d as o3d # for point cloud processing and Chamfer Distance computation + +from scipy.spatial.transform import Rotation +from torchvision import transforms as TF + + +from vggt.utils.geometry import unproject_depth_map_to_point_map +from vggt.utils.pose_enc import pose_encoding_to_extri_intri + + +def shuffle_deque(dq, seed=None): + # Set the random seed for reproducibility + if seed is not None: + random.seed(seed) + + # Convert deque to list, shuffle, and convert back + shuffled_list = list(dq) + random.shuffle(shuffled_list) + return deque(shuffled_list) + + +def imread_cv2(path, options=cv2.IMREAD_COLOR): + """Open an image or a depthmap with opencv-python.""" + if path.endswith((".exr", "EXR")): + options = cv2.IMREAD_ANYDEPTH + img = cv2.imread(path, options) + if img is None: + raise IOError(f"Could not load image={path} with {options=}") + if img.ndim == 3: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +# Import umeyama_alignment for internal use in eval_trajectory +def umeyama_alignment(src, dst, estimate_scale=True): + # Ensure inputs have correct shape + assert ( + src.shape == dst.shape + ), f"Input shapes don't match: src {src.shape}, dst {dst.shape}" + assert src.shape[0] == 3, f"Expected point cloud dimension (3,N), got {src.shape}" + + # Compute centroids + src_mean = src.mean(axis=1, keepdims=True) + dst_mean = dst.mean(axis=1, keepdims=True) + + # Center the point clouds + src_centered = src - src_mean + dst_centered = dst - dst_mean + + # Compute covariance matrix + cov = dst_centered @ src_centered.T + + try: + # Singular Value Decomposition + U, D, Vt = svd(cov) + V = Vt.T + + # Handle reflection case + det_UV = np.linalg.det(U @ V.T) + S = np.eye(3) + if det_UV < 0: + S[2, 2] = -1 + + # Compute rotation matrix + R = U @ S @ V.T + + if estimate_scale: + # Compute scale factor - fix dimension issue + src_var = np.sum(src_centered * src_centered) + if src_var < 1e-10: + print( + "Warning: Source point cloud variance close to zero, setting scale factor to 1.0" + ) + scale = 1.0 + else: + # Fix potential dimension issue with np.diag(S) + # Use diagonal elements directly + scale = np.sum(D * np.diag(S)) / src_var + else: + scale = 1.0 + + # Compute translation vector + t = dst_mean.ravel() - scale * (R @ src_mean).ravel() + + return scale, R, t + + except Exception as e: + print(f"Error in umeyama_alignment computation: {e}") + print( + "Returning default transformation: scale=1.0, rotation=identity matrix, translation=centroid difference" + ) + # Return default transformation + scale = 1.0 + R = np.eye(3) + t = (dst_mean - src_mean).ravel() + return scale, R, t + + +def compute_chamfer_distance(points_pred, points_gt, max_dist=1.0): + # Ensure point cloud size is not too large, which would cause slow computation + MAX_POINTS = 100000 + if points_pred.shape[0] > MAX_POINTS: + np.random.seed(33) # Fix random seed + indices = np.random.choice(points_pred.shape[0], MAX_POINTS, replace=False) + points_pred = points_pred[indices] + + if points_gt.shape[0] > MAX_POINTS: + np.random.seed(33) # Fix random seed + indices = np.random.choice(points_gt.shape[0], MAX_POINTS, replace=False) + points_gt = points_gt[indices] + + # Convert numpy point clouds to open3d point cloud objects + pcd_pred = o3d.geometry.PointCloud() + pcd_gt = o3d.geometry.PointCloud() + pcd_pred.points = o3d.utility.Vector3dVector(points_pred) + pcd_gt.points = o3d.utility.Vector3dVector(points_gt) + + # Downsample point clouds to accelerate computation + voxel_size = 0.05 # 5cm voxel size + pcd_pred = pcd_pred.voxel_down_sample(voxel_size) + pcd_gt = pcd_gt.voxel_down_sample(voxel_size) + + # Compute distances from predicted point cloud to GT point cloud + distances1 = np.asarray(pcd_pred.compute_point_cloud_distance(pcd_gt)) + # Compute distances from GT point cloud to predicted point cloud + distances2 = np.asarray(pcd_gt.compute_point_cloud_distance(pcd_pred)) + + # Apply distance clipping + distances1 = np.clip(distances1, 0, max_dist) + distances2 = np.clip(distances2, 0, max_dist) + + # Chamfer Distance is the sum of mean distances in both directions + chamfer_dist = np.mean(distances1) + np.mean(distances2) + + return chamfer_dist + + +def load_gt_pointcloud(scene_id, gt_ply_dir): + scene_dir = gt_ply_dir / scene_id + ply_path = scene_dir / (scene_id + "_vh_clean_2.ply") + pcd = o3d.io.read_point_cloud(str(ply_path)) + + # Convert to numpy arrays + points = np.asarray(pcd.points) + colors = None + try: + if pcd.has_colors(): + colors = np.asarray(pcd.colors) + except Exception: + colors = None + + return points, colors + + +def eval_trajectory(poses_est, poses_gt, frame_ids, align=False): + # Build reference trajectory object + traj_ref = PoseTrajectory3D( + positions_xyz=poses_gt[:, :3, 3], # Extract translation part + orientations_quat_wxyz=Rotation.from_matrix(poses_gt[:, :3, :3]).as_quat( + scalar_first=True + ), # Convert rotation matrix to quaternion + timestamps=frame_ids, + ) + + # Build estimated trajectory object + traj_est = PoseTrajectory3D( + positions_xyz=poses_est[:, :3, 3], + orientations_quat_wxyz=Rotation.from_matrix(poses_est[:, :3, :3]).as_quat( + scalar_first=True + ), + timestamps=frame_ids, + ) + + # Calculate Absolute Trajectory Error (ATE) + ate_result = main_ape.ape( + deepcopy(traj_ref), + deepcopy(traj_est), + est_name="traj", + pose_relation=PoseRelation.translation_part, + align=align, + correct_scale=align, + align_origin=True, + ) + ate = ate_result.stats["rmse"] + + # Get alignment transformation matrix + transform = np.eye(4) + if align: + try: + # Use umeyama algorithm to compute optimal rigid transformation (including rotation, translation and scaling) + aligned_xyz = ate_result.trajectories["traj"].positions_xyz + original_xyz = traj_est.positions_xyz + + # At least 3 points needed to compute reliable transformation + if len(aligned_xyz) >= 3 and len(original_xyz) >= 3: + # Ensure point count matches + min_points = min(len(aligned_xyz), len(original_xyz)) + aligned_xyz = aligned_xyz[:min_points] + original_xyz = original_xyz[:min_points] + + # Compute transformation matrix (scaling, rotation and translation) + try: + s, R, t = umeyama_alignment( + original_xyz.T, # Source point cloud (3xN) + aligned_xyz.T, # Target point cloud (3xN) + True, # Whether to estimate scaling + ) + + # Build complete transformation matrix + transform = np.eye(4) + transform[:3, :3] = s * R # Scaling and rotation + transform[:3, 3] = t # Translation + + except Exception as e: + print(f"umeyama_alignment failed: {e}") + else: + print( + "Insufficient points, cannot reliably compute transformation matrix" + ) + except Exception as e: + print(f"Error computing transformation matrix: {e}") + + # If the above method fails, fallback to simple translation transformation + if np.array_equal(transform, np.eye(4)) and hasattr(ate_result, "trajectories"): + try: + # Get original and aligned first position + orig_pos = traj_est.positions_xyz[0] + aligned_pos = ate_result.trajectories["traj"].positions_xyz[0] + + # Calculate translation part + translation = aligned_pos - orig_pos + + # Update translation part of transformation matrix + transform[:3, 3] = translation + print(f"Fallback to simple translation transformation: {transform}") + except Exception as e: + print(f"Error building translation transformation: {e}") + print("Will use identity matrix") + + # Calculate Absolute Rotation Error (ARE) + are_result = main_ape.ape( + deepcopy(traj_ref), + deepcopy(traj_est), + est_name="traj", + pose_relation=PoseRelation.rotation_angle_deg, + align=align, + correct_scale=align, + align_origin=True, + ) + are = are_result.stats["rmse"] + + # Calculate Relative Pose Error (RPE) - rotation part + rpe_rots_result = main_rpe.rpe( + deepcopy(traj_ref), + deepcopy(traj_est), + est_name="traj", + pose_relation=PoseRelation.rotation_angle_deg, + align=align, + correct_scale=align, + delta=1, + delta_unit=Unit.frames, + rel_delta_tol=0.01, + all_pairs=True, + align_origin=True, + ) + rpe_rot = rpe_rots_result.stats["rmse"] + + # Calculate Relative Pose Error (RPE) - translation part + rpe_transs_result = main_rpe.rpe( + deepcopy(traj_ref), + deepcopy(traj_est), + est_name="traj", + pose_relation=PoseRelation.translation_part, + align=align, + correct_scale=align, + delta=1, + delta_unit=Unit.frames, + rel_delta_tol=0.01, + all_pairs=True, + align_origin=True, + ) + rpe_trans = rpe_transs_result.stats["rmse"] + + # Plot trajectory graph + plot_mode = plot.PlotMode.xz # Use correct PlotMode reference + fig = plt.figure() + ax = plot.prepare_axis(fig, plot_mode) + ax.set_title(f"ATE: {round(ate, 3)}, ARE: {round(are, 3)}") + + # Use reference trajectory (GT) for plotting + plot.traj(ax, plot_mode, traj_ref, "--", "gray", "gt") + + # Use aligned trajectory for visualization + if align: + traj_est_aligned = ate_result.trajectories["traj"] + plot.traj_colormap( + ax, + traj_est_aligned, + ate_result.np_arrays["error_array"], + plot_mode, + min_map=ate_result.stats["min"], + max_map=ate_result.stats["max"], + ) + else: + plot.traj_colormap( + ax, + traj_est, + ate_result.np_arrays["error_array"], + plot_mode, + min_map=ate_result.stats["min"], + max_map=ate_result.stats["max"], + ) + + ax.legend() + + # Save image to memory buffer + buffer = io.BytesIO() + plt.savefig(buffer, format="png", dpi=90) + buffer.seek(0) + + pillow_image = Image.open(buffer) + pillow_image.load() + buffer.close() + plt.close(fig) + + return ( + {"ate": ate, "are": are, "rpe_rot": rpe_rot, "rpe_trans": rpe_trans}, + pillow_image, + transform, + ) + + +def load_poses(path): + # Read all txt files from pose directory + pose_files = sorted( + path.glob("*.txt"), key=lambda x: int(x.stem) + ) # Sort by numerical order + + # Check if pose files exist + if len(pose_files) == 0: + print(f"Warning: No pose files (.txt) found in directory {path}") + return None, None, None + + c2ws = [] + available_frame_ids = [] + + for pose_file in pose_files: + try: + with open(pose_file, "r") as f: + # Each file contains 16 numbers representing a 4x4 transformation matrix + nums = [float(x) for x in f.read().strip().split()] + pose = np.array(nums).reshape(4, 4) + # Check if pose is valid (no infinite or NaN values) + if not (np.isinf(pose).any() or np.isnan(pose).any()): + c2ws.append(pose) + available_frame_ids.append(int(pose_file.stem)) + else: + continue + except Exception as e: + print(f"Error reading pose file {pose_file}: {e}") + continue + + if len(c2ws) == 0: + print(f"Warning: No valid pose files found in directory {path}") + return None, None, None + + c2ws = np.stack(c2ws) + available_frame_ids = np.array(available_frame_ids) + + # Transform all poses to first frame coordinate system + first_gt_pose = c2ws[0].copy() # Save original pose of first frame + c2ws = np.linalg.inv(c2ws[0]) @ c2ws + return c2ws, first_gt_pose, available_frame_ids + + +def get_vgg_input_imgs(images: np.ndarray): + to_tensor = TF.ToTensor() + vgg_input_images = [] + final_width = None + final_height = None + + for image in images: + img = Image.fromarray(image, mode="RGB") + width, height = img.size + # Resize image, maintain aspect ratio, ensure height is multiple of 14 + new_width = 518 + new_height = round(height * (new_width / width) / 14) * 14 + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # If height exceeds 518, perform center cropping + if new_height > 518: + start_y = (new_height - 518) // 2 + img = img[:, start_y : start_y + 518, :] + final_height = 518 + else: + final_height = new_height + + final_width = new_width + vgg_input_images.append(img) + + vgg_input_images = torch.stack(vgg_input_images) + + # Calculate the patch dimensions (divided by 14 for patch size) + patch_width = final_width // 14 # 518 // 14 = 37 + patch_height = final_height // 14 # computed dynamically, typically 28 + + return vgg_input_images, patch_width, patch_height + + +def get_sorted_image_paths(images_dir): + image_paths = [] + for ext in ["*.jpg", "*.png", "*.jpeg"]: + image_paths.extend(sorted(images_dir.glob(ext))) + # image_paths.sort(key=lambda x: int(x.stem)) + return image_paths + + +def to_homogeneous(extrinsics): + n = extrinsics.shape[0] + homogeneous_extrinsics = np.eye(4)[None, :, :].repeat( + n, axis=0 + ) # Create identity matrix + homogeneous_extrinsics[:, :3, :4] = extrinsics # Copy [R|t] part + return homogeneous_extrinsics + + +def umeyama_alignment(src, dst, estimate_scale=True): + # Ensure inputs have correct shape + assert ( + src.shape == dst.shape + ), f"Input shapes don't match: src {src.shape}, dst {dst.shape}" + assert src.shape[0] == 3, f"Expected point cloud dimension (3,N), got {src.shape}" + + # Compute centroids + src_mean = src.mean(axis=1, keepdims=True) + dst_mean = dst.mean(axis=1, keepdims=True) + + # Center the point clouds + src_centered = src - src_mean + dst_centered = dst - dst_mean + + # Compute covariance matrix + cov = dst_centered @ src_centered.T + + try: + # Singular Value Decomposition + U, D, Vt = svd(cov) + V = Vt.T + + # Handle reflection case + det_UV = np.linalg.det(U @ V.T) + S = np.eye(3) + if det_UV < 0: + S[2, 2] = -1 + + # Compute rotation matrix + R = U @ S @ V.T + + if estimate_scale: + # Compute scale factor - fix dimension issue + src_var = np.sum(src_centered * src_centered) + if src_var < 1e-10: + print( + "Warning: Source point cloud variance close to zero, setting scale factor to 1.0" + ) + scale = 1.0 + else: + # Fix potential dimension issue with np.diag(S) + # Use diagonal elements directly + scale = np.sum(D * np.diag(S)) / src_var + else: + scale = 1.0 + + # Compute translation vector + t = dst_mean.ravel() - scale * (R @ src_mean).ravel() + + return scale, R, t + + except Exception as e: + print(f"Error in umeyama_alignment computation: {e}") + print( + "Returning default transformation: scale=1.0, rotation=identity matrix, translation=centroid difference" + ) + # Return default transformation + scale = 1.0 + R = np.eye(3) + t = (dst_mean - src_mean).ravel() + return scale, R, t + + +def align_point_clouds_scale(source_pc, target_pc): + # Compute bounding box sizes of point clouds + source_min = np.min(source_pc, axis=0) + source_max = np.max(source_pc, axis=0) + target_min = np.min(target_pc, axis=0) + target_max = np.max(target_pc, axis=0) + + source_size = source_max - source_min + target_size = target_max - target_min + + # Compute point cloud centers + source_center = (source_max + source_min) / 2 + target_center = (target_max + target_min) / 2 + + # Compute overall scale factor (using diagonal length) + source_diag = np.sqrt(np.sum(source_size**2)) + target_diag = np.sqrt(np.sum(target_size**2)) + + if source_diag < 1e-8: + print("Warning: Source point cloud size close to zero") + scale_factor = 1.0 + else: + scale_factor = target_diag / source_diag + + # Apply scaling (with source point cloud center as reference) + centered_source = source_pc - source_center + scaled_centered = centered_source * scale_factor + scaled_aligned_source = scaled_centered + target_center + + return scaled_aligned_source, scale_factor + + +def get_all_scenes(data_dir: Path, num_scenes: int) -> List[str]: + all_scenes = sorted([d.name for d in data_dir.iterdir() if d.is_dir()]) + if len(all_scenes) > num_scenes: + sample_interval = max(1, len(all_scenes) // num_scenes) + return all_scenes[::sample_interval][:num_scenes] + return all_scenes + + +def build_frame_selection( + image_paths: List[Path], + available_pose_frame_ids: np.ndarray, + input_frame: int, +) -> Tuple[List[int], List[Path], List[int]]: + all_image_frame_ids = [int(path.stem) for path in image_paths] + valid_frame_ids = sorted( + list(set(all_image_frame_ids) & set(available_pose_frame_ids)) + ) + if len(valid_frame_ids) > input_frame: + first_frame = valid_frame_ids[0] + remaining_frames = valid_frame_ids[1:] + step = max(1, len(remaining_frames) // (input_frame - 1)) + selected_remaining = remaining_frames[::step][: input_frame - 1] + selected_frame_ids = [first_frame] + selected_remaining + else: + selected_frame_ids = valid_frame_ids + + frame_id_to_path = {int(path.stem): path for path in image_paths} + selected_image_paths = [ + frame_id_to_path[fid] for fid in selected_frame_ids if fid in frame_id_to_path + ] + + pose_frame_to_idx = {fid: idx for idx, fid in enumerate(available_pose_frame_ids)} + selected_pose_indices = [ + pose_frame_to_idx[fid] for fid in selected_frame_ids if fid in pose_frame_to_idx + ] + + return selected_frame_ids, selected_image_paths, selected_pose_indices + + +def load_images_rgb(image_paths: List[Path]) -> List[np.ndarray]: + images: List[np.ndarray] = [] + for image_path in image_paths: + img = cv2.imread(str(image_path)) + if img is None: + continue + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + images.append(img) + return images + + +@torch.no_grad() +def infer_vggt_and_reconstruct( + model: torch.nn.Module, + vgg_input: torch.Tensor, + dtype: torch.dtype, + depth_conf_thresh: float, + image_paths: list = None, +) -> Tuple[ + np.ndarray, + np.ndarray, + List[np.ndarray], + List[np.ndarray], + List[np.ndarray], + float, +]: + torch.cuda.synchronize() + start = time.time() + with torch.cuda.amp.autocast(dtype=dtype): + vgg_input_cuda = vgg_input.cuda().to(torch.bfloat16) + predictions = model(vgg_input_cuda, image_paths=image_paths) + torch.cuda.synchronize() + end = time.time() + inference_time_ms = (end - start) * 1000.0 + + extrinsic, intrinsic = pose_encoding_to_extri_intri( + predictions["pose_enc"], (vgg_input.shape[2], vgg_input.shape[3]) + ) + + depth_tensor = predictions["depth"] + depth_conf = predictions["depth_conf"] + depth_conf_np = depth_conf[0].detach().float().cpu().numpy() + depth_mask = depth_conf_np >= depth_conf_thresh + depth_filtered = depth_tensor[0].detach().float().cpu().numpy() + depth_filtered[~depth_mask] = np.nan + depth_np = depth_filtered + + extrinsic_np = extrinsic[0].detach().float().cpu().numpy() + intrinsic_np = intrinsic[0].detach().float().cpu().numpy() + + world_points = unproject_depth_map_to_point_map( + depth_np, extrinsic_np, intrinsic_np + ) + all_points: List[np.ndarray] = [] + all_colors: List[np.ndarray] = [] + + # Prepare RGB images aligned with vgg_input for coloring point clouds (0-255, uint8) + vgg_np = vgg_input.detach().float().cpu().numpy() # [S, 3, H, W] in [0,1] + + for frame_idx in range(world_points.shape[0]): + points = world_points[frame_idx].reshape(-1, 3) + valid_mask = ~np.isnan(points).any(axis=1) & ~np.isinf(points).any(axis=1) + valid_points = points[valid_mask] + if len(valid_points) > 0: + all_points.append(valid_points) + + # Generate corresponding colors + img_chw = vgg_np[frame_idx] # [3, H, W] + img_hwc = ( + (np.transpose(img_chw, (1, 2, 0)) * 255.0).clip(0, 255).astype(np.uint8) + ) # [H, W, 3] uint8 + rgb_flat = img_hwc.reshape(-1, 3) + valid_colors = rgb_flat[valid_mask] + all_colors.append(valid_colors) + + camera_poses = to_homogeneous(extrinsic_np) + all_cam_to_world_mat = list(camera_poses) + + return ( + extrinsic_np, + intrinsic_np, + all_points, + all_colors, + all_cam_to_world_mat, + inference_time_ms, + ) + + +def evaluate_scene_and_save( + scene: str, + c2ws: np.ndarray, + first_gt_pose: np.ndarray, + frame_ids: List[int], + all_cam_to_world_mat: List[np.ndarray], + all_world_points: List[np.ndarray], + output_scene_dir: Path, + gt_ply_dir: Path, + chamfer_max_dist: float, + inference_time_ms: float, + plot_flag: bool, +) -> Optional[Dict[str, float]]: + if not all_cam_to_world_mat or not all_world_points: + print(f"Skipping {scene}: failed to obtain valid camera poses or point clouds") + return None + + output_scene_dir.mkdir(parents=True, exist_ok=True) + + poses_gt = c2ws + w2cs = np.linalg.inv(poses_gt) + traj_est_poses = np.array(all_cam_to_world_mat) + n = min(len(traj_est_poses), len(w2cs)) + timestamps = frame_ids[:n] + stats_aligned, traj_plot, _ = eval_trajectory( + traj_est_poses[:n], w2cs[:n], timestamps, align=True + ) + + try: + merged_point_cloud = np.vstack(all_world_points) + gt_point_cloud, _ = load_gt_pointcloud(scene, gt_ply_dir) + if gt_point_cloud is not None: + homogeneous_points = np.hstack( + [merged_point_cloud, np.ones((merged_point_cloud.shape[0], 1))] + ) + world_points_raw = np.dot(homogeneous_points, first_gt_pose.T)[:, :3] + + world_points_scaled, scale_factor = align_point_clouds_scale( + world_points_raw, gt_point_cloud + ) + + cd_value = compute_chamfer_distance( + world_points_scaled, gt_point_cloud, max_dist=chamfer_max_dist + ) + stats_aligned["chamfer_distance"] = float(cd_value) + stats_aligned["scale_factor"] = float(scale_factor) + except Exception as e: + print(f"Error computing Chamfer Distance for {scene}: {e}") + + all_metrics: Dict[str, float] = deepcopy(stats_aligned) + for metric_name, metric_value in list(stats_aligned.items()): + all_metrics[f"aligned_{metric_name}"] = metric_value + all_metrics["inference_time_ms"] = float(inference_time_ms) + + with open(output_scene_dir / "metrics.json", "w") as f: + import json + + json.dump(all_metrics, f, indent=4) + if plot_flag: + try: + traj_plot.save(output_scene_dir / "plot.png") + except Exception: + pass + + return all_metrics + + +def compute_average_metrics_and_save( + all_scenes_metrics: Dict[str, Dict[str, Dict[str, float]]], + output_path: Path, + input_frame: int, +) -> Dict[str, float]: + metric_names = [ + "chamfer_distance", + "ate", + "are", + "rpe_rot", + "rpe_trans", + "inference_time_ms", + ] + average_metrics_list: Dict[str, List[float]] = { + metric: [] for metric in metric_names + } + for _, metrics in all_scenes_metrics["scenes"].items(): + for metric_name, metric_value in metrics.items(): + if metric_name in average_metrics_list: + average_metrics_list[metric_name].append(float(metric_value)) + + average_metrics: Dict[str, float] = {} + for metric_name, values in average_metrics_list.items(): + average_metrics[metric_name] = float(np.mean(values)) if values else 0.0 + + all_scenes_metrics["average"] = average_metrics + output_path.mkdir(parents=True, exist_ok=True) + + input_frame_dir = output_path / f"input_frame_{input_frame}" + input_frame_dir.mkdir(parents=True, exist_ok=True) + + with open(input_frame_dir / "all_scenes_metrics.json", "w") as f: + import json + + json.dump(all_scenes_metrics, f, indent=4) + + with open(input_frame_dir / "average_metrics.json", "w") as f: + import json + + json.dump(average_metrics, f, indent=4) + + print("\nAverage metrics:") + for metric_name, value in average_metrics.items(): + print(f"{metric_name}: {value:.6f}") + + return average_metrics diff --git a/FastVGGT/vggt/utils/geometry.py b/FastVGGT/vggt/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f555516dbc8a7dd8c7b15e6fbc928a5bfe8f740b --- /dev/null +++ b/FastVGGT/vggt/utils/geometry.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import numpy as np + + +from vggt.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix + + +# TODO: this code can be further cleaned up + + +def project_world_points_to_camera_points_batch(world_points, cam_extrinsics): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + world_points (torch.Tensor): 3D points of shape BxSxHxWx3. + cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4. + Returns: + """ + # TODO: merge this into project_world_points_to_cam + + # device = world_points.device + # with torch.autocast(device_type=device.type, enabled=False): + ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1) + world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4) + + # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4) + extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3) + + # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1) + world_points_h_exp = world_points_h.unsqueeze(-1) + + # Now perform the matrix multiplication + # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1) + camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1) + + return camera_points + + + +def project_world_points_to_cam( + world_points, + cam_extrinsics, + cam_intrinsics=None, + distortion_params=None, + default=0, + only_points_cam=False, +): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + world_points (torch.Tensor): 3D points of shape Px3. + cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. + cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. + distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion. + Returns: + torch.Tensor: Transformed 2D points of shape BxNx2. + """ + device = world_points.device + # with torch.autocast(device_type=device.type, dtype=torch.double): + with torch.autocast(device_type=device.type, enabled=False): + N = world_points.shape[0] # Number of points + B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras + world_points_homogeneous = torch.cat( + [world_points, torch.ones_like(world_points[..., 0:1])], dim=1 + ) # Nx4 + # Reshape for batch processing + world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand( + B, -1, -1 + ) # BxNx4 + + # Step 1: Apply extrinsic parameters + # Transform 3D points to camera coordinate system for all cameras + cam_points = torch.bmm( + cam_extrinsics, world_points_homogeneous.transpose(-1, -2) + ) + + if only_points_cam: + return None, cam_points + + # Step 2: Apply intrinsic parameters and (optional) distortion + image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default) + + return image_points, cam_points + + + +def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0): + """ + Applies intrinsic parameters and optional distortion to the given 3D points. + + Args: + cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. + cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. + distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + default (float, optional): Default value to replace NaNs in the output. + + Returns: + pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. + """ + + # Normalized device coordinates (NDC) + cam_points = cam_points / cam_points[:, 2:3, :] + ndc_xy = cam_points[:, :2, :] + + # Apply distortion if distortion_params are provided + if distortion_params is not None: + x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1]) + distorted_xy = torch.stack([x_distorted, y_distorted], dim=1) + else: + distorted_xy = ndc_xy + + # Prepare cam_points for batch matrix multiplication + cam_coords_homo = torch.cat( + (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1 + ) # Bx3xN + # Apply intrinsic parameters using batch matrix multiplication + pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN + + # Extract x and y coordinates + pixel_coords = pixel_coords[:, :2, :] # Bx2xN + + # Replace NaNs with default value + pixel_coords = torch.nan_to_num(pixel_coords, nan=default) + + return pixel_coords.transpose(1, 2) # BxNx2 + + + + +def cam_from_img(pred_tracks, intrinsics, extra_params=None): + """ + Normalize predicted tracks based on camera intrinsics. + Args: + intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3]. + pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2]. + extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + Returns: + torch.Tensor: Normalized tracks tensor. + """ + + # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here + # otherwise we can use something like + # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2)) + + principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2) + focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2) + tracks_normalized = (pred_tracks - principal_point) / focal_length + + if extra_params is not None: + # Apply iterative undistortion + try: + tracks_normalized = iterative_undistortion( + extra_params, tracks_normalized + ) + except: + tracks_normalized = single_undistortion( + extra_params, tracks_normalized + ) + + return tracks_normalized \ No newline at end of file diff --git a/FastVGGT/vggt/utils/helper.py b/FastVGGT/vggt/utils/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..7b019189c85ff86645a4cf3756632aa8d4500649 --- /dev/null +++ b/FastVGGT/vggt/utils/helper.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + + +def randomly_limit_trues(mask: np.ndarray, max_trues: int) -> np.ndarray: + """ + If mask has more than max_trues True values, + randomly keep only max_trues of them and set the rest to False. + """ + # 1D positions of all True entries + true_indices = np.flatnonzero(mask) # shape = (N_true,) + + # if already within budget, return as-is + if true_indices.size <= max_trues: + return mask + + # randomly pick which True positions to keep + sampled_indices = np.random.choice(true_indices, size=max_trues, replace=False) # shape = (max_trues,) + + # build new flat mask: True only at sampled positions + limited_flat_mask = np.zeros(mask.size, dtype=bool) + limited_flat_mask[sampled_indices] = True + + # restore original shape + return limited_flat_mask.reshape(mask.shape) + + +def create_pixel_coordinate_grid(num_frames, height, width): + """ + Creates a grid of pixel coordinates and frame indices for all frames. + Returns: + tuple: A tuple containing: + - points_xyf (numpy.ndarray): Array of shape (num_frames, height, width, 3) + with x, y coordinates and frame indices + - y_coords (numpy.ndarray): Array of y coordinates for all frames + - x_coords (numpy.ndarray): Array of x coordinates for all frames + - f_coords (numpy.ndarray): Array of frame indices for all frames + """ + # Create coordinate grids for a single frame + y_grid, x_grid = np.indices((height, width), dtype=np.float32) + x_grid = x_grid[np.newaxis, :, :] + y_grid = y_grid[np.newaxis, :, :] + + # Broadcast to all frames + x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) + y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) + + # Create frame indices and broadcast + f_idx = np.arange(num_frames, dtype=np.float32)[:, np.newaxis, np.newaxis] + f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) + + # Stack coordinates and frame indices + points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) + + return points_xyf diff --git a/FastVGGT/vggt/utils/load_fn.py b/FastVGGT/vggt/utils/load_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae01ba2601f7b5039220bbae94725646dd66a41 --- /dev/null +++ b/FastVGGT/vggt/utils/load_fn.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from PIL import Image +from torchvision import transforms as TF +import numpy as np + + +def load_and_preprocess_images_square(image_path_list, target_size=1024): + """ + Load and preprocess images by center padding to square and resizing to target size. + Also returns the position information of original pixels after transformation. + + Args: + image_path_list (list): List of paths to image files + target_size (int, optional): Target size for both width and height. Defaults to 518. + + Returns: + tuple: ( + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), + torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image + ) + + Raises: + ValueError: If the input list is empty + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + images = [] + original_coords = [] # Renamed from position_info to be more descriptive + to_tensor = TF.ToTensor() + + for image_path in image_path_list: + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background + if img.mode == "RGBA": + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + img = Image.alpha_composite(background, img) + + # Convert to RGB + img = img.convert("RGB") + + # Get original dimensions + width, height = img.size + + # Make the image square by padding the shorter dimension + max_dim = max(width, height) + + # Calculate padding + left = (max_dim - width) // 2 + top = (max_dim - height) // 2 + + # Calculate scale factor for resizing + scale = target_size / max_dim + + # Calculate final coordinates of original image in target space + x1 = left * scale + y1 = top * scale + x2 = (left + width) * scale + y2 = (top + height) * scale + + # Store original image coordinates and scale + original_coords.append(np.array([x1, y1, x2, y2, width, height])) + + # Create a new black square image and paste original + square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) + square_img.paste(img, (left, top)) + + # Resize to target size + square_img = square_img.resize( + (target_size, target_size), Image.Resampling.BICUBIC + ) + + # Convert to tensor + img_tensor = to_tensor(square_img) + images.append(img_tensor) + + # Stack all images + images = torch.stack(images) + original_coords = torch.from_numpy(np.array(original_coords)).float() + + # Add additional dimension if single image to ensure correct shape + if len(image_path_list) == 1: + if images.dim() == 3: + images = images.unsqueeze(0) + original_coords = original_coords.unsqueeze(0) + + return images, original_coords + + +def load_and_preprocess_images(image_path_list, mode="crop"): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. + + Args: + image_path_list (list): List of paths to image files + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = 518 + + # First process all images and collect their shapes + for image_path in image_path_list: + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background: + if img.mode == "RGBA": + # Create white background + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + # Alpha composite onto the white background + img = Image.alpha_composite(background, img) + + # Now convert to "RGB" (this step assigns white for transparent areas) + img = img.convert("RGB") + + width, height = img.size + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = ( + round(height * (new_width / width) / 14) * 14 + ) # Make divisible by 14 + else: + new_height = target_size + new_width = ( + round(width * (new_height / height) / 14) * 14 + ) # Make divisible by 14 + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / 14) * 14 + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=1.0, + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, + (pad_left, pad_right, pad_top, pad_bottom), + mode="constant", + value=1.0, + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(image_path_list) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + + return images diff --git a/FastVGGT/vggt/utils/pose_enc.py b/FastVGGT/vggt/utils/pose_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..9d3b964330af0e62f4d36d332317ae00cb99b3a9 --- /dev/null +++ b/FastVGGT/vggt/utils/pose_enc.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from .rotation import quat_to_mat, mat_to_quat + + +def extri_intri_to_pose_encoding( + extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512) +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512) +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + else: + raise NotImplementedError + + return extrinsics, intrinsics diff --git a/FastVGGT/vggt/utils/rotation.py b/FastVGGT/vggt/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..f972afd8414c82fa1e9ed231725fd3f9f6ebde77 --- /dev/null +++ b/FastVGGT/vggt/utils/rotation.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import numpy as np +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/FastVGGT/vggt/utils/visual_track.py b/FastVGGT/vggt/utils/visual_track.py new file mode 100644 index 0000000000000000000000000000000000000000..796c114ccba00b5f7850e04b9444a6cd5c44b154 --- /dev/null +++ b/FastVGGT/vggt/utils/visual_track.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import numpy as np +import os + + +def color_from_xy(x, y, W, H, cmap_name="hsv"): + """ + Map (x, y) -> color in (R, G, B). + 1) Normalize x,y to [0,1]. + 2) Combine them into a single scalar c in [0,1]. + 3) Use matplotlib's colormap to convert c -> (R,G,B). + + You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y). + """ + import matplotlib.cm + import matplotlib.colors + + x_norm = x / max(W - 1, 1) + y_norm = y / max(H - 1, 1) + # Simple combination: + c = (x_norm + y_norm) / 2.0 + + cmap = matplotlib.cm.get_cmap(cmap_name) + # cmap(c) -> (r,g,b,a) in [0,1] + rgba = cmap(c) + r, g, b = rgba[0], rgba[1], rgba[2] + return (r, g, b) # in [0,1], RGB order + + +def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"): + """ + Given all tracks in one sample (b), compute a (N,3) array of RGB color values + in [0,255]. The color is determined by the (x,y) position in the first + visible frame for each track. + + Args: + tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame. + vis_mask_b: (S, N) boolean mask; if None, assume all are visible. + image_width, image_height: used for normalizing (x, y). + cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet'). + + Returns: + track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255]. + """ + S, N, _ = tracks_b.shape + track_colors = np.zeros((N, 3), dtype=np.uint8) + + if vis_mask_b is None: + # treat all as visible + vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device) + + for i in range(N): + # Find first visible frame for track i + visible_frames = torch.where(vis_mask_b[:, i])[0] + if len(visible_frames) == 0: + # track is never visible; just assign black or something + track_colors[i] = (0, 0, 0) + continue + + first_s = int(visible_frames[0].item()) + # use that frame's (x,y) + x, y = tracks_b[first_s, i].tolist() + + # map (x,y) -> (R,G,B) in [0,1] + r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name) + # scale to [0,255] + r, g, b = int(r * 255), int(g * 255), int(b * 255) + track_colors[i] = (r, g, b) + + return track_colors + + +def visualize_tracks_on_images( + images, + tracks, + track_vis_mask=None, + out_dir="track_visuals_concat_by_xy", + image_format="CHW", # "CHW" or "HWC" + normalize_mode="[0,1]", + cmap_name="hsv", # e.g. "hsv", "rainbow", "jet" + frames_per_row=4, # New parameter for grid layout + save_grid=True, # Flag to control whether to save the grid image +): + """ + Visualizes frames in a grid layout with specified frames per row. + Each track's color is determined by its (x,y) position + in the first visible frame (or frame 0 if always visible). + Finally convert the BGR result to RGB before saving. + Also saves each individual frame as a separate PNG file. + + Args: + images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC. + tracks: torch.Tensor (S, N, 2), last dim = (x, y). + track_vis_mask: torch.Tensor (S, N) or None. + out_dir: folder to save visualizations. + image_format: "CHW" or "HWC". + normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255 + cmap_name: a matplotlib colormap name for color_from_xy. + frames_per_row: number of frames to display in each row of the grid. + save_grid: whether to save all frames in one grid image. + + Returns: + None (saves images in out_dir). + """ + + if len(tracks.shape) == 4: + tracks = tracks.squeeze(0) + images = images.squeeze(0) + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.squeeze(0) + + import matplotlib + + matplotlib.use("Agg") # for non-interactive (optional) + + os.makedirs(out_dir, exist_ok=True) + + S = images.shape[0] + _, N, _ = tracks.shape # (S, N, 2) + + # Move to CPU + images = images.cpu().clone() + tracks = tracks.cpu().clone() + if track_vis_mask is not None: + track_vis_mask = track_vis_mask.cpu().clone() + + # Infer H, W from images shape + if image_format == "CHW": + # e.g. images[s].shape = (3, H, W) + H, W = images.shape[2], images.shape[3] + else: + # e.g. images[s].shape = (H, W, 3) + H, W = images.shape[1], images.shape[2] + + # Pre-compute the color for each track i based on first visible position + track_colors_rgb = get_track_colors_by_position( + tracks, # shape (S, N, 2) + vis_mask_b=track_vis_mask if track_vis_mask is not None else None, + image_width=W, + image_height=H, + cmap_name=cmap_name, + ) + + # We'll accumulate each frame's drawn image in a list + frame_images = [] + + for s in range(S): + # shape => either (3, H, W) or (H, W, 3) + img = images[s] + + # Convert to (H, W, 3) + if image_format == "CHW": + img = img.permute(1, 2, 0) # (H, W, 3) + # else "HWC", do nothing + + img = img.numpy().astype(np.float32) + + # Scale to [0,255] if needed + if normalize_mode == "[0,1]": + img = np.clip(img, 0, 1) * 255.0 + elif normalize_mode == "[-1,1]": + img = (img + 1.0) * 0.5 * 255.0 + img = np.clip(img, 0, 255.0) + # else no normalization + + # Convert to uint8 + img = img.astype(np.uint8) + + # For drawing in OpenCV, convert to BGR + img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + # Draw each visible track + cur_tracks = tracks[s] # shape (N, 2) + if track_vis_mask is not None: + valid_indices = torch.where(track_vis_mask[s])[0] + else: + valid_indices = range(N) + + cur_tracks_np = cur_tracks.numpy() + for i in valid_indices: + x, y = cur_tracks_np[i] + pt = (int(round(x)), int(round(y))) + + # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR + R, G, B = track_colors_rgb[i] + color_bgr = (int(B), int(G), int(R)) + cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1) + + # Convert back to RGB for consistent final saving: + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + + # Save individual frame + frame_path = os.path.join(out_dir, f"frame_{s:04d}.png") + # Convert to BGR for OpenCV imwrite + frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) + cv2.imwrite(frame_path, frame_bgr) + + frame_images.append(img_rgb) + + # Only create and save the grid image if save_grid is True + if save_grid: + # Calculate grid dimensions + num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division + + # Create a grid of images + grid_img = None + for row in range(num_rows): + start_idx = row * frames_per_row + end_idx = min(start_idx + frames_per_row, S) + + # Concatenate this row horizontally + row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1) + + # If this row has fewer than frames_per_row images, pad with black + if end_idx - start_idx < frames_per_row: + padding_width = (frames_per_row - (end_idx - start_idx)) * W + padding = np.zeros((H, padding_width, 3), dtype=np.uint8) + row_img = np.concatenate([row_img, padding], axis=1) + + # Add this row to the grid + if grid_img is None: + grid_img = row_img + else: + grid_img = np.concatenate([grid_img, row_img], axis=0) + + out_path = os.path.join(out_dir, "tracks_grid.png") + # Convert back to BGR for OpenCV imwrite + grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, grid_img_bgr) + print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}") + + print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")