qizhangslam commited on
Commit
1acc7fb
·
verified ·
1 Parent(s): f9347d2

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/LICENSE +52 -0
  2. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/NOTICE +21 -0
  3. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/README.MD +124 -0
  4. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/croco-stereo-flow-demo.ipynb +191 -0
  5. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/__init__.py +0 -0
  6. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/crops/README.MD +104 -0
  7. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/crops/extract_crops_from_images.py +183 -0
  8. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/README.MD +76 -0
  9. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/__init__.py +0 -0
  10. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/generate_from_metadata.py +125 -0
  11. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/generate_from_metadata_files.py +36 -0
  12. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/generate_multiview_images.py +231 -0
  13. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +501 -0
  14. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/pack_metadata_files.py +80 -0
  15. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/paths.py +179 -0
  16. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/pairs_dataset.py +162 -0
  17. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/transforms.py +135 -0
  18. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/interactive_demo.ipynb +271 -0
  19. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/blocks.py +385 -0
  20. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/criterion.py +38 -0
  21. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/croco.py +330 -0
  22. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/croco_downstream.py +141 -0
  23. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/__init__.py +4 -0
  24. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/curope.cpp +69 -0
  25. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/curope2d.py +40 -0
  26. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/kernels.cu +108 -0
  27. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/setup.py +34 -0
  28. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/dpt_block.py +513 -0
  29. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/head_downstream.py +83 -0
  30. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/masking.py +26 -0
  31. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/pos_embed.py +179 -0
  32. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/pretrain.py +391 -0
  33. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/README.MD +318 -0
  34. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/augmentor.py +396 -0
  35. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/criterion.py +351 -0
  36. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/datasets_flow.py +936 -0
  37. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/datasets_stereo.py +991 -0
  38. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/download_model.sh +12 -0
  39. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/engine.py +367 -0
  40. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/test.py +303 -0
  41. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/train.py +455 -0
  42. outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/utils/misc.py +635 -0
  43. outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/__init__.py +0 -0
  44. outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/alignment.py +416 -0
  45. outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/blocks.py +531 -0
  46. outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/bedlam.py +297 -0
  47. outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/eden.py +94 -0
  48. outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/irs.py +86 -0
  49. outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/kitti.py +323 -0
  50. outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/mp3d.py +132 -0
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/LICENSE ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
8
+
9
+
10
+ SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
11
+
12
+ ***************************
13
+
14
+ NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
15
+
16
+ This software is being redistributed in a modifiled form. The original form is available here:
17
+
18
+ https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+
20
+ This software in this file incorporates parts of the following software available here:
21
+
22
+ Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
23
+ available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
24
+
25
+ MoCo v3: https://github.com/facebookresearch/moco-v3
26
+ available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
27
+
28
+ DeiT: https://github.com/facebookresearch/deit
29
+ available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
30
+
31
+
32
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
33
+
34
+ https://github.com/facebookresearch/mae/blob/main/LICENSE
35
+
36
+ Attribution-NonCommercial 4.0 International
37
+
38
+ ***************************
39
+
40
+ NOTICE WITH RESPECT TO THE FILE: models/blocks.py
41
+
42
+ This software is being redistributed in a modifiled form. The original form is available here:
43
+
44
+ https://github.com/rwightman/pytorch-image-models
45
+
46
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
47
+
48
+ https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
49
+
50
+ Apache License
51
+ Version 2.0, January 2004
52
+ http://www.apache.org/licenses/
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/NOTICE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo
2
+ Copyright 2022-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ facebookresearch/mae
10
+ https://github.com/facebookresearch/mae
11
+
12
+ Attribution-NonCommercial 4.0 International
13
+
14
+ ====
15
+
16
+ rwightman/pytorch-image-models
17
+ https://github.com/rwightman/pytorch-image-models
18
+
19
+ Apache License
20
+ Version 2.0, January 2004
21
+ http://www.apache.org/licenses/
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/README.MD ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow
2
+
3
+ [[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]
4
+
5
+ This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:
6
+
7
+ ![image](assets/arch.jpg)
8
+
9
+ ```bibtex
10
+ @inproceedings{croco,
11
+ title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
12
+ author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
13
+ booktitle={{NeurIPS}},
14
+ year={2022}
15
+ }
16
+
17
+ @inproceedings{croco_v2,
18
+ title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
19
+ author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me},
20
+ booktitle={ICCV},
21
+ year={2023}
22
+ }
23
+ ```
24
+
25
+ ## License
26
+
27
+ The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
28
+ Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
29
+ Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.
30
+
31
+ ## Preparation
32
+
33
+ 1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.
34
+
35
+ ```bash
36
+ conda create -n croco python=3.7 cmake=3.14.0
37
+ conda activate croco
38
+ conda install habitat-sim headless -c conda-forge -c aihabitat
39
+ conda install pytorch torchvision -c pytorch
40
+ conda install notebook ipykernel matplotlib
41
+ conda install ipywidgets widgetsnbextension
42
+ conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation
43
+
44
+ ```
45
+
46
+ 2. Compile cuda kernels for RoPE
47
+
48
+ CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
49
+ ```bash
50
+ cd models/curope/
51
+ python setup.py build_ext --inplace
52
+ cd ../../
53
+ ```
54
+
55
+ This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
56
+ You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.
57
+
58
+ In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.
59
+
60
+ 3. Download pre-trained model
61
+
62
+ We provide several pre-trained models:
63
+
64
+ | modelname | pre-training data | pos. embed. | Encoder | Decoder |
65
+ |------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
66
+ | [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small |
67
+ | [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small |
68
+ | [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base |
69
+ | [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base |
70
+
71
+ To download a specific model, i.e., the first one (`CroCo.pth`)
72
+ ```bash
73
+ mkdir -p pretrained_models/
74
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
75
+ ```
76
+
77
+ ## Reconstruction example
78
+
79
+ Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
80
+ ```bash
81
+ python demo.py
82
+ ```
83
+
84
+ ## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator
85
+
86
+ First download the test scene from Habitat:
87
+ ```bash
88
+ python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
89
+ ```
90
+
91
+ Then, run the Notebook demo `interactive_demo.ipynb`.
92
+
93
+ In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.
94
+ ![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg)
95
+
96
+ ## Pre-training
97
+
98
+ ### CroCo
99
+
100
+ To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
101
+ ```
102
+ torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
103
+ ```
104
+
105
+ Our CroCo pre-training was launched on a single server with 4 GPUs.
106
+ It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
107
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
108
+ The first run can take a few minutes to start, to parse all available pre-training pairs.
109
+
110
+ ### CroCo v2
111
+
112
+ For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
113
+ Then, run the following command for the largest model (ViT-L encoder, Base decoder):
114
+ ```
115
+ torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
116
+ ```
117
+
118
+ Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
119
+ The largest model should take around 12 days on A100.
120
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
121
+
122
+ ## Stereo matching and Optical flow downstream tasks
123
+
124
+ For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/croco-stereo-flow-demo.ipynb ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9bca0f41",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Simple inference example with CroCo-Stereo or CroCo-Flow"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "80653ef7",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
19
+ "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "4f033862",
25
+ "metadata": {},
26
+ "source": [
27
+ "First download the model(s) of your choice by running\n",
28
+ "```\n",
29
+ "bash stereoflow/download_model.sh crocostereo.pth\n",
30
+ "bash stereoflow/download_model.sh crocoflow.pth\n",
31
+ "```"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "id": "1fb2e392",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "import torch\n",
42
+ "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
43
+ "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
44
+ "import matplotlib.pylab as plt"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "e0e25d77",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "from stereoflow.test import _load_model_and_criterion\n",
55
+ "from stereoflow.engine import tiled_pred\n",
56
+ "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
57
+ "from stereoflow.datasets_flow import flowToColor\n",
58
+ "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "id": "86a921f5",
64
+ "metadata": {},
65
+ "source": [
66
+ "### CroCo-Stereo example"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "id": "64e483cb",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "image1 = np.asarray(Image.open('<path_to_left_image>'))\n",
77
+ "image2 = np.asarray(Image.open('<path_to_right_image>'))"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "f0d04303",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "47dc14b5",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
98
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
99
+ "with torch.inference_mode():\n",
100
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
101
+ "pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "583b9f16",
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "plt.imshow(vis_disparity(pred))\n",
112
+ "plt.axis('off')"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "markdown",
117
+ "id": "d2df5d70",
118
+ "metadata": {},
119
+ "source": [
120
+ "### CroCo-Flow example"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "id": "9ee257a7",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "image1 = np.asarray(Image.open('<path_to_first_image>'))\n",
131
+ "image2 = np.asarray(Image.open('<path_to_second_image>'))"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "id": "d5edccf0",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "b19692c3",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": [
151
+ "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
152
+ "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
153
+ "with torch.inference_mode():\n",
154
+ " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
155
+ "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "id": "26f79db3",
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "plt.imshow(flowToColor(pred))\n",
166
+ "plt.axis('off')"
167
+ ]
168
+ }
169
+ ],
170
+ "metadata": {
171
+ "kernelspec": {
172
+ "display_name": "Python 3 (ipykernel)",
173
+ "language": "python",
174
+ "name": "python3"
175
+ },
176
+ "language_info": {
177
+ "codemirror_mode": {
178
+ "name": "ipython",
179
+ "version": 3
180
+ },
181
+ "file_extension": ".py",
182
+ "mimetype": "text/x-python",
183
+ "name": "python",
184
+ "nbconvert_exporter": "python",
185
+ "pygments_lexer": "ipython3",
186
+ "version": "3.9.7"
187
+ }
188
+ },
189
+ "nbformat": 4,
190
+ "nbformat_minor": 5
191
+ }
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/__init__.py ADDED
File without changes
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/crops/README.MD ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of crops from the real datasets
2
+
3
+ The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
4
+
5
+ ### Download the metadata of the crops to generate
6
+
7
+ First, download the metadata and put them in `./data/`:
8
+ ```
9
+ mkdir -p data
10
+ cd data/
11
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
12
+ unzip crop_metadata.zip
13
+ rm crop_metadata.zip
14
+ cd ..
15
+ ```
16
+
17
+ ### Prepare the original datasets
18
+
19
+ Second, download the original datasets in `./data/original_datasets/`.
20
+ ```
21
+ mkdir -p data/original_datasets
22
+ ```
23
+
24
+ ##### ARKitScenes
25
+
26
+ Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
27
+ The resulting file structure should be like:
28
+ ```
29
+ ./data/original_datasets/ARKitScenes/
30
+ └───Training
31
+ └───40753679
32
+ │ │ ultrawide
33
+ │ │ ...
34
+ └───40753686
35
+
36
+ ...
37
+ ```
38
+
39
+ ##### MegaDepth
40
+
41
+ Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
42
+ The resulting file structure should be like:
43
+
44
+ ```
45
+ ./data/original_datasets/MegaDepth/
46
+ └───0000
47
+ │ └───images
48
+ │ │ │ 1000557903_87fa96b8a4_o.jpg
49
+ │ │ └ ...
50
+ │ └─── ...
51
+ └───0001
52
+ │ │
53
+ │ └ ...
54
+ └─── ...
55
+ ```
56
+
57
+ ##### 3DStreetView
58
+
59
+ Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
60
+ The resulting file structure should be like:
61
+
62
+ ```
63
+ ./data/original_datasets/3DStreetView/
64
+ └───dataset_aligned
65
+ │ └───0002
66
+ │ │ │ 0000002_0000001_0000002_0000001.jpg
67
+ │ │ └ ...
68
+ │ └─── ...
69
+ └───dataset_unaligned
70
+ │ └───0003
71
+ │ │ │ 0000003_0000001_0000002_0000001.jpg
72
+ │ │ └ ...
73
+ │ └─── ...
74
+ ```
75
+
76
+ ##### IndoorVL
77
+
78
+ Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
79
+
80
+ ```
81
+ pip install kapture
82
+ mkdir -p ./data/original_datasets/IndoorVL
83
+ cd ./data/original_datasets/IndoorVL
84
+ kapture_download_dataset.py update
85
+ kapture_download_dataset.py install "HyundaiDepartmentStore_*"
86
+ kapture_download_dataset.py install "GangnamStation_*"
87
+ cd -
88
+ ```
89
+
90
+ ### Extract the crops
91
+
92
+ Now, extract the crops for each of the dataset:
93
+ ```
94
+ for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
95
+ do
96
+ python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
97
+ done
98
+ ```
99
+
100
+ ##### Note for IndoorVL
101
+
102
+ Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
103
+ To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
104
+ The impact on the performance is negligible.
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/crops/extract_crops_from_images.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Extracting crops for pre-training
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import argparse
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ import functools
13
+ from multiprocessing import Pool
14
+ import math
15
+
16
+
17
+ def arg_parser():
18
+ parser = argparse.ArgumentParser(
19
+ "Generate cropped image pairs from image crop list"
20
+ )
21
+
22
+ parser.add_argument("--crops", type=str, required=True, help="crop file")
23
+ parser.add_argument("--root-dir", type=str, required=True, help="root directory")
24
+ parser.add_argument(
25
+ "--output-dir", type=str, required=True, help="output directory"
26
+ )
27
+ parser.add_argument("--imsize", type=int, default=256, help="size of the crops")
28
+ parser.add_argument(
29
+ "--nthread", type=int, required=True, help="number of simultaneous threads"
30
+ )
31
+ parser.add_argument(
32
+ "--max-subdir-levels",
33
+ type=int,
34
+ default=5,
35
+ help="maximum number of subdirectories",
36
+ )
37
+ parser.add_argument(
38
+ "--ideal-number-pairs-in-dir",
39
+ type=int,
40
+ default=500,
41
+ help="number of pairs stored in a dir",
42
+ )
43
+ return parser
44
+
45
+
46
+ def main(args):
47
+ listing_path = os.path.join(args.output_dir, "listing.txt")
48
+
49
+ print(f"Loading list of crops ... ({args.nthread} threads)")
50
+ crops, num_crops_to_generate = load_crop_file(args.crops)
51
+
52
+ print(f"Preparing jobs ({len(crops)} candidate image pairs)...")
53
+ num_levels = min(
54
+ math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)),
55
+ args.max_subdir_levels,
56
+ )
57
+ num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1 / num_levels))
58
+
59
+ jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
60
+ del crops
61
+
62
+ os.makedirs(args.output_dir, exist_ok=True)
63
+ mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
64
+ call = functools.partial(save_image_crops, args)
65
+
66
+ print(f"Generating cropped images to {args.output_dir} ...")
67
+ with open(listing_path, "w") as listing:
68
+ listing.write("# pair_path\n")
69
+ for results in tqdm(mmap(call, jobs), total=len(jobs)):
70
+ for path in results:
71
+ listing.write(f"{path}\n")
72
+ print("Finished writing listing to", listing_path)
73
+
74
+
75
+ def load_crop_file(path):
76
+ data = open(path).read().splitlines()
77
+ pairs = []
78
+ num_crops_to_generate = 0
79
+ for line in tqdm(data):
80
+ if line.startswith("#"):
81
+ continue
82
+ line = line.split(", ")
83
+ if len(line) < 8:
84
+ img1, img2, rotation = line
85
+ pairs.append((img1, img2, int(rotation), []))
86
+ else:
87
+ l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
88
+ rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
89
+ pairs[-1][-1].append((rect1, rect2))
90
+ num_crops_to_generate += 1
91
+ return pairs, num_crops_to_generate
92
+
93
+
94
+ def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
95
+ jobs = []
96
+ powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
97
+
98
+ def get_path(idx):
99
+ idx_array = []
100
+ d = idx
101
+ for level in range(num_levels - 1):
102
+ idx_array.append(idx // powers[level])
103
+ idx = idx % powers[level]
104
+ idx_array.append(d)
105
+ return "/".join(map(lambda x: hex(x)[2:], idx_array))
106
+
107
+ idx = 0
108
+ for pair_data in tqdm(pairs):
109
+ img1, img2, rotation, crops = pair_data
110
+ if -60 <= rotation and rotation <= 60:
111
+ rotation = 0 # most likely not a true rotation
112
+ paths = [get_path(idx + k) for k in range(len(crops))]
113
+ idx += len(crops)
114
+ jobs.append(((img1, img2), rotation, crops, paths))
115
+ return jobs
116
+
117
+
118
+ def load_image(path):
119
+ try:
120
+ return Image.open(path).convert("RGB")
121
+ except Exception as e:
122
+ print("skipping", path, e)
123
+ raise OSError()
124
+
125
+
126
+ def save_image_crops(args, data):
127
+ # load images
128
+ img_pair, rot, crops, paths = data
129
+ try:
130
+ img1, img2 = [
131
+ load_image(os.path.join(args.root_dir, impath)) for impath in img_pair
132
+ ]
133
+ except OSError as e:
134
+ return []
135
+
136
+ def area(sz):
137
+ return sz[0] * sz[1]
138
+
139
+ tgt_size = (args.imsize, args.imsize)
140
+
141
+ def prepare_crop(img, rect, rot=0):
142
+ # actual crop
143
+ img = img.crop(rect)
144
+
145
+ # resize to desired size
146
+ interp = (
147
+ Image.Resampling.LANCZOS
148
+ if area(img.size) > 4 * area(tgt_size)
149
+ else Image.Resampling.BICUBIC
150
+ )
151
+ img = img.resize(tgt_size, resample=interp)
152
+
153
+ # rotate the image
154
+ rot90 = (round(rot / 90) % 4) * 90
155
+ if rot90 == 90:
156
+ img = img.transpose(Image.Transpose.ROTATE_90)
157
+ elif rot90 == 180:
158
+ img = img.transpose(Image.Transpose.ROTATE_180)
159
+ elif rot90 == 270:
160
+ img = img.transpose(Image.Transpose.ROTATE_270)
161
+ return img
162
+
163
+ results = []
164
+ for (rect1, rect2), path in zip(crops, paths):
165
+ crop1 = prepare_crop(img1, rect1)
166
+ crop2 = prepare_crop(img2, rect2, rot)
167
+
168
+ fullpath1 = os.path.join(args.output_dir, path + "_1.jpg")
169
+ fullpath2 = os.path.join(args.output_dir, path + "_2.jpg")
170
+ os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
171
+
172
+ assert not os.path.isfile(fullpath1), fullpath1
173
+ assert not os.path.isfile(fullpath2), fullpath2
174
+ crop1.save(fullpath1)
175
+ crop2.save(fullpath2)
176
+ results.append(path)
177
+
178
+ return results
179
+
180
+
181
+ if __name__ == "__main__":
182
+ args = arg_parser().parse_args()
183
+ main(args)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/README.MD ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of synthetic image pairs using Habitat-Sim
2
+
3
+ These instructions allow to generate pre-training pairs from the Habitat simulator.
4
+ As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
5
+
6
+ ### Download Habitat-Sim scenes
7
+ Download Habitat-Sim scenes:
8
+ - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
9
+ - We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
10
+ - Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
11
+ ```
12
+ ./data/
13
+ └──habitat-sim-data/
14
+ └──scene_datasets/
15
+ ├──hm3d/
16
+ ├──gibson/
17
+ ├──habitat-test-scenes/
18
+ ├──replica_cad_baked_lighting/
19
+ ├──replica_cad/
20
+ ├──ReplicaDataset/
21
+ └──scannet/
22
+ ```
23
+
24
+ ### Image pairs generation
25
+ We provide metadata to generate reproducible images pairs for pretraining and validation.
26
+ Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
27
+
28
+ Specifications:
29
+ - 256x256 resolution images, with 60 degrees field of view .
30
+ - Up to 1000 image pairs per scene.
31
+ - Number of scenes considered/number of images pairs per dataset:
32
+ - Scannet: 1097 scenes / 985 209 pairs
33
+ - HM3D:
34
+ - hm3d/train: 800 / 800k pairs
35
+ - hm3d/val: 100 scenes / 100k pairs
36
+ - hm3d/minival: 10 scenes / 10k pairs
37
+ - habitat-test-scenes: 3 scenes / 3k pairs
38
+ - replica_cad_baked_lighting: 13 scenes / 13k pairs
39
+
40
+ - Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
41
+
42
+ Download metadata and extract it:
43
+ ```bash
44
+ mkdir -p data/habitat_release_metadata/
45
+ cd data/habitat_release_metadata/
46
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
47
+ tar -xvf multiview_habitat_metadata.tar.gz
48
+ cd ../..
49
+ # Location of the metadata
50
+ METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
51
+ ```
52
+
53
+ Generate image pairs from metadata:
54
+ - The following command will print a list of commandlines to generate image pairs for each scene:
55
+ ```bash
56
+ # Target output directory
57
+ PAIRS_DATASET_DIR="./data/habitat_release/"
58
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
59
+ ```
60
+ - One can launch multiple of such commands in parallel e.g. using GNU Parallel:
61
+ ```bash
62
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
63
+ ```
64
+
65
+ ## Metadata generation
66
+
67
+ Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
68
+ ```bash
69
+ # Print commandlines to generate image pairs from the different scenes available.
70
+ PAIRS_DATASET_DIR=MY_CUSTOM_PATH
71
+ python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
72
+
73
+ # Once a dataset is generated, pack metadata files for reproducibility.
74
+ METADATA_DIR=MY_CUSTON_PATH
75
+ python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
76
+ ```
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/__init__.py ADDED
File without changes
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/generate_from_metadata.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
6
+ """
7
+ import os
8
+ from datasets.habitat_sim.multiview_habitat_sim_generator import (
9
+ MultiviewHabitatSimGenerator,
10
+ )
11
+ from datasets.habitat_sim.paths import SCENES_DATASET
12
+ import argparse
13
+ import quaternion
14
+ import PIL.Image
15
+ import cv2
16
+ import json
17
+ from tqdm import tqdm
18
+
19
+
20
+ def generate_multiview_images_from_metadata(
21
+ metadata_filename,
22
+ output_dir,
23
+ overload_params=dict(),
24
+ scene_datasets_paths=None,
25
+ exist_ok=False,
26
+ ):
27
+ """
28
+ Generate images from a metadata file for reproducibility purposes.
29
+ """
30
+ # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
31
+ if scene_datasets_paths is not None:
32
+ scene_datasets_paths = dict(
33
+ sorted(scene_datasets_paths.items(), key=lambda x: len(x[0]), reverse=True)
34
+ )
35
+
36
+ with open(metadata_filename, "r") as f:
37
+ input_metadata = json.load(f)
38
+ metadata = dict()
39
+ for key, value in input_metadata.items():
40
+ # Optionally replace some paths
41
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
42
+ if scene_datasets_paths is not None:
43
+ for dataset_label, dataset_path in scene_datasets_paths.items():
44
+ if value.startswith(dataset_label):
45
+ value = os.path.normpath(
46
+ os.path.join(
47
+ dataset_path, os.path.relpath(value, dataset_label)
48
+ )
49
+ )
50
+ break
51
+ metadata[key] = value
52
+
53
+ # Overload some parameters
54
+ for key, value in overload_params.items():
55
+ metadata[key] = value
56
+
57
+ generation_entries = dict(
58
+ [
59
+ (key, value)
60
+ for key, value in metadata.items()
61
+ if not (key in ("multiviews", "output_dir", "generate_depth"))
62
+ ]
63
+ )
64
+ generate_depth = metadata["generate_depth"]
65
+
66
+ os.makedirs(output_dir, exist_ok=exist_ok)
67
+
68
+ generator = MultiviewHabitatSimGenerator(**generation_entries)
69
+
70
+ # Generate views
71
+ for idx_label, data in tqdm(metadata["multiviews"].items()):
72
+ positions = data["positions"]
73
+ orientations = data["orientations"]
74
+ n = len(positions)
75
+ for oidx in range(n):
76
+ observation = generator.render_viewpoint(
77
+ positions[oidx], quaternion.from_float_array(orientations[oidx])
78
+ )
79
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
80
+ # Color image saved using PIL
81
+ img = PIL.Image.fromarray(observation["color"][:, :, :3])
82
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
83
+ img.save(filename)
84
+ if generate_depth:
85
+ # Depth image as EXR file
86
+ filename = os.path.join(
87
+ output_dir, f"{idx_label}_{observation_label}_depth.exr"
88
+ )
89
+ cv2.imwrite(
90
+ filename,
91
+ observation["depth"],
92
+ [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF],
93
+ )
94
+ # Camera parameters
95
+ camera_params = dict(
96
+ [
97
+ (key, observation[key].tolist())
98
+ for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")
99
+ ]
100
+ )
101
+ filename = os.path.join(
102
+ output_dir, f"{idx_label}_{observation_label}_camera_params.json"
103
+ )
104
+ with open(filename, "w") as f:
105
+ json.dump(camera_params, f)
106
+ # Save metadata
107
+ with open(os.path.join(output_dir, "metadata.json"), "w") as f:
108
+ json.dump(metadata, f)
109
+
110
+ generator.close()
111
+
112
+
113
+ if __name__ == "__main__":
114
+ parser = argparse.ArgumentParser()
115
+ parser.add_argument("--metadata_filename", required=True)
116
+ parser.add_argument("--output_dir", required=True)
117
+ args = parser.parse_args()
118
+
119
+ generate_multiview_images_from_metadata(
120
+ metadata_filename=args.metadata_filename,
121
+ output_dir=args.output_dir,
122
+ scene_datasets_paths=SCENES_DATASET,
123
+ overload_params=dict(),
124
+ exist_ok=True,
125
+ )
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/generate_from_metadata_files.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script generating commandlines to generate image pairs from metadata files.
6
+ """
7
+ import os
8
+ import glob
9
+ from tqdm import tqdm
10
+ import argparse
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--input_dir", required=True)
15
+ parser.add_argument("--output_dir", required=True)
16
+ parser.add_argument(
17
+ "--prefix",
18
+ default="",
19
+ help="Commanline prefix, useful e.g. to setup environment.",
20
+ )
21
+ args = parser.parse_args()
22
+
23
+ input_metadata_filenames = glob.iglob(
24
+ f"{args.input_dir}/**/metadata.json", recursive=True
25
+ )
26
+
27
+ for metadata_filename in tqdm(input_metadata_filenames):
28
+ output_dir = os.path.join(
29
+ args.output_dir,
30
+ os.path.relpath(os.path.dirname(metadata_filename), args.input_dir),
31
+ )
32
+ # Do not process the scene if the metadata file already exists
33
+ if os.path.exists(os.path.join(output_dir, "metadata.json")):
34
+ continue
35
+ commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
36
+ print(commandline)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/generate_multiview_images.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from tqdm import tqdm
6
+ import argparse
7
+ import PIL.Image
8
+ import numpy as np
9
+ import json
10
+ from datasets.habitat_sim.multiview_habitat_sim_generator import (
11
+ MultiviewHabitatSimGenerator,
12
+ NoNaviguableSpaceError,
13
+ )
14
+ from datasets.habitat_sim.paths import list_scenes_available
15
+ import cv2
16
+ import quaternion
17
+ import shutil
18
+
19
+
20
+ def generate_multiview_images_for_scene(
21
+ scene_dataset_config_file,
22
+ scene,
23
+ navmesh,
24
+ output_dir,
25
+ views_count,
26
+ size,
27
+ exist_ok=False,
28
+ generate_depth=False,
29
+ **kwargs,
30
+ ):
31
+ """
32
+ Generate tuples of overlapping views for a given scene.
33
+ generate_depth: generate depth images and camera parameters.
34
+ """
35
+ if os.path.exists(output_dir) and not exist_ok:
36
+ print(f"Scene {scene}: data already generated. Ignoring generation.")
37
+ return
38
+ try:
39
+ print(f"Scene {scene}: {size} multiview acquisitions to generate...")
40
+ os.makedirs(output_dir, exist_ok=exist_ok)
41
+
42
+ metadata_filename = os.path.join(output_dir, "metadata.json")
43
+
44
+ metadata_template = dict(
45
+ scene_dataset_config_file=scene_dataset_config_file,
46
+ scene=scene,
47
+ navmesh=navmesh,
48
+ views_count=views_count,
49
+ size=size,
50
+ generate_depth=generate_depth,
51
+ **kwargs,
52
+ )
53
+ metadata_template["multiviews"] = dict()
54
+
55
+ if os.path.exists(metadata_filename):
56
+ print("Metadata file already exists:", metadata_filename)
57
+ print("Loading already generated metadata file...")
58
+ with open(metadata_filename, "r") as f:
59
+ metadata = json.load(f)
60
+
61
+ for key in metadata_template.keys():
62
+ if key != "multiviews":
63
+ assert (
64
+ metadata_template[key] == metadata[key]
65
+ ), f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}."
66
+ else:
67
+ print("No temporary file found. Starting generation from scratch...")
68
+ metadata = metadata_template
69
+
70
+ starting_id = len(metadata["multiviews"])
71
+ print(f"Starting generation from index {starting_id}/{size}...")
72
+ if starting_id >= size:
73
+ print("Generation already done.")
74
+ return
75
+
76
+ generator = MultiviewHabitatSimGenerator(
77
+ scene_dataset_config_file=scene_dataset_config_file,
78
+ scene=scene,
79
+ navmesh=navmesh,
80
+ views_count=views_count,
81
+ size=size,
82
+ **kwargs,
83
+ )
84
+
85
+ for idx in tqdm(range(starting_id, size)):
86
+ # Generate / re-generate the observations
87
+ try:
88
+ data = generator[idx]
89
+ observations = data["observations"]
90
+ positions = data["positions"]
91
+ orientations = data["orientations"]
92
+
93
+ idx_label = f"{idx:08}"
94
+ for oidx, observation in enumerate(observations):
95
+ observation_label = (
96
+ f"{oidx + 1}" # Leonid is indexing starting from 1
97
+ )
98
+ # Color image saved using PIL
99
+ img = PIL.Image.fromarray(observation["color"][:, :, :3])
100
+ filename = os.path.join(
101
+ output_dir, f"{idx_label}_{observation_label}.jpeg"
102
+ )
103
+ img.save(filename)
104
+ if generate_depth:
105
+ # Depth image as EXR file
106
+ filename = os.path.join(
107
+ output_dir, f"{idx_label}_{observation_label}_depth.exr"
108
+ )
109
+ cv2.imwrite(
110
+ filename,
111
+ observation["depth"],
112
+ [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF],
113
+ )
114
+ # Camera parameters
115
+ camera_params = dict(
116
+ [
117
+ (key, observation[key].tolist())
118
+ for key in (
119
+ "camera_intrinsics",
120
+ "R_cam2world",
121
+ "t_cam2world",
122
+ )
123
+ ]
124
+ )
125
+ filename = os.path.join(
126
+ output_dir,
127
+ f"{idx_label}_{observation_label}_camera_params.json",
128
+ )
129
+ with open(filename, "w") as f:
130
+ json.dump(camera_params, f)
131
+ metadata["multiviews"][idx_label] = {
132
+ "positions": positions.tolist(),
133
+ "orientations": orientations.tolist(),
134
+ "covisibility_ratios": data["covisibility_ratios"].tolist(),
135
+ "valid_fractions": data["valid_fractions"].tolist(),
136
+ "pairwise_visibility_ratios": data[
137
+ "pairwise_visibility_ratios"
138
+ ].tolist(),
139
+ }
140
+ except RecursionError:
141
+ print(
142
+ "Recursion error: unable to sample observations for this scene. We will stop there."
143
+ )
144
+ break
145
+
146
+ # Regularly save a temporary metadata file, in case we need to restart the generation
147
+ if idx % 10 == 0:
148
+ with open(metadata_filename, "w") as f:
149
+ json.dump(metadata, f)
150
+
151
+ # Save metadata
152
+ with open(metadata_filename, "w") as f:
153
+ json.dump(metadata, f)
154
+
155
+ generator.close()
156
+ except NoNaviguableSpaceError:
157
+ pass
158
+
159
+
160
+ def create_commandline(scene_data, generate_depth, exist_ok=False):
161
+ """
162
+ Create a commandline string to generate a scene.
163
+ """
164
+
165
+ def my_formatting(val):
166
+ if val is None or val == "":
167
+ return '""'
168
+ else:
169
+ return val
170
+
171
+ commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)}
172
+ --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)}
173
+ --navmesh {my_formatting(scene_data.navmesh)}
174
+ --output_dir {my_formatting(scene_data.output_dir)}
175
+ --generate_depth {int(generate_depth)}
176
+ --exist_ok {int(exist_ok)}
177
+ """
178
+ commandline = " ".join(commandline.split())
179
+ return commandline
180
+
181
+
182
+ if __name__ == "__main__":
183
+ os.umask(2)
184
+
185
+ parser = argparse.ArgumentParser(
186
+ description="""Example of use -- listing commands to generate data for scenes available:
187
+ > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands
188
+ """
189
+ )
190
+
191
+ parser.add_argument("--output_dir", type=str, required=True)
192
+ parser.add_argument(
193
+ "--list_commands", action="store_true", help="list commandlines to run if true"
194
+ )
195
+ parser.add_argument("--scene", type=str, default="")
196
+ parser.add_argument("--scene_dataset_config_file", type=str, default="")
197
+ parser.add_argument("--navmesh", type=str, default="")
198
+
199
+ parser.add_argument("--generate_depth", type=int, default=1)
200
+ parser.add_argument("--exist_ok", type=int, default=0)
201
+
202
+ kwargs = dict(resolution=(256, 256), hfov=60, views_count=2, size=1000)
203
+
204
+ args = parser.parse_args()
205
+ generate_depth = bool(args.generate_depth)
206
+ exist_ok = bool(args.exist_ok)
207
+
208
+ if args.list_commands:
209
+ # Listing scenes available...
210
+ scenes_data = list_scenes_available(base_output_dir=args.output_dir)
211
+
212
+ for scene_data in scenes_data:
213
+ print(
214
+ create_commandline(
215
+ scene_data, generate_depth=generate_depth, exist_ok=exist_ok
216
+ )
217
+ )
218
+ else:
219
+ if args.scene == "" or args.output_dir == "":
220
+ print("Missing scene or output dir argument!")
221
+ print(parser.format_help())
222
+ else:
223
+ generate_multiview_images_for_scene(
224
+ scene=args.scene,
225
+ scene_dataset_config_file=args.scene_dataset_config_file,
226
+ navmesh=args.navmesh,
227
+ output_dir=args.output_dir,
228
+ exist_ok=exist_ok,
229
+ generate_depth=generate_depth,
230
+ **kwargs,
231
+ )
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ import numpy as np
6
+ import quaternion
7
+ import habitat_sim
8
+ import json
9
+ from sklearn.neighbors import NearestNeighbors
10
+ import cv2
11
+
12
+ # OpenCV to habitat camera convention transformation
13
+ R_OPENCV2HABITAT = np.stack(
14
+ (habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0
15
+ )
16
+ R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
17
+ DEG2RAD = np.pi / 180
18
+
19
+
20
+ def compute_camera_intrinsics(height, width, hfov):
21
+ f = width / 2 / np.tan(hfov / 2 * np.pi / 180)
22
+ cu, cv = width / 2, height / 2
23
+ return f, cu, cv
24
+
25
+
26
+ def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
27
+ R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
28
+ t_cam2world = np.asarray(camera_position)
29
+ return R_cam2world, t_cam2world
30
+
31
+
32
+ def compute_pointmap(depthmap, hfov):
33
+ """Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
34
+ height, width = depthmap.shape
35
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
36
+ # Cast depth map to point
37
+ z_cam = depthmap
38
+ u, v = np.meshgrid(range(width), range(height))
39
+ x_cam = (u - cu) / f * z_cam
40
+ y_cam = (v - cv) / f * z_cam
41
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
42
+ return X_cam
43
+
44
+
45
+ def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
46
+ """Return a 3D point cloud corresponding to valid pixels of the depth map"""
47
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(
48
+ camera_position, camera_rotation
49
+ )
50
+
51
+ X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
52
+ valid_mask = X_cam[:, :, 2] != 0.0
53
+
54
+ X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
55
+ X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
56
+ return X_world
57
+
58
+
59
+ def compute_pointcloud_overlaps_scikit(
60
+ pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False
61
+ ):
62
+ """
63
+ Compute 'overlapping' metrics based on a distance threshold between two point clouds.
64
+ """
65
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(pointcloud2)
66
+ distances, indices = nbrs.kneighbors(pointcloud1)
67
+ intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
68
+
69
+ data = {"intersection1": intersection1, "size1": len(pointcloud1)}
70
+ if compute_symmetric:
71
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(pointcloud1)
72
+ distances, indices = nbrs.kneighbors(pointcloud2)
73
+ intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
74
+ data["intersection2"] = intersection2
75
+ data["size2"] = len(pointcloud2)
76
+
77
+ return data
78
+
79
+
80
+ def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
81
+ """
82
+ Add camera parameters to the observation dictionnary produced by Habitat-Sim
83
+ In-place modifications.
84
+ """
85
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(
86
+ camera_location, camera_rotation
87
+ )
88
+ height, width = observation["depth"].shape
89
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
90
+ K = np.asarray([[f, 0, cu], [0, f, cv], [0, 0, 1.0]])
91
+ observation["camera_intrinsics"] = K
92
+ observation["t_cam2world"] = t_cam2world
93
+ observation["R_cam2world"] = R_cam2world
94
+
95
+
96
+ def look_at(eye, center, up, return_cam2world=True):
97
+ """
98
+ Return camera pose looking at a given center point.
99
+ Analogous of gluLookAt function, using OpenCV camera convention.
100
+ """
101
+ z = center - eye
102
+ z /= np.linalg.norm(z, axis=-1, keepdims=True)
103
+ y = -up
104
+ y = y - np.sum(y * z, axis=-1, keepdims=True) * z
105
+ y /= np.linalg.norm(y, axis=-1, keepdims=True)
106
+ x = np.cross(y, z, axis=-1)
107
+
108
+ if return_cam2world:
109
+ R = np.stack((x, y, z), axis=-1)
110
+ t = eye
111
+ else:
112
+ # World to camera transformation
113
+ # Transposed matrix
114
+ R = np.stack((x, y, z), axis=-2)
115
+ t = -np.einsum("...ij, ...j", R, eye)
116
+ return R, t
117
+
118
+
119
+ def look_at_for_habitat(eye, center, up, return_cam2world=True):
120
+ R, t = look_at(eye, center, up)
121
+ orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
122
+ return orientation, t
123
+
124
+
125
+ def generate_orientation_noise(pan_range, tilt_range, roll_range):
126
+ return (
127
+ quaternion.from_rotation_vector(
128
+ np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP
129
+ )
130
+ * quaternion.from_rotation_vector(
131
+ np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT
132
+ )
133
+ * quaternion.from_rotation_vector(
134
+ np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT
135
+ )
136
+ )
137
+
138
+
139
+ class NoNaviguableSpaceError(RuntimeError):
140
+ def __init__(self, *args):
141
+ super().__init__(*args)
142
+
143
+
144
+ class MultiviewHabitatSimGenerator:
145
+ def __init__(
146
+ self,
147
+ scene,
148
+ navmesh,
149
+ scene_dataset_config_file,
150
+ resolution=(240, 320),
151
+ views_count=2,
152
+ hfov=60,
153
+ gpu_id=0,
154
+ size=10000,
155
+ minimum_covisibility=0.5,
156
+ transform=None,
157
+ ):
158
+ self.scene = scene
159
+ self.navmesh = navmesh
160
+ self.scene_dataset_config_file = scene_dataset_config_file
161
+ self.resolution = resolution
162
+ self.views_count = views_count
163
+ assert self.views_count >= 1
164
+ self.hfov = hfov
165
+ self.gpu_id = gpu_id
166
+ self.size = size
167
+ self.transform = transform
168
+
169
+ # Noise added to camera orientation
170
+ self.pan_range = (-3, 3)
171
+ self.tilt_range = (-10, 10)
172
+ self.roll_range = (-5, 5)
173
+
174
+ # Height range to sample cameras
175
+ self.height_range = (1.2, 1.8)
176
+
177
+ # Random steps between the camera views
178
+ self.random_steps_count = 5
179
+ self.random_step_variance = 2.0
180
+
181
+ # Minimum fraction of the scene which should be valid (well defined depth)
182
+ self.minimum_valid_fraction = 0.7
183
+
184
+ # Distance threshold to see to select pairs
185
+ self.distance_threshold = 0.05
186
+ # Minimum IoU of a view point cloud with respect to the reference view to be kept.
187
+ self.minimum_covisibility = minimum_covisibility
188
+
189
+ # Maximum number of retries.
190
+ self.max_attempts_count = 100
191
+
192
+ self.seed = None
193
+ self._lazy_initialization()
194
+
195
+ def _lazy_initialization(self):
196
+ # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
197
+ if self.seed == None:
198
+ # Re-seed numpy generator
199
+ np.random.seed()
200
+ self.seed = np.random.randint(2**32 - 1)
201
+ sim_cfg = habitat_sim.SimulatorConfiguration()
202
+ sim_cfg.scene_id = self.scene
203
+ if (
204
+ self.scene_dataset_config_file is not None
205
+ and self.scene_dataset_config_file != ""
206
+ ):
207
+ sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
208
+ sim_cfg.random_seed = self.seed
209
+ sim_cfg.load_semantic_mesh = False
210
+ sim_cfg.gpu_device_id = self.gpu_id
211
+
212
+ depth_sensor_spec = habitat_sim.CameraSensorSpec()
213
+ depth_sensor_spec.uuid = "depth"
214
+ depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
215
+ depth_sensor_spec.resolution = self.resolution
216
+ depth_sensor_spec.hfov = self.hfov
217
+ depth_sensor_spec.position = [0.0, 0.0, 0]
218
+ depth_sensor_spec.orientation
219
+
220
+ rgb_sensor_spec = habitat_sim.CameraSensorSpec()
221
+ rgb_sensor_spec.uuid = "color"
222
+ rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
223
+ rgb_sensor_spec.resolution = self.resolution
224
+ rgb_sensor_spec.hfov = self.hfov
225
+ rgb_sensor_spec.position = [0.0, 0.0, 0]
226
+ agent_cfg = habitat_sim.agent.AgentConfiguration(
227
+ sensor_specifications=[rgb_sensor_spec, depth_sensor_spec]
228
+ )
229
+
230
+ cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
231
+ self.sim = habitat_sim.Simulator(cfg)
232
+ if self.navmesh is not None and self.navmesh != "":
233
+ # Use pre-computed navmesh when available (usually better than those generated automatically)
234
+ self.sim.pathfinder.load_nav_mesh(self.navmesh)
235
+
236
+ if not self.sim.pathfinder.is_loaded:
237
+ # Try to compute a navmesh
238
+ navmesh_settings = habitat_sim.NavMeshSettings()
239
+ navmesh_settings.set_defaults()
240
+ self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
241
+
242
+ # Ensure that the navmesh is not empty
243
+ if not self.sim.pathfinder.is_loaded:
244
+ raise NoNaviguableSpaceError(
245
+ f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})"
246
+ )
247
+
248
+ self.agent = self.sim.initialize_agent(agent_id=0)
249
+
250
+ def close(self):
251
+ self.sim.close()
252
+
253
+ def __del__(self):
254
+ self.sim.close()
255
+
256
+ def __len__(self):
257
+ return self.size
258
+
259
+ def sample_random_viewpoint(self):
260
+ """Sample a random viewpoint using the navmesh"""
261
+ nav_point = self.sim.pathfinder.get_random_navigable_point()
262
+
263
+ # Sample a random viewpoint height
264
+ viewpoint_height = np.random.uniform(*self.height_range)
265
+ viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
266
+ viewpoint_orientation = quaternion.from_rotation_vector(
267
+ np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP
268
+ ) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
269
+ return viewpoint_position, viewpoint_orientation, nav_point
270
+
271
+ def sample_other_random_viewpoint(self, observed_point, nav_point):
272
+ """Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
273
+ other_nav_point = nav_point
274
+
275
+ walk_directions = self.random_step_variance * np.asarray([1, 0, 1])
276
+ for i in range(self.random_steps_count):
277
+ temp = self.sim.pathfinder.snap_point(
278
+ other_nav_point + walk_directions * np.random.normal(size=3)
279
+ )
280
+ # Snapping may return nan when it fails
281
+ if not np.isnan(temp[0]):
282
+ other_nav_point = temp
283
+
284
+ other_viewpoint_height = np.random.uniform(*self.height_range)
285
+ other_viewpoint_position = (
286
+ other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
287
+ )
288
+
289
+ # Set viewing direction towards the central point
290
+ rotation, position = look_at_for_habitat(
291
+ eye=other_viewpoint_position,
292
+ center=observed_point,
293
+ up=habitat_sim.geo.UP,
294
+ return_cam2world=True,
295
+ )
296
+ rotation = rotation * generate_orientation_noise(
297
+ self.pan_range, self.tilt_range, self.roll_range
298
+ )
299
+ return position, rotation, other_nav_point
300
+
301
+ def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
302
+ """Check if a viewpoint is valid and overlaps significantly with a reference one."""
303
+ # Observation
304
+ pixels_count = self.resolution[0] * self.resolution[1]
305
+ valid_fraction = len(other_pointcloud) / pixels_count
306
+ assert valid_fraction <= 1.0 and valid_fraction >= 0.0
307
+ overlap = compute_pointcloud_overlaps_scikit(
308
+ ref_pointcloud,
309
+ other_pointcloud,
310
+ self.distance_threshold,
311
+ compute_symmetric=True,
312
+ )
313
+ covisibility = min(
314
+ overlap["intersection1"] / pixels_count,
315
+ overlap["intersection2"] / pixels_count,
316
+ )
317
+ is_valid = (valid_fraction >= self.minimum_valid_fraction) and (
318
+ covisibility >= self.minimum_covisibility
319
+ )
320
+ return is_valid, valid_fraction, covisibility
321
+
322
+ def is_other_viewpoint_overlapping(
323
+ self, ref_pointcloud, observation, position, rotation
324
+ ):
325
+ """Check if a viewpoint is valid and overlaps significantly with a reference one."""
326
+ # Observation
327
+ other_pointcloud = compute_pointcloud(
328
+ observation["depth"], self.hfov, position, rotation
329
+ )
330
+ return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
331
+
332
+ def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
333
+ agent_state = habitat_sim.AgentState()
334
+ agent_state.position = viewpoint_position
335
+ agent_state.rotation = viewpoint_orientation
336
+ self.agent.set_state(agent_state)
337
+ viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
338
+ _append_camera_parameters(
339
+ viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation
340
+ )
341
+ return viewpoint_observations
342
+
343
+ def __getitem__(self, useless_idx):
344
+ ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
345
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
346
+ # Extract point cloud
347
+ ref_pointcloud = compute_pointcloud(
348
+ depthmap=ref_observations["depth"],
349
+ hfov=self.hfov,
350
+ camera_position=ref_position,
351
+ camera_rotation=ref_orientation,
352
+ )
353
+
354
+ pixels_count = self.resolution[0] * self.resolution[1]
355
+ ref_valid_fraction = len(ref_pointcloud) / pixels_count
356
+ assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
357
+ if ref_valid_fraction < self.minimum_valid_fraction:
358
+ # This should produce a recursion error at some point when something is very wrong.
359
+ return self[0]
360
+ # Pick an reference observed point in the point cloud
361
+ observed_point = np.mean(ref_pointcloud, axis=0)
362
+
363
+ # Add the first image as reference
364
+ viewpoints_observations = [ref_observations]
365
+ viewpoints_covisibility = [ref_valid_fraction]
366
+ viewpoints_positions = [ref_position]
367
+ viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
368
+ viewpoints_clouds = [ref_pointcloud]
369
+ viewpoints_valid_fractions = [ref_valid_fraction]
370
+
371
+ for _ in range(self.views_count - 1):
372
+ # Generate an other viewpoint using some dummy random walk
373
+ successful_sampling = False
374
+ for sampling_attempt in range(self.max_attempts_count):
375
+ position, rotation, _ = self.sample_other_random_viewpoint(
376
+ observed_point, nav_point
377
+ )
378
+ # Observation
379
+ other_viewpoint_observations = self.render_viewpoint(position, rotation)
380
+ other_pointcloud = compute_pointcloud(
381
+ other_viewpoint_observations["depth"], self.hfov, position, rotation
382
+ )
383
+
384
+ is_valid, valid_fraction, covisibility = (
385
+ self.is_other_pointcloud_overlapping(
386
+ ref_pointcloud, other_pointcloud
387
+ )
388
+ )
389
+ if is_valid:
390
+ successful_sampling = True
391
+ break
392
+ if not successful_sampling:
393
+ print("WARNING: Maximum number of attempts reached.")
394
+ # Dirty hack, try using a novel original viewpoint
395
+ return self[0]
396
+ viewpoints_observations.append(other_viewpoint_observations)
397
+ viewpoints_covisibility.append(covisibility)
398
+ viewpoints_positions.append(position)
399
+ viewpoints_orientations.append(
400
+ quaternion.as_float_array(rotation)
401
+ ) # WXYZ convention for the quaternion encoding.
402
+ viewpoints_clouds.append(other_pointcloud)
403
+ viewpoints_valid_fractions.append(valid_fraction)
404
+
405
+ # Estimate relations between all pairs of images
406
+ pairwise_visibility_ratios = np.ones(
407
+ (len(viewpoints_observations), len(viewpoints_observations))
408
+ )
409
+ for i in range(len(viewpoints_observations)):
410
+ pairwise_visibility_ratios[i, i] = viewpoints_valid_fractions[i]
411
+ for j in range(i + 1, len(viewpoints_observations)):
412
+ overlap = compute_pointcloud_overlaps_scikit(
413
+ viewpoints_clouds[i],
414
+ viewpoints_clouds[j],
415
+ self.distance_threshold,
416
+ compute_symmetric=True,
417
+ )
418
+ pairwise_visibility_ratios[i, j] = (
419
+ overlap["intersection1"] / pixels_count
420
+ )
421
+ pairwise_visibility_ratios[j, i] = (
422
+ overlap["intersection2"] / pixels_count
423
+ )
424
+
425
+ # IoU is relative to the image 0
426
+ data = {
427
+ "observations": viewpoints_observations,
428
+ "positions": np.asarray(viewpoints_positions),
429
+ "orientations": np.asarray(viewpoints_orientations),
430
+ "covisibility_ratios": np.asarray(viewpoints_covisibility),
431
+ "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
432
+ "pairwise_visibility_ratios": np.asarray(
433
+ pairwise_visibility_ratios, dtype=float
434
+ ),
435
+ }
436
+
437
+ if self.transform is not None:
438
+ data = self.transform(data)
439
+ return data
440
+
441
+ def generate_random_spiral_trajectory(
442
+ self,
443
+ images_count=100,
444
+ max_radius=0.5,
445
+ half_turns=5,
446
+ use_constant_orientation=False,
447
+ ):
448
+ """
449
+ Return a list of images corresponding to a spiral trajectory from a random starting point.
450
+ Useful to generate nice visualisations.
451
+ Use an even number of half turns to get a nice "C1-continuous" loop effect
452
+ """
453
+ ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
454
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
455
+ ref_pointcloud = compute_pointcloud(
456
+ depthmap=ref_observations["depth"],
457
+ hfov=self.hfov,
458
+ camera_position=ref_position,
459
+ camera_rotation=ref_orientation,
460
+ )
461
+ pixels_count = self.resolution[0] * self.resolution[1]
462
+ if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
463
+ # Dirty hack: ensure that the valid part of the image is significant
464
+ return self.generate_random_spiral_trajectory(
465
+ images_count, max_radius, half_turns, use_constant_orientation
466
+ )
467
+
468
+ # Pick an observed point in the point cloud
469
+ observed_point = np.mean(ref_pointcloud, axis=0)
470
+ ref_R, ref_t = compute_camera_pose_opencv_convention(
471
+ ref_position, ref_orientation
472
+ )
473
+
474
+ images = []
475
+ is_valid = []
476
+ # Spiral trajectory, use_constant orientation
477
+ for i, alpha in enumerate(np.linspace(0, 1, images_count)):
478
+ r = max_radius * np.abs(
479
+ np.sin(alpha * np.pi)
480
+ ) # Increase then decrease the radius
481
+ theta = alpha * half_turns * np.pi
482
+ x = r * np.cos(theta)
483
+ y = r * np.sin(theta)
484
+ z = 0.0
485
+ position = (
486
+ ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3, 1)).flatten()
487
+ )
488
+ if use_constant_orientation:
489
+ orientation = ref_orientation
490
+ else:
491
+ # trajectory looking at a mean point in front of the ref observation
492
+ orientation, position = look_at_for_habitat(
493
+ eye=position, center=observed_point, up=habitat_sim.geo.UP
494
+ )
495
+ observations = self.render_viewpoint(position, orientation)
496
+ images.append(observations["color"][..., :3])
497
+ _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(
498
+ ref_pointcloud, observations, position, orientation
499
+ )
500
+ is_valid.append(_is_valid)
501
+ return images, np.all(is_valid)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/pack_metadata_files.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ """
4
+ Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
5
+ """
6
+ import os
7
+ import glob
8
+ from tqdm import tqdm
9
+ import shutil
10
+ import json
11
+ from datasets.habitat_sim.paths import *
12
+ import argparse
13
+ import collections
14
+
15
+ if __name__ == "__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("input_dir")
18
+ parser.add_argument("output_dir")
19
+ args = parser.parse_args()
20
+
21
+ input_dirname = args.input_dir
22
+ output_dirname = args.output_dir
23
+
24
+ input_metadata_filenames = glob.iglob(
25
+ f"{input_dirname}/**/metadata.json", recursive=True
26
+ )
27
+
28
+ images_count = collections.defaultdict(lambda: 0)
29
+
30
+ os.makedirs(output_dirname)
31
+ for input_filename in tqdm(input_metadata_filenames):
32
+ # Ignore empty files
33
+ with open(input_filename, "r") as f:
34
+ original_metadata = json.load(f)
35
+ if (
36
+ "multiviews" not in original_metadata
37
+ or len(original_metadata["multiviews"]) == 0
38
+ ):
39
+ print("No views in", input_filename)
40
+ continue
41
+
42
+ relpath = os.path.relpath(input_filename, input_dirname)
43
+ print(relpath)
44
+
45
+ # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
46
+ # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
47
+ scenes_dataset_paths = dict(
48
+ sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True)
49
+ )
50
+ metadata = dict()
51
+ for key, value in original_metadata.items():
52
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
53
+ known_path = False
54
+ for dataset, dataset_path in scenes_dataset_paths.items():
55
+ if value.startswith(dataset_path):
56
+ value = os.path.join(
57
+ dataset, os.path.relpath(value, dataset_path)
58
+ )
59
+ known_path = True
60
+ break
61
+ if not known_path:
62
+ raise KeyError("Unknown path:" + value)
63
+ metadata[key] = value
64
+
65
+ # Compile some general statistics while packing data
66
+ scene_split = metadata["scene"].split("/")
67
+ upper_level = (
68
+ "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
69
+ )
70
+ images_count[upper_level] += len(metadata["multiviews"])
71
+
72
+ output_filename = os.path.join(output_dirname, relpath)
73
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
74
+ with open(output_filename, "w") as f:
75
+ json.dump(metadata, f)
76
+
77
+ # Print statistics
78
+ print("Images count:")
79
+ for upper_level, count in images_count.items():
80
+ print(f"- {upper_level}: {count}")
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/habitat_sim/paths.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Paths to Habitat-Sim scenes
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import collections
11
+ from tqdm import tqdm
12
+
13
+
14
+ # Hardcoded path to the different scene datasets
15
+ SCENES_DATASET = {
16
+ "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
17
+ "gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
18
+ "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
19
+ "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
20
+ "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
21
+ "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
22
+ "scannet": "./data/habitat-sim/scene_datasets/scannet/",
23
+ }
24
+
25
+ SceneData = collections.namedtuple(
26
+ "SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"]
27
+ )
28
+
29
+
30
+ def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
31
+ scene_dataset_config_file = os.path.join(
32
+ base_path, "replicaCAD.scene_dataset_config.json"
33
+ )
34
+ scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
35
+ navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + [
36
+ "empty_stage.navmesh"
37
+ ]
38
+ scenes_data = []
39
+ for idx in range(len(scenes)):
40
+ output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
41
+ # Add scene
42
+ data = SceneData(
43
+ scene_dataset_config_file=scene_dataset_config_file,
44
+ scene=scenes[idx] + ".scene_instance.json",
45
+ navmesh=os.path.join(base_path, navmeshes[idx]),
46
+ output_dir=output_dir,
47
+ )
48
+ scenes_data.append(data)
49
+ return scenes_data
50
+
51
+
52
+ def list_replica_cad_baked_lighting_scenes(
53
+ base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]
54
+ ):
55
+ scene_dataset_config_file = os.path.join(
56
+ base_path, "replicaCAD_baked.scene_dataset_config.json"
57
+ )
58
+ scenes = sum(
59
+ [[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], []
60
+ )
61
+ navmeshes = "" # [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
62
+ scenes_data = []
63
+ for idx in range(len(scenes)):
64
+ output_dir = os.path.join(
65
+ base_output_dir, "replica_cad_baked_lighting", scenes[idx]
66
+ )
67
+ data = SceneData(
68
+ scene_dataset_config_file=scene_dataset_config_file,
69
+ scene=scenes[idx],
70
+ navmesh="",
71
+ output_dir=output_dir,
72
+ )
73
+ scenes_data.append(data)
74
+ return scenes_data
75
+
76
+
77
+ def list_replica_scenes(base_output_dir, base_path):
78
+ scenes_data = []
79
+ for scene_id in os.listdir(base_path):
80
+ scene = os.path.join(base_path, scene_id, "mesh.ply")
81
+ navmesh = os.path.join(
82
+ base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh"
83
+ ) # Not sure if I should use it
84
+ scene_dataset_config_file = ""
85
+ output_dir = os.path.join(base_output_dir, scene_id)
86
+ # Add scene only if it does not exist already, or if exist_ok
87
+ data = SceneData(
88
+ scene_dataset_config_file=scene_dataset_config_file,
89
+ scene=scene,
90
+ navmesh=navmesh,
91
+ output_dir=output_dir,
92
+ )
93
+ scenes_data.append(data)
94
+ return scenes_data
95
+
96
+
97
+ def list_scenes(base_output_dir, base_path):
98
+ """
99
+ Generic method iterating through a base_path folder to find scenes.
100
+ """
101
+ scenes_data = []
102
+ for root, dirs, files in os.walk(base_path, followlinks=True):
103
+ folder_scenes_data = []
104
+ for file in files:
105
+ name, ext = os.path.splitext(file)
106
+ if ext == ".glb":
107
+ scene = os.path.join(root, name + ".glb")
108
+ navmesh = os.path.join(root, name + ".navmesh")
109
+ if not os.path.exists(navmesh):
110
+ navmesh = ""
111
+ relpath = os.path.relpath(root, base_path)
112
+ output_dir = os.path.abspath(
113
+ os.path.join(base_output_dir, relpath, name)
114
+ )
115
+ data = SceneData(
116
+ scene_dataset_config_file="",
117
+ scene=scene,
118
+ navmesh=navmesh,
119
+ output_dir=output_dir,
120
+ )
121
+ folder_scenes_data.append(data)
122
+
123
+ # Specific check for HM3D:
124
+ # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
125
+ basis_scenes = [
126
+ data.scene[: -len(".basis.glb")]
127
+ for data in folder_scenes_data
128
+ if data.scene.endswith(".basis.glb")
129
+ ]
130
+ if len(basis_scenes) != 0:
131
+ folder_scenes_data = [
132
+ data
133
+ for data in folder_scenes_data
134
+ if not (data.scene[: -len(".glb")] in basis_scenes)
135
+ ]
136
+
137
+ scenes_data.extend(folder_scenes_data)
138
+ return scenes_data
139
+
140
+
141
+ def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
142
+ scenes_data = []
143
+
144
+ # HM3D
145
+ for split in ("minival", "train", "val", "examples"):
146
+ scenes_data += list_scenes(
147
+ base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
148
+ base_path=f"{scenes_dataset_paths['hm3d']}/{split}",
149
+ )
150
+
151
+ # Gibson
152
+ scenes_data += list_scenes(
153
+ base_output_dir=os.path.join(base_output_dir, "gibson"),
154
+ base_path=scenes_dataset_paths["gibson"],
155
+ )
156
+
157
+ # Habitat test scenes (just a few)
158
+ scenes_data += list_scenes(
159
+ base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
160
+ base_path=scenes_dataset_paths["habitat-test-scenes"],
161
+ )
162
+
163
+ # ReplicaCAD (baked lightning)
164
+ scenes_data += list_replica_cad_baked_lighting_scenes(
165
+ base_output_dir=base_output_dir
166
+ )
167
+
168
+ # ScanNet
169
+ scenes_data += list_scenes(
170
+ base_output_dir=os.path.join(base_output_dir, "scannet"),
171
+ base_path=scenes_dataset_paths["scannet"],
172
+ )
173
+
174
+ # Replica
175
+ list_replica_scenes(
176
+ base_output_dir=os.path.join(base_output_dir, "replica"),
177
+ base_path=scenes_dataset_paths["replica"],
178
+ )
179
+ return scenes_data
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/pairs_dataset.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+
8
+ from datasets.transforms import get_pair_transforms
9
+
10
+
11
+ def load_image(impath):
12
+ return Image.open(impath)
13
+
14
+
15
+ def load_pairs_from_cache_file(fname, root=""):
16
+ assert os.path.isfile(
17
+ fname
18
+ ), "cannot parse pairs from {:s}, file does not exist".format(fname)
19
+ with open(fname, "r") as fid:
20
+ lines = fid.read().strip().splitlines()
21
+ pairs = [
22
+ (os.path.join(root, l.split()[0]), os.path.join(root, l.split()[1]))
23
+ for l in lines
24
+ ]
25
+ return pairs
26
+
27
+
28
+ def load_pairs_from_list_file(fname, root=""):
29
+ assert os.path.isfile(
30
+ fname
31
+ ), "cannot parse pairs from {:s}, file does not exist".format(fname)
32
+ with open(fname, "r") as fid:
33
+ lines = fid.read().strip().splitlines()
34
+ pairs = [
35
+ (os.path.join(root, l + "_1.jpg"), os.path.join(root, l + "_2.jpg"))
36
+ for l in lines
37
+ if not l.startswith("#")
38
+ ]
39
+ return pairs
40
+
41
+
42
+ def write_cache_file(fname, pairs, root=""):
43
+ if len(root) > 0:
44
+ if not root.endswith("/"):
45
+ root += "/"
46
+ assert os.path.isdir(root)
47
+ s = ""
48
+ for im1, im2 in pairs:
49
+ if len(root) > 0:
50
+ assert im1.startswith(root), im1
51
+ assert im2.startswith(root), im2
52
+ s += "{:s} {:s}\n".format(im1[len(root) :], im2[len(root) :])
53
+ with open(fname, "w") as fid:
54
+ fid.write(s[:-1])
55
+
56
+
57
+ def parse_and_cache_all_pairs(dname, data_dir="./data/"):
58
+ if dname == "habitat_release":
59
+ dirname = os.path.join(data_dir, "habitat_release")
60
+ assert os.path.isdir(dirname), (
61
+ "cannot find folder for habitat_release pairs: " + dirname
62
+ )
63
+ cache_file = os.path.join(dirname, "pairs.txt")
64
+ assert not os.path.isfile(cache_file), (
65
+ "cache file already exists: " + cache_file
66
+ )
67
+
68
+ print("Parsing pairs for dataset: " + dname)
69
+ pairs = []
70
+ for root, dirs, files in os.walk(dirname):
71
+ if "val" in root:
72
+ continue
73
+ dirs.sort()
74
+ pairs += [
75
+ (
76
+ os.path.join(root, f),
77
+ os.path.join(root, f[: -len("_1.jpeg")] + "_2.jpeg"),
78
+ )
79
+ for f in sorted(files)
80
+ if f.endswith("_1.jpeg")
81
+ ]
82
+ print("Found {:,} pairs".format(len(pairs)))
83
+ print("Writing cache to: " + cache_file)
84
+ write_cache_file(cache_file, pairs, root=dirname)
85
+
86
+ else:
87
+ raise NotImplementedError("Unknown dataset: " + dname)
88
+
89
+
90
+ def dnames_to_image_pairs(dnames, data_dir="./data/"):
91
+ """
92
+ dnames: list of datasets with image pairs, separated by +
93
+ """
94
+ all_pairs = []
95
+ for dname in dnames.split("+"):
96
+ if dname == "habitat_release":
97
+ dirname = os.path.join(data_dir, "habitat_release")
98
+ assert os.path.isdir(dirname), (
99
+ "cannot find folder for habitat_release pairs: " + dirname
100
+ )
101
+ cache_file = os.path.join(dirname, "pairs.txt")
102
+ assert os.path.isfile(cache_file), (
103
+ "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "
104
+ + cache_file
105
+ )
106
+ pairs = load_pairs_from_cache_file(cache_file, root=dirname)
107
+ elif dname in ["ARKitScenes", "MegaDepth", "3DStreetView", "IndoorVL"]:
108
+ dirname = os.path.join(data_dir, dname + "_crops")
109
+ assert os.path.isdir(
110
+ dirname
111
+ ), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
112
+ list_file = os.path.join(dirname, "listing.txt")
113
+ assert os.path.isfile(
114
+ list_file
115
+ ), "cannot find list file for {:s} pairs, see instructions. {:s}".format(
116
+ dname, list_file
117
+ )
118
+ pairs = load_pairs_from_list_file(list_file, root=dirname)
119
+ print(" {:s}: {:,} pairs".format(dname, len(pairs)))
120
+ all_pairs += pairs
121
+ if "+" in dnames:
122
+ print(" Total: {:,} pairs".format(len(all_pairs)))
123
+ return all_pairs
124
+
125
+
126
+ class PairsDataset(Dataset):
127
+
128
+ def __init__(
129
+ self, dnames, trfs="", totensor=True, normalize=True, data_dir="./data/"
130
+ ):
131
+ super().__init__()
132
+ self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
133
+ self.transforms = get_pair_transforms(
134
+ transform_str=trfs, totensor=totensor, normalize=normalize
135
+ )
136
+
137
+ def __len__(self):
138
+ return len(self.image_pairs)
139
+
140
+ def __getitem__(self, index):
141
+ im1path, im2path = self.image_pairs[index]
142
+ im1 = load_image(im1path)
143
+ im2 = load_image(im2path)
144
+ if self.transforms is not None:
145
+ im1, im2 = self.transforms(im1, im2)
146
+ return im1, im2
147
+
148
+
149
+ if __name__ == "__main__":
150
+ import argparse
151
+
152
+ parser = argparse.ArgumentParser(
153
+ prog="Computing and caching list of pairs for a given dataset"
154
+ )
155
+ parser.add_argument(
156
+ "--data_dir", default="./data/", type=str, help="path where data are stored"
157
+ )
158
+ parser.add_argument(
159
+ "--dataset", default="habitat_release", type=str, help="name of the dataset"
160
+ )
161
+ args = parser.parse_args()
162
+ parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/datasets/transforms.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ import torchvision.transforms
6
+ import torchvision.transforms.functional as F
7
+
8
+ # "Pair": apply a transform on a pair
9
+ # "Both": apply the exact same transform to both images
10
+
11
+
12
+ class ComposePair(torchvision.transforms.Compose):
13
+ def __call__(self, img1, img2):
14
+ for t in self.transforms:
15
+ img1, img2 = t(img1, img2)
16
+ return img1, img2
17
+
18
+
19
+ class NormalizeBoth(torchvision.transforms.Normalize):
20
+ def forward(self, img1, img2):
21
+ img1 = super().forward(img1)
22
+ img2 = super().forward(img2)
23
+ return img1, img2
24
+
25
+
26
+ class ToTensorBoth(torchvision.transforms.ToTensor):
27
+ def __call__(self, img1, img2):
28
+ img1 = super().__call__(img1)
29
+ img2 = super().__call__(img2)
30
+ return img1, img2
31
+
32
+
33
+ class RandomCropPair(torchvision.transforms.RandomCrop):
34
+ # the crop will be intentionally different for the two images with this class
35
+ def forward(self, img1, img2):
36
+ img1 = super().forward(img1)
37
+ img2 = super().forward(img2)
38
+ return img1, img2
39
+
40
+
41
+ class ColorJitterPair(torchvision.transforms.ColorJitter):
42
+ # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
43
+ def __init__(self, assymetric_prob, **kwargs):
44
+ super().__init__(**kwargs)
45
+ self.assymetric_prob = assymetric_prob
46
+
47
+ def jitter_one(
48
+ self,
49
+ img,
50
+ fn_idx,
51
+ brightness_factor,
52
+ contrast_factor,
53
+ saturation_factor,
54
+ hue_factor,
55
+ ):
56
+ for fn_id in fn_idx:
57
+ if fn_id == 0 and brightness_factor is not None:
58
+ img = F.adjust_brightness(img, brightness_factor)
59
+ elif fn_id == 1 and contrast_factor is not None:
60
+ img = F.adjust_contrast(img, contrast_factor)
61
+ elif fn_id == 2 and saturation_factor is not None:
62
+ img = F.adjust_saturation(img, saturation_factor)
63
+ elif fn_id == 3 and hue_factor is not None:
64
+ img = F.adjust_hue(img, hue_factor)
65
+ return img
66
+
67
+ def forward(self, img1, img2):
68
+
69
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = (
70
+ self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
71
+ )
72
+ img1 = self.jitter_one(
73
+ img1,
74
+ fn_idx,
75
+ brightness_factor,
76
+ contrast_factor,
77
+ saturation_factor,
78
+ hue_factor,
79
+ )
80
+ if torch.rand(1) < self.assymetric_prob: # assymetric:
81
+ (
82
+ fn_idx,
83
+ brightness_factor,
84
+ contrast_factor,
85
+ saturation_factor,
86
+ hue_factor,
87
+ ) = self.get_params(
88
+ self.brightness, self.contrast, self.saturation, self.hue
89
+ )
90
+ img2 = self.jitter_one(
91
+ img2,
92
+ fn_idx,
93
+ brightness_factor,
94
+ contrast_factor,
95
+ saturation_factor,
96
+ hue_factor,
97
+ )
98
+ return img1, img2
99
+
100
+
101
+ def get_pair_transforms(transform_str, totensor=True, normalize=True):
102
+ # transform_str is eg crop224+color
103
+ trfs = []
104
+ for s in transform_str.split("+"):
105
+ if s.startswith("crop"):
106
+ size = int(s[len("crop") :])
107
+ trfs.append(RandomCropPair(size))
108
+ elif s == "acolor":
109
+ trfs.append(
110
+ ColorJitterPair(
111
+ assymetric_prob=1.0,
112
+ brightness=(0.6, 1.4),
113
+ contrast=(0.6, 1.4),
114
+ saturation=(0.6, 1.4),
115
+ hue=0.0,
116
+ )
117
+ )
118
+ elif s == "": # if transform_str was ""
119
+ pass
120
+ else:
121
+ raise NotImplementedError("Unknown augmentation: " + s)
122
+
123
+ if totensor:
124
+ trfs.append(ToTensorBoth())
125
+ if normalize:
126
+ trfs.append(
127
+ NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
128
+ )
129
+
130
+ if len(trfs) == 0:
131
+ return None
132
+ elif len(trfs) == 1:
133
+ return trfs
134
+ else:
135
+ return ComposePair(trfs)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/interactive_demo.ipynb ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Interactive demo of Cross-view Completion."
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
17
+ "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "import torch\n",
27
+ "import numpy as np\n",
28
+ "from models.croco import CroCoNet\n",
29
+ "from ipywidgets import interact, interactive, fixed, interact_manual\n",
30
+ "import ipywidgets as widgets\n",
31
+ "import matplotlib.pyplot as plt\n",
32
+ "import quaternion\n",
33
+ "import models.masking"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {},
39
+ "source": [
40
+ "### Load CroCo model"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')\n",
50
+ "model = CroCoNet( **ckpt.get('croco_kwargs',{}))\n",
51
+ "msg = model.load_state_dict(ckpt['model'], strict=True)\n",
52
+ "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
53
+ "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
54
+ "model = model.eval()\n",
55
+ "model = model.to(device=device)\n",
56
+ "print(msg)\n",
57
+ "\n",
58
+ "def process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches=False):\n",
59
+ " \"\"\"\n",
60
+ " Perform Cross-View completion using two input images, specified using Numpy arrays.\n",
61
+ " \"\"\"\n",
62
+ " # Replace the mask generator\n",
63
+ " model.mask_generator = models.masking.RandomMask(model.patch_embed.num_patches, masking_ratio)\n",
64
+ "\n",
65
+ " # ImageNet-1k color normalization\n",
66
+ " imagenet_mean = torch.as_tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1).to(device)\n",
67
+ " imagenet_std = torch.as_tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1).to(device)\n",
68
+ "\n",
69
+ " normalize_input_colors = True\n",
70
+ " is_output_normalized = True\n",
71
+ " with torch.no_grad():\n",
72
+ " # Cast data to torch\n",
73
+ " target_image = (torch.as_tensor(target_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
74
+ " ref_image = (torch.as_tensor(ref_image, dtype=torch.float, device=device).permute(2,0,1) / 255)[None]\n",
75
+ "\n",
76
+ " if normalize_input_colors:\n",
77
+ " ref_image = (ref_image - imagenet_mean) / imagenet_std\n",
78
+ " target_image = (target_image - imagenet_mean) / imagenet_std\n",
79
+ "\n",
80
+ " out, mask, _ = model(target_image, ref_image)\n",
81
+ " # # get target\n",
82
+ " if not is_output_normalized:\n",
83
+ " predicted_image = model.unpatchify(out)\n",
84
+ " else:\n",
85
+ " # The output only contains higher order information,\n",
86
+ " # we retrieve mean and standard deviation from the actual target image\n",
87
+ " patchified = model.patchify(target_image)\n",
88
+ " mean = patchified.mean(dim=-1, keepdim=True)\n",
89
+ " var = patchified.var(dim=-1, keepdim=True)\n",
90
+ " pred_renorm = out * (var + 1.e-6)**.5 + mean\n",
91
+ " predicted_image = model.unpatchify(pred_renorm)\n",
92
+ "\n",
93
+ " image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])\n",
94
+ " masked_target_image = (1 - image_masks) * target_image\n",
95
+ " \n",
96
+ " if not reconstruct_unmasked_patches:\n",
97
+ " # Replace unmasked patches by their actual values\n",
98
+ " predicted_image = predicted_image * image_masks + masked_target_image\n",
99
+ "\n",
100
+ " # Unapply color normalization\n",
101
+ " if normalize_input_colors:\n",
102
+ " predicted_image = predicted_image * imagenet_std + imagenet_mean\n",
103
+ " masked_target_image = masked_target_image * imagenet_std + imagenet_mean\n",
104
+ " \n",
105
+ " # Cast to Numpy\n",
106
+ " masked_target_image = np.asarray(torch.clamp(masked_target_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
107
+ " predicted_image = np.asarray(torch.clamp(predicted_image.squeeze(0).permute(1,2,0) * 255, 0, 255).cpu().numpy(), dtype=np.uint8)\n",
108
+ " return masked_target_image, predicted_image"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "markdown",
113
+ "metadata": {},
114
+ "source": [
115
+ "### Use the Habitat simulator to render images from arbitrary viewpoints (requires habitat_sim to be installed)"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "import os\n",
125
+ "os.environ[\"MAGNUM_LOG\"]=\"quiet\"\n",
126
+ "os.environ[\"HABITAT_SIM_LOG\"]=\"quiet\"\n",
127
+ "import habitat_sim\n",
128
+ "\n",
129
+ "scene = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.glb\"\n",
130
+ "navmesh = \"habitat-sim-data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh\"\n",
131
+ "\n",
132
+ "sim_cfg = habitat_sim.SimulatorConfiguration()\n",
133
+ "if use_gpu: sim_cfg.gpu_device_id = 0\n",
134
+ "sim_cfg.scene_id = scene\n",
135
+ "sim_cfg.load_semantic_mesh = False\n",
136
+ "rgb_sensor_spec = habitat_sim.CameraSensorSpec()\n",
137
+ "rgb_sensor_spec.uuid = \"color\"\n",
138
+ "rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR\n",
139
+ "rgb_sensor_spec.resolution = (224,224)\n",
140
+ "rgb_sensor_spec.hfov = 56.56\n",
141
+ "rgb_sensor_spec.position = [0.0, 0.0, 0.0]\n",
142
+ "rgb_sensor_spec.orientation = [0, 0, 0]\n",
143
+ "agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec])\n",
144
+ "\n",
145
+ "\n",
146
+ "cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])\n",
147
+ "sim = habitat_sim.Simulator(cfg)\n",
148
+ "if navmesh is not None:\n",
149
+ " sim.pathfinder.load_nav_mesh(navmesh)\n",
150
+ "agent = sim.initialize_agent(agent_id=0)\n",
151
+ "\n",
152
+ "def sample_random_viewpoint():\n",
153
+ " \"\"\" Sample a random viewpoint using the navmesh \"\"\"\n",
154
+ " nav_point = sim.pathfinder.get_random_navigable_point()\n",
155
+ " # Sample a random viewpoint height\n",
156
+ " viewpoint_height = np.random.uniform(1.0, 1.6)\n",
157
+ " viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP\n",
158
+ " viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(-np.pi, np.pi) * habitat_sim.geo.UP)\n",
159
+ " return viewpoint_position, viewpoint_orientation\n",
160
+ "\n",
161
+ "def render_viewpoint(position, orientation):\n",
162
+ " agent_state = habitat_sim.AgentState()\n",
163
+ " agent_state.position = position\n",
164
+ " agent_state.rotation = orientation\n",
165
+ " agent.set_state(agent_state)\n",
166
+ " viewpoint_observations = sim.get_sensor_observations(agent_ids=0)\n",
167
+ " image = viewpoint_observations['color'][:,:,:3]\n",
168
+ " image = np.asarray(np.clip(1.5 * np.asarray(image, dtype=float), 0, 255), dtype=np.uint8)\n",
169
+ " return image"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "markdown",
174
+ "metadata": {},
175
+ "source": [
176
+ "### Sample a random reference view"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "ref_position, ref_orientation = sample_random_viewpoint()\n",
186
+ "ref_image = render_viewpoint(ref_position, ref_orientation)\n",
187
+ "plt.clf()\n",
188
+ "fig, axes = plt.subplots(1,1, squeeze=False, num=1)\n",
189
+ "axes[0,0].imshow(ref_image)\n",
190
+ "for ax in axes.flatten():\n",
191
+ " ax.set_xticks([])\n",
192
+ " ax.set_yticks([])"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {},
198
+ "source": [
199
+ "### Interactive cross-view completion using CroCo"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "reconstruct_unmasked_patches = False\n",
209
+ "\n",
210
+ "def show_demo(masking_ratio, x, y, z, panorama, elevation):\n",
211
+ " R = quaternion.as_rotation_matrix(ref_orientation)\n",
212
+ " target_position = ref_position + x * R[:,0] + y * R[:,1] + z * R[:,2]\n",
213
+ " target_orientation = (ref_orientation\n",
214
+ " * quaternion.from_rotation_vector(-elevation * np.pi/180 * habitat_sim.geo.LEFT) \n",
215
+ " * quaternion.from_rotation_vector(-panorama * np.pi/180 * habitat_sim.geo.UP))\n",
216
+ " \n",
217
+ " ref_image = render_viewpoint(ref_position, ref_orientation)\n",
218
+ " target_image = render_viewpoint(target_position, target_orientation)\n",
219
+ "\n",
220
+ " masked_target_image, predicted_image = process_images(ref_image, target_image, masking_ratio, reconstruct_unmasked_patches)\n",
221
+ "\n",
222
+ " fig, axes = plt.subplots(1,4, squeeze=True, dpi=300)\n",
223
+ " axes[0].imshow(ref_image)\n",
224
+ " axes[0].set_xlabel(\"Reference\")\n",
225
+ " axes[1].imshow(masked_target_image)\n",
226
+ " axes[1].set_xlabel(\"Masked target\")\n",
227
+ " axes[2].imshow(predicted_image)\n",
228
+ " axes[2].set_xlabel(\"Reconstruction\") \n",
229
+ " axes[3].imshow(target_image)\n",
230
+ " axes[3].set_xlabel(\"Target\")\n",
231
+ " for ax in axes.flatten():\n",
232
+ " ax.set_xticks([])\n",
233
+ " ax.set_yticks([])\n",
234
+ "\n",
235
+ "interact(show_demo,\n",
236
+ " masking_ratio=widgets.FloatSlider(description='masking', value=0.9, min=0.0, max=1.0),\n",
237
+ " x=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
238
+ " y=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
239
+ " z=widgets.FloatSlider(value=0.0, min=-0.5, max=0.5, step=0.05),\n",
240
+ " panorama=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5),\n",
241
+ " elevation=widgets.FloatSlider(value=0.0, min=-20, max=20, step=0.5));"
242
+ ]
243
+ }
244
+ ],
245
+ "metadata": {
246
+ "kernelspec": {
247
+ "display_name": "Python 3 (ipykernel)",
248
+ "language": "python",
249
+ "name": "python3"
250
+ },
251
+ "language_info": {
252
+ "codemirror_mode": {
253
+ "name": "ipython",
254
+ "version": 3
255
+ },
256
+ "file_extension": ".py",
257
+ "mimetype": "text/x-python",
258
+ "name": "python",
259
+ "nbconvert_exporter": "python",
260
+ "pygments_lexer": "ipython3",
261
+ "version": "3.7.13"
262
+ },
263
+ "vscode": {
264
+ "interpreter": {
265
+ "hash": "f9237820cd248d7e07cb4fb9f0e4508a85d642f19d831560c0a4b61f3e907e67"
266
+ }
267
+ }
268
+ },
269
+ "nbformat": 4,
270
+ "nbformat_minor": 2
271
+ }
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/blocks.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Main encoder/decoder blocks
7
+ # --------------------------------------------------------
8
+ # References:
9
+ # timm
10
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
11
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
12
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
13
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
14
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from itertools import repeat
21
+ import collections.abc
22
+ from torch.nn.functional import scaled_dot_product_attention
23
+
24
+
25
+ def _ntuple(n):
26
+ def parse(x):
27
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
28
+ return x
29
+ return tuple(repeat(x, n))
30
+
31
+ return parse
32
+
33
+
34
+ to_2tuple = _ntuple(2)
35
+
36
+
37
+ def drop_path(
38
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
39
+ ):
40
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
41
+ if drop_prob == 0.0 or not training:
42
+ return x
43
+ keep_prob = 1 - drop_prob
44
+ shape = (x.shape[0],) + (1,) * (
45
+ x.ndim - 1
46
+ ) # work with diff dim tensors, not just 2D ConvNets
47
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
48
+ if keep_prob > 0.0 and scale_by_keep:
49
+ random_tensor.div_(keep_prob)
50
+ return x * random_tensor
51
+
52
+
53
+ class DropPath(nn.Module):
54
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
55
+
56
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
57
+ super(DropPath, self).__init__()
58
+ self.drop_prob = drop_prob
59
+ self.scale_by_keep = scale_by_keep
60
+
61
+ def forward(self, x):
62
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
63
+
64
+ def extra_repr(self):
65
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
66
+
67
+
68
+ class Mlp(nn.Module):
69
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
70
+
71
+ def __init__(
72
+ self,
73
+ in_features,
74
+ hidden_features=None,
75
+ out_features=None,
76
+ act_layer=nn.GELU,
77
+ bias=True,
78
+ drop=0.0,
79
+ ):
80
+ super().__init__()
81
+ out_features = out_features or in_features
82
+ hidden_features = hidden_features or in_features
83
+ bias = to_2tuple(bias)
84
+ drop_probs = to_2tuple(drop)
85
+
86
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
87
+ self.act = act_layer()
88
+ self.drop1 = nn.Dropout(drop_probs[0])
89
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
90
+ self.drop2 = nn.Dropout(drop_probs[1])
91
+
92
+ def forward(self, x):
93
+ return self.drop2(self.fc2(self.drop1(self.act(self.fc1(x)))))
94
+
95
+
96
+ class Attention(nn.Module):
97
+
98
+ def __init__(
99
+ self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
100
+ ):
101
+ super().__init__()
102
+ self.num_heads = num_heads
103
+ head_dim = dim // num_heads
104
+ self.scale = head_dim**-0.5
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+ self.rope = rope.float() if rope is not None else None
110
+
111
+ def forward(self, x, xpos):
112
+ B, N, C = x.shape
113
+
114
+ qkv = (
115
+ self.qkv(x)
116
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
117
+ .transpose(1, 3)
118
+ )
119
+ q, k, v = [qkv[:, :, i] for i in range(3)]
120
+ # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
121
+
122
+ q_type = q.dtype
123
+ k_type = k.dtype
124
+ if self.rope is not None:
125
+ q = q.to(torch.float16)
126
+ k = k.to(torch.float16)
127
+ with torch.autocast(device_type="cuda", enabled=False):
128
+ q = self.rope(q, xpos)
129
+ k = self.rope(k, xpos)
130
+ q = q.to(q_type)
131
+ k = k.to(k_type)
132
+
133
+ # attn = (q @ k.transpose(-2, -1)) * self.scale
134
+ # attn = attn.softmax(dim=-1)
135
+ # attn = self.attn_drop(attn)
136
+
137
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
138
+ # x = memory_efficient_attention(query=q.permute(0, 2, 1, 3), key=k.permute(0, 2, 1, 3), value=v.permute(0, 2, 1, 3), p=self.attn_drop.p, scale=self.scale).reshape(B, N, C)
139
+ x = (
140
+ scaled_dot_product_attention(
141
+ query=q, key=k, value=v, dropout_p=self.attn_drop.p, scale=self.scale
142
+ )
143
+ .transpose(1, 2)
144
+ .reshape(B, N, C)
145
+ )
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ return x
149
+
150
+
151
+ class Block(nn.Module):
152
+
153
+ def __init__(
154
+ self,
155
+ dim,
156
+ num_heads,
157
+ mlp_ratio=4.0,
158
+ qkv_bias=False,
159
+ drop=0.0,
160
+ attn_drop=0.0,
161
+ drop_path=0.0,
162
+ act_layer=nn.GELU,
163
+ norm_layer=nn.LayerNorm,
164
+ rope=None,
165
+ ):
166
+ super().__init__()
167
+ self.norm1 = norm_layer(dim)
168
+ self.attn = Attention(
169
+ dim,
170
+ rope=rope,
171
+ num_heads=num_heads,
172
+ qkv_bias=qkv_bias,
173
+ attn_drop=attn_drop,
174
+ proj_drop=drop,
175
+ )
176
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
177
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
178
+ self.norm2 = norm_layer(dim)
179
+ mlp_hidden_dim = int(dim * mlp_ratio)
180
+ self.mlp = Mlp(
181
+ in_features=dim,
182
+ hidden_features=mlp_hidden_dim,
183
+ act_layer=act_layer,
184
+ drop=drop,
185
+ )
186
+
187
+ def forward(self, x, xpos):
188
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
189
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
190
+ return x
191
+
192
+
193
+ class CrossAttention(nn.Module):
194
+
195
+ def __init__(
196
+ self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
197
+ ):
198
+ super().__init__()
199
+ self.num_heads = num_heads
200
+ head_dim = dim // num_heads
201
+ self.scale = head_dim**-0.5
202
+
203
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
204
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
205
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
206
+ self.attn_drop = nn.Dropout(attn_drop)
207
+ self.proj = nn.Linear(dim, dim)
208
+ self.proj_drop = nn.Dropout(proj_drop)
209
+
210
+ self.rope = rope.float() if rope is not None else None
211
+
212
+ def forward(self, query, key, value, qpos, kpos):
213
+ B, Nq, C = query.shape
214
+ Nk = key.shape[1]
215
+ Nv = value.shape[1]
216
+
217
+ q = (
218
+ self.projq(query)
219
+ .reshape(B, Nq, self.num_heads, C // self.num_heads)
220
+ .permute(0, 2, 1, 3)
221
+ )
222
+ k = (
223
+ self.projk(key)
224
+ .reshape(B, Nk, self.num_heads, C // self.num_heads)
225
+ .permute(0, 2, 1, 3)
226
+ )
227
+ v = (
228
+ self.projv(value)
229
+ .reshape(B, Nv, self.num_heads, C // self.num_heads)
230
+ .permute(0, 2, 1, 3)
231
+ )
232
+
233
+ q_type = q.dtype
234
+ k_type = k.dtype
235
+ if self.rope is not None:
236
+ if qpos is not None:
237
+ q = q.to(torch.float16)
238
+ with torch.autocast(device_type="cuda", enabled=False):
239
+ q = self.rope(q, qpos)
240
+ q = q.to(q_type)
241
+
242
+ if kpos is not None:
243
+ k = k.to(torch.float16)
244
+ with torch.autocast(device_type="cuda", enabled=False):
245
+ k = self.rope(k, kpos)
246
+ k = k.to(k_type)
247
+
248
+ # attn = (q @ k.transpose(-2, -1)) * self.scale
249
+ # attn = attn.softmax(dim=-1)
250
+ # attn = self.attn_drop(attn)
251
+
252
+ # x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
253
+
254
+ # x = memory_efficient_attention(query=q.permute(0, 2, 1, 3), key=k.permute(0, 2, 1, 3), value=v.permute(0, 2, 1, 3), p=self.attn_drop.p, scale=self.scale).reshape(B, Nq, C)
255
+ x = (
256
+ scaled_dot_product_attention(
257
+ query=q, key=k, value=v, dropout_p=self.attn_drop.p, scale=self.scale
258
+ )
259
+ .transpose(1, 2)
260
+ .reshape(B, Nq, C)
261
+ )
262
+
263
+ x = self.proj(x)
264
+ x = self.proj_drop(x)
265
+ return x
266
+
267
+
268
+ class DecoderBlock(nn.Module):
269
+
270
+ def __init__(
271
+ self,
272
+ dim,
273
+ num_heads,
274
+ mlp_ratio=4.0,
275
+ qkv_bias=False,
276
+ drop=0.0,
277
+ attn_drop=0.0,
278
+ drop_path=0.0,
279
+ act_layer=nn.GELU,
280
+ norm_layer=nn.LayerNorm,
281
+ norm_mem=True,
282
+ rope=None,
283
+ ):
284
+ super().__init__()
285
+ self.norm1 = norm_layer(dim)
286
+ self.attn = Attention(
287
+ dim,
288
+ rope=rope,
289
+ num_heads=num_heads,
290
+ qkv_bias=qkv_bias,
291
+ attn_drop=attn_drop,
292
+ proj_drop=drop,
293
+ )
294
+ self.cross_attn = CrossAttention(
295
+ dim,
296
+ rope=rope,
297
+ num_heads=num_heads,
298
+ qkv_bias=qkv_bias,
299
+ attn_drop=attn_drop,
300
+ proj_drop=drop,
301
+ )
302
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
303
+ self.norm2 = norm_layer(dim)
304
+ self.norm3 = norm_layer(dim)
305
+ mlp_hidden_dim = int(dim * mlp_ratio)
306
+ self.mlp = Mlp(
307
+ in_features=dim,
308
+ hidden_features=mlp_hidden_dim,
309
+ act_layer=act_layer,
310
+ drop=drop,
311
+ )
312
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
313
+
314
+ def forward(self, x, y, xpos, ypos):
315
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
316
+ y_ = self.norm_y(y)
317
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
318
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
319
+ return x, y
320
+
321
+
322
+ # patch embedding
323
+ class PositionGetter(object):
324
+ """return positions of patches"""
325
+
326
+ def __init__(self):
327
+ self.cache_positions = {}
328
+
329
+ def __call__(self, b, h, w, device):
330
+ if not (h, w) in self.cache_positions:
331
+ x = torch.arange(w, device=device)
332
+ y = torch.arange(h, device=device)
333
+ self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
334
+ pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
335
+ return pos
336
+
337
+
338
+ class PatchEmbed(nn.Module):
339
+ """just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
340
+
341
+ def __init__(
342
+ self,
343
+ img_size=224,
344
+ patch_size=16,
345
+ in_chans=3,
346
+ embed_dim=768,
347
+ norm_layer=None,
348
+ flatten=True,
349
+ ):
350
+ super().__init__()
351
+ img_size = to_2tuple(img_size)
352
+ patch_size = to_2tuple(patch_size)
353
+ self.img_size = img_size
354
+ self.patch_size = patch_size
355
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
356
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
357
+ self.flatten = flatten
358
+
359
+ self.proj = nn.Conv2d(
360
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
361
+ )
362
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
363
+
364
+ self.position_getter = PositionGetter()
365
+
366
+ def forward(self, x):
367
+ B, C, H, W = x.shape
368
+ torch._assert(
369
+ H == self.img_size[0],
370
+ f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
371
+ )
372
+ torch._assert(
373
+ W == self.img_size[1],
374
+ f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
375
+ )
376
+ x = self.proj(x)
377
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
378
+ if self.flatten:
379
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
380
+ x = self.norm(x)
381
+ return x, pos
382
+
383
+ def _init_weights(self):
384
+ w = self.proj.weight.data
385
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/criterion.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Criterion to train CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+
14
+ class MaskedMSE(torch.nn.Module):
15
+
16
+ def __init__(self, norm_pix_loss=False, masked=True):
17
+ """
18
+ norm_pix_loss: normalize each patch by their pixel mean and variance
19
+ masked: compute loss over the masked patches only
20
+ """
21
+ super().__init__()
22
+ self.norm_pix_loss = norm_pix_loss
23
+ self.masked = masked
24
+
25
+ def forward(self, pred, mask, target):
26
+
27
+ if self.norm_pix_loss:
28
+ mean = target.mean(dim=-1, keepdim=True)
29
+ var = target.var(dim=-1, keepdim=True)
30
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
31
+
32
+ loss = (pred - target) ** 2
33
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
34
+ if self.masked:
35
+ loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
36
+ else:
37
+ loss = loss.mean() # mean loss
38
+ return loss
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/croco.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # CroCo model during pretraining
7
+ # --------------------------------------------------------
8
+
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
14
+ from functools import partial
15
+
16
+ from models.blocks import Block, DecoderBlock, PatchEmbed
17
+ from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
18
+ from models.masking import RandomMask
19
+
20
+ from transformers import PretrainedConfig
21
+ from transformers import PreTrainedModel
22
+
23
+
24
+ class CrocoConfig(PretrainedConfig):
25
+ model_type = "croco"
26
+
27
+ def __init__(
28
+ self,
29
+ img_size=224, # input image size
30
+ patch_size=16, # patch_size
31
+ mask_ratio=0.9, # ratios of masked tokens
32
+ enc_embed_dim=768, # encoder feature dimension
33
+ enc_depth=12, # encoder depth
34
+ enc_num_heads=12, # encoder number of heads in the transformer block
35
+ dec_embed_dim=512, # decoder feature dimension
36
+ dec_depth=8, # decoder depth
37
+ dec_num_heads=16, # decoder number of heads in the transformer block
38
+ mlp_ratio=4,
39
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
40
+ norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
41
+ pos_embed="cosine", # positional embedding (either cosine or RoPE100)
42
+ ):
43
+ super().__init__()
44
+ self.img_size = img_size
45
+ self.patch_size = patch_size
46
+ self.mask_ratio = mask_ratio
47
+ self.enc_embed_dim = enc_embed_dim
48
+ self.enc_depth = enc_depth
49
+ self.enc_num_heads = enc_num_heads
50
+ self.dec_embed_dim = dec_embed_dim
51
+ self.dec_depth = dec_depth
52
+ self.dec_num_heads = dec_num_heads
53
+ self.mlp_ratio = mlp_ratio
54
+ self.norm_layer = norm_layer
55
+ self.norm_im2_in_dec = norm_im2_in_dec
56
+ self.pos_embed = pos_embed
57
+
58
+
59
+ class CroCoNet(PreTrainedModel):
60
+
61
+ config_class = CrocoConfig
62
+ base_model_prefix = "croco"
63
+
64
+ def __init__(self, config: CrocoConfig):
65
+
66
+ super().__init__(config)
67
+
68
+ # patch embeddings (with initialization done as in MAE)
69
+ self._set_patch_embed(config.img_size, config.patch_size, config.enc_embed_dim)
70
+
71
+ # mask generations
72
+ self._set_mask_generator(self.patch_embed.num_patches, config.mask_ratio)
73
+
74
+ self.pos_embed = config.pos_embed
75
+ if config.pos_embed == "cosine":
76
+ # positional embedding of the encoder
77
+ enc_pos_embed = get_2d_sincos_pos_embed(
78
+ config.enc_embed_dim,
79
+ int(self.patch_embed.num_patches**0.5),
80
+ n_cls_token=0,
81
+ )
82
+ self.register_buffer(
83
+ "enc_pos_embed", torch.from_numpy(enc_pos_embed).float()
84
+ )
85
+ # positional embedding of the decoder
86
+ dec_pos_embed = get_2d_sincos_pos_embed(
87
+ config.dec_embed_dim,
88
+ int(self.patch_embed.num_patches**0.5),
89
+ n_cls_token=0,
90
+ )
91
+ self.register_buffer(
92
+ "dec_pos_embed", torch.from_numpy(dec_pos_embed).float()
93
+ )
94
+ # pos embedding in each block
95
+ self.rope = None # nothing for cosine
96
+ elif config.pos_embed.startswith("RoPE"): # eg RoPE100
97
+ self.enc_pos_embed = None # nothing to add in the encoder with RoPE
98
+ self.dec_pos_embed = None # nothing to add in the decoder with RoPE
99
+ if RoPE2D is None:
100
+ raise ImportError(
101
+ "Cannot find cuRoPE2D, please install it following the README instructions"
102
+ )
103
+ freq = float(config.pos_embed[len("RoPE") :])
104
+ self.rope = RoPE2D(freq=freq)
105
+ else:
106
+ raise NotImplementedError("Unknown pos_embed " + config.pos_embed)
107
+
108
+ # transformer for the encoder
109
+ self.enc_depth = config.enc_depth
110
+ self.enc_embed_dim = config.enc_embed_dim
111
+ self.enc_blocks = nn.ModuleList(
112
+ [
113
+ Block(
114
+ config.enc_embed_dim,
115
+ config.enc_num_heads,
116
+ config.mlp_ratio,
117
+ qkv_bias=True,
118
+ norm_layer=config.norm_layer,
119
+ rope=self.rope,
120
+ )
121
+ for i in range(config.enc_depth)
122
+ ]
123
+ )
124
+ self.enc_norm = config.norm_layer(config.enc_embed_dim)
125
+
126
+ # masked tokens
127
+ # self._set_mask_token(config.dec_embed_dim)
128
+ self.mask_token = None
129
+
130
+ # decoder
131
+ self._set_decoder(
132
+ config.enc_embed_dim,
133
+ config.dec_embed_dim,
134
+ config.dec_num_heads,
135
+ config.dec_depth,
136
+ config.mlp_ratio,
137
+ config.norm_layer,
138
+ config.norm_im2_in_dec,
139
+ )
140
+
141
+ # prediction head
142
+ self._set_prediction_head(config.dec_embed_dim, config.patch_size)
143
+
144
+ # initializer weights
145
+ self.initialize_weights()
146
+
147
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
148
+ self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
149
+
150
+ def _set_mask_generator(self, num_patches, mask_ratio):
151
+ self.mask_generator = RandomMask(num_patches, mask_ratio)
152
+
153
+ def _set_mask_token(self, dec_embed_dim):
154
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
155
+
156
+ def _set_decoder(
157
+ self,
158
+ enc_embed_dim,
159
+ dec_embed_dim,
160
+ dec_num_heads,
161
+ dec_depth,
162
+ mlp_ratio,
163
+ norm_layer,
164
+ norm_im2_in_dec,
165
+ ):
166
+ self.dec_depth = dec_depth
167
+ self.dec_embed_dim = dec_embed_dim
168
+ # transfer from encoder to decoder
169
+ self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
170
+ # transformer for the decoder
171
+ self.dec_blocks = nn.ModuleList(
172
+ [
173
+ DecoderBlock(
174
+ dec_embed_dim,
175
+ dec_num_heads,
176
+ mlp_ratio=mlp_ratio,
177
+ qkv_bias=True,
178
+ norm_layer=norm_layer,
179
+ norm_mem=norm_im2_in_dec,
180
+ rope=self.rope,
181
+ )
182
+ for i in range(dec_depth)
183
+ ]
184
+ )
185
+ # final norm layer
186
+ self.dec_norm = norm_layer(dec_embed_dim)
187
+
188
+ def _set_prediction_head(self, dec_embed_dim, patch_size):
189
+ self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
190
+
191
+ def initialize_weights(self):
192
+ # patch embed
193
+ self.patch_embed._init_weights()
194
+ # mask tokens
195
+ if self.mask_token is not None:
196
+ torch.nn.init.normal_(self.mask_token, std=0.02)
197
+ # linears and layer norms
198
+ self.apply(self._init_weights)
199
+
200
+ def _init_weights(self, m):
201
+ if isinstance(m, nn.Linear):
202
+ # we use xavier_uniform following official JAX ViT:
203
+ torch.nn.init.xavier_uniform_(m.weight)
204
+ if isinstance(m, nn.Linear) and m.bias is not None:
205
+ nn.init.constant_(m.bias, 0)
206
+ elif isinstance(m, nn.LayerNorm):
207
+ nn.init.constant_(m.bias, 0)
208
+ nn.init.constant_(m.weight, 1.0)
209
+
210
+ def _encode_image(self, image, do_mask=False, return_all_blocks=False):
211
+ """
212
+ image has B x 3 x img_size x img_size
213
+ do_mask: whether to perform masking or not
214
+ return_all_blocks: if True, return the features at the end of every block
215
+ instead of just the features from the last block (eg for some prediction heads)
216
+ """
217
+ # embed the image into patches (x has size B x Npatches x C)
218
+ # and get position if each return patch (pos has size B x Npatches x 2)
219
+ x, pos = self.patch_embed(image)
220
+ # add positional embedding without cls token
221
+ if self.enc_pos_embed is not None:
222
+ x = x + self.enc_pos_embed[None, ...]
223
+ # apply masking
224
+ B, N, C = x.size()
225
+ if do_mask:
226
+ masks = self.mask_generator(x)
227
+ x = x[~masks].view(B, -1, C)
228
+ posvis = pos[~masks].view(B, -1, 2)
229
+ else:
230
+ B, N, C = x.size()
231
+ masks = torch.zeros((B, N), dtype=bool)
232
+ posvis = pos
233
+ # now apply the transformer encoder and normalization
234
+ if return_all_blocks:
235
+ out = []
236
+ for blk in self.enc_blocks:
237
+ x = blk(x, posvis)
238
+ out.append(x)
239
+ out[-1] = self.enc_norm(out[-1])
240
+ return out, pos, masks
241
+ else:
242
+ for blk in self.enc_blocks:
243
+ x = blk(x, posvis)
244
+ x = self.enc_norm(x)
245
+ return x, pos, masks
246
+
247
+ def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
248
+ """
249
+ return_all_blocks: if True, return the features at the end of every block
250
+ instead of just the features from the last block (eg for some prediction heads)
251
+
252
+ masks1 can be None => assume image1 fully visible
253
+ """
254
+ # encoder to decoder layer
255
+ visf1 = self.decoder_embed(feat1)
256
+ f2 = self.decoder_embed(feat2)
257
+ # append masked tokens to the sequence
258
+ B, Nenc, C = visf1.size()
259
+ if masks1 is None: # downstreams
260
+ f1_ = visf1
261
+ else: # pretraining
262
+ Ntotal = masks1.size(1)
263
+ f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
264
+ f1_[~masks1] = visf1.view(B * Nenc, C)
265
+ # add positional embedding
266
+ if self.dec_pos_embed is not None:
267
+ f1_ = f1_ + self.dec_pos_embed
268
+ f2 = f2 + self.dec_pos_embed
269
+ # apply Transformer blocks
270
+ out = f1_
271
+ out2 = f2
272
+ if return_all_blocks:
273
+ _out, out = out, []
274
+ for blk in self.dec_blocks:
275
+ _out, out2 = blk(_out, out2, pos1, pos2)
276
+ out.append(_out)
277
+ out[-1] = self.dec_norm(out[-1])
278
+ else:
279
+ for blk in self.dec_blocks:
280
+ out, out2 = blk(out, out2, pos1, pos2)
281
+ out = self.dec_norm(out)
282
+ return out
283
+
284
+ def patchify(self, imgs):
285
+ """
286
+ imgs: (B, 3, H, W)
287
+ x: (B, L, patch_size**2 *3)
288
+ """
289
+ p = self.patch_embed.patch_size[0]
290
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
291
+
292
+ h = w = imgs.shape[2] // p
293
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
294
+ x = torch.einsum("nchpwq->nhwpqc", x)
295
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
296
+
297
+ return x
298
+
299
+ def unpatchify(self, x, channels=3):
300
+ """
301
+ x: (N, L, patch_size**2 *channels)
302
+ imgs: (N, 3, H, W)
303
+ """
304
+ patch_size = self.patch_embed.patch_size[0]
305
+ h = w = int(x.shape[1] ** 0.5)
306
+ assert h * w == x.shape[1]
307
+ x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
308
+ x = torch.einsum("nhwpqc->nchpwq", x)
309
+ imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
310
+ return imgs
311
+
312
+ # def forward(self, img1, img2):
313
+ # """
314
+ # img1: tensor of size B x 3 x img_size x img_size
315
+ # img2: tensor of size B x 3 x img_size x img_size
316
+
317
+ # out will be B x N x (3*patch_size*patch_size)
318
+ # masks are also returned as B x N just in case
319
+ # """
320
+ # # encoder of the masked first image
321
+ # feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
322
+ # # encoder of the second image
323
+ # feat2, pos2, _ = self._encode_image(img2, do_mask=False)
324
+ # # decoder
325
+ # decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
326
+ # # prediction head
327
+ # out = self.prediction_head(decfeat)
328
+ # # get target
329
+ # target = self.patchify(img1)
330
+ # return out, mask1, target
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/croco_downstream.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # CroCo model for downstream tasks
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from .croco import CroCoNet
11
+
12
+
13
+ def croco_args_from_ckpt(ckpt):
14
+ if "croco_kwargs" in ckpt: # CroCo v2 released models
15
+ return ckpt["croco_kwargs"]
16
+ elif "args" in ckpt and hasattr(
17
+ ckpt["args"], "model"
18
+ ): # pretrained using the official code release
19
+ s = ckpt[
20
+ "args"
21
+ ].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)"
22
+ assert s.startswith("CroCoNet(")
23
+ return eval(
24
+ "dict" + s[len("CroCoNet") :]
25
+ ) # transform it into the string of a dictionary and evaluate it
26
+ else: # CroCo v1 released models
27
+ return dict()
28
+
29
+
30
+ class CroCoDownstreamMonocularEncoder(CroCoNet):
31
+
32
+ def __init__(self, head, **kwargs):
33
+ """Build network for monocular downstream task, only using the encoder.
34
+ It takes an extra argument head, that is called with the features
35
+ and a dictionary img_info containing 'width' and 'height' keys
36
+ The head is setup with the croconet arguments in this init function
37
+ NOTE: It works by *calling super().__init__() but with redefined setters
38
+
39
+ """
40
+ super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs)
41
+ head.setup(self)
42
+ self.head = head
43
+
44
+ def _set_mask_generator(self, *args, **kwargs):
45
+ """No mask generator"""
46
+ return
47
+
48
+ def _set_mask_token(self, *args, **kwargs):
49
+ """No mask token"""
50
+ self.mask_token = None
51
+ return
52
+
53
+ def _set_decoder(self, *args, **kwargs):
54
+ """No decoder"""
55
+ return
56
+
57
+ def _set_prediction_head(self, *args, **kwargs):
58
+ """No 'prediction head' for downstream tasks."""
59
+ return
60
+
61
+ def forward(self, img):
62
+ """
63
+ img if of size batch_size x 3 x h x w
64
+ """
65
+ B, C, H, W = img.size()
66
+ img_info = {"height": H, "width": W}
67
+ need_all_layers = (
68
+ hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks
69
+ )
70
+ out, _, _ = self._encode_image(
71
+ img, do_mask=False, return_all_blocks=need_all_layers
72
+ )
73
+ return self.head(out, img_info)
74
+
75
+
76
+ class CroCoDownstreamBinocular(CroCoNet):
77
+
78
+ def __init__(self, head, **kwargs):
79
+ """Build network for binocular downstream task
80
+ It takes an extra argument head, that is called with the features
81
+ and a dictionary img_info containing 'width' and 'height' keys
82
+ The head is setup with the croconet arguments in this init function
83
+ """
84
+ super(CroCoDownstreamBinocular, self).__init__(**kwargs)
85
+ head.setup(self)
86
+ self.head = head
87
+
88
+ def _set_mask_generator(self, *args, **kwargs):
89
+ """No mask generator"""
90
+ return
91
+
92
+ def _set_mask_token(self, *args, **kwargs):
93
+ """No mask token"""
94
+ self.mask_token = None
95
+ return
96
+
97
+ def _set_prediction_head(self, *args, **kwargs):
98
+ """No prediction head for downstream tasks, define your own head"""
99
+ return
100
+
101
+ def encode_image_pairs(self, img1, img2, return_all_blocks=False):
102
+ """run encoder for a pair of images
103
+ it is actually ~5% faster to concatenate the images along the batch dimension
104
+ than to encode them separately
105
+ """
106
+ ## the two commented lines below is the naive version with separate encoding
107
+ # out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks)
108
+ # out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False)
109
+ ## and now the faster version
110
+ out, pos, _ = self._encode_image(
111
+ torch.cat((img1, img2), dim=0),
112
+ do_mask=False,
113
+ return_all_blocks=return_all_blocks,
114
+ )
115
+ if return_all_blocks:
116
+ out, out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out])))
117
+ out2 = out2[-1]
118
+ else:
119
+ out, out2 = out.chunk(2, dim=0)
120
+ pos, pos2 = pos.chunk(2, dim=0)
121
+ return out, out2, pos, pos2
122
+
123
+ def forward(self, img1, img2):
124
+ B, C, H, W = img1.size()
125
+ img_info = {"height": H, "width": W}
126
+ return_all_blocks = (
127
+ hasattr(self.head, "return_all_blocks") and self.head.return_all_blocks
128
+ )
129
+ out, out2, pos, pos2 = self.encode_image_pairs(
130
+ img1, img2, return_all_blocks=return_all_blocks
131
+ )
132
+ if return_all_blocks:
133
+ decout = self._decoder(
134
+ out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks
135
+ )
136
+ decout = out + decout
137
+ else:
138
+ decout = self._decoder(
139
+ out, pos, None, out2, pos2, return_all_blocks=return_all_blocks
140
+ )
141
+ return self.head(decout, img_info)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ from .curope2d import cuRoPE2D
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/curope.cpp ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
3
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+
8
+ // forward declaration
9
+ void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd );
10
+
11
+ void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd )
12
+ {
13
+ const int B = tokens.size(0);
14
+ const int N = tokens.size(1);
15
+ const int H = tokens.size(2);
16
+ const int D = tokens.size(3) / 4;
17
+
18
+ auto tok = tokens.accessor<float, 4>();
19
+ auto pos = positions.accessor<int64_t, 3>();
20
+
21
+ for (int b = 0; b < B; b++) {
22
+ for (int x = 0; x < 2; x++) { // y and then x (2d)
23
+ for (int n = 0; n < N; n++) {
24
+
25
+ // grab the token position
26
+ const int p = pos[b][n][x];
27
+
28
+ for (int h = 0; h < H; h++) {
29
+ for (int d = 0; d < D; d++) {
30
+ // grab the two values
31
+ float u = tok[b][n][h][d+0+x*2*D];
32
+ float v = tok[b][n][h][d+D+x*2*D];
33
+
34
+ // grab the cos,sin
35
+ const float inv_freq = fwd * p / powf(base, d/float(D));
36
+ float c = cosf(inv_freq);
37
+ float s = sinf(inv_freq);
38
+
39
+ // write the result
40
+ tok[b][n][h][d+0+x*2*D] = u*c - v*s;
41
+ tok[b][n][h][d+D+x*2*D] = v*c + u*s;
42
+ }
43
+ }
44
+ }
45
+ }
46
+ }
47
+ }
48
+
49
+ void rope_2d( torch::Tensor tokens, // B,N,H,D
50
+ const torch::Tensor positions, // B,N,2
51
+ const float base,
52
+ const float fwd )
53
+ {
54
+ TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions");
55
+ TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions");
56
+ TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions");
57
+ TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions");
58
+ TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2");
59
+ TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" );
60
+
61
+ if (tokens.is_cuda())
62
+ rope_2d_cuda( tokens, positions, base, fwd );
63
+ else
64
+ rope_2d_cpu( tokens, positions, base, fwd );
65
+ }
66
+
67
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
68
+ m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward");
69
+ }
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/curope2d.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+
6
+ try:
7
+ import curope as _kernels # run `python setup.py install`
8
+ except ModuleNotFoundError:
9
+ from . import curope as _kernels # run `python setup.py build_ext --inplace`
10
+
11
+
12
+ class cuRoPE2D_func(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ def forward(ctx, tokens, positions, base, F0=1):
16
+ ctx.save_for_backward(positions)
17
+ ctx.saved_base = base
18
+ ctx.saved_F0 = F0
19
+ # tokens = tokens.clone() # uncomment this if inplace doesn't work
20
+ _kernels.rope_2d(tokens, positions, base, F0)
21
+ ctx.mark_dirty(tokens)
22
+ return tokens
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_res):
26
+ positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
27
+ _kernels.rope_2d(grad_res, positions, base, -F0)
28
+ ctx.mark_dirty(grad_res)
29
+ return grad_res, None, None, None
30
+
31
+
32
+ class cuRoPE2D(torch.nn.Module):
33
+ def __init__(self, freq=100.0, F0=1.0):
34
+ super().__init__()
35
+ self.base = freq
36
+ self.F0 = F0
37
+
38
+ def forward(self, tokens, positions):
39
+ cuRoPE2D_func.apply(tokens.transpose(1, 2), positions, self.base, self.F0)
40
+ return tokens
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/kernels.cu ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
3
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+ #include <cuda.h>
8
+ #include <cuda_runtime.h>
9
+ #include <vector>
10
+
11
+ #define CHECK_CUDA(tensor) {\
12
+ TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
13
+ TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
14
+ void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
15
+
16
+
17
+ template < typename scalar_t >
18
+ __global__ void rope_2d_cuda_kernel(
19
+ //scalar_t* __restrict__ tokens,
20
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> tokens,
21
+ const int64_t* __restrict__ pos,
22
+ const float base,
23
+ const float fwd )
24
+ // const int N, const int H, const int D )
25
+ {
26
+ // tokens shape = (B, N, H, D)
27
+ const int N = tokens.size(1);
28
+ const int H = tokens.size(2);
29
+ const int D = tokens.size(3);
30
+
31
+ // each block update a single token, for all heads
32
+ // each thread takes care of a single output
33
+ extern __shared__ float shared[];
34
+ float* shared_inv_freq = shared + D;
35
+
36
+ const int b = blockIdx.x / N;
37
+ const int n = blockIdx.x % N;
38
+
39
+ const int Q = D / 4;
40
+ // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D]
41
+ // u_Y v_Y u_X v_X
42
+
43
+ // shared memory: first, compute inv_freq
44
+ if (threadIdx.x < Q)
45
+ shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q));
46
+ __syncthreads();
47
+
48
+ // start of X or Y part
49
+ const int X = threadIdx.x < D/2 ? 0 : 1;
50
+ const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X
51
+
52
+ // grab the cos,sin appropriate for me
53
+ const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q];
54
+ const float cos = cosf(freq);
55
+ const float sin = sinf(freq);
56
+ /*
57
+ float* shared_cos_sin = shared + D + D/4;
58
+ if ((threadIdx.x % (D/2)) < Q)
59
+ shared_cos_sin[m+0] = cosf(freq);
60
+ else
61
+ shared_cos_sin[m+Q] = sinf(freq);
62
+ __syncthreads();
63
+ const float cos = shared_cos_sin[m+0];
64
+ const float sin = shared_cos_sin[m+Q];
65
+ */
66
+
67
+ for (int h = 0; h < H; h++)
68
+ {
69
+ // then, load all the token for this head in shared memory
70
+ shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
71
+ __syncthreads();
72
+
73
+ const float u = shared[m];
74
+ const float v = shared[m+Q];
75
+
76
+ // write output
77
+ if ((threadIdx.x % (D/2)) < Q)
78
+ tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
79
+ else
80
+ tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
81
+ }
82
+ }
83
+
84
+ void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )
85
+ {
86
+ const int B = tokens.size(0); // batch size
87
+ const int N = tokens.size(1); // sequence length
88
+ const int H = tokens.size(2); // number of heads
89
+ const int D = tokens.size(3); // dimension per head
90
+
91
+ TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous");
92
+ TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous");
93
+ TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape");
94
+ TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4");
95
+
96
+ // one block for each layer, one thread per local-max
97
+ const int THREADS_PER_BLOCK = D;
98
+ const int N_BLOCKS = B * N; // each block takes care of H*D values
99
+ const int SHARED_MEM = sizeof(float) * (D + D/4);
100
+
101
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] {
102
+ rope_2d_cuda_kernel<scalar_t> <<<N_BLOCKS, THREADS_PER_BLOCK, SHARED_MEM>>> (
103
+ //tokens.data_ptr<scalar_t>(),
104
+ tokens.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
105
+ pos.data_ptr<int64_t>(),
106
+ base, fwd); //, N, H, D );
107
+ }));
108
+ }
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/curope/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ from setuptools import setup
5
+ from torch import cuda
6
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7
+
8
+ # compile for all possible CUDA architectures
9
+ all_cuda_archs = cuda.get_gencode_flags().replace("compute=", "arch=").split()
10
+ # alternatively, you can list cuda archs that you want, eg:
11
+ # all_cuda_archs = [
12
+ # '-gencode', 'arch=compute_70,code=sm_70',
13
+ # '-gencode', 'arch=compute_75,code=sm_75',
14
+ # '-gencode', 'arch=compute_80,code=sm_80',
15
+ # '-gencode', 'arch=compute_86,code=sm_86'
16
+ # ]
17
+
18
+ setup(
19
+ name="curope",
20
+ ext_modules=[
21
+ CUDAExtension(
22
+ name="curope",
23
+ sources=[
24
+ "curope.cpp",
25
+ "kernels.cu",
26
+ ],
27
+ extra_compile_args=dict(
28
+ nvcc=["-O3", "--ptxas-options=-v", "--use_fast_math"] + all_cuda_archs,
29
+ cxx=["-O3"],
30
+ ),
31
+ )
32
+ ],
33
+ cmdclass={"build_ext": BuildExtension},
34
+ )
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/dpt_block.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # DPT head for ViTs
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # https://github.com/isl-org/DPT
9
+ # https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+ from typing import Union, Tuple, Iterable, List, Optional, Dict
16
+
17
+
18
+ def pair(t):
19
+ return t if isinstance(t, tuple) else (t, t)
20
+
21
+
22
+ def make_scratch(in_shape, out_shape, groups=1, expand=False):
23
+ scratch = nn.Module()
24
+
25
+ out_shape1 = out_shape
26
+ out_shape2 = out_shape
27
+ out_shape3 = out_shape
28
+ out_shape4 = out_shape
29
+ if expand == True:
30
+ out_shape1 = out_shape
31
+ out_shape2 = out_shape * 2
32
+ out_shape3 = out_shape * 4
33
+ out_shape4 = out_shape * 8
34
+
35
+ scratch.layer1_rn = nn.Conv2d(
36
+ in_shape[0],
37
+ out_shape1,
38
+ kernel_size=3,
39
+ stride=1,
40
+ padding=1,
41
+ bias=False,
42
+ groups=groups,
43
+ )
44
+ scratch.layer2_rn = nn.Conv2d(
45
+ in_shape[1],
46
+ out_shape2,
47
+ kernel_size=3,
48
+ stride=1,
49
+ padding=1,
50
+ bias=False,
51
+ groups=groups,
52
+ )
53
+ scratch.layer3_rn = nn.Conv2d(
54
+ in_shape[2],
55
+ out_shape3,
56
+ kernel_size=3,
57
+ stride=1,
58
+ padding=1,
59
+ bias=False,
60
+ groups=groups,
61
+ )
62
+ scratch.layer4_rn = nn.Conv2d(
63
+ in_shape[3],
64
+ out_shape4,
65
+ kernel_size=3,
66
+ stride=1,
67
+ padding=1,
68
+ bias=False,
69
+ groups=groups,
70
+ )
71
+
72
+ scratch.layer_rn = nn.ModuleList(
73
+ [
74
+ scratch.layer1_rn,
75
+ scratch.layer2_rn,
76
+ scratch.layer3_rn,
77
+ scratch.layer4_rn,
78
+ ]
79
+ )
80
+
81
+ return scratch
82
+
83
+
84
+ class ResidualConvUnit_custom(nn.Module):
85
+ """Residual convolution module."""
86
+
87
+ def __init__(self, features, activation, bn):
88
+ """Init.
89
+ Args:
90
+ features (int): number of features
91
+ """
92
+ super().__init__()
93
+
94
+ self.bn = bn
95
+
96
+ self.groups = 1
97
+
98
+ self.conv1 = nn.Conv2d(
99
+ features,
100
+ features,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ bias=not self.bn,
105
+ groups=self.groups,
106
+ )
107
+
108
+ self.conv2 = nn.Conv2d(
109
+ features,
110
+ features,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1,
114
+ bias=not self.bn,
115
+ groups=self.groups,
116
+ )
117
+
118
+ if self.bn == True:
119
+ self.bn1 = nn.BatchNorm2d(features)
120
+ self.bn2 = nn.BatchNorm2d(features)
121
+
122
+ self.activation = activation
123
+
124
+ self.skip_add = nn.quantized.FloatFunctional()
125
+
126
+ def forward(self, x):
127
+ """Forward pass.
128
+ Args:
129
+ x (tensor): input
130
+ Returns:
131
+ tensor: output
132
+ """
133
+
134
+ out = self.activation(x)
135
+ out = self.conv1(out)
136
+ if self.bn == True:
137
+ out = self.bn1(out)
138
+
139
+ out = self.activation(out)
140
+ out = self.conv2(out)
141
+ if self.bn == True:
142
+ out = self.bn2(out)
143
+
144
+ if self.groups > 1:
145
+ out = self.conv_merge(out)
146
+
147
+ return self.skip_add.add(out, x)
148
+
149
+
150
+ class FeatureFusionBlock_custom(nn.Module):
151
+ """Feature fusion block."""
152
+
153
+ def __init__(
154
+ self,
155
+ features,
156
+ activation,
157
+ deconv=False,
158
+ bn=False,
159
+ expand=False,
160
+ align_corners=True,
161
+ width_ratio=1,
162
+ ):
163
+ """Init.
164
+ Args:
165
+ features (int): number of features
166
+ """
167
+ super(FeatureFusionBlock_custom, self).__init__()
168
+ self.width_ratio = width_ratio
169
+
170
+ self.deconv = deconv
171
+ self.align_corners = align_corners
172
+
173
+ self.groups = 1
174
+
175
+ self.expand = expand
176
+ out_features = features
177
+ if self.expand == True:
178
+ out_features = features // 2
179
+
180
+ self.out_conv = nn.Conv2d(
181
+ features,
182
+ out_features,
183
+ kernel_size=1,
184
+ stride=1,
185
+ padding=0,
186
+ bias=True,
187
+ groups=1,
188
+ )
189
+
190
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
191
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
192
+
193
+ self.skip_add = nn.quantized.FloatFunctional()
194
+
195
+ def forward(self, *xs):
196
+ """Forward pass.
197
+ Returns:
198
+ tensor: output
199
+ """
200
+ output = xs[0]
201
+
202
+ if len(xs) == 2:
203
+ res = self.resConfUnit1(xs[1])
204
+ if self.width_ratio != 1:
205
+ res = F.interpolate(
206
+ res, size=(output.shape[2], output.shape[3]), mode="bilinear"
207
+ )
208
+
209
+ output = self.skip_add.add(output, res)
210
+ # output += res
211
+
212
+ output = self.resConfUnit2(output)
213
+
214
+ if self.width_ratio != 1:
215
+ # and output.shape[3] < self.width_ratio * output.shape[2]
216
+ # size=(image.shape[])
217
+ if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
218
+ shape = 3 * output.shape[3]
219
+ else:
220
+ shape = int(self.width_ratio * 2 * output.shape[2])
221
+ output = F.interpolate(
222
+ output, size=(2 * output.shape[2], shape), mode="bilinear"
223
+ )
224
+ else:
225
+ output = nn.functional.interpolate(
226
+ output,
227
+ scale_factor=2,
228
+ mode="bilinear",
229
+ align_corners=self.align_corners,
230
+ )
231
+ output = self.out_conv(output)
232
+ return output
233
+
234
+
235
+ def make_fusion_block(features, use_bn, width_ratio=1):
236
+ return FeatureFusionBlock_custom(
237
+ features,
238
+ nn.ReLU(False),
239
+ deconv=False,
240
+ bn=use_bn,
241
+ expand=False,
242
+ align_corners=True,
243
+ width_ratio=width_ratio,
244
+ )
245
+
246
+
247
+ class Interpolate(nn.Module):
248
+ """Interpolation module."""
249
+
250
+ def __init__(self, scale_factor, mode, align_corners=False):
251
+ """Init.
252
+ Args:
253
+ scale_factor (float): scaling
254
+ mode (str): interpolation mode
255
+ """
256
+ super(Interpolate, self).__init__()
257
+
258
+ self.interp = nn.functional.interpolate
259
+ self.scale_factor = scale_factor
260
+ self.mode = mode
261
+ self.align_corners = align_corners
262
+
263
+ def forward(self, x):
264
+ """Forward pass.
265
+ Args:
266
+ x (tensor): input
267
+ Returns:
268
+ tensor: interpolated data
269
+ """
270
+
271
+ x = self.interp(
272
+ x,
273
+ scale_factor=self.scale_factor,
274
+ mode=self.mode,
275
+ align_corners=self.align_corners,
276
+ )
277
+
278
+ return x
279
+
280
+
281
+ class DPTOutputAdapter(nn.Module):
282
+ """DPT output adapter.
283
+
284
+ :param num_cahnnels: Number of output channels
285
+ :param stride_level: tride level compared to the full-sized image.
286
+ E.g. 4 for 1/4th the size of the image.
287
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
288
+ Patch size for smaller inputs will be computed accordingly.
289
+ :param hooks: Index of intermediate layers
290
+ :param layer_dims: Dimension of intermediate layers
291
+ :param feature_dim: Feature dimension
292
+ :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
293
+ :param use_bn: If set to True, activates batch norm
294
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
295
+ """
296
+
297
+ def __init__(
298
+ self,
299
+ num_channels: int = 1,
300
+ stride_level: int = 1,
301
+ patch_size: Union[int, Tuple[int, int]] = 16,
302
+ main_tasks: Iterable[str] = ("rgb",),
303
+ hooks: List[int] = [2, 5, 8, 11],
304
+ layer_dims: List[int] = [96, 192, 384, 768],
305
+ feature_dim: int = 256,
306
+ last_dim: int = 32,
307
+ use_bn: bool = False,
308
+ dim_tokens_enc: Optional[int] = None,
309
+ head_type: str = "regression",
310
+ output_width_ratio=1,
311
+ **kwargs
312
+ ):
313
+ super().__init__()
314
+ self.num_channels = num_channels
315
+ self.stride_level = stride_level
316
+ self.patch_size = pair(patch_size)
317
+ self.main_tasks = main_tasks
318
+ self.hooks = hooks
319
+ self.layer_dims = layer_dims
320
+ self.feature_dim = feature_dim
321
+ self.dim_tokens_enc = (
322
+ dim_tokens_enc * len(self.main_tasks)
323
+ if dim_tokens_enc is not None
324
+ else None
325
+ )
326
+ self.head_type = head_type
327
+
328
+ # Actual patch height and width, taking into account stride of input
329
+ self.P_H = max(1, self.patch_size[0] // stride_level)
330
+ self.P_W = max(1, self.patch_size[1] // stride_level)
331
+
332
+ self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
333
+
334
+ self.scratch.refinenet1 = make_fusion_block(
335
+ feature_dim, use_bn, output_width_ratio
336
+ )
337
+ self.scratch.refinenet2 = make_fusion_block(
338
+ feature_dim, use_bn, output_width_ratio
339
+ )
340
+ self.scratch.refinenet3 = make_fusion_block(
341
+ feature_dim, use_bn, output_width_ratio
342
+ )
343
+ self.scratch.refinenet4 = make_fusion_block(
344
+ feature_dim, use_bn, output_width_ratio
345
+ )
346
+
347
+ if self.head_type == "regression":
348
+ # The "DPTDepthModel" head
349
+ self.head = nn.Sequential(
350
+ nn.Conv2d(
351
+ feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1
352
+ ),
353
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
354
+ nn.Conv2d(
355
+ feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1
356
+ ),
357
+ nn.ReLU(True),
358
+ nn.Conv2d(
359
+ last_dim, self.num_channels, kernel_size=1, stride=1, padding=0
360
+ ),
361
+ )
362
+ elif self.head_type == "semseg":
363
+ # The "DPTSegmentationModel" head
364
+ self.head = nn.Sequential(
365
+ nn.Conv2d(
366
+ feature_dim, feature_dim, kernel_size=3, padding=1, bias=False
367
+ ),
368
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
369
+ nn.ReLU(True),
370
+ nn.Dropout(0.1, False),
371
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
372
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
373
+ )
374
+ else:
375
+ raise ValueError('DPT head_type must be "regression" or "semseg".')
376
+
377
+ if self.dim_tokens_enc is not None:
378
+ self.init(dim_tokens_enc=dim_tokens_enc)
379
+
380
+ def init(self, dim_tokens_enc=768):
381
+ """
382
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
383
+ Should be called when setting up MultiMAE.
384
+
385
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
386
+ """
387
+ # print(dim_tokens_enc)
388
+
389
+ # Set up activation postprocessing layers
390
+ if isinstance(dim_tokens_enc, int):
391
+ dim_tokens_enc = 4 * [dim_tokens_enc]
392
+
393
+ self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
394
+
395
+ self.act_1_postprocess = nn.Sequential(
396
+ nn.Conv2d(
397
+ in_channels=self.dim_tokens_enc[0],
398
+ out_channels=self.layer_dims[0],
399
+ kernel_size=1,
400
+ stride=1,
401
+ padding=0,
402
+ ),
403
+ nn.ConvTranspose2d(
404
+ in_channels=self.layer_dims[0],
405
+ out_channels=self.layer_dims[0],
406
+ kernel_size=4,
407
+ stride=4,
408
+ padding=0,
409
+ bias=True,
410
+ dilation=1,
411
+ groups=1,
412
+ ),
413
+ )
414
+
415
+ self.act_2_postprocess = nn.Sequential(
416
+ nn.Conv2d(
417
+ in_channels=self.dim_tokens_enc[1],
418
+ out_channels=self.layer_dims[1],
419
+ kernel_size=1,
420
+ stride=1,
421
+ padding=0,
422
+ ),
423
+ nn.ConvTranspose2d(
424
+ in_channels=self.layer_dims[1],
425
+ out_channels=self.layer_dims[1],
426
+ kernel_size=2,
427
+ stride=2,
428
+ padding=0,
429
+ bias=True,
430
+ dilation=1,
431
+ groups=1,
432
+ ),
433
+ )
434
+
435
+ self.act_3_postprocess = nn.Sequential(
436
+ nn.Conv2d(
437
+ in_channels=self.dim_tokens_enc[2],
438
+ out_channels=self.layer_dims[2],
439
+ kernel_size=1,
440
+ stride=1,
441
+ padding=0,
442
+ )
443
+ )
444
+
445
+ self.act_4_postprocess = nn.Sequential(
446
+ nn.Conv2d(
447
+ in_channels=self.dim_tokens_enc[3],
448
+ out_channels=self.layer_dims[3],
449
+ kernel_size=1,
450
+ stride=1,
451
+ padding=0,
452
+ ),
453
+ nn.Conv2d(
454
+ in_channels=self.layer_dims[3],
455
+ out_channels=self.layer_dims[3],
456
+ kernel_size=3,
457
+ stride=2,
458
+ padding=1,
459
+ ),
460
+ )
461
+
462
+ self.act_postprocess = nn.ModuleList(
463
+ [
464
+ self.act_1_postprocess,
465
+ self.act_2_postprocess,
466
+ self.act_3_postprocess,
467
+ self.act_4_postprocess,
468
+ ]
469
+ )
470
+
471
+ def adapt_tokens(self, encoder_tokens):
472
+ # Adapt tokens
473
+ x = []
474
+ x.append(encoder_tokens[:, :])
475
+ x = torch.cat(x, dim=-1)
476
+ return x
477
+
478
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size):
479
+ # input_info: Dict):
480
+ assert (
481
+ self.dim_tokens_enc is not None
482
+ ), "Need to call init(dim_tokens_enc) function first"
483
+ H, W = image_size
484
+
485
+ # Number of patches in height and width
486
+ N_H = H // (self.stride_level * self.P_H)
487
+ N_W = W // (self.stride_level * self.P_W)
488
+
489
+ # Hook decoder onto 4 layers from specified ViT layers
490
+ layers = [encoder_tokens[hook] for hook in self.hooks]
491
+
492
+ # Extract only task-relevant tokens and ignore global tokens.
493
+ layers = [self.adapt_tokens(l) for l in layers]
494
+
495
+ # Reshape tokens to spatial representation
496
+ layers = [
497
+ rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers
498
+ ]
499
+
500
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
501
+ # Project layers to chosen feature dim
502
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
503
+
504
+ # Fuse layers using refinement stages
505
+ path_4 = self.scratch.refinenet4(layers[3])
506
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
507
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
508
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
509
+
510
+ # Output head
511
+ out = self.head(path_1)
512
+
513
+ return out
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/head_downstream.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Heads for downstream tasks
6
+ # --------------------------------------------------------
7
+
8
+ """
9
+ A head is a module where the __init__ defines only the head hyperparameters.
10
+ A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes.
11
+ The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height'
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from .dpt_block import DPTOutputAdapter
17
+
18
+
19
+ class PixelwiseTaskWithDPT(nn.Module):
20
+ """DPT module for CroCo.
21
+ by default, hooks_idx will be equal to:
22
+ * for encoder-only: 4 equally spread layers
23
+ * for encoder+decoder: last encoder + 3 equally spread layers of the decoder
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ *,
29
+ hooks_idx=None,
30
+ layer_dims=[96, 192, 384, 768],
31
+ output_width_ratio=1,
32
+ num_channels=1,
33
+ postprocess=None,
34
+ **kwargs,
35
+ ):
36
+ super(PixelwiseTaskWithDPT, self).__init__()
37
+ self.return_all_blocks = True # backbone needs to return all layers
38
+ self.postprocess = postprocess
39
+ self.output_width_ratio = output_width_ratio
40
+ self.num_channels = num_channels
41
+ self.hooks_idx = hooks_idx
42
+ self.layer_dims = layer_dims
43
+
44
+ def setup(self, croconet):
45
+ dpt_args = {
46
+ "output_width_ratio": self.output_width_ratio,
47
+ "num_channels": self.num_channels,
48
+ }
49
+ if self.hooks_idx is None:
50
+ if hasattr(croconet, "dec_blocks"): # encoder + decoder
51
+ step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth]
52
+ hooks_idx = [
53
+ croconet.dec_depth + croconet.enc_depth - 1 - i * step
54
+ for i in range(3, -1, -1)
55
+ ]
56
+ else: # encoder only
57
+ step = croconet.enc_depth // 4
58
+ hooks_idx = [
59
+ croconet.enc_depth - 1 - i * step for i in range(3, -1, -1)
60
+ ]
61
+ self.hooks_idx = hooks_idx
62
+ print(
63
+ f" PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}"
64
+ )
65
+ dpt_args["hooks"] = self.hooks_idx
66
+ dpt_args["layer_dims"] = self.layer_dims
67
+ self.dpt = DPTOutputAdapter(**dpt_args)
68
+ dim_tokens = [
69
+ (
70
+ croconet.enc_embed_dim
71
+ if hook < croconet.enc_depth
72
+ else croconet.dec_embed_dim
73
+ )
74
+ for hook in self.hooks_idx
75
+ ]
76
+ dpt_init_args = {"dim_tokens_enc": dim_tokens}
77
+ self.dpt.init(**dpt_init_args)
78
+
79
+ def forward(self, x, img_info):
80
+ out = self.dpt(x, image_size=(img_info["height"], img_info["width"]))
81
+ if self.postprocess:
82
+ out = self.postprocess(out)
83
+ return out
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/masking.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Masking utils
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class RandomMask(nn.Module):
14
+ """
15
+ random masking
16
+ """
17
+
18
+ def __init__(self, num_patches, mask_ratio):
19
+ super().__init__()
20
+ self.num_patches = num_patches
21
+ self.num_mask = int(mask_ratio * self.num_patches)
22
+
23
+ def __call__(self, x):
24
+ noise = torch.rand(x.size(0), self.num_patches, device=x.device)
25
+ argsort = torch.argsort(noise, dim=1)
26
+ return argsort < self.num_mask
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/models/pos_embed.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Position embedding utils
7
+ # --------------------------------------------------------
8
+
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+
15
+ # --------------------------------------------------------
16
+ # 2D sine-cosine position embedding
17
+ # References:
18
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
21
+ # --------------------------------------------------------
22
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
23
+ """
24
+ grid_size: int of the grid height and width
25
+ return:
26
+ pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27
+ """
28
+ grid_h = np.arange(grid_size, dtype=np.float32)
29
+ grid_w = np.arange(grid_size, dtype=np.float32)
30
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
31
+ grid = np.stack(grid, axis=0)
32
+
33
+ grid = grid.reshape([2, 1, grid_size, grid_size])
34
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
35
+ if n_cls_token > 0:
36
+ pos_embed = np.concatenate(
37
+ [np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0
38
+ )
39
+ return pos_embed
40
+
41
+
42
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
43
+ assert embed_dim % 2 == 0
44
+
45
+ # use half of dimensions to encode grid_h
46
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
47
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
48
+
49
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
50
+ return emb
51
+
52
+
53
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
54
+ """
55
+ embed_dim: output dimension for each position
56
+ pos: a list of positions to be encoded: size (M,)
57
+ out: (M, D)
58
+ """
59
+ assert embed_dim % 2 == 0
60
+ omega = np.arange(embed_dim // 2, dtype=float)
61
+ omega /= embed_dim / 2.0
62
+ omega = 1.0 / 10000**omega # (D/2,)
63
+
64
+ pos = pos.reshape(-1) # (M,)
65
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
66
+
67
+ emb_sin = np.sin(out) # (M, D/2)
68
+ emb_cos = np.cos(out) # (M, D/2)
69
+
70
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
71
+ return emb
72
+
73
+
74
+ # --------------------------------------------------------
75
+ # Interpolate position embeddings for high-resolution
76
+ # References:
77
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
78
+ # DeiT: https://github.com/facebookresearch/deit
79
+ # --------------------------------------------------------
80
+ def interpolate_pos_embed(model, checkpoint_model):
81
+ if "pos_embed" in checkpoint_model:
82
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
83
+ embedding_size = pos_embed_checkpoint.shape[-1]
84
+ num_patches = model.patch_embed.num_patches
85
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
86
+ # height (== width) for the checkpoint position embedding
87
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
88
+ # height (== width) for the new position embedding
89
+ new_size = int(num_patches**0.5)
90
+ # class_token and dist_token are kept unchanged
91
+ if orig_size != new_size:
92
+ print(
93
+ "Position interpolate from %dx%d to %dx%d"
94
+ % (orig_size, orig_size, new_size, new_size)
95
+ )
96
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
97
+ # only the position tokens are interpolated
98
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
99
+ pos_tokens = pos_tokens.reshape(
100
+ -1, orig_size, orig_size, embedding_size
101
+ ).permute(0, 3, 1, 2)
102
+ pos_tokens = torch.nn.functional.interpolate(
103
+ pos_tokens,
104
+ size=(new_size, new_size),
105
+ mode="bicubic",
106
+ align_corners=False,
107
+ )
108
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
109
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
110
+ checkpoint_model["pos_embed"] = new_pos_embed
111
+
112
+
113
+ # ----------------------------------------------------------
114
+ # RoPE2D: RoPE implementation in 2D
115
+ # ----------------------------------------------------------
116
+
117
+ try:
118
+ from models.curope import cuRoPE2D
119
+
120
+ RoPE2D = cuRoPE2D
121
+ except ImportError:
122
+ print(
123
+ "Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead"
124
+ )
125
+
126
+ class RoPE2D(torch.nn.Module):
127
+
128
+ def __init__(self, freq=100.0, F0=1.0):
129
+ super().__init__()
130
+ self.base = freq
131
+ self.F0 = F0
132
+ self.cache = {}
133
+
134
+ def get_cos_sin(self, D, seq_len, device, dtype):
135
+ if (D, seq_len, device, dtype) not in self.cache:
136
+ inv_freq = 1.0 / (
137
+ self.base ** (torch.arange(0, D, 2).float().to(device) / D)
138
+ )
139
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
140
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
141
+ freqs = torch.cat((freqs, freqs), dim=-1)
142
+ cos = freqs.cos() # (Seq, Dim)
143
+ sin = freqs.sin()
144
+ self.cache[D, seq_len, device, dtype] = (cos, sin)
145
+ return self.cache[D, seq_len, device, dtype]
146
+
147
+ @staticmethod
148
+ def rotate_half(x):
149
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
150
+ return torch.cat((-x2, x1), dim=-1)
151
+
152
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
153
+ assert pos1d.ndim == 2
154
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
155
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
156
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
157
+
158
+ def forward(self, tokens, positions):
159
+ """
160
+ input:
161
+ * tokens: batch_size x nheads x ntokens x dim
162
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
163
+ output:
164
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
165
+ """
166
+ assert (
167
+ tokens.size(3) % 2 == 0
168
+ ), "number of dimensions should be a multiple of two"
169
+ D = tokens.size(3) // 2
170
+ assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2
171
+ cos, sin = self.get_cos_sin(
172
+ D, int(positions.max()) + 1, tokens.device, tokens.dtype
173
+ )
174
+ # split features into two along the feature dimension, and apply rope1d on each half
175
+ y, x = tokens.chunk(2, dim=-1)
176
+ y = self.apply_rope1d(y, positions[:, :, 0], cos, sin)
177
+ x = self.apply_rope1d(x, positions[:, :, 1], cos, sin)
178
+ tokens = torch.cat((y, x), dim=-1)
179
+ return tokens
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/pretrain.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Pre-training CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
11
+ # --------------------------------------------------------
12
+ import argparse
13
+ import datetime
14
+ import json
15
+ import numpy as np
16
+ import os
17
+ import sys
18
+ import time
19
+ import math
20
+ from pathlib import Path
21
+ from typing import Iterable
22
+
23
+ import torch
24
+ import torch.distributed as dist
25
+ import torch.backends.cudnn as cudnn
26
+ from torch.utils.tensorboard import SummaryWriter
27
+ import torchvision.transforms as transforms
28
+ import torchvision.datasets as datasets
29
+
30
+ import utils.misc as misc
31
+ from utils.misc import NativeScalerWithGradNormCount as NativeScaler
32
+ from models.croco import CroCoNet
33
+ from models.criterion import MaskedMSE
34
+ from datasets.pairs_dataset import PairsDataset
35
+
36
+
37
+ def get_args_parser():
38
+ parser = argparse.ArgumentParser("CroCo pre-training", add_help=False)
39
+ # model and criterion
40
+ parser.add_argument(
41
+ "--model",
42
+ default="CroCoNet()",
43
+ type=str,
44
+ help="string containing the model to build",
45
+ )
46
+ parser.add_argument(
47
+ "--norm_pix_loss",
48
+ default=1,
49
+ choices=[0, 1],
50
+ help="apply per-patch mean/std normalization before applying the loss",
51
+ )
52
+ # dataset
53
+ parser.add_argument(
54
+ "--dataset", default="habitat_release", type=str, help="training set"
55
+ )
56
+ parser.add_argument(
57
+ "--transforms", default="crop224+acolor", type=str, help="transforms to apply"
58
+ ) # in the paper, we also use some homography and rotation, but find later that they were not useful or even harmful
59
+ # training
60
+ parser.add_argument("--seed", default=0, type=int, help="Random seed")
61
+ parser.add_argument(
62
+ "--batch_size",
63
+ default=64,
64
+ type=int,
65
+ help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
66
+ )
67
+ parser.add_argument(
68
+ "--epochs",
69
+ default=800,
70
+ type=int,
71
+ help="Maximum number of epochs for the scheduler",
72
+ )
73
+ parser.add_argument(
74
+ "--max_epoch", default=400, type=int, help="Stop training at this epoch"
75
+ )
76
+ parser.add_argument(
77
+ "--accum_iter",
78
+ default=1,
79
+ type=int,
80
+ help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)",
81
+ )
82
+ parser.add_argument(
83
+ "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)"
84
+ )
85
+ parser.add_argument(
86
+ "--lr",
87
+ type=float,
88
+ default=None,
89
+ metavar="LR",
90
+ help="learning rate (absolute lr)",
91
+ )
92
+ parser.add_argument(
93
+ "--blr",
94
+ type=float,
95
+ default=1.5e-4,
96
+ metavar="LR",
97
+ help="base learning rate: absolute_lr = base_lr * total_batch_size / 256",
98
+ )
99
+ parser.add_argument(
100
+ "--min_lr",
101
+ type=float,
102
+ default=0.0,
103
+ metavar="LR",
104
+ help="lower lr bound for cyclic schedulers that hit 0",
105
+ )
106
+ parser.add_argument(
107
+ "--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR"
108
+ )
109
+ parser.add_argument(
110
+ "--amp",
111
+ type=int,
112
+ default=1,
113
+ choices=[0, 1],
114
+ help="Use Automatic Mixed Precision for pretraining",
115
+ )
116
+ # others
117
+ parser.add_argument("--num_workers", default=8, type=int)
118
+ parser.add_argument(
119
+ "--world_size", default=1, type=int, help="number of distributed processes"
120
+ )
121
+ parser.add_argument("--local_rank", default=-1, type=int)
122
+ parser.add_argument(
123
+ "--dist_url", default="env://", help="url used to set up distributed training"
124
+ )
125
+ parser.add_argument(
126
+ "--save_freq",
127
+ default=1,
128
+ type=int,
129
+ help="frequence (number of epochs) to save checkpoint in checkpoint-last.pth",
130
+ )
131
+ parser.add_argument(
132
+ "--keep_freq",
133
+ default=20,
134
+ type=int,
135
+ help="frequence (number of epochs) to save checkpoint in checkpoint-%d.pth",
136
+ )
137
+ parser.add_argument(
138
+ "--print_freq",
139
+ default=20,
140
+ type=int,
141
+ help="frequence (number of iterations) to print infos while training",
142
+ )
143
+ # paths
144
+ parser.add_argument(
145
+ "--output_dir",
146
+ default="./output/",
147
+ type=str,
148
+ help="path where to save the output",
149
+ )
150
+ parser.add_argument(
151
+ "--data_dir", default="./data/", type=str, help="path where data are stored"
152
+ )
153
+ return parser
154
+
155
+
156
+ def main(args):
157
+ misc.init_distributed_mode(args)
158
+ global_rank = misc.get_rank()
159
+ world_size = misc.get_world_size()
160
+
161
+ print("output_dir: " + args.output_dir)
162
+ if args.output_dir:
163
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
164
+
165
+ # auto resume
166
+ last_ckpt_fname = os.path.join(args.output_dir, f"checkpoint-last.pth")
167
+ args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None
168
+
169
+ print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
170
+ print("{}".format(args).replace(", ", ",\n"))
171
+
172
+ device = "cuda" if torch.cuda.is_available() else "cpu"
173
+ device = torch.device(device)
174
+
175
+ # fix the seed
176
+ seed = args.seed + misc.get_rank()
177
+ torch.manual_seed(seed)
178
+ np.random.seed(seed)
179
+
180
+ cudnn.benchmark = True
181
+
182
+ ## training dataset and loader
183
+ print(
184
+ "Building dataset for {:s} with transforms {:s}".format(
185
+ args.dataset, args.transforms
186
+ )
187
+ )
188
+ dataset = PairsDataset(args.dataset, trfs=args.transforms, data_dir=args.data_dir)
189
+ if world_size > 1:
190
+ sampler_train = torch.utils.data.DistributedSampler(
191
+ dataset, num_replicas=world_size, rank=global_rank, shuffle=True
192
+ )
193
+ print("Sampler_train = %s" % str(sampler_train))
194
+ else:
195
+ sampler_train = torch.utils.data.RandomSampler(dataset)
196
+ data_loader_train = torch.utils.data.DataLoader(
197
+ dataset,
198
+ sampler=sampler_train,
199
+ batch_size=args.batch_size,
200
+ num_workers=args.num_workers,
201
+ pin_memory=True,
202
+ drop_last=True,
203
+ )
204
+
205
+ ## model
206
+ print("Loading model: {:s}".format(args.model))
207
+ model = eval(args.model)
208
+ print(
209
+ "Loading criterion: MaskedMSE(norm_pix_loss={:s})".format(
210
+ str(bool(args.norm_pix_loss))
211
+ )
212
+ )
213
+ criterion = MaskedMSE(norm_pix_loss=bool(args.norm_pix_loss))
214
+
215
+ model.to(device)
216
+ model_without_ddp = model
217
+ print("Model = %s" % str(model_without_ddp))
218
+
219
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
220
+ if args.lr is None: # only base_lr is specified
221
+ args.lr = args.blr * eff_batch_size / 256
222
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
223
+ print("actual lr: %.2e" % args.lr)
224
+ print("accumulate grad iterations: %d" % args.accum_iter)
225
+ print("effective batch size: %d" % eff_batch_size)
226
+
227
+ if args.distributed:
228
+ model = torch.nn.parallel.DistributedDataParallel(
229
+ model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True
230
+ )
231
+ model_without_ddp = model.module
232
+
233
+ param_groups = misc.get_parameter_groups(
234
+ model_without_ddp, args.weight_decay
235
+ ) # following timm: set wd as 0 for bias and norm layers
236
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
237
+ print(optimizer)
238
+ loss_scaler = NativeScaler()
239
+
240
+ misc.load_model(
241
+ args=args,
242
+ model_without_ddp=model_without_ddp,
243
+ optimizer=optimizer,
244
+ loss_scaler=loss_scaler,
245
+ )
246
+
247
+ if global_rank == 0 and args.output_dir is not None:
248
+ log_writer = SummaryWriter(log_dir=args.output_dir)
249
+ else:
250
+ log_writer = None
251
+
252
+ print(f"Start training until {args.max_epoch} epochs")
253
+ start_time = time.time()
254
+ for epoch in range(args.start_epoch, args.max_epoch):
255
+ if world_size > 1:
256
+ data_loader_train.sampler.set_epoch(epoch)
257
+
258
+ train_stats = train_one_epoch(
259
+ model,
260
+ criterion,
261
+ data_loader_train,
262
+ optimizer,
263
+ device,
264
+ epoch,
265
+ loss_scaler,
266
+ log_writer=log_writer,
267
+ args=args,
268
+ )
269
+
270
+ if args.output_dir and epoch % args.save_freq == 0:
271
+ misc.save_model(
272
+ args=args,
273
+ model_without_ddp=model_without_ddp,
274
+ optimizer=optimizer,
275
+ loss_scaler=loss_scaler,
276
+ epoch=epoch,
277
+ fname="last",
278
+ )
279
+
280
+ if (
281
+ args.output_dir
282
+ and (epoch % args.keep_freq == 0 or epoch + 1 == args.max_epoch)
283
+ and (epoch > 0 or args.max_epoch == 1)
284
+ ):
285
+ misc.save_model(
286
+ args=args,
287
+ model_without_ddp=model_without_ddp,
288
+ optimizer=optimizer,
289
+ loss_scaler=loss_scaler,
290
+ epoch=epoch,
291
+ )
292
+
293
+ log_stats = {
294
+ **{f"train_{k}": v for k, v in train_stats.items()},
295
+ "epoch": epoch,
296
+ }
297
+
298
+ if args.output_dir and misc.is_main_process():
299
+ if log_writer is not None:
300
+ log_writer.flush()
301
+ with open(
302
+ os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
303
+ ) as f:
304
+ f.write(json.dumps(log_stats) + "\n")
305
+
306
+ total_time = time.time() - start_time
307
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
308
+ print("Training time {}".format(total_time_str))
309
+
310
+
311
+ def train_one_epoch(
312
+ model: torch.nn.Module,
313
+ criterion: torch.nn.Module,
314
+ data_loader: Iterable,
315
+ optimizer: torch.optim.Optimizer,
316
+ device: torch.device,
317
+ epoch: int,
318
+ loss_scaler,
319
+ log_writer=None,
320
+ args=None,
321
+ ):
322
+ model.train(True)
323
+ metric_logger = misc.MetricLogger(delimiter=" ")
324
+ metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
325
+ header = "Epoch: [{}]".format(epoch)
326
+ accum_iter = args.accum_iter
327
+
328
+ optimizer.zero_grad()
329
+
330
+ if log_writer is not None:
331
+ print("log_dir: {}".format(log_writer.log_dir))
332
+
333
+ for data_iter_step, (image1, image2) in enumerate(
334
+ metric_logger.log_every(data_loader, args.print_freq, header)
335
+ ):
336
+
337
+ # we use a per iteration lr scheduler
338
+ if data_iter_step % accum_iter == 0:
339
+ misc.adjust_learning_rate(
340
+ optimizer, data_iter_step / len(data_loader) + epoch, args
341
+ )
342
+
343
+ image1 = image1.to(device, non_blocking=True)
344
+ image2 = image2.to(device, non_blocking=True)
345
+ with torch.cuda.amp.autocast(enabled=bool(args.amp)):
346
+ out, mask, target = model(image1, image2)
347
+ loss = criterion(out, mask, target)
348
+
349
+ loss_value = loss.item()
350
+
351
+ if not math.isfinite(loss_value):
352
+ print("Loss is {}, stopping training".format(loss_value))
353
+ sys.exit(1)
354
+
355
+ loss /= accum_iter
356
+ loss_scaler(
357
+ loss,
358
+ optimizer,
359
+ parameters=model.parameters(),
360
+ update_grad=(data_iter_step + 1) % accum_iter == 0,
361
+ )
362
+ if (data_iter_step + 1) % accum_iter == 0:
363
+ optimizer.zero_grad()
364
+
365
+ torch.cuda.synchronize()
366
+
367
+ metric_logger.update(loss=loss_value)
368
+
369
+ lr = optimizer.param_groups[0]["lr"]
370
+ metric_logger.update(lr=lr)
371
+
372
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
373
+ if (
374
+ log_writer is not None
375
+ and ((data_iter_step + 1) % (accum_iter * args.print_freq)) == 0
376
+ ):
377
+ # x-axis is based on epoch_1000x in the tensorboard, calibrating differences curves when batch size changes
378
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
379
+ log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
380
+ log_writer.add_scalar("lr", lr, epoch_1000x)
381
+
382
+ # gather the stats from all processes
383
+ metric_logger.synchronize_between_processes()
384
+ print("Averaged stats:", metric_logger)
385
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
386
+
387
+
388
+ if __name__ == "__main__":
389
+ args = get_args_parser()
390
+ args = args.parse_args()
391
+ main(args)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/README.MD ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## CroCo-Stereo and CroCo-Flow
2
+
3
+ This README explains how to use CroCo-Stereo and CroCo-Flow as well as how they were trained.
4
+ All commands should be launched from the root directory.
5
+
6
+ ### Simple inference example
7
+
8
+ We provide a simple inference exemple for CroCo-Stereo and CroCo-Flow in the Totebook `croco-stereo-flow-demo.ipynb`.
9
+ Before running it, please download the trained models with:
10
+ ```
11
+ bash stereoflow/download_model.sh crocostereo.pth
12
+ bash stereoflow/download_model.sh crocoflow.pth
13
+ ```
14
+
15
+ ### Prepare data for training or evaluation
16
+
17
+ Put the datasets used for training/evaluation in `./data/stereoflow` (or update the paths at the top of `stereoflow/datasets_stereo.py` and `stereoflow/datasets_flow.py`).
18
+ Please find below on the file structure should look for each dataset:
19
+ <details>
20
+ <summary>FlyingChairs</summary>
21
+
22
+ ```
23
+ ./data/stereoflow/FlyingChairs/
24
+ └───chairs_split.txt
25
+ └───data/
26
+ └─── ...
27
+ ```
28
+ </details>
29
+
30
+ <details>
31
+ <summary>MPI-Sintel</summary>
32
+
33
+ ```
34
+ ./data/stereoflow/MPI-Sintel/
35
+ └───training/
36
+ │ └───clean/
37
+ │ └───final/
38
+ │ └───flow/
39
+ └───test/
40
+ └───clean/
41
+ └───final/
42
+ ```
43
+ </details>
44
+
45
+ <details>
46
+ <summary>SceneFlow (including FlyingThings)</summary>
47
+
48
+ ```
49
+ ./data/stereoflow/SceneFlow/
50
+ └───Driving/
51
+ │ └───disparity/
52
+ │ └───frames_cleanpass/
53
+ │ └───frames_finalpass/
54
+ └───FlyingThings/
55
+ │ └───disparity/
56
+ │ └───frames_cleanpass/
57
+ │ └───frames_finalpass/
58
+ │ └───optical_flow/
59
+ └───Monkaa/
60
+ └───disparity/
61
+ └───frames_cleanpass/
62
+ └───frames_finalpass/
63
+ ```
64
+ </details>
65
+
66
+ <details>
67
+ <summary>TartanAir</summary>
68
+
69
+ ```
70
+ ./data/stereoflow/TartanAir/
71
+ └───abandonedfactory/
72
+ │ └───.../
73
+ └───abandonedfactory_night/
74
+ │ └───.../
75
+ └───.../
76
+ ```
77
+ </details>
78
+
79
+ <details>
80
+ <summary>Booster</summary>
81
+
82
+ ```
83
+ ./data/stereoflow/booster_gt/
84
+ └───train/
85
+ └───balanced/
86
+ └───Bathroom/
87
+ └───Bedroom/
88
+ └───...
89
+ ```
90
+ </details>
91
+
92
+ <details>
93
+ <summary>CREStereo</summary>
94
+
95
+ ```
96
+ ./data/stereoflow/crenet_stereo_trainset/
97
+ └───stereo_trainset/
98
+ └───crestereo/
99
+ └───hole/
100
+ └───reflective/
101
+ └───shapenet/
102
+ └───tree/
103
+ ```
104
+ </details>
105
+
106
+ <details>
107
+ <summary>ETH3D Two-view Low-res</summary>
108
+
109
+ ```
110
+ ./data/stereoflow/eth3d_lowres/
111
+ └───test/
112
+ │ └───lakeside_1l/
113
+ │ └───...
114
+ └───train/
115
+ │ └───delivery_area_1l/
116
+ │ └───...
117
+ └───train_gt/
118
+ └───delivery_area_1l/
119
+ └───...
120
+ ```
121
+ </details>
122
+
123
+ <details>
124
+ <summary>KITTI 2012</summary>
125
+
126
+ ```
127
+ ./data/stereoflow/kitti-stereo-2012/
128
+ └───testing/
129
+ │ └───colored_0/
130
+ │ └───colored_1/
131
+ └───training/
132
+ └───colored_0/
133
+ └───colored_1/
134
+ └───disp_occ/
135
+ └───flow_occ/
136
+ ```
137
+ </details>
138
+
139
+ <details>
140
+ <summary>KITTI 2015</summary>
141
+
142
+ ```
143
+ ./data/stereoflow/kitti-stereo-2015/
144
+ └───testing/
145
+ │ └───image_2/
146
+ │ └───image_3/
147
+ └───training/
148
+ └───image_2/
149
+ └───image_3/
150
+ └───disp_occ_0/
151
+ └───flow_occ/
152
+ ```
153
+ </details>
154
+
155
+ <details>
156
+ <summary>Middlebury</summary>
157
+
158
+ ```
159
+ ./data/stereoflow/middlebury
160
+ └───2005/
161
+ │ └───train/
162
+ │ └───Art/
163
+ │ └───...
164
+ └───2006/
165
+ │ └───Aloe/
166
+ │ └───Baby1/
167
+ │ └───...
168
+ └───2014/
169
+ │ └───Adirondack-imperfect/
170
+ │ └───Adirondack-perfect/
171
+ │ └───...
172
+ └───2021/
173
+ │ └───data/
174
+ │ └───artroom1/
175
+ │ └───artroom2/
176
+ │ └───...
177
+ └───MiddEval3_F/
178
+ └───test/
179
+ │ └───Australia/
180
+ │ └───...
181
+ └───train/
182
+ └───Adirondack/
183
+ └───...
184
+ ```
185
+ </details>
186
+
187
+ <details>
188
+ <summary>Spring</summary>
189
+
190
+ ```
191
+ ./data/stereoflow/spring/
192
+ └───test/
193
+ │ └───0003/
194
+ │ └───...
195
+ └───train/
196
+ └───0001/
197
+ └───...
198
+ ```
199
+ </details>
200
+
201
+
202
+ ### CroCo-Stereo
203
+
204
+ ##### Main model
205
+
206
+ The main training of CroCo-Stereo was performed on a series of datasets, and it was used as it for Middlebury v3 benchmark.
207
+
208
+ ```
209
+ # Download the model
210
+ bash stereoflow/download_model.sh crocostereo.pth
211
+ # Middlebury v3 submission
212
+ python stereoflow/test.py --model stereoflow_models/crocostereo.pth --dataset "MdEval3('all_full')" --save submission --tile_overlap 0.9
213
+ # Training command that was used, using checkpoint-last.pth
214
+ python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main/
215
+ # or it can be launched on multiple gpus (while maintaining the effective batch size), e.g. on 3 gpus:
216
+ torchrun --nproc_per_node 3 stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 2 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main/
217
+ ```
218
+
219
+ For evaluation of validation set, we also provide the model trained on the `subtrain` subset of the training sets.
220
+
221
+ ```
222
+ # Download the model
223
+ bash stereoflow/download_model.sh crocostereo_subtrain.pth
224
+ # Evaluation on validation sets
225
+ python stereoflow/test.py --model stereoflow_models/crocostereo_subtrain.pth --dataset "MdEval3('subval_full')+ETH3DLowRes('subval')+SceneFlow('test_finalpass')+SceneFlow('test_cleanpass')" --save metrics --tile_overlap 0.9
226
+ # Training command that was used (same as above but on subtrain, using checkpoint-best.pth), can also be launched on multiple gpus
227
+ python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('subtrain')+50*Md05('subtrain')+50*Md06('subtrain')+50*Md14('subtrain')+50*Md21('subtrain')+50*MdEval3('subtrain_full')+Booster('subtrain_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main_subtrain/
228
+ ```
229
+
230
+ ##### Other models
231
+
232
+ <details>
233
+ <summary>Model for ETH3D</summary>
234
+ The model used for the submission on ETH3D is trained with the same command but using an unbounded Laplacian loss.
235
+
236
+ # Download the model
237
+ bash stereoflow/download_model.sh crocostereo_eth3d.pth
238
+ # ETH3D submission
239
+ python stereoflow/test.py --model stereoflow_models/crocostereo_eth3d.pth --dataset "ETH3DLowRes('all')" --save submission --tile_overlap 0.9
240
+ # Training command that was used
241
+ python -u stereoflow/train.py stereo --criterion "LaplacianLoss()" --tile_conf_mode conf_expbeta3 --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main_eth3d/
242
+
243
+ </details>
244
+
245
+ <details>
246
+ <summary>Main model finetuned on Kitti</summary>
247
+
248
+ # Download the model
249
+ bash stereoflow/download_model.sh crocostereo_finetune_kitti.pth
250
+ # Kitti submission
251
+ python stereoflow/test.py --model stereoflow_models/crocostereo_finetune_kitti.pth --dataset "Kitti15('test')" --save submission --tile_overlap 0.9
252
+ # Training that was used
253
+ python -u stereoflow/train.py stereo --crop 352 1216 --criterion "LaplacianLossBounded2()" --dataset "Kitti12('train')+Kitti15('train')" --lr 3e-5 --batch_size 1 --accum_iter 6 --epochs 20 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocostereo.pth --output_dir xps/crocostereo/finetune_kitti/ --save_every 5
254
+ </details>
255
+
256
+ <details>
257
+ <summary>Main model finetuned on Spring</summary>
258
+
259
+ # Download the model
260
+ bash stereoflow/download_model.sh crocostereo_finetune_spring.pth
261
+ # Spring submission
262
+ python stereoflow/test.py --model stereoflow_models/crocostereo_finetune_spring.pth --dataset "Spring('test')" --save submission --tile_overlap 0.9
263
+ # Training command that was used
264
+ python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "Spring('train')" --lr 3e-5 --batch_size 6 --epochs 8 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocostereo.pth --output_dir xps/crocostereo/finetune_spring/
265
+ </details>
266
+
267
+ <details>
268
+ <summary>Smaller models</summary>
269
+ To train CroCo-Stereo with smaller CroCo pretrained models, simply replace the <code>--pretrained</code> argument. To download the smaller CroCo-Stereo models based on CroCo v2 pretraining with ViT-Base encoder and Small encoder, use <code>bash stereoflow/download_model.sh crocostereo_subtrain_vitb_smalldecoder.pth</code>, and for the model with a ViT-Base encoder and a Base decoder, use <code>bash stereoflow/download_model.sh crocostereo_subtrain_vitb_basedecoder.pth</code>.
270
+ </details>
271
+
272
+
273
+ ### CroCo-Flow
274
+
275
+ ##### Main model
276
+
277
+ The main training of CroCo-Flow was performed on the FlyingThings, FlyingChairs, MPI-Sintel and TartanAir datasets.
278
+ It was used for our submission to the MPI-Sintel benchmark.
279
+
280
+ ```
281
+ # Download the model
282
+ bash stereoflow/download_model.sh crocoflow.pth
283
+ # Evaluation
284
+ python stereoflow/test.py --model stereoflow_models/crocoflow.pth --dataset "MPISintel('subval_cleanpass')+MPISintel('subval_finalpass')" --save metrics --tile_overlap 0.9
285
+ # Sintel submission
286
+ python stereoflow/test.py --model stereoflow_models/crocoflow.pth --dataset "MPISintel('test_allpass')" --save submission --tile_overlap 0.9
287
+ # Training command that was used, with checkpoint-best.pth
288
+ python -u stereoflow/train.py flow --criterion "LaplacianLossBounded()" --dataset "40*MPISintel('subtrain_cleanpass')+40*MPISintel('subtrain_finalpass')+4*FlyingThings('train_allpass')+4*FlyingChairs('train')+TartanAir('train')" --val_dataset "MPISintel('subval_cleanpass')+MPISintel('subval_finalpass')" --lr 2e-5 --batch_size 8 --epochs 240 --img_per_epoch 30000 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocoflow/main/
289
+ ```
290
+
291
+ ##### Other models
292
+
293
+ <details>
294
+ <summary>Main model finetuned on Kitti</summary>
295
+
296
+ # Download the model
297
+ bash stereoflow/download_model.sh crocoflow_finetune_kitti.pth
298
+ # Kitti submission
299
+ python stereoflow/test.py --model stereoflow_models/crocoflow_finetune_kitti.pth --dataset "Kitti15('test')" --save submission --tile_overlap 0.99
300
+ # Training that was used, with checkpoint-last.pth
301
+ python -u stereoflow/train.py flow --crop 352 1216 --criterion "LaplacianLossBounded()" --dataset "Kitti15('train')+Kitti12('train')" --lr 2e-5 --batch_size 1 --accum_iter 8 --epochs 150 --save_every 5 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocoflow.pth --output_dir xps/crocoflow/finetune_kitti/
302
+ </details>
303
+
304
+ <details>
305
+ <summary>Main model finetuned on Spring</summary>
306
+
307
+ # Download the model
308
+ bash stereoflow/download_model.sh crocoflow_finetune_spring.pth
309
+ # Spring submission
310
+ python stereoflow/test.py --model stereoflow_models/crocoflow_finetune_spring.pth --dataset "Spring('test')" --save submission --tile_overlap 0.9
311
+ # Training command that was used, with checkpoint-last.pth
312
+ python -u stereoflow/train.py flow --criterion "LaplacianLossBounded()" --dataset "Spring('train')" --lr 2e-5 --batch_size 8 --epochs 12 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocoflow.pth --output_dir xps/crocoflow/finetune_spring/
313
+ </details>
314
+
315
+ <details>
316
+ <summary>Smaller models</summary>
317
+ To train CroCo-Flow with smaller CroCo pretrained models, simply replace the <code>--pretrained</code> argument. To download the smaller CroCo-Flow models based on CroCo v2 pretraining with ViT-Base encoder and Small encoder, use <code>bash stereoflow/download_model.sh crocoflow_vitb_smalldecoder.pth</code>, and for the model with a ViT-Base encoder and a Base decoder, use <code>bash stereoflow/download_model.sh crocoflow_vitb_basedecoder.pth</code>.
318
+ </details>
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/augmentor.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Data augmentation for training stereo and flow
6
+ # --------------------------------------------------------
7
+
8
+ # References
9
+ # https://github.com/autonomousvision/unimatch/blob/master/dataloader/stereo/transforms.py
10
+ # https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/transforms.py
11
+
12
+
13
+ import numpy as np
14
+ import random
15
+ from PIL import Image
16
+
17
+ import cv2
18
+
19
+ cv2.setNumThreads(0)
20
+ cv2.ocl.setUseOpenCL(False)
21
+
22
+ import torch
23
+ from torchvision.transforms import ColorJitter
24
+ import torchvision.transforms.functional as FF
25
+
26
+
27
+ class StereoAugmentor(object):
28
+
29
+ def __init__(
30
+ self,
31
+ crop_size,
32
+ scale_prob=0.5,
33
+ scale_xonly=True,
34
+ lhth=800.0,
35
+ lminscale=0.0,
36
+ lmaxscale=1.0,
37
+ hminscale=-0.2,
38
+ hmaxscale=0.4,
39
+ scale_interp_nearest=True,
40
+ rightjitterprob=0.5,
41
+ v_flip_prob=0.5,
42
+ color_aug_asym=True,
43
+ color_choice_prob=0.5,
44
+ ):
45
+ self.crop_size = crop_size
46
+ self.scale_prob = scale_prob
47
+ self.scale_xonly = scale_xonly
48
+ self.lhth = lhth
49
+ self.lminscale = lminscale
50
+ self.lmaxscale = lmaxscale
51
+ self.hminscale = hminscale
52
+ self.hmaxscale = hmaxscale
53
+ self.scale_interp_nearest = scale_interp_nearest
54
+ self.rightjitterprob = rightjitterprob
55
+ self.v_flip_prob = v_flip_prob
56
+ self.color_aug_asym = color_aug_asym
57
+ self.color_choice_prob = color_choice_prob
58
+
59
+ def _random_scale(self, img1, img2, disp):
60
+ ch, cw = self.crop_size
61
+ h, w = img1.shape[:2]
62
+ if self.scale_prob > 0.0 and np.random.rand() < self.scale_prob:
63
+ min_scale, max_scale = (
64
+ (self.lminscale, self.lmaxscale)
65
+ if min(h, w) < self.lhth
66
+ else (self.hminscale, self.hmaxscale)
67
+ )
68
+ scale_x = 2.0 ** np.random.uniform(min_scale, max_scale)
69
+ scale_x = np.clip(scale_x, (cw + 8) / float(w), None)
70
+ scale_y = 1.0
71
+ if not self.scale_xonly:
72
+ scale_y = scale_x
73
+ scale_y = np.clip(scale_y, (ch + 8) / float(h), None)
74
+ img1 = cv2.resize(
75
+ img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
76
+ )
77
+ img2 = cv2.resize(
78
+ img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
79
+ )
80
+ disp = (
81
+ cv2.resize(
82
+ disp,
83
+ None,
84
+ fx=scale_x,
85
+ fy=scale_y,
86
+ interpolation=(
87
+ cv2.INTER_LINEAR
88
+ if not self.scale_interp_nearest
89
+ else cv2.INTER_NEAREST
90
+ ),
91
+ )
92
+ * scale_x
93
+ )
94
+ else: # check if we need to resize to be able to crop
95
+ h, w = img1.shape[:2]
96
+ clip_scale = (cw + 8) / float(w)
97
+ if clip_scale > 1.0:
98
+ scale_x = clip_scale
99
+ scale_y = scale_x if not self.scale_xonly else 1.0
100
+ img1 = cv2.resize(
101
+ img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
102
+ )
103
+ img2 = cv2.resize(
104
+ img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
105
+ )
106
+ disp = (
107
+ cv2.resize(
108
+ disp,
109
+ None,
110
+ fx=scale_x,
111
+ fy=scale_y,
112
+ interpolation=(
113
+ cv2.INTER_LINEAR
114
+ if not self.scale_interp_nearest
115
+ else cv2.INTER_NEAREST
116
+ ),
117
+ )
118
+ * scale_x
119
+ )
120
+ return img1, img2, disp
121
+
122
+ def _random_crop(self, img1, img2, disp):
123
+ h, w = img1.shape[:2]
124
+ ch, cw = self.crop_size
125
+ assert ch <= h and cw <= w, (img1.shape, h, w, ch, cw)
126
+ offset_x = np.random.randint(w - cw + 1)
127
+ offset_y = np.random.randint(h - ch + 1)
128
+ img1 = img1[offset_y : offset_y + ch, offset_x : offset_x + cw]
129
+ img2 = img2[offset_y : offset_y + ch, offset_x : offset_x + cw]
130
+ disp = disp[offset_y : offset_y + ch, offset_x : offset_x + cw]
131
+ return img1, img2, disp
132
+
133
+ def _random_vflip(self, img1, img2, disp):
134
+ # vertical flip
135
+ if self.v_flip_prob > 0 and np.random.rand() < self.v_flip_prob:
136
+ img1 = np.copy(np.flipud(img1))
137
+ img2 = np.copy(np.flipud(img2))
138
+ disp = np.copy(np.flipud(disp))
139
+ return img1, img2, disp
140
+
141
+ def _random_rotate_shift_right(self, img2):
142
+ if self.rightjitterprob > 0.0 and np.random.rand() < self.rightjitterprob:
143
+ angle, pixel = 0.1, 2
144
+ px = np.random.uniform(-pixel, pixel)
145
+ ag = np.random.uniform(-angle, angle)
146
+ image_center = (
147
+ np.random.uniform(0, img2.shape[0]),
148
+ np.random.uniform(0, img2.shape[1]),
149
+ )
150
+ rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
151
+ img2 = cv2.warpAffine(
152
+ img2, rot_mat, img2.shape[1::-1], flags=cv2.INTER_LINEAR
153
+ )
154
+ trans_mat = np.float32([[1, 0, 0], [0, 1, px]])
155
+ img2 = cv2.warpAffine(
156
+ img2, trans_mat, img2.shape[1::-1], flags=cv2.INTER_LINEAR
157
+ )
158
+ return img2
159
+
160
+ def _random_color_contrast(self, img1, img2):
161
+ if np.random.random() < 0.5:
162
+ contrast_factor = np.random.uniform(0.8, 1.2)
163
+ img1 = FF.adjust_contrast(img1, contrast_factor)
164
+ if self.color_aug_asym and np.random.random() < 0.5:
165
+ contrast_factor = np.random.uniform(0.8, 1.2)
166
+ img2 = FF.adjust_contrast(img2, contrast_factor)
167
+ return img1, img2
168
+
169
+ def _random_color_gamma(self, img1, img2):
170
+ if np.random.random() < 0.5:
171
+ gamma = np.random.uniform(0.7, 1.5)
172
+ img1 = FF.adjust_gamma(img1, gamma)
173
+ if self.color_aug_asym and np.random.random() < 0.5:
174
+ gamma = np.random.uniform(0.7, 1.5)
175
+ img2 = FF.adjust_gamma(img2, gamma)
176
+ return img1, img2
177
+
178
+ def _random_color_brightness(self, img1, img2):
179
+ if np.random.random() < 0.5:
180
+ brightness = np.random.uniform(0.5, 2.0)
181
+ img1 = FF.adjust_brightness(img1, brightness)
182
+ if self.color_aug_asym and np.random.random() < 0.5:
183
+ brightness = np.random.uniform(0.5, 2.0)
184
+ img2 = FF.adjust_brightness(img2, brightness)
185
+ return img1, img2
186
+
187
+ def _random_color_hue(self, img1, img2):
188
+ if np.random.random() < 0.5:
189
+ hue = np.random.uniform(-0.1, 0.1)
190
+ img1 = FF.adjust_hue(img1, hue)
191
+ if self.color_aug_asym and np.random.random() < 0.5:
192
+ hue = np.random.uniform(-0.1, 0.1)
193
+ img2 = FF.adjust_hue(img2, hue)
194
+ return img1, img2
195
+
196
+ def _random_color_saturation(self, img1, img2):
197
+ if np.random.random() < 0.5:
198
+ saturation = np.random.uniform(0.8, 1.2)
199
+ img1 = FF.adjust_saturation(img1, saturation)
200
+ if self.color_aug_asym and np.random.random() < 0.5:
201
+ saturation = np.random.uniform(-0.8, 1.2)
202
+ img2 = FF.adjust_saturation(img2, saturation)
203
+ return img1, img2
204
+
205
+ def _random_color(self, img1, img2):
206
+ trfs = [
207
+ self._random_color_contrast,
208
+ self._random_color_gamma,
209
+ self._random_color_brightness,
210
+ self._random_color_hue,
211
+ self._random_color_saturation,
212
+ ]
213
+ img1 = Image.fromarray(img1.astype("uint8"))
214
+ img2 = Image.fromarray(img2.astype("uint8"))
215
+ if np.random.random() < self.color_choice_prob:
216
+ # A single transform
217
+ t = random.choice(trfs)
218
+ img1, img2 = t(img1, img2)
219
+ else:
220
+ # Combination of trfs
221
+ # Random order
222
+ random.shuffle(trfs)
223
+ for t in trfs:
224
+ img1, img2 = t(img1, img2)
225
+ img1 = np.array(img1).astype(np.float32)
226
+ img2 = np.array(img2).astype(np.float32)
227
+ return img1, img2
228
+
229
+ def __call__(self, img1, img2, disp, dataset_name):
230
+ img1, img2, disp = self._random_scale(img1, img2, disp)
231
+ img1, img2, disp = self._random_crop(img1, img2, disp)
232
+ img1, img2, disp = self._random_vflip(img1, img2, disp)
233
+ img2 = self._random_rotate_shift_right(img2)
234
+ img1, img2 = self._random_color(img1, img2)
235
+ return img1, img2, disp
236
+
237
+
238
+ class FlowAugmentor:
239
+
240
+ def __init__(
241
+ self,
242
+ crop_size,
243
+ min_scale=-0.2,
244
+ max_scale=0.5,
245
+ spatial_aug_prob=0.8,
246
+ stretch_prob=0.8,
247
+ max_stretch=0.2,
248
+ h_flip_prob=0.5,
249
+ v_flip_prob=0.1,
250
+ asymmetric_color_aug_prob=0.2,
251
+ ):
252
+
253
+ # spatial augmentation params
254
+ self.crop_size = crop_size
255
+ self.min_scale = min_scale
256
+ self.max_scale = max_scale
257
+ self.spatial_aug_prob = spatial_aug_prob
258
+ self.stretch_prob = stretch_prob
259
+ self.max_stretch = max_stretch
260
+
261
+ # flip augmentation params
262
+ self.h_flip_prob = h_flip_prob
263
+ self.v_flip_prob = v_flip_prob
264
+
265
+ # photometric augmentation params
266
+ self.photo_aug = ColorJitter(
267
+ brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14
268
+ )
269
+
270
+ self.asymmetric_color_aug_prob = asymmetric_color_aug_prob
271
+
272
+ def color_transform(self, img1, img2):
273
+ """Photometric augmentation"""
274
+
275
+ # asymmetric
276
+ if np.random.rand() < self.asymmetric_color_aug_prob:
277
+ img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
278
+ img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
279
+
280
+ # symmetric
281
+ else:
282
+ image_stack = np.concatenate([img1, img2], axis=0)
283
+ image_stack = np.array(
284
+ self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
285
+ )
286
+ img1, img2 = np.split(image_stack, 2, axis=0)
287
+
288
+ return img1, img2
289
+
290
+ def _resize_flow(self, flow, scale_x, scale_y, factor=1.0):
291
+ if np.all(np.isfinite(flow)):
292
+ flow = cv2.resize(
293
+ flow,
294
+ None,
295
+ fx=scale_x / factor,
296
+ fy=scale_y / factor,
297
+ interpolation=cv2.INTER_LINEAR,
298
+ )
299
+ flow = flow * [scale_x, scale_y]
300
+ else: # sparse version
301
+ fx, fy = scale_x, scale_y
302
+ ht, wd = flow.shape[:2]
303
+ coords = np.meshgrid(np.arange(wd), np.arange(ht))
304
+ coords = np.stack(coords, axis=-1)
305
+
306
+ coords = coords.reshape(-1, 2).astype(np.float32)
307
+ flow = flow.reshape(-1, 2).astype(np.float32)
308
+ valid = np.isfinite(flow[:, 0])
309
+
310
+ coords0 = coords[valid]
311
+ flow0 = flow[valid]
312
+
313
+ ht1 = int(round(ht * fy / factor))
314
+ wd1 = int(round(wd * fx / factor))
315
+
316
+ rescale = np.expand_dims(np.array([fx, fy]), axis=0)
317
+ coords1 = coords0 * rescale / factor
318
+ flow1 = flow0 * rescale
319
+
320
+ xx = np.round(coords1[:, 0]).astype(np.int32)
321
+ yy = np.round(coords1[:, 1]).astype(np.int32)
322
+
323
+ v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
324
+ xx = xx[v]
325
+ yy = yy[v]
326
+ flow1 = flow1[v]
327
+
328
+ flow = np.inf * np.ones(
329
+ [ht1, wd1, 2], dtype=np.float32
330
+ ) # invalid value every where, before we fill it with the correct ones
331
+ flow[yy, xx] = flow1
332
+ return flow
333
+
334
+ def spatial_transform(self, img1, img2, flow, dname):
335
+
336
+ if np.random.rand() < self.spatial_aug_prob:
337
+ # randomly sample scale
338
+ ht, wd = img1.shape[:2]
339
+ clip_min_scale = np.maximum(
340
+ (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
341
+ )
342
+ min_scale, max_scale = self.min_scale, self.max_scale
343
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
344
+ scale_x = scale
345
+ scale_y = scale
346
+ if np.random.rand() < self.stretch_prob:
347
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
348
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
349
+ scale_x = np.clip(scale_x, clip_min_scale, None)
350
+ scale_y = np.clip(scale_y, clip_min_scale, None)
351
+ # rescale the images
352
+ img1 = cv2.resize(
353
+ img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
354
+ )
355
+ img2 = cv2.resize(
356
+ img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
357
+ )
358
+ flow = self._resize_flow(
359
+ flow, scale_x, scale_y, factor=2.0 if dname == "Spring" else 1.0
360
+ )
361
+ elif dname == "Spring":
362
+ flow = self._resize_flow(flow, 1.0, 1.0, factor=2.0)
363
+
364
+ if self.h_flip_prob > 0.0 and np.random.rand() < self.h_flip_prob: # h-flip
365
+ img1 = img1[:, ::-1]
366
+ img2 = img2[:, ::-1]
367
+ flow = flow[:, ::-1] * [-1.0, 1.0]
368
+
369
+ if self.v_flip_prob > 0.0 and np.random.rand() < self.v_flip_prob: # v-flip
370
+ img1 = img1[::-1, :]
371
+ img2 = img2[::-1, :]
372
+ flow = flow[::-1, :] * [1.0, -1.0]
373
+
374
+ # In case no cropping
375
+ if img1.shape[0] - self.crop_size[0] > 0:
376
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
377
+ else:
378
+ y0 = 0
379
+ if img1.shape[1] - self.crop_size[1] > 0:
380
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
381
+ else:
382
+ x0 = 0
383
+
384
+ img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
385
+ img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
386
+ flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
387
+
388
+ return img1, img2, flow
389
+
390
+ def __call__(self, img1, img2, flow, dname):
391
+ img1, img2, flow = self.spatial_transform(img1, img2, flow, dname)
392
+ img1, img2 = self.color_transform(img1, img2)
393
+ img1 = np.ascontiguousarray(img1)
394
+ img2 = np.ascontiguousarray(img2)
395
+ flow = np.ascontiguousarray(flow)
396
+ return img1, img2, flow
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/criterion.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Losses, metrics per batch, metrics per dataset
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def _get_gtnorm(gt):
14
+ if gt.size(1) == 1: # stereo
15
+ return gt
16
+ # flow
17
+ return torch.sqrt(torch.sum(gt**2, dim=1, keepdims=True)) # Bx1xHxW
18
+
19
+
20
+ ############ losses without confidence
21
+
22
+
23
+ class L1Loss(nn.Module):
24
+
25
+ def __init__(self, max_gtnorm=None):
26
+ super().__init__()
27
+ self.max_gtnorm = max_gtnorm
28
+ self.with_conf = False
29
+
30
+ def _error(self, gt, predictions):
31
+ return torch.abs(gt - predictions)
32
+
33
+ def forward(self, predictions, gt, inspect=False):
34
+ mask = torch.isfinite(gt)
35
+ if self.max_gtnorm is not None:
36
+ mask *= _get_gtnorm(gt).expand(-1, gt.size(1), -1, -1) < self.max_gtnorm
37
+ if inspect:
38
+ return self._error(gt, predictions)
39
+ return self._error(gt[mask], predictions[mask]).mean()
40
+
41
+
42
+ ############## losses with confience
43
+ ## there are several parametrizations
44
+
45
+
46
+ class LaplacianLoss(nn.Module): # used for CroCo-Stereo on ETH3D, d'=exp(d)
47
+
48
+ def __init__(self, max_gtnorm=None):
49
+ super().__init__()
50
+ self.max_gtnorm = max_gtnorm
51
+ self.with_conf = True
52
+
53
+ def forward(self, predictions, gt, conf):
54
+ mask = torch.isfinite(gt)
55
+ mask = mask[:, 0, :, :]
56
+ if self.max_gtnorm is not None:
57
+ mask *= _get_gtnorm(gt)[:, 0, :, :] < self.max_gtnorm
58
+ conf = conf.squeeze(1)
59
+ return (
60
+ torch.abs(gt - predictions).sum(dim=1)[mask] / torch.exp(conf[mask])
61
+ + conf[mask]
62
+ ).mean() # + torch.log(2) => which is a constant
63
+
64
+
65
+ class LaplacianLossBounded(
66
+ nn.Module
67
+ ): # used for CroCo-Flow ; in the equation of the paper, we have a=1/b
68
+ def __init__(self, max_gtnorm=10000.0, a=0.25, b=4.0):
69
+ super().__init__()
70
+ self.max_gtnorm = max_gtnorm
71
+ self.with_conf = True
72
+ self.a, self.b = a, b
73
+
74
+ def forward(self, predictions, gt, conf):
75
+ mask = torch.isfinite(gt)
76
+ mask = mask[:, 0, :, :]
77
+ if self.max_gtnorm is not None:
78
+ mask *= _get_gtnorm(gt)[:, 0, :, :] < self.max_gtnorm
79
+ conf = conf.squeeze(1)
80
+ conf = (self.b - self.a) * torch.sigmoid(conf) + self.a
81
+ return (
82
+ torch.abs(gt - predictions).sum(dim=1)[mask] / conf[mask]
83
+ + torch.log(conf)[mask]
84
+ ).mean() # + torch.log(2) => which is a constant
85
+
86
+
87
+ class LaplacianLossBounded2(
88
+ nn.Module
89
+ ): # used for CroCo-Stereo (except for ETH3D) ; in the equation of the paper, we have a=b
90
+ def __init__(self, max_gtnorm=None, a=3.0, b=3.0):
91
+ super().__init__()
92
+ self.max_gtnorm = max_gtnorm
93
+ self.with_conf = True
94
+ self.a, self.b = a, b
95
+
96
+ def forward(self, predictions, gt, conf):
97
+ mask = torch.isfinite(gt)
98
+ mask = mask[:, 0, :, :]
99
+ if self.max_gtnorm is not None:
100
+ mask *= _get_gtnorm(gt)[:, 0, :, :] < self.max_gtnorm
101
+ conf = conf.squeeze(1)
102
+ conf = 2 * self.a * (torch.sigmoid(conf / self.b) - 0.5)
103
+ return (
104
+ torch.abs(gt - predictions).sum(dim=1)[mask] / torch.exp(conf[mask])
105
+ + conf[mask]
106
+ ).mean() # + torch.log(2) => which is a constant
107
+
108
+
109
+ ############## metrics per batch
110
+
111
+
112
+ class StereoMetrics(nn.Module):
113
+
114
+ def __init__(self, do_quantile=False):
115
+ super().__init__()
116
+ self.bad_ths = [0.5, 1, 2, 3]
117
+ self.do_quantile = do_quantile
118
+
119
+ def forward(self, predictions, gt):
120
+ B = predictions.size(0)
121
+ metrics = {}
122
+ gtcopy = gt.clone()
123
+ mask = torch.isfinite(gtcopy)
124
+ gtcopy[~mask] = (
125
+ 999999.0 # we make a copy and put a non-infinite value, such that it does not become nan once multiplied by the mask value 0
126
+ )
127
+ Npx = mask.view(B, -1).sum(dim=1)
128
+ L1error = (torch.abs(gtcopy - predictions) * mask).view(B, -1)
129
+ L2error = (torch.square(gtcopy - predictions) * mask).view(B, -1)
130
+ # avgerr
131
+ metrics["avgerr"] = torch.mean(L1error.sum(dim=1) / Npx)
132
+ # rmse
133
+ metrics["rmse"] = torch.sqrt(L2error.sum(dim=1) / Npx).mean(dim=0)
134
+ # err > t for t in [0.5,1,2,3]
135
+ for ths in self.bad_ths:
136
+ metrics["bad@{:.1f}".format(ths)] = (
137
+ ((L1error > ths) * mask.view(B, -1)).sum(dim=1) / Npx
138
+ ).mean(dim=0) * 100
139
+ return metrics
140
+
141
+
142
+ class FlowMetrics(nn.Module):
143
+ def __init__(self):
144
+ super().__init__()
145
+ self.bad_ths = [1, 3, 5]
146
+
147
+ def forward(self, predictions, gt):
148
+ B = predictions.size(0)
149
+ metrics = {}
150
+ mask = torch.isfinite(gt[:, 0, :, :]) # both x and y would be infinite
151
+ Npx = mask.view(B, -1).sum(dim=1)
152
+ gtcopy = (
153
+ gt.clone()
154
+ ) # to compute L1/L2 error, we need to have non-infinite value, the error computed at this locations will be ignored
155
+ gtcopy[:, 0, :, :][~mask] = 999999.0
156
+ gtcopy[:, 1, :, :][~mask] = 999999.0
157
+ L1error = (torch.abs(gtcopy - predictions).sum(dim=1) * mask).view(B, -1)
158
+ L2error = (
159
+ torch.sqrt(torch.sum(torch.square(gtcopy - predictions), dim=1)) * mask
160
+ ).view(B, -1)
161
+ metrics["L1err"] = torch.mean(L1error.sum(dim=1) / Npx)
162
+ metrics["EPE"] = torch.mean(L2error.sum(dim=1) / Npx)
163
+ for ths in self.bad_ths:
164
+ metrics["bad@{:.1f}".format(ths)] = (
165
+ ((L2error > ths) * mask.view(B, -1)).sum(dim=1) / Npx
166
+ ).mean(dim=0) * 100
167
+ return metrics
168
+
169
+
170
+ ############## metrics per dataset
171
+ ## we update the average and maintain the number of pixels while adding data batch per batch
172
+ ## at the beggining, call reset()
173
+ ## after each batch, call add_batch(...)
174
+ ## at the end: call get_results()
175
+
176
+
177
+ class StereoDatasetMetrics(nn.Module):
178
+
179
+ def __init__(self):
180
+ super().__init__()
181
+ self.bad_ths = [0.5, 1, 2, 3]
182
+
183
+ def reset(self):
184
+ self.agg_N = 0 # number of pixels so far
185
+ self.agg_L1err = torch.tensor(0.0) # L1 error so far
186
+ self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels
187
+ self._metrics = None
188
+
189
+ def add_batch(self, predictions, gt):
190
+ assert predictions.size(1) == 1, predictions.size()
191
+ assert gt.size(1) == 1, gt.size()
192
+ if (
193
+ gt.size(2) == predictions.size(2) * 2
194
+ and gt.size(3) == predictions.size(3) * 2
195
+ ): # special case for Spring ...
196
+ L1err = torch.minimum(
197
+ torch.minimum(
198
+ torch.minimum(
199
+ torch.sum(torch.abs(gt[:, :, 0::2, 0::2] - predictions), dim=1),
200
+ torch.sum(torch.abs(gt[:, :, 1::2, 0::2] - predictions), dim=1),
201
+ ),
202
+ torch.sum(torch.abs(gt[:, :, 0::2, 1::2] - predictions), dim=1),
203
+ ),
204
+ torch.sum(torch.abs(gt[:, :, 1::2, 1::2] - predictions), dim=1),
205
+ )
206
+ valid = torch.isfinite(L1err)
207
+ else:
208
+ valid = torch.isfinite(gt[:, 0, :, :]) # both x and y would be infinite
209
+ L1err = torch.sum(torch.abs(gt - predictions), dim=1)
210
+ N = valid.sum()
211
+ Nnew = self.agg_N + N
212
+ self.agg_L1err = (
213
+ float(self.agg_N) / Nnew * self.agg_L1err
214
+ + L1err[valid].mean().cpu() * float(N) / Nnew
215
+ )
216
+ self.agg_N = Nnew
217
+ for i, th in enumerate(self.bad_ths):
218
+ self.agg_Nbad[i] += (L1err[valid] > th).sum().cpu()
219
+
220
+ def _compute_metrics(self):
221
+ if self._metrics is not None:
222
+ return
223
+ out = {}
224
+ out["L1err"] = self.agg_L1err.item()
225
+ for i, th in enumerate(self.bad_ths):
226
+ out["bad@{:.1f}".format(th)] = (
227
+ float(self.agg_Nbad[i]) / self.agg_N
228
+ ).item() * 100.0
229
+ self._metrics = out
230
+
231
+ def get_results(self):
232
+ self._compute_metrics() # to avoid recompute them multiple times
233
+ return self._metrics
234
+
235
+
236
+ class FlowDatasetMetrics(nn.Module):
237
+
238
+ def __init__(self):
239
+ super().__init__()
240
+ self.bad_ths = [0.5, 1, 3, 5]
241
+ self.speed_ths = [(0, 10), (10, 40), (40, torch.inf)]
242
+
243
+ def reset(self):
244
+ self.agg_N = 0 # number of pixels so far
245
+ self.agg_L1err = torch.tensor(0.0) # L1 error so far
246
+ self.agg_L2err = torch.tensor(0.0) # L2 (=EPE) error so far
247
+ self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels
248
+ self.agg_EPEspeed = [
249
+ torch.tensor(0.0) for _ in self.speed_ths
250
+ ] # EPE per speed bin so far
251
+ self.agg_Nspeed = [0 for _ in self.speed_ths] # N pixels per speed bin so far
252
+ self._metrics = None
253
+ self.pairname_results = {}
254
+
255
+ def add_batch(self, predictions, gt):
256
+ assert predictions.size(1) == 2, predictions.size()
257
+ assert gt.size(1) == 2, gt.size()
258
+ if (
259
+ gt.size(2) == predictions.size(2) * 2
260
+ and gt.size(3) == predictions.size(3) * 2
261
+ ): # special case for Spring ...
262
+ L1err = torch.minimum(
263
+ torch.minimum(
264
+ torch.minimum(
265
+ torch.sum(torch.abs(gt[:, :, 0::2, 0::2] - predictions), dim=1),
266
+ torch.sum(torch.abs(gt[:, :, 1::2, 0::2] - predictions), dim=1),
267
+ ),
268
+ torch.sum(torch.abs(gt[:, :, 0::2, 1::2] - predictions), dim=1),
269
+ ),
270
+ torch.sum(torch.abs(gt[:, :, 1::2, 1::2] - predictions), dim=1),
271
+ )
272
+ L2err = torch.minimum(
273
+ torch.minimum(
274
+ torch.minimum(
275
+ torch.sqrt(
276
+ torch.sum(
277
+ torch.square(gt[:, :, 0::2, 0::2] - predictions), dim=1
278
+ )
279
+ ),
280
+ torch.sqrt(
281
+ torch.sum(
282
+ torch.square(gt[:, :, 1::2, 0::2] - predictions), dim=1
283
+ )
284
+ ),
285
+ ),
286
+ torch.sqrt(
287
+ torch.sum(
288
+ torch.square(gt[:, :, 0::2, 1::2] - predictions), dim=1
289
+ )
290
+ ),
291
+ ),
292
+ torch.sqrt(
293
+ torch.sum(torch.square(gt[:, :, 1::2, 1::2] - predictions), dim=1)
294
+ ),
295
+ )
296
+ valid = torch.isfinite(L1err)
297
+ gtspeed = (
298
+ torch.sqrt(torch.sum(torch.square(gt[:, :, 0::2, 0::2]), dim=1))
299
+ + torch.sqrt(torch.sum(torch.square(gt[:, :, 0::2, 1::2]), dim=1))
300
+ + torch.sqrt(torch.sum(torch.square(gt[:, :, 1::2, 0::2]), dim=1))
301
+ + torch.sqrt(torch.sum(torch.square(gt[:, :, 1::2, 1::2]), dim=1))
302
+ ) / 4.0 # let's just average them
303
+ else:
304
+ valid = torch.isfinite(gt[:, 0, :, :]) # both x and y would be infinite
305
+ L1err = torch.sum(torch.abs(gt - predictions), dim=1)
306
+ L2err = torch.sqrt(torch.sum(torch.square(gt - predictions), dim=1))
307
+ gtspeed = torch.sqrt(torch.sum(torch.square(gt), dim=1))
308
+ N = valid.sum()
309
+ Nnew = self.agg_N + N
310
+ self.agg_L1err = (
311
+ float(self.agg_N) / Nnew * self.agg_L1err
312
+ + L1err[valid].mean().cpu() * float(N) / Nnew
313
+ )
314
+ self.agg_L2err = (
315
+ float(self.agg_N) / Nnew * self.agg_L2err
316
+ + L2err[valid].mean().cpu() * float(N) / Nnew
317
+ )
318
+ self.agg_N = Nnew
319
+ for i, th in enumerate(self.bad_ths):
320
+ self.agg_Nbad[i] += (L2err[valid] > th).sum().cpu()
321
+ for i, (th1, th2) in enumerate(self.speed_ths):
322
+ vv = (gtspeed[valid] >= th1) * (gtspeed[valid] < th2)
323
+ iNspeed = vv.sum()
324
+ if iNspeed == 0:
325
+ continue
326
+ iNnew = self.agg_Nspeed[i] + iNspeed
327
+ self.agg_EPEspeed[i] = (
328
+ float(self.agg_Nspeed[i]) / iNnew * self.agg_EPEspeed[i]
329
+ + float(iNspeed) / iNnew * L2err[valid][vv].mean().cpu()
330
+ )
331
+ self.agg_Nspeed[i] = iNnew
332
+
333
+ def _compute_metrics(self):
334
+ if self._metrics is not None:
335
+ return
336
+ out = {}
337
+ out["L1err"] = self.agg_L1err.item()
338
+ out["EPE"] = self.agg_L2err.item()
339
+ for i, th in enumerate(self.bad_ths):
340
+ out["bad@{:.1f}".format(th)] = (
341
+ float(self.agg_Nbad[i]) / self.agg_N
342
+ ).item() * 100.0
343
+ for i, (th1, th2) in enumerate(self.speed_ths):
344
+ out["s{:d}{:s}".format(th1, "-" + str(th2) if th2 < torch.inf else "+")] = (
345
+ self.agg_EPEspeed[i].item()
346
+ )
347
+ self._metrics = out
348
+
349
+ def get_results(self):
350
+ self._compute_metrics() # to avoid recompute them multiple times
351
+ return self._metrics
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/datasets_flow.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Dataset structure for flow
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import os.path as osp
10
+ import pickle
11
+ import numpy as np
12
+ import struct
13
+ from PIL import Image
14
+ import json
15
+ import h5py
16
+ import torch
17
+ from torch.utils import data
18
+
19
+ from .augmentor import FlowAugmentor
20
+ from .datasets_stereo import _read_img, img_to_tensor, dataset_to_root, _read_pfm
21
+ from copy import deepcopy
22
+
23
+ dataset_to_root = deepcopy(dataset_to_root)
24
+
25
+ dataset_to_root.update(
26
+ **{
27
+ "TartanAir": "./data/stereoflow/TartanAir",
28
+ "FlyingChairs": "./data/stereoflow/FlyingChairs/",
29
+ "FlyingThings": osp.join(dataset_to_root["SceneFlow"], "FlyingThings") + "/",
30
+ "MPISintel": "./data/stereoflow//MPI-Sintel/" + "/",
31
+ }
32
+ )
33
+ cache_dir = "./data/stereoflow/datasets_flow_cache/"
34
+
35
+
36
+ def flow_to_tensor(disp):
37
+ return torch.from_numpy(disp).float().permute(2, 0, 1)
38
+
39
+
40
+ class FlowDataset(data.Dataset):
41
+
42
+ def __init__(self, split, augmentor=False, crop_size=None, totensor=True):
43
+ self.split = split
44
+ if not augmentor:
45
+ assert crop_size is None
46
+ if crop_size is not None:
47
+ assert augmentor
48
+ self.crop_size = crop_size
49
+ self.augmentor_str = augmentor
50
+ self.augmentor = FlowAugmentor(crop_size) if augmentor else None
51
+ self.totensor = totensor
52
+ self.rmul = 1 # keep track of rmul
53
+ self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time)
54
+ self._prepare_data()
55
+ self._load_or_build_cache()
56
+
57
+ def prepare_data(self):
58
+ """
59
+ to be defined for each dataset
60
+ """
61
+ raise NotImplementedError
62
+
63
+ def __len__(self):
64
+ return len(
65
+ self.pairnames
66
+ ) # each pairname is typically of the form (str, int1, int2)
67
+
68
+ def __getitem__(self, index):
69
+ pairname = self.pairnames[index]
70
+
71
+ # get filenames
72
+ img1name = self.pairname_to_img1name(pairname)
73
+ img2name = self.pairname_to_img2name(pairname)
74
+ flowname = (
75
+ self.pairname_to_flowname(pairname)
76
+ if self.pairname_to_flowname is not None
77
+ else None
78
+ )
79
+
80
+ # load images and disparities
81
+ img1 = _read_img(img1name)
82
+ img2 = _read_img(img2name)
83
+ flow = self.load_flow(flowname) if flowname is not None else None
84
+
85
+ # apply augmentations
86
+ if self.augmentor is not None:
87
+ img1, img2, flow = self.augmentor(img1, img2, flow, self.name)
88
+
89
+ if self.totensor:
90
+ img1 = img_to_tensor(img1)
91
+ img2 = img_to_tensor(img2)
92
+ if flow is not None:
93
+ flow = flow_to_tensor(flow)
94
+ else:
95
+ flow = torch.tensor(
96
+ []
97
+ ) # to allow dataloader batching with default collate_gn
98
+ pairname = str(
99
+ pairname
100
+ ) # transform potential tuple to str to be able to batch it
101
+
102
+ return img1, img2, flow, pairname
103
+
104
+ def __rmul__(self, v):
105
+ self.rmul *= v
106
+ self.pairnames = v * self.pairnames
107
+ return self
108
+
109
+ def __str__(self):
110
+ return f"{self.__class__.__name__}_{self.split}"
111
+
112
+ def __repr__(self):
113
+ s = f"{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})"
114
+ if self.rmul == 1:
115
+ s += f"\n\tnum pairs: {len(self.pairnames)}"
116
+ else:
117
+ s += f"\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})"
118
+ return s
119
+
120
+ def _set_root(self):
121
+ self.root = dataset_to_root[self.name]
122
+ assert os.path.isdir(
123
+ self.root
124
+ ), f"could not find root directory for dataset {self.name}: {self.root}"
125
+
126
+ def _load_or_build_cache(self):
127
+ cache_file = osp.join(cache_dir, self.name + ".pkl")
128
+ if osp.isfile(cache_file):
129
+ with open(cache_file, "rb") as fid:
130
+ self.pairnames = pickle.load(fid)[self.split]
131
+ else:
132
+ tosave = self._build_cache()
133
+ os.makedirs(cache_dir, exist_ok=True)
134
+ with open(cache_file, "wb") as fid:
135
+ pickle.dump(tosave, fid)
136
+ self.pairnames = tosave[self.split]
137
+
138
+
139
+ class TartanAirDataset(FlowDataset):
140
+
141
+ def _prepare_data(self):
142
+ self.name = "TartanAir"
143
+ self._set_root()
144
+ assert self.split in ["train"]
145
+ self.pairname_to_img1name = lambda pairname: osp.join(
146
+ self.root, pairname[0], "image_left/{:06d}_left.png".format(pairname[1])
147
+ )
148
+ self.pairname_to_img2name = lambda pairname: osp.join(
149
+ self.root, pairname[0], "image_left/{:06d}_left.png".format(pairname[2])
150
+ )
151
+ self.pairname_to_flowname = lambda pairname: osp.join(
152
+ self.root,
153
+ pairname[0],
154
+ "flow/{:06d}_{:06d}_flow.npy".format(pairname[1], pairname[2]),
155
+ )
156
+ self.pairname_to_str = lambda pairname: os.path.join(
157
+ pairname[0][pairname[0].find("/") + 1 :],
158
+ "{:06d}_{:06d}".format(pairname[1], pairname[2]),
159
+ )
160
+ self.load_flow = _read_numpy_flow
161
+
162
+ def _build_cache(self):
163
+ seqs = sorted(os.listdir(self.root))
164
+ pairs = [
165
+ (osp.join(s, s, difficulty, Pxxx), int(a[:6]), int(a[:6]) + 1)
166
+ for s in seqs
167
+ for difficulty in ["Easy", "Hard"]
168
+ for Pxxx in sorted(os.listdir(osp.join(self.root, s, s, difficulty)))
169
+ for a in sorted(
170
+ os.listdir(osp.join(self.root, s, s, difficulty, Pxxx, "image_left/"))
171
+ )[:-1]
172
+ ]
173
+ assert len(pairs) == 306268, "incorrect parsing of pairs in TartanAir"
174
+ tosave = {"train": pairs}
175
+ return tosave
176
+
177
+
178
+ class FlyingChairsDataset(FlowDataset):
179
+
180
+ def _prepare_data(self):
181
+ self.name = "FlyingChairs"
182
+ self._set_root()
183
+ assert self.split in ["train", "val"]
184
+ self.pairname_to_img1name = lambda pairname: osp.join(
185
+ self.root, "data", pairname + "_img1.ppm"
186
+ )
187
+ self.pairname_to_img2name = lambda pairname: osp.join(
188
+ self.root, "data", pairname + "_img2.ppm"
189
+ )
190
+ self.pairname_to_flowname = lambda pairname: osp.join(
191
+ self.root, "data", pairname + "_flow.flo"
192
+ )
193
+ self.pairname_to_str = lambda pairname: pairname
194
+ self.load_flow = _read_flo_file
195
+
196
+ def _build_cache(self):
197
+ split_file = osp.join(self.root, "chairs_split.txt")
198
+ split_list = np.loadtxt(split_file, dtype=np.int32)
199
+ trainpairs = ["{:05d}".format(i) for i in np.where(split_list == 1)[0] + 1]
200
+ valpairs = ["{:05d}".format(i) for i in np.where(split_list == 2)[0] + 1]
201
+ assert (
202
+ len(trainpairs) == 22232 and len(valpairs) == 640
203
+ ), "incorrect parsing of pairs in MPI-Sintel"
204
+ tosave = {"train": trainpairs, "val": valpairs}
205
+ return tosave
206
+
207
+
208
+ class FlyingThingsDataset(FlowDataset):
209
+
210
+ def _prepare_data(self):
211
+ self.name = "FlyingThings"
212
+ self._set_root()
213
+ assert self.split in [
214
+ f"{set_}_{pass_}pass{camstr}"
215
+ for set_ in ["train", "test", "test1024"]
216
+ for camstr in ["", "_rightcam"]
217
+ for pass_ in ["clean", "final", "all"]
218
+ ]
219
+ self.pairname_to_img1name = lambda pairname: osp.join(
220
+ self.root,
221
+ f"frames_{pairname[3]}pass",
222
+ pairname[0].replace("into_future", "").replace("into_past", ""),
223
+ "{:04d}.png".format(pairname[1]),
224
+ )
225
+ self.pairname_to_img2name = lambda pairname: osp.join(
226
+ self.root,
227
+ f"frames_{pairname[3]}pass",
228
+ pairname[0].replace("into_future", "").replace("into_past", ""),
229
+ "{:04d}.png".format(pairname[2]),
230
+ )
231
+ self.pairname_to_flowname = lambda pairname: osp.join(
232
+ self.root,
233
+ "optical_flow",
234
+ pairname[0],
235
+ "OpticalFlowInto{f:s}_{i:04d}_{c:s}.pfm".format(
236
+ f="Future" if "future" in pairname[0] else "Past",
237
+ i=pairname[1],
238
+ c="L" if "left" in pairname[0] else "R",
239
+ ),
240
+ )
241
+ self.pairname_to_str = lambda pairname: os.path.join(
242
+ pairname[3] + "pass",
243
+ pairname[0],
244
+ "Into{f:s}_{i:04d}_{c:s}".format(
245
+ f="Future" if "future" in pairname[0] else "Past",
246
+ i=pairname[1],
247
+ c="L" if "left" in pairname[0] else "R",
248
+ ),
249
+ )
250
+ self.load_flow = _read_pfm_flow
251
+
252
+ def _build_cache(self):
253
+ tosave = {}
254
+ # train and test splits for the different passes
255
+ for set_ in ["train", "test"]:
256
+ sroot = osp.join(self.root, "optical_flow", set_.upper())
257
+ fname_to_i = lambda f: int(
258
+ f[len("OpticalFlowIntoFuture_") : -len("_L.pfm")]
259
+ )
260
+ pp = [
261
+ (osp.join(set_.upper(), d, s, "into_future/left"), fname_to_i(fname))
262
+ for d in sorted(os.listdir(sroot))
263
+ for s in sorted(os.listdir(osp.join(sroot, d)))
264
+ for fname in sorted(
265
+ os.listdir(osp.join(sroot, d, s, "into_future/left"))
266
+ )[:-1]
267
+ ]
268
+ pairs = [(a, i, i + 1) for a, i in pp]
269
+ pairs += [(a.replace("into_future", "into_past"), i + 1, i) for a, i in pp]
270
+ assert (
271
+ len(pairs) == {"train": 40302, "test": 7866}[set_]
272
+ ), "incorrect parsing of pairs Flying Things"
273
+ for cam in ["left", "right"]:
274
+ camstr = "" if cam == "left" else f"_{cam}cam"
275
+ for pass_ in ["final", "clean"]:
276
+ tosave[f"{set_}_{pass_}pass{camstr}"] = [
277
+ (a.replace("left", cam), i, j, pass_) for a, i, j in pairs
278
+ ]
279
+ tosave[f"{set_}_allpass{camstr}"] = (
280
+ tosave[f"{set_}_cleanpass{camstr}"]
281
+ + tosave[f"{set_}_finalpass{camstr}"]
282
+ )
283
+ # test1024: this is the same split as unimatch 'validation' split
284
+ # see https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/datasets.py#L229
285
+ test1024_nsamples = 1024
286
+ alltest_nsamples = len(tosave["test_cleanpass"]) # 7866
287
+ stride = alltest_nsamples // test1024_nsamples
288
+ remove = alltest_nsamples % test1024_nsamples
289
+ for cam in ["left", "right"]:
290
+ camstr = "" if cam == "left" else f"_{cam}cam"
291
+ for pass_ in ["final", "clean"]:
292
+ tosave[f"test1024_{pass_}pass{camstr}"] = sorted(
293
+ tosave[f"test_{pass_}pass{camstr}"]
294
+ )[:-remove][
295
+ ::stride
296
+ ] # warning, it was not sorted before
297
+ assert (
298
+ len(tosave["test1024_cleanpass"]) == 1024
299
+ ), "incorrect parsing of pairs in Flying Things"
300
+ tosave[f"test1024_allpass{camstr}"] = (
301
+ tosave[f"test1024_cleanpass{camstr}"]
302
+ + tosave[f"test1024_finalpass{camstr}"]
303
+ )
304
+ return tosave
305
+
306
+
307
+ class MPISintelDataset(FlowDataset):
308
+
309
+ def _prepare_data(self):
310
+ self.name = "MPISintel"
311
+ self._set_root()
312
+ assert self.split in [
313
+ s + "_" + p
314
+ for s in ["train", "test", "subval", "subtrain"]
315
+ for p in ["cleanpass", "finalpass", "allpass"]
316
+ ]
317
+ self.pairname_to_img1name = lambda pairname: osp.join(
318
+ self.root, pairname[0], "frame_{:04d}.png".format(pairname[1])
319
+ )
320
+ self.pairname_to_img2name = lambda pairname: osp.join(
321
+ self.root, pairname[0], "frame_{:04d}.png".format(pairname[1] + 1)
322
+ )
323
+ self.pairname_to_flowname = lambda pairname: (
324
+ None
325
+ if pairname[0].startswith("test/")
326
+ else osp.join(
327
+ self.root,
328
+ pairname[0].replace("/clean/", "/flow/").replace("/final/", "/flow/"),
329
+ "frame_{:04d}.flo".format(pairname[1]),
330
+ )
331
+ )
332
+ self.pairname_to_str = lambda pairname: osp.join(
333
+ pairname[0], "frame_{:04d}".format(pairname[1])
334
+ )
335
+ self.load_flow = _read_flo_file
336
+
337
+ def _build_cache(self):
338
+ trainseqs = sorted(os.listdir(self.root + "training/clean"))
339
+ trainpairs = [
340
+ (osp.join("training/clean", s), i)
341
+ for s in trainseqs
342
+ for i in range(1, len(os.listdir(self.root + "training/clean/" + s)))
343
+ ]
344
+ subvalseqs = ["temple_2", "temple_3"]
345
+ subtrainseqs = [s for s in trainseqs if s not in subvalseqs]
346
+ subvalpairs = [(p, i) for p, i in trainpairs if any(s in p for s in subvalseqs)]
347
+ subtrainpairs = [
348
+ (p, i) for p, i in trainpairs if any(s in p for s in subtrainseqs)
349
+ ]
350
+ testseqs = sorted(os.listdir(self.root + "test/clean"))
351
+ testpairs = [
352
+ (osp.join("test/clean", s), i)
353
+ for s in testseqs
354
+ for i in range(1, len(os.listdir(self.root + "test/clean/" + s)))
355
+ ]
356
+ assert (
357
+ len(trainpairs) == 1041
358
+ and len(testpairs) == 552
359
+ and len(subvalpairs) == 98
360
+ and len(subtrainpairs) == 943
361
+ ), "incorrect parsing of pairs in MPI-Sintel"
362
+ tosave = {}
363
+ tosave["train_cleanpass"] = trainpairs
364
+ tosave["test_cleanpass"] = testpairs
365
+ tosave["subval_cleanpass"] = subvalpairs
366
+ tosave["subtrain_cleanpass"] = subtrainpairs
367
+ for t in ["train", "test", "subval", "subtrain"]:
368
+ tosave[t + "_finalpass"] = [
369
+ (p.replace("/clean/", "/final/"), i)
370
+ for p, i in tosave[t + "_cleanpass"]
371
+ ]
372
+ tosave[t + "_allpass"] = tosave[t + "_cleanpass"] + tosave[t + "_finalpass"]
373
+ return tosave
374
+
375
+ def submission_save_pairname(self, pairname, prediction, outdir, _time):
376
+ assert prediction.shape[2] == 2
377
+ outfile = os.path.join(
378
+ outdir, "submission", self.pairname_to_str(pairname) + ".flo"
379
+ )
380
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
381
+ writeFlowFile(prediction, outfile)
382
+
383
+ def finalize_submission(self, outdir):
384
+ assert self.split == "test_allpass"
385
+ bundle_exe = "/nfs/data/ffs-3d/datasets/StereoFlow/MPI-Sintel/bundler/linux-x64/bundler" # eg <bundle_exe> <path_to_results_for_clean> <path_to_results_for_final> <output/bundled.lzma>
386
+ if os.path.isfile(bundle_exe):
387
+ cmd = f'{bundle_exe} "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"'
388
+ print(cmd)
389
+ os.system(cmd)
390
+ print(f'Done. Submission file at: "{outdir}/submission/bundled.lzma"')
391
+ else:
392
+ print("Could not find bundler executable for submission.")
393
+ print("Please download it and run:")
394
+ print(
395
+ f'<bundle_exe> "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"'
396
+ )
397
+
398
+
399
+ class SpringDataset(FlowDataset):
400
+
401
+ def _prepare_data(self):
402
+ self.name = "Spring"
403
+ self._set_root()
404
+ assert self.split in ["train", "test", "subtrain", "subval"]
405
+ self.pairname_to_img1name = lambda pairname: osp.join(
406
+ self.root,
407
+ pairname[0],
408
+ pairname[1],
409
+ "frame_" + pairname[3],
410
+ "frame_{:s}_{:04d}.png".format(pairname[3], pairname[4]),
411
+ )
412
+ self.pairname_to_img2name = lambda pairname: osp.join(
413
+ self.root,
414
+ pairname[0],
415
+ pairname[1],
416
+ "frame_" + pairname[3],
417
+ "frame_{:s}_{:04d}.png".format(
418
+ pairname[3], pairname[4] + (1 if pairname[2] == "FW" else -1)
419
+ ),
420
+ )
421
+ self.pairname_to_flowname = lambda pairname: (
422
+ None
423
+ if pairname[0] == "test"
424
+ else osp.join(
425
+ self.root,
426
+ pairname[0],
427
+ pairname[1],
428
+ f"flow_{pairname[2]}_{pairname[3]}",
429
+ f"flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5",
430
+ )
431
+ )
432
+ self.pairname_to_str = lambda pairname: osp.join(
433
+ pairname[0],
434
+ pairname[1],
435
+ f"flow_{pairname[2]}_{pairname[3]}",
436
+ f"flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}",
437
+ )
438
+ self.load_flow = _read_hdf5_flow
439
+
440
+ def _build_cache(self):
441
+ # train
442
+ trainseqs = sorted(os.listdir(osp.join(self.root, "train")))
443
+ trainpairs = []
444
+ for leftright in ["left", "right"]:
445
+ for fwbw in ["FW", "BW"]:
446
+ trainpairs += [
447
+ (
448
+ "train",
449
+ s,
450
+ fwbw,
451
+ leftright,
452
+ int(f[len(f"flow_{fwbw}_{leftright}_") : -len(".flo5")]),
453
+ )
454
+ for s in trainseqs
455
+ for f in sorted(
456
+ os.listdir(
457
+ osp.join(self.root, "train", s, f"flow_{fwbw}_{leftright}")
458
+ )
459
+ )
460
+ ]
461
+ # test
462
+ testseqs = sorted(os.listdir(osp.join(self.root, "test")))
463
+ testpairs = []
464
+ for leftright in ["left", "right"]:
465
+ testpairs += [
466
+ (
467
+ "test",
468
+ s,
469
+ "FW",
470
+ leftright,
471
+ int(f[len(f"frame_{leftright}_") : -len(".png")]),
472
+ )
473
+ for s in testseqs
474
+ for f in sorted(
475
+ os.listdir(osp.join(self.root, "test", s, f"frame_{leftright}"))
476
+ )[:-1]
477
+ ]
478
+ testpairs += [
479
+ (
480
+ "test",
481
+ s,
482
+ "BW",
483
+ leftright,
484
+ int(f[len(f"frame_{leftright}_") : -len(".png")]) + 1,
485
+ )
486
+ for s in testseqs
487
+ for f in sorted(
488
+ os.listdir(osp.join(self.root, "test", s, f"frame_{leftright}"))
489
+ )[:-1]
490
+ ]
491
+ # subtrain / subval
492
+ subtrainpairs = [p for p in trainpairs if p[1] != "0041"]
493
+ subvalpairs = [p for p in trainpairs if p[1] == "0041"]
494
+ assert (
495
+ len(trainpairs) == 19852
496
+ and len(testpairs) == 3960
497
+ and len(subtrainpairs) == 19472
498
+ and len(subvalpairs) == 380
499
+ ), "incorrect parsing of pairs in Spring"
500
+ tosave = {
501
+ "train": trainpairs,
502
+ "test": testpairs,
503
+ "subtrain": subtrainpairs,
504
+ "subval": subvalpairs,
505
+ }
506
+ return tosave
507
+
508
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
509
+ assert prediction.ndim == 3
510
+ assert prediction.shape[2] == 2
511
+ assert prediction.dtype == np.float32
512
+ outfile = osp.join(
513
+ outdir,
514
+ pairname[0],
515
+ pairname[1],
516
+ f"flow_{pairname[2]}_{pairname[3]}",
517
+ f"flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5",
518
+ )
519
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
520
+ writeFlo5File(prediction, outfile)
521
+
522
+ def finalize_submission(self, outdir):
523
+ assert self.split == "test"
524
+ exe = "{self.root}/flow_subsampling"
525
+ if os.path.isfile(exe):
526
+ cmd = f'cd "{outdir}/test"; {exe} .'
527
+ print(cmd)
528
+ os.system(cmd)
529
+ print(f"Done. Submission file at {outdir}/test/flow_submission.hdf5")
530
+ else:
531
+ print("Could not find flow_subsampling executable for submission.")
532
+ print("Please download it and run:")
533
+ print(f'cd "{outdir}/test"; <flow_subsampling_exe> .')
534
+
535
+
536
+ class Kitti12Dataset(FlowDataset):
537
+
538
+ def _prepare_data(self):
539
+ self.name = "Kitti12"
540
+ self._set_root()
541
+ assert self.split in ["train", "test"]
542
+ self.pairname_to_img1name = lambda pairname: osp.join(
543
+ self.root, pairname + "_10.png"
544
+ )
545
+ self.pairname_to_img2name = lambda pairname: osp.join(
546
+ self.root, pairname + "_11.png"
547
+ )
548
+ self.pairname_to_flowname = (
549
+ None
550
+ if self.split == "test"
551
+ else lambda pairname: osp.join(
552
+ self.root, pairname.replace("/colored_0/", "/flow_occ/") + "_10.png"
553
+ )
554
+ )
555
+ self.pairname_to_str = lambda pairname: pairname.replace("/colored_0/", "/")
556
+ self.load_flow = _read_kitti_flow
557
+
558
+ def _build_cache(self):
559
+ trainseqs = ["training/colored_0/%06d" % (i) for i in range(194)]
560
+ testseqs = ["testing/colored_0/%06d" % (i) for i in range(195)]
561
+ assert (
562
+ len(trainseqs) == 194 and len(testseqs) == 195
563
+ ), "incorrect parsing of pairs in Kitti12"
564
+ tosave = {"train": trainseqs, "test": testseqs}
565
+ return tosave
566
+
567
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
568
+ assert prediction.ndim == 3
569
+ assert prediction.shape[2] == 2
570
+ outfile = os.path.join(outdir, pairname.split("/")[-1] + "_10.png")
571
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
572
+ writeFlowKitti(outfile, prediction)
573
+
574
+ def finalize_submission(self, outdir):
575
+ assert self.split == "test"
576
+ cmd = f'cd {outdir}/; zip -r "kitti12_flow_results.zip" .'
577
+ print(cmd)
578
+ os.system(cmd)
579
+ print(f"Done. Submission file at {outdir}/kitti12_flow_results.zip")
580
+
581
+
582
+ class Kitti15Dataset(FlowDataset):
583
+
584
+ def _prepare_data(self):
585
+ self.name = "Kitti15"
586
+ self._set_root()
587
+ assert self.split in ["train", "subtrain", "subval", "test"]
588
+ self.pairname_to_img1name = lambda pairname: osp.join(
589
+ self.root, pairname + "_10.png"
590
+ )
591
+ self.pairname_to_img2name = lambda pairname: osp.join(
592
+ self.root, pairname + "_11.png"
593
+ )
594
+ self.pairname_to_flowname = (
595
+ None
596
+ if self.split == "test"
597
+ else lambda pairname: osp.join(
598
+ self.root, pairname.replace("/image_2/", "/flow_occ/") + "_10.png"
599
+ )
600
+ )
601
+ self.pairname_to_str = lambda pairname: pairname.replace("/image_2/", "/")
602
+ self.load_flow = _read_kitti_flow
603
+
604
+ def _build_cache(self):
605
+ trainseqs = ["training/image_2/%06d" % (i) for i in range(200)]
606
+ subtrainseqs = trainseqs[:-10]
607
+ subvalseqs = trainseqs[-10:]
608
+ testseqs = ["testing/image_2/%06d" % (i) for i in range(200)]
609
+ assert (
610
+ len(trainseqs) == 200
611
+ and len(subtrainseqs) == 190
612
+ and len(subvalseqs) == 10
613
+ and len(testseqs) == 200
614
+ ), "incorrect parsing of pairs in Kitti15"
615
+ tosave = {
616
+ "train": trainseqs,
617
+ "subtrain": subtrainseqs,
618
+ "subval": subvalseqs,
619
+ "test": testseqs,
620
+ }
621
+ return tosave
622
+
623
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
624
+ assert prediction.ndim == 3
625
+ assert prediction.shape[2] == 2
626
+ outfile = os.path.join(outdir, "flow", pairname.split("/")[-1] + "_10.png")
627
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
628
+ writeFlowKitti(outfile, prediction)
629
+
630
+ def finalize_submission(self, outdir):
631
+ assert self.split == "test"
632
+ cmd = f'cd {outdir}/; zip -r "kitti15_flow_results.zip" flow'
633
+ print(cmd)
634
+ os.system(cmd)
635
+ print(f"Done. Submission file at {outdir}/kitti15_flow_results.zip")
636
+
637
+
638
+ import cv2
639
+
640
+
641
+ def _read_numpy_flow(filename):
642
+ return np.load(filename)
643
+
644
+
645
+ def _read_pfm_flow(filename):
646
+ f, _ = _read_pfm(filename)
647
+ assert np.all(f[:, :, 2] == 0.0)
648
+ return np.ascontiguousarray(f[:, :, :2])
649
+
650
+
651
+ TAG_FLOAT = 202021.25 # tag to check the sanity of the file
652
+ TAG_STRING = "PIEH" # string containing the tag
653
+ MIN_WIDTH = 1
654
+ MAX_WIDTH = 99999
655
+ MIN_HEIGHT = 1
656
+ MAX_HEIGHT = 99999
657
+
658
+
659
+ def readFlowFile(filename):
660
+ """
661
+ readFlowFile(<FILENAME>) reads a flow file <FILENAME> into a 2-band np.array.
662
+ if <FILENAME> does not exist, an IOError is raised.
663
+ if <FILENAME> does not finish by '.flo' or the tag, the width, the height or the file's size is illegal, an Expcetion is raised.
664
+ ---- PARAMETERS ----
665
+ filename: string containg the name of the file to read a flow
666
+ ---- OUTPUTS ----
667
+ a np.array of dimension (height x width x 2) containing the flow of type 'float32'
668
+ """
669
+
670
+ # check filename
671
+ if not filename.endswith(".flo"):
672
+ raise Exception(
673
+ "readFlowFile({:s}): filename must finish with '.flo'".format(filename)
674
+ )
675
+
676
+ # open the file and read it
677
+ with open(filename, "rb") as f:
678
+ # check tag
679
+ tag = struct.unpack("f", f.read(4))[0]
680
+ if tag != TAG_FLOAT:
681
+ raise Exception("flow_utils.readFlowFile({:s}): wrong tag".format(filename))
682
+ # read dimension
683
+ w, h = struct.unpack("ii", f.read(8))
684
+ if w < MIN_WIDTH or w > MAX_WIDTH:
685
+ raise Exception(
686
+ "flow_utils.readFlowFile({:s}: illegal width {:d}".format(filename, w)
687
+ )
688
+ if h < MIN_HEIGHT or h > MAX_HEIGHT:
689
+ raise Exception(
690
+ "flow_utils.readFlowFile({:s}: illegal height {:d}".format(filename, h)
691
+ )
692
+ flow = np.fromfile(f, "float32")
693
+ if not flow.shape == (h * w * 2,):
694
+ raise Exception(
695
+ "flow_utils.readFlowFile({:s}: illegal size of the file".format(
696
+ filename
697
+ )
698
+ )
699
+ flow.shape = (h, w, 2)
700
+ return flow
701
+
702
+
703
+ def writeFlowFile(flow, filename):
704
+ """
705
+ writeFlowFile(flow,<FILENAME>) write flow to the file <FILENAME>.
706
+ if <FILENAME> does not exist, an IOError is raised.
707
+ if <FILENAME> does not finish with '.flo' or the flow has not 2 bands, an Exception is raised.
708
+ ---- PARAMETERS ----
709
+ flow: np.array of dimension (height x width x 2) containing the flow to write
710
+ filename: string containg the name of the file to write a flow
711
+ """
712
+
713
+ # check filename
714
+ if not filename.endswith(".flo"):
715
+ raise Exception(
716
+ "flow_utils.writeFlowFile(<flow>,{:s}): filename must finish with '.flo'".format(
717
+ filename
718
+ )
719
+ )
720
+
721
+ if not flow.shape[2:] == (2,):
722
+ raise Exception(
723
+ "flow_utils.writeFlowFile(<flow>,{:s}): <flow> must have 2 bands".format(
724
+ filename
725
+ )
726
+ )
727
+
728
+ # open the file and write it
729
+ with open(filename, "wb") as f:
730
+ # write TAG
731
+ f.write(TAG_STRING.encode("utf-8"))
732
+ # write dimension
733
+ f.write(struct.pack("ii", flow.shape[1], flow.shape[0]))
734
+ # write the flow
735
+
736
+ flow.astype(np.float32).tofile(f)
737
+
738
+
739
+ _read_flo_file = readFlowFile
740
+
741
+
742
+ def _read_kitti_flow(filename):
743
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
744
+ flow = flow[:, :, ::-1].astype(np.float32)
745
+ valid = flow[:, :, 2] > 0
746
+ flow = flow[:, :, :2]
747
+ flow = (flow - 2**15) / 64.0
748
+ flow[~valid, 0] = np.inf
749
+ flow[~valid, 1] = np.inf
750
+ return flow
751
+
752
+
753
+ _read_hd1k_flow = _read_kitti_flow
754
+
755
+
756
+ def writeFlowKitti(filename, uv):
757
+ uv = 64.0 * uv + 2**15
758
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
759
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
760
+ cv2.imwrite(filename, uv[..., ::-1])
761
+
762
+
763
+ def writeFlo5File(flow, filename):
764
+ with h5py.File(filename, "w") as f:
765
+ f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
766
+
767
+
768
+ def _read_hdf5_flow(filename):
769
+ flow = np.asarray(h5py.File(filename)["flow"])
770
+ flow[np.isnan(flow)] = np.inf # make invalid values as +inf
771
+ return flow.astype(np.float32)
772
+
773
+
774
+ # flow visualization
775
+ RY = 15
776
+ YG = 6
777
+ GC = 4
778
+ CB = 11
779
+ BM = 13
780
+ MR = 6
781
+ UNKNOWN_THRESH = 1e9
782
+
783
+
784
+ def colorTest():
785
+ """
786
+ flow_utils.colorTest(): display an example of image showing the color encoding scheme
787
+ """
788
+ import matplotlib.pylab as plt
789
+
790
+ truerange = 1
791
+ h, w = 151, 151
792
+ trange = truerange * 1.04
793
+ s2 = round(h / 2)
794
+ x, y = np.meshgrid(range(w), range(h))
795
+ u = x * trange / s2 - trange
796
+ v = y * trange / s2 - trange
797
+ img = _computeColor(
798
+ np.concatenate((u[:, :, np.newaxis], v[:, :, np.newaxis]), 2)
799
+ / trange
800
+ / np.sqrt(2)
801
+ )
802
+ plt.imshow(img)
803
+ plt.axis("off")
804
+ plt.axhline(round(h / 2), color="k")
805
+ plt.axvline(round(w / 2), color="k")
806
+
807
+
808
+ def flowToColor(flow, maxflow=None, maxmaxflow=None, saturate=False):
809
+ """
810
+ flow_utils.flowToColor(flow): return a color code flow field, normalized based on the maximum l2-norm of the flow
811
+ flow_utils.flowToColor(flow,maxflow): return a color code flow field, normalized by maxflow
812
+ ---- PARAMETERS ----
813
+ flow: flow to display of shape (height x width x 2)
814
+ maxflow (default:None): if given, normalize the flow by its value, otherwise by the flow norm
815
+ maxmaxflow (default:None): if given, normalize the flow by the max of its value and the flow norm
816
+ ---- OUTPUT ----
817
+ an np.array of shape (height x width x 3) of type uint8 containing a color code of the flow
818
+ """
819
+ h, w, n = flow.shape
820
+ # check size of flow
821
+ assert n == 2, "flow_utils.flowToColor(flow): flow must have 2 bands"
822
+ # fix unknown flow
823
+ unknown_idx = np.max(np.abs(flow), 2) > UNKNOWN_THRESH
824
+ flow[unknown_idx] = 0.0
825
+ # compute max flow if needed
826
+ if maxflow is None:
827
+ maxflow = flowMaxNorm(flow)
828
+ if maxmaxflow is not None:
829
+ maxflow = min(maxmaxflow, maxflow)
830
+ # normalize flow
831
+ eps = np.spacing(1) # minimum positive float value to avoid division by 0
832
+ # compute the flow
833
+ img = _computeColor(flow / (maxflow + eps), saturate=saturate)
834
+ # put black pixels in unknown location
835
+ img[np.tile(unknown_idx[:, :, np.newaxis], [1, 1, 3])] = 0.0
836
+ return img
837
+
838
+
839
+ def flowMaxNorm(flow):
840
+ """
841
+ flow_utils.flowMaxNorm(flow): return the maximum of the l2-norm of the given flow
842
+ ---- PARAMETERS ----
843
+ flow: the flow
844
+
845
+ ---- OUTPUT ----
846
+ a float containing the maximum of the l2-norm of the flow
847
+ """
848
+ return np.max(np.sqrt(np.sum(np.square(flow), 2)))
849
+
850
+
851
+ def _computeColor(flow, saturate=True):
852
+ """
853
+ flow_utils._computeColor(flow): compute color codes for the flow field flow
854
+
855
+ ---- PARAMETERS ----
856
+ flow: np.array of dimension (height x width x 2) containing the flow to display
857
+ ---- OUTPUTS ----
858
+ an np.array of dimension (height x width x 3) containing the color conversion of the flow
859
+ """
860
+ # set nan to 0
861
+ nanidx = np.isnan(flow[:, :, 0])
862
+ flow[nanidx] = 0.0
863
+
864
+ # colorwheel
865
+ ncols = RY + YG + GC + CB + BM + MR
866
+ nchans = 3
867
+ colorwheel = np.zeros((ncols, nchans), "uint8")
868
+ col = 0
869
+ # RY
870
+ colorwheel[:RY, 0] = 255
871
+ colorwheel[:RY, 1] = [(255 * i) // RY for i in range(RY)]
872
+ col += RY
873
+ # YG
874
+ colorwheel[col : col + YG, 0] = [255 - (255 * i) // YG for i in range(YG)]
875
+ colorwheel[col : col + YG, 1] = 255
876
+ col += YG
877
+ # GC
878
+ colorwheel[col : col + GC, 1] = 255
879
+ colorwheel[col : col + GC, 2] = [(255 * i) // GC for i in range(GC)]
880
+ col += GC
881
+ # CB
882
+ colorwheel[col : col + CB, 1] = [255 - (255 * i) // CB for i in range(CB)]
883
+ colorwheel[col : col + CB, 2] = 255
884
+ col += CB
885
+ # BM
886
+ colorwheel[col : col + BM, 0] = [(255 * i) // BM for i in range(BM)]
887
+ colorwheel[col : col + BM, 2] = 255
888
+ col += BM
889
+ # MR
890
+ colorwheel[col : col + MR, 0] = 255
891
+ colorwheel[col : col + MR, 2] = [255 - (255 * i) // MR for i in range(MR)]
892
+
893
+ # compute utility variables
894
+ rad = np.sqrt(np.sum(np.square(flow), 2)) # magnitude
895
+ a = np.arctan2(-flow[:, :, 1], -flow[:, :, 0]) / np.pi # angle
896
+ fk = (a + 1) / 2 * (ncols - 1) # map [-1,1] to [0,ncols-1]
897
+ k0 = np.floor(fk).astype("int")
898
+ k1 = k0 + 1
899
+ k1[k1 == ncols] = 0
900
+ f = fk - k0
901
+
902
+ if not saturate:
903
+ rad = np.minimum(rad, 1)
904
+
905
+ # compute the image
906
+ img = np.zeros((flow.shape[0], flow.shape[1], nchans), "uint8")
907
+ for i in range(nchans):
908
+ tmp = colorwheel[:, i].astype("float")
909
+ col0 = tmp[k0] / 255
910
+ col1 = tmp[k1] / 255
911
+ col = (1 - f) * col0 + f * col1
912
+ idx = rad <= 1
913
+ col[idx] = 1 - rad[idx] * (1 - col[idx]) # increase saturation with radius
914
+ col[~idx] *= 0.75 # out of range
915
+ img[:, :, i] = (255 * col * (1 - nanidx.astype("float"))).astype("uint8")
916
+
917
+ return img
918
+
919
+
920
+ # flow dataset getter
921
+
922
+
923
+ def get_train_dataset_flow(dataset_str, augmentor=True, crop_size=None):
924
+ dataset_str = dataset_str.replace("(", "Dataset(")
925
+ if augmentor:
926
+ dataset_str = dataset_str.replace(")", ", augmentor=True)")
927
+ if crop_size is not None:
928
+ dataset_str = dataset_str.replace(
929
+ ")", ", crop_size={:s})".format(str(crop_size))
930
+ )
931
+ return eval(dataset_str)
932
+
933
+
934
+ def get_test_datasets_flow(dataset_str):
935
+ dataset_str = dataset_str.replace("(", "Dataset(")
936
+ return [eval(s) for s in dataset_str.split("+")]
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/datasets_stereo.py ADDED
@@ -0,0 +1,991 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Dataset structure for stereo
6
+ # --------------------------------------------------------
7
+
8
+ import sys, os
9
+ import os.path as osp
10
+ import pickle
11
+ import numpy as np
12
+ from PIL import Image
13
+ import json
14
+ import h5py
15
+ from glob import glob
16
+ import cv2
17
+
18
+ import torch
19
+ from torch.utils import data
20
+
21
+ from .augmentor import StereoAugmentor
22
+
23
+
24
+ dataset_to_root = {
25
+ "CREStereo": "./data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/",
26
+ "SceneFlow": "./data/stereoflow//SceneFlow/",
27
+ "ETH3DLowRes": "./data/stereoflow/eth3d_lowres/",
28
+ "Booster": "./data/stereoflow/booster_gt/",
29
+ "Middlebury2021": "./data/stereoflow/middlebury/2021/data/",
30
+ "Middlebury2014": "./data/stereoflow/middlebury/2014/",
31
+ "Middlebury2006": "./data/stereoflow/middlebury/2006/",
32
+ "Middlebury2005": "./data/stereoflow/middlebury/2005/train/",
33
+ "MiddleburyEval3": "./data/stereoflow/middlebury/MiddEval3/",
34
+ "Spring": "./data/stereoflow/spring/",
35
+ "Kitti15": "./data/stereoflow/kitti-stereo-2015/",
36
+ "Kitti12": "./data/stereoflow/kitti-stereo-2012/",
37
+ }
38
+ cache_dir = "./data/stereoflow/datasets_stereo_cache/"
39
+
40
+
41
+ in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
42
+ in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
43
+
44
+
45
+ def img_to_tensor(img):
46
+ img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
47
+ img = (img - in1k_mean) / in1k_std
48
+ return img
49
+
50
+
51
+ def disp_to_tensor(disp):
52
+ return torch.from_numpy(disp)[None, :, :]
53
+
54
+
55
+ class StereoDataset(data.Dataset):
56
+
57
+ def __init__(self, split, augmentor=False, crop_size=None, totensor=True):
58
+ self.split = split
59
+ if not augmentor:
60
+ assert crop_size is None
61
+ if crop_size:
62
+ assert augmentor
63
+ self.crop_size = crop_size
64
+ self.augmentor_str = augmentor
65
+ self.augmentor = StereoAugmentor(crop_size) if augmentor else None
66
+ self.totensor = totensor
67
+ self.rmul = 1 # keep track of rmul
68
+ self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time)
69
+ self._prepare_data()
70
+ self._load_or_build_cache()
71
+
72
+ def prepare_data(self):
73
+ """
74
+ to be defined for each dataset
75
+ """
76
+ raise NotImplementedError
77
+
78
+ def __len__(self):
79
+ return len(self.pairnames)
80
+
81
+ def __getitem__(self, index):
82
+ pairname = self.pairnames[index]
83
+
84
+ # get filenames
85
+ Limgname = self.pairname_to_Limgname(pairname)
86
+ Rimgname = self.pairname_to_Rimgname(pairname)
87
+ Ldispname = (
88
+ self.pairname_to_Ldispname(pairname)
89
+ if self.pairname_to_Ldispname is not None
90
+ else None
91
+ )
92
+
93
+ # load images and disparities
94
+ Limg = _read_img(Limgname)
95
+ Rimg = _read_img(Rimgname)
96
+ disp = self.load_disparity(Ldispname) if Ldispname is not None else None
97
+
98
+ # sanity check
99
+ if disp is not None:
100
+ assert np.all(disp > 0) or self.name == "Spring", (
101
+ self.name,
102
+ pairname,
103
+ Ldispname,
104
+ )
105
+
106
+ # apply augmentations
107
+ if self.augmentor is not None:
108
+ Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name)
109
+
110
+ if self.totensor:
111
+ Limg = img_to_tensor(Limg)
112
+ Rimg = img_to_tensor(Rimg)
113
+ if disp is None:
114
+ disp = torch.tensor(
115
+ []
116
+ ) # to allow dataloader batching with default collate_gn
117
+ else:
118
+ disp = disp_to_tensor(disp)
119
+
120
+ return Limg, Rimg, disp, str(pairname)
121
+
122
+ def __rmul__(self, v):
123
+ self.rmul *= v
124
+ self.pairnames = v * self.pairnames
125
+ return self
126
+
127
+ def __str__(self):
128
+ return f"{self.__class__.__name__}_{self.split}"
129
+
130
+ def __repr__(self):
131
+ s = f"{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})"
132
+ if self.rmul == 1:
133
+ s += f"\n\tnum pairs: {len(self.pairnames)}"
134
+ else:
135
+ s += f"\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})"
136
+ return s
137
+
138
+ def _set_root(self):
139
+ self.root = dataset_to_root[self.name]
140
+ assert os.path.isdir(
141
+ self.root
142
+ ), f"could not find root directory for dataset {self.name}: {self.root}"
143
+
144
+ def _load_or_build_cache(self):
145
+ cache_file = osp.join(cache_dir, self.name + ".pkl")
146
+ if osp.isfile(cache_file):
147
+ with open(cache_file, "rb") as fid:
148
+ self.pairnames = pickle.load(fid)[self.split]
149
+ else:
150
+ tosave = self._build_cache()
151
+ os.makedirs(cache_dir, exist_ok=True)
152
+ with open(cache_file, "wb") as fid:
153
+ pickle.dump(tosave, fid)
154
+ self.pairnames = tosave[self.split]
155
+
156
+
157
+ class CREStereoDataset(StereoDataset):
158
+
159
+ def _prepare_data(self):
160
+ self.name = "CREStereo"
161
+ self._set_root()
162
+ assert self.split in ["train"]
163
+ self.pairname_to_Limgname = lambda pairname: osp.join(
164
+ self.root, pairname + "_left.jpg"
165
+ )
166
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
167
+ self.root, pairname + "_right.jpg"
168
+ )
169
+ self.pairname_to_Ldispname = lambda pairname: osp.join(
170
+ self.root, pairname + "_left.disp.png"
171
+ )
172
+ self.pairname_to_str = lambda pairname: pairname
173
+ self.load_disparity = _read_crestereo_disp
174
+
175
+ def _build_cache(self):
176
+ allpairs = [
177
+ s + "/" + f[: -len("_left.jpg")]
178
+ for s in sorted(os.listdir(self.root))
179
+ for f in sorted(os.listdir(self.root + "/" + s))
180
+ if f.endswith("_left.jpg")
181
+ ]
182
+ assert len(allpairs) == 200000, "incorrect parsing of pairs in CreStereo"
183
+ tosave = {"train": allpairs}
184
+ return tosave
185
+
186
+
187
+ class SceneFlowDataset(StereoDataset):
188
+
189
+ def _prepare_data(self):
190
+ self.name = "SceneFlow"
191
+ self._set_root()
192
+ assert self.split in [
193
+ "train_finalpass",
194
+ "train_cleanpass",
195
+ "train_allpass",
196
+ "test_finalpass",
197
+ "test_cleanpass",
198
+ "test_allpass",
199
+ "test1of100_cleanpass",
200
+ "test1of100_finalpass",
201
+ ]
202
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
203
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
204
+ self.root, pairname
205
+ ).replace("/left/", "/right/")
206
+ self.pairname_to_Ldispname = (
207
+ lambda pairname: osp.join(self.root, pairname)
208
+ .replace("/frames_finalpass/", "/disparity/")
209
+ .replace("/frames_cleanpass/", "/disparity/")[:-4]
210
+ + ".pfm"
211
+ )
212
+ self.pairname_to_str = lambda pairname: pairname[:-4]
213
+ self.load_disparity = _read_sceneflow_disp
214
+
215
+ def _build_cache(self):
216
+ trainpairs = []
217
+ # driving
218
+ pairs = sorted(glob(self.root + "Driving/frames_finalpass/*/*/*/left/*.png"))
219
+ pairs = list(map(lambda x: x[len(self.root) :], pairs))
220
+ assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow"
221
+ trainpairs += pairs
222
+ # monkaa
223
+ pairs = sorted(glob(self.root + "Monkaa/frames_finalpass/*/left/*.png"))
224
+ pairs = list(map(lambda x: x[len(self.root) :], pairs))
225
+ assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow"
226
+ trainpairs += pairs
227
+ # flyingthings
228
+ pairs = sorted(
229
+ glob(self.root + "FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png")
230
+ )
231
+ pairs = list(map(lambda x: x[len(self.root) :], pairs))
232
+ assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow"
233
+ trainpairs += pairs
234
+ assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow"
235
+ testpairs = sorted(
236
+ glob(self.root + "FlyingThings/frames_finalpass/TEST/*/*/left/*.png")
237
+ )
238
+ testpairs = list(map(lambda x: x[len(self.root) :], testpairs))
239
+ assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow"
240
+ test1of100pairs = testpairs[::100]
241
+ assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow"
242
+ # all
243
+ tosave = {
244
+ "train_finalpass": trainpairs,
245
+ "train_cleanpass": list(
246
+ map(
247
+ lambda x: x.replace("frames_finalpass", "frames_cleanpass"),
248
+ trainpairs,
249
+ )
250
+ ),
251
+ "test_finalpass": testpairs,
252
+ "test_cleanpass": list(
253
+ map(
254
+ lambda x: x.replace("frames_finalpass", "frames_cleanpass"),
255
+ testpairs,
256
+ )
257
+ ),
258
+ "test1of100_finalpass": test1of100pairs,
259
+ "test1of100_cleanpass": list(
260
+ map(
261
+ lambda x: x.replace("frames_finalpass", "frames_cleanpass"),
262
+ test1of100pairs,
263
+ )
264
+ ),
265
+ }
266
+ tosave["train_allpass"] = tosave["train_finalpass"] + tosave["train_cleanpass"]
267
+ tosave["test_allpass"] = tosave["test_finalpass"] + tosave["test_cleanpass"]
268
+ return tosave
269
+
270
+
271
+ class Md21Dataset(StereoDataset):
272
+
273
+ def _prepare_data(self):
274
+ self.name = "Middlebury2021"
275
+ self._set_root()
276
+ assert self.split in ["train", "subtrain", "subval"]
277
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
278
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
279
+ self.root, pairname.replace("/im0", "/im1")
280
+ )
281
+ self.pairname_to_Ldispname = lambda pairname: osp.join(
282
+ self.root, pairname.split("/")[0], "disp0.pfm"
283
+ )
284
+ self.pairname_to_str = lambda pairname: pairname[:-4]
285
+ self.load_disparity = _read_middlebury_disp
286
+
287
+ def _build_cache(self):
288
+ seqs = sorted(os.listdir(self.root))
289
+ trainpairs = []
290
+ for s in seqs:
291
+ # trainpairs += [s+'/im0.png'] # we should remove it, it is included as such in other lightings
292
+ trainpairs += [
293
+ s + "/ambient/" + b + "/" + a
294
+ for b in sorted(os.listdir(osp.join(self.root, s, "ambient")))
295
+ for a in sorted(os.listdir(osp.join(self.root, s, "ambient", b)))
296
+ if a.startswith("im0")
297
+ ]
298
+ assert len(trainpairs) == 355
299
+ subtrainpairs = [
300
+ p for p in trainpairs if any(p.startswith(s + "/") for s in seqs[:-2])
301
+ ]
302
+ subvalpairs = [
303
+ p for p in trainpairs if any(p.startswith(s + "/") for s in seqs[-2:])
304
+ ]
305
+ assert (
306
+ len(subtrainpairs) == 335 and len(subvalpairs) == 20
307
+ ), "incorrect parsing of pairs in Middlebury 2021"
308
+ tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs}
309
+ return tosave
310
+
311
+
312
+ class Md14Dataset(StereoDataset):
313
+
314
+ def _prepare_data(self):
315
+ self.name = "Middlebury2014"
316
+ self._set_root()
317
+ assert self.split in ["train", "subtrain", "subval"]
318
+ self.pairname_to_Limgname = lambda pairname: osp.join(
319
+ self.root, osp.dirname(pairname), "im0.png"
320
+ )
321
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname)
322
+ self.pairname_to_Ldispname = lambda pairname: osp.join(
323
+ self.root, osp.dirname(pairname), "disp0.pfm"
324
+ )
325
+ self.pairname_to_str = lambda pairname: pairname[:-4]
326
+ self.load_disparity = _read_middlebury_disp
327
+ self.has_constant_resolution = False
328
+
329
+ def _build_cache(self):
330
+ seqs = sorted(os.listdir(self.root))
331
+ trainpairs = []
332
+ for s in seqs:
333
+ trainpairs += [s + "/im1.png", s + "/im1E.png", s + "/im1L.png"]
334
+ assert len(trainpairs) == 138
335
+ valseqs = ["Umbrella-imperfect", "Vintage-perfect"]
336
+ assert all(s in seqs for s in valseqs)
337
+ subtrainpairs = [
338
+ p for p in trainpairs if not any(p.startswith(s + "/") for s in valseqs)
339
+ ]
340
+ subvalpairs = [
341
+ p for p in trainpairs if any(p.startswith(s + "/") for s in valseqs)
342
+ ]
343
+ assert (
344
+ len(subtrainpairs) == 132 and len(subvalpairs) == 6
345
+ ), "incorrect parsing of pairs in Middlebury 2014"
346
+ tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs}
347
+ return tosave
348
+
349
+
350
+ class Md06Dataset(StereoDataset):
351
+
352
+ def _prepare_data(self):
353
+ self.name = "Middlebury2006"
354
+ self._set_root()
355
+ assert self.split in ["train", "subtrain", "subval"]
356
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
357
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
358
+ self.root, osp.dirname(pairname), "view5.png"
359
+ )
360
+ self.pairname_to_Ldispname = lambda pairname: osp.join(
361
+ self.root, pairname.split("/")[0], "disp1.png"
362
+ )
363
+ self.load_disparity = _read_middlebury20052006_disp
364
+ self.has_constant_resolution = False
365
+
366
+ def _build_cache(self):
367
+ seqs = sorted(os.listdir(self.root))
368
+ trainpairs = []
369
+ for s in seqs:
370
+ for i in ["Illum1", "Illum2", "Illum3"]:
371
+ for e in ["Exp0", "Exp1", "Exp2"]:
372
+ trainpairs.append(osp.join(s, i, e, "view1.png"))
373
+ assert len(trainpairs) == 189
374
+ valseqs = ["Rocks1", "Wood2"]
375
+ assert all(s in seqs for s in valseqs)
376
+ subtrainpairs = [
377
+ p for p in trainpairs if not any(p.startswith(s + "/") for s in valseqs)
378
+ ]
379
+ subvalpairs = [
380
+ p for p in trainpairs if any(p.startswith(s + "/") for s in valseqs)
381
+ ]
382
+ assert (
383
+ len(subtrainpairs) == 171 and len(subvalpairs) == 18
384
+ ), "incorrect parsing of pairs in Middlebury 2006"
385
+ tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs}
386
+ return tosave
387
+
388
+
389
+ class Md05Dataset(StereoDataset):
390
+
391
+ def _prepare_data(self):
392
+ self.name = "Middlebury2005"
393
+ self._set_root()
394
+ assert self.split in ["train", "subtrain", "subval"]
395
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
396
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
397
+ self.root, osp.dirname(pairname), "view5.png"
398
+ )
399
+ self.pairname_to_Ldispname = lambda pairname: osp.join(
400
+ self.root, pairname.split("/")[0], "disp1.png"
401
+ )
402
+ self.pairname_to_str = lambda pairname: pairname[:-4]
403
+ self.load_disparity = _read_middlebury20052006_disp
404
+
405
+ def _build_cache(self):
406
+ seqs = sorted(os.listdir(self.root))
407
+ trainpairs = []
408
+ for s in seqs:
409
+ for i in ["Illum1", "Illum2", "Illum3"]:
410
+ for e in ["Exp0", "Exp1", "Exp2"]:
411
+ trainpairs.append(osp.join(s, i, e, "view1.png"))
412
+ assert len(trainpairs) == 54, "incorrect parsing of pairs in Middlebury 2005"
413
+ valseqs = ["Reindeer"]
414
+ assert all(s in seqs for s in valseqs)
415
+ subtrainpairs = [
416
+ p for p in trainpairs if not any(p.startswith(s + "/") for s in valseqs)
417
+ ]
418
+ subvalpairs = [
419
+ p for p in trainpairs if any(p.startswith(s + "/") for s in valseqs)
420
+ ]
421
+ assert (
422
+ len(subtrainpairs) == 45 and len(subvalpairs) == 9
423
+ ), "incorrect parsing of pairs in Middlebury 2005"
424
+ tosave = {"train": trainpairs, "subtrain": subtrainpairs, "subval": subvalpairs}
425
+ return tosave
426
+
427
+
428
+ class MdEval3Dataset(StereoDataset):
429
+
430
+ def _prepare_data(self):
431
+ self.name = "MiddleburyEval3"
432
+ self._set_root()
433
+ assert self.split in [
434
+ s + "_" + r
435
+ for s in ["train", "subtrain", "subval", "test", "all"]
436
+ for r in ["full", "half", "quarter"]
437
+ ]
438
+ if self.split.endswith("_full"):
439
+ self.root = self.root.replace("/MiddEval3", "/MiddEval3_F")
440
+ elif self.split.endswith("_half"):
441
+ self.root = self.root.replace("/MiddEval3", "/MiddEval3_H")
442
+ else:
443
+ assert self.split.endswith("_quarter")
444
+ self.pairname_to_Limgname = lambda pairname: osp.join(
445
+ self.root, pairname, "im0.png"
446
+ )
447
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
448
+ self.root, pairname, "im1.png"
449
+ )
450
+ self.pairname_to_Ldispname = lambda pairname: (
451
+ None
452
+ if pairname.startswith("test")
453
+ else osp.join(self.root, pairname, "disp0GT.pfm")
454
+ )
455
+ self.pairname_to_str = lambda pairname: pairname
456
+ self.load_disparity = _read_middlebury_disp
457
+ # for submission only
458
+ self.submission_methodname = "CroCo-Stereo"
459
+ self.submission_sresolution = (
460
+ "F"
461
+ if self.split.endswith("_full")
462
+ else ("H" if self.split.endswith("_half") else "Q")
463
+ )
464
+
465
+ def _build_cache(self):
466
+ trainpairs = ["train/" + s for s in sorted(os.listdir(self.root + "train/"))]
467
+ testpairs = ["test/" + s for s in sorted(os.listdir(self.root + "test/"))]
468
+ subvalpairs = trainpairs[-1:]
469
+ subtrainpairs = trainpairs[:-1]
470
+ allpairs = trainpairs + testpairs
471
+ assert (
472
+ len(trainpairs) == 15
473
+ and len(testpairs) == 15
474
+ and len(subvalpairs) == 1
475
+ and len(subtrainpairs) == 14
476
+ and len(allpairs) == 30
477
+ ), "incorrect parsing of pairs in Middlebury Eval v3"
478
+ tosave = {}
479
+ for r in ["full", "half", "quarter"]:
480
+ tosave.update(
481
+ **{
482
+ "train_" + r: trainpairs,
483
+ "subtrain_" + r: subtrainpairs,
484
+ "subval_" + r: subvalpairs,
485
+ "test_" + r: testpairs,
486
+ "all_" + r: allpairs,
487
+ }
488
+ )
489
+ return tosave
490
+
491
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
492
+ assert prediction.ndim == 2
493
+ assert prediction.dtype == np.float32
494
+ outfile = os.path.join(
495
+ outdir,
496
+ pairname.split("/")[0].replace("train", "training")
497
+ + self.submission_sresolution,
498
+ pairname.split("/")[1],
499
+ "disp0" + self.submission_methodname + ".pfm",
500
+ )
501
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
502
+ writePFM(outfile, prediction)
503
+ timefile = os.path.join(
504
+ os.path.dirname(outfile), "time" + self.submission_methodname + ".txt"
505
+ )
506
+ with open(timefile, "w") as fid:
507
+ fid.write(str(time))
508
+
509
+ def finalize_submission(self, outdir):
510
+ cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .'
511
+ print(cmd)
512
+ os.system(cmd)
513
+ print(f"Done. Submission file at {outdir}/{self.submission_methodname}.zip")
514
+
515
+
516
+ class ETH3DLowResDataset(StereoDataset):
517
+
518
+ def _prepare_data(self):
519
+ self.name = "ETH3DLowRes"
520
+ self._set_root()
521
+ assert self.split in ["train", "test", "subtrain", "subval", "all"]
522
+ self.pairname_to_Limgname = lambda pairname: osp.join(
523
+ self.root, pairname, "im0.png"
524
+ )
525
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
526
+ self.root, pairname, "im1.png"
527
+ )
528
+ self.pairname_to_Ldispname = (
529
+ None
530
+ if self.split == "test"
531
+ else lambda pairname: (
532
+ None
533
+ if pairname.startswith("test/")
534
+ else osp.join(
535
+ self.root, pairname.replace("train/", "train_gt/"), "disp0GT.pfm"
536
+ )
537
+ )
538
+ )
539
+ self.pairname_to_str = lambda pairname: pairname
540
+ self.load_disparity = _read_eth3d_disp
541
+ self.has_constant_resolution = False
542
+
543
+ def _build_cache(self):
544
+ trainpairs = ["train/" + s for s in sorted(os.listdir(self.root + "train/"))]
545
+ testpairs = ["test/" + s for s in sorted(os.listdir(self.root + "test/"))]
546
+ assert (
547
+ len(trainpairs) == 27 and len(testpairs) == 20
548
+ ), "incorrect parsing of pairs in ETH3D Low Res"
549
+ subvalpairs = [
550
+ "train/delivery_area_3s",
551
+ "train/electro_3l",
552
+ "train/playground_3l",
553
+ ]
554
+ assert all(p in trainpairs for p in subvalpairs)
555
+ subtrainpairs = [p for p in trainpairs if not p in subvalpairs]
556
+ assert (
557
+ len(subvalpairs) == 3 and len(subtrainpairs) == 24
558
+ ), "incorrect parsing of pairs in ETH3D Low Res"
559
+ tosave = {
560
+ "train": trainpairs,
561
+ "test": testpairs,
562
+ "subtrain": subtrainpairs,
563
+ "subval": subvalpairs,
564
+ "all": trainpairs + testpairs,
565
+ }
566
+ return tosave
567
+
568
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
569
+ assert prediction.ndim == 2
570
+ assert prediction.dtype == np.float32
571
+ outfile = os.path.join(
572
+ outdir, "low_res_two_view", pairname.split("/")[1] + ".pfm"
573
+ )
574
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
575
+ writePFM(outfile, prediction)
576
+ timefile = outfile[:-4] + ".txt"
577
+ with open(timefile, "w") as fid:
578
+ fid.write("runtime " + str(time))
579
+
580
+ def finalize_submission(self, outdir):
581
+ cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view'
582
+ print(cmd)
583
+ os.system(cmd)
584
+ print(f"Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip")
585
+
586
+
587
+ class BoosterDataset(StereoDataset):
588
+
589
+ def _prepare_data(self):
590
+ self.name = "Booster"
591
+ self._set_root()
592
+ assert self.split in [
593
+ "train_balanced",
594
+ "test_balanced",
595
+ "subtrain_balanced",
596
+ "subval_balanced",
597
+ ] # we use only the balanced version
598
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
599
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
600
+ self.root, pairname
601
+ ).replace("/camera_00/", "/camera_02/")
602
+ self.pairname_to_Ldispname = lambda pairname: osp.join(
603
+ self.root, osp.dirname(pairname), "../disp_00.npy"
604
+ ) # same images with different colors, same gt per sequence
605
+ self.pairname_to_str = lambda pairname: pairname[:-4].replace(
606
+ "/camera_00/", "/"
607
+ )
608
+ self.load_disparity = _read_booster_disp
609
+
610
+ def _build_cache(self):
611
+ trainseqs = sorted(os.listdir(self.root + "train/balanced"))
612
+ trainpairs = [
613
+ "train/balanced/" + s + "/camera_00/" + imname
614
+ for s in trainseqs
615
+ for imname in sorted(
616
+ os.listdir(self.root + "train/balanced/" + s + "/camera_00/")
617
+ )
618
+ ]
619
+ testpairs = [
620
+ "test/balanced/" + s + "/camera_00/" + imname
621
+ for s in sorted(os.listdir(self.root + "test/balanced"))
622
+ for imname in sorted(
623
+ os.listdir(self.root + "test/balanced/" + s + "/camera_00/")
624
+ )
625
+ ]
626
+ assert len(trainpairs) == 228 and len(testpairs) == 191
627
+ subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])]
628
+ subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])]
629
+ # warning: if we do validation split, we should split scenes!!!
630
+ tosave = {
631
+ "train_balanced": trainpairs,
632
+ "test_balanced": testpairs,
633
+ "subtrain_balanced": subtrainpairs,
634
+ "subval_balanced": subvalpairs,
635
+ }
636
+ return tosave
637
+
638
+
639
+ class SpringDataset(StereoDataset):
640
+
641
+ def _prepare_data(self):
642
+ self.name = "Spring"
643
+ self._set_root()
644
+ assert self.split in ["train", "test", "subtrain", "subval"]
645
+ self.pairname_to_Limgname = lambda pairname: osp.join(
646
+ self.root, pairname + ".png"
647
+ )
648
+ self.pairname_to_Rimgname = (
649
+ lambda pairname: osp.join(self.root, pairname + ".png")
650
+ .replace("frame_right", "<frame_right>")
651
+ .replace("frame_left", "frame_right")
652
+ .replace("<frame_right>", "frame_left")
653
+ )
654
+ self.pairname_to_Ldispname = lambda pairname: (
655
+ None
656
+ if pairname.startswith("test")
657
+ else osp.join(self.root, pairname + ".dsp5")
658
+ .replace("frame_left", "disp1_left")
659
+ .replace("frame_right", "disp1_right")
660
+ )
661
+ self.pairname_to_str = lambda pairname: pairname
662
+ self.load_disparity = _read_hdf5_disp
663
+
664
+ def _build_cache(self):
665
+ trainseqs = sorted(os.listdir(osp.join(self.root, "train")))
666
+ trainpairs = [
667
+ osp.join("train", s, "frame_left", f[:-4])
668
+ for s in trainseqs
669
+ for f in sorted(os.listdir(osp.join(self.root, "train", s, "frame_left")))
670
+ ]
671
+ testseqs = sorted(os.listdir(osp.join(self.root, "test")))
672
+ testpairs = [
673
+ osp.join("test", s, "frame_left", f[:-4])
674
+ for s in testseqs
675
+ for f in sorted(os.listdir(osp.join(self.root, "test", s, "frame_left")))
676
+ ]
677
+ testpairs += [p.replace("frame_left", "frame_right") for p in testpairs]
678
+ """maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041"""
679
+ subtrainpairs = [p for p in trainpairs if p.split("/")[1] != "0041"]
680
+ subvalpairs = [p for p in trainpairs if p.split("/")[1] == "0041"]
681
+ assert (
682
+ len(trainpairs) == 5000
683
+ and len(testpairs) == 2000
684
+ and len(subtrainpairs) == 4904
685
+ and len(subvalpairs) == 96
686
+ ), "incorrect parsing of pairs in Spring"
687
+ tosave = {
688
+ "train": trainpairs,
689
+ "test": testpairs,
690
+ "subtrain": subtrainpairs,
691
+ "subval": subvalpairs,
692
+ }
693
+ return tosave
694
+
695
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
696
+ assert prediction.ndim == 2
697
+ assert prediction.dtype == np.float32
698
+ outfile = (
699
+ os.path.join(outdir, pairname + ".dsp5")
700
+ .replace("frame_left", "disp1_left")
701
+ .replace("frame_right", "disp1_right")
702
+ )
703
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
704
+ writeDsp5File(prediction, outfile)
705
+
706
+ def finalize_submission(self, outdir):
707
+ assert self.split == "test"
708
+ exe = "{self.root}/disp1_subsampling"
709
+ if os.path.isfile(exe):
710
+ cmd = f'cd "{outdir}/test"; {exe} .'
711
+ print(cmd)
712
+ os.system(cmd)
713
+ else:
714
+ print("Could not find disp1_subsampling executable for submission.")
715
+ print("Please download it and run:")
716
+ print(f'cd "{outdir}/test"; <disp1_subsampling_exe> .')
717
+
718
+
719
+ class Kitti12Dataset(StereoDataset):
720
+
721
+ def _prepare_data(self):
722
+ self.name = "Kitti12"
723
+ self._set_root()
724
+ assert self.split in ["train", "test"]
725
+ self.pairname_to_Limgname = lambda pairname: osp.join(
726
+ self.root, pairname + "_10.png"
727
+ )
728
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
729
+ self.root, pairname.replace("/colored_0/", "/colored_1/") + "_10.png"
730
+ )
731
+ self.pairname_to_Ldispname = (
732
+ None
733
+ if self.split == "test"
734
+ else lambda pairname: osp.join(
735
+ self.root, pairname.replace("/colored_0/", "/disp_occ/") + "_10.png"
736
+ )
737
+ )
738
+ self.pairname_to_str = lambda pairname: pairname.replace("/colored_0/", "/")
739
+ self.load_disparity = _read_kitti_disp
740
+
741
+ def _build_cache(self):
742
+ trainseqs = ["training/colored_0/%06d" % (i) for i in range(194)]
743
+ testseqs = ["testing/colored_0/%06d" % (i) for i in range(195)]
744
+ assert (
745
+ len(trainseqs) == 194 and len(testseqs) == 195
746
+ ), "incorrect parsing of pairs in Kitti12"
747
+ tosave = {"train": trainseqs, "test": testseqs}
748
+ return tosave
749
+
750
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
751
+ assert prediction.ndim == 2
752
+ assert prediction.dtype == np.float32
753
+ outfile = os.path.join(outdir, pairname.split("/")[-1] + "_10.png")
754
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
755
+ img = (prediction * 256).astype("uint16")
756
+ Image.fromarray(img).save(outfile)
757
+
758
+ def finalize_submission(self, outdir):
759
+ assert self.split == "test"
760
+ cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .'
761
+ print(cmd)
762
+ os.system(cmd)
763
+ print(f"Done. Submission file at {outdir}/kitti12_results.zip")
764
+
765
+
766
+ class Kitti15Dataset(StereoDataset):
767
+
768
+ def _prepare_data(self):
769
+ self.name = "Kitti15"
770
+ self._set_root()
771
+ assert self.split in ["train", "subtrain", "subval", "test"]
772
+ self.pairname_to_Limgname = lambda pairname: osp.join(
773
+ self.root, pairname + "_10.png"
774
+ )
775
+ self.pairname_to_Rimgname = lambda pairname: osp.join(
776
+ self.root, pairname.replace("/image_2/", "/image_3/") + "_10.png"
777
+ )
778
+ self.pairname_to_Ldispname = (
779
+ None
780
+ if self.split == "test"
781
+ else lambda pairname: osp.join(
782
+ self.root, pairname.replace("/image_2/", "/disp_occ_0/") + "_10.png"
783
+ )
784
+ )
785
+ self.pairname_to_str = lambda pairname: pairname.replace("/image_2/", "/")
786
+ self.load_disparity = _read_kitti_disp
787
+
788
+ def _build_cache(self):
789
+ trainseqs = ["training/image_2/%06d" % (i) for i in range(200)]
790
+ subtrainseqs = trainseqs[:-5]
791
+ subvalseqs = trainseqs[-5:]
792
+ testseqs = ["testing/image_2/%06d" % (i) for i in range(200)]
793
+ assert (
794
+ len(trainseqs) == 200
795
+ and len(subtrainseqs) == 195
796
+ and len(subvalseqs) == 5
797
+ and len(testseqs) == 200
798
+ ), "incorrect parsing of pairs in Kitti15"
799
+ tosave = {
800
+ "train": trainseqs,
801
+ "subtrain": subtrainseqs,
802
+ "subval": subvalseqs,
803
+ "test": testseqs,
804
+ }
805
+ return tosave
806
+
807
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
808
+ assert prediction.ndim == 2
809
+ assert prediction.dtype == np.float32
810
+ outfile = os.path.join(outdir, "disp_0", pairname.split("/")[-1] + "_10.png")
811
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
812
+ img = (prediction * 256).astype("uint16")
813
+ Image.fromarray(img).save(outfile)
814
+
815
+ def finalize_submission(self, outdir):
816
+ assert self.split == "test"
817
+ cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0'
818
+ print(cmd)
819
+ os.system(cmd)
820
+ print(f"Done. Submission file at {outdir}/kitti15_results.zip")
821
+
822
+
823
+ ### auxiliary functions
824
+
825
+
826
+ def _read_img(filename):
827
+ # convert to RGB for scene flow finalpass data
828
+ img = np.asarray(Image.open(filename).convert("RGB"))
829
+ return img
830
+
831
+
832
+ def _read_booster_disp(filename):
833
+ disp = np.load(filename)
834
+ disp[disp == 0.0] = np.inf
835
+ return disp
836
+
837
+
838
+ def _read_png_disp(filename, coef=1.0):
839
+ disp = np.asarray(Image.open(filename))
840
+ disp = disp.astype(np.float32) / coef
841
+ disp[disp == 0.0] = np.inf
842
+ return disp
843
+
844
+
845
+ def _read_pfm_disp(filename):
846
+ disp = np.ascontiguousarray(_read_pfm(filename)[0])
847
+ disp[disp <= 0] = (
848
+ np.inf
849
+ ) # eg /nfs/data/ffs-3d/datasets/middlebury/2014/Shopvac-imperfect/disp0.pfm
850
+ return disp
851
+
852
+
853
+ def _read_npy_disp(filename):
854
+ return np.load(filename)
855
+
856
+
857
+ def _read_crestereo_disp(filename):
858
+ return _read_png_disp(filename, coef=32.0)
859
+
860
+
861
+ def _read_middlebury20052006_disp(filename):
862
+ return _read_png_disp(filename, coef=1.0)
863
+
864
+
865
+ def _read_kitti_disp(filename):
866
+ return _read_png_disp(filename, coef=256.0)
867
+
868
+
869
+ _read_sceneflow_disp = _read_pfm_disp
870
+ _read_eth3d_disp = _read_pfm_disp
871
+ _read_middlebury_disp = _read_pfm_disp
872
+ _read_carla_disp = _read_pfm_disp
873
+ _read_tartanair_disp = _read_npy_disp
874
+
875
+
876
+ def _read_hdf5_disp(filename):
877
+ disp = np.asarray(h5py.File(filename)["disparity"])
878
+ disp[np.isnan(disp)] = np.inf # make invalid values as +inf
879
+ # disp[disp==0.0] = np.inf # make invalid values as +inf
880
+ return disp.astype(np.float32)
881
+
882
+
883
+ import re
884
+
885
+
886
+ def _read_pfm(file):
887
+ file = open(file, "rb")
888
+
889
+ color = None
890
+ width = None
891
+ height = None
892
+ scale = None
893
+ endian = None
894
+
895
+ header = file.readline().rstrip()
896
+ if header.decode("ascii") == "PF":
897
+ color = True
898
+ elif header.decode("ascii") == "Pf":
899
+ color = False
900
+ else:
901
+ raise Exception("Not a PFM file.")
902
+
903
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
904
+ if dim_match:
905
+ width, height = list(map(int, dim_match.groups()))
906
+ else:
907
+ raise Exception("Malformed PFM header.")
908
+
909
+ scale = float(file.readline().decode("ascii").rstrip())
910
+ if scale < 0: # little-endian
911
+ endian = "<"
912
+ scale = -scale
913
+ else:
914
+ endian = ">" # big-endian
915
+
916
+ data = np.fromfile(file, endian + "f")
917
+ shape = (height, width, 3) if color else (height, width)
918
+
919
+ data = np.reshape(data, shape)
920
+ data = np.flipud(data)
921
+ return data, scale
922
+
923
+
924
+ def writePFM(file, image, scale=1):
925
+ file = open(file, "wb")
926
+
927
+ color = None
928
+
929
+ if image.dtype.name != "float32":
930
+ raise Exception("Image dtype must be float32.")
931
+
932
+ image = np.flipud(image)
933
+
934
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
935
+ color = True
936
+ elif (
937
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
938
+ ): # greyscale
939
+ color = False
940
+ else:
941
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
942
+
943
+ file.write("PF\n" if color else "Pf\n".encode())
944
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
945
+
946
+ endian = image.dtype.byteorder
947
+
948
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
949
+ scale = -scale
950
+
951
+ file.write("%f\n".encode() % scale)
952
+
953
+ image.tofile(file)
954
+
955
+
956
+ def writeDsp5File(disp, filename):
957
+ with h5py.File(filename, "w") as f:
958
+ f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5)
959
+
960
+
961
+ # disp visualization
962
+
963
+
964
+ def vis_disparity(disp, m=None, M=None):
965
+ if m is None:
966
+ m = disp.min()
967
+ if M is None:
968
+ M = disp.max()
969
+ disp_vis = (disp - m) / (M - m) * 255.0
970
+ disp_vis = disp_vis.astype("uint8")
971
+ disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
972
+ return disp_vis
973
+
974
+
975
+ # dataset getter
976
+
977
+
978
+ def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None):
979
+ dataset_str = dataset_str.replace("(", "Dataset(")
980
+ if augmentor:
981
+ dataset_str = dataset_str.replace(")", ", augmentor=True)")
982
+ if crop_size is not None:
983
+ dataset_str = dataset_str.replace(
984
+ ")", ", crop_size={:s})".format(str(crop_size))
985
+ )
986
+ return eval(dataset_str)
987
+
988
+
989
+ def get_test_datasets_stereo(dataset_str):
990
+ dataset_str = dataset_str.replace("(", "Dataset(")
991
+ return [eval(s) for s in dataset_str.split("+")]
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/download_model.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ model=$1
5
+ outfile="stereoflow_models/${model}"
6
+ if [[ ! -f $outfile ]]
7
+ then
8
+ mkdir -p stereoflow_models/;
9
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/StereoFlow_models/$1 -P stereoflow_models/;
10
+ else
11
+ echo "Model ${model} already downloaded in ${outfile}."
12
+ fi
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/engine.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Main function for training one epoch or testing
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import sys
10
+ from typing import Iterable
11
+ import numpy as np
12
+ import torch
13
+ import torchvision
14
+
15
+ from utils import misc as misc
16
+
17
+
18
+ def split_prediction_conf(predictions, with_conf=False):
19
+ if not with_conf:
20
+ return predictions, None
21
+ conf = predictions[:, -1:, :, :]
22
+ predictions = predictions[:, :-1, :, :]
23
+ return predictions, conf
24
+
25
+
26
+ def train_one_epoch(
27
+ model: torch.nn.Module,
28
+ criterion: torch.nn.Module,
29
+ metrics: torch.nn.Module,
30
+ data_loader: Iterable,
31
+ optimizer: torch.optim.Optimizer,
32
+ device: torch.device,
33
+ epoch: int,
34
+ loss_scaler,
35
+ log_writer=None,
36
+ print_freq=20,
37
+ args=None,
38
+ ):
39
+ model.train(True)
40
+ metric_logger = misc.MetricLogger(delimiter=" ")
41
+ metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
42
+ header = "Epoch: [{}]".format(epoch)
43
+
44
+ accum_iter = args.accum_iter
45
+
46
+ optimizer.zero_grad()
47
+
48
+ details = {}
49
+
50
+ if log_writer is not None:
51
+ print("log_dir: {}".format(log_writer.log_dir))
52
+
53
+ if args.img_per_epoch:
54
+ iter_per_epoch = args.img_per_epoch // args.batch_size + int(
55
+ args.img_per_epoch % args.batch_size > 0
56
+ )
57
+ assert (
58
+ len(data_loader) >= iter_per_epoch
59
+ ), "Dataset is too small for so many iterations"
60
+ len_data_loader = iter_per_epoch
61
+ else:
62
+ len_data_loader, iter_per_epoch = len(data_loader), None
63
+
64
+ for data_iter_step, (image1, image2, gt, pairname) in enumerate(
65
+ metric_logger.log_every(
66
+ data_loader, print_freq, header, max_iter=iter_per_epoch
67
+ )
68
+ ):
69
+
70
+ image1 = image1.to(device, non_blocking=True)
71
+ image2 = image2.to(device, non_blocking=True)
72
+ gt = gt.to(device, non_blocking=True)
73
+
74
+ # we use a per iteration (instead of per epoch) lr scheduler
75
+ if data_iter_step % accum_iter == 0:
76
+ misc.adjust_learning_rate(
77
+ optimizer, data_iter_step / len_data_loader + epoch, args
78
+ )
79
+
80
+ with torch.cuda.amp.autocast(enabled=bool(args.amp)):
81
+ prediction = model(image1, image2)
82
+ prediction, conf = split_prediction_conf(prediction, criterion.with_conf)
83
+ batch_metrics = metrics(prediction.detach(), gt)
84
+ loss = (
85
+ criterion(prediction, gt)
86
+ if conf is None
87
+ else criterion(prediction, gt, conf)
88
+ )
89
+
90
+ loss_value = loss.item()
91
+ if not math.isfinite(loss_value):
92
+ print("Loss is {}, stopping training".format(loss_value))
93
+ sys.exit(1)
94
+
95
+ loss /= accum_iter
96
+ loss_scaler(
97
+ loss,
98
+ optimizer,
99
+ parameters=model.parameters(),
100
+ update_grad=(data_iter_step + 1) % accum_iter == 0,
101
+ )
102
+ if (data_iter_step + 1) % accum_iter == 0:
103
+ optimizer.zero_grad()
104
+
105
+ torch.cuda.synchronize()
106
+
107
+ metric_logger.update(loss=loss_value)
108
+ for k, v in batch_metrics.items():
109
+ metric_logger.update(**{k: v.item()})
110
+ lr = optimizer.param_groups[0]["lr"]
111
+ metric_logger.update(lr=lr)
112
+
113
+ # if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value)
114
+ time_to_log = (data_iter_step + 1) % (
115
+ args.tboard_log_step * accum_iter
116
+ ) == 0 or data_iter_step == len_data_loader - 1
117
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
118
+ if log_writer is not None and time_to_log:
119
+ epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000)
120
+ # We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes.
121
+ log_writer.add_scalar("train/loss", loss_value_reduce, epoch_1000x)
122
+ log_writer.add_scalar("lr", lr, epoch_1000x)
123
+ for k, v in batch_metrics.items():
124
+ log_writer.add_scalar("train/" + k, v.item(), epoch_1000x)
125
+
126
+ # gather the stats from all processes
127
+ # if args.distributed: metric_logger.synchronize_between_processes()
128
+ print("Averaged stats:", metric_logger)
129
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
130
+
131
+
132
+ @torch.no_grad()
133
+ def validate_one_epoch(
134
+ model: torch.nn.Module,
135
+ criterion: torch.nn.Module,
136
+ metrics: torch.nn.Module,
137
+ data_loaders: list[Iterable],
138
+ device: torch.device,
139
+ epoch: int,
140
+ log_writer=None,
141
+ args=None,
142
+ ):
143
+
144
+ model.eval()
145
+ metric_loggers = []
146
+ header = "Epoch: [{}]".format(epoch)
147
+ print_freq = 20
148
+
149
+ conf_mode = args.tile_conf_mode
150
+ crop = args.crop
151
+
152
+ if log_writer is not None:
153
+ print("log_dir: {}".format(log_writer.log_dir))
154
+
155
+ results = {}
156
+ dnames = []
157
+ image1, image2, gt, prediction = None, None, None, None
158
+ for didx, data_loader in enumerate(data_loaders):
159
+ dname = str(data_loader.dataset)
160
+ dnames.append(dname)
161
+ metric_loggers.append(misc.MetricLogger(delimiter=" "))
162
+ for data_iter_step, (image1, image2, gt, pairname) in enumerate(
163
+ metric_loggers[didx].log_every(data_loader, print_freq, header)
164
+ ):
165
+ image1 = image1.to(device, non_blocking=True)
166
+ image2 = image2.to(device, non_blocking=True)
167
+ gt = gt.to(device, non_blocking=True)
168
+ if dname.startswith("Spring"):
169
+ assert (
170
+ gt.size(2) == image1.size(2) * 2
171
+ and gt.size(3) == image1.size(3) * 2
172
+ )
173
+ gt = (
174
+ gt[:, :, 0::2, 0::2]
175
+ + gt[:, :, 0::2, 1::2]
176
+ + gt[:, :, 1::2, 0::2]
177
+ + gt[:, :, 1::2, 1::2]
178
+ ) / 4.0 # we approximate the gt based on the 2x upsampled ones
179
+
180
+ with torch.inference_mode():
181
+ prediction, tiled_loss, c = tiled_pred(
182
+ model,
183
+ criterion,
184
+ image1,
185
+ image2,
186
+ gt,
187
+ conf_mode=conf_mode,
188
+ overlap=args.val_overlap,
189
+ crop=crop,
190
+ with_conf=criterion.with_conf,
191
+ )
192
+ batch_metrics = metrics(prediction.detach(), gt)
193
+ loss = (
194
+ criterion(prediction.detach(), gt)
195
+ if not criterion.with_conf
196
+ else criterion(prediction.detach(), gt, c)
197
+ )
198
+ loss_value = loss.item()
199
+ metric_loggers[didx].update(loss_tiled=tiled_loss.item())
200
+ metric_loggers[didx].update(**{f"loss": loss_value})
201
+ for k, v in batch_metrics.items():
202
+ metric_loggers[didx].update(**{dname + "_" + k: v.item()})
203
+
204
+ results = {
205
+ k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items()
206
+ }
207
+ if len(dnames) > 1:
208
+ for k in batch_metrics.keys():
209
+ results["AVG_" + k] = sum(
210
+ results[dname + "_" + k] for dname in dnames
211
+ ) / len(dnames)
212
+
213
+ if log_writer is not None:
214
+ epoch_1000x = int((1 + epoch) * 1000)
215
+ for k, v in results.items():
216
+ log_writer.add_scalar("val/" + k, v, epoch_1000x)
217
+
218
+ print("Averaged stats:", results)
219
+ return results
220
+
221
+
222
+ import torch.nn.functional as F
223
+
224
+
225
+ def _resize_img(img, new_size):
226
+ return F.interpolate(img, size=new_size, mode="bicubic", align_corners=False)
227
+
228
+
229
+ def _resize_stereo_or_flow(data, new_size):
230
+ assert data.ndim == 4
231
+ assert data.size(1) in [1, 2]
232
+ scale_x = new_size[1] / float(data.size(3))
233
+ out = F.interpolate(data, size=new_size, mode="bicubic", align_corners=False)
234
+ out[:, 0, :, :] *= scale_x
235
+ if out.size(1) == 2:
236
+ scale_y = new_size[0] / float(data.size(2))
237
+ out[:, 1, :, :] *= scale_y
238
+ print(scale_x, new_size, data.shape)
239
+ return out
240
+
241
+
242
+ @torch.no_grad()
243
+ def tiled_pred(
244
+ model,
245
+ criterion,
246
+ img1,
247
+ img2,
248
+ gt,
249
+ overlap=0.5,
250
+ bad_crop_thr=0.05,
251
+ downscale=False,
252
+ crop=512,
253
+ ret="loss",
254
+ conf_mode="conf_expsigmoid_10_5",
255
+ with_conf=False,
256
+ return_time=False,
257
+ ):
258
+
259
+ # for each image, we are going to run inference on many overlapping patches
260
+ # then, all predictions will be weighted-averaged
261
+ if gt is not None:
262
+ B, C, H, W = gt.shape
263
+ else:
264
+ B, _, H, W = img1.shape
265
+ C = model.head.num_channels - int(with_conf)
266
+ win_height, win_width = crop[0], crop[1]
267
+
268
+ # upscale to be larger than the crop
269
+ do_change_scale = H < win_height or W < win_width
270
+ if do_change_scale:
271
+ upscale_factor = max(win_width / W, win_height / W)
272
+ original_size = (H, W)
273
+ new_size = (round(H * upscale_factor), round(W * upscale_factor))
274
+ img1 = _resize_img(img1, new_size)
275
+ img2 = _resize_img(img2, new_size)
276
+ # resize gt just for the computation of tiled losses
277
+ if gt is not None:
278
+ gt = _resize_stereo_or_flow(gt, new_size)
279
+ H, W = img1.shape[2:4]
280
+
281
+ if conf_mode.startswith("conf_expsigmoid_"): # conf_expsigmoid_30_10
282
+ beta, betasigmoid = map(float, conf_mode[len("conf_expsigmoid_") :].split("_"))
283
+ elif conf_mode.startswith("conf_expbeta"): # conf_expbeta3
284
+ beta = float(conf_mode[len("conf_expbeta") :])
285
+ else:
286
+ raise NotImplementedError(f"conf_mode {conf_mode} is not implemented")
287
+
288
+ def crop_generator():
289
+ for sy in _overlapping(H, win_height, overlap):
290
+ for sx in _overlapping(W, win_width, overlap):
291
+ yield sy, sx, sy, sx, True
292
+
293
+ # keep track of weighted sum of prediction*weights and weights
294
+ accu_pred = img1.new_zeros(
295
+ (B, C, H, W)
296
+ ) # accumulate the weighted sum of predictions
297
+ accu_conf = img1.new_zeros((B, H, W)) + 1e-16 # accumulate the weights
298
+ accu_c = img1.new_zeros(
299
+ (B, H, W)
300
+ ) # accumulate the weighted sum of confidences ; not so useful except for computing some losses
301
+
302
+ tiled_losses = []
303
+
304
+ if return_time:
305
+ start = torch.cuda.Event(enable_timing=True)
306
+ end = torch.cuda.Event(enable_timing=True)
307
+ start.record()
308
+
309
+ for sy1, sx1, sy2, sx2, aligned in crop_generator():
310
+ # compute optical flow there
311
+ pred = model(_crop(img1, sy1, sx1), _crop(img2, sy2, sx2))
312
+ pred, predconf = split_prediction_conf(pred, with_conf=with_conf)
313
+
314
+ if gt is not None:
315
+ gtcrop = _crop(gt, sy1, sx1)
316
+ if criterion is not None and gt is not None:
317
+ tiled_losses.append(
318
+ criterion(pred, gtcrop).item()
319
+ if predconf is None
320
+ else criterion(pred, gtcrop, predconf).item()
321
+ )
322
+
323
+ if conf_mode.startswith("conf_expsigmoid_"):
324
+ conf = torch.exp(
325
+ -beta * 2 * (torch.sigmoid(predconf / betasigmoid) - 0.5)
326
+ ).view(B, win_height, win_width)
327
+ elif conf_mode.startswith("conf_expbeta"):
328
+ conf = torch.exp(-beta * predconf).view(B, win_height, win_width)
329
+ else:
330
+ raise NotImplementedError
331
+
332
+ accu_pred[..., sy1, sx1] += pred * conf[:, None, :, :]
333
+ accu_conf[..., sy1, sx1] += conf
334
+ accu_c[..., sy1, sx1] += predconf.view(B, win_height, win_width) * conf
335
+
336
+ pred = accu_pred / accu_conf[:, None, :, :]
337
+ c = accu_c / accu_conf
338
+ assert not torch.any(torch.isnan(pred))
339
+
340
+ if return_time:
341
+ end.record()
342
+ torch.cuda.synchronize()
343
+ time = start.elapsed_time(end) / 1000.0 # this was in milliseconds
344
+
345
+ if do_change_scale:
346
+ pred = _resize_stereo_or_flow(pred, original_size)
347
+
348
+ if return_time:
349
+ return pred, torch.mean(torch.tensor(tiled_losses)), c, time
350
+ return pred, torch.mean(torch.tensor(tiled_losses)), c
351
+
352
+
353
+ def _overlapping(total, window, overlap=0.5):
354
+ assert total >= window and 0 <= overlap < 1, (total, window, overlap)
355
+ num_windows = 1 + int(np.ceil((total - window) / ((1 - overlap) * window)))
356
+ offsets = np.linspace(0, total - window, num_windows).round().astype(int)
357
+ yield from (slice(x, x + window) for x in offsets)
358
+
359
+
360
+ def _crop(img, sy, sx):
361
+ B, THREE, H, W = img.shape
362
+ if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W:
363
+ return img[:, :, sy, sx]
364
+ l, r = max(0, -sx.start), max(0, sx.stop - W)
365
+ t, b = max(0, -sy.start), max(0, sy.stop - H)
366
+ img = torch.nn.functional.pad(img, (l, r, t, b), mode="constant")
367
+ return img[:, :, slice(sy.start + t, sy.stop + t), slice(sx.start + l, sx.stop + l)]
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/test.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Main test function
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import argparse
10
+ import pickle
11
+ from PIL import Image
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ import torch
16
+ from torch.utils.data import DataLoader
17
+
18
+ import utils.misc as misc
19
+ from models.croco_downstream import CroCoDownstreamBinocular
20
+ from models.head_downstream import PixelwiseTaskWithDPT
21
+
22
+ from stereoflow.criterion import *
23
+ from stereoflow.datasets_stereo import get_test_datasets_stereo
24
+ from stereoflow.datasets_flow import get_test_datasets_flow
25
+ from stereoflow.engine import tiled_pred
26
+
27
+ from stereoflow.datasets_stereo import vis_disparity
28
+ from stereoflow.datasets_flow import flowToColor
29
+
30
+
31
+ def get_args_parser():
32
+ parser = argparse.ArgumentParser("Test CroCo models on stereo/flow", add_help=False)
33
+ # important argument
34
+ parser.add_argument(
35
+ "--model", required=True, type=str, help="Path to the model to evaluate"
36
+ )
37
+ parser.add_argument(
38
+ "--dataset",
39
+ required=True,
40
+ type=str,
41
+ help="test dataset (there can be multiple dataset separated by a +)",
42
+ )
43
+ # tiling
44
+ parser.add_argument(
45
+ "--tile_conf_mode",
46
+ type=str,
47
+ default="",
48
+ help="Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint",
49
+ )
50
+ parser.add_argument(
51
+ "--tile_overlap", type=float, default=0.7, help="overlap between tiles"
52
+ )
53
+ # save (it will automatically go to <model_path>_<dataset_str>/<tile_str>_<save>)
54
+ parser.add_argument(
55
+ "--save",
56
+ type=str,
57
+ nargs="+",
58
+ default=[],
59
+ help="what to save: \
60
+ metrics (pickle file), \
61
+ pred (raw prediction save as torch tensor), \
62
+ visu (visualization in png of each prediction), \
63
+ err10 (visualization in png of the error clamp at 10 for each prediction), \
64
+ submission (submission file)",
65
+ )
66
+ # other (no impact)
67
+ parser.add_argument("--num_workers", default=4, type=int)
68
+ return parser
69
+
70
+
71
+ def _load_model_and_criterion(model_path, do_load_metrics, device):
72
+ print("loading model from", model_path)
73
+ assert os.path.isfile(model_path)
74
+ ckpt = torch.load(model_path, "cpu")
75
+
76
+ ckpt_args = ckpt["args"]
77
+ task = ckpt_args.task
78
+ tile_conf_mode = ckpt_args.tile_conf_mode
79
+ num_channels = {"stereo": 1, "flow": 2}[task]
80
+ with_conf = eval(ckpt_args.criterion).with_conf
81
+ if with_conf:
82
+ num_channels += 1
83
+ print("head: PixelwiseTaskWithDPT()")
84
+ head = PixelwiseTaskWithDPT()
85
+ head.num_channels = num_channels
86
+ print("croco_args:", ckpt_args.croco_args)
87
+ model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args)
88
+ msg = model.load_state_dict(ckpt["model"], strict=True)
89
+ model.eval()
90
+ model = model.to(device)
91
+
92
+ if do_load_metrics:
93
+ if task == "stereo":
94
+ metrics = StereoDatasetMetrics().to(device)
95
+ else:
96
+ metrics = FlowDatasetMetrics().to(device)
97
+ else:
98
+ metrics = None
99
+
100
+ return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode
101
+
102
+
103
+ def _save_batch(
104
+ pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None
105
+ ):
106
+
107
+ for i in range(len(pairnames)):
108
+
109
+ pairname = (
110
+ eval(pairnames[i]) if pairnames[i].startswith("(") else pairnames[i]
111
+ ) # unbatch pairname
112
+ fname = os.path.join(outdir, dataset.pairname_to_str(pairname))
113
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
114
+
115
+ predi = pred[i, ...]
116
+ if gt is not None:
117
+ gti = gt[i, ...]
118
+
119
+ if "pred" in save:
120
+ torch.save(predi.squeeze(0).cpu(), fname + "_pred.pth")
121
+
122
+ if "visu" in save:
123
+ if task == "stereo":
124
+ disparity = predi.permute((1, 2, 0)).squeeze(2).cpu().numpy()
125
+ m, M = None
126
+ if gt is not None:
127
+ mask = torch.isfinite(gti)
128
+ m = gt[mask].min()
129
+ M = gt[mask].max()
130
+ img_disparity = vis_disparity(disparity, m=m, M=M)
131
+ Image.fromarray(img_disparity).save(fname + "_pred.png")
132
+ else:
133
+ # normalize flowToColor according to the maxnorm of gt (or prediction if not available)
134
+ flowNorm = (
135
+ torch.sqrt(
136
+ torch.sum((gti if gt is not None else predi) ** 2, dim=0)
137
+ )
138
+ .max()
139
+ .item()
140
+ )
141
+ imgflow = flowToColor(
142
+ predi.permute((1, 2, 0)).cpu().numpy(), maxflow=flowNorm
143
+ )
144
+ Image.fromarray(imgflow).save(fname + "_pred.png")
145
+
146
+ if "err10" in save:
147
+ assert gt is not None
148
+ L2err = torch.sqrt(torch.sum((gti - predi) ** 2, dim=0))
149
+ valid = torch.isfinite(gti[0, :, :])
150
+ L2err[~valid] = 0.0
151
+ L2err = torch.clamp(L2err, max=10.0)
152
+ red = (L2err * 255.0 / 10.0).to(dtype=torch.uint8)[:, :, None]
153
+ zer = torch.zeros_like(red)
154
+ imgerr = torch.cat((red, zer, zer), dim=2).cpu().numpy()
155
+ Image.fromarray(imgerr).save(fname + "_err10.png")
156
+
157
+ if "submission" in save:
158
+ assert submission_dir is not None
159
+ predi_np = (
160
+ predi.permute(1, 2, 0).squeeze(2).cpu().numpy()
161
+ ) # transform into HxWx2 for flow or HxW for stereo
162
+ dataset.submission_save_pairname(pairname, predi_np, submission_dir, time)
163
+
164
+
165
+ def main(args):
166
+
167
+ # load the pretrained model and metrics
168
+ device = (
169
+ torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
170
+ )
171
+ model, metrics, cropsize, with_conf, task, tile_conf_mode = (
172
+ _load_model_and_criterion(args.model, "metrics" in args.save, device)
173
+ )
174
+ if args.tile_conf_mode == "":
175
+ args.tile_conf_mode = tile_conf_mode
176
+
177
+ # load the datasets
178
+ datasets = (
179
+ get_test_datasets_stereo if task == "stereo" else get_test_datasets_flow
180
+ )(args.dataset)
181
+ dataloaders = [
182
+ DataLoader(
183
+ dataset,
184
+ batch_size=1,
185
+ shuffle=False,
186
+ num_workers=args.num_workers,
187
+ pin_memory=True,
188
+ drop_last=False,
189
+ )
190
+ for dataset in datasets
191
+ ]
192
+
193
+ # run
194
+ for i, dataloader in enumerate(dataloaders):
195
+ dataset = datasets[i]
196
+ dstr = args.dataset.split("+")[i]
197
+
198
+ outdir = args.model + "_" + misc.filename(dstr)
199
+ if "metrics" in args.save and len(args.save) == 1:
200
+ fname = os.path.join(
201
+ outdir, f"conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl"
202
+ )
203
+ if os.path.isfile(fname) and len(args.save) == 1:
204
+ print(" metrics already compute in " + fname)
205
+ with open(fname, "rb") as fid:
206
+ results = pickle.load(fid)
207
+ for k, v in results.items():
208
+ print("{:s}: {:.3f}".format(k, v))
209
+ continue
210
+
211
+ if "submission" in args.save:
212
+ dirname = (
213
+ f"submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}"
214
+ )
215
+ submission_dir = os.path.join(outdir, dirname)
216
+ else:
217
+ submission_dir = None
218
+
219
+ print("")
220
+ print("saving {:s} in {:s}".format("+".join(args.save), outdir))
221
+ print(repr(dataset))
222
+
223
+ if metrics is not None:
224
+ metrics.reset()
225
+
226
+ for data_iter_step, (image1, image2, gt, pairnames) in enumerate(
227
+ tqdm(dataloader)
228
+ ):
229
+
230
+ do_flip = (
231
+ task == "stereo"
232
+ and dstr.startswith("Spring")
233
+ and any("right" in p for p in pairnames)
234
+ ) # we flip the images and will flip the prediction after as we assume img1 is on the left
235
+
236
+ image1 = image1.to(device, non_blocking=True)
237
+ image2 = image2.to(device, non_blocking=True)
238
+ gt = (
239
+ gt.to(device, non_blocking=True) if gt.numel() > 0 else None
240
+ ) # special case for test time
241
+ if do_flip:
242
+ assert all("right" in p for p in pairnames)
243
+ image1 = image1.flip(
244
+ dims=[3]
245
+ ) # this is already the right frame, let's flip it
246
+ image2 = image2.flip(dims=[3])
247
+ gt = gt # that is ok
248
+
249
+ with torch.inference_mode():
250
+ pred, _, _, time = tiled_pred(
251
+ model,
252
+ None,
253
+ image1,
254
+ image2,
255
+ None if dataset.name == "Spring" else gt,
256
+ conf_mode=args.tile_conf_mode,
257
+ overlap=args.tile_overlap,
258
+ crop=cropsize,
259
+ with_conf=with_conf,
260
+ return_time=True,
261
+ )
262
+
263
+ if do_flip:
264
+ pred = pred.flip(dims=[3])
265
+
266
+ if metrics is not None:
267
+ metrics.add_batch(pred, gt)
268
+
269
+ if any(k in args.save for k in ["pred", "visu", "err10", "submission"]):
270
+ _save_batch(
271
+ pred,
272
+ gt,
273
+ pairnames,
274
+ dataset,
275
+ task,
276
+ args.save,
277
+ outdir,
278
+ time,
279
+ submission_dir=submission_dir,
280
+ )
281
+
282
+ # print
283
+ if metrics is not None:
284
+ results = metrics.get_results()
285
+ for k, v in results.items():
286
+ print("{:s}: {:.3f}".format(k, v))
287
+
288
+ # save if needed
289
+ if "metrics" in args.save:
290
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
291
+ with open(fname, "wb") as fid:
292
+ pickle.dump(results, fid)
293
+ print("metrics saved in", fname)
294
+
295
+ # finalize submission if needed
296
+ if "submission" in args.save:
297
+ dataset.finalize_submission(submission_dir)
298
+
299
+
300
+ if __name__ == "__main__":
301
+ args = get_args_parser()
302
+ args = args.parse_args()
303
+ main(args)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/stereoflow/train.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Main training function
6
+ # --------------------------------------------------------
7
+
8
+ import argparse
9
+ import datetime
10
+ import json
11
+ import numpy as np
12
+ import os
13
+ import sys
14
+ import time
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ import torch.backends.cudnn as cudnn
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ import torchvision.transforms as transforms
21
+ import torchvision.datasets as datasets
22
+ from torch.utils.data import DataLoader
23
+
24
+ import utils
25
+ import utils.misc as misc
26
+ from utils.misc import NativeScalerWithGradNormCount as NativeScaler
27
+ from models.croco_downstream import CroCoDownstreamBinocular, croco_args_from_ckpt
28
+ from models.pos_embed import interpolate_pos_embed
29
+ from models.head_downstream import PixelwiseTaskWithDPT
30
+
31
+ from stereoflow.datasets_stereo import (
32
+ get_train_dataset_stereo,
33
+ get_test_datasets_stereo,
34
+ )
35
+ from stereoflow.datasets_flow import get_train_dataset_flow, get_test_datasets_flow
36
+ from stereoflow.engine import train_one_epoch, validate_one_epoch
37
+ from stereoflow.criterion import *
38
+
39
+
40
+ def get_args_parser():
41
+ # prepare subparsers
42
+ parser = argparse.ArgumentParser(
43
+ "Finetuning CroCo models on stereo or flow", add_help=False
44
+ )
45
+ subparsers = parser.add_subparsers(
46
+ title="Task (stereo or flow)", dest="task", required=True
47
+ )
48
+ parser_stereo = subparsers.add_parser("stereo", help="Training stereo model")
49
+ parser_flow = subparsers.add_parser("flow", help="Training flow model")
50
+
51
+ def add_arg(
52
+ name_or_flags, default=None, default_stereo=None, default_flow=None, **kwargs
53
+ ):
54
+ if default is not None:
55
+ assert (
56
+ default_stereo is None and default_flow is None
57
+ ), "setting default makes default_stereo and default_flow disabled"
58
+ parser_stereo.add_argument(
59
+ name_or_flags,
60
+ default=default if default is not None else default_stereo,
61
+ **kwargs,
62
+ )
63
+ parser_flow.add_argument(
64
+ name_or_flags,
65
+ default=default if default is not None else default_flow,
66
+ **kwargs,
67
+ )
68
+
69
+ # output dir
70
+ add_arg(
71
+ "--output_dir",
72
+ required=True,
73
+ type=str,
74
+ help="path where to save, if empty, automatically created",
75
+ )
76
+ # model
77
+ add_arg(
78
+ "--crop",
79
+ type=int,
80
+ nargs="+",
81
+ default_stereo=[352, 704],
82
+ default_flow=[320, 384],
83
+ help="size of the random image crops used during training.",
84
+ )
85
+ add_arg(
86
+ "--pretrained",
87
+ required=True,
88
+ type=str,
89
+ help="Load pretrained model (required as croco arguments come from there)",
90
+ )
91
+ # criterion
92
+ add_arg(
93
+ "--criterion",
94
+ default_stereo="LaplacianLossBounded2()",
95
+ default_flow="LaplacianLossBounded()",
96
+ type=str,
97
+ help="string to evaluate to get criterion",
98
+ )
99
+ add_arg("--bestmetric", default_stereo="avgerr", default_flow="EPE", type=str)
100
+ # dataset
101
+ add_arg("--dataset", type=str, required=True, help="training set")
102
+ # training
103
+ add_arg("--seed", default=0, type=int, help="seed")
104
+ add_arg(
105
+ "--batch_size",
106
+ default_stereo=6,
107
+ default_flow=8,
108
+ type=int,
109
+ help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
110
+ )
111
+ add_arg("--epochs", default=32, type=int, help="number of training epochs")
112
+ add_arg(
113
+ "--img_per_epoch",
114
+ type=int,
115
+ default=None,
116
+ help="Fix the number of images seen in an epoch (None means use all training pairs)",
117
+ )
118
+ add_arg(
119
+ "--accum_iter",
120
+ default=1,
121
+ type=int,
122
+ help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)",
123
+ )
124
+ add_arg(
125
+ "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)"
126
+ )
127
+ add_arg(
128
+ "--lr",
129
+ type=float,
130
+ default_stereo=3e-5,
131
+ default_flow=2e-5,
132
+ metavar="LR",
133
+ help="learning rate (absolute lr)",
134
+ )
135
+ add_arg(
136
+ "--min_lr",
137
+ type=float,
138
+ default=0.0,
139
+ metavar="LR",
140
+ help="lower lr bound for cyclic schedulers that hit 0",
141
+ )
142
+ add_arg(
143
+ "--warmup_epochs", type=int, default=1, metavar="N", help="epochs to warmup LR"
144
+ )
145
+ add_arg(
146
+ "--optimizer",
147
+ default="AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))",
148
+ type=str,
149
+ help="Optimizer from torch.optim [ default: AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) ]",
150
+ )
151
+ add_arg(
152
+ "--amp",
153
+ default=0,
154
+ type=int,
155
+ choices=[0, 1],
156
+ help="enable automatic mixed precision training",
157
+ )
158
+ # validation
159
+ add_arg(
160
+ "--val_dataset",
161
+ type=str,
162
+ default="",
163
+ help="Validation sets, multiple separated by + (empty string means that no validation is performed)",
164
+ )
165
+ add_arg(
166
+ "--tile_conf_mode",
167
+ type=str,
168
+ default_stereo="conf_expsigmoid_15_3",
169
+ default_flow="conf_expsigmoid_10_5",
170
+ help="Weights for tile aggregation",
171
+ )
172
+ add_arg(
173
+ "--val_overlap", default=0.7, type=float, help="Overlap value for the tiling"
174
+ )
175
+ # others
176
+ add_arg("--num_workers", default=8, type=int)
177
+ add_arg("--eval_every", type=int, default=1, help="Val loss evaluation frequency")
178
+ add_arg("--save_every", type=int, default=1, help="Save checkpoint frequency")
179
+ add_arg(
180
+ "--start_from",
181
+ type=str,
182
+ default=None,
183
+ help="Start training using weights from an other model (eg for finetuning)",
184
+ )
185
+ add_arg(
186
+ "--tboard_log_step",
187
+ type=int,
188
+ default=100,
189
+ help="Log to tboard every so many steps",
190
+ )
191
+ add_arg(
192
+ "--dist_url", default="env://", help="url used to set up distributed training"
193
+ )
194
+
195
+ return parser
196
+
197
+
198
+ def main(args):
199
+ misc.init_distributed_mode(args)
200
+ global_rank = misc.get_rank()
201
+ num_tasks = misc.get_world_size()
202
+
203
+ assert os.path.isfile(args.pretrained)
204
+ print("output_dir: " + args.output_dir)
205
+ os.makedirs(args.output_dir, exist_ok=True)
206
+
207
+ # fix the seed for reproducibility
208
+ seed = args.seed + misc.get_rank()
209
+ torch.manual_seed(seed)
210
+ np.random.seed(seed)
211
+ cudnn.benchmark = True
212
+
213
+ # Metrics / criterion
214
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
215
+ metrics = (StereoMetrics if args.task == "stereo" else FlowMetrics)().to(device)
216
+ criterion = eval(args.criterion).to(device)
217
+ print("Criterion: ", args.criterion)
218
+
219
+ # Prepare model
220
+ assert os.path.isfile(args.pretrained)
221
+ ckpt = torch.load(args.pretrained, "cpu")
222
+ croco_args = croco_args_from_ckpt(ckpt)
223
+ croco_args["img_size"] = (args.crop[0], args.crop[1])
224
+ print("Croco args: " + str(croco_args))
225
+ args.croco_args = croco_args # saved for test time
226
+ # prepare head
227
+ num_channels = {"stereo": 1, "flow": 2}[args.task]
228
+ if criterion.with_conf:
229
+ num_channels += 1
230
+ print(f"Building head PixelwiseTaskWithDPT() with {num_channels} channel(s)")
231
+ head = PixelwiseTaskWithDPT()
232
+ head.num_channels = num_channels
233
+ # build model and load pretrained weights
234
+ model = CroCoDownstreamBinocular(head, **croco_args)
235
+ interpolate_pos_embed(model, ckpt["model"])
236
+ msg = model.load_state_dict(ckpt["model"], strict=False)
237
+ print(msg)
238
+
239
+ total_params = sum(p.numel() for p in model.parameters())
240
+ total_params_trainable = sum(
241
+ p.numel() for p in model.parameters() if p.requires_grad
242
+ )
243
+ print(f"Total params: {total_params}")
244
+ print(f"Total params trainable: {total_params_trainable}")
245
+ model_without_ddp = model.to(device)
246
+
247
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
248
+ print("lr: %.2e" % args.lr)
249
+ print("accumulate grad iterations: %d" % args.accum_iter)
250
+ print("effective batch size: %d" % eff_batch_size)
251
+
252
+ if args.distributed:
253
+ model = torch.nn.parallel.DistributedDataParallel(
254
+ model, device_ids=[args.gpu], static_graph=True
255
+ )
256
+ model_without_ddp = model.module
257
+
258
+ # following timm: set wd as 0 for bias and norm layers
259
+ param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay)
260
+ optimizer = eval(f"torch.optim.{args.optimizer}")
261
+ print(optimizer)
262
+ loss_scaler = NativeScaler()
263
+
264
+ # automatic restart
265
+ last_ckpt_fname = os.path.join(args.output_dir, f"checkpoint-last.pth")
266
+ args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None
267
+
268
+ if not args.resume and args.start_from:
269
+ print(f"Starting from an other model's weights: {args.start_from}")
270
+ best_so_far = None
271
+ args.start_epoch = 0
272
+ ckpt = torch.load(args.start_from, "cpu")
273
+ msg = model_without_ddp.load_state_dict(ckpt["model"], strict=False)
274
+ print(msg)
275
+ else:
276
+ best_so_far = misc.load_model(
277
+ args=args,
278
+ model_without_ddp=model_without_ddp,
279
+ optimizer=optimizer,
280
+ loss_scaler=loss_scaler,
281
+ )
282
+
283
+ if best_so_far is None:
284
+ best_so_far = np.inf
285
+
286
+ # tensorboard
287
+ log_writer = None
288
+ if global_rank == 0 and args.output_dir is not None:
289
+ log_writer = SummaryWriter(
290
+ log_dir=args.output_dir, purge_step=args.start_epoch * 1000
291
+ )
292
+
293
+ # dataset and loader
294
+ print("Building Train Data loader for dataset: ", args.dataset)
295
+ train_dataset = (
296
+ get_train_dataset_stereo if args.task == "stereo" else get_train_dataset_flow
297
+ )(args.dataset, crop_size=args.crop)
298
+
299
+ def _print_repr_dataset(d):
300
+ if isinstance(d, torch.utils.data.dataset.ConcatDataset):
301
+ for dd in d.datasets:
302
+ _print_repr_dataset(dd)
303
+ else:
304
+ print(repr(d))
305
+
306
+ _print_repr_dataset(train_dataset)
307
+ print(" total length:", len(train_dataset))
308
+ if args.distributed:
309
+ sampler_train = torch.utils.data.DistributedSampler(
310
+ train_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
311
+ )
312
+ else:
313
+ sampler_train = torch.utils.data.RandomSampler(train_dataset)
314
+ data_loader_train = torch.utils.data.DataLoader(
315
+ train_dataset,
316
+ sampler=sampler_train,
317
+ batch_size=args.batch_size,
318
+ num_workers=args.num_workers,
319
+ pin_memory=True,
320
+ drop_last=True,
321
+ )
322
+ if args.val_dataset == "":
323
+ data_loaders_val = None
324
+ else:
325
+ print("Building Val Data loader for datasets: ", args.val_dataset)
326
+ val_datasets = (
327
+ get_test_datasets_stereo
328
+ if args.task == "stereo"
329
+ else get_test_datasets_flow
330
+ )(args.val_dataset)
331
+ for val_dataset in val_datasets:
332
+ print(repr(val_dataset))
333
+ data_loaders_val = [
334
+ DataLoader(
335
+ val_dataset,
336
+ batch_size=1,
337
+ shuffle=False,
338
+ num_workers=args.num_workers,
339
+ pin_memory=True,
340
+ drop_last=False,
341
+ )
342
+ for val_dataset in val_datasets
343
+ ]
344
+ bestmetric = (
345
+ "AVG_"
346
+ if len(data_loaders_val) > 1
347
+ else str(data_loaders_val[0].dataset) + "_"
348
+ ) + args.bestmetric
349
+
350
+ print(f"Start training for {args.epochs} epochs")
351
+ start_time = time.time()
352
+ # Training Loop
353
+ for epoch in range(args.start_epoch, args.epochs):
354
+
355
+ if args.distributed:
356
+ data_loader_train.sampler.set_epoch(epoch)
357
+
358
+ # Train
359
+ epoch_start = time.time()
360
+ train_stats = train_one_epoch(
361
+ model,
362
+ criterion,
363
+ metrics,
364
+ data_loader_train,
365
+ optimizer,
366
+ device,
367
+ epoch,
368
+ loss_scaler,
369
+ log_writer=log_writer,
370
+ args=args,
371
+ )
372
+ epoch_time = time.time() - epoch_start
373
+
374
+ if args.distributed:
375
+ dist.barrier()
376
+
377
+ # Validation (current naive implementation runs the validation on every gpu ... not smart ...)
378
+ if (
379
+ data_loaders_val is not None
380
+ and args.eval_every > 0
381
+ and (epoch + 1) % args.eval_every == 0
382
+ ):
383
+ val_epoch_start = time.time()
384
+ val_stats = validate_one_epoch(
385
+ model,
386
+ criterion,
387
+ metrics,
388
+ data_loaders_val,
389
+ device,
390
+ epoch,
391
+ log_writer=log_writer,
392
+ args=args,
393
+ )
394
+ val_epoch_time = time.time() - val_epoch_start
395
+
396
+ val_best = val_stats[bestmetric]
397
+
398
+ # Save best of all
399
+ if val_best <= best_so_far:
400
+ best_so_far = val_best
401
+ misc.save_model(
402
+ args=args,
403
+ model_without_ddp=model_without_ddp,
404
+ optimizer=optimizer,
405
+ loss_scaler=loss_scaler,
406
+ epoch=epoch,
407
+ best_so_far=best_so_far,
408
+ fname="best",
409
+ )
410
+
411
+ log_stats = {
412
+ **{f"train_{k}": v for k, v in train_stats.items()},
413
+ "epoch": epoch,
414
+ **{f"val_{k}": v for k, v in val_stats.items()},
415
+ }
416
+ else:
417
+ log_stats = {
418
+ **{f"train_{k}": v for k, v in train_stats.items()},
419
+ "epoch": epoch,
420
+ }
421
+
422
+ if args.distributed:
423
+ dist.barrier()
424
+
425
+ # Save stuff
426
+ if args.output_dir and (
427
+ (epoch + 1) % args.save_every == 0 or epoch + 1 == args.epochs
428
+ ):
429
+ misc.save_model(
430
+ args=args,
431
+ model_without_ddp=model_without_ddp,
432
+ optimizer=optimizer,
433
+ loss_scaler=loss_scaler,
434
+ epoch=epoch,
435
+ best_so_far=best_so_far,
436
+ fname="last",
437
+ )
438
+
439
+ if args.output_dir:
440
+ if log_writer is not None:
441
+ log_writer.flush()
442
+ with open(
443
+ os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
444
+ ) as f:
445
+ f.write(json.dumps(log_stats) + "\n")
446
+
447
+ total_time = time.time() - start_time
448
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
449
+ print("Training time {}".format(total_time_str))
450
+
451
+
452
+ if __name__ == "__main__":
453
+ args = get_args_parser()
454
+ args = args.parse_args()
455
+ main(args)
outdoor_v48_4gpu_v2/code/05_02-14:21:58/croco/utils/misc.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
11
+ # --------------------------------------------------------
12
+
13
+ import builtins
14
+ import datetime
15
+ import os
16
+ import time
17
+ import math
18
+ import json
19
+ from collections import defaultdict, deque
20
+ from pathlib import Path
21
+ import numpy as np
22
+ from itertools import islice
23
+
24
+ import torch
25
+ import torch.distributed as dist
26
+ from torch import inf
27
+ from accelerate import Accelerator
28
+ from accelerate.logging import get_logger
29
+
30
+ printer = get_logger(__name__, log_level="DEBUG")
31
+
32
+
33
+ class SmoothedValue(object):
34
+ """Track a series of values and provide access to smoothed values."""
35
+
36
+ def __init__(self, window_size=20, fmt=None):
37
+ if fmt is None:
38
+ fmt = "{median:.4f} ({global_avg:.4f})"
39
+ self.deque = deque(maxlen=window_size)
40
+ self.total = 0.0
41
+ self.count = 0
42
+ self.fmt = fmt
43
+
44
+ def update(self, value, n=1):
45
+ self.deque.append(value)
46
+ self.count += n
47
+ self.total += value * n
48
+
49
+ def synchronize_between_processes(self, accelerator: Accelerator):
50
+ """Synchronize the count and total across all processes."""
51
+ if accelerator.num_processes == 1:
52
+ return
53
+ t = torch.tensor(
54
+ [self.count, self.total], dtype=torch.float64, device=accelerator.device
55
+ )
56
+ accelerator.wait_for_everyone()
57
+ accelerator.reduce(t, reduction="sum")
58
+ t = t.tolist()
59
+ self.count = int(t[0])
60
+ self.total = t[1]
61
+
62
+ @property
63
+ def median(self):
64
+ return torch.tensor(list(self.deque)).median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ return torch.tensor(list(self.deque), dtype=torch.float32).mean().item()
69
+
70
+ @property
71
+ def global_avg(self):
72
+ return self.total / self.count
73
+
74
+ @property
75
+ def max(self):
76
+ return max(self.deque)
77
+
78
+ @property
79
+ def value(self):
80
+ return self.deque[-1]
81
+
82
+ def __str__(self):
83
+ return self.fmt.format(
84
+ median=self.median,
85
+ avg=self.avg,
86
+ global_avg=self.global_avg,
87
+ max=self.max,
88
+ value=self.value,
89
+ )
90
+
91
+
92
+ class MetricLogger(object):
93
+ def __init__(self, delimiter="\t"):
94
+ self.meters = defaultdict(SmoothedValue)
95
+ self.delimiter = delimiter
96
+
97
+ def update(self, **kwargs):
98
+ for k, v in kwargs.items():
99
+ if v is None:
100
+ continue
101
+ if isinstance(v, torch.Tensor):
102
+ if v.ndim > 0:
103
+ continue
104
+ v = v.item()
105
+ if isinstance(v, list):
106
+ continue
107
+ assert isinstance(v, (float, int))
108
+ self.meters[k].update(v)
109
+
110
+ def __getattr__(self, attr):
111
+ if attr in self.meters:
112
+ return self.meters[attr]
113
+ if attr in self.__dict__:
114
+ return self.__dict__[attr]
115
+ raise AttributeError(
116
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
117
+ )
118
+
119
+ def __str__(self):
120
+ loss_str = []
121
+ for name, meter in self.meters.items():
122
+ loss_str.append("{}: {}".format(name, str(meter)))
123
+ return self.delimiter.join(loss_str)
124
+
125
+ def synchronize_between_processes(self, accelerator):
126
+ for meter in self.meters.values():
127
+ meter.synchronize_between_processes(accelerator)
128
+
129
+ def add_meter(self, name, meter):
130
+ self.meters[name] = meter
131
+
132
+ def log_every(
133
+ self, iterable, print_freq, accelerator: Accelerator, header=None, max_iter=None, start_step=0,
134
+ ):
135
+ # `start_step` is used for resume: skip the first `start_step` iterations
136
+ # while keeping correct logging indices/ETA based on the original iterable length.
137
+ start_step = int(start_step or 0)
138
+ i = start_step
139
+ if not header:
140
+ header = ""
141
+ start_time = time.time()
142
+ end = time.time()
143
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
144
+ data_time = SmoothedValue(fmt="{avg:.4f}")
145
+ # Compute the total length BEFORE slicing; `islice` itself has no __len__.
146
+ len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable)
147
+ if start_step > 0:
148
+ iterable = islice(iterable, start_step, None)
149
+ space_fmt = ":" + str(len(str(len_iterable))) + "d"
150
+ log_msg = [
151
+ header,
152
+ "[{0" + space_fmt + "}/{1}]",
153
+ "eta: {eta}",
154
+ "{meters}",
155
+ "time: {time}",
156
+ "data: {data}",
157
+ ]
158
+ if torch.cuda.is_available():
159
+ log_msg.append("max mem: {memory:.0f}")
160
+ log_msg = self.delimiter.join(log_msg)
161
+ MB = 1024.0 * 1024.0
162
+ for it, obj in enumerate(iterable):
163
+ data_time.update(time.time() - end)
164
+ yield obj
165
+ iter_time.update(time.time() - end)
166
+ if i % print_freq == 0 or i == len_iterable - 1:
167
+ eta_seconds = iter_time.global_avg * (len_iterable - i)
168
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
169
+ if torch.cuda.is_available():
170
+ if accelerator.is_main_process:
171
+ printer.info(
172
+ log_msg.format(
173
+ i,
174
+ len_iterable,
175
+ eta=eta_string,
176
+ meters=str(self),
177
+ time=str(iter_time),
178
+ data=str(data_time),
179
+ memory=torch.cuda.max_memory_allocated() / MB,
180
+ )
181
+ )
182
+ else:
183
+ if accelerator.is_main_process:
184
+ printer.info(
185
+ log_msg.format(
186
+ i,
187
+ len_iterable,
188
+ eta=eta_string,
189
+ meters=str(self),
190
+ time=str(iter_time),
191
+ data=str(data_time),
192
+ )
193
+ )
194
+ i += 1
195
+ end = time.time()
196
+ if max_iter and it >= max_iter:
197
+ break
198
+ # if i + start_step >= len_iterable:
199
+ # break
200
+ total_time = time.time() - start_time
201
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
202
+ if accelerator.is_main_process:
203
+ printer.info(
204
+ "{} Total time: {} ({:.4f} s / it)".format(
205
+ header, total_time_str, total_time / len_iterable
206
+ )
207
+ )
208
+
209
+
210
+ def setup_for_distributed(is_master):
211
+ """
212
+ This function disables printing when not in master process
213
+ """
214
+ builtin_print = builtins.print
215
+
216
+ def print(*args, **kwargs):
217
+ force = kwargs.pop("force", False)
218
+ force = force or (get_world_size() > 8)
219
+ if is_master or force:
220
+ now = datetime.datetime.now().time()
221
+ builtin_print("[{}] ".format(now), end="") # print with time stamp
222
+ builtin_print(*args, **kwargs)
223
+
224
+ builtins.print = print
225
+
226
+
227
+ def is_dist_avail_and_initialized():
228
+ if not dist.is_available():
229
+ return False
230
+ if not dist.is_initialized():
231
+ return False
232
+ return True
233
+
234
+
235
+ def get_world_size():
236
+ if not is_dist_avail_and_initialized():
237
+ return 1
238
+ return dist.get_world_size()
239
+
240
+
241
+ def get_rank():
242
+ if not is_dist_avail_and_initialized():
243
+ return 0
244
+ return dist.get_rank()
245
+
246
+
247
+ def is_main_process(accelerator: Accelerator):
248
+ return accelerator.is_main_process
249
+
250
+
251
+ def save_on_master(accelerator: Accelerator, *args, **kwargs):
252
+ if is_main_process(accelerator):
253
+ # torch.save(*args, **kwargs)
254
+ accelerator.save(*args, **kwargs)
255
+ # unwrapped_model = accelerator.unwrap_model(model)
256
+ # accelerator.save(unwrapped_model.state_dict(), checkpoint_path)
257
+
258
+
259
+ def init_distributed_mode(args):
260
+ nodist = args.nodist if hasattr(args, "nodist") else False
261
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not nodist:
262
+ args.rank = int(os.environ["RANK"])
263
+ args.world_size = int(os.environ["WORLD_SIZE"])
264
+ args.gpu = int(os.environ["LOCAL_RANK"])
265
+ else:
266
+ print("Not using distributed mode")
267
+ setup_for_distributed(is_master=True) # hack
268
+ args.distributed = False
269
+ return
270
+
271
+ args.distributed = True
272
+
273
+ torch.cuda.set_device(args.gpu)
274
+ args.dist_backend = "nccl"
275
+ print(
276
+ "| distributed init (rank {}): {}, gpu {}".format(
277
+ args.rank, args.dist_url, args.gpu
278
+ ),
279
+ flush=True,
280
+ )
281
+ torch.distributed.init_process_group(
282
+ backend=args.dist_backend,
283
+ init_method=args.dist_url,
284
+ world_size=args.world_size,
285
+ rank=args.rank,
286
+ )
287
+ torch.distributed.barrier()
288
+ setup_for_distributed(args.rank == 0)
289
+
290
+
291
+ class NativeScalerWithGradNormCount:
292
+ state_dict_key = "amp_scaler"
293
+
294
+ def __init__(self, enabled=True, accelerator: Accelerator = None):
295
+ self.accelerator = accelerator
296
+
297
+ def __call__(
298
+ self,
299
+ loss,
300
+ optimizer,
301
+ clip_grad=None,
302
+ parameters=None,
303
+ create_graph=False,
304
+ update_grad=True,
305
+ ):
306
+ self.accelerator.backward(
307
+ loss, create_graph=create_graph
308
+ ) # .backward(create_graph=create_graph)
309
+ if update_grad:
310
+ if clip_grad is not None:
311
+ assert parameters is not None
312
+ # self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
313
+ norm = self.accelerator.clip_grad_norm_(parameters, clip_grad)
314
+ else:
315
+ if self.accelerator.scaler is not None:
316
+ self.accelerator.unscale_gradients()
317
+ norm = get_grad_norm_(parameters)
318
+ optimizer.step()
319
+ else:
320
+ norm = None
321
+ return norm
322
+
323
+ def state_dict(self):
324
+ if self.accelerator.scaler is not None:
325
+ return self.accelerator.scaler.state_dict()
326
+ else:
327
+ return {}
328
+
329
+ def load_state_dict(self, state_dict):
330
+ if self.accelerator.scaler is not None:
331
+ self.accelerator.scaler.load_state_dict(state_dict)
332
+
333
+
334
+ # class NativeScalerWithGradNormCount:
335
+ # state_dict_key = "amp_scaler"
336
+
337
+ # def __init__(self, enabled=True, accelerator:Accelerator=None):
338
+ # self._scaler = torch.cuda.amp.GradScaler(enabled=enabled)
339
+ # self.accelerator = accelerator
340
+
341
+ # def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
342
+ # # self.accelerator.backward(loss, create_graph=create_graph) #.backward(create_graph=create_graph)
343
+ # self._scaler.scale(loss).backward(create_graph=create_graph)
344
+ # if update_grad:
345
+ # if clip_grad is not None:
346
+ # assert parameters is not None
347
+ # # #self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
348
+ # # norm = self.accelerator.clip_grad_norm_(parameters, clip_grad)
349
+ # self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
350
+ # norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
351
+ # else:
352
+ # # if self.accelerator.scaler is not None:
353
+ # # self.accelerator.unscale_gradients()
354
+ # # norm = get_grad_norm_(parameters)
355
+ # self._scaler.unscale_(optimizer)
356
+ # norm = get_grad_norm_(parameters)
357
+ # # optimizer.step()
358
+ # self._scaler.step(optimizer)
359
+ # self._scaler.update()
360
+ # else:
361
+ # norm = None
362
+ # return norm
363
+
364
+ # # def state_dict(self):
365
+ # # if self.accelerator.scaler is not None:
366
+ # # return self.accelerator.scaler.state_dict()
367
+ # # else:
368
+ # # return {}
369
+
370
+ # # def load_state_dict(self, state_dict):
371
+ # # if self.accelerator.scaler is not None:
372
+ # # self.accelerator.scaler.load_state_dict(state_dict)
373
+
374
+ # def state_dict(self):
375
+ # return self._scaler.state_dict()
376
+
377
+ # def load_state_dict(self, state_dict):
378
+ # self._scaler.load_state_dict(state_dict)
379
+
380
+
381
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
382
+ if isinstance(parameters, torch.Tensor):
383
+ parameters = [parameters]
384
+ parameters = [p for p in parameters if p.grad is not None]
385
+ norm_type = float(norm_type)
386
+ if len(parameters) == 0:
387
+ return torch.tensor(0.0)
388
+ device = parameters[0].grad.device
389
+ if norm_type == inf:
390
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
391
+ else:
392
+ total_norm = torch.norm(
393
+ torch.stack(
394
+ [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
395
+ ),
396
+ norm_type,
397
+ )
398
+ return total_norm
399
+
400
+
401
+ def save_model(
402
+ accelerator,
403
+ args,
404
+ epoch,
405
+ model_without_ddp,
406
+ optimizer,
407
+ loss_scaler,
408
+ step,
409
+ fname=None,
410
+ best_so_far=None,
411
+ ):
412
+ if accelerator.is_main_process:
413
+ output_dir = Path(args.output_dir)
414
+ if fname is None:
415
+ fname = str(epoch)
416
+ checkpoint_path = output_dir / ("checkpoint-%s.pth" % fname)
417
+ to_save = {
418
+ "model": model_without_ddp.state_dict(),
419
+ "optimizer": optimizer.state_dict(),
420
+ "scaler": loss_scaler.state_dict(),
421
+ "args": args,
422
+ "epoch": epoch,
423
+ "step": step,
424
+ }
425
+ if best_so_far is not None:
426
+ to_save["best_so_far"] = best_so_far
427
+ print(f">> Saving model to {checkpoint_path} ...")
428
+ save_on_master(accelerator, to_save, checkpoint_path)
429
+
430
+ to_save = {
431
+ "model": model_without_ddp.state_dict(),
432
+ }
433
+ checkpoint_path = output_dir / ("model.pth")
434
+ save_on_master(accelerator, to_save, checkpoint_path)
435
+
436
+
437
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
438
+ args.start_epoch = 0
439
+ args.start_step = 0
440
+ best_so_far = None
441
+ if args.resume is not None:
442
+ if args.resume.startswith("https"):
443
+ checkpoint = torch.hub.load_state_dict_from_url(
444
+ args.resume, map_location="cpu", check_hash=True
445
+ )
446
+ else:
447
+ checkpoint = torch.load(args.resume, map_location="cuda", weights_only=False)
448
+ printer.info("Resume checkpoint %s" % args.resume)
449
+ state_dict = checkpoint["model"]
450
+ new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
451
+ model_without_ddp.load_state_dict(new_state_dict, strict=True)
452
+ args.start_epoch = checkpoint["epoch"] + 1
453
+ if "step" in checkpoint:
454
+ args.start_step = checkpoint["step"]
455
+ device = next(model_without_ddp.parameters()).device
456
+ printer.info(f"Moving optimizer state to device: {device}")
457
+
458
+ if "optimizer" in checkpoint:
459
+ for state in checkpoint["optimizer"]["state"].values():
460
+ for k, v in state.items():
461
+ if isinstance(v, torch.Tensor):
462
+ state[k] = v.to(device)
463
+
464
+ optimizer.load_state_dict(checkpoint["optimizer"])
465
+
466
+ if "scaler" in checkpoint:
467
+ loss_scaler.load_state_dict(checkpoint["scaler"])
468
+ if "best_so_far" in checkpoint:
469
+ best_so_far = checkpoint["best_so_far"]
470
+ printer.info(" & best_so_far={:g}".format(best_so_far))
471
+ else:
472
+ printer.info("")
473
+ printer.info("With optim & sched! start_epoch={:d}".format(args.start_epoch))
474
+ return best_so_far
475
+
476
+
477
+ def all_reduce_mean(x, accelerator):
478
+ """Use accelerator to all-reduce and compute mean."""
479
+ if accelerator.state.num_processes > 1:
480
+ x_reduce = torch.tensor(x).cuda()
481
+ accelerator.reduce(x_reduce, reduce_op="SUM")
482
+ x_reduce /= accelerator.state.num_processes
483
+ return x_reduce.item()
484
+ else:
485
+ return x
486
+
487
+
488
+ def _replace(text, src, tgt, rm=""):
489
+ """Advanced string replacement.
490
+ Given a text:
491
+ - replace all elements in src by the corresponding element in tgt
492
+ - remove all elements in rm
493
+ """
494
+ if len(tgt) == 1:
495
+ tgt = tgt * len(src)
496
+ assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len"
497
+ for s, t in zip(src, tgt):
498
+ text = text.replace(s, t)
499
+ for c in rm:
500
+ text = text.replace(c, "")
501
+ return text
502
+
503
+
504
+ def filename(obj):
505
+ """transform a python obj or cmd into a proper filename.
506
+ - \1 gets replaced by slash '/'
507
+ - \2 gets replaced by comma ','
508
+ """
509
+ if not isinstance(obj, str):
510
+ obj = repr(obj)
511
+ obj = str(obj).replace("()", "")
512
+ obj = _replace(obj, "_,(*/\1\2", "-__x%/,", rm=" )'\"")
513
+ assert all(len(s) < 256 for s in obj.split(os.sep)), (
514
+ "filename too long (>256 characters):\n" + obj
515
+ )
516
+ return obj
517
+
518
+
519
+ def _get_num_layer_for_vit(var_name, enc_depth, dec_depth):
520
+ if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"):
521
+ return 0
522
+ elif var_name.startswith("patch_embed"):
523
+ return 0
524
+ elif var_name.startswith("enc_blocks"):
525
+ layer_id = int(var_name.split(".")[1])
526
+ return layer_id + 1
527
+ elif var_name.startswith("decoder_embed") or var_name.startswith(
528
+ "enc_norm"
529
+ ): # part of the last black
530
+ return enc_depth
531
+ elif var_name.startswith("dec_blocks"):
532
+ layer_id = int(var_name.split(".")[1])
533
+ return enc_depth + layer_id + 1
534
+ elif var_name.startswith("dec_norm"): # part of the last block
535
+ return enc_depth + dec_depth
536
+ elif any(var_name.startswith(k) for k in ["head", "prediction_head"]):
537
+ return enc_depth + dec_depth + 1
538
+ else:
539
+ raise NotImplementedError(var_name)
540
+
541
+
542
+ def get_parameter_groups(
543
+ model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]
544
+ ):
545
+ parameter_group_names = {}
546
+ parameter_group_vars = {}
547
+ enc_depth, dec_depth = None, None
548
+ # prepare layer decay values
549
+ assert layer_decay == 1.0 or 0.0 < layer_decay < 1.0
550
+ if layer_decay < 1.0:
551
+ enc_depth = model.enc_depth
552
+ dec_depth = model.dec_depth if hasattr(model, "dec_blocks") else 0
553
+ num_layers = enc_depth + dec_depth
554
+ layer_decay_values = list(
555
+ layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)
556
+ )
557
+
558
+ for name, param in model.named_parameters():
559
+ if not param.requires_grad:
560
+ continue # frozen weights
561
+
562
+ # Assign weight decay values
563
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
564
+ if "enc_blocks" in name:
565
+ group_name = "no_decay_enc_blocks"
566
+ else:
567
+ group_name = "no_decay"
568
+ this_weight_decay = 0.0
569
+ else:
570
+ if "enc_blocks" in name:
571
+ group_name = "decay_enc_blocks"
572
+ else:
573
+ group_name = "decay"
574
+ this_weight_decay = weight_decay
575
+
576
+ # Assign layer ID for LR scaling
577
+ if layer_decay < 1.0:
578
+ skip_scale = False
579
+ layer_id = _get_num_layer_for_vit(name, enc_depth, dec_depth)
580
+ group_name = "layer_%d_%s" % (layer_id, group_name)
581
+ if name in no_lr_scale_list:
582
+ skip_scale = True
583
+ group_name = f"{group_name}_no_lr_scale"
584
+ else:
585
+ layer_id = 0
586
+ skip_scale = True
587
+
588
+ if group_name not in parameter_group_names:
589
+ if not skip_scale:
590
+ scale = layer_decay_values[layer_id]
591
+ else:
592
+ scale = 1.0
593
+
594
+ if "enc_blocks" in group_name:
595
+ scale *= 1.0
596
+ parameter_group_names[group_name] = {
597
+ "weight_decay": this_weight_decay,
598
+ "params": [],
599
+ "lr_scale": scale,
600
+ }
601
+ parameter_group_vars[group_name] = {
602
+ "weight_decay": this_weight_decay,
603
+ "params": [],
604
+ "lr_scale": scale,
605
+ }
606
+
607
+ parameter_group_vars[group_name]["params"].append(param)
608
+ parameter_group_names[group_name]["params"].append(name)
609
+ printer.info("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
610
+ return list(parameter_group_vars.values())
611
+
612
+
613
+ def adjust_learning_rate(optimizer, epoch, args):
614
+ """Decay the learning rate with half-cycle cosine after warmup"""
615
+
616
+ if epoch < args.warmup_epochs:
617
+ lr = args.lr * epoch / args.warmup_epochs
618
+ else:
619
+ # lr = args.lr
620
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
621
+ 1.0
622
+ + math.cos(
623
+ math.pi
624
+ * (epoch - args.warmup_epochs)
625
+ / (args.epochs - args.warmup_epochs)
626
+ )
627
+ )
628
+
629
+ for param_group in optimizer.param_groups:
630
+ if "lr_scale" in param_group:
631
+ param_group["lr"] = lr * param_group["lr_scale"]
632
+ else:
633
+ param_group["lr"] = lr
634
+
635
+ return lr
outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/__init__.py ADDED
File without changes
outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/alignment.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+ from collections import namedtuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.types
10
+ # import utils3d
11
+
12
+
13
+ def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
14
+ "Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
15
+ shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
16
+ minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
17
+ minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
18
+ indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
19
+ indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
20
+ return torch.return_types.min((minimum, indices))
21
+
22
+
23
+ def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
24
+ batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
25
+ n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
26
+ splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
27
+ splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
28
+ results = []
29
+ for i in range(n_chunks):
30
+ chunk_args = tuple(arg[i] for arg in splited_args)
31
+ chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
32
+ results.append(fn(*chunk_args, **chunk_kwargs))
33
+
34
+ if isinstance(results[0], tuple):
35
+ return tuple(torch.cat(r, dim=0) for r in zip(*results))
36
+ else:
37
+ return torch.cat(results, dim=0)
38
+
39
+
40
+ def _pad_inf(x_: torch.Tensor):
41
+ return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
42
+
43
+
44
+ def _pad_cumsum(cumsum: torch.Tensor):
45
+ return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
46
+
47
+
48
+ def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
49
+ return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
50
+
51
+
52
+ def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
53
+ """
54
+ If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
55
+
56
+ w_i must be >= 0.
57
+
58
+ ### Parameters:
59
+ - `x`: tensor of shape (..., n)
60
+ - `y`: tensor of shape (..., n)
61
+ - `w`: tensor of shape (..., n)
62
+ - `trunc`: optional, float or tensor of shape (..., n) or None
63
+
64
+ ### Returns:
65
+ - `a`: tensor of shape (...), differentiable
66
+ - `loss`: tensor of shape (...), value of loss function at `a`, detached
67
+ - `index`: tensor of shape (...), where a = y[idx] / x[idx]
68
+ """
69
+ if trunc is None:
70
+ x, y, w = torch.broadcast_tensors(x, y, w)
71
+ sign = torch.sign(x)
72
+ x, y = x * sign, y * sign
73
+ y_div_x = y / x.clamp_min(eps)
74
+ y_div_x, argsort = y_div_x.sort(dim=-1)
75
+
76
+ wx = torch.gather(x * w, dim=-1, index=argsort)
77
+ derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
78
+ search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
79
+
80
+ a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
81
+ index = argsort.gather(dim=-1, index=search).squeeze(-1)
82
+ loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
83
+
84
+ else:
85
+ # Reshape to (batch_size, n) for simplicity
86
+ x, y, w = torch.broadcast_tensors(x, y, w)
87
+ batch_shape = x.shape[:-1]
88
+ batch_size = math.prod(batch_shape)
89
+ x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
90
+
91
+ sign = torch.sign(x)
92
+ x, y = x * sign, y * sign
93
+ wx, wy = w * x, w * y
94
+ xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
95
+
96
+ y_div_x = A = y / x.clamp_min(eps)
97
+ B = (wy - trunc) / wx.clamp_min(eps)
98
+ C = (wy + trunc) / wx.clamp_min(eps)
99
+ with torch.no_grad():
100
+ # Caculate prefix sum by orders of A, B, C
101
+ A, A_argsort = A.sort(dim=-1)
102
+ Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
103
+ A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
104
+
105
+ B, B_argsort = B.sort(dim=-1)
106
+ Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
107
+ B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
108
+
109
+ C, C_argsort = C.sort(dim=-1)
110
+ Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
111
+ C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
112
+
113
+ # Caculate left and right derivative of A
114
+ j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
115
+ j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
116
+ j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
117
+ left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
118
+ j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
119
+ j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
120
+ j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
121
+ right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
122
+
123
+ # Find extrema
124
+ is_extrema = (left_derivative < 0) & (right_derivative >= 0)
125
+ is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
126
+ where_extrema_batch, where_extrema_index = torch.where(is_extrema)
127
+
128
+ # Calculate objective value at extrema
129
+ extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
130
+ MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
131
+ SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
132
+ extrema_value = torch.cat([
133
+ _compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
134
+ for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
135
+ ]) # (num_extrema,)
136
+
137
+ # Find minima among corresponding extrema
138
+ minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
139
+ index = where_extrema_index[indices]
140
+
141
+ a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
142
+ a = a.reshape(batch_shape)
143
+ loss = minima.reshape(batch_shape)
144
+ index = index.reshape(batch_shape)
145
+
146
+ return a, loss, index
147
+
148
+
149
+ def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
150
+ """
151
+ Align `depth_src` to `depth_tgt` with given constant weights.
152
+
153
+ ### Parameters:
154
+ - `depth_src: torch.Tensor` of shape (..., N)
155
+ - `depth_tgt: torch.Tensor` of shape (..., N)
156
+
157
+ """
158
+ scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
159
+
160
+ return scale
161
+
162
+
163
+ def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
164
+ """
165
+ Align `depth_src` to `depth_tgt` with given constant weights.
166
+
167
+ ### Parameters:
168
+ - `depth_src: torch.Tensor` of shape (..., N)
169
+ - `depth_tgt: torch.Tensor` of shape (..., N)
170
+ - `weight: torch.Tensor` of shape (..., N)
171
+ - `trunc: float` or tensor of shape (..., N) or None
172
+
173
+ ### Returns:
174
+ - `scale: torch.Tensor` of shape (...).
175
+ - `shift: torch.Tensor` of shape (...).
176
+ """
177
+ dtype, device = depth_src.dtype, depth_src.device
178
+
179
+ # Flatten batch dimensions for simplicity
180
+ batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
181
+ batch_size = math.prod(batch_shape)
182
+ depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
183
+
184
+ # Here, we take anchors only for non-zero weights.
185
+ # Although the results will be still correct even anchor points have zero weight,
186
+ # it is wasting computation and may cause instability in some cases, e.g. too many extrema.
187
+ anchors_where_batch, anchors_where_n = torch.where(weight > 0)
188
+
189
+ # Stop gradient when solving optimal anchors
190
+ with torch.no_grad():
191
+ depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
192
+ depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
193
+
194
+ depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
195
+ depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
196
+ weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
197
+
198
+ scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
199
+
200
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
201
+
202
+ # Reproduce by indexing for shorter compute graph
203
+ index_1 = anchors_where_n[index_anchor] # (batch_size,)
204
+ index_2 = index[index_anchor] # (batch_size,)
205
+
206
+ tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
207
+ tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
208
+
209
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
210
+ shift = tgt_1 - scale * src_1
211
+
212
+ scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
213
+
214
+ return scale, shift
215
+
216
+ def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
217
+ """
218
+ Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
219
+ """
220
+ dtype, device = depth_src.dtype, depth_src.device
221
+
222
+ w = weight
223
+ x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
224
+ y = depth_tgt
225
+
226
+ for i in range(max_iter):
227
+ beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
228
+ w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
229
+
230
+ return beta[..., 0], beta[..., 1]
231
+
232
+
233
+ def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
234
+ """
235
+ ### Parameters:
236
+ - `points_src: torch.Tensor` of shape (..., N, 3)
237
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
238
+ - `weight: torch.Tensor` of shape (..., N)
239
+
240
+ ### Returns:
241
+ - `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
242
+ - `b: torch.Tensor` of shape (...)
243
+ """
244
+ dtype, device = points_src.dtype, points_src.device
245
+
246
+ scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
247
+
248
+ return scale
249
+
250
+
251
+ def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
252
+ """
253
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
254
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
255
+
256
+ ### Parameters:
257
+ - `points_src: torch.Tensor` of shape (..., N, 3)
258
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
259
+ - `weights: torch.Tensor` of shape (..., N)
260
+
261
+ ### Returns:
262
+ - `scale: torch.Tensor` of shape (...).
263
+ - `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
264
+ """
265
+ dtype, device = points_src.dtype, points_src.device
266
+
267
+ # Flatten batch dimensions for simplicity
268
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
269
+ batch_size = math.prod(batch_shape)
270
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
271
+
272
+ # Take anchors
273
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
274
+ with torch.no_grad():
275
+ zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
276
+ points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
277
+ points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
278
+
279
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
280
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
281
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
282
+
283
+ # Solve optimal scale and shift for each anchor
284
+ MAX_ELEMENTS = 2 ** 20
285
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
286
+
287
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
288
+
289
+ # Reproduce by indexing for shorter compute graph
290
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
291
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
292
+
293
+ zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
294
+ points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
295
+ tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
296
+ tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
297
+
298
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
299
+ shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
300
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
301
+
302
+ return scale, shift
303
+
304
+
305
+ def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
306
+ """
307
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
308
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
309
+
310
+ ### Parameters:
311
+ - `points_src: torch.Tensor` of shape (..., N, 3)
312
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
313
+ - `weights: torch.Tensor` of shape (..., N)
314
+
315
+ ### Returns:
316
+ - `scale: torch.Tensor` of shape (...).
317
+ - `shift: torch.Tensor` of shape (..., 3)
318
+ """
319
+ dtype, device = points_src.dtype, points_src.device
320
+
321
+ # Flatten batch dimensions for simplicity
322
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
323
+ batch_size = math.prod(batch_shape)
324
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
325
+
326
+ # Take anchors
327
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
328
+
329
+ with torch.no_grad():
330
+ points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
331
+ points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
332
+
333
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
334
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
335
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
336
+
337
+ # Solve optimal scale and shift for each anchor
338
+ MAX_ELEMENTS = 2 ** 20
339
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
340
+
341
+ # Get optimal scale and shift for each batch element
342
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
343
+
344
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
345
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
346
+
347
+ src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
348
+ src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
349
+
350
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
351
+ shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
352
+
353
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
354
+
355
+ return scale, shift
356
+
357
+
358
+ def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
359
+ """
360
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
361
+
362
+ ### Parameters:
363
+ - `points_src: torch.Tensor` of shape (..., N, 3)
364
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
365
+ - `weights: torch.Tensor` of shape (..., N)
366
+
367
+ ### Returns:
368
+ - `scale: torch.Tensor` of shape (...).
369
+ - `shift: torch.Tensor` of shape (..., 3)
370
+ """
371
+ dtype, device = points_src.dtype, points_src.device
372
+
373
+ shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
374
+ shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
375
+
376
+ return shift
377
+
378
+
379
+ def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
380
+ """
381
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
382
+
383
+ ### Parameters:
384
+ - `points_src: torch.Tensor` of shape (..., N, 3)
385
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
386
+ - `weights: torch.Tensor` of shape (..., N)
387
+
388
+ ### Returns:
389
+ - `scale: torch.Tensor` of shape (...).
390
+ - `shift: torch.Tensor` of shape (..., 3)
391
+ """
392
+ dtype, device = points_src.dtype, points_src.device
393
+
394
+ shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
395
+
396
+ return shift
397
+
398
+
399
+ def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
400
+ """
401
+ Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
402
+
403
+ ### Parameters:
404
+ - `x: torch.Tensor` of shape (..., N)
405
+ - `y: torch.Tensor` of shape (..., N)
406
+ - `w: torch.Tensor` of shape (..., N)
407
+
408
+ ### Returns:
409
+ - `a: torch.Tensor` of shape (...,)
410
+ - `b: torch.Tensor` of shape (...,)
411
+ """
412
+ w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
413
+ A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
414
+ B = (w_sqrt * y)[..., None]
415
+ a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
416
+ return a, b
outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/blocks.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # modified from DUSt3R
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from itertools import repeat
11
+ import collections.abc
12
+ from torch.nn.functional import scaled_dot_product_attention
13
+ from functools import partial
14
+
15
+
16
+ def _ntuple(n):
17
+ def parse(x):
18
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
19
+ return x
20
+ return tuple(repeat(x, n))
21
+
22
+ return parse
23
+
24
+
25
+ to_2tuple = _ntuple(2)
26
+
27
+
28
+ def drop_path(
29
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
30
+ ):
31
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
32
+ if drop_prob == 0.0 or not training:
33
+ return x
34
+ keep_prob = 1 - drop_prob
35
+ shape = (x.shape[0],) + (1,) * (
36
+ x.ndim - 1
37
+ ) # work with diff dim tensors, not just 2D ConvNets
38
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
39
+ if keep_prob > 0.0 and scale_by_keep:
40
+ random_tensor.div_(keep_prob)
41
+ return x * random_tensor
42
+
43
+
44
+ class DropPath(nn.Module):
45
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
46
+
47
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
48
+ super(DropPath, self).__init__()
49
+ self.drop_prob = drop_prob
50
+ self.scale_by_keep = scale_by_keep
51
+
52
+ def forward(self, x):
53
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
54
+
55
+ def extra_repr(self):
56
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
57
+
58
+
59
+ class Mlp(nn.Module):
60
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
61
+
62
+ def __init__(
63
+ self,
64
+ in_features,
65
+ hidden_features=None,
66
+ out_features=None,
67
+ act_layer=nn.GELU,
68
+ bias=True,
69
+ drop=0.0,
70
+ ):
71
+ super().__init__()
72
+ out_features = out_features or in_features
73
+ hidden_features = hidden_features or in_features
74
+ bias = to_2tuple(bias)
75
+ drop_probs = to_2tuple(drop)
76
+
77
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
78
+ self.act = act_layer()
79
+ self.drop1 = nn.Dropout(drop_probs[0])
80
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
81
+ self.drop2 = nn.Dropout(drop_probs[1])
82
+
83
+ def forward(self, x):
84
+ return self.drop2(self.fc2(self.drop1(self.act(self.fc1(x)))))
85
+
86
+
87
+ class Attention(nn.Module):
88
+
89
+ def __init__(
90
+ self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
91
+ ):
92
+ super().__init__()
93
+ self.num_heads = num_heads
94
+ head_dim = dim // num_heads
95
+ self.scale = head_dim**-0.5
96
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
97
+ self.attn_drop = nn.Dropout(attn_drop)
98
+ self.proj = nn.Linear(dim, dim)
99
+ self.proj_drop = nn.Dropout(proj_drop)
100
+ self.rope = rope.float() if rope is not None else None
101
+
102
+ def forward(self, x, xpos):
103
+ B, N, C = x.shape
104
+
105
+ qkv = (
106
+ self.qkv(x)
107
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
108
+ .transpose(1, 3)
109
+ )
110
+ q, k, v = [qkv[:, :, i] for i in range(3)]
111
+
112
+ q_type = q.dtype
113
+ k_type = k.dtype
114
+ if self.rope is not None:
115
+ q = q.float()
116
+ k = k.float()
117
+ with torch.autocast(device_type="cuda", enabled=False):
118
+ q = self.rope(q, xpos)
119
+ k = self.rope(k, xpos)
120
+ q = q.to(q_type)
121
+ k = k.to(k_type)
122
+
123
+ x = (
124
+ scaled_dot_product_attention(
125
+ query=q, key=k, value=v, dropout_p=self.attn_drop.p, scale=self.scale
126
+ )
127
+ .transpose(1, 2)
128
+ .reshape(B, N, C)
129
+ )
130
+
131
+ x = self.proj(x)
132
+ x = self.proj_drop(x)
133
+ return x
134
+
135
+
136
+ class Block(nn.Module):
137
+
138
+ def __init__(
139
+ self,
140
+ dim,
141
+ num_heads,
142
+ mlp_ratio=4.0,
143
+ qkv_bias=False,
144
+ drop=0.0,
145
+ attn_drop=0.0,
146
+ drop_path=0.0,
147
+ act_layer=nn.GELU,
148
+ norm_layer=nn.LayerNorm,
149
+ rope=None,
150
+ ):
151
+ super().__init__()
152
+ self.norm1 = norm_layer(dim)
153
+ self.attn = Attention(
154
+ dim,
155
+ rope=rope,
156
+ num_heads=num_heads,
157
+ qkv_bias=qkv_bias,
158
+ attn_drop=attn_drop,
159
+ proj_drop=drop,
160
+ )
161
+
162
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
163
+ self.norm2 = norm_layer(dim)
164
+ mlp_hidden_dim = int(dim * mlp_ratio)
165
+ self.mlp = Mlp(
166
+ in_features=dim,
167
+ hidden_features=mlp_hidden_dim,
168
+ act_layer=act_layer,
169
+ drop=drop,
170
+ )
171
+
172
+ def forward(self, x, xpos):
173
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
174
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
175
+ return x
176
+
177
+
178
+ class CrossAttention(nn.Module):
179
+
180
+ def __init__(
181
+ self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
182
+ ):
183
+ super().__init__()
184
+ self.num_heads = num_heads
185
+ head_dim = dim // num_heads
186
+ self.scale = head_dim**-0.5
187
+
188
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
189
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
190
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
191
+ self.attn_drop = nn.Dropout(attn_drop)
192
+ self.proj = nn.Linear(dim, dim)
193
+ self.proj_drop = nn.Dropout(proj_drop)
194
+
195
+ self.rope = rope.float() if rope is not None else None
196
+
197
+ def forward(self, query, key, value, qpos, kpos):
198
+ B, Nq, C = query.shape
199
+ Nk = key.shape[1]
200
+ Nv = value.shape[1]
201
+
202
+ q = (
203
+ self.projq(query)
204
+ .reshape(B, Nq, self.num_heads, C // self.num_heads)
205
+ .permute(0, 2, 1, 3)
206
+ )
207
+ k = (
208
+ self.projk(key)
209
+ .reshape(B, Nk, self.num_heads, C // self.num_heads)
210
+ .permute(0, 2, 1, 3)
211
+ )
212
+ v = (
213
+ self.projv(value)
214
+ .reshape(B, Nv, self.num_heads, C // self.num_heads)
215
+ .permute(0, 2, 1, 3)
216
+ )
217
+
218
+ q_type = q.dtype
219
+ k_type = k.dtype
220
+ if self.rope is not None:
221
+ if qpos is not None:
222
+ q = q.float()
223
+ with torch.autocast(device_type="cuda", enabled=False):
224
+ q = self.rope(q, qpos)
225
+ q = q.to(q_type)
226
+
227
+ if kpos is not None:
228
+ k = k.float()
229
+ with torch.autocast(device_type="cuda", enabled=False):
230
+ k = self.rope(k, kpos)
231
+ k = k.to(k_type)
232
+
233
+ x = (
234
+ scaled_dot_product_attention(
235
+ query=q, key=k, value=v, dropout_p=self.attn_drop.p, scale=self.scale
236
+ )
237
+ .transpose(1, 2)
238
+ .reshape(B, Nq, C)
239
+ )
240
+
241
+ x = self.proj(x)
242
+ x = self.proj_drop(x)
243
+ return x
244
+
245
+
246
+ class DecoderBlock(nn.Module):
247
+
248
+ def __init__(
249
+ self,
250
+ dim,
251
+ num_heads,
252
+ mlp_ratio=4.0,
253
+ qkv_bias=False,
254
+ drop=0.0,
255
+ attn_drop=0.0,
256
+ drop_path=0.0,
257
+ act_layer=nn.GELU,
258
+ norm_layer=nn.LayerNorm,
259
+ norm_mem=True,
260
+ rope=None,
261
+ ):
262
+ super().__init__()
263
+ self.norm1 = norm_layer(dim)
264
+ self.attn = Attention(
265
+ dim,
266
+ rope=rope,
267
+ num_heads=num_heads,
268
+ qkv_bias=qkv_bias,
269
+ attn_drop=attn_drop,
270
+ proj_drop=drop,
271
+ )
272
+ self.cross_attn = CrossAttention(
273
+ dim,
274
+ rope=rope,
275
+ num_heads=num_heads,
276
+ qkv_bias=qkv_bias,
277
+ attn_drop=attn_drop,
278
+ proj_drop=drop,
279
+ )
280
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
281
+ self.norm2 = norm_layer(dim)
282
+ self.norm3 = norm_layer(dim)
283
+ mlp_hidden_dim = int(dim * mlp_ratio)
284
+ self.mlp = Mlp(
285
+ in_features=dim,
286
+ hidden_features=mlp_hidden_dim,
287
+ act_layer=act_layer,
288
+ drop=drop,
289
+ )
290
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
291
+
292
+ def forward(self, x, y, xpos, ypos):
293
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
294
+ y_ = self.norm_y(y)
295
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
296
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
297
+ return x, y
298
+
299
+
300
+ class CustomDecoderBlock(nn.Module):
301
+
302
+ def __init__(
303
+ self,
304
+ dim,
305
+ num_heads,
306
+ mlp_ratio=4.0,
307
+ qkv_bias=False,
308
+ drop=0.0,
309
+ attn_drop=0.0,
310
+ drop_path=0.0,
311
+ act_layer=nn.GELU,
312
+ norm_layer=nn.LayerNorm,
313
+ norm_mem=True,
314
+ rope=None,
315
+ ):
316
+ super().__init__()
317
+ self.norm1 = norm_layer(dim)
318
+ self.attn = Attention(
319
+ dim,
320
+ rope=rope,
321
+ num_heads=num_heads,
322
+ qkv_bias=qkv_bias,
323
+ attn_drop=attn_drop,
324
+ proj_drop=drop,
325
+ )
326
+ self.cross_attn = CrossAttention(
327
+ dim,
328
+ rope=rope,
329
+ num_heads=num_heads,
330
+ qkv_bias=qkv_bias,
331
+ attn_drop=attn_drop,
332
+ proj_drop=drop,
333
+ )
334
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
335
+ self.norm2 = norm_layer(dim)
336
+ self.norm3 = norm_layer(dim)
337
+ mlp_hidden_dim = int(dim * mlp_ratio)
338
+ self.mlp = Mlp(
339
+ in_features=dim,
340
+ hidden_features=mlp_hidden_dim,
341
+ act_layer=act_layer,
342
+ drop=drop,
343
+ )
344
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
345
+ self.norm_z = norm_layer(dim) if norm_mem else nn.Identity()
346
+
347
+ def forward(self, x, y, z, xpos, ypos):
348
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
349
+ y_ = self.norm_y(y)
350
+ z_ = self.norm_z(z)
351
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, z_, xpos, ypos))
352
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
353
+ return x, y
354
+
355
+
356
+ class ModLN(nn.Module):
357
+ """
358
+ Modulation with adaLN.
359
+
360
+ References:
361
+ DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L101
362
+ """
363
+
364
+ def __init__(self, inner_dim: int, mod_dim: int, eps: float):
365
+ super().__init__()
366
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
367
+ self.mlp = nn.Sequential(
368
+ nn.SiLU(),
369
+ nn.Linear(mod_dim, inner_dim * 2),
370
+ )
371
+
372
+ @staticmethod
373
+ def modulate(x, shift, scale):
374
+
375
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
376
+
377
+ def forward(self, x: torch.Tensor, mod: torch.Tensor) -> torch.Tensor:
378
+ shift, scale = self.mlp(mod).chunk(2, dim=-1) # [N, D]
379
+ return self.modulate(self.norm(x), shift, scale) # [N, L, D]
380
+
381
+
382
+ class ConditionModulationBlock(nn.Module):
383
+
384
+ def __init__(
385
+ self,
386
+ dim,
387
+ num_heads,
388
+ mlp_ratio=4.0,
389
+ qkv_bias=False,
390
+ drop=0.0,
391
+ attn_drop=0.0,
392
+ drop_path=0.0,
393
+ act_layer=nn.GELU,
394
+ norm_layer=partial(ModLN, eps=1e-6),
395
+ rope=None,
396
+ ):
397
+ super().__init__()
398
+ self.norm1 = norm_layer(dim, dim)
399
+ self.attn = Attention(
400
+ dim,
401
+ rope=rope,
402
+ num_heads=num_heads,
403
+ qkv_bias=qkv_bias,
404
+ attn_drop=attn_drop,
405
+ proj_drop=drop,
406
+ )
407
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
408
+ self.norm2 = norm_layer(dim, dim)
409
+ mlp_hidden_dim = int(dim * mlp_ratio)
410
+ self.mlp = Mlp(
411
+ in_features=dim,
412
+ hidden_features=mlp_hidden_dim,
413
+ act_layer=act_layer,
414
+ drop=drop,
415
+ )
416
+
417
+ def forward(self, x, mod, xpos):
418
+ x = x + self.drop_path(self.attn(self.norm1(x, mod), xpos))
419
+ x = x + self.drop_path(self.mlp(self.norm2(x, mod)))
420
+ return x
421
+
422
+
423
+ class PositionGetter(object):
424
+ """return positions of patches"""
425
+
426
+ def __init__(self):
427
+ self.cache_positions = {}
428
+
429
+ def __call__(self, b, h, w, device):
430
+ if not (h, w) in self.cache_positions:
431
+ x = torch.arange(w, device=device)
432
+ y = torch.arange(h, device=device)
433
+ self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2)
434
+ pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone()
435
+ return pos
436
+
437
+
438
+ class PatchEmbed(nn.Module):
439
+ """just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
440
+
441
+ def __init__(
442
+ self,
443
+ img_size=224,
444
+ patch_size=16,
445
+ in_chans=3,
446
+ embed_dim=768,
447
+ norm_layer=None,
448
+ flatten=True,
449
+ ):
450
+ super().__init__()
451
+ img_size = to_2tuple(img_size)
452
+ patch_size = to_2tuple(patch_size)
453
+ self.img_size = img_size
454
+ self.patch_size = patch_size
455
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
456
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
457
+ self.flatten = flatten
458
+
459
+ self.proj = nn.Conv2d(
460
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
461
+ )
462
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
463
+
464
+ self.position_getter = PositionGetter()
465
+
466
+ def forward(self, x):
467
+ B, C, H, W = x.shape
468
+ torch._assert(
469
+ H == self.img_size[0],
470
+ f"Input image height ({H}) doesn't match model ({self.img_size[0]}).",
471
+ )
472
+ torch._assert(
473
+ W == self.img_size[1],
474
+ f"Input image width ({W}) doesn't match model ({self.img_size[1]}).",
475
+ )
476
+ x = self.proj(x)
477
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
478
+ if self.flatten:
479
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
480
+ x = self.norm(x)
481
+ return x, pos
482
+
483
+ def _init_weights(self):
484
+ w = self.proj.weight.data
485
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
486
+
487
+
488
+ if __name__ == "__main__":
489
+ import os
490
+ import sys
491
+
492
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
493
+ import dust3r.utils.path_to_croco
494
+ from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
495
+ from functools import partial
496
+ from torch.utils.checkpoint import checkpoint
497
+
498
+ torch.manual_seed(0)
499
+
500
+ enc_blocks_ray_map = (
501
+ nn.ModuleList(
502
+ [
503
+ Block(
504
+ 768,
505
+ 16,
506
+ 4,
507
+ qkv_bias=True,
508
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
509
+ rope=RoPE2D(100),
510
+ )
511
+ for _ in range(2)
512
+ ]
513
+ )
514
+ .cuda()
515
+ .train()
516
+ )
517
+
518
+ x = torch.randn(2, 196, 768, requires_grad=True).cuda()
519
+ xpos = torch.arange(0, 196).unsqueeze(0).unsqueeze(-1).repeat(2, 1, 2).cuda().long()
520
+ enc_blocks_ray_map.zero_grad()
521
+ for blk in enc_blocks_ray_map:
522
+
523
+ x = checkpoint(blk, x, xpos)
524
+ enc_blocks_ray_map.zero_grad()
525
+ x.sum().backward()
526
+
527
+ grad_not_checkpointed = {}
528
+ for name, param in enc_blocks_ray_map.named_parameters():
529
+ grad_not_checkpointed[name] = param.grad.data.clone()
530
+ print(name, grad_not_checkpointed[name])
531
+ break
outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/bedlam.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+ invalid_seqs = [
14
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000042",
15
+ "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000059",
16
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000079",
17
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000978",
18
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000081",
19
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000268",
20
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000089",
21
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000189",
22
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000034",
23
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000889",
24
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000293",
25
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000067",
26
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000904",
27
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000434",
28
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000044",
29
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000013",
30
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000396",
31
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000012",
32
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000082",
33
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000120",
34
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000324",
35
+ "20221013_3_250_batch01hand_static_bigOffice_seq_000038",
36
+ "20221012_3-10_500_batch01hand_zoom_highSchoolGym_seq_000486",
37
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000421",
38
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000226",
39
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000012",
40
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000149",
41
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000311",
42
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000080",
43
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000122",
44
+ "20221012_3-10_500_batch01hand_zoom_highSchoolGym_seq_000079",
45
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000077",
46
+ "20221014_3_250_batch01hand_orbit_archVizUI3_time15_seq_000095",
47
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000062",
48
+ "20221013_3_250_batch01hand_static_bigOffice_seq_000015",
49
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000095",
50
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000119",
51
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000297",
52
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000011",
53
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000196",
54
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000316",
55
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000283",
56
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000085",
57
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000287",
58
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000163",
59
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000804",
60
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000842",
61
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000027",
62
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000182",
63
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000982",
64
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000029",
65
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000031",
66
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000025",
67
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000250",
68
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000785",
69
+ "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000069",
70
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000122",
71
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000246",
72
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000352",
73
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000425",
74
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000192",
75
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000900",
76
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000043",
77
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000063",
78
+ "20221014_3_250_batch01hand_orbit_archVizUI3_time15_seq_000096",
79
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000091",
80
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000013",
81
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000309",
82
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000114",
83
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000969",
84
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000361",
85
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000267",
86
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000083",
87
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000383",
88
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000890",
89
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000003",
90
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000045",
91
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000317",
92
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000076",
93
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000082",
94
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000907",
95
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000279",
96
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000076",
97
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000004",
98
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000061",
99
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000811",
100
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000800",
101
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000841",
102
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000794",
103
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000308",
104
+ "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000064",
105
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000284",
106
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000752",
107
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000269",
108
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000036",
109
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000419",
110
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000290",
111
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000322",
112
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000818",
113
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000327",
114
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000326",
115
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000002",
116
+ "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000060",
117
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000348",
118
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000059",
119
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000016",
120
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000817",
121
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000332",
122
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000094",
123
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000193",
124
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000779",
125
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000177",
126
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000368",
127
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000023",
128
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000024",
129
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000310",
130
+ "20221014_3_250_batch01hand_orbit_archVizUI3_time15_seq_000086",
131
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000038",
132
+ "20221024_10_100_batch01handhair_zoom_suburb_d_seq_000071",
133
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000768",
134
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000017",
135
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000053",
136
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000097",
137
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000856",
138
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000827",
139
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000161",
140
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000084",
141
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000106",
142
+ "20221013_3_250_batch01hand_orbit_bigOffice_seq_000207",
143
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000007",
144
+ "20221024_3-10_100_batch01handhair_static_highSchoolGym_seq_000013",
145
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000251",
146
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000796",
147
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000105",
148
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000251",
149
+ "20221019_3-8_250_highbmihand_orbit_stadium_seq_000046",
150
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000334",
151
+ "20221019_3-8_1000_highbmihand_static_suburb_d_seq_000453",
152
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000373",
153
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000283",
154
+ "20221010_3-10_500_batch01hand_zoom_suburb_d_seq_000249",
155
+ ]
156
+ hdri_scenes = [
157
+ "20221010_3_1000_batch01hand",
158
+ "20221017_3_1000_batch01hand",
159
+ "20221018_3-8_250_batch01hand",
160
+ "20221019_3_250_highbmihand",
161
+ ]
162
+
163
+
164
+ class BEDLAM_Multi(BaseMultiViewDataset):
165
+ def __init__(self, *args, ROOT, **kwargs):
166
+ self.ROOT = ROOT
167
+ self.pose_root = os.path.join(
168
+ os.path.dirname(ROOT), f"{os.path.basename(ROOT)}_pose"
169
+ )
170
+ assert os.path.exists(self.pose_root)
171
+ self.video = True
172
+ self.is_metric = True
173
+ self.max_interval = 4
174
+ super().__init__(*args, **kwargs)
175
+ self.loaded_data = self._load_data()
176
+
177
+ def _load_data(self):
178
+ self.scenes = os.listdir(self.ROOT)
179
+
180
+ offset = 0
181
+ scenes = []
182
+ sceneids = []
183
+ scene_img_list = []
184
+ images = []
185
+ start_img_ids = []
186
+
187
+ j = 0
188
+ for scene in tqdm(self.scenes):
189
+ if scene in invalid_seqs:
190
+ continue
191
+ if any([scene.startswith(x) for x in hdri_scenes]):
192
+ continue
193
+ if "closeup" in scene:
194
+ continue
195
+ scene_dir = osp.join(self.ROOT, scene)
196
+ rgb_dir = osp.join(scene_dir, "rgb")
197
+ basenames = sorted(
198
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")]
199
+ )
200
+ num_imgs = len(basenames)
201
+ img_ids = list(np.arange(num_imgs) + offset)
202
+ cut_off = (
203
+ self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
204
+ )
205
+ if num_imgs < cut_off:
206
+ print(f"Skipping {scene}")
207
+ continue
208
+ start_img_ids_ = img_ids[: num_imgs - cut_off + 1]
209
+
210
+ start_img_ids.extend(start_img_ids_)
211
+ sceneids.extend([j] * num_imgs)
212
+ images.extend(basenames)
213
+ scenes.append(scene)
214
+ scene_img_list.append(img_ids)
215
+
216
+ # offset groups
217
+ offset += num_imgs
218
+ j += 1
219
+
220
+ self.scenes = scenes
221
+ assert len(set(self.scenes) - set(os.listdir(self.pose_root))) == 0
222
+ self.sceneids = sceneids
223
+ self.images = images
224
+ self.start_img_ids = start_img_ids
225
+ self.scene_img_list = scene_img_list
226
+
227
+ def __len__(self):
228
+ return len(self.start_img_ids)
229
+
230
+ def get_image_num(self):
231
+ return len(self.images)
232
+
233
+ def _get_views(self, idx, resolution, rng, num_views):
234
+ start_id = self.start_img_ids[idx]
235
+ all_image_ids = self.scene_img_list[self.sceneids[start_id]]
236
+ pos, ordered_video = self.get_seq_from_start_id(
237
+ num_views,
238
+ start_id,
239
+ all_image_ids,
240
+ rng,
241
+ max_interval=self.max_interval,
242
+ video_prob=1.0,
243
+ fix_interval_prob=1.0,
244
+ )
245
+ image_idxs = np.array(all_image_ids)[pos]
246
+
247
+ views = []
248
+ for v, view_idx in enumerate(image_idxs):
249
+ scene_id = self.sceneids[view_idx]
250
+ scene_dir = osp.join(self.ROOT, self.scenes[scene_id])
251
+ rgb_dir = osp.join(scene_dir, "rgb")
252
+ depth_dir = osp.join(scene_dir, "depth")
253
+ cam_dir = osp.join(osp.join(self.pose_root, self.scenes[scene_id]), "cam")
254
+
255
+ basename = self.images[view_idx]
256
+
257
+ # Load RGB image
258
+ rgb_image = imread_cv2(osp.join(rgb_dir, basename + ".png"))
259
+ # Load depthmap
260
+ depthmap = np.load(osp.join(depth_dir, basename + ".npy"))
261
+ depthmap[~np.isfinite(depthmap)] = 0 # invalid
262
+ depthmap[depthmap > 200.0] = 0.0
263
+
264
+ cam = np.load(osp.join(cam_dir, basename + ".npz"))
265
+ camera_pose = cam["pose"]
266
+ intrinsics = cam["intrinsics"]
267
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
268
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
269
+ )
270
+
271
+ # generate img mask and raymap mask
272
+ img_mask, ray_mask = self.get_img_and_ray_masks(
273
+ self.is_metric, v, rng, p=[0.85, 0.10, 0.05]
274
+ )
275
+
276
+ views.append(
277
+ dict(
278
+ img=rgb_image,
279
+ depthmap=depthmap.astype(np.float32),
280
+ camera_pose=camera_pose.astype(np.float32),
281
+ camera_intrinsics=intrinsics.astype(np.float32),
282
+ dataset="BEDLAM",
283
+ label=self.scenes[scene_id] + "_" + basename,
284
+ instance=osp.join(rgb_dir, basename + ".png"),
285
+ is_metric=self.is_metric,
286
+ is_video=ordered_video,
287
+ quantile=np.array(1, dtype=np.float32),
288
+ img_mask=img_mask,
289
+ ray_mask=ray_mask,
290
+ camera_only=False,
291
+ depth_only=False,
292
+ single_view=False,
293
+ reset=False,
294
+ )
295
+ )
296
+ assert len(views) == num_views
297
+ return views
outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/eden.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+
14
+ class EDEN_Multi(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = True
18
+ self.is_metric = True
19
+ super().__init__(*args, **kwargs)
20
+ self.loaded_data = self._load_data()
21
+
22
+ def _load_data(self):
23
+ scenes = os.listdir(self.ROOT)
24
+ img_names = []
25
+ for scene in scenes:
26
+ scene_dir = osp.join(self.ROOT, scene)
27
+ rgb_dir = osp.join(scene_dir, "rgb")
28
+ basenames = sorted(
29
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")]
30
+ )
31
+ img_names.extend([(scene, basename) for basename in basenames])
32
+
33
+ self.img_names = img_names
34
+
35
+ def __len__(self):
36
+ return len(self.img_names)
37
+
38
+ def get_image_num(self):
39
+ return len(self.img_names)
40
+
41
+ def _get_views(self, idx, resolution, rng, num_views):
42
+ new_seed = rng.integers(0, 2**32) + idx
43
+ new_rng = np.random.default_rng(new_seed)
44
+ img_names = new_rng.permutation(self.img_names)
45
+
46
+ views = []
47
+ i = 0
48
+ while len(views) < num_views:
49
+ # Load RGB image
50
+ scene, img_name = img_names[i]
51
+ try:
52
+ rgb_image = imread_cv2(
53
+ osp.join(self.ROOT, scene, "rgb", f"{img_name}.png")
54
+ )
55
+ depthmap = np.load(
56
+ osp.join(self.ROOT, scene, "depth", f"{img_name}.npy")
57
+ )
58
+ depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
59
+
60
+ intrinsics = np.load(
61
+ osp.join(self.ROOT, scene, "cam", f"{img_name}.npz")
62
+ )["intrinsics"]
63
+ # camera pose is not provided, placeholder
64
+ camera_pose = np.eye(4)
65
+ except:
66
+ i += 1
67
+ continue
68
+
69
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
70
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name
71
+ )
72
+
73
+ views.append(
74
+ dict(
75
+ img=rgb_image,
76
+ depthmap=depthmap.astype(np.float32),
77
+ camera_pose=camera_pose.astype(np.float32),
78
+ camera_intrinsics=intrinsics.astype(np.float32),
79
+ dataset="EDEN",
80
+ label=img_name,
81
+ instance=osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"),
82
+ is_metric=self.is_metric,
83
+ is_video=False,
84
+ quantile=np.array(1.0, dtype=np.float32),
85
+ img_mask=True,
86
+ ray_mask=False,
87
+ camera_only=False,
88
+ depth_only=False,
89
+ single_view=True,
90
+ reset=True,
91
+ )
92
+ )
93
+ i += 1
94
+ return views
outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/irs.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import cv2
3
+ import numpy as np
4
+ import itertools
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
9
+ from tqdm import tqdm
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+
14
+ class IRS(BaseMultiViewDataset):
15
+ def __init__(self, *args, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = False
18
+ self.is_metric = True
19
+ super().__init__(*args, **kwargs)
20
+ self.loaded_data = self._load_data()
21
+
22
+ def _load_data(self):
23
+ scenes = os.listdir(self.ROOT)
24
+ img_names = []
25
+ for scene in scenes:
26
+ scene_dir = osp.join(self.ROOT, scene)
27
+ rgb_dir = osp.join(scene_dir, "rgb")
28
+ basenames = sorted(
29
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")]
30
+ )
31
+ img_names.extend([(scene, basename) for basename in basenames])
32
+
33
+ self.img_names = img_names
34
+
35
+ def __len__(self):
36
+ return len(self.img_names)
37
+
38
+ def get_image_num(self):
39
+ return len(self.img_names)
40
+
41
+ def _get_views(self, idx, resolution, rng, num_views):
42
+ new_seed = rng.integers(0, 2**32) + idx
43
+ new_rng = np.random.default_rng(new_seed)
44
+ img_names = new_rng.choice(self.img_names, num_views, replace=False)
45
+
46
+ views = []
47
+ for v, img_name in enumerate(img_names):
48
+ # Load RGB image
49
+ scene, img_name = img_name
50
+ rgb_image = imread_cv2(osp.join(self.ROOT, scene, "rgb", f"{img_name}.png"))
51
+ depthmap = np.load(osp.join(self.ROOT, scene, "depth", f"{img_name}.npy"))
52
+ depthmap[depthmap > 200] = 0.0
53
+ depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0)
54
+
55
+ intrinsics = np.load(osp.join(self.ROOT, scene, "cam", f"{img_name}.npz"))[
56
+ "intrinsics"
57
+ ]
58
+ # camera pose is not provided, placeholder
59
+ camera_pose = np.eye(4)
60
+
61
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
62
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=img_name
63
+ )
64
+
65
+ views.append(
66
+ dict(
67
+ img=rgb_image,
68
+ depthmap=depthmap.astype(np.float32),
69
+ camera_pose=camera_pose.astype(np.float32),
70
+ camera_intrinsics=intrinsics.astype(np.float32),
71
+ dataset="irs",
72
+ label=img_name,
73
+ instance=f"{str(idx)}_{img_name}",
74
+ is_metric=self.is_metric,
75
+ is_video=False,
76
+ quantile=np.array(1.0, dtype=np.float32),
77
+ img_mask=True,
78
+ ray_mask=False,
79
+ camera_only=False,
80
+ depth_only=False,
81
+ single_view=True,
82
+ reset=True,
83
+ )
84
+ )
85
+ assert len(views) == num_views
86
+ return views
outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/kitti.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """KITTI Odometry training dataset loader.
2
+
3
+ Mirrors VirtualKITTI2_Multi but for the real KITTI odometry benchmark
4
+ (11 sequences 00-10 with ground-truth poses).
5
+
6
+ Optional sparse LiDAR depth supervision: when `velodyne_root` is provided,
7
+ loads the 64-beam Velodyne scan per frame, applies the `Tr` velo→cam0
8
+ rigid transform from calib.txt, then projects via P2 onto image_2 to
9
+ build a sparse depthmap. Pixels without a lidar return are marked -1.
10
+
11
+ Layout expected:
12
+ ROOT/sequences/<00..10>/image_2/{NNNNNN}.png
13
+ ROOT/sequences/<00..10>/calib.txt # P0 P1 P2 P3 + Tr (3x4 each)
14
+ ROOT/poses/<00..10>.txt # one 12-float c2w 3x4 row per frame
15
+ velodyne_root/sequences/<00..10>/velodyne/{NNNNNN}.bin # (N,4) float32 [x y z r]
16
+
17
+ Train/eval split (matching `setup_lingbot_env.sh` and SLAM-eval convention):
18
+ train: 00, 01, 02, 03, 04, 05, 06, 07, 08
19
+ test: 09, 10
20
+ """
21
+ import os
22
+ import os.path as osp
23
+ import sys
24
+
25
+ import cv2
26
+ import numpy as np
27
+
28
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
29
+
30
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
31
+ from dust3r.utils.image import imread_cv2
32
+
33
+ TRAIN_SEQS = ["00", "01", "02", "03", "04", "05", "06", "07", "08"]
34
+ TEST_SEQS = ["09", "10"]
35
+
36
+
37
+ def _parse_calib(calib_path: str):
38
+ """Read calib.txt; return (P2 (3,4) left-color cam projection,
39
+ Tr (3,4) velodyne→cam0 rigid transform, or None if absent)."""
40
+ P2 = None
41
+ Tr = None
42
+ with open(calib_path) as fh:
43
+ for line in fh:
44
+ line = line.strip()
45
+ if line.startswith("P2:"):
46
+ vals = list(map(float, line[3:].split()))
47
+ P2 = np.array(vals).reshape(3, 4)
48
+ elif line.startswith("Tr:") or line.startswith("Tr_velo_to_cam:"):
49
+ pref = "Tr_velo_to_cam:" if line.startswith("Tr_velo_to_cam:") else "Tr:"
50
+ vals = list(map(float, line[len(pref):].split()))
51
+ Tr = np.array(vals).reshape(3, 4)
52
+ if P2 is None:
53
+ raise RuntimeError(f"P2 not found in {calib_path}")
54
+ return P2, Tr
55
+
56
+
57
+ def _parse_calib_p2(calib_path: str) -> np.ndarray:
58
+ """Backward-compat shim returning only P2."""
59
+ P2, _ = _parse_calib(calib_path)
60
+ return P2
61
+
62
+
63
+ def _project_velo_to_depth(velo_pts, P2, Tr, H, W,
64
+ min_depth=0.5, max_depth=80.0):
65
+ """Project a velodyne scan onto image_2 to build a sparse depthmap.
66
+
67
+ Args:
68
+ velo_pts: (N,4) float32 [x,y,z,reflectance] in velodyne frame.
69
+ P2: (3,4) left-color cam projection (cam0 → image_2).
70
+ Tr: (3,4) velodyne → cam0 rigid transform.
71
+ H, W: target depthmap shape.
72
+ Returns:
73
+ (H,W) float32 depthmap, -1 where no lidar return.
74
+ """
75
+ Tr_h = np.eye(4, dtype=np.float64)
76
+ Tr_h[:3, :] = Tr
77
+ pts_h = np.concatenate(
78
+ [velo_pts[:, :3].astype(np.float64), np.ones((velo_pts.shape[0], 1))],
79
+ axis=1,
80
+ )
81
+ cam0 = pts_h @ Tr_h.T # (N,4) in cam0 frame
82
+ in_front = cam0[:, 2] > min_depth
83
+ cam0 = cam0[in_front]
84
+ if cam0.shape[0] == 0:
85
+ return np.full((H, W), -1.0, dtype=np.float32)
86
+ uv_h = cam0 @ P2.T # (M,3)
87
+ z = uv_h[:, 2]
88
+ valid = z > min_depth
89
+ z = z[valid]
90
+ u = uv_h[valid, 0] / z
91
+ v = uv_h[valid, 1] / z
92
+ in_img = (u >= 0) & (u < W) & (v >= 0) & (v < H) & (z < max_depth)
93
+ u = u[in_img].astype(np.int32)
94
+ v = v[in_img].astype(np.int32)
95
+ z = z[in_img]
96
+ depthmap = np.full((H, W), -1.0, dtype=np.float32)
97
+ if z.size == 0:
98
+ return depthmap
99
+ # Multiple lidar points hitting one pixel: keep the closest (smallest z).
100
+ order = np.argsort(-z) # descending; smallest z written last → wins
101
+ depthmap[v[order], u[order]] = z[order].astype(np.float32)
102
+ return depthmap
103
+
104
+
105
+ def _load_velodyne_bin(bin_path: str) -> np.ndarray:
106
+ """Load (N,4) float32 [x,y,z,reflectance] from KITTI .bin file."""
107
+ return np.fromfile(bin_path, dtype=np.float32).reshape(-1, 4)
108
+
109
+
110
+ def _load_kitti_poses(poses_path: str) -> np.ndarray:
111
+ """Load KITTI poses.txt into (N,4,4) c2w matrices (homogeneous)."""
112
+ raw = np.loadtxt(poses_path) # (N, 12) row-major 3x4
113
+ N = raw.shape[0]
114
+ out = np.zeros((N, 4, 4), dtype=np.float32)
115
+ out[:, :3, :] = raw.reshape(N, 3, 4)
116
+ out[:, 3, 3] = 1.0
117
+ return out
118
+
119
+
120
+ class KITTI_Multi(BaseMultiViewDataset):
121
+ """Real KITTI odometry, image_2 (left color), camera-only supervision."""
122
+
123
+ def __init__(self, ROOT, *args, velodyne_root=None, **kwargs):
124
+ self.ROOT = ROOT
125
+ # Optional separate root for Velodyne .bin scans (e.g. semantickitti);
126
+ # if None, expects them at ROOT/sequences/<sid>/velodyne/.
127
+ self.velodyne_root = velodyne_root
128
+ self.video = True
129
+ self.is_metric = True # poses are metric
130
+ self.max_interval = 4
131
+ super().__init__(*args, **kwargs)
132
+ self._load_data(self.split)
133
+
134
+ def _load_data(self, split=None):
135
+ seq_ids = TRAIN_SEQS if split == "train" else TEST_SEQS
136
+ seq_dirs = []
137
+ for sid in seq_ids:
138
+ seq_path = osp.join(self.ROOT, "sequences", sid)
139
+ if not osp.isdir(seq_path):
140
+ continue
141
+ img_dir = osp.join(seq_path, "image_2")
142
+ poses_path = osp.join(self.ROOT, "poses", f"{sid}.txt")
143
+ if not osp.isdir(img_dir) or not osp.isfile(poses_path):
144
+ continue
145
+ seq_dirs.append(sid)
146
+
147
+ if not seq_dirs:
148
+ raise RuntimeError(f"No KITTI sequences found at {self.ROOT}")
149
+
150
+ offset = 0
151
+ scenes = [] # list of seq ids
152
+ seq_intrinsics = [] # (3,3) per seq
153
+ seq_p2 = [] # (3,4) P2 per seq (needed for lidar projection)
154
+ seq_tr = [] # (3,4) Tr velo→cam0 per seq, or None
155
+ seq_velo_dir = [] # absolute velodyne dir per seq, or None
156
+ seq_poses = [] # (N_i,4,4) per seq
157
+ sceneids = [] # per-frame seq idx (global)
158
+ scene_img_list = [] # per-seq global frame idx list
159
+ start_img_ids = []
160
+ j = 0
161
+
162
+ for sid in seq_dirs:
163
+ seq_path = osp.join(self.ROOT, "sequences", sid)
164
+ img_dir = osp.join(seq_path, "image_2")
165
+ calib_path = osp.join(seq_path, "calib.txt")
166
+ poses_path = osp.join(self.ROOT, "poses", f"{sid}.txt")
167
+
168
+ P2, Tr = _parse_calib(calib_path)
169
+ K = P2[:, :3] # left-color intrinsics
170
+ poses_c2w = _load_kitti_poses(poses_path) # (N,4,4)
171
+ n_imgs = poses_c2w.shape[0]
172
+
173
+ # Resolve velodyne dir (per-sequence). Try ROOT first; fall back to
174
+ # velodyne_root if provided. None = no lidar supervision for this seq.
175
+ velo_dir = None
176
+ velo_calib_path = None
177
+ for cand_velo, cand_calib in (
178
+ (osp.join(seq_path, "velodyne"), calib_path),
179
+ (
180
+ osp.join(self.velodyne_root, "sequences", sid, "velodyne") if self.velodyne_root else None,
181
+ osp.join(self.velodyne_root, "sequences", sid, "calib.txt") if self.velodyne_root else None,
182
+ ),
183
+ ):
184
+ if cand_velo and osp.isdir(cand_velo):
185
+ velo_dir = cand_velo
186
+ velo_calib_path = cand_calib
187
+ break
188
+
189
+ # KITTI odometry eval calib.txt may ship without Tr (velo→cam0).
190
+ # If so, parse Tr from the velodyne_root's calib.txt (semantickitti
191
+ # ships full P0..P3 + Tr per sequence).
192
+ if Tr is None and velo_calib_path and osp.isfile(velo_calib_path) and velo_calib_path != calib_path:
193
+ _, Tr_velo = _parse_calib(velo_calib_path)
194
+ if Tr_velo is not None:
195
+ Tr = Tr_velo
196
+
197
+ cut_off = (
198
+ self.num_views
199
+ if not self.allow_repeat
200
+ else max(self.num_views // 3, 3)
201
+ )
202
+ if n_imgs < cut_off:
203
+ continue
204
+
205
+ img_ids = list(np.arange(n_imgs) + offset)
206
+ start_img_ids_ = img_ids[: n_imgs - cut_off + 1]
207
+
208
+ scenes.append(sid)
209
+ seq_intrinsics.append(K.astype(np.float32))
210
+ seq_p2.append(P2.astype(np.float64))
211
+ seq_tr.append(Tr.astype(np.float64) if Tr is not None else None)
212
+ seq_velo_dir.append(velo_dir)
213
+ seq_poses.append(poses_c2w)
214
+ scene_img_list.append(img_ids)
215
+ sceneids.extend([j] * n_imgs)
216
+ start_img_ids.extend(start_img_ids_)
217
+ offset += n_imgs
218
+ j += 1
219
+
220
+ self.scenes = scenes
221
+ self.seq_intrinsics = seq_intrinsics
222
+ self.seq_p2 = seq_p2
223
+ self.seq_tr = seq_tr
224
+ self.seq_velo_dir = seq_velo_dir
225
+ self.seq_poses = seq_poses
226
+ self.sceneids = sceneids
227
+ self.start_img_ids = start_img_ids
228
+ self.scene_img_list = scene_img_list
229
+
230
+ def __len__(self):
231
+ return len(self.start_img_ids)
232
+
233
+ def get_image_num(self):
234
+ return sum(len(p) for p in self.seq_poses)
235
+
236
+ def get_stats(self):
237
+ return f"{len(self)} groups of views across {len(self.scenes)} KITTI sequences"
238
+
239
+ def _get_views(self, idx, resolution, rng, num_views):
240
+ start_id = self.start_img_ids[idx]
241
+ scene_id = self.sceneids[start_id]
242
+ all_image_ids = self.scene_img_list[scene_id]
243
+ n_frames = len(all_image_ids)
244
+ seq_id_str = self.scenes[scene_id]
245
+ seq_path = osp.join(self.ROOT, "sequences", seq_id_str)
246
+ img_dir = osp.join(seq_path, "image_2")
247
+ K = self.seq_intrinsics[scene_id]
248
+ P2 = self.seq_p2[scene_id]
249
+ Tr = self.seq_tr[scene_id]
250
+ velo_dir = self.seq_velo_dir[scene_id]
251
+ has_lidar = velo_dir is not None and Tr is not None
252
+ poses = self.seq_poses[scene_id]
253
+
254
+ pos, ordered_video = self.get_seq_from_start_id(
255
+ num_views,
256
+ start_id,
257
+ all_image_ids,
258
+ rng,
259
+ max_interval=self.max_interval,
260
+ video_prob=1.0,
261
+ fix_interval_prob=0.9,
262
+ )
263
+ # `pos` are 0-based local positions within all_image_ids
264
+ # (matching upstream SLAMFormer get_seq_from_start_id semantics).
265
+ local_idxs = np.asarray(pos, dtype=int)
266
+
267
+ views = []
268
+ for v, lid in enumerate(local_idxs):
269
+ lid = int(lid)
270
+ img_path = osp.join(img_dir, f"{lid:06d}.png")
271
+ image = imread_cv2(img_path)
272
+ H, W = image.shape[:2]
273
+
274
+ # If velodyne available, project LiDAR scan to image_2 → sparse depthmap.
275
+ # Otherwise emit invalid depthmap and fall back to camera-only.
276
+ if has_lidar:
277
+ bin_path = osp.join(velo_dir, f"{lid:06d}.bin")
278
+ if osp.isfile(bin_path):
279
+ velo_pts = _load_velodyne_bin(bin_path)
280
+ depthmap = _project_velo_to_depth(velo_pts, P2, Tr, H, W)
281
+ frame_has_lidar = (depthmap > 0).any()
282
+ else:
283
+ depthmap = np.full((H, W), -1.0, dtype=np.float32)
284
+ frame_has_lidar = False
285
+ else:
286
+ depthmap = np.full((H, W), -1.0, dtype=np.float32)
287
+ frame_has_lidar = False
288
+
289
+ intrinsics = K.copy()
290
+ camera_pose = poses[lid].astype(np.float32)
291
+
292
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
293
+ image, depthmap, intrinsics, resolution, rng, info=(seq_path, img_path)
294
+ )
295
+
296
+ img_mask, ray_mask = self.get_img_and_ray_masks(
297
+ self.is_metric, v, rng, p=[0.85, 0.1, 0.05]
298
+ )
299
+
300
+ views.append(
301
+ dict(
302
+ img=image,
303
+ depthmap=depthmap,
304
+ camera_pose=camera_pose,
305
+ camera_intrinsics=intrinsics,
306
+ dataset="KITTI",
307
+ label=seq_path,
308
+ is_metric=self.is_metric,
309
+ instance=f"{seq_id_str}/image_2/{lid:06d}.png",
310
+ is_video=ordered_video,
311
+ quantile=np.array(1.0, dtype=np.float32),
312
+ img_mask=img_mask,
313
+ ray_mask=ray_mask,
314
+ # If frame has any valid lidar return, supervise depth too.
315
+ camera_only=not frame_has_lidar,
316
+ depth_only=False,
317
+ single_view=False,
318
+ reset=False,
319
+ scene_tag=f"kitti/{seq_id_str}",
320
+ )
321
+ )
322
+ assert len(views) == num_views
323
+ return views
outdoor_v48_4gpu_v2/code/05_02-14:21:58/dust3r/datasets/mp3d.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import os
3
+ import sys
4
+ import itertools
5
+
6
+ sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
7
+ import cv2
8
+ import numpy as np
9
+
10
+ from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
11
+ from dust3r.utils.image import imread_cv2
12
+
13
+
14
+ class MP3D_Multi(BaseMultiViewDataset):
15
+ def __init__(self, *args, split, ROOT, **kwargs):
16
+ self.ROOT = ROOT
17
+ self.video = False
18
+ self.is_metric = True
19
+ super().__init__(*args, **kwargs)
20
+
21
+ self.loaded_data = self._load_data()
22
+
23
+ def _load_data(self):
24
+ scenes = os.listdir(self.ROOT)
25
+ offset = 0
26
+ overlaps = {scene: [] for scene in scenes}
27
+ scene_img_list = {scene: [] for scene in scenes}
28
+ images = []
29
+
30
+ j = 0
31
+ for scene in scenes:
32
+ scene_dir = osp.join(self.ROOT, scene)
33
+ rgb_dir = osp.join(scene_dir, "rgb")
34
+ basenames = sorted(
35
+ [f[:-4] for f in os.listdir(rgb_dir) if f.endswith(".png")]
36
+ )
37
+ overlap = np.load(osp.join(scene_dir, "overlap.npy"))
38
+ overlaps[scene] = overlap
39
+ num_imgs = len(basenames)
40
+
41
+ images.extend(
42
+ [(scene, i, basename) for i, basename in enumerate(basenames)]
43
+ )
44
+ scene_img_list[scene] = np.arange(num_imgs) + offset
45
+ offset += num_imgs
46
+ j += 1
47
+
48
+ self.scenes = scenes
49
+ self.scene_img_list = scene_img_list
50
+ self.images = images
51
+ self.overlaps = overlaps
52
+
53
+ def __len__(self):
54
+ return len(self.images)
55
+
56
+ def get_image_num(self):
57
+ return len(self.images)
58
+
59
+ def _get_views(self, idx, resolution, rng, num_views):
60
+ num_views_posible = 0
61
+ num_unique = num_views if not self.allow_repeat else max(num_views // 3, 3)
62
+ while num_views_posible < num_unique - 1:
63
+ scene, img_idx, _ = self.images[idx]
64
+ overlap = self.overlaps[scene]
65
+ sel_img_idx = np.where(overlap[:, 0] == img_idx)[0]
66
+ overlap_sel = overlap[sel_img_idx]
67
+ overlap_sel = overlap_sel[
68
+ (overlap_sel[:, 2] > 0.01) * (overlap_sel[:, 2] < 1)
69
+ ]
70
+ num_views_posible = len(overlap_sel)
71
+ if num_views_posible >= num_unique - 1:
72
+ break
73
+ idx = rng.choice(len(self.images))
74
+
75
+ ref_id = self.scene_img_list[scene][img_idx]
76
+ ids = self.scene_img_list[scene][overlap_sel[:, 1].astype(np.int64)]
77
+ replace = False if not self.allow_repeat else True
78
+ image_idxs = rng.choice(
79
+ ids,
80
+ num_views - 1,
81
+ replace=replace,
82
+ p=overlap_sel[:, 2] / np.sum(overlap_sel[:, 2]),
83
+ )
84
+ image_idxs = np.concatenate([[ref_id], image_idxs])
85
+
86
+ ordered_video = False
87
+ views = []
88
+ for v, view_idx in enumerate(image_idxs):
89
+ scene, _, basename = self.images[view_idx]
90
+ scene_dir = osp.join(self.ROOT, scene)
91
+ rgb_path = osp.join(scene_dir, "rgb", basename + ".png")
92
+ depth_path = osp.join(scene_dir, "depth", basename + ".npy")
93
+ cam_path = osp.join(scene_dir, "cam", basename + ".npz")
94
+
95
+ rgb_image = imread_cv2(rgb_path, cv2.IMREAD_COLOR)
96
+ depthmap = np.load(depth_path).astype(np.float32)
97
+ depthmap[~np.isfinite(depthmap)] = 0 # invalid
98
+ cam_file = np.load(cam_path)
99
+ intrinsics = cam_file["intrinsics"]
100
+ camera_pose = cam_file["pose"]
101
+
102
+ rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
103
+ rgb_image, depthmap, intrinsics, resolution, rng=rng, info=view_idx
104
+ )
105
+
106
+ # generate img mask and raymap mask
107
+ img_mask, ray_mask = self.get_img_and_ray_masks(
108
+ self.is_metric, v, rng, p=[0.85, 0.1, 0.05]
109
+ )
110
+
111
+ views.append(
112
+ dict(
113
+ img=rgb_image,
114
+ depthmap=depthmap.astype(np.float32),
115
+ camera_pose=camera_pose.astype(np.float32),
116
+ camera_intrinsics=intrinsics.astype(np.float32),
117
+ dataset="mp3d",
118
+ label=scene + "_" + rgb_path,
119
+ instance=f"{str(idx)}_{str(view_idx)}",
120
+ is_metric=self.is_metric,
121
+ is_video=ordered_video,
122
+ quantile=np.array(0.99, dtype=np.float32),
123
+ img_mask=img_mask,
124
+ ray_mask=ray_mask,
125
+ camera_only=False,
126
+ depth_only=False,
127
+ single_view=False,
128
+ reset=False,
129
+ )
130
+ )
131
+ assert len(views) == num_views
132
+ return views