聂如 commited on
Commit ·
91126af
1
Parent(s): 7829591
Add design file
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- LEGAL.md +7 -0
- LICENSE.txt +7 -0
- README.md +100 -13
- app.py +291 -0
- dust3r/croco/datasets/__init__.py +0 -0
- dust3r/croco/datasets/crops/README.MD +104 -0
- dust3r/croco/datasets/crops/extract_crops_from_images.py +159 -0
- dust3r/croco/datasets/habitat_sim/README.MD +76 -0
- dust3r/croco/datasets/habitat_sim/__init__.py +0 -0
- dust3r/croco/datasets/habitat_sim/generate_from_metadata.py +92 -0
- dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py +27 -0
- dust3r/croco/datasets/habitat_sim/generate_multiview_images.py +177 -0
- dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +390 -0
- dust3r/croco/datasets/habitat_sim/pack_metadata_files.py +69 -0
- dust3r/croco/datasets/habitat_sim/paths.py +129 -0
- dust3r/croco/datasets/pairs_dataset.py +109 -0
- dust3r/croco/datasets/transforms.py +95 -0
- dust3r/croco/models/__pycache__/blocks.cpython-312.pyc +0 -0
- dust3r/croco/models/__pycache__/croco.cpython-312.pyc +0 -0
- dust3r/croco/models/__pycache__/dpt_block.cpython-312.pyc +0 -0
- dust3r/croco/models/__pycache__/masking.cpython-312.pyc +0 -0
- dust3r/croco/models/__pycache__/pos_embed.cpython-312.pyc +0 -0
- dust3r/croco/models/blocks.py +307 -0
- dust3r/croco/models/criterion.py +37 -0
- dust3r/croco/models/croco.py +288 -0
- dust3r/croco/models/dpt_block.py +450 -0
- dust3r/croco/models/head_downstream.py +58 -0
- dust3r/croco/models/masking.py +25 -0
- dust3r/croco/models/pos_embed.py +159 -0
- dust3r/croco/models/transformer_utils.py +1021 -0
- dust3r/croco/models/x_transformer.py +558 -0
- dust3r/croco/utils/misc.py +583 -0
- dust3r/dust3r/__init__.py +2 -0
- dust3r/dust3r/__pycache__/__init__.cpython-312.pyc +0 -0
- dust3r/dust3r/__pycache__/model.cpython-312.pyc +0 -0
- dust3r/dust3r/__pycache__/patch_embed.cpython-312.pyc +0 -0
- dust3r/dust3r/__pycache__/viz.cpython-312.pyc +0 -0
- dust3r/dust3r/datasets/CustomDataset.py +145 -0
- dust3r/dust3r/datasets/__init__.py +39 -0
- dust3r/dust3r/datasets/__pycache__/CustomDataset.cpython-312.pyc +0 -0
- dust3r/dust3r/datasets/__pycache__/__init__.cpython-312.pyc +0 -0
- dust3r/dust3r/datasets/base/__init__.py +2 -0
- dust3r/dust3r/datasets/base/__pycache__/__init__.cpython-312.pyc +0 -0
- dust3r/dust3r/datasets/base/__pycache__/base_stereo_view_dataset.cpython-312.pyc +0 -0
- dust3r/dust3r/datasets/base/__pycache__/batched_sampler.cpython-312.pyc +0 -0
- dust3r/dust3r/datasets/base/__pycache__/easy_dataset.cpython-312.pyc +0 -0
- dust3r/dust3r/datasets/base/base_stereo_view_dataset.py +774 -0
- dust3r/dust3r/datasets/base/batched_sampler.py +74 -0
- dust3r/dust3r/datasets/base/easy_dataset.py +157 -0
- dust3r/dust3r/datasets/utils/__init__.py +2 -0
LEGAL.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Legal Disclaimer
|
| 2 |
+
|
| 3 |
+
Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
|
| 4 |
+
|
| 5 |
+
法律免责声明
|
| 6 |
+
|
| 7 |
+
关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FLARE, Copyright (c) 2025-present Ant Group, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
|
| 2 |
+
|
| 3 |
+
A summary of the CC BY-NC-SA 4.0 license is located here:
|
| 4 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 5 |
+
|
| 6 |
+
The CC BY-NC-SA 4.0 license is located here:
|
| 7 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
README.md
CHANGED
|
@@ -1,13 +1,100 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FLARE: Feed-forward Geometry, Appearance and Camera Estimation from Uncalibrated Sparse Views
|
| 2 |
+
[](https://zhanghe3z.github.io/FLARE/)
|
| 3 |
+
[](https://huggingface.co/AntResearch/FLARE)
|
| 4 |
+
[](https://zhanghe3z.github.io/FLARE/videos/teaser_video.mp4)
|
| 5 |
+
|
| 6 |
+
Official implementation of **FLARE** (CVPR 2025) - a feed-forward model for joint camera pose estimation, 3D reconstruction and novel view synthesis from sparse uncalibrated views.
|
| 7 |
+
|
| 8 |
+

|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
<!-- TOC start (generated with https://github.com/derlin/bitdowntoc) -->
|
| 12 |
+
|
| 13 |
+
- [📖 Overview](#-overview)
|
| 14 |
+
- [🛠️ TODO List](#-todo-list)
|
| 15 |
+
- [🌍 Installation](#-installation)
|
| 16 |
+
- [💿 Checkpoints](#-checkpoints)
|
| 17 |
+
- [🎯 Run a Demo (Point Cloud and Camera Pose Estimation) ](#-run-a-demo-point-cloud-and-camera-pose-estimation)
|
| 18 |
+
- [👀 Visualization ](#-visualization)
|
| 19 |
+
- [📜 Citation ](#-citation)
|
| 20 |
+
|
| 21 |
+
<!-- TOC end -->
|
| 22 |
+
|
| 23 |
+
## 📖 Overview
|
| 24 |
+
We present FLARE, a feed-forward model that simultaneously estimates high-quality camera poses, 3D geometry, and appearance from as few as 2-8 uncalibrated images. Our cascaded learning paradigm:
|
| 25 |
+
|
| 26 |
+
1. **Camera Pose Estimation**: Directly regress camera poses without bundle adjustment
|
| 27 |
+
2. **Geometry Reconstruction**: Decompose geometry reconstruction into two simpler sub-problems
|
| 28 |
+
3. **Appearance Modeling**: Enable photorealistic novel view synthesis via 3D Gaussians
|
| 29 |
+
|
| 30 |
+
Achieves SOTA performance with inference times <0.5 seconds!
|
| 31 |
+
|
| 32 |
+
## 🛠️ TODO List
|
| 33 |
+
- [x] Release point cloud and camera pose estimation code.
|
| 34 |
+
- [x] Updated Gradio demo (app.py).
|
| 35 |
+
- [ ] Release novel view synthesis code. (~2 weeks)
|
| 36 |
+
- [ ] Release evaluation code. (~2 weeks)
|
| 37 |
+
- [ ] Release training code.
|
| 38 |
+
- [ ] Release data processing code.
|
| 39 |
+
|
| 40 |
+
## 🌍 Installation
|
| 41 |
+
|
| 42 |
+
```
|
| 43 |
+
conda create -n flare python=3.8
|
| 44 |
+
conda activate flare
|
| 45 |
+
conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system
|
| 46 |
+
pip install -r requirements.txt
|
| 47 |
+
conda uninstall ffmpeg
|
| 48 |
+
conda install -c conda-forge ffmpeg
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
## 💿 Checkpoints
|
| 53 |
+
Download the checkpoint from [huggingface](https://huggingface.co/AntResearch/FLARE/blob/main/geometry_pose.pth) and place it in the /checkpoints/geometry_pose.pth directory.
|
| 54 |
+
|
| 55 |
+
## 🎯 Run a Demo (Point Cloud and Camera Pose Estimation)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
sh scripts/run_pose_pointcloud.sh
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
torchrun --nproc_per_node=1 run_pose_pointcloud.py \
|
| 65 |
+
--test_dataset "1 @ CustomDataset(split='train', ROOT='Your/Data/Path', resolution=(512,384), seed=1, num_views=7, gt_num_image=0, aug_portrait_or_landscape=False, sequential_input=False)" \
|
| 66 |
+
--model "AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True, desc_conf_mode=('exp', 0, inf))" \
|
| 67 |
+
--pretrained "Your/Checkpoint/Path" \
|
| 68 |
+
--test_criterion "MeshOutput(sam=False)" --output_dir "log/" --amp 1 --seed 1 --num_workers 0
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**To run the demo using ground truth camera poses:**
|
| 72 |
+
Enable the wpose=True flag in both the CustomDataset and AsymmetricMASt3R. An example script demonstrating this setup is provided in run_pose_pointcloud_wpose.sh.
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
sh scripts/run_pose_pointcloud_wpose.sh
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## 👀 Visualization
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
sh ./visualizer/vis.sh
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
CUDA_VISIBLE_DEVICES=0 python visualizer/run_vis.py --result_npz data/mesh/IMG_1511.HEIC.JPG.JPG/pred.npz --results_folder data/mesh/IMG_1511.HEIC.JPG.JPG/
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
## 📜 Citation
|
| 91 |
+
```bibtex
|
| 92 |
+
@misc{zhang2025flarefeedforwardgeometryappearance,
|
| 93 |
+
title={FLARE: Feed-forward Geometry, Appearance and Camera Estimation from Uncalibrated Sparse Views},
|
| 94 |
+
author={Shangzhan Zhang and Jianyuan Wang and Yinghao Xu and Nan Xue and Christian Rupprecht and Xiaowei Zhou and Yujun Shen and Gordon Wetzstein},
|
| 95 |
+
year={2025},
|
| 96 |
+
eprint={2502.12138},
|
| 97 |
+
archivePrefix={arXiv},
|
| 98 |
+
primaryClass={cs.CV},
|
| 99 |
+
url={https://arxiv.org/abs/2502.12138},
|
| 100 |
+
}
|
app.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
import mast3r.utils.path_to_dust3r # noqa
|
| 3 |
+
import dust3r.utils.path_to_croco # noqa: F401
|
| 4 |
+
import mast3r.utils.path_to_dust3r # noqa
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import os.path as path
|
| 8 |
+
import torch
|
| 9 |
+
import tempfile
|
| 10 |
+
import gradio
|
| 11 |
+
import shutil
|
| 12 |
+
import math
|
| 13 |
+
from mast3r.model import AsymmetricMASt3R
|
| 14 |
+
import matplotlib.pyplot as pl
|
| 15 |
+
from dust3r.utils.image import load_images
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from pytorch3d.ops import knn_points
|
| 18 |
+
from dust3r.utils.geometry import xy_grid
|
| 19 |
+
import numpy as np
|
| 20 |
+
import cv2
|
| 21 |
+
from dust3r.utils.device import to_numpy
|
| 22 |
+
import trimesh
|
| 23 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
| 24 |
+
from scipy.spatial.transform import Rotation
|
| 25 |
+
|
| 26 |
+
pl.ion()
|
| 27 |
+
# for gpu >= Ampere and pytorch >= 1.12
|
| 28 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 29 |
+
batch_size = 1
|
| 30 |
+
inf = float('inf')
|
| 31 |
+
weights_path = "checkpoints/geometry_pose.pth"
|
| 32 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 33 |
+
ckpt = torch.load(weights_path, map_location=device)
|
| 34 |
+
model = AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True, desc_conf_mode=('exp', 0, inf))
|
| 35 |
+
model = AsymmetricMASt3R.from_pretrained("zhang3z/FLARE").to(device)
|
| 36 |
+
# model.from_pretrained(ckpt['model'], strict=False)
|
| 37 |
+
model = model.to(device).eval()
|
| 38 |
+
|
| 39 |
+
tmpdirname = tempfile.mkdtemp(suffix='_FLARE_gradio_demo')
|
| 40 |
+
image_size = 512
|
| 41 |
+
silent = True
|
| 42 |
+
gradio_delete_cache = 7200
|
| 43 |
+
backbone = torch.hub.load(
|
| 44 |
+
"facebookresearch/dinov2", "dinov2_vitb14_reg"
|
| 45 |
+
)
|
| 46 |
+
backbone = backbone.eval().cuda()
|
| 47 |
+
|
| 48 |
+
class FileState:
|
| 49 |
+
def __init__(self, outfile_name=None):
|
| 50 |
+
self.outfile_name = outfile_name
|
| 51 |
+
|
| 52 |
+
def __del__(self):
|
| 53 |
+
if self.outfile_name is not None and os.path.isfile(self.outfile_name):
|
| 54 |
+
os.remove(self.outfile_name)
|
| 55 |
+
self.outfile_name = None
|
| 56 |
+
|
| 57 |
+
def pad_to_square(reshaped_image):
|
| 58 |
+
B, C, H, W = reshaped_image.shape
|
| 59 |
+
max_dim = max(H, W)
|
| 60 |
+
pad_height = max_dim - H
|
| 61 |
+
pad_width = max_dim - W
|
| 62 |
+
padding = (pad_width // 2, pad_width - pad_width // 2,
|
| 63 |
+
pad_height // 2, pad_height - pad_height // 2)
|
| 64 |
+
padded_image = F.pad(reshaped_image, padding, mode='constant', value=0)
|
| 65 |
+
return padded_image
|
| 66 |
+
|
| 67 |
+
def generate_rank_by_dino(
|
| 68 |
+
reshaped_image, backbone, query_frame_num, image_size=336
|
| 69 |
+
):
|
| 70 |
+
# Downsample image to image_size x image_size
|
| 71 |
+
# because we found it is unnecessary to use high resolution
|
| 72 |
+
rgbs = pad_to_square(reshaped_image)
|
| 73 |
+
rgbs = F.interpolate(
|
| 74 |
+
reshaped_image,
|
| 75 |
+
(image_size, image_size),
|
| 76 |
+
mode="bilinear",
|
| 77 |
+
align_corners=True,
|
| 78 |
+
)
|
| 79 |
+
rgbs = _resnet_normalize_image(rgbs.cuda())
|
| 80 |
+
|
| 81 |
+
# Get the image features (patch level)
|
| 82 |
+
frame_feat = backbone(rgbs, is_training=True)
|
| 83 |
+
frame_feat = frame_feat["x_norm_patchtokens"]
|
| 84 |
+
frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
|
| 85 |
+
|
| 86 |
+
# Compute the similiarty matrix
|
| 87 |
+
frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
|
| 88 |
+
similarity_matrix = torch.bmm(
|
| 89 |
+
frame_feat_norm, frame_feat_norm.transpose(-1, -2)
|
| 90 |
+
)
|
| 91 |
+
similarity_matrix = similarity_matrix.mean(dim=0)
|
| 92 |
+
distance_matrix = 100 - similarity_matrix.clone()
|
| 93 |
+
|
| 94 |
+
# Ignore self-pairing
|
| 95 |
+
similarity_matrix.fill_diagonal_(-100)
|
| 96 |
+
|
| 97 |
+
similarity_sum = similarity_matrix.sum(dim=1)
|
| 98 |
+
|
| 99 |
+
# Find the most common frame
|
| 100 |
+
most_common_frame_index = torch.argmax(similarity_sum).item()
|
| 101 |
+
return most_common_frame_index
|
| 102 |
+
|
| 103 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
| 104 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
| 105 |
+
_resnet_mean = torch.tensor(_RESNET_MEAN).view(1, 3, 1, 1).cuda()
|
| 106 |
+
_resnet_std = torch.tensor(_RESNET_STD).view(1, 3, 1, 1).cuda()
|
| 107 |
+
def _resnet_normalize_image(img: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
return (img - _resnet_mean) / _resnet_std
|
| 109 |
+
|
| 110 |
+
def calculate_index_mappings(query_index, S, device=None):
|
| 111 |
+
"""
|
| 112 |
+
Construct an order that we can switch [query_index] and [0]
|
| 113 |
+
so that the content of query_index would be placed at [0]
|
| 114 |
+
"""
|
| 115 |
+
new_order = torch.arange(S)
|
| 116 |
+
new_order[0] = query_index
|
| 117 |
+
new_order[query_index] = 0
|
| 118 |
+
if device is not None:
|
| 119 |
+
new_order = new_order.to(device)
|
| 120 |
+
return new_order
|
| 121 |
+
|
| 122 |
+
def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
| 123 |
+
cam_color=None, as_pointcloud=False,
|
| 124 |
+
transparent_cams=False, silent=False):
|
| 125 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
| 126 |
+
pts3d = to_numpy(pts3d)
|
| 127 |
+
imgs = to_numpy(imgs)
|
| 128 |
+
focals = to_numpy(focals)
|
| 129 |
+
mask = to_numpy(mask)
|
| 130 |
+
cams2world = to_numpy(cams2world)
|
| 131 |
+
|
| 132 |
+
scene = trimesh.Scene()
|
| 133 |
+
# full pointcloud
|
| 134 |
+
if as_pointcloud:
|
| 135 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
|
| 136 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
|
| 137 |
+
valid_msk = np.isfinite(pts.sum(axis=1))
|
| 138 |
+
pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
|
| 139 |
+
scene.add_geometry(pct)
|
| 140 |
+
else:
|
| 141 |
+
meshes = []
|
| 142 |
+
for i in range(len(imgs)):
|
| 143 |
+
pts3d_i = pts3d[i].reshape(imgs[i].shape)
|
| 144 |
+
msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
|
| 145 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
|
| 146 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
| 147 |
+
scene.add_geometry(mesh)
|
| 148 |
+
|
| 149 |
+
# add each camera
|
| 150 |
+
for i, pose_c2w in enumerate(cams2world):
|
| 151 |
+
if isinstance(cam_color, list):
|
| 152 |
+
camera_edge_color = cam_color[i]
|
| 153 |
+
else:
|
| 154 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
| 155 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
| 156 |
+
None if transparent_cams else imgs[i], focals[i],
|
| 157 |
+
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
| 158 |
+
|
| 159 |
+
rot = np.eye(4)
|
| 160 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
| 161 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
| 162 |
+
if not silent:
|
| 163 |
+
print('(exporting 3D scene to', outfile, ')')
|
| 164 |
+
|
| 165 |
+
scene.export(file_obj=outfile)
|
| 166 |
+
return outfile
|
| 167 |
+
|
| 168 |
+
@spaces.GPU(duration=180)
|
| 169 |
+
def local_get_reconstructed_scene(inputfiles, min_conf_thr, cam_size):
|
| 170 |
+
|
| 171 |
+
batch = load_images(inputfiles, size=image_size, verbose=not silent)
|
| 172 |
+
images = [gt['img'] for gt in batch]
|
| 173 |
+
images = torch.cat(images, dim=0)
|
| 174 |
+
images = images / 2 + 0.5
|
| 175 |
+
index = generate_rank_by_dino(images, backbone, query_frame_num=1)
|
| 176 |
+
sorted_order = calculate_index_mappings(index, len(images), device=device)
|
| 177 |
+
sorted_batch = []
|
| 178 |
+
for i in range(len(batch)):
|
| 179 |
+
sorted_batch.append(batch[sorted_order[i]])
|
| 180 |
+
batch = sorted_batch
|
| 181 |
+
ignore_keys = set(['depthmap', 'dataset', 'label', 'instance', 'idx', 'rng', 'vid'])
|
| 182 |
+
ignore_dtype_keys = set(['true_shape', 'camera_pose', 'pts3d', 'fxfycxcy', 'img_org', 'camera_intrinsics', 'depthmap', 'depth_anything', 'fxfycxcy_unorm'])
|
| 183 |
+
dtype = torch.bfloat16
|
| 184 |
+
for view in batch:
|
| 185 |
+
for name in view.keys(): # pseudo_focal
|
| 186 |
+
if name in ignore_keys:
|
| 187 |
+
continue
|
| 188 |
+
if isinstance(view[name], torch.Tensor):
|
| 189 |
+
view[name] = view[name].to(device, non_blocking=True)
|
| 190 |
+
else:
|
| 191 |
+
view[name] = torch.tensor(view[name]).to(device, non_blocking=True)
|
| 192 |
+
if view[name].dtype == torch.float32 and name not in ignore_dtype_keys:
|
| 193 |
+
view[name] = view[name].to(dtype)
|
| 194 |
+
view1 = batch[:1]
|
| 195 |
+
view2 = batch[1:]
|
| 196 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
|
| 197 |
+
pred1, pred2, pred_cameras = model(view1, view2, True, dtype)
|
| 198 |
+
pts3d = pred2['pts3d']
|
| 199 |
+
conf = pred2['conf']
|
| 200 |
+
pts3d = pts3d.detach().cpu()
|
| 201 |
+
B, N, H, W, _ = pts3d.shape
|
| 202 |
+
thres = torch.quantile(conf.flatten(2,3), min_conf_thr, dim=-1)[0]
|
| 203 |
+
masks_conf = conf > thres[None, :, None, None]
|
| 204 |
+
masks_conf = masks_conf.cpu()
|
| 205 |
+
|
| 206 |
+
images = [view['img'] for view in view1+view2]
|
| 207 |
+
shape = torch.stack([view['true_shape'] for view in view1+view2], dim=1).detach().cpu().numpy()
|
| 208 |
+
images = torch.stack(images,1).float().permute(0,1,3,4,2).detach().cpu().numpy()
|
| 209 |
+
images = images / 2 + 0.5
|
| 210 |
+
images = images.reshape(B, N, H, W, 3)
|
| 211 |
+
# estimate focal length
|
| 212 |
+
images = images[0]
|
| 213 |
+
pts3d = pts3d[0]
|
| 214 |
+
masks_conf = masks_conf[0]
|
| 215 |
+
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1)
|
| 216 |
+
pp = torch.tensor((W/2, H/2)).to(xy_over_z)
|
| 217 |
+
pixels = xy_grid(W, H, device=xy_over_z.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2
|
| 218 |
+
u, v = pixels[:1].unbind(dim=-1)
|
| 219 |
+
x, y, z = pts3d[:1].reshape(-1,3).unbind(dim=-1)
|
| 220 |
+
fx_votes = (u * z) / x
|
| 221 |
+
fy_votes = (v * z) / y
|
| 222 |
+
# assume square pixels, hence same focal for X and Y
|
| 223 |
+
f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
|
| 224 |
+
focal = torch.nanmedian(f_votes, dim=-1).values
|
| 225 |
+
focal = focal.item()
|
| 226 |
+
pts3d = pts3d.numpy()
|
| 227 |
+
# use PNP to estimate camera poses
|
| 228 |
+
pred_poses = []
|
| 229 |
+
for i in range(pts3d.shape[0]):
|
| 230 |
+
shape_input_each = shape[:, i]
|
| 231 |
+
mesh_grid = xy_grid(shape_input_each[0,1], shape_input_each[0,0])
|
| 232 |
+
cur_inlier = conf[0,i] > torch.quantile(conf[0,i], 0.6)
|
| 233 |
+
cur_inlier = cur_inlier.detach().cpu().numpy()
|
| 234 |
+
ransac_thres = 0.5
|
| 235 |
+
confidence = 0.9999
|
| 236 |
+
iterationsCount = 10_000
|
| 237 |
+
cur_pts3d = pts3d[i]
|
| 238 |
+
K = np.float32([(focal, 0, W/2), (0, focal, H/2), (0, 0, 1)])
|
| 239 |
+
success, r_pose, t_pose, _ = cv2.solvePnPRansac(cur_pts3d[cur_inlier].astype(np.float64), mesh_grid[cur_inlier].astype(np.float64), K, None,
|
| 240 |
+
flags=cv2.SOLVEPNP_SQPNP,
|
| 241 |
+
iterationsCount=iterationsCount,
|
| 242 |
+
reprojectionError=1,
|
| 243 |
+
confidence=confidence)
|
| 244 |
+
r_pose = cv2.Rodrigues(r_pose)[0]
|
| 245 |
+
RT = np.r_[np.c_[r_pose, t_pose], [(0,0,0,1)]]
|
| 246 |
+
cam2world = np.linalg.inv(RT)
|
| 247 |
+
pred_poses.append(cam2world)
|
| 248 |
+
pred_poses = np.stack(pred_poses, axis=0)
|
| 249 |
+
pred_poses = torch.tensor(pred_poses)
|
| 250 |
+
# use knn to clean the point cloud
|
| 251 |
+
K = 10
|
| 252 |
+
points = torch.tensor(pts3d.reshape(1,-1,3)).cuda()
|
| 253 |
+
knn = knn_points(points, points, K=K)
|
| 254 |
+
dists = knn.dists
|
| 255 |
+
mean_dists = dists.mean(dim=-1)
|
| 256 |
+
masks_dist = mean_dists < torch.quantile(mean_dists.reshape(-1), 0.95)
|
| 257 |
+
masks_dist = masks_dist.detach().cpu().numpy()
|
| 258 |
+
masks_conf = (masks_conf > 0) & masks_dist.reshape(-1,H,W)
|
| 259 |
+
masks_conf = masks_conf > 0
|
| 260 |
+
outdir = tempfile.mkdtemp(suffix='_FLARE_gradio_demo')
|
| 261 |
+
os.makedirs(outdir, exist_ok=True)
|
| 262 |
+
focals = [focal] * len(images)
|
| 263 |
+
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
|
| 264 |
+
|
| 265 |
+
_convert_scene_output_to_glb(outfile_name, images, pts3d, masks_conf, focals, pred_poses, as_pointcloud=True,
|
| 266 |
+
transparent_cams=False, cam_size=cam_size, silent=silent)
|
| 267 |
+
return filestate, outfile_name
|
| 268 |
+
|
| 269 |
+
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
| 270 |
+
title = "FLARE Demo"
|
| 271 |
+
with gradio.Blocks(css=css, title=title, delete_cache=(gradio_delete_cache, gradio_delete_cache)) as demo:
|
| 272 |
+
filestate = gradio.State(None)
|
| 273 |
+
gradio.HTML('<h2 style="text-align: center;">3D Reconstruction with FLARE</h2>')
|
| 274 |
+
with gradio.Column():
|
| 275 |
+
inputfiles = gradio.File(file_count="multiple")
|
| 276 |
+
snapshot = gradio.Image(None, visible=False)
|
| 277 |
+
with gradio.Row():
|
| 278 |
+
# adjust the confidence threshold
|
| 279 |
+
min_conf_thr = gradio.Slider(label="min_conf_thr", value=0.1, minimum=0.0, maximum=1, step=0.05)
|
| 280 |
+
# adjust the camera size in the output pointcloud
|
| 281 |
+
cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
|
| 282 |
+
run_btn = gradio.Button("Run")
|
| 283 |
+
outmodel = gradio.Model3D()
|
| 284 |
+
|
| 285 |
+
# events
|
| 286 |
+
run_btn.click(fn=local_get_reconstructed_scene,
|
| 287 |
+
inputs=[inputfiles, min_conf_thr, cam_size],
|
| 288 |
+
outputs=[filestate, outmodel])
|
| 289 |
+
|
| 290 |
+
demo.launch(show_error=True, share=None, server_name=None, server_port=None)
|
| 291 |
+
shutil.rmtree(tmpdirname)
|
dust3r/croco/datasets/__init__.py
ADDED
|
File without changes
|
dust3r/croco/datasets/crops/README.MD
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Generation of crops from the real datasets
|
| 2 |
+
|
| 3 |
+
The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
|
| 4 |
+
|
| 5 |
+
### Download the metadata of the crops to generate
|
| 6 |
+
|
| 7 |
+
First, download the metadata and put them in `./data/`:
|
| 8 |
+
```
|
| 9 |
+
mkdir -p data
|
| 10 |
+
cd data/
|
| 11 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
|
| 12 |
+
unzip crop_metadata.zip
|
| 13 |
+
rm crop_metadata.zip
|
| 14 |
+
cd ..
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
### Prepare the original datasets
|
| 18 |
+
|
| 19 |
+
Second, download the original datasets in `./data/original_datasets/`.
|
| 20 |
+
```
|
| 21 |
+
mkdir -p data/original_datasets
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
##### ARKitScenes
|
| 25 |
+
|
| 26 |
+
Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
|
| 27 |
+
The resulting file structure should be like:
|
| 28 |
+
```
|
| 29 |
+
./data/original_datasets/ARKitScenes/
|
| 30 |
+
└───Training
|
| 31 |
+
└───40753679
|
| 32 |
+
│ │ ultrawide
|
| 33 |
+
│ │ ...
|
| 34 |
+
└───40753686
|
| 35 |
+
│
|
| 36 |
+
...
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
##### MegaDepth
|
| 40 |
+
|
| 41 |
+
Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
|
| 42 |
+
The resulting file structure should be like:
|
| 43 |
+
|
| 44 |
+
```
|
| 45 |
+
./data/original_datasets/MegaDepth/
|
| 46 |
+
└───0000
|
| 47 |
+
│ └───images
|
| 48 |
+
│ │ │ 1000557903_87fa96b8a4_o.jpg
|
| 49 |
+
│ │ └ ...
|
| 50 |
+
│ └─── ...
|
| 51 |
+
└───0001
|
| 52 |
+
│ │
|
| 53 |
+
│ └ ...
|
| 54 |
+
└─── ...
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
##### 3DStreetView
|
| 58 |
+
|
| 59 |
+
Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
|
| 60 |
+
The resulting file structure should be like:
|
| 61 |
+
|
| 62 |
+
```
|
| 63 |
+
./data/original_datasets/3DStreetView/
|
| 64 |
+
└───dataset_aligned
|
| 65 |
+
│ └───0002
|
| 66 |
+
│ │ │ 0000002_0000001_0000002_0000001.jpg
|
| 67 |
+
│ │ └ ...
|
| 68 |
+
│ └─── ...
|
| 69 |
+
└───dataset_unaligned
|
| 70 |
+
│ └───0003
|
| 71 |
+
│ │ │ 0000003_0000001_0000002_0000001.jpg
|
| 72 |
+
│ │ └ ...
|
| 73 |
+
│ └─── ...
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
##### IndoorVL
|
| 77 |
+
|
| 78 |
+
Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
pip install kapture
|
| 82 |
+
mkdir -p ./data/original_datasets/IndoorVL
|
| 83 |
+
cd ./data/original_datasets/IndoorVL
|
| 84 |
+
kapture_download_dataset.py update
|
| 85 |
+
kapture_download_dataset.py install "HyundaiDepartmentStore_*"
|
| 86 |
+
kapture_download_dataset.py install "GangnamStation_*"
|
| 87 |
+
cd -
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Extract the crops
|
| 91 |
+
|
| 92 |
+
Now, extract the crops for each of the dataset:
|
| 93 |
+
```
|
| 94 |
+
for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
|
| 95 |
+
do
|
| 96 |
+
python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
|
| 97 |
+
done
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
##### Note for IndoorVL
|
| 101 |
+
|
| 102 |
+
Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
|
| 103 |
+
To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
|
| 104 |
+
The impact on the performance is negligible.
|
dust3r/croco/datasets/crops/extract_crops_from_images.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Extracting crops for pre-training
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import argparse
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import functools
|
| 13 |
+
from multiprocessing import Pool
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def arg_parser():
|
| 18 |
+
parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
|
| 19 |
+
|
| 20 |
+
parser.add_argument('--crops', type=str, required=True, help='crop file')
|
| 21 |
+
parser.add_argument('--root-dir', type=str, required=True, help='root directory')
|
| 22 |
+
parser.add_argument('--output-dir', type=str, required=True, help='output directory')
|
| 23 |
+
parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
|
| 24 |
+
parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
|
| 25 |
+
parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
|
| 26 |
+
parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
|
| 27 |
+
return parser
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main(args):
|
| 31 |
+
listing_path = os.path.join(args.output_dir, 'listing.txt')
|
| 32 |
+
|
| 33 |
+
print(f'Loading list of crops ... ({args.nthread} threads)')
|
| 34 |
+
crops, num_crops_to_generate = load_crop_file(args.crops)
|
| 35 |
+
|
| 36 |
+
print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
|
| 37 |
+
num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
|
| 38 |
+
num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
|
| 39 |
+
|
| 40 |
+
jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
|
| 41 |
+
del crops
|
| 42 |
+
|
| 43 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 44 |
+
mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
|
| 45 |
+
call = functools.partial(save_image_crops, args)
|
| 46 |
+
|
| 47 |
+
print(f"Generating cropped images to {args.output_dir} ...")
|
| 48 |
+
with open(listing_path, 'w') as listing:
|
| 49 |
+
listing.write('# pair_path\n')
|
| 50 |
+
for results in tqdm(mmap(call, jobs), total=len(jobs)):
|
| 51 |
+
for path in results:
|
| 52 |
+
listing.write(f'{path}\n')
|
| 53 |
+
print('Finished writing listing to', listing_path)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_crop_file(path):
|
| 57 |
+
data = open(path).read().splitlines()
|
| 58 |
+
pairs = []
|
| 59 |
+
num_crops_to_generate = 0
|
| 60 |
+
for line in tqdm(data):
|
| 61 |
+
if line.startswith('#'):
|
| 62 |
+
continue
|
| 63 |
+
line = line.split(', ')
|
| 64 |
+
if len(line) < 8:
|
| 65 |
+
img1, img2, rotation = line
|
| 66 |
+
pairs.append((img1, img2, int(rotation), []))
|
| 67 |
+
else:
|
| 68 |
+
l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
|
| 69 |
+
rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
|
| 70 |
+
pairs[-1][-1].append((rect1, rect2))
|
| 71 |
+
num_crops_to_generate += 1
|
| 72 |
+
return pairs, num_crops_to_generate
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
|
| 76 |
+
jobs = []
|
| 77 |
+
powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
|
| 78 |
+
|
| 79 |
+
def get_path(idx):
|
| 80 |
+
idx_array = []
|
| 81 |
+
d = idx
|
| 82 |
+
for level in range(num_levels - 1):
|
| 83 |
+
idx_array.append(idx // powers[level])
|
| 84 |
+
idx = idx % powers[level]
|
| 85 |
+
idx_array.append(d)
|
| 86 |
+
return '/'.join(map(lambda x: hex(x)[2:], idx_array))
|
| 87 |
+
|
| 88 |
+
idx = 0
|
| 89 |
+
for pair_data in tqdm(pairs):
|
| 90 |
+
img1, img2, rotation, crops = pair_data
|
| 91 |
+
if -60 <= rotation and rotation <= 60:
|
| 92 |
+
rotation = 0 # most likely not a true rotation
|
| 93 |
+
paths = [get_path(idx + k) for k in range(len(crops))]
|
| 94 |
+
idx += len(crops)
|
| 95 |
+
jobs.append(((img1, img2), rotation, crops, paths))
|
| 96 |
+
return jobs
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_image(path):
|
| 100 |
+
try:
|
| 101 |
+
return Image.open(path).convert('RGB')
|
| 102 |
+
except Exception as e:
|
| 103 |
+
print('skipping', path, e)
|
| 104 |
+
raise OSError()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def save_image_crops(args, data):
|
| 108 |
+
# load images
|
| 109 |
+
img_pair, rot, crops, paths = data
|
| 110 |
+
try:
|
| 111 |
+
img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
|
| 112 |
+
except OSError as e:
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
def area(sz):
|
| 116 |
+
return sz[0] * sz[1]
|
| 117 |
+
|
| 118 |
+
tgt_size = (args.imsize, args.imsize)
|
| 119 |
+
|
| 120 |
+
def prepare_crop(img, rect, rot=0):
|
| 121 |
+
# actual crop
|
| 122 |
+
img = img.crop(rect)
|
| 123 |
+
|
| 124 |
+
# resize to desired size
|
| 125 |
+
interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
|
| 126 |
+
img = img.resize(tgt_size, resample=interp)
|
| 127 |
+
|
| 128 |
+
# rotate the image
|
| 129 |
+
rot90 = (round(rot/90) % 4) * 90
|
| 130 |
+
if rot90 == 90:
|
| 131 |
+
img = img.transpose(Image.Transpose.ROTATE_90)
|
| 132 |
+
elif rot90 == 180:
|
| 133 |
+
img = img.transpose(Image.Transpose.ROTATE_180)
|
| 134 |
+
elif rot90 == 270:
|
| 135 |
+
img = img.transpose(Image.Transpose.ROTATE_270)
|
| 136 |
+
return img
|
| 137 |
+
|
| 138 |
+
results = []
|
| 139 |
+
for (rect1, rect2), path in zip(crops, paths):
|
| 140 |
+
crop1 = prepare_crop(img1, rect1)
|
| 141 |
+
crop2 = prepare_crop(img2, rect2, rot)
|
| 142 |
+
|
| 143 |
+
fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
|
| 144 |
+
fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
|
| 145 |
+
os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
|
| 146 |
+
|
| 147 |
+
assert not os.path.isfile(fullpath1), fullpath1
|
| 148 |
+
assert not os.path.isfile(fullpath2), fullpath2
|
| 149 |
+
crop1.save(fullpath1)
|
| 150 |
+
crop2.save(fullpath2)
|
| 151 |
+
results.append(path)
|
| 152 |
+
|
| 153 |
+
return results
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == '__main__':
|
| 157 |
+
args = arg_parser().parse_args()
|
| 158 |
+
main(args)
|
| 159 |
+
|
dust3r/croco/datasets/habitat_sim/README.MD
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Generation of synthetic image pairs using Habitat-Sim
|
| 2 |
+
|
| 3 |
+
These instructions allow to generate pre-training pairs from the Habitat simulator.
|
| 4 |
+
As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
|
| 5 |
+
|
| 6 |
+
### Download Habitat-Sim scenes
|
| 7 |
+
Download Habitat-Sim scenes:
|
| 8 |
+
- Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
|
| 9 |
+
- We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
|
| 10 |
+
- Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
|
| 11 |
+
```
|
| 12 |
+
./data/
|
| 13 |
+
└──habitat-sim-data/
|
| 14 |
+
└──scene_datasets/
|
| 15 |
+
├──hm3d/
|
| 16 |
+
├──gibson/
|
| 17 |
+
├──habitat-test-scenes/
|
| 18 |
+
├──replica_cad_baked_lighting/
|
| 19 |
+
├──replica_cad/
|
| 20 |
+
├──ReplicaDataset/
|
| 21 |
+
└──scannet/
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
### Image pairs generation
|
| 25 |
+
We provide metadata to generate reproducible images pairs for pretraining and validation.
|
| 26 |
+
Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
|
| 27 |
+
|
| 28 |
+
Specifications:
|
| 29 |
+
- 256x256 resolution images, with 60 degrees field of view .
|
| 30 |
+
- Up to 1000 image pairs per scene.
|
| 31 |
+
- Number of scenes considered/number of images pairs per dataset:
|
| 32 |
+
- Scannet: 1097 scenes / 985 209 pairs
|
| 33 |
+
- HM3D:
|
| 34 |
+
- hm3d/train: 800 / 800k pairs
|
| 35 |
+
- hm3d/val: 100 scenes / 100k pairs
|
| 36 |
+
- hm3d/minival: 10 scenes / 10k pairs
|
| 37 |
+
- habitat-test-scenes: 3 scenes / 3k pairs
|
| 38 |
+
- replica_cad_baked_lighting: 13 scenes / 13k pairs
|
| 39 |
+
|
| 40 |
+
- Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
|
| 41 |
+
|
| 42 |
+
Download metadata and extract it:
|
| 43 |
+
```bash
|
| 44 |
+
mkdir -p data/habitat_release_metadata/
|
| 45 |
+
cd data/habitat_release_metadata/
|
| 46 |
+
wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
|
| 47 |
+
tar -xvf multiview_habitat_metadata.tar.gz
|
| 48 |
+
cd ../..
|
| 49 |
+
# Location of the metadata
|
| 50 |
+
METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
Generate image pairs from metadata:
|
| 54 |
+
- The following command will print a list of commandlines to generate image pairs for each scene:
|
| 55 |
+
```bash
|
| 56 |
+
# Target output directory
|
| 57 |
+
PAIRS_DATASET_DIR="./data/habitat_release/"
|
| 58 |
+
python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
|
| 59 |
+
```
|
| 60 |
+
- One can launch multiple of such commands in parallel e.g. using GNU Parallel:
|
| 61 |
+
```bash
|
| 62 |
+
python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Metadata generation
|
| 66 |
+
|
| 67 |
+
Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
|
| 68 |
+
```bash
|
| 69 |
+
# Print commandlines to generate image pairs from the different scenes available.
|
| 70 |
+
PAIRS_DATASET_DIR=MY_CUSTOM_PATH
|
| 71 |
+
python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
|
| 72 |
+
|
| 73 |
+
# Once a dataset is generated, pack metadata files for reproducibility.
|
| 74 |
+
METADATA_DIR=MY_CUSTON_PATH
|
| 75 |
+
python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
|
| 76 |
+
```
|
dust3r/croco/datasets/habitat_sim/__init__.py
ADDED
|
File without changes
|
dust3r/croco/datasets/habitat_sim/generate_from_metadata.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator
|
| 9 |
+
from datasets.habitat_sim.paths import SCENES_DATASET
|
| 10 |
+
import argparse
|
| 11 |
+
import quaternion
|
| 12 |
+
import PIL.Image
|
| 13 |
+
import cv2
|
| 14 |
+
import json
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
def generate_multiview_images_from_metadata(metadata_filename,
|
| 18 |
+
output_dir,
|
| 19 |
+
overload_params = dict(),
|
| 20 |
+
scene_datasets_paths=None,
|
| 21 |
+
exist_ok=False):
|
| 22 |
+
"""
|
| 23 |
+
Generate images from a metadata file for reproducibility purposes.
|
| 24 |
+
"""
|
| 25 |
+
# Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
|
| 26 |
+
if scene_datasets_paths is not None:
|
| 27 |
+
scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True))
|
| 28 |
+
|
| 29 |
+
with open(metadata_filename, 'r') as f:
|
| 30 |
+
input_metadata = json.load(f)
|
| 31 |
+
metadata = dict()
|
| 32 |
+
for key, value in input_metadata.items():
|
| 33 |
+
# Optionally replace some paths
|
| 34 |
+
if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
|
| 35 |
+
if scene_datasets_paths is not None:
|
| 36 |
+
for dataset_label, dataset_path in scene_datasets_paths.items():
|
| 37 |
+
if value.startswith(dataset_label):
|
| 38 |
+
value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label)))
|
| 39 |
+
break
|
| 40 |
+
metadata[key] = value
|
| 41 |
+
|
| 42 |
+
# Overload some parameters
|
| 43 |
+
for key, value in overload_params.items():
|
| 44 |
+
metadata[key] = value
|
| 45 |
+
|
| 46 |
+
generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))])
|
| 47 |
+
generate_depth = metadata["generate_depth"]
|
| 48 |
+
|
| 49 |
+
os.makedirs(output_dir, exist_ok=exist_ok)
|
| 50 |
+
|
| 51 |
+
generator = MultiviewHabitatSimGenerator(**generation_entries)
|
| 52 |
+
|
| 53 |
+
# Generate views
|
| 54 |
+
for idx_label, data in tqdm(metadata['multiviews'].items()):
|
| 55 |
+
positions = data["positions"]
|
| 56 |
+
orientations = data["orientations"]
|
| 57 |
+
n = len(positions)
|
| 58 |
+
for oidx in range(n):
|
| 59 |
+
observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx]))
|
| 60 |
+
observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
|
| 61 |
+
# Color image saved using PIL
|
| 62 |
+
img = PIL.Image.fromarray(observation['color'][:,:,:3])
|
| 63 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
|
| 64 |
+
img.save(filename)
|
| 65 |
+
if generate_depth:
|
| 66 |
+
# Depth image as EXR file
|
| 67 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
|
| 68 |
+
cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
|
| 69 |
+
# Camera parameters
|
| 70 |
+
camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
|
| 71 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
|
| 72 |
+
with open(filename, "w") as f:
|
| 73 |
+
json.dump(camera_params, f)
|
| 74 |
+
# Save metadata
|
| 75 |
+
with open(os.path.join(output_dir, "metadata.json"), "w") as f:
|
| 76 |
+
json.dump(metadata, f)
|
| 77 |
+
|
| 78 |
+
generator.close()
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
parser = argparse.ArgumentParser()
|
| 82 |
+
parser.add_argument("--metadata_filename", required=True)
|
| 83 |
+
parser.add_argument("--output_dir", required=True)
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
|
| 86 |
+
generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename,
|
| 87 |
+
output_dir=args.output_dir,
|
| 88 |
+
scene_datasets_paths=SCENES_DATASET,
|
| 89 |
+
overload_params=dict(),
|
| 90 |
+
exist_ok=True)
|
| 91 |
+
|
| 92 |
+
|
dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Script generating commandlines to generate image pairs from metadata files.
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import glob
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import argparse
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument("--input_dir", required=True)
|
| 15 |
+
parser.add_argument("--output_dir", required=True)
|
| 16 |
+
parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.")
|
| 17 |
+
args = parser.parse_args()
|
| 18 |
+
|
| 19 |
+
input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True)
|
| 20 |
+
|
| 21 |
+
for metadata_filename in tqdm(input_metadata_filenames):
|
| 22 |
+
output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir))
|
| 23 |
+
# Do not process the scene if the metadata file already exists
|
| 24 |
+
if os.path.exists(os.path.join(output_dir, "metadata.json")):
|
| 25 |
+
continue
|
| 26 |
+
commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
|
| 27 |
+
print(commandline)
|
dust3r/croco/datasets/habitat_sim/generate_multiview_images.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import argparse
|
| 7 |
+
import PIL.Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
import json
|
| 10 |
+
from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError
|
| 11 |
+
from datasets.habitat_sim.paths import list_scenes_available
|
| 12 |
+
import cv2
|
| 13 |
+
import quaternion
|
| 14 |
+
import shutil
|
| 15 |
+
|
| 16 |
+
def generate_multiview_images_for_scene(scene_dataset_config_file,
|
| 17 |
+
scene,
|
| 18 |
+
navmesh,
|
| 19 |
+
output_dir,
|
| 20 |
+
views_count,
|
| 21 |
+
size,
|
| 22 |
+
exist_ok=False,
|
| 23 |
+
generate_depth=False,
|
| 24 |
+
**kwargs):
|
| 25 |
+
"""
|
| 26 |
+
Generate tuples of overlapping views for a given scene.
|
| 27 |
+
generate_depth: generate depth images and camera parameters.
|
| 28 |
+
"""
|
| 29 |
+
if os.path.exists(output_dir) and not exist_ok:
|
| 30 |
+
print(f"Scene {scene}: data already generated. Ignoring generation.")
|
| 31 |
+
return
|
| 32 |
+
try:
|
| 33 |
+
print(f"Scene {scene}: {size} multiview acquisitions to generate...")
|
| 34 |
+
os.makedirs(output_dir, exist_ok=exist_ok)
|
| 35 |
+
|
| 36 |
+
metadata_filename = os.path.join(output_dir, "metadata.json")
|
| 37 |
+
|
| 38 |
+
metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file,
|
| 39 |
+
scene=scene,
|
| 40 |
+
navmesh=navmesh,
|
| 41 |
+
views_count=views_count,
|
| 42 |
+
size=size,
|
| 43 |
+
generate_depth=generate_depth,
|
| 44 |
+
**kwargs)
|
| 45 |
+
metadata_template["multiviews"] = dict()
|
| 46 |
+
|
| 47 |
+
if os.path.exists(metadata_filename):
|
| 48 |
+
print("Metadata file already exists:", metadata_filename)
|
| 49 |
+
print("Loading already generated metadata file...")
|
| 50 |
+
with open(metadata_filename, "r") as f:
|
| 51 |
+
metadata = json.load(f)
|
| 52 |
+
|
| 53 |
+
for key in metadata_template.keys():
|
| 54 |
+
if key != "multiviews":
|
| 55 |
+
assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}."
|
| 56 |
+
else:
|
| 57 |
+
print("No temporary file found. Starting generation from scratch...")
|
| 58 |
+
metadata = metadata_template
|
| 59 |
+
|
| 60 |
+
starting_id = len(metadata["multiviews"])
|
| 61 |
+
print(f"Starting generation from index {starting_id}/{size}...")
|
| 62 |
+
if starting_id >= size:
|
| 63 |
+
print("Generation already done.")
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file,
|
| 67 |
+
scene=scene,
|
| 68 |
+
navmesh=navmesh,
|
| 69 |
+
views_count = views_count,
|
| 70 |
+
size = size,
|
| 71 |
+
**kwargs)
|
| 72 |
+
|
| 73 |
+
for idx in tqdm(range(starting_id, size)):
|
| 74 |
+
# Generate / re-generate the observations
|
| 75 |
+
try:
|
| 76 |
+
data = generator[idx]
|
| 77 |
+
observations = data["observations"]
|
| 78 |
+
positions = data["positions"]
|
| 79 |
+
orientations = data["orientations"]
|
| 80 |
+
|
| 81 |
+
idx_label = f"{idx:08}"
|
| 82 |
+
for oidx, observation in enumerate(observations):
|
| 83 |
+
observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
|
| 84 |
+
# Color image saved using PIL
|
| 85 |
+
img = PIL.Image.fromarray(observation['color'][:,:,:3])
|
| 86 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
|
| 87 |
+
img.save(filename)
|
| 88 |
+
if generate_depth:
|
| 89 |
+
# Depth image as EXR file
|
| 90 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
|
| 91 |
+
cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
|
| 92 |
+
# Camera parameters
|
| 93 |
+
camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
|
| 94 |
+
filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
|
| 95 |
+
with open(filename, "w") as f:
|
| 96 |
+
json.dump(camera_params, f)
|
| 97 |
+
metadata["multiviews"][idx_label] = {"positions": positions.tolist(),
|
| 98 |
+
"orientations": orientations.tolist(),
|
| 99 |
+
"covisibility_ratios": data["covisibility_ratios"].tolist(),
|
| 100 |
+
"valid_fractions": data["valid_fractions"].tolist(),
|
| 101 |
+
"pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()}
|
| 102 |
+
except RecursionError:
|
| 103 |
+
print("Recursion error: unable to sample observations for this scene. We will stop there.")
|
| 104 |
+
break
|
| 105 |
+
|
| 106 |
+
# Regularly save a temporary metadata file, in case we need to restart the generation
|
| 107 |
+
if idx % 10 == 0:
|
| 108 |
+
with open(metadata_filename, "w") as f:
|
| 109 |
+
json.dump(metadata, f)
|
| 110 |
+
|
| 111 |
+
# Save metadata
|
| 112 |
+
with open(metadata_filename, "w") as f:
|
| 113 |
+
json.dump(metadata, f)
|
| 114 |
+
|
| 115 |
+
generator.close()
|
| 116 |
+
except NoNaviguableSpaceError:
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
def create_commandline(scene_data, generate_depth, exist_ok=False):
|
| 120 |
+
"""
|
| 121 |
+
Create a commandline string to generate a scene.
|
| 122 |
+
"""
|
| 123 |
+
def my_formatting(val):
|
| 124 |
+
if val is None or val == "":
|
| 125 |
+
return '""'
|
| 126 |
+
else:
|
| 127 |
+
return val
|
| 128 |
+
commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)}
|
| 129 |
+
--scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)}
|
| 130 |
+
--navmesh {my_formatting(scene_data.navmesh)}
|
| 131 |
+
--output_dir {my_formatting(scene_data.output_dir)}
|
| 132 |
+
--generate_depth {int(generate_depth)}
|
| 133 |
+
--exist_ok {int(exist_ok)}
|
| 134 |
+
"""
|
| 135 |
+
commandline = " ".join(commandline.split())
|
| 136 |
+
return commandline
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
os.umask(2)
|
| 140 |
+
|
| 141 |
+
parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available:
|
| 142 |
+
> python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands
|
| 143 |
+
""")
|
| 144 |
+
|
| 145 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
| 146 |
+
parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true")
|
| 147 |
+
parser.add_argument("--scene", type=str, default="")
|
| 148 |
+
parser.add_argument("--scene_dataset_config_file", type=str, default="")
|
| 149 |
+
parser.add_argument("--navmesh", type=str, default="")
|
| 150 |
+
|
| 151 |
+
parser.add_argument("--generate_depth", type=int, default=1)
|
| 152 |
+
parser.add_argument("--exist_ok", type=int, default=0)
|
| 153 |
+
|
| 154 |
+
kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000)
|
| 155 |
+
|
| 156 |
+
args = parser.parse_args()
|
| 157 |
+
generate_depth=bool(args.generate_depth)
|
| 158 |
+
exist_ok = bool(args.exist_ok)
|
| 159 |
+
|
| 160 |
+
if args.list_commands:
|
| 161 |
+
# Listing scenes available...
|
| 162 |
+
scenes_data = list_scenes_available(base_output_dir=args.output_dir)
|
| 163 |
+
|
| 164 |
+
for scene_data in scenes_data:
|
| 165 |
+
print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok))
|
| 166 |
+
else:
|
| 167 |
+
if args.scene == "" or args.output_dir == "":
|
| 168 |
+
print("Missing scene or output dir argument!")
|
| 169 |
+
print(parser.format_help())
|
| 170 |
+
else:
|
| 171 |
+
generate_multiview_images_for_scene(scene=args.scene,
|
| 172 |
+
scene_dataset_config_file = args.scene_dataset_config_file,
|
| 173 |
+
navmesh = args.navmesh,
|
| 174 |
+
output_dir = args.output_dir,
|
| 175 |
+
exist_ok=exist_ok,
|
| 176 |
+
generate_depth=generate_depth,
|
| 177 |
+
**kwargs)
|
dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import quaternion
|
| 7 |
+
import habitat_sim
|
| 8 |
+
import json
|
| 9 |
+
from sklearn.neighbors import NearestNeighbors
|
| 10 |
+
import cv2
|
| 11 |
+
|
| 12 |
+
# OpenCV to habitat camera convention transformation
|
| 13 |
+
R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)
|
| 14 |
+
R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
|
| 15 |
+
DEG2RAD = np.pi / 180
|
| 16 |
+
|
| 17 |
+
def compute_camera_intrinsics(height, width, hfov):
|
| 18 |
+
f = width/2 / np.tan(hfov/2 * np.pi/180)
|
| 19 |
+
cu, cv = width/2, height/2
|
| 20 |
+
return f, cu, cv
|
| 21 |
+
|
| 22 |
+
def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
|
| 23 |
+
R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
|
| 24 |
+
t_cam2world = np.asarray(camera_position)
|
| 25 |
+
return R_cam2world, t_cam2world
|
| 26 |
+
|
| 27 |
+
def compute_pointmap(depthmap, hfov):
|
| 28 |
+
""" Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
|
| 29 |
+
height, width = depthmap.shape
|
| 30 |
+
f, cu, cv = compute_camera_intrinsics(height, width, hfov)
|
| 31 |
+
# Cast depth map to point
|
| 32 |
+
z_cam = depthmap
|
| 33 |
+
u, v = np.meshgrid(range(width), range(height))
|
| 34 |
+
x_cam = (u - cu) / f * z_cam
|
| 35 |
+
y_cam = (v - cv) / f * z_cam
|
| 36 |
+
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
|
| 37 |
+
return X_cam
|
| 38 |
+
|
| 39 |
+
def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
|
| 40 |
+
"""Return a 3D point cloud corresponding to valid pixels of the depth map"""
|
| 41 |
+
R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation)
|
| 42 |
+
|
| 43 |
+
X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
|
| 44 |
+
valid_mask = (X_cam[:,:,2] != 0.0)
|
| 45 |
+
|
| 46 |
+
X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
|
| 47 |
+
X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
|
| 48 |
+
return X_world
|
| 49 |
+
|
| 50 |
+
def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False):
|
| 51 |
+
"""
|
| 52 |
+
Compute 'overlapping' metrics based on a distance threshold between two point clouds.
|
| 53 |
+
"""
|
| 54 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2)
|
| 55 |
+
distances, indices = nbrs.kneighbors(pointcloud1)
|
| 56 |
+
intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
|
| 57 |
+
|
| 58 |
+
data = {"intersection1": intersection1,
|
| 59 |
+
"size1": len(pointcloud1)}
|
| 60 |
+
if compute_symmetric:
|
| 61 |
+
nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1)
|
| 62 |
+
distances, indices = nbrs.kneighbors(pointcloud2)
|
| 63 |
+
intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
|
| 64 |
+
data["intersection2"] = intersection2
|
| 65 |
+
data["size2"] = len(pointcloud2)
|
| 66 |
+
|
| 67 |
+
return data
|
| 68 |
+
|
| 69 |
+
def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
|
| 70 |
+
"""
|
| 71 |
+
Add camera parameters to the observation dictionnary produced by Habitat-Sim
|
| 72 |
+
In-place modifications.
|
| 73 |
+
"""
|
| 74 |
+
R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation)
|
| 75 |
+
height, width = observation['depth'].shape
|
| 76 |
+
f, cu, cv = compute_camera_intrinsics(height, width, hfov)
|
| 77 |
+
K = np.asarray([[f, 0, cu],
|
| 78 |
+
[0, f, cv],
|
| 79 |
+
[0, 0, 1.0]])
|
| 80 |
+
observation["camera_intrinsics"] = K
|
| 81 |
+
observation["t_cam2world"] = t_cam2world
|
| 82 |
+
observation["R_cam2world"] = R_cam2world
|
| 83 |
+
|
| 84 |
+
def look_at(eye, center, up, return_cam2world=True):
|
| 85 |
+
"""
|
| 86 |
+
Return camera pose looking at a given center point.
|
| 87 |
+
Analogous of gluLookAt function, using OpenCV camera convention.
|
| 88 |
+
"""
|
| 89 |
+
z = center - eye
|
| 90 |
+
z /= np.linalg.norm(z, axis=-1, keepdims=True)
|
| 91 |
+
y = -up
|
| 92 |
+
y = y - np.sum(y * z, axis=-1, keepdims=True) * z
|
| 93 |
+
y /= np.linalg.norm(y, axis=-1, keepdims=True)
|
| 94 |
+
x = np.cross(y, z, axis=-1)
|
| 95 |
+
|
| 96 |
+
if return_cam2world:
|
| 97 |
+
R = np.stack((x, y, z), axis=-1)
|
| 98 |
+
t = eye
|
| 99 |
+
else:
|
| 100 |
+
# World to camera transformation
|
| 101 |
+
# Transposed matrix
|
| 102 |
+
R = np.stack((x, y, z), axis=-2)
|
| 103 |
+
t = - np.einsum('...ij, ...j', R, eye)
|
| 104 |
+
return R, t
|
| 105 |
+
|
| 106 |
+
def look_at_for_habitat(eye, center, up, return_cam2world=True):
|
| 107 |
+
R, t = look_at(eye, center, up)
|
| 108 |
+
orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
|
| 109 |
+
return orientation, t
|
| 110 |
+
|
| 111 |
+
def generate_orientation_noise(pan_range, tilt_range, roll_range):
|
| 112 |
+
return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP)
|
| 113 |
+
* quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT)
|
| 114 |
+
* quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT))
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class NoNaviguableSpaceError(RuntimeError):
|
| 118 |
+
def __init__(self, *args):
|
| 119 |
+
super().__init__(*args)
|
| 120 |
+
|
| 121 |
+
class MultiviewHabitatSimGenerator:
|
| 122 |
+
def __init__(self,
|
| 123 |
+
scene,
|
| 124 |
+
navmesh,
|
| 125 |
+
scene_dataset_config_file,
|
| 126 |
+
resolution = (240, 320),
|
| 127 |
+
views_count=2,
|
| 128 |
+
hfov = 60,
|
| 129 |
+
gpu_id = 0,
|
| 130 |
+
size = 10000,
|
| 131 |
+
minimum_covisibility = 0.5,
|
| 132 |
+
transform = None):
|
| 133 |
+
self.scene = scene
|
| 134 |
+
self.navmesh = navmesh
|
| 135 |
+
self.scene_dataset_config_file = scene_dataset_config_file
|
| 136 |
+
self.resolution = resolution
|
| 137 |
+
self.views_count = views_count
|
| 138 |
+
assert(self.views_count >= 1)
|
| 139 |
+
self.hfov = hfov
|
| 140 |
+
self.gpu_id = gpu_id
|
| 141 |
+
self.size = size
|
| 142 |
+
self.transform = transform
|
| 143 |
+
|
| 144 |
+
# Noise added to camera orientation
|
| 145 |
+
self.pan_range = (-3, 3)
|
| 146 |
+
self.tilt_range = (-10, 10)
|
| 147 |
+
self.roll_range = (-5, 5)
|
| 148 |
+
|
| 149 |
+
# Height range to sample cameras
|
| 150 |
+
self.height_range = (1.2, 1.8)
|
| 151 |
+
|
| 152 |
+
# Random steps between the camera views
|
| 153 |
+
self.random_steps_count = 5
|
| 154 |
+
self.random_step_variance = 2.0
|
| 155 |
+
|
| 156 |
+
# Minimum fraction of the scene which should be valid (well defined depth)
|
| 157 |
+
self.minimum_valid_fraction = 0.7
|
| 158 |
+
|
| 159 |
+
# Distance threshold to see to select pairs
|
| 160 |
+
self.distance_threshold = 0.05
|
| 161 |
+
# Minimum IoU of a view point cloud with respect to the reference view to be kept.
|
| 162 |
+
self.minimum_covisibility = minimum_covisibility
|
| 163 |
+
|
| 164 |
+
# Maximum number of retries.
|
| 165 |
+
self.max_attempts_count = 100
|
| 166 |
+
|
| 167 |
+
self.seed = None
|
| 168 |
+
self._lazy_initialization()
|
| 169 |
+
|
| 170 |
+
def _lazy_initialization(self):
|
| 171 |
+
# Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
|
| 172 |
+
if self.seed == None:
|
| 173 |
+
# Re-seed numpy generator
|
| 174 |
+
np.random.seed()
|
| 175 |
+
self.seed = np.random.randint(2**32-1)
|
| 176 |
+
sim_cfg = habitat_sim.SimulatorConfiguration()
|
| 177 |
+
sim_cfg.scene_id = self.scene
|
| 178 |
+
if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "":
|
| 179 |
+
sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
|
| 180 |
+
sim_cfg.random_seed = self.seed
|
| 181 |
+
sim_cfg.load_semantic_mesh = False
|
| 182 |
+
sim_cfg.gpu_device_id = self.gpu_id
|
| 183 |
+
|
| 184 |
+
depth_sensor_spec = habitat_sim.CameraSensorSpec()
|
| 185 |
+
depth_sensor_spec.uuid = "depth"
|
| 186 |
+
depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
|
| 187 |
+
depth_sensor_spec.resolution = self.resolution
|
| 188 |
+
depth_sensor_spec.hfov = self.hfov
|
| 189 |
+
depth_sensor_spec.position = [0.0, 0.0, 0]
|
| 190 |
+
depth_sensor_spec.orientation
|
| 191 |
+
|
| 192 |
+
rgb_sensor_spec = habitat_sim.CameraSensorSpec()
|
| 193 |
+
rgb_sensor_spec.uuid = "color"
|
| 194 |
+
rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
|
| 195 |
+
rgb_sensor_spec.resolution = self.resolution
|
| 196 |
+
rgb_sensor_spec.hfov = self.hfov
|
| 197 |
+
rgb_sensor_spec.position = [0.0, 0.0, 0]
|
| 198 |
+
agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec])
|
| 199 |
+
|
| 200 |
+
cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
|
| 201 |
+
self.sim = habitat_sim.Simulator(cfg)
|
| 202 |
+
if self.navmesh is not None and self.navmesh != "":
|
| 203 |
+
# Use pre-computed navmesh when available (usually better than those generated automatically)
|
| 204 |
+
self.sim.pathfinder.load_nav_mesh(self.navmesh)
|
| 205 |
+
|
| 206 |
+
if not self.sim.pathfinder.is_loaded:
|
| 207 |
+
# Try to compute a navmesh
|
| 208 |
+
navmesh_settings = habitat_sim.NavMeshSettings()
|
| 209 |
+
navmesh_settings.set_defaults()
|
| 210 |
+
self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
|
| 211 |
+
|
| 212 |
+
# Ensure that the navmesh is not empty
|
| 213 |
+
if not self.sim.pathfinder.is_loaded:
|
| 214 |
+
raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})")
|
| 215 |
+
|
| 216 |
+
self.agent = self.sim.initialize_agent(agent_id=0)
|
| 217 |
+
|
| 218 |
+
def close(self):
|
| 219 |
+
self.sim.close()
|
| 220 |
+
|
| 221 |
+
def __del__(self):
|
| 222 |
+
self.sim.close()
|
| 223 |
+
|
| 224 |
+
def __len__(self):
|
| 225 |
+
return self.size
|
| 226 |
+
|
| 227 |
+
def sample_random_viewpoint(self):
|
| 228 |
+
""" Sample a random viewpoint using the navmesh """
|
| 229 |
+
nav_point = self.sim.pathfinder.get_random_navigable_point()
|
| 230 |
+
|
| 231 |
+
# Sample a random viewpoint height
|
| 232 |
+
viewpoint_height = np.random.uniform(*self.height_range)
|
| 233 |
+
viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
|
| 234 |
+
viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
|
| 235 |
+
return viewpoint_position, viewpoint_orientation, nav_point
|
| 236 |
+
|
| 237 |
+
def sample_other_random_viewpoint(self, observed_point, nav_point):
|
| 238 |
+
""" Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
|
| 239 |
+
other_nav_point = nav_point
|
| 240 |
+
|
| 241 |
+
walk_directions = self.random_step_variance * np.asarray([1,0,1])
|
| 242 |
+
for i in range(self.random_steps_count):
|
| 243 |
+
temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3))
|
| 244 |
+
# Snapping may return nan when it fails
|
| 245 |
+
if not np.isnan(temp[0]):
|
| 246 |
+
other_nav_point = temp
|
| 247 |
+
|
| 248 |
+
other_viewpoint_height = np.random.uniform(*self.height_range)
|
| 249 |
+
other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
|
| 250 |
+
|
| 251 |
+
# Set viewing direction towards the central point
|
| 252 |
+
rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True)
|
| 253 |
+
rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
|
| 254 |
+
return position, rotation, other_nav_point
|
| 255 |
+
|
| 256 |
+
def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
|
| 257 |
+
""" Check if a viewpoint is valid and overlaps significantly with a reference one. """
|
| 258 |
+
# Observation
|
| 259 |
+
pixels_count = self.resolution[0] * self.resolution[1]
|
| 260 |
+
valid_fraction = len(other_pointcloud) / pixels_count
|
| 261 |
+
assert valid_fraction <= 1.0 and valid_fraction >= 0.0
|
| 262 |
+
overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True)
|
| 263 |
+
covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count)
|
| 264 |
+
is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility)
|
| 265 |
+
return is_valid, valid_fraction, covisibility
|
| 266 |
+
|
| 267 |
+
def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation):
|
| 268 |
+
""" Check if a viewpoint is valid and overlaps significantly with a reference one. """
|
| 269 |
+
# Observation
|
| 270 |
+
other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation)
|
| 271 |
+
return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
|
| 272 |
+
|
| 273 |
+
def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
|
| 274 |
+
agent_state = habitat_sim.AgentState()
|
| 275 |
+
agent_state.position = viewpoint_position
|
| 276 |
+
agent_state.rotation = viewpoint_orientation
|
| 277 |
+
self.agent.set_state(agent_state)
|
| 278 |
+
viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
|
| 279 |
+
_append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation)
|
| 280 |
+
return viewpoint_observations
|
| 281 |
+
|
| 282 |
+
def __getitem__(self, useless_idx):
|
| 283 |
+
ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
|
| 284 |
+
ref_observations = self.render_viewpoint(ref_position, ref_orientation)
|
| 285 |
+
# Extract point cloud
|
| 286 |
+
ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
|
| 287 |
+
camera_position=ref_position, camera_rotation=ref_orientation)
|
| 288 |
+
|
| 289 |
+
pixels_count = self.resolution[0] * self.resolution[1]
|
| 290 |
+
ref_valid_fraction = len(ref_pointcloud) / pixels_count
|
| 291 |
+
assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
|
| 292 |
+
if ref_valid_fraction < self.minimum_valid_fraction:
|
| 293 |
+
# This should produce a recursion error at some point when something is very wrong.
|
| 294 |
+
return self[0]
|
| 295 |
+
# Pick an reference observed point in the point cloud
|
| 296 |
+
observed_point = np.mean(ref_pointcloud, axis=0)
|
| 297 |
+
|
| 298 |
+
# Add the first image as reference
|
| 299 |
+
viewpoints_observations = [ref_observations]
|
| 300 |
+
viewpoints_covisibility = [ref_valid_fraction]
|
| 301 |
+
viewpoints_positions = [ref_position]
|
| 302 |
+
viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
|
| 303 |
+
viewpoints_clouds = [ref_pointcloud]
|
| 304 |
+
viewpoints_valid_fractions = [ref_valid_fraction]
|
| 305 |
+
|
| 306 |
+
for _ in range(self.views_count - 1):
|
| 307 |
+
# Generate an other viewpoint using some dummy random walk
|
| 308 |
+
successful_sampling = False
|
| 309 |
+
for sampling_attempt in range(self.max_attempts_count):
|
| 310 |
+
position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point)
|
| 311 |
+
# Observation
|
| 312 |
+
other_viewpoint_observations = self.render_viewpoint(position, rotation)
|
| 313 |
+
other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation)
|
| 314 |
+
|
| 315 |
+
is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
|
| 316 |
+
if is_valid:
|
| 317 |
+
successful_sampling = True
|
| 318 |
+
break
|
| 319 |
+
if not successful_sampling:
|
| 320 |
+
print("WARNING: Maximum number of attempts reached.")
|
| 321 |
+
# Dirty hack, try using a novel original viewpoint
|
| 322 |
+
return self[0]
|
| 323 |
+
viewpoints_observations.append(other_viewpoint_observations)
|
| 324 |
+
viewpoints_covisibility.append(covisibility)
|
| 325 |
+
viewpoints_positions.append(position)
|
| 326 |
+
viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding.
|
| 327 |
+
viewpoints_clouds.append(other_pointcloud)
|
| 328 |
+
viewpoints_valid_fractions.append(valid_fraction)
|
| 329 |
+
|
| 330 |
+
# Estimate relations between all pairs of images
|
| 331 |
+
pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations)))
|
| 332 |
+
for i in range(len(viewpoints_observations)):
|
| 333 |
+
pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i]
|
| 334 |
+
for j in range(i+1, len(viewpoints_observations)):
|
| 335 |
+
overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True)
|
| 336 |
+
pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count
|
| 337 |
+
pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count
|
| 338 |
+
|
| 339 |
+
# IoU is relative to the image 0
|
| 340 |
+
data = {"observations": viewpoints_observations,
|
| 341 |
+
"positions": np.asarray(viewpoints_positions),
|
| 342 |
+
"orientations": np.asarray(viewpoints_orientations),
|
| 343 |
+
"covisibility_ratios": np.asarray(viewpoints_covisibility),
|
| 344 |
+
"valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
|
| 345 |
+
"pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float),
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
if self.transform is not None:
|
| 349 |
+
data = self.transform(data)
|
| 350 |
+
return data
|
| 351 |
+
|
| 352 |
+
def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False):
|
| 353 |
+
"""
|
| 354 |
+
Return a list of images corresponding to a spiral trajectory from a random starting point.
|
| 355 |
+
Useful to generate nice visualisations.
|
| 356 |
+
Use an even number of half turns to get a nice "C1-continuous" loop effect
|
| 357 |
+
"""
|
| 358 |
+
ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
|
| 359 |
+
ref_observations = self.render_viewpoint(ref_position, ref_orientation)
|
| 360 |
+
ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
|
| 361 |
+
camera_position=ref_position, camera_rotation=ref_orientation)
|
| 362 |
+
pixels_count = self.resolution[0] * self.resolution[1]
|
| 363 |
+
if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
|
| 364 |
+
# Dirty hack: ensure that the valid part of the image is significant
|
| 365 |
+
return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation)
|
| 366 |
+
|
| 367 |
+
# Pick an observed point in the point cloud
|
| 368 |
+
observed_point = np.mean(ref_pointcloud, axis=0)
|
| 369 |
+
ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation)
|
| 370 |
+
|
| 371 |
+
images = []
|
| 372 |
+
is_valid = []
|
| 373 |
+
# Spiral trajectory, use_constant orientation
|
| 374 |
+
for i, alpha in enumerate(np.linspace(0, 1, images_count)):
|
| 375 |
+
r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius
|
| 376 |
+
theta = alpha * half_turns * np.pi
|
| 377 |
+
x = r * np.cos(theta)
|
| 378 |
+
y = r * np.sin(theta)
|
| 379 |
+
z = 0.0
|
| 380 |
+
position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten()
|
| 381 |
+
if use_constant_orientation:
|
| 382 |
+
orientation = ref_orientation
|
| 383 |
+
else:
|
| 384 |
+
# trajectory looking at a mean point in front of the ref observation
|
| 385 |
+
orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP)
|
| 386 |
+
observations = self.render_viewpoint(position, orientation)
|
| 387 |
+
images.append(observations['color'][...,:3])
|
| 388 |
+
_is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation)
|
| 389 |
+
is_valid.append(_is_valid)
|
| 390 |
+
return images, np.all(is_valid)
|
dust3r/croco/datasets/habitat_sim/pack_metadata_files.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
"""
|
| 4 |
+
Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import glob
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import shutil
|
| 10 |
+
import json
|
| 11 |
+
from datasets.habitat_sim.paths import *
|
| 12 |
+
import argparse
|
| 13 |
+
import collections
|
| 14 |
+
|
| 15 |
+
if __name__ == "__main__":
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
parser.add_argument("input_dir")
|
| 18 |
+
parser.add_argument("output_dir")
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
input_dirname = args.input_dir
|
| 22 |
+
output_dirname = args.output_dir
|
| 23 |
+
|
| 24 |
+
input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True)
|
| 25 |
+
|
| 26 |
+
images_count = collections.defaultdict(lambda : 0)
|
| 27 |
+
|
| 28 |
+
os.makedirs(output_dirname)
|
| 29 |
+
for input_filename in tqdm(input_metadata_filenames):
|
| 30 |
+
# Ignore empty files
|
| 31 |
+
with open(input_filename, "r") as f:
|
| 32 |
+
original_metadata = json.load(f)
|
| 33 |
+
if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0:
|
| 34 |
+
print("No views in", input_filename)
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
relpath = os.path.relpath(input_filename, input_dirname)
|
| 38 |
+
print(relpath)
|
| 39 |
+
|
| 40 |
+
# Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
|
| 41 |
+
# Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
|
| 42 |
+
scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True))
|
| 43 |
+
metadata = dict()
|
| 44 |
+
for key, value in original_metadata.items():
|
| 45 |
+
if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
|
| 46 |
+
known_path = False
|
| 47 |
+
for dataset, dataset_path in scenes_dataset_paths.items():
|
| 48 |
+
if value.startswith(dataset_path):
|
| 49 |
+
value = os.path.join(dataset, os.path.relpath(value, dataset_path))
|
| 50 |
+
known_path = True
|
| 51 |
+
break
|
| 52 |
+
if not known_path:
|
| 53 |
+
raise KeyError("Unknown path:" + value)
|
| 54 |
+
metadata[key] = value
|
| 55 |
+
|
| 56 |
+
# Compile some general statistics while packing data
|
| 57 |
+
scene_split = metadata["scene"].split("/")
|
| 58 |
+
upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
|
| 59 |
+
images_count[upper_level] += len(metadata["multiviews"])
|
| 60 |
+
|
| 61 |
+
output_filename = os.path.join(output_dirname, relpath)
|
| 62 |
+
os.makedirs(os.path.dirname(output_filename), exist_ok=True)
|
| 63 |
+
with open(output_filename, "w") as f:
|
| 64 |
+
json.dump(metadata, f)
|
| 65 |
+
|
| 66 |
+
# Print statistics
|
| 67 |
+
print("Images count:")
|
| 68 |
+
for upper_level, count in images_count.items():
|
| 69 |
+
print(f"- {upper_level}: {count}")
|
dust3r/croco/datasets/habitat_sim/paths.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Paths to Habitat-Sim scenes
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
import collections
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Hardcoded path to the different scene datasets
|
| 15 |
+
SCENES_DATASET = {
|
| 16 |
+
"hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
|
| 17 |
+
"gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
|
| 18 |
+
"habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
|
| 19 |
+
"replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
|
| 20 |
+
"replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
|
| 21 |
+
"replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
|
| 22 |
+
"scannet": "./data/habitat-sim/scene_datasets/scannet/"
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"])
|
| 26 |
+
|
| 27 |
+
def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
|
| 28 |
+
scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json")
|
| 29 |
+
scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
|
| 30 |
+
navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
|
| 31 |
+
scenes_data = []
|
| 32 |
+
for idx in range(len(scenes)):
|
| 33 |
+
output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
|
| 34 |
+
# Add scene
|
| 35 |
+
data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
|
| 36 |
+
scene = scenes[idx] + ".scene_instance.json",
|
| 37 |
+
navmesh = os.path.join(base_path, navmeshes[idx]),
|
| 38 |
+
output_dir = output_dir)
|
| 39 |
+
scenes_data.append(data)
|
| 40 |
+
return scenes_data
|
| 41 |
+
|
| 42 |
+
def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]):
|
| 43 |
+
scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json")
|
| 44 |
+
scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [])
|
| 45 |
+
navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
|
| 46 |
+
scenes_data = []
|
| 47 |
+
for idx in range(len(scenes)):
|
| 48 |
+
output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx])
|
| 49 |
+
data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
|
| 50 |
+
scene = scenes[idx],
|
| 51 |
+
navmesh = "",
|
| 52 |
+
output_dir = output_dir)
|
| 53 |
+
scenes_data.append(data)
|
| 54 |
+
return scenes_data
|
| 55 |
+
|
| 56 |
+
def list_replica_scenes(base_output_dir, base_path):
|
| 57 |
+
scenes_data = []
|
| 58 |
+
for scene_id in os.listdir(base_path):
|
| 59 |
+
scene = os.path.join(base_path, scene_id, "mesh.ply")
|
| 60 |
+
navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it
|
| 61 |
+
scene_dataset_config_file = ""
|
| 62 |
+
output_dir = os.path.join(base_output_dir, scene_id)
|
| 63 |
+
# Add scene only if it does not exist already, or if exist_ok
|
| 64 |
+
data = SceneData(scene_dataset_config_file = scene_dataset_config_file,
|
| 65 |
+
scene = scene,
|
| 66 |
+
navmesh = navmesh,
|
| 67 |
+
output_dir = output_dir)
|
| 68 |
+
scenes_data.append(data)
|
| 69 |
+
return scenes_data
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def list_scenes(base_output_dir, base_path):
|
| 73 |
+
"""
|
| 74 |
+
Generic method iterating through a base_path folder to find scenes.
|
| 75 |
+
"""
|
| 76 |
+
scenes_data = []
|
| 77 |
+
for root, dirs, files in os.walk(base_path, followlinks=True):
|
| 78 |
+
folder_scenes_data = []
|
| 79 |
+
for file in files:
|
| 80 |
+
name, ext = os.path.splitext(file)
|
| 81 |
+
if ext == ".glb":
|
| 82 |
+
scene = os.path.join(root, name + ".glb")
|
| 83 |
+
navmesh = os.path.join(root, name + ".navmesh")
|
| 84 |
+
if not os.path.exists(navmesh):
|
| 85 |
+
navmesh = ""
|
| 86 |
+
relpath = os.path.relpath(root, base_path)
|
| 87 |
+
output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name))
|
| 88 |
+
data = SceneData(scene_dataset_config_file="",
|
| 89 |
+
scene = scene,
|
| 90 |
+
navmesh = navmesh,
|
| 91 |
+
output_dir = output_dir)
|
| 92 |
+
folder_scenes_data.append(data)
|
| 93 |
+
|
| 94 |
+
# Specific check for HM3D:
|
| 95 |
+
# When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
|
| 96 |
+
basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")]
|
| 97 |
+
if len(basis_scenes) != 0:
|
| 98 |
+
folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)]
|
| 99 |
+
|
| 100 |
+
scenes_data.extend(folder_scenes_data)
|
| 101 |
+
return scenes_data
|
| 102 |
+
|
| 103 |
+
def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
|
| 104 |
+
scenes_data = []
|
| 105 |
+
|
| 106 |
+
# HM3D
|
| 107 |
+
for split in ("minival", "train", "val", "examples"):
|
| 108 |
+
scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
|
| 109 |
+
base_path=f"{scenes_dataset_paths['hm3d']}/{split}")
|
| 110 |
+
|
| 111 |
+
# Gibson
|
| 112 |
+
scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"),
|
| 113 |
+
base_path=scenes_dataset_paths["gibson"])
|
| 114 |
+
|
| 115 |
+
# Habitat test scenes (just a few)
|
| 116 |
+
scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
|
| 117 |
+
base_path=scenes_dataset_paths["habitat-test-scenes"])
|
| 118 |
+
|
| 119 |
+
# ReplicaCAD (baked lightning)
|
| 120 |
+
scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir)
|
| 121 |
+
|
| 122 |
+
# ScanNet
|
| 123 |
+
scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"),
|
| 124 |
+
base_path=scenes_dataset_paths["scannet"])
|
| 125 |
+
|
| 126 |
+
# Replica
|
| 127 |
+
list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"),
|
| 128 |
+
base_path=scenes_dataset_paths["replica"])
|
| 129 |
+
return scenes_data
|
dust3r/croco/datasets/pairs_dataset.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from datasets.transforms import get_pair_transforms
|
| 9 |
+
|
| 10 |
+
def load_image(impath):
|
| 11 |
+
return Image.open(impath)
|
| 12 |
+
|
| 13 |
+
def load_pairs_from_cache_file(fname, root=''):
|
| 14 |
+
assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
|
| 15 |
+
with open(fname, 'r') as fid:
|
| 16 |
+
lines = fid.read().strip().splitlines()
|
| 17 |
+
pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines]
|
| 18 |
+
return pairs
|
| 19 |
+
|
| 20 |
+
def load_pairs_from_list_file(fname, root=''):
|
| 21 |
+
assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
|
| 22 |
+
with open(fname, 'r') as fid:
|
| 23 |
+
lines = fid.read().strip().splitlines()
|
| 24 |
+
pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')]
|
| 25 |
+
return pairs
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def write_cache_file(fname, pairs, root=''):
|
| 29 |
+
if len(root)>0:
|
| 30 |
+
if not root.endswith('/'): root+='/'
|
| 31 |
+
assert os.path.isdir(root)
|
| 32 |
+
s = ''
|
| 33 |
+
for im1, im2 in pairs:
|
| 34 |
+
if len(root)>0:
|
| 35 |
+
assert im1.startswith(root), im1
|
| 36 |
+
assert im2.startswith(root), im2
|
| 37 |
+
s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):])
|
| 38 |
+
with open(fname, 'w') as fid:
|
| 39 |
+
fid.write(s[:-1])
|
| 40 |
+
|
| 41 |
+
def parse_and_cache_all_pairs(dname, data_dir='./data/'):
|
| 42 |
+
if dname=='habitat_release':
|
| 43 |
+
dirname = os.path.join(data_dir, 'habitat_release')
|
| 44 |
+
assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
|
| 45 |
+
cache_file = os.path.join(dirname, 'pairs.txt')
|
| 46 |
+
assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file
|
| 47 |
+
|
| 48 |
+
print('Parsing pairs for dataset: '+dname)
|
| 49 |
+
pairs = []
|
| 50 |
+
for root, dirs, files in os.walk(dirname):
|
| 51 |
+
if 'val' in root: continue
|
| 52 |
+
dirs.sort()
|
| 53 |
+
pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')]
|
| 54 |
+
print('Found {:,} pairs'.format(len(pairs)))
|
| 55 |
+
print('Writing cache to: '+cache_file)
|
| 56 |
+
write_cache_file(cache_file, pairs, root=dirname)
|
| 57 |
+
|
| 58 |
+
else:
|
| 59 |
+
raise NotImplementedError('Unknown dataset: '+dname)
|
| 60 |
+
|
| 61 |
+
def dnames_to_image_pairs(dnames, data_dir='./data/'):
|
| 62 |
+
"""
|
| 63 |
+
dnames: list of datasets with image pairs, separated by +
|
| 64 |
+
"""
|
| 65 |
+
all_pairs = []
|
| 66 |
+
for dname in dnames.split('+'):
|
| 67 |
+
if dname=='habitat_release':
|
| 68 |
+
dirname = os.path.join(data_dir, 'habitat_release')
|
| 69 |
+
assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
|
| 70 |
+
cache_file = os.path.join(dirname, 'pairs.txt')
|
| 71 |
+
assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file
|
| 72 |
+
pairs = load_pairs_from_cache_file(cache_file, root=dirname)
|
| 73 |
+
elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']:
|
| 74 |
+
dirname = os.path.join(data_dir, dname+'_crops')
|
| 75 |
+
assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
|
| 76 |
+
list_file = os.path.join(dirname, 'listing.txt')
|
| 77 |
+
assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file)
|
| 78 |
+
pairs = load_pairs_from_list_file(list_file, root=dirname)
|
| 79 |
+
print(' {:s}: {:,} pairs'.format(dname, len(pairs)))
|
| 80 |
+
all_pairs += pairs
|
| 81 |
+
if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs)))
|
| 82 |
+
return all_pairs
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class PairsDataset(Dataset):
|
| 86 |
+
|
| 87 |
+
def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
|
| 90 |
+
self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize)
|
| 91 |
+
|
| 92 |
+
def __len__(self):
|
| 93 |
+
return len(self.image_pairs)
|
| 94 |
+
|
| 95 |
+
def __getitem__(self, index):
|
| 96 |
+
im1path, im2path = self.image_pairs[index]
|
| 97 |
+
im1 = load_image(im1path)
|
| 98 |
+
im2 = load_image(im2path)
|
| 99 |
+
if self.transforms is not None: im1, im2 = self.transforms(im1, im2)
|
| 100 |
+
return im1, im2
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__=="__main__":
|
| 104 |
+
import argparse
|
| 105 |
+
parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset")
|
| 106 |
+
parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
|
| 107 |
+
parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset")
|
| 108 |
+
args = parser.parse_args()
|
| 109 |
+
parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
|
dust3r/croco/datasets/transforms.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision.transforms
|
| 6 |
+
import torchvision.transforms.functional as F
|
| 7 |
+
|
| 8 |
+
# "Pair": apply a transform on a pair
|
| 9 |
+
# "Both": apply the exact same transform to both images
|
| 10 |
+
|
| 11 |
+
class ComposePair(torchvision.transforms.Compose):
|
| 12 |
+
def __call__(self, img1, img2):
|
| 13 |
+
for t in self.transforms:
|
| 14 |
+
img1, img2 = t(img1, img2)
|
| 15 |
+
return img1, img2
|
| 16 |
+
|
| 17 |
+
class NormalizeBoth(torchvision.transforms.Normalize):
|
| 18 |
+
def forward(self, img1, img2):
|
| 19 |
+
img1 = super().forward(img1)
|
| 20 |
+
img2 = super().forward(img2)
|
| 21 |
+
return img1, img2
|
| 22 |
+
|
| 23 |
+
class ToTensorBoth(torchvision.transforms.ToTensor):
|
| 24 |
+
def __call__(self, img1, img2):
|
| 25 |
+
img1 = super().__call__(img1)
|
| 26 |
+
img2 = super().__call__(img2)
|
| 27 |
+
return img1, img2
|
| 28 |
+
|
| 29 |
+
class RandomCropPair(torchvision.transforms.RandomCrop):
|
| 30 |
+
# the crop will be intentionally different for the two images with this class
|
| 31 |
+
def forward(self, img1, img2):
|
| 32 |
+
img1 = super().forward(img1)
|
| 33 |
+
img2 = super().forward(img2)
|
| 34 |
+
return img1, img2
|
| 35 |
+
|
| 36 |
+
class ColorJitterPair(torchvision.transforms.ColorJitter):
|
| 37 |
+
# can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
|
| 38 |
+
def __init__(self, assymetric_prob, **kwargs):
|
| 39 |
+
super().__init__(**kwargs)
|
| 40 |
+
self.assymetric_prob = assymetric_prob
|
| 41 |
+
def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor):
|
| 42 |
+
for fn_id in fn_idx:
|
| 43 |
+
if fn_id == 0 and brightness_factor is not None:
|
| 44 |
+
img = F.adjust_brightness(img, brightness_factor)
|
| 45 |
+
elif fn_id == 1 and contrast_factor is not None:
|
| 46 |
+
img = F.adjust_contrast(img, contrast_factor)
|
| 47 |
+
elif fn_id == 2 and saturation_factor is not None:
|
| 48 |
+
img = F.adjust_saturation(img, saturation_factor)
|
| 49 |
+
elif fn_id == 3 and hue_factor is not None:
|
| 50 |
+
img = F.adjust_hue(img, hue_factor)
|
| 51 |
+
return img
|
| 52 |
+
|
| 53 |
+
def forward(self, img1, img2):
|
| 54 |
+
|
| 55 |
+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
| 56 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
| 57 |
+
)
|
| 58 |
+
img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
|
| 59 |
+
if torch.rand(1) < self.assymetric_prob: # assymetric:
|
| 60 |
+
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
|
| 61 |
+
self.brightness, self.contrast, self.saturation, self.hue
|
| 62 |
+
)
|
| 63 |
+
img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
|
| 64 |
+
return img1, img2
|
| 65 |
+
|
| 66 |
+
def get_pair_transforms(transform_str, totensor=True, normalize=True):
|
| 67 |
+
# transform_str is eg crop224+color
|
| 68 |
+
trfs = []
|
| 69 |
+
for s in transform_str.split('+'):
|
| 70 |
+
if s.startswith('crop'):
|
| 71 |
+
size = int(s[len('crop'):])
|
| 72 |
+
trfs.append(RandomCropPair(size))
|
| 73 |
+
elif s=='acolor':
|
| 74 |
+
trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0))
|
| 75 |
+
elif s=='': # if transform_str was ""
|
| 76 |
+
pass
|
| 77 |
+
else:
|
| 78 |
+
raise NotImplementedError('Unknown augmentation: '+s)
|
| 79 |
+
|
| 80 |
+
if totensor:
|
| 81 |
+
trfs.append( ToTensorBoth() )
|
| 82 |
+
if normalize:
|
| 83 |
+
trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )
|
| 84 |
+
|
| 85 |
+
if len(trfs)==0:
|
| 86 |
+
return None
|
| 87 |
+
elif len(trfs)==1:
|
| 88 |
+
return trfs
|
| 89 |
+
else:
|
| 90 |
+
return ComposePair(trfs)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
dust3r/croco/models/__pycache__/blocks.cpython-312.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
dust3r/croco/models/__pycache__/croco.cpython-312.pyc
ADDED
|
Binary file (15.2 kB). View file
|
|
|
dust3r/croco/models/__pycache__/dpt_block.cpython-312.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
dust3r/croco/models/__pycache__/masking.cpython-312.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
dust3r/croco/models/__pycache__/pos_embed.cpython-312.pyc
ADDED
|
Binary file (8.31 kB). View file
|
|
|
dust3r/croco/models/blocks.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# Main encoder/decoder blocks
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
# References:
|
| 9 |
+
# timm
|
| 10 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 11 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
|
| 12 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
| 13 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
|
| 14 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from itertools import repeat
|
| 21 |
+
import collections.abc
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _ntuple(n):
|
| 25 |
+
def parse(x):
|
| 26 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 27 |
+
return x
|
| 28 |
+
return tuple(repeat(x, n))
|
| 29 |
+
return parse
|
| 30 |
+
to_2tuple = _ntuple(2)
|
| 31 |
+
|
| 32 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
| 33 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 34 |
+
"""
|
| 35 |
+
if drop_prob == 0. or not training:
|
| 36 |
+
return x
|
| 37 |
+
keep_prob = 1 - drop_prob
|
| 38 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 39 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 40 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 41 |
+
random_tensor.div_(keep_prob)
|
| 42 |
+
return x * random_tensor
|
| 43 |
+
|
| 44 |
+
class DropPath(nn.Module):
|
| 45 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 46 |
+
"""
|
| 47 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
| 48 |
+
super(DropPath, self).__init__()
|
| 49 |
+
self.drop_prob = drop_prob
|
| 50 |
+
self.scale_by_keep = scale_by_keep
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 54 |
+
|
| 55 |
+
def extra_repr(self):
|
| 56 |
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
| 57 |
+
|
| 58 |
+
class Mlp(nn.Module):
|
| 59 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 60 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
|
| 61 |
+
super().__init__()
|
| 62 |
+
out_features = out_features or in_features
|
| 63 |
+
hidden_features = hidden_features or in_features
|
| 64 |
+
bias = to_2tuple(bias)
|
| 65 |
+
drop_probs = to_2tuple(drop)
|
| 66 |
+
|
| 67 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
|
| 68 |
+
self.act = act_layer()
|
| 69 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 70 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
| 71 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
x = self.fc1(x)
|
| 75 |
+
x = self.act(x)
|
| 76 |
+
x = self.drop1(x)
|
| 77 |
+
x = self.fc2(x)
|
| 78 |
+
x = self.drop2(x)
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
class Attention(nn.Module):
|
| 82 |
+
|
| 83 |
+
def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.num_heads = num_heads
|
| 86 |
+
head_dim = dim // num_heads
|
| 87 |
+
self.scale = head_dim ** -0.5
|
| 88 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 89 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 90 |
+
self.proj = nn.Linear(dim, dim)
|
| 91 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 92 |
+
self.rope = rope
|
| 93 |
+
|
| 94 |
+
def forward(self, x, xpos):
|
| 95 |
+
B, N, C = x.shape
|
| 96 |
+
|
| 97 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
|
| 98 |
+
q, k, v = [qkv[:,:,i] for i in range(3)]
|
| 99 |
+
# q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
|
| 100 |
+
|
| 101 |
+
if self.rope is not None:
|
| 102 |
+
q = self.rope(q, xpos)
|
| 103 |
+
k = self.rope(k, xpos)
|
| 104 |
+
|
| 105 |
+
# attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 106 |
+
# attn = attn.softmax(dim=-1)
|
| 107 |
+
# attn = self.attn_drop(attn)
|
| 108 |
+
# x_old = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 109 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
| 110 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale = self.scale, dropout_p=0.).transpose(1, 2).reshape(B, N, C)
|
| 111 |
+
# import ipdb;ipdb.set_trace()
|
| 112 |
+
# (x - x_old).abs().mean()
|
| 113 |
+
x = self.proj(x)
|
| 114 |
+
x = self.proj_drop(x)
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
class LayerNorm(nn.LayerNorm):
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
t = x.dtype
|
| 120 |
+
x = super().forward(x.type(torch.float32))
|
| 121 |
+
return x.type(t)
|
| 122 |
+
|
| 123 |
+
class Block(nn.Module):
|
| 124 |
+
|
| 125 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
| 126 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.norm1 = norm_layer(dim)
|
| 129 |
+
self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 130 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 131 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 132 |
+
self.norm2 = norm_layer(dim)
|
| 133 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 134 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 135 |
+
|
| 136 |
+
def forward(self, x, xpos):
|
| 137 |
+
dtype = x.dtype
|
| 138 |
+
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
|
| 139 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
class CrossAttention(nn.Module):
|
| 143 |
+
|
| 144 |
+
def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.num_heads = num_heads
|
| 147 |
+
head_dim = dim // num_heads
|
| 148 |
+
self.scale = head_dim ** -0.5
|
| 149 |
+
|
| 150 |
+
self.projq = nn.Linear(dim, dim, bias=qkv_bias)
|
| 151 |
+
self.projk = nn.Linear(dim, dim, bias=qkv_bias)
|
| 152 |
+
self.projv = nn.Linear(dim, dim, bias=qkv_bias)
|
| 153 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 154 |
+
self.proj = nn.Linear(dim, dim)
|
| 155 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 156 |
+
|
| 157 |
+
self.rope = rope
|
| 158 |
+
|
| 159 |
+
def forward(self, query, key, value, qpos, kpos):
|
| 160 |
+
B, Nq, C = query.shape
|
| 161 |
+
Nk = key.shape[1]
|
| 162 |
+
Nv = value.shape[1]
|
| 163 |
+
|
| 164 |
+
q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
|
| 165 |
+
k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
|
| 166 |
+
v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
|
| 167 |
+
|
| 168 |
+
if self.rope is not None:
|
| 169 |
+
q = self.rope(q, qpos)
|
| 170 |
+
k = self.rope(k, kpos)
|
| 171 |
+
|
| 172 |
+
# attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 173 |
+
# attn = attn.softmax(dim=-1)
|
| 174 |
+
# attn = self.attn_drop(attn)
|
| 175 |
+
|
| 176 |
+
# x_old = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
|
| 177 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
| 178 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale = self.scale, dropout_p=0.).transpose(1, 2).reshape(B, Nq, C)
|
| 179 |
+
# import ipdb;ipdb.set_trace()
|
| 180 |
+
# (x - x_old).abs().mean()
|
| 181 |
+
x = self.proj(x)
|
| 182 |
+
x = self.proj_drop(x)
|
| 183 |
+
return x
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class DecoderBlock_onlyself(nn.Module):
|
| 187 |
+
|
| 188 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
| 189 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 192 |
+
# self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 193 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 194 |
+
self.norm1 = norm_layer(dim)
|
| 195 |
+
self.norm3 = norm_layer(dim)
|
| 196 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 197 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 198 |
+
# self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
|
| 199 |
+
|
| 200 |
+
def forward(self, x, xpos, split=False):
|
| 201 |
+
# if split==False:
|
| 202 |
+
# else:
|
| 203 |
+
# x = x.reshape(-1, y.shape[1], y.shape[2])
|
| 204 |
+
# x = x + self.drop_path(self.attn(self.norm1(x), xpos.reshape(-1, y.shape[1], 2)))
|
| 205 |
+
# x = x.reshape(y.shape[0], -1, y.shape[2])
|
| 206 |
+
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
|
| 207 |
+
# y_ = self.norm_y(y)
|
| 208 |
+
# x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
|
| 209 |
+
x = x + self.drop_path(self.mlp(self.norm3(x)))
|
| 210 |
+
return x
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class DecoderBlock_onlycross(nn.Module):
|
| 215 |
+
|
| 216 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
| 217 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
|
| 218 |
+
super().__init__()
|
| 219 |
+
self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 220 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 221 |
+
self.norm2 = norm_layer(dim)
|
| 222 |
+
self.norm3 = norm_layer(dim)
|
| 223 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 224 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 225 |
+
self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
|
| 226 |
+
def forward(self, x, y, xpos, ypos, split=False):
|
| 227 |
+
y_ = self.norm_y(y)
|
| 228 |
+
x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
|
| 229 |
+
x = x + self.drop_path(self.mlp(self.norm3(x)))
|
| 230 |
+
return x, y
|
| 231 |
+
|
| 232 |
+
class DecoderBlock(nn.Module):
|
| 233 |
+
|
| 234 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
| 235 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.norm1 = norm_layer(dim)
|
| 238 |
+
self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 239 |
+
self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 240 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 241 |
+
self.norm2 = norm_layer(dim)
|
| 242 |
+
self.norm3 = norm_layer(dim)
|
| 243 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 244 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 245 |
+
self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
|
| 246 |
+
|
| 247 |
+
def forward(self, x, y, xpos, ypos, split=False):
|
| 248 |
+
# if split==False:
|
| 249 |
+
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
|
| 250 |
+
# else:
|
| 251 |
+
# x = x.reshape(-1, y.shape[1], y.shape[2])
|
| 252 |
+
# x = x + self.drop_path(self.attn(self.norm1(x), xpos.reshape(-1, y.shape[1], 2)))
|
| 253 |
+
# x = x.reshape(y.shape[0], -1, y.shape[2])
|
| 254 |
+
y_ = self.norm_y(y)
|
| 255 |
+
x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
|
| 256 |
+
x = x + self.drop_path(self.mlp(self.norm3(x)))
|
| 257 |
+
return x, y
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# patch embedding
|
| 261 |
+
class PositionGetter(object):
|
| 262 |
+
""" return positions of patches """
|
| 263 |
+
|
| 264 |
+
def __init__(self):
|
| 265 |
+
self.cache_positions = {}
|
| 266 |
+
|
| 267 |
+
def __call__(self, b, h, w, device):
|
| 268 |
+
if not (h,w) in self.cache_positions:
|
| 269 |
+
x = torch.arange(w, device=device)
|
| 270 |
+
y = torch.arange(h, device=device)
|
| 271 |
+
self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
|
| 272 |
+
pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
|
| 273 |
+
return pos
|
| 274 |
+
|
| 275 |
+
class PatchEmbed(nn.Module):
|
| 276 |
+
""" just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
|
| 277 |
+
|
| 278 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
| 279 |
+
super().__init__()
|
| 280 |
+
img_size = to_2tuple(img_size)
|
| 281 |
+
patch_size = to_2tuple(patch_size)
|
| 282 |
+
self.img_size = img_size
|
| 283 |
+
self.patch_size = patch_size
|
| 284 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 285 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 286 |
+
self.flatten = flatten
|
| 287 |
+
|
| 288 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 289 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 290 |
+
|
| 291 |
+
self.position_getter = PositionGetter()
|
| 292 |
+
|
| 293 |
+
def forward(self, x):
|
| 294 |
+
B, C, H, W = x.shape
|
| 295 |
+
torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
| 296 |
+
torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
| 297 |
+
x = self.proj(x)
|
| 298 |
+
pos = self.position_getter(B, x.size(2), x.size(3), x.device)
|
| 299 |
+
if self.flatten:
|
| 300 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 301 |
+
x = self.norm(x)
|
| 302 |
+
return x, pos
|
| 303 |
+
|
| 304 |
+
def _init_weights(self):
|
| 305 |
+
w = self.proj.weight.data
|
| 306 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 307 |
+
|
dust3r/croco/models/criterion.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Criterion to train CroCo
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# MAE: https://github.com/facebookresearch/mae
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
class MaskedMSE(torch.nn.Module):
|
| 14 |
+
|
| 15 |
+
def __init__(self, norm_pix_loss=False, masked=True):
|
| 16 |
+
"""
|
| 17 |
+
norm_pix_loss: normalize each patch by their pixel mean and variance
|
| 18 |
+
masked: compute loss over the masked patches only
|
| 19 |
+
"""
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.norm_pix_loss = norm_pix_loss
|
| 22 |
+
self.masked = masked
|
| 23 |
+
|
| 24 |
+
def forward(self, pred, mask, target):
|
| 25 |
+
|
| 26 |
+
if self.norm_pix_loss:
|
| 27 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 28 |
+
var = target.var(dim=-1, keepdim=True)
|
| 29 |
+
target = (target - mean) / (var + 1.e-6)**.5
|
| 30 |
+
|
| 31 |
+
loss = (pred - target) ** 2
|
| 32 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
| 33 |
+
if self.masked:
|
| 34 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
|
| 35 |
+
else:
|
| 36 |
+
loss = loss.mean() # mean loss
|
| 37 |
+
return loss
|
dust3r/croco/models/croco.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# CroCo model during pretraining
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
| 14 |
+
from functools import partial
|
| 15 |
+
from models.blocks import Block, DecoderBlock, PatchEmbed, DecoderBlock_onlyself, DecoderBlock_onlycross
|
| 16 |
+
from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
|
| 17 |
+
from models.masking import RandomMask
|
| 18 |
+
from mast3r.modules import AttnBlock
|
| 19 |
+
|
| 20 |
+
class CroCoNet(nn.Module):
|
| 21 |
+
|
| 22 |
+
def __init__(self,
|
| 23 |
+
img_size=224, # input image size
|
| 24 |
+
patch_size=16, # patch_size
|
| 25 |
+
mask_ratio=0.9, # ratios of masked tokens
|
| 26 |
+
enc_embed_dim=768, # encoder feature dimension
|
| 27 |
+
enc_depth=12, # encoder depth
|
| 28 |
+
enc_num_heads=12, # encoder number of heads in the transformer block
|
| 29 |
+
dec_embed_dim=512, # decoder feature dimension
|
| 30 |
+
dec_depth=8, # decoder depth
|
| 31 |
+
dec_num_heads=16, # decoder number of heads in the transformer block
|
| 32 |
+
mlp_ratio=4,
|
| 33 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 34 |
+
norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
|
| 35 |
+
pos_embed='cosine', # positional embedding (either cosine or RoPE100)
|
| 36 |
+
):
|
| 37 |
+
|
| 38 |
+
super(CroCoNet, self).__init__()
|
| 39 |
+
|
| 40 |
+
# patch embeddings (with initialization done as in MAE)
|
| 41 |
+
self._set_patch_embed(img_size, patch_size, enc_embed_dim)
|
| 42 |
+
|
| 43 |
+
# mask generations
|
| 44 |
+
self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
|
| 45 |
+
|
| 46 |
+
self.pos_embed = pos_embed
|
| 47 |
+
if pos_embed=='cosine':
|
| 48 |
+
# positional embedding of the encoder
|
| 49 |
+
enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
|
| 50 |
+
self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
|
| 51 |
+
# positional embedding of the decoder
|
| 52 |
+
dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
|
| 53 |
+
self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
|
| 54 |
+
# pos embedding in each block
|
| 55 |
+
self.rope = None # nothing for cosine
|
| 56 |
+
elif pos_embed.startswith('RoPE'): # eg RoPE100
|
| 57 |
+
self.enc_pos_embed = None # nothing to add in the encoder with RoPE
|
| 58 |
+
self.dec_pos_embed = None # nothing to add in the decoder with RoPE
|
| 59 |
+
if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
|
| 60 |
+
freq = float(pos_embed[len('RoPE'):])
|
| 61 |
+
self.rope = RoPE2D(freq=freq)
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError('Unknown pos_embed '+pos_embed)
|
| 64 |
+
|
| 65 |
+
# transformer for the encoder
|
| 66 |
+
self.enc_depth = enc_depth
|
| 67 |
+
self.enc_embed_dim = enc_embed_dim
|
| 68 |
+
self.enc_blocks = nn.ModuleList([
|
| 69 |
+
Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
|
| 70 |
+
for i in range(enc_depth)])
|
| 71 |
+
self.enc_norm = norm_layer(enc_embed_dim)
|
| 72 |
+
self.enc_blocks_stage2 = nn.ModuleList([
|
| 73 |
+
Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
|
| 74 |
+
for i in range(enc_depth//6)])
|
| 75 |
+
self.enc_norm_stage2 = norm_layer(enc_embed_dim)
|
| 76 |
+
# masked tokens
|
| 77 |
+
self._set_mask_token(dec_embed_dim)
|
| 78 |
+
|
| 79 |
+
# decoder
|
| 80 |
+
self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
|
| 81 |
+
|
| 82 |
+
# prediction head
|
| 83 |
+
self._set_prediction_head(dec_embed_dim, patch_size)
|
| 84 |
+
|
| 85 |
+
# initializer weights
|
| 86 |
+
self.initialize_weights()
|
| 87 |
+
|
| 88 |
+
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
|
| 89 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _set_mask_generator(self, num_patches, mask_ratio):
|
| 93 |
+
self.mask_generator = RandomMask(num_patches, mask_ratio)
|
| 94 |
+
|
| 95 |
+
def _set_mask_token(self, dec_embed_dim):
|
| 96 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
|
| 97 |
+
|
| 98 |
+
def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
|
| 99 |
+
self.dec_depth = dec_depth
|
| 100 |
+
self.dec_embed_dim = dec_embed_dim
|
| 101 |
+
# transfer from encoder to decoder
|
| 102 |
+
self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
|
| 103 |
+
# transformer for the decoder
|
| 104 |
+
self.dec_blocks = nn.ModuleList([
|
| 105 |
+
DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
|
| 106 |
+
for i in range(dec_depth)])
|
| 107 |
+
self.dec_blocks_fine = nn.ModuleList([
|
| 108 |
+
DecoderBlock_onlyself(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
|
| 109 |
+
for i in range(dec_depth)])
|
| 110 |
+
self.dec_blocks_point_cross = nn.ModuleList([
|
| 111 |
+
DecoderBlock_onlycross(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
|
| 112 |
+
for i in range(dec_depth)])
|
| 113 |
+
# final norm layer
|
| 114 |
+
self.cam_cond_encoder = nn.ModuleList([AttnBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, attn_class=nn.MultiheadAttention)
|
| 115 |
+
for _ in range(dec_depth)])
|
| 116 |
+
self.pose_token_ref = nn.Parameter(torch.randn(1, 1, dec_embed_dim))
|
| 117 |
+
self.pose_token_source = nn.Parameter(torch.randn(1, 1, dec_embed_dim))
|
| 118 |
+
|
| 119 |
+
self.cam_cond_embed = nn.ModuleList([nn.Linear(dec_embed_dim, dec_embed_dim, bias=False) for i in range(dec_depth)])
|
| 120 |
+
self.dec_norm = norm_layer(dec_embed_dim)
|
| 121 |
+
self.dec_cam_norm = norm_layer(dec_embed_dim)
|
| 122 |
+
|
| 123 |
+
def _set_prediction_head(self, dec_embed_dim, patch_size):
|
| 124 |
+
self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def initialize_weights(self):
|
| 128 |
+
# patch embed
|
| 129 |
+
self.patch_embed._init_weights()
|
| 130 |
+
# mask tokens
|
| 131 |
+
if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
|
| 132 |
+
# linears and layer norms
|
| 133 |
+
self.apply(self._init_weights)
|
| 134 |
+
|
| 135 |
+
def _init_weights(self, m):
|
| 136 |
+
if isinstance(m, nn.Linear):
|
| 137 |
+
# we use xavier_uniform following official JAX ViT:
|
| 138 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 139 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 140 |
+
nn.init.constant_(m.bias, 0)
|
| 141 |
+
elif isinstance(m, nn.LayerNorm):
|
| 142 |
+
if m.elementwise_affine == True:
|
| 143 |
+
nn.init.constant_(m.bias, 0)
|
| 144 |
+
nn.init.constant_(m.weight, 1.0)
|
| 145 |
+
|
| 146 |
+
def _encode_image_fine(self, image, shapes, dtype=torch.float32):
|
| 147 |
+
"""
|
| 148 |
+
image has B x 3 x img_size x img_size
|
| 149 |
+
do_mask: whether to perform masking or not
|
| 150 |
+
return_all_blocks: if True, return the features at the end of every block
|
| 151 |
+
instead of just the features from the last block (eg for some prediction heads)
|
| 152 |
+
"""
|
| 153 |
+
# embed the image into patches (x has size B x Npatches x C)
|
| 154 |
+
# and get position if each return patch (pos has size B x Npatches x 2)
|
| 155 |
+
x, pos = self.patch_embed_fine(image, shapes)
|
| 156 |
+
x = x.to(dtype)
|
| 157 |
+
# add positional embedding without cls token
|
| 158 |
+
B,N,C = x.size()
|
| 159 |
+
posvis = pos
|
| 160 |
+
# now apply the transformer encoder and normalization
|
| 161 |
+
for blk in self.enc_fine_blocks:
|
| 162 |
+
x = blk(x, posvis)
|
| 163 |
+
x = self.enc_fine_norm(x)
|
| 164 |
+
x, pos = self.patch_embed_fine2(x)
|
| 165 |
+
x = self.enc_fine_norm2(x)
|
| 166 |
+
return x, pos, None
|
| 167 |
+
|
| 168 |
+
def _encode_image(self, image, do_mask=False, return_all_blocks=False):
|
| 169 |
+
"""
|
| 170 |
+
image has B x 3 x img_size x img_size
|
| 171 |
+
do_mask: whether to perform masking or not
|
| 172 |
+
return_all_blocks: if True, return the features at the end of every block
|
| 173 |
+
instead of just the features from the last block (eg for some prediction heads)
|
| 174 |
+
"""
|
| 175 |
+
# embed the image into patches (x has size B x Npatches x C)
|
| 176 |
+
# and get position if each return patch (pos has size B x Npatches x 2)
|
| 177 |
+
x, pos = self.patch_embed(image)
|
| 178 |
+
# add positional embedding without cls token
|
| 179 |
+
if self.enc_pos_embed is not None:
|
| 180 |
+
x = x + self.enc_pos_embed[None,...]
|
| 181 |
+
# apply masking
|
| 182 |
+
B,N,C = x.size()
|
| 183 |
+
if do_mask:
|
| 184 |
+
masks = self.mask_generator(x)
|
| 185 |
+
x = x[~masks].view(B, -1, C)
|
| 186 |
+
posvis = pos[~masks].view(B, -1, 2)
|
| 187 |
+
else:
|
| 188 |
+
B,N,C = x.size()
|
| 189 |
+
masks = torch.zeros((B,N), dtype=bool)
|
| 190 |
+
posvis = pos
|
| 191 |
+
# now apply the transformer encoder and normalization
|
| 192 |
+
if return_all_blocks:
|
| 193 |
+
out = []
|
| 194 |
+
for blk in self.enc_blocks:
|
| 195 |
+
x = blk(x, posvis)
|
| 196 |
+
out.append(x)
|
| 197 |
+
out[-1] = self.enc_norm(out[-1])
|
| 198 |
+
return out, pos, masks
|
| 199 |
+
else:
|
| 200 |
+
for blk in self.enc_blocks:
|
| 201 |
+
x = blk(x, posvis)
|
| 202 |
+
x = self.enc_norm(x)
|
| 203 |
+
return x, pos, masks
|
| 204 |
+
|
| 205 |
+
def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
|
| 206 |
+
"""
|
| 207 |
+
return_all_blocks: if True, return the features at the end of every block
|
| 208 |
+
instead of just the features from the last block (eg for some prediction heads)
|
| 209 |
+
|
| 210 |
+
masks1 can be None => assume image1 fully visible
|
| 211 |
+
"""
|
| 212 |
+
# encoder to decoder layer
|
| 213 |
+
visf1 = self.decoder_embed(feat1)
|
| 214 |
+
f2 = self.decoder_embed(feat2)
|
| 215 |
+
# append masked tokens to the sequence
|
| 216 |
+
B,Nenc,C = visf1.size()
|
| 217 |
+
if masks1 is None: # downstreams
|
| 218 |
+
f1_ = visf1
|
| 219 |
+
else: # pretraining
|
| 220 |
+
Ntotal = masks1.size(1)
|
| 221 |
+
f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
|
| 222 |
+
f1_[~masks1] = visf1.view(B * Nenc, C)
|
| 223 |
+
# add positional embedding
|
| 224 |
+
if self.dec_pos_embed is not None:
|
| 225 |
+
f1_ = f1_ + self.dec_pos_embed
|
| 226 |
+
f2 = f2 + self.dec_pos_embed
|
| 227 |
+
# apply Transformer blocks
|
| 228 |
+
out = f1_
|
| 229 |
+
out2 = f2
|
| 230 |
+
if return_all_blocks:
|
| 231 |
+
_out, out = out, []
|
| 232 |
+
for blk in self.dec_blocks:
|
| 233 |
+
_out, out2 = blk(_out, out2, pos1, pos2)
|
| 234 |
+
out.append(_out)
|
| 235 |
+
out[-1] = self.dec_norm(out[-1])
|
| 236 |
+
else:
|
| 237 |
+
for blk in self.dec_blocks:
|
| 238 |
+
out, out2 = blk(out, out2, pos1, pos2)
|
| 239 |
+
out = self.dec_norm(out)
|
| 240 |
+
return out
|
| 241 |
+
|
| 242 |
+
def patchify(self, imgs):
|
| 243 |
+
"""
|
| 244 |
+
imgs: (B, 3, H, W)
|
| 245 |
+
x: (B, L, patch_size**2 *3)
|
| 246 |
+
"""
|
| 247 |
+
p = self.patch_embed.patch_size[0]
|
| 248 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
| 249 |
+
|
| 250 |
+
h = w = imgs.shape[2] // p
|
| 251 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
| 252 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
| 253 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
| 254 |
+
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
def unpatchify(self, x, channels=3):
|
| 258 |
+
"""
|
| 259 |
+
x: (N, L, patch_size**2 *channels)
|
| 260 |
+
imgs: (N, 3, H, W)
|
| 261 |
+
"""
|
| 262 |
+
patch_size = self.patch_embed.patch_size[0]
|
| 263 |
+
h = w = int(x.shape[1]**.5)
|
| 264 |
+
assert h * w == x.shape[1]
|
| 265 |
+
x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
|
| 266 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 267 |
+
imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
|
| 268 |
+
return imgs
|
| 269 |
+
|
| 270 |
+
def forward(self, img1, img2):
|
| 271 |
+
"""
|
| 272 |
+
img1: tensor of size B x 3 x img_size x img_size
|
| 273 |
+
img2: tensor of size B x 3 x img_size x img_size
|
| 274 |
+
|
| 275 |
+
out will be B x N x (3*patch_size*patch_size)
|
| 276 |
+
masks are also returned as B x N just in case
|
| 277 |
+
"""
|
| 278 |
+
# encoder of the masked first image
|
| 279 |
+
feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
|
| 280 |
+
# encoder of the second image
|
| 281 |
+
feat2, pos2, _ = self._encode_image(img2, do_mask=False)
|
| 282 |
+
# decoder
|
| 283 |
+
decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
|
| 284 |
+
# prediction head
|
| 285 |
+
out = self.prediction_head(decfeat)
|
| 286 |
+
# get target
|
| 287 |
+
target = self.patchify(img1)
|
| 288 |
+
return out, mask1, target
|
dust3r/croco/models/dpt_block.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# DPT head for ViTs
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/isl-org/DPT
|
| 9 |
+
# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from einops import rearrange, repeat
|
| 15 |
+
from typing import Union, Tuple, Iterable, List, Optional, Dict
|
| 16 |
+
|
| 17 |
+
def pair(t):
|
| 18 |
+
return t if isinstance(t, tuple) else (t, t)
|
| 19 |
+
|
| 20 |
+
def make_scratch(in_shape, out_shape, groups=1, expand=False):
|
| 21 |
+
scratch = nn.Module()
|
| 22 |
+
|
| 23 |
+
out_shape1 = out_shape
|
| 24 |
+
out_shape2 = out_shape
|
| 25 |
+
out_shape3 = out_shape
|
| 26 |
+
out_shape4 = out_shape
|
| 27 |
+
if expand == True:
|
| 28 |
+
out_shape1 = out_shape
|
| 29 |
+
out_shape2 = out_shape * 2
|
| 30 |
+
out_shape3 = out_shape * 4
|
| 31 |
+
out_shape4 = out_shape * 8
|
| 32 |
+
|
| 33 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 34 |
+
in_shape[0],
|
| 35 |
+
out_shape1,
|
| 36 |
+
kernel_size=3,
|
| 37 |
+
stride=1,
|
| 38 |
+
padding=1,
|
| 39 |
+
bias=False,
|
| 40 |
+
groups=groups,
|
| 41 |
+
)
|
| 42 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 43 |
+
in_shape[1],
|
| 44 |
+
out_shape2,
|
| 45 |
+
kernel_size=3,
|
| 46 |
+
stride=1,
|
| 47 |
+
padding=1,
|
| 48 |
+
bias=False,
|
| 49 |
+
groups=groups,
|
| 50 |
+
)
|
| 51 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 52 |
+
in_shape[2],
|
| 53 |
+
out_shape3,
|
| 54 |
+
kernel_size=3,
|
| 55 |
+
stride=1,
|
| 56 |
+
padding=1,
|
| 57 |
+
bias=False,
|
| 58 |
+
groups=groups,
|
| 59 |
+
)
|
| 60 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 61 |
+
in_shape[3],
|
| 62 |
+
out_shape4,
|
| 63 |
+
kernel_size=3,
|
| 64 |
+
stride=1,
|
| 65 |
+
padding=1,
|
| 66 |
+
bias=False,
|
| 67 |
+
groups=groups,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
scratch.layer_rn = nn.ModuleList([
|
| 71 |
+
scratch.layer1_rn,
|
| 72 |
+
scratch.layer2_rn,
|
| 73 |
+
scratch.layer3_rn,
|
| 74 |
+
scratch.layer4_rn,
|
| 75 |
+
])
|
| 76 |
+
|
| 77 |
+
return scratch
|
| 78 |
+
|
| 79 |
+
class ResidualConvUnit_custom(nn.Module):
|
| 80 |
+
"""Residual convolution module."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, features, activation, bn):
|
| 83 |
+
"""Init.
|
| 84 |
+
Args:
|
| 85 |
+
features (int): number of features
|
| 86 |
+
"""
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.bn = bn
|
| 90 |
+
|
| 91 |
+
self.groups = 1
|
| 92 |
+
|
| 93 |
+
self.conv1 = nn.Conv2d(
|
| 94 |
+
features,
|
| 95 |
+
features,
|
| 96 |
+
kernel_size=3,
|
| 97 |
+
stride=1,
|
| 98 |
+
padding=1,
|
| 99 |
+
bias=not self.bn,
|
| 100 |
+
groups=self.groups,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.conv2 = nn.Conv2d(
|
| 104 |
+
features,
|
| 105 |
+
features,
|
| 106 |
+
kernel_size=3,
|
| 107 |
+
stride=1,
|
| 108 |
+
padding=1,
|
| 109 |
+
bias=not self.bn,
|
| 110 |
+
groups=self.groups,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if self.bn == True:
|
| 114 |
+
self.bn1 = nn.BatchNorm2d(features)
|
| 115 |
+
self.bn2 = nn.BatchNorm2d(features)
|
| 116 |
+
|
| 117 |
+
self.activation = activation
|
| 118 |
+
|
| 119 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
"""Forward pass.
|
| 123 |
+
Args:
|
| 124 |
+
x (tensor): input
|
| 125 |
+
Returns:
|
| 126 |
+
tensor: output
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
out = self.activation(x)
|
| 130 |
+
out = self.conv1(out)
|
| 131 |
+
if self.bn == True:
|
| 132 |
+
out = self.bn1(out)
|
| 133 |
+
|
| 134 |
+
out = self.activation(out)
|
| 135 |
+
out = self.conv2(out)
|
| 136 |
+
if self.bn == True:
|
| 137 |
+
out = self.bn2(out)
|
| 138 |
+
|
| 139 |
+
if self.groups > 1:
|
| 140 |
+
out = self.conv_merge(out)
|
| 141 |
+
|
| 142 |
+
return self.skip_add.add(out, x)
|
| 143 |
+
|
| 144 |
+
class FeatureFusionBlock_custom(nn.Module):
|
| 145 |
+
"""Feature fusion block."""
|
| 146 |
+
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
features,
|
| 150 |
+
activation,
|
| 151 |
+
deconv=False,
|
| 152 |
+
bn=False,
|
| 153 |
+
expand=False,
|
| 154 |
+
align_corners=True,
|
| 155 |
+
width_ratio=1,
|
| 156 |
+
):
|
| 157 |
+
"""Init.
|
| 158 |
+
Args:
|
| 159 |
+
features (int): number of features
|
| 160 |
+
"""
|
| 161 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
| 162 |
+
self.width_ratio = width_ratio
|
| 163 |
+
|
| 164 |
+
self.deconv = deconv
|
| 165 |
+
self.align_corners = align_corners
|
| 166 |
+
|
| 167 |
+
self.groups = 1
|
| 168 |
+
|
| 169 |
+
self.expand = expand
|
| 170 |
+
out_features = features
|
| 171 |
+
if self.expand == True:
|
| 172 |
+
out_features = features // 2
|
| 173 |
+
|
| 174 |
+
self.out_conv = nn.Conv2d(
|
| 175 |
+
features,
|
| 176 |
+
out_features,
|
| 177 |
+
kernel_size=1,
|
| 178 |
+
stride=1,
|
| 179 |
+
padding=0,
|
| 180 |
+
bias=True,
|
| 181 |
+
groups=1,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
| 185 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
| 186 |
+
|
| 187 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 188 |
+
|
| 189 |
+
def forward(self, *xs):
|
| 190 |
+
"""Forward pass.
|
| 191 |
+
Returns:
|
| 192 |
+
tensor: output
|
| 193 |
+
"""
|
| 194 |
+
output = xs[0]
|
| 195 |
+
|
| 196 |
+
if len(xs) == 2:
|
| 197 |
+
res = self.resConfUnit1(xs[1])
|
| 198 |
+
if self.width_ratio != 1:
|
| 199 |
+
res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
|
| 200 |
+
|
| 201 |
+
output = self.skip_add.add(output, res)
|
| 202 |
+
# output += res
|
| 203 |
+
|
| 204 |
+
output = self.resConfUnit2(output)
|
| 205 |
+
|
| 206 |
+
if self.width_ratio != 1:
|
| 207 |
+
# and output.shape[3] < self.width_ratio * output.shape[2]
|
| 208 |
+
#size=(image.shape[])
|
| 209 |
+
if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
|
| 210 |
+
shape = 3 * output.shape[3]
|
| 211 |
+
else:
|
| 212 |
+
shape = int(self.width_ratio * 2 * output.shape[2])
|
| 213 |
+
output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
|
| 214 |
+
else:
|
| 215 |
+
output = nn.functional.interpolate(output, scale_factor=2,
|
| 216 |
+
mode="bilinear", align_corners=self.align_corners)
|
| 217 |
+
output = self.out_conv(output)
|
| 218 |
+
return output
|
| 219 |
+
|
| 220 |
+
def make_fusion_block(features, use_bn, width_ratio=1):
|
| 221 |
+
return FeatureFusionBlock_custom(
|
| 222 |
+
features,
|
| 223 |
+
nn.ReLU(False),
|
| 224 |
+
deconv=False,
|
| 225 |
+
bn=use_bn,
|
| 226 |
+
expand=False,
|
| 227 |
+
align_corners=True,
|
| 228 |
+
width_ratio=width_ratio,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
class Interpolate(nn.Module):
|
| 232 |
+
"""Interpolation module."""
|
| 233 |
+
|
| 234 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
| 235 |
+
"""Init.
|
| 236 |
+
Args:
|
| 237 |
+
scale_factor (float): scaling
|
| 238 |
+
mode (str): interpolation mode
|
| 239 |
+
"""
|
| 240 |
+
super(Interpolate, self).__init__()
|
| 241 |
+
|
| 242 |
+
self.interp = nn.functional.interpolate
|
| 243 |
+
self.scale_factor = scale_factor
|
| 244 |
+
self.mode = mode
|
| 245 |
+
self.align_corners = align_corners
|
| 246 |
+
|
| 247 |
+
def forward(self, x):
|
| 248 |
+
"""Forward pass.
|
| 249 |
+
Args:
|
| 250 |
+
x (tensor): input
|
| 251 |
+
Returns:
|
| 252 |
+
tensor: interpolated data
|
| 253 |
+
"""
|
| 254 |
+
dtype = x.dtype
|
| 255 |
+
x = self.interp(
|
| 256 |
+
x.float(),
|
| 257 |
+
scale_factor=self.scale_factor,
|
| 258 |
+
mode=self.mode,
|
| 259 |
+
align_corners=self.align_corners,
|
| 260 |
+
)
|
| 261 |
+
x = x.to(dtype)
|
| 262 |
+
return x
|
| 263 |
+
|
| 264 |
+
class DPTOutputAdapter(nn.Module):
|
| 265 |
+
"""DPT output adapter.
|
| 266 |
+
|
| 267 |
+
:param num_cahnnels: Number of output channels
|
| 268 |
+
:param stride_level: tride level compared to the full-sized image.
|
| 269 |
+
E.g. 4 for 1/4th the size of the image.
|
| 270 |
+
:param patch_size_full: Int or tuple of the patch size over the full image size.
|
| 271 |
+
Patch size for smaller inputs will be computed accordingly.
|
| 272 |
+
:param hooks: Index of intermediate layers
|
| 273 |
+
:param layer_dims: Dimension of intermediate layers
|
| 274 |
+
:param feature_dim: Feature dimension
|
| 275 |
+
:param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
|
| 276 |
+
:param use_bn: If set to True, activates batch norm
|
| 277 |
+
:param dim_tokens_enc: Dimension of tokens coming from encoder
|
| 278 |
+
"""
|
| 279 |
+
|
| 280 |
+
def __init__(self,
|
| 281 |
+
num_channels: int = 1,
|
| 282 |
+
stride_level: int = 1,
|
| 283 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 284 |
+
main_tasks: Iterable[str] = ('rgb',),
|
| 285 |
+
hooks: List[int] = [2, 5, 8, 11],
|
| 286 |
+
layer_dims: List[int] = [96, 192, 384, 768],
|
| 287 |
+
feature_dim: int = 256,
|
| 288 |
+
last_dim: int = 32,
|
| 289 |
+
use_bn: bool = False,
|
| 290 |
+
dim_tokens_enc: Optional[int] = None,
|
| 291 |
+
head_type: str = 'regression',
|
| 292 |
+
output_width_ratio=1,
|
| 293 |
+
**kwargs):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.num_channels = num_channels
|
| 296 |
+
self.stride_level = stride_level
|
| 297 |
+
self.patch_size = pair(patch_size)
|
| 298 |
+
self.main_tasks = main_tasks
|
| 299 |
+
self.hooks = hooks
|
| 300 |
+
self.layer_dims = layer_dims
|
| 301 |
+
self.feature_dim = feature_dim
|
| 302 |
+
self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
|
| 303 |
+
self.head_type = head_type
|
| 304 |
+
|
| 305 |
+
# Actual patch height and width, taking into account stride of input
|
| 306 |
+
self.P_H = max(1, self.patch_size[0] // stride_level)
|
| 307 |
+
self.P_W = max(1, self.patch_size[1] // stride_level)
|
| 308 |
+
|
| 309 |
+
self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
|
| 310 |
+
|
| 311 |
+
self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
| 312 |
+
self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
| 313 |
+
self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
| 314 |
+
self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
| 315 |
+
|
| 316 |
+
if self.head_type == 'regression':
|
| 317 |
+
# The "DPTDepthModel" head
|
| 318 |
+
self.head = nn.Sequential(
|
| 319 |
+
nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
|
| 320 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
| 321 |
+
nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
|
| 322 |
+
nn.ReLU(True),
|
| 323 |
+
nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
|
| 324 |
+
)
|
| 325 |
+
elif self.head_type == 'semseg':
|
| 326 |
+
# The "DPTSegmentationModel" head
|
| 327 |
+
self.head = nn.Sequential(
|
| 328 |
+
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
|
| 329 |
+
nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
|
| 330 |
+
nn.ReLU(True),
|
| 331 |
+
nn.Dropout(0.1, False),
|
| 332 |
+
nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
|
| 333 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
| 334 |
+
)
|
| 335 |
+
else:
|
| 336 |
+
raise ValueError('DPT head_type must be "regression" or "semseg".')
|
| 337 |
+
|
| 338 |
+
if self.dim_tokens_enc is not None:
|
| 339 |
+
self.init(dim_tokens_enc=dim_tokens_enc)
|
| 340 |
+
|
| 341 |
+
def init(self, dim_tokens_enc=768):
|
| 342 |
+
"""
|
| 343 |
+
Initialize parts of decoder that are dependent on dimension of encoder tokens.
|
| 344 |
+
Should be called when setting up MultiMAE.
|
| 345 |
+
|
| 346 |
+
:param dim_tokens_enc: Dimension of tokens coming from encoder
|
| 347 |
+
"""
|
| 348 |
+
#print(dim_tokens_enc)
|
| 349 |
+
|
| 350 |
+
# Set up activation postprocessing layers
|
| 351 |
+
if isinstance(dim_tokens_enc, int):
|
| 352 |
+
dim_tokens_enc = 4 * [dim_tokens_enc]
|
| 353 |
+
|
| 354 |
+
self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
|
| 355 |
+
|
| 356 |
+
self.act_1_postprocess = nn.Sequential(
|
| 357 |
+
nn.Conv2d(
|
| 358 |
+
in_channels=self.dim_tokens_enc[0],
|
| 359 |
+
out_channels=self.layer_dims[0],
|
| 360 |
+
kernel_size=1, stride=1, padding=0,
|
| 361 |
+
),
|
| 362 |
+
nn.ConvTranspose2d(
|
| 363 |
+
in_channels=self.layer_dims[0],
|
| 364 |
+
out_channels=self.layer_dims[0],
|
| 365 |
+
kernel_size=4, stride=4, padding=0,
|
| 366 |
+
bias=True, dilation=1, groups=1,
|
| 367 |
+
)
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
self.act_2_postprocess = nn.Sequential(
|
| 371 |
+
nn.Conv2d(
|
| 372 |
+
in_channels=self.dim_tokens_enc[1],
|
| 373 |
+
out_channels=self.layer_dims[1],
|
| 374 |
+
kernel_size=1, stride=1, padding=0,
|
| 375 |
+
),
|
| 376 |
+
nn.ConvTranspose2d(
|
| 377 |
+
in_channels=self.layer_dims[1],
|
| 378 |
+
out_channels=self.layer_dims[1],
|
| 379 |
+
kernel_size=2, stride=2, padding=0,
|
| 380 |
+
bias=True, dilation=1, groups=1,
|
| 381 |
+
)
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
self.act_3_postprocess = nn.Sequential(
|
| 385 |
+
nn.Conv2d(
|
| 386 |
+
in_channels=self.dim_tokens_enc[2],
|
| 387 |
+
out_channels=self.layer_dims[2],
|
| 388 |
+
kernel_size=1, stride=1, padding=0,
|
| 389 |
+
)
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
self.act_4_postprocess = nn.Sequential(
|
| 393 |
+
nn.Conv2d(
|
| 394 |
+
in_channels=self.dim_tokens_enc[3],
|
| 395 |
+
out_channels=self.layer_dims[3],
|
| 396 |
+
kernel_size=1, stride=1, padding=0,
|
| 397 |
+
),
|
| 398 |
+
nn.Conv2d(
|
| 399 |
+
in_channels=self.layer_dims[3],
|
| 400 |
+
out_channels=self.layer_dims[3],
|
| 401 |
+
kernel_size=3, stride=2, padding=1,
|
| 402 |
+
)
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
self.act_postprocess = nn.ModuleList([
|
| 406 |
+
self.act_1_postprocess,
|
| 407 |
+
self.act_2_postprocess,
|
| 408 |
+
self.act_3_postprocess,
|
| 409 |
+
self.act_4_postprocess
|
| 410 |
+
])
|
| 411 |
+
|
| 412 |
+
def adapt_tokens(self, encoder_tokens):
|
| 413 |
+
# Adapt tokens
|
| 414 |
+
x = []
|
| 415 |
+
x.append(encoder_tokens[:, :])
|
| 416 |
+
x = torch.cat(x, dim=-1)
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
def forward(self, encoder_tokens: List[torch.Tensor], image_size):
|
| 420 |
+
#input_info: Dict):
|
| 421 |
+
assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
|
| 422 |
+
H, W = image_size
|
| 423 |
+
|
| 424 |
+
# Number of patches in height and width
|
| 425 |
+
N_H = H // (self.stride_level * self.P_H)
|
| 426 |
+
N_W = W // (self.stride_level * self.P_W)
|
| 427 |
+
|
| 428 |
+
# Hook decoder onto 4 layers from specified ViT layers
|
| 429 |
+
layers = [encoder_tokens[hook] for hook in self.hooks]
|
| 430 |
+
|
| 431 |
+
# Extract only task-relevant tokens and ignore global tokens.
|
| 432 |
+
layers = [self.adapt_tokens(l) for l in layers]
|
| 433 |
+
|
| 434 |
+
# Reshape tokens to spatial representation
|
| 435 |
+
layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
|
| 436 |
+
|
| 437 |
+
layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
|
| 438 |
+
# Project layers to chosen feature dim
|
| 439 |
+
layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
|
| 440 |
+
|
| 441 |
+
# Fuse layers using refinement stages
|
| 442 |
+
path_4 = self.scratch.refinenet4(layers[3])
|
| 443 |
+
path_3 = self.scratch.refinenet3(path_4, layers[2])
|
| 444 |
+
path_2 = self.scratch.refinenet2(path_3, layers[1])
|
| 445 |
+
path_1 = self.scratch.refinenet1(path_2, layers[0])
|
| 446 |
+
|
| 447 |
+
# Output head
|
| 448 |
+
out = self.head(path_1)
|
| 449 |
+
|
| 450 |
+
return out
|
dust3r/croco/models/head_downstream.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Heads for downstream tasks
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
A head is a module where the __init__ defines only the head hyperparameters.
|
| 10 |
+
A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes.
|
| 11 |
+
The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height'
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from .dpt_block import DPTOutputAdapter
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PixelwiseTaskWithDPT(nn.Module):
|
| 20 |
+
""" DPT module for CroCo.
|
| 21 |
+
by default, hooks_idx will be equal to:
|
| 22 |
+
* for encoder-only: 4 equally spread layers
|
| 23 |
+
* for encoder+decoder: last encoder + 3 equally spread layers of the decoder
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768],
|
| 27 |
+
output_width_ratio=1, num_channels=1, postprocess=None, **kwargs):
|
| 28 |
+
super(PixelwiseTaskWithDPT, self).__init__()
|
| 29 |
+
self.return_all_blocks = True # backbone needs to return all layers
|
| 30 |
+
self.postprocess = postprocess
|
| 31 |
+
self.output_width_ratio = output_width_ratio
|
| 32 |
+
self.num_channels = num_channels
|
| 33 |
+
self.hooks_idx = hooks_idx
|
| 34 |
+
self.layer_dims = layer_dims
|
| 35 |
+
|
| 36 |
+
def setup(self, croconet):
|
| 37 |
+
dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels}
|
| 38 |
+
if self.hooks_idx is None:
|
| 39 |
+
if hasattr(croconet, 'dec_blocks'): # encoder + decoder
|
| 40 |
+
step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth]
|
| 41 |
+
hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
|
| 42 |
+
else: # encoder only
|
| 43 |
+
step = croconet.enc_depth//4
|
| 44 |
+
hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
|
| 45 |
+
self.hooks_idx = hooks_idx
|
| 46 |
+
print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}')
|
| 47 |
+
dpt_args['hooks'] = self.hooks_idx
|
| 48 |
+
dpt_args['layer_dims'] = self.layer_dims
|
| 49 |
+
self.dpt = DPTOutputAdapter(**dpt_args)
|
| 50 |
+
dim_tokens = [croconet.enc_embed_dim if hook<croconet.enc_depth else croconet.dec_embed_dim for hook in self.hooks_idx]
|
| 51 |
+
dpt_init_args = {'dim_tokens_enc': dim_tokens}
|
| 52 |
+
self.dpt.init(**dpt_init_args)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def forward(self, x, img_info):
|
| 56 |
+
out = self.dpt(x, image_size=(img_info['height'],img_info['width']))
|
| 57 |
+
if self.postprocess: out = self.postprocess(out)
|
| 58 |
+
return out
|
dust3r/croco/models/masking.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# Masking utils
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
class RandomMask(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
random masking
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, num_patches, mask_ratio):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.num_patches = num_patches
|
| 20 |
+
self.num_mask = int(mask_ratio * self.num_patches)
|
| 21 |
+
|
| 22 |
+
def __call__(self, x):
|
| 23 |
+
noise = torch.rand(x.size(0), self.num_patches, device=x.device)
|
| 24 |
+
argsort = torch.argsort(noise, dim=1)
|
| 25 |
+
return argsort < self.num_mask
|
dust3r/croco/models/pos_embed.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# --------------------------------------------------------
|
| 6 |
+
# Position embedding utils
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# --------------------------------------------------------
|
| 16 |
+
# 2D sine-cosine position embedding
|
| 17 |
+
# References:
|
| 18 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 19 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 20 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 21 |
+
# --------------------------------------------------------
|
| 22 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
|
| 23 |
+
"""
|
| 24 |
+
grid_size: int of the grid height and width
|
| 25 |
+
return:
|
| 26 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 27 |
+
"""
|
| 28 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 29 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 30 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 31 |
+
grid = np.stack(grid, axis=0)
|
| 32 |
+
|
| 33 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 34 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 35 |
+
if n_cls_token>0:
|
| 36 |
+
pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
|
| 37 |
+
return pos_embed
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 41 |
+
assert embed_dim % 2 == 0
|
| 42 |
+
|
| 43 |
+
# use half of dimensions to encode grid_h
|
| 44 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 45 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 46 |
+
|
| 47 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 48 |
+
return emb
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 52 |
+
"""
|
| 53 |
+
embed_dim: output dimension for each position
|
| 54 |
+
pos: a list of positions to be encoded: size (M,)
|
| 55 |
+
out: (M, D)
|
| 56 |
+
"""
|
| 57 |
+
assert embed_dim % 2 == 0
|
| 58 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 59 |
+
omega /= embed_dim / 2.
|
| 60 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 61 |
+
|
| 62 |
+
pos = pos.reshape(-1) # (M,)
|
| 63 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 64 |
+
|
| 65 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 66 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 67 |
+
|
| 68 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 69 |
+
return emb
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# --------------------------------------------------------
|
| 73 |
+
# Interpolate position embeddings for high-resolution
|
| 74 |
+
# References:
|
| 75 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 76 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 77 |
+
# --------------------------------------------------------
|
| 78 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
| 79 |
+
if 'pos_embed' in checkpoint_model:
|
| 80 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
| 81 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 82 |
+
num_patches = model.patch_embed.num_patches
|
| 83 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 84 |
+
# height (== width) for the checkpoint position embedding
|
| 85 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 86 |
+
# height (== width) for the new position embedding
|
| 87 |
+
new_size = int(num_patches ** 0.5)
|
| 88 |
+
# class_token and dist_token are kept unchanged
|
| 89 |
+
if orig_size != new_size:
|
| 90 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 91 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 92 |
+
# only the position tokens are interpolated
|
| 93 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 94 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 95 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 96 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 97 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 98 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 99 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
#----------------------------------------------------------
|
| 103 |
+
# RoPE2D: RoPE implementation in 2D
|
| 104 |
+
#----------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
try:
|
| 107 |
+
from models.curope import cuRoPE2D
|
| 108 |
+
RoPE2D = cuRoPE2D
|
| 109 |
+
except ImportError:
|
| 110 |
+
print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
|
| 111 |
+
|
| 112 |
+
class RoPE2D(torch.nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self, freq=100.0, F0=1.0):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.base = freq
|
| 117 |
+
self.F0 = F0
|
| 118 |
+
self.cache = {}
|
| 119 |
+
|
| 120 |
+
def get_cos_sin(self, D, seq_len, device, dtype):
|
| 121 |
+
if (D,seq_len,device,dtype) not in self.cache:
|
| 122 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
| 123 |
+
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 124 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
| 125 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
| 126 |
+
cos = freqs.cos() # (Seq, Dim)
|
| 127 |
+
sin = freqs.sin()
|
| 128 |
+
self.cache[D,seq_len,device,dtype] = (cos,sin)
|
| 129 |
+
return self.cache[D,seq_len,device,dtype]
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def rotate_half(x):
|
| 133 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 134 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 135 |
+
|
| 136 |
+
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
| 137 |
+
assert pos1d.ndim==2
|
| 138 |
+
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
| 139 |
+
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
| 140 |
+
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
| 141 |
+
|
| 142 |
+
def forward(self, tokens, positions):
|
| 143 |
+
"""
|
| 144 |
+
input:
|
| 145 |
+
* tokens: batch_size x nheads x ntokens x dim
|
| 146 |
+
* positions: batch_size x ntokens x 2 (y and x position of each token)
|
| 147 |
+
output:
|
| 148 |
+
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
| 149 |
+
"""
|
| 150 |
+
assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
| 151 |
+
D = tokens.size(3) // 2
|
| 152 |
+
assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
| 153 |
+
cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
|
| 154 |
+
# split features into two along the feature dimension, and apply rope1d on each half
|
| 155 |
+
y, x = tokens.chunk(2, dim=-1)
|
| 156 |
+
y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
| 157 |
+
x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
| 158 |
+
tokens = torch.cat((y, x), dim=-1)
|
| 159 |
+
return tokens
|
dust3r/croco/models/transformer_utils.py
ADDED
|
@@ -0,0 +1,1021 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from typing import Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from torch import nn, einsum
|
| 11 |
+
from torch.utils.checkpoint import checkpoint
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
|
| 15 |
+
from inspect import isfunction
|
| 16 |
+
try:
|
| 17 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func, flash_attn_varlen_qkvpacked_func
|
| 18 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
| 19 |
+
except:
|
| 20 |
+
flash_attn_qkvpacked_func, flash_attn_func, flash_attn_varlen_qkvpacked_func = None, None, None
|
| 21 |
+
unpad_input, pad_input = None, None
|
| 22 |
+
from .x_transformer import AttentionLayers, BasicEncoder
|
| 23 |
+
import math
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _init_weights(module):
|
| 28 |
+
if isinstance(module, nn.Linear):
|
| 29 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 30 |
+
if module.bias is not None:
|
| 31 |
+
torch.nn.init.zeros_(module.bias)
|
| 32 |
+
|
| 33 |
+
def exists(val):
|
| 34 |
+
return val is not None
|
| 35 |
+
|
| 36 |
+
def default(val, d):
|
| 37 |
+
if exists(val):
|
| 38 |
+
return val
|
| 39 |
+
return d() if isfunction(d) else d
|
| 40 |
+
|
| 41 |
+
def zero_module(module):
|
| 42 |
+
"""
|
| 43 |
+
Zero out the parameters of a module and return it.
|
| 44 |
+
"""
|
| 45 |
+
for p in module.parameters():
|
| 46 |
+
p.detach().zero_()
|
| 47 |
+
return module
|
| 48 |
+
|
| 49 |
+
# Copy from CLIP GitHub
|
| 50 |
+
class LayerNorm(nn.LayerNorm):
|
| 51 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 52 |
+
|
| 53 |
+
def forward(self, x: torch.Tensor):
|
| 54 |
+
orig_type = x.dtype
|
| 55 |
+
ret = super().forward(x.type(torch.float32))
|
| 56 |
+
return ret.type(orig_type)
|
| 57 |
+
|
| 58 |
+
class QuickGELU(nn.Module):
|
| 59 |
+
def forward(self, x: torch.Tensor):
|
| 60 |
+
return x * torch.sigmoid(1.702 * x)
|
| 61 |
+
|
| 62 |
+
def modulate(x, shift, scale):
|
| 63 |
+
# from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 64 |
+
return x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MultiheadAttentionFlashV2(nn.Module):
|
| 68 |
+
def __init__(self, embed_dim, n_head, bias=False, shift_group=None, qkv_packed=False, window_size=None):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.head_dim = embed_dim// n_head
|
| 72 |
+
self.embed_dim = embed_dim
|
| 73 |
+
self.n_head = n_head
|
| 74 |
+
self.to_q = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 75 |
+
self.to_k = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 76 |
+
self.to_v = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 77 |
+
self.shift_group = shift_group
|
| 78 |
+
self.qkv_packed = qkv_packed
|
| 79 |
+
self.window_size = window_size
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def forward(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, need_weights=False, attn_mask=None):
|
| 83 |
+
q = q.permute(1, 0, 2)
|
| 84 |
+
k = k.permute(1, 0, 2)
|
| 85 |
+
v = v.permute(1, 0, 2)
|
| 86 |
+
|
| 87 |
+
h = self.n_head
|
| 88 |
+
q = self.to_q(q)
|
| 89 |
+
k = self.to_k(k)
|
| 90 |
+
v = self.to_v(v)
|
| 91 |
+
|
| 92 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v))
|
| 93 |
+
# print(q.dtype, k.dtype, v.dtype)
|
| 94 |
+
if self.qkv_packed:
|
| 95 |
+
bsz, q_len, heads, head_dim = q.shape
|
| 96 |
+
group_size = self.shift_group
|
| 97 |
+
nheads = self.n_head
|
| 98 |
+
qkv = torch.stack([q,k,v], dim=2)
|
| 99 |
+
qkv = qkv.reshape(bsz, q_len, 3, 2, nheads // 2, self.head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,
|
| 100 |
+
q_len, 3,
|
| 101 |
+
nheads // 2,
|
| 102 |
+
self.head_dim)
|
| 103 |
+
|
| 104 |
+
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
| 105 |
+
key_padding_mask = torch.ones(x.shape[0], x.shape[1], device=x.device, dtype=x.dtype)
|
| 106 |
+
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
| 107 |
+
cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)
|
| 108 |
+
cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2
|
| 109 |
+
cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).min
|
| 110 |
+
cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)
|
| 111 |
+
cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
|
| 112 |
+
cu_q_lens = cu_q_lens[cu_q_lens >= 0]
|
| 113 |
+
x_unpad = rearrange(
|
| 114 |
+
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
|
| 115 |
+
)
|
| 116 |
+
output_unpad = flash_attn_varlen_qkvpacked_func(
|
| 117 |
+
x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=False,
|
| 118 |
+
)
|
| 119 |
+
output = rearrange(
|
| 120 |
+
pad_input(
|
| 121 |
+
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
|
| 122 |
+
),
|
| 123 |
+
"b s (h d) -> b s h d",
|
| 124 |
+
h=nheads // 2,
|
| 125 |
+
)
|
| 126 |
+
r_out = output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,
|
| 127 |
+
self.head_dim)
|
| 128 |
+
else:
|
| 129 |
+
if self.shift_group is not None:
|
| 130 |
+
bsz, q_len, heads, head_dim = q.shape
|
| 131 |
+
assert q_len % self.shift_group == 0
|
| 132 |
+
|
| 133 |
+
def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
|
| 134 |
+
qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
|
| 135 |
+
qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
|
| 136 |
+
return qkv
|
| 137 |
+
|
| 138 |
+
q = shift(q, bsz, q_len, self.shift_group, h, self.head_dim)
|
| 139 |
+
k = shift(k, bsz, q_len, self.shift_group, h, self.head_dim)
|
| 140 |
+
v = shift(v, bsz, q_len, self.shift_group, h, self.head_dim)
|
| 141 |
+
if self.window_size:
|
| 142 |
+
out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal, window_size=(self.window_size // 2, self.window_size // 2))
|
| 143 |
+
else:
|
| 144 |
+
out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal)
|
| 145 |
+
|
| 146 |
+
if self.shift_group is not None:
|
| 147 |
+
out = out.transpose(1, 2).contiguous()
|
| 148 |
+
out = rearrange(out, '(b l) g h d -> b (l g) h d', l=q_len // self.shift_group)
|
| 149 |
+
r_out = out.clone()
|
| 150 |
+
r_out[:, :, h//2:] = r_out[:, :, h//2:].roll(h//2, dims=1)
|
| 151 |
+
else:
|
| 152 |
+
r_out = out
|
| 153 |
+
|
| 154 |
+
r_out = rearrange(r_out, 'b n h d -> b n (h d)')
|
| 155 |
+
r_out = r_out.permute(1, 0, 2)
|
| 156 |
+
return (r_out,)
|
| 157 |
+
|
| 158 |
+
class PSUpsamplerBlock(nn.Module):
|
| 159 |
+
def __init__(self, d_model: int, d_model_out: int, scale_factor: int):
|
| 160 |
+
super().__init__()
|
| 161 |
+
|
| 162 |
+
# self.mlp = nn.Sequential(OrderedDict([
|
| 163 |
+
# ("c_fc", nn.Linear(d_model, d_model_out * scale_factor**2)),
|
| 164 |
+
# ("gelu", QuickGELU()),
|
| 165 |
+
# ("c_proj", nn.Linear(d_model_out * scale_factor**2, d_model_out * scale_factor**2))
|
| 166 |
+
# ]))
|
| 167 |
+
# self.ln_2 = LayerNorm(d_model)
|
| 168 |
+
self.scale_factor = scale_factor
|
| 169 |
+
self.d_model_out = d_model_out
|
| 170 |
+
self.residual_fc = nn.Linear(d_model, d_model_out * (scale_factor**2))
|
| 171 |
+
self.pixelshuffle = nn.PixelShuffle(scale_factor)
|
| 172 |
+
|
| 173 |
+
def forward(self, x: torch.Tensor):
|
| 174 |
+
# mlp block
|
| 175 |
+
# x.shape b, l, d
|
| 176 |
+
# y = self.ln_2(x)
|
| 177 |
+
# y = self.mlp(y)
|
| 178 |
+
# For here we have two cases:
|
| 179 |
+
# 1. If we have a modulation function for the MLP, we use it to modulate the output of the MLP
|
| 180 |
+
# 2. If we don't have a modulation function for the MLP, we use the modulation function for the attention
|
| 181 |
+
x = self.residual_fc(x)# .repeat(1, 1, self.scale_factor**2)
|
| 182 |
+
# x = x + y
|
| 183 |
+
bs, l, c = x.shape
|
| 184 |
+
resolution = int(np.sqrt(l))
|
| 185 |
+
x = x.permute(0, 2, 1).reshape(bs, c, resolution, resolution)
|
| 186 |
+
x = self.pixelshuffle(x)
|
| 187 |
+
x = x.reshape(bs, self.d_model_out, resolution*self.scale_factor*resolution*self.scale_factor)
|
| 188 |
+
x = x.permute(0, 2, 1)
|
| 189 |
+
# x = rearrange(x, 'b l (s c) -> b (l s) c', s=self.scale_factor**2)
|
| 190 |
+
return x
|
| 191 |
+
|
| 192 |
+
class ResidualAttentionBlock(nn.Module):
|
| 193 |
+
def __init__(self, d_model: int,
|
| 194 |
+
n_head: int,
|
| 195 |
+
attn_mask: torch.Tensor = None,
|
| 196 |
+
modulate_feature_size: int = None,
|
| 197 |
+
modulate_act_type: str = 'gelu',
|
| 198 |
+
cross_att: bool = None,
|
| 199 |
+
flash_v2: bool = None,
|
| 200 |
+
qkv_packed: bool = None,
|
| 201 |
+
shift_group: int = None,
|
| 202 |
+
window_size: int = None,):
|
| 203 |
+
super().__init__()
|
| 204 |
+
|
| 205 |
+
print('vit flashv2', flash_v2)
|
| 206 |
+
|
| 207 |
+
self.flash_v2 = flash_v2
|
| 208 |
+
self.window_size = window_size
|
| 209 |
+
if self.flash_v2:
|
| 210 |
+
self.attn = MultiheadAttentionFlashV2(d_model, n_head, shift_group=shift_group, qkv_packed=qkv_packed, window_size=window_size)
|
| 211 |
+
else:
|
| 212 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 213 |
+
|
| 214 |
+
self.ln_1 = LayerNorm(d_model)
|
| 215 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 216 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 217 |
+
("gelu", QuickGELU()),
|
| 218 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 219 |
+
]))
|
| 220 |
+
self.ln_2 = LayerNorm(d_model)
|
| 221 |
+
self.attn_mask = attn_mask
|
| 222 |
+
self.window_size = window_size
|
| 223 |
+
|
| 224 |
+
if modulate_feature_size is not None:
|
| 225 |
+
act_dict = {'gelu': QuickGELU,
|
| 226 |
+
'silu': nn.SiLU}
|
| 227 |
+
self.modulation_fn = nn.Sequential(
|
| 228 |
+
LayerNorm(modulate_feature_size),
|
| 229 |
+
act_dict[modulate_act_type](),
|
| 230 |
+
nn.Linear(modulate_feature_size, 3 * d_model, bias=True)
|
| 231 |
+
)
|
| 232 |
+
self.mlp_modulation_fn = nn.Sequential(
|
| 233 |
+
LayerNorm(modulate_feature_size),
|
| 234 |
+
act_dict[modulate_act_type](),
|
| 235 |
+
nn.Linear(modulate_feature_size, 3 * d_model, bias=True)
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
self.modulation_fn = None
|
| 239 |
+
self.mlp_modulation_fn = None
|
| 240 |
+
|
| 241 |
+
self.cross_att = cross_att
|
| 242 |
+
if self.cross_att:
|
| 243 |
+
self.cross_att = CrossAttention(query_dim=d_model, context_dim=d_model,
|
| 244 |
+
heads=n_head, dim_head=int(d_model//n_head), dropout=0)
|
| 245 |
+
self.ln_1_5 = LayerNorm(d_model)
|
| 246 |
+
|
| 247 |
+
def attention(self, x: torch.Tensor, index):
|
| 248 |
+
if self.attn_mask is not None:
|
| 249 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device)
|
| 250 |
+
length = x.shape[0]
|
| 251 |
+
attn_mask = self.attn_mask[:length, :length]
|
| 252 |
+
else:
|
| 253 |
+
attn_mask = None
|
| 254 |
+
if self.window_size is not None:
|
| 255 |
+
x = x.permute(1, 0, 2)
|
| 256 |
+
b, l, c = x.shape
|
| 257 |
+
# print(x.shape)
|
| 258 |
+
assert l % self.window_size == 0
|
| 259 |
+
if index % 2 == 0:
|
| 260 |
+
x = rearrange(x, 'b (p w) c -> (b p) w c', w=self.window_size)
|
| 261 |
+
x = x.permute(1, 0, 2) # w, bp, c
|
| 262 |
+
x = self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
| 263 |
+
x = x.permute(1, 0, 2) # bp, w, c
|
| 264 |
+
x = rearrange(x, '(b l) w c -> b (l w) c', l=l//self.window_size, w=self.window_size)
|
| 265 |
+
x = x.permute(1, 0, 2) # w, bp, c
|
| 266 |
+
else:
|
| 267 |
+
x = torch.roll(x, shifts=self.window_size//2, dims=1)
|
| 268 |
+
x = rearrange(x, 'b (p w) c -> (b p) w c', w=self.window_size)
|
| 269 |
+
x = x.permute(1, 0, 2) # w, bp, c
|
| 270 |
+
x = self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
| 271 |
+
x = x.permute(1, 0, 2) # w, bp, c
|
| 272 |
+
x = rearrange(x, '(b l) w c -> b (l w) c', l=l//self.window_size, w=self.window_size)
|
| 273 |
+
x = torch.roll(x, shifts=-self.window_size//2, dims=1)
|
| 274 |
+
x = x.permute(1, 0, 2)
|
| 275 |
+
else:
|
| 276 |
+
x = self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
| 277 |
+
|
| 278 |
+
return x
|
| 279 |
+
|
| 280 |
+
def forward(self, x: torch.Tensor, modulation: torch.Tensor = None, context: torch.Tensor = None, index=None):
|
| 281 |
+
# self attention block
|
| 282 |
+
y = self.ln_1(x)
|
| 283 |
+
if self.modulation_fn is not None:
|
| 284 |
+
shift, scale, gate = self.modulation_fn(modulation).chunk(3, dim=1)
|
| 285 |
+
y = modulate(y, shift, scale)
|
| 286 |
+
y = self.attention(y, index)
|
| 287 |
+
# If we have modulation func for mlp as well, we will just use the gate for the attention
|
| 288 |
+
if self.modulation_fn is not None and self.mlp_modulation_fn is not None:
|
| 289 |
+
y = y * gate.unsqueeze(0)
|
| 290 |
+
x = x + y
|
| 291 |
+
|
| 292 |
+
# cross attention block
|
| 293 |
+
if self.cross_att:
|
| 294 |
+
y = self.cross_att(self.ln_1_5(x), context=context)
|
| 295 |
+
# print(y.mean().item())
|
| 296 |
+
x = x + y
|
| 297 |
+
|
| 298 |
+
# mlp block
|
| 299 |
+
y = self.ln_2(x)
|
| 300 |
+
if self.mlp_modulation_fn is not None:
|
| 301 |
+
shift, scale, gate = self.mlp_modulation_fn(modulation).chunk(3, dim=1)
|
| 302 |
+
y = modulate(y, shift, scale)
|
| 303 |
+
y = self.mlp(y)
|
| 304 |
+
# For here we have two cases:
|
| 305 |
+
# 1. If we have a modulation function for the MLP, we use it to modulate the output of the MLP
|
| 306 |
+
# 2. If we don't have a modulation function for the MLP, we use the modulation function for the attention
|
| 307 |
+
if self.modulation_fn is not None:
|
| 308 |
+
y = y * gate.unsqueeze(0)
|
| 309 |
+
x = x + y
|
| 310 |
+
|
| 311 |
+
return x
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class Transformer(nn.Module):
|
| 315 |
+
def __init__(self,
|
| 316 |
+
width: int,
|
| 317 |
+
layers: int,
|
| 318 |
+
heads: int,
|
| 319 |
+
attn_mask: torch.Tensor = None,
|
| 320 |
+
modulate_feature_size: int = None,
|
| 321 |
+
modulate_act_type: str = 'gelu',
|
| 322 |
+
cross_att_layers: int = 0,
|
| 323 |
+
return_all_layers=False,
|
| 324 |
+
flash_v2=True,
|
| 325 |
+
qkv_packed=False,
|
| 326 |
+
shift_group=None,
|
| 327 |
+
window_size=None):
|
| 328 |
+
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.width = width
|
| 331 |
+
self.layers = layers
|
| 332 |
+
|
| 333 |
+
blocks = []
|
| 334 |
+
for _ in range(layers):
|
| 335 |
+
layer = ResidualAttentionBlock(width,
|
| 336 |
+
heads,
|
| 337 |
+
attn_mask,
|
| 338 |
+
modulate_feature_size=modulate_feature_size,
|
| 339 |
+
modulate_act_type=modulate_act_type,
|
| 340 |
+
cross_att = (_ + cross_att_layers)>=layers,
|
| 341 |
+
flash_v2=flash_v2,
|
| 342 |
+
qkv_packed=qkv_packed,
|
| 343 |
+
shift_group=shift_group,
|
| 344 |
+
window_size=window_size)
|
| 345 |
+
blocks.append(layer)
|
| 346 |
+
|
| 347 |
+
self.resblocks = nn.Sequential(*blocks)
|
| 348 |
+
|
| 349 |
+
self.grad_checkpointing = False
|
| 350 |
+
self.return_all_layers = return_all_layers
|
| 351 |
+
self.flash_v2 = flash_v2
|
| 352 |
+
|
| 353 |
+
def set_grad_checkpointing(self, flag=True):
|
| 354 |
+
self.grad_checkpointing = flag
|
| 355 |
+
|
| 356 |
+
def forward(self,
|
| 357 |
+
x: torch.Tensor,
|
| 358 |
+
modulation: torch.Tensor = None,
|
| 359 |
+
context: torch.Tensor = None,
|
| 360 |
+
additional_residuals = None):
|
| 361 |
+
|
| 362 |
+
all_x = []
|
| 363 |
+
if additional_residuals is not None:
|
| 364 |
+
assert len(additional_residuals) == self.layers
|
| 365 |
+
for res_i, module in enumerate(self.resblocks):
|
| 366 |
+
if self.grad_checkpointing:
|
| 367 |
+
# print("Grad checkpointing")
|
| 368 |
+
x = checkpoint(module, x, modulation, context, res_i)
|
| 369 |
+
else:
|
| 370 |
+
x = module(x, modulation, context, res_i)
|
| 371 |
+
if additional_residuals is not None:
|
| 372 |
+
add_res = additional_residuals[res_i]
|
| 373 |
+
x[:, :add_res.shape[1]] = x[:, :add_res.shape[1]] + add_res
|
| 374 |
+
all_x.append(x)
|
| 375 |
+
if self.return_all_layers:
|
| 376 |
+
return all_x
|
| 377 |
+
else:
|
| 378 |
+
return x
|
| 379 |
+
|
| 380 |
+
class GaussianUpsampler(nn.Module):
|
| 381 |
+
def __init__(self, width,
|
| 382 |
+
up_ratio,
|
| 383 |
+
ch_decay=1,
|
| 384 |
+
low_channels=64,
|
| 385 |
+
window_size=False,
|
| 386 |
+
with_additional_inputs=False):
|
| 387 |
+
|
| 388 |
+
super().__init__()
|
| 389 |
+
self.up_ratio = up_ratio
|
| 390 |
+
self.low_channels = low_channels
|
| 391 |
+
self.window_size = window_size
|
| 392 |
+
self.base_width = width
|
| 393 |
+
self.with_additional_inputs = with_additional_inputs
|
| 394 |
+
for res_log2 in range(int(np.log2(up_ratio))):
|
| 395 |
+
_width = width
|
| 396 |
+
width = max(width // ch_decay, 64)
|
| 397 |
+
heads = int(width / 64)
|
| 398 |
+
width = heads * 64
|
| 399 |
+
if self.with_additional_inputs:
|
| 400 |
+
self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width+self.base_width, width, 2))
|
| 401 |
+
else:
|
| 402 |
+
self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width, width, 2))
|
| 403 |
+
encoder = Transformer(width, 2, heads,
|
| 404 |
+
modulate_feature_size=None,
|
| 405 |
+
modulate_act_type=None,
|
| 406 |
+
cross_att_layers=0,
|
| 407 |
+
return_all_layers=False,
|
| 408 |
+
flash_v2=False,
|
| 409 |
+
qkv_packed=False,
|
| 410 |
+
shift_group=False,
|
| 411 |
+
window_size=window_size)
|
| 412 |
+
self.add_module(f'attention_{res_log2}', encoder)
|
| 413 |
+
self.out_channels = width
|
| 414 |
+
self.ln_post = LayerNorm(width)
|
| 415 |
+
|
| 416 |
+
def forward(self, x, additional_inputs=None):
|
| 417 |
+
if self.with_additional_inputs:
|
| 418 |
+
assert len(additional_inputs) == int(np.log2(self.up_ratio))
|
| 419 |
+
for res_log2 in range(int(np.log2(self.up_ratio))):
|
| 420 |
+
if self.with_additional_inputs:
|
| 421 |
+
add_input = additional_inputs[res_log2]
|
| 422 |
+
scale = x.shape[1] // add_input.shape[1]
|
| 423 |
+
add_input = add_input.repeat_interleave(scale, 1)
|
| 424 |
+
x = torch.cat([x, add_input], dim=2)
|
| 425 |
+
x = getattr(self, f'upsampler_{res_log2}')(x)
|
| 426 |
+
x = x.permute(1, 0, 2)
|
| 427 |
+
x = getattr(self, f'attention_{res_log2}')(x)
|
| 428 |
+
x = x.permute(1, 0, 2)
|
| 429 |
+
x = self.ln_post(x)
|
| 430 |
+
return x
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class HyperGaussianUpsampler(nn.Module):
|
| 435 |
+
def __init__(self, width,
|
| 436 |
+
resolution,
|
| 437 |
+
up_ratio,
|
| 438 |
+
ch_decay=1,
|
| 439 |
+
window_size=False,
|
| 440 |
+
with_additional_inputs=False,
|
| 441 |
+
upsampler_kwargs={}):
|
| 442 |
+
|
| 443 |
+
super().__init__()
|
| 444 |
+
self.up_ratio = up_ratio
|
| 445 |
+
self.window_size = window_size
|
| 446 |
+
self.base_width = width
|
| 447 |
+
self.with_additional_inputs = with_additional_inputs
|
| 448 |
+
self.resolution = resolution
|
| 449 |
+
for res_log2 in range(int(np.log2(up_ratio))):
|
| 450 |
+
if res_log2 == 0:
|
| 451 |
+
_width = width
|
| 452 |
+
width = width
|
| 453 |
+
heads = int(width / 64)
|
| 454 |
+
width = heads * 64
|
| 455 |
+
if self.with_additional_inputs:
|
| 456 |
+
self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width+self.base_width, width, 2))
|
| 457 |
+
else:
|
| 458 |
+
self.add_module(f'upsampler_{res_log2}', PSUpsamplerBlock(_width, width, 2))
|
| 459 |
+
encoder = Transformer(width, 2, heads,
|
| 460 |
+
modulate_feature_size=None,
|
| 461 |
+
modulate_act_type=None,
|
| 462 |
+
cross_att_layers=0,
|
| 463 |
+
return_all_layers=False,
|
| 464 |
+
flash_v2=False,
|
| 465 |
+
qkv_packed=False,
|
| 466 |
+
shift_group=False,
|
| 467 |
+
window_size=window_size)
|
| 468 |
+
self.add_module(f'attention_{res_log2}', encoder)
|
| 469 |
+
self.resolution = self.resolution * 2
|
| 470 |
+
else:
|
| 471 |
+
self.resolution = self.resolution * 2
|
| 472 |
+
self.add_module(f'upsample_{res_log2}',
|
| 473 |
+
UpsamplerLayers_conv(in_channels=width,
|
| 474 |
+
out_channels=width,
|
| 475 |
+
resolution=self.resolution,
|
| 476 |
+
conv_block_type = 'convnext',
|
| 477 |
+
**upsampler_kwargs))
|
| 478 |
+
self.out_channels = width
|
| 479 |
+
# self.ln_post = LayerNorm(width)
|
| 480 |
+
self.ln_post = LayerNorm([self.resolution, self.resolution, width])
|
| 481 |
+
|
| 482 |
+
def forward(self, x, additional_inputs=None):
|
| 483 |
+
if self.with_additional_inputs:
|
| 484 |
+
assert len(additional_inputs) == int(np.log2(self.up_ratio))
|
| 485 |
+
for res_log2 in range(int(np.log2(self.up_ratio))):
|
| 486 |
+
if res_log2 == 0:
|
| 487 |
+
if self.with_additional_inputs:
|
| 488 |
+
add_input = additional_inputs[res_log2]
|
| 489 |
+
scale = x.shape[1] // add_input.shape[1]
|
| 490 |
+
add_input = add_input.repeat_interleave(scale, 1)
|
| 491 |
+
x = torch.cat([x, add_input], dim=2)
|
| 492 |
+
x = getattr(self, f'upsampler_{res_log2}')(x)
|
| 493 |
+
x = x.permute(1, 0, 2)
|
| 494 |
+
x = getattr(self, f'attention_{res_log2}')(x)
|
| 495 |
+
x = x.permute(1, 0, 2)
|
| 496 |
+
x = x.reshape(x.shape[0], int(math.sqrt(x.shape[1])), int(math.sqrt(x.shape[1])), -1).permute(0, 3, 1, 2)
|
| 497 |
+
else:
|
| 498 |
+
x = getattr(self, f'upsample_{res_log2}')(x)
|
| 499 |
+
x = self.ln_post(x.permute(0, 2, 3, 1))
|
| 500 |
+
return x
|
| 501 |
+
|
| 502 |
+
class VisionTransformer(nn.Module):
|
| 503 |
+
def __init__(self,
|
| 504 |
+
# transformer params
|
| 505 |
+
in_channels: int,
|
| 506 |
+
patch_size: int,
|
| 507 |
+
width: int,
|
| 508 |
+
layers: int,
|
| 509 |
+
heads: int,
|
| 510 |
+
weight: str = None,
|
| 511 |
+
encode_layers: int = 0,
|
| 512 |
+
shift_group = False,
|
| 513 |
+
flash_v2 = False,
|
| 514 |
+
qkv_packed = False,
|
| 515 |
+
window_size = False,
|
| 516 |
+
use_pe = False,
|
| 517 |
+
# modualtion params
|
| 518 |
+
modulate_feature_size: int = None,
|
| 519 |
+
modulate_act_type: str = 'gelu',
|
| 520 |
+
# camera condition
|
| 521 |
+
camera_condition: str = 'plucker',
|
| 522 |
+
# init params
|
| 523 |
+
disable_dino=False,
|
| 524 |
+
error_weight_init_mode='mean',
|
| 525 |
+
# other params
|
| 526 |
+
add_zero_conv=False,
|
| 527 |
+
return_all_layers=False,
|
| 528 |
+
disable_post_ln=False,
|
| 529 |
+
rope=None):
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.patch_size = patch_size
|
| 532 |
+
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 533 |
+
self.use_pe = use_pe
|
| 534 |
+
self.rope = rope
|
| 535 |
+
self.disable_dino = disable_dino
|
| 536 |
+
# if not self.disable_dino:
|
| 537 |
+
# scale = width ** -0.5
|
| 538 |
+
# self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 539 |
+
# self.positional_embedding = nn.Parameter(scale * torch.randn((input_res// patch_size) ** 2 + 1, width))
|
| 540 |
+
# else:
|
| 541 |
+
# if self.use_pe:
|
| 542 |
+
# self.positional_embedding = nn.Parameter(torch.zeros(1, (input_res// patch_size) ** 2, width))
|
| 543 |
+
# nn.init.trunc_normal_(self.positional_embedding, std=0.02)
|
| 544 |
+
self.ln_pre = LayerNorm(width)
|
| 545 |
+
self.add_zero_conv = add_zero_conv
|
| 546 |
+
self.return_all_layers = return_all_layers
|
| 547 |
+
self.disable_post_ln = disable_post_ln
|
| 548 |
+
self.flash_v2 = flash_v2
|
| 549 |
+
self.qkv_packed = qkv_packed
|
| 550 |
+
|
| 551 |
+
self.camera_condition = camera_condition
|
| 552 |
+
if self.camera_condition == 'plucker': assert modulate_feature_size is None
|
| 553 |
+
|
| 554 |
+
if self.add_zero_conv:
|
| 555 |
+
assert self.return_all_layers
|
| 556 |
+
self.zero_convs = nn.ModuleList([zero_module(nn.Conv1d(in_channels=width, out_channels=width, kernel_size=1, stride=1, bias=True)) for _ in range(layers)])
|
| 557 |
+
|
| 558 |
+
self.encode_layers = encode_layers
|
| 559 |
+
if self.encode_layers > 0:
|
| 560 |
+
self.encoder = Transformer(width, encode_layers, heads,
|
| 561 |
+
modulate_feature_size=modulate_feature_size,
|
| 562 |
+
modulate_act_type=modulate_act_type,
|
| 563 |
+
cross_att_layers=0,
|
| 564 |
+
return_all_layers=return_all_layers,
|
| 565 |
+
flash_v2=flash_v2,
|
| 566 |
+
qkv_packed=qkv_packed,
|
| 567 |
+
shift_group=shift_group,
|
| 568 |
+
window_size=window_size)
|
| 569 |
+
self.transformer = Transformer(width, layers-encode_layers, heads,
|
| 570 |
+
modulate_feature_size=modulate_feature_size,
|
| 571 |
+
modulate_act_type=modulate_act_type,
|
| 572 |
+
cross_att_layers=0,
|
| 573 |
+
return_all_layers=return_all_layers,
|
| 574 |
+
flash_v2=flash_v2,
|
| 575 |
+
qkv_packed=qkv_packed,
|
| 576 |
+
shift_group=shift_group,
|
| 577 |
+
window_size=window_size)
|
| 578 |
+
|
| 579 |
+
if not self.disable_post_ln:
|
| 580 |
+
self.ln_post = LayerNorm(width)
|
| 581 |
+
|
| 582 |
+
if weight is not None:
|
| 583 |
+
if not self.disable_dino:
|
| 584 |
+
if "clip" in weight:
|
| 585 |
+
raise NotImplementedError()
|
| 586 |
+
elif weight.startswith("vit_b_16"):
|
| 587 |
+
load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
|
| 588 |
+
elif weight.startswith("vit_b_8"):
|
| 589 |
+
load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
|
| 590 |
+
else:
|
| 591 |
+
raise NotImplementedError()
|
| 592 |
+
else:
|
| 593 |
+
self.apply(_init_weights)
|
| 594 |
+
|
| 595 |
+
# Init the weight and bias of modulation_fn to zero
|
| 596 |
+
if modulate_feature_size != 0:
|
| 597 |
+
for block in self.transformer.resblocks:
|
| 598 |
+
if block.modulation_fn is not None:
|
| 599 |
+
block.modulation_fn[2].weight.data.zero_()
|
| 600 |
+
block.modulation_fn[2].bias.data.zero_()
|
| 601 |
+
if block.mlp_modulation_fn is not None:
|
| 602 |
+
block.mlp_modulation_fn[2].weight.data.zero_()
|
| 603 |
+
block.mlp_modulation_fn[2].bias.data.zero_()
|
| 604 |
+
for block in self.transformer.resblocks:
|
| 605 |
+
if block.cross_att:
|
| 606 |
+
zero_module(block.cross_att.to_out)
|
| 607 |
+
|
| 608 |
+
def set_grad_checkpointing(self, flag=True):
|
| 609 |
+
self.transformer.set_grad_checkpointing(flag)
|
| 610 |
+
|
| 611 |
+
def forward(self,
|
| 612 |
+
x: torch.Tensor,
|
| 613 |
+
modulation: torch.Tensor = None,
|
| 614 |
+
context: torch.Tensor = None,
|
| 615 |
+
additional_residuals=None,
|
| 616 |
+
abla_crossview=False,
|
| 617 |
+
pos=None):
|
| 618 |
+
|
| 619 |
+
# image tokenization
|
| 620 |
+
bs, vs = x.shape[:2]
|
| 621 |
+
x = rearrange(x, 'b v c h w -> (b v) c h w')
|
| 622 |
+
pos = rearrange(pos, 'b v c h -> (b v) c h')
|
| 623 |
+
if self.camera_condition == 'plucker' and modulation is not None:
|
| 624 |
+
modulation = rearrange(modulation, 'b v c h w -> (b v) c h w')
|
| 625 |
+
x = torch.cat([x, modulation], dim=1)
|
| 626 |
+
modulation = None
|
| 627 |
+
|
| 628 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 629 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 630 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 631 |
+
|
| 632 |
+
# pre-normalization
|
| 633 |
+
x = self.ln_pre(x)
|
| 634 |
+
B, N, C = x.shape
|
| 635 |
+
x = x.reshape(B, N, -1, 64)
|
| 636 |
+
x = x.permute(0, 2, 1, 3)
|
| 637 |
+
# print('pre x mean: ', x.mean().item())
|
| 638 |
+
# print('pre x var: ', x.var().item())
|
| 639 |
+
x = x + self.rope(torch.ones_like(x).to(x), pos)
|
| 640 |
+
# print('x mean: ', x.mean().item())
|
| 641 |
+
# print('x var: ', x.var().item())
|
| 642 |
+
|
| 643 |
+
x = x.permute(0, 2, 1, 3)
|
| 644 |
+
x = x.reshape(B, N, -1)
|
| 645 |
+
# use encode to extract features
|
| 646 |
+
if self.encode_layers > 0:
|
| 647 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 648 |
+
x = self.encoder(x, modulation, context, additional_residuals=additional_residuals)
|
| 649 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 650 |
+
if not self.disable_dino:
|
| 651 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 652 |
+
else:
|
| 653 |
+
if not abla_crossview:
|
| 654 |
+
# flatten x along the video dimension
|
| 655 |
+
x = rearrange(x, '(b v) n d -> b (v n) d', v=vs)
|
| 656 |
+
# print(x.shape)
|
| 657 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 658 |
+
else:
|
| 659 |
+
x = x.permute(1, 0, 2)
|
| 660 |
+
x = self.transformer(x, modulation, context, additional_residuals=additional_residuals)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
if self.add_zero_conv:
|
| 664 |
+
assert isinstance(x, (list, tuple))
|
| 665 |
+
assert len(x) == len(self.zero_convs)
|
| 666 |
+
new_x = []
|
| 667 |
+
for sub_x, sub_zero_conv in zip(x, self.zero_convs):
|
| 668 |
+
sub_x_out = sub_zero_conv(sub_x.permute(1, 2, 0))
|
| 669 |
+
new_x.append(sub_x_out.permute(2, 0, 1))
|
| 670 |
+
x = new_x
|
| 671 |
+
|
| 672 |
+
if self.return_all_layers:
|
| 673 |
+
assert isinstance(x, (list, tuple))
|
| 674 |
+
if not self.disable_post_ln:
|
| 675 |
+
x_final = x[-1].permute(1, 0, 2) # LND -> NLD
|
| 676 |
+
x_final = self.ln_post(x_final)
|
| 677 |
+
x_final = rearrange(x_final, 'b (v n) d -> b v n d', v=vs)
|
| 678 |
+
x = [s.permute(1, 0, 2) for s in x]
|
| 679 |
+
x.append(x_final)
|
| 680 |
+
return x
|
| 681 |
+
|
| 682 |
+
if not self.disable_post_ln:
|
| 683 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 684 |
+
x = self.ln_post(x)
|
| 685 |
+
if not self.disable_dino:
|
| 686 |
+
x = rearrange(x, '(b v) n d -> b v n d', b=bs, v=vs)
|
| 687 |
+
else:
|
| 688 |
+
if not abla_crossview:
|
| 689 |
+
# reshape x back to video dimension
|
| 690 |
+
x = rearrange(x, 'b (v n) d -> b v n d', v=vs)
|
| 691 |
+
else:
|
| 692 |
+
x = rearrange(x, '(b v) n d -> b v n d', v=vs)
|
| 693 |
+
return x
|
| 694 |
+
|
| 695 |
+
def extra_repr(self) -> str:
|
| 696 |
+
pass
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
class VisionTransformer_fusion(nn.Module):
|
| 700 |
+
def __init__(self,
|
| 701 |
+
# transformer params
|
| 702 |
+
in_channels: int,
|
| 703 |
+
patch_size: int,
|
| 704 |
+
width: int,
|
| 705 |
+
layers: int,
|
| 706 |
+
heads: int,
|
| 707 |
+
weight: str = None,
|
| 708 |
+
encode_layers: int = 0,
|
| 709 |
+
shift_group = False,
|
| 710 |
+
flash_v2 = False,
|
| 711 |
+
qkv_packed = False,
|
| 712 |
+
window_size = False,
|
| 713 |
+
use_pe = False,
|
| 714 |
+
# modualtion params
|
| 715 |
+
modulate_feature_size: int = None,
|
| 716 |
+
modulate_act_type: str = 'gelu',
|
| 717 |
+
# camera condition
|
| 718 |
+
camera_condition: str = 'plucker',
|
| 719 |
+
# init params
|
| 720 |
+
disable_dino=False,
|
| 721 |
+
error_weight_init_mode='mean',
|
| 722 |
+
# other params
|
| 723 |
+
add_zero_conv=False,
|
| 724 |
+
return_all_layers=False,
|
| 725 |
+
disable_post_ln=False,
|
| 726 |
+
rope=None):
|
| 727 |
+
super().__init__()
|
| 728 |
+
self.patch_size = patch_size
|
| 729 |
+
self.use_pe = use_pe
|
| 730 |
+
self.rope = rope
|
| 731 |
+
self.disable_dino = disable_dino
|
| 732 |
+
# if not self.disable_dino:
|
| 733 |
+
# scale = width ** -0.5
|
| 734 |
+
# self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 735 |
+
# self.positional_embedding = nn.Parameter(scale * torch.randn((input_res// patch_size) ** 2 + 1, width))
|
| 736 |
+
# else:
|
| 737 |
+
# if self.use_pe:
|
| 738 |
+
# self.positional_embedding = nn.Parameter(torch.zeros(1, (input_res// patch_size) ** 2, width))
|
| 739 |
+
# nn.init.trunc_normal_(self.positional_embedding, std=0.02)
|
| 740 |
+
self.ln_pre = LayerNorm(width)
|
| 741 |
+
self.add_zero_conv = add_zero_conv
|
| 742 |
+
self.return_all_layers = return_all_layers
|
| 743 |
+
self.disable_post_ln = disable_post_ln
|
| 744 |
+
self.flash_v2 = flash_v2
|
| 745 |
+
self.qkv_packed = qkv_packed
|
| 746 |
+
|
| 747 |
+
self.camera_condition = camera_condition
|
| 748 |
+
if self.camera_condition == 'plucker': assert modulate_feature_size is None
|
| 749 |
+
|
| 750 |
+
if self.add_zero_conv:
|
| 751 |
+
assert self.return_all_layers
|
| 752 |
+
self.zero_convs = nn.ModuleList([zero_module(nn.Conv1d(in_channels=width, out_channels=width, kernel_size=1, stride=1, bias=True)) for _ in range(layers)])
|
| 753 |
+
|
| 754 |
+
self.encode_layers = encode_layers
|
| 755 |
+
if self.encode_layers > 0:
|
| 756 |
+
self.encoder = Transformer(width, encode_layers, heads,
|
| 757 |
+
modulate_feature_size=modulate_feature_size,
|
| 758 |
+
modulate_act_type=modulate_act_type,
|
| 759 |
+
cross_att_layers=0,
|
| 760 |
+
return_all_layers=return_all_layers,
|
| 761 |
+
flash_v2=flash_v2,
|
| 762 |
+
qkv_packed=qkv_packed,
|
| 763 |
+
shift_group=shift_group,
|
| 764 |
+
window_size=window_size)
|
| 765 |
+
self.transformer = Transformer(width, layers-encode_layers, heads,
|
| 766 |
+
modulate_feature_size=modulate_feature_size,
|
| 767 |
+
modulate_act_type=modulate_act_type,
|
| 768 |
+
cross_att_layers=0,
|
| 769 |
+
return_all_layers=return_all_layers,
|
| 770 |
+
flash_v2=flash_v2,
|
| 771 |
+
qkv_packed=qkv_packed,
|
| 772 |
+
shift_group=shift_group,
|
| 773 |
+
window_size=window_size)
|
| 774 |
+
|
| 775 |
+
if not self.disable_post_ln:
|
| 776 |
+
self.ln_post = LayerNorm(width)
|
| 777 |
+
|
| 778 |
+
if weight is not None:
|
| 779 |
+
if not self.disable_dino:
|
| 780 |
+
if "clip" in weight:
|
| 781 |
+
raise NotImplementedError()
|
| 782 |
+
elif weight.startswith("vit_b_16"):
|
| 783 |
+
load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
|
| 784 |
+
elif weight.startswith("vit_b_8"):
|
| 785 |
+
load_timm_to_clip(self, config_name=weight, init_mode=error_weight_init_mode)
|
| 786 |
+
else:
|
| 787 |
+
raise NotImplementedError()
|
| 788 |
+
else:
|
| 789 |
+
self.apply(_init_weights)
|
| 790 |
+
|
| 791 |
+
# Init the weight and bias of modulation_fn to zero
|
| 792 |
+
if modulate_feature_size != 0:
|
| 793 |
+
for block in self.transformer.resblocks:
|
| 794 |
+
if block.modulation_fn is not None:
|
| 795 |
+
block.modulation_fn[2].weight.data.zero_()
|
| 796 |
+
block.modulation_fn[2].bias.data.zero_()
|
| 797 |
+
if block.mlp_modulation_fn is not None:
|
| 798 |
+
block.mlp_modulation_fn[2].weight.data.zero_()
|
| 799 |
+
block.mlp_modulation_fn[2].bias.data.zero_()
|
| 800 |
+
for block in self.transformer.resblocks:
|
| 801 |
+
if block.cross_att:
|
| 802 |
+
zero_module(block.cross_att.to_out)
|
| 803 |
+
|
| 804 |
+
def set_grad_checkpointing(self, flag=True):
|
| 805 |
+
self.transformer.set_grad_checkpointing(flag)
|
| 806 |
+
|
| 807 |
+
def forward(self,
|
| 808 |
+
x: torch.Tensor,
|
| 809 |
+
modulation: torch.Tensor = None,
|
| 810 |
+
context: torch.Tensor = None,
|
| 811 |
+
additional_residuals=None,
|
| 812 |
+
abla_crossview=False,
|
| 813 |
+
pos=None):
|
| 814 |
+
|
| 815 |
+
# image tokenization
|
| 816 |
+
bs, vs = x.shape[:2]
|
| 817 |
+
x = rearrange(x, 'b v h g -> (b v) h g') # shape = [*, grid ** 2, width]
|
| 818 |
+
pos = rearrange(pos, 'b v c h -> (b v) c h')
|
| 819 |
+
|
| 820 |
+
# pre-normalization
|
| 821 |
+
B, N, C = x.shape
|
| 822 |
+
x = x.reshape(B, N, -1, 64)
|
| 823 |
+
x = x.permute(0, 2, 1, 3)
|
| 824 |
+
# print('pre x mean: ', x.mean().item())
|
| 825 |
+
# print('pre x var: ', x.var().item())
|
| 826 |
+
x = x + self.rope(torch.ones_like(x).to(x), pos)
|
| 827 |
+
# print('x mean: ', x.mean().item())
|
| 828 |
+
# print('x var: ', x.var().item())
|
| 829 |
+
x = x.permute(0, 2, 1, 3)
|
| 830 |
+
x = x.reshape(B, N, -1)
|
| 831 |
+
# use encode to extract features
|
| 832 |
+
if self.encode_layers > 0:
|
| 833 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 834 |
+
x = self.encoder(x, modulation, context, additional_residuals=additional_residuals)
|
| 835 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 836 |
+
if not self.disable_dino:
|
| 837 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 838 |
+
else:
|
| 839 |
+
if not abla_crossview:
|
| 840 |
+
# flatten x along the video dimension
|
| 841 |
+
x = rearrange(x, '(b v) n d -> b (v n) d', v=vs)
|
| 842 |
+
# print(x.shape)
|
| 843 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 844 |
+
else:
|
| 845 |
+
x = x.permute(1, 0, 2)
|
| 846 |
+
x = self.transformer(x, modulation, context, additional_residuals=additional_residuals)
|
| 847 |
+
|
| 848 |
+
if self.add_zero_conv:
|
| 849 |
+
assert isinstance(x, (list, tuple))
|
| 850 |
+
assert len(x) == len(self.zero_convs)
|
| 851 |
+
new_x = []
|
| 852 |
+
for sub_x, sub_zero_conv in zip(x, self.zero_convs):
|
| 853 |
+
sub_x_out = sub_zero_conv(sub_x.permute(1, 2, 0))
|
| 854 |
+
new_x.append(sub_x_out.permute(2, 0, 1))
|
| 855 |
+
x = new_x
|
| 856 |
+
|
| 857 |
+
if self.return_all_layers:
|
| 858 |
+
assert isinstance(x, (list, tuple))
|
| 859 |
+
if not self.disable_post_ln:
|
| 860 |
+
x_final = x[-1].permute(1, 0, 2) # LND -> NLD
|
| 861 |
+
x_final = self.ln_post(x_final)
|
| 862 |
+
x_final = rearrange(x_final, 'b (v n) d -> b v n d', v=vs)
|
| 863 |
+
x = [s.permute(1, 0, 2) for s in x]
|
| 864 |
+
x.append(x_final)
|
| 865 |
+
return x
|
| 866 |
+
|
| 867 |
+
if not self.disable_post_ln:
|
| 868 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 869 |
+
x = self.ln_post(x)
|
| 870 |
+
if not self.disable_dino:
|
| 871 |
+
x = rearrange(x, '(b v) n d -> b v n d', b=bs, v=vs)
|
| 872 |
+
else:
|
| 873 |
+
if not abla_crossview:
|
| 874 |
+
# reshape x back to video dimension
|
| 875 |
+
x = rearrange(x, 'b (v n) d -> b v n d', v=vs)
|
| 876 |
+
else:
|
| 877 |
+
x = rearrange(x, '(b v) n d -> b v n d', v=vs)
|
| 878 |
+
return x
|
| 879 |
+
|
| 880 |
+
def extra_repr(self) -> str:
|
| 881 |
+
pass
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic'):
|
| 885 |
+
"""
|
| 886 |
+
Resize positional embeddings, implementation from google/simclr and open_clip.
|
| 887 |
+
"""
|
| 888 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
| 889 |
+
old_pos_embed = state_dict.get('positional_embedding', None)
|
| 890 |
+
if old_pos_embed is None:
|
| 891 |
+
return
|
| 892 |
+
|
| 893 |
+
# Compute the grid size and extra tokens
|
| 894 |
+
old_pos_len = state_dict["positional_embedding"].shape[0]
|
| 895 |
+
old_grid_size = round((state_dict["positional_embedding"].shape[0]) ** 0.5)
|
| 896 |
+
grid_size = round((model.positional_embedding.shape[0]) ** 0.5)
|
| 897 |
+
if old_grid_size == grid_size:
|
| 898 |
+
return
|
| 899 |
+
extra_tokens = old_pos_len - (old_grid_size ** 2)
|
| 900 |
+
|
| 901 |
+
if extra_tokens:
|
| 902 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
| 903 |
+
else:
|
| 904 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
| 905 |
+
|
| 906 |
+
# Only interpolate the positional emb part, not the extra token part.
|
| 907 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
|
| 908 |
+
pos_emb_img = F.interpolate(
|
| 909 |
+
pos_emb_img,
|
| 910 |
+
size=grid_size,
|
| 911 |
+
mode=interpolation,
|
| 912 |
+
align_corners=True,
|
| 913 |
+
)
|
| 914 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size * grid_size, -1)[0]
|
| 915 |
+
|
| 916 |
+
# Concatenate back the
|
| 917 |
+
if pos_emb_tok is not None:
|
| 918 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
| 919 |
+
else:
|
| 920 |
+
new_pos_embed = pos_emb_img
|
| 921 |
+
state_dict['positional_embedding'] = new_pos_embed
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
myname2timmname = {
|
| 925 |
+
"vit_b_16_mae": None,
|
| 926 |
+
"vit_b_16_in": "vit_base_patch16_224",
|
| 927 |
+
"vit_b_16_in21k": 'vit_base_patch16_224_in21k',
|
| 928 |
+
"vit_b_16_sam": 'vit_base_patch16_224_sam',
|
| 929 |
+
"vit_b_16_dino": 'vit_base_patch16_224_dino',
|
| 930 |
+
"vit_b_16_mill_in21k": 'vit_base_patch16_224_miil_in21k',
|
| 931 |
+
"vit_b_16_mill": 'vit_base_patch16_224_miil',
|
| 932 |
+
"vit_b_8_dino": 'vit_base_patch16_224_dino',
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
def load_timm_to_clip(module, config_name="vit_b_16_mae", init_mode='zero'):
|
| 936 |
+
from torch import nn
|
| 937 |
+
from clip.model import LayerNorm as CLIPLayerNorm
|
| 938 |
+
from clip.model import QuickGELU
|
| 939 |
+
|
| 940 |
+
from torch.nn import GELU
|
| 941 |
+
from torch.nn import LayerNorm
|
| 942 |
+
|
| 943 |
+
import json
|
| 944 |
+
now_dir = os.path.abspath(os.path.dirname(__file__))
|
| 945 |
+
timm2clip = json.load(open(f"{now_dir}/timm2clip_vit_b_16.json"))
|
| 946 |
+
|
| 947 |
+
assert config_name in myname2timmname, f"The name {config_name} is not one of {list(myname2timmname.keys())}"
|
| 948 |
+
try:
|
| 949 |
+
timm_weight = torch.load(f"/sensei-fs/users/hatan/model/{config_name}.pth")["model"]
|
| 950 |
+
except Exception as e:
|
| 951 |
+
try:
|
| 952 |
+
print(f"/input/yhxu/models/dino_weights/{config_name}.pth")
|
| 953 |
+
timm_weight = torch.load(f"/input/yhxu/models/dino_weights/{config_name}.pth")["model"]
|
| 954 |
+
except Exception as e:
|
| 955 |
+
try:
|
| 956 |
+
print(f"/home/yhxu/models/dino_weights/{config_name}.pth")
|
| 957 |
+
timm_weight = torch.load(f"/home/yhxu/models/dino_weights/{config_name}.pth")["model"]
|
| 958 |
+
except:
|
| 959 |
+
try:
|
| 960 |
+
timm_weight = torch.load(f"/nas2/zifan/checkpoint/dino_weights/{config_name}.pth")["model"]
|
| 961 |
+
except Exception as e:
|
| 962 |
+
print("Please download weight with support/dump_timm_weights.py. \n"
|
| 963 |
+
"If using mae weight, please check https://github.com/facebookresearch/mae,"
|
| 964 |
+
"and download the weight as vit_b_16_mae.pth")
|
| 965 |
+
assert False
|
| 966 |
+
|
| 967 |
+
# Build model's state dict
|
| 968 |
+
clipname2timmweight = {}
|
| 969 |
+
for timm_key, clip_key in timm2clip.items():
|
| 970 |
+
timm_value = timm_weight[timm_key]
|
| 971 |
+
clipname2timmweight[clip_key[len("visual."):]] = timm_value.squeeze()
|
| 972 |
+
|
| 973 |
+
# Resize positional embedding
|
| 974 |
+
resize_pos_embed(clipname2timmweight, module)
|
| 975 |
+
|
| 976 |
+
# Load weight to model.
|
| 977 |
+
model_visual_keys = set(module.state_dict().keys())
|
| 978 |
+
load_keys = set(clipname2timmweight.keys())
|
| 979 |
+
# print(f"Load not in model: {load_keys - model_visual_keys}")
|
| 980 |
+
# print(f"Model not in load: {model_visual_keys - load_keys}")
|
| 981 |
+
# status = module.load_state_dict(clipname2timmweight, strict=False)
|
| 982 |
+
try:
|
| 983 |
+
status = module.load_state_dict(clipname2timmweight, strict=False)
|
| 984 |
+
except:
|
| 985 |
+
print('conv.weight has error!')
|
| 986 |
+
if init_mode == 'zero':
|
| 987 |
+
new_weight = torch.zeros_like(clipname2timmweight['conv1.weight'])
|
| 988 |
+
new_weight = new_weight.repeat(1, 2, 1, 1)
|
| 989 |
+
new_weight[:,:3] = clipname2timmweight['conv1.weight']
|
| 990 |
+
elif init_mode == 'mean':
|
| 991 |
+
new_weight = torch.zeros_like(clipname2timmweight['conv1.weight'])
|
| 992 |
+
new_weight = new_weight.repeat(1, 3, 1, 1)
|
| 993 |
+
new_weight = ((clipname2timmweight['conv1.weight']).repeat(1, 3, 1, 1))/3
|
| 994 |
+
|
| 995 |
+
clipname2timmweight['conv1.weight'] = new_weight
|
| 996 |
+
status = module.load_state_dict(clipname2timmweight, strict=False)
|
| 997 |
+
|
| 998 |
+
# Since timm model has bias, we add it back here.
|
| 999 |
+
module.conv1.bias = nn.Parameter(clipname2timmweight['conv1.bias'])
|
| 1000 |
+
|
| 1001 |
+
# Reinit the visual weights that not covered by timm
|
| 1002 |
+
module.ln_pre.reset_parameters()
|
| 1003 |
+
|
| 1004 |
+
def convert_clip_to_timm(module):
|
| 1005 |
+
"""Copy from detectron2, frozen BN"""
|
| 1006 |
+
res = module
|
| 1007 |
+
if isinstance(module, CLIPLayerNorm):
|
| 1008 |
+
# Timm uses eps=1e-6 while CLIP uses eps=1e-5
|
| 1009 |
+
res = LayerNorm(module.normalized_shape, eps=1e-6, elementwise_affine=module.elementwise_affine)
|
| 1010 |
+
if module.elementwise_affine:
|
| 1011 |
+
res.weight.data = module.weight.data.clone().detach()
|
| 1012 |
+
res.bias.data = module.bias.data.clone().detach()
|
| 1013 |
+
elif isinstance(module, QuickGELU):
|
| 1014 |
+
# Timm uses GELU while CLIP uses QuickGELU
|
| 1015 |
+
res = GELU()
|
| 1016 |
+
else:
|
| 1017 |
+
for name, child in module.named_children():
|
| 1018 |
+
new_child = convert_clip_to_timm(child)
|
| 1019 |
+
if new_child is not child:
|
| 1020 |
+
res.add_module(name, new_child)
|
| 1021 |
+
return res
|
dust3r/croco/models/x_transformer.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""from https://github.com/lucidrains/x-transformers"""
|
| 2 |
+
import math
|
| 3 |
+
from random import random
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn, einsum
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.utils.checkpoint import checkpoint
|
| 9 |
+
|
| 10 |
+
from functools import partial, wraps
|
| 11 |
+
from inspect import isfunction
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat, reduce
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# constants
|
| 17 |
+
|
| 18 |
+
DEFAULT_DIM_HEAD = 64
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# helpers
|
| 22 |
+
|
| 23 |
+
def exists(val):
|
| 24 |
+
return val is not None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def default(val, d):
|
| 28 |
+
if exists(val):
|
| 29 |
+
return val
|
| 30 |
+
return d() if isfunction(d) else d
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def cast_tuple(val, depth):
|
| 34 |
+
return val if isinstance(val, tuple) else (val,) * depth
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# init helpers
|
| 38 |
+
|
| 39 |
+
def init_zero_(layer):
|
| 40 |
+
nn.init.constant_(layer.weight, 0.)
|
| 41 |
+
if exists(layer.bias):
|
| 42 |
+
nn.init.constant_(layer.bias, 0.)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# keyword argument helpers
|
| 46 |
+
|
| 47 |
+
def pick_and_pop(keys, d):
|
| 48 |
+
values = list(map(lambda key: d.pop(key), keys))
|
| 49 |
+
return dict(zip(keys, values))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def group_dict_by_key(cond, d):
|
| 53 |
+
return_val = [dict(), dict()]
|
| 54 |
+
for key in d.keys():
|
| 55 |
+
match = bool(cond(key))
|
| 56 |
+
ind = int(not match)
|
| 57 |
+
return_val[ind][key] = d[key]
|
| 58 |
+
return (*return_val,)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def string_begins_with(prefix, str):
|
| 62 |
+
return str.startswith(prefix)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def group_by_key_prefix(prefix, d):
|
| 66 |
+
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def groupby_prefix_and_trim(prefix, d):
|
| 70 |
+
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
| 71 |
+
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
| 72 |
+
return kwargs_without_prefix, kwargs
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# initializations
|
| 76 |
+
|
| 77 |
+
def deepnorm_init(
|
| 78 |
+
transformer,
|
| 79 |
+
beta,
|
| 80 |
+
module_name_match_list=['.ff.', '.to_v', '.to_out']
|
| 81 |
+
):
|
| 82 |
+
for name, module in transformer.named_modules():
|
| 83 |
+
if type(module) != nn.Linear:
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list))
|
| 87 |
+
gain = beta if needs_beta_gain else 1
|
| 88 |
+
nn.init.xavier_normal_(module.weight.data, gain=gain)
|
| 89 |
+
|
| 90 |
+
if exists(module.bias):
|
| 91 |
+
nn.init.constant_(module.bias.data, 0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# activations
|
| 95 |
+
|
| 96 |
+
class ReluSquared(nn.Module):
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
return F.relu(x) ** 2
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# norms
|
| 102 |
+
|
| 103 |
+
class Scale(nn.Module):
|
| 104 |
+
def __init__(self, value, fn):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.value = value
|
| 107 |
+
self.fn = fn
|
| 108 |
+
|
| 109 |
+
def forward(self, x, **kwargs):
|
| 110 |
+
out = self.fn(x, **kwargs)
|
| 111 |
+
scale_fn = lambda t: t * self.value
|
| 112 |
+
|
| 113 |
+
if not isinstance(out, tuple):
|
| 114 |
+
return scale_fn(out)
|
| 115 |
+
|
| 116 |
+
return (scale_fn(out[0]), *out[1:])
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ScaleNorm(nn.Module):
|
| 120 |
+
def __init__(self, dim, eps=1e-5):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.eps = eps
|
| 123 |
+
self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
norm = torch.norm(x, dim=-1, keepdim=True)
|
| 127 |
+
return x / norm.clamp(min=self.eps) * self.g
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class RMSNorm(nn.Module):
|
| 131 |
+
def __init__(self, dim, eps=1e-8):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.scale = dim ** -0.5
|
| 134 |
+
self.eps = eps
|
| 135 |
+
self.g = nn.Parameter(torch.ones(dim))
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
| 139 |
+
return x / norm.clamp(min=self.eps) * self.g
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# residual and residual gates
|
| 143 |
+
|
| 144 |
+
class Residual(nn.Module):
|
| 145 |
+
def __init__(self, dim, scale_residual=False, scale_residual_constant=1.):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
| 148 |
+
self.scale_residual_constant = scale_residual_constant
|
| 149 |
+
|
| 150 |
+
def forward(self, x, residual):
|
| 151 |
+
if exists(self.residual_scale):
|
| 152 |
+
residual = residual * self.residual_scale
|
| 153 |
+
|
| 154 |
+
if self.scale_residual_constant != 1:
|
| 155 |
+
residual = residual * self.scale_residual_constant
|
| 156 |
+
|
| 157 |
+
return x + residual
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class GRUGating(nn.Module):
|
| 161 |
+
def __init__(self, dim, scale_residual=False, **kwargs):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.gru = nn.GRUCell(dim, dim)
|
| 164 |
+
self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
|
| 165 |
+
|
| 166 |
+
def forward(self, x, residual):
|
| 167 |
+
if exists(self.residual_scale):
|
| 168 |
+
residual = residual * self.residual_scale
|
| 169 |
+
|
| 170 |
+
gated_output = self.gru(
|
| 171 |
+
rearrange(x, 'b n d -> (b n) d'),
|
| 172 |
+
rearrange(residual, 'b n d -> (b n) d')
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return gated_output.reshape_as(x)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# feedforward
|
| 179 |
+
class GLU(nn.Module):
|
| 180 |
+
def __init__(self, dim_in, dim_out, activation):
|
| 181 |
+
super().__init__()
|
| 182 |
+
self.act = activation
|
| 183 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
| 187 |
+
return x * self.act(gate)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class FeedForward(nn.Module):
|
| 191 |
+
def __init__(
|
| 192 |
+
self,
|
| 193 |
+
dim,
|
| 194 |
+
dim_out=None,
|
| 195 |
+
mult=4,
|
| 196 |
+
glu=False,
|
| 197 |
+
swish=False,
|
| 198 |
+
relu_squared=False,
|
| 199 |
+
post_act_ln=False,
|
| 200 |
+
dropout=0.,
|
| 201 |
+
no_bias=False,
|
| 202 |
+
zero_init_output=False
|
| 203 |
+
):
|
| 204 |
+
super().__init__()
|
| 205 |
+
inner_dim = int(dim * mult)
|
| 206 |
+
dim_out = default(dim_out, dim)
|
| 207 |
+
|
| 208 |
+
if relu_squared:
|
| 209 |
+
activation = ReluSquared()
|
| 210 |
+
elif swish:
|
| 211 |
+
activation = nn.SiLU()
|
| 212 |
+
else:
|
| 213 |
+
activation = nn.GELU()
|
| 214 |
+
|
| 215 |
+
project_in = nn.Sequential(
|
| 216 |
+
nn.Linear(dim, inner_dim, bias=not no_bias),
|
| 217 |
+
activation
|
| 218 |
+
) if not glu else GLU(dim, inner_dim, activation)
|
| 219 |
+
|
| 220 |
+
self.ff = nn.Sequential(
|
| 221 |
+
project_in,
|
| 222 |
+
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
|
| 223 |
+
nn.Dropout(dropout),
|
| 224 |
+
nn.Linear(inner_dim, dim_out, bias=not no_bias)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# init last linear layer to 0
|
| 228 |
+
if zero_init_output:
|
| 229 |
+
init_zero_(self.ff[-1])
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
return self.ff(x)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# attention.
|
| 236 |
+
|
| 237 |
+
class Attention(nn.Module):
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
dim,
|
| 241 |
+
kv_dim=None,
|
| 242 |
+
dim_head=DEFAULT_DIM_HEAD,
|
| 243 |
+
heads=8,
|
| 244 |
+
causal=False,
|
| 245 |
+
dropout=0.,
|
| 246 |
+
zero_init_output=False,
|
| 247 |
+
shared_kv=False,
|
| 248 |
+
value_dim_head=None,
|
| 249 |
+
flash_attention=True,
|
| 250 |
+
):
|
| 251 |
+
super().__init__()
|
| 252 |
+
self.scale = dim_head ** -0.5
|
| 253 |
+
if kv_dim is None:
|
| 254 |
+
kv_dim = dim
|
| 255 |
+
|
| 256 |
+
self.heads = heads
|
| 257 |
+
self.causal = causal
|
| 258 |
+
|
| 259 |
+
value_dim_head = default(value_dim_head, dim_head)
|
| 260 |
+
q_dim = k_dim = dim_head * heads
|
| 261 |
+
v_dim = out_dim = value_dim_head * heads
|
| 262 |
+
|
| 263 |
+
self.to_q = nn.Linear(dim, q_dim, bias=False)
|
| 264 |
+
self.to_k = nn.Linear(kv_dim, k_dim, bias=False)
|
| 265 |
+
|
| 266 |
+
# shared key / values, for further memory savings during inference
|
| 267 |
+
assert not (
|
| 268 |
+
shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
| 269 |
+
self.to_v = nn.Linear(kv_dim, v_dim, bias=False) if not shared_kv else None
|
| 270 |
+
|
| 271 |
+
# Convert to output
|
| 272 |
+
self.to_out = nn.Linear(out_dim, dim, bias=False)
|
| 273 |
+
|
| 274 |
+
# dropout
|
| 275 |
+
self.dropout_p = dropout
|
| 276 |
+
self.dropout = nn.Dropout(dropout)
|
| 277 |
+
|
| 278 |
+
# Flash Attention, needs PyTorch >= 1.13
|
| 279 |
+
self.flash = flash_attention
|
| 280 |
+
assert self.flash
|
| 281 |
+
|
| 282 |
+
# Use torch.nn.functional.scaled_dot_product_attention if available
|
| 283 |
+
# otherwise, we use the xformer library.
|
| 284 |
+
# self.use_xformer = True
|
| 285 |
+
self.use_xformer = not hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
| 286 |
+
|
| 287 |
+
# init output projection 0
|
| 288 |
+
if zero_init_output:
|
| 289 |
+
init_zero_(self.to_out)
|
| 290 |
+
|
| 291 |
+
def forward(
|
| 292 |
+
self,
|
| 293 |
+
x,
|
| 294 |
+
context=None,
|
| 295 |
+
mask=None,
|
| 296 |
+
context_mask=None,
|
| 297 |
+
):
|
| 298 |
+
# print("x", x.dtype)
|
| 299 |
+
h = self.heads
|
| 300 |
+
kv_input = default(context, x)
|
| 301 |
+
|
| 302 |
+
q_input = x
|
| 303 |
+
k_input = kv_input
|
| 304 |
+
v_input = kv_input
|
| 305 |
+
|
| 306 |
+
q = self.to_q(q_input)
|
| 307 |
+
k = self.to_k(k_input)
|
| 308 |
+
v = self.to_v(v_input) if exists(self.to_v) else k
|
| 309 |
+
|
| 310 |
+
# print("q", q.dtype)
|
| 311 |
+
# print("k", k.dtype)
|
| 312 |
+
# print("v", v.dtype)
|
| 313 |
+
|
| 314 |
+
if self.use_xformer:
|
| 315 |
+
# Since xformers only accepts bf16/fp16, we need to convert qkv to bf16/fp16
|
| 316 |
+
dtype = q.dtype
|
| 317 |
+
q, k, v = map(lambda t: t.bfloat16() if t.dtype == torch.float32 else t, (q, k, v))
|
| 318 |
+
|
| 319 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v))
|
| 320 |
+
try:
|
| 321 |
+
import xformers.ops as xops
|
| 322 |
+
except ImportError as e:
|
| 323 |
+
print("Please install xformers to use flash attention for PyTorch < 2.0.0.")
|
| 324 |
+
raise e
|
| 325 |
+
|
| 326 |
+
# Use the flash attention support from the xformers library
|
| 327 |
+
if self.causal:
|
| 328 |
+
attention_bias = xops.LowerTriangularMask()
|
| 329 |
+
else:
|
| 330 |
+
attention_bias = None
|
| 331 |
+
|
| 332 |
+
# The memory_efficient_attention takes the input as (batch, seq_len, heads, dim)
|
| 333 |
+
out = xops.memory_efficient_attention(
|
| 334 |
+
q, k, v, attn_bias=attention_bias,
|
| 335 |
+
# op=(xops.fmha.flash.FwOp, xops.fmha.flash.BwOp),
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
out = out.to(dtype)
|
| 339 |
+
|
| 340 |
+
out = rearrange(out, 'b n h d -> b n (h d)')
|
| 341 |
+
else:
|
| 342 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
| 343 |
+
# efficient attention using Flash Attention CUDA kernels
|
| 344 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 345 |
+
q, k, v, attn_mask=None, dropout_p=self.dropout_p, is_causal=self.causal,
|
| 346 |
+
)
|
| 347 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 348 |
+
|
| 349 |
+
out = self.to_out(out)
|
| 350 |
+
|
| 351 |
+
if exists(mask):
|
| 352 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
| 353 |
+
out = out.masked_fill(~mask, 0.)
|
| 354 |
+
|
| 355 |
+
return out
|
| 356 |
+
|
| 357 |
+
def extra_repr(self) -> str:
|
| 358 |
+
return f"causal: {self.causal}, flash attention: {self.flash}, " \
|
| 359 |
+
f"use_xformers (if False, use torch.nn.functional.scaled_dot_product_attention): {self.use_xformer}"
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def modulate(x, shift, scale):
|
| 363 |
+
# from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 364 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class AttentionLayers(nn.Module):
|
| 368 |
+
def __init__(
|
| 369 |
+
self,
|
| 370 |
+
dim,
|
| 371 |
+
depth,
|
| 372 |
+
heads=8,
|
| 373 |
+
ctx_dim=None,
|
| 374 |
+
causal=False,
|
| 375 |
+
cross_attend=False,
|
| 376 |
+
only_cross=False,
|
| 377 |
+
use_scalenorm=False,
|
| 378 |
+
use_rmsnorm=False,
|
| 379 |
+
residual_attn=False,
|
| 380 |
+
cross_residual_attn=False,
|
| 381 |
+
macaron=False,
|
| 382 |
+
pre_norm=True,
|
| 383 |
+
gate_residual=False,
|
| 384 |
+
scale_residual=False,
|
| 385 |
+
scale_residual_constant=1.,
|
| 386 |
+
deepnorm=False,
|
| 387 |
+
sandwich_norm=False,
|
| 388 |
+
zero_init_branch_output=False,
|
| 389 |
+
layer_dropout=0.,
|
| 390 |
+
# Below are the arguments used for this img2nerf projects
|
| 391 |
+
modulate_feature_size=-1,
|
| 392 |
+
checkpointing=False,
|
| 393 |
+
checkpoint_every=1,
|
| 394 |
+
**kwargs
|
| 395 |
+
):
|
| 396 |
+
super().__init__()
|
| 397 |
+
|
| 398 |
+
# Add checkpointing
|
| 399 |
+
self.checkpointing = checkpointing
|
| 400 |
+
self.checkpoint_every = checkpoint_every
|
| 401 |
+
|
| 402 |
+
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
| 403 |
+
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
|
| 404 |
+
|
| 405 |
+
self.dim = dim
|
| 406 |
+
self.depth = depth
|
| 407 |
+
self.layers = nn.ModuleList([])
|
| 408 |
+
|
| 409 |
+
# determine deepnorm and residual scale
|
| 410 |
+
if deepnorm:
|
| 411 |
+
assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings'
|
| 412 |
+
pre_norm = sandwich_norm = False
|
| 413 |
+
scale_residual = True
|
| 414 |
+
scale_residual_constant = (2 * depth) ** 0.25
|
| 415 |
+
|
| 416 |
+
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
| 417 |
+
self.pre_norm = pre_norm
|
| 418 |
+
self.sandwich_norm = sandwich_norm
|
| 419 |
+
|
| 420 |
+
self.residual_attn = residual_attn
|
| 421 |
+
self.cross_residual_attn = cross_residual_attn
|
| 422 |
+
self.cross_attend = cross_attend
|
| 423 |
+
|
| 424 |
+
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
| 425 |
+
norm_class = RMSNorm if use_rmsnorm else norm_class
|
| 426 |
+
norm_fn = partial(norm_class, dim)
|
| 427 |
+
|
| 428 |
+
if cross_attend and not only_cross:
|
| 429 |
+
default_block = ('a', 'c', 'f')
|
| 430 |
+
elif cross_attend and only_cross:
|
| 431 |
+
default_block = ('c', 'f')
|
| 432 |
+
else:
|
| 433 |
+
default_block = ('a', 'f')
|
| 434 |
+
|
| 435 |
+
if macaron:
|
| 436 |
+
default_block = ('f',) + default_block
|
| 437 |
+
|
| 438 |
+
# zero init
|
| 439 |
+
|
| 440 |
+
if zero_init_branch_output:
|
| 441 |
+
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
| 442 |
+
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
| 443 |
+
|
| 444 |
+
# calculate layer block order
|
| 445 |
+
layer_types = default_block * depth
|
| 446 |
+
|
| 447 |
+
self.layer_types = layer_types
|
| 448 |
+
|
| 449 |
+
# stochastic depth
|
| 450 |
+
self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
|
| 451 |
+
|
| 452 |
+
# iterate and construct layers
|
| 453 |
+
for ind, layer_type in enumerate(self.layer_types):
|
| 454 |
+
is_last_layer = ind == (len(self.layer_types) - 1)
|
| 455 |
+
|
| 456 |
+
if layer_type == 'a':
|
| 457 |
+
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
| 458 |
+
elif layer_type == 'c':
|
| 459 |
+
layer = Attention(dim, kv_dim=ctx_dim, heads=heads, **attn_kwargs)
|
| 460 |
+
elif layer_type == 'f':
|
| 461 |
+
layer = FeedForward(dim, **ff_kwargs)
|
| 462 |
+
layer = layer if not macaron else Scale(0.5, layer)
|
| 463 |
+
else:
|
| 464 |
+
raise Exception(f'invalid layer type {layer_type}')
|
| 465 |
+
|
| 466 |
+
residual_fn = GRUGating if gate_residual else Residual
|
| 467 |
+
residual = residual_fn(dim, scale_residual=scale_residual, scale_residual_constant=scale_residual_constant)
|
| 468 |
+
|
| 469 |
+
pre_branch_norm = norm_fn() if pre_norm else None
|
| 470 |
+
post_branch_norm = norm_fn() if sandwich_norm else None
|
| 471 |
+
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
|
| 472 |
+
|
| 473 |
+
# The whole modulation part is copied from DiT
|
| 474 |
+
# https://github.com/facebookresearch/DiT
|
| 475 |
+
modulation = None
|
| 476 |
+
if modulate_feature_size is not None:
|
| 477 |
+
modulation = nn.Sequential(
|
| 478 |
+
nn.LayerNorm(modulate_feature_size),
|
| 479 |
+
nn.GELU(),
|
| 480 |
+
nn.Linear(modulate_feature_size, 3 * dim, bias=True)
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
norms = nn.ModuleList([
|
| 484 |
+
pre_branch_norm,
|
| 485 |
+
post_branch_norm,
|
| 486 |
+
post_main_norm,
|
| 487 |
+
])
|
| 488 |
+
|
| 489 |
+
self.layers.append(nn.ModuleList([
|
| 490 |
+
norms,
|
| 491 |
+
layer,
|
| 492 |
+
residual,
|
| 493 |
+
modulation,
|
| 494 |
+
]))
|
| 495 |
+
|
| 496 |
+
if deepnorm:
|
| 497 |
+
init_gain = (8 * depth) ** -0.25
|
| 498 |
+
deepnorm_init(self, init_gain)
|
| 499 |
+
|
| 500 |
+
def forward(
|
| 501 |
+
self,
|
| 502 |
+
x,
|
| 503 |
+
context=None,
|
| 504 |
+
modulation=None,
|
| 505 |
+
mask=None,
|
| 506 |
+
context_mask=None,
|
| 507 |
+
):
|
| 508 |
+
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
| 509 |
+
|
| 510 |
+
num_layers = len(self.layer_types)
|
| 511 |
+
assert num_layers % self.checkpoint_every == 0
|
| 512 |
+
|
| 513 |
+
for start_layer_idx in range(0, num_layers, self.checkpoint_every):
|
| 514 |
+
end_layer_idx = min(start_layer_idx + self.checkpoint_every, num_layers)
|
| 515 |
+
|
| 516 |
+
def run_layers(x, context, modulation, start, end):
|
| 517 |
+
for ind, (layer_type, (norm, block, residual_fn, modulation_fn), layer_dropout) in enumerate(
|
| 518 |
+
zip(self.layer_types[start: end], self.layers[start: end], self.layer_dropouts[start: end])):
|
| 519 |
+
residual = x
|
| 520 |
+
|
| 521 |
+
pre_branch_norm, post_branch_norm, post_main_norm = norm
|
| 522 |
+
|
| 523 |
+
if exists(pre_branch_norm):
|
| 524 |
+
x = pre_branch_norm(x)
|
| 525 |
+
|
| 526 |
+
if modulation_fn is not None:
|
| 527 |
+
shift, scale, gate = modulation_fn(modulation).chunk(3, dim=1)
|
| 528 |
+
x = modulate(x, shift, scale)
|
| 529 |
+
|
| 530 |
+
if layer_type == 'a':
|
| 531 |
+
out = block(x, mask=mask)
|
| 532 |
+
elif layer_type == 'c':
|
| 533 |
+
out = block(x, context=context, mask=mask, context_mask=context_mask)
|
| 534 |
+
elif layer_type == 'f':
|
| 535 |
+
out = block(x)
|
| 536 |
+
|
| 537 |
+
if exists(post_branch_norm):
|
| 538 |
+
out = post_branch_norm(out)
|
| 539 |
+
|
| 540 |
+
if modulation_fn is not None:
|
| 541 |
+
# TODO: add a option to use gate or not.
|
| 542 |
+
out = out * gate.unsqueeze(1)
|
| 543 |
+
|
| 544 |
+
x = residual_fn(out, residual)
|
| 545 |
+
|
| 546 |
+
if exists(post_main_norm):
|
| 547 |
+
x = post_main_norm(x)
|
| 548 |
+
|
| 549 |
+
return x
|
| 550 |
+
|
| 551 |
+
if self.checkpointing:
|
| 552 |
+
# print("X checkpointing")
|
| 553 |
+
x = checkpoint(run_layers, x, context, modulation, start_layer_idx, end_layer_idx)
|
| 554 |
+
else:
|
| 555 |
+
x = run_layers(x, context, modulation, start_layer_idx, end_layer_idx)
|
| 556 |
+
|
| 557 |
+
return x
|
| 558 |
+
|
dust3r/croco/utils/misc.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# utilitary functions for CroCo
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# MAE: https://github.com/facebookresearch/mae
|
| 9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 10 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 11 |
+
# --------------------------------------------------------
|
| 12 |
+
|
| 13 |
+
import builtins
|
| 14 |
+
import datetime
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
import math
|
| 18 |
+
import json
|
| 19 |
+
from collections import defaultdict, deque
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
import numpy as np
|
| 22 |
+
from datetime import timedelta
|
| 23 |
+
import torch
|
| 24 |
+
import torch.distributed as dist
|
| 25 |
+
from torch import inf
|
| 26 |
+
import functools
|
| 27 |
+
from typing import cast, Dict, Iterable, List, Optional, Tuple, Union
|
| 28 |
+
from typing_extensions import deprecated
|
| 29 |
+
|
| 30 |
+
class SmoothedValue(object):
|
| 31 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 32 |
+
window or the global series average.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, window_size=20, fmt=None):
|
| 36 |
+
if fmt is None:
|
| 37 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 38 |
+
self.deque = deque(maxlen=window_size)
|
| 39 |
+
self.total = 0.0
|
| 40 |
+
self.count = 0
|
| 41 |
+
self.fmt = fmt
|
| 42 |
+
|
| 43 |
+
def update(self, value, n=1):
|
| 44 |
+
self.deque.append(value)
|
| 45 |
+
self.count += n
|
| 46 |
+
self.total += value * n
|
| 47 |
+
|
| 48 |
+
def synchronize_between_processes(self):
|
| 49 |
+
"""
|
| 50 |
+
Warning: does not synchronize the deque!
|
| 51 |
+
"""
|
| 52 |
+
if not is_dist_avail_and_initialized():
|
| 53 |
+
return
|
| 54 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 55 |
+
dist.barrier()
|
| 56 |
+
dist.all_reduce(t)
|
| 57 |
+
t = t.tolist()
|
| 58 |
+
self.count = int(t[0])
|
| 59 |
+
self.total = t[1]
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def median(self):
|
| 63 |
+
d = torch.tensor(list(self.deque))
|
| 64 |
+
return d.median().item()
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def avg(self):
|
| 68 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 69 |
+
return d.mean().item()
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def global_avg(self):
|
| 73 |
+
return self.total / self.count
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def max(self):
|
| 77 |
+
return max(self.deque)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def value(self):
|
| 81 |
+
return self.deque[-1]
|
| 82 |
+
|
| 83 |
+
def __str__(self):
|
| 84 |
+
return self.fmt.format(
|
| 85 |
+
median=self.median,
|
| 86 |
+
avg=self.avg,
|
| 87 |
+
global_avg=self.global_avg,
|
| 88 |
+
max=self.max,
|
| 89 |
+
value=self.value)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class MetricLogger(object):
|
| 93 |
+
def __init__(self, delimiter="\t"):
|
| 94 |
+
self.meters = defaultdict(SmoothedValue)
|
| 95 |
+
self.delimiter = delimiter
|
| 96 |
+
|
| 97 |
+
def update(self, **kwargs):
|
| 98 |
+
for k, v in kwargs.items():
|
| 99 |
+
if v is None:
|
| 100 |
+
continue
|
| 101 |
+
if isinstance(v, torch.Tensor):
|
| 102 |
+
v = v.item()
|
| 103 |
+
assert isinstance(v, (float, int))
|
| 104 |
+
self.meters[k].update(v)
|
| 105 |
+
|
| 106 |
+
def __getattr__(self, attr):
|
| 107 |
+
if attr in self.meters:
|
| 108 |
+
return self.meters[attr]
|
| 109 |
+
if attr in self.__dict__:
|
| 110 |
+
return self.__dict__[attr]
|
| 111 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 112 |
+
type(self).__name__, attr))
|
| 113 |
+
|
| 114 |
+
def __str__(self):
|
| 115 |
+
loss_str = []
|
| 116 |
+
for name, meter in self.meters.items():
|
| 117 |
+
loss_str.append(
|
| 118 |
+
"{}: {}".format(name, str(meter))
|
| 119 |
+
)
|
| 120 |
+
return self.delimiter.join(loss_str)
|
| 121 |
+
|
| 122 |
+
def synchronize_between_processes(self):
|
| 123 |
+
for meter in self.meters.values():
|
| 124 |
+
meter.synchronize_between_processes()
|
| 125 |
+
|
| 126 |
+
def add_meter(self, name, meter):
|
| 127 |
+
self.meters[name] = meter
|
| 128 |
+
|
| 129 |
+
def log_every(self, iterable, print_freq, header=None, max_iter=None):
|
| 130 |
+
i = 0
|
| 131 |
+
if not header:
|
| 132 |
+
header = ''
|
| 133 |
+
start_time = time.time()
|
| 134 |
+
end = time.time()
|
| 135 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 136 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 137 |
+
len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable)
|
| 138 |
+
space_fmt = ':' + str(len(str(len_iterable))) + 'd'
|
| 139 |
+
log_msg = [
|
| 140 |
+
header,
|
| 141 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 142 |
+
'eta: {eta}',
|
| 143 |
+
'{meters}',
|
| 144 |
+
'time: {time}',
|
| 145 |
+
'data: {data}'
|
| 146 |
+
]
|
| 147 |
+
if torch.cuda.is_available():
|
| 148 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 149 |
+
log_msg = self.delimiter.join(log_msg)
|
| 150 |
+
MB = 1024.0 * 1024.0
|
| 151 |
+
for it,obj in enumerate(iterable):
|
| 152 |
+
data_time.update(time.time() - end)
|
| 153 |
+
yield obj
|
| 154 |
+
iter_time.update(time.time() - end)
|
| 155 |
+
if i % print_freq == 0 or i == len_iterable - 1:
|
| 156 |
+
eta_seconds = iter_time.global_avg * (len_iterable - i)
|
| 157 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 158 |
+
if torch.cuda.is_available():
|
| 159 |
+
print(log_msg.format(
|
| 160 |
+
i, len_iterable, eta=eta_string,
|
| 161 |
+
meters=str(self),
|
| 162 |
+
time=str(iter_time), data=str(data_time),
|
| 163 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 164 |
+
else:
|
| 165 |
+
print(log_msg.format(
|
| 166 |
+
i, len_iterable, eta=eta_string,
|
| 167 |
+
meters=str(self),
|
| 168 |
+
time=str(iter_time), data=str(data_time)))
|
| 169 |
+
i += 1
|
| 170 |
+
end = time.time()
|
| 171 |
+
if max_iter and it >= max_iter:
|
| 172 |
+
break
|
| 173 |
+
total_time = time.time() - start_time
|
| 174 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 175 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 176 |
+
header, total_time_str, total_time / len_iterable))
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def setup_for_distributed(is_master):
|
| 180 |
+
"""
|
| 181 |
+
This function disables printing when not in master process
|
| 182 |
+
"""
|
| 183 |
+
builtin_print = builtins.print
|
| 184 |
+
|
| 185 |
+
def print(*args, **kwargs):
|
| 186 |
+
force = kwargs.pop('force', False)
|
| 187 |
+
force = force #or (get_world_size() > 8)
|
| 188 |
+
if is_master or force:
|
| 189 |
+
now = datetime.datetime.now().time()
|
| 190 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
| 191 |
+
builtin_print(*args, **kwargs)
|
| 192 |
+
|
| 193 |
+
builtins.print = print
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def is_dist_avail_and_initialized():
|
| 197 |
+
if not dist.is_available():
|
| 198 |
+
return False
|
| 199 |
+
if not dist.is_initialized():
|
| 200 |
+
return False
|
| 201 |
+
return True
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_world_size():
|
| 205 |
+
if not is_dist_avail_and_initialized():
|
| 206 |
+
return 1
|
| 207 |
+
return dist.get_world_size()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def get_rank():
|
| 211 |
+
if not is_dist_avail_and_initialized():
|
| 212 |
+
return 0
|
| 213 |
+
return dist.get_rank()
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def is_main_process():
|
| 217 |
+
return get_rank() == 0
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def save_on_master(*args, **kwargs):
|
| 221 |
+
if is_main_process():
|
| 222 |
+
torch.save(*args, **kwargs)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def init_distributed_mode(args):
|
| 226 |
+
nodist = args.nodist if hasattr(args,'nodist') else False
|
| 227 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ and not nodist:
|
| 228 |
+
args.rank = int(os.environ["RANK"])
|
| 229 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 230 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 231 |
+
else:
|
| 232 |
+
print('Not using distributed mode')
|
| 233 |
+
setup_for_distributed(is_master=True) # hack
|
| 234 |
+
args.distributed = False
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
# args.distributed = True
|
| 238 |
+
|
| 239 |
+
# torch.cuda.set_device(args.gpu)
|
| 240 |
+
# args.dist_backend = 'nccl'
|
| 241 |
+
# print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 242 |
+
# args.rank, args.dist_url, args.gpu), flush=True)
|
| 243 |
+
# # os.environ['TORCH_NCCL_BLOCKING_WAIT'] = '0' # not to enforce timeout
|
| 244 |
+
# torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 245 |
+
# timeout=timedelta(seconds=72000000),
|
| 246 |
+
# world_size=args.world_size, rank=args.rank)
|
| 247 |
+
# # print('| distributed is master {} {} {}'.format(is_main_process(), is_dist_avail_and_initialized(), dist.get_rank()), flush=True)
|
| 248 |
+
# torch.distributed.barrier()
|
| 249 |
+
# setup_for_distributed(args.gpu == 0)
|
| 250 |
+
args.distributed = True
|
| 251 |
+
torch.cuda.set_device(args.gpu)
|
| 252 |
+
args.dist_backend = 'nccl'
|
| 253 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 254 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
| 255 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timedelta(seconds=72000000),
|
| 256 |
+
world_size=args.world_size, rank=args.rank)
|
| 257 |
+
torch.distributed.barrier()
|
| 258 |
+
setup_for_distributed(args.gpu == 0)
|
| 259 |
+
|
| 260 |
+
def _no_grad(func):
|
| 261 |
+
"""
|
| 262 |
+
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
|
| 263 |
+
clip_grad_norm_ and clip_grad_value_ themselves.
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def _no_grad_wrapper(*args, **kwargs):
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
return func(*args, **kwargs)
|
| 269 |
+
|
| 270 |
+
functools.update_wrapper(_no_grad_wrapper, func)
|
| 271 |
+
return _no_grad_wrapper
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
@_no_grad
|
| 275 |
+
def clip_grad_norm_(
|
| 276 |
+
parameters,
|
| 277 |
+
max_norm,
|
| 278 |
+
norm_type= 2.0,
|
| 279 |
+
error_if_nonfinite = False,
|
| 280 |
+
foreach = None,
|
| 281 |
+
):
|
| 282 |
+
r"""Clip the gradient norm of an iterable of parameters.
|
| 283 |
+
|
| 284 |
+
The norm is computed over the norms of the individual gradients of all parameters,
|
| 285 |
+
as if the norms of the individual gradients were concatenated into a single vector.
|
| 286 |
+
Gradients are modified in-place.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
| 290 |
+
single Tensor that will have gradients normalized
|
| 291 |
+
max_norm (float): max norm of the gradients
|
| 292 |
+
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
| 293 |
+
infinity norm.
|
| 294 |
+
error_if_nonfinite (bool): if True, an error is thrown if the total
|
| 295 |
+
norm of the gradients from :attr:`parameters` is ``nan``,
|
| 296 |
+
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
| 297 |
+
foreach (bool): use the faster foreach-based implementation.
|
| 298 |
+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
| 299 |
+
fall back to the slow implementation for other device types.
|
| 300 |
+
Default: ``None``
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
Total norm of the parameter gradients (viewed as a single vector).
|
| 304 |
+
"""
|
| 305 |
+
if isinstance(parameters, torch.Tensor):
|
| 306 |
+
parameters = [parameters]
|
| 307 |
+
grads = [p.grad for p in parameters if p.grad is not None]
|
| 308 |
+
max_norm = float(max_norm)
|
| 309 |
+
norm_type = float(norm_type)
|
| 310 |
+
if len(grads) == 0:
|
| 311 |
+
return torch.tensor(0.0)
|
| 312 |
+
first_device = grads[0].device
|
| 313 |
+
grouped_grads: Dict[
|
| 314 |
+
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
| 315 |
+
] = _group_tensors_by_device_and_dtype(
|
| 316 |
+
[grads]
|
| 317 |
+
) # type: ignore[assignment]
|
| 318 |
+
|
| 319 |
+
norms: List[Tensor] = []
|
| 320 |
+
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
| 321 |
+
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
| 322 |
+
foreach and _device_has_foreach_support(device)
|
| 323 |
+
):
|
| 324 |
+
norms.extend(torch._foreach_norm(device_grads, norm_type))
|
| 325 |
+
elif foreach:
|
| 326 |
+
raise RuntimeError(
|
| 327 |
+
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
| 328 |
+
)
|
| 329 |
+
else:
|
| 330 |
+
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
|
| 331 |
+
|
| 332 |
+
total_norm = torch.linalg.vector_norm(
|
| 333 |
+
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
| 334 |
+
)
|
| 335 |
+
return total_norm
|
| 336 |
+
|
| 337 |
+
class NativeScalerWithGradNormCount:
|
| 338 |
+
state_dict_key = "amp_scaler"
|
| 339 |
+
|
| 340 |
+
def __init__(self, enabled=True):
|
| 341 |
+
self._scaler = torch.cuda.amp.GradScaler(enabled=False)
|
| 342 |
+
|
| 343 |
+
def __call__(self, loss, optimizer, clip_grad=10, parameters=None, create_graph=False, update_grad=True):
|
| 344 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
| 345 |
+
if update_grad:
|
| 346 |
+
# if clip_grad is not None:
|
| 347 |
+
# assert parameters is not None
|
| 348 |
+
|
| 349 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
| 350 |
+
with torch.no_grad():
|
| 351 |
+
if isinstance(parameters, torch.Tensor):
|
| 352 |
+
parameters = [parameters]
|
| 353 |
+
parameters = [p for p in parameters if p.grad is not None]
|
| 354 |
+
for p in parameters:
|
| 355 |
+
if p.grad is None:
|
| 356 |
+
print(f"WARNING: found a None grad, name is {n}, step is {step}", force=True)
|
| 357 |
+
else:
|
| 358 |
+
p.grad.nan_to_num_(nan=0., posinf=1e-3, neginf=-1e-3)
|
| 359 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, 10.)
|
| 360 |
+
# else:
|
| 361 |
+
# self._scaler.unscale_(optimizer)
|
| 362 |
+
# norm = get_grad_norm_(parameters)
|
| 363 |
+
self._scaler.step(optimizer)
|
| 364 |
+
self._scaler.update()
|
| 365 |
+
else:
|
| 366 |
+
norm = None
|
| 367 |
+
return norm
|
| 368 |
+
|
| 369 |
+
def state_dict(self):
|
| 370 |
+
return self._scaler.state_dict()
|
| 371 |
+
|
| 372 |
+
def load_state_dict(self, state_dict):
|
| 373 |
+
self._scaler.load_state_dict(state_dict)
|
| 374 |
+
|
| 375 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
| 376 |
+
if isinstance(parameters, torch.Tensor):
|
| 377 |
+
parameters = [parameters]
|
| 378 |
+
parameters = [p for p in parameters if p.grad is not None]
|
| 379 |
+
norm_type = float(norm_type)
|
| 380 |
+
if len(parameters) == 0:
|
| 381 |
+
return torch.tensor(0.)
|
| 382 |
+
device = parameters[0].grad.device
|
| 383 |
+
if norm_type == inf:
|
| 384 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 385 |
+
else:
|
| 386 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
| 387 |
+
|
| 388 |
+
return total_norm
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def save_model(args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None):
|
| 394 |
+
output_dir = Path(args.output_dir)
|
| 395 |
+
if fname is None: fname = str(epoch)
|
| 396 |
+
checkpoint_path = output_dir / ('checkpoint-%s.pth' % fname)
|
| 397 |
+
to_save = {
|
| 398 |
+
'model': model_without_ddp.state_dict(),
|
| 399 |
+
'optimizer': optimizer.state_dict(),
|
| 400 |
+
'scaler': loss_scaler.state_dict(),
|
| 401 |
+
'args': args,
|
| 402 |
+
'epoch': epoch,
|
| 403 |
+
}
|
| 404 |
+
if best_so_far is not None: to_save['best_so_far'] = best_so_far
|
| 405 |
+
print(f'>> Saving model to {checkpoint_path} ...')
|
| 406 |
+
save_on_master(to_save, checkpoint_path)
|
| 407 |
+
if is_main_process():
|
| 408 |
+
os.system('ossutil64 cp -f %s oss://antsys-vilab/zsz/checkpoints/%s -j 200' % (checkpoint_path, checkpoint_path))
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def load_model(args, model_without_ddp, optimizer, loss_scaler):
|
| 412 |
+
args.start_epoch = 0
|
| 413 |
+
best_so_far = None
|
| 414 |
+
if args.resume is not None:
|
| 415 |
+
if args.resume.startswith('https'):
|
| 416 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 417 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 418 |
+
else:
|
| 419 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 420 |
+
print("Resume checkpoint %s" % args.resume)
|
| 421 |
+
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
| 422 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 423 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 424 |
+
if 'scaler' in checkpoint:
|
| 425 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 426 |
+
if 'best_so_far' in checkpoint:
|
| 427 |
+
best_so_far = checkpoint['best_so_far']
|
| 428 |
+
print(" & best_so_far={:g}".format(best_so_far))
|
| 429 |
+
else:
|
| 430 |
+
print("")
|
| 431 |
+
print("With optim & sched! start_epoch={:d}".format(args.start_epoch), end='')
|
| 432 |
+
return best_so_far
|
| 433 |
+
|
| 434 |
+
def all_reduce_mean(x):
|
| 435 |
+
world_size = get_world_size()
|
| 436 |
+
if world_size > 1:
|
| 437 |
+
x_reduce = torch.tensor(x).cuda()
|
| 438 |
+
dist.all_reduce(x_reduce)
|
| 439 |
+
x_reduce /= world_size
|
| 440 |
+
return x_reduce.item()
|
| 441 |
+
else:
|
| 442 |
+
return x
|
| 443 |
+
|
| 444 |
+
def _replace(text, src, tgt, rm=''):
|
| 445 |
+
""" Advanced string replacement.
|
| 446 |
+
Given a text:
|
| 447 |
+
- replace all elements in src by the corresponding element in tgt
|
| 448 |
+
- remove all elements in rm
|
| 449 |
+
"""
|
| 450 |
+
if len(tgt) == 1:
|
| 451 |
+
tgt = tgt * len(src)
|
| 452 |
+
assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len"
|
| 453 |
+
for s,t in zip(src, tgt):
|
| 454 |
+
text = text.replace(s,t)
|
| 455 |
+
for c in rm:
|
| 456 |
+
text = text.replace(c,'')
|
| 457 |
+
return text
|
| 458 |
+
|
| 459 |
+
def filename( obj ):
|
| 460 |
+
""" transform a python obj or cmd into a proper filename.
|
| 461 |
+
- \1 gets replaced by slash '/'
|
| 462 |
+
- \2 gets replaced by comma ','
|
| 463 |
+
"""
|
| 464 |
+
if not isinstance(obj, str):
|
| 465 |
+
obj = repr(obj)
|
| 466 |
+
obj = str(obj).replace('()','')
|
| 467 |
+
obj = _replace(obj, '_,(*/\1\2','-__x%/,', rm=' )\'"')
|
| 468 |
+
assert all(len(s) < 256 for s in obj.split(os.sep)), 'filename too long (>256 characters):\n'+obj
|
| 469 |
+
return obj
|
| 470 |
+
|
| 471 |
+
def _get_num_layer_for_vit(var_name, enc_depth, dec_depth):
|
| 472 |
+
if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"):
|
| 473 |
+
return 0
|
| 474 |
+
elif var_name.startswith("patch_embed"):
|
| 475 |
+
return 0
|
| 476 |
+
elif var_name.startswith("enc_blocks"):
|
| 477 |
+
layer_id = int(var_name.split('.')[1])
|
| 478 |
+
return layer_id + 1
|
| 479 |
+
elif var_name.startswith('decoder_embed') or var_name.startswith('enc_norm'): # part of the last black
|
| 480 |
+
return enc_depth
|
| 481 |
+
elif var_name.startswith('dec_blocks'):
|
| 482 |
+
layer_id = int(var_name.split('.')[1])
|
| 483 |
+
return enc_depth + layer_id + 1
|
| 484 |
+
elif var_name.startswith('dec_norm'): # part of the last block
|
| 485 |
+
return enc_depth + dec_depth
|
| 486 |
+
elif any(var_name.startswith(k) for k in ['head','prediction_head']):
|
| 487 |
+
return enc_depth + dec_depth + 1
|
| 488 |
+
else:
|
| 489 |
+
raise NotImplementedError(var_name)
|
| 490 |
+
|
| 491 |
+
def get_parameter_groups(model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]):
|
| 492 |
+
parameter_group_names = {}
|
| 493 |
+
parameter_group_vars = {}
|
| 494 |
+
enc_depth, dec_depth = None, None
|
| 495 |
+
# prepare layer decay values
|
| 496 |
+
assert layer_decay==1.0 or 0.<layer_decay<1.
|
| 497 |
+
if layer_decay<1.:
|
| 498 |
+
enc_depth = model.enc_depth
|
| 499 |
+
dec_depth = model.dec_depth if hasattr(model, 'dec_blocks') else 0
|
| 500 |
+
num_layers = enc_depth+dec_depth
|
| 501 |
+
layer_decay_values = list(layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))
|
| 502 |
+
|
| 503 |
+
for name, param in model.named_parameters():
|
| 504 |
+
if not param.requires_grad:
|
| 505 |
+
continue # frozen weights
|
| 506 |
+
|
| 507 |
+
# Assign weight decay values
|
| 508 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
|
| 509 |
+
group_name = "no_decay"
|
| 510 |
+
this_weight_decay = 0.
|
| 511 |
+
elif 'mask_token' in name or 'pathch_embed' in name or 'enc_blocks' in name:
|
| 512 |
+
group_name = "encoder"
|
| 513 |
+
this_weight_decay = weight_decay
|
| 514 |
+
else:
|
| 515 |
+
group_name = "decay"
|
| 516 |
+
this_weight_decay = weight_decay
|
| 517 |
+
|
| 518 |
+
# Assign layer ID for LR scaling
|
| 519 |
+
if layer_decay<1.:
|
| 520 |
+
skip_scale = False
|
| 521 |
+
layer_id = _get_num_layer_for_vit(name, enc_depth, dec_depth)
|
| 522 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
| 523 |
+
if name in no_lr_scale_list:
|
| 524 |
+
skip_scale = True
|
| 525 |
+
group_name = f'{group_name}_no_lr_scale'
|
| 526 |
+
else:
|
| 527 |
+
layer_id = 0
|
| 528 |
+
skip_scale = True
|
| 529 |
+
|
| 530 |
+
if group_name not in parameter_group_names:
|
| 531 |
+
if not skip_scale:
|
| 532 |
+
scale = layer_decay_values[layer_id]
|
| 533 |
+
elif group_name == "encoder":
|
| 534 |
+
scale = 0.5
|
| 535 |
+
else:
|
| 536 |
+
scale = 1.
|
| 537 |
+
|
| 538 |
+
parameter_group_names[group_name] = {
|
| 539 |
+
"weight_decay": this_weight_decay,
|
| 540 |
+
"params": [],
|
| 541 |
+
"lr_scale": scale
|
| 542 |
+
}
|
| 543 |
+
parameter_group_vars[group_name] = {
|
| 544 |
+
"weight_decay": this_weight_decay,
|
| 545 |
+
"params": [],
|
| 546 |
+
"lr_scale": scale
|
| 547 |
+
}
|
| 548 |
+
parameter_group_vars[group_name]["params"].append(param)
|
| 549 |
+
parameter_group_names[group_name]["params"].append(name)
|
| 550 |
+
print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
| 551 |
+
return list(parameter_group_vars.values())
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
| 556 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
| 557 |
+
lr_peak = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
|
| 558 |
+
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
| 559 |
+
T_mult = 1
|
| 560 |
+
warmup_iters = 2
|
| 561 |
+
T_0 = args.cycle_epoch
|
| 562 |
+
end = args.epochs
|
| 563 |
+
epoch_int = int(epoch) % T_0
|
| 564 |
+
decimal_part = epoch - math.floor(epoch)
|
| 565 |
+
epoch = decimal_part + epoch_int
|
| 566 |
+
T_cur = epoch
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
if T_cur < warmup_iters:
|
| 570 |
+
warmup_ratio = T_cur / warmup_iters
|
| 571 |
+
lr = args.min_lr + (lr_peak - args.min_lr) * warmup_ratio
|
| 572 |
+
else:
|
| 573 |
+
T_cur_adjusted = T_cur - warmup_iters
|
| 574 |
+
T_i = T_0 - warmup_iters
|
| 575 |
+
lr = args.min_lr + (lr_peak - args.min_lr) * (1 + math.cos(math.pi * T_cur_adjusted / T_i)) / 2
|
| 576 |
+
# 1e-5 + (1e-4-1e-5) * (1+math.cos(math.pi * (10-2) / 98)) / 2
|
| 577 |
+
for param_group in optimizer.param_groups:
|
| 578 |
+
if "lr_scale" in param_group:
|
| 579 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
| 580 |
+
else:
|
| 581 |
+
param_group["lr"] = lr
|
| 582 |
+
return lr
|
| 583 |
+
|
dust3r/dust3r/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
dust3r/dust3r/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (150 Bytes). View file
|
|
|
dust3r/dust3r/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
dust3r/dust3r/__pycache__/patch_embed.cpython-312.pyc
ADDED
|
Binary file (4.86 kB). View file
|
|
|
dust3r/dust3r/__pycache__/viz.cpython-312.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
dust3r/dust3r/datasets/CustomDataset.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Dataloader for preprocessed arkitscenes
|
| 6 |
+
# dataset at https://github.com/apple/ARKitScenes - Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license
|
| 7 |
+
# See datasets_preprocess/preprocess_arkitscenes.py
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
import os.path as osp
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import random
|
| 13 |
+
import mast3r.utils.path_to_dust3r # noqa
|
| 14 |
+
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset_test
|
| 15 |
+
from collections import deque
|
| 16 |
+
import os
|
| 17 |
+
import json
|
| 18 |
+
import time
|
| 19 |
+
import glob
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
class CustomDataset(BaseStereoViewDataset_test):
|
| 23 |
+
def __init__(self, *args, split, ROOT, wpose=False, sequential_input=False, index_list=None, **kwargs):
|
| 24 |
+
self.ROOT = ROOT
|
| 25 |
+
self.wpose = wpose
|
| 26 |
+
self.sequential_input = sequential_input
|
| 27 |
+
super().__init__(*args, **kwargs)
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return 684000
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def image_read(image_file):
|
| 34 |
+
img = cv2.imread(image_file)
|
| 35 |
+
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 36 |
+
|
| 37 |
+
def read_cam_file(self, filename):
|
| 38 |
+
with open(filename) as f:
|
| 39 |
+
lines = [line.rstrip() for line in f.readlines()]
|
| 40 |
+
# extrinsics: line [1,5), 4x4 matrix
|
| 41 |
+
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
|
| 42 |
+
extrinsics = extrinsics.reshape((4, 4))
|
| 43 |
+
# intrinsics: line [7-10), 3x3 matrix
|
| 44 |
+
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
|
| 45 |
+
intrinsics = intrinsics.reshape((3, 3))
|
| 46 |
+
# depth_min & depth_interval: line 11
|
| 47 |
+
depth_min = float(lines[11].split()[0])
|
| 48 |
+
depth_interval = float(lines[11].split()[1])
|
| 49 |
+
return intrinsics, extrinsics, depth_min, depth_interval
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _get_views(self, idx, resolution, rng):
|
| 53 |
+
images_list = glob.glob(osp.join(self.ROOT, '*.png')) + glob.glob(osp.join(self.ROOT, '*.jpg')) + glob.glob(osp.join(self.ROOT, '*.JPG'))
|
| 54 |
+
images_list = sorted(images_list)
|
| 55 |
+
if self.num_image != len(images_list):
|
| 56 |
+
images_list = random.sample(images_list, self.num_image)
|
| 57 |
+
self.gt_num_image = 0
|
| 58 |
+
views = []
|
| 59 |
+
for image in images_list:
|
| 60 |
+
rgb_image = self.image_read(image)
|
| 61 |
+
H, W = rgb_image.shape[:2]
|
| 62 |
+
if self.wpose == False:
|
| 63 |
+
intrinsics = np.array([[W, 0, W/2], [0, H, H/2], [0, 0, 1]])
|
| 64 |
+
camera_pose = np.eye(4)
|
| 65 |
+
else:
|
| 66 |
+
image_index = image.split('/')[-1].split('.')[0]
|
| 67 |
+
proj_mat_filename = os.path.join(self.ROOT, image_index+'.txt')
|
| 68 |
+
intrinsics, camera_pose, depth_min, depth_interval = self.read_cam_file(proj_mat_filename)
|
| 69 |
+
camera_pose = np.linalg.inv(camera_pose)
|
| 70 |
+
|
| 71 |
+
depthmap = np.zeros((H, W))
|
| 72 |
+
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
|
| 73 |
+
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=None)
|
| 74 |
+
rgb_image_orig = rgb_image.copy()
|
| 75 |
+
H, W = depthmap.shape[:2]
|
| 76 |
+
fxfycxcy = np.array([intrinsics[0, 0]/W, intrinsics[1, 1]/H, intrinsics[0,2]/W, intrinsics[1,2]/H]).astype(np.float32)
|
| 77 |
+
views.append(dict(
|
| 78 |
+
img_org=rgb_image_orig,
|
| 79 |
+
img=rgb_image,
|
| 80 |
+
depthmap=depthmap.astype(np.float32),
|
| 81 |
+
camera_pose=camera_pose.astype(np.float32),
|
| 82 |
+
camera_intrinsics=intrinsics.astype(np.float32),
|
| 83 |
+
fxfycxcy=fxfycxcy,
|
| 84 |
+
dataset='custom',
|
| 85 |
+
label=image,
|
| 86 |
+
instance=image,
|
| 87 |
+
))
|
| 88 |
+
return views
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
from dust3r.datasets.base.base_stereo_view_dataset import view_name
|
| 93 |
+
from dust3r.viz import SceneViz, auto_cam_size
|
| 94 |
+
from dust3r.utils.image import rgb
|
| 95 |
+
import nerfvis.scene as scene_vis
|
| 96 |
+
|
| 97 |
+
for idx in np.random.permutation(len(dataset)):
|
| 98 |
+
views = dataset[idx]
|
| 99 |
+
# assert len(views) == 2
|
| 100 |
+
# print(view_name(views[0]), view_name(views[1]))
|
| 101 |
+
view_idxs = list(range(len(views)))
|
| 102 |
+
poses = [views[view_idx]['camera_pose'] for view_idx in view_idxs]
|
| 103 |
+
cam_size = max(auto_cam_size(poses), 0.001)
|
| 104 |
+
pts3ds = []
|
| 105 |
+
colors = []
|
| 106 |
+
valid_masks = []
|
| 107 |
+
c2ws = []
|
| 108 |
+
intrinsics = []
|
| 109 |
+
for view_idx in view_idxs:
|
| 110 |
+
pts3d = views[view_idx]['pts3d']
|
| 111 |
+
pts3ds.append(pts3d)
|
| 112 |
+
valid_mask = views[view_idx]['valid_mask']
|
| 113 |
+
valid_masks.append(valid_mask)
|
| 114 |
+
color = rgb(views[view_idx]['img'])
|
| 115 |
+
colors.append(color)
|
| 116 |
+
# viz.add_pointcloud(pts3d, colors, valid_mask)
|
| 117 |
+
c2ws.append(views[view_idx]['camera_pose'])
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
pts3ds = np.stack(pts3ds, axis=0)
|
| 121 |
+
colors = np.stack(colors, axis=0)
|
| 122 |
+
valid_masks = np.stack(valid_masks, axis=0)
|
| 123 |
+
c2ws = np.stack(c2ws)
|
| 124 |
+
scene_vis.set_title("My Scene")
|
| 125 |
+
scene_vis.set_opencv()
|
| 126 |
+
# colors = torch.zeros_like(structure).to(structure)
|
| 127 |
+
# scene_vis.add_points("points", pts3ds.reshape(-1,3)[valid_masks.reshape(-1)], vert_color=colors.reshape(-1,3)[valid_masks.reshape(-1)], point_size=1)
|
| 128 |
+
# for i in range(len(c2ws)):
|
| 129 |
+
f = 1111.0 / 2.5
|
| 130 |
+
z = 10.
|
| 131 |
+
scene_vis.add_camera_frustum("cameras", r=c2ws[:, :3, :3], t=c2ws[:, :3, 3], focal_length=f,
|
| 132 |
+
image_width=colors.shape[2], image_height=colors.shape[1],
|
| 133 |
+
z=z, connect=False, color=[1.0, 0.0, 0.0])
|
| 134 |
+
for i in range(len(c2ws)):
|
| 135 |
+
scene_vis.add_image(
|
| 136 |
+
f"images/{i}",
|
| 137 |
+
colors[i], # Can be a list of paths too (requires joblib for that)
|
| 138 |
+
r=c2ws[i, :3, :3],
|
| 139 |
+
t=c2ws[i, :3, 3],
|
| 140 |
+
# Alternatively: from nerfvis.utils import split_mat4; **split_mat4(c2ws)
|
| 141 |
+
focal_length=f,
|
| 142 |
+
z=z,
|
| 143 |
+
)
|
| 144 |
+
scene_vis.display(port=8081)
|
| 145 |
+
|
dust3r/dust3r/datasets/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
from .utils.transforms import *
|
| 4 |
+
from .base.batched_sampler import BatchedRandomSampler # noqa
|
| 5 |
+
from .CustomDataset import CustomDataset # noqa
|
| 6 |
+
|
| 7 |
+
def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
|
| 8 |
+
import torch
|
| 9 |
+
from croco.utils.misc import get_world_size, get_rank
|
| 10 |
+
# pytorch dataset
|
| 11 |
+
if isinstance(dataset, str):
|
| 12 |
+
dataset = eval(dataset)
|
| 13 |
+
|
| 14 |
+
world_size = get_world_size()
|
| 15 |
+
rank = get_rank()
|
| 16 |
+
try:
|
| 17 |
+
sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
|
| 18 |
+
rank=rank, drop_last=drop_last)
|
| 19 |
+
except (AttributeError, NotImplementedError):
|
| 20 |
+
# not avail for this dataset
|
| 21 |
+
if torch.distributed.is_initialized():
|
| 22 |
+
sampler = torch.utils.data.DistributedSampler(
|
| 23 |
+
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
|
| 24 |
+
)
|
| 25 |
+
elif shuffle:
|
| 26 |
+
sampler = torch.utils.data.RandomSampler(dataset)
|
| 27 |
+
else:
|
| 28 |
+
sampler = torch.utils.data.SequentialSampler(dataset)
|
| 29 |
+
|
| 30 |
+
data_loader = torch.utils.data.DataLoader(
|
| 31 |
+
dataset,
|
| 32 |
+
sampler=sampler,
|
| 33 |
+
batch_size=batch_size,
|
| 34 |
+
num_workers=num_workers,
|
| 35 |
+
pin_memory=pin_mem,
|
| 36 |
+
drop_last=drop_last,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
return data_loader
|
dust3r/dust3r/datasets/__pycache__/CustomDataset.cpython-312.pyc
ADDED
|
Binary file (8.14 kB). View file
|
|
|
dust3r/dust3r/datasets/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.76 kB). View file
|
|
|
dust3r/dust3r/datasets/base/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
dust3r/dust3r/datasets/base/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
dust3r/dust3r/datasets/base/__pycache__/base_stereo_view_dataset.cpython-312.pyc
ADDED
|
Binary file (32.6 kB). View file
|
|
|
dust3r/dust3r/datasets/base/__pycache__/batched_sampler.cpython-312.pyc
ADDED
|
Binary file (4.09 kB). View file
|
|
|
dust3r/dust3r/datasets/base/__pycache__/easy_dataset.cpython-312.pyc
ADDED
|
Binary file (8.87 kB). View file
|
|
|
dust3r/dust3r/datasets/base/base_stereo_view_dataset.py
ADDED
|
@@ -0,0 +1,774 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# base class for implementing datasets
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import PIL
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from dust3r.datasets.base.easy_dataset import EasyDataset
|
| 12 |
+
from dust3r.datasets.utils.transforms import ImgNorm
|
| 13 |
+
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf
|
| 14 |
+
import dust3r.datasets.utils.cropping as cropping
|
| 15 |
+
import random
|
| 16 |
+
import copy
|
| 17 |
+
from scipy.spatial.transform import Rotation
|
| 18 |
+
import torchvision.transforms as transforms
|
| 19 |
+
from dust3r.utils.geometry import inv, geotrf
|
| 20 |
+
import cv2
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BaseStereoViewDataset_test (EasyDataset):
|
| 25 |
+
""" Define all basic options.
|
| 26 |
+
|
| 27 |
+
Usage:
|
| 28 |
+
class MyDataset (BaseStereoViewDataset):
|
| 29 |
+
def _get_views(self, idx, rng):
|
| 30 |
+
# overload here
|
| 31 |
+
views = []
|
| 32 |
+
views.append(dict(img=, ...))
|
| 33 |
+
return views
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, *, # only keyword arguments
|
| 37 |
+
split=None,
|
| 38 |
+
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
| 39 |
+
transform=ImgNorm,
|
| 40 |
+
aug_crop=False,
|
| 41 |
+
seed=None,
|
| 42 |
+
num_views=8,
|
| 43 |
+
gt_num_image=0,
|
| 44 |
+
aug_monocular=False,
|
| 45 |
+
aug_portrait_or_landscape=False,
|
| 46 |
+
aug_rot90=False,
|
| 47 |
+
aug_swap=False,
|
| 48 |
+
only_pose=False,
|
| 49 |
+
sequential_input=False,
|
| 50 |
+
overfit=False,
|
| 51 |
+
caculate_mask=False):
|
| 52 |
+
self.sequential_input = sequential_input
|
| 53 |
+
self.split = split
|
| 54 |
+
self.num_image = num_views
|
| 55 |
+
self._set_resolutions(resolution)
|
| 56 |
+
self.gt_num_image=gt_num_image
|
| 57 |
+
self.aug_monocular=aug_monocular
|
| 58 |
+
self.aug_portrait_or_landscape = aug_portrait_or_landscape
|
| 59 |
+
self.transform = transform
|
| 60 |
+
self.transform_org = transforms.Compose([transform for transform in transform.transforms if type(transform).__name__ != 'ColorJitter'])
|
| 61 |
+
self.aug_rot90 = aug_rot90
|
| 62 |
+
self.aug_swap = aug_swap
|
| 63 |
+
self.only_pose = only_pose
|
| 64 |
+
self.overfit = overfit
|
| 65 |
+
self.rendering = False
|
| 66 |
+
self.caculate_mask = caculate_mask
|
| 67 |
+
if isinstance(transform, str):
|
| 68 |
+
transform = eval(transform)
|
| 69 |
+
|
| 70 |
+
self.aug_crop = aug_crop
|
| 71 |
+
self.seed = seed
|
| 72 |
+
self.kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(9, 9))
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
return len(self.scenes)
|
| 76 |
+
|
| 77 |
+
# def sequential_sample(self, im_start, last, interal):
|
| 78 |
+
# im_list = [im_start + i * interal + random.choice(list(range(interal))) for i in range(self.num_image)]
|
| 79 |
+
# im_list += [random.choice(im_list) + random.choice(list(range(interal))) for _ in range(self.gt_num_image)]
|
| 80 |
+
# return im_list
|
| 81 |
+
def sequential_sample(self, im_start, last, interal):
|
| 82 |
+
im_list = [
|
| 83 |
+
im_start + i * interal + random.choice(list(range(-interal//2, interal//2)))
|
| 84 |
+
for i in range(self.num_image)
|
| 85 |
+
]
|
| 86 |
+
im_list += [
|
| 87 |
+
random.choice(im_list) + random.choice(list(range(-interal//2, interal//2)))
|
| 88 |
+
for _ in range(self.gt_num_image)
|
| 89 |
+
]
|
| 90 |
+
return im_list
|
| 91 |
+
|
| 92 |
+
def get_stats(self):
|
| 93 |
+
return f"{len(self)} pairs"
|
| 94 |
+
|
| 95 |
+
def __repr__(self):
|
| 96 |
+
resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
|
| 97 |
+
return f"""{type(self).__name__}({self.get_stats()},
|
| 98 |
+
{self.split=},
|
| 99 |
+
{self.seed=},
|
| 100 |
+
resolutions={resolutions_str},
|
| 101 |
+
{self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '')
|
| 102 |
+
|
| 103 |
+
def _get_views(self, idx, resolution, rng):
|
| 104 |
+
raise NotImplementedError()
|
| 105 |
+
|
| 106 |
+
def _swap_view_aug(self, views):
|
| 107 |
+
# if self._rng.random() < 0.5:
|
| 108 |
+
# views.reverse()
|
| 109 |
+
return random.shuffle(views)
|
| 110 |
+
|
| 111 |
+
def __getitem__(self, idx):
|
| 112 |
+
if isinstance(idx, tuple):
|
| 113 |
+
# the idx is specifying the aspect-ratio
|
| 114 |
+
idx, ar_idx = idx
|
| 115 |
+
else:
|
| 116 |
+
assert len(self._resolutions) == 1
|
| 117 |
+
ar_idx = 0
|
| 118 |
+
|
| 119 |
+
# set-up the rng
|
| 120 |
+
if self.seed: # reseed for each __getitem__
|
| 121 |
+
self._rng = np.random.default_rng(seed=self.seed + idx)
|
| 122 |
+
elif not hasattr(self, '_rng'):
|
| 123 |
+
seed = torch.initial_seed() # this is different for each dataloader process
|
| 124 |
+
self._rng = np.random.default_rng(seed=seed)
|
| 125 |
+
|
| 126 |
+
# over-loaded code
|
| 127 |
+
resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
| 128 |
+
flag = False
|
| 129 |
+
i = 0
|
| 130 |
+
# while flag == False and i < 100:
|
| 131 |
+
# try:
|
| 132 |
+
# views = self._get_views(idx, resolution, self._rng)
|
| 133 |
+
# flag = True
|
| 134 |
+
# except:
|
| 135 |
+
# flag = False
|
| 136 |
+
# i += 1
|
| 137 |
+
|
| 138 |
+
views = self._get_views(idx, resolution, self._rng)
|
| 139 |
+
|
| 140 |
+
# assert len(views) == self.num_image + self.gt_num_image
|
| 141 |
+
if self.only_pose == True:
|
| 142 |
+
# check data-types
|
| 143 |
+
for view in views:
|
| 144 |
+
# transpose to make sure all views are the same size
|
| 145 |
+
# this allows to check whether the RNG is is the same state each time
|
| 146 |
+
view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
|
| 147 |
+
return views
|
| 148 |
+
else:
|
| 149 |
+
for v, view in enumerate(views):
|
| 150 |
+
assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
| 151 |
+
view['idx'] = (idx, ar_idx, v)
|
| 152 |
+
# encode the image
|
| 153 |
+
width, height = view['img'].size
|
| 154 |
+
view['true_shape'] = np.int32((height, width))
|
| 155 |
+
view['img'] = self.transform_org(view['img'])
|
| 156 |
+
view['img_org' ] = self.transform_org(view['img_org'])
|
| 157 |
+
if 'depth_anything' not in view:
|
| 158 |
+
view['depth_anything'] = np.zeros_like(view['depthmap'])
|
| 159 |
+
# if view['img_org'].shape[1] != 224:
|
| 160 |
+
# print(view['img_org' ].shape)
|
| 161 |
+
# print(view['img'].shape)
|
| 162 |
+
assert 'camera_intrinsics' in view
|
| 163 |
+
if 'camera_pose' not in view:
|
| 164 |
+
view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
|
| 165 |
+
else:
|
| 166 |
+
assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
|
| 167 |
+
assert 'pts3d' not in view
|
| 168 |
+
assert 'valid_mask' not in view
|
| 169 |
+
assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
|
| 170 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 171 |
+
|
| 172 |
+
view['pts3d'] = pts3d
|
| 173 |
+
view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
| 174 |
+
# print(view['pts3d'].shape)
|
| 175 |
+
# print(view['valid_mask'].shape)
|
| 176 |
+
|
| 177 |
+
# check all datatypes
|
| 178 |
+
for key, val in view.items():
|
| 179 |
+
res, err_msg = is_good_type(key, val)
|
| 180 |
+
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
| 181 |
+
K = view['camera_intrinsics']
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
for view in views:
|
| 185 |
+
fxfycxcy = view['fxfycxcy'].copy()
|
| 186 |
+
H, W = view['img'].shape[1:]
|
| 187 |
+
fxfycxcy[0] = fxfycxcy[0] * W
|
| 188 |
+
fxfycxcy[1] = fxfycxcy[1] * H
|
| 189 |
+
fxfycxcy[2] = fxfycxcy[2] * W
|
| 190 |
+
fxfycxcy[3] = fxfycxcy[3] * H
|
| 191 |
+
view['fxfycxcy_unorm'] = fxfycxcy
|
| 192 |
+
|
| 193 |
+
# last thing done!
|
| 194 |
+
for view in views:
|
| 195 |
+
view['render_mask'] = np.ones((view['img'].shape[1], view['img'].shape[2]), dtype=np.uint8) > 0.1
|
| 196 |
+
|
| 197 |
+
for view in views:
|
| 198 |
+
# transpose to make sure all views are the same size
|
| 199 |
+
transpose_to_landscape(view)
|
| 200 |
+
if 'sky_mask' in view:
|
| 201 |
+
view.pop('sky_mask')
|
| 202 |
+
# this allows to check whether the RNG is is the same state each time
|
| 203 |
+
view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
|
| 204 |
+
return views
|
| 205 |
+
|
| 206 |
+
def _set_resolutions(self, resolutions):
|
| 207 |
+
assert resolutions is not None, 'undefined resolution'
|
| 208 |
+
|
| 209 |
+
if not isinstance(resolutions, list):
|
| 210 |
+
resolutions = [resolutions]
|
| 211 |
+
|
| 212 |
+
self._resolutions = []
|
| 213 |
+
for resolution in resolutions:
|
| 214 |
+
if isinstance(resolution, int):
|
| 215 |
+
width = height = resolution
|
| 216 |
+
else:
|
| 217 |
+
width, height = resolution
|
| 218 |
+
assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
|
| 219 |
+
assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
|
| 220 |
+
assert width >= height
|
| 221 |
+
self._resolutions.append((width, height))
|
| 222 |
+
|
| 223 |
+
def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None, depth_anything=None):
|
| 224 |
+
""" This function:
|
| 225 |
+
- first downsizes the image with LANCZOS inteprolation,
|
| 226 |
+
which is better than bilinear interpolation in
|
| 227 |
+
"""
|
| 228 |
+
if not isinstance(image, PIL.Image.Image):
|
| 229 |
+
image = PIL.Image.fromarray(image)
|
| 230 |
+
|
| 231 |
+
# transpose the resolution if necessary
|
| 232 |
+
W, H = image.size # new size
|
| 233 |
+
assert resolution[0] >= resolution[1]
|
| 234 |
+
if H > 1.1 * W:
|
| 235 |
+
# image is portrait mode
|
| 236 |
+
resolution = resolution[::-1]
|
| 237 |
+
elif 0.7 < H / W < 1.3 and resolution[0] != resolution[1] and self.aug_portrait_or_landscape:
|
| 238 |
+
# image is square, so we chose (portrait, landscape) randomly
|
| 239 |
+
if rng.integers(2):
|
| 240 |
+
resolution = resolution[::-1]
|
| 241 |
+
# resolution = resolution[::-1]
|
| 242 |
+
# high-quality Lanczos down-scaling
|
| 243 |
+
target_resolution = np.array(resolution)
|
| 244 |
+
if depth_anything is not None:
|
| 245 |
+
image, depthmap, intrinsics, depth_anything = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution, depth_anything=depth_anything)
|
| 246 |
+
else:
|
| 247 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
|
| 248 |
+
|
| 249 |
+
# actual cropping (if necessary) with bilinear interpolation
|
| 250 |
+
offset_factor = 0.5
|
| 251 |
+
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
|
| 252 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
|
| 253 |
+
if depth_anything is not None:
|
| 254 |
+
image, depthmap, intrinsics2, depth_anything = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox, depth_anything=depth_anything)
|
| 255 |
+
return image, depthmap, intrinsics2, depth_anything
|
| 256 |
+
else:
|
| 257 |
+
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
| 258 |
+
return image, depthmap, intrinsics2
|
| 259 |
+
|
| 260 |
+
def _crop_resize_if_necessary_test(self, image, depthmap, intrinsics, resolution, rng=None, info=None, depth_anything=None):
|
| 261 |
+
""" This function:
|
| 262 |
+
- first downsizes the image with LANCZOS inteprolation,
|
| 263 |
+
which is better than bilinear interpolation in
|
| 264 |
+
"""
|
| 265 |
+
if not isinstance(image, PIL.Image.Image):
|
| 266 |
+
image = PIL.Image.fromarray(image)
|
| 267 |
+
|
| 268 |
+
# transpose the resolution if necessary
|
| 269 |
+
W, H = image.size # new size
|
| 270 |
+
assert resolution[0] >= resolution[1]
|
| 271 |
+
if H > 1.1 * W:
|
| 272 |
+
# image is portrait mode
|
| 273 |
+
resolution = resolution[::-1]
|
| 274 |
+
|
| 275 |
+
# resolution = resolution[::-1]
|
| 276 |
+
# high-quality Lanczos down-scaling
|
| 277 |
+
target_resolution = np.array(resolution)
|
| 278 |
+
if depth_anything is not None:
|
| 279 |
+
image, depthmap, intrinsics, depth_anything = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution, depth_anything=depth_anything)
|
| 280 |
+
else:
|
| 281 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
|
| 282 |
+
|
| 283 |
+
# actual cropping (if necessary) with bilinear interpolation
|
| 284 |
+
offset_factor = 0.5
|
| 285 |
+
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
|
| 286 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
|
| 287 |
+
if depth_anything is not None:
|
| 288 |
+
image, depthmap, intrinsics2, depth_anything = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox, depth_anything=depth_anything)
|
| 289 |
+
else:
|
| 290 |
+
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
| 291 |
+
|
| 292 |
+
return image, depthmap, intrinsics2
|
| 293 |
+
|
| 294 |
+
def rotate_90(views, k=1):
|
| 295 |
+
# print('rotation =', k)
|
| 296 |
+
RT = np.eye(4, dtype=np.float32)
|
| 297 |
+
RT[:3, :3] = Rotation.from_euler('z', 90 * k, degrees=True).as_matrix()
|
| 298 |
+
|
| 299 |
+
for view in views:
|
| 300 |
+
view['img'] = torch.rot90(view['img'], k=k, dims=(-2, -1)) # WARNING!! dims=(-1,-2) != dims=(-2,-1)
|
| 301 |
+
view['depthmap'] = np.rot90(view['depthmap'], k=k).copy()
|
| 302 |
+
view['camera_pose'] = view['camera_pose'] @ RT
|
| 303 |
+
|
| 304 |
+
RT2 = np.eye(3, dtype=np.float32)
|
| 305 |
+
RT2[:2, :2] = RT[:2, :2] * ((1, -1), (-1, 1))
|
| 306 |
+
H, W = view['depthmap'].shape
|
| 307 |
+
if k % 4 == 0:
|
| 308 |
+
pass
|
| 309 |
+
elif k % 4 == 1:
|
| 310 |
+
# top-left (0,0) pixel becomes (0,H-1)
|
| 311 |
+
RT2[:2, 2] = (0, H - 1)
|
| 312 |
+
elif k % 4 == 2:
|
| 313 |
+
# top-left (0,0) pixel becomes (W-1,H-1)
|
| 314 |
+
RT2[:2, 2] = (W - 1, H - 1)
|
| 315 |
+
elif k % 4 == 3:
|
| 316 |
+
# top-left (0,0) pixel becomes (W-1,0)
|
| 317 |
+
RT2[:2, 2] = (W - 1, 0)
|
| 318 |
+
else:
|
| 319 |
+
raise ValueError(f'Bad value for {k=}')
|
| 320 |
+
|
| 321 |
+
view['camera_intrinsics'][:2, 2] = geotrf(RT2, view['camera_intrinsics'][:2, 2])
|
| 322 |
+
if k % 2 == 1:
|
| 323 |
+
K = view['camera_intrinsics']
|
| 324 |
+
np.fill_diagonal(K, K.diagonal()[[1, 0, 2]])
|
| 325 |
+
|
| 326 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 327 |
+
view['pts3d'] = pts3d
|
| 328 |
+
view['valid_mask'] = np.rot90(view['valid_mask'], k=k).copy()
|
| 329 |
+
view['true_shape'] = np.int32((H, W))
|
| 330 |
+
intrinsics = view['camera_intrinsics']
|
| 331 |
+
fxfycxcy = np.array([intrinsics[0, 0]/W, intrinsics[1, 1]/H, intrinsics[0,2]/W, intrinsics[1,2]/H]).astype(np.float32)
|
| 332 |
+
view['fxfycxcy'] = fxfycxcy
|
| 333 |
+
|
| 334 |
+
def reciprocal_1d(corres_1_to_2, corres_2_to_1, shape1, shape2, ret_recip=False):
|
| 335 |
+
is_reciprocal1 = np.abs(unravel_xy(corres_2_to_1[corres_1_to_2], shape1) - unravel_xy(np.arange(len(corres_1_to_2)), shape1)).sum(-1) < 4
|
| 336 |
+
pos1 = is_reciprocal1.nonzero()[0]
|
| 337 |
+
pos2 = corres_1_to_2[pos1]
|
| 338 |
+
if ret_recip:
|
| 339 |
+
return is_reciprocal1, pos1, pos2
|
| 340 |
+
return pos1, pos2
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def reproject_view(pts3d, view2):
|
| 344 |
+
shape = view2['pts3d'].shape[:2]
|
| 345 |
+
return reproject(pts3d, view2['camera_intrinsics'], inv(view2['camera_pose']), shape)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def reproject(pts3d, K, world2cam, shape):
|
| 349 |
+
H, W, THREE = pts3d.shape
|
| 350 |
+
assert THREE == 3
|
| 351 |
+
|
| 352 |
+
# reproject in camera2 space
|
| 353 |
+
with np.errstate(divide='ignore', invalid='ignore'):
|
| 354 |
+
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
|
| 355 |
+
|
| 356 |
+
# quantize to pixel positions
|
| 357 |
+
return (H, W), ravel_xy(pos, shape)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def ravel_xy(pos, shape):
|
| 361 |
+
H, W = shape
|
| 362 |
+
with np.errstate(invalid='ignore'):
|
| 363 |
+
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
|
| 364 |
+
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
|
| 365 |
+
return quantized_pos
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def unravel_xy(pos, shape):
|
| 369 |
+
# convert (x+W*y) back to 2d (x,y) coordinates
|
| 370 |
+
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class BaseStereoViewDataset (EasyDataset):
|
| 374 |
+
""" Define all basic options.
|
| 375 |
+
|
| 376 |
+
Usage:
|
| 377 |
+
class MyDataset (BaseStereoViewDataset):
|
| 378 |
+
def _get_views(self, idx, rng):
|
| 379 |
+
# overload here
|
| 380 |
+
views = []
|
| 381 |
+
views.append(dict(img=, ...))
|
| 382 |
+
return views
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
def __init__(self, *, # only keyword arguments
|
| 386 |
+
split=None,
|
| 387 |
+
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
| 388 |
+
transform=ImgNorm,
|
| 389 |
+
aug_crop=False,
|
| 390 |
+
seed=None,
|
| 391 |
+
num_views=8,
|
| 392 |
+
gt_num_image=0,
|
| 393 |
+
aug_monocular=False,
|
| 394 |
+
aug_portrait_or_landscape=True,
|
| 395 |
+
aug_rot90=False,
|
| 396 |
+
aug_swap=False,
|
| 397 |
+
only_pose=False,
|
| 398 |
+
sequential_input=False,
|
| 399 |
+
overfit=False,
|
| 400 |
+
caculate_mask=False):
|
| 401 |
+
self.sequential_input = sequential_input
|
| 402 |
+
self.split = split
|
| 403 |
+
self.num_image = num_views
|
| 404 |
+
self._set_resolutions(resolution)
|
| 405 |
+
self.gt_num_image=gt_num_image
|
| 406 |
+
self.aug_monocular=aug_monocular
|
| 407 |
+
self.aug_portrait_or_landscape = aug_portrait_or_landscape
|
| 408 |
+
self.transform = transform
|
| 409 |
+
self.transform_org = transforms.Compose([transform for transform in transform.transforms if type(transform).__name__ != 'ColorJitter'])
|
| 410 |
+
self.aug_rot90 = aug_rot90
|
| 411 |
+
self.aug_swap = aug_swap
|
| 412 |
+
self.only_pose = only_pose
|
| 413 |
+
self.overfit = overfit
|
| 414 |
+
self.rendering = False
|
| 415 |
+
self.caculate_mask = caculate_mask
|
| 416 |
+
if isinstance(transform, str):
|
| 417 |
+
transform = eval(transform)
|
| 418 |
+
|
| 419 |
+
self.aug_crop = aug_crop
|
| 420 |
+
self.seed = seed
|
| 421 |
+
self.kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(9, 9))
|
| 422 |
+
|
| 423 |
+
def __len__(self):
|
| 424 |
+
return len(self.scenes)
|
| 425 |
+
|
| 426 |
+
# def sequential_sample(self, im_start, last, interal):
|
| 427 |
+
# im_list = [im_start + i * interal + random.choice(list(range(interal))) for i in range(self.num_image)]
|
| 428 |
+
# im_list += [random.choice(im_list) + random.choice(list(range(interal))) for _ in range(self.gt_num_image)]
|
| 429 |
+
# return im_list
|
| 430 |
+
def sequential_sample(self, im_start, last, interal):
|
| 431 |
+
im_list = [
|
| 432 |
+
im_start + i * interal + random.choice(list(range(-interal//2, interal//2)))
|
| 433 |
+
for i in range(self.num_image)
|
| 434 |
+
]
|
| 435 |
+
im_list += [
|
| 436 |
+
random.choice(im_list) + random.choice(list(range(-interal//2, interal//2)))
|
| 437 |
+
for _ in range(self.gt_num_image)
|
| 438 |
+
]
|
| 439 |
+
return im_list
|
| 440 |
+
|
| 441 |
+
def get_stats(self):
|
| 442 |
+
return f"{len(self)} pairs"
|
| 443 |
+
|
| 444 |
+
def __repr__(self):
|
| 445 |
+
resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
|
| 446 |
+
return f"""{type(self).__name__}({self.get_stats()},
|
| 447 |
+
{self.split=},
|
| 448 |
+
{self.seed=},
|
| 449 |
+
resolutions={resolutions_str},
|
| 450 |
+
{self.transform=})""".replace('self.', '').replace('\n', '').replace(' ', '')
|
| 451 |
+
|
| 452 |
+
def _get_views(self, idx, resolution, rng):
|
| 453 |
+
raise NotImplementedError()
|
| 454 |
+
|
| 455 |
+
def _swap_view_aug(self, views):
|
| 456 |
+
# if self._rng.random() < 0.5:
|
| 457 |
+
# views.reverse()
|
| 458 |
+
return random.shuffle(views)
|
| 459 |
+
|
| 460 |
+
def __getitem__(self, idx):
|
| 461 |
+
if isinstance(idx, tuple):
|
| 462 |
+
# the idx is specifying the aspect-ratio
|
| 463 |
+
idx, ar_idx = idx
|
| 464 |
+
else:
|
| 465 |
+
assert len(self._resolutions) == 1
|
| 466 |
+
ar_idx = 0
|
| 467 |
+
|
| 468 |
+
# set-up the rng
|
| 469 |
+
if self.seed: # reseed for each __getitem__
|
| 470 |
+
self._rng = np.random.default_rng(seed=self.seed + idx)
|
| 471 |
+
elif not hasattr(self, '_rng'):
|
| 472 |
+
seed = torch.initial_seed() # this is different for each dataloader process
|
| 473 |
+
self._rng = np.random.default_rng(seed=seed)
|
| 474 |
+
|
| 475 |
+
# over-loaded code
|
| 476 |
+
resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
| 477 |
+
flag = False
|
| 478 |
+
i = 0
|
| 479 |
+
# views = self._get_views(idx, resolution, self._rng)
|
| 480 |
+
while flag == False and i < 1000:
|
| 481 |
+
try:
|
| 482 |
+
views = self._get_views(idx, resolution, self._rng)
|
| 483 |
+
flag = True
|
| 484 |
+
except:
|
| 485 |
+
flag = False
|
| 486 |
+
i += 1
|
| 487 |
+
|
| 488 |
+
# assert len(views) == self.num_image + self.gt_num_image
|
| 489 |
+
if self.only_pose == True:
|
| 490 |
+
# check data-types
|
| 491 |
+
for view in views:
|
| 492 |
+
# transpose to make sure all views are the same size
|
| 493 |
+
# this allows to check whether the RNG is is the same state each time
|
| 494 |
+
view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
|
| 495 |
+
return views
|
| 496 |
+
else:
|
| 497 |
+
for v, view in enumerate(views):
|
| 498 |
+
assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
| 499 |
+
view['idx'] = (idx, ar_idx, v)
|
| 500 |
+
# encode the image
|
| 501 |
+
width, height = view['img'].size
|
| 502 |
+
view['true_shape'] = np.int32((height, width))
|
| 503 |
+
view['img'] = self.transform(view['img'])
|
| 504 |
+
view['img_org' ] = self.transform_org(view['img_org'])
|
| 505 |
+
if 'depth_anything' not in view:
|
| 506 |
+
view['depth_anything'] = np.zeros_like(view['depthmap'])
|
| 507 |
+
# if view['img_org'].shape[1] != 224:
|
| 508 |
+
# print(view['img_org' ].shape)
|
| 509 |
+
# print(view['img'].shape)
|
| 510 |
+
assert 'camera_intrinsics' in view
|
| 511 |
+
if 'camera_pose' not in view:
|
| 512 |
+
view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
|
| 513 |
+
else:
|
| 514 |
+
assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
|
| 515 |
+
assert 'pts3d' not in view
|
| 516 |
+
assert 'valid_mask' not in view
|
| 517 |
+
assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
|
| 518 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
| 519 |
+
|
| 520 |
+
view['pts3d'] = pts3d
|
| 521 |
+
view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
| 522 |
+
# print(view['pts3d'].shape)
|
| 523 |
+
# print(view['valid_mask'].shape)
|
| 524 |
+
|
| 525 |
+
# check all datatypes
|
| 526 |
+
for key, val in view.items():
|
| 527 |
+
res, err_msg = is_good_type(key, val)
|
| 528 |
+
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
| 529 |
+
K = view['camera_intrinsics']
|
| 530 |
+
|
| 531 |
+
# if self.aug_swap:
|
| 532 |
+
# self._swap_view_aug(views)
|
| 533 |
+
|
| 534 |
+
if self.aug_monocular:
|
| 535 |
+
if self._rng.random() < self.aug_monocular:
|
| 536 |
+
random_idxs = random.choices(list(range(len(views)-1)), k = self.num_image + self.gt_num_image-1)
|
| 537 |
+
views_copy = [views[-1]] + [copy.deepcopy(views[random_idxs[i]]) for i in range(len(views)-1)]
|
| 538 |
+
views = views_copy
|
| 539 |
+
|
| 540 |
+
# if self.aug_rot90 is False:
|
| 541 |
+
# pass
|
| 542 |
+
# elif self.aug_rot90 == 'same':
|
| 543 |
+
# rotate_90(views, k=self._rng.choice(4))
|
| 544 |
+
# elif self.aug_rot90 == 'diff':
|
| 545 |
+
# views_list = []
|
| 546 |
+
# for view in views:
|
| 547 |
+
# views_list += rotate_90([view], k=self._rng.choice(4))
|
| 548 |
+
# views = views_list
|
| 549 |
+
# else:
|
| 550 |
+
# raise ValueError(f'Bad value for {self.aug_rot90=}')
|
| 551 |
+
if self.rendering==False:
|
| 552 |
+
self._rng.shuffle(views)
|
| 553 |
+
|
| 554 |
+
if self.caculate_mask:
|
| 555 |
+
for view1 in views[self.num_image:]:
|
| 556 |
+
render_mask = []
|
| 557 |
+
start = True
|
| 558 |
+
# images = []
|
| 559 |
+
height, width = view1['true_shape']
|
| 560 |
+
for view2 in views[:self.num_image]:
|
| 561 |
+
shape1, corres1_to_2 = reproject_view(view1['pts3d'], view2)
|
| 562 |
+
shape2, corres2_to_1 = reproject_view(view2['pts3d'], view1)
|
| 563 |
+
# compute reciprocal correspondences:
|
| 564 |
+
# pos1 == valid pixels (correspondences) in image1
|
| 565 |
+
# corres1_to_2 = unravel_xy(corres1_to_2, shape2)
|
| 566 |
+
# corres2_to_1 = unravel_xy(corres2_to_1, shape1)
|
| 567 |
+
is_reciprocal1, pos1, pos2 = reciprocal_1d(corres1_to_2, corres2_to_1, shape1, shape2, ret_recip=True)
|
| 568 |
+
render_mask.append(is_reciprocal1.reshape(shape1))
|
| 569 |
+
# is_reciprocal1 = is_reciprocal1.reshape(shape1)
|
| 570 |
+
# plt.subplot(1, 3, 1)
|
| 571 |
+
# plt.imshow(is_reciprocal1)
|
| 572 |
+
# plt.subplot(1, 3, 2)
|
| 573 |
+
# plt.imshow(view1['img'].permute(1, 2, 0) / 2 + 0.5)
|
| 574 |
+
# plt.subplot(1, 3, 3)
|
| 575 |
+
# plt.imshow(view2['img'].permute(1, 2, 0) / 2 + 0.5)
|
| 576 |
+
# plt.savefig('/data0/zsz/mast3recon/test/est.png')
|
| 577 |
+
# import ipdb; ipdb.set_trace()
|
| 578 |
+
# images.append(view2['img'])
|
| 579 |
+
if start:
|
| 580 |
+
view2['render_mask'] = np.ones((view2['img'].shape[1], view2['img'].shape[2]), dtype=np.uint8) > 0.1
|
| 581 |
+
start = False
|
| 582 |
+
render_mask = np.stack(render_mask, axis=0).sum(0)
|
| 583 |
+
render_mask = cv2.dilate(render_mask/16, self.kernel)
|
| 584 |
+
view1['render_mask'] = render_mask > 1e-5
|
| 585 |
+
# images = torch.concatenate(images, dim=2)
|
| 586 |
+
# import matplotlib.pyplot as plt
|
| 587 |
+
# plt.subplot(3, 4, 1)
|
| 588 |
+
# plt.imshow(render_mask)
|
| 589 |
+
# plt.subplot(3, 4, 2)
|
| 590 |
+
# plt.imshow(view1['img'].permute(1, 2, 0) / 2 + 0.5)
|
| 591 |
+
# for i, image in enumerate(images):
|
| 592 |
+
# plt.subplot(3, 4, 3+i)
|
| 593 |
+
# plt.imshow(image.permute(1, 2, 0) / 2 + 0.5)
|
| 594 |
+
# plt.savefig('/data0/zsz/mast3recon/test/est.png')
|
| 595 |
+
# import ipdb; ipdb.set_trace()
|
| 596 |
+
# if view1['render_mask'].shape != (height, width):
|
| 597 |
+
# import ipdb; ipdb.set_trace()
|
| 598 |
+
assert view1['render_mask'].shape == (height, width)
|
| 599 |
+
else:
|
| 600 |
+
for view in views:
|
| 601 |
+
view['render_mask'] = np.ones((view['img'].shape[1], view['img'].shape[2]), dtype=np.uint8) > 0.1
|
| 602 |
+
|
| 603 |
+
for view in views:
|
| 604 |
+
fxfycxcy = view['fxfycxcy'].copy()
|
| 605 |
+
H, W = view['img'].shape[1:]
|
| 606 |
+
fxfycxcy[0] = fxfycxcy[0] * W
|
| 607 |
+
fxfycxcy[1] = fxfycxcy[1] * H
|
| 608 |
+
fxfycxcy[2] = fxfycxcy[2] * W
|
| 609 |
+
fxfycxcy[3] = fxfycxcy[3] * H
|
| 610 |
+
view['fxfycxcy_unorm'] = fxfycxcy
|
| 611 |
+
|
| 612 |
+
# last thing done!
|
| 613 |
+
for view in views:
|
| 614 |
+
# transpose to make sure all views are the same size
|
| 615 |
+
transpose_to_landscape(view)
|
| 616 |
+
if 'sky_mask' in view:
|
| 617 |
+
view.pop('sky_mask')
|
| 618 |
+
# this allows to check whether the RNG is is the same state each time
|
| 619 |
+
view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
|
| 620 |
+
return views
|
| 621 |
+
|
| 622 |
+
def _set_resolutions(self, resolutions):
|
| 623 |
+
assert resolutions is not None, 'undefined resolution'
|
| 624 |
+
|
| 625 |
+
if not isinstance(resolutions, list):
|
| 626 |
+
resolutions = [resolutions]
|
| 627 |
+
|
| 628 |
+
self._resolutions = []
|
| 629 |
+
for resolution in resolutions:
|
| 630 |
+
if isinstance(resolution, int):
|
| 631 |
+
width = height = resolution
|
| 632 |
+
else:
|
| 633 |
+
width, height = resolution
|
| 634 |
+
assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
|
| 635 |
+
assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
|
| 636 |
+
assert width >= height
|
| 637 |
+
self._resolutions.append((width, height))
|
| 638 |
+
|
| 639 |
+
def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None, depth_anything=None):
|
| 640 |
+
""" This function:
|
| 641 |
+
- first downsizes the image with LANCZOS inteprolation,
|
| 642 |
+
which is better than bilinear interpolation in
|
| 643 |
+
"""
|
| 644 |
+
if not isinstance(image, PIL.Image.Image):
|
| 645 |
+
image = PIL.Image.fromarray(image)
|
| 646 |
+
|
| 647 |
+
# transpose the resolution if necessary
|
| 648 |
+
W, H = image.size # new size
|
| 649 |
+
assert resolution[0] >= resolution[1]
|
| 650 |
+
if H > 1.1 * W:
|
| 651 |
+
# image is portrait mode
|
| 652 |
+
resolution = resolution[::-1]
|
| 653 |
+
elif 0.7 < H / W < 1.3 and resolution[0] != resolution[1] and self.aug_portrait_or_landscape:
|
| 654 |
+
# image is square, so we chose (portrait, landscape) randomly
|
| 655 |
+
if rng.integers(2):
|
| 656 |
+
resolution = resolution[::-1]
|
| 657 |
+
# resolution = resolution[::-1]
|
| 658 |
+
# high-quality Lanczos down-scaling
|
| 659 |
+
target_resolution = np.array(resolution)
|
| 660 |
+
if depth_anything is not None:
|
| 661 |
+
image, depthmap, intrinsics, depth_anything = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution, depth_anything=depth_anything)
|
| 662 |
+
else:
|
| 663 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
|
| 664 |
+
|
| 665 |
+
# actual cropping (if necessary) with bilinear interpolation
|
| 666 |
+
offset_factor = 0.5
|
| 667 |
+
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
|
| 668 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
|
| 669 |
+
if depth_anything is not None:
|
| 670 |
+
image, depthmap, intrinsics2, depth_anything = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox, depth_anything=depth_anything)
|
| 671 |
+
return image, depthmap, intrinsics2, depth_anything
|
| 672 |
+
else:
|
| 673 |
+
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
| 674 |
+
return image, depthmap, intrinsics2
|
| 675 |
+
|
| 676 |
+
def _crop_resize_if_necessary_test(self, image, depthmap, intrinsics, resolution, rng=None, info=None, depth_anything=None):
|
| 677 |
+
""" This function:
|
| 678 |
+
- first downsizes the image with LANCZOS inteprolation,
|
| 679 |
+
which is better than bilinear interpolation in
|
| 680 |
+
"""
|
| 681 |
+
if not isinstance(image, PIL.Image.Image):
|
| 682 |
+
image = PIL.Image.fromarray(image)
|
| 683 |
+
|
| 684 |
+
# transpose the resolution if necessary
|
| 685 |
+
W, H = image.size # new size
|
| 686 |
+
assert resolution[0] >= resolution[1]
|
| 687 |
+
if H > 1.1 * W:
|
| 688 |
+
# image is portrait mode
|
| 689 |
+
resolution = resolution[::-1]
|
| 690 |
+
|
| 691 |
+
# resolution = resolution[::-1]
|
| 692 |
+
# high-quality Lanczos down-scaling
|
| 693 |
+
target_resolution = np.array(resolution)
|
| 694 |
+
if depth_anything is not None:
|
| 695 |
+
image, depthmap, intrinsics, depth_anything = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution, depth_anything=depth_anything)
|
| 696 |
+
else:
|
| 697 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
|
| 698 |
+
|
| 699 |
+
# actual cropping (if necessary) with bilinear interpolation
|
| 700 |
+
offset_factor = 0.5
|
| 701 |
+
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
|
| 702 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
|
| 703 |
+
if depth_anything is not None:
|
| 704 |
+
image, depthmap, intrinsics2, depth_anything = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox, depth_anything=depth_anything)
|
| 705 |
+
else:
|
| 706 |
+
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
| 707 |
+
|
| 708 |
+
return image, depthmap, intrinsics2
|
| 709 |
+
|
| 710 |
+
def is_good_type(key, v):
|
| 711 |
+
""" returns (is_good, err_msg)
|
| 712 |
+
"""
|
| 713 |
+
if isinstance(v, (str, int, tuple)):
|
| 714 |
+
return True, None
|
| 715 |
+
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
|
| 716 |
+
return False, f"bad {v.dtype=}"
|
| 717 |
+
return True, None
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
def view_name(view, batch_index=None):
|
| 721 |
+
def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x
|
| 722 |
+
db = sel(view['dataset'])
|
| 723 |
+
label = sel(view['label'])
|
| 724 |
+
instance = sel(view['instance'])
|
| 725 |
+
return f"{db}/{label}/{instance}"
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def transpose_to_landscape(view):
|
| 729 |
+
height, width = view['true_shape']
|
| 730 |
+
|
| 731 |
+
if width < height:
|
| 732 |
+
# rectify portrait to landscape
|
| 733 |
+
assert view['img'].shape == (3, height, width)
|
| 734 |
+
view['img'] = view['img'].swapaxes(1, 2)
|
| 735 |
+
# try:
|
| 736 |
+
if 'render_mask' in view:
|
| 737 |
+
assert view['render_mask'].shape == (height, width)
|
| 738 |
+
# except:
|
| 739 |
+
# import ipdb; ipdb.set_trace()
|
| 740 |
+
view['render_mask'] = view['render_mask'].swapaxes(0, 1)
|
| 741 |
+
|
| 742 |
+
assert view['img_org'].shape == (3, height, width)
|
| 743 |
+
view['img_org'] = view['img_org'].swapaxes(1, 2)
|
| 744 |
+
|
| 745 |
+
assert view['valid_mask'].shape == (height, width)
|
| 746 |
+
view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
|
| 747 |
+
|
| 748 |
+
assert view['depthmap'].shape == (height, width)
|
| 749 |
+
view['depthmap'] = view['depthmap'].swapaxes(0, 1)
|
| 750 |
+
|
| 751 |
+
assert view['pts3d'].shape == (height, width, 3)
|
| 752 |
+
view['pts3d'] = view['pts3d'].swapaxes(0, 1)
|
| 753 |
+
|
| 754 |
+
assert view['depth_anything'].shape == (height, width)
|
| 755 |
+
view['depth_anything'] = view['depth_anything'].swapaxes(0, 1)
|
| 756 |
+
|
| 757 |
+
# transpose x and y pixels
|
| 758 |
+
view['camera_intrinsics'] = view['camera_intrinsics']#[[1, 0, 2]]
|
| 759 |
+
# view['fxfycxcy'] = view['fxfycxcy']
|
| 760 |
+
# print(view['img'].shape)
|
| 761 |
+
# print(view['img_org'].shape)
|
| 762 |
+
# print(view['valid_mask'].shape)
|
| 763 |
+
# print(view['depthmap'].shape)
|
| 764 |
+
# print(view['pts3d'].shape)
|
| 765 |
+
# print(view['camera_intrinsics'].shape)
|
| 766 |
+
# print(view['fxfycxcy'].shape)
|
| 767 |
+
# assert view['img'].shape == (3, height, width)
|
| 768 |
+
# assert view['img_org'].shape == (3, height, width)
|
| 769 |
+
# assert view['valid_mask'].shape == (height, width)
|
| 770 |
+
# assert view['depthmap'].shape == (height, width)
|
| 771 |
+
# assert view['pts3d'].shape == (height, width, 3)
|
| 772 |
+
# assert view['camera_intrinsics'].shape == (3, 3)
|
| 773 |
+
# assert view['fxfycxcy'].shape == (4,)
|
| 774 |
+
|
dust3r/dust3r/datasets/base/batched_sampler.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# Random sampling under a constraint
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class BatchedRandomSampler:
|
| 12 |
+
""" Random sampling under a constraint: each sample in the batch has the same feature,
|
| 13 |
+
which is chosen randomly from a known pool of 'features' for each batch.
|
| 14 |
+
|
| 15 |
+
For instance, the 'feature' could be the image aspect-ratio.
|
| 16 |
+
|
| 17 |
+
The index returned is a tuple (sample_idx, feat_idx).
|
| 18 |
+
This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):
|
| 22 |
+
self.batch_size = batch_size
|
| 23 |
+
self.pool_size = pool_size
|
| 24 |
+
|
| 25 |
+
self.len_dataset = N = len(dataset)
|
| 26 |
+
self.total_size = round_by(N, batch_size*world_size) if drop_last else N
|
| 27 |
+
assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'
|
| 28 |
+
|
| 29 |
+
# distributed sampler
|
| 30 |
+
self.world_size = world_size
|
| 31 |
+
self.rank = rank
|
| 32 |
+
self.epoch = None
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return self.total_size // self.world_size
|
| 36 |
+
|
| 37 |
+
def set_epoch(self, epoch):
|
| 38 |
+
self.epoch = epoch
|
| 39 |
+
|
| 40 |
+
def __iter__(self):
|
| 41 |
+
# prepare RNG
|
| 42 |
+
if self.epoch is None:
|
| 43 |
+
assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'
|
| 44 |
+
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
| 45 |
+
else:
|
| 46 |
+
seed = self.epoch + 777
|
| 47 |
+
rng = np.random.default_rng(seed=seed)
|
| 48 |
+
|
| 49 |
+
# random indices (will restart from 0 if not drop_last)
|
| 50 |
+
sample_idxs = np.arange(self.total_size)
|
| 51 |
+
rng.shuffle(sample_idxs)
|
| 52 |
+
|
| 53 |
+
# random feat_idxs (same across each batch)
|
| 54 |
+
n_batches = (self.total_size+self.batch_size-1) // self.batch_size
|
| 55 |
+
feat_idxs = rng.integers(self.pool_size, size=n_batches)
|
| 56 |
+
feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
|
| 57 |
+
feat_idxs = feat_idxs.ravel()[:self.total_size]
|
| 58 |
+
|
| 59 |
+
# put them together
|
| 60 |
+
idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
|
| 61 |
+
|
| 62 |
+
# Distributed sampler: we select a subset of batches
|
| 63 |
+
# make sure the slice for each node is aligned with batch_size
|
| 64 |
+
size_per_proc = self.batch_size * ((self.total_size + self.world_size *
|
| 65 |
+
self.batch_size-1) // (self.world_size * self.batch_size))
|
| 66 |
+
idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]
|
| 67 |
+
|
| 68 |
+
yield from (tuple(idx) for idx in idxs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def round_by(total, multiple, up=False):
|
| 72 |
+
if up:
|
| 73 |
+
total = total + multiple-1
|
| 74 |
+
return (total//multiple) * multiple
|
dust3r/dust3r/datasets/base/easy_dataset.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
+
#
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
# A dataset base class that you can easily resize and combine.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
import numpy as np
|
| 8 |
+
from dust3r.datasets.base.batched_sampler import BatchedRandomSampler
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EasyDataset:
|
| 12 |
+
""" a dataset that you can easily resize and combine.
|
| 13 |
+
Examples:
|
| 14 |
+
---------
|
| 15 |
+
2 * dataset ==> duplicate each element 2x
|
| 16 |
+
|
| 17 |
+
10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
|
| 18 |
+
|
| 19 |
+
dataset1 + dataset2 ==> concatenate datasets
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __add__(self, other):
|
| 23 |
+
return CatDataset([self, other])
|
| 24 |
+
|
| 25 |
+
def __rmul__(self, factor):
|
| 26 |
+
return MulDataset(factor, self)
|
| 27 |
+
|
| 28 |
+
def __rmatmul__(self, factor):
|
| 29 |
+
return ResizedDataset(factor, self)
|
| 30 |
+
|
| 31 |
+
def set_epoch(self, epoch):
|
| 32 |
+
pass # nothing to do by default
|
| 33 |
+
|
| 34 |
+
def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):
|
| 35 |
+
if not (shuffle):
|
| 36 |
+
raise NotImplementedError() # cannot deal yet
|
| 37 |
+
num_of_aspect_ratios = len(self._resolutions)
|
| 38 |
+
return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MulDataset (EasyDataset):
|
| 42 |
+
""" Artifically augmenting the size of a dataset.
|
| 43 |
+
"""
|
| 44 |
+
multiplicator: int
|
| 45 |
+
|
| 46 |
+
def __init__(self, multiplicator, dataset):
|
| 47 |
+
assert isinstance(multiplicator, int) and multiplicator > 0
|
| 48 |
+
self.multiplicator = multiplicator
|
| 49 |
+
self.dataset = dataset
|
| 50 |
+
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return self.multiplicator * len(self.dataset)
|
| 53 |
+
|
| 54 |
+
def __repr__(self):
|
| 55 |
+
return f'{self.multiplicator}*{repr(self.dataset)}'
|
| 56 |
+
|
| 57 |
+
def __getitem__(self, idx):
|
| 58 |
+
if isinstance(idx, tuple):
|
| 59 |
+
idx, other = idx
|
| 60 |
+
return self.dataset[idx // self.multiplicator, other]
|
| 61 |
+
else:
|
| 62 |
+
return self.dataset[idx // self.multiplicator]
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def _resolutions(self):
|
| 66 |
+
return self.dataset._resolutions
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ResizedDataset (EasyDataset):
|
| 70 |
+
""" Artifically changing the size of a dataset.
|
| 71 |
+
"""
|
| 72 |
+
new_size: int
|
| 73 |
+
|
| 74 |
+
def __init__(self, new_size, dataset):
|
| 75 |
+
assert isinstance(new_size, int) and new_size > 0
|
| 76 |
+
self.new_size = new_size
|
| 77 |
+
self.dataset = dataset
|
| 78 |
+
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return self.new_size
|
| 81 |
+
|
| 82 |
+
def __repr__(self):
|
| 83 |
+
size_str = str(self.new_size)
|
| 84 |
+
for i in range((len(size_str)-1) // 3):
|
| 85 |
+
sep = -4*i-3
|
| 86 |
+
size_str = size_str[:sep] + '_' + size_str[sep:]
|
| 87 |
+
return f'{size_str} @ {repr(self.dataset)}'
|
| 88 |
+
|
| 89 |
+
def set_epoch(self, epoch):
|
| 90 |
+
# this random shuffle only depends on the epoch
|
| 91 |
+
rng = np.random.default_rng(seed=epoch+777)
|
| 92 |
+
|
| 93 |
+
# shuffle all indices
|
| 94 |
+
perm = rng.permutation(len(self.dataset))
|
| 95 |
+
|
| 96 |
+
# rotary extension until target size is met
|
| 97 |
+
shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))
|
| 98 |
+
self._idxs_mapping = shuffled_idxs[:self.new_size]
|
| 99 |
+
|
| 100 |
+
assert len(self._idxs_mapping) == self.new_size
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, idx):
|
| 103 |
+
assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'
|
| 104 |
+
if isinstance(idx, tuple):
|
| 105 |
+
idx, other = idx
|
| 106 |
+
return self.dataset[self._idxs_mapping[idx], other]
|
| 107 |
+
else:
|
| 108 |
+
return self.dataset[self._idxs_mapping[idx]]
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def _resolutions(self):
|
| 112 |
+
return self.dataset._resolutions
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class CatDataset (EasyDataset):
|
| 116 |
+
""" Concatenation of several datasets
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, datasets):
|
| 120 |
+
for dataset in datasets:
|
| 121 |
+
assert isinstance(dataset, EasyDataset)
|
| 122 |
+
self.datasets = datasets
|
| 123 |
+
self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
|
| 124 |
+
|
| 125 |
+
def __len__(self):
|
| 126 |
+
return self._cum_sizes[-1]
|
| 127 |
+
|
| 128 |
+
def __repr__(self):
|
| 129 |
+
# remove uselessly long transform
|
| 130 |
+
return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)
|
| 131 |
+
|
| 132 |
+
def set_epoch(self, epoch):
|
| 133 |
+
for dataset in self.datasets:
|
| 134 |
+
dataset.set_epoch(epoch)
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
other = None
|
| 138 |
+
if isinstance(idx, tuple):
|
| 139 |
+
idx, other = idx
|
| 140 |
+
|
| 141 |
+
if not (0 <= idx < len(self)):
|
| 142 |
+
raise IndexError()
|
| 143 |
+
|
| 144 |
+
db_idx = np.searchsorted(self._cum_sizes, idx, 'right')
|
| 145 |
+
dataset = self.datasets[db_idx]
|
| 146 |
+
new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
|
| 147 |
+
|
| 148 |
+
if other is not None:
|
| 149 |
+
new_idx = (new_idx, other)
|
| 150 |
+
return dataset[new_idx]
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def _resolutions(self):
|
| 154 |
+
resolutions = self.datasets[0]._resolutions
|
| 155 |
+
for dataset in self.datasets[1:]:
|
| 156 |
+
assert tuple(dataset._resolutions) == tuple(resolutions)
|
| 157 |
+
return resolutions
|
dust3r/dust3r/datasets/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|