FrickYinn commited on
Commit
e170a8e
·
verified ·
1 Parent(s): ee4e79a

Upload 53 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +35 -0
  2. LICENSE +201 -0
  3. README.md +183 -3
  4. assets/figures/obj_vis_gt.png +0 -0
  5. assets/figures/obj_vis_query.png +0 -0
  6. assets/figures/obj_vis_reference_labeled.png +0 -0
  7. assets/figures/scene5_vis_0.png +0 -0
  8. assets/figures/scene5_vis_1.png +0 -0
  9. assets/figures/scene5_vis_gt.png +0 -0
  10. assets/ho3d_test_3000/ho3d_test.json +0 -0
  11. assets/linemod_test_1500/linemod_test.json +0 -0
  12. assets/mapfree_submission.zip +3 -0
  13. assets/megadepth_test_1500_scene_info/0015_0.1_0.3.npz +3 -0
  14. assets/megadepth_test_1500_scene_info/0015_0.3_0.5.npz +3 -0
  15. assets/megadepth_test_1500_scene_info/0022_0.1_0.3.npz +3 -0
  16. assets/megadepth_test_1500_scene_info/0022_0.3_0.5.npz +3 -0
  17. assets/megadepth_test_1500_scene_info/0022_0.5_0.7.npz +3 -0
  18. assets/megadepth_test_1500_scene_info/megadepth_test_1500.txt +5 -0
  19. assets/scannet_test_1500/intrinsics.npz +3 -0
  20. assets/scannet_test_1500/scannet_test.txt +1 -0
  21. assets/scannet_test_1500/statistics.json +102 -0
  22. assets/scannet_test_1500/test.npz +3 -0
  23. baselines/matchers.py +72 -0
  24. baselines/pose.py +92 -0
  25. baselines/pose_solver.py +320 -0
  26. configs/default.py +85 -0
  27. configs/ho3d.yaml +19 -0
  28. configs/linemod.yaml +20 -0
  29. configs/mapfree.yaml +14 -0
  30. configs/matterport.yaml +11 -0
  31. configs/megadepth.yaml +29 -0
  32. configs/scannet.yaml +33 -0
  33. datasets/__init__.py +20 -0
  34. datasets/ho3d.py +331 -0
  35. datasets/linemod.py +441 -0
  36. datasets/mapfree.py +178 -0
  37. datasets/matterport.py +86 -0
  38. datasets/megadepth.py +125 -0
  39. datasets/sampler.py +77 -0
  40. datasets/scannet.py +154 -0
  41. eval.py +48 -0
  42. eval_add_reproj.py +138 -0
  43. eval_baselines.py +189 -0
  44. model/__init__.py +4 -0
  45. model/pl_trainer.py +201 -0
  46. model/relpose.py +465 -0
  47. requirements.txt +11 -0
  48. train.py +104 -0
  49. utils/__init__.py +19 -0
  50. utils/augment.py +15 -0
.gitignore ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/__pycache__
2
+ /checkpoints
3
+ /log
4
+ /lightning_logs
5
+ /pyramid
6
+ *.ipynb
7
+ /data
8
+ preprocess_megadepth.py
9
+ *.ckpt
10
+ test.py
11
+ /RelPoseRepo
12
+ eval_regressor.py
13
+ /assets/megadepth_test_new
14
+ /configs/megadepth_new.yaml
15
+ /results
16
+ assets/new_submission.zip
17
+
18
+ __eval_baselines.py
19
+ __pose_tracking.py
20
+ __track.py
21
+ /__pose_tracking
22
+
23
+ /baselines/configs
24
+ /baselines/repo
25
+ /baselines/weights
26
+ /baselines/__models.py
27
+ baselines/demo.html
28
+
29
+ utils/__reprojection.py
30
+ utils/__pose_solver.py
31
+ utils/__generate_epipolar_imgs.py
32
+ utils/__visualize.py
33
+
34
+ submission.py
35
+ /qualitative
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,183 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SRPose: Two-view Relative Pose Estimation with Sparse Keypoints
2
+
3
+ **SRPose**: A **S**parse keypoint-based framework for **R**elative **Pose** estimation between two views in both camera-to-world and object-to-camera scenarios.
4
+
5
+ | Reference | Query | Ground Truth |
6
+ |:--------:|:---------:|:--------:|
7
+ | ![](assets/figures/scene5_vis_0.png) | ![](assets/figures/scene5_vis_1.png) | ![](assets/figures/scene5_vis_gt.png) |
8
+ | ![](assets/figures/obj_vis_reference_labeled.png) | ![](assets/figures/obj_vis_query.png) |![](assets/figures/obj_vis_gt.png)|
9
+
10
+ ## [Project page](https://frickyinn.github.io/srpose/) | [arXiv](https://arxiv.org/abs/2407.08199)
11
+
12
+ ## Setup
13
+ Please first intall PyTorch according to [here](https://pytorch.org/get-started/locally/), then install other dependencies using pip:
14
+ ```
15
+ cd SRPose
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ## Evaluation
20
+ 1. Download pretrained models [here](https://drive.google.com/drive/folders/1bBlds3UX7-XDCevbIl4bnnywvWzzP5nN) for evaluation.
21
+ 2. Create new folders:
22
+ ```
23
+ mkdir checkpoints & mkdir data
24
+ ```
25
+ 3. Organize the downloaded checkpoints like this:
26
+ ```
27
+ SRPose
28
+ |-- checkpoints
29
+ |-- ho3d.ckpt
30
+ |-- linemod.ckpt
31
+ |-- mapfree.ckpt
32
+ |-- matterport.ckpt
33
+ |-- megadepth.ckpt
34
+ `-- scannet.ckpt
35
+ ...
36
+ ```
37
+
38
+ ### Matterport
39
+ 1. Download Matterport dataset [here](https://github.com/jinlinyi/SparsePlanes/blob/main/docs/data.md), only `mp3d_planercnn_json.zip` and `rgb.zip` are required.
40
+ 2. Unzip and organize the downloaded files:
41
+ ```
42
+ mkdir data/mp3d
43
+ mkdir data/mp3d/mp3d_planercnn_json & mkdir data/mp3d/rgb
44
+ unzip <pathto>/mp3d_planercnn_json.zip -d data/mp3d/mp3d_planercnn_json
45
+ unzip <pathto>/rgb.zip -d data/mp3d/rgb
46
+ ```
47
+ 3. The resulted directory tree should be like this:
48
+ ```
49
+ SRPose
50
+ |-- data
51
+ |-- mp3d
52
+ |-- mp3d_planercnn_json
53
+ | |-- cached_set_test.json
54
+ | |-- cached_set_train.json
55
+ | `-- cached_set_val.json
56
+ `-- rgb
57
+ |-- 17DRP5sb8fy
58
+ ...
59
+ ...
60
+ ...
61
+ ```
62
+ 4. Evaluate with the following command:
63
+ ```
64
+ python eval.py configs/matterport.yaml checkpoints/matterport.ckpt
65
+ ```
66
+
67
+ ### ScanNet & MegaDepth
68
+ 1. Download and organize the ScanNet-1500 and MegaDepth-1500 test sets according to the [LoFTR Training Script](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md). Note that only the test sets and the dataset indices are required.
69
+ 2. The resulted directory tree should be:
70
+ ```
71
+ SRPose
72
+ |-- data
73
+ |-- scannet
74
+ | |-- index
75
+ | |-- test
76
+ | `-- train (optional)
77
+ |-- megadepth
78
+ |-- index
79
+ |-- test
80
+ `-- train (optional)
81
+ ...
82
+ ...
83
+ ```
84
+ 3. Evaluate with the following commands:
85
+ ```
86
+ python eval.py configs/scannet.yaml checkpoints/scannet.ckpt
87
+ python eval.py configs/megadepth.yaml checkpoints/megedepth.ckpt
88
+ ```
89
+
90
+ ### HO3D
91
+ 1. Download HO3D (version 3) dataset [here](https://www.tugraz.at/institute/icg/research/team-lepetit/research-projects/hand-object-3d-pose-annotation/), `HO3D_v3.zip` and `HO3D_v3_segmentations_rendered.zip` are required.
92
+ 2. Unzip and organize the downloaded files:
93
+ ```
94
+ mkdir data/ho3d
95
+ unzip <pathto>/HO3D_v3.zip -d data/ho3d
96
+ unzip <pathto>/HO3D_v3_segmentations_rendered.zip -d data/ho3d
97
+ ```
98
+ 3. Evaluate with the following commands:
99
+ ```
100
+ python eval.py configs/ho3d.yaml checkpoints/ho3d.ckpt
101
+ python eval_add_reproj.py configs/ho3d.yaml checkpoints/ho3d.ckpt
102
+ ```
103
+
104
+ ### Linemod
105
+ 1. Download Linemod dataset [here](https://bop.felk.cvut.cz/datasets/) or run the following commands:
106
+ ```
107
+ cd data
108
+
109
+ export SRC=https://bop.felk.cvut.cz/media/data/bop_datasets
110
+ wget $SRC/lm_base.zip # Base archive with dataset info, camera parameters, etc.
111
+ wget $SRC/lm_models.zip # 3D object models.
112
+ wget $SRC/lm_test_all.zip # All test images ("_bop19" for a subset used in the BOP Challenge 2019/2020).
113
+ wget $SRC/lm_train_pbr.zip # PBR training images (rendered with BlenderProc4BOP).
114
+
115
+ unzip lm_base.zip # Contains folder "lm".
116
+ unzip lm_models.zip -d lm # Unpacks to "lm".
117
+ unzip lm_test_all.zip -d lm # Unpacks to "lm".
118
+ unzip lm_train_pbr.zip -d lm # Unpacks to "lm".
119
+ ```
120
+
121
+ 2. Evaluate with the following commands:
122
+ ```
123
+ python eval.py configs/linemod.yaml checkpoints/linemod.ckpt
124
+ python eval_add_reproj.py configs/linemod.yaml checkpoints/linemod.ckpt
125
+ ```
126
+
127
+ ### Niantic
128
+ 1. Download Niantic dataset [here](https://research.nianticlabs.com/mapfree-reloc-benchmark/dataset).
129
+ 2. Unzip and organize the downloaded files:
130
+ ```
131
+ mkdir data/mapfree
132
+ unzip <pathto>/train.zip -d data/mapfree
133
+ unzip <pathto>/val.zip -d data/mapfree
134
+ unzip <pathto>/test.zip -d data/mapfree
135
+ ```
136
+ 3. The ground truth of the test set is not publicly available, but you can run the following command to produce a new submission file and submit it on the [project page](https://research.nianticlabs.com/mapfree-reloc-benchmark/submit) for evaluation:
137
+ ```
138
+ python eval_add_reproj.py configs/mapfree.yaml checkpoints/mapfree.ckpt
139
+ ```
140
+ You should be able to find a `new_submission.zip` in `SRPose/assets/` afterwards, or you can submit the already produced file `SRPose/assets/mapfree_submission.zip` instead.
141
+
142
+
143
+ ## Training
144
+ Download and organize the datasets following [Evaluation](#evaluation), then run the following command for training:
145
+ ```
146
+ python train.py configs/<dataset>.yaml
147
+ ```
148
+ Please refer to the `.yaml` files in `SRPose/configs/` for detailed configurations.
149
+
150
+
151
+ ## Baselines
152
+ We also offer two publicly available matcher-based baselines, [LightGlue](https://github.com/cvg/LightGlue) and [LoFTR](https://github.com/zju3dv/LoFTR), for evaluation and comparison.
153
+ Just run the following commands:
154
+ ```
155
+ # For Matterport, ScanNet and MegaDepth
156
+ python eval_baselines.py configs/<dataset>.yaml lightglue
157
+ python eval_baselines.py configs/<dataset>.yaml loftr
158
+
159
+ # For HO3D and Linemod
160
+ python eval_baselines.py configs/<dataset>.yaml lightglue --resize 640 --depth
161
+ python eval_baselines.py configs/<dataset>.yaml loftr --resize 640 --depth
162
+ ```
163
+
164
+ The `--resize xx` option controls the larger dimension of cropped target object images that will be resized to.
165
+ The `--depth` option controls whether the depth maps will be used to obtain scaled pose estimation.
166
+
167
+ ## Acknowledgements
168
+ In this repository, we have used codes from the following repositories. We thank all the authors for sharing great codes.
169
+ - [LightGlue](https://github.com/cvg/LightGlue)
170
+ - [LoFTR](https://github.com/zju3dv/LoFTR)
171
+ - [8point](https://github.com/crockwell/rel_pose)
172
+ - [SparsePlanes](https://github.com/jinlinyi/SparsePlanes/tree/main)
173
+ - [Map-free](https://github.com/nianticlabs/map-free-reloc/tree/main)
174
+
175
+ ## Citation
176
+ ```
177
+ @inproceedings{yin2024srpose,
178
+ title={SRPose: Two-view Relative Pose Estimation with Sparse Keypoints},
179
+ author={Yin, Rui and Zhang, Yulun and Pan, Zherong and Zhu, Jianjun and Wang, Cheng and Jia, Biao},
180
+ booktitle={ECCV},
181
+ year={2024}
182
+ }
183
+ ```
assets/figures/obj_vis_gt.png ADDED
assets/figures/obj_vis_query.png ADDED
assets/figures/obj_vis_reference_labeled.png ADDED
assets/figures/scene5_vis_0.png ADDED
assets/figures/scene5_vis_1.png ADDED
assets/figures/scene5_vis_gt.png ADDED
assets/ho3d_test_3000/ho3d_test.json ADDED
The diff for this file is too large to render. See raw diff
 
assets/linemod_test_1500/linemod_test.json ADDED
The diff for this file is too large to render. See raw diff
 
assets/mapfree_submission.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:026e799de5bf9eed2a1f627e64d6981f983cb729337036390481c402c47fbc5c
3
+ size 6569663
assets/megadepth_test_1500_scene_info/0015_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d441df1d380b2ed34449b944d9f13127e695542fa275098d38a6298835672f22
3
+ size 231253
assets/megadepth_test_1500_scene_info/0015_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f34b5231d04a84d84378c671dd26854869663b5eafeae2ebaf624a279325139
3
+ size 231253
assets/megadepth_test_1500_scene_info/0022_0.1_0.3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba46e6b9ec291fc7271eb9741d5c75ca04b83d3d7281e049815de9cb9024f4d9
3
+ size 272610
assets/megadepth_test_1500_scene_info/0022_0.3_0.5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f4465da174b96deba61e5328886e4f2e687d34b890efca69e0c838736f8ae12
3
+ size 272610
assets/megadepth_test_1500_scene_info/0022_0.5_0.7.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:684ae10f03001917c3ca0d12d441f372ce3c7e6637bd1277a3cda60df4207fe9
3
+ size 272610
assets/megadepth_test_1500_scene_info/megadepth_test_1500.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 0022_0.1_0.3
2
+ 0015_0.1_0.3
3
+ 0015_0.3_0.5
4
+ 0022_0.3_0.5
5
+ 0022_0.5_0.7
assets/scannet_test_1500/intrinsics.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25ac102c69e2e4e2f0ab9c0d64f4da2b815e0901630768bdfde30080ced3605c
3
+ size 23922
assets/scannet_test_1500/scannet_test.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ test.npz
assets/scannet_test_1500/statistics.json ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "scene0707_00": 15,
3
+ "scene0708_00": 15,
4
+ "scene0709_00": 15,
5
+ "scene0710_00": 15,
6
+ "scene0711_00": 15,
7
+ "scene0712_00": 15,
8
+ "scene0713_00": 15,
9
+ "scene0714_00": 15,
10
+ "scene0715_00": 15,
11
+ "scene0716_00": 15,
12
+ "scene0717_00": 15,
13
+ "scene0718_00": 15,
14
+ "scene0719_00": 15,
15
+ "scene0720_00": 15,
16
+ "scene0721_00": 15,
17
+ "scene0722_00": 15,
18
+ "scene0723_00": 15,
19
+ "scene0724_00": 15,
20
+ "scene0725_00": 15,
21
+ "scene0726_00": 15,
22
+ "scene0727_00": 15,
23
+ "scene0728_00": 15,
24
+ "scene0729_00": 15,
25
+ "scene0730_00": 15,
26
+ "scene0731_00": 15,
27
+ "scene0732_00": 15,
28
+ "scene0733_00": 15,
29
+ "scene0734_00": 15,
30
+ "scene0735_00": 15,
31
+ "scene0736_00": 15,
32
+ "scene0737_00": 15,
33
+ "scene0738_00": 15,
34
+ "scene0739_00": 15,
35
+ "scene0740_00": 15,
36
+ "scene0741_00": 15,
37
+ "scene0742_00": 15,
38
+ "scene0743_00": 15,
39
+ "scene0744_00": 15,
40
+ "scene0745_00": 15,
41
+ "scene0746_00": 15,
42
+ "scene0747_00": 15,
43
+ "scene0748_00": 15,
44
+ "scene0749_00": 15,
45
+ "scene0750_00": 15,
46
+ "scene0751_00": 15,
47
+ "scene0752_00": 15,
48
+ "scene0753_00": 15,
49
+ "scene0754_00": 15,
50
+ "scene0755_00": 15,
51
+ "scene0756_00": 15,
52
+ "scene0757_00": 15,
53
+ "scene0758_00": 15,
54
+ "scene0759_00": 15,
55
+ "scene0760_00": 15,
56
+ "scene0761_00": 15,
57
+ "scene0762_00": 15,
58
+ "scene0763_00": 15,
59
+ "scene0764_00": 15,
60
+ "scene0765_00": 15,
61
+ "scene0766_00": 15,
62
+ "scene0767_00": 15,
63
+ "scene0768_00": 15,
64
+ "scene0769_00": 15,
65
+ "scene0770_00": 15,
66
+ "scene0771_00": 15,
67
+ "scene0772_00": 15,
68
+ "scene0773_00": 15,
69
+ "scene0774_00": 15,
70
+ "scene0775_00": 15,
71
+ "scene0776_00": 15,
72
+ "scene0777_00": 15,
73
+ "scene0778_00": 15,
74
+ "scene0779_00": 15,
75
+ "scene0780_00": 15,
76
+ "scene0781_00": 15,
77
+ "scene0782_00": 15,
78
+ "scene0783_00": 15,
79
+ "scene0784_00": 15,
80
+ "scene0785_00": 15,
81
+ "scene0786_00": 15,
82
+ "scene0787_00": 15,
83
+ "scene0788_00": 15,
84
+ "scene0789_00": 15,
85
+ "scene0790_00": 15,
86
+ "scene0791_00": 15,
87
+ "scene0792_00": 15,
88
+ "scene0793_00": 15,
89
+ "scene0794_00": 15,
90
+ "scene0795_00": 15,
91
+ "scene0796_00": 15,
92
+ "scene0797_00": 15,
93
+ "scene0798_00": 15,
94
+ "scene0799_00": 15,
95
+ "scene0800_00": 15,
96
+ "scene0801_00": 15,
97
+ "scene0802_00": 15,
98
+ "scene0803_00": 15,
99
+ "scene0804_00": 15,
100
+ "scene0805_00": 15,
101
+ "scene0806_00": 15
102
+ }
assets/scannet_test_1500/test.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b982b9c1f762e7d31af552ecc1ccf1a6add013197f74ec69c84a6deaa6f580ad
3
+ size 71687
baselines/matchers.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+
4
+ from lightglue import LightGlue as LightGlue_
5
+ from lightglue import SuperPoint
6
+ from lightglue.utils import rbd
7
+ from kornia.feature import LoFTR as LoFTR_
8
+
9
+
10
+ def image_rgb2gray(image):
11
+ # in: torch.tensor - (3, H, W)
12
+ # out: (1, H, W)
13
+ image = image[0] * 0.3 + image[1] * 0.59 + image[2] * 0.11
14
+ return image[None]
15
+
16
+
17
+ class LightGlue():
18
+ def __init__(self, num_keypoints=2048, device='cuda'):
19
+ self.extractor = SuperPoint(max_num_keypoints=num_keypoints).eval().to(device) # load the extractor
20
+ self.matcher = LightGlue_(features='superpoint').eval().to(device) # load the matcher
21
+ self.device = device
22
+
23
+ @torch.no_grad()
24
+ def match(self, image0, image1):
25
+ start_time = time.time()
26
+
27
+ # image: torch.tensor - (3, H, W)
28
+ image0 = image0.to(self.device)
29
+ image1 = image1.to(self.device)
30
+
31
+ preprocess_time = time.time()
32
+
33
+ # extract local features
34
+ feats0 = self.extractor.extract(image0) # auto-resize the image, disable with resize=None
35
+ feats1 = self.extractor.extract(image1)
36
+
37
+ extract_time = time.time()
38
+
39
+ # match the features
40
+ matches01 = self.matcher({'image0': feats0, 'image1': feats1})
41
+ feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension
42
+ matches = matches01['matches'] # indices with shape (K,2)
43
+ points0 = feats0['keypoints'][matches[..., 0]] # coordinates in image #0, shape (K,2)
44
+ points1 = feats1['keypoints'][matches[..., 1]] # coordinates in image #1, shape (K,2)
45
+
46
+ match_time = time.time()
47
+
48
+ return points0, points1, preprocess_time-start_time, extract_time-preprocess_time, match_time-extract_time
49
+
50
+
51
+ class LoFTR():
52
+ def __init__(self, pretrained='indoor', device='cuda'):
53
+ self.loftr = LoFTR_(pretrained=pretrained).eval().to(device)
54
+ self.device = device
55
+
56
+ @torch.no_grad()
57
+ def match(self, image0, image1):
58
+ start_time = time.time()
59
+
60
+ # image: torch.tensor - (3, H, W)
61
+ image0 = image_rgb2gray(image0)[None].to(self.device)
62
+ image1 = image_rgb2gray(image1)[None].to(self.device)
63
+
64
+ preprocess_time = time.time()
65
+
66
+ extract_time = time.time()
67
+
68
+ out = self.loftr({'image0': image0, 'image1': image1})
69
+ points0, points1 = out['keypoints0'], out['keypoints1']
70
+
71
+ match_time = time.time()
72
+ return points0, points1, preprocess_time-start_time, extract_time-preprocess_time, match_time-extract_time
baselines/pose.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import Resize
3
+
4
+ from .matchers import LightGlue, LoFTR
5
+ # from .__models import SuperGlue, SGMNet, ASpanFormer, DKM
6
+ from .pose_solver import EssentialMatrixSolver, EssentialMatrixMetricSolver, PnPSolver, ProcrustesSolver
7
+
8
+ import time
9
+
10
+
11
+ class PoseRecover():
12
+ def __init__(self, matcher='lightglue', solver='procrustes', img_resize=None, device='cuda'):
13
+ self.device = device
14
+
15
+ if matcher == 'lightglue':
16
+ self.matcher = LightGlue(device=device)
17
+ elif matcher == 'loftr':
18
+ self.matcher = LoFTR(device=device)
19
+ # elif matcher == 'superglue':
20
+ # self.matcher = SuperGlue(device=device)
21
+ # elif matcher == 'aspanformer':
22
+ # self.matcher = ASpanFormer(device=device)
23
+ # elif matcher == 'sgmnet':
24
+ # self.matcher = SGMNet(device=device)
25
+ # elif matcher == 'dkm':
26
+ # self.matcher = DKM(device=device)
27
+ else:
28
+ raise NotImplementedError
29
+
30
+ self.img_resize = img_resize
31
+
32
+ self.basic_solver = EssentialMatrixSolver()
33
+
34
+ if solver == 'essential':
35
+ self.scaled_solver = EssentialMatrixMetricSolver()
36
+ elif solver == 'pnp':
37
+ self.scaled_solver = PnPSolver()
38
+ elif solver == 'procrustes':
39
+ self.scaled_solver = ProcrustesSolver()
40
+
41
+ def recover(self, image0, image1, K0, K1, bbox0=None, bbox1=None, mask0=None, mask1=None, depth0=None, depth1=None):
42
+ if self.img_resize is not None:
43
+ h, w = image0.shape[-2:]
44
+ if h > w:
45
+ h_new = self.img_resize
46
+ w_new = int(w * h_new / h)
47
+ else:
48
+ w_new = self.img_resize
49
+ h_new = int(h * w_new / w)
50
+
51
+ # h_new, w_new = 480, 640
52
+ resize = Resize((h_new, w_new), antialias=True)
53
+ scale0 = torch.tensor([image0.shape[-1]/w_new, image0.shape[-2]/h_new], dtype=torch.float)
54
+ scale1 = torch.tensor([image1.shape[-1]/w_new, image1.shape[-2]/h_new], dtype=torch.float)
55
+ image0 = resize(image0)
56
+ image1 = resize(image1)
57
+
58
+ points0, points1, preprocess_time, extract_time, match_time = self.matcher.match(image0, image1)
59
+
60
+ if self.img_resize is not None:
61
+ points0 *= scale0.unsqueeze(0).to(points0.device)
62
+ points1 *= scale1.unsqueeze(0).to(points1.device)
63
+
64
+ if bbox0 is not None and bbox1 is not None:
65
+ x1, y1, x2, y2 = bbox0
66
+ u1, v1, u2, v2 = bbox1
67
+
68
+ points0[:, 0] += x1
69
+ points0[:, 1] += y1
70
+
71
+ points1[:, 0] += u1
72
+ points1[:, 1] += v1
73
+
74
+ if mask0 is not None and mask1 is not None:
75
+ filtered_ind0 = mask0[(points0[:, 1]).int(), (points0[:, 0]).int()]
76
+ filtered_ind1 = mask1[(points1[:, 1]).int(), (points1[:, 0]).int()]
77
+ filtered_inds = filtered_ind0 * filtered_ind1
78
+ points0 = points0[filtered_inds]
79
+ points1 = points1[filtered_inds]
80
+
81
+ points0, points1 = points0.cpu().numpy(), points1.cpu().numpy()
82
+
83
+ start_time = time.time()
84
+
85
+ if depth0 is None or depth1 is None:
86
+ R_est, t_est, _ = self.basic_solver.estimate_pose(points0, points1, {'K_color0': K0, 'K_color1': K1})
87
+ else:
88
+ R_est, t_est, _ = self.scaled_solver.estimate_pose(points0, points1, {'K_color0': K0, 'K_color1': K1, 'depth0': depth0, 'depth1': depth1})
89
+
90
+ recover_time = time.time()
91
+
92
+ return R_est, t_est, points0, points1, preprocess_time, extract_time, match_time, recover_time-start_time
baselines/pose_solver.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2 as cv
3
+ import open3d as o3d
4
+
5
+
6
+ def backproject_3d(uv, depth, K):
7
+ '''
8
+ Backprojects 2d points given by uv coordinates into 3D using their depth values and intrinsic K
9
+ :param uv: array [N,2]
10
+ :param depth: array [N]
11
+ :param K: array [3,3]
12
+ :return: xyz: array [N,3]
13
+ '''
14
+
15
+ uv1 = np.concatenate([uv, np.ones((uv.shape[0], 1))], axis=1)
16
+ xyz = depth.reshape(-1, 1) * (np.linalg.inv(K) @ uv1.T).T
17
+ return xyz
18
+
19
+
20
+ class EssentialMatrixSolver:
21
+ '''Obtain relative pose (up to scale) given a set of 2D-2D correspondences'''
22
+
23
+ def __init__(self, ransac_pix_threshold=0.5, ransac_confidence=0.99999):
24
+
25
+ # EMat RANSAC parameters
26
+ self.ransac_pix_threshold = ransac_pix_threshold
27
+ self.ransac_confidence = ransac_confidence
28
+
29
+ def estimate_pose(self, kpts0, kpts1, data):
30
+ R = np.full((3, 3), np.nan)
31
+ t = np.full((3), np.nan)
32
+ if len(kpts0) < 5:
33
+ return R, t, 0
34
+
35
+ K0 = data['K_color0'].numpy()
36
+ K1 = data['K_color1'].numpy()
37
+
38
+ # normalize keypoints
39
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
40
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
41
+
42
+ # normalize ransac threshold
43
+ ransac_thr = self.ransac_pix_threshold / np.mean([K0[0, 0], K1[1, 1], K0[1, 1], K1[0, 0]])
44
+
45
+ # compute pose with OpenCV
46
+ E, mask = cv.findEssentialMat(
47
+ kpts0, kpts1, np.eye(3),
48
+ threshold=ransac_thr, prob=self.ransac_confidence, method=cv.RANSAC)
49
+ self.mask = mask
50
+ if E is None:
51
+ return R, t, 0
52
+
53
+ # recover pose from E
54
+ best_num_inliers = 0
55
+ ret = R, t, 0
56
+ for _E in np.split(E, len(E) / 3):
57
+ n, R, t, _ = cv.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
58
+ if n > best_num_inliers:
59
+ best_num_inliers = n
60
+ ret = (R, t[:, 0], n)
61
+ return ret
62
+
63
+
64
+ class EssentialMatrixMetricSolverMEAN(EssentialMatrixSolver):
65
+ '''Obtains relative pose with scale using E-Mat decomposition and depth values at inlier correspondences'''
66
+
67
+ def __init__(self, cfg):
68
+ super().__init__(cfg)
69
+
70
+ def estimate_pose(self, kpts0, kpts1, data):
71
+ '''Estimates metric translation vector using by back-projecting E-mat inliers to 3D using depthmaps.
72
+ The metric translation vector can be obtained by looking at the residual vector (projected to the translation vector direction).
73
+ In this version, each 3D-3D correspondence gives an optimal scale for the translation vector.
74
+ We simply aggregate them by averaging them.
75
+ '''
76
+
77
+ # get pose up to scale
78
+ R, t, inliers = super().estimate_pose(kpts0, kpts1, data)
79
+ if inliers == 0:
80
+ return R, t, inliers
81
+
82
+ # backproject E-mat inliers at each camera
83
+ K0 = data['K_color0']
84
+ K1 = data['K_color1']
85
+ mask = self.mask.ravel() == 1 # get E-mat inlier mask from super class
86
+ inliers_kpts0 = np.int32(kpts0[mask])
87
+ inliers_kpts1 = np.int32(kpts1[mask])
88
+ depth_inliers_0 = data['depth0'][inliers_kpts0[:, 1], inliers_kpts0[:, 0]].numpy()
89
+ depth_inliers_1 = data['depth1'][inliers_kpts1[:, 1], inliers_kpts1[:, 0]].numpy()
90
+ # check for valid depth
91
+ valid = (depth_inliers_0 > 0) * (depth_inliers_1 > 0)
92
+ if valid.sum() < 1:
93
+ R = np.full((3, 3), np.nan)
94
+ t = np.full((3, 1), np.nan)
95
+ inliers = 0
96
+ return R, t, inliers
97
+ xyz0 = backproject_3d(inliers_kpts0[valid], depth_inliers_0[valid], K0)
98
+ xyz1 = backproject_3d(inliers_kpts1[valid], depth_inliers_1[valid], K1)
99
+
100
+ # rotate xyz0 to xyz1 CS (so that axes are parallel)
101
+ xyz0 = (R @ xyz0.T).T
102
+
103
+ # get average point for each camera
104
+ pmean0 = np.mean(xyz0, axis=0)
105
+ pmean1 = np.mean(xyz1, axis=0)
106
+
107
+ # find scale as the 'length' of the translation vector that minimises the 3D distance between projected points from 0 and the corresponding points in 1
108
+ scale = np.dot(pmean1 - pmean0, t)
109
+ t_metric = scale * t
110
+ t_metric = t_metric.reshape(3, 1)
111
+
112
+ return R, t_metric[:, 0], inliers
113
+
114
+
115
+ class EssentialMatrixMetricSolver(EssentialMatrixSolver):
116
+ '''
117
+ Obtains relative pose with scale using E-Mat decomposition and RANSAC for scale based on depth values at inlier correspondences.
118
+ The scale of the translation vector is obtained using RANSAC over the possible scales recovered from 3D-3D correspondences.
119
+ '''
120
+
121
+ def __init__(self, ransac_pix_threshold=0.5, ransac_confidence=0.99999, ransac_scale_threshold=0.1):
122
+ super().__init__(ransac_pix_threshold, ransac_confidence)
123
+ self.ransac_scale_threshold = ransac_scale_threshold
124
+
125
+ def estimate_pose(self, kpts0, kpts1, data):
126
+ '''Estimates metric translation vector using by back-projecting E-mat inliers to 3D using depthmaps.
127
+ '''
128
+
129
+ # get pose up to scale
130
+ R, t, inliers = super().estimate_pose(kpts0, kpts1, data)
131
+ if inliers == 0:
132
+ return R, t, inliers
133
+
134
+ # backproject E-mat inliers at each camera
135
+ K0 = data['K_color0']
136
+ K1 = data['K_color1']
137
+ mask = self.mask.ravel() == 1 # get E-mat inlier mask from super class
138
+ inliers_kpts0 = np.int32(kpts0[mask])
139
+ inliers_kpts1 = np.int32(kpts1[mask])
140
+ depth_inliers_0 = data['depth0'][inliers_kpts0[:, 1], inliers_kpts0[:, 0]].numpy()
141
+ depth_inliers_1 = data['depth1'][inliers_kpts1[:, 1], inliers_kpts1[:, 0]].numpy()
142
+
143
+ # check for valid depth
144
+ valid = (depth_inliers_0 > 0) * (depth_inliers_1 > 0)
145
+ if valid.sum() < 1:
146
+ R = np.full((3, 3), np.nan)
147
+ t = np.full((3, ), np.nan)
148
+ inliers = 0
149
+ return R, t, inliers
150
+ xyz0 = backproject_3d(inliers_kpts0[valid], depth_inliers_0[valid], K0)
151
+ xyz1 = backproject_3d(inliers_kpts1[valid], depth_inliers_1[valid], K1)
152
+
153
+ # rotate xyz0 to xyz1 CS (so that axes are parallel)
154
+ xyz0 = (R @ xyz0.T).T
155
+
156
+ # get individual scales (for each 3D-3D correspondence)
157
+ scale = np.dot(xyz1 - xyz0, t.reshape(3, 1)) # [N, 1]
158
+
159
+ # RANSAC loop
160
+ best_inliers = 0
161
+ best_scale = None
162
+ for scale_hyp in scale:
163
+ inliers_hyp = (np.abs(scale - scale_hyp) < self.ransac_scale_threshold).sum().item()
164
+ if inliers_hyp > best_inliers:
165
+ best_scale = scale_hyp
166
+ best_inliers = inliers_hyp
167
+
168
+ # Output results
169
+ t_metric = best_scale * t
170
+ t_metric = t_metric.reshape(3, 1)
171
+
172
+ return R, t_metric[:, 0], best_inliers
173
+
174
+
175
+ class PnPSolver:
176
+ '''Estimate relative pose (metric) using Perspective-n-Point algorithm (2D-3D) correspondences'''
177
+
178
+ def __init__(self, ransac_iterations=1000, reprojection_inlier_threshold=3, confidence=0.99999):
179
+ # PnP RANSAC parameters
180
+ self.ransac_iterations = ransac_iterations
181
+ self.reprojection_inlier_threshold = reprojection_inlier_threshold
182
+ self.confidence = confidence
183
+
184
+ def estimate_pose(self, pts0, pts1, data):
185
+ # uses nearest neighbour
186
+ pts0 = np.int32(pts0)
187
+
188
+ if len(pts0) < 4:
189
+ return np.full((3, 3), np.nan), np.full((3, 1), np.nan), 0
190
+
191
+ # get depth at correspondence points
192
+ depth_0 = data['depth0']
193
+ depth_pts0 = depth_0[pts0[:, 1], pts0[:, 0]]
194
+
195
+ # remove invalid pts (depth == 0)
196
+ valid = depth_pts0 > depth_0.min()
197
+ if valid.sum() < 4:
198
+ return np.full((3, 3), np.nan), np.full((3, 1), np.nan), 0
199
+ pts0 = pts0[valid]
200
+ pts1 = pts1[valid]
201
+ depth_pts0 = depth_pts0[valid]
202
+
203
+ # backproject points to 3D in each sensors' local coordinates
204
+ K0 = data['K_color0']
205
+ K1 = data['K_color1']
206
+ xyz_0 = backproject_3d(pts0, depth_pts0, K0).numpy()
207
+
208
+ # get relative pose using PnP + RANSAC
209
+ succ, rvec, tvec, inliers = cv.solvePnPRansac(
210
+ xyz_0, pts1, K1.numpy(),
211
+ None, iterationsCount=self.ransac_iterations,
212
+ reprojectionError=self.reprojection_inlier_threshold, confidence=self.confidence,
213
+ flags=cv.SOLVEPNP_P3P)
214
+
215
+ # refine with iterative PnP using inliers only
216
+ if succ and len(inliers) >= 6:
217
+ succ, rvec, tvec, _ = cv.solvePnPGeneric(xyz_0[inliers], pts1[inliers], K1.numpy(
218
+ ), None, useExtrinsicGuess=True, rvec=rvec, tvec=tvec, flags=cv.SOLVEPNP_ITERATIVE)
219
+ rvec = rvec[0]
220
+ tvec = tvec[0]
221
+
222
+ # avoid degenerate solutions
223
+ if succ:
224
+ if np.linalg.norm(tvec) > 1000:
225
+ succ = False
226
+
227
+ if succ:
228
+ R, _ = cv.Rodrigues(rvec)
229
+ t = tvec.reshape(3, 1)
230
+ else:
231
+ R = np.full((3, 3), np.nan)
232
+ t = np.full((3, 1), np.nan)
233
+ inliers = []
234
+
235
+ return R, t[:, 0], inliers
236
+
237
+
238
+ class ProcrustesSolver:
239
+ '''Estimate relative pose (metric) using 3D-3D correspondences'''
240
+
241
+ def __init__(self, ransac_max_corr_distance=0.5, refine=False):
242
+
243
+ # Procrustes RANSAC parameters
244
+ self.ransac_max_corr_distance = ransac_max_corr_distance
245
+ self.refine = refine
246
+
247
+ def estimate_pose(self, pts0, pts1, data):
248
+ # uses nearest neighbour
249
+ pts0 = np.int32(pts0)
250
+ pts1 = np.int32(pts1)
251
+
252
+ if len(pts0) < 3:
253
+ return np.full((3, 3), np.nan), np.full((3), np.nan), 0
254
+
255
+ # get depth at correspondence points
256
+ depth_0, depth_1 = data['depth0'], data['depth1']
257
+ depth_pts0 = depth_0[pts0[:, 1], pts0[:, 0]]
258
+ depth_pts1 = depth_1[pts1[:, 1], pts1[:, 0]]
259
+
260
+ # remove invalid pts (depth == 0)
261
+ valid = (depth_pts0 > depth_0.min()) * (depth_pts1 > depth_1.min())
262
+ if valid.sum() < 3:
263
+ return np.full((3, 3), np.nan), np.full((3), np.nan), 0
264
+ pts0 = pts0[valid]
265
+ pts1 = pts1[valid]
266
+ depth_pts0 = depth_pts0[valid]
267
+ depth_pts1 = depth_pts1[valid]
268
+
269
+ # backproject points to 3D in each sensors' local coordinates
270
+ K0 = data['K_color0']
271
+ K1 = data['K_color1']
272
+ xyz_0 = backproject_3d(pts0, depth_pts0, K0)
273
+ xyz_1 = backproject_3d(pts1, depth_pts1, K1)
274
+
275
+ # create open3d point cloud objects and correspondences idxs
276
+ pcl_0 = o3d.geometry.PointCloud()
277
+ pcl_0.points = o3d.utility.Vector3dVector(xyz_0)
278
+ pcl_1 = o3d.geometry.PointCloud()
279
+ pcl_1.points = o3d.utility.Vector3dVector(xyz_1)
280
+ corr_idx = np.arange(pts0.shape[0])
281
+ corr_idx = np.tile(corr_idx.reshape(-1, 1), (1, 2))
282
+ corr_idx = o3d.utility.Vector2iVector(corr_idx)
283
+
284
+ # obtain relative pose using procrustes
285
+ ransac_criteria = o3d.pipelines.registration.RANSACConvergenceCriteria()
286
+ res = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
287
+ pcl_0, pcl_1, corr_idx, self.ransac_max_corr_distance, criteria=ransac_criteria)
288
+ inliers = int(res.fitness * np.asarray(pcl_1.points).shape[0])
289
+
290
+ # refine with ICP
291
+ if self.refine:
292
+ # first, backproject both (whole) point clouds
293
+ vv, uu = np.mgrid[0:depth_0.shape[0], 0:depth_1.shape[1]]
294
+ uv_coords = np.concatenate([uu.reshape(-1, 1), vv.reshape(-1, 1)], axis=1)
295
+
296
+ valid = depth_0.reshape(-1) > 0
297
+ xyz_0 = backproject_3d(uv_coords[valid], depth_0.reshape(-1)[valid], K0)
298
+
299
+ valid = depth_1.reshape(-1) > 0
300
+ xyz_1 = backproject_3d(uv_coords[valid], depth_1.reshape(-1)[valid], K1)
301
+
302
+ pcl_0 = o3d.geometry.PointCloud()
303
+ pcl_0.points = o3d.utility.Vector3dVector(xyz_0)
304
+ pcl_1 = o3d.geometry.PointCloud()
305
+ pcl_1.points = o3d.utility.Vector3dVector(xyz_1)
306
+
307
+ icp_criteria = o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-4,
308
+ relative_rmse=1e-4,
309
+ max_iteration=30)
310
+
311
+ res = o3d.pipelines.registration.registration_icp(pcl_0,
312
+ pcl_1,
313
+ self.ransac_max_corr_distance,
314
+ init=res.transformation,
315
+ criteria=icp_criteria)
316
+
317
+ R = res.transformation[:3, :3]
318
+ t = res.transformation[:3, -1]
319
+ inliers = int(res.fitness * np.asarray(pcl_1.points).shape[0])
320
+ return R, t, inliers
configs/default.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+
4
+ _CN = CN()
5
+
6
+ # Model
7
+ _CN.MODEL = CN()
8
+ _CN.MODEL.NUM_KEYPOINTS = 1024
9
+ _CN.MODEL.TEST_NUM_KEYPOINTS = 2048
10
+ _CN.MODEL.N_LAYERS = 6
11
+ _CN.MODEL.NUM_HEADS = 4
12
+ _CN.MODEL.FEATURES = 'superpoint'
13
+
14
+ # Dataset
15
+ _CN.DATASET = CN()
16
+ _CN.DATASET.TASK = None
17
+ _CN.DATASET.DATA_SOURCE = None
18
+ _CN.DATASET.DATA_ROOT = None
19
+ _CN.DATASET.MIN_OVERLAP_SCORE = None
20
+
21
+ ## For MapFree
22
+ _CN.DATASET.ESTIMATED_DEPTH = None
23
+
24
+ ## For Linemod(BOP)
25
+ _CN.DATASET.OBJECT_ID = None
26
+ _CN.DATASET.MIN_VISIBLE_FRACT = None
27
+ _CN.DATASET.MAX_ANGLE_ERROR = None
28
+ _CN.DATASET.JSON_PATH = None
29
+
30
+ ## For MegaDepth/ScanNet
31
+ _CN.DATASET.TRAIN = CN()
32
+ _CN.DATASET.TRAIN.DATA_ROOT = None
33
+ _CN.DATASET.TRAIN.NPZ_ROOT = None
34
+ _CN.DATASET.TRAIN.LIST_PATH = None
35
+ _CN.DATASET.TRAIN.INTRINSIC_PATH = None
36
+ _CN.DATASET.TRAIN.MIN_OVERLAP_SCORE = None
37
+
38
+ _CN.DATASET.VAL = CN()
39
+ _CN.DATASET.VAL.DATA_ROOT = None
40
+ _CN.DATASET.VAL.NPZ_ROOT = None
41
+ _CN.DATASET.VAL.LIST_PATH = None
42
+ _CN.DATASET.VAL.INTRINSIC_PATH = None
43
+ _CN.DATASET.VAL.MIN_OVERLAP_SCORE = None
44
+
45
+ _CN.DATASET.TEST = CN()
46
+ _CN.DATASET.TEST.DATA_ROOT = None
47
+ _CN.DATASET.TEST.NPZ_ROOT = None
48
+ _CN.DATASET.TEST.LIST_PATH = None
49
+ _CN.DATASET.TEST.INTRINSIC_PATH = None
50
+ _CN.DATASET.TEST.MIN_OVERLAP_SCORE = None
51
+
52
+ # Train
53
+ _CN.TRAINER = CN()
54
+
55
+ _CN.TRAINER.EPOCHS = None
56
+ _CN.TRAINER.LEARNING_RATE = None
57
+ _CN.TRAINER.PCT_START = None
58
+ _CN.TRAINER.BATCH_SIZE = None
59
+ _CN.TRAINER.NUM_WORKERS = None
60
+ _CN.TRAINER.PIN_MEMORY = True
61
+ _CN.TRAINER.N_SAMPLES_PER_SUBSET = None
62
+
63
+ _CN.RANDOM_SEED = 0
64
+
65
+
66
+ # _CN.EMAT_RANSAC = CN()
67
+ # _CN.EMAT_RANSAC.PIX_THRESHOLD = 0.5
68
+ # _CN.EMAT_RANSAC.SCALE_THRESHOLD = 0.1
69
+ # _CN.EMAT_RANSAC.CONFIDENCE = 0.99999
70
+
71
+ # _CN.PNP = CN()
72
+ # _CN.PNP.RANSAC_ITER = 1000
73
+ # _CN.PNP.REPROJECTION_INLIER_THRESHOLD = 3
74
+ # _CN.PNP.CONFIDENCE = 0.99999
75
+
76
+ # _CN.PROCRUSTES = CN()
77
+ # _CN.PROCRUSTES.MAX_CORR_DIST = 0.05 # meters
78
+ # _CN.PROCRUSTES.REFINE = False
79
+
80
+
81
+ def get_cfg_defaults():
82
+ """Get a yacs CfgNode object with default values for my_project."""
83
+ # Return a clone so that the defaults will not be altered
84
+ # This is for the "local variable" use pattern
85
+ return _CN.clone()
configs/ho3d.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ NUM_KEYPOINTS: 1024
3
+
4
+ DATASET:
5
+ TASK: 'object'
6
+ DATA_SOURCE: 'ho3d'
7
+ DATA_ROOT: 'data/ho3d'
8
+ JSON_PATH: 'assets/ho3d_test_3000/ho3d_test.json'
9
+
10
+ MAX_ANGLE_ERROR: 45
11
+
12
+ TRAINER:
13
+ EPOCHS: 200
14
+ LEARNING_RATE: 0.00002
15
+ BATCH_SIZE: 32
16
+ NUM_WORKERS: 8
17
+ PCT_START: 0.3
18
+ N_SAMPLES_PER_SUBSET: 4000
19
+
configs/linemod.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ NUM_KEYPOINTS: 1200
3
+
4
+ DATASET:
5
+ TASK: 'object'
6
+ DATA_SOURCE: 'linemod'
7
+ DATA_ROOT: 'data/lm'
8
+ JSON_PATH: 'assets/linemod_test_1500/linemod_test.json'
9
+
10
+ MIN_VISIBLE_FRACT: 0.75
11
+ MAX_ANGLE_ERROR: 45
12
+
13
+ TRAINER:
14
+ EPOCHS: 200
15
+ LEARNING_RATE: 0.00002
16
+ BATCH_SIZE: 32
17
+ NUM_WORKERS: 8
18
+ PCT_START: 0.3
19
+ N_SAMPLES_PER_SUBSET: 200
20
+
configs/mapfree.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TASK: 'scene'
3
+ DATA_SOURCE: 'mapfree'
4
+ DATA_ROOT: 'data/mapfree/'
5
+ # ESTIMATED_DEPTH: None # To load estimated depth map, provide the suffix to the depth files, e.g. 'dptnyu', 'dptkiti'
6
+
7
+ TRAINER:
8
+ EPOCHS: 200
9
+ LEARNING_RATE: 0.00002
10
+ BATCH_SIZE: 32
11
+ NUM_WORKERS: 6
12
+ PCT_START: 0.3
13
+ N_SAMPLES_PER_SUBSET: 200
14
+
configs/matterport.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TASK: 'scene'
3
+ DATA_SOURCE: 'matterport'
4
+ DATA_ROOT: 'data/mp3d'
5
+
6
+ TRAINER:
7
+ EPOCHS: 200
8
+ LEARNING_RATE: 0.00005
9
+ BATCH_SIZE: 32
10
+ NUM_WORKERS: 8
11
+ PCT_START: 0.3
configs/megadepth.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TASK: "scene"
3
+ DATA_SOURCE: "megadepth"
4
+
5
+ TRAIN:
6
+ DATA_ROOT: "data/megadepth/train"
7
+ NPZ_ROOT: "data/megadepth/index/scene_info_0.1_0.7"
8
+ LIST_PATH: "data/megadepth/index/trainvaltest_list/train_list.txt"
9
+ MIN_OVERLAP_SCORE: 0.0
10
+
11
+ VAL:
12
+ DATA_ROOT: "data/megadepth/test"
13
+ NPZ_ROOT: "data/megadepth/index/scene_info_val_1500"
14
+ LIST_PATH: "data/megadepth/index/trainvaltest_list/val_list.txt"
15
+ MIN_OVERLAP_SCORE: 0.0
16
+
17
+ TEST:
18
+ DATA_ROOT: "data/megadepth/test"
19
+ NPZ_ROOT: "assets/megadepth_test_1500_scene_info"
20
+ LIST_PATH: "assets/megadepth_test_1500_scene_info/megadepth_test_1500.txt"
21
+ MIN_OVERLAP_SCORE: 0.0
22
+
23
+ TRAINER:
24
+ EPOCHS: 500
25
+ LEARNING_RATE: 0.00002
26
+ BATCH_SIZE: 32
27
+ NUM_WORKERS: 8
28
+ PCT_START: 0.3
29
+ N_SAMPLES_PER_SUBSET: 200
configs/scannet.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET:
2
+ TASK: "scene"
3
+ DATA_SOURCE: "scannet"
4
+
5
+ TRAIN:
6
+ DATA_ROOT: "data/scannet/train"
7
+ NPZ_ROOT: "data/scannet/index/scene_data/train"
8
+ LIST_PATH: "data/scannet/index/scene_data/train_list/scannet_all.txt"
9
+ INTRINSIC_PATH: "data/scannet/index/intrinsics.npz"
10
+ MIN_OVERLAP_SCORE: 0.4
11
+
12
+ VAL:
13
+ DATA_ROOT: "data/scannet/test"
14
+ NPZ_ROOT: "assets/scannet_test_1500"
15
+ LIST_PATH: "assets/scannet_test_1500/scannet_test.txt"
16
+ INTRINSIC_PATH: "assets/scannet_test_1500/intrinsics.npz"
17
+ MIN_OVERLAP_SCORE: 0.0
18
+
19
+ TEST:
20
+ DATA_ROOT: "data/scannet/test"
21
+ NPZ_ROOT: "assets/scannet_test_1500"
22
+ LIST_PATH: "assets/scannet_test_1500/scannet_test.txt"
23
+ INTRINSIC_PATH: "assets/scannet_test_1500/intrinsics.npz"
24
+ MIN_OVERLAP_SCORE: 0.0
25
+
26
+ TRAINER:
27
+ EPOCHS: 500
28
+ LEARNING_RATE: 0.0001
29
+ BATCH_SIZE: 32
30
+ NUM_WORKERS: 8
31
+ PCT_START: 0.3
32
+ N_SAMPLES_PER_SUBSET: 200
33
+
datasets/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .matterport import build_matterport
2
+ from .linemod import build_linemod
3
+ from .megadepth import build_concat_megadepth
4
+ from .scannet import build_concat_scannet
5
+ from .ho3d import build_ho3d
6
+ from .mapfree import build_concat_mapfree
7
+ from .sampler import RandomConcatSampler
8
+
9
+ dataset_dict = {
10
+ 'scene': {
11
+ 'matterport': build_matterport,
12
+ 'megadepth': build_concat_megadepth,
13
+ 'scannet': build_concat_scannet,
14
+ 'mapfree': build_concat_mapfree,
15
+ },
16
+ 'object': {
17
+ 'linemod': build_linemod,
18
+ 'ho3d': build_ho3d,
19
+ }
20
+ }
datasets/ho3d.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ from PIL import Image
4
+ import cv2
5
+ import pickle
6
+ import json
7
+ from tqdm import tqdm
8
+ import torch
9
+ from torch.utils.data import Dataset, ConcatDataset
10
+
11
+ from utils.augment import Augmentor
12
+
13
+
14
+ class HO3D(Dataset):
15
+ def __init__(self, data_root, sequence_path, mode):
16
+ self.data_root = Path(data_root)
17
+ mode = 'evaluation' if mode != 'train' else 'train'
18
+ self.sequence_dir = self.data_root / mode / sequence_path
19
+
20
+ self.color_dir = self.sequence_dir / 'rgb'
21
+ self.mask_dir = self.sequence_dir / 'seg'
22
+ self.depth_dir = self.sequence_dir / 'depth'
23
+ self.meta_dir = self.sequence_dir / 'meta'
24
+
25
+ self.color_paths = list(self.color_dir.iterdir())
26
+ self.color_paths = sorted(self.color_paths)
27
+
28
+ self.mask_paths = [self.mask_dir / f'{x.stem}.png' for x in self.color_paths]
29
+ self.depth_paths = [self.depth_dir / f'{x.stem}.png' for x in self.color_paths]
30
+ self.meta_paths = [self.meta_dir / f'{x.stem}.pkl' for x in self.color_paths]
31
+
32
+ # self.glcam_in_cvcam = torch.tensor([
33
+ # [1,0,0,0],
34
+ # [0,-1,0,0],
35
+ # [0,0,-1,0],
36
+ # [0,0,0,1]
37
+ # ]).float()
38
+ self.intrinsics, self.extrinsics, self.objCorners, self.objNames, valid = self._load_meta(self.meta_paths)
39
+
40
+ self.color_paths = np.array(self.color_paths)[valid.numpy()]
41
+ self.mask_paths = np.array(self.mask_paths)[valid.numpy()]
42
+ self.depth_paths = np.array(self.depth_paths)[valid.numpy()]
43
+ self.meta_paths = np.array(self.meta_paths)[valid.numpy()]
44
+
45
+ self.bboxes, valid = self._load_bboxes(self.mask_paths)
46
+ self.intrinsics = self.intrinsics[valid]
47
+ self.extrinsics = self.extrinsics[valid]
48
+ self.objCorners = self.objCorners[valid]
49
+ self.objNames = self.objNames[valid.numpy()]
50
+ self.color_paths = self.color_paths[valid.numpy()]
51
+ self.mask_paths = self.mask_paths[valid.numpy()]
52
+ self.depth_paths = self.depth_paths[valid.numpy()]
53
+ self.meta_paths = self.meta_paths[valid.numpy()]
54
+
55
+ assert len(self.color_paths) == self.intrinsics.shape[0]
56
+ assert len(self.objNames) == self.extrinsics.shape[0]
57
+
58
+ self.augment = Augmentor(mode=='train')
59
+
60
+ def __len__(self):
61
+ return len(self.color_paths)
62
+
63
+ def _load_bboxes(self, mask_paths):
64
+ bboxes = []
65
+ valid = []
66
+ for mask_path in mask_paths:
67
+ mask = cv2.imread(str(mask_path))
68
+ # mask = cv2.resize(mask, (640, 480))
69
+ w_scale, h_scale = 640 / mask.shape[1], 480 / mask.shape[0]
70
+ obj_mask = torch.from_numpy(mask[..., 1] == 255)
71
+
72
+ if obj_mask.float().sum() < 100:
73
+ valid.append(False)
74
+ continue
75
+ valid.append(True)
76
+
77
+ mask_inds = torch.where(obj_mask)
78
+ x1, x2 = mask_inds[0].aminmax()
79
+ y1, y2 = mask_inds[1].aminmax()
80
+ bbox = torch.tensor([y1*h_scale, x1*w_scale, y2*h_scale, x2*w_scale]).int()
81
+ bboxes.append(bbox)
82
+
83
+ bboxes = torch.stack(bboxes)
84
+ valid = torch.tensor(valid)
85
+
86
+ return bboxes, valid
87
+
88
+ def _load_meta(self, meta_paths):
89
+ intrinsics = []
90
+ extrinsics = []
91
+ objCorners = []
92
+ objNames = []
93
+ valid = []
94
+ for meta_path in meta_paths:
95
+ with open(meta_path, 'rb') as f:
96
+ anno = pickle.load(f, encoding='latin1')
97
+
98
+ if anno['camMat'] is None:
99
+ valid.append(False)
100
+ continue
101
+ valid.append(True)
102
+
103
+ camMat = torch.from_numpy(anno['camMat'])
104
+ ex = torch.eye(4)
105
+ ex[:3, :3] = torch.from_numpy(cv2.Rodrigues(anno['objRot'])[0])
106
+ ex[:3, 3] = torch.from_numpy(anno['objTrans'])
107
+ # ex = self.glcam_in_cvcam @ ex
108
+ objCorners3DRest = torch.from_numpy(anno['objCorners3DRest']).float()
109
+ # objCorners3DRest = (ex[:3, :3] @ objCorners3DRest.T + ex[:3, 3:]).T
110
+ objCorners3DRest = objCorners3DRest @ ex[:3, :3].T + ex[:3, 3]
111
+
112
+ intrinsics.append(camMat)
113
+ extrinsics.append(ex)
114
+ objCorners.append(objCorners3DRest)
115
+ objNames.append(anno['objName'])
116
+
117
+ intrinsics = torch.stack(intrinsics).float()
118
+ extrinsics = torch.stack(extrinsics).float()
119
+ objCorners = torch.stack(objCorners)
120
+ objNames = np.array(objNames)
121
+ valid = torch.tensor(valid)
122
+
123
+ return intrinsics, extrinsics, objCorners, objNames, valid
124
+
125
+ def _load_mask(self, mask_path):
126
+ mask = cv2.imread(str(mask_path))
127
+ mask = cv2.resize(mask, (640, 480))
128
+ mask = mask[..., 1] == 255
129
+ return mask
130
+
131
+ def _load_depth(self, depth_path):
132
+ depth_scale = 0.00012498664727900177
133
+ depth_img = cv2.imread(str(depth_path))
134
+
135
+ dpt = depth_img[:, :, 2] + depth_img[:, :, 1] * 256
136
+ dpt = dpt * depth_scale
137
+
138
+ return dpt
139
+
140
+ def __getitem__(self, idx):
141
+ color = cv2.imread(str(self.color_paths[idx]))
142
+ color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB)
143
+ # color = self.augment(color)
144
+ color = (torch.tensor(color).float() / 255.0).permute(2, 0, 1)
145
+
146
+ mask = self._load_mask(self.mask_paths[idx])
147
+ mask = torch.from_numpy(mask)
148
+ depth = self._load_depth(self.depth_paths[idx])
149
+ depth = torch.from_numpy(depth)
150
+
151
+ bbox = self.bboxes[idx]
152
+
153
+ intrinsic = self.intrinsics[idx]
154
+ extrinsic = self.extrinsics[idx]
155
+ objCorners = self.objCorners[idx]
156
+ objName = self.objNames[idx]
157
+
158
+ return {
159
+ 'color': color,
160
+ 'mask': mask,
161
+ 'depth': depth,
162
+ 'extrinsic': extrinsic,
163
+ 'intrinsic': intrinsic,
164
+ 'objCorners': objCorners,
165
+ 'bbox': bbox,
166
+ 'color_path': str(self.color_paths[idx]).split('/', 2)[-1],
167
+ 'objName': objName,
168
+ }
169
+
170
+
171
+ class HO3DPair(Dataset):
172
+ def __init__(self, data_root, mode, sequence_id, max_angle_error):
173
+ self.ho3d_dataset = HO3D(data_root, sequence_id, mode)
174
+
175
+ angle_err = self.get_angle_error(self.ho3d_dataset.extrinsics[:, :3, :3])
176
+ index0, index1 = torch.where(angle_err < max_angle_error)
177
+ filter = torch.where(index0 < index1)
178
+ self.index0, self.index1 = index0[filter], index1[filter]
179
+ # angle_err_filtered = angle_err[row, col]
180
+
181
+ self.indices = torch.tensor(list(zip(self.index0, self.index1)))
182
+ if mode == 'val' or mode == 'test':
183
+ self.indices = self.indices[torch.randperm(self.indices.size(0))[:1500]]
184
+
185
+ def get_angle_error(self, R):
186
+ # R: (B, 3, 3)
187
+ residual = torch.einsum('aij,bik->abjk', R, R)
188
+ trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1)
189
+ cosine = (trace - 1) / 2
190
+ cosine = torch.clip(cosine, -1, 1)
191
+ R_err = torch.acos(cosine)
192
+ angle_err = R_err.rad2deg()
193
+
194
+ return angle_err
195
+
196
+ def __len__(self):
197
+ return len(self.indices)
198
+
199
+ def __getitem__(self, idx):
200
+ idx0, idx1 = self.indices[idx]
201
+ data0, data1 = self.ho3d_dataset[idx0], self.ho3d_dataset[idx1]
202
+
203
+ images = torch.stack([data0['color'], data1['color']], dim=0)
204
+
205
+ ex0, ex1 = data0['extrinsic'], data1['extrinsic']
206
+ rel_ex = ex1 @ ex0.inverse()
207
+ rel_R = rel_ex[:3, :3]
208
+ rel_t = rel_ex[:3, 3]
209
+
210
+ intrinsics = torch.stack([data0['intrinsic'], data1['intrinsic']], dim=0)
211
+ bboxes = torch.stack([data0['bbox'], data1['bbox']])
212
+ objCorners = torch.stack([data0['objCorners'], data1['objCorners']])
213
+
214
+ return {
215
+ 'images': images,
216
+ 'rotation': rel_R,
217
+ 'translation': rel_t,
218
+ 'intrinsics': intrinsics,
219
+ 'bboxes': bboxes,
220
+ 'objCorners': objCorners,
221
+ 'pair_names': (data0['color_path'], data1['color_path']),
222
+ 'objName': data0['objName']
223
+ }
224
+
225
+
226
+ class HO3DfromJson(Dataset):
227
+ def __init__(self, data_root, json_path):
228
+ self.data_root = Path(data_root)
229
+ with open(json_path, 'r') as f:
230
+ self.scene_info = json.load(f)
231
+
232
+ self.obj_names = [
233
+ '003_cracker_box',
234
+ '006_mustard_bottle',
235
+ '011_banana',
236
+ '025_mug',
237
+ '037_scissors'
238
+ ]
239
+ self.object_points = {obj: np.loadtxt(self.data_root / 'models' / obj / 'points.xyz') for obj in self.obj_names}
240
+
241
+ def _load_color(self, path):
242
+ color = cv2.imread(path)
243
+ color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB)
244
+ return color
245
+
246
+ def _load_mask(self, path):
247
+ mask_path = str(path).replace('rgb', 'seg').replace('.jpg', '.png')
248
+ mask = cv2.imread(str(mask_path))
249
+ mask = cv2.resize(mask, (640, 480))
250
+ mask = mask[..., 1] == 255
251
+ return mask
252
+
253
+ def _load_depth(self, path):
254
+ depth_scale = 0.00012498664727900177
255
+
256
+ depth_path = str(path).replace('rgb', 'depth').replace('.jpg', '.png')
257
+ depth_img = cv2.imread(depth_path)
258
+
259
+ dpt = depth_img[:, :, 2] + depth_img[:, :, 1] * 256
260
+ dpt = dpt * depth_scale
261
+
262
+ return dpt
263
+
264
+ def __len__(self):
265
+ return len(self.scene_info)
266
+
267
+ def __getitem__(self, idx):
268
+ info = self.scene_info[str(idx)]
269
+ pair_names = info['pair_names']
270
+
271
+ image0 = self._load_color(str(self.data_root / pair_names[0]))
272
+ image0 = (torch.tensor(image0).float() / 255.0).permute(2, 0, 1)
273
+ image1 = self._load_color(str(self.data_root / pair_names[1]))
274
+ image1 = (torch.tensor(image1).float() / 255.0).permute(2, 0, 1)
275
+ images = torch.stack([image0, image1], dim=0)
276
+
277
+ mask0 = self._load_mask(str(self.data_root / pair_names[0]))
278
+ mask0 = torch.from_numpy(mask0)
279
+ mask1 = self._load_mask(str(self.data_root / pair_names[1]))
280
+ mask1 = torch.from_numpy(mask1)
281
+ masks = torch.stack([mask0, mask1], dim=0)
282
+
283
+ depth0 = self._load_depth(str(self.data_root / pair_names[0]))
284
+ depth0 = torch.from_numpy(depth0)
285
+ depth1 = self._load_depth(str(self.data_root / pair_names[1]))
286
+ depth1 = torch.from_numpy(depth1)
287
+ depths = torch.stack([depth0, depth1], dim=0)
288
+
289
+ rotation = torch.tensor(info['rotation']).reshape(3, 3)
290
+ translation = torch.tensor(info['translation'])
291
+ intrinsics = torch.tensor(info['intrinsics']).reshape(2, 3, 3)
292
+ bboxes = torch.tensor(info['bboxes'])
293
+ objCorners = torch.tensor(info['objCorners'])
294
+
295
+ return {
296
+ 'images': images,
297
+ 'masks': masks,
298
+ 'depths': depths,
299
+ 'rotation': rotation,
300
+ 'translation': translation,
301
+ 'intrinsics': intrinsics,
302
+ 'bboxes': bboxes,
303
+ 'objCorners': objCorners,
304
+ 'objName': info['objName'][0],
305
+ 'point_cloud': self.object_points[info['objName'][0]]
306
+ }
307
+
308
+
309
+ def build_ho3d(mode, config):
310
+ config = config.DATASET
311
+
312
+ data_root = config.DATA_ROOT
313
+ seq_id_list = [x.stem for x in (Path(data_root) / 'train').iterdir()]
314
+ val_id_list = ['BB14', 'SMu1', 'MC1', 'GSF14', 'SM2', 'SM3', 'SM4', 'SM5', 'MC2', 'MC4', 'MC5', 'MC6']
315
+ for val_id in val_id_list:
316
+ seq_id_list.remove(val_id)
317
+
318
+ if mode == 'train':
319
+ datasets = []
320
+ for seq_id in tqdm(seq_id_list, desc=f'Loading HO3D {mode} dataset'):
321
+ datasets.append(HO3DPair(data_root, mode, seq_id, config.MAX_ANGLE_ERROR))
322
+ return ConcatDataset(datasets)
323
+
324
+ elif mode == 'test' or mode == 'val':
325
+ # datasets = []
326
+ # for seq_id in tqdm(val_id_list[:5], desc=f'Loading HO3D {mode} dataset'):
327
+ # datasets.append(HO3DPair(data_root, mode, seq_id, config.MAX_ANGLE_ERROR))
328
+ # return ConcatDataset(datasets)
329
+
330
+ return HO3DfromJson(config.DATA_ROOT, config.JSON_PATH)
331
+
datasets/linemod.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ import math
4
+ import random
5
+ from tqdm import tqdm, trange
6
+ import plyfile
7
+
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset, ConcatDataset
12
+ from torch.nn import functional as F
13
+
14
+ from utils import Augmentor
15
+
16
+
17
+ LINEMOD_ID_TO_NAME = {
18
+ '000001': 'ape',
19
+ '000002': 'benchvise',
20
+ '000003': 'bowl',
21
+ '000004': 'camera',
22
+ '000005': 'can',
23
+ '000006': 'cat',
24
+ '000007': 'mug',
25
+ '000008': 'driller',
26
+ '000009': 'duck',
27
+ '000010': 'eggbox',
28
+ '000011': 'glue',
29
+ '000012': 'holepuncher',
30
+ '000013': 'iron',
31
+ '000014': 'lamp',
32
+ '000015': 'phone',
33
+ }
34
+
35
+
36
+ def inverse_transform(trans):
37
+ rot = trans[:3, :3]
38
+ t = trans[:3, 3]
39
+ rot = np.transpose(rot)
40
+ t = -np.matmul(rot, t)
41
+ output = np.zeros((4, 4), dtype=np.float32)
42
+ output[3][3] = 1
43
+ output[:3, :3] = rot
44
+ output[:3, 3] = t
45
+ return output
46
+
47
+
48
+ class BOPDataset(Dataset):
49
+
50
+ def __init__(self,
51
+ dataset_path,
52
+ scene_path,
53
+ object_id,
54
+ min_visible_fract,
55
+ mode,
56
+ rgb_postfix='.png',
57
+ object_scale=None
58
+ ):
59
+ super().__init__()
60
+ self.dataset_path = dataset_path
61
+ self.scene_path = scene_path
62
+ self.object_id = object_id
63
+
64
+ if dataset_path.name == 'lm' or dataset_path.name == 'lmo':
65
+ base_obj_scale = 1.0
66
+ self.models_path = self.dataset_path / 'models'
67
+ elif dataset_path.name == 'tless':
68
+ base_obj_scale = 0.60
69
+ self.models_path = self.dataset_path / 'models_reconst'
70
+ else:
71
+ raise ValueError(f'Unknown dataset type {dataset_path.name}')
72
+
73
+ self.model_path = self.models_path / f'obj_{self.object_id:06d}.ply'
74
+ self.pointcloud_path = self.dataset_path / 'models_eval' / f'obj_{self.object_id:06d}.ply'
75
+
76
+ models_info_path = self.dataset_path / 'models_eval' / 'models_info.json'
77
+ with open(models_info_path, 'r') as f:
78
+ self.model_info = json.load(f)[str(object_id)]
79
+
80
+ # self.center_object = center_object
81
+ if object_scale is None:
82
+ self.object_scale = base_obj_scale / self.model_info['diameter']
83
+ else:
84
+ self.object_scale = object_scale
85
+
86
+ # self.image_scale = 1.0
87
+ self.bounds = torch.tensor([
88
+ (self.model_info['min_x'], self.model_info['min_x'] + self.model_info['size_x']),
89
+ (self.model_info['min_y'], self.model_info['min_y'] + self.model_info['size_y']),
90
+ (self.model_info['min_z'], self.model_info['min_z'] + self.model_info['size_z']),
91
+ ])
92
+ self.centroid = self.bounds.mean(dim=1)
93
+
94
+ self.depth_dir = self.scene_path / 'depth'
95
+ self.mask_dir = self.scene_path / 'mask_visib'
96
+ self.color_dir = self.scene_path / 'rgb'
97
+ self.intrinsics_path = self.scene_path / 'scene_camera.json'
98
+ self.extrinsics_path = self.scene_path / 'scene_gt.json'
99
+ self.gt_info_path = self.scene_path / 'scene_gt_info.json'
100
+
101
+ self.intrinsics, self.depth_scales = self.load_intrinsics(self.intrinsics_path)
102
+ self.extrinsics, self.scene_object_inds = self.load_extrinsics(self.extrinsics_path)
103
+ self.extrinsics = torch.stack(self.extrinsics, dim=0)
104
+ self.gt_info = self.load_gt_info(self.gt_info_path)
105
+
106
+ # # Compute quaternions for sampling.
107
+ # rotation, translation = three.decompose(self.extrinsics)
108
+ # self.quaternions = three.quaternion.mat_to_quat(rotation[:, :3, :3])
109
+
110
+ self.depth_paths = sorted([self.depth_dir / f'{frame_ind:06d}.png'
111
+ for frame_ind in self.scene_object_inds.keys()])
112
+ self.mask_paths = [
113
+ self.mask_dir / f'{frame_ind:06d}_{obj_ind:06d}.png'
114
+ for frame_ind, obj_ind in self.scene_object_inds.items()
115
+ ]
116
+ self.color_paths = sorted([self.color_dir / f'{frame_ind:06d}{rgb_postfix}'
117
+ for frame_ind in self.scene_object_inds.keys()])
118
+
119
+ visib_filter = np.array(self.gt_info['visib_fract']) >= min_visible_fract
120
+ self.color_paths = np.array(self.color_paths)[visib_filter]
121
+ self.mask_paths = np.array(self.mask_paths)[visib_filter]
122
+ self.depth_paths = np.array(self.depth_paths)[visib_filter]
123
+ self.depth_scales = np.array(self.depth_scales)[visib_filter]
124
+ for k in self.gt_info:
125
+ self.gt_info[k] = np.array(self.gt_info[k])[visib_filter]
126
+
127
+ self.extrinsics = np.array(self.extrinsics)[visib_filter]
128
+ self.intrinsics = np.array(self.intrinsics)[visib_filter]
129
+
130
+ self.augment = Augmentor(mode=='train')
131
+
132
+ assert len(self.depth_paths) == len(self.mask_paths)
133
+ assert len(self.depth_paths) == len(self.color_paths)
134
+
135
+ # def load_pointcloud(self):
136
+ # obj = meshutils.Object3D(self.pointcloud_path)
137
+ # points = torch.tensor(obj.vertices, dtype=torch.float32)
138
+ # points = points * self.object_scale
139
+ # return points
140
+
141
+ def load_gt_info(self, path):
142
+ with open(path, 'r') as f:
143
+ gt_info_json = json.load(f)
144
+ keys = sorted([int(k) for k in gt_info_json.keys()])
145
+ gt_info = {k: [] for k in gt_info_json['0'][0]}
146
+ for key in keys:
147
+ value = gt_info_json[str(key)]
148
+ obj_info = value[self.scene_object_inds[key]]
149
+ for info_k in obj_info:
150
+ gt_info[info_k].append(obj_info[info_k])
151
+
152
+ return gt_info
153
+
154
+ def load_intrinsics(self, path):
155
+ intrinsics = []
156
+ depth_scales = []
157
+ with open(path, 'r') as f:
158
+ intrinsics_json = json.load(f)
159
+ keys = sorted([int(k) for k in intrinsics_json.keys()])
160
+ for key in keys:
161
+ value = intrinsics_json[str(key)]
162
+ intrinsic_3x3 = value['cam_K']
163
+ intrinsics.append(torch.tensor(intrinsic_3x3).reshape(3, 3))
164
+ depth_scales.append(value['depth_scale'])
165
+
166
+ return intrinsics, depth_scales
167
+
168
+ def load_extrinsics(self, path):
169
+ extrinsics = []
170
+ scene_object_inds = {}
171
+ with open(path, 'r') as f:
172
+ extrinsics_json = json.load(f)
173
+ frame_inds = sorted([int(k) for k in extrinsics_json.keys()])
174
+ for frame_ind in frame_inds:
175
+ for obj_ind, cam_d in enumerate(extrinsics_json[str(frame_ind)]):
176
+ if cam_d['obj_id'] == self.object_id:
177
+ rotation = torch.tensor(
178
+ cam_d['cam_R_m2c'], dtype=torch.float32).reshape(3, 3)
179
+ translation = torch.tensor(cam_d['cam_t_m2c'], dtype=torch.float32) / 1000.
180
+ # quaternion = three.quaternion.mat_to_quat(rotation)
181
+ # extrinsics.append(three.to_extrinsic_matrix(translation, quaternion))
182
+ extrinsic = torch.eye(4)
183
+ extrinsic[:3, :3] = rotation
184
+ extrinsic[:3, 3] = translation
185
+ extrinsics.append(extrinsic)
186
+ scene_object_inds[frame_ind] = obj_ind
187
+
188
+ return extrinsics, scene_object_inds
189
+
190
+ def __len__(self):
191
+ return len(self.color_paths)
192
+
193
+ def get_ids(self):
194
+ return [p.stem for p in self.color_paths]
195
+
196
+ def _load_color(self, path):
197
+ image = Image.open(path)
198
+ image = np.array(image)
199
+ return image
200
+
201
+ def _load_mask(self, path):
202
+ image = Image.open(path)
203
+ image = np.array(image, dtype=bool)
204
+ if len(image.shape) > 2:
205
+ image = image[:, :, 0]
206
+ return image
207
+
208
+ def _load_depth(self, path):
209
+ image = Image.open(path)
210
+ image = np.array(image, dtype=np.float32)
211
+ return image
212
+
213
+ def __getitem__(self, idx):
214
+ color = self._load_color(self.color_paths[idx])
215
+ # color = self.augment(color)
216
+ color = (torch.tensor(color).float() / 255.0).permute(2, 0, 1)
217
+ mask = self._load_mask(self.mask_paths[idx])
218
+ mask = torch.tensor(mask).bool()
219
+ depth = self._load_depth(self.depth_paths[idx])
220
+ depth = torch.tensor(depth) * self.object_scale * self.depth_scales[idx]
221
+
222
+ # intrinsic = self.normalize_intrinsic(self.intrinsics[idx])
223
+ # extrinsic = self.normalize_extrinsic(self.extrinsics[idx])
224
+
225
+ intrinsic = torch.from_numpy(self.intrinsics[idx])
226
+ extrinsic = torch.from_numpy(self.extrinsics[idx])
227
+
228
+ bbox_obj = self.gt_info['bbox_obj'][idx]
229
+ bbox_visib = self.gt_info['bbox_visib'][idx]
230
+
231
+ bbox_obj = torch.tensor([bbox_obj[0], bbox_obj[1], bbox_obj[0]+bbox_obj[2], bbox_obj[1]+bbox_obj[3]])
232
+ bbox_visib = torch.tensor([bbox_visib[0], bbox_visib[1], bbox_visib[0]+bbox_visib[2], bbox_visib[1]+bbox_visib[3]])
233
+
234
+ visib_fract = self.gt_info['visib_fract'][idx]
235
+ px_count_visib = self.gt_info['px_count_visib'][idx]
236
+
237
+ return {
238
+ 'color': color,
239
+ 'mask': mask,
240
+ 'depth': depth,
241
+ 'extrinsic': extrinsic,
242
+ 'intrinsic': intrinsic,
243
+ 'bbox_obj': bbox_obj,
244
+ 'bbox_visib': bbox_visib,
245
+ 'visib_fract': visib_fract,
246
+ 'px_count_visib': px_count_visib,
247
+ 'color_path': str(self.color_paths[idx]).split('/', 2)[-1],
248
+ 'object_scale': self.object_scale,
249
+ 'depth_scale': self.depth_scales[idx]
250
+ }
251
+
252
+
253
+ class Linemod(Dataset):
254
+ def __init__(self, data_root, mode, object_id, scene_id, min_visible_fract, max_angle_error):
255
+ if mode == 'train':
256
+ type_path = 'train_pbr'
257
+ rgb_postfix = '.jpg'
258
+ scene_id = scene_id
259
+ elif mode == 'val' or mode == 'test':
260
+ type_path = 'test'
261
+ rgb_postfix = '.png'
262
+ scene_id = object_id
263
+ else:
264
+ raise NotImplementedError(f'mode {mode}')
265
+
266
+ data_root = Path(data_root)
267
+ scene_path = data_root / type_path / f'{scene_id:06d}'
268
+ self.bop_dataset = BOPDataset(data_root, scene_path, object_id=object_id, min_visible_fract=min_visible_fract, mode=mode, rgb_postfix=rgb_postfix)
269
+
270
+ angle_err = self.get_angle_error(torch.from_numpy(self.bop_dataset.extrinsics[:, :3, :3]))
271
+ index0, index1 = torch.where(angle_err < max_angle_error)
272
+ filter = torch.where(index0 < index1)
273
+ self.index0, self.index1 = index0[filter], index1[filter]
274
+ # angle_err_filtered = angle_err[row, col]
275
+
276
+ self.indices = torch.tensor(list(zip(self.index0, self.index1)))
277
+ if mode == 'val':
278
+ self.indices = self.indices[torch.randperm(self.indices.size(0))[:1500]]
279
+
280
+ def get_angle_error(self, R):
281
+ # R: (B, 3, 3)
282
+ residual = torch.einsum('aij,bik->abjk', R, R)
283
+ trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1)
284
+ cosine = (trace - 1) / 2
285
+ cosine = torch.clip(cosine, -1, 1)
286
+ R_err = torch.acos(cosine)
287
+ angle_err = R_err.rad2deg()
288
+
289
+ return angle_err
290
+
291
+ def __len__(self):
292
+ return len(self.indices)
293
+
294
+ def __getitem__(self, idx):
295
+ idx0, idx1 = self.indices[idx]
296
+ data0, data1 = self.bop_dataset[idx0], self.bop_dataset[idx1]
297
+
298
+ images = torch.stack([data0['color'], data1['color']], dim=0)
299
+
300
+ ex0, ex1 = data0['extrinsic'], data1['extrinsic']
301
+ rel_ex = ex1 @ ex0.inverse()
302
+ rel_R = rel_ex[:3, :3]
303
+ rel_t = rel_ex[:3, 3]
304
+
305
+ intrinsics = torch.stack([data0['intrinsic'], data1['intrinsic']], dim=0)
306
+ bboxes = torch.stack([data0['bbox_visib'], data1['bbox_visib']])
307
+
308
+ return {
309
+ 'images': images,
310
+ 'rotation': rel_R,
311
+ 'translation': rel_t,
312
+ 'intrinsics': intrinsics,
313
+ 'bboxes': bboxes,
314
+ 'pair_names': (data0['color_path'], data1['color_path']),
315
+ 'object_scale': data0['object_scale'],
316
+ 'depth_scale': (data0['depth_scale'], data1['depth_scale']),
317
+ }
318
+
319
+
320
+ class LinemodfromJson(Dataset):
321
+ def __init__(self, data_root, json_path):
322
+ self.data_root = Path(data_root)
323
+ with open(json_path, 'r') as f:
324
+ self.scene_info = json.load(f)
325
+
326
+ # self.image_scale = 1.0
327
+
328
+ models_info_path = self.data_root / 'models_eval' / 'models_info.json'
329
+ with open(models_info_path, 'r') as f:
330
+ model_info = json.load(f)
331
+
332
+ self.object_diameters = {obj: model_info[obj]['diameter'] for obj in model_info}
333
+ self.object_points = {obj: self._load_point_cloud(obj) for obj in self.object_diameters}
334
+
335
+ def _load_point_cloud(self, obj_id):
336
+ with open(self.data_root / 'models_eval' / f'obj_{int(obj_id):06d}.ply', "rb") as f:
337
+ plydata = plyfile.PlyData.read(f)
338
+ xyz = np.stack([np.array(plydata["vertex"][c]).astype(float) for c in ("x", "y", "z")], axis=1)
339
+ return xyz
340
+
341
+ def _load_color(self, path):
342
+ image = Image.open(path)
343
+ # new_shape = (int(image.width * self.image_scale), int(image.height * self.image_scale))
344
+ # image = image.resize(new_shape)
345
+ image = np.array(image)
346
+ return image
347
+
348
+ def _load_mask(self, path):
349
+ path = path.replace('rgb', 'mask_visib').replace('.png', '_000000.png')
350
+ image = Image.open(path)
351
+ # new_shape = (int(image.width * self.image_scale), int(image.height * self.image_scale))
352
+ # image = image.resize(new_shape)
353
+ image = np.array(image, dtype=bool)
354
+ if len(image.shape) > 2:
355
+ image = image[:, :, 0]
356
+ return image
357
+
358
+ def _load_depth(self, path):
359
+ path = path.replace('rgb', 'depth')
360
+ image = Image.open(path)
361
+ # new_shape = (int(image.width * self.image_scale), int(image.height * self.image_scale))
362
+ # image = image.resize(new_shape)
363
+ image = np.array(image, dtype=np.float32)
364
+ return image
365
+
366
+ def __len__(self):
367
+ return len(self.scene_info)
368
+
369
+ def __getitem__(self, idx):
370
+ info = self.scene_info[str(idx)]
371
+ pair_names = info['pair_names']
372
+
373
+ image0 = self._load_color(str(self.data_root / pair_names[0]))
374
+ image0 = (torch.tensor(image0).float() / 255.0).permute(2, 0, 1)
375
+ image1 = self._load_color(str(self.data_root / pair_names[1]))
376
+ image1 = (torch.tensor(image1).float() / 255.0).permute(2, 0, 1)
377
+ images = torch.stack([image0, image1], dim=0)
378
+
379
+ mask0 = self._load_mask(str(self.data_root / pair_names[0]))
380
+ mask0 = torch.tensor(mask0).bool()
381
+ mask1 = self._load_mask(str(self.data_root / pair_names[1]))
382
+ mask1 = torch.tensor(mask1).bool()
383
+ masks = torch.stack([mask0, mask1], dim=0)
384
+
385
+ depth0 = self._load_depth(str(self.data_root / pair_names[0]))
386
+ depth0 = torch.tensor(depth0) * info['depth_scale'][0]
387
+ depth1 = self._load_depth(str(self.data_root / pair_names[1]))
388
+ depth1 = torch.tensor(depth1) * info['depth_scale'][1]
389
+ depths = torch.stack([depth0, depth1], dim=0) / 1000.
390
+
391
+ rotation = torch.tensor(info['rotation']).reshape(3, 3)
392
+ translation = torch.tensor(info['translation'])
393
+ intrinsics = torch.tensor(info['intrinsics']).reshape(2, 3, 3)
394
+ bboxes = torch.tensor(info['bboxes'])
395
+
396
+ obj_id = str(int(pair_names[0].split('/')[1]))
397
+ diameter = self.object_diameters[obj_id]
398
+ point_cloud = torch.from_numpy(self.object_points[obj_id]) / 1000.
399
+
400
+ return {
401
+ 'images': images,
402
+ 'masks': masks,
403
+ 'depths': depths,
404
+ 'rotation': rotation,
405
+ 'translation': translation,
406
+ 'intrinsics': intrinsics,
407
+ 'bboxes': bboxes,
408
+ 'diameter': diameter,
409
+ 'point_cloud': point_cloud,
410
+ }
411
+
412
+
413
+ def build_linemod(mode, config):
414
+ config = config.DATASET
415
+
416
+ # datasets = []
417
+ # for i, _ in enumerate(LINEMOD_ID_TO_NAME):
418
+ # datasets.append(Linemod(config.DATA_ROOT, mode, i+1, config.MIN_VISIBLE_FRACT, config.MAX_ANGLE_ERROR))
419
+
420
+ # return ConcatDataset(datasets)
421
+
422
+ if mode == 'train':
423
+ datasets = []
424
+ with tqdm(total=len(LINEMOD_ID_TO_NAME) * 50) as t:
425
+ t.set_description(f'Loading Linemod {mode} datasets')
426
+ for i, _ in enumerate(LINEMOD_ID_TO_NAME):
427
+ for j in range(50):
428
+ t.update(1)
429
+ try:
430
+ datasets.append(Linemod(config.DATA_ROOT, mode, i+1, j, config.MIN_VISIBLE_FRACT, config.MAX_ANGLE_ERROR))
431
+ except KeyError:
432
+ continue
433
+ return ConcatDataset(datasets)
434
+
435
+ elif mode == 'test' or mode == 'val':
436
+ # datasets = []
437
+ # for i, _ in enumerate(LINEMOD_ID_TO_NAME):
438
+ # datasets.append(Linemod(config.DATA_ROOT, mode, i+1, i+1, config.MIN_VISIBLE_FRACT, config.MAX_ANGLE_ERROR))
439
+
440
+ # return ConcatDataset(datasets)
441
+ return LinemodfromJson(config.DATA_ROOT, config.JSON_PATH)
datasets/mapfree.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import torch.utils.data as data
5
+ import cv2
6
+ import numpy as np
7
+ from transforms3d.quaternions import qinverse, qmult, rotate_vector, quat2mat
8
+
9
+ from utils.transform import correct_intrinsic_scale
10
+ from utils import Augmentor
11
+
12
+
13
+ class MapFreeScene(data.Dataset):
14
+ def __init__(self, scene_root, resize, sample_factor=1, overlap_limits=None, estimated_depth=None, mode='train'):
15
+ super().__init__()
16
+
17
+ self.scene_root = Path(scene_root)
18
+ self.resize = resize
19
+ self.sample_factor = sample_factor
20
+ self.estimated_depth = estimated_depth
21
+
22
+ # load absolute poses
23
+ self.poses = self.read_poses(self.scene_root)
24
+
25
+ # read intrinsics
26
+ self.K = self.read_intrinsics(self.scene_root, resize)
27
+
28
+ # load pairs
29
+ self.pairs = self.load_pairs(self.scene_root, overlap_limits, self.sample_factor)
30
+
31
+ self.augment = Augmentor(mode=='train')
32
+
33
+ @staticmethod
34
+ def read_intrinsics(scene_root: Path, resize=None):
35
+ Ks = {}
36
+ with (scene_root / 'intrinsics.txt').open('r') as f:
37
+ for line in f.readlines():
38
+ if '#' in line:
39
+ continue
40
+
41
+ line = line.strip().split(' ')
42
+ img_name = line[0]
43
+ fx, fy, cx, cy, W, H = map(float, line[1:])
44
+
45
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
46
+ if resize is not None:
47
+ K = correct_intrinsic_scale(K, resize[0] / W, resize[1] / H)
48
+ Ks[img_name] = K
49
+ return Ks
50
+
51
+ @staticmethod
52
+ def read_poses(scene_root: Path):
53
+ """
54
+ Returns a dictionary that maps: img_path -> (q, t) where
55
+ np.array q = (qw, qx qy qz) quaternion encoding rotation matrix;
56
+ np.array t = (tx ty tz) translation vector;
57
+ (q, t) encodes absolute pose (world-to-camera), i.e. X_c = R(q) X_W + t
58
+ """
59
+ poses = {}
60
+ with (scene_root / 'poses.txt').open('r') as f:
61
+ for line in f.readlines():
62
+ if '#' in line:
63
+ continue
64
+
65
+ line = line.strip().split(' ')
66
+ img_name = line[0]
67
+ qt = np.array(list(map(float, line[1:])))
68
+ poses[img_name] = (qt[:4], qt[4:])
69
+ return poses
70
+
71
+ def load_pairs(self, scene_root: Path, overlap_limits: tuple = None, sample_factor: int = 1):
72
+ """
73
+ For training scenes, filter pairs of frames based on overlap (pre-computed in overlaps.npz)
74
+ For test/val scenes, pairs are formed between keyframe and every other sample_factor query frames.
75
+ If sample_factor == 1, all query frames are used. Note: sample_factor applicable only to test/val
76
+ Returns:
77
+ pairs: nd.array [Npairs, 4], where each column represents seaA, imA, seqB, imB, respectively
78
+ """
79
+ overlaps_path = scene_root / 'overlaps.npz'
80
+
81
+ if overlaps_path.exists():
82
+ f = np.load(overlaps_path, allow_pickle=True)
83
+ idxs, overlaps = f['idxs'], f['overlaps']
84
+ if overlap_limits is not None:
85
+ min_overlap, max_overlap = overlap_limits
86
+ mask = (overlaps > min_overlap) * (overlaps < max_overlap)
87
+ idxs = idxs[mask]
88
+ return idxs.copy()
89
+ else:
90
+ idxs = np.zeros((len(self.poses) - 1, 4), dtype=np.uint16)
91
+ idxs[:, 2] = 1
92
+ idxs[:, 3] = np.array([int(fn[-9:-4])
93
+ for fn in self.poses.keys() if 'seq0' not in fn], dtype=np.uint16)
94
+ return idxs[::sample_factor]
95
+
96
+ def get_pair_path(self, pair):
97
+ seqA, imgA, seqB, imgB = pair
98
+ return (f'seq{seqA}/frame_{imgA:05}.jpg', f'seq{seqB}/frame_{imgB:05}.jpg')
99
+
100
+ def __len__(self):
101
+ return len(self.pairs)
102
+
103
+ def __getitem__(self, index):
104
+ # image paths (relative to scene_root)
105
+ img_name0, img_name1 = self.get_pair_path(self.pairs[index])
106
+ w_new, h_new = self.resize
107
+
108
+ image0 = cv2.imread(str(self.scene_root / img_name0))
109
+ # image0 = cv2.resize(image0, (w_new, h_new))
110
+ image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)
111
+ image0 = self.augment(image0)
112
+ image0 = torch.from_numpy(image0).permute(2, 0, 1).float() / 255.
113
+
114
+ image1 = cv2.imread(str(self.scene_root / img_name1))
115
+ # image1 = cv2.resize(image1, (w_new, h_new))
116
+ image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
117
+ image1 = self.augment(image1)
118
+ image1 = torch.from_numpy(image1).permute(2, 0, 1).float() / 255.
119
+ images = torch.stack([image0, image1], dim=0)
120
+
121
+
122
+ depth0 = np.load(str(self.scene_root / img_name0).replace('.jpg', f'.da.npy'))
123
+ depth0 = torch.from_numpy(depth0).float()
124
+
125
+ depth1 = np.load(str(self.scene_root / img_name1).replace('.jpg', f'.da.npy'))
126
+ depth1 = torch.from_numpy(depth1).float()
127
+
128
+ depths = torch.stack([depth0, depth1], dim=0)
129
+
130
+ # get absolute pose of im0 and im1
131
+ # quaternion and translation vector that transforms World-to-Cam
132
+ q1, t1 = self.poses[img_name0]
133
+ # quaternion and translation vector that transforms World-to-Cam
134
+ q2, t2 = self.poses[img_name1]
135
+
136
+ # get 4 x 4 relative pose transformation matrix (from im1 to im2)
137
+ # for test/val set, q1,t1 is the identity pose, so the relative pose matches the absolute pose
138
+ q12 = qmult(q2, qinverse(q1))
139
+ t12 = t2 - rotate_vector(t1, q12)
140
+ T = np.eye(4, dtype=np.float32)
141
+ T[:3, :3] = quat2mat(q12)
142
+ T[:3, -1] = t12
143
+ T = torch.from_numpy(T)
144
+
145
+ K_0 = torch.from_numpy(self.K[img_name0].copy()).reshape(3, 3)
146
+ K_1 = torch.from_numpy(self.K[img_name1].copy()).reshape(3, 3)
147
+ intrinsics = torch.stack([K_0, K_1], dim=0).float()
148
+
149
+ data = {
150
+ 'images': images,
151
+ 'depths': depths,
152
+ 'rotation': T[:3, :3],
153
+ 'translation': T[:3, 3],
154
+ 'intrinsics': intrinsics,
155
+ 'scene_id': self.scene_root.stem,
156
+ 'scene_root': str(self.scene_root),
157
+ 'pair_id': index*self.sample_factor,
158
+ 'pair_names': (img_name0, img_name1),
159
+ }
160
+
161
+ return data
162
+
163
+
164
+ def build_concat_mapfree(mode, config):
165
+ assert mode in ['train', 'val', 'test'], 'Invalid dataset mode'
166
+
167
+ data_root = Path(config.DATASET.DATA_ROOT) / mode
168
+ scenes = scenes = [s.name for s in data_root.iterdir() if s.is_dir()]
169
+ sample_factor = {'train': 1, 'val': 5, 'test': 1}[mode]
170
+ estimated_depth = config.DATASET.ESTIMATED_DEPTH
171
+
172
+ resize = (540, 720)
173
+ overlap_limits = (0.2, 0.7)
174
+
175
+ # Init dataset objects for each scene
176
+ datasets = [MapFreeScene(data_root / scene, resize, sample_factor, overlap_limits, estimated_depth, mode) for scene in scenes]
177
+
178
+ return data.ConcatDataset(datasets)
datasets/matterport.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from pathlib import Path
4
+ import json
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+ from utils import rotation_matrix_from_quaternion, Augmentor
9
+
10
+
11
+ class Matterport3D(Dataset):
12
+ def __init__(self, data_root, mode='train'):
13
+ data_root = Path(data_root)
14
+ json_path = data_root / 'mp3d_planercnn_json' / f'cached_set_{mode}.json'
15
+
16
+ scene_info = {'images': [], 'rotation': [], 'translation': [], 'intrinsics': []}
17
+
18
+ with open(json_path) as f:
19
+ split = json.load(f)
20
+
21
+ for _, data in enumerate(split['data']):
22
+ images = []
23
+ for imgnum in ['0', '1']:
24
+ img_name = data_root / '/'.join(data[imgnum]['file_name'].split('/')[6:])
25
+ images.append(img_name)
26
+
27
+ rel_rotation = data['rel_pose']['rotation']
28
+ rel_translation = data['rel_pose']['position']
29
+ intrinsic = [
30
+ [517.97, 0, 320],
31
+ [0, 517.97, 240],
32
+ [0, 0, 1]
33
+ ]
34
+ intrinsics = [intrinsic, intrinsic]
35
+
36
+ scene_info['images'].append(images)
37
+ scene_info['rotation'].append(rel_rotation)
38
+ scene_info['translation'].append(rel_translation)
39
+ scene_info['intrinsics'].append(intrinsics)
40
+
41
+ scene_info['rotation'] = torch.tensor(scene_info['rotation'])
42
+ scene_info['translation'] = torch.tensor(scene_info['translation'])
43
+ scene_info['intrinsics'] = torch.tensor(scene_info['intrinsics'])
44
+
45
+ self.scene_info = scene_info
46
+ self.augment = Augmentor(mode=='train')
47
+
48
+ self.is_training = mode == 'train'
49
+
50
+ def __len__(self):
51
+ return len(self.scene_info['images'])
52
+
53
+ def __getitem__(self, idx):
54
+ img_names = self.scene_info['images'][idx]
55
+ rotation = self.scene_info['rotation'][idx]
56
+ translation = self.scene_info['translation'][idx]
57
+ intrinsics = self.scene_info['intrinsics'][idx]
58
+
59
+ images = []
60
+ for i in range(2):
61
+ image = cv2.imread(str(img_names[i]))
62
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
63
+ image = self.augment(image)
64
+ image = torch.from_numpy(image).permute(2, 0, 1)
65
+ images.append(image)
66
+ images = torch.stack(images)
67
+ images = images.float() / 255.
68
+
69
+ rotation = -rotation if rotation[0] < 0 else rotation
70
+ rotation /= rotation.norm(2)
71
+ rotation = rotation_matrix_from_quaternion(rotation[None,])[0]
72
+
73
+ rotation = rotation.mT
74
+ translation = -rotation @ translation.unsqueeze(-1)
75
+ translation = translation[:, 0]
76
+
77
+ return {
78
+ 'images': images,
79
+ 'rotation': rotation,
80
+ 'translation': translation,
81
+ 'intrinsics': intrinsics,
82
+ }
83
+
84
+
85
+ def build_matterport(mode, config):
86
+ return Matterport3D(config.DATASET.DATA_ROOT, mode)
datasets/megadepth.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import cv2
4
+ from tqdm import tqdm
5
+ import torch
6
+ from torch.utils.data import Dataset, ConcatDataset
7
+
8
+ from utils import Augmentor
9
+
10
+
11
+ class MegaDepthDataset(Dataset):
12
+ def __init__(self,
13
+ root_dir,
14
+ npz_path,
15
+ mode='train',
16
+ min_overlap_score=0.4,
17
+ ):
18
+ """
19
+ Manage one scene(npz_path) of MegaDepth dataset.
20
+
21
+ Args:
22
+ root_dir (str): megadepth root directory that has `phoenix`.
23
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
24
+ mode (str): options are ['train', 'val', 'test']
25
+ min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
26
+ """
27
+ super().__init__()
28
+ self.root_dir = root_dir
29
+ self.mode = mode
30
+ self.scene_id = npz_path.split('.')[0]
31
+
32
+ # prepare scene_info and pair_info
33
+ if mode == 'test':
34
+ min_overlap_score = 0
35
+ self.scene_info = np.load(npz_path, allow_pickle=True)
36
+ self.pair_infos = self.scene_info['pair_infos'].copy()
37
+ del self.scene_info['pair_infos']
38
+ self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
39
+
40
+ self.augment = Augmentor(mode=='train')
41
+
42
+ def __len__(self):
43
+ return len(self.pair_infos)
44
+
45
+ def __getitem__(self, idx):
46
+ (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
47
+
48
+ # read grayscale image and mask. (1, h, w) and (h, w)
49
+ img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
50
+ img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
51
+
52
+ w_new, h_new = 640, 480
53
+
54
+ image0 = cv2.imread(img_name0)
55
+ scale0 = torch.tensor([image0.shape[1]/w_new, image0.shape[0]/h_new], dtype=torch.float)
56
+ image0 = cv2.resize(image0, (w_new, h_new))
57
+ image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)
58
+ # image0 = self.augment(image0)
59
+ image0 = torch.from_numpy(image0).permute(2, 0, 1).float() / 255.
60
+
61
+ image1 = cv2.imread(img_name1)
62
+ scale1 = torch.tensor([image1.shape[1]/w_new, image1.shape[0]/h_new], dtype=torch.float)
63
+ image1 = cv2.resize(image1, (w_new, h_new))
64
+ image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
65
+ # image1 = self.augment(image1)
66
+ image1 = torch.from_numpy(image1).permute(2, 0, 1).float() / 255.
67
+
68
+ scales = torch.stack([scale0, scale1], dim=0)
69
+ images = torch.stack([image0, image1], dim=0)
70
+
71
+ # read intrinsics of original size
72
+ K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
73
+ K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
74
+ intrinsics = torch.stack([K_0, K_1], dim=0)
75
+
76
+ # read and compute relative poses
77
+ T0 = self.scene_info['poses'][idx0]
78
+ T1 = self.scene_info['poses'][idx1]
79
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
80
+
81
+ data = {
82
+ 'images': images,
83
+ 'scales': scales, # (2, 2): [scale_w, scale_h]
84
+ 'rotation': T_0to1[:3, :3],
85
+ 'translation': T_0to1[:3, 3],
86
+ 'intrinsics': intrinsics,
87
+ 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
88
+ 'depth_pair_names': (self.scene_info['depth_paths'][idx0], self.scene_info['depth_paths'][idx1]),
89
+ }
90
+
91
+ return data
92
+
93
+
94
+ def build_concat_megadepth(mode, config):
95
+ if mode == 'train':
96
+ config = config.DATASET.TRAIN
97
+ elif mode == 'val':
98
+ config = config.DATASET.VAL
99
+ elif mode == 'test':
100
+ config = config.DATASET.TEST
101
+ else:
102
+ raise NotImplementedError(f'mode {mode}')
103
+
104
+ data_root = config.DATA_ROOT
105
+ # pose_root = config.POSE_ROOT
106
+ npz_root = config.NPZ_ROOT
107
+ list_path = config.LIST_PATH
108
+ # intrinsic_path = config.INTRINSIC_PATH
109
+ min_overlap_score = config.MIN_OVERLAP_SCORE
110
+
111
+ with open(list_path, 'r') as f:
112
+ npz_names = [name.split()[0] for name in f.readlines()]
113
+
114
+ datasets = []
115
+ npz_names = [f'{n}.npz' for n in npz_names]
116
+ for npz_name in tqdm(npz_names, desc=f'Loading MegaDepth {mode} datasets',):
117
+ npz_path = osp.join(npz_root, npz_name)
118
+ datasets.append(MegaDepthDataset(
119
+ data_root,
120
+ npz_path,
121
+ mode=mode,
122
+ min_overlap_score=min_overlap_score,
123
+ ))
124
+
125
+ return ConcatDataset(datasets)
datasets/sampler.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Sampler, ConcatDataset
3
+
4
+
5
+ class RandomConcatSampler(Sampler):
6
+ """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
7
+ in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
8
+ However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
9
+
10
+ For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
11
+ Args:
12
+ shuffle (bool): shuffle the random sampled indices across all sub-datsets.
13
+ repeat (int): repeatedly use the sampled indices multiple times for training.
14
+ [arXiv:1902.05509, arXiv:1901.09335]
15
+ NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
16
+ NOTE: This sampler behaves differently with DistributedSampler.
17
+ It assume the dataset is splitted across ranks instead of replicated.
18
+ TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
19
+ ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
20
+ """
21
+ def __init__(self,
22
+ data_source: ConcatDataset,
23
+ n_samples_per_subset: int,
24
+ subset_replacement: bool=True,
25
+ shuffle: bool=True,
26
+ repeat: int=1,
27
+ seed: int=None):
28
+ if not isinstance(data_source, ConcatDataset):
29
+ raise TypeError("data_source should be torch.utils.data.ConcatDataset")
30
+
31
+ self.data_source = data_source
32
+ self.n_subset = len(self.data_source.datasets)
33
+ self.n_samples_per_subset = n_samples_per_subset
34
+ self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
35
+ self.subset_replacement = subset_replacement
36
+ self.repeat = repeat
37
+ self.shuffle = shuffle
38
+ self.generator = torch.manual_seed(seed)
39
+ assert self.repeat >= 1
40
+
41
+ def __len__(self):
42
+ return self.n_samples
43
+
44
+ def __iter__(self):
45
+ indices = []
46
+ # sample from each sub-dataset
47
+ for d_idx in range(self.n_subset):
48
+ low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
49
+ high = self.data_source.cumulative_sizes[d_idx]
50
+ if self.subset_replacement:
51
+ rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
52
+ generator=self.generator, dtype=torch.int64)
53
+ else: # sample without replacement
54
+ len_subset = len(self.data_source.datasets[d_idx])
55
+ rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
56
+ if len_subset >= self.n_samples_per_subset:
57
+ rand_tensor = rand_tensor[:self.n_samples_per_subset]
58
+ else: # padding with replacement
59
+ rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
60
+ generator=self.generator, dtype=torch.int64)
61
+ rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
62
+ indices.append(rand_tensor)
63
+ indices = torch.cat(indices)
64
+ if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
65
+ rand_tensor = torch.randperm(len(indices), generator=self.generator)
66
+ indices = indices[rand_tensor]
67
+
68
+ # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
69
+ if self.repeat > 1:
70
+ repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
71
+ if self.shuffle:
72
+ _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
73
+ repeat_indices = map(_choice, repeat_indices)
74
+ indices = torch.cat([indices, *repeat_indices], 0)
75
+
76
+ assert indices.shape[0] == self.n_samples
77
+ return iter(indices.tolist())
datasets/scannet.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path as osp
2
+ import numpy as np
3
+ from numpy.linalg import inv
4
+ import cv2
5
+ from tqdm import tqdm
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset, ConcatDataset
9
+
10
+ from utils import Augmentor
11
+
12
+
13
+ def read_scannet_pose(path):
14
+ """ Read ScanNet's Camera2World pose and transform it to World2Camera.
15
+
16
+ Returns:
17
+ pose_w2c (np.ndarray): (4, 4)
18
+ """
19
+ cam2world = np.loadtxt(path, delimiter=' ')
20
+ world2cam = inv(cam2world)
21
+ return world2cam
22
+
23
+
24
+ class ScanNetDataset(Dataset):
25
+ def __init__(self,
26
+ root_dir,
27
+ npz_path,
28
+ intrinsic_path,
29
+ mode='train',
30
+ min_overlap_score=0.4,
31
+ pose_dir=None,
32
+ ):
33
+ """Manage one scene of ScanNet Dataset.
34
+ Args:
35
+ root_dir (str): ScanNet root directory that contains scene folders.
36
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
37
+ intrinsic_path (str): path to depth-camera intrinsic file.
38
+ mode (str): options are ['train', 'val', 'test'].
39
+ pose_dir (str): ScanNet root directory that contains all poses.
40
+ (we use a separate (optional) pose_dir since we store images and poses separately.)
41
+ """
42
+ super().__init__()
43
+ self.root_dir = root_dir
44
+ self.pose_dir = pose_dir if pose_dir is not None else root_dir
45
+ self.mode = mode
46
+
47
+ # prepare data_names, intrinsics and extrinsics(T)
48
+ with np.load(npz_path) as data:
49
+ self.data_names = data['name']
50
+ if 'score' in data.keys() and mode not in ['val' or 'test']:
51
+ kept_mask = data['score'] > min_overlap_score
52
+ self.data_names = self.data_names[kept_mask]
53
+ self.intrinsics = dict(np.load(intrinsic_path))
54
+ self.augment = Augmentor(mode=='train')
55
+
56
+ def __len__(self):
57
+ return len(self.data_names)
58
+
59
+ def _read_abs_pose(self, scene_name, name):
60
+ pth = osp.join(self.pose_dir,
61
+ scene_name,
62
+ 'pose', f'{name}.txt')
63
+ return read_scannet_pose(pth)
64
+
65
+ def _compute_rel_pose(self, scene_name, name0, name1):
66
+ pose0 = self._read_abs_pose(scene_name, name0)
67
+ pose1 = self._read_abs_pose(scene_name, name1)
68
+
69
+ return np.matmul(pose1, inv(pose0)) # (4, 4)
70
+
71
+ def __getitem__(self, idx):
72
+ data_name = self.data_names[idx]
73
+ scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
74
+ scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
75
+
76
+ img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
77
+ img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
78
+
79
+ w_new, h_new = 640, 480
80
+
81
+ image0 = cv2.imread(img_name0)
82
+ image0 = cv2.resize(image0, (w_new, h_new))
83
+ image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)
84
+ # image0 = self.augment(image0)
85
+ image0 = torch.from_numpy(image0).permute(2, 0, 1).float() / 255.
86
+
87
+ image1 = cv2.imread(img_name1)
88
+ image1 = cv2.resize(image1, (w_new, h_new))
89
+ image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
90
+ # image1 = self.augment(image1)
91
+ image1 = torch.from_numpy(image1).permute(2, 0, 1).float() / 255.
92
+ images = torch.stack([image0, image1], dim=0)
93
+
94
+ # depth0 = cv2.imread(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'), cv2.IMREAD_UNCHANGED)
95
+ # depth0 = depth0 / 1000
96
+ # depth0 = torch.from_numpy(depth0).float()
97
+
98
+ # depth1 = cv2.imread(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'), cv2.IMREAD_UNCHANGED)
99
+ # depth1 = depth1 / 1000
100
+ # depth1 = torch.from_numpy(depth1).float()
101
+ # depths = torch.stack([depth0, depth1], dim=0)
102
+
103
+ K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
104
+ intrinsics = torch.stack([K_0, K_1], dim=0)
105
+
106
+ # read and compute relative poses
107
+ T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
108
+ dtype=torch.float32)
109
+
110
+ data = {
111
+ 'images': images,
112
+ # 'depths': depths,
113
+ 'rotation': T_0to1[:3, :3],
114
+ 'translation': T_0to1[:3, 3],
115
+ 'intrinsics': intrinsics,
116
+ 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
117
+ osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
118
+ }
119
+
120
+ return data
121
+
122
+
123
+ def build_concat_scannet(mode, config):
124
+ if mode == 'train':
125
+ config = config.DATASET.TRAIN
126
+ elif mode == 'val':
127
+ config = config.DATASET.VAL
128
+ elif mode == 'test':
129
+ config = config.DATASET.TEST
130
+ else:
131
+ raise NotImplementedError(f'mode {mode}')
132
+
133
+ data_root = config.DATA_ROOT
134
+ npz_root = config.NPZ_ROOT
135
+ list_path = config.LIST_PATH
136
+ intrinsic_path = config.INTRINSIC_PATH
137
+ min_overlap_score = config.MIN_OVERLAP_SCORE
138
+
139
+ with open(list_path, 'r') as f:
140
+ npz_names = [name.split()[0] for name in f.readlines()]
141
+
142
+ datasets = []
143
+ for npz_name in tqdm(npz_names, desc=f'Loading ScanNet {mode} datasets',):
144
+ npz_path = osp.join(npz_root, npz_name)
145
+ datasets.append(ScanNetDataset(
146
+ data_root,
147
+ npz_path,
148
+ intrinsic_path,
149
+ mode=mode,
150
+ min_overlap_score=min_overlap_score,
151
+ ))
152
+
153
+ return ConcatDataset(datasets)
154
+
eval.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from torch.utils.data import DataLoader
3
+ import lightning as L
4
+
5
+ from datasets import dataset_dict
6
+ from model import PL_RelPose, keypoint_dict
7
+ from configs.default import get_cfg_defaults
8
+
9
+
10
+ def main(args):
11
+ config = get_cfg_defaults()
12
+ config.merge_from_file(args.config)
13
+
14
+ task = config.DATASET.TASK
15
+ dataset = config.DATASET.DATA_SOURCE
16
+
17
+ batch_size = config.TRAINER.BATCH_SIZE
18
+ num_workers = config.TRAINER.NUM_WORKERS
19
+ pin_memory = config.TRAINER.PIN_MEMORY
20
+
21
+ test_num_keypoints = config.MODEL.TEST_NUM_KEYPOINTS
22
+
23
+ build_fn = dataset_dict[task][dataset]
24
+ testset = build_fn('test', config)
25
+ testloader = DataLoader(testset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
26
+
27
+ pl_relpose = PL_RelPose.load_from_checkpoint(args.ckpt_path)
28
+ pl_relpose.extractor = keypoint_dict[pl_relpose.hparams['features']](max_num_keypoints=test_num_keypoints, detection_threshold=0.0).eval()
29
+
30
+ trainer = L.Trainer(
31
+ devices=[0],
32
+ )
33
+
34
+ trainer.test(pl_relpose, dataloaders=testloader)
35
+
36
+
37
+ def get_parser():
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument('config', type=str, help='.yaml configure file path')
40
+ parser.add_argument('ckpt_path', type=str)
41
+
42
+ return parser
43
+
44
+
45
+ if __name__ == "__main__":
46
+ parser = get_parser()
47
+ args = parser.parse_args()
48
+ main(args)
eval_add_reproj.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ from collections import defaultdict
5
+ from tqdm import tqdm
6
+ from transforms3d.quaternions import mat2quat
7
+ import pandas as pd
8
+
9
+ from model import PL_RelPose, keypoint_dict
10
+ from utils.reproject import reprojection_error, Pose, save_submission
11
+ from utils.metrics import reproj, add, adi, compute_continuous_auc, relative_pose_error, rotation_angular_error
12
+ from datasets import dataset_dict
13
+ from configs.default import get_cfg_defaults
14
+
15
+
16
+ @torch.no_grad()
17
+ def main(args):
18
+ config = get_cfg_defaults()
19
+ config.merge_from_file(args.config)
20
+
21
+ task = config.DATASET.TASK
22
+ dataset = config.DATASET.DATA_SOURCE
23
+ device = args.device
24
+
25
+ test_num_keypoints = test_num_keypoints = config.MODEL.TEST_NUM_KEYPOINTS
26
+
27
+ build_fn = dataset_dict[task][dataset]
28
+ testset = build_fn('test', config)
29
+ testloader = torch.utils.data.DataLoader(testset, batch_size=1)
30
+
31
+ pl_relpose = PL_RelPose.load_from_checkpoint(args.ckpt_path)
32
+ pl_relpose.extractor = keypoint_dict[pl_relpose.hparams['features']](max_num_keypoints=test_num_keypoints, detection_threshold=0.0).eval().to(device)
33
+ pl_relpose.module = pl_relpose.module.eval().to(device)
34
+
35
+ preprocess_times, extract_times, regress_times = [], [], []
36
+ adds, adis = [], []
37
+ repr_errs = []
38
+ R_errs, t_errs = [], []
39
+ ts_errs = []
40
+ results_dict = defaultdict(list)
41
+ for i, data in enumerate(tqdm(testloader)):
42
+ if dataset == 'ho3d' and args.obj_name is not None and data['objName'][0] != args.obj_name:
43
+ continue
44
+ image0, image1 = data['images'][0]
45
+ K0, K1 = data['intrinsics'][0]
46
+ T = torch.eye(4)
47
+ T[:3, :3] = data['rotation'][0]
48
+ T[:3, 3] = data['translation'][0]
49
+ T = T.numpy()
50
+
51
+ # with record_function("model_inference"):
52
+ R_est, t_est, preprocess_time, extract_time, regress_time = pl_relpose.predict_one_data(data)
53
+ preprocess_times.append(preprocess_time)
54
+ extract_times.append(extract_time)
55
+ regress_times.append(regress_time)
56
+
57
+ t_err, R_err = relative_pose_error(T, R_est.cpu().numpy(), t_est.cpu().numpy(), ignore_gt_t_thr=0.0)
58
+
59
+ R_errs.append(R_err)
60
+ t_errs.append(t_err)
61
+
62
+ ts_errs.append(torch.tensor(T[:3, 3] - t_est.cpu().numpy()).norm(2))
63
+
64
+ if dataset == 'mapfree':
65
+ repr_err = reprojection_error(R_est.cpu().numpy(), t_est.cpu().numpy(), T[:3, :3], T[:3, 3], K=K1, W=image1.shape[-1], H=image1.shape[-2])
66
+ repr_errs.append(repr_err)
67
+ R = R_est.detach().cpu().numpy()
68
+ t = t_est.reshape(-1).detach().cpu().numpy()
69
+ scene = data['scene_id'][0]
70
+ estimated_pose = Pose(
71
+ image_name=data['pair_names'][1][0],
72
+ q=mat2quat(R).reshape(-1),
73
+ t=t.reshape(-1),
74
+ inliers=0
75
+ )
76
+ results_dict[scene].append(estimated_pose)
77
+
78
+ if 'point_cloud' in data:
79
+ adds.append(add(R_est.cpu().numpy(), t_est.cpu().numpy(), T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
80
+ adis.append(adi(R_est.cpu().numpy(), t_est.cpu().numpy(), T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
81
+
82
+ metrics = []
83
+ values = []
84
+
85
+ preprocess_times = np.array(preprocess_times) * 1000
86
+ extract_times = np.array(extract_times) * 1000
87
+ regress_times = np.array(regress_times) * 1000
88
+
89
+ metrics.append('Extracting Time (ms)')
90
+ values.append(f'{np.mean(extract_times):.1f}')
91
+
92
+ metrics.append('Recovering Time (ms)')
93
+ values.append(f'{np.mean(regress_times):.1f}')
94
+
95
+ metrics.append('Total Time (ms)')
96
+ values.append(f'{np.mean(extract_times) + np.mean(regress_times):.1f}')
97
+
98
+ # ts_errs = np.array(ts_errs)
99
+ # print(f'Median Trans. Error (m):\t{np.median(ts_errs):.2f}')
100
+ # print(f'Median Rot. Error (°):\t{np.median(R_errs):.2f}')
101
+
102
+ if task == 'object':
103
+ metrics.append('Object ADD')
104
+ values.append(f'{compute_continuous_auc(adds, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
105
+
106
+ metrics.append('Object ADD-S')
107
+ values.append(f'{compute_continuous_auc(adis, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
108
+
109
+ if dataset == 'mapfree':
110
+ re = np.array(repr_errs)
111
+
112
+ metrics.append('VCRE @90px Prec.')
113
+ values.append(f'{(re < 90).mean() * 100:.2f}')
114
+
115
+ metrics.append('VCRE Med.')
116
+ values.append(f'{np.median(re):.2f}')
117
+
118
+ save_submission(results_dict, 'assets/new_submission.zip')
119
+
120
+ res = pd.DataFrame({'Metrics': metrics, 'Values': values})
121
+ print(res)
122
+
123
+
124
+ def get_parser():
125
+ parser = argparse.ArgumentParser()
126
+ parser.add_argument('config', type=str, help='.yaml configure file path')
127
+ parser.add_argument('ckpt_path', type=str)
128
+
129
+ parser.add_argument('--device', type=str, default='cuda:0')
130
+ parser.add_argument('--obj_name', type=str, default=None)
131
+
132
+ return parser
133
+
134
+
135
+ if __name__ == "__main__":
136
+ parser = get_parser()
137
+ args = parser.parse_args()
138
+ main(args)
eval_baselines.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import argparse
3
+ from tqdm import tqdm
4
+ import torch
5
+ import pandas as pd
6
+
7
+ # from lightglue.utils import load_image
8
+ from configs.default import get_cfg_defaults
9
+ from datasets import dataset_dict
10
+ from baselines.pose import PoseRecover
11
+ from utils.metrics import relative_pose_error, rotation_angular_error, error_auc, add, adi, compute_continuous_auc
12
+
13
+
14
+ def main(args):
15
+ config = get_cfg_defaults()
16
+ config.merge_from_file(args.config)
17
+
18
+ task = config.DATASET.TASK
19
+ dataset = config.DATASET.DATA_SOURCE
20
+
21
+ # try:
22
+ # data_root = config.DATASET.TEST.DATA_ROOT
23
+ # except:
24
+ # data_root = config.DATASET.DATA_ROOT
25
+
26
+ build_fn = dataset_dict[task][dataset]
27
+ testset = build_fn('test', config)
28
+ testloader = torch.utils.data.DataLoader(testset, batch_size=1)
29
+
30
+ device = args.device
31
+ img_resize = args.resize
32
+ poseRec = PoseRecover(matcher=args.matcher, solver=args.solver, img_resize=img_resize, device=device)
33
+
34
+ preprocess_times, extract_times, match_times, recover_times = [], [], [], []
35
+ R_errs, t_errs = [], []
36
+ ts_errs = []
37
+ adds, adis = [], []
38
+ for i, data in enumerate(tqdm(testloader)):
39
+ if dataset == 'ho3d' and args.obj_name is not None and data['objName'][0] != args.obj_name:
40
+ continue
41
+
42
+ image0, image1 = data['images'][0].to(device)
43
+ # if dataset == 'megadepth':
44
+ # image0 = load_image(os.path.join(data_root, data['pair_names'][0][0])).to(device)
45
+ # image1 = load_image(os.path.join(data_root, data['pair_names'][1][0])).to(device)
46
+ # else:
47
+ # image0, image1 = data['images'][0].to(device)
48
+
49
+ bbox0, bbox1 = None, None
50
+ if task == 'object':
51
+ bbox0, bbox1 = data['bboxes'][0]
52
+ x1, y1, x2, y2 = bbox0
53
+ u1, v1, u2, v2 = bbox1
54
+ image0 = image0[:, y1:y2, x1:x2]
55
+ image1 = image1[:, v1:v2, u1:u2]
56
+
57
+ mask0, mask1 = None, None
58
+ if args.mask:
59
+ mask0, mask1 = data['masks'][0].to(device)
60
+
61
+ depth0, depth1 = None, None
62
+ if args.depth:
63
+ depth0, depth1 = data['depths'][0]
64
+
65
+ K0, K1 = data['intrinsics'][0]
66
+ T = torch.eye(4)
67
+ T[:3, :3] = data['rotation'][0]
68
+ T[:3, 3] = data['translation'][0]
69
+ T = T.numpy()
70
+
71
+ R, t, points0, points1, preprocess_time, extract_time, match_time, recover_time = poseRec.recover(image0, image1, K0, K1, bbox0, bbox1, mask0, mask1, depth0, depth1)
72
+ preprocess_times.append(preprocess_time)
73
+ extract_times.append(extract_time)
74
+ match_times.append(match_time)
75
+ recover_times.append(recover_time)
76
+
77
+ if np.isnan(R).any():
78
+ R_err = 180
79
+ R = np.identity(3)
80
+ t_err = 180
81
+ t = np.array([0., 0., 0.])
82
+ else:
83
+ t_err, R_err = relative_pose_error(T, R, t, ignore_gt_t_thr=0.0)
84
+
85
+ R_errs.append(R_err)
86
+ t_errs.append(t_err)
87
+
88
+ if args.depth:
89
+ t = np.nan_to_num(t)
90
+ ts_errs.append(torch.tensor(T[:3, 3] - t).norm(2))
91
+
92
+ if task == 'object':
93
+ if np.isnan(R).any():
94
+ adds.append(1.)
95
+ adis.append(1.)
96
+ else:
97
+ adds.append(add(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
98
+ adis.append(adi(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
99
+
100
+ metrics = []
101
+ values = []
102
+
103
+ preprocess_times = np.array(preprocess_time) * 1000
104
+ extract_times = np.array(extract_time) * 1000
105
+ match_times = np.array(match_times) * 1000
106
+ recover_times = np.array(recover_time) * 1000
107
+
108
+ metrics.append('Extracting Time (ms)')
109
+ values.append(f'{np.mean(extract_times):.1f}')
110
+
111
+ metrics.append('Matching Time (ms)')
112
+ values.append(f'{np.mean(match_times):.1f}')
113
+
114
+ metrics.append('Recovering Time (ms)')
115
+ values.append(f'{np.mean(recover_times):.1f}')
116
+
117
+ metrics.append('Total Time (ms)')
118
+ values.append(f'{np.mean(extract_times) + np.mean(match_times) + np.mean(recover_times):.1f}')
119
+
120
+ # pose auc
121
+ angular_thresholds = [5, 10, 20]
122
+ pose_errors = np.max(np.stack([R_errs, t_errs]), axis=0)
123
+ aucs = error_auc(pose_errors, angular_thresholds, mode='Pose estimation') # (auc@5, auc@10, auc@20)
124
+ for k in aucs:
125
+ metrics.append(k)
126
+ values.append(f'{aucs[k] * 100:.2f}')
127
+
128
+ R_errs = torch.tensor(R_errs)
129
+ t_errs = torch.tensor(t_errs)
130
+
131
+ metrics.append('Rotation Avg. Error (°)')
132
+ values.append(f'{R_errs.mean():.2f}')
133
+
134
+ metrics.append('Rotation Med. Error (°)')
135
+ values.append(f'{R_errs.median():.2f}')
136
+
137
+ metrics.append('Rotation @30° ACC')
138
+ values.append(f'{(R_errs < 30).float().mean() * 100:.1f}')
139
+
140
+ metrics.append('Rotation @15° ACC')
141
+ values.append(f'{(R_errs < 15).float().mean() * 100:.1f}')
142
+
143
+ if args.depth:
144
+ ts_errs = torch.tensor(ts_errs)
145
+
146
+ metrics.append('Translation Avg. Error (m)')
147
+ values.append(f'{ts_errs.mean():.4f}')
148
+
149
+ metrics.append('Translation Med. Error (m)')
150
+ values.append(f'{ts_errs.median():.4f}')
151
+
152
+ metrics.append('Translation @1m ACC')
153
+ values.append(f'{(ts_errs < 1.0).float().mean() * 100:.1f}')
154
+
155
+ metrics.append('Translation @10cm ACC')
156
+ values.append(f'{(ts_errs < 0.1).float().mean() * 100:.1f}')
157
+
158
+ if task == 'object':
159
+ metrics.append('Object ADD')
160
+ values.append(f'{compute_continuous_auc(adds, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
161
+
162
+ metrics.append('Object ADD-S')
163
+ values.append(f'{compute_continuous_auc(adis, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
164
+
165
+ res = pd.DataFrame({'Metrics': metrics, 'Values': values})
166
+ print(res)
167
+
168
+
169
+ def get_parser():
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument('config', type=str, help='.yaml configure file path')
172
+ parser.add_argument('matcher', type=str)
173
+ parser.add_argument('--solver', type=str, default='procrustes')
174
+
175
+ parser.add_argument('--resize', type=int, default=None)
176
+ parser.add_argument('--depth', action='store_true')
177
+
178
+ parser.add_argument('--mask', action='store_true')
179
+ parser.add_argument('--obj_name', type=str, default=None)
180
+
181
+ parser.add_argument('--device', type=str, default='cuda:0')
182
+
183
+ return parser
184
+
185
+
186
+ if __name__ == "__main__":
187
+ parser = get_parser()
188
+ args = parser.parse_args()
189
+ main(args)
model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch
2
+
3
+ from .relpose import RelPose
4
+ from .pl_trainer import PL_RelPose, keypoint_dict
model/pl_trainer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import numpy as np
3
+ import torch
4
+ import lightning as L
5
+ from lightglue import SuperPoint, DISK, SIFT, ALIKED
6
+ import time
7
+
8
+ from utils import rotation_angular_error, translation_angular_error, error_auc
9
+ from .relpose import RelPose
10
+
11
+
12
+ keypoint_dict = {
13
+ 'superpoint': SuperPoint,
14
+ 'disk': DISK,
15
+ 'sift': SIFT,
16
+ 'aliked': ALIKED,
17
+ }
18
+
19
+
20
+ class PL_RelPose(L.LightningModule):
21
+ def __init__(
22
+ self,
23
+ task,
24
+ lr,
25
+ epochs,
26
+ pct_start,
27
+ num_keypoints,
28
+ n_layers,
29
+ num_heads,
30
+ features='superpoint',
31
+ ):
32
+ super().__init__()
33
+
34
+ self.extractor = keypoint_dict[features](max_num_keypoints=num_keypoints, detection_threshold=0.0).eval()
35
+ self.module = RelPose(features=features, task=task, n_layers=n_layers, num_heads=num_heads)
36
+ self.criterion = torch.nn.HuberLoss()
37
+
38
+ self.s_r = torch.nn.Parameter(torch.zeros(1))
39
+ # self.s_ta = torch.nn.Parameter(torch. zeros(1))
40
+ self.s_t = torch.nn.Parameter(torch.zeros(1))
41
+
42
+ self.r_errors = {k:[] for k in ['train', 'valid', 'test']}
43
+ self.ta_errors = {k:[] for k in ['train', 'valid', 'test']}
44
+ self.t_errors = {k:[] for k in ['train', 'valid', 'test']}
45
+
46
+ self.save_hyperparameters()
47
+
48
+ def _shared_log(self, mode, loss, loss_r, loss_t, loss_ta, loss_tn):
49
+ self.log_dict({
50
+ f'{mode}_loss/sum': loss,
51
+ f'{mode}_loss/r': loss_r,
52
+ f'{mode}_loss/t': loss_t,
53
+ f'{mode}_loss/ta': loss_ta,
54
+ f'{mode}_loss/tn': loss_tn,
55
+ }, on_epoch=True, sync_dist=True)
56
+
57
+ def training_step(self, batch, batch_idx):
58
+ loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
59
+
60
+ self.r_errors['train'].append(r_err)
61
+ self.ta_errors['train'].append(ta_err)
62
+ self.t_errors['train'].append(t_err)
63
+
64
+ self._shared_log('train', loss, loss_r, loss_t, loss_ta, loss_tn)
65
+
66
+ return loss
67
+
68
+ def validation_step(self, batch, batch_idx):
69
+ loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
70
+
71
+ self.r_errors['valid'].append(r_err)
72
+ self.ta_errors['valid'].append(ta_err)
73
+ self.t_errors['valid'].append(t_err)
74
+
75
+ self._shared_log('valid', loss, loss_r, loss_t, loss_ta, loss_tn)
76
+
77
+ def test_step(self, batch, batch_idx):
78
+ loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
79
+
80
+ self.r_errors['test'].append(r_err)
81
+ self.ta_errors['test'].append(ta_err)
82
+ self.t_errors['test'].append(t_err)
83
+
84
+ self._shared_log('test', loss, loss_r, loss_t, loss_ta, loss_tn)
85
+
86
+ def _shared_forward_step(self, batch, batch_idx):
87
+ images = batch['images']
88
+ rotation = batch['rotation']
89
+ translation = batch['translation']
90
+ intrinsics = batch['intrinsics']
91
+
92
+ image0 = images[:, 0, ...]
93
+ image1 = images[:, 1, ...]
94
+
95
+ with torch.no_grad():
96
+ feats0 = self.extractor({'image': image0})
97
+ feats1 = self.extractor({'image': image1})
98
+
99
+ if 'scales' in batch:
100
+ scales = batch['scales']
101
+ feats0['keypoints'] *= scales[:, 0].unsqueeze(1)
102
+ feats1['keypoints'] *= scales[:, 1].unsqueeze(1)
103
+
104
+ if self.hparams.task == 'scene':
105
+ pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
106
+ elif self.hparams.task == 'object':
107
+ bboxes = batch['bboxes']
108
+ pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0], 'bbox': bboxes[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
109
+
110
+ r_err = rotation_angular_error(pred_r, rotation)
111
+ ta_err = translation_angular_error(pred_t, translation)
112
+
113
+ loss_r = self.criterion(r_err, torch.zeros_like(r_err))
114
+ loss_ta = self.criterion(ta_err, torch.zeros_like(ta_err))
115
+ loss_tn = self.criterion(pred_t / pred_t.norm(2, dim=-1, keepdim=True), translation / translation.norm(2, dim=-1, keepdim=True))
116
+ loss_t = self.criterion(pred_t, translation)
117
+
118
+ # loss = loss_r * torch.exp(-self.s_r) + loss_t * torch.exp(-self.s_t) + loss_ta * torch.exp(-self.s_ta) + self.s_r + self.s_t + self.s_ta
119
+ loss = loss_r + loss_ta + loss_t + loss_tn
120
+
121
+ r_err = r_err.detach()
122
+ ta_err = ta_err.detach()
123
+ t_err = (pred_t.detach() - translation).norm(2, dim=1)
124
+
125
+ return loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err
126
+
127
+ def predict_one_data(self, data, device='cuda'):
128
+ st_time = time.time()
129
+ images = data['images'].to(device)
130
+ intrinsics = data['intrinsics'].to(device)
131
+
132
+ image0 = images[:, 0, ...]
133
+ image1 = images[:, 1, ...]
134
+
135
+ preprocess = time.time()
136
+
137
+ with torch.no_grad():
138
+ feats0 = self.extractor({'image': image0})
139
+ feats1 = self.extractor({'image': image1})
140
+
141
+ extract_time = time.time()
142
+
143
+ if 'scales' in data:
144
+ scales = data['scales'].to(device)
145
+ feats0['keypoints'] *= scales[:, 0].unsqueeze(1)
146
+ feats1['keypoints'] *= scales[:, 1].unsqueeze(1)
147
+
148
+ if self.hparams.task == 'scene':
149
+ pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
150
+ elif self.hparams.task == 'object':
151
+ bboxes = data['bboxes'].to(device)
152
+ pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0], 'bbox': bboxes[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
153
+
154
+ regress_time = time.time()
155
+
156
+ return pred_r[0], pred_t[0], preprocess-st_time, extract_time-preprocess, regress_time-extract_time
157
+
158
+
159
+ def _shared_on_epoch_end(self, mode):
160
+ r_errors = torch.hstack(self.r_errors[mode]).rad2deg()
161
+ ta_errors = torch.hstack(self.ta_errors[mode]).rad2deg()
162
+ ta_errors = torch.minimum(ta_errors, 180-ta_errors)
163
+
164
+ auc = error_auc(torch.maximum(r_errors, ta_errors).cpu(), [5, 10, 20], mode)
165
+ t_errors = torch.hstack(self.t_errors[mode])
166
+
167
+ self.log_dict({
168
+ **auc,
169
+ f'{mode}_Rot./Avg. Error': r_errors.mean(),
170
+ f'{mode}_Rot./Med. Error': r_errors.median(),
171
+ f'{mode}_Rot./@30° ACC': (r_errors < 30).float().mean(),
172
+ f'{mode}_Rot./@15° ACC': (r_errors < 15).float().mean(),
173
+ # f'{mode}_ta/avg': ta_errors.mean(),
174
+ # f'{mode}_ta/med': ta_errors.median(),
175
+ f'{mode}_Trans./Avg. Error': t_errors.mean(),
176
+ f'{mode}_Trans./Med. Error': t_errors.median(),
177
+ f'{mode}_Trans./@10cm ACC': (t_errors < 0.1).float().mean(),
178
+ f'{mode}_Trans./@1m ACC': (t_errors < 1.0).float().mean(),
179
+ }, sync_dist=True)
180
+
181
+ self.r_errors[mode].clear()
182
+ self.ta_errors[mode].clear()
183
+ self.t_errors[mode].clear()
184
+
185
+ def on_train_epoch_end(self):
186
+ self._shared_on_epoch_end('train')
187
+
188
+ def on_validation_epoch_end(self):
189
+ self._shared_on_epoch_end('valid')
190
+
191
+ def on_test_epoch_end(self):
192
+ self._shared_on_epoch_end('test')
193
+
194
+ def configure_optimizers(self):
195
+ optimizer = torch.optim.AdamW(self.module.parameters(), lr=self.hparams.lr)
196
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=1, epochs=self.hparams.epochs, pct_start=self.hparams.pct_start)
197
+
198
+ return {
199
+ 'optimizer': optimizer,
200
+ 'lr_scheduler': scheduler
201
+ }
model/relpose.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from types import SimpleNamespace
3
+ from typing import Callable, List, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from utils import rotation_matrix_from_ortho6d
10
+
11
+ try:
12
+ from flash_attn.modules.mha import FlashCrossAttention
13
+ except ModuleNotFoundError:
14
+ FlashCrossAttention = None
15
+
16
+ if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
17
+ FLASH_AVAILABLE = True
18
+ else:
19
+ FLASH_AVAILABLE = False
20
+
21
+ torch.backends.cudnn.deterministic = True
22
+ torch.set_float32_matmul_precision('medium')
23
+
24
+
25
+ def normalize_keypoints(kpts, intrinsics):
26
+ # kpts: (B, M, 2)
27
+ # intrinsics: (B, 3, 3)
28
+ b, m, _ = kpts.shape
29
+ kpts = torch.cat([kpts, torch.ones((b, m, 1), device=kpts.device)], dim=2)
30
+ kpts = intrinsics.inverse() @ kpts.mT
31
+ kpts = kpts.mT
32
+ kpts = kpts[..., :2]
33
+
34
+ return kpts
35
+
36
+
37
+ # @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
38
+ def cosine_similarity(x, y):
39
+ sim = torch.einsum('...id,...jd->...ij', x / x.norm(2, -1, keepdim=True), y / y.norm(2, -1, keepdim=True))
40
+ sim = (sim + 1) / 2
41
+ return sim
42
+
43
+
44
+ def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
45
+ if length <= x.shape[-2]:
46
+ return x, torch.ones_like(x[..., :1], dtype=torch.bool)
47
+ pad = torch.ones(
48
+ *x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
49
+ )
50
+ y = torch.cat([x, pad], dim=-2)
51
+ mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
52
+ mask[..., : x.shape[-2], :] = True
53
+ return y, mask
54
+
55
+
56
+ def gather(x: torch.Tensor, indices: torch.tensor):
57
+ b, _, n = x.shape
58
+ bs = torch.arange(b).reshape(b, 1, 1)
59
+ ns = torch.arange(n)
60
+ return x[bs, indices.unsqueeze(-1), ns]
61
+
62
+
63
+ class Attention(nn.Module):
64
+ def __init__(self, allow_flash: bool = True) -> None:
65
+ super().__init__()
66
+ if allow_flash and not FLASH_AVAILABLE:
67
+ warnings.warn(
68
+ "FlashAttention is not available. For optimal speed, "
69
+ "consider installing torch >= 2.0 or flash-attn.",
70
+ stacklevel=2,
71
+ )
72
+ self.enable_flash = allow_flash and FLASH_AVAILABLE
73
+ self.has_sdp = hasattr(F, "scaled_dot_product_attention")
74
+ if allow_flash and FlashCrossAttention:
75
+ self.flash_ = FlashCrossAttention()
76
+ if self.has_sdp:
77
+ torch.backends.cuda.enable_flash_sdp(allow_flash)
78
+
79
+ def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
80
+ if self.enable_flash and q.device.type == "cuda":
81
+ # use torch 2.0 scaled_dot_product_attention with flash
82
+ if self.has_sdp:
83
+ args = [x.contiguous() for x in [q, k, v]]
84
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
85
+ return v if mask is None else v.nan_to_num()
86
+ else:
87
+ assert mask is None
88
+ q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
89
+ m = self.flash_(q, torch.stack([k, v], 2))
90
+ return m.transpose(-2, -3).to(q.dtype).clone()
91
+ elif self.has_sdp:
92
+ args = [x.contiguous() for x in [q, k, v]]
93
+ v = F.scaled_dot_product_attention(*args, attn_mask=mask)
94
+ return v if mask is None else v.nan_to_num()
95
+ else:
96
+ s = q.shape[-1] ** -0.5
97
+ sim = torch.einsum("...id,...jd->...ij", q, k) * s
98
+ if mask is not None:
99
+ sim.masked_fill(~mask, -float("inf"))
100
+ attn = F.softmax(sim, -1)
101
+ return torch.einsum("...ij,...jd->...id", attn, v)
102
+
103
+
104
+ class SelfBlock(nn.Module):
105
+ def __init__(
106
+ self, embed_dim: int, num_heads: int, bias: bool = True
107
+ ) -> None:
108
+ super().__init__()
109
+ self.embed_dim = embed_dim
110
+ self.num_heads = num_heads
111
+ assert self.embed_dim % num_heads == 0
112
+ self.head_dim = self.embed_dim // num_heads
113
+ self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
114
+ self.inner_attn = Attention()
115
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
116
+ self.ffn = nn.Sequential(
117
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
118
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
119
+ nn.GELU(),
120
+ nn.Linear(2 * embed_dim, embed_dim),
121
+ )
122
+
123
+ def forward(
124
+ self,
125
+ x: torch.Tensor,
126
+ encoding: torch.Tensor,
127
+ mask: Optional[torch.Tensor] = None,
128
+ ) -> torch.Tensor:
129
+ qkv = self.Wqkv(x)
130
+ qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
131
+ q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
132
+ q += encoding
133
+ k += encoding
134
+
135
+ context = self.inner_attn(q, k, v, mask=mask)
136
+
137
+ message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
138
+ return x + self.ffn(torch.cat([x, message], -1))
139
+
140
+
141
+ class CrossBlock(nn.Module):
142
+ def __init__(
143
+ self, embed_dim: int, num_heads: int, bias: bool = True
144
+ ) -> None:
145
+ super().__init__()
146
+ self.heads = num_heads
147
+ dim_head = embed_dim // num_heads
148
+ self.scale = dim_head**-0.5
149
+ inner_dim = dim_head * num_heads
150
+ self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
151
+ self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
152
+ self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
153
+ self.ffn = nn.Sequential(
154
+ nn.Linear(2 * embed_dim, 2 * embed_dim),
155
+ nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
156
+ nn.GELU(),
157
+ nn.Linear(2 * embed_dim, embed_dim),
158
+ )
159
+ # self.reg_attn = nn.Identity()
160
+ # self.reg_sim = nn.Identity()
161
+
162
+ def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
163
+ return func(x0), func(x1)
164
+
165
+ def forward(
166
+ self, x0: torch.Tensor, x1: torch.Tensor, match: torch.Tensor, mask: Optional[torch.Tensor] = None
167
+ ) -> List[torch.Tensor]:
168
+
169
+ qk0, qk1 = self.map_(self.to_qk, x0, x1)
170
+ v0, v1 = self.map_(self.to_v, x0, x1)
171
+ qk0, qk1, v0, v1 = map(
172
+ lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
173
+ (qk0, qk1, v0, v1),
174
+ )
175
+
176
+ qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
177
+ sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
178
+ if mask is not None:
179
+ sim = sim.masked_fill(~mask.unsqueeze(1), -float("inf"))
180
+
181
+ assert len(match.shape) == 3
182
+ match = match.unsqueeze(1)
183
+ sim = sim * match
184
+ # sim = self.reg_attn(sim)
185
+ # match = self.reg_sim(match)
186
+
187
+ attn01 = F.softmax(sim, dim=-1)
188
+ attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
189
+ m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
190
+ m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
191
+ if mask is not None:
192
+ m0, m1 = m0.nan_to_num(), m1.nan_to_num()
193
+
194
+ m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
195
+ m0, m1 = self.map_(self.to_out, m0, m1)
196
+ x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
197
+ x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
198
+
199
+ return x0, x1
200
+
201
+
202
+ class TransformerLayer(nn.Module):
203
+ def __init__(self, *args, **kwargs):
204
+ super().__init__()
205
+ self.self_attn = SelfBlock(*args, **kwargs)
206
+ self.cross_attn = CrossBlock(*args, **kwargs)
207
+
208
+ def forward(
209
+ self,
210
+ desc0,
211
+ desc1,
212
+ encoding0,
213
+ encoding1,
214
+ match,
215
+ mask0: Optional[torch.Tensor] = None,
216
+ mask1: Optional[torch.Tensor] = None,
217
+ ):
218
+ if mask0 is not None and mask1 is not None:
219
+ return self.masked_forward(desc0, desc1, encoding0, encoding1, match, mask0, mask1)
220
+ else:
221
+ desc0 = self.self_attn(desc0, encoding0)
222
+ desc1 = self.self_attn(desc1, encoding1)
223
+ return self.cross_attn(desc0, desc1, match)
224
+
225
+ # This part is compiled and allows padding inputs
226
+ def masked_forward(self, desc0, desc1, encoding0, encoding1, match, mask0, mask1):
227
+ mask = mask0 & mask1.transpose(-1, -2)
228
+ mask0 = mask0 & mask0.transpose(-1, -2)
229
+ mask1 = mask1 & mask1.transpose(-1, -2)
230
+ desc0 = self.self_attn(desc0, encoding0, mask0)
231
+ desc1 = self.self_attn(desc1, encoding1, mask1)
232
+ return self.cross_attn(desc0, desc1, match, mask)
233
+
234
+
235
+ class RelPose(nn.Module):
236
+ default_conf = {
237
+ "name": "RelPose", # just for interfacing
238
+ "input_dim": 256, # input descriptor dimension (autoselected from weights)
239
+ "descriptor_dim": 256,
240
+ "add_scale_ori": False,
241
+ "n_layers": 3,
242
+ "num_heads": 4,
243
+ "pct_pruning": 0,
244
+ "task": "scene",
245
+ "mp": False, # enable mixed precision
246
+ "weights": None,
247
+ }
248
+
249
+ required_data_keys = ["image0", "image1"]
250
+
251
+ features = {
252
+ "superpoint": {
253
+ "input_dim": 256,
254
+ },
255
+ "disk": {
256
+ "input_dim": 128,
257
+ },
258
+ "aliked": {
259
+ "input_dim": 128,
260
+ },
261
+ "sift": {
262
+ "input_dim": 128,
263
+ # "add_scale_ori": True,
264
+ },
265
+ }
266
+
267
+ def __init__(self, features="superpoint", **conf) -> None:
268
+ super().__init__()
269
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
270
+ if features is not None:
271
+ if features not in self.features:
272
+ raise ValueError(
273
+ f"Unsupported features: {features} not in "
274
+ f"{{{','.join(self.features)}}}"
275
+ )
276
+ for k, v in self.features[features].items():
277
+ setattr(conf, k, v)
278
+
279
+ if conf.input_dim != conf.descriptor_dim:
280
+ self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
281
+ else:
282
+ self.input_proj = nn.Identity()
283
+
284
+ head_dim = conf.descriptor_dim // conf.num_heads
285
+ self.posenc = nn.Linear(
286
+ 2 + 2 * self.conf.add_scale_ori, head_dim
287
+ )
288
+
289
+ h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
290
+
291
+ self.transformers = nn.ModuleList(
292
+ [TransformerLayer(d, h) for _ in range(n)]
293
+ )
294
+
295
+ self.rotation_regressor = nn.Sequential(
296
+ nn.Linear(conf.descriptor_dim*2, conf.descriptor_dim),
297
+ nn.ReLU(),
298
+ nn.Linear(conf.descriptor_dim, conf.descriptor_dim//2),
299
+ nn.ReLU(),
300
+ nn.Linear(conf.descriptor_dim//2, 6),
301
+ )
302
+
303
+ self.translation_regressor = nn.Sequential(
304
+ nn.Linear(conf.descriptor_dim*2, conf.descriptor_dim),
305
+ nn.ReLU(),
306
+ nn.Linear(conf.descriptor_dim, conf.descriptor_dim//2),
307
+ nn.ReLU(),
308
+ nn.Linear(conf.descriptor_dim//2, 3),
309
+ )
310
+
311
+ # self.reg_kpts0 = nn.Identity()
312
+ # self.reg_kpts1 = nn.Identity()
313
+
314
+ # static lengths LightGlue is compiled for (only used with torch.compile)
315
+ self.static_lengths = None
316
+
317
+ def compile(
318
+ self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
319
+ ):
320
+ for i in range(self.conf.n_layers):
321
+ self.transformers[i].masked_forward = torch.compile(
322
+ self.transformers[i].masked_forward, mode=mode, fullgraph=True
323
+ )
324
+
325
+ self.static_lengths = static_lengths
326
+
327
+ def forward(self, data: dict) -> dict:
328
+ """
329
+ Match keypoints and descriptors between two images
330
+
331
+ Input (dict):
332
+ image0: dict
333
+ keypoints: [B x M x 2]
334
+ descriptors: [B x M x D]
335
+ image: [B x C x H x W] or image_size: [B x 2]
336
+ image1: dict
337
+ keypoints: [B x N x 2]
338
+ descriptors: [B x N x D]
339
+ image: [B x C x H x W] or image_size: [B x 2]
340
+ Output
341
+ """
342
+ with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
343
+ return self._forward(data)
344
+
345
+ def _forward(self, data: dict) -> dict:
346
+ for key in self.required_data_keys:
347
+ assert key in data, f"Missing key {key} in data"
348
+ data0, data1 = data["image0"], data["image1"]
349
+ kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
350
+ intrinsic0, intrinsic1 = data0["intrinsics"], data1["intrinsics"]
351
+ b, m, _ = kpts0.shape
352
+ b, n, _ = kpts1.shape
353
+
354
+ if self.conf.add_scale_ori:
355
+ kpts0 = torch.cat(
356
+ [kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
357
+ )
358
+ kpts1 = torch.cat(
359
+ [kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
360
+ )
361
+ desc0 = data0["descriptors"].detach().contiguous()
362
+ desc1 = data1["descriptors"].detach().contiguous()
363
+
364
+ assert desc0.shape[-1] == self.conf.input_dim
365
+ assert desc1.shape[-1] == self.conf.input_dim
366
+
367
+ mask0, mask1 = None, None
368
+ c = max(m, n)
369
+ do_compile = self.static_lengths and c <= max(self.static_lengths)
370
+ if do_compile:
371
+ kn = min([k for k in self.static_lengths if k >= c])
372
+ desc0, mask0 = pad_to_length(desc0, kn)
373
+ desc1, mask1 = pad_to_length(desc1, kn)
374
+ kpts0, _ = pad_to_length(kpts0, kn)
375
+ kpts1, _ = pad_to_length(kpts1, kn)
376
+
377
+ matchability = cosine_similarity(desc0, desc1)
378
+
379
+ assert self.conf.pct_pruning >= 0 and self.conf.pct_pruning < 1
380
+ if self.conf.pct_pruning > 0:
381
+ ind0, ind1 = self.get_pruned_indices(matchability, self.conf.pct_pruning)
382
+
383
+ matchability = gather(matchability, ind0)
384
+ matchability = gather(matchability.mT, ind1).mT
385
+
386
+ desc0 = gather(desc0, ind0)
387
+ desc1 = gather(desc1, ind1)
388
+
389
+ kpts0 = gather(kpts0, ind0)
390
+ kpts1 = gather(kpts1, ind1)
391
+
392
+ if self.conf.task == "object":
393
+ bbox = data0["bbox"] # (B, 4)
394
+ ind0, mask0 = self.get_prompted_indices(kpts0, bbox)
395
+
396
+ matchability[:, 0] = torch.zeros_like(matchability[:, 0], device=matchability.device)
397
+ desc0[:, 0] = torch.zeros_like(desc0[:, 0], device=desc0.device)
398
+ kpts0[:, 0] = torch.zeros_like(kpts0[:, 0], device=kpts0.device)
399
+
400
+ matchability = gather(matchability, ind0)
401
+ desc0 = gather(desc0, ind0)
402
+ kpts0 = gather(kpts0, ind0)
403
+
404
+ desc0 = self.input_proj(desc0)
405
+ desc1 = self.input_proj(desc1)
406
+
407
+ # kpts0 = self.reg_kpts0(kpts0)
408
+ # kpts1 = self.reg_kpts1(kpts1)
409
+
410
+ # cache positional embeddings
411
+ kpts0 = normalize_keypoints(kpts0, intrinsic0)
412
+ kpts1 = normalize_keypoints(kpts1, intrinsic1)
413
+
414
+ encoding0 = self.posenc(kpts0).unsqueeze(-3)
415
+ encoding1 = self.posenc(kpts1).unsqueeze(-3)
416
+
417
+ for i in range(self.conf.n_layers):
418
+ desc0, desc1 = self.transformers[i](
419
+ desc0, desc1, encoding0, encoding1, match=matchability, mask0=mask0, mask1=mask1,
420
+ )
421
+
422
+ desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
423
+ if self.conf.task == 'object':
424
+ n_kpts0 = mask0.sum(1, keepdim=True)
425
+ n_kpts0 = torch.clip(n_kpts0, min=1)
426
+ desc0 = (desc0 * mask0.unsqueeze(-1)).sum(1) / n_kpts0
427
+ desc1 = desc1.mean(1)
428
+ else:
429
+ desc0, desc1 = desc0.mean(1), desc1.mean(1)
430
+
431
+ feat = torch.cat([desc0, desc1], 1)
432
+
433
+ R = self.rotation_regressor(feat)
434
+ R = rotation_matrix_from_ortho6d(R)
435
+ t = self.translation_regressor(feat)
436
+
437
+ return R, t
438
+
439
+ def get_pruned_indices(self, match, pct_pruning):
440
+ matching_scores0 = match.mean(-1)
441
+ matching_scores1 = match.mean(-2)
442
+
443
+ num_pruning0 = int(pct_pruning * matching_scores0.size(-1))
444
+ num_pruning1 = int(pct_pruning * matching_scores1.size(-1))
445
+
446
+ _, indices0 = matching_scores0.sort()
447
+ _, indices1 = matching_scores1.sort()
448
+
449
+ indices0 = indices0[:, num_pruning0:]
450
+ indices1 = indices1[:, num_pruning1:]
451
+
452
+ return indices0, indices1
453
+
454
+ def get_prompted_indices(self, kpts, bbox):
455
+ # kpts: (B, M, 2)
456
+ # bbox: (B, 4) - (x, y, x, y)
457
+ x, y = kpts[..., 0], kpts[..., 1]
458
+ mask = (x >= bbox[:, 0].unsqueeze(-1)) & (x <= bbox[:, 2].unsqueeze(-1))
459
+ mask &= (y >= bbox[:, 1].unsqueeze(-1)) & (y <= bbox[:, 3].unsqueeze(-1))
460
+ mask_sorted, indices = mask.long().sort(descending=True)
461
+ indices *= mask_sorted
462
+ indices = indices[:, :mask_sorted.sum(-1).max()]
463
+ mask_sorted = mask_sorted[:, :mask_sorted.sum(-1).max()]
464
+
465
+ return indices, mask_sorted
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ albumentations==1.4.1
2
+ kornia==0.7.1
3
+ open3d==0.18.0
4
+ opencv-python==4.9.0.80
5
+ plyfile==1.0.3
6
+ scikit-learn==1.4.1.post1
7
+ yacs==0.1.8
8
+ lightning==2.2.1
9
+ transforms3d==0.4.1
10
+ pandas==2.1.1
11
+ lightglue @ git+https://github.com/cvg/LightGlue@main
train.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from torch.utils.data import DataLoader
3
+ import lightning as L
4
+ from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
5
+
6
+ from datasets import dataset_dict, RandomConcatSampler
7
+ from model import PL_RelPose
8
+ from utils import seed_torch
9
+ from configs.default import get_cfg_defaults
10
+
11
+
12
+ def main(args):
13
+ config = get_cfg_defaults()
14
+ config.merge_from_file(args.config)
15
+
16
+ task = config.DATASET.TASK
17
+ dataset = config.DATASET.DATA_SOURCE
18
+
19
+ batch_size = config.TRAINER.BATCH_SIZE
20
+ num_workers = config.TRAINER.NUM_WORKERS
21
+ pin_memory = config.TRAINER.PIN_MEMORY
22
+ n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
23
+ lr = config.TRAINER.LEARNING_RATE
24
+ epochs = config.TRAINER.EPOCHS
25
+ pct_start = config.TRAINER.PCT_START
26
+
27
+ num_keypoints = config.MODEL.NUM_KEYPOINTS
28
+ n_layers = config.MODEL.N_LAYERS
29
+ num_heads = config.MODEL.NUM_HEADS
30
+ features = config.MODEL.FEATURES
31
+
32
+ seed = config.RANDOM_SEED
33
+ seed_torch(seed)
34
+
35
+ build_fn = dataset_dict[task][dataset]
36
+ trainset = build_fn('train', config)
37
+ validset = build_fn('val', config)
38
+
39
+ if dataset == 'scannet' or dataset == 'megadepth' or dataset == 'linemod' or dataset == 'ho3d' or dataset == 'mapfree':
40
+ sampler = RandomConcatSampler(
41
+ trainset,
42
+ n_samples_per_subset=n_samples_per_subset,
43
+ subset_replacement=True,
44
+ shuffle=True,
45
+ seed=seed
46
+ )
47
+ trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, sampler=sampler)
48
+ else:
49
+ trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)
50
+
51
+ validloader = DataLoader(validset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
52
+
53
+ if args.weights is None:
54
+ pl_relpose = PL_RelPose(
55
+ task=task,
56
+ lr=lr,
57
+ epochs=epochs,
58
+ pct_start=pct_start,
59
+ n_layers=n_layers,
60
+ num_heads=num_heads,
61
+ num_keypoints=num_keypoints,
62
+ features=features,
63
+ )
64
+ else:
65
+ pl_relpose = PL_RelPose.load_from_checkpoint(
66
+ checkpoint_path=args.weights,
67
+ task=task,
68
+ lr=lr,
69
+ epochs=epochs,
70
+ pct_start=pct_start,
71
+ n_layers=n_layers,
72
+ num_heads=num_heads,
73
+ num_keypoints=num_keypoints,
74
+ )
75
+
76
+ lr_monitor = LearningRateMonitor(logging_interval='epoch')
77
+ latest_checkpoint_callback = ModelCheckpoint()
78
+ best_checkpoint_callback = ModelCheckpoint(monitor='valid/auc@20', mode='max')
79
+ trainer = L.Trainer(
80
+ devices=[0],
81
+ # devices=[0, 1],
82
+ # accelerator='gpu', strategy='ddp_find_unused_parameters_true',
83
+ max_epochs=epochs,
84
+ callbacks=[lr_monitor, latest_checkpoint_callback, best_checkpoint_callback],
85
+ precision="bf16-mixed",
86
+ # fast_dev_run=1,
87
+ )
88
+
89
+ trainer.fit(pl_relpose, trainloader, validloader, ckpt_path=args.resume)
90
+
91
+
92
+ def get_parser():
93
+ parser = argparse.ArgumentParser()
94
+ parser.add_argument('config', type=str, help='.yaml configure file path')
95
+ parser.add_argument('--resume', type=str, default=None)
96
+ parser.add_argument('--weights', type=str, default=None)
97
+
98
+ return parser
99
+
100
+
101
+ if __name__ == "__main__":
102
+ parser = get_parser()
103
+ args = parser.parse_args()
104
+ main(args)
utils/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+
6
+ from .metrics import quat_degree_error, rotation_angular_error, translation_angular_error, error_auc
7
+ from .transform import rotation_matrix_from_ortho6d, rotation_matrix_from_quaternion
8
+ from .augment import Augmentor
9
+ # from .visualize import project_3D_points, plot_3D_box
10
+
11
+
12
+ def seed_torch(seed):
13
+ random.seed(seed)
14
+ os.environ['PYTHONHASHSEED'] = str(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed(seed)
18
+ torch.backends.cudnn.benchmark = False
19
+ torch.backends.cudnn.deterministic = True
utils/augment.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+
3
+
4
+ class Augmentor(object):
5
+ def __init__(self, is_training:bool):
6
+ self.augmentor = A.Compose([
7
+ A.MotionBlur(p=0.25),
8
+ A.ColorJitter(p=0.25),
9
+ A.ImageCompression(p=0.25),
10
+ A.ISONoise(p=0.25),
11
+ A.ToGray(p=0.1)
12
+ ], p=float(is_training))
13
+
14
+ def __call__(self, x):
15
+ return self.augmentor(image=x)['image']