Upload 53 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +35 -0
- LICENSE +201 -0
- README.md +183 -3
- assets/figures/obj_vis_gt.png +0 -0
- assets/figures/obj_vis_query.png +0 -0
- assets/figures/obj_vis_reference_labeled.png +0 -0
- assets/figures/scene5_vis_0.png +0 -0
- assets/figures/scene5_vis_1.png +0 -0
- assets/figures/scene5_vis_gt.png +0 -0
- assets/ho3d_test_3000/ho3d_test.json +0 -0
- assets/linemod_test_1500/linemod_test.json +0 -0
- assets/mapfree_submission.zip +3 -0
- assets/megadepth_test_1500_scene_info/0015_0.1_0.3.npz +3 -0
- assets/megadepth_test_1500_scene_info/0015_0.3_0.5.npz +3 -0
- assets/megadepth_test_1500_scene_info/0022_0.1_0.3.npz +3 -0
- assets/megadepth_test_1500_scene_info/0022_0.3_0.5.npz +3 -0
- assets/megadepth_test_1500_scene_info/0022_0.5_0.7.npz +3 -0
- assets/megadepth_test_1500_scene_info/megadepth_test_1500.txt +5 -0
- assets/scannet_test_1500/intrinsics.npz +3 -0
- assets/scannet_test_1500/scannet_test.txt +1 -0
- assets/scannet_test_1500/statistics.json +102 -0
- assets/scannet_test_1500/test.npz +3 -0
- baselines/matchers.py +72 -0
- baselines/pose.py +92 -0
- baselines/pose_solver.py +320 -0
- configs/default.py +85 -0
- configs/ho3d.yaml +19 -0
- configs/linemod.yaml +20 -0
- configs/mapfree.yaml +14 -0
- configs/matterport.yaml +11 -0
- configs/megadepth.yaml +29 -0
- configs/scannet.yaml +33 -0
- datasets/__init__.py +20 -0
- datasets/ho3d.py +331 -0
- datasets/linemod.py +441 -0
- datasets/mapfree.py +178 -0
- datasets/matterport.py +86 -0
- datasets/megadepth.py +125 -0
- datasets/sampler.py +77 -0
- datasets/scannet.py +154 -0
- eval.py +48 -0
- eval_add_reproj.py +138 -0
- eval_baselines.py +189 -0
- model/__init__.py +4 -0
- model/pl_trainer.py +201 -0
- model/relpose.py +465 -0
- requirements.txt +11 -0
- train.py +104 -0
- utils/__init__.py +19 -0
- utils/augment.py +15 -0
.gitignore
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**/__pycache__
|
| 2 |
+
/checkpoints
|
| 3 |
+
/log
|
| 4 |
+
/lightning_logs
|
| 5 |
+
/pyramid
|
| 6 |
+
*.ipynb
|
| 7 |
+
/data
|
| 8 |
+
preprocess_megadepth.py
|
| 9 |
+
*.ckpt
|
| 10 |
+
test.py
|
| 11 |
+
/RelPoseRepo
|
| 12 |
+
eval_regressor.py
|
| 13 |
+
/assets/megadepth_test_new
|
| 14 |
+
/configs/megadepth_new.yaml
|
| 15 |
+
/results
|
| 16 |
+
assets/new_submission.zip
|
| 17 |
+
|
| 18 |
+
__eval_baselines.py
|
| 19 |
+
__pose_tracking.py
|
| 20 |
+
__track.py
|
| 21 |
+
/__pose_tracking
|
| 22 |
+
|
| 23 |
+
/baselines/configs
|
| 24 |
+
/baselines/repo
|
| 25 |
+
/baselines/weights
|
| 26 |
+
/baselines/__models.py
|
| 27 |
+
baselines/demo.html
|
| 28 |
+
|
| 29 |
+
utils/__reprojection.py
|
| 30 |
+
utils/__pose_solver.py
|
| 31 |
+
utils/__generate_epipolar_imgs.py
|
| 32 |
+
utils/__visualize.py
|
| 33 |
+
|
| 34 |
+
submission.py
|
| 35 |
+
/qualitative
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,183 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SRPose: Two-view Relative Pose Estimation with Sparse Keypoints
|
| 2 |
+
|
| 3 |
+
**SRPose**: A **S**parse keypoint-based framework for **R**elative **Pose** estimation between two views in both camera-to-world and object-to-camera scenarios.
|
| 4 |
+
|
| 5 |
+
| Reference | Query | Ground Truth |
|
| 6 |
+
|:--------:|:---------:|:--------:|
|
| 7 |
+
|  |  |  |
|
| 8 |
+
|  |  ||
|
| 9 |
+
|
| 10 |
+
## [Project page](https://frickyinn.github.io/srpose/) | [arXiv](https://arxiv.org/abs/2407.08199)
|
| 11 |
+
|
| 12 |
+
## Setup
|
| 13 |
+
Please first intall PyTorch according to [here](https://pytorch.org/get-started/locally/), then install other dependencies using pip:
|
| 14 |
+
```
|
| 15 |
+
cd SRPose
|
| 16 |
+
pip install -r requirements.txt
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Evaluation
|
| 20 |
+
1. Download pretrained models [here](https://drive.google.com/drive/folders/1bBlds3UX7-XDCevbIl4bnnywvWzzP5nN) for evaluation.
|
| 21 |
+
2. Create new folders:
|
| 22 |
+
```
|
| 23 |
+
mkdir checkpoints & mkdir data
|
| 24 |
+
```
|
| 25 |
+
3. Organize the downloaded checkpoints like this:
|
| 26 |
+
```
|
| 27 |
+
SRPose
|
| 28 |
+
|-- checkpoints
|
| 29 |
+
|-- ho3d.ckpt
|
| 30 |
+
|-- linemod.ckpt
|
| 31 |
+
|-- mapfree.ckpt
|
| 32 |
+
|-- matterport.ckpt
|
| 33 |
+
|-- megadepth.ckpt
|
| 34 |
+
`-- scannet.ckpt
|
| 35 |
+
...
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
### Matterport
|
| 39 |
+
1. Download Matterport dataset [here](https://github.com/jinlinyi/SparsePlanes/blob/main/docs/data.md), only `mp3d_planercnn_json.zip` and `rgb.zip` are required.
|
| 40 |
+
2. Unzip and organize the downloaded files:
|
| 41 |
+
```
|
| 42 |
+
mkdir data/mp3d
|
| 43 |
+
mkdir data/mp3d/mp3d_planercnn_json & mkdir data/mp3d/rgb
|
| 44 |
+
unzip <pathto>/mp3d_planercnn_json.zip -d data/mp3d/mp3d_planercnn_json
|
| 45 |
+
unzip <pathto>/rgb.zip -d data/mp3d/rgb
|
| 46 |
+
```
|
| 47 |
+
3. The resulted directory tree should be like this:
|
| 48 |
+
```
|
| 49 |
+
SRPose
|
| 50 |
+
|-- data
|
| 51 |
+
|-- mp3d
|
| 52 |
+
|-- mp3d_planercnn_json
|
| 53 |
+
| |-- cached_set_test.json
|
| 54 |
+
| |-- cached_set_train.json
|
| 55 |
+
| `-- cached_set_val.json
|
| 56 |
+
`-- rgb
|
| 57 |
+
|-- 17DRP5sb8fy
|
| 58 |
+
...
|
| 59 |
+
...
|
| 60 |
+
...
|
| 61 |
+
```
|
| 62 |
+
4. Evaluate with the following command:
|
| 63 |
+
```
|
| 64 |
+
python eval.py configs/matterport.yaml checkpoints/matterport.ckpt
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### ScanNet & MegaDepth
|
| 68 |
+
1. Download and organize the ScanNet-1500 and MegaDepth-1500 test sets according to the [LoFTR Training Script](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md). Note that only the test sets and the dataset indices are required.
|
| 69 |
+
2. The resulted directory tree should be:
|
| 70 |
+
```
|
| 71 |
+
SRPose
|
| 72 |
+
|-- data
|
| 73 |
+
|-- scannet
|
| 74 |
+
| |-- index
|
| 75 |
+
| |-- test
|
| 76 |
+
| `-- train (optional)
|
| 77 |
+
|-- megadepth
|
| 78 |
+
|-- index
|
| 79 |
+
|-- test
|
| 80 |
+
`-- train (optional)
|
| 81 |
+
...
|
| 82 |
+
...
|
| 83 |
+
```
|
| 84 |
+
3. Evaluate with the following commands:
|
| 85 |
+
```
|
| 86 |
+
python eval.py configs/scannet.yaml checkpoints/scannet.ckpt
|
| 87 |
+
python eval.py configs/megadepth.yaml checkpoints/megedepth.ckpt
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### HO3D
|
| 91 |
+
1. Download HO3D (version 3) dataset [here](https://www.tugraz.at/institute/icg/research/team-lepetit/research-projects/hand-object-3d-pose-annotation/), `HO3D_v3.zip` and `HO3D_v3_segmentations_rendered.zip` are required.
|
| 92 |
+
2. Unzip and organize the downloaded files:
|
| 93 |
+
```
|
| 94 |
+
mkdir data/ho3d
|
| 95 |
+
unzip <pathto>/HO3D_v3.zip -d data/ho3d
|
| 96 |
+
unzip <pathto>/HO3D_v3_segmentations_rendered.zip -d data/ho3d
|
| 97 |
+
```
|
| 98 |
+
3. Evaluate with the following commands:
|
| 99 |
+
```
|
| 100 |
+
python eval.py configs/ho3d.yaml checkpoints/ho3d.ckpt
|
| 101 |
+
python eval_add_reproj.py configs/ho3d.yaml checkpoints/ho3d.ckpt
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
### Linemod
|
| 105 |
+
1. Download Linemod dataset [here](https://bop.felk.cvut.cz/datasets/) or run the following commands:
|
| 106 |
+
```
|
| 107 |
+
cd data
|
| 108 |
+
|
| 109 |
+
export SRC=https://bop.felk.cvut.cz/media/data/bop_datasets
|
| 110 |
+
wget $SRC/lm_base.zip # Base archive with dataset info, camera parameters, etc.
|
| 111 |
+
wget $SRC/lm_models.zip # 3D object models.
|
| 112 |
+
wget $SRC/lm_test_all.zip # All test images ("_bop19" for a subset used in the BOP Challenge 2019/2020).
|
| 113 |
+
wget $SRC/lm_train_pbr.zip # PBR training images (rendered with BlenderProc4BOP).
|
| 114 |
+
|
| 115 |
+
unzip lm_base.zip # Contains folder "lm".
|
| 116 |
+
unzip lm_models.zip -d lm # Unpacks to "lm".
|
| 117 |
+
unzip lm_test_all.zip -d lm # Unpacks to "lm".
|
| 118 |
+
unzip lm_train_pbr.zip -d lm # Unpacks to "lm".
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
2. Evaluate with the following commands:
|
| 122 |
+
```
|
| 123 |
+
python eval.py configs/linemod.yaml checkpoints/linemod.ckpt
|
| 124 |
+
python eval_add_reproj.py configs/linemod.yaml checkpoints/linemod.ckpt
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### Niantic
|
| 128 |
+
1. Download Niantic dataset [here](https://research.nianticlabs.com/mapfree-reloc-benchmark/dataset).
|
| 129 |
+
2. Unzip and organize the downloaded files:
|
| 130 |
+
```
|
| 131 |
+
mkdir data/mapfree
|
| 132 |
+
unzip <pathto>/train.zip -d data/mapfree
|
| 133 |
+
unzip <pathto>/val.zip -d data/mapfree
|
| 134 |
+
unzip <pathto>/test.zip -d data/mapfree
|
| 135 |
+
```
|
| 136 |
+
3. The ground truth of the test set is not publicly available, but you can run the following command to produce a new submission file and submit it on the [project page](https://research.nianticlabs.com/mapfree-reloc-benchmark/submit) for evaluation:
|
| 137 |
+
```
|
| 138 |
+
python eval_add_reproj.py configs/mapfree.yaml checkpoints/mapfree.ckpt
|
| 139 |
+
```
|
| 140 |
+
You should be able to find a `new_submission.zip` in `SRPose/assets/` afterwards, or you can submit the already produced file `SRPose/assets/mapfree_submission.zip` instead.
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
## Training
|
| 144 |
+
Download and organize the datasets following [Evaluation](#evaluation), then run the following command for training:
|
| 145 |
+
```
|
| 146 |
+
python train.py configs/<dataset>.yaml
|
| 147 |
+
```
|
| 148 |
+
Please refer to the `.yaml` files in `SRPose/configs/` for detailed configurations.
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
## Baselines
|
| 152 |
+
We also offer two publicly available matcher-based baselines, [LightGlue](https://github.com/cvg/LightGlue) and [LoFTR](https://github.com/zju3dv/LoFTR), for evaluation and comparison.
|
| 153 |
+
Just run the following commands:
|
| 154 |
+
```
|
| 155 |
+
# For Matterport, ScanNet and MegaDepth
|
| 156 |
+
python eval_baselines.py configs/<dataset>.yaml lightglue
|
| 157 |
+
python eval_baselines.py configs/<dataset>.yaml loftr
|
| 158 |
+
|
| 159 |
+
# For HO3D and Linemod
|
| 160 |
+
python eval_baselines.py configs/<dataset>.yaml lightglue --resize 640 --depth
|
| 161 |
+
python eval_baselines.py configs/<dataset>.yaml loftr --resize 640 --depth
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
The `--resize xx` option controls the larger dimension of cropped target object images that will be resized to.
|
| 165 |
+
The `--depth` option controls whether the depth maps will be used to obtain scaled pose estimation.
|
| 166 |
+
|
| 167 |
+
## Acknowledgements
|
| 168 |
+
In this repository, we have used codes from the following repositories. We thank all the authors for sharing great codes.
|
| 169 |
+
- [LightGlue](https://github.com/cvg/LightGlue)
|
| 170 |
+
- [LoFTR](https://github.com/zju3dv/LoFTR)
|
| 171 |
+
- [8point](https://github.com/crockwell/rel_pose)
|
| 172 |
+
- [SparsePlanes](https://github.com/jinlinyi/SparsePlanes/tree/main)
|
| 173 |
+
- [Map-free](https://github.com/nianticlabs/map-free-reloc/tree/main)
|
| 174 |
+
|
| 175 |
+
## Citation
|
| 176 |
+
```
|
| 177 |
+
@inproceedings{yin2024srpose,
|
| 178 |
+
title={SRPose: Two-view Relative Pose Estimation with Sparse Keypoints},
|
| 179 |
+
author={Yin, Rui and Zhang, Yulun and Pan, Zherong and Zhu, Jianjun and Wang, Cheng and Jia, Biao},
|
| 180 |
+
booktitle={ECCV},
|
| 181 |
+
year={2024}
|
| 182 |
+
}
|
| 183 |
+
```
|
assets/figures/obj_vis_gt.png
ADDED
|
assets/figures/obj_vis_query.png
ADDED
|
assets/figures/obj_vis_reference_labeled.png
ADDED
|
assets/figures/scene5_vis_0.png
ADDED
|
assets/figures/scene5_vis_1.png
ADDED
|
assets/figures/scene5_vis_gt.png
ADDED
|
assets/ho3d_test_3000/ho3d_test.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
assets/linemod_test_1500/linemod_test.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
assets/mapfree_submission.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:026e799de5bf9eed2a1f627e64d6981f983cb729337036390481c402c47fbc5c
|
| 3 |
+
size 6569663
|
assets/megadepth_test_1500_scene_info/0015_0.1_0.3.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d441df1d380b2ed34449b944d9f13127e695542fa275098d38a6298835672f22
|
| 3 |
+
size 231253
|
assets/megadepth_test_1500_scene_info/0015_0.3_0.5.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f34b5231d04a84d84378c671dd26854869663b5eafeae2ebaf624a279325139
|
| 3 |
+
size 231253
|
assets/megadepth_test_1500_scene_info/0022_0.1_0.3.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba46e6b9ec291fc7271eb9741d5c75ca04b83d3d7281e049815de9cb9024f4d9
|
| 3 |
+
size 272610
|
assets/megadepth_test_1500_scene_info/0022_0.3_0.5.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f4465da174b96deba61e5328886e4f2e687d34b890efca69e0c838736f8ae12
|
| 3 |
+
size 272610
|
assets/megadepth_test_1500_scene_info/0022_0.5_0.7.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:684ae10f03001917c3ca0d12d441f372ce3c7e6637bd1277a3cda60df4207fe9
|
| 3 |
+
size 272610
|
assets/megadepth_test_1500_scene_info/megadepth_test_1500.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0022_0.1_0.3
|
| 2 |
+
0015_0.1_0.3
|
| 3 |
+
0015_0.3_0.5
|
| 4 |
+
0022_0.3_0.5
|
| 5 |
+
0022_0.5_0.7
|
assets/scannet_test_1500/intrinsics.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:25ac102c69e2e4e2f0ab9c0d64f4da2b815e0901630768bdfde30080ced3605c
|
| 3 |
+
size 23922
|
assets/scannet_test_1500/scannet_test.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
test.npz
|
assets/scannet_test_1500/statistics.json
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"scene0707_00": 15,
|
| 3 |
+
"scene0708_00": 15,
|
| 4 |
+
"scene0709_00": 15,
|
| 5 |
+
"scene0710_00": 15,
|
| 6 |
+
"scene0711_00": 15,
|
| 7 |
+
"scene0712_00": 15,
|
| 8 |
+
"scene0713_00": 15,
|
| 9 |
+
"scene0714_00": 15,
|
| 10 |
+
"scene0715_00": 15,
|
| 11 |
+
"scene0716_00": 15,
|
| 12 |
+
"scene0717_00": 15,
|
| 13 |
+
"scene0718_00": 15,
|
| 14 |
+
"scene0719_00": 15,
|
| 15 |
+
"scene0720_00": 15,
|
| 16 |
+
"scene0721_00": 15,
|
| 17 |
+
"scene0722_00": 15,
|
| 18 |
+
"scene0723_00": 15,
|
| 19 |
+
"scene0724_00": 15,
|
| 20 |
+
"scene0725_00": 15,
|
| 21 |
+
"scene0726_00": 15,
|
| 22 |
+
"scene0727_00": 15,
|
| 23 |
+
"scene0728_00": 15,
|
| 24 |
+
"scene0729_00": 15,
|
| 25 |
+
"scene0730_00": 15,
|
| 26 |
+
"scene0731_00": 15,
|
| 27 |
+
"scene0732_00": 15,
|
| 28 |
+
"scene0733_00": 15,
|
| 29 |
+
"scene0734_00": 15,
|
| 30 |
+
"scene0735_00": 15,
|
| 31 |
+
"scene0736_00": 15,
|
| 32 |
+
"scene0737_00": 15,
|
| 33 |
+
"scene0738_00": 15,
|
| 34 |
+
"scene0739_00": 15,
|
| 35 |
+
"scene0740_00": 15,
|
| 36 |
+
"scene0741_00": 15,
|
| 37 |
+
"scene0742_00": 15,
|
| 38 |
+
"scene0743_00": 15,
|
| 39 |
+
"scene0744_00": 15,
|
| 40 |
+
"scene0745_00": 15,
|
| 41 |
+
"scene0746_00": 15,
|
| 42 |
+
"scene0747_00": 15,
|
| 43 |
+
"scene0748_00": 15,
|
| 44 |
+
"scene0749_00": 15,
|
| 45 |
+
"scene0750_00": 15,
|
| 46 |
+
"scene0751_00": 15,
|
| 47 |
+
"scene0752_00": 15,
|
| 48 |
+
"scene0753_00": 15,
|
| 49 |
+
"scene0754_00": 15,
|
| 50 |
+
"scene0755_00": 15,
|
| 51 |
+
"scene0756_00": 15,
|
| 52 |
+
"scene0757_00": 15,
|
| 53 |
+
"scene0758_00": 15,
|
| 54 |
+
"scene0759_00": 15,
|
| 55 |
+
"scene0760_00": 15,
|
| 56 |
+
"scene0761_00": 15,
|
| 57 |
+
"scene0762_00": 15,
|
| 58 |
+
"scene0763_00": 15,
|
| 59 |
+
"scene0764_00": 15,
|
| 60 |
+
"scene0765_00": 15,
|
| 61 |
+
"scene0766_00": 15,
|
| 62 |
+
"scene0767_00": 15,
|
| 63 |
+
"scene0768_00": 15,
|
| 64 |
+
"scene0769_00": 15,
|
| 65 |
+
"scene0770_00": 15,
|
| 66 |
+
"scene0771_00": 15,
|
| 67 |
+
"scene0772_00": 15,
|
| 68 |
+
"scene0773_00": 15,
|
| 69 |
+
"scene0774_00": 15,
|
| 70 |
+
"scene0775_00": 15,
|
| 71 |
+
"scene0776_00": 15,
|
| 72 |
+
"scene0777_00": 15,
|
| 73 |
+
"scene0778_00": 15,
|
| 74 |
+
"scene0779_00": 15,
|
| 75 |
+
"scene0780_00": 15,
|
| 76 |
+
"scene0781_00": 15,
|
| 77 |
+
"scene0782_00": 15,
|
| 78 |
+
"scene0783_00": 15,
|
| 79 |
+
"scene0784_00": 15,
|
| 80 |
+
"scene0785_00": 15,
|
| 81 |
+
"scene0786_00": 15,
|
| 82 |
+
"scene0787_00": 15,
|
| 83 |
+
"scene0788_00": 15,
|
| 84 |
+
"scene0789_00": 15,
|
| 85 |
+
"scene0790_00": 15,
|
| 86 |
+
"scene0791_00": 15,
|
| 87 |
+
"scene0792_00": 15,
|
| 88 |
+
"scene0793_00": 15,
|
| 89 |
+
"scene0794_00": 15,
|
| 90 |
+
"scene0795_00": 15,
|
| 91 |
+
"scene0796_00": 15,
|
| 92 |
+
"scene0797_00": 15,
|
| 93 |
+
"scene0798_00": 15,
|
| 94 |
+
"scene0799_00": 15,
|
| 95 |
+
"scene0800_00": 15,
|
| 96 |
+
"scene0801_00": 15,
|
| 97 |
+
"scene0802_00": 15,
|
| 98 |
+
"scene0803_00": 15,
|
| 99 |
+
"scene0804_00": 15,
|
| 100 |
+
"scene0805_00": 15,
|
| 101 |
+
"scene0806_00": 15
|
| 102 |
+
}
|
assets/scannet_test_1500/test.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b982b9c1f762e7d31af552ecc1ccf1a6add013197f74ec69c84a6deaa6f580ad
|
| 3 |
+
size 71687
|
baselines/matchers.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
from lightglue import LightGlue as LightGlue_
|
| 5 |
+
from lightglue import SuperPoint
|
| 6 |
+
from lightglue.utils import rbd
|
| 7 |
+
from kornia.feature import LoFTR as LoFTR_
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def image_rgb2gray(image):
|
| 11 |
+
# in: torch.tensor - (3, H, W)
|
| 12 |
+
# out: (1, H, W)
|
| 13 |
+
image = image[0] * 0.3 + image[1] * 0.59 + image[2] * 0.11
|
| 14 |
+
return image[None]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LightGlue():
|
| 18 |
+
def __init__(self, num_keypoints=2048, device='cuda'):
|
| 19 |
+
self.extractor = SuperPoint(max_num_keypoints=num_keypoints).eval().to(device) # load the extractor
|
| 20 |
+
self.matcher = LightGlue_(features='superpoint').eval().to(device) # load the matcher
|
| 21 |
+
self.device = device
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def match(self, image0, image1):
|
| 25 |
+
start_time = time.time()
|
| 26 |
+
|
| 27 |
+
# image: torch.tensor - (3, H, W)
|
| 28 |
+
image0 = image0.to(self.device)
|
| 29 |
+
image1 = image1.to(self.device)
|
| 30 |
+
|
| 31 |
+
preprocess_time = time.time()
|
| 32 |
+
|
| 33 |
+
# extract local features
|
| 34 |
+
feats0 = self.extractor.extract(image0) # auto-resize the image, disable with resize=None
|
| 35 |
+
feats1 = self.extractor.extract(image1)
|
| 36 |
+
|
| 37 |
+
extract_time = time.time()
|
| 38 |
+
|
| 39 |
+
# match the features
|
| 40 |
+
matches01 = self.matcher({'image0': feats0, 'image1': feats1})
|
| 41 |
+
feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension
|
| 42 |
+
matches = matches01['matches'] # indices with shape (K,2)
|
| 43 |
+
points0 = feats0['keypoints'][matches[..., 0]] # coordinates in image #0, shape (K,2)
|
| 44 |
+
points1 = feats1['keypoints'][matches[..., 1]] # coordinates in image #1, shape (K,2)
|
| 45 |
+
|
| 46 |
+
match_time = time.time()
|
| 47 |
+
|
| 48 |
+
return points0, points1, preprocess_time-start_time, extract_time-preprocess_time, match_time-extract_time
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LoFTR():
|
| 52 |
+
def __init__(self, pretrained='indoor', device='cuda'):
|
| 53 |
+
self.loftr = LoFTR_(pretrained=pretrained).eval().to(device)
|
| 54 |
+
self.device = device
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def match(self, image0, image1):
|
| 58 |
+
start_time = time.time()
|
| 59 |
+
|
| 60 |
+
# image: torch.tensor - (3, H, W)
|
| 61 |
+
image0 = image_rgb2gray(image0)[None].to(self.device)
|
| 62 |
+
image1 = image_rgb2gray(image1)[None].to(self.device)
|
| 63 |
+
|
| 64 |
+
preprocess_time = time.time()
|
| 65 |
+
|
| 66 |
+
extract_time = time.time()
|
| 67 |
+
|
| 68 |
+
out = self.loftr({'image0': image0, 'image1': image1})
|
| 69 |
+
points0, points1 = out['keypoints0'], out['keypoints1']
|
| 70 |
+
|
| 71 |
+
match_time = time.time()
|
| 72 |
+
return points0, points1, preprocess_time-start_time, extract_time-preprocess_time, match_time-extract_time
|
baselines/pose.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.transforms import Resize
|
| 3 |
+
|
| 4 |
+
from .matchers import LightGlue, LoFTR
|
| 5 |
+
# from .__models import SuperGlue, SGMNet, ASpanFormer, DKM
|
| 6 |
+
from .pose_solver import EssentialMatrixSolver, EssentialMatrixMetricSolver, PnPSolver, ProcrustesSolver
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PoseRecover():
|
| 12 |
+
def __init__(self, matcher='lightglue', solver='procrustes', img_resize=None, device='cuda'):
|
| 13 |
+
self.device = device
|
| 14 |
+
|
| 15 |
+
if matcher == 'lightglue':
|
| 16 |
+
self.matcher = LightGlue(device=device)
|
| 17 |
+
elif matcher == 'loftr':
|
| 18 |
+
self.matcher = LoFTR(device=device)
|
| 19 |
+
# elif matcher == 'superglue':
|
| 20 |
+
# self.matcher = SuperGlue(device=device)
|
| 21 |
+
# elif matcher == 'aspanformer':
|
| 22 |
+
# self.matcher = ASpanFormer(device=device)
|
| 23 |
+
# elif matcher == 'sgmnet':
|
| 24 |
+
# self.matcher = SGMNet(device=device)
|
| 25 |
+
# elif matcher == 'dkm':
|
| 26 |
+
# self.matcher = DKM(device=device)
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
|
| 30 |
+
self.img_resize = img_resize
|
| 31 |
+
|
| 32 |
+
self.basic_solver = EssentialMatrixSolver()
|
| 33 |
+
|
| 34 |
+
if solver == 'essential':
|
| 35 |
+
self.scaled_solver = EssentialMatrixMetricSolver()
|
| 36 |
+
elif solver == 'pnp':
|
| 37 |
+
self.scaled_solver = PnPSolver()
|
| 38 |
+
elif solver == 'procrustes':
|
| 39 |
+
self.scaled_solver = ProcrustesSolver()
|
| 40 |
+
|
| 41 |
+
def recover(self, image0, image1, K0, K1, bbox0=None, bbox1=None, mask0=None, mask1=None, depth0=None, depth1=None):
|
| 42 |
+
if self.img_resize is not None:
|
| 43 |
+
h, w = image0.shape[-2:]
|
| 44 |
+
if h > w:
|
| 45 |
+
h_new = self.img_resize
|
| 46 |
+
w_new = int(w * h_new / h)
|
| 47 |
+
else:
|
| 48 |
+
w_new = self.img_resize
|
| 49 |
+
h_new = int(h * w_new / w)
|
| 50 |
+
|
| 51 |
+
# h_new, w_new = 480, 640
|
| 52 |
+
resize = Resize((h_new, w_new), antialias=True)
|
| 53 |
+
scale0 = torch.tensor([image0.shape[-1]/w_new, image0.shape[-2]/h_new], dtype=torch.float)
|
| 54 |
+
scale1 = torch.tensor([image1.shape[-1]/w_new, image1.shape[-2]/h_new], dtype=torch.float)
|
| 55 |
+
image0 = resize(image0)
|
| 56 |
+
image1 = resize(image1)
|
| 57 |
+
|
| 58 |
+
points0, points1, preprocess_time, extract_time, match_time = self.matcher.match(image0, image1)
|
| 59 |
+
|
| 60 |
+
if self.img_resize is not None:
|
| 61 |
+
points0 *= scale0.unsqueeze(0).to(points0.device)
|
| 62 |
+
points1 *= scale1.unsqueeze(0).to(points1.device)
|
| 63 |
+
|
| 64 |
+
if bbox0 is not None and bbox1 is not None:
|
| 65 |
+
x1, y1, x2, y2 = bbox0
|
| 66 |
+
u1, v1, u2, v2 = bbox1
|
| 67 |
+
|
| 68 |
+
points0[:, 0] += x1
|
| 69 |
+
points0[:, 1] += y1
|
| 70 |
+
|
| 71 |
+
points1[:, 0] += u1
|
| 72 |
+
points1[:, 1] += v1
|
| 73 |
+
|
| 74 |
+
if mask0 is not None and mask1 is not None:
|
| 75 |
+
filtered_ind0 = mask0[(points0[:, 1]).int(), (points0[:, 0]).int()]
|
| 76 |
+
filtered_ind1 = mask1[(points1[:, 1]).int(), (points1[:, 0]).int()]
|
| 77 |
+
filtered_inds = filtered_ind0 * filtered_ind1
|
| 78 |
+
points0 = points0[filtered_inds]
|
| 79 |
+
points1 = points1[filtered_inds]
|
| 80 |
+
|
| 81 |
+
points0, points1 = points0.cpu().numpy(), points1.cpu().numpy()
|
| 82 |
+
|
| 83 |
+
start_time = time.time()
|
| 84 |
+
|
| 85 |
+
if depth0 is None or depth1 is None:
|
| 86 |
+
R_est, t_est, _ = self.basic_solver.estimate_pose(points0, points1, {'K_color0': K0, 'K_color1': K1})
|
| 87 |
+
else:
|
| 88 |
+
R_est, t_est, _ = self.scaled_solver.estimate_pose(points0, points1, {'K_color0': K0, 'K_color1': K1, 'depth0': depth0, 'depth1': depth1})
|
| 89 |
+
|
| 90 |
+
recover_time = time.time()
|
| 91 |
+
|
| 92 |
+
return R_est, t_est, points0, points1, preprocess_time, extract_time, match_time, recover_time-start_time
|
baselines/pose_solver.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2 as cv
|
| 3 |
+
import open3d as o3d
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def backproject_3d(uv, depth, K):
|
| 7 |
+
'''
|
| 8 |
+
Backprojects 2d points given by uv coordinates into 3D using their depth values and intrinsic K
|
| 9 |
+
:param uv: array [N,2]
|
| 10 |
+
:param depth: array [N]
|
| 11 |
+
:param K: array [3,3]
|
| 12 |
+
:return: xyz: array [N,3]
|
| 13 |
+
'''
|
| 14 |
+
|
| 15 |
+
uv1 = np.concatenate([uv, np.ones((uv.shape[0], 1))], axis=1)
|
| 16 |
+
xyz = depth.reshape(-1, 1) * (np.linalg.inv(K) @ uv1.T).T
|
| 17 |
+
return xyz
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class EssentialMatrixSolver:
|
| 21 |
+
'''Obtain relative pose (up to scale) given a set of 2D-2D correspondences'''
|
| 22 |
+
|
| 23 |
+
def __init__(self, ransac_pix_threshold=0.5, ransac_confidence=0.99999):
|
| 24 |
+
|
| 25 |
+
# EMat RANSAC parameters
|
| 26 |
+
self.ransac_pix_threshold = ransac_pix_threshold
|
| 27 |
+
self.ransac_confidence = ransac_confidence
|
| 28 |
+
|
| 29 |
+
def estimate_pose(self, kpts0, kpts1, data):
|
| 30 |
+
R = np.full((3, 3), np.nan)
|
| 31 |
+
t = np.full((3), np.nan)
|
| 32 |
+
if len(kpts0) < 5:
|
| 33 |
+
return R, t, 0
|
| 34 |
+
|
| 35 |
+
K0 = data['K_color0'].numpy()
|
| 36 |
+
K1 = data['K_color1'].numpy()
|
| 37 |
+
|
| 38 |
+
# normalize keypoints
|
| 39 |
+
kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
|
| 40 |
+
kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
|
| 41 |
+
|
| 42 |
+
# normalize ransac threshold
|
| 43 |
+
ransac_thr = self.ransac_pix_threshold / np.mean([K0[0, 0], K1[1, 1], K0[1, 1], K1[0, 0]])
|
| 44 |
+
|
| 45 |
+
# compute pose with OpenCV
|
| 46 |
+
E, mask = cv.findEssentialMat(
|
| 47 |
+
kpts0, kpts1, np.eye(3),
|
| 48 |
+
threshold=ransac_thr, prob=self.ransac_confidence, method=cv.RANSAC)
|
| 49 |
+
self.mask = mask
|
| 50 |
+
if E is None:
|
| 51 |
+
return R, t, 0
|
| 52 |
+
|
| 53 |
+
# recover pose from E
|
| 54 |
+
best_num_inliers = 0
|
| 55 |
+
ret = R, t, 0
|
| 56 |
+
for _E in np.split(E, len(E) / 3):
|
| 57 |
+
n, R, t, _ = cv.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
|
| 58 |
+
if n > best_num_inliers:
|
| 59 |
+
best_num_inliers = n
|
| 60 |
+
ret = (R, t[:, 0], n)
|
| 61 |
+
return ret
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class EssentialMatrixMetricSolverMEAN(EssentialMatrixSolver):
|
| 65 |
+
'''Obtains relative pose with scale using E-Mat decomposition and depth values at inlier correspondences'''
|
| 66 |
+
|
| 67 |
+
def __init__(self, cfg):
|
| 68 |
+
super().__init__(cfg)
|
| 69 |
+
|
| 70 |
+
def estimate_pose(self, kpts0, kpts1, data):
|
| 71 |
+
'''Estimates metric translation vector using by back-projecting E-mat inliers to 3D using depthmaps.
|
| 72 |
+
The metric translation vector can be obtained by looking at the residual vector (projected to the translation vector direction).
|
| 73 |
+
In this version, each 3D-3D correspondence gives an optimal scale for the translation vector.
|
| 74 |
+
We simply aggregate them by averaging them.
|
| 75 |
+
'''
|
| 76 |
+
|
| 77 |
+
# get pose up to scale
|
| 78 |
+
R, t, inliers = super().estimate_pose(kpts0, kpts1, data)
|
| 79 |
+
if inliers == 0:
|
| 80 |
+
return R, t, inliers
|
| 81 |
+
|
| 82 |
+
# backproject E-mat inliers at each camera
|
| 83 |
+
K0 = data['K_color0']
|
| 84 |
+
K1 = data['K_color1']
|
| 85 |
+
mask = self.mask.ravel() == 1 # get E-mat inlier mask from super class
|
| 86 |
+
inliers_kpts0 = np.int32(kpts0[mask])
|
| 87 |
+
inliers_kpts1 = np.int32(kpts1[mask])
|
| 88 |
+
depth_inliers_0 = data['depth0'][inliers_kpts0[:, 1], inliers_kpts0[:, 0]].numpy()
|
| 89 |
+
depth_inliers_1 = data['depth1'][inliers_kpts1[:, 1], inliers_kpts1[:, 0]].numpy()
|
| 90 |
+
# check for valid depth
|
| 91 |
+
valid = (depth_inliers_0 > 0) * (depth_inliers_1 > 0)
|
| 92 |
+
if valid.sum() < 1:
|
| 93 |
+
R = np.full((3, 3), np.nan)
|
| 94 |
+
t = np.full((3, 1), np.nan)
|
| 95 |
+
inliers = 0
|
| 96 |
+
return R, t, inliers
|
| 97 |
+
xyz0 = backproject_3d(inliers_kpts0[valid], depth_inliers_0[valid], K0)
|
| 98 |
+
xyz1 = backproject_3d(inliers_kpts1[valid], depth_inliers_1[valid], K1)
|
| 99 |
+
|
| 100 |
+
# rotate xyz0 to xyz1 CS (so that axes are parallel)
|
| 101 |
+
xyz0 = (R @ xyz0.T).T
|
| 102 |
+
|
| 103 |
+
# get average point for each camera
|
| 104 |
+
pmean0 = np.mean(xyz0, axis=0)
|
| 105 |
+
pmean1 = np.mean(xyz1, axis=0)
|
| 106 |
+
|
| 107 |
+
# find scale as the 'length' of the translation vector that minimises the 3D distance between projected points from 0 and the corresponding points in 1
|
| 108 |
+
scale = np.dot(pmean1 - pmean0, t)
|
| 109 |
+
t_metric = scale * t
|
| 110 |
+
t_metric = t_metric.reshape(3, 1)
|
| 111 |
+
|
| 112 |
+
return R, t_metric[:, 0], inliers
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class EssentialMatrixMetricSolver(EssentialMatrixSolver):
|
| 116 |
+
'''
|
| 117 |
+
Obtains relative pose with scale using E-Mat decomposition and RANSAC for scale based on depth values at inlier correspondences.
|
| 118 |
+
The scale of the translation vector is obtained using RANSAC over the possible scales recovered from 3D-3D correspondences.
|
| 119 |
+
'''
|
| 120 |
+
|
| 121 |
+
def __init__(self, ransac_pix_threshold=0.5, ransac_confidence=0.99999, ransac_scale_threshold=0.1):
|
| 122 |
+
super().__init__(ransac_pix_threshold, ransac_confidence)
|
| 123 |
+
self.ransac_scale_threshold = ransac_scale_threshold
|
| 124 |
+
|
| 125 |
+
def estimate_pose(self, kpts0, kpts1, data):
|
| 126 |
+
'''Estimates metric translation vector using by back-projecting E-mat inliers to 3D using depthmaps.
|
| 127 |
+
'''
|
| 128 |
+
|
| 129 |
+
# get pose up to scale
|
| 130 |
+
R, t, inliers = super().estimate_pose(kpts0, kpts1, data)
|
| 131 |
+
if inliers == 0:
|
| 132 |
+
return R, t, inliers
|
| 133 |
+
|
| 134 |
+
# backproject E-mat inliers at each camera
|
| 135 |
+
K0 = data['K_color0']
|
| 136 |
+
K1 = data['K_color1']
|
| 137 |
+
mask = self.mask.ravel() == 1 # get E-mat inlier mask from super class
|
| 138 |
+
inliers_kpts0 = np.int32(kpts0[mask])
|
| 139 |
+
inliers_kpts1 = np.int32(kpts1[mask])
|
| 140 |
+
depth_inliers_0 = data['depth0'][inliers_kpts0[:, 1], inliers_kpts0[:, 0]].numpy()
|
| 141 |
+
depth_inliers_1 = data['depth1'][inliers_kpts1[:, 1], inliers_kpts1[:, 0]].numpy()
|
| 142 |
+
|
| 143 |
+
# check for valid depth
|
| 144 |
+
valid = (depth_inliers_0 > 0) * (depth_inliers_1 > 0)
|
| 145 |
+
if valid.sum() < 1:
|
| 146 |
+
R = np.full((3, 3), np.nan)
|
| 147 |
+
t = np.full((3, ), np.nan)
|
| 148 |
+
inliers = 0
|
| 149 |
+
return R, t, inliers
|
| 150 |
+
xyz0 = backproject_3d(inliers_kpts0[valid], depth_inliers_0[valid], K0)
|
| 151 |
+
xyz1 = backproject_3d(inliers_kpts1[valid], depth_inliers_1[valid], K1)
|
| 152 |
+
|
| 153 |
+
# rotate xyz0 to xyz1 CS (so that axes are parallel)
|
| 154 |
+
xyz0 = (R @ xyz0.T).T
|
| 155 |
+
|
| 156 |
+
# get individual scales (for each 3D-3D correspondence)
|
| 157 |
+
scale = np.dot(xyz1 - xyz0, t.reshape(3, 1)) # [N, 1]
|
| 158 |
+
|
| 159 |
+
# RANSAC loop
|
| 160 |
+
best_inliers = 0
|
| 161 |
+
best_scale = None
|
| 162 |
+
for scale_hyp in scale:
|
| 163 |
+
inliers_hyp = (np.abs(scale - scale_hyp) < self.ransac_scale_threshold).sum().item()
|
| 164 |
+
if inliers_hyp > best_inliers:
|
| 165 |
+
best_scale = scale_hyp
|
| 166 |
+
best_inliers = inliers_hyp
|
| 167 |
+
|
| 168 |
+
# Output results
|
| 169 |
+
t_metric = best_scale * t
|
| 170 |
+
t_metric = t_metric.reshape(3, 1)
|
| 171 |
+
|
| 172 |
+
return R, t_metric[:, 0], best_inliers
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class PnPSolver:
|
| 176 |
+
'''Estimate relative pose (metric) using Perspective-n-Point algorithm (2D-3D) correspondences'''
|
| 177 |
+
|
| 178 |
+
def __init__(self, ransac_iterations=1000, reprojection_inlier_threshold=3, confidence=0.99999):
|
| 179 |
+
# PnP RANSAC parameters
|
| 180 |
+
self.ransac_iterations = ransac_iterations
|
| 181 |
+
self.reprojection_inlier_threshold = reprojection_inlier_threshold
|
| 182 |
+
self.confidence = confidence
|
| 183 |
+
|
| 184 |
+
def estimate_pose(self, pts0, pts1, data):
|
| 185 |
+
# uses nearest neighbour
|
| 186 |
+
pts0 = np.int32(pts0)
|
| 187 |
+
|
| 188 |
+
if len(pts0) < 4:
|
| 189 |
+
return np.full((3, 3), np.nan), np.full((3, 1), np.nan), 0
|
| 190 |
+
|
| 191 |
+
# get depth at correspondence points
|
| 192 |
+
depth_0 = data['depth0']
|
| 193 |
+
depth_pts0 = depth_0[pts0[:, 1], pts0[:, 0]]
|
| 194 |
+
|
| 195 |
+
# remove invalid pts (depth == 0)
|
| 196 |
+
valid = depth_pts0 > depth_0.min()
|
| 197 |
+
if valid.sum() < 4:
|
| 198 |
+
return np.full((3, 3), np.nan), np.full((3, 1), np.nan), 0
|
| 199 |
+
pts0 = pts0[valid]
|
| 200 |
+
pts1 = pts1[valid]
|
| 201 |
+
depth_pts0 = depth_pts0[valid]
|
| 202 |
+
|
| 203 |
+
# backproject points to 3D in each sensors' local coordinates
|
| 204 |
+
K0 = data['K_color0']
|
| 205 |
+
K1 = data['K_color1']
|
| 206 |
+
xyz_0 = backproject_3d(pts0, depth_pts0, K0).numpy()
|
| 207 |
+
|
| 208 |
+
# get relative pose using PnP + RANSAC
|
| 209 |
+
succ, rvec, tvec, inliers = cv.solvePnPRansac(
|
| 210 |
+
xyz_0, pts1, K1.numpy(),
|
| 211 |
+
None, iterationsCount=self.ransac_iterations,
|
| 212 |
+
reprojectionError=self.reprojection_inlier_threshold, confidence=self.confidence,
|
| 213 |
+
flags=cv.SOLVEPNP_P3P)
|
| 214 |
+
|
| 215 |
+
# refine with iterative PnP using inliers only
|
| 216 |
+
if succ and len(inliers) >= 6:
|
| 217 |
+
succ, rvec, tvec, _ = cv.solvePnPGeneric(xyz_0[inliers], pts1[inliers], K1.numpy(
|
| 218 |
+
), None, useExtrinsicGuess=True, rvec=rvec, tvec=tvec, flags=cv.SOLVEPNP_ITERATIVE)
|
| 219 |
+
rvec = rvec[0]
|
| 220 |
+
tvec = tvec[0]
|
| 221 |
+
|
| 222 |
+
# avoid degenerate solutions
|
| 223 |
+
if succ:
|
| 224 |
+
if np.linalg.norm(tvec) > 1000:
|
| 225 |
+
succ = False
|
| 226 |
+
|
| 227 |
+
if succ:
|
| 228 |
+
R, _ = cv.Rodrigues(rvec)
|
| 229 |
+
t = tvec.reshape(3, 1)
|
| 230 |
+
else:
|
| 231 |
+
R = np.full((3, 3), np.nan)
|
| 232 |
+
t = np.full((3, 1), np.nan)
|
| 233 |
+
inliers = []
|
| 234 |
+
|
| 235 |
+
return R, t[:, 0], inliers
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class ProcrustesSolver:
|
| 239 |
+
'''Estimate relative pose (metric) using 3D-3D correspondences'''
|
| 240 |
+
|
| 241 |
+
def __init__(self, ransac_max_corr_distance=0.5, refine=False):
|
| 242 |
+
|
| 243 |
+
# Procrustes RANSAC parameters
|
| 244 |
+
self.ransac_max_corr_distance = ransac_max_corr_distance
|
| 245 |
+
self.refine = refine
|
| 246 |
+
|
| 247 |
+
def estimate_pose(self, pts0, pts1, data):
|
| 248 |
+
# uses nearest neighbour
|
| 249 |
+
pts0 = np.int32(pts0)
|
| 250 |
+
pts1 = np.int32(pts1)
|
| 251 |
+
|
| 252 |
+
if len(pts0) < 3:
|
| 253 |
+
return np.full((3, 3), np.nan), np.full((3), np.nan), 0
|
| 254 |
+
|
| 255 |
+
# get depth at correspondence points
|
| 256 |
+
depth_0, depth_1 = data['depth0'], data['depth1']
|
| 257 |
+
depth_pts0 = depth_0[pts0[:, 1], pts0[:, 0]]
|
| 258 |
+
depth_pts1 = depth_1[pts1[:, 1], pts1[:, 0]]
|
| 259 |
+
|
| 260 |
+
# remove invalid pts (depth == 0)
|
| 261 |
+
valid = (depth_pts0 > depth_0.min()) * (depth_pts1 > depth_1.min())
|
| 262 |
+
if valid.sum() < 3:
|
| 263 |
+
return np.full((3, 3), np.nan), np.full((3), np.nan), 0
|
| 264 |
+
pts0 = pts0[valid]
|
| 265 |
+
pts1 = pts1[valid]
|
| 266 |
+
depth_pts0 = depth_pts0[valid]
|
| 267 |
+
depth_pts1 = depth_pts1[valid]
|
| 268 |
+
|
| 269 |
+
# backproject points to 3D in each sensors' local coordinates
|
| 270 |
+
K0 = data['K_color0']
|
| 271 |
+
K1 = data['K_color1']
|
| 272 |
+
xyz_0 = backproject_3d(pts0, depth_pts0, K0)
|
| 273 |
+
xyz_1 = backproject_3d(pts1, depth_pts1, K1)
|
| 274 |
+
|
| 275 |
+
# create open3d point cloud objects and correspondences idxs
|
| 276 |
+
pcl_0 = o3d.geometry.PointCloud()
|
| 277 |
+
pcl_0.points = o3d.utility.Vector3dVector(xyz_0)
|
| 278 |
+
pcl_1 = o3d.geometry.PointCloud()
|
| 279 |
+
pcl_1.points = o3d.utility.Vector3dVector(xyz_1)
|
| 280 |
+
corr_idx = np.arange(pts0.shape[0])
|
| 281 |
+
corr_idx = np.tile(corr_idx.reshape(-1, 1), (1, 2))
|
| 282 |
+
corr_idx = o3d.utility.Vector2iVector(corr_idx)
|
| 283 |
+
|
| 284 |
+
# obtain relative pose using procrustes
|
| 285 |
+
ransac_criteria = o3d.pipelines.registration.RANSACConvergenceCriteria()
|
| 286 |
+
res = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
|
| 287 |
+
pcl_0, pcl_1, corr_idx, self.ransac_max_corr_distance, criteria=ransac_criteria)
|
| 288 |
+
inliers = int(res.fitness * np.asarray(pcl_1.points).shape[0])
|
| 289 |
+
|
| 290 |
+
# refine with ICP
|
| 291 |
+
if self.refine:
|
| 292 |
+
# first, backproject both (whole) point clouds
|
| 293 |
+
vv, uu = np.mgrid[0:depth_0.shape[0], 0:depth_1.shape[1]]
|
| 294 |
+
uv_coords = np.concatenate([uu.reshape(-1, 1), vv.reshape(-1, 1)], axis=1)
|
| 295 |
+
|
| 296 |
+
valid = depth_0.reshape(-1) > 0
|
| 297 |
+
xyz_0 = backproject_3d(uv_coords[valid], depth_0.reshape(-1)[valid], K0)
|
| 298 |
+
|
| 299 |
+
valid = depth_1.reshape(-1) > 0
|
| 300 |
+
xyz_1 = backproject_3d(uv_coords[valid], depth_1.reshape(-1)[valid], K1)
|
| 301 |
+
|
| 302 |
+
pcl_0 = o3d.geometry.PointCloud()
|
| 303 |
+
pcl_0.points = o3d.utility.Vector3dVector(xyz_0)
|
| 304 |
+
pcl_1 = o3d.geometry.PointCloud()
|
| 305 |
+
pcl_1.points = o3d.utility.Vector3dVector(xyz_1)
|
| 306 |
+
|
| 307 |
+
icp_criteria = o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-4,
|
| 308 |
+
relative_rmse=1e-4,
|
| 309 |
+
max_iteration=30)
|
| 310 |
+
|
| 311 |
+
res = o3d.pipelines.registration.registration_icp(pcl_0,
|
| 312 |
+
pcl_1,
|
| 313 |
+
self.ransac_max_corr_distance,
|
| 314 |
+
init=res.transformation,
|
| 315 |
+
criteria=icp_criteria)
|
| 316 |
+
|
| 317 |
+
R = res.transformation[:3, :3]
|
| 318 |
+
t = res.transformation[:3, -1]
|
| 319 |
+
inliers = int(res.fitness * np.asarray(pcl_1.points).shape[0])
|
| 320 |
+
return R, t, inliers
|
configs/default.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from yacs.config import CfgNode as CN
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
_CN = CN()
|
| 5 |
+
|
| 6 |
+
# Model
|
| 7 |
+
_CN.MODEL = CN()
|
| 8 |
+
_CN.MODEL.NUM_KEYPOINTS = 1024
|
| 9 |
+
_CN.MODEL.TEST_NUM_KEYPOINTS = 2048
|
| 10 |
+
_CN.MODEL.N_LAYERS = 6
|
| 11 |
+
_CN.MODEL.NUM_HEADS = 4
|
| 12 |
+
_CN.MODEL.FEATURES = 'superpoint'
|
| 13 |
+
|
| 14 |
+
# Dataset
|
| 15 |
+
_CN.DATASET = CN()
|
| 16 |
+
_CN.DATASET.TASK = None
|
| 17 |
+
_CN.DATASET.DATA_SOURCE = None
|
| 18 |
+
_CN.DATASET.DATA_ROOT = None
|
| 19 |
+
_CN.DATASET.MIN_OVERLAP_SCORE = None
|
| 20 |
+
|
| 21 |
+
## For MapFree
|
| 22 |
+
_CN.DATASET.ESTIMATED_DEPTH = None
|
| 23 |
+
|
| 24 |
+
## For Linemod(BOP)
|
| 25 |
+
_CN.DATASET.OBJECT_ID = None
|
| 26 |
+
_CN.DATASET.MIN_VISIBLE_FRACT = None
|
| 27 |
+
_CN.DATASET.MAX_ANGLE_ERROR = None
|
| 28 |
+
_CN.DATASET.JSON_PATH = None
|
| 29 |
+
|
| 30 |
+
## For MegaDepth/ScanNet
|
| 31 |
+
_CN.DATASET.TRAIN = CN()
|
| 32 |
+
_CN.DATASET.TRAIN.DATA_ROOT = None
|
| 33 |
+
_CN.DATASET.TRAIN.NPZ_ROOT = None
|
| 34 |
+
_CN.DATASET.TRAIN.LIST_PATH = None
|
| 35 |
+
_CN.DATASET.TRAIN.INTRINSIC_PATH = None
|
| 36 |
+
_CN.DATASET.TRAIN.MIN_OVERLAP_SCORE = None
|
| 37 |
+
|
| 38 |
+
_CN.DATASET.VAL = CN()
|
| 39 |
+
_CN.DATASET.VAL.DATA_ROOT = None
|
| 40 |
+
_CN.DATASET.VAL.NPZ_ROOT = None
|
| 41 |
+
_CN.DATASET.VAL.LIST_PATH = None
|
| 42 |
+
_CN.DATASET.VAL.INTRINSIC_PATH = None
|
| 43 |
+
_CN.DATASET.VAL.MIN_OVERLAP_SCORE = None
|
| 44 |
+
|
| 45 |
+
_CN.DATASET.TEST = CN()
|
| 46 |
+
_CN.DATASET.TEST.DATA_ROOT = None
|
| 47 |
+
_CN.DATASET.TEST.NPZ_ROOT = None
|
| 48 |
+
_CN.DATASET.TEST.LIST_PATH = None
|
| 49 |
+
_CN.DATASET.TEST.INTRINSIC_PATH = None
|
| 50 |
+
_CN.DATASET.TEST.MIN_OVERLAP_SCORE = None
|
| 51 |
+
|
| 52 |
+
# Train
|
| 53 |
+
_CN.TRAINER = CN()
|
| 54 |
+
|
| 55 |
+
_CN.TRAINER.EPOCHS = None
|
| 56 |
+
_CN.TRAINER.LEARNING_RATE = None
|
| 57 |
+
_CN.TRAINER.PCT_START = None
|
| 58 |
+
_CN.TRAINER.BATCH_SIZE = None
|
| 59 |
+
_CN.TRAINER.NUM_WORKERS = None
|
| 60 |
+
_CN.TRAINER.PIN_MEMORY = True
|
| 61 |
+
_CN.TRAINER.N_SAMPLES_PER_SUBSET = None
|
| 62 |
+
|
| 63 |
+
_CN.RANDOM_SEED = 0
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# _CN.EMAT_RANSAC = CN()
|
| 67 |
+
# _CN.EMAT_RANSAC.PIX_THRESHOLD = 0.5
|
| 68 |
+
# _CN.EMAT_RANSAC.SCALE_THRESHOLD = 0.1
|
| 69 |
+
# _CN.EMAT_RANSAC.CONFIDENCE = 0.99999
|
| 70 |
+
|
| 71 |
+
# _CN.PNP = CN()
|
| 72 |
+
# _CN.PNP.RANSAC_ITER = 1000
|
| 73 |
+
# _CN.PNP.REPROJECTION_INLIER_THRESHOLD = 3
|
| 74 |
+
# _CN.PNP.CONFIDENCE = 0.99999
|
| 75 |
+
|
| 76 |
+
# _CN.PROCRUSTES = CN()
|
| 77 |
+
# _CN.PROCRUSTES.MAX_CORR_DIST = 0.05 # meters
|
| 78 |
+
# _CN.PROCRUSTES.REFINE = False
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_cfg_defaults():
|
| 82 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
| 83 |
+
# Return a clone so that the defaults will not be altered
|
| 84 |
+
# This is for the "local variable" use pattern
|
| 85 |
+
return _CN.clone()
|
configs/ho3d.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL:
|
| 2 |
+
NUM_KEYPOINTS: 1024
|
| 3 |
+
|
| 4 |
+
DATASET:
|
| 5 |
+
TASK: 'object'
|
| 6 |
+
DATA_SOURCE: 'ho3d'
|
| 7 |
+
DATA_ROOT: 'data/ho3d'
|
| 8 |
+
JSON_PATH: 'assets/ho3d_test_3000/ho3d_test.json'
|
| 9 |
+
|
| 10 |
+
MAX_ANGLE_ERROR: 45
|
| 11 |
+
|
| 12 |
+
TRAINER:
|
| 13 |
+
EPOCHS: 200
|
| 14 |
+
LEARNING_RATE: 0.00002
|
| 15 |
+
BATCH_SIZE: 32
|
| 16 |
+
NUM_WORKERS: 8
|
| 17 |
+
PCT_START: 0.3
|
| 18 |
+
N_SAMPLES_PER_SUBSET: 4000
|
| 19 |
+
|
configs/linemod.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL:
|
| 2 |
+
NUM_KEYPOINTS: 1200
|
| 3 |
+
|
| 4 |
+
DATASET:
|
| 5 |
+
TASK: 'object'
|
| 6 |
+
DATA_SOURCE: 'linemod'
|
| 7 |
+
DATA_ROOT: 'data/lm'
|
| 8 |
+
JSON_PATH: 'assets/linemod_test_1500/linemod_test.json'
|
| 9 |
+
|
| 10 |
+
MIN_VISIBLE_FRACT: 0.75
|
| 11 |
+
MAX_ANGLE_ERROR: 45
|
| 12 |
+
|
| 13 |
+
TRAINER:
|
| 14 |
+
EPOCHS: 200
|
| 15 |
+
LEARNING_RATE: 0.00002
|
| 16 |
+
BATCH_SIZE: 32
|
| 17 |
+
NUM_WORKERS: 8
|
| 18 |
+
PCT_START: 0.3
|
| 19 |
+
N_SAMPLES_PER_SUBSET: 200
|
| 20 |
+
|
configs/mapfree.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATASET:
|
| 2 |
+
TASK: 'scene'
|
| 3 |
+
DATA_SOURCE: 'mapfree'
|
| 4 |
+
DATA_ROOT: 'data/mapfree/'
|
| 5 |
+
# ESTIMATED_DEPTH: None # To load estimated depth map, provide the suffix to the depth files, e.g. 'dptnyu', 'dptkiti'
|
| 6 |
+
|
| 7 |
+
TRAINER:
|
| 8 |
+
EPOCHS: 200
|
| 9 |
+
LEARNING_RATE: 0.00002
|
| 10 |
+
BATCH_SIZE: 32
|
| 11 |
+
NUM_WORKERS: 6
|
| 12 |
+
PCT_START: 0.3
|
| 13 |
+
N_SAMPLES_PER_SUBSET: 200
|
| 14 |
+
|
configs/matterport.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATASET:
|
| 2 |
+
TASK: 'scene'
|
| 3 |
+
DATA_SOURCE: 'matterport'
|
| 4 |
+
DATA_ROOT: 'data/mp3d'
|
| 5 |
+
|
| 6 |
+
TRAINER:
|
| 7 |
+
EPOCHS: 200
|
| 8 |
+
LEARNING_RATE: 0.00005
|
| 9 |
+
BATCH_SIZE: 32
|
| 10 |
+
NUM_WORKERS: 8
|
| 11 |
+
PCT_START: 0.3
|
configs/megadepth.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATASET:
|
| 2 |
+
TASK: "scene"
|
| 3 |
+
DATA_SOURCE: "megadepth"
|
| 4 |
+
|
| 5 |
+
TRAIN:
|
| 6 |
+
DATA_ROOT: "data/megadepth/train"
|
| 7 |
+
NPZ_ROOT: "data/megadepth/index/scene_info_0.1_0.7"
|
| 8 |
+
LIST_PATH: "data/megadepth/index/trainvaltest_list/train_list.txt"
|
| 9 |
+
MIN_OVERLAP_SCORE: 0.0
|
| 10 |
+
|
| 11 |
+
VAL:
|
| 12 |
+
DATA_ROOT: "data/megadepth/test"
|
| 13 |
+
NPZ_ROOT: "data/megadepth/index/scene_info_val_1500"
|
| 14 |
+
LIST_PATH: "data/megadepth/index/trainvaltest_list/val_list.txt"
|
| 15 |
+
MIN_OVERLAP_SCORE: 0.0
|
| 16 |
+
|
| 17 |
+
TEST:
|
| 18 |
+
DATA_ROOT: "data/megadepth/test"
|
| 19 |
+
NPZ_ROOT: "assets/megadepth_test_1500_scene_info"
|
| 20 |
+
LIST_PATH: "assets/megadepth_test_1500_scene_info/megadepth_test_1500.txt"
|
| 21 |
+
MIN_OVERLAP_SCORE: 0.0
|
| 22 |
+
|
| 23 |
+
TRAINER:
|
| 24 |
+
EPOCHS: 500
|
| 25 |
+
LEARNING_RATE: 0.00002
|
| 26 |
+
BATCH_SIZE: 32
|
| 27 |
+
NUM_WORKERS: 8
|
| 28 |
+
PCT_START: 0.3
|
| 29 |
+
N_SAMPLES_PER_SUBSET: 200
|
configs/scannet.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATASET:
|
| 2 |
+
TASK: "scene"
|
| 3 |
+
DATA_SOURCE: "scannet"
|
| 4 |
+
|
| 5 |
+
TRAIN:
|
| 6 |
+
DATA_ROOT: "data/scannet/train"
|
| 7 |
+
NPZ_ROOT: "data/scannet/index/scene_data/train"
|
| 8 |
+
LIST_PATH: "data/scannet/index/scene_data/train_list/scannet_all.txt"
|
| 9 |
+
INTRINSIC_PATH: "data/scannet/index/intrinsics.npz"
|
| 10 |
+
MIN_OVERLAP_SCORE: 0.4
|
| 11 |
+
|
| 12 |
+
VAL:
|
| 13 |
+
DATA_ROOT: "data/scannet/test"
|
| 14 |
+
NPZ_ROOT: "assets/scannet_test_1500"
|
| 15 |
+
LIST_PATH: "assets/scannet_test_1500/scannet_test.txt"
|
| 16 |
+
INTRINSIC_PATH: "assets/scannet_test_1500/intrinsics.npz"
|
| 17 |
+
MIN_OVERLAP_SCORE: 0.0
|
| 18 |
+
|
| 19 |
+
TEST:
|
| 20 |
+
DATA_ROOT: "data/scannet/test"
|
| 21 |
+
NPZ_ROOT: "assets/scannet_test_1500"
|
| 22 |
+
LIST_PATH: "assets/scannet_test_1500/scannet_test.txt"
|
| 23 |
+
INTRINSIC_PATH: "assets/scannet_test_1500/intrinsics.npz"
|
| 24 |
+
MIN_OVERLAP_SCORE: 0.0
|
| 25 |
+
|
| 26 |
+
TRAINER:
|
| 27 |
+
EPOCHS: 500
|
| 28 |
+
LEARNING_RATE: 0.0001
|
| 29 |
+
BATCH_SIZE: 32
|
| 30 |
+
NUM_WORKERS: 8
|
| 31 |
+
PCT_START: 0.3
|
| 32 |
+
N_SAMPLES_PER_SUBSET: 200
|
| 33 |
+
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .matterport import build_matterport
|
| 2 |
+
from .linemod import build_linemod
|
| 3 |
+
from .megadepth import build_concat_megadepth
|
| 4 |
+
from .scannet import build_concat_scannet
|
| 5 |
+
from .ho3d import build_ho3d
|
| 6 |
+
from .mapfree import build_concat_mapfree
|
| 7 |
+
from .sampler import RandomConcatSampler
|
| 8 |
+
|
| 9 |
+
dataset_dict = {
|
| 10 |
+
'scene': {
|
| 11 |
+
'matterport': build_matterport,
|
| 12 |
+
'megadepth': build_concat_megadepth,
|
| 13 |
+
'scannet': build_concat_scannet,
|
| 14 |
+
'mapfree': build_concat_mapfree,
|
| 15 |
+
},
|
| 16 |
+
'object': {
|
| 17 |
+
'linemod': build_linemod,
|
| 18 |
+
'ho3d': build_ho3d,
|
| 19 |
+
}
|
| 20 |
+
}
|
datasets/ho3d.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import cv2
|
| 5 |
+
import pickle
|
| 6 |
+
import json
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 10 |
+
|
| 11 |
+
from utils.augment import Augmentor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class HO3D(Dataset):
|
| 15 |
+
def __init__(self, data_root, sequence_path, mode):
|
| 16 |
+
self.data_root = Path(data_root)
|
| 17 |
+
mode = 'evaluation' if mode != 'train' else 'train'
|
| 18 |
+
self.sequence_dir = self.data_root / mode / sequence_path
|
| 19 |
+
|
| 20 |
+
self.color_dir = self.sequence_dir / 'rgb'
|
| 21 |
+
self.mask_dir = self.sequence_dir / 'seg'
|
| 22 |
+
self.depth_dir = self.sequence_dir / 'depth'
|
| 23 |
+
self.meta_dir = self.sequence_dir / 'meta'
|
| 24 |
+
|
| 25 |
+
self.color_paths = list(self.color_dir.iterdir())
|
| 26 |
+
self.color_paths = sorted(self.color_paths)
|
| 27 |
+
|
| 28 |
+
self.mask_paths = [self.mask_dir / f'{x.stem}.png' for x in self.color_paths]
|
| 29 |
+
self.depth_paths = [self.depth_dir / f'{x.stem}.png' for x in self.color_paths]
|
| 30 |
+
self.meta_paths = [self.meta_dir / f'{x.stem}.pkl' for x in self.color_paths]
|
| 31 |
+
|
| 32 |
+
# self.glcam_in_cvcam = torch.tensor([
|
| 33 |
+
# [1,0,0,0],
|
| 34 |
+
# [0,-1,0,0],
|
| 35 |
+
# [0,0,-1,0],
|
| 36 |
+
# [0,0,0,1]
|
| 37 |
+
# ]).float()
|
| 38 |
+
self.intrinsics, self.extrinsics, self.objCorners, self.objNames, valid = self._load_meta(self.meta_paths)
|
| 39 |
+
|
| 40 |
+
self.color_paths = np.array(self.color_paths)[valid.numpy()]
|
| 41 |
+
self.mask_paths = np.array(self.mask_paths)[valid.numpy()]
|
| 42 |
+
self.depth_paths = np.array(self.depth_paths)[valid.numpy()]
|
| 43 |
+
self.meta_paths = np.array(self.meta_paths)[valid.numpy()]
|
| 44 |
+
|
| 45 |
+
self.bboxes, valid = self._load_bboxes(self.mask_paths)
|
| 46 |
+
self.intrinsics = self.intrinsics[valid]
|
| 47 |
+
self.extrinsics = self.extrinsics[valid]
|
| 48 |
+
self.objCorners = self.objCorners[valid]
|
| 49 |
+
self.objNames = self.objNames[valid.numpy()]
|
| 50 |
+
self.color_paths = self.color_paths[valid.numpy()]
|
| 51 |
+
self.mask_paths = self.mask_paths[valid.numpy()]
|
| 52 |
+
self.depth_paths = self.depth_paths[valid.numpy()]
|
| 53 |
+
self.meta_paths = self.meta_paths[valid.numpy()]
|
| 54 |
+
|
| 55 |
+
assert len(self.color_paths) == self.intrinsics.shape[0]
|
| 56 |
+
assert len(self.objNames) == self.extrinsics.shape[0]
|
| 57 |
+
|
| 58 |
+
self.augment = Augmentor(mode=='train')
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return len(self.color_paths)
|
| 62 |
+
|
| 63 |
+
def _load_bboxes(self, mask_paths):
|
| 64 |
+
bboxes = []
|
| 65 |
+
valid = []
|
| 66 |
+
for mask_path in mask_paths:
|
| 67 |
+
mask = cv2.imread(str(mask_path))
|
| 68 |
+
# mask = cv2.resize(mask, (640, 480))
|
| 69 |
+
w_scale, h_scale = 640 / mask.shape[1], 480 / mask.shape[0]
|
| 70 |
+
obj_mask = torch.from_numpy(mask[..., 1] == 255)
|
| 71 |
+
|
| 72 |
+
if obj_mask.float().sum() < 100:
|
| 73 |
+
valid.append(False)
|
| 74 |
+
continue
|
| 75 |
+
valid.append(True)
|
| 76 |
+
|
| 77 |
+
mask_inds = torch.where(obj_mask)
|
| 78 |
+
x1, x2 = mask_inds[0].aminmax()
|
| 79 |
+
y1, y2 = mask_inds[1].aminmax()
|
| 80 |
+
bbox = torch.tensor([y1*h_scale, x1*w_scale, y2*h_scale, x2*w_scale]).int()
|
| 81 |
+
bboxes.append(bbox)
|
| 82 |
+
|
| 83 |
+
bboxes = torch.stack(bboxes)
|
| 84 |
+
valid = torch.tensor(valid)
|
| 85 |
+
|
| 86 |
+
return bboxes, valid
|
| 87 |
+
|
| 88 |
+
def _load_meta(self, meta_paths):
|
| 89 |
+
intrinsics = []
|
| 90 |
+
extrinsics = []
|
| 91 |
+
objCorners = []
|
| 92 |
+
objNames = []
|
| 93 |
+
valid = []
|
| 94 |
+
for meta_path in meta_paths:
|
| 95 |
+
with open(meta_path, 'rb') as f:
|
| 96 |
+
anno = pickle.load(f, encoding='latin1')
|
| 97 |
+
|
| 98 |
+
if anno['camMat'] is None:
|
| 99 |
+
valid.append(False)
|
| 100 |
+
continue
|
| 101 |
+
valid.append(True)
|
| 102 |
+
|
| 103 |
+
camMat = torch.from_numpy(anno['camMat'])
|
| 104 |
+
ex = torch.eye(4)
|
| 105 |
+
ex[:3, :3] = torch.from_numpy(cv2.Rodrigues(anno['objRot'])[0])
|
| 106 |
+
ex[:3, 3] = torch.from_numpy(anno['objTrans'])
|
| 107 |
+
# ex = self.glcam_in_cvcam @ ex
|
| 108 |
+
objCorners3DRest = torch.from_numpy(anno['objCorners3DRest']).float()
|
| 109 |
+
# objCorners3DRest = (ex[:3, :3] @ objCorners3DRest.T + ex[:3, 3:]).T
|
| 110 |
+
objCorners3DRest = objCorners3DRest @ ex[:3, :3].T + ex[:3, 3]
|
| 111 |
+
|
| 112 |
+
intrinsics.append(camMat)
|
| 113 |
+
extrinsics.append(ex)
|
| 114 |
+
objCorners.append(objCorners3DRest)
|
| 115 |
+
objNames.append(anno['objName'])
|
| 116 |
+
|
| 117 |
+
intrinsics = torch.stack(intrinsics).float()
|
| 118 |
+
extrinsics = torch.stack(extrinsics).float()
|
| 119 |
+
objCorners = torch.stack(objCorners)
|
| 120 |
+
objNames = np.array(objNames)
|
| 121 |
+
valid = torch.tensor(valid)
|
| 122 |
+
|
| 123 |
+
return intrinsics, extrinsics, objCorners, objNames, valid
|
| 124 |
+
|
| 125 |
+
def _load_mask(self, mask_path):
|
| 126 |
+
mask = cv2.imread(str(mask_path))
|
| 127 |
+
mask = cv2.resize(mask, (640, 480))
|
| 128 |
+
mask = mask[..., 1] == 255
|
| 129 |
+
return mask
|
| 130 |
+
|
| 131 |
+
def _load_depth(self, depth_path):
|
| 132 |
+
depth_scale = 0.00012498664727900177
|
| 133 |
+
depth_img = cv2.imread(str(depth_path))
|
| 134 |
+
|
| 135 |
+
dpt = depth_img[:, :, 2] + depth_img[:, :, 1] * 256
|
| 136 |
+
dpt = dpt * depth_scale
|
| 137 |
+
|
| 138 |
+
return dpt
|
| 139 |
+
|
| 140 |
+
def __getitem__(self, idx):
|
| 141 |
+
color = cv2.imread(str(self.color_paths[idx]))
|
| 142 |
+
color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB)
|
| 143 |
+
# color = self.augment(color)
|
| 144 |
+
color = (torch.tensor(color).float() / 255.0).permute(2, 0, 1)
|
| 145 |
+
|
| 146 |
+
mask = self._load_mask(self.mask_paths[idx])
|
| 147 |
+
mask = torch.from_numpy(mask)
|
| 148 |
+
depth = self._load_depth(self.depth_paths[idx])
|
| 149 |
+
depth = torch.from_numpy(depth)
|
| 150 |
+
|
| 151 |
+
bbox = self.bboxes[idx]
|
| 152 |
+
|
| 153 |
+
intrinsic = self.intrinsics[idx]
|
| 154 |
+
extrinsic = self.extrinsics[idx]
|
| 155 |
+
objCorners = self.objCorners[idx]
|
| 156 |
+
objName = self.objNames[idx]
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
'color': color,
|
| 160 |
+
'mask': mask,
|
| 161 |
+
'depth': depth,
|
| 162 |
+
'extrinsic': extrinsic,
|
| 163 |
+
'intrinsic': intrinsic,
|
| 164 |
+
'objCorners': objCorners,
|
| 165 |
+
'bbox': bbox,
|
| 166 |
+
'color_path': str(self.color_paths[idx]).split('/', 2)[-1],
|
| 167 |
+
'objName': objName,
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class HO3DPair(Dataset):
|
| 172 |
+
def __init__(self, data_root, mode, sequence_id, max_angle_error):
|
| 173 |
+
self.ho3d_dataset = HO3D(data_root, sequence_id, mode)
|
| 174 |
+
|
| 175 |
+
angle_err = self.get_angle_error(self.ho3d_dataset.extrinsics[:, :3, :3])
|
| 176 |
+
index0, index1 = torch.where(angle_err < max_angle_error)
|
| 177 |
+
filter = torch.where(index0 < index1)
|
| 178 |
+
self.index0, self.index1 = index0[filter], index1[filter]
|
| 179 |
+
# angle_err_filtered = angle_err[row, col]
|
| 180 |
+
|
| 181 |
+
self.indices = torch.tensor(list(zip(self.index0, self.index1)))
|
| 182 |
+
if mode == 'val' or mode == 'test':
|
| 183 |
+
self.indices = self.indices[torch.randperm(self.indices.size(0))[:1500]]
|
| 184 |
+
|
| 185 |
+
def get_angle_error(self, R):
|
| 186 |
+
# R: (B, 3, 3)
|
| 187 |
+
residual = torch.einsum('aij,bik->abjk', R, R)
|
| 188 |
+
trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1)
|
| 189 |
+
cosine = (trace - 1) / 2
|
| 190 |
+
cosine = torch.clip(cosine, -1, 1)
|
| 191 |
+
R_err = torch.acos(cosine)
|
| 192 |
+
angle_err = R_err.rad2deg()
|
| 193 |
+
|
| 194 |
+
return angle_err
|
| 195 |
+
|
| 196 |
+
def __len__(self):
|
| 197 |
+
return len(self.indices)
|
| 198 |
+
|
| 199 |
+
def __getitem__(self, idx):
|
| 200 |
+
idx0, idx1 = self.indices[idx]
|
| 201 |
+
data0, data1 = self.ho3d_dataset[idx0], self.ho3d_dataset[idx1]
|
| 202 |
+
|
| 203 |
+
images = torch.stack([data0['color'], data1['color']], dim=0)
|
| 204 |
+
|
| 205 |
+
ex0, ex1 = data0['extrinsic'], data1['extrinsic']
|
| 206 |
+
rel_ex = ex1 @ ex0.inverse()
|
| 207 |
+
rel_R = rel_ex[:3, :3]
|
| 208 |
+
rel_t = rel_ex[:3, 3]
|
| 209 |
+
|
| 210 |
+
intrinsics = torch.stack([data0['intrinsic'], data1['intrinsic']], dim=0)
|
| 211 |
+
bboxes = torch.stack([data0['bbox'], data1['bbox']])
|
| 212 |
+
objCorners = torch.stack([data0['objCorners'], data1['objCorners']])
|
| 213 |
+
|
| 214 |
+
return {
|
| 215 |
+
'images': images,
|
| 216 |
+
'rotation': rel_R,
|
| 217 |
+
'translation': rel_t,
|
| 218 |
+
'intrinsics': intrinsics,
|
| 219 |
+
'bboxes': bboxes,
|
| 220 |
+
'objCorners': objCorners,
|
| 221 |
+
'pair_names': (data0['color_path'], data1['color_path']),
|
| 222 |
+
'objName': data0['objName']
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class HO3DfromJson(Dataset):
|
| 227 |
+
def __init__(self, data_root, json_path):
|
| 228 |
+
self.data_root = Path(data_root)
|
| 229 |
+
with open(json_path, 'r') as f:
|
| 230 |
+
self.scene_info = json.load(f)
|
| 231 |
+
|
| 232 |
+
self.obj_names = [
|
| 233 |
+
'003_cracker_box',
|
| 234 |
+
'006_mustard_bottle',
|
| 235 |
+
'011_banana',
|
| 236 |
+
'025_mug',
|
| 237 |
+
'037_scissors'
|
| 238 |
+
]
|
| 239 |
+
self.object_points = {obj: np.loadtxt(self.data_root / 'models' / obj / 'points.xyz') for obj in self.obj_names}
|
| 240 |
+
|
| 241 |
+
def _load_color(self, path):
|
| 242 |
+
color = cv2.imread(path)
|
| 243 |
+
color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB)
|
| 244 |
+
return color
|
| 245 |
+
|
| 246 |
+
def _load_mask(self, path):
|
| 247 |
+
mask_path = str(path).replace('rgb', 'seg').replace('.jpg', '.png')
|
| 248 |
+
mask = cv2.imread(str(mask_path))
|
| 249 |
+
mask = cv2.resize(mask, (640, 480))
|
| 250 |
+
mask = mask[..., 1] == 255
|
| 251 |
+
return mask
|
| 252 |
+
|
| 253 |
+
def _load_depth(self, path):
|
| 254 |
+
depth_scale = 0.00012498664727900177
|
| 255 |
+
|
| 256 |
+
depth_path = str(path).replace('rgb', 'depth').replace('.jpg', '.png')
|
| 257 |
+
depth_img = cv2.imread(depth_path)
|
| 258 |
+
|
| 259 |
+
dpt = depth_img[:, :, 2] + depth_img[:, :, 1] * 256
|
| 260 |
+
dpt = dpt * depth_scale
|
| 261 |
+
|
| 262 |
+
return dpt
|
| 263 |
+
|
| 264 |
+
def __len__(self):
|
| 265 |
+
return len(self.scene_info)
|
| 266 |
+
|
| 267 |
+
def __getitem__(self, idx):
|
| 268 |
+
info = self.scene_info[str(idx)]
|
| 269 |
+
pair_names = info['pair_names']
|
| 270 |
+
|
| 271 |
+
image0 = self._load_color(str(self.data_root / pair_names[0]))
|
| 272 |
+
image0 = (torch.tensor(image0).float() / 255.0).permute(2, 0, 1)
|
| 273 |
+
image1 = self._load_color(str(self.data_root / pair_names[1]))
|
| 274 |
+
image1 = (torch.tensor(image1).float() / 255.0).permute(2, 0, 1)
|
| 275 |
+
images = torch.stack([image0, image1], dim=0)
|
| 276 |
+
|
| 277 |
+
mask0 = self._load_mask(str(self.data_root / pair_names[0]))
|
| 278 |
+
mask0 = torch.from_numpy(mask0)
|
| 279 |
+
mask1 = self._load_mask(str(self.data_root / pair_names[1]))
|
| 280 |
+
mask1 = torch.from_numpy(mask1)
|
| 281 |
+
masks = torch.stack([mask0, mask1], dim=0)
|
| 282 |
+
|
| 283 |
+
depth0 = self._load_depth(str(self.data_root / pair_names[0]))
|
| 284 |
+
depth0 = torch.from_numpy(depth0)
|
| 285 |
+
depth1 = self._load_depth(str(self.data_root / pair_names[1]))
|
| 286 |
+
depth1 = torch.from_numpy(depth1)
|
| 287 |
+
depths = torch.stack([depth0, depth1], dim=0)
|
| 288 |
+
|
| 289 |
+
rotation = torch.tensor(info['rotation']).reshape(3, 3)
|
| 290 |
+
translation = torch.tensor(info['translation'])
|
| 291 |
+
intrinsics = torch.tensor(info['intrinsics']).reshape(2, 3, 3)
|
| 292 |
+
bboxes = torch.tensor(info['bboxes'])
|
| 293 |
+
objCorners = torch.tensor(info['objCorners'])
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
'images': images,
|
| 297 |
+
'masks': masks,
|
| 298 |
+
'depths': depths,
|
| 299 |
+
'rotation': rotation,
|
| 300 |
+
'translation': translation,
|
| 301 |
+
'intrinsics': intrinsics,
|
| 302 |
+
'bboxes': bboxes,
|
| 303 |
+
'objCorners': objCorners,
|
| 304 |
+
'objName': info['objName'][0],
|
| 305 |
+
'point_cloud': self.object_points[info['objName'][0]]
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def build_ho3d(mode, config):
|
| 310 |
+
config = config.DATASET
|
| 311 |
+
|
| 312 |
+
data_root = config.DATA_ROOT
|
| 313 |
+
seq_id_list = [x.stem for x in (Path(data_root) / 'train').iterdir()]
|
| 314 |
+
val_id_list = ['BB14', 'SMu1', 'MC1', 'GSF14', 'SM2', 'SM3', 'SM4', 'SM5', 'MC2', 'MC4', 'MC5', 'MC6']
|
| 315 |
+
for val_id in val_id_list:
|
| 316 |
+
seq_id_list.remove(val_id)
|
| 317 |
+
|
| 318 |
+
if mode == 'train':
|
| 319 |
+
datasets = []
|
| 320 |
+
for seq_id in tqdm(seq_id_list, desc=f'Loading HO3D {mode} dataset'):
|
| 321 |
+
datasets.append(HO3DPair(data_root, mode, seq_id, config.MAX_ANGLE_ERROR))
|
| 322 |
+
return ConcatDataset(datasets)
|
| 323 |
+
|
| 324 |
+
elif mode == 'test' or mode == 'val':
|
| 325 |
+
# datasets = []
|
| 326 |
+
# for seq_id in tqdm(val_id_list[:5], desc=f'Loading HO3D {mode} dataset'):
|
| 327 |
+
# datasets.append(HO3DPair(data_root, mode, seq_id, config.MAX_ANGLE_ERROR))
|
| 328 |
+
# return ConcatDataset(datasets)
|
| 329 |
+
|
| 330 |
+
return HO3DfromJson(config.DATA_ROOT, config.JSON_PATH)
|
| 331 |
+
|
datasets/linemod.py
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import math
|
| 4 |
+
import random
|
| 5 |
+
from tqdm import tqdm, trange
|
| 6 |
+
import plyfile
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 12 |
+
from torch.nn import functional as F
|
| 13 |
+
|
| 14 |
+
from utils import Augmentor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
LINEMOD_ID_TO_NAME = {
|
| 18 |
+
'000001': 'ape',
|
| 19 |
+
'000002': 'benchvise',
|
| 20 |
+
'000003': 'bowl',
|
| 21 |
+
'000004': 'camera',
|
| 22 |
+
'000005': 'can',
|
| 23 |
+
'000006': 'cat',
|
| 24 |
+
'000007': 'mug',
|
| 25 |
+
'000008': 'driller',
|
| 26 |
+
'000009': 'duck',
|
| 27 |
+
'000010': 'eggbox',
|
| 28 |
+
'000011': 'glue',
|
| 29 |
+
'000012': 'holepuncher',
|
| 30 |
+
'000013': 'iron',
|
| 31 |
+
'000014': 'lamp',
|
| 32 |
+
'000015': 'phone',
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def inverse_transform(trans):
|
| 37 |
+
rot = trans[:3, :3]
|
| 38 |
+
t = trans[:3, 3]
|
| 39 |
+
rot = np.transpose(rot)
|
| 40 |
+
t = -np.matmul(rot, t)
|
| 41 |
+
output = np.zeros((4, 4), dtype=np.float32)
|
| 42 |
+
output[3][3] = 1
|
| 43 |
+
output[:3, :3] = rot
|
| 44 |
+
output[:3, 3] = t
|
| 45 |
+
return output
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class BOPDataset(Dataset):
|
| 49 |
+
|
| 50 |
+
def __init__(self,
|
| 51 |
+
dataset_path,
|
| 52 |
+
scene_path,
|
| 53 |
+
object_id,
|
| 54 |
+
min_visible_fract,
|
| 55 |
+
mode,
|
| 56 |
+
rgb_postfix='.png',
|
| 57 |
+
object_scale=None
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.dataset_path = dataset_path
|
| 61 |
+
self.scene_path = scene_path
|
| 62 |
+
self.object_id = object_id
|
| 63 |
+
|
| 64 |
+
if dataset_path.name == 'lm' or dataset_path.name == 'lmo':
|
| 65 |
+
base_obj_scale = 1.0
|
| 66 |
+
self.models_path = self.dataset_path / 'models'
|
| 67 |
+
elif dataset_path.name == 'tless':
|
| 68 |
+
base_obj_scale = 0.60
|
| 69 |
+
self.models_path = self.dataset_path / 'models_reconst'
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f'Unknown dataset type {dataset_path.name}')
|
| 72 |
+
|
| 73 |
+
self.model_path = self.models_path / f'obj_{self.object_id:06d}.ply'
|
| 74 |
+
self.pointcloud_path = self.dataset_path / 'models_eval' / f'obj_{self.object_id:06d}.ply'
|
| 75 |
+
|
| 76 |
+
models_info_path = self.dataset_path / 'models_eval' / 'models_info.json'
|
| 77 |
+
with open(models_info_path, 'r') as f:
|
| 78 |
+
self.model_info = json.load(f)[str(object_id)]
|
| 79 |
+
|
| 80 |
+
# self.center_object = center_object
|
| 81 |
+
if object_scale is None:
|
| 82 |
+
self.object_scale = base_obj_scale / self.model_info['diameter']
|
| 83 |
+
else:
|
| 84 |
+
self.object_scale = object_scale
|
| 85 |
+
|
| 86 |
+
# self.image_scale = 1.0
|
| 87 |
+
self.bounds = torch.tensor([
|
| 88 |
+
(self.model_info['min_x'], self.model_info['min_x'] + self.model_info['size_x']),
|
| 89 |
+
(self.model_info['min_y'], self.model_info['min_y'] + self.model_info['size_y']),
|
| 90 |
+
(self.model_info['min_z'], self.model_info['min_z'] + self.model_info['size_z']),
|
| 91 |
+
])
|
| 92 |
+
self.centroid = self.bounds.mean(dim=1)
|
| 93 |
+
|
| 94 |
+
self.depth_dir = self.scene_path / 'depth'
|
| 95 |
+
self.mask_dir = self.scene_path / 'mask_visib'
|
| 96 |
+
self.color_dir = self.scene_path / 'rgb'
|
| 97 |
+
self.intrinsics_path = self.scene_path / 'scene_camera.json'
|
| 98 |
+
self.extrinsics_path = self.scene_path / 'scene_gt.json'
|
| 99 |
+
self.gt_info_path = self.scene_path / 'scene_gt_info.json'
|
| 100 |
+
|
| 101 |
+
self.intrinsics, self.depth_scales = self.load_intrinsics(self.intrinsics_path)
|
| 102 |
+
self.extrinsics, self.scene_object_inds = self.load_extrinsics(self.extrinsics_path)
|
| 103 |
+
self.extrinsics = torch.stack(self.extrinsics, dim=0)
|
| 104 |
+
self.gt_info = self.load_gt_info(self.gt_info_path)
|
| 105 |
+
|
| 106 |
+
# # Compute quaternions for sampling.
|
| 107 |
+
# rotation, translation = three.decompose(self.extrinsics)
|
| 108 |
+
# self.quaternions = three.quaternion.mat_to_quat(rotation[:, :3, :3])
|
| 109 |
+
|
| 110 |
+
self.depth_paths = sorted([self.depth_dir / f'{frame_ind:06d}.png'
|
| 111 |
+
for frame_ind in self.scene_object_inds.keys()])
|
| 112 |
+
self.mask_paths = [
|
| 113 |
+
self.mask_dir / f'{frame_ind:06d}_{obj_ind:06d}.png'
|
| 114 |
+
for frame_ind, obj_ind in self.scene_object_inds.items()
|
| 115 |
+
]
|
| 116 |
+
self.color_paths = sorted([self.color_dir / f'{frame_ind:06d}{rgb_postfix}'
|
| 117 |
+
for frame_ind in self.scene_object_inds.keys()])
|
| 118 |
+
|
| 119 |
+
visib_filter = np.array(self.gt_info['visib_fract']) >= min_visible_fract
|
| 120 |
+
self.color_paths = np.array(self.color_paths)[visib_filter]
|
| 121 |
+
self.mask_paths = np.array(self.mask_paths)[visib_filter]
|
| 122 |
+
self.depth_paths = np.array(self.depth_paths)[visib_filter]
|
| 123 |
+
self.depth_scales = np.array(self.depth_scales)[visib_filter]
|
| 124 |
+
for k in self.gt_info:
|
| 125 |
+
self.gt_info[k] = np.array(self.gt_info[k])[visib_filter]
|
| 126 |
+
|
| 127 |
+
self.extrinsics = np.array(self.extrinsics)[visib_filter]
|
| 128 |
+
self.intrinsics = np.array(self.intrinsics)[visib_filter]
|
| 129 |
+
|
| 130 |
+
self.augment = Augmentor(mode=='train')
|
| 131 |
+
|
| 132 |
+
assert len(self.depth_paths) == len(self.mask_paths)
|
| 133 |
+
assert len(self.depth_paths) == len(self.color_paths)
|
| 134 |
+
|
| 135 |
+
# def load_pointcloud(self):
|
| 136 |
+
# obj = meshutils.Object3D(self.pointcloud_path)
|
| 137 |
+
# points = torch.tensor(obj.vertices, dtype=torch.float32)
|
| 138 |
+
# points = points * self.object_scale
|
| 139 |
+
# return points
|
| 140 |
+
|
| 141 |
+
def load_gt_info(self, path):
|
| 142 |
+
with open(path, 'r') as f:
|
| 143 |
+
gt_info_json = json.load(f)
|
| 144 |
+
keys = sorted([int(k) for k in gt_info_json.keys()])
|
| 145 |
+
gt_info = {k: [] for k in gt_info_json['0'][0]}
|
| 146 |
+
for key in keys:
|
| 147 |
+
value = gt_info_json[str(key)]
|
| 148 |
+
obj_info = value[self.scene_object_inds[key]]
|
| 149 |
+
for info_k in obj_info:
|
| 150 |
+
gt_info[info_k].append(obj_info[info_k])
|
| 151 |
+
|
| 152 |
+
return gt_info
|
| 153 |
+
|
| 154 |
+
def load_intrinsics(self, path):
|
| 155 |
+
intrinsics = []
|
| 156 |
+
depth_scales = []
|
| 157 |
+
with open(path, 'r') as f:
|
| 158 |
+
intrinsics_json = json.load(f)
|
| 159 |
+
keys = sorted([int(k) for k in intrinsics_json.keys()])
|
| 160 |
+
for key in keys:
|
| 161 |
+
value = intrinsics_json[str(key)]
|
| 162 |
+
intrinsic_3x3 = value['cam_K']
|
| 163 |
+
intrinsics.append(torch.tensor(intrinsic_3x3).reshape(3, 3))
|
| 164 |
+
depth_scales.append(value['depth_scale'])
|
| 165 |
+
|
| 166 |
+
return intrinsics, depth_scales
|
| 167 |
+
|
| 168 |
+
def load_extrinsics(self, path):
|
| 169 |
+
extrinsics = []
|
| 170 |
+
scene_object_inds = {}
|
| 171 |
+
with open(path, 'r') as f:
|
| 172 |
+
extrinsics_json = json.load(f)
|
| 173 |
+
frame_inds = sorted([int(k) for k in extrinsics_json.keys()])
|
| 174 |
+
for frame_ind in frame_inds:
|
| 175 |
+
for obj_ind, cam_d in enumerate(extrinsics_json[str(frame_ind)]):
|
| 176 |
+
if cam_d['obj_id'] == self.object_id:
|
| 177 |
+
rotation = torch.tensor(
|
| 178 |
+
cam_d['cam_R_m2c'], dtype=torch.float32).reshape(3, 3)
|
| 179 |
+
translation = torch.tensor(cam_d['cam_t_m2c'], dtype=torch.float32) / 1000.
|
| 180 |
+
# quaternion = three.quaternion.mat_to_quat(rotation)
|
| 181 |
+
# extrinsics.append(three.to_extrinsic_matrix(translation, quaternion))
|
| 182 |
+
extrinsic = torch.eye(4)
|
| 183 |
+
extrinsic[:3, :3] = rotation
|
| 184 |
+
extrinsic[:3, 3] = translation
|
| 185 |
+
extrinsics.append(extrinsic)
|
| 186 |
+
scene_object_inds[frame_ind] = obj_ind
|
| 187 |
+
|
| 188 |
+
return extrinsics, scene_object_inds
|
| 189 |
+
|
| 190 |
+
def __len__(self):
|
| 191 |
+
return len(self.color_paths)
|
| 192 |
+
|
| 193 |
+
def get_ids(self):
|
| 194 |
+
return [p.stem for p in self.color_paths]
|
| 195 |
+
|
| 196 |
+
def _load_color(self, path):
|
| 197 |
+
image = Image.open(path)
|
| 198 |
+
image = np.array(image)
|
| 199 |
+
return image
|
| 200 |
+
|
| 201 |
+
def _load_mask(self, path):
|
| 202 |
+
image = Image.open(path)
|
| 203 |
+
image = np.array(image, dtype=bool)
|
| 204 |
+
if len(image.shape) > 2:
|
| 205 |
+
image = image[:, :, 0]
|
| 206 |
+
return image
|
| 207 |
+
|
| 208 |
+
def _load_depth(self, path):
|
| 209 |
+
image = Image.open(path)
|
| 210 |
+
image = np.array(image, dtype=np.float32)
|
| 211 |
+
return image
|
| 212 |
+
|
| 213 |
+
def __getitem__(self, idx):
|
| 214 |
+
color = self._load_color(self.color_paths[idx])
|
| 215 |
+
# color = self.augment(color)
|
| 216 |
+
color = (torch.tensor(color).float() / 255.0).permute(2, 0, 1)
|
| 217 |
+
mask = self._load_mask(self.mask_paths[idx])
|
| 218 |
+
mask = torch.tensor(mask).bool()
|
| 219 |
+
depth = self._load_depth(self.depth_paths[idx])
|
| 220 |
+
depth = torch.tensor(depth) * self.object_scale * self.depth_scales[idx]
|
| 221 |
+
|
| 222 |
+
# intrinsic = self.normalize_intrinsic(self.intrinsics[idx])
|
| 223 |
+
# extrinsic = self.normalize_extrinsic(self.extrinsics[idx])
|
| 224 |
+
|
| 225 |
+
intrinsic = torch.from_numpy(self.intrinsics[idx])
|
| 226 |
+
extrinsic = torch.from_numpy(self.extrinsics[idx])
|
| 227 |
+
|
| 228 |
+
bbox_obj = self.gt_info['bbox_obj'][idx]
|
| 229 |
+
bbox_visib = self.gt_info['bbox_visib'][idx]
|
| 230 |
+
|
| 231 |
+
bbox_obj = torch.tensor([bbox_obj[0], bbox_obj[1], bbox_obj[0]+bbox_obj[2], bbox_obj[1]+bbox_obj[3]])
|
| 232 |
+
bbox_visib = torch.tensor([bbox_visib[0], bbox_visib[1], bbox_visib[0]+bbox_visib[2], bbox_visib[1]+bbox_visib[3]])
|
| 233 |
+
|
| 234 |
+
visib_fract = self.gt_info['visib_fract'][idx]
|
| 235 |
+
px_count_visib = self.gt_info['px_count_visib'][idx]
|
| 236 |
+
|
| 237 |
+
return {
|
| 238 |
+
'color': color,
|
| 239 |
+
'mask': mask,
|
| 240 |
+
'depth': depth,
|
| 241 |
+
'extrinsic': extrinsic,
|
| 242 |
+
'intrinsic': intrinsic,
|
| 243 |
+
'bbox_obj': bbox_obj,
|
| 244 |
+
'bbox_visib': bbox_visib,
|
| 245 |
+
'visib_fract': visib_fract,
|
| 246 |
+
'px_count_visib': px_count_visib,
|
| 247 |
+
'color_path': str(self.color_paths[idx]).split('/', 2)[-1],
|
| 248 |
+
'object_scale': self.object_scale,
|
| 249 |
+
'depth_scale': self.depth_scales[idx]
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class Linemod(Dataset):
|
| 254 |
+
def __init__(self, data_root, mode, object_id, scene_id, min_visible_fract, max_angle_error):
|
| 255 |
+
if mode == 'train':
|
| 256 |
+
type_path = 'train_pbr'
|
| 257 |
+
rgb_postfix = '.jpg'
|
| 258 |
+
scene_id = scene_id
|
| 259 |
+
elif mode == 'val' or mode == 'test':
|
| 260 |
+
type_path = 'test'
|
| 261 |
+
rgb_postfix = '.png'
|
| 262 |
+
scene_id = object_id
|
| 263 |
+
else:
|
| 264 |
+
raise NotImplementedError(f'mode {mode}')
|
| 265 |
+
|
| 266 |
+
data_root = Path(data_root)
|
| 267 |
+
scene_path = data_root / type_path / f'{scene_id:06d}'
|
| 268 |
+
self.bop_dataset = BOPDataset(data_root, scene_path, object_id=object_id, min_visible_fract=min_visible_fract, mode=mode, rgb_postfix=rgb_postfix)
|
| 269 |
+
|
| 270 |
+
angle_err = self.get_angle_error(torch.from_numpy(self.bop_dataset.extrinsics[:, :3, :3]))
|
| 271 |
+
index0, index1 = torch.where(angle_err < max_angle_error)
|
| 272 |
+
filter = torch.where(index0 < index1)
|
| 273 |
+
self.index0, self.index1 = index0[filter], index1[filter]
|
| 274 |
+
# angle_err_filtered = angle_err[row, col]
|
| 275 |
+
|
| 276 |
+
self.indices = torch.tensor(list(zip(self.index0, self.index1)))
|
| 277 |
+
if mode == 'val':
|
| 278 |
+
self.indices = self.indices[torch.randperm(self.indices.size(0))[:1500]]
|
| 279 |
+
|
| 280 |
+
def get_angle_error(self, R):
|
| 281 |
+
# R: (B, 3, 3)
|
| 282 |
+
residual = torch.einsum('aij,bik->abjk', R, R)
|
| 283 |
+
trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1)
|
| 284 |
+
cosine = (trace - 1) / 2
|
| 285 |
+
cosine = torch.clip(cosine, -1, 1)
|
| 286 |
+
R_err = torch.acos(cosine)
|
| 287 |
+
angle_err = R_err.rad2deg()
|
| 288 |
+
|
| 289 |
+
return angle_err
|
| 290 |
+
|
| 291 |
+
def __len__(self):
|
| 292 |
+
return len(self.indices)
|
| 293 |
+
|
| 294 |
+
def __getitem__(self, idx):
|
| 295 |
+
idx0, idx1 = self.indices[idx]
|
| 296 |
+
data0, data1 = self.bop_dataset[idx0], self.bop_dataset[idx1]
|
| 297 |
+
|
| 298 |
+
images = torch.stack([data0['color'], data1['color']], dim=0)
|
| 299 |
+
|
| 300 |
+
ex0, ex1 = data0['extrinsic'], data1['extrinsic']
|
| 301 |
+
rel_ex = ex1 @ ex0.inverse()
|
| 302 |
+
rel_R = rel_ex[:3, :3]
|
| 303 |
+
rel_t = rel_ex[:3, 3]
|
| 304 |
+
|
| 305 |
+
intrinsics = torch.stack([data0['intrinsic'], data1['intrinsic']], dim=0)
|
| 306 |
+
bboxes = torch.stack([data0['bbox_visib'], data1['bbox_visib']])
|
| 307 |
+
|
| 308 |
+
return {
|
| 309 |
+
'images': images,
|
| 310 |
+
'rotation': rel_R,
|
| 311 |
+
'translation': rel_t,
|
| 312 |
+
'intrinsics': intrinsics,
|
| 313 |
+
'bboxes': bboxes,
|
| 314 |
+
'pair_names': (data0['color_path'], data1['color_path']),
|
| 315 |
+
'object_scale': data0['object_scale'],
|
| 316 |
+
'depth_scale': (data0['depth_scale'], data1['depth_scale']),
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class LinemodfromJson(Dataset):
|
| 321 |
+
def __init__(self, data_root, json_path):
|
| 322 |
+
self.data_root = Path(data_root)
|
| 323 |
+
with open(json_path, 'r') as f:
|
| 324 |
+
self.scene_info = json.load(f)
|
| 325 |
+
|
| 326 |
+
# self.image_scale = 1.0
|
| 327 |
+
|
| 328 |
+
models_info_path = self.data_root / 'models_eval' / 'models_info.json'
|
| 329 |
+
with open(models_info_path, 'r') as f:
|
| 330 |
+
model_info = json.load(f)
|
| 331 |
+
|
| 332 |
+
self.object_diameters = {obj: model_info[obj]['diameter'] for obj in model_info}
|
| 333 |
+
self.object_points = {obj: self._load_point_cloud(obj) for obj in self.object_diameters}
|
| 334 |
+
|
| 335 |
+
def _load_point_cloud(self, obj_id):
|
| 336 |
+
with open(self.data_root / 'models_eval' / f'obj_{int(obj_id):06d}.ply', "rb") as f:
|
| 337 |
+
plydata = plyfile.PlyData.read(f)
|
| 338 |
+
xyz = np.stack([np.array(plydata["vertex"][c]).astype(float) for c in ("x", "y", "z")], axis=1)
|
| 339 |
+
return xyz
|
| 340 |
+
|
| 341 |
+
def _load_color(self, path):
|
| 342 |
+
image = Image.open(path)
|
| 343 |
+
# new_shape = (int(image.width * self.image_scale), int(image.height * self.image_scale))
|
| 344 |
+
# image = image.resize(new_shape)
|
| 345 |
+
image = np.array(image)
|
| 346 |
+
return image
|
| 347 |
+
|
| 348 |
+
def _load_mask(self, path):
|
| 349 |
+
path = path.replace('rgb', 'mask_visib').replace('.png', '_000000.png')
|
| 350 |
+
image = Image.open(path)
|
| 351 |
+
# new_shape = (int(image.width * self.image_scale), int(image.height * self.image_scale))
|
| 352 |
+
# image = image.resize(new_shape)
|
| 353 |
+
image = np.array(image, dtype=bool)
|
| 354 |
+
if len(image.shape) > 2:
|
| 355 |
+
image = image[:, :, 0]
|
| 356 |
+
return image
|
| 357 |
+
|
| 358 |
+
def _load_depth(self, path):
|
| 359 |
+
path = path.replace('rgb', 'depth')
|
| 360 |
+
image = Image.open(path)
|
| 361 |
+
# new_shape = (int(image.width * self.image_scale), int(image.height * self.image_scale))
|
| 362 |
+
# image = image.resize(new_shape)
|
| 363 |
+
image = np.array(image, dtype=np.float32)
|
| 364 |
+
return image
|
| 365 |
+
|
| 366 |
+
def __len__(self):
|
| 367 |
+
return len(self.scene_info)
|
| 368 |
+
|
| 369 |
+
def __getitem__(self, idx):
|
| 370 |
+
info = self.scene_info[str(idx)]
|
| 371 |
+
pair_names = info['pair_names']
|
| 372 |
+
|
| 373 |
+
image0 = self._load_color(str(self.data_root / pair_names[0]))
|
| 374 |
+
image0 = (torch.tensor(image0).float() / 255.0).permute(2, 0, 1)
|
| 375 |
+
image1 = self._load_color(str(self.data_root / pair_names[1]))
|
| 376 |
+
image1 = (torch.tensor(image1).float() / 255.0).permute(2, 0, 1)
|
| 377 |
+
images = torch.stack([image0, image1], dim=0)
|
| 378 |
+
|
| 379 |
+
mask0 = self._load_mask(str(self.data_root / pair_names[0]))
|
| 380 |
+
mask0 = torch.tensor(mask0).bool()
|
| 381 |
+
mask1 = self._load_mask(str(self.data_root / pair_names[1]))
|
| 382 |
+
mask1 = torch.tensor(mask1).bool()
|
| 383 |
+
masks = torch.stack([mask0, mask1], dim=0)
|
| 384 |
+
|
| 385 |
+
depth0 = self._load_depth(str(self.data_root / pair_names[0]))
|
| 386 |
+
depth0 = torch.tensor(depth0) * info['depth_scale'][0]
|
| 387 |
+
depth1 = self._load_depth(str(self.data_root / pair_names[1]))
|
| 388 |
+
depth1 = torch.tensor(depth1) * info['depth_scale'][1]
|
| 389 |
+
depths = torch.stack([depth0, depth1], dim=0) / 1000.
|
| 390 |
+
|
| 391 |
+
rotation = torch.tensor(info['rotation']).reshape(3, 3)
|
| 392 |
+
translation = torch.tensor(info['translation'])
|
| 393 |
+
intrinsics = torch.tensor(info['intrinsics']).reshape(2, 3, 3)
|
| 394 |
+
bboxes = torch.tensor(info['bboxes'])
|
| 395 |
+
|
| 396 |
+
obj_id = str(int(pair_names[0].split('/')[1]))
|
| 397 |
+
diameter = self.object_diameters[obj_id]
|
| 398 |
+
point_cloud = torch.from_numpy(self.object_points[obj_id]) / 1000.
|
| 399 |
+
|
| 400 |
+
return {
|
| 401 |
+
'images': images,
|
| 402 |
+
'masks': masks,
|
| 403 |
+
'depths': depths,
|
| 404 |
+
'rotation': rotation,
|
| 405 |
+
'translation': translation,
|
| 406 |
+
'intrinsics': intrinsics,
|
| 407 |
+
'bboxes': bboxes,
|
| 408 |
+
'diameter': diameter,
|
| 409 |
+
'point_cloud': point_cloud,
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def build_linemod(mode, config):
|
| 414 |
+
config = config.DATASET
|
| 415 |
+
|
| 416 |
+
# datasets = []
|
| 417 |
+
# for i, _ in enumerate(LINEMOD_ID_TO_NAME):
|
| 418 |
+
# datasets.append(Linemod(config.DATA_ROOT, mode, i+1, config.MIN_VISIBLE_FRACT, config.MAX_ANGLE_ERROR))
|
| 419 |
+
|
| 420 |
+
# return ConcatDataset(datasets)
|
| 421 |
+
|
| 422 |
+
if mode == 'train':
|
| 423 |
+
datasets = []
|
| 424 |
+
with tqdm(total=len(LINEMOD_ID_TO_NAME) * 50) as t:
|
| 425 |
+
t.set_description(f'Loading Linemod {mode} datasets')
|
| 426 |
+
for i, _ in enumerate(LINEMOD_ID_TO_NAME):
|
| 427 |
+
for j in range(50):
|
| 428 |
+
t.update(1)
|
| 429 |
+
try:
|
| 430 |
+
datasets.append(Linemod(config.DATA_ROOT, mode, i+1, j, config.MIN_VISIBLE_FRACT, config.MAX_ANGLE_ERROR))
|
| 431 |
+
except KeyError:
|
| 432 |
+
continue
|
| 433 |
+
return ConcatDataset(datasets)
|
| 434 |
+
|
| 435 |
+
elif mode == 'test' or mode == 'val':
|
| 436 |
+
# datasets = []
|
| 437 |
+
# for i, _ in enumerate(LINEMOD_ID_TO_NAME):
|
| 438 |
+
# datasets.append(Linemod(config.DATA_ROOT, mode, i+1, i+1, config.MIN_VISIBLE_FRACT, config.MAX_ANGLE_ERROR))
|
| 439 |
+
|
| 440 |
+
# return ConcatDataset(datasets)
|
| 441 |
+
return LinemodfromJson(config.DATA_ROOT, config.JSON_PATH)
|
datasets/mapfree.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.utils.data as data
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
from transforms3d.quaternions import qinverse, qmult, rotate_vector, quat2mat
|
| 8 |
+
|
| 9 |
+
from utils.transform import correct_intrinsic_scale
|
| 10 |
+
from utils import Augmentor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MapFreeScene(data.Dataset):
|
| 14 |
+
def __init__(self, scene_root, resize, sample_factor=1, overlap_limits=None, estimated_depth=None, mode='train'):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
self.scene_root = Path(scene_root)
|
| 18 |
+
self.resize = resize
|
| 19 |
+
self.sample_factor = sample_factor
|
| 20 |
+
self.estimated_depth = estimated_depth
|
| 21 |
+
|
| 22 |
+
# load absolute poses
|
| 23 |
+
self.poses = self.read_poses(self.scene_root)
|
| 24 |
+
|
| 25 |
+
# read intrinsics
|
| 26 |
+
self.K = self.read_intrinsics(self.scene_root, resize)
|
| 27 |
+
|
| 28 |
+
# load pairs
|
| 29 |
+
self.pairs = self.load_pairs(self.scene_root, overlap_limits, self.sample_factor)
|
| 30 |
+
|
| 31 |
+
self.augment = Augmentor(mode=='train')
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def read_intrinsics(scene_root: Path, resize=None):
|
| 35 |
+
Ks = {}
|
| 36 |
+
with (scene_root / 'intrinsics.txt').open('r') as f:
|
| 37 |
+
for line in f.readlines():
|
| 38 |
+
if '#' in line:
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
line = line.strip().split(' ')
|
| 42 |
+
img_name = line[0]
|
| 43 |
+
fx, fy, cx, cy, W, H = map(float, line[1:])
|
| 44 |
+
|
| 45 |
+
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| 46 |
+
if resize is not None:
|
| 47 |
+
K = correct_intrinsic_scale(K, resize[0] / W, resize[1] / H)
|
| 48 |
+
Ks[img_name] = K
|
| 49 |
+
return Ks
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def read_poses(scene_root: Path):
|
| 53 |
+
"""
|
| 54 |
+
Returns a dictionary that maps: img_path -> (q, t) where
|
| 55 |
+
np.array q = (qw, qx qy qz) quaternion encoding rotation matrix;
|
| 56 |
+
np.array t = (tx ty tz) translation vector;
|
| 57 |
+
(q, t) encodes absolute pose (world-to-camera), i.e. X_c = R(q) X_W + t
|
| 58 |
+
"""
|
| 59 |
+
poses = {}
|
| 60 |
+
with (scene_root / 'poses.txt').open('r') as f:
|
| 61 |
+
for line in f.readlines():
|
| 62 |
+
if '#' in line:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
line = line.strip().split(' ')
|
| 66 |
+
img_name = line[0]
|
| 67 |
+
qt = np.array(list(map(float, line[1:])))
|
| 68 |
+
poses[img_name] = (qt[:4], qt[4:])
|
| 69 |
+
return poses
|
| 70 |
+
|
| 71 |
+
def load_pairs(self, scene_root: Path, overlap_limits: tuple = None, sample_factor: int = 1):
|
| 72 |
+
"""
|
| 73 |
+
For training scenes, filter pairs of frames based on overlap (pre-computed in overlaps.npz)
|
| 74 |
+
For test/val scenes, pairs are formed between keyframe and every other sample_factor query frames.
|
| 75 |
+
If sample_factor == 1, all query frames are used. Note: sample_factor applicable only to test/val
|
| 76 |
+
Returns:
|
| 77 |
+
pairs: nd.array [Npairs, 4], where each column represents seaA, imA, seqB, imB, respectively
|
| 78 |
+
"""
|
| 79 |
+
overlaps_path = scene_root / 'overlaps.npz'
|
| 80 |
+
|
| 81 |
+
if overlaps_path.exists():
|
| 82 |
+
f = np.load(overlaps_path, allow_pickle=True)
|
| 83 |
+
idxs, overlaps = f['idxs'], f['overlaps']
|
| 84 |
+
if overlap_limits is not None:
|
| 85 |
+
min_overlap, max_overlap = overlap_limits
|
| 86 |
+
mask = (overlaps > min_overlap) * (overlaps < max_overlap)
|
| 87 |
+
idxs = idxs[mask]
|
| 88 |
+
return idxs.copy()
|
| 89 |
+
else:
|
| 90 |
+
idxs = np.zeros((len(self.poses) - 1, 4), dtype=np.uint16)
|
| 91 |
+
idxs[:, 2] = 1
|
| 92 |
+
idxs[:, 3] = np.array([int(fn[-9:-4])
|
| 93 |
+
for fn in self.poses.keys() if 'seq0' not in fn], dtype=np.uint16)
|
| 94 |
+
return idxs[::sample_factor]
|
| 95 |
+
|
| 96 |
+
def get_pair_path(self, pair):
|
| 97 |
+
seqA, imgA, seqB, imgB = pair
|
| 98 |
+
return (f'seq{seqA}/frame_{imgA:05}.jpg', f'seq{seqB}/frame_{imgB:05}.jpg')
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return len(self.pairs)
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, index):
|
| 104 |
+
# image paths (relative to scene_root)
|
| 105 |
+
img_name0, img_name1 = self.get_pair_path(self.pairs[index])
|
| 106 |
+
w_new, h_new = self.resize
|
| 107 |
+
|
| 108 |
+
image0 = cv2.imread(str(self.scene_root / img_name0))
|
| 109 |
+
# image0 = cv2.resize(image0, (w_new, h_new))
|
| 110 |
+
image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)
|
| 111 |
+
image0 = self.augment(image0)
|
| 112 |
+
image0 = torch.from_numpy(image0).permute(2, 0, 1).float() / 255.
|
| 113 |
+
|
| 114 |
+
image1 = cv2.imread(str(self.scene_root / img_name1))
|
| 115 |
+
# image1 = cv2.resize(image1, (w_new, h_new))
|
| 116 |
+
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
|
| 117 |
+
image1 = self.augment(image1)
|
| 118 |
+
image1 = torch.from_numpy(image1).permute(2, 0, 1).float() / 255.
|
| 119 |
+
images = torch.stack([image0, image1], dim=0)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
depth0 = np.load(str(self.scene_root / img_name0).replace('.jpg', f'.da.npy'))
|
| 123 |
+
depth0 = torch.from_numpy(depth0).float()
|
| 124 |
+
|
| 125 |
+
depth1 = np.load(str(self.scene_root / img_name1).replace('.jpg', f'.da.npy'))
|
| 126 |
+
depth1 = torch.from_numpy(depth1).float()
|
| 127 |
+
|
| 128 |
+
depths = torch.stack([depth0, depth1], dim=0)
|
| 129 |
+
|
| 130 |
+
# get absolute pose of im0 and im1
|
| 131 |
+
# quaternion and translation vector that transforms World-to-Cam
|
| 132 |
+
q1, t1 = self.poses[img_name0]
|
| 133 |
+
# quaternion and translation vector that transforms World-to-Cam
|
| 134 |
+
q2, t2 = self.poses[img_name1]
|
| 135 |
+
|
| 136 |
+
# get 4 x 4 relative pose transformation matrix (from im1 to im2)
|
| 137 |
+
# for test/val set, q1,t1 is the identity pose, so the relative pose matches the absolute pose
|
| 138 |
+
q12 = qmult(q2, qinverse(q1))
|
| 139 |
+
t12 = t2 - rotate_vector(t1, q12)
|
| 140 |
+
T = np.eye(4, dtype=np.float32)
|
| 141 |
+
T[:3, :3] = quat2mat(q12)
|
| 142 |
+
T[:3, -1] = t12
|
| 143 |
+
T = torch.from_numpy(T)
|
| 144 |
+
|
| 145 |
+
K_0 = torch.from_numpy(self.K[img_name0].copy()).reshape(3, 3)
|
| 146 |
+
K_1 = torch.from_numpy(self.K[img_name1].copy()).reshape(3, 3)
|
| 147 |
+
intrinsics = torch.stack([K_0, K_1], dim=0).float()
|
| 148 |
+
|
| 149 |
+
data = {
|
| 150 |
+
'images': images,
|
| 151 |
+
'depths': depths,
|
| 152 |
+
'rotation': T[:3, :3],
|
| 153 |
+
'translation': T[:3, 3],
|
| 154 |
+
'intrinsics': intrinsics,
|
| 155 |
+
'scene_id': self.scene_root.stem,
|
| 156 |
+
'scene_root': str(self.scene_root),
|
| 157 |
+
'pair_id': index*self.sample_factor,
|
| 158 |
+
'pair_names': (img_name0, img_name1),
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
return data
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def build_concat_mapfree(mode, config):
|
| 165 |
+
assert mode in ['train', 'val', 'test'], 'Invalid dataset mode'
|
| 166 |
+
|
| 167 |
+
data_root = Path(config.DATASET.DATA_ROOT) / mode
|
| 168 |
+
scenes = scenes = [s.name for s in data_root.iterdir() if s.is_dir()]
|
| 169 |
+
sample_factor = {'train': 1, 'val': 5, 'test': 1}[mode]
|
| 170 |
+
estimated_depth = config.DATASET.ESTIMATED_DEPTH
|
| 171 |
+
|
| 172 |
+
resize = (540, 720)
|
| 173 |
+
overlap_limits = (0.2, 0.7)
|
| 174 |
+
|
| 175 |
+
# Init dataset objects for each scene
|
| 176 |
+
datasets = [MapFreeScene(data_root / scene, resize, sample_factor, overlap_limits, estimated_depth, mode) for scene in scenes]
|
| 177 |
+
|
| 178 |
+
return data.ConcatDataset(datasets)
|
datasets/matterport.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
|
| 8 |
+
from utils import rotation_matrix_from_quaternion, Augmentor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Matterport3D(Dataset):
|
| 12 |
+
def __init__(self, data_root, mode='train'):
|
| 13 |
+
data_root = Path(data_root)
|
| 14 |
+
json_path = data_root / 'mp3d_planercnn_json' / f'cached_set_{mode}.json'
|
| 15 |
+
|
| 16 |
+
scene_info = {'images': [], 'rotation': [], 'translation': [], 'intrinsics': []}
|
| 17 |
+
|
| 18 |
+
with open(json_path) as f:
|
| 19 |
+
split = json.load(f)
|
| 20 |
+
|
| 21 |
+
for _, data in enumerate(split['data']):
|
| 22 |
+
images = []
|
| 23 |
+
for imgnum in ['0', '1']:
|
| 24 |
+
img_name = data_root / '/'.join(data[imgnum]['file_name'].split('/')[6:])
|
| 25 |
+
images.append(img_name)
|
| 26 |
+
|
| 27 |
+
rel_rotation = data['rel_pose']['rotation']
|
| 28 |
+
rel_translation = data['rel_pose']['position']
|
| 29 |
+
intrinsic = [
|
| 30 |
+
[517.97, 0, 320],
|
| 31 |
+
[0, 517.97, 240],
|
| 32 |
+
[0, 0, 1]
|
| 33 |
+
]
|
| 34 |
+
intrinsics = [intrinsic, intrinsic]
|
| 35 |
+
|
| 36 |
+
scene_info['images'].append(images)
|
| 37 |
+
scene_info['rotation'].append(rel_rotation)
|
| 38 |
+
scene_info['translation'].append(rel_translation)
|
| 39 |
+
scene_info['intrinsics'].append(intrinsics)
|
| 40 |
+
|
| 41 |
+
scene_info['rotation'] = torch.tensor(scene_info['rotation'])
|
| 42 |
+
scene_info['translation'] = torch.tensor(scene_info['translation'])
|
| 43 |
+
scene_info['intrinsics'] = torch.tensor(scene_info['intrinsics'])
|
| 44 |
+
|
| 45 |
+
self.scene_info = scene_info
|
| 46 |
+
self.augment = Augmentor(mode=='train')
|
| 47 |
+
|
| 48 |
+
self.is_training = mode == 'train'
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
return len(self.scene_info['images'])
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx):
|
| 54 |
+
img_names = self.scene_info['images'][idx]
|
| 55 |
+
rotation = self.scene_info['rotation'][idx]
|
| 56 |
+
translation = self.scene_info['translation'][idx]
|
| 57 |
+
intrinsics = self.scene_info['intrinsics'][idx]
|
| 58 |
+
|
| 59 |
+
images = []
|
| 60 |
+
for i in range(2):
|
| 61 |
+
image = cv2.imread(str(img_names[i]))
|
| 62 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 63 |
+
image = self.augment(image)
|
| 64 |
+
image = torch.from_numpy(image).permute(2, 0, 1)
|
| 65 |
+
images.append(image)
|
| 66 |
+
images = torch.stack(images)
|
| 67 |
+
images = images.float() / 255.
|
| 68 |
+
|
| 69 |
+
rotation = -rotation if rotation[0] < 0 else rotation
|
| 70 |
+
rotation /= rotation.norm(2)
|
| 71 |
+
rotation = rotation_matrix_from_quaternion(rotation[None,])[0]
|
| 72 |
+
|
| 73 |
+
rotation = rotation.mT
|
| 74 |
+
translation = -rotation @ translation.unsqueeze(-1)
|
| 75 |
+
translation = translation[:, 0]
|
| 76 |
+
|
| 77 |
+
return {
|
| 78 |
+
'images': images,
|
| 79 |
+
'rotation': rotation,
|
| 80 |
+
'translation': translation,
|
| 81 |
+
'intrinsics': intrinsics,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def build_matterport(mode, config):
|
| 86 |
+
return Matterport3D(config.DATASET.DATA_ROOT, mode)
|
datasets/megadepth.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 7 |
+
|
| 8 |
+
from utils import Augmentor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MegaDepthDataset(Dataset):
|
| 12 |
+
def __init__(self,
|
| 13 |
+
root_dir,
|
| 14 |
+
npz_path,
|
| 15 |
+
mode='train',
|
| 16 |
+
min_overlap_score=0.4,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Manage one scene(npz_path) of MegaDepth dataset.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
root_dir (str): megadepth root directory that has `phoenix`.
|
| 23 |
+
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
|
| 24 |
+
mode (str): options are ['train', 'val', 'test']
|
| 25 |
+
min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
|
| 26 |
+
"""
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.root_dir = root_dir
|
| 29 |
+
self.mode = mode
|
| 30 |
+
self.scene_id = npz_path.split('.')[0]
|
| 31 |
+
|
| 32 |
+
# prepare scene_info and pair_info
|
| 33 |
+
if mode == 'test':
|
| 34 |
+
min_overlap_score = 0
|
| 35 |
+
self.scene_info = np.load(npz_path, allow_pickle=True)
|
| 36 |
+
self.pair_infos = self.scene_info['pair_infos'].copy()
|
| 37 |
+
del self.scene_info['pair_infos']
|
| 38 |
+
self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score]
|
| 39 |
+
|
| 40 |
+
self.augment = Augmentor(mode=='train')
|
| 41 |
+
|
| 42 |
+
def __len__(self):
|
| 43 |
+
return len(self.pair_infos)
|
| 44 |
+
|
| 45 |
+
def __getitem__(self, idx):
|
| 46 |
+
(idx0, idx1), overlap_score, central_matches = self.pair_infos[idx]
|
| 47 |
+
|
| 48 |
+
# read grayscale image and mask. (1, h, w) and (h, w)
|
| 49 |
+
img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
|
| 50 |
+
img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
|
| 51 |
+
|
| 52 |
+
w_new, h_new = 640, 480
|
| 53 |
+
|
| 54 |
+
image0 = cv2.imread(img_name0)
|
| 55 |
+
scale0 = torch.tensor([image0.shape[1]/w_new, image0.shape[0]/h_new], dtype=torch.float)
|
| 56 |
+
image0 = cv2.resize(image0, (w_new, h_new))
|
| 57 |
+
image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)
|
| 58 |
+
# image0 = self.augment(image0)
|
| 59 |
+
image0 = torch.from_numpy(image0).permute(2, 0, 1).float() / 255.
|
| 60 |
+
|
| 61 |
+
image1 = cv2.imread(img_name1)
|
| 62 |
+
scale1 = torch.tensor([image1.shape[1]/w_new, image1.shape[0]/h_new], dtype=torch.float)
|
| 63 |
+
image1 = cv2.resize(image1, (w_new, h_new))
|
| 64 |
+
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
|
| 65 |
+
# image1 = self.augment(image1)
|
| 66 |
+
image1 = torch.from_numpy(image1).permute(2, 0, 1).float() / 255.
|
| 67 |
+
|
| 68 |
+
scales = torch.stack([scale0, scale1], dim=0)
|
| 69 |
+
images = torch.stack([image0, image1], dim=0)
|
| 70 |
+
|
| 71 |
+
# read intrinsics of original size
|
| 72 |
+
K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
|
| 73 |
+
K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
|
| 74 |
+
intrinsics = torch.stack([K_0, K_1], dim=0)
|
| 75 |
+
|
| 76 |
+
# read and compute relative poses
|
| 77 |
+
T0 = self.scene_info['poses'][idx0]
|
| 78 |
+
T1 = self.scene_info['poses'][idx1]
|
| 79 |
+
T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
|
| 80 |
+
|
| 81 |
+
data = {
|
| 82 |
+
'images': images,
|
| 83 |
+
'scales': scales, # (2, 2): [scale_w, scale_h]
|
| 84 |
+
'rotation': T_0to1[:3, :3],
|
| 85 |
+
'translation': T_0to1[:3, 3],
|
| 86 |
+
'intrinsics': intrinsics,
|
| 87 |
+
'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
|
| 88 |
+
'depth_pair_names': (self.scene_info['depth_paths'][idx0], self.scene_info['depth_paths'][idx1]),
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
return data
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def build_concat_megadepth(mode, config):
|
| 95 |
+
if mode == 'train':
|
| 96 |
+
config = config.DATASET.TRAIN
|
| 97 |
+
elif mode == 'val':
|
| 98 |
+
config = config.DATASET.VAL
|
| 99 |
+
elif mode == 'test':
|
| 100 |
+
config = config.DATASET.TEST
|
| 101 |
+
else:
|
| 102 |
+
raise NotImplementedError(f'mode {mode}')
|
| 103 |
+
|
| 104 |
+
data_root = config.DATA_ROOT
|
| 105 |
+
# pose_root = config.POSE_ROOT
|
| 106 |
+
npz_root = config.NPZ_ROOT
|
| 107 |
+
list_path = config.LIST_PATH
|
| 108 |
+
# intrinsic_path = config.INTRINSIC_PATH
|
| 109 |
+
min_overlap_score = config.MIN_OVERLAP_SCORE
|
| 110 |
+
|
| 111 |
+
with open(list_path, 'r') as f:
|
| 112 |
+
npz_names = [name.split()[0] for name in f.readlines()]
|
| 113 |
+
|
| 114 |
+
datasets = []
|
| 115 |
+
npz_names = [f'{n}.npz' for n in npz_names]
|
| 116 |
+
for npz_name in tqdm(npz_names, desc=f'Loading MegaDepth {mode} datasets',):
|
| 117 |
+
npz_path = osp.join(npz_root, npz_name)
|
| 118 |
+
datasets.append(MegaDepthDataset(
|
| 119 |
+
data_root,
|
| 120 |
+
npz_path,
|
| 121 |
+
mode=mode,
|
| 122 |
+
min_overlap_score=min_overlap_score,
|
| 123 |
+
))
|
| 124 |
+
|
| 125 |
+
return ConcatDataset(datasets)
|
datasets/sampler.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Sampler, ConcatDataset
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RandomConcatSampler(Sampler):
|
| 6 |
+
""" Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
|
| 7 |
+
in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
|
| 8 |
+
However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
|
| 9 |
+
|
| 10 |
+
For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
|
| 11 |
+
Args:
|
| 12 |
+
shuffle (bool): shuffle the random sampled indices across all sub-datsets.
|
| 13 |
+
repeat (int): repeatedly use the sampled indices multiple times for training.
|
| 14 |
+
[arXiv:1902.05509, arXiv:1901.09335]
|
| 15 |
+
NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
|
| 16 |
+
NOTE: This sampler behaves differently with DistributedSampler.
|
| 17 |
+
It assume the dataset is splitted across ranks instead of replicated.
|
| 18 |
+
TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
|
| 19 |
+
ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
|
| 20 |
+
"""
|
| 21 |
+
def __init__(self,
|
| 22 |
+
data_source: ConcatDataset,
|
| 23 |
+
n_samples_per_subset: int,
|
| 24 |
+
subset_replacement: bool=True,
|
| 25 |
+
shuffle: bool=True,
|
| 26 |
+
repeat: int=1,
|
| 27 |
+
seed: int=None):
|
| 28 |
+
if not isinstance(data_source, ConcatDataset):
|
| 29 |
+
raise TypeError("data_source should be torch.utils.data.ConcatDataset")
|
| 30 |
+
|
| 31 |
+
self.data_source = data_source
|
| 32 |
+
self.n_subset = len(self.data_source.datasets)
|
| 33 |
+
self.n_samples_per_subset = n_samples_per_subset
|
| 34 |
+
self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
|
| 35 |
+
self.subset_replacement = subset_replacement
|
| 36 |
+
self.repeat = repeat
|
| 37 |
+
self.shuffle = shuffle
|
| 38 |
+
self.generator = torch.manual_seed(seed)
|
| 39 |
+
assert self.repeat >= 1
|
| 40 |
+
|
| 41 |
+
def __len__(self):
|
| 42 |
+
return self.n_samples
|
| 43 |
+
|
| 44 |
+
def __iter__(self):
|
| 45 |
+
indices = []
|
| 46 |
+
# sample from each sub-dataset
|
| 47 |
+
for d_idx in range(self.n_subset):
|
| 48 |
+
low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1]
|
| 49 |
+
high = self.data_source.cumulative_sizes[d_idx]
|
| 50 |
+
if self.subset_replacement:
|
| 51 |
+
rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ),
|
| 52 |
+
generator=self.generator, dtype=torch.int64)
|
| 53 |
+
else: # sample without replacement
|
| 54 |
+
len_subset = len(self.data_source.datasets[d_idx])
|
| 55 |
+
rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
|
| 56 |
+
if len_subset >= self.n_samples_per_subset:
|
| 57 |
+
rand_tensor = rand_tensor[:self.n_samples_per_subset]
|
| 58 |
+
else: # padding with replacement
|
| 59 |
+
rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ),
|
| 60 |
+
generator=self.generator, dtype=torch.int64)
|
| 61 |
+
rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
|
| 62 |
+
indices.append(rand_tensor)
|
| 63 |
+
indices = torch.cat(indices)
|
| 64 |
+
if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
|
| 65 |
+
rand_tensor = torch.randperm(len(indices), generator=self.generator)
|
| 66 |
+
indices = indices[rand_tensor]
|
| 67 |
+
|
| 68 |
+
# repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
|
| 69 |
+
if self.repeat > 1:
|
| 70 |
+
repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
|
| 71 |
+
if self.shuffle:
|
| 72 |
+
_choice = lambda x: x[torch.randperm(len(x), generator=self.generator)]
|
| 73 |
+
repeat_indices = map(_choice, repeat_indices)
|
| 74 |
+
indices = torch.cat([indices, *repeat_indices], 0)
|
| 75 |
+
|
| 76 |
+
assert indices.shape[0] == self.n_samples
|
| 77 |
+
return iter(indices.tolist())
|
datasets/scannet.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import path as osp
|
| 2 |
+
import numpy as np
|
| 3 |
+
from numpy.linalg import inv
|
| 4 |
+
import cv2
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 9 |
+
|
| 10 |
+
from utils import Augmentor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def read_scannet_pose(path):
|
| 14 |
+
""" Read ScanNet's Camera2World pose and transform it to World2Camera.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
pose_w2c (np.ndarray): (4, 4)
|
| 18 |
+
"""
|
| 19 |
+
cam2world = np.loadtxt(path, delimiter=' ')
|
| 20 |
+
world2cam = inv(cam2world)
|
| 21 |
+
return world2cam
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ScanNetDataset(Dataset):
|
| 25 |
+
def __init__(self,
|
| 26 |
+
root_dir,
|
| 27 |
+
npz_path,
|
| 28 |
+
intrinsic_path,
|
| 29 |
+
mode='train',
|
| 30 |
+
min_overlap_score=0.4,
|
| 31 |
+
pose_dir=None,
|
| 32 |
+
):
|
| 33 |
+
"""Manage one scene of ScanNet Dataset.
|
| 34 |
+
Args:
|
| 35 |
+
root_dir (str): ScanNet root directory that contains scene folders.
|
| 36 |
+
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
|
| 37 |
+
intrinsic_path (str): path to depth-camera intrinsic file.
|
| 38 |
+
mode (str): options are ['train', 'val', 'test'].
|
| 39 |
+
pose_dir (str): ScanNet root directory that contains all poses.
|
| 40 |
+
(we use a separate (optional) pose_dir since we store images and poses separately.)
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.root_dir = root_dir
|
| 44 |
+
self.pose_dir = pose_dir if pose_dir is not None else root_dir
|
| 45 |
+
self.mode = mode
|
| 46 |
+
|
| 47 |
+
# prepare data_names, intrinsics and extrinsics(T)
|
| 48 |
+
with np.load(npz_path) as data:
|
| 49 |
+
self.data_names = data['name']
|
| 50 |
+
if 'score' in data.keys() and mode not in ['val' or 'test']:
|
| 51 |
+
kept_mask = data['score'] > min_overlap_score
|
| 52 |
+
self.data_names = self.data_names[kept_mask]
|
| 53 |
+
self.intrinsics = dict(np.load(intrinsic_path))
|
| 54 |
+
self.augment = Augmentor(mode=='train')
|
| 55 |
+
|
| 56 |
+
def __len__(self):
|
| 57 |
+
return len(self.data_names)
|
| 58 |
+
|
| 59 |
+
def _read_abs_pose(self, scene_name, name):
|
| 60 |
+
pth = osp.join(self.pose_dir,
|
| 61 |
+
scene_name,
|
| 62 |
+
'pose', f'{name}.txt')
|
| 63 |
+
return read_scannet_pose(pth)
|
| 64 |
+
|
| 65 |
+
def _compute_rel_pose(self, scene_name, name0, name1):
|
| 66 |
+
pose0 = self._read_abs_pose(scene_name, name0)
|
| 67 |
+
pose1 = self._read_abs_pose(scene_name, name1)
|
| 68 |
+
|
| 69 |
+
return np.matmul(pose1, inv(pose0)) # (4, 4)
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, idx):
|
| 72 |
+
data_name = self.data_names[idx]
|
| 73 |
+
scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
|
| 74 |
+
scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
|
| 75 |
+
|
| 76 |
+
img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg')
|
| 77 |
+
img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg')
|
| 78 |
+
|
| 79 |
+
w_new, h_new = 640, 480
|
| 80 |
+
|
| 81 |
+
image0 = cv2.imread(img_name0)
|
| 82 |
+
image0 = cv2.resize(image0, (w_new, h_new))
|
| 83 |
+
image0 = cv2.cvtColor(image0, cv2.COLOR_BGR2RGB)
|
| 84 |
+
# image0 = self.augment(image0)
|
| 85 |
+
image0 = torch.from_numpy(image0).permute(2, 0, 1).float() / 255.
|
| 86 |
+
|
| 87 |
+
image1 = cv2.imread(img_name1)
|
| 88 |
+
image1 = cv2.resize(image1, (w_new, h_new))
|
| 89 |
+
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
|
| 90 |
+
# image1 = self.augment(image1)
|
| 91 |
+
image1 = torch.from_numpy(image1).permute(2, 0, 1).float() / 255.
|
| 92 |
+
images = torch.stack([image0, image1], dim=0)
|
| 93 |
+
|
| 94 |
+
# depth0 = cv2.imread(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png'), cv2.IMREAD_UNCHANGED)
|
| 95 |
+
# depth0 = depth0 / 1000
|
| 96 |
+
# depth0 = torch.from_numpy(depth0).float()
|
| 97 |
+
|
| 98 |
+
# depth1 = cv2.imread(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png'), cv2.IMREAD_UNCHANGED)
|
| 99 |
+
# depth1 = depth1 / 1000
|
| 100 |
+
# depth1 = torch.from_numpy(depth1).float()
|
| 101 |
+
# depths = torch.stack([depth0, depth1], dim=0)
|
| 102 |
+
|
| 103 |
+
K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
|
| 104 |
+
intrinsics = torch.stack([K_0, K_1], dim=0)
|
| 105 |
+
|
| 106 |
+
# read and compute relative poses
|
| 107 |
+
T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1),
|
| 108 |
+
dtype=torch.float32)
|
| 109 |
+
|
| 110 |
+
data = {
|
| 111 |
+
'images': images,
|
| 112 |
+
# 'depths': depths,
|
| 113 |
+
'rotation': T_0to1[:3, :3],
|
| 114 |
+
'translation': T_0to1[:3, 3],
|
| 115 |
+
'intrinsics': intrinsics,
|
| 116 |
+
'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'),
|
| 117 |
+
osp.join(scene_name, 'color', f'{stem_name_1}.jpg'))
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
return data
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def build_concat_scannet(mode, config):
|
| 124 |
+
if mode == 'train':
|
| 125 |
+
config = config.DATASET.TRAIN
|
| 126 |
+
elif mode == 'val':
|
| 127 |
+
config = config.DATASET.VAL
|
| 128 |
+
elif mode == 'test':
|
| 129 |
+
config = config.DATASET.TEST
|
| 130 |
+
else:
|
| 131 |
+
raise NotImplementedError(f'mode {mode}')
|
| 132 |
+
|
| 133 |
+
data_root = config.DATA_ROOT
|
| 134 |
+
npz_root = config.NPZ_ROOT
|
| 135 |
+
list_path = config.LIST_PATH
|
| 136 |
+
intrinsic_path = config.INTRINSIC_PATH
|
| 137 |
+
min_overlap_score = config.MIN_OVERLAP_SCORE
|
| 138 |
+
|
| 139 |
+
with open(list_path, 'r') as f:
|
| 140 |
+
npz_names = [name.split()[0] for name in f.readlines()]
|
| 141 |
+
|
| 142 |
+
datasets = []
|
| 143 |
+
for npz_name in tqdm(npz_names, desc=f'Loading ScanNet {mode} datasets',):
|
| 144 |
+
npz_path = osp.join(npz_root, npz_name)
|
| 145 |
+
datasets.append(ScanNetDataset(
|
| 146 |
+
data_root,
|
| 147 |
+
npz_path,
|
| 148 |
+
intrinsic_path,
|
| 149 |
+
mode=mode,
|
| 150 |
+
min_overlap_score=min_overlap_score,
|
| 151 |
+
))
|
| 152 |
+
|
| 153 |
+
return ConcatDataset(datasets)
|
| 154 |
+
|
eval.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
import lightning as L
|
| 4 |
+
|
| 5 |
+
from datasets import dataset_dict
|
| 6 |
+
from model import PL_RelPose, keypoint_dict
|
| 7 |
+
from configs.default import get_cfg_defaults
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main(args):
|
| 11 |
+
config = get_cfg_defaults()
|
| 12 |
+
config.merge_from_file(args.config)
|
| 13 |
+
|
| 14 |
+
task = config.DATASET.TASK
|
| 15 |
+
dataset = config.DATASET.DATA_SOURCE
|
| 16 |
+
|
| 17 |
+
batch_size = config.TRAINER.BATCH_SIZE
|
| 18 |
+
num_workers = config.TRAINER.NUM_WORKERS
|
| 19 |
+
pin_memory = config.TRAINER.PIN_MEMORY
|
| 20 |
+
|
| 21 |
+
test_num_keypoints = config.MODEL.TEST_NUM_KEYPOINTS
|
| 22 |
+
|
| 23 |
+
build_fn = dataset_dict[task][dataset]
|
| 24 |
+
testset = build_fn('test', config)
|
| 25 |
+
testloader = DataLoader(testset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
|
| 26 |
+
|
| 27 |
+
pl_relpose = PL_RelPose.load_from_checkpoint(args.ckpt_path)
|
| 28 |
+
pl_relpose.extractor = keypoint_dict[pl_relpose.hparams['features']](max_num_keypoints=test_num_keypoints, detection_threshold=0.0).eval()
|
| 29 |
+
|
| 30 |
+
trainer = L.Trainer(
|
| 31 |
+
devices=[0],
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
trainer.test(pl_relpose, dataloaders=testloader)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_parser():
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument('config', type=str, help='.yaml configure file path')
|
| 40 |
+
parser.add_argument('ckpt_path', type=str)
|
| 41 |
+
|
| 42 |
+
return parser
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
parser = get_parser()
|
| 47 |
+
args = parser.parse_args()
|
| 48 |
+
main(args)
|
eval_add_reproj.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from transforms3d.quaternions import mat2quat
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
from model import PL_RelPose, keypoint_dict
|
| 10 |
+
from utils.reproject import reprojection_error, Pose, save_submission
|
| 11 |
+
from utils.metrics import reproj, add, adi, compute_continuous_auc, relative_pose_error, rotation_angular_error
|
| 12 |
+
from datasets import dataset_dict
|
| 13 |
+
from configs.default import get_cfg_defaults
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@torch.no_grad()
|
| 17 |
+
def main(args):
|
| 18 |
+
config = get_cfg_defaults()
|
| 19 |
+
config.merge_from_file(args.config)
|
| 20 |
+
|
| 21 |
+
task = config.DATASET.TASK
|
| 22 |
+
dataset = config.DATASET.DATA_SOURCE
|
| 23 |
+
device = args.device
|
| 24 |
+
|
| 25 |
+
test_num_keypoints = test_num_keypoints = config.MODEL.TEST_NUM_KEYPOINTS
|
| 26 |
+
|
| 27 |
+
build_fn = dataset_dict[task][dataset]
|
| 28 |
+
testset = build_fn('test', config)
|
| 29 |
+
testloader = torch.utils.data.DataLoader(testset, batch_size=1)
|
| 30 |
+
|
| 31 |
+
pl_relpose = PL_RelPose.load_from_checkpoint(args.ckpt_path)
|
| 32 |
+
pl_relpose.extractor = keypoint_dict[pl_relpose.hparams['features']](max_num_keypoints=test_num_keypoints, detection_threshold=0.0).eval().to(device)
|
| 33 |
+
pl_relpose.module = pl_relpose.module.eval().to(device)
|
| 34 |
+
|
| 35 |
+
preprocess_times, extract_times, regress_times = [], [], []
|
| 36 |
+
adds, adis = [], []
|
| 37 |
+
repr_errs = []
|
| 38 |
+
R_errs, t_errs = [], []
|
| 39 |
+
ts_errs = []
|
| 40 |
+
results_dict = defaultdict(list)
|
| 41 |
+
for i, data in enumerate(tqdm(testloader)):
|
| 42 |
+
if dataset == 'ho3d' and args.obj_name is not None and data['objName'][0] != args.obj_name:
|
| 43 |
+
continue
|
| 44 |
+
image0, image1 = data['images'][0]
|
| 45 |
+
K0, K1 = data['intrinsics'][0]
|
| 46 |
+
T = torch.eye(4)
|
| 47 |
+
T[:3, :3] = data['rotation'][0]
|
| 48 |
+
T[:3, 3] = data['translation'][0]
|
| 49 |
+
T = T.numpy()
|
| 50 |
+
|
| 51 |
+
# with record_function("model_inference"):
|
| 52 |
+
R_est, t_est, preprocess_time, extract_time, regress_time = pl_relpose.predict_one_data(data)
|
| 53 |
+
preprocess_times.append(preprocess_time)
|
| 54 |
+
extract_times.append(extract_time)
|
| 55 |
+
regress_times.append(regress_time)
|
| 56 |
+
|
| 57 |
+
t_err, R_err = relative_pose_error(T, R_est.cpu().numpy(), t_est.cpu().numpy(), ignore_gt_t_thr=0.0)
|
| 58 |
+
|
| 59 |
+
R_errs.append(R_err)
|
| 60 |
+
t_errs.append(t_err)
|
| 61 |
+
|
| 62 |
+
ts_errs.append(torch.tensor(T[:3, 3] - t_est.cpu().numpy()).norm(2))
|
| 63 |
+
|
| 64 |
+
if dataset == 'mapfree':
|
| 65 |
+
repr_err = reprojection_error(R_est.cpu().numpy(), t_est.cpu().numpy(), T[:3, :3], T[:3, 3], K=K1, W=image1.shape[-1], H=image1.shape[-2])
|
| 66 |
+
repr_errs.append(repr_err)
|
| 67 |
+
R = R_est.detach().cpu().numpy()
|
| 68 |
+
t = t_est.reshape(-1).detach().cpu().numpy()
|
| 69 |
+
scene = data['scene_id'][0]
|
| 70 |
+
estimated_pose = Pose(
|
| 71 |
+
image_name=data['pair_names'][1][0],
|
| 72 |
+
q=mat2quat(R).reshape(-1),
|
| 73 |
+
t=t.reshape(-1),
|
| 74 |
+
inliers=0
|
| 75 |
+
)
|
| 76 |
+
results_dict[scene].append(estimated_pose)
|
| 77 |
+
|
| 78 |
+
if 'point_cloud' in data:
|
| 79 |
+
adds.append(add(R_est.cpu().numpy(), t_est.cpu().numpy(), T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
| 80 |
+
adis.append(adi(R_est.cpu().numpy(), t_est.cpu().numpy(), T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
| 81 |
+
|
| 82 |
+
metrics = []
|
| 83 |
+
values = []
|
| 84 |
+
|
| 85 |
+
preprocess_times = np.array(preprocess_times) * 1000
|
| 86 |
+
extract_times = np.array(extract_times) * 1000
|
| 87 |
+
regress_times = np.array(regress_times) * 1000
|
| 88 |
+
|
| 89 |
+
metrics.append('Extracting Time (ms)')
|
| 90 |
+
values.append(f'{np.mean(extract_times):.1f}')
|
| 91 |
+
|
| 92 |
+
metrics.append('Recovering Time (ms)')
|
| 93 |
+
values.append(f'{np.mean(regress_times):.1f}')
|
| 94 |
+
|
| 95 |
+
metrics.append('Total Time (ms)')
|
| 96 |
+
values.append(f'{np.mean(extract_times) + np.mean(regress_times):.1f}')
|
| 97 |
+
|
| 98 |
+
# ts_errs = np.array(ts_errs)
|
| 99 |
+
# print(f'Median Trans. Error (m):\t{np.median(ts_errs):.2f}')
|
| 100 |
+
# print(f'Median Rot. Error (°):\t{np.median(R_errs):.2f}')
|
| 101 |
+
|
| 102 |
+
if task == 'object':
|
| 103 |
+
metrics.append('Object ADD')
|
| 104 |
+
values.append(f'{compute_continuous_auc(adds, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
| 105 |
+
|
| 106 |
+
metrics.append('Object ADD-S')
|
| 107 |
+
values.append(f'{compute_continuous_auc(adis, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
| 108 |
+
|
| 109 |
+
if dataset == 'mapfree':
|
| 110 |
+
re = np.array(repr_errs)
|
| 111 |
+
|
| 112 |
+
metrics.append('VCRE @90px Prec.')
|
| 113 |
+
values.append(f'{(re < 90).mean() * 100:.2f}')
|
| 114 |
+
|
| 115 |
+
metrics.append('VCRE Med.')
|
| 116 |
+
values.append(f'{np.median(re):.2f}')
|
| 117 |
+
|
| 118 |
+
save_submission(results_dict, 'assets/new_submission.zip')
|
| 119 |
+
|
| 120 |
+
res = pd.DataFrame({'Metrics': metrics, 'Values': values})
|
| 121 |
+
print(res)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_parser():
|
| 125 |
+
parser = argparse.ArgumentParser()
|
| 126 |
+
parser.add_argument('config', type=str, help='.yaml configure file path')
|
| 127 |
+
parser.add_argument('ckpt_path', type=str)
|
| 128 |
+
|
| 129 |
+
parser.add_argument('--device', type=str, default='cuda:0')
|
| 130 |
+
parser.add_argument('--obj_name', type=str, default=None)
|
| 131 |
+
|
| 132 |
+
return parser
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
parser = get_parser()
|
| 137 |
+
args = parser.parse_args()
|
| 138 |
+
main(args)
|
eval_baselines.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import argparse
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import torch
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
# from lightglue.utils import load_image
|
| 8 |
+
from configs.default import get_cfg_defaults
|
| 9 |
+
from datasets import dataset_dict
|
| 10 |
+
from baselines.pose import PoseRecover
|
| 11 |
+
from utils.metrics import relative_pose_error, rotation_angular_error, error_auc, add, adi, compute_continuous_auc
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main(args):
|
| 15 |
+
config = get_cfg_defaults()
|
| 16 |
+
config.merge_from_file(args.config)
|
| 17 |
+
|
| 18 |
+
task = config.DATASET.TASK
|
| 19 |
+
dataset = config.DATASET.DATA_SOURCE
|
| 20 |
+
|
| 21 |
+
# try:
|
| 22 |
+
# data_root = config.DATASET.TEST.DATA_ROOT
|
| 23 |
+
# except:
|
| 24 |
+
# data_root = config.DATASET.DATA_ROOT
|
| 25 |
+
|
| 26 |
+
build_fn = dataset_dict[task][dataset]
|
| 27 |
+
testset = build_fn('test', config)
|
| 28 |
+
testloader = torch.utils.data.DataLoader(testset, batch_size=1)
|
| 29 |
+
|
| 30 |
+
device = args.device
|
| 31 |
+
img_resize = args.resize
|
| 32 |
+
poseRec = PoseRecover(matcher=args.matcher, solver=args.solver, img_resize=img_resize, device=device)
|
| 33 |
+
|
| 34 |
+
preprocess_times, extract_times, match_times, recover_times = [], [], [], []
|
| 35 |
+
R_errs, t_errs = [], []
|
| 36 |
+
ts_errs = []
|
| 37 |
+
adds, adis = [], []
|
| 38 |
+
for i, data in enumerate(tqdm(testloader)):
|
| 39 |
+
if dataset == 'ho3d' and args.obj_name is not None and data['objName'][0] != args.obj_name:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
image0, image1 = data['images'][0].to(device)
|
| 43 |
+
# if dataset == 'megadepth':
|
| 44 |
+
# image0 = load_image(os.path.join(data_root, data['pair_names'][0][0])).to(device)
|
| 45 |
+
# image1 = load_image(os.path.join(data_root, data['pair_names'][1][0])).to(device)
|
| 46 |
+
# else:
|
| 47 |
+
# image0, image1 = data['images'][0].to(device)
|
| 48 |
+
|
| 49 |
+
bbox0, bbox1 = None, None
|
| 50 |
+
if task == 'object':
|
| 51 |
+
bbox0, bbox1 = data['bboxes'][0]
|
| 52 |
+
x1, y1, x2, y2 = bbox0
|
| 53 |
+
u1, v1, u2, v2 = bbox1
|
| 54 |
+
image0 = image0[:, y1:y2, x1:x2]
|
| 55 |
+
image1 = image1[:, v1:v2, u1:u2]
|
| 56 |
+
|
| 57 |
+
mask0, mask1 = None, None
|
| 58 |
+
if args.mask:
|
| 59 |
+
mask0, mask1 = data['masks'][0].to(device)
|
| 60 |
+
|
| 61 |
+
depth0, depth1 = None, None
|
| 62 |
+
if args.depth:
|
| 63 |
+
depth0, depth1 = data['depths'][0]
|
| 64 |
+
|
| 65 |
+
K0, K1 = data['intrinsics'][0]
|
| 66 |
+
T = torch.eye(4)
|
| 67 |
+
T[:3, :3] = data['rotation'][0]
|
| 68 |
+
T[:3, 3] = data['translation'][0]
|
| 69 |
+
T = T.numpy()
|
| 70 |
+
|
| 71 |
+
R, t, points0, points1, preprocess_time, extract_time, match_time, recover_time = poseRec.recover(image0, image1, K0, K1, bbox0, bbox1, mask0, mask1, depth0, depth1)
|
| 72 |
+
preprocess_times.append(preprocess_time)
|
| 73 |
+
extract_times.append(extract_time)
|
| 74 |
+
match_times.append(match_time)
|
| 75 |
+
recover_times.append(recover_time)
|
| 76 |
+
|
| 77 |
+
if np.isnan(R).any():
|
| 78 |
+
R_err = 180
|
| 79 |
+
R = np.identity(3)
|
| 80 |
+
t_err = 180
|
| 81 |
+
t = np.array([0., 0., 0.])
|
| 82 |
+
else:
|
| 83 |
+
t_err, R_err = relative_pose_error(T, R, t, ignore_gt_t_thr=0.0)
|
| 84 |
+
|
| 85 |
+
R_errs.append(R_err)
|
| 86 |
+
t_errs.append(t_err)
|
| 87 |
+
|
| 88 |
+
if args.depth:
|
| 89 |
+
t = np.nan_to_num(t)
|
| 90 |
+
ts_errs.append(torch.tensor(T[:3, 3] - t).norm(2))
|
| 91 |
+
|
| 92 |
+
if task == 'object':
|
| 93 |
+
if np.isnan(R).any():
|
| 94 |
+
adds.append(1.)
|
| 95 |
+
adis.append(1.)
|
| 96 |
+
else:
|
| 97 |
+
adds.append(add(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
| 98 |
+
adis.append(adi(R, t, T[:3, :3], T[:3, 3], data['point_cloud'][0].numpy()))
|
| 99 |
+
|
| 100 |
+
metrics = []
|
| 101 |
+
values = []
|
| 102 |
+
|
| 103 |
+
preprocess_times = np.array(preprocess_time) * 1000
|
| 104 |
+
extract_times = np.array(extract_time) * 1000
|
| 105 |
+
match_times = np.array(match_times) * 1000
|
| 106 |
+
recover_times = np.array(recover_time) * 1000
|
| 107 |
+
|
| 108 |
+
metrics.append('Extracting Time (ms)')
|
| 109 |
+
values.append(f'{np.mean(extract_times):.1f}')
|
| 110 |
+
|
| 111 |
+
metrics.append('Matching Time (ms)')
|
| 112 |
+
values.append(f'{np.mean(match_times):.1f}')
|
| 113 |
+
|
| 114 |
+
metrics.append('Recovering Time (ms)')
|
| 115 |
+
values.append(f'{np.mean(recover_times):.1f}')
|
| 116 |
+
|
| 117 |
+
metrics.append('Total Time (ms)')
|
| 118 |
+
values.append(f'{np.mean(extract_times) + np.mean(match_times) + np.mean(recover_times):.1f}')
|
| 119 |
+
|
| 120 |
+
# pose auc
|
| 121 |
+
angular_thresholds = [5, 10, 20]
|
| 122 |
+
pose_errors = np.max(np.stack([R_errs, t_errs]), axis=0)
|
| 123 |
+
aucs = error_auc(pose_errors, angular_thresholds, mode='Pose estimation') # (auc@5, auc@10, auc@20)
|
| 124 |
+
for k in aucs:
|
| 125 |
+
metrics.append(k)
|
| 126 |
+
values.append(f'{aucs[k] * 100:.2f}')
|
| 127 |
+
|
| 128 |
+
R_errs = torch.tensor(R_errs)
|
| 129 |
+
t_errs = torch.tensor(t_errs)
|
| 130 |
+
|
| 131 |
+
metrics.append('Rotation Avg. Error (°)')
|
| 132 |
+
values.append(f'{R_errs.mean():.2f}')
|
| 133 |
+
|
| 134 |
+
metrics.append('Rotation Med. Error (°)')
|
| 135 |
+
values.append(f'{R_errs.median():.2f}')
|
| 136 |
+
|
| 137 |
+
metrics.append('Rotation @30° ACC')
|
| 138 |
+
values.append(f'{(R_errs < 30).float().mean() * 100:.1f}')
|
| 139 |
+
|
| 140 |
+
metrics.append('Rotation @15° ACC')
|
| 141 |
+
values.append(f'{(R_errs < 15).float().mean() * 100:.1f}')
|
| 142 |
+
|
| 143 |
+
if args.depth:
|
| 144 |
+
ts_errs = torch.tensor(ts_errs)
|
| 145 |
+
|
| 146 |
+
metrics.append('Translation Avg. Error (m)')
|
| 147 |
+
values.append(f'{ts_errs.mean():.4f}')
|
| 148 |
+
|
| 149 |
+
metrics.append('Translation Med. Error (m)')
|
| 150 |
+
values.append(f'{ts_errs.median():.4f}')
|
| 151 |
+
|
| 152 |
+
metrics.append('Translation @1m ACC')
|
| 153 |
+
values.append(f'{(ts_errs < 1.0).float().mean() * 100:.1f}')
|
| 154 |
+
|
| 155 |
+
metrics.append('Translation @10cm ACC')
|
| 156 |
+
values.append(f'{(ts_errs < 0.1).float().mean() * 100:.1f}')
|
| 157 |
+
|
| 158 |
+
if task == 'object':
|
| 159 |
+
metrics.append('Object ADD')
|
| 160 |
+
values.append(f'{compute_continuous_auc(adds, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
| 161 |
+
|
| 162 |
+
metrics.append('Object ADD-S')
|
| 163 |
+
values.append(f'{compute_continuous_auc(adis, np.linspace(0.0, 0.1, 1000)) * 100:.1f}')
|
| 164 |
+
|
| 165 |
+
res = pd.DataFrame({'Metrics': metrics, 'Values': values})
|
| 166 |
+
print(res)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_parser():
|
| 170 |
+
parser = argparse.ArgumentParser()
|
| 171 |
+
parser.add_argument('config', type=str, help='.yaml configure file path')
|
| 172 |
+
parser.add_argument('matcher', type=str)
|
| 173 |
+
parser.add_argument('--solver', type=str, default='procrustes')
|
| 174 |
+
|
| 175 |
+
parser.add_argument('--resize', type=int, default=None)
|
| 176 |
+
parser.add_argument('--depth', action='store_true')
|
| 177 |
+
|
| 178 |
+
parser.add_argument('--mask', action='store_true')
|
| 179 |
+
parser.add_argument('--obj_name', type=str, default=None)
|
| 180 |
+
|
| 181 |
+
parser.add_argument('--device', type=str, default='cuda:0')
|
| 182 |
+
|
| 183 |
+
return parser
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
parser = get_parser()
|
| 188 |
+
args = parser.parse_args()
|
| 189 |
+
main(args)
|
model/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .relpose import RelPose
|
| 4 |
+
from .pl_trainer import PL_RelPose, keypoint_dict
|
model/pl_trainer.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import lightning as L
|
| 5 |
+
from lightglue import SuperPoint, DISK, SIFT, ALIKED
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
from utils import rotation_angular_error, translation_angular_error, error_auc
|
| 9 |
+
from .relpose import RelPose
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
keypoint_dict = {
|
| 13 |
+
'superpoint': SuperPoint,
|
| 14 |
+
'disk': DISK,
|
| 15 |
+
'sift': SIFT,
|
| 16 |
+
'aliked': ALIKED,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PL_RelPose(L.LightningModule):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
task,
|
| 24 |
+
lr,
|
| 25 |
+
epochs,
|
| 26 |
+
pct_start,
|
| 27 |
+
num_keypoints,
|
| 28 |
+
n_layers,
|
| 29 |
+
num_heads,
|
| 30 |
+
features='superpoint',
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.extractor = keypoint_dict[features](max_num_keypoints=num_keypoints, detection_threshold=0.0).eval()
|
| 35 |
+
self.module = RelPose(features=features, task=task, n_layers=n_layers, num_heads=num_heads)
|
| 36 |
+
self.criterion = torch.nn.HuberLoss()
|
| 37 |
+
|
| 38 |
+
self.s_r = torch.nn.Parameter(torch.zeros(1))
|
| 39 |
+
# self.s_ta = torch.nn.Parameter(torch. zeros(1))
|
| 40 |
+
self.s_t = torch.nn.Parameter(torch.zeros(1))
|
| 41 |
+
|
| 42 |
+
self.r_errors = {k:[] for k in ['train', 'valid', 'test']}
|
| 43 |
+
self.ta_errors = {k:[] for k in ['train', 'valid', 'test']}
|
| 44 |
+
self.t_errors = {k:[] for k in ['train', 'valid', 'test']}
|
| 45 |
+
|
| 46 |
+
self.save_hyperparameters()
|
| 47 |
+
|
| 48 |
+
def _shared_log(self, mode, loss, loss_r, loss_t, loss_ta, loss_tn):
|
| 49 |
+
self.log_dict({
|
| 50 |
+
f'{mode}_loss/sum': loss,
|
| 51 |
+
f'{mode}_loss/r': loss_r,
|
| 52 |
+
f'{mode}_loss/t': loss_t,
|
| 53 |
+
f'{mode}_loss/ta': loss_ta,
|
| 54 |
+
f'{mode}_loss/tn': loss_tn,
|
| 55 |
+
}, on_epoch=True, sync_dist=True)
|
| 56 |
+
|
| 57 |
+
def training_step(self, batch, batch_idx):
|
| 58 |
+
loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
|
| 59 |
+
|
| 60 |
+
self.r_errors['train'].append(r_err)
|
| 61 |
+
self.ta_errors['train'].append(ta_err)
|
| 62 |
+
self.t_errors['train'].append(t_err)
|
| 63 |
+
|
| 64 |
+
self._shared_log('train', loss, loss_r, loss_t, loss_ta, loss_tn)
|
| 65 |
+
|
| 66 |
+
return loss
|
| 67 |
+
|
| 68 |
+
def validation_step(self, batch, batch_idx):
|
| 69 |
+
loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
|
| 70 |
+
|
| 71 |
+
self.r_errors['valid'].append(r_err)
|
| 72 |
+
self.ta_errors['valid'].append(ta_err)
|
| 73 |
+
self.t_errors['valid'].append(t_err)
|
| 74 |
+
|
| 75 |
+
self._shared_log('valid', loss, loss_r, loss_t, loss_ta, loss_tn)
|
| 76 |
+
|
| 77 |
+
def test_step(self, batch, batch_idx):
|
| 78 |
+
loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err = self._shared_forward_step(batch, batch_idx)
|
| 79 |
+
|
| 80 |
+
self.r_errors['test'].append(r_err)
|
| 81 |
+
self.ta_errors['test'].append(ta_err)
|
| 82 |
+
self.t_errors['test'].append(t_err)
|
| 83 |
+
|
| 84 |
+
self._shared_log('test', loss, loss_r, loss_t, loss_ta, loss_tn)
|
| 85 |
+
|
| 86 |
+
def _shared_forward_step(self, batch, batch_idx):
|
| 87 |
+
images = batch['images']
|
| 88 |
+
rotation = batch['rotation']
|
| 89 |
+
translation = batch['translation']
|
| 90 |
+
intrinsics = batch['intrinsics']
|
| 91 |
+
|
| 92 |
+
image0 = images[:, 0, ...]
|
| 93 |
+
image1 = images[:, 1, ...]
|
| 94 |
+
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
feats0 = self.extractor({'image': image0})
|
| 97 |
+
feats1 = self.extractor({'image': image1})
|
| 98 |
+
|
| 99 |
+
if 'scales' in batch:
|
| 100 |
+
scales = batch['scales']
|
| 101 |
+
feats0['keypoints'] *= scales[:, 0].unsqueeze(1)
|
| 102 |
+
feats1['keypoints'] *= scales[:, 1].unsqueeze(1)
|
| 103 |
+
|
| 104 |
+
if self.hparams.task == 'scene':
|
| 105 |
+
pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
|
| 106 |
+
elif self.hparams.task == 'object':
|
| 107 |
+
bboxes = batch['bboxes']
|
| 108 |
+
pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0], 'bbox': bboxes[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
|
| 109 |
+
|
| 110 |
+
r_err = rotation_angular_error(pred_r, rotation)
|
| 111 |
+
ta_err = translation_angular_error(pred_t, translation)
|
| 112 |
+
|
| 113 |
+
loss_r = self.criterion(r_err, torch.zeros_like(r_err))
|
| 114 |
+
loss_ta = self.criterion(ta_err, torch.zeros_like(ta_err))
|
| 115 |
+
loss_tn = self.criterion(pred_t / pred_t.norm(2, dim=-1, keepdim=True), translation / translation.norm(2, dim=-1, keepdim=True))
|
| 116 |
+
loss_t = self.criterion(pred_t, translation)
|
| 117 |
+
|
| 118 |
+
# loss = loss_r * torch.exp(-self.s_r) + loss_t * torch.exp(-self.s_t) + loss_ta * torch.exp(-self.s_ta) + self.s_r + self.s_t + self.s_ta
|
| 119 |
+
loss = loss_r + loss_ta + loss_t + loss_tn
|
| 120 |
+
|
| 121 |
+
r_err = r_err.detach()
|
| 122 |
+
ta_err = ta_err.detach()
|
| 123 |
+
t_err = (pred_t.detach() - translation).norm(2, dim=1)
|
| 124 |
+
|
| 125 |
+
return loss, loss_r, loss_ta, loss_t, loss_tn, r_err, ta_err, t_err
|
| 126 |
+
|
| 127 |
+
def predict_one_data(self, data, device='cuda'):
|
| 128 |
+
st_time = time.time()
|
| 129 |
+
images = data['images'].to(device)
|
| 130 |
+
intrinsics = data['intrinsics'].to(device)
|
| 131 |
+
|
| 132 |
+
image0 = images[:, 0, ...]
|
| 133 |
+
image1 = images[:, 1, ...]
|
| 134 |
+
|
| 135 |
+
preprocess = time.time()
|
| 136 |
+
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
feats0 = self.extractor({'image': image0})
|
| 139 |
+
feats1 = self.extractor({'image': image1})
|
| 140 |
+
|
| 141 |
+
extract_time = time.time()
|
| 142 |
+
|
| 143 |
+
if 'scales' in data:
|
| 144 |
+
scales = data['scales'].to(device)
|
| 145 |
+
feats0['keypoints'] *= scales[:, 0].unsqueeze(1)
|
| 146 |
+
feats1['keypoints'] *= scales[:, 1].unsqueeze(1)
|
| 147 |
+
|
| 148 |
+
if self.hparams.task == 'scene':
|
| 149 |
+
pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
|
| 150 |
+
elif self.hparams.task == 'object':
|
| 151 |
+
bboxes = data['bboxes'].to(device)
|
| 152 |
+
pred_r, pred_t = self.module({'image0': {**feats0, 'intrinsics': intrinsics[:, 0], 'bbox': bboxes[:, 0]}, 'image1': {**feats1, 'intrinsics': intrinsics[:, 1]}})
|
| 153 |
+
|
| 154 |
+
regress_time = time.time()
|
| 155 |
+
|
| 156 |
+
return pred_r[0], pred_t[0], preprocess-st_time, extract_time-preprocess, regress_time-extract_time
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _shared_on_epoch_end(self, mode):
|
| 160 |
+
r_errors = torch.hstack(self.r_errors[mode]).rad2deg()
|
| 161 |
+
ta_errors = torch.hstack(self.ta_errors[mode]).rad2deg()
|
| 162 |
+
ta_errors = torch.minimum(ta_errors, 180-ta_errors)
|
| 163 |
+
|
| 164 |
+
auc = error_auc(torch.maximum(r_errors, ta_errors).cpu(), [5, 10, 20], mode)
|
| 165 |
+
t_errors = torch.hstack(self.t_errors[mode])
|
| 166 |
+
|
| 167 |
+
self.log_dict({
|
| 168 |
+
**auc,
|
| 169 |
+
f'{mode}_Rot./Avg. Error': r_errors.mean(),
|
| 170 |
+
f'{mode}_Rot./Med. Error': r_errors.median(),
|
| 171 |
+
f'{mode}_Rot./@30° ACC': (r_errors < 30).float().mean(),
|
| 172 |
+
f'{mode}_Rot./@15° ACC': (r_errors < 15).float().mean(),
|
| 173 |
+
# f'{mode}_ta/avg': ta_errors.mean(),
|
| 174 |
+
# f'{mode}_ta/med': ta_errors.median(),
|
| 175 |
+
f'{mode}_Trans./Avg. Error': t_errors.mean(),
|
| 176 |
+
f'{mode}_Trans./Med. Error': t_errors.median(),
|
| 177 |
+
f'{mode}_Trans./@10cm ACC': (t_errors < 0.1).float().mean(),
|
| 178 |
+
f'{mode}_Trans./@1m ACC': (t_errors < 1.0).float().mean(),
|
| 179 |
+
}, sync_dist=True)
|
| 180 |
+
|
| 181 |
+
self.r_errors[mode].clear()
|
| 182 |
+
self.ta_errors[mode].clear()
|
| 183 |
+
self.t_errors[mode].clear()
|
| 184 |
+
|
| 185 |
+
def on_train_epoch_end(self):
|
| 186 |
+
self._shared_on_epoch_end('train')
|
| 187 |
+
|
| 188 |
+
def on_validation_epoch_end(self):
|
| 189 |
+
self._shared_on_epoch_end('valid')
|
| 190 |
+
|
| 191 |
+
def on_test_epoch_end(self):
|
| 192 |
+
self._shared_on_epoch_end('test')
|
| 193 |
+
|
| 194 |
+
def configure_optimizers(self):
|
| 195 |
+
optimizer = torch.optim.AdamW(self.module.parameters(), lr=self.hparams.lr)
|
| 196 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=1, epochs=self.hparams.epochs, pct_start=self.hparams.pct_start)
|
| 197 |
+
|
| 198 |
+
return {
|
| 199 |
+
'optimizer': optimizer,
|
| 200 |
+
'lr_scheduler': scheduler
|
| 201 |
+
}
|
model/relpose.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from types import SimpleNamespace
|
| 3 |
+
from typing import Callable, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from utils import rotation_matrix_from_ortho6d
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flash_attn.modules.mha import FlashCrossAttention
|
| 13 |
+
except ModuleNotFoundError:
|
| 14 |
+
FlashCrossAttention = None
|
| 15 |
+
|
| 16 |
+
if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
|
| 17 |
+
FLASH_AVAILABLE = True
|
| 18 |
+
else:
|
| 19 |
+
FLASH_AVAILABLE = False
|
| 20 |
+
|
| 21 |
+
torch.backends.cudnn.deterministic = True
|
| 22 |
+
torch.set_float32_matmul_precision('medium')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def normalize_keypoints(kpts, intrinsics):
|
| 26 |
+
# kpts: (B, M, 2)
|
| 27 |
+
# intrinsics: (B, 3, 3)
|
| 28 |
+
b, m, _ = kpts.shape
|
| 29 |
+
kpts = torch.cat([kpts, torch.ones((b, m, 1), device=kpts.device)], dim=2)
|
| 30 |
+
kpts = intrinsics.inverse() @ kpts.mT
|
| 31 |
+
kpts = kpts.mT
|
| 32 |
+
kpts = kpts[..., :2]
|
| 33 |
+
|
| 34 |
+
return kpts
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
| 38 |
+
def cosine_similarity(x, y):
|
| 39 |
+
sim = torch.einsum('...id,...jd->...ij', x / x.norm(2, -1, keepdim=True), y / y.norm(2, -1, keepdim=True))
|
| 40 |
+
sim = (sim + 1) / 2
|
| 41 |
+
return sim
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
|
| 45 |
+
if length <= x.shape[-2]:
|
| 46 |
+
return x, torch.ones_like(x[..., :1], dtype=torch.bool)
|
| 47 |
+
pad = torch.ones(
|
| 48 |
+
*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
|
| 49 |
+
)
|
| 50 |
+
y = torch.cat([x, pad], dim=-2)
|
| 51 |
+
mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
|
| 52 |
+
mask[..., : x.shape[-2], :] = True
|
| 53 |
+
return y, mask
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def gather(x: torch.Tensor, indices: torch.tensor):
|
| 57 |
+
b, _, n = x.shape
|
| 58 |
+
bs = torch.arange(b).reshape(b, 1, 1)
|
| 59 |
+
ns = torch.arange(n)
|
| 60 |
+
return x[bs, indices.unsqueeze(-1), ns]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Attention(nn.Module):
|
| 64 |
+
def __init__(self, allow_flash: bool = True) -> None:
|
| 65 |
+
super().__init__()
|
| 66 |
+
if allow_flash and not FLASH_AVAILABLE:
|
| 67 |
+
warnings.warn(
|
| 68 |
+
"FlashAttention is not available. For optimal speed, "
|
| 69 |
+
"consider installing torch >= 2.0 or flash-attn.",
|
| 70 |
+
stacklevel=2,
|
| 71 |
+
)
|
| 72 |
+
self.enable_flash = allow_flash and FLASH_AVAILABLE
|
| 73 |
+
self.has_sdp = hasattr(F, "scaled_dot_product_attention")
|
| 74 |
+
if allow_flash and FlashCrossAttention:
|
| 75 |
+
self.flash_ = FlashCrossAttention()
|
| 76 |
+
if self.has_sdp:
|
| 77 |
+
torch.backends.cuda.enable_flash_sdp(allow_flash)
|
| 78 |
+
|
| 79 |
+
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 80 |
+
if self.enable_flash and q.device.type == "cuda":
|
| 81 |
+
# use torch 2.0 scaled_dot_product_attention with flash
|
| 82 |
+
if self.has_sdp:
|
| 83 |
+
args = [x.contiguous() for x in [q, k, v]]
|
| 84 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
|
| 85 |
+
return v if mask is None else v.nan_to_num()
|
| 86 |
+
else:
|
| 87 |
+
assert mask is None
|
| 88 |
+
q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
|
| 89 |
+
m = self.flash_(q, torch.stack([k, v], 2))
|
| 90 |
+
return m.transpose(-2, -3).to(q.dtype).clone()
|
| 91 |
+
elif self.has_sdp:
|
| 92 |
+
args = [x.contiguous() for x in [q, k, v]]
|
| 93 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
|
| 94 |
+
return v if mask is None else v.nan_to_num()
|
| 95 |
+
else:
|
| 96 |
+
s = q.shape[-1] ** -0.5
|
| 97 |
+
sim = torch.einsum("...id,...jd->...ij", q, k) * s
|
| 98 |
+
if mask is not None:
|
| 99 |
+
sim.masked_fill(~mask, -float("inf"))
|
| 100 |
+
attn = F.softmax(sim, -1)
|
| 101 |
+
return torch.einsum("...ij,...jd->...id", attn, v)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class SelfBlock(nn.Module):
|
| 105 |
+
def __init__(
|
| 106 |
+
self, embed_dim: int, num_heads: int, bias: bool = True
|
| 107 |
+
) -> None:
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.embed_dim = embed_dim
|
| 110 |
+
self.num_heads = num_heads
|
| 111 |
+
assert self.embed_dim % num_heads == 0
|
| 112 |
+
self.head_dim = self.embed_dim // num_heads
|
| 113 |
+
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
|
| 114 |
+
self.inner_attn = Attention()
|
| 115 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
| 116 |
+
self.ffn = nn.Sequential(
|
| 117 |
+
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
| 118 |
+
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
| 119 |
+
nn.GELU(),
|
| 120 |
+
nn.Linear(2 * embed_dim, embed_dim),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def forward(
|
| 124 |
+
self,
|
| 125 |
+
x: torch.Tensor,
|
| 126 |
+
encoding: torch.Tensor,
|
| 127 |
+
mask: Optional[torch.Tensor] = None,
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
qkv = self.Wqkv(x)
|
| 130 |
+
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
|
| 131 |
+
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
|
| 132 |
+
q += encoding
|
| 133 |
+
k += encoding
|
| 134 |
+
|
| 135 |
+
context = self.inner_attn(q, k, v, mask=mask)
|
| 136 |
+
|
| 137 |
+
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
|
| 138 |
+
return x + self.ffn(torch.cat([x, message], -1))
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class CrossBlock(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self, embed_dim: int, num_heads: int, bias: bool = True
|
| 144 |
+
) -> None:
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.heads = num_heads
|
| 147 |
+
dim_head = embed_dim // num_heads
|
| 148 |
+
self.scale = dim_head**-0.5
|
| 149 |
+
inner_dim = dim_head * num_heads
|
| 150 |
+
self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
|
| 151 |
+
self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
|
| 152 |
+
self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
|
| 153 |
+
self.ffn = nn.Sequential(
|
| 154 |
+
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
| 155 |
+
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
| 156 |
+
nn.GELU(),
|
| 157 |
+
nn.Linear(2 * embed_dim, embed_dim),
|
| 158 |
+
)
|
| 159 |
+
# self.reg_attn = nn.Identity()
|
| 160 |
+
# self.reg_sim = nn.Identity()
|
| 161 |
+
|
| 162 |
+
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
|
| 163 |
+
return func(x0), func(x1)
|
| 164 |
+
|
| 165 |
+
def forward(
|
| 166 |
+
self, x0: torch.Tensor, x1: torch.Tensor, match: torch.Tensor, mask: Optional[torch.Tensor] = None
|
| 167 |
+
) -> List[torch.Tensor]:
|
| 168 |
+
|
| 169 |
+
qk0, qk1 = self.map_(self.to_qk, x0, x1)
|
| 170 |
+
v0, v1 = self.map_(self.to_v, x0, x1)
|
| 171 |
+
qk0, qk1, v0, v1 = map(
|
| 172 |
+
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
|
| 173 |
+
(qk0, qk1, v0, v1),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
|
| 177 |
+
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
|
| 178 |
+
if mask is not None:
|
| 179 |
+
sim = sim.masked_fill(~mask.unsqueeze(1), -float("inf"))
|
| 180 |
+
|
| 181 |
+
assert len(match.shape) == 3
|
| 182 |
+
match = match.unsqueeze(1)
|
| 183 |
+
sim = sim * match
|
| 184 |
+
# sim = self.reg_attn(sim)
|
| 185 |
+
# match = self.reg_sim(match)
|
| 186 |
+
|
| 187 |
+
attn01 = F.softmax(sim, dim=-1)
|
| 188 |
+
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
|
| 189 |
+
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
|
| 190 |
+
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
|
| 191 |
+
if mask is not None:
|
| 192 |
+
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
|
| 193 |
+
|
| 194 |
+
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
|
| 195 |
+
m0, m1 = self.map_(self.to_out, m0, m1)
|
| 196 |
+
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
| 197 |
+
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
|
| 198 |
+
|
| 199 |
+
return x0, x1
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class TransformerLayer(nn.Module):
|
| 203 |
+
def __init__(self, *args, **kwargs):
|
| 204 |
+
super().__init__()
|
| 205 |
+
self.self_attn = SelfBlock(*args, **kwargs)
|
| 206 |
+
self.cross_attn = CrossBlock(*args, **kwargs)
|
| 207 |
+
|
| 208 |
+
def forward(
|
| 209 |
+
self,
|
| 210 |
+
desc0,
|
| 211 |
+
desc1,
|
| 212 |
+
encoding0,
|
| 213 |
+
encoding1,
|
| 214 |
+
match,
|
| 215 |
+
mask0: Optional[torch.Tensor] = None,
|
| 216 |
+
mask1: Optional[torch.Tensor] = None,
|
| 217 |
+
):
|
| 218 |
+
if mask0 is not None and mask1 is not None:
|
| 219 |
+
return self.masked_forward(desc0, desc1, encoding0, encoding1, match, mask0, mask1)
|
| 220 |
+
else:
|
| 221 |
+
desc0 = self.self_attn(desc0, encoding0)
|
| 222 |
+
desc1 = self.self_attn(desc1, encoding1)
|
| 223 |
+
return self.cross_attn(desc0, desc1, match)
|
| 224 |
+
|
| 225 |
+
# This part is compiled and allows padding inputs
|
| 226 |
+
def masked_forward(self, desc0, desc1, encoding0, encoding1, match, mask0, mask1):
|
| 227 |
+
mask = mask0 & mask1.transpose(-1, -2)
|
| 228 |
+
mask0 = mask0 & mask0.transpose(-1, -2)
|
| 229 |
+
mask1 = mask1 & mask1.transpose(-1, -2)
|
| 230 |
+
desc0 = self.self_attn(desc0, encoding0, mask0)
|
| 231 |
+
desc1 = self.self_attn(desc1, encoding1, mask1)
|
| 232 |
+
return self.cross_attn(desc0, desc1, match, mask)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class RelPose(nn.Module):
|
| 236 |
+
default_conf = {
|
| 237 |
+
"name": "RelPose", # just for interfacing
|
| 238 |
+
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
| 239 |
+
"descriptor_dim": 256,
|
| 240 |
+
"add_scale_ori": False,
|
| 241 |
+
"n_layers": 3,
|
| 242 |
+
"num_heads": 4,
|
| 243 |
+
"pct_pruning": 0,
|
| 244 |
+
"task": "scene",
|
| 245 |
+
"mp": False, # enable mixed precision
|
| 246 |
+
"weights": None,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
required_data_keys = ["image0", "image1"]
|
| 250 |
+
|
| 251 |
+
features = {
|
| 252 |
+
"superpoint": {
|
| 253 |
+
"input_dim": 256,
|
| 254 |
+
},
|
| 255 |
+
"disk": {
|
| 256 |
+
"input_dim": 128,
|
| 257 |
+
},
|
| 258 |
+
"aliked": {
|
| 259 |
+
"input_dim": 128,
|
| 260 |
+
},
|
| 261 |
+
"sift": {
|
| 262 |
+
"input_dim": 128,
|
| 263 |
+
# "add_scale_ori": True,
|
| 264 |
+
},
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
def __init__(self, features="superpoint", **conf) -> None:
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
|
| 270 |
+
if features is not None:
|
| 271 |
+
if features not in self.features:
|
| 272 |
+
raise ValueError(
|
| 273 |
+
f"Unsupported features: {features} not in "
|
| 274 |
+
f"{{{','.join(self.features)}}}"
|
| 275 |
+
)
|
| 276 |
+
for k, v in self.features[features].items():
|
| 277 |
+
setattr(conf, k, v)
|
| 278 |
+
|
| 279 |
+
if conf.input_dim != conf.descriptor_dim:
|
| 280 |
+
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
| 281 |
+
else:
|
| 282 |
+
self.input_proj = nn.Identity()
|
| 283 |
+
|
| 284 |
+
head_dim = conf.descriptor_dim // conf.num_heads
|
| 285 |
+
self.posenc = nn.Linear(
|
| 286 |
+
2 + 2 * self.conf.add_scale_ori, head_dim
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
| 290 |
+
|
| 291 |
+
self.transformers = nn.ModuleList(
|
| 292 |
+
[TransformerLayer(d, h) for _ in range(n)]
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
self.rotation_regressor = nn.Sequential(
|
| 296 |
+
nn.Linear(conf.descriptor_dim*2, conf.descriptor_dim),
|
| 297 |
+
nn.ReLU(),
|
| 298 |
+
nn.Linear(conf.descriptor_dim, conf.descriptor_dim//2),
|
| 299 |
+
nn.ReLU(),
|
| 300 |
+
nn.Linear(conf.descriptor_dim//2, 6),
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
self.translation_regressor = nn.Sequential(
|
| 304 |
+
nn.Linear(conf.descriptor_dim*2, conf.descriptor_dim),
|
| 305 |
+
nn.ReLU(),
|
| 306 |
+
nn.Linear(conf.descriptor_dim, conf.descriptor_dim//2),
|
| 307 |
+
nn.ReLU(),
|
| 308 |
+
nn.Linear(conf.descriptor_dim//2, 3),
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# self.reg_kpts0 = nn.Identity()
|
| 312 |
+
# self.reg_kpts1 = nn.Identity()
|
| 313 |
+
|
| 314 |
+
# static lengths LightGlue is compiled for (only used with torch.compile)
|
| 315 |
+
self.static_lengths = None
|
| 316 |
+
|
| 317 |
+
def compile(
|
| 318 |
+
self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
|
| 319 |
+
):
|
| 320 |
+
for i in range(self.conf.n_layers):
|
| 321 |
+
self.transformers[i].masked_forward = torch.compile(
|
| 322 |
+
self.transformers[i].masked_forward, mode=mode, fullgraph=True
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
self.static_lengths = static_lengths
|
| 326 |
+
|
| 327 |
+
def forward(self, data: dict) -> dict:
|
| 328 |
+
"""
|
| 329 |
+
Match keypoints and descriptors between two images
|
| 330 |
+
|
| 331 |
+
Input (dict):
|
| 332 |
+
image0: dict
|
| 333 |
+
keypoints: [B x M x 2]
|
| 334 |
+
descriptors: [B x M x D]
|
| 335 |
+
image: [B x C x H x W] or image_size: [B x 2]
|
| 336 |
+
image1: dict
|
| 337 |
+
keypoints: [B x N x 2]
|
| 338 |
+
descriptors: [B x N x D]
|
| 339 |
+
image: [B x C x H x W] or image_size: [B x 2]
|
| 340 |
+
Output
|
| 341 |
+
"""
|
| 342 |
+
with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
|
| 343 |
+
return self._forward(data)
|
| 344 |
+
|
| 345 |
+
def _forward(self, data: dict) -> dict:
|
| 346 |
+
for key in self.required_data_keys:
|
| 347 |
+
assert key in data, f"Missing key {key} in data"
|
| 348 |
+
data0, data1 = data["image0"], data["image1"]
|
| 349 |
+
kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
|
| 350 |
+
intrinsic0, intrinsic1 = data0["intrinsics"], data1["intrinsics"]
|
| 351 |
+
b, m, _ = kpts0.shape
|
| 352 |
+
b, n, _ = kpts1.shape
|
| 353 |
+
|
| 354 |
+
if self.conf.add_scale_ori:
|
| 355 |
+
kpts0 = torch.cat(
|
| 356 |
+
[kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
| 357 |
+
)
|
| 358 |
+
kpts1 = torch.cat(
|
| 359 |
+
[kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
| 360 |
+
)
|
| 361 |
+
desc0 = data0["descriptors"].detach().contiguous()
|
| 362 |
+
desc1 = data1["descriptors"].detach().contiguous()
|
| 363 |
+
|
| 364 |
+
assert desc0.shape[-1] == self.conf.input_dim
|
| 365 |
+
assert desc1.shape[-1] == self.conf.input_dim
|
| 366 |
+
|
| 367 |
+
mask0, mask1 = None, None
|
| 368 |
+
c = max(m, n)
|
| 369 |
+
do_compile = self.static_lengths and c <= max(self.static_lengths)
|
| 370 |
+
if do_compile:
|
| 371 |
+
kn = min([k for k in self.static_lengths if k >= c])
|
| 372 |
+
desc0, mask0 = pad_to_length(desc0, kn)
|
| 373 |
+
desc1, mask1 = pad_to_length(desc1, kn)
|
| 374 |
+
kpts0, _ = pad_to_length(kpts0, kn)
|
| 375 |
+
kpts1, _ = pad_to_length(kpts1, kn)
|
| 376 |
+
|
| 377 |
+
matchability = cosine_similarity(desc0, desc1)
|
| 378 |
+
|
| 379 |
+
assert self.conf.pct_pruning >= 0 and self.conf.pct_pruning < 1
|
| 380 |
+
if self.conf.pct_pruning > 0:
|
| 381 |
+
ind0, ind1 = self.get_pruned_indices(matchability, self.conf.pct_pruning)
|
| 382 |
+
|
| 383 |
+
matchability = gather(matchability, ind0)
|
| 384 |
+
matchability = gather(matchability.mT, ind1).mT
|
| 385 |
+
|
| 386 |
+
desc0 = gather(desc0, ind0)
|
| 387 |
+
desc1 = gather(desc1, ind1)
|
| 388 |
+
|
| 389 |
+
kpts0 = gather(kpts0, ind0)
|
| 390 |
+
kpts1 = gather(kpts1, ind1)
|
| 391 |
+
|
| 392 |
+
if self.conf.task == "object":
|
| 393 |
+
bbox = data0["bbox"] # (B, 4)
|
| 394 |
+
ind0, mask0 = self.get_prompted_indices(kpts0, bbox)
|
| 395 |
+
|
| 396 |
+
matchability[:, 0] = torch.zeros_like(matchability[:, 0], device=matchability.device)
|
| 397 |
+
desc0[:, 0] = torch.zeros_like(desc0[:, 0], device=desc0.device)
|
| 398 |
+
kpts0[:, 0] = torch.zeros_like(kpts0[:, 0], device=kpts0.device)
|
| 399 |
+
|
| 400 |
+
matchability = gather(matchability, ind0)
|
| 401 |
+
desc0 = gather(desc0, ind0)
|
| 402 |
+
kpts0 = gather(kpts0, ind0)
|
| 403 |
+
|
| 404 |
+
desc0 = self.input_proj(desc0)
|
| 405 |
+
desc1 = self.input_proj(desc1)
|
| 406 |
+
|
| 407 |
+
# kpts0 = self.reg_kpts0(kpts0)
|
| 408 |
+
# kpts1 = self.reg_kpts1(kpts1)
|
| 409 |
+
|
| 410 |
+
# cache positional embeddings
|
| 411 |
+
kpts0 = normalize_keypoints(kpts0, intrinsic0)
|
| 412 |
+
kpts1 = normalize_keypoints(kpts1, intrinsic1)
|
| 413 |
+
|
| 414 |
+
encoding0 = self.posenc(kpts0).unsqueeze(-3)
|
| 415 |
+
encoding1 = self.posenc(kpts1).unsqueeze(-3)
|
| 416 |
+
|
| 417 |
+
for i in range(self.conf.n_layers):
|
| 418 |
+
desc0, desc1 = self.transformers[i](
|
| 419 |
+
desc0, desc1, encoding0, encoding1, match=matchability, mask0=mask0, mask1=mask1,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
|
| 423 |
+
if self.conf.task == 'object':
|
| 424 |
+
n_kpts0 = mask0.sum(1, keepdim=True)
|
| 425 |
+
n_kpts0 = torch.clip(n_kpts0, min=1)
|
| 426 |
+
desc0 = (desc0 * mask0.unsqueeze(-1)).sum(1) / n_kpts0
|
| 427 |
+
desc1 = desc1.mean(1)
|
| 428 |
+
else:
|
| 429 |
+
desc0, desc1 = desc0.mean(1), desc1.mean(1)
|
| 430 |
+
|
| 431 |
+
feat = torch.cat([desc0, desc1], 1)
|
| 432 |
+
|
| 433 |
+
R = self.rotation_regressor(feat)
|
| 434 |
+
R = rotation_matrix_from_ortho6d(R)
|
| 435 |
+
t = self.translation_regressor(feat)
|
| 436 |
+
|
| 437 |
+
return R, t
|
| 438 |
+
|
| 439 |
+
def get_pruned_indices(self, match, pct_pruning):
|
| 440 |
+
matching_scores0 = match.mean(-1)
|
| 441 |
+
matching_scores1 = match.mean(-2)
|
| 442 |
+
|
| 443 |
+
num_pruning0 = int(pct_pruning * matching_scores0.size(-1))
|
| 444 |
+
num_pruning1 = int(pct_pruning * matching_scores1.size(-1))
|
| 445 |
+
|
| 446 |
+
_, indices0 = matching_scores0.sort()
|
| 447 |
+
_, indices1 = matching_scores1.sort()
|
| 448 |
+
|
| 449 |
+
indices0 = indices0[:, num_pruning0:]
|
| 450 |
+
indices1 = indices1[:, num_pruning1:]
|
| 451 |
+
|
| 452 |
+
return indices0, indices1
|
| 453 |
+
|
| 454 |
+
def get_prompted_indices(self, kpts, bbox):
|
| 455 |
+
# kpts: (B, M, 2)
|
| 456 |
+
# bbox: (B, 4) - (x, y, x, y)
|
| 457 |
+
x, y = kpts[..., 0], kpts[..., 1]
|
| 458 |
+
mask = (x >= bbox[:, 0].unsqueeze(-1)) & (x <= bbox[:, 2].unsqueeze(-1))
|
| 459 |
+
mask &= (y >= bbox[:, 1].unsqueeze(-1)) & (y <= bbox[:, 3].unsqueeze(-1))
|
| 460 |
+
mask_sorted, indices = mask.long().sort(descending=True)
|
| 461 |
+
indices *= mask_sorted
|
| 462 |
+
indices = indices[:, :mask_sorted.sum(-1).max()]
|
| 463 |
+
mask_sorted = mask_sorted[:, :mask_sorted.sum(-1).max()]
|
| 464 |
+
|
| 465 |
+
return indices, mask_sorted
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
albumentations==1.4.1
|
| 2 |
+
kornia==0.7.1
|
| 3 |
+
open3d==0.18.0
|
| 4 |
+
opencv-python==4.9.0.80
|
| 5 |
+
plyfile==1.0.3
|
| 6 |
+
scikit-learn==1.4.1.post1
|
| 7 |
+
yacs==0.1.8
|
| 8 |
+
lightning==2.2.1
|
| 9 |
+
transforms3d==0.4.1
|
| 10 |
+
pandas==2.1.1
|
| 11 |
+
lightglue @ git+https://github.com/cvg/LightGlue@main
|
train.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from torch.utils.data import DataLoader
|
| 3 |
+
import lightning as L
|
| 4 |
+
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
| 5 |
+
|
| 6 |
+
from datasets import dataset_dict, RandomConcatSampler
|
| 7 |
+
from model import PL_RelPose
|
| 8 |
+
from utils import seed_torch
|
| 9 |
+
from configs.default import get_cfg_defaults
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main(args):
|
| 13 |
+
config = get_cfg_defaults()
|
| 14 |
+
config.merge_from_file(args.config)
|
| 15 |
+
|
| 16 |
+
task = config.DATASET.TASK
|
| 17 |
+
dataset = config.DATASET.DATA_SOURCE
|
| 18 |
+
|
| 19 |
+
batch_size = config.TRAINER.BATCH_SIZE
|
| 20 |
+
num_workers = config.TRAINER.NUM_WORKERS
|
| 21 |
+
pin_memory = config.TRAINER.PIN_MEMORY
|
| 22 |
+
n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET
|
| 23 |
+
lr = config.TRAINER.LEARNING_RATE
|
| 24 |
+
epochs = config.TRAINER.EPOCHS
|
| 25 |
+
pct_start = config.TRAINER.PCT_START
|
| 26 |
+
|
| 27 |
+
num_keypoints = config.MODEL.NUM_KEYPOINTS
|
| 28 |
+
n_layers = config.MODEL.N_LAYERS
|
| 29 |
+
num_heads = config.MODEL.NUM_HEADS
|
| 30 |
+
features = config.MODEL.FEATURES
|
| 31 |
+
|
| 32 |
+
seed = config.RANDOM_SEED
|
| 33 |
+
seed_torch(seed)
|
| 34 |
+
|
| 35 |
+
build_fn = dataset_dict[task][dataset]
|
| 36 |
+
trainset = build_fn('train', config)
|
| 37 |
+
validset = build_fn('val', config)
|
| 38 |
+
|
| 39 |
+
if dataset == 'scannet' or dataset == 'megadepth' or dataset == 'linemod' or dataset == 'ho3d' or dataset == 'mapfree':
|
| 40 |
+
sampler = RandomConcatSampler(
|
| 41 |
+
trainset,
|
| 42 |
+
n_samples_per_subset=n_samples_per_subset,
|
| 43 |
+
subset_replacement=True,
|
| 44 |
+
shuffle=True,
|
| 45 |
+
seed=seed
|
| 46 |
+
)
|
| 47 |
+
trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, sampler=sampler)
|
| 48 |
+
else:
|
| 49 |
+
trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)
|
| 50 |
+
|
| 51 |
+
validloader = DataLoader(validset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
|
| 52 |
+
|
| 53 |
+
if args.weights is None:
|
| 54 |
+
pl_relpose = PL_RelPose(
|
| 55 |
+
task=task,
|
| 56 |
+
lr=lr,
|
| 57 |
+
epochs=epochs,
|
| 58 |
+
pct_start=pct_start,
|
| 59 |
+
n_layers=n_layers,
|
| 60 |
+
num_heads=num_heads,
|
| 61 |
+
num_keypoints=num_keypoints,
|
| 62 |
+
features=features,
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
pl_relpose = PL_RelPose.load_from_checkpoint(
|
| 66 |
+
checkpoint_path=args.weights,
|
| 67 |
+
task=task,
|
| 68 |
+
lr=lr,
|
| 69 |
+
epochs=epochs,
|
| 70 |
+
pct_start=pct_start,
|
| 71 |
+
n_layers=n_layers,
|
| 72 |
+
num_heads=num_heads,
|
| 73 |
+
num_keypoints=num_keypoints,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
| 77 |
+
latest_checkpoint_callback = ModelCheckpoint()
|
| 78 |
+
best_checkpoint_callback = ModelCheckpoint(monitor='valid/auc@20', mode='max')
|
| 79 |
+
trainer = L.Trainer(
|
| 80 |
+
devices=[0],
|
| 81 |
+
# devices=[0, 1],
|
| 82 |
+
# accelerator='gpu', strategy='ddp_find_unused_parameters_true',
|
| 83 |
+
max_epochs=epochs,
|
| 84 |
+
callbacks=[lr_monitor, latest_checkpoint_callback, best_checkpoint_callback],
|
| 85 |
+
precision="bf16-mixed",
|
| 86 |
+
# fast_dev_run=1,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
trainer.fit(pl_relpose, trainloader, validloader, ckpt_path=args.resume)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_parser():
|
| 93 |
+
parser = argparse.ArgumentParser()
|
| 94 |
+
parser.add_argument('config', type=str, help='.yaml configure file path')
|
| 95 |
+
parser.add_argument('--resume', type=str, default=None)
|
| 96 |
+
parser.add_argument('--weights', type=str, default=None)
|
| 97 |
+
|
| 98 |
+
return parser
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
parser = get_parser()
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
main(args)
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .metrics import quat_degree_error, rotation_angular_error, translation_angular_error, error_auc
|
| 7 |
+
from .transform import rotation_matrix_from_ortho6d, rotation_matrix_from_quaternion
|
| 8 |
+
from .augment import Augmentor
|
| 9 |
+
# from .visualize import project_3D_points, plot_3D_box
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def seed_torch(seed):
|
| 13 |
+
random.seed(seed)
|
| 14 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 15 |
+
np.random.seed(seed)
|
| 16 |
+
torch.manual_seed(seed)
|
| 17 |
+
torch.cuda.manual_seed(seed)
|
| 18 |
+
torch.backends.cudnn.benchmark = False
|
| 19 |
+
torch.backends.cudnn.deterministic = True
|
utils/augment.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import albumentations as A
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Augmentor(object):
|
| 5 |
+
def __init__(self, is_training:bool):
|
| 6 |
+
self.augmentor = A.Compose([
|
| 7 |
+
A.MotionBlur(p=0.25),
|
| 8 |
+
A.ColorJitter(p=0.25),
|
| 9 |
+
A.ImageCompression(p=0.25),
|
| 10 |
+
A.ISONoise(p=0.25),
|
| 11 |
+
A.ToGray(p=0.1)
|
| 12 |
+
], p=float(is_training))
|
| 13 |
+
|
| 14 |
+
def __call__(self, x):
|
| 15 |
+
return self.augmentor(image=x)['image']
|