yujia commited on
Commit
d4c7a24
·
0 Parent(s):

init utonia

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +28 -0
  2. .github/CODE_OF_CONDUCT.md +80 -0
  3. .github/CONTRIBUTING.md +35 -0
  4. .gitignore +18 -0
  5. .gradio/certificate.pem +31 -0
  6. LICENSE +201 -0
  7. README.md +11 -0
  8. app.py +989 -0
  9. geometry_utils.py +197 -0
  10. requirements.txt +27 -0
  11. setup.py +12 -0
  12. utonia/__init__.py +26 -0
  13. utonia/data.py +92 -0
  14. utonia/model.py +891 -0
  15. utonia/module.py +107 -0
  16. utonia/registry.py +340 -0
  17. utonia/serialization/__init__.py +8 -0
  18. utonia/serialization/default.py +82 -0
  19. utonia/serialization/hilbert.py +309 -0
  20. utonia/serialization/z_order.py +127 -0
  21. utonia/structure.py +159 -0
  22. utonia/transform.py +1226 -0
  23. utonia/utils.py +75 -0
  24. vggt/__init__.py +0 -0
  25. vggt/heads/camera_head.py +162 -0
  26. vggt/heads/dpt_head.py +497 -0
  27. vggt/heads/head_act.py +125 -0
  28. vggt/heads/track_head.py +108 -0
  29. vggt/heads/track_modules/__init__.py +5 -0
  30. vggt/heads/track_modules/base_track_predictor.py +209 -0
  31. vggt/heads/track_modules/blocks.py +246 -0
  32. vggt/heads/track_modules/modules.py +218 -0
  33. vggt/heads/track_modules/utils.py +226 -0
  34. vggt/heads/utils.py +108 -0
  35. vggt/layers/__init__.py +11 -0
  36. vggt/layers/attention.py +98 -0
  37. vggt/layers/block.py +259 -0
  38. vggt/layers/drop_path.py +34 -0
  39. vggt/layers/layer_scale.py +27 -0
  40. vggt/layers/mlp.py +40 -0
  41. vggt/layers/patch_embed.py +88 -0
  42. vggt/layers/rope.py +188 -0
  43. vggt/layers/swiglu_ffn.py +72 -0
  44. vggt/layers/vision_transformer.py +407 -0
  45. vggt/models/aggregator.py +331 -0
  46. vggt/models/vggt.py +81 -0
  47. vggt/utils/geometry.py +167 -0
  48. vggt/utils/load_fn.py +111 -0
  49. vggt/utils/pose_enc.py +73 -0
  50. vggt/utils/rotation.py +138 -0
.gitattributes ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ckpt/concerto_large.pth filter=lfs diff=lfs merge=lfs -text
2
+ example filter=lfs diff=lfs merge=lfs -text
3
+ example/pcd filter=lfs diff=lfs merge=lfs -text
4
+ example/video filter=lfs diff=lfs merge=lfs -text
5
+ example/pcd/hm3d_00012_kDgLKdMd5X8_1.ply filter=lfs diff=lfs merge=lfs -text
6
+ example/pcd/hm3d_00012_kDgLKdMd5X8_2.ply filter=lfs diff=lfs merge=lfs -text
7
+ example/pcd/s3dis_Area2_conferenceRoom1.png filter=lfs diff=lfs merge=lfs -text
8
+ example/pcd/scannet_0603.png filter=lfs diff=lfs merge=lfs -text
9
+ example/pcd/hm3d_00012_kDgLKdMd5X8_1.png filter=lfs diff=lfs merge=lfs -text
10
+ example/pcd/scannet_0024.png filter=lfs diff=lfs merge=lfs -text
11
+ example/pcd/s3dis_Area4_lobby1.png filter=lfs diff=lfs merge=lfs -text
12
+ example/pcd/scannet_0024.ply filter=lfs diff=lfs merge=lfs -text
13
+ example/pcd/hm3d_00012_kDgLKdMd5X8_2.png filter=lfs diff=lfs merge=lfs -text
14
+ example/pcd/hm3d_00113_3goH1WRaCYC.ply filter=lfs diff=lfs merge=lfs -text
15
+ example/pcd/hm3d_00113_3goH1WRaCYC.png filter=lfs diff=lfs merge=lfs -text
16
+ example/pcd/s3dis_Area2_auditorium1.ply filter=lfs diff=lfs merge=lfs -text
17
+ example/pcd/s3dis_Area2_auditorium1.png filter=lfs diff=lfs merge=lfs -text
18
+ example/pcd/scannet_0603.ply filter=lfs diff=lfs merge=lfs -text
19
+ example/pcd/s3dis_Area2_conferenceRoom1.ply filter=lfs diff=lfs merge=lfs -text
20
+ example/pcd/s3dis_Area4_lobby1.ply filter=lfs diff=lfs merge=lfs -text
21
+ example/pcd/scannetpp_2a1b555966.ply filter=lfs diff=lfs merge=lfs -text
22
+ example/pcd/scannetpp_2a1b555966.png filter=lfs diff=lfs merge=lfs -text
23
+ example/video/re10k_2.mp4 filter=lfs diff=lfs merge=lfs -text
24
+ example/video/re10k_3.mp4 filter=lfs diff=lfs merge=lfs -text
25
+ example/video/re10k_4.mp4 filter=lfs diff=lfs merge=lfs -text
26
+ example/video/conference_room.mp4 filter=lfs diff=lfs merge=lfs -text
27
+ example/video/office.mp4 filter=lfs diff=lfs merge=lfs -text
28
+ example/video/re10k_1.mp4 filter=lfs diff=lfs merge=lfs -text
.github/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
.github/CONTRIBUTING.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to "sonata"
2
+
3
+ We want to make contributing to this project as easy and transparent as
4
+ possible.
5
+
6
+ ## Pull Requests
7
+
8
+ We welcome pull requests.
9
+
10
+ 1. Fork the repo and create your branch from `main`.
11
+ 2. If you've added code that should be tested, add tests.
12
+ 3. If you've changed APIs, update the documentation in the code.
13
+ 4. Ensure the test suite passes.
14
+ 5. If you haven't already, complete the Contributor License Agreement ("CLA").
15
+
16
+ ## Contributor License Agreement ("CLA")
17
+
18
+ In order to accept your pull request, we need you to submit a CLA. You only need
19
+ to do this once to work on any of Facebook's open source projects.
20
+
21
+ Complete your CLA here: <https://code.facebook.com/cla>
22
+
23
+ ## Issues
24
+
25
+ We use GitHub issues to track public bugs. Please ensure your description is
26
+ clear and has sufficient instructions to be able to reproduce the issue.
27
+
28
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
29
+ disclosure of security bugs. In those cases, please go through the process
30
+ outlined on that page and do not file a public issue.
31
+
32
+ ## License
33
+
34
+ By contributing to "sonata", you agree that your contributions will be licensed under
35
+ the [LICENSE](../LICENSE) file in the root directory of this source tree.
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image/
2
+ __pycache__
3
+ **/build/
4
+ **/*.egg-info/
5
+ **/dist/
6
+ *.so
7
+ exp
8
+ weights
9
+ data
10
+ log
11
+ **/ckpt/
12
+ outputs/
13
+ .vscode
14
+ .idea
15
+ */.DS_Store
16
+ **/*.out
17
+ vggt/ckpt
18
+ **/demo_output*
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Utonia
3
+ emoji: ✨
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 6.1.0
8
+ app_file: app.py
9
+ pinned: false
10
+ python_version: "3.10"
11
+ ---
app.py ADDED
@@ -0,0 +1,989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ import trimesh
5
+ import torch
6
+ import time
7
+ import spaces
8
+ import cv2
9
+ import shutil
10
+ from datetime import datetime
11
+ import glob
12
+ from einops import rearrange
13
+
14
+ # Local imports
15
+ from geometry_utils import (
16
+ Coord2zup,
17
+ extract_and_align_ground_plane,
18
+ pad_0001,
19
+ T_to_C,
20
+ im_distance_to_im_depth,
21
+ im_depth_to_point_cloud,
22
+ )
23
+ # VGGT specific imports
24
+ from vggt.utils.load_fn import load_and_preprocess_images
25
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
26
+ from vggt.utils.geometry import unproject_depth_map_to_point_map
27
+ from vggt.models.vggt import VGGT
28
+
29
+ import utonia
30
+ utonia.utils.set_seed(53124)
31
+ utonia_model = utonia.load("utonia", repo_id="Pointcept/Utonia")
32
+
33
+ VGGT_model = VGGT()
34
+ _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
35
+ VGGT_model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
36
+
37
+ @spaces.GPU
38
+ def _gpu_run_vggt_inference(images_tensor):
39
+ """
40
+ GPU-only function: Run VGGT model inference on preprocessed images.
41
+ Minimizes GPU time by only doing model inference and pose encoding conversion.
42
+ """
43
+ global VGGT_model
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ # Move images to GPU
46
+ images_tensor = images_tensor.to(device)
47
+ model = VGGT_model.to(device)
48
+ model.eval()
49
+
50
+ print("Running inference...")
51
+ with torch.no_grad():
52
+ if device == "cuda":
53
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
54
+ predictions = model(images_tensor)
55
+ else:
56
+ predictions = model(images_tensor)
57
+
58
+ # Convert pose encoding to extrinsic and intrinsic matrices (GPU operation)
59
+ print("Converting pose encoding to extrinsic and intrinsic matrices...")
60
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images_tensor.shape[-2:])
61
+ predictions["extrinsic"] = extrinsic
62
+ predictions["intrinsic"] = intrinsic
63
+
64
+ # Convert to numpy (still on GPU to minimize memory transfer)
65
+ for key in predictions.keys():
66
+ if isinstance(predictions[key], torch.Tensor):
67
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0)
68
+
69
+ torch.cuda.empty_cache()
70
+ return predictions
71
+
72
+ def run_model(target_dir) -> dict:
73
+ """
74
+ CPU-GPU hybrid: Handle CPU-intensive file I/O and call GPU function for inference.
75
+ """
76
+ print(f"Processing images from {target_dir}")
77
+
78
+ # Load and preprocess images (CPU)
79
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
80
+ image_names = sorted(image_names)
81
+ print(f"Found {len(image_names)} images")
82
+ if len(image_names) == 0:
83
+ raise ValueError("No images found. Check your upload.")
84
+
85
+ images = load_and_preprocess_images(image_names)
86
+ print(f"Preprocessed images shape: {images.shape}")
87
+
88
+ # Call GPU function for inference
89
+ predictions = _gpu_run_vggt_inference(images)
90
+
91
+ # Post-processing (CPU)
92
+ print("Computing world points from depth map...")
93
+ depth_map = predictions["depth"] # (S, H, W, 1)
94
+ world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
95
+ predictions["world_points_from_depth"] = world_points
96
+
97
+ return predictions
98
+
99
+ def parse_frames(
100
+ target_dir,
101
+ conf_thres=3.0,
102
+ prediction_mode="Pointmap Regression",
103
+ ):
104
+ """
105
+ Perform reconstruction using the already-created target_dir/images.
106
+ """
107
+ if not os.path.isdir(target_dir) or target_dir == "None":
108
+ return None, "No valid target directory found. Please upload first.", None, None
109
+
110
+ start_time = time.time()
111
+
112
+ # Prepare frame_filter dropdown
113
+ target_dir_images = os.path.join(target_dir, "images")
114
+ all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
115
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
116
+
117
+ print("Running run_model...")
118
+ with torch.no_grad():
119
+ predictions = run_model(target_dir)
120
+
121
+ # Save predictions
122
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
123
+ np.savez(prediction_save_path, **predictions)
124
+
125
+ # Convert pose encoding to extrinsic and intrinsic matrices
126
+ images = predictions["images"]
127
+ Ts, Ks = predictions["extrinsic"],predictions["intrinsic"]
128
+ Ts = pad_0001(Ts)
129
+ Ts_inv = np.linalg.inv(Ts)
130
+ Cs = np.array([T_to_C(T) for T in Ts]) # (n, 3)
131
+
132
+ # [1, 8, 294, 518, 3]
133
+ world_points = predictions["world_points"]
134
+
135
+ # Compute view direction for each pixel
136
+ # (b n h w c) - (n, 3)
137
+ view_dirs = world_points - rearrange(Cs, "n c -> n 1 1 c")
138
+ view_dirs = rearrange(view_dirs, "n h w c -> (n h w) c")
139
+ view_dirs = view_dirs / np.linalg.norm(view_dirs, axis=-1, keepdims=True)
140
+
141
+ # Extract points and colors
142
+ # [1, 8, 3, 294, 518]
143
+ img_num = world_points.shape[1]
144
+ images = predictions["images"]
145
+ points = rearrange(world_points, "n h w c -> (n h w) c")
146
+ colors = rearrange(images, "n c h w -> (n h w) c")
147
+ normals = np.zeros_like(points)
148
+
149
+ if prediction_mode=="Pointmap Branch":
150
+ world_points_conf = predictions["world_points_conf"]
151
+ conf = world_points_conf.reshape(-1)
152
+ points,Ts_inv,_ = Coord2zup(points, Ts_inv)
153
+ scale = 3 / (points[:, 2].max() - points[:, 2].min())
154
+ points *= scale
155
+ Ts_inv[:, :3, 3] *= scale
156
+
157
+ normals = -np.asarray(view_dirs)
158
+ normals = normals / np.clip(np.linalg.norm(normals, axis=-1, keepdims=True), 1e-8, None)
159
+ if conf_thres == 0.0:
160
+ conf_threshold = 0.0
161
+ else:
162
+ conf_threshold = np.percentile(conf, conf_thres)
163
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
164
+ points = points[conf_mask]
165
+ colors = colors[conf_mask]
166
+ normals = normals[conf_mask]
167
+
168
+ try:
169
+ points, colors, normals, _, _, _ = extract_and_align_ground_plane(
170
+ points=points,
171
+ colors=colors,
172
+ normals=normals,
173
+ )
174
+ except Exception as e:
175
+ print(f"cannot find ground, err:{e}")
176
+ elif prediction_mode=="Depthmap Branch":
177
+ # Integrate RGBD images into a TSDF volume and extract a mesh
178
+ # (n, h, w, 3)
179
+ im_colors = rearrange(images, "n c h w -> (n) h w c")
180
+ # (b, n, h, w, 3)
181
+ im_dists = world_points - rearrange(Cs, "n c -> n 1 1 c")
182
+ im_dists = np.linalg.norm(im_dists, axis=-1, keepdims=False)
183
+
184
+ # Convert distance to depth
185
+ im_depths = [] # (n, h, w, c)
186
+ for im_dist, K in zip(im_dists, Ks):
187
+ im_depth = im_distance_to_im_depth(im_dist, K)
188
+ im_depths.append(im_depth)
189
+ im_depths = np.stack(im_depths, axis=0)
190
+ points=[]
191
+ for K, T, im_depth in zip(Ks, Ts, im_depths):
192
+ point = im_depth_to_point_cloud(
193
+ im_depth=im_depth,
194
+ K=K,
195
+ T=T,
196
+ to_image=False,
197
+ ignore_invalid=False,
198
+ )
199
+ points.append(point)
200
+ points = np.vstack(points)
201
+ colors = im_colors.reshape(-1,3)
202
+ world_points_conf = predictions["depth_conf"]
203
+ conf = world_points_conf.reshape(-1)
204
+ if conf_thres == 0.0:
205
+ conf_threshold = 0.0
206
+ else:
207
+ conf_threshold = np.percentile(conf, conf_thres)
208
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
209
+ points = points[conf_mask]
210
+ colors = colors[conf_mask]
211
+ points,Ts_inv,_ = Coord2zup(points, Ts_inv)
212
+ scale_factor = 3./(np.max(points[:,2])-np.min(points[:,2]))
213
+ points *= scale_factor
214
+ Ts_inv[:, :3, 3] *= scale_factor
215
+ normals = np.zeros_like(points)
216
+ try:
217
+ points, colors, normals, _, _, _ = extract_and_align_ground_plane(
218
+ points=points,
219
+ colors=colors,
220
+ normals=normals,
221
+ )
222
+ except Exception as e:
223
+ print(f"cannot find ground, err:{e}")
224
+ original_points = np.asarray(points)
225
+ original_colors = np.asarray(colors)
226
+ original_normals = np.asarray(normals)
227
+ # Cleanup
228
+ del predictions
229
+ end_time = time.time()
230
+ print(f"Total time: {end_time - start_time:.2f} seconds")
231
+ return original_points, original_colors, original_normals
232
+
233
+ def handle_uploads(input_file,input_video,conf_thres,frame_slider,prediction_mode):
234
+ """
235
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
236
+ images or extracted frames from video into it. Return (target_dir, image_paths).
237
+ """
238
+ start_time = time.time()
239
+
240
+ # Create a unique folder name
241
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
242
+ target_dir = f"demo_output/inputs_{timestamp}"
243
+ target_dir_images = os.path.join(target_dir, "images")
244
+ target_dir_pcds = os.path.join(target_dir, "pcds")
245
+
246
+ # Clean up if somehow that folder already exists
247
+ if os.path.exists(target_dir):
248
+ shutil.rmtree(target_dir)
249
+ os.makedirs(target_dir)
250
+ os.makedirs(target_dir_images)
251
+ os.makedirs(target_dir_pcds)
252
+ # Handle video
253
+ if input_video is not None:
254
+ print("processing video")
255
+ if isinstance(input_video, dict) and "name" in input_video:
256
+ video_path = input_video["name"]
257
+ else:
258
+ video_path = input_video
259
+
260
+ vs = cv2.VideoCapture(video_path)
261
+ fps = vs.get(cv2.CAP_PROP_FPS)
262
+ frame_interval = int(fps * frame_slider) # 1 frame/sec
263
+
264
+ count = 0
265
+ video_frame_num = 0
266
+ image_paths = []
267
+ while True:
268
+ gotit, frame = vs.read()
269
+ if not gotit:
270
+ break
271
+ count += 1
272
+ if count % frame_interval == 0:
273
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
274
+ cv2.imwrite(image_path, frame)
275
+ image_paths.append(image_path)
276
+ video_frame_num += 1
277
+ # Sort final images for gallery
278
+ image_paths = sorted(image_paths)
279
+ original_points, original_colors, original_normals = parse_frames(target_dir,conf_thres,prediction_mode)
280
+ if input_file is not None:
281
+ print("processing ply")
282
+ loaded = load_point_from_file(input_file)
283
+ if loaded is None:
284
+ raise ValueError("Failed to load input point cloud file")
285
+ original_points = loaded["coord"]
286
+ original_colors = loaded["color"]
287
+ original_normals = loaded["normal"]
288
+ image_paths = None
289
+ scene_3d = trimesh.Scene()
290
+ point_cloud_data = trimesh.PointCloud(vertices=original_points, colors=original_colors, vertex_normals=original_normals)
291
+ scene_3d.add_geometry(point_cloud_data)
292
+ original_temp = os.path.join(target_dir_pcds,"original.glb")
293
+ scene_3d.export(file_obj=original_temp)
294
+ np.save(os.path.join(target_dir_pcds, f"points.npy"), original_points)
295
+ np.save(os.path.join(target_dir_pcds, f"colors.npy"), original_colors)
296
+ np.save(os.path.join(target_dir_pcds, f"normals.npy"), original_normals)
297
+ end_time = time.time()
298
+ print(f"Files copied to {target_dir}; took {end_time - start_time:.3f} seconds")
299
+ return target_dir, image_paths,original_temp, end_time - start_time
300
+
301
+
302
+ def load_point_from_file(input_file):
303
+ if input_file is None:
304
+ return None
305
+
306
+ file_path = input_file
307
+ if hasattr(input_file, "name"):
308
+ file_path = input_file.name
309
+ elif isinstance(input_file, dict) and "name" in input_file:
310
+ file_path = input_file["name"]
311
+
312
+ if not file_path:
313
+ return None
314
+
315
+ geometry = trimesh.load(file_path, process=False)
316
+ if isinstance(geometry, trimesh.Scene):
317
+ geometries = [g for g in geometry.geometry.values()]
318
+ if not geometries:
319
+ return None
320
+ geometry = geometries[0]
321
+
322
+ if isinstance(geometry, trimesh.PointCloud):
323
+ coord = np.asarray(geometry.vertices)
324
+ color = np.asarray(geometry.colors[:, :3]) if geometry.colors is not None and len(geometry.colors) else np.zeros_like(coord)
325
+ normal = np.zeros_like(coord)
326
+ if color.dtype != np.float32 and color.dtype != np.float64:
327
+ color = color.astype(np.float32) / 255.0
328
+ return {"coord": coord, "color": color, "normal": normal}
329
+
330
+ if isinstance(geometry, trimesh.Trimesh):
331
+ coord = np.asarray(geometry.vertices)
332
+ if geometry.visual is not None and hasattr(geometry.visual, "vertex_colors") and geometry.visual.vertex_colors is not None and len(geometry.visual.vertex_colors):
333
+ color = np.asarray(geometry.visual.vertex_colors[:, :3]).astype(np.float32) / 255.0
334
+ else:
335
+ color = np.zeros_like(coord)
336
+ normal = np.asarray(geometry.vertex_normals) if geometry.vertex_normals is not None and len(geometry.vertex_normals) else np.zeros_like(coord)
337
+ return {"coord": coord, "color": color, "normal": normal}
338
+
339
+ return None
340
+
341
+ def update_gallery_on_upload(input_file,input_video,conf_thres,frame_slider,prediction_mode):
342
+ """
343
+ Whenever user uploads or changes files, immediately handle them
344
+ and show in the gallery. Return (target_dir, image_paths).
345
+ If nothing is uploaded, returns "None" and empty list.
346
+ """
347
+ if not input_video and not input_file:
348
+ return None, None, None, None
349
+ target_dir, image_paths,original_view, reconstruction_time = handle_uploads(input_file,input_video,conf_thres,frame_slider,prediction_mode)
350
+ if input_file is not None:
351
+ return original_view, target_dir, [], f"Upload and preprocess complete with {reconstruction_time:.3f} sec. Click \"PCA Generate\" to begin PCA processing."
352
+ if input_video is not None:
353
+ return original_view, target_dir, image_paths, f"Upload and preprocess complete with {reconstruction_time:.3f} sec. Click \"PCA Generate\" to begin PCA processing."
354
+
355
+
356
+ def get_pca_color(feat, start = 0, brightness=1.25, center=True):
357
+ u, s, v = torch.pca_lowrank(feat, center=center, q=3*(start+1), niter=5)
358
+ projection = feat @ v
359
+ projection = projection[:, 3*start:3*(start+1)] * 0.6 + projection[:, 3*start:3*(start+1)] * 0.4
360
+ min_val = projection.min(dim=-2, keepdim=True)[0]
361
+ max_val = projection.max(dim=-2, keepdim=True)[0]
362
+ div = torch.clamp(max_val - min_val, min=1e-6)
363
+ color = (projection - min_val) / div * brightness
364
+ color = color.clamp(0.0, 1.0)
365
+ return color
366
+
367
+ def clear_fields():
368
+ """
369
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
370
+ """
371
+ return None
372
+
373
+ def PCAing_log(is_example, log_output):
374
+ """
375
+ Display a quick log message while waiting.
376
+ """
377
+ if is_example:
378
+ return log_output
379
+ return "Loading for Doing PCA..."
380
+
381
+ def reset_log():
382
+ """
383
+ Reset a quick log message.
384
+ """
385
+ return "A new point cloud file or video is uploading and preprocessing..."
386
+
387
+
388
+ @spaces.GPU
389
+ def _gpu_utonia_forward_pca(point, utonia_model_, pca_slider, bright_slider):
390
+ """
391
+ GPU-only function: Run Utonia model forward pass and PCA in one place.
392
+ Uses inference_mode overall with a scoped disable for the forward call.
393
+ """
394
+ device = "cuda" if torch.cuda.is_available() else "cpu"
395
+
396
+ # Move tensors and model to GPU
397
+ for key in point.keys():
398
+ if isinstance(point[key], torch.Tensor):
399
+ point[key] = point[key].to(device, non_blocking=True)
400
+
401
+ utonia_model_ = utonia_model_.to(device)
402
+ utonia_model_.eval()
403
+
404
+ with torch.inference_mode():
405
+ utonia_start_time = time.time()
406
+ # Disable inference_mode for model forward to avoid version counter issues
407
+ with torch.inference_mode(False):
408
+ point = utonia_model_(point)
409
+ utonia_end_time = time.time()
410
+
411
+ # Upcast point feature through hierarchical pooling
412
+ for _ in range(4):
413
+ assert "pooling_parent" in point.keys()
414
+ assert "pooling_inverse" in point.keys()
415
+ parent = point.pop("pooling_parent")
416
+ inverse = point.pop("pooling_inverse")
417
+ parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
418
+ point = parent
419
+ while "pooling_parent" in point.keys():
420
+ assert "pooling_inverse" in point.keys()
421
+ parent = point.pop("pooling_parent")
422
+ inverse = point.pop("pooling_inverse")
423
+ parent.feat = point.feat[inverse]
424
+ point = parent
425
+
426
+ pca_start_time = time.time()
427
+ pca_color = get_pca_color(point.feat, start=pca_slider, brightness=bright_slider, center=True)
428
+ pca_end_time = time.time()
429
+
430
+ # Inverse back to original scale
431
+ original_pca_color = pca_color[point.inverse]
432
+
433
+ processed_colors = original_pca_color.cpu().detach().numpy()
434
+ point_feat = point.feat.cpu().detach().numpy()
435
+ point_inverse = point.inverse.cpu().detach().numpy()
436
+ utonia_time = utonia_end_time - utonia_start_time
437
+ pca_time = pca_end_time - pca_start_time
438
+ return processed_colors, point_feat, point_inverse, utonia_time, pca_time
439
+
440
+ def gradio_demo(target_dir, pca_slider, bright_slider, if_color=True, if_normal=True, scale_value=1.0, apply_z_positive=True, normalize_coord=False):
441
+ global utonia_model
442
+ target_dir_pcds = os.path.join(target_dir, "pcds")
443
+ if not os.path.isfile(os.path.join(target_dir_pcds, "points.npy")):
444
+ return None, "No point cloud available. Please upload data first."
445
+
446
+ # CPU: Load point cloud data from disk
447
+ original_points = np.load(os.path.join(target_dir_pcds, "points.npy"))
448
+ if if_color:
449
+ original_colors = np.load(os.path.join(target_dir_pcds, "colors.npy"))
450
+ else:
451
+ original_colors = np.zeros_like(original_points)
452
+ if if_normal:
453
+ original_normals = np.load(os.path.join(target_dir_pcds, "normals.npy"))
454
+ else:
455
+ original_normals = np.zeros_like(original_points)
456
+
457
+ processed_temp = os.path.join(target_dir_pcds, "processed.glb")
458
+
459
+ point = {"coord": original_points, "color": original_colors, "normal": original_normals}
460
+ original_coord = point["coord"].copy()
461
+
462
+ # CPU: Apply transform pipeline
463
+ transform = utonia.transform.default(scale=scale_value, apply_z_positive=apply_z_positive, normalize_coord=normalize_coord)
464
+ point = transform(point)
465
+
466
+ # GPU: Run Utonia forward + PCA together (inference_mode inside GPU function)
467
+ processed_colors, point_feat, point_inverse_cpu, utonia_time, pca_time = _gpu_utonia_forward_pca(
468
+ point, utonia_model, pca_slider, bright_slider
469
+ )
470
+
471
+ # CPU: Save features
472
+ np.save(os.path.join(target_dir_pcds, "feat.npy"), point_feat)
473
+ np.save(os.path.join(target_dir_pcds, "inverse.npy"), point_inverse_cpu)
474
+
475
+ # CPU: Build and save the 3D mesh
476
+ processed_points = original_coord
477
+ feat_3d = trimesh.Scene()
478
+ feat_data = trimesh.PointCloud(vertices=processed_points, colors=processed_colors, vertex_normals=original_normals)
479
+ feat_3d.add_geometry(feat_data)
480
+ feat_3d.export(processed_temp)
481
+
482
+ return processed_temp, f"Feature visualization process finished with {utonia_time:.3f} seconds using utonia inference and {pca_time:.3f} seconds using PCA. Updating visualization."
483
+
484
+ @spaces.GPU
485
+ def _gpu_pca_slider_compute(feat_array, inverse_array, pca_slider, bright_slider):
486
+ """
487
+ GPU-only function: Compute PCA colors for slider updates.
488
+ Minimal GPU allocation for only the essential computation.
489
+ """
490
+ device = "cuda" if torch.cuda.is_available() else "cpu"
491
+ # Move data to GPU inside GPU function
492
+ feat_tensor = torch.tensor(feat_array, device=device)
493
+ inverse_tensor = torch.tensor(inverse_array, device=device)
494
+
495
+ pca_start_time = time.time()
496
+ pca_colors = get_pca_color(feat_tensor, start=pca_slider, brightness=bright_slider, center=True)
497
+ processed_colors = pca_colors[inverse_tensor].cpu().detach().numpy()
498
+ pca_end_time = time.time()
499
+ return processed_colors, (pca_end_time - pca_start_time)
500
+
501
+ def utonia_slider_update(target_dir, pca_slider, bright_slider, is_example, log_output):
502
+ """
503
+ CPU-GPU hybrid: Handle file I/O on CPU, GPU for PCA computation only.
504
+ """
505
+ if is_example == "True":
506
+ return None, log_output
507
+ else:
508
+ target_dir_pcds = os.path.join(target_dir, "pcds")
509
+ if os.path.isfile(os.path.join(target_dir_pcds, "feat.npy")):
510
+ # CPU: Load data from disk
511
+ feat = np.load(os.path.join(target_dir_pcds, "feat.npy"))
512
+ inverse = np.load(os.path.join(target_dir_pcds, "inverse.npy"))
513
+
514
+ # GPU: Compute PCA colors only (numpy arrays passed to GPU function)
515
+ processed_colors, pca_time = _gpu_pca_slider_compute(feat, inverse, pca_slider, bright_slider)
516
+
517
+ # CPU: Load additional data and build mesh
518
+ processed_points = np.load(os.path.join(target_dir_pcds, "points.npy"))
519
+ processed_normals = np.load(os.path.join(target_dir_pcds, "normals.npy"))
520
+ processed_temp = os.path.join(target_dir_pcds, "processed.glb")
521
+
522
+ feat_3d = trimesh.Scene()
523
+ feat_data = trimesh.PointCloud(vertices=processed_points, colors=processed_colors, vertex_normals=processed_normals)
524
+ feat_3d.add_geometry(feat_data)
525
+ feat_3d.export(processed_temp)
526
+
527
+ log_output = f"Feature visualization process finished with {pca_time:.3f} seconds using PCA. Updating visualization."
528
+ else:
529
+ processed_temp = None
530
+ log_output = "No representations saved, please click PCA generate first."
531
+ return processed_temp, log_output
532
+
533
+ BASE_URL = "https://huggingface.co/datasets/pointcept-bot/utonia_huggingface_demo/resolve/main/"
534
+ def get_url(path):
535
+ return f"{BASE_URL}{path}"
536
+
537
+ examples_object = [
538
+ [
539
+ get_url("object/0005df571e71437991594d0affec9c2b.png"),
540
+ get_url("object/0005df571e71437991594d0affec9c2b.ply"),
541
+ 0, 1.2, "True", 1.0, True
542
+ ],
543
+ [
544
+ get_url("object/0023687e90394c3e97ab19b0160cafb3.png"),
545
+ get_url("object/0023687e90394c3e97ab19b0160cafb3.ply"),
546
+ 0, 1.2, "True", 1.0, True
547
+ ],
548
+ [
549
+ get_url("object/0015eb3cf53b4339b2d0532cf912ab26.png"),
550
+ get_url("object/0015eb3cf53b4339b2d0532cf912ab26.ply"),
551
+ 0, 1.2, "True", 1.0, True
552
+ ],
553
+ [
554
+ get_url("object/001a5201eddf4f3b98591598584673f5.png"),
555
+ get_url("object/001a5201eddf4f3b98591598584673f5.ply"),
556
+ 0, 1.2, "True", 1.0, True
557
+ ],
558
+ ]
559
+
560
+ examples_manipulation = [
561
+ [
562
+ get_url("manipulation/000021_AUTOLab_5d05c5aa_2023-11-17-23h-40m-52s-35_46.png"),
563
+ get_url("manipulation/000021_AUTOLab_5d05c5aa_2023-11-17-23h-40m-52s-35_46.ply"),
564
+ 0, 1.2, "True", 4.0, False
565
+ ],
566
+ [
567
+ get_url("manipulation/000018_AUTOLab_44bb9c36_2023-11-23-20h-05m-45s-55_66.png"),
568
+ get_url("manipulation/000018_AUTOLab_44bb9c36_2023-11-23-20h-05m-45s-55_66.ply"),
569
+ 1, 1.0, "True", 4.0, False
570
+ ],
571
+ [
572
+ get_url("manipulation/000037_IPRL_7790ec0a_2023-07-01-09h-37m-21s-15_26.png"),
573
+ get_url("manipulation/000037_IPRL_7790ec0a_2023-07-01-09h-37m-21s-15_26.ply"),
574
+ 0, 1.2, "True", 4.0, False
575
+ ],
576
+ [
577
+ get_url("manipulation/000061_TRI_938130c4_2023-08-10-14h-40m-11s-70_81.png"),
578
+ get_url("manipulation/000061_TRI_938130c4_2023-08-10-14h-40m-11s-70_81.ply"),
579
+ 2, 1.2, "True", 4.0, False
580
+ ],
581
+ ]
582
+
583
+ examples_indoor = [
584
+ [
585
+ get_url("indoor/scene0024_00.png"),
586
+ get_url("indoor/scene0024_00.ply"),
587
+ 0, 1.0, "True", 0.5, False
588
+ ],
589
+ [
590
+ get_url("indoor/scene0603_00.png"),
591
+ get_url("indoor/scene0603_00.ply"),
592
+ 0, 1.0, "True", 0.5, False
593
+ ],
594
+ [
595
+ get_url("indoor/027cd6ea0f.png"),
596
+ get_url("indoor/027cd6ea0f.ply"),
597
+ 0, 1.0, "True", 0.5, False
598
+ ],
599
+ [
600
+ get_url("indoor/2c7c10379b.png"),
601
+ get_url("indoor/2c7c10379b.ply"),
602
+ 3, 1.0, "True", 0.5, False
603
+ ],
604
+ ]
605
+
606
+ examples_outdoor = [
607
+ [
608
+ get_url("outdoor/segment-10455472356147194054_1560_000_1580_000_with_camera_labels.png"),
609
+ get_url("outdoor/segment-10455472356147194054_1560_000_1580_000_with_camera_labels.ply"),
610
+ 1, 1.2, "True", 0.2, False
611
+ ],
612
+ [
613
+ get_url("outdoor/segment-10963653239323173269_1924_000_1944_000_with_camera_labels.png"),
614
+ get_url("outdoor/segment-10963653239323173269_1924_000_1944_000_with_camera_labels.ply"),
615
+ 0, 1.2, "True", 0.2, False
616
+ ],
617
+ [
618
+ get_url("outdoor/segment-11718898130355901268_2300_000_2320_000_with_camera_labels.png"),
619
+ get_url("outdoor/segment-11718898130355901268_2300_000_2320_000_with_camera_labels.ply"),
620
+ 0, 1.2, "True", 0.2, False
621
+ ],
622
+ [
623
+ get_url("outdoor/segment-11925224148023145510_1040_000_1060_000_with_camera_labels.png"),
624
+ get_url("outdoor/segment-11925224148023145510_1040_000_1060_000_with_camera_labels.ply"),
625
+ 0, 1.2, "True", 0.2, False
626
+ ],
627
+ ]
628
+
629
+ examples_video = [
630
+ [
631
+ get_url("video/re10k_1.mp4"),
632
+ 10.0, 1, "Depthmap Branch", 2, 1.2, "True", 0.5, False
633
+ ],
634
+ [
635
+ get_url("video/re10k_2.mp4"),
636
+ 20.0, 1, "Pointmap Branch", 2, 1.2, "True", 0.5, False
637
+ ],
638
+ [
639
+ get_url("video/re10k_3.mp4"),
640
+ 10.0, 1, "Pointmap Branch", 1, 1.2, "True", 0.5, False
641
+ ],
642
+ [
643
+ get_url("video/re10k_4.mp4"),
644
+ 10.0, 1, "Pointmap Branch", 1, 1., "True", 0.5, False
645
+ ],
646
+ ]
647
+
648
+ def example_file_updated(
649
+ preview_imgs,
650
+ inputs,
651
+ pca_slider,
652
+ bright_slider,
653
+ is_example,
654
+ scale_slider,
655
+ normalize_coord,
656
+ url_input,
657
+ ):
658
+ pass
659
+
660
+ def example_video_updated(
661
+ inputs,
662
+ conf_thres,
663
+ frame_slider,
664
+ prediction_mode,
665
+ pca_slider,
666
+ bright_slider,
667
+ is_example,
668
+ scale_slider,
669
+ normalize_coord,
670
+ url_input,
671
+ ):
672
+ pass
673
+
674
+ with gr.Blocks(
675
+ css="""
676
+ .custom-log * {
677
+ font-style: italic;
678
+ font-size: 22px !important;
679
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
680
+ -webkit-background-clip: text;
681
+ background-clip: text;
682
+ font-weight: bold !important;
683
+ color: transparent !important;
684
+ text-align: center !important;
685
+ width: 800px;
686
+ height: 100px;
687
+ }
688
+
689
+ .example-log * {
690
+ font-style: italic;
691
+ font-size: 16px !important;
692
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
693
+ -webkit-background-clip: text;
694
+ background-clip: text;
695
+ color: transparent !important;
696
+ }
697
+
698
+ .common-markdown * {
699
+ font-size: 22px !important;
700
+ -webkit-background-clip: text;
701
+ background-clip: text;
702
+ font-weight: bold !important;
703
+ color: #0ea5e9 !important;
704
+ text-align: center !important;
705
+ }
706
+
707
+ #big-box {
708
+ border: 3px solid #00bcd4;
709
+ padding: 20px;
710
+ background-color: transparent;
711
+ border-radius: 15px;
712
+ }
713
+
714
+ #my_radio .wrap {
715
+ display: flex;
716
+ flex-wrap: nowrap;
717
+ justify-content: center;
718
+ align-items: center;
719
+ }
720
+
721
+ #my_radio .wrap label {
722
+ display: flex;
723
+ width: 50%;
724
+ justify-content: center;
725
+ align-items: center;
726
+ margin: 0;
727
+ padding: 10px 0;
728
+ box-sizing: border-box;
729
+ }
730
+ """,
731
+ ) as demo:
732
+ gr.HTML(
733
+ """
734
+ <h1>Utonia: Toward One Encoder for All Point Clouds</h1>
735
+ <div style="font-size: 16px; line-height: 1.5;">
736
+ <ol>
737
+ <details style="display:inline;">
738
+ <summary style="display:inline;"><h3>Getting Started:(<strong>Click to expand</strong>)</h3></summary>
739
+ <li><strong>Before Start: We recommend cloning this space to run locally on GPU</strong> for the limited GPU time.</li>
740
+ <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Point Cloud" blocks on the left to provide your input. If you upload a video, it will be automatically split into individual frames with the specified frame gap by VGGT.</li>
741
+ <li>
742
+ <strong>[Optional] Adjust Video-Lifted Point Cloud:</strong>
743
+ Before reconstructing the video, you can fine-tune the VGGT lifting process using the options below
744
+ <details style="display:inline;">
745
+ <summary style="display:inline;">(<strong>Click to expand</strong>)</summary>
746
+ <ul>
747
+ <li><em>Frame Gap / N Sec:</em> Adjust the frame interval.</li>
748
+ <li><em>Confidence Threshold:</em> Adjust the point filtering based on confidence levels.</li>
749
+ <li><em>Select Prediction Mode:</em> Choose between "Depthmap Branch" and "Pointmap Branch."</li>
750
+ </ul>
751
+ </details>
752
+ </li>
753
+ <li><strong>PCA Generation:</strong> After reconstruction, click the "PCA Generate" button to start the representation extraction and PCA process.</li>
754
+ <li><strong>Clear:</strong> Click the "Clear" button to reset all content in the blocks.</li>
755
+ <li><strong>Point Cloud Preview:</strong> Your uploaded video or point cloud will be displayed in this block.</li>
756
+ <li><strong>PCA Result:</strong> The PCA point cloud will appear here. You can rotate, drag, and zoom to explore the model, and download the GLB file.</li>
757
+ <li>
758
+ <strong>[Optional] Adjust the Point Cloud Input:</strong>
759
+ <details style="display:inline;">
760
+ <summary style="display:inline;">(<strong>Click to expand</strong>)</summary>
761
+ <ul>
762
+ <li><em>Input with Point Cloud Color/Normal:</em> If not checked, the corresponding information will be set to zeros.</li>
763
+ <li><em>Z positive:</em> When enabled, the point cloud is transformed so that the Z-axis is strictly positive. This is the default recommendation for most inputs, except for outdoor scenes where maintaining the road at the XY plane is preferred.</li>
764
+ <li><em>Normalize coord:</em> When enabled, the point cloud will be normalized to fit within the [-1,1] range before scaling. This is typically used for single objects.</li>
765
+ </ul>
766
+ </details>
767
+ </li>
768
+ <li>
769
+ <strong>[Optional] Adjust PCA Visualization:</strong>
770
+ Fine-tune the PCA visualization using the options below
771
+ <details style="display:inline;">
772
+ <summary style="display:inline;">(<strong>Click to expand</strong>)</summary>
773
+ <ul>
774
+ <li><em>PCA Start Dimension:</em> PCA reduces high-dimensional representations into 3D vectors. Adjust the PCA start dimension to change the range of the visualization. Increasing this value can help you see PCA visualization with less variance when the initial PCA dimension shows less diversity.</li>
775
+ <li><em>PCA Brightness:</em> Adjust the brightness of the PCA visualization results.</li>
776
+ <li><em>Notice:</em> As a linear dimension reduction method, PCA has its limitation. Sometimes, the visualization cannot fully exhibit the quality of representations.</li>
777
+ </ul>
778
+ </details>
779
+ </li>
780
+ <li><strong>Adjust Scale Parameter:</strong> Adjust the scale of the point cloud. The default scale is 0.2 for outdoor, 0.5 for indoor, and 4.0 for manipulation with high resolution needs. Object data needs normalization. You can adjust it according to your desired granularity and memory limit.</li>
781
+ </details>
782
+ </ol>
783
+ </div>
784
+
785
+ """
786
+ )
787
+ _ = gr.Textbox(label="_", visible=False, value="False")
788
+ is_example = gr.Textbox(label="is_example", visible=False, value="False")
789
+ target_dir = gr.Textbox(label="Target Dir", visible=False, value="None")
790
+ preview_imgs = gr.Image(type="filepath",label="Preview Imgs", visible=False, value="None")
791
+ with gr.Row():
792
+ with gr.Column(scale=1,elem_id="big-box"):
793
+ input_file = gr.File(label="Upload Point Cloud", file_types=[".ply"])
794
+ input_video = gr.Video(label="Upload Video", interactive=True)
795
+ image_gallery = gr.Gallery(
796
+ label="Video Frame Preview",
797
+ columns=4,
798
+ height="300px",
799
+ # show_download_button=True,
800
+ object_fit="contain",
801
+ preview=True,
802
+ )
803
+
804
+ frame_slider = gr.Slider(minimum=0.1, maximum=10, value=1, step=0.1,
805
+ label="1 Frame/ N Sec", interactive=True)
806
+ conf_thres = gr.Slider(minimum=0, maximum=100, value=10, step=0.1,
807
+ label="Confidence", interactive=True)
808
+ prediction_mode = gr.Radio(
809
+ ["Depthmap Branch", "Pointmap Branch"],
810
+ label="Select a Prediction Mode",
811
+ value="Depthmap Branch",
812
+ scale=1,
813
+ elem_id="my_radio",
814
+ )
815
+ reconstruction_btn = gr.Button("Video Reconstruct")
816
+ with gr.Column(scale=2):
817
+ log_output = gr.Markdown(
818
+ "Please upload a video or point cloud ply file, then click \"PCA Generate\".", elem_classes=["custom-log"]
819
+ )
820
+ original_view = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5, label="Point Cloud Preview", camera_position = (90,None,None))
821
+ processed_view = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5, label="PCA Result", camera_position = (90,None,None))
822
+ with gr.Row():
823
+ if_color = gr.Checkbox(label="Input with point cloud color", value=True)
824
+ if_normal = gr.Checkbox(label="Input with point cloud normal", value=True)
825
+ with gr.Row():
826
+ normalize_coord = gr.Checkbox(label="Normalize coord", value=False)
827
+ apply_z_positive = gr.Checkbox(label="Z positive", value=True)
828
+ scale_slider = gr.Slider(minimum=0.001, maximum=5.0, value=1.0, step=0.0005,
829
+ label="Scale Parameter", interactive=True)
830
+ pca_slider = gr.Slider(minimum=0, maximum=5, value=0, step=1,
831
+ label="PCA Start Dimension", interactive=True)
832
+ bright_slider = gr.Slider(minimum=0.5, maximum=1.5, value=1.2, step=0.05,
833
+ label="PCA Brightness", interactive=True)
834
+ with gr.Row():
835
+ submit_btn = gr.Button("PCA Generate")
836
+ clear_btn = gr.ClearButton(
837
+ [input_video, input_file, original_view, processed_view, log_output, target_dir, image_gallery],
838
+ scale=1,
839
+ elem_id="my_clear",
840
+ )
841
+
842
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
843
+ with gr.Row():
844
+ gr.Examples(
845
+ examples=examples_object,
846
+ inputs=[
847
+ preview_imgs,
848
+ input_file,
849
+ pca_slider,
850
+ bright_slider,
851
+ is_example,
852
+ scale_slider,
853
+ normalize_coord,
854
+ ],
855
+ outputs=[
856
+ ],
857
+ label = "Object Point Cloud Examples",
858
+ fn=example_file_updated,
859
+ cache_examples=False,
860
+ examples_per_page=50,
861
+ )
862
+ with gr.Row():
863
+ gr.Examples(
864
+ examples=examples_manipulation,
865
+ inputs=[
866
+ preview_imgs,
867
+ input_file,
868
+ pca_slider,
869
+ bright_slider,
870
+ is_example,
871
+ scale_slider,
872
+ normalize_coord,
873
+ ],
874
+ outputs=[
875
+ ],
876
+ label = "Manipulation Point Cloud Examples",
877
+ fn=example_file_updated,
878
+ cache_examples=False,
879
+ examples_per_page=50,
880
+ )
881
+ with gr.Row():
882
+ gr.Examples(
883
+ examples=examples_indoor,
884
+ inputs=[
885
+ preview_imgs,
886
+ input_file,
887
+ pca_slider,
888
+ bright_slider,
889
+ is_example,
890
+ scale_slider,
891
+ normalize_coord,
892
+ ],
893
+ outputs=[
894
+ ],
895
+ label = "Indoor Point Cloud Examples",
896
+ fn=example_file_updated,
897
+ cache_examples=False,
898
+ examples_per_page=50,
899
+ )
900
+ with gr.Row():
901
+ gr.Examples(
902
+ examples=examples_outdoor,
903
+ inputs=[
904
+ preview_imgs,
905
+ input_file,
906
+ pca_slider,
907
+ bright_slider,
908
+ is_example,
909
+ scale_slider,
910
+ normalize_coord,
911
+ ],
912
+ outputs=[
913
+ ],
914
+ label = "Outdoor Point Cloud Examples",
915
+ fn=example_file_updated,
916
+ cache_examples=False,
917
+ examples_per_page=50,
918
+ )
919
+
920
+ with gr.Row():
921
+ gr.Examples(
922
+ examples=examples_video,
923
+ inputs=[
924
+ input_video,
925
+ conf_thres,
926
+ frame_slider,
927
+ prediction_mode,
928
+ pca_slider,
929
+ bright_slider,
930
+ is_example,
931
+ scale_slider,
932
+ normalize_coord,
933
+ ],
934
+ outputs=[
935
+ ],
936
+ label = "Video Examples",
937
+ fn=example_video_updated,
938
+ cache_examples=False,
939
+ examples_per_page=50,
940
+ )
941
+
942
+ reconstruction_btn.click(
943
+ fn = update_gallery_on_upload,
944
+ inputs = [input_file,input_video,conf_thres,frame_slider,prediction_mode],
945
+ outputs = [original_view, target_dir, image_gallery, log_output]
946
+ )
947
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[processed_view]).then(
948
+ fn=PCAing_log, inputs=[is_example, log_output], outputs=[log_output]
949
+ ).then(
950
+ fn=gradio_demo,
951
+ inputs=[target_dir,pca_slider,bright_slider, if_color, if_normal, scale_slider, apply_z_positive, normalize_coord],
952
+ outputs=[processed_view,log_output],
953
+ ).then(
954
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
955
+ )
956
+
957
+ pca_slider.release(fn=clear_fields, inputs=[], outputs=[processed_view]).then(
958
+ fn=PCAing_log, inputs=[is_example, log_output], outputs=[log_output]
959
+ ).then(
960
+ fn=utonia_slider_update,
961
+ inputs=[target_dir,pca_slider,bright_slider,is_example,log_output],
962
+ outputs=[processed_view, log_output],
963
+ ).then(
964
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
965
+ )
966
+ bright_slider.release(fn=clear_fields, inputs=[], outputs=[processed_view]).then(
967
+ fn=PCAing_log, inputs=[is_example, log_output], outputs=[log_output]
968
+ ).then(
969
+ fn=utonia_slider_update,
970
+ inputs=[target_dir,pca_slider,bright_slider,is_example,log_output],
971
+ outputs=[processed_view, log_output],
972
+ ).then(
973
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
974
+ )
975
+
976
+ input_file.change(fn=reset_log, inputs=[], outputs=[log_output]).then(
977
+ fn=update_gallery_on_upload,
978
+ inputs=[input_file,input_video, conf_thres,frame_slider,prediction_mode],
979
+ outputs=[original_view, target_dir, _, log_output],
980
+ )
981
+
982
+ input_video.change(fn=reset_log, inputs=[], outputs=[log_output]).then(
983
+ fn=update_gallery_on_upload,
984
+ inputs=[input_file,input_video, conf_thres,frame_slider,prediction_mode],
985
+ outputs=[original_view, target_dir, image_gallery, log_output],
986
+ )
987
+
988
+ if __name__ == "__main__":
989
+ demo.queue(max_size=20).launch(show_error=True, share=True)
geometry_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial.transform import Rotation as R
3
+
4
+
5
+ def pad_0001(Ts):
6
+ Ts = np.asarray(Ts)
7
+ if Ts.shape[-2:] == (4, 4):
8
+ return Ts
9
+ if Ts.shape[-2:] != (3, 4):
10
+ raise ValueError("Ts must have shape (..., 3, 4) or (..., 4, 4)")
11
+ pad = np.zeros((*Ts.shape[:-2], 1, 4), dtype=Ts.dtype)
12
+ pad[..., 0, 3] = 1
13
+ return np.concatenate([Ts, pad], axis=-2)
14
+
15
+
16
+ def T_to_C(T):
17
+ T = np.asarray(T)
18
+ if T.shape != (4, 4):
19
+ raise ValueError("T must be shape (4,4)")
20
+ Rm = T[:3, :3]
21
+ t = T[:3, 3]
22
+ return -Rm.T @ t
23
+
24
+
25
+ def im_distance_to_im_depth(im_dist, K):
26
+ im_dist = np.asarray(im_dist)
27
+ H, W = im_dist.shape[:2]
28
+ ys, xs = np.indices((H, W), dtype=np.float32)
29
+ fx, fy = K[0, 0], K[1, 1]
30
+ cx, cy = K[0, 2], K[1, 2]
31
+ x = (xs - cx) / max(fx, 1e-8)
32
+ y = (ys - cy) / max(fy, 1e-8)
33
+ ray_norm = np.sqrt(x * x + y * y + 1.0)
34
+ return im_dist / np.clip(ray_norm, 1e-8, None)
35
+
36
+
37
+ def im_depth_to_point_cloud(im_depth, K, T, to_image=False, ignore_invalid=False):
38
+ im_depth = np.asarray(im_depth)
39
+ H, W = im_depth.shape[:2]
40
+ ys, xs = np.indices((H, W), dtype=np.float32)
41
+
42
+ fx, fy = K[0, 0], K[1, 1]
43
+ cx, cy = K[0, 2], K[1, 2]
44
+
45
+ z = im_depth.reshape(-1)
46
+ x = (xs.reshape(-1) - cx) / max(fx, 1e-8) * z
47
+ y = (ys.reshape(-1) - cy) / max(fy, 1e-8) * z
48
+ pts_cam = np.stack([x, y, z], axis=-1)
49
+
50
+ T = np.asarray(T)
51
+ if T.shape != (4, 4):
52
+ raise ValueError("T must be shape (4,4)")
53
+ Rm = T[:3, :3]
54
+ t = T[:3, 3]
55
+
56
+ # Assume T is world->camera; invert to camera->world for point transform.
57
+ pts_world = (Rm.T @ (pts_cam - t).T).T
58
+
59
+ if ignore_invalid:
60
+ valid = np.isfinite(pts_world).all(axis=1) & (z > 0)
61
+ pts_world = pts_world[valid]
62
+
63
+ if to_image:
64
+ return pts_world.reshape(H, W, 3)
65
+ return pts_world
66
+
67
+ def rotx(x, theta=90):
68
+ """
69
+ Rotate x by theta degrees around the x-axis
70
+ """
71
+ theta = np.deg2rad(theta)
72
+ rot_matrix = np.array(
73
+ [
74
+ [1, 0, 0, 0],
75
+ [0, np.cos(theta), -np.sin(theta), 0],
76
+ [0, np.sin(theta), np.cos(theta), 0],
77
+ [0, 0, 0, 1],
78
+ ]
79
+ )
80
+ return rot_matrix@ x
81
+
82
+
83
+ def Coord2zup(points, extrinsics, normals = None):
84
+ """
85
+ Convert the dust3r coordinate system to the z-up coordinate system
86
+ """
87
+ points = np.concatenate([points, np.ones([points.shape[0], 1])], axis=1).T
88
+ points = rotx(points, -90)[:3].T
89
+ if normals is not None:
90
+ normals = np.concatenate([normals, np.ones([normals.shape[0], 1])], axis=1).T
91
+ normals = rotx(normals, -90)[:3].T
92
+ normals = normals / np.linalg.norm(normals, axis=1, keepdims=True)
93
+ t = np.min(points,axis=0)
94
+ points -= t
95
+ extrinsics = rotx(extrinsics, -90)
96
+ extrinsics[:, :3, 3] -= t.T
97
+ return points, extrinsics, normals
98
+
99
+ def _ransac_plane(points, distance_threshold=0.01, ransac_n=3, num_iterations=1000):
100
+ if points.shape[0] < ransac_n:
101
+ raise ValueError("Not enough points for plane fitting.")
102
+
103
+ best_inliers = None
104
+ best_plane = None
105
+ rng = np.random.default_rng(42)
106
+
107
+ for _ in range(num_iterations):
108
+ sample_idx = rng.choice(points.shape[0], size=ransac_n, replace=False)
109
+ p0, p1, p2 = points[sample_idx]
110
+ normal = np.cross(p1 - p0, p2 - p0)
111
+ norm = np.linalg.norm(normal)
112
+ if norm < 1e-8:
113
+ continue
114
+ normal = normal / norm
115
+ d = -np.dot(normal, p0)
116
+
117
+ dist = np.abs(points @ normal + d)
118
+ inliers = np.where(dist < distance_threshold)[0]
119
+ if best_inliers is None or len(inliers) > len(best_inliers):
120
+ best_inliers = inliers
121
+ best_plane = np.array([normal[0], normal[1], normal[2], d], dtype=np.float64)
122
+
123
+ if best_inliers is None or best_plane is None:
124
+ raise ValueError("Failed to fit plane with RANSAC.")
125
+
126
+ return best_plane, best_inliers
127
+
128
+
129
+ def extract_and_align_ground_plane(points,
130
+ colors=None,
131
+ normals=None,
132
+ height_percentile=20,
133
+ ransac_distance_threshold=0.01,
134
+ ransac_n=3,
135
+ ransac_iterations=1000,
136
+ max_angle_degree=40,
137
+ max_trials=6):
138
+ points = np.asarray(points)
139
+ if points.ndim != 2 or points.shape[1] != 3:
140
+ raise ValueError("points must be shaped (N, 3)")
141
+
142
+ aligned_colors = np.asarray(colors) if colors is not None else None
143
+ aligned_normals = np.asarray(normals) if normals is not None else None
144
+
145
+ z_vals = points[:, 2]
146
+ z_thresh = np.percentile(z_vals, height_percentile)
147
+ low_indices = np.where(z_vals <= z_thresh)[0]
148
+
149
+ remaining_indices = low_indices.copy()
150
+
151
+ for trial in range(max_trials):
152
+ if len(remaining_indices) < ransac_n:
153
+ raise ValueError("Not enough points left to fit a plane.")
154
+
155
+ candidate_points = points[remaining_indices]
156
+ plane_model, inliers = _ransac_plane(
157
+ candidate_points,
158
+ distance_threshold=ransac_distance_threshold,
159
+ ransac_n=ransac_n,
160
+ num_iterations=ransac_iterations,
161
+ )
162
+ a, b, c, d = plane_model
163
+ normal = np.array([a, b, c])
164
+ normal /= np.linalg.norm(normal)
165
+
166
+ angle = np.arccos(np.clip(np.dot(normal, [0, 0, 1]), -1.0, 1.0)) * 180 / np.pi
167
+ if angle <= max_angle_degree:
168
+ inliers_global = remaining_indices[inliers]
169
+
170
+ target = np.array([0, 0, 1])
171
+ axis = np.cross(normal, target)
172
+ axis_norm = np.linalg.norm(axis)
173
+
174
+ if axis_norm < 1e-6:
175
+ rotation_matrix = np.eye(3)
176
+ else:
177
+ axis /= axis_norm
178
+ rot_angle = np.arccos(np.clip(np.dot(normal, target), -1.0, 1.0))
179
+ rotation = R.from_rotvec(axis * rot_angle)
180
+ rotation_matrix = rotation.as_matrix()
181
+
182
+ rotated_points = points @ rotation_matrix.T
183
+ ground_points_z = rotated_points[inliers_global, 2]
184
+ offset = np.mean(ground_points_z)
185
+ rotated_points[:, 2] -= offset
186
+
187
+ if aligned_normals is not None and len(aligned_normals) == len(points):
188
+ aligned_normals = aligned_normals @ rotation_matrix.T
189
+ aligned_normals = aligned_normals / np.clip(np.linalg.norm(aligned_normals, axis=-1, keepdims=True), 1e-8, None)
190
+
191
+ return rotated_points, aligned_colors, aligned_normals, inliers_global, rotation_matrix, offset
192
+
193
+ else:
194
+ rejected_indices = remaining_indices[inliers]
195
+ remaining_indices = np.setdiff1d(remaining_indices, rejected_indices)
196
+
197
+ raise ValueError("Failed to find a valid ground plane within max trials.")
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base
2
+ numpy<=1.26.4
3
+ scipy
4
+ addict
5
+ timm
6
+ psutil
7
+ huggingface_hub
8
+ opencv-python-headless
9
+ einops
10
+ ninja
11
+
12
+ torch==2.4.0
13
+ torchvision==0.19.0
14
+ torchaudio==2.4.0
15
+
16
+ # Extra dependencies
17
+ --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html
18
+ torch-scatter
19
+
20
+ spconv-cu120
21
+
22
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
23
+
24
+ # Visualization
25
+ trimesh
26
+ camtools
27
+ matplotlib
setup.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+
4
+ setup(
5
+ name="utonia",
6
+ py_modules=["utonia"],
7
+ version="1.0",
8
+ description="",
9
+ author="Yujia Zhang",
10
+ packages=find_packages(exclude=["demo*"]),
11
+ include_package_data=True,
12
+ )
utonia/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from .model import load
17
+
18
+ from . import model
19
+ from . import module
20
+ from . import structure
21
+ from . import data
22
+ from . import transform
23
+ from . import utils
24
+ from . import registry
25
+
26
+ __all__ = ["load", "model", "module", "structure", "transform", "registry", "utils"]
utonia/data.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import numpy as np
18
+ import torch
19
+ from collections.abc import Mapping, Sequence
20
+ from huggingface_hub import hf_hub_download
21
+
22
+ DATAS = [
23
+ "sample1",
24
+ "sample1_high_res",
25
+ "sample1_dino",
26
+ "sample2_outdoor",
27
+ "sample2_outdoor_multiframe",
28
+ "sample3_object",
29
+ "sample4_manipulation",
30
+ "sample5_hk",
31
+ ]
32
+
33
+
34
+ def load(
35
+ name: str = "utonia_large",
36
+ download_root: str = None,
37
+ ):
38
+ if name in DATAS:
39
+ print(f"Loading data from HuggingFace: {name} ...")
40
+ data_path = hf_hub_download(
41
+ repo_id="pointcept/demo",
42
+ filename=f"{name}.npz",
43
+ repo_type="dataset",
44
+ revision="main",
45
+ local_dir=download_root or os.path.expanduser("~/.cache/utonia/data"),
46
+ )
47
+ elif os.path.isfile(name):
48
+ print(f"Loading data in local path: {name} ...")
49
+ data_path = name
50
+ else:
51
+ raise RuntimeError(f"Data {name} not found; available models = {DATAS}")
52
+ return dict(np.load(data_path))
53
+
54
+
55
+ from torch.utils.data.dataloader import default_collate
56
+
57
+
58
+ def collate_fn(batch):
59
+ """
60
+ collate function for point cloud which support dict and list,
61
+ 'coord' is necessary to determine 'offset'
62
+ """
63
+ if not isinstance(batch, Sequence):
64
+ raise TypeError(f"{batch.dtype} is not supported.")
65
+
66
+ if isinstance(batch[0], torch.Tensor):
67
+ return torch.cat(list(batch))
68
+ elif isinstance(batch[0], str):
69
+ # str is also a kind of Sequence, judgement should before Sequence
70
+ return list(batch)
71
+ elif isinstance(batch[0], Sequence):
72
+ for data in batch:
73
+ data.append(torch.tensor([data[0].shape[0]]))
74
+ batch = [collate_fn(samples) for samples in zip(*batch)]
75
+ batch[-1] = torch.cumsum(batch[-1], dim=0).int()
76
+ return batch
77
+ elif isinstance(batch[0], Mapping):
78
+ batch = {
79
+ key: (
80
+ collate_fn([d[key] for d in batch])
81
+ if "offset" not in key
82
+ # offset -> bincount -> concat bincount-> concat offset
83
+ else torch.cumsum(
84
+ collate_fn([d[key].diff(prepend=torch.tensor([0])) for d in batch]),
85
+ dim=0,
86
+ )
87
+ )
88
+ for key in batch[0]
89
+ }
90
+ return batch
91
+ else:
92
+ return default_collate(batch)
utonia/model.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Point Transformer - V3 Mode3 - Utonia
3
+ Pointcept detached version
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import os
25
+ import math
26
+ from packaging import version
27
+ from huggingface_hub import hf_hub_download, PyTorchModelHubMixin
28
+ from addict import Dict
29
+ import torch
30
+ import torch.nn as nn
31
+ from torch.nn.init import trunc_normal_
32
+ import spconv.pytorch as spconv
33
+ import torch_scatter
34
+ from timm.layers import DropPath
35
+
36
+ try:
37
+ import flash_attn
38
+ except ImportError:
39
+ flash_attn = None
40
+
41
+ from .structure import Point
42
+ from .module import PointSequential, PointModule
43
+ from .utils import offset2bincount
44
+
45
+ MODELS = [
46
+ "utonia",
47
+ "utonia_linear_prob_head_sc",
48
+ ]
49
+
50
+
51
+ class Point3DRoPE(nn.Module):
52
+ def __init__(self, head_dim, base=10000):
53
+ super().__init__()
54
+ assert (
55
+ head_dim % 3 == 0
56
+ ), f"Head dimension must be divisible by 3 for 3D RoPE, {head_dim}"
57
+
58
+ self.head_dim = head_dim
59
+ self.chunk_dim = head_dim // 3
60
+ self.base = base
61
+ inv_freq = 1.0 / (
62
+ self.base ** (torch.arange(0, self.chunk_dim, 2).float() / self.chunk_dim)
63
+ )
64
+ self.register_buffer("inv_freq", inv_freq)
65
+
66
+ def get_cos_sin(self, xyz):
67
+ x = xyz[:, 0:1]
68
+ y = xyz[:, 1:2]
69
+ z = xyz[:, 2:3]
70
+ freqs = self.inv_freq.unsqueeze(0)
71
+
72
+ emb_x = x * freqs
73
+ emb_y = y * freqs
74
+ emb_z = z * freqs
75
+
76
+ emb_x = torch.cat((emb_x, emb_x), dim=-1)
77
+ emb_y = torch.cat((emb_y, emb_y), dim=-1)
78
+ emb_z = torch.cat((emb_z, emb_z), dim=-1)
79
+ emb_3d = torch.cat((emb_x, emb_y, emb_z), dim=-1)
80
+ return emb_3d.cos().unsqueeze(1), emb_3d.sin().unsqueeze(1)
81
+
82
+ def rotate_half(self, x):
83
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
84
+ return torch.cat((-x2, x1), dim=-1)
85
+
86
+ def forward(self, q, k, xyz):
87
+ cos, sin = self.get_cos_sin(xyz)
88
+ qs = torch.split(q, self.chunk_dim, dim=-1)
89
+ ks = torch.split(k, self.chunk_dim, dim=-1)
90
+
91
+ coss = torch.split(cos, self.chunk_dim, dim=-1)
92
+ sins = torch.split(sin, self.chunk_dim, dim=-1)
93
+
94
+ q_outs = []
95
+ k_outs = []
96
+
97
+ for i in range(3):
98
+ q_part = (qs[i] * coss[i]) + (self.rotate_half(qs[i]) * sins[i])
99
+ k_part = (ks[i] * coss[i]) + (self.rotate_half(ks[i]) * sins[i])
100
+
101
+ q_outs.append(q_part)
102
+ k_outs.append(k_part)
103
+
104
+ q_rot = torch.cat(q_outs, dim=-1)
105
+ k_rot = torch.cat(k_outs, dim=-1)
106
+
107
+ return q_rot, k_rot
108
+
109
+
110
+ class LayerScale(nn.Module):
111
+ def __init__(
112
+ self,
113
+ dim: int,
114
+ init_values: float = 1e-5,
115
+ inplace: bool = False,
116
+ ) -> None:
117
+ super().__init__()
118
+ self.inplace = inplace
119
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
120
+
121
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
122
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
123
+
124
+
125
+ class RPE(torch.nn.Module):
126
+ def __init__(self, patch_size, num_heads):
127
+ super().__init__()
128
+ self.patch_size = patch_size
129
+ self.num_heads = num_heads
130
+ self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2)
131
+ self.rpe_num = 2 * self.pos_bnd + 1
132
+ self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
133
+ torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
134
+
135
+ def forward(self, coord):
136
+ idx = (
137
+ coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
138
+ + self.pos_bnd # relative position to positive index
139
+ + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
140
+ )
141
+ out = self.rpe_table.index_select(0, idx.reshape(-1))
142
+ out = out.view(idx.shape + (-1,)).sum(3)
143
+ out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
144
+ return out
145
+
146
+
147
+ class SerializedAttention(PointModule):
148
+ def __init__(
149
+ self,
150
+ channels,
151
+ num_heads,
152
+ patch_size,
153
+ qkv_bias=True,
154
+ qk_scale=None,
155
+ attn_drop=0.0,
156
+ proj_drop=0.0,
157
+ order_index=0,
158
+ enable_rpe=False,
159
+ enable_flash=True,
160
+ upcast_attention=True,
161
+ upcast_softmax=True,
162
+ rope_base=10,
163
+ shift_coords=None,
164
+ jitter_coords=None,
165
+ rescale_coords=None,
166
+ ):
167
+ super().__init__()
168
+ assert channels % num_heads == 0
169
+ self.channels = channels
170
+ self.num_heads = num_heads
171
+ self.scale = qk_scale or (channels // num_heads) ** -0.5
172
+ self.order_index = order_index
173
+ self.upcast_attention = upcast_attention
174
+ self.upcast_softmax = upcast_softmax
175
+ self.enable_rpe = enable_rpe
176
+ self.enable_flash = enable_flash
177
+ if enable_flash:
178
+ assert (
179
+ enable_rpe is False
180
+ ), "Set enable_rpe to False when enable Flash Attention"
181
+ assert (
182
+ upcast_attention is False
183
+ ), "Set upcast_attention to False when enable Flash Attention"
184
+ assert (
185
+ upcast_softmax is False
186
+ ), "Set upcast_softmax to False when enable Flash Attention"
187
+ assert flash_attn is not None, "Make sure flash_attn is installed."
188
+ self.patch_size = patch_size
189
+ self.attn_drop = attn_drop
190
+ else:
191
+ # when disable flash attention, we still don't want to use mask
192
+ # consequently, patch size will auto set to the
193
+ # min number of patch_size_max and number of points
194
+ self.patch_size_max = patch_size
195
+ self.patch_size = 0
196
+ self.attn_drop = torch.nn.Dropout(attn_drop)
197
+
198
+ self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias)
199
+ self.proj = torch.nn.Linear(channels, channels)
200
+ self.proj_drop = torch.nn.Dropout(proj_drop)
201
+ self.softmax = torch.nn.Softmax(dim=-1)
202
+ self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None
203
+ self.rope = Point3DRoPE(head_dim=channels // num_heads, base=rope_base)
204
+ self.shift_coords = shift_coords
205
+ self.jitter_coords = jitter_coords
206
+ self.rescale_coords = rescale_coords
207
+
208
+ @torch.no_grad()
209
+ def get_rel_pos(self, point, order):
210
+ K = self.patch_size
211
+ rel_pos_key = f"rel_pos_{self.order_index}"
212
+ if rel_pos_key not in point.keys():
213
+ grid_coord = point.grid_coord[order]
214
+ grid_coord = grid_coord.reshape(-1, K, 3)
215
+ point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1)
216
+ return point[rel_pos_key]
217
+
218
+ @torch.no_grad()
219
+ def get_padding_and_inverse(self, point):
220
+ pad_key = "pad"
221
+ unpad_key = "unpad"
222
+ cu_seqlens_key = "cu_seqlens_key"
223
+ if (
224
+ pad_key not in point.keys()
225
+ or unpad_key not in point.keys()
226
+ or cu_seqlens_key not in point.keys()
227
+ ):
228
+ offset = point.offset
229
+ bincount = offset2bincount(offset)
230
+ bincount_pad = (
231
+ torch.div(
232
+ bincount + self.patch_size - 1,
233
+ self.patch_size,
234
+ rounding_mode="trunc",
235
+ )
236
+ * self.patch_size
237
+ )
238
+ # only pad point when num of points larger than patch_size
239
+ mask_pad = bincount > self.patch_size
240
+ bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad
241
+ _offset = nn.functional.pad(offset, (1, 0))
242
+ _offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0))
243
+ pad = torch.arange(_offset_pad[-1], device=offset.device)
244
+ unpad = torch.arange(_offset[-1], device=offset.device)
245
+ cu_seqlens = []
246
+ for i in range(len(offset)):
247
+ unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i]
248
+ if bincount[i] != bincount_pad[i]:
249
+ pad[
250
+ _offset_pad[i + 1]
251
+ - self.patch_size
252
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
253
+ ] = pad[
254
+ _offset_pad[i + 1]
255
+ - 2 * self.patch_size
256
+ + (bincount[i] % self.patch_size) : _offset_pad[i + 1]
257
+ - self.patch_size
258
+ ]
259
+ pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i]
260
+ cu_seqlens.append(
261
+ torch.arange(
262
+ _offset_pad[i],
263
+ _offset_pad[i + 1],
264
+ step=self.patch_size,
265
+ dtype=torch.int32,
266
+ device=offset.device,
267
+ )
268
+ )
269
+ point[pad_key] = pad
270
+ point[unpad_key] = unpad
271
+ point[cu_seqlens_key] = nn.functional.pad(
272
+ torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1]
273
+ )
274
+ return point[pad_key], point[unpad_key], point[cu_seqlens_key]
275
+
276
+ def forward(self, point):
277
+ if not self.enable_flash:
278
+ self.patch_size = min(
279
+ offset2bincount(point.offset).min().tolist(), self.patch_size_max
280
+ )
281
+
282
+ H = self.num_heads
283
+ K = self.patch_size
284
+ C = self.channels
285
+
286
+ pad, unpad, cu_seqlens = self.get_padding_and_inverse(point)
287
+
288
+ order = point.serialized_order[self.order_index][pad]
289
+ inverse = unpad[point.serialized_inverse[self.order_index]]
290
+
291
+ # padding and reshape feat and batch for serialized point patch
292
+ qkv = self.qkv(point.feat)[order]
293
+
294
+ rope_coord = point.coord[order].clone()
295
+ if not self.enable_flash:
296
+ # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C')
297
+ qkv = qkv.reshape(-1, 3, H, C // H)
298
+ q, k, v = qkv.unbind(dim=1)
299
+ q, k = self.rope(q, k, rope_coord)
300
+ qkv_roped = torch.stack([q, k, v], dim=1)
301
+ q, k, v = (
302
+ qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0)
303
+ )
304
+ # attn
305
+ if self.upcast_attention:
306
+ q = q.float()
307
+ k = k.float()
308
+ attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K)
309
+ if self.enable_rpe:
310
+ attn = attn + self.rpe(self.get_rel_pos(point, order))
311
+ if self.upcast_softmax:
312
+ attn = attn.float()
313
+ attn = self.softmax(attn)
314
+ attn = self.attn_drop(attn).to(qkv.dtype)
315
+ feat = (attn @ v).transpose(1, 2).reshape(-1, C)
316
+ else:
317
+ qkv = qkv.reshape(-1, 3, H, C // H)
318
+ q, k, v = qkv.unbind(dim=1)
319
+ q, k = self.rope(q, k, rope_coord)
320
+ qkv_roped = torch.stack([q, k, v], dim=1)
321
+ feat = flash_attn.flash_attn_varlen_qkvpacked_func(
322
+ qkv_roped.to(torch.bfloat16),
323
+ cu_seqlens,
324
+ max_seqlen=self.patch_size,
325
+ dropout_p=self.attn_drop if self.training else 0,
326
+ softmax_scale=self.scale,
327
+ ).reshape(-1, C)
328
+ feat = feat.to(qkv.dtype)
329
+ feat = feat[inverse]
330
+
331
+ # ffn
332
+ feat = self.proj(feat)
333
+ feat = self.proj_drop(feat)
334
+ point.feat = feat
335
+ return point
336
+
337
+
338
+ class MLP(nn.Module):
339
+ def __init__(
340
+ self,
341
+ in_channels,
342
+ hidden_channels=None,
343
+ out_channels=None,
344
+ act_layer=nn.GELU,
345
+ drop=0.0,
346
+ ):
347
+ super().__init__()
348
+ out_channels = out_channels or in_channels
349
+ hidden_channels = hidden_channels or in_channels
350
+ self.fc1 = nn.Linear(in_channels, hidden_channels)
351
+ self.act = act_layer()
352
+ self.fc2 = nn.Linear(hidden_channels, out_channels)
353
+ self.drop = nn.Dropout(drop)
354
+
355
+ def forward(self, x):
356
+ x = self.fc1(x)
357
+ x = self.act(x)
358
+ x = self.drop(x)
359
+ x = self.fc2(x)
360
+ x = self.drop(x)
361
+ return x
362
+
363
+
364
+ class Block(PointModule):
365
+ def __init__(
366
+ self,
367
+ channels,
368
+ num_heads,
369
+ patch_size=48,
370
+ mlp_ratio=4.0,
371
+ qkv_bias=True,
372
+ qk_scale=None,
373
+ attn_drop=0.0,
374
+ proj_drop=0.0,
375
+ drop_path=0.0,
376
+ layer_scale=None,
377
+ norm_layer=nn.LayerNorm,
378
+ act_layer=nn.GELU,
379
+ pre_norm=True,
380
+ order_index=0,
381
+ cpe_indice_key=None,
382
+ enable_rpe=False,
383
+ enable_flash=True,
384
+ upcast_attention=True,
385
+ upcast_softmax=True,
386
+ rope_base=10,
387
+ shift_coords=None,
388
+ jitter_coords=None,
389
+ rescale_coords=None,
390
+ ):
391
+ super().__init__()
392
+ self.channels = channels
393
+ self.pre_norm = pre_norm
394
+
395
+ self.cpe = PointSequential(
396
+ spconv.SubMConv3d(
397
+ channels,
398
+ channels,
399
+ kernel_size=3,
400
+ bias=True,
401
+ indice_key=cpe_indice_key,
402
+ ),
403
+ nn.Linear(channels, channels),
404
+ norm_layer(channels),
405
+ )
406
+
407
+ self.norm1 = PointSequential(norm_layer(channels))
408
+ self.ls1 = PointSequential(
409
+ LayerScale(channels, init_values=layer_scale)
410
+ if layer_scale is not None
411
+ else nn.Identity()
412
+ )
413
+ self.attn = SerializedAttention(
414
+ channels=channels,
415
+ patch_size=patch_size,
416
+ num_heads=num_heads,
417
+ qkv_bias=qkv_bias,
418
+ qk_scale=qk_scale,
419
+ attn_drop=attn_drop,
420
+ proj_drop=proj_drop,
421
+ order_index=order_index,
422
+ enable_rpe=enable_rpe,
423
+ enable_flash=enable_flash,
424
+ upcast_attention=upcast_attention,
425
+ upcast_softmax=upcast_softmax,
426
+ rope_base=rope_base,
427
+ shift_coords=shift_coords,
428
+ jitter_coords=jitter_coords,
429
+ rescale_coords=rescale_coords,
430
+ )
431
+ self.norm2 = PointSequential(norm_layer(channels))
432
+ self.ls2 = PointSequential(
433
+ LayerScale(channels, init_values=layer_scale)
434
+ if layer_scale is not None
435
+ else nn.Identity()
436
+ )
437
+ self.mlp = PointSequential(
438
+ MLP(
439
+ in_channels=channels,
440
+ hidden_channels=int(channels * mlp_ratio),
441
+ out_channels=channels,
442
+ act_layer=act_layer,
443
+ drop=proj_drop,
444
+ )
445
+ )
446
+ self.drop_path = PointSequential(
447
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
448
+ )
449
+
450
+ def forward(self, point: Point):
451
+ shortcut = point.feat
452
+ point = self.cpe(point)
453
+ point.feat = shortcut + point.feat
454
+ shortcut = point.feat
455
+ if self.pre_norm:
456
+ point = self.norm1(point)
457
+ point = self.drop_path(self.ls1(self.attn(point)))
458
+ point.feat = shortcut + point.feat
459
+ if not self.pre_norm:
460
+ point = self.norm1(point)
461
+
462
+ shortcut = point.feat
463
+ if self.pre_norm:
464
+ point = self.norm2(point)
465
+ point = self.drop_path(self.ls2(self.mlp(point)))
466
+ point.feat = shortcut + point.feat
467
+ if not self.pre_norm:
468
+ point = self.norm2(point)
469
+ point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat)
470
+ return point
471
+
472
+
473
+ class GridPooling(PointModule):
474
+ def __init__(
475
+ self,
476
+ in_channels,
477
+ out_channels,
478
+ stride=2,
479
+ norm_layer=None,
480
+ act_layer=None,
481
+ reduce="max",
482
+ shuffle_orders=True,
483
+ traceable=True, # record parent and cluster
484
+ ):
485
+ super().__init__()
486
+ self.in_channels = in_channels
487
+ self.out_channels = out_channels
488
+
489
+ self.stride = stride
490
+ assert reduce in ["sum", "mean", "min", "max"]
491
+ self.reduce = reduce
492
+ self.shuffle_orders = shuffle_orders
493
+ self.traceable = traceable
494
+
495
+ self.proj = nn.Linear(in_channels, out_channels)
496
+ if norm_layer is not None:
497
+ self.norm = PointSequential(norm_layer(out_channels))
498
+ if act_layer is not None:
499
+ self.act = PointSequential(act_layer())
500
+
501
+ def forward(self, point: Point):
502
+ if "grid_coord" in point.keys():
503
+ grid_coord = point.grid_coord
504
+ elif {"coord", "grid_size"}.issubset(point.keys()):
505
+ grid_coord = torch.div(
506
+ point.coord - point.coord.min(0)[0],
507
+ point.grid_size,
508
+ rounding_mode="trunc",
509
+ ).int()
510
+ else:
511
+ raise AssertionError(
512
+ "[gird_coord] or [coord, grid_size] should be include in the Point"
513
+ )
514
+ grid_coord = torch.div(grid_coord, self.stride, rounding_mode="trunc")
515
+ grid_coord = grid_coord | point.batch.view(-1, 1) << 48
516
+ grid_coord, cluster, counts = torch.unique(
517
+ grid_coord,
518
+ sorted=True,
519
+ return_inverse=True,
520
+ return_counts=True,
521
+ dim=0,
522
+ )
523
+ grid_coord = grid_coord & ((1 << 48) - 1)
524
+ # indices of point sorted by cluster, for torch_scatter.segment_csr
525
+ _, indices = torch.sort(cluster)
526
+ # index pointer for sorted point, for torch_scatter.segment_csr
527
+ idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)])
528
+ # head_indices of each cluster, for reduce attr e.g. code, batch
529
+ head_indices = indices[idx_ptr[:-1]]
530
+ point_dict = Dict(
531
+ feat=torch_scatter.segment_csr(
532
+ self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce
533
+ ),
534
+ coord=torch_scatter.segment_csr(
535
+ point.coord[indices], idx_ptr, reduce="mean"
536
+ ),
537
+ grid_coord=grid_coord,
538
+ batch=point.batch[head_indices],
539
+ )
540
+ if "origin_coord" in point.keys():
541
+ point_dict["origin_coord"] = torch_scatter.segment_csr(
542
+ point.origin_coord[indices], idx_ptr, reduce="mean"
543
+ )
544
+ if "condition" in point.keys():
545
+ point_dict["condition"] = point.condition
546
+ if "context" in point.keys():
547
+ point_dict["context"] = point.context
548
+ if "name" in point.keys():
549
+ point_dict["name"] = point.name
550
+ if "split" in point.keys():
551
+ point_dict["split"] = point.split
552
+ if "color" in point.keys():
553
+ point_dict["color"] = torch_scatter.segment_csr(
554
+ point.color[indices], idx_ptr, reduce="mean"
555
+ )
556
+ if "grid_size" in point.keys():
557
+ point_dict["grid_size"] = point.grid_size * self.stride
558
+
559
+ if self.traceable:
560
+ point_dict["pooling_inverse"] = cluster
561
+ point_dict["pooling_parent"] = point
562
+ point_dict["idx_ptr"] = idx_ptr
563
+ order = point.order
564
+ point = Point(point_dict)
565
+ if self.norm is not None:
566
+ point = self.norm(point)
567
+ if self.act is not None:
568
+ point = self.act(point)
569
+ point.serialization(order=order, shuffle_orders=self.shuffle_orders)
570
+ point.sparsify()
571
+ return point
572
+
573
+
574
+ class GridUnpooling(PointModule):
575
+ def __init__(
576
+ self,
577
+ in_channels,
578
+ skip_channels,
579
+ out_channels,
580
+ norm_layer=None,
581
+ act_layer=None,
582
+ traceable=False, # record parent and cluster
583
+ ):
584
+ super().__init__()
585
+ self.proj = PointSequential(nn.Linear(in_channels, out_channels))
586
+ self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels))
587
+
588
+ if norm_layer is not None:
589
+ self.proj.add(norm_layer(out_channels))
590
+ self.proj_skip.add(norm_layer(out_channels))
591
+
592
+ if act_layer is not None:
593
+ self.proj.add(act_layer())
594
+ self.proj_skip.add(act_layer())
595
+
596
+ self.traceable = traceable
597
+
598
+ def forward(self, point):
599
+ assert "pooling_parent" in point.keys()
600
+ assert "pooling_inverse" in point.keys()
601
+ parent = point.pop("pooling_parent")
602
+ inverse = point.pooling_inverse
603
+ feat = point.feat
604
+
605
+ parent = self.proj_skip(parent)
606
+ parent.feat = parent.feat + self.proj(point).feat[inverse]
607
+ parent.sparse_conv_feat = parent.sparse_conv_feat.replace_feature(parent.feat)
608
+
609
+ if self.traceable:
610
+ point.feat = feat
611
+ parent["unpooling_parent"] = point
612
+ return parent
613
+
614
+
615
+ class Embedding(PointModule):
616
+ def __init__(
617
+ self,
618
+ in_channels,
619
+ embed_channels,
620
+ norm_layer=None,
621
+ act_layer=None,
622
+ mask_token=False,
623
+ ):
624
+ super().__init__()
625
+ self.in_channels = in_channels
626
+ self.embed_channels = embed_channels
627
+
628
+ self.stem = PointSequential(linear=nn.Linear(in_channels, embed_channels))
629
+ if norm_layer is not None:
630
+ self.stem.add(norm_layer(embed_channels), name="norm")
631
+ if act_layer is not None:
632
+ self.stem.add(act_layer(), name="act")
633
+
634
+ if mask_token:
635
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_channels))
636
+ else:
637
+ self.mask_token = None
638
+
639
+ def forward(self, point: Point):
640
+ point = self.stem(point)
641
+ if "mask" in point.keys():
642
+ point.feat = torch.where(
643
+ point.mask.unsqueeze(-1),
644
+ self.mask_token.to(point.feat.dtype),
645
+ point.feat,
646
+ )
647
+ return point
648
+
649
+
650
+ class PointTransformerV3(PointModule, PyTorchModelHubMixin):
651
+ def __init__(
652
+ self,
653
+ in_channels=6,
654
+ order=("z", "z-trans"),
655
+ stride=(2, 2, 2, 2),
656
+ enc_depths=(3, 3, 3, 12, 3),
657
+ enc_channels=(48, 96, 192, 384, 512),
658
+ enc_num_head=(3, 6, 12, 24, 32),
659
+ enc_patch_size=(1024, 1024, 1024, 1024, 1024),
660
+ dec_depths=(3, 3, 3, 3),
661
+ dec_channels=(96, 96, 192, 384),
662
+ dec_num_head=(6, 6, 12, 32),
663
+ dec_patch_size=(1024, 1024, 1024, 1024),
664
+ mlp_ratio=4,
665
+ qkv_bias=True,
666
+ qk_scale=None,
667
+ attn_drop=0.0,
668
+ proj_drop=0.0,
669
+ drop_path=0.3,
670
+ layer_scale=None,
671
+ pre_norm=True,
672
+ shuffle_orders=True,
673
+ enable_rpe=False,
674
+ enable_flash=True,
675
+ upcast_attention=False,
676
+ upcast_softmax=False,
677
+ traceable=False,
678
+ mask_token=False,
679
+ enc_mode=False,
680
+ freeze_encoder=False,
681
+ rope_base=10,
682
+ shift_coords=None,
683
+ jitter_coords=None,
684
+ rescale_coords=None,
685
+ ):
686
+ super().__init__()
687
+ self.num_stages = len(enc_depths)
688
+ self.order = [order] if isinstance(order, str) else order
689
+ self.enc_mode = enc_mode
690
+ self.shuffle_orders = shuffle_orders
691
+ self.freeze_encoder = freeze_encoder
692
+
693
+ assert self.num_stages == len(stride) + 1
694
+ assert self.num_stages == len(enc_depths)
695
+ assert self.num_stages == len(enc_channels)
696
+ assert self.num_stages == len(enc_num_head)
697
+ assert self.num_stages == len(enc_patch_size)
698
+ assert self.enc_mode or self.num_stages == len(dec_depths) + 1
699
+ assert self.enc_mode or self.num_stages == len(dec_channels) + 1
700
+ assert self.enc_mode or self.num_stages == len(dec_num_head) + 1
701
+ assert self.enc_mode or self.num_stages == len(dec_patch_size) + 1
702
+
703
+ # normalization layer
704
+ ln_layer = nn.LayerNorm
705
+ # activation layers
706
+ act_layer = nn.GELU
707
+
708
+ self.embedding = Embedding(
709
+ in_channels=in_channels,
710
+ embed_channels=enc_channels[0],
711
+ norm_layer=ln_layer,
712
+ act_layer=act_layer,
713
+ mask_token=mask_token,
714
+ )
715
+
716
+ # encoder
717
+ enc_drop_path = [
718
+ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths))
719
+ ]
720
+ self.enc = PointSequential()
721
+ for s in range(self.num_stages):
722
+ enc_drop_path_ = enc_drop_path[
723
+ sum(enc_depths[:s]) : sum(enc_depths[: s + 1])
724
+ ]
725
+ enc = PointSequential()
726
+ if s > 0:
727
+ enc.add(
728
+ GridPooling(
729
+ in_channels=enc_channels[s - 1],
730
+ out_channels=enc_channels[s],
731
+ stride=stride[s - 1],
732
+ norm_layer=ln_layer,
733
+ act_layer=act_layer,
734
+ ),
735
+ name="down",
736
+ )
737
+ for i in range(enc_depths[s]):
738
+ enc.add(
739
+ Block(
740
+ channels=enc_channels[s],
741
+ num_heads=enc_num_head[s],
742
+ patch_size=enc_patch_size[s],
743
+ mlp_ratio=mlp_ratio,
744
+ qkv_bias=qkv_bias,
745
+ qk_scale=qk_scale,
746
+ attn_drop=attn_drop,
747
+ proj_drop=proj_drop,
748
+ drop_path=enc_drop_path_[i],
749
+ layer_scale=layer_scale,
750
+ norm_layer=ln_layer,
751
+ act_layer=act_layer,
752
+ pre_norm=pre_norm,
753
+ order_index=i % len(self.order),
754
+ cpe_indice_key=f"stage{s}",
755
+ enable_rpe=enable_rpe,
756
+ enable_flash=enable_flash,
757
+ upcast_attention=upcast_attention,
758
+ upcast_softmax=upcast_softmax,
759
+ rope_base=rope_base,
760
+ shift_coords=shift_coords,
761
+ jitter_coords=jitter_coords,
762
+ rescale_coords=rescale_coords,
763
+ ),
764
+ name=f"block{i}",
765
+ )
766
+ if len(enc) != 0:
767
+ self.enc.add(module=enc, name=f"enc{s}")
768
+
769
+ # decoder
770
+ if not self.enc_mode:
771
+ dec_drop_path = [
772
+ x.item() for x in torch.linspace(0, drop_path, sum(dec_depths))
773
+ ]
774
+ self.dec = PointSequential()
775
+ dec_channels = list(dec_channels) + [enc_channels[-1]]
776
+ for s in reversed(range(self.num_stages - 1)):
777
+ dec_drop_path_ = dec_drop_path[
778
+ sum(dec_depths[:s]) : sum(dec_depths[: s + 1])
779
+ ]
780
+ dec_drop_path_.reverse()
781
+ dec = PointSequential()
782
+ dec.add(
783
+ GridUnpooling(
784
+ in_channels=dec_channels[s + 1],
785
+ skip_channels=enc_channels[s],
786
+ out_channels=dec_channels[s],
787
+ norm_layer=ln_layer,
788
+ act_layer=act_layer,
789
+ traceable=traceable,
790
+ ),
791
+ name="up",
792
+ )
793
+ for i in range(dec_depths[s]):
794
+ dec.add(
795
+ Block(
796
+ channels=dec_channels[s],
797
+ num_heads=dec_num_head[s],
798
+ patch_size=dec_patch_size[s],
799
+ mlp_ratio=mlp_ratio,
800
+ qkv_bias=qkv_bias,
801
+ qk_scale=qk_scale,
802
+ attn_drop=attn_drop,
803
+ proj_drop=proj_drop,
804
+ drop_path=dec_drop_path_[i],
805
+ layer_scale=layer_scale,
806
+ norm_layer=ln_layer,
807
+ act_layer=act_layer,
808
+ pre_norm=pre_norm,
809
+ order_index=i % len(self.order),
810
+ cpe_indice_key=f"stage{s}",
811
+ enable_rpe=enable_rpe,
812
+ enable_flash=enable_flash,
813
+ upcast_attention=upcast_attention,
814
+ upcast_softmax=upcast_softmax,
815
+ rope_base=rope_base,
816
+ shift_coords=shift_coords,
817
+ jitter_coords=jitter_coords,
818
+ rescale_coords=rescale_coords,
819
+ ),
820
+ name=f"block{i}",
821
+ )
822
+ self.dec.add(module=dec, name=f"dec{s}")
823
+ if self.freeze_encoder:
824
+ for p in self.embedding.parameters():
825
+ p.requires_grad = False
826
+ for p in self.enc.parameters():
827
+ p.requires_grad = False
828
+ self.apply(self._init_weights)
829
+
830
+ @staticmethod
831
+ def _init_weights(module):
832
+ if isinstance(module, nn.Linear):
833
+ trunc_normal_(module.weight, std=0.02)
834
+ if module.bias is not None:
835
+ nn.init.zeros_(module.bias)
836
+ elif isinstance(module, spconv.SubMConv3d):
837
+ trunc_normal_(module.weight, std=0.02)
838
+ if module.bias is not None:
839
+ nn.init.zeros_(module.bias)
840
+
841
+ def forward(self, data_dict):
842
+ point = Point(data_dict)
843
+ point = self.embedding(point)
844
+
845
+ point.serialization(order=self.order, shuffle_orders=self.shuffle_orders)
846
+ point.sparsify()
847
+
848
+ point = self.enc(point)
849
+ if not self.enc_mode:
850
+ point = self.dec(point)
851
+ return point
852
+
853
+
854
+ def load(
855
+ name: str = "utonia",
856
+ repo_id="Pointcept/utonia",
857
+ download_root: str = None,
858
+ custom_config: dict = None,
859
+ ckpt_only: bool = False,
860
+ ):
861
+ if name in MODELS:
862
+ print(f"Loading checkpoint from HuggingFace: {name} ...")
863
+ ckpt_path = hf_hub_download(
864
+ repo_id=repo_id,
865
+ filename=f"{name}.pth",
866
+ repo_type="model",
867
+ revision="main",
868
+ local_dir=download_root or os.path.expanduser("~/.cache/utonia/ckpt"),
869
+ )
870
+ elif os.path.isfile(name):
871
+ print(f"Loading checkpoint in local path: {name} ...")
872
+ ckpt_path = name
873
+ else:
874
+ raise RuntimeError(f"Model {name} not found; available models = {MODELS}")
875
+
876
+ if version.parse(torch.__version__) >= version.parse("2.4"):
877
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
878
+ else:
879
+ ckpt = torch.load(ckpt_path, map_location="cpu")
880
+ if custom_config is not None:
881
+ for key, value in custom_config.items():
882
+ ckpt["config"][key] = value
883
+
884
+ if ckpt_only:
885
+ return ckpt
886
+ ckpt["config"]["drop_path"] = 0.0
887
+ model = PointTransformerV3(**ckpt["config"])
888
+ model.load_state_dict(ckpt["state_dict"])
889
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
890
+ print(f"Model params: {n_parameters / 1e6:.2f}M")
891
+ return model
utonia/module.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Point Modules
3
+ Pointcept detached version
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import sys
25
+ import torch.nn as nn
26
+ import spconv.pytorch as spconv
27
+ from collections import OrderedDict
28
+
29
+ from .structure import Point
30
+
31
+
32
+ class PointModule(nn.Module):
33
+ r"""PointModule
34
+ placeholder, all module subclass from this will take Point in PointSequential.
35
+ """
36
+
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+
40
+
41
+ class PointSequential(PointModule):
42
+ r"""A sequential container.
43
+ Modules will be added to it in the order they are passed in the constructor.
44
+ Alternatively, an ordered dict of modules can also be passed in.
45
+ """
46
+
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__()
49
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
50
+ for key, module in args[0].items():
51
+ self.add_module(key, module)
52
+ else:
53
+ for idx, module in enumerate(args):
54
+ self.add_module(str(idx), module)
55
+ for name, module in kwargs.items():
56
+ if sys.version_info < (3, 6):
57
+ raise ValueError("kwargs only supported in py36+")
58
+ if name in self._modules:
59
+ raise ValueError("name exists.")
60
+ self.add_module(name, module)
61
+
62
+ def __getitem__(self, idx):
63
+ if not (-len(self) <= idx < len(self)):
64
+ raise IndexError("index {} is out of range".format(idx))
65
+ if idx < 0:
66
+ idx += len(self)
67
+ it = iter(self._modules.values())
68
+ for i in range(idx):
69
+ next(it)
70
+ return next(it)
71
+
72
+ def __len__(self):
73
+ return len(self._modules)
74
+
75
+ def add(self, module, name=None):
76
+ if name is None:
77
+ name = str(len(self._modules))
78
+ if name in self._modules:
79
+ raise KeyError("name exists")
80
+ self.add_module(name, module)
81
+
82
+ def forward(self, input):
83
+ for k, module in self._modules.items():
84
+ # Point module
85
+ if isinstance(module, PointModule):
86
+ input = module(input)
87
+ # Spconv module
88
+ elif spconv.modules.is_spconv_module(module):
89
+ if isinstance(input, Point):
90
+ input.sparse_conv_feat = module(input.sparse_conv_feat)
91
+ input.feat = input.sparse_conv_feat.features
92
+ else:
93
+ input = module(input)
94
+ # PyTorch module
95
+ else:
96
+ if isinstance(input, Point):
97
+ input.feat = module(input.feat)
98
+ if "sparse_conv_feat" in input.keys():
99
+ input.sparse_conv_feat = input.sparse_conv_feat.replace_feature(
100
+ input.feat
101
+ )
102
+ elif isinstance(input, spconv.SparseConvTensor):
103
+ if input.indices.shape[0] != 0:
104
+ input = input.replace_feature(module(input.features))
105
+ else:
106
+ input = module(input)
107
+ return input
utonia/registry.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @lint-ignore-every LICENSELINT
2
+ # Copyright (c) OpenMMLab. All rights reserved.
3
+ import inspect
4
+ import warnings
5
+ from functools import partial
6
+ from collections import abc
7
+
8
+
9
+ def is_seq_of(seq, expected_type, seq_type=None):
10
+ """Check whether it is a sequence of some type.
11
+
12
+ Args:
13
+ seq (Sequence): The sequence to be checked.
14
+ expected_type (type): Expected type of sequence items.
15
+ seq_type (type, optional): Expected sequence type.
16
+
17
+ Returns:
18
+ bool: Whether the sequence is valid.
19
+ """
20
+ if seq_type is None:
21
+ exp_seq_type = abc.Sequence
22
+ else:
23
+ assert isinstance(seq_type, type)
24
+ exp_seq_type = seq_type
25
+ if not isinstance(seq, exp_seq_type):
26
+ return False
27
+ for item in seq:
28
+ if not isinstance(item, expected_type):
29
+ return False
30
+ return True
31
+
32
+
33
+ def build_from_cfg(cfg, registry, default_args=None):
34
+ """Build a module from configs dict.
35
+
36
+ Args:
37
+ cfg (dict): Config dict. It should at least contain the key "type".
38
+ registry (:obj:`Registry`): The registry to search the type from.
39
+ default_args (dict, optional): Default initialization arguments.
40
+
41
+ Returns:
42
+ object: The constructed object.
43
+ """
44
+ if not isinstance(cfg, dict):
45
+ raise TypeError(f"cfg must be a dict, but got {type(cfg)}")
46
+ if "type" not in cfg:
47
+ if default_args is None or "type" not in default_args:
48
+ raise KeyError(
49
+ '`cfg` or `default_args` must contain the key "type", '
50
+ f"but got {cfg}\n{default_args}"
51
+ )
52
+ if not isinstance(registry, Registry):
53
+ raise TypeError(
54
+ "registry must be an mmcv.Registry object, " f"but got {type(registry)}"
55
+ )
56
+ if not (isinstance(default_args, dict) or default_args is None):
57
+ raise TypeError(
58
+ "default_args must be a dict or None, " f"but got {type(default_args)}"
59
+ )
60
+
61
+ args = cfg.copy()
62
+
63
+ if default_args is not None:
64
+ for name, value in default_args.items():
65
+ args.setdefault(name, value)
66
+
67
+ obj_type = args.pop("type")
68
+ if isinstance(obj_type, str):
69
+ obj_cls = registry.get(obj_type)
70
+ if obj_cls is None:
71
+ raise KeyError(f"{obj_type} is not in the {registry.name} registry")
72
+ elif inspect.isclass(obj_type):
73
+ obj_cls = obj_type
74
+ else:
75
+ raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
76
+ try:
77
+ return obj_cls(**args)
78
+ except Exception as e:
79
+ # Normal TypeError does not print class name.
80
+ raise type(e)(f"{obj_cls.__name__}: {e}")
81
+
82
+
83
+ class Registry:
84
+ """A registry to map strings to classes.
85
+
86
+ Registered object could be built from registry.
87
+ Example:
88
+ >>> MODELS = Registry('models')
89
+ >>> @MODELS.register_module()
90
+ >>> class ResNet:
91
+ >>> pass
92
+ >>> resnet = MODELS.build(dict(type='ResNet'))
93
+
94
+ Please refer to
95
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
96
+ advanced usage.
97
+
98
+ Args:
99
+ name (str): Registry name.
100
+ build_func(func, optional): Build function to construct instance from
101
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
102
+ ``build_func`` is specified. If ``parent`` is specified and
103
+ ``build_func`` is not given, ``build_func`` will be inherited
104
+ from ``parent``. Default: None.
105
+ parent (Registry, optional): Parent registry. The class registered in
106
+ children registry could be built from parent. Default: None.
107
+ scope (str, optional): The scope of registry. It is the key to search
108
+ for children registry. If not specified, scope will be the name of
109
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
110
+ Default: None.
111
+ """
112
+
113
+ def __init__(self, name, build_func=None, parent=None, scope=None):
114
+ self._name = name
115
+ self._module_dict = dict()
116
+ self._children = dict()
117
+ self._scope = self.infer_scope() if scope is None else scope
118
+
119
+ # self.build_func will be set with the following priority:
120
+ # 1. build_func
121
+ # 2. parent.build_func
122
+ # 3. build_from_cfg
123
+ if build_func is None:
124
+ if parent is not None:
125
+ self.build_func = parent.build_func
126
+ else:
127
+ self.build_func = build_from_cfg
128
+ else:
129
+ self.build_func = build_func
130
+ if parent is not None:
131
+ assert isinstance(parent, Registry)
132
+ parent._add_children(self)
133
+ self.parent = parent
134
+ else:
135
+ self.parent = None
136
+
137
+ def __len__(self):
138
+ return len(self._module_dict)
139
+
140
+ def __contains__(self, key):
141
+ return self.get(key) is not None
142
+
143
+ def __repr__(self):
144
+ format_str = (
145
+ self.__class__.__name__ + f"(name={self._name}, "
146
+ f"items={self._module_dict})"
147
+ )
148
+ return format_str
149
+
150
+ @staticmethod
151
+ def infer_scope():
152
+ """Infer the scope of registry.
153
+
154
+ The name of the package where registry is defined will be returned.
155
+
156
+ Example:
157
+ # in mmdet/models/backbone/resnet.py
158
+ >>> MODELS = Registry('models')
159
+ >>> @MODELS.register_module()
160
+ >>> class ResNet:
161
+ >>> pass
162
+ The scope of ``ResNet`` will be ``mmdet``.
163
+
164
+
165
+ Returns:
166
+ scope (str): The inferred scope name.
167
+ """
168
+ # inspect.stack() trace where this function is called, the index-2
169
+ # indicates the frame where `infer_scope()` is called
170
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
171
+ split_filename = filename.split(".")
172
+ return split_filename[0]
173
+
174
+ @staticmethod
175
+ def split_scope_key(key):
176
+ """Split scope and key.
177
+
178
+ The first scope will be split from key.
179
+
180
+ Examples:
181
+ >>> Registry.split_scope_key('mmdet.ResNet')
182
+ 'mmdet', 'ResNet'
183
+ >>> Registry.split_scope_key('ResNet')
184
+ None, 'ResNet'
185
+
186
+ Return:
187
+ scope (str, None): The first scope.
188
+ key (str): The remaining key.
189
+ """
190
+ split_index = key.find(".")
191
+ if split_index != -1:
192
+ return key[:split_index], key[split_index + 1 :]
193
+ else:
194
+ return None, key
195
+
196
+ @property
197
+ def name(self):
198
+ return self._name
199
+
200
+ @property
201
+ def scope(self):
202
+ return self._scope
203
+
204
+ @property
205
+ def module_dict(self):
206
+ return self._module_dict
207
+
208
+ @property
209
+ def children(self):
210
+ return self._children
211
+
212
+ def get(self, key):
213
+ """Get the registry record.
214
+
215
+ Args:
216
+ key (str): The class name in string format.
217
+
218
+ Returns:
219
+ class: The corresponding class.
220
+ """
221
+ scope, real_key = self.split_scope_key(key)
222
+ if scope is None or scope == self._scope:
223
+ # get from self
224
+ if real_key in self._module_dict:
225
+ return self._module_dict[real_key]
226
+ else:
227
+ # get from self._children
228
+ if scope in self._children:
229
+ return self._children[scope].get(real_key)
230
+ else:
231
+ # goto root
232
+ parent = self.parent
233
+ while parent.parent is not None:
234
+ parent = parent.parent
235
+ return parent.get(key)
236
+
237
+ def build(self, *args, **kwargs):
238
+ return self.build_func(*args, **kwargs, registry=self)
239
+
240
+ def _add_children(self, registry):
241
+ """Add children for a registry.
242
+
243
+ The ``registry`` will be added as children based on its scope.
244
+ The parent registry could build objects from children registry.
245
+
246
+ Example:
247
+ >>> models = Registry('models')
248
+ >>> mmdet_models = Registry('models', parent=models)
249
+ >>> @mmdet_models.register_module()
250
+ >>> class ResNet:
251
+ >>> pass
252
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
253
+ """
254
+
255
+ assert isinstance(registry, Registry)
256
+ assert registry.scope is not None
257
+ assert (
258
+ registry.scope not in self.children
259
+ ), f"scope {registry.scope} exists in {self.name} registry"
260
+ self.children[registry.scope] = registry
261
+
262
+ def _register_module(self, module_class, module_name=None, force=False):
263
+ if not inspect.isclass(module_class):
264
+ raise TypeError("module must be a class, " f"but got {type(module_class)}")
265
+
266
+ if module_name is None:
267
+ module_name = module_class.__name__
268
+ if isinstance(module_name, str):
269
+ module_name = [module_name]
270
+ for name in module_name:
271
+ if not force and name in self._module_dict:
272
+ raise KeyError(f"{name} is already registered " f"in {self.name}")
273
+ self._module_dict[name] = module_class
274
+
275
+ def deprecated_register_module(self, cls=None, force=False):
276
+ warnings.warn(
277
+ "The old API of register_module(module, force=False) "
278
+ "is deprecated and will be removed, please use the new API "
279
+ "register_module(name=None, force=False, module=None) instead."
280
+ )
281
+ if cls is None:
282
+ return partial(self.deprecated_register_module, force=force)
283
+ self._register_module(cls, force=force)
284
+ return cls
285
+
286
+ def register_module(self, name=None, force=False, module=None):
287
+ """Register a module.
288
+
289
+ A record will be added to `self._module_dict`, whose key is the class
290
+ name or the specified name, and value is the class itself.
291
+ It can be used as a decorator or a normal function.
292
+
293
+ Example:
294
+ >>> backbones = Registry('backbone')
295
+ >>> @backbones.register_module()
296
+ >>> class ResNet:
297
+ >>> pass
298
+
299
+ >>> backbones = Registry('backbone')
300
+ >>> @backbones.register_module(name='mnet')
301
+ >>> class MobileNet:
302
+ >>> pass
303
+
304
+ >>> backbones = Registry('backbone')
305
+ >>> class ResNet:
306
+ >>> pass
307
+ >>> backbones.register_module(ResNet)
308
+
309
+ Args:
310
+ name (str | None): The module name to be registered. If not
311
+ specified, the class name will be used.
312
+ force (bool, optional): Whether to override an existing class with
313
+ the same name. Default: False.
314
+ module (type): Module class to be registered.
315
+ """
316
+ if not isinstance(force, bool):
317
+ raise TypeError(f"force must be a boolean, but got {type(force)}")
318
+ # NOTE: This is a walkaround to be compatible with the old api,
319
+ # while it may introduce unexpected bugs.
320
+ if isinstance(name, type):
321
+ return self.deprecated_register_module(name, force=force)
322
+
323
+ # raise the error ahead of time
324
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
325
+ raise TypeError(
326
+ "name must be either of None, an instance of str or a sequence"
327
+ f" of str, but got {type(name)}"
328
+ )
329
+
330
+ # use it as a normal method: x.register_module(module=SomeClass)
331
+ if module is not None:
332
+ self._register_module(module_class=module, module_name=name, force=force)
333
+ return module
334
+
335
+ # use it as a decorator: @x.register_module()
336
+ def _register(cls):
337
+ self._register_module(module_class=cls, module_name=name, force=force)
338
+ return cls
339
+
340
+ return _register
utonia/serialization/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .default import (
2
+ encode,
3
+ decode,
4
+ z_order_encode,
5
+ z_order_decode,
6
+ hilbert_encode,
7
+ hilbert_decode,
8
+ )
utonia/serialization/default.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Serialization Encoding
3
+ Pointcept detached version
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import torch
25
+ from .z_order import xyz2key as z_order_encode_
26
+ from .z_order import key2xyz as z_order_decode_
27
+ from .hilbert import encode as hilbert_encode_
28
+ from .hilbert import decode as hilbert_decode_
29
+
30
+
31
+ @torch.inference_mode()
32
+ def encode(grid_coord, batch=None, depth=16, order="z"):
33
+ assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
34
+ if order == "z":
35
+ code = z_order_encode(grid_coord, depth=depth)
36
+ elif order == "z-trans":
37
+ code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
38
+ elif order == "hilbert":
39
+ code = hilbert_encode(grid_coord, depth=depth)
40
+ elif order == "hilbert-trans":
41
+ code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
42
+ else:
43
+ raise NotImplementedError
44
+ if batch is not None:
45
+ batch = batch.long()
46
+ code = batch << depth * 3 | code
47
+ return code
48
+
49
+
50
+ @torch.inference_mode()
51
+ def decode(code, depth=16, order="z"):
52
+ assert order in {"z", "hilbert"}
53
+ batch = code >> depth * 3
54
+ code = code & ((1 << depth * 3) - 1)
55
+ if order == "z":
56
+ grid_coord = z_order_decode(code, depth=depth)
57
+ elif order == "hilbert":
58
+ grid_coord = hilbert_decode(code, depth=depth)
59
+ else:
60
+ raise NotImplementedError
61
+ return grid_coord, batch
62
+
63
+
64
+ def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
65
+ x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
66
+ # we block the support to batch, maintain batched code in Point class
67
+ code = z_order_encode_(x, y, z, b=None, depth=depth)
68
+ return code
69
+
70
+
71
+ def z_order_decode(code: torch.Tensor, depth):
72
+ x, y, z = z_order_decode_(code, depth=depth)
73
+ grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
74
+ return grid_coord
75
+
76
+
77
+ def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
78
+ return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
79
+
80
+
81
+ def hilbert_decode(code: torch.Tensor, depth: int = 16):
82
+ return hilbert_decode_(code, num_dims=3, num_bits=depth)
utonia/serialization/hilbert.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hilbert Order
3
+ Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+
24
+ import torch
25
+
26
+
27
+ def right_shift(binary, k=1, axis=-1):
28
+ """Right shift an array of binary values.
29
+
30
+ Parameters:
31
+ -----------
32
+ binary: An ndarray of binary values.
33
+
34
+ k: The number of bits to shift. Default 1.
35
+
36
+ axis: The axis along which to shift. Default -1.
37
+
38
+ Returns:
39
+ --------
40
+ Returns an ndarray with zero prepended and the ends truncated, along
41
+ whatever axis was specified."""
42
+
43
+ # If we're shifting the whole thing, just return zeros.
44
+ if binary.shape[axis] <= k:
45
+ return torch.zeros_like(binary)
46
+
47
+ # Determine the padding pattern.
48
+ # padding = [(0,0)] * len(binary.shape)
49
+ # padding[axis] = (k,0)
50
+
51
+ # Determine the slicing pattern to eliminate just the last one.
52
+ slicing = [slice(None)] * len(binary.shape)
53
+ slicing[axis] = slice(None, -k)
54
+ shifted = torch.nn.functional.pad(
55
+ binary[tuple(slicing)], (k, 0), mode="constant", value=0
56
+ )
57
+
58
+ return shifted
59
+
60
+
61
+ def binary2gray(binary, axis=-1):
62
+ """Convert an array of binary values into Gray codes.
63
+
64
+ This uses the classic X ^ (X >> 1) trick to compute the Gray code.
65
+
66
+ Parameters:
67
+ -----------
68
+ binary: An ndarray of binary values.
69
+
70
+ axis: The axis along which to compute the gray code. Default=-1.
71
+
72
+ Returns:
73
+ --------
74
+ Returns an ndarray of Gray codes.
75
+ """
76
+ shifted = right_shift(binary, axis=axis)
77
+
78
+ # Do the X ^ (X >> 1) trick.
79
+ gray = torch.logical_xor(binary, shifted)
80
+
81
+ return gray
82
+
83
+
84
+ def gray2binary(gray, axis=-1):
85
+ """Convert an array of Gray codes back into binary values.
86
+
87
+ Parameters:
88
+ -----------
89
+ gray: An ndarray of gray codes.
90
+
91
+ axis: The axis along which to perform Gray decoding. Default=-1.
92
+
93
+ Returns:
94
+ --------
95
+ Returns an ndarray of binary values.
96
+ """
97
+
98
+ # Loop the log2(bits) number of times necessary, with shift and xor.
99
+ shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
100
+ while shift > 0:
101
+ gray = torch.logical_xor(gray, right_shift(gray, shift))
102
+ shift = torch.div(shift, 2, rounding_mode="floor")
103
+ return gray
104
+
105
+
106
+ def encode(locs, num_dims, num_bits):
107
+ """Decode an array of locations in a hypercube into a Hilbert integer.
108
+
109
+ This is a vectorized-ish version of the Hilbert curve implementation by John
110
+ Skilling as described in:
111
+
112
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
113
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
114
+
115
+ Params:
116
+ -------
117
+ locs - An ndarray of locations in a hypercube of num_dims dimensions, in
118
+ which each dimension runs from 0 to 2**num_bits-1. The shape can
119
+ be arbitrary, as long as the last dimension of the same has size
120
+ num_dims.
121
+
122
+ num_dims - The dimensionality of the hypercube. Integer.
123
+
124
+ num_bits - The number of bits for each dimension. Integer.
125
+
126
+ Returns:
127
+ --------
128
+ The output is an ndarray of uint64 integers with the same shape as the
129
+ input, excluding the last dimension, which needs to be num_dims.
130
+ """
131
+
132
+ # Keep around the original shape for later.
133
+ orig_shape = locs.shape
134
+ bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
135
+ bitpack_mask_rev = bitpack_mask.flip(-1)
136
+
137
+ if orig_shape[-1] != num_dims:
138
+ raise ValueError("""
139
+ The shape of locs was surprising in that the last dimension was of size
140
+ %d, but num_dims=%d. These need to be equal.
141
+ """ % (orig_shape[-1], num_dims))
142
+
143
+ if num_dims * num_bits > 63:
144
+ raise ValueError("""
145
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
146
+ into a int64. Are you sure you need that many points on your Hilbert
147
+ curve?
148
+ """ % (num_dims, num_bits, num_dims * num_bits))
149
+
150
+ # Treat the location integers as 64-bit unsigned and then split them up into
151
+ # a sequence of uint8s. Preserve the association by dimension.
152
+ locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
153
+
154
+ # Now turn these into bits and truncate to num_bits.
155
+ gray = (
156
+ locs_uint8.unsqueeze(-1)
157
+ .bitwise_and(bitpack_mask_rev)
158
+ .ne(0)
159
+ .byte()
160
+ .flatten(-2, -1)[..., -num_bits:]
161
+ )
162
+
163
+ # Run the decoding process the other way.
164
+ # Iterate forwards through the bits.
165
+ for bit in range(0, num_bits):
166
+ # Iterate forwards through the dimensions.
167
+ for dim in range(0, num_dims):
168
+ # Identify which ones have this bit active.
169
+ mask = gray[:, dim, bit]
170
+
171
+ # Where this bit is on, invert the 0 dimension for lower bits.
172
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
173
+ gray[:, 0, bit + 1 :], mask[:, None]
174
+ )
175
+
176
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
177
+ to_flip = torch.logical_and(
178
+ torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
179
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
180
+ )
181
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
182
+ gray[:, dim, bit + 1 :], to_flip
183
+ )
184
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
185
+
186
+ # Now flatten out.
187
+ gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
188
+
189
+ # Convert Gray back to binary.
190
+ hh_bin = gray2binary(gray)
191
+
192
+ # Pad back out to 64 bits.
193
+ extra_dims = 64 - num_bits * num_dims
194
+ padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
195
+
196
+ # Convert binary values into uint8s.
197
+ hh_uint8 = (
198
+ (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
199
+ .sum(2)
200
+ .squeeze()
201
+ .type(torch.uint8)
202
+ )
203
+
204
+ # Convert uint8s into uint64s.
205
+ hh_uint64 = hh_uint8.view(torch.int64).squeeze()
206
+
207
+ return hh_uint64
208
+
209
+
210
+ def decode(hilberts, num_dims, num_bits):
211
+ """Decode an array of Hilbert integers into locations in a hypercube.
212
+
213
+ This is a vectorized-ish version of the Hilbert curve implementation by John
214
+ Skilling as described in:
215
+
216
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
217
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
218
+
219
+ Params:
220
+ -------
221
+ hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
222
+ cannot have fewer bits than num_dims * num_bits.
223
+
224
+ num_dims - The dimensionality of the hypercube. Integer.
225
+
226
+ num_bits - The number of bits for each dimension. Integer.
227
+
228
+ Returns:
229
+ --------
230
+ The output is an ndarray of unsigned integers with the same shape as hilberts
231
+ but with an additional dimension of size num_dims.
232
+ """
233
+
234
+ if num_dims * num_bits > 64:
235
+ raise ValueError("""
236
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
237
+ into a uint64. Are you sure you need that many points on your Hilbert
238
+ curve?
239
+ """ % (num_dims, num_bits))
240
+
241
+ # Handle the case where we got handed a naked integer.
242
+ hilberts = torch.atleast_1d(hilberts)
243
+
244
+ # Keep around the shape for later.
245
+ orig_shape = hilberts.shape
246
+ bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
247
+ bitpack_mask_rev = bitpack_mask.flip(-1)
248
+
249
+ # Treat each of the hilberts as a s equence of eight uint8.
250
+ # This treats all of the inputs as uint64 and makes things uniform.
251
+ hh_uint8 = (
252
+ hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
253
+ )
254
+
255
+ # Turn these lists of uints into lists of bits and then truncate to the size
256
+ # we actually need for using Skilling's procedure.
257
+ hh_bits = (
258
+ hh_uint8.unsqueeze(-1)
259
+ .bitwise_and(bitpack_mask_rev)
260
+ .ne(0)
261
+ .byte()
262
+ .flatten(-2, -1)[:, -num_dims * num_bits :]
263
+ )
264
+
265
+ # Take the sequence of bits and Gray-code it.
266
+ gray = binary2gray(hh_bits)
267
+
268
+ # There has got to be a better way to do this.
269
+ # I could index them differently, but the eventual packbits likes it this way.
270
+ gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
271
+
272
+ # Iterate backwards through the bits.
273
+ for bit in range(num_bits - 1, -1, -1):
274
+ # Iterate backwards through the dimensions.
275
+ for dim in range(num_dims - 1, -1, -1):
276
+ # Identify which ones have this bit active.
277
+ mask = gray[:, dim, bit]
278
+
279
+ # Where this bit is on, invert the 0 dimension for lower bits.
280
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
281
+ gray[:, 0, bit + 1 :], mask[:, None]
282
+ )
283
+
284
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
285
+ to_flip = torch.logical_and(
286
+ torch.logical_not(mask[:, None]),
287
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
288
+ )
289
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
290
+ gray[:, dim, bit + 1 :], to_flip
291
+ )
292
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
293
+
294
+ # Pad back out to 64 bits.
295
+ extra_dims = 64 - num_bits
296
+ padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
297
+
298
+ # Now chop these up into blocks of 8.
299
+ locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
300
+
301
+ # Take those blocks and turn them unto uint8s.
302
+ # from IPython import embed; embed()
303
+ locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
304
+
305
+ # Finally, treat these as uint64s.
306
+ flat_locs = locs_uint8.view(torch.int64)
307
+
308
+ # Return them in the expected shape.
309
+ return flat_locs.reshape((*orig_shape, num_dims))
utonia/serialization/z_order.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @lint-ignore-every LICENSELINT
2
+ # --------------------------------------------------------
3
+ # Octree-based Sparse Convolutional Neural Networks
4
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Written by Peng-Shuai Wang
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ from typing import Optional, Union
11
+
12
+
13
+ class KeyLUT:
14
+ def __init__(self):
15
+ r256 = torch.arange(256, dtype=torch.int64)
16
+ r512 = torch.arange(512, dtype=torch.int64)
17
+ zero = torch.zeros(256, dtype=torch.int64)
18
+ device = torch.device("cpu")
19
+
20
+ self._encode = {
21
+ device: (
22
+ self.xyz2key(r256, zero, zero, 8),
23
+ self.xyz2key(zero, r256, zero, 8),
24
+ self.xyz2key(zero, zero, r256, 8),
25
+ )
26
+ }
27
+ self._decode = {device: self.key2xyz(r512, 9)}
28
+
29
+ def encode_lut(self, device=torch.device("cpu")):
30
+ if device not in self._encode:
31
+ cpu = torch.device("cpu")
32
+ self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
33
+ return self._encode[device]
34
+
35
+ def decode_lut(self, device=torch.device("cpu")):
36
+ if device not in self._decode:
37
+ cpu = torch.device("cpu")
38
+ self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
39
+ return self._decode[device]
40
+
41
+ def xyz2key(self, x, y, z, depth):
42
+ key = torch.zeros_like(x)
43
+ for i in range(depth):
44
+ mask = 1 << i
45
+ key = (
46
+ key
47
+ | ((x & mask) << (2 * i + 2))
48
+ | ((y & mask) << (2 * i + 1))
49
+ | ((z & mask) << (2 * i + 0))
50
+ )
51
+ return key
52
+
53
+ def key2xyz(self, key, depth):
54
+ x = torch.zeros_like(key)
55
+ y = torch.zeros_like(key)
56
+ z = torch.zeros_like(key)
57
+ for i in range(depth):
58
+ x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
59
+ y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
60
+ z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
61
+ return x, y, z
62
+
63
+
64
+ _key_lut = KeyLUT()
65
+
66
+
67
+ def xyz2key(
68
+ x: torch.Tensor,
69
+ y: torch.Tensor,
70
+ z: torch.Tensor,
71
+ b: Optional[Union[torch.Tensor, int]] = None,
72
+ depth: int = 16,
73
+ ):
74
+ r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
75
+ based on pre-computed look up tables. The speed of this function is much
76
+ faster than the method based on for-loop.
77
+
78
+ Args:
79
+ x (torch.Tensor): The x coordinate.
80
+ y (torch.Tensor): The y coordinate.
81
+ z (torch.Tensor): The z coordinate.
82
+ b (torch.Tensor or int): The batch index of the coordinates, and should be
83
+ smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
84
+ :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
85
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
86
+ """
87
+
88
+ EX, EY, EZ = _key_lut.encode_lut(x.device)
89
+ x, y, z = x.long(), y.long(), z.long()
90
+
91
+ mask = 255 if depth > 8 else (1 << depth) - 1
92
+ key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
93
+ if depth > 8:
94
+ mask = (1 << (depth - 8)) - 1
95
+ key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
96
+ key = key16 << 24 | key
97
+
98
+ if b is not None:
99
+ b = b.long()
100
+ key = b << 48 | key
101
+
102
+ return key
103
+
104
+
105
+ def key2xyz(key: torch.Tensor, depth: int = 16):
106
+ r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
107
+ and the batch index based on pre-computed look up tables.
108
+
109
+ Args:
110
+ key (torch.Tensor): The shuffled key.
111
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
112
+ """
113
+
114
+ DX, DY, DZ = _key_lut.decode_lut(key.device)
115
+ x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
116
+
117
+ b = key >> 48
118
+ key = key & ((1 << 48) - 1)
119
+
120
+ n = (depth + 2) // 3
121
+ for i in range(n):
122
+ k = key >> (i * 9) & 511
123
+ x = x | (DX[k] << (i * 3))
124
+ y = y | (DY[k] << (i * 3))
125
+ z = z | (DZ[k] << (i * 3))
126
+
127
+ return x, y, z, b
utonia/structure.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data structure for 3D point cloud
3
+
4
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
5
+ Please cite our work if the code is helpful to you.
6
+ """
7
+
8
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import torch
24
+ import spconv.pytorch as spconv
25
+ from addict import Dict
26
+
27
+ from .serialization import encode
28
+ from .utils import offset2batch, batch2offset
29
+
30
+
31
+ class Point(Dict):
32
+ """
33
+ Point Structure of Pointcept
34
+
35
+ A Point (point cloud) in Pointcept is a dictionary that contains various properties of
36
+ a batched point cloud. The property with the following names have a specific definition
37
+ as follows:
38
+
39
+ - "coord": original coordinate of point cloud;
40
+ - "grid_coord": grid coordinate for specific grid size (related to GridSampling);
41
+ Point also support the following optional attributes:
42
+ - "offset": if not exist, initialized as batch size is 1;
43
+ - "batch": if not exist, initialized as batch size is 1;
44
+ - "feat": feature of point cloud, default input of model;
45
+ - "grid_size": Grid size of point cloud (related to GridSampling);
46
+ (related to Serialization)
47
+ - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
48
+ - "serialized_code": a list of serialization codes;
49
+ - "serialized_order": a list of serialization order determined by code;
50
+ - "serialized_inverse": a list of inverse mapping determined by code;
51
+ (related to Sparsify: SpConv)
52
+ - "sparse_shape": Sparse shape for Sparse Conv Tensor;
53
+ - "sparse_conv_feat": SparseConvTensor init with information provide by Point;
54
+ """
55
+
56
+ def __init__(self, *args, **kwargs):
57
+ super().__init__(*args, **kwargs)
58
+ # If one of "offset" or "batch" do not exist, generate by the existing one
59
+ if "batch" not in self.keys() and "offset" in self.keys():
60
+ self["batch"] = offset2batch(self.offset)
61
+ elif "offset" not in self.keys() and "batch" in self.keys():
62
+ self["offset"] = batch2offset(self.batch)
63
+
64
+ def serialization(self, order="z", depth=None, shuffle_orders=False):
65
+ """
66
+ Point Cloud Serialization
67
+
68
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
69
+ """
70
+ self["order"] = order
71
+ assert "batch" in self.keys()
72
+ if "grid_coord" not in self.keys():
73
+ # if you don't want to operate GridSampling in data augmentation,
74
+ # please add the following augmentation into your pipeline:
75
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
76
+ # (adjust `grid_size` to what your want)
77
+ assert {"grid_size", "coord"}.issubset(self.keys())
78
+
79
+ self["grid_coord"] = torch.div(
80
+ self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
81
+ ).int()
82
+
83
+ if depth is None:
84
+ # Adaptive measure the depth of serialization cube (length = 2 ^ depth)
85
+ depth = int(self.grid_coord.max() + 1).bit_length()
86
+ self["serialized_depth"] = depth
87
+ # Maximum bit length for serialization code is 63 (int64)
88
+ assert depth * 3 + len(self.offset).bit_length() <= 63
89
+ # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
90
+ # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
91
+ # cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
92
+ # We can unlock the limitation by optimizing the z-order encoding function if necessary.
93
+ assert depth <= 16
94
+
95
+ # The serialization codes are arranged as following structures:
96
+ # [Order1 ([n]),
97
+ # Order2 ([n]),
98
+ # ...
99
+ # OrderN ([n])] (k, n)
100
+ code = [
101
+ encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order
102
+ ]
103
+ code = torch.stack(code)
104
+ order = torch.argsort(code)
105
+ inverse = torch.zeros_like(order).scatter_(
106
+ dim=1,
107
+ index=order,
108
+ src=torch.arange(0, code.shape[1], device=order.device).repeat(
109
+ code.shape[0], 1
110
+ ),
111
+ )
112
+
113
+ if shuffle_orders:
114
+ perm = torch.randperm(code.shape[0])
115
+ code = code[perm]
116
+ order = order[perm]
117
+ inverse = inverse[perm]
118
+
119
+ self["serialized_code"] = code
120
+ self["serialized_order"] = order
121
+ self["serialized_inverse"] = inverse
122
+
123
+ def sparsify(self, pad=96):
124
+ """
125
+ Point Cloud Serialization
126
+
127
+ Point cloud is sparse, here we use "sparsify" to specifically refer to
128
+ preparing "spconv.SparseConvTensor" for SpConv.
129
+
130
+ relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
131
+
132
+ pad: padding sparse for sparse shape.
133
+ """
134
+ assert {"feat", "batch"}.issubset(self.keys())
135
+ if "grid_coord" not in self.keys():
136
+ # if you don't want to operate GridSampling in data augmentation,
137
+ # please add the following augmentation into your pipeline:
138
+ # dict(type="Copy", keys_dict={"grid_size": 0.01}),
139
+ # (adjust `grid_size` to what your want)
140
+ assert {"grid_size", "coord"}.issubset(self.keys())
141
+ self["grid_coord"] = torch.div(
142
+ self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
143
+ ).int()
144
+ if "sparse_shape" in self.keys():
145
+ sparse_shape = self.sparse_shape
146
+ else:
147
+ sparse_shape = torch.add(
148
+ torch.max(self.grid_coord, dim=0).values, pad
149
+ ).tolist()
150
+ sparse_conv_feat = spconv.SparseConvTensor(
151
+ features=self.feat,
152
+ indices=torch.cat(
153
+ [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1
154
+ ).contiguous(),
155
+ spatial_shape=sparse_shape,
156
+ batch_size=self.batch[-1].tolist() + 1,
157
+ )
158
+ self["sparse_shape"] = sparse_shape
159
+ self["sparse_conv_feat"] = sparse_conv_feat
utonia/transform.py ADDED
@@ -0,0 +1,1226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D point cloud augmentation
3
+
4
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
5
+ Please cite our work if the code is helpful to you.
6
+ """
7
+
8
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import random
24
+ import numbers
25
+ import scipy
26
+ import scipy.ndimage
27
+ import scipy.interpolate
28
+ import scipy.stats
29
+ import numpy as np
30
+ import torch
31
+ import copy
32
+ from collections.abc import Sequence, Mapping
33
+
34
+ from .registry import Registry
35
+
36
+ TRANSFORMS = Registry("transforms")
37
+
38
+
39
+ def index_operator(data_dict, index, duplicate=False):
40
+ # index selection operator for keys in "index_valid_keys"
41
+ # custom these keys by "Update" transform in config
42
+ if "index_valid_keys" not in data_dict:
43
+ data_dict["index_valid_keys"] = [
44
+ "coord",
45
+ "color",
46
+ "normal",
47
+ "strength",
48
+ "segment",
49
+ "instance",
50
+ ]
51
+ if not duplicate:
52
+ for key in data_dict["index_valid_keys"]:
53
+ if key in data_dict:
54
+ data_dict[key] = data_dict[key][index]
55
+ return data_dict
56
+ else:
57
+ data_dict_ = dict()
58
+ for key in data_dict.keys():
59
+ if key in data_dict["index_valid_keys"]:
60
+ data_dict_[key] = data_dict[key][index]
61
+ else:
62
+ data_dict_[key] = data_dict[key]
63
+ return data_dict_
64
+
65
+
66
+ @TRANSFORMS.register_module()
67
+ class Collect(object):
68
+ def __init__(self, keys, offset_keys_dict=None, **kwargs):
69
+ """
70
+ e.g. Collect(keys=[coord], feat_keys=[coord, color])
71
+ """
72
+ if offset_keys_dict is None:
73
+ offset_keys_dict = dict(offset="coord")
74
+ self.keys = keys
75
+ self.offset_keys = offset_keys_dict
76
+ self.kwargs = kwargs
77
+
78
+ def __call__(self, data_dict):
79
+ data = dict()
80
+ if isinstance(self.keys, str):
81
+ self.keys = [self.keys]
82
+ for key in self.keys:
83
+ data[key] = data_dict[key]
84
+ for key, value in self.offset_keys.items():
85
+ data[key] = torch.tensor([data_dict[value].shape[0]])
86
+ for name, keys in self.kwargs.items():
87
+ name = name.replace("_keys", "")
88
+ assert isinstance(keys, Sequence)
89
+ data[name] = torch.cat([data_dict[key].float() for key in keys], dim=1)
90
+ return data
91
+
92
+
93
+ @TRANSFORMS.register_module()
94
+ class Copy(object):
95
+ def __init__(self, keys_dict=None):
96
+ if keys_dict is None:
97
+ keys_dict = dict(coord="origin_coord", segment="origin_segment")
98
+ self.keys_dict = keys_dict
99
+
100
+ def __call__(self, data_dict):
101
+ for key, value in self.keys_dict.items():
102
+ if isinstance(data_dict[key], np.ndarray):
103
+ data_dict[value] = data_dict[key].copy()
104
+ elif isinstance(data_dict[key], torch.Tensor):
105
+ data_dict[value] = data_dict[key].clone().detach()
106
+ else:
107
+ data_dict[value] = copy.deepcopy(data_dict[key])
108
+ return data_dict
109
+
110
+
111
+ @TRANSFORMS.register_module()
112
+ class Update(object):
113
+ def __init__(self, keys_dict=None):
114
+ if keys_dict is None:
115
+ keys_dict = dict()
116
+ self.keys_dict = keys_dict
117
+
118
+ def __call__(self, data_dict):
119
+ for key, value in self.keys_dict.items():
120
+ data_dict[key] = value
121
+ return data_dict
122
+
123
+
124
+ @TRANSFORMS.register_module()
125
+ class ToTensor(object):
126
+ def __call__(self, data):
127
+ if isinstance(data, torch.Tensor):
128
+ return data
129
+ elif isinstance(data, str):
130
+ # note that str is also a kind of sequence, judgement should before sequence
131
+ return data
132
+ elif isinstance(data, int):
133
+ return torch.LongTensor([data])
134
+ elif isinstance(data, float):
135
+ return torch.FloatTensor([data])
136
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, bool):
137
+ return torch.from_numpy(data)
138
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.integer):
139
+ return torch.from_numpy(data).long()
140
+ elif isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating):
141
+ return torch.from_numpy(data).float()
142
+ elif isinstance(data, Mapping):
143
+ result = {sub_key: self(item) for sub_key, item in data.items()}
144
+ return result
145
+ elif isinstance(data, Sequence):
146
+ result = [self(item) for item in data]
147
+ return result
148
+ else:
149
+ raise TypeError(f"type {type(data)} cannot be converted to tensor.")
150
+
151
+
152
+ @TRANSFORMS.register_module()
153
+ class NormalizeColor(object):
154
+ def __call__(self, data_dict):
155
+ if "color" in data_dict.keys():
156
+ data_dict["color"] = data_dict["color"] / 255
157
+ return data_dict
158
+
159
+
160
+ @TRANSFORMS.register_module()
161
+ class NormalizeCoord(object):
162
+ def __call__(self, data_dict):
163
+ if "coord" in data_dict.keys():
164
+ # modified from pointnet2
165
+ centroid = np.mean(data_dict["coord"], axis=0)
166
+ data_dict["coord"] -= centroid
167
+ m = np.max(np.sqrt(np.sum(data_dict["coord"] ** 2, axis=1)))
168
+ data_dict["coord"] = data_dict["coord"] / m
169
+ return data_dict
170
+
171
+
172
+ @TRANSFORMS.register_module()
173
+ class PositiveShift(object):
174
+ def __call__(self, data_dict):
175
+ if "coord" in data_dict.keys():
176
+ coord_min = np.min(data_dict["coord"], 0)
177
+ data_dict["coord"] -= coord_min
178
+ return data_dict
179
+
180
+
181
+ @TRANSFORMS.register_module()
182
+ class CenterShift(object):
183
+ def __init__(self, apply_z=True):
184
+ self.apply_z = apply_z
185
+
186
+ def __call__(self, data_dict):
187
+ if "coord" in data_dict.keys():
188
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
189
+ x_max, y_max, _ = data_dict["coord"].max(axis=0)
190
+ if self.apply_z:
191
+ shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, z_min]
192
+ else:
193
+ shift = [(x_min + x_max) / 2, (y_min + y_max) / 2, 0]
194
+ data_dict["coord"] -= shift
195
+ return data_dict
196
+
197
+
198
+ @TRANSFORMS.register_module()
199
+ class RandomShift(object):
200
+ def __init__(self, shift=((-0.2, 0.2), (-0.2, 0.2), (0, 0))):
201
+ self.shift = shift
202
+
203
+ def __call__(self, data_dict):
204
+ if "coord" in data_dict.keys():
205
+ shift_x = np.random.uniform(self.shift[0][0], self.shift[0][1])
206
+ shift_y = np.random.uniform(self.shift[1][0], self.shift[1][1])
207
+ shift_z = np.random.uniform(self.shift[2][0], self.shift[2][1])
208
+ data_dict["coord"] += [shift_x, shift_y, shift_z]
209
+ return data_dict
210
+
211
+
212
+ @TRANSFORMS.register_module()
213
+ class PointClip(object):
214
+ def __init__(self, point_cloud_range=(-80, -80, -3, 80, 80, 1)):
215
+ self.point_cloud_range = point_cloud_range
216
+
217
+ def __call__(self, data_dict):
218
+ if "coord" in data_dict.keys():
219
+ data_dict["coord"] = np.clip(
220
+ data_dict["coord"],
221
+ a_min=self.point_cloud_range[:3],
222
+ a_max=self.point_cloud_range[3:],
223
+ )
224
+ return data_dict
225
+
226
+
227
+ @TRANSFORMS.register_module()
228
+ class RandomDropout(object):
229
+ def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5):
230
+ """
231
+ upright_axis: axis index among x,y,z, i.e. 2 for z
232
+ """
233
+ self.dropout_ratio = dropout_ratio
234
+ self.dropout_application_ratio = dropout_application_ratio
235
+
236
+ def __call__(self, data_dict):
237
+ if random.random() < self.dropout_application_ratio:
238
+ n = len(data_dict["coord"])
239
+ idx = np.random.choice(n, int(n * (1 - self.dropout_ratio)), replace=False)
240
+ if "sampled_index" in data_dict:
241
+ # for ScanNet data efficient, we need to make sure labeled point is sampled.
242
+ idx = np.unique(np.append(idx, data_dict["sampled_index"]))
243
+ mask = np.zeros_like(data_dict["segment"]).astype(bool)
244
+ mask[data_dict["sampled_index"]] = True
245
+ data_dict["sampled_index"] = np.where(mask[idx])[0]
246
+ data_dict = index_operator(data_dict, idx)
247
+ return data_dict
248
+
249
+
250
+ @TRANSFORMS.register_module()
251
+ class RandomRotate(object):
252
+ def __init__(self, angle=None, center=None, axis="z", always_apply=False, p=0.5):
253
+ self.angle = [-1, 1] if angle is None else angle
254
+ self.axis = axis
255
+ self.always_apply = always_apply
256
+ self.p = p if not self.always_apply else 1
257
+ self.center = center
258
+
259
+ def __call__(self, data_dict):
260
+ if random.random() > self.p:
261
+ return data_dict
262
+ angle = np.random.uniform(self.angle[0], self.angle[1]) * np.pi
263
+ rot_cos, rot_sin = np.cos(angle), np.sin(angle)
264
+ if self.axis == "x":
265
+ rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
266
+ elif self.axis == "y":
267
+ rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
268
+ elif self.axis == "z":
269
+ rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
270
+ else:
271
+ raise NotImplementedError
272
+ if "coord" in data_dict.keys():
273
+ if self.center is None:
274
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
275
+ x_max, y_max, z_max = data_dict["coord"].max(axis=0)
276
+ center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
277
+ else:
278
+ center = self.center
279
+ data_dict["coord"] -= center
280
+ data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
281
+ data_dict["coord"] += center
282
+ if "normal" in data_dict.keys():
283
+ data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
284
+ return data_dict
285
+
286
+
287
+ @TRANSFORMS.register_module()
288
+ class RandomRotateTargetAngle(object):
289
+ def __init__(
290
+ self, angle=(1 / 2, 1, 3 / 2), center=None, axis="z", always_apply=False, p=0.75
291
+ ):
292
+ self.angle = angle
293
+ self.axis = axis
294
+ self.always_apply = always_apply
295
+ self.p = p if not self.always_apply else 1
296
+ self.center = center
297
+
298
+ def __call__(self, data_dict):
299
+ if random.random() > self.p:
300
+ return data_dict
301
+ angle = np.random.choice(self.angle) * np.pi
302
+ rot_cos, rot_sin = np.cos(angle), np.sin(angle)
303
+ if self.axis == "x":
304
+ rot_t = np.array([[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]])
305
+ elif self.axis == "y":
306
+ rot_t = np.array([[rot_cos, 0, rot_sin], [0, 1, 0], [-rot_sin, 0, rot_cos]])
307
+ elif self.axis == "z":
308
+ rot_t = np.array([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]])
309
+ else:
310
+ raise NotImplementedError
311
+ if "coord" in data_dict.keys():
312
+ if self.center is None:
313
+ x_min, y_min, z_min = data_dict["coord"].min(axis=0)
314
+ x_max, y_max, z_max = data_dict["coord"].max(axis=0)
315
+ center = [(x_min + x_max) / 2, (y_min + y_max) / 2, (z_min + z_max) / 2]
316
+ else:
317
+ center = self.center
318
+ data_dict["coord"] -= center
319
+ data_dict["coord"] = np.dot(data_dict["coord"], np.transpose(rot_t))
320
+ data_dict["coord"] += center
321
+ if "normal" in data_dict.keys():
322
+ data_dict["normal"] = np.dot(data_dict["normal"], np.transpose(rot_t))
323
+ return data_dict
324
+
325
+
326
+ @TRANSFORMS.register_module()
327
+ class RandomScale(object):
328
+ def __init__(self, scale=None, anisotropic=False):
329
+ self.scale = scale if scale is not None else [0.95, 1.05]
330
+ self.anisotropic = anisotropic
331
+
332
+ def __call__(self, data_dict):
333
+ if "coord" in data_dict.keys():
334
+ scale = np.random.uniform(
335
+ self.scale[0], self.scale[1], 3 if self.anisotropic else 1
336
+ )
337
+ data_dict["coord"] *= scale
338
+ return data_dict
339
+
340
+
341
+ @TRANSFORMS.register_module()
342
+ class RandomFlip(object):
343
+ def __init__(self, p=0.5):
344
+ self.p = p
345
+
346
+ def __call__(self, data_dict):
347
+ if np.random.rand() < self.p:
348
+ if "coord" in data_dict.keys():
349
+ data_dict["coord"][:, 0] = -data_dict["coord"][:, 0]
350
+ if "normal" in data_dict.keys():
351
+ data_dict["normal"][:, 0] = -data_dict["normal"][:, 0]
352
+ if np.random.rand() < self.p:
353
+ if "coord" in data_dict.keys():
354
+ data_dict["coord"][:, 1] = -data_dict["coord"][:, 1]
355
+ if "normal" in data_dict.keys():
356
+ data_dict["normal"][:, 1] = -data_dict["normal"][:, 1]
357
+ return data_dict
358
+
359
+
360
+ @TRANSFORMS.register_module()
361
+ class RandomJitter(object):
362
+ def __init__(self, sigma=0.01, clip=0.05):
363
+ assert clip > 0
364
+ self.sigma = sigma
365
+ self.clip = clip
366
+
367
+ def __call__(self, data_dict):
368
+ if "coord" in data_dict.keys():
369
+ jitter = np.clip(
370
+ self.sigma * np.random.randn(data_dict["coord"].shape[0], 3),
371
+ -self.clip,
372
+ self.clip,
373
+ )
374
+ data_dict["coord"] += jitter
375
+ return data_dict
376
+
377
+
378
+ @TRANSFORMS.register_module()
379
+ class ClipGaussianJitter(object):
380
+ def __init__(self, scalar=0.02, store_jitter=False):
381
+ self.scalar = scalar
382
+ self.mean = np.mean(3)
383
+ self.cov = np.identity(3)
384
+ self.quantile = 1.96
385
+ self.store_jitter = store_jitter
386
+
387
+ def __call__(self, data_dict):
388
+ if "coord" in data_dict.keys():
389
+ jitter = np.random.multivariate_normal(
390
+ self.mean, self.cov, data_dict["coord"].shape[0]
391
+ )
392
+ jitter = self.scalar * np.clip(jitter / 1.96, -1, 1)
393
+ data_dict["coord"] += jitter
394
+ if self.store_jitter:
395
+ data_dict["jitter"] = jitter
396
+ return data_dict
397
+
398
+
399
+ @TRANSFORMS.register_module()
400
+ class ChromaticAutoContrast(object):
401
+ def __init__(self, p=0.2, blend_factor=None):
402
+ self.p = p
403
+ self.blend_factor = blend_factor
404
+
405
+ def __call__(self, data_dict):
406
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
407
+ lo = np.min(data_dict["color"], 0, keepdims=True)
408
+ hi = np.max(data_dict["color"], 0, keepdims=True)
409
+ scale = 255 / (hi - lo)
410
+ contrast_feat = (data_dict["color"][:, :3] - lo) * scale
411
+ blend_factor = (
412
+ np.random.rand() if self.blend_factor is None else self.blend_factor
413
+ )
414
+ data_dict["color"][:, :3] = (1 - blend_factor) * data_dict["color"][
415
+ :, :3
416
+ ] + blend_factor * contrast_feat
417
+ return data_dict
418
+
419
+
420
+ @TRANSFORMS.register_module()
421
+ class ChromaticTranslation(object):
422
+ def __init__(self, p=0.95, ratio=0.05):
423
+ self.p = p
424
+ self.ratio = ratio
425
+
426
+ def __call__(self, data_dict):
427
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
428
+ tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.ratio
429
+ data_dict["color"][:, :3] = np.clip(tr + data_dict["color"][:, :3], 0, 255)
430
+ return data_dict
431
+
432
+
433
+ @TRANSFORMS.register_module()
434
+ class ChromaticJitter(object):
435
+ def __init__(self, p=0.95, std=0.005):
436
+ self.p = p
437
+ self.std = std
438
+
439
+ def __call__(self, data_dict):
440
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
441
+ noise = np.random.randn(data_dict["color"].shape[0], 3)
442
+ noise *= self.std * 255
443
+ data_dict["color"][:, :3] = np.clip(
444
+ noise + data_dict["color"][:, :3], 0, 255
445
+ )
446
+ return data_dict
447
+
448
+
449
+ @TRANSFORMS.register_module()
450
+ class RandomColorGrayScale(object):
451
+ def __init__(self, p):
452
+ self.p = p
453
+
454
+ @staticmethod
455
+ def rgb_to_grayscale(color, num_output_channels=1):
456
+ if color.shape[-1] < 3:
457
+ raise TypeError(
458
+ "Input color should have at least 3 dimensions, but found {}".format(
459
+ color.shape[-1]
460
+ )
461
+ )
462
+
463
+ if num_output_channels not in (1, 3):
464
+ raise ValueError("num_output_channels should be either 1 or 3")
465
+
466
+ r, g, b = color[..., 0], color[..., 1], color[..., 2]
467
+ gray = (0.2989 * r + 0.587 * g + 0.114 * b).astype(color.dtype)
468
+ gray = np.expand_dims(gray, axis=-1)
469
+
470
+ if num_output_channels == 3:
471
+ gray = np.broadcast_to(gray, color.shape)
472
+
473
+ return gray
474
+
475
+ def __call__(self, data_dict):
476
+ if np.random.rand() < self.p:
477
+ data_dict["color"] = self.rgb_to_grayscale(data_dict["color"], 3)
478
+ return data_dict
479
+
480
+
481
+ @TRANSFORMS.register_module()
482
+ class RandomColorJitter(object):
483
+ """
484
+ Random Color Jitter for 3D point cloud (refer torchvision)
485
+ """
486
+
487
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.95):
488
+ self.brightness = self._check_input(brightness, "brightness")
489
+ self.contrast = self._check_input(contrast, "contrast")
490
+ self.saturation = self._check_input(saturation, "saturation")
491
+ self.hue = self._check_input(
492
+ hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False
493
+ )
494
+ self.p = p
495
+
496
+ @staticmethod
497
+ def _check_input(
498
+ value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True
499
+ ):
500
+ if isinstance(value, numbers.Number):
501
+ if value < 0:
502
+ raise ValueError(
503
+ "If {} is a single number, it must be non negative.".format(name)
504
+ )
505
+ value = [center - float(value), center + float(value)]
506
+ if clip_first_on_zero:
507
+ value[0] = max(value[0], 0.0)
508
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
509
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
510
+ raise ValueError("{} values should be between {}".format(name, bound))
511
+ else:
512
+ raise TypeError(
513
+ "{} should be a single number or a list/tuple with length 2.".format(
514
+ name
515
+ )
516
+ )
517
+
518
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
519
+ # or (0., 0.) for hue, do nothing
520
+ if value[0] == value[1] == center:
521
+ value = None
522
+ return value
523
+
524
+ @staticmethod
525
+ def blend(color1, color2, ratio):
526
+ ratio = float(ratio)
527
+ bound = 255.0
528
+ return (
529
+ (ratio * color1 + (1.0 - ratio) * color2)
530
+ .clip(0, bound)
531
+ .astype(color1.dtype)
532
+ )
533
+
534
+ @staticmethod
535
+ def rgb2hsv(rgb):
536
+ r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
537
+ maxc = np.max(rgb, axis=-1)
538
+ minc = np.min(rgb, axis=-1)
539
+ eqc = maxc == minc
540
+ cr = maxc - minc
541
+ s = cr / (np.ones_like(maxc) * eqc + maxc * (1 - eqc))
542
+ cr_divisor = np.ones_like(maxc) * eqc + cr * (1 - eqc)
543
+ rc = (maxc - r) / cr_divisor
544
+ gc = (maxc - g) / cr_divisor
545
+ bc = (maxc - b) / cr_divisor
546
+
547
+ hr = (maxc == r) * (bc - gc)
548
+ hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc)
549
+ hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc)
550
+ h = hr + hg + hb
551
+ h = (h / 6.0 + 1.0) % 1.0
552
+ return np.stack((h, s, maxc), axis=-1)
553
+
554
+ @staticmethod
555
+ def hsv2rgb(hsv):
556
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
557
+ i = np.floor(h * 6.0)
558
+ f = (h * 6.0) - i
559
+ i = i.astype(np.int32)
560
+
561
+ p = np.clip((v * (1.0 - s)), 0.0, 1.0)
562
+ q = np.clip((v * (1.0 - s * f)), 0.0, 1.0)
563
+ t = np.clip((v * (1.0 - s * (1.0 - f))), 0.0, 1.0)
564
+ i = i % 6
565
+ mask = np.expand_dims(i, axis=-1) == np.arange(6)
566
+
567
+ a1 = np.stack((v, q, p, p, t, v), axis=-1)
568
+ a2 = np.stack((t, v, v, q, p, p), axis=-1)
569
+ a3 = np.stack((p, p, t, v, v, q), axis=-1)
570
+ a4 = np.stack((a1, a2, a3), axis=-1)
571
+
572
+ return np.einsum("...na, ...nab -> ...nb", mask.astype(hsv.dtype), a4)
573
+
574
+ def adjust_brightness(self, color, brightness_factor):
575
+ if brightness_factor < 0:
576
+ raise ValueError(
577
+ "brightness_factor ({}) is not non-negative.".format(brightness_factor)
578
+ )
579
+
580
+ return self.blend(color, np.zeros_like(color), brightness_factor)
581
+
582
+ def adjust_contrast(self, color, contrast_factor):
583
+ if contrast_factor < 0:
584
+ raise ValueError(
585
+ "contrast_factor ({}) is not non-negative.".format(contrast_factor)
586
+ )
587
+ mean = np.mean(RandomColorGrayScale.rgb_to_grayscale(color))
588
+ return self.blend(color, mean, contrast_factor)
589
+
590
+ def adjust_saturation(self, color, saturation_factor):
591
+ if saturation_factor < 0:
592
+ raise ValueError(
593
+ "saturation_factor ({}) is not non-negative.".format(saturation_factor)
594
+ )
595
+ gray = RandomColorGrayScale.rgb_to_grayscale(color)
596
+ return self.blend(color, gray, saturation_factor)
597
+
598
+ def adjust_hue(self, color, hue_factor):
599
+ if not (-0.5 <= hue_factor <= 0.5):
600
+ raise ValueError(
601
+ "hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor)
602
+ )
603
+ orig_dtype = color.dtype
604
+ hsv = self.rgb2hsv(color / 255.0)
605
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
606
+ h = (h + hue_factor) % 1.0
607
+ hsv = np.stack((h, s, v), axis=-1)
608
+ color_hue_adj = (self.hsv2rgb(hsv) * 255.0).astype(orig_dtype)
609
+ return color_hue_adj
610
+
611
+ @staticmethod
612
+ def get_params(brightness, contrast, saturation, hue):
613
+ fn_idx = torch.randperm(4)
614
+ b = (
615
+ None
616
+ if brightness is None
617
+ else np.random.uniform(brightness[0], brightness[1])
618
+ )
619
+ c = None if contrast is None else np.random.uniform(contrast[0], contrast[1])
620
+ s = (
621
+ None
622
+ if saturation is None
623
+ else np.random.uniform(saturation[0], saturation[1])
624
+ )
625
+ h = None if hue is None else np.random.uniform(hue[0], hue[1])
626
+ return fn_idx, b, c, s, h
627
+
628
+ def __call__(self, data_dict):
629
+ (
630
+ fn_idx,
631
+ brightness_factor,
632
+ contrast_factor,
633
+ saturation_factor,
634
+ hue_factor,
635
+ ) = self.get_params(self.brightness, self.contrast, self.saturation, self.hue)
636
+
637
+ for fn_id in fn_idx:
638
+ if (
639
+ fn_id == 0
640
+ and brightness_factor is not None
641
+ and np.random.rand() < self.p
642
+ ):
643
+ data_dict["color"] = self.adjust_brightness(
644
+ data_dict["color"], brightness_factor
645
+ )
646
+ elif (
647
+ fn_id == 1 and contrast_factor is not None and np.random.rand() < self.p
648
+ ):
649
+ data_dict["color"] = self.adjust_contrast(
650
+ data_dict["color"], contrast_factor
651
+ )
652
+ elif (
653
+ fn_id == 2
654
+ and saturation_factor is not None
655
+ and np.random.rand() < self.p
656
+ ):
657
+ data_dict["color"] = self.adjust_saturation(
658
+ data_dict["color"], saturation_factor
659
+ )
660
+ elif fn_id == 3 and hue_factor is not None and np.random.rand() < self.p:
661
+ data_dict["color"] = self.adjust_hue(data_dict["color"], hue_factor)
662
+ return data_dict
663
+
664
+
665
+ @TRANSFORMS.register_module()
666
+ class HueSaturationTranslation(object):
667
+ @staticmethod
668
+ def rgb_to_hsv(rgb):
669
+ # Translated from source of colorsys.rgb_to_hsv
670
+ # r,g,b should be a numpy arrays with values between 0 and 255
671
+ # rgb_to_hsv returns an array of floats between 0.0 and 1.0.
672
+ rgb = rgb.astype("float")
673
+ hsv = np.zeros_like(rgb)
674
+ # in case an RGBA array was passed, just copy the A channel
675
+ hsv[..., 3:] = rgb[..., 3:]
676
+ r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
677
+ maxc = np.max(rgb[..., :3], axis=-1)
678
+ minc = np.min(rgb[..., :3], axis=-1)
679
+ hsv[..., 2] = maxc
680
+ mask = maxc != minc
681
+ hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask]
682
+ rc = np.zeros_like(r)
683
+ gc = np.zeros_like(g)
684
+ bc = np.zeros_like(b)
685
+ rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask]
686
+ gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask]
687
+ bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask]
688
+ hsv[..., 0] = np.select(
689
+ [r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc
690
+ )
691
+ hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0
692
+ return hsv
693
+
694
+ @staticmethod
695
+ def hsv_to_rgb(hsv):
696
+ # Translated from source of colorsys.hsv_to_rgb
697
+ # h,s should be a numpy arrays with values between 0.0 and 1.0
698
+ # v should be a numpy array with values between 0.0 and 255.0
699
+ # hsv_to_rgb returns an array of uints between 0 and 255.
700
+ rgb = np.empty_like(hsv)
701
+ rgb[..., 3:] = hsv[..., 3:]
702
+ h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
703
+ i = (h * 6.0).astype("uint8")
704
+ f = (h * 6.0) - i
705
+ p = v * (1.0 - s)
706
+ q = v * (1.0 - s * f)
707
+ t = v * (1.0 - s * (1.0 - f))
708
+ i = i % 6
709
+ conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5]
710
+ rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v)
711
+ rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t)
712
+ rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p)
713
+ return rgb.astype("uint8")
714
+
715
+ def __init__(self, hue_max=0.5, saturation_max=0.2):
716
+ self.hue_max = hue_max
717
+ self.saturation_max = saturation_max
718
+
719
+ def __call__(self, data_dict):
720
+ if "color" in data_dict.keys():
721
+ # Assume color[:, :3] is rgb
722
+ hsv = HueSaturationTranslation.rgb_to_hsv(data_dict["color"][:, :3])
723
+ hue_val = (np.random.rand() - 0.5) * 2 * self.hue_max
724
+ sat_ratio = 1 + (np.random.rand() - 0.5) * 2 * self.saturation_max
725
+ hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1)
726
+ hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1)
727
+ data_dict["color"][:, :3] = np.clip(
728
+ HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255
729
+ )
730
+ return data_dict
731
+
732
+
733
+ @TRANSFORMS.register_module()
734
+ class RandomColorDrop(object):
735
+ def __init__(self, p=0.2, color_augment=0.0):
736
+ self.p = p
737
+ self.color_augment = color_augment
738
+
739
+ def __call__(self, data_dict):
740
+ if "color" in data_dict.keys() and np.random.rand() < self.p:
741
+ data_dict["color"] *= self.color_augment
742
+ return data_dict
743
+
744
+ def __repr__(self):
745
+ return "RandomColorDrop(color_augment: {}, p: {})".format(
746
+ self.color_augment, self.p
747
+ )
748
+
749
+
750
+ @TRANSFORMS.register_module()
751
+ class ElasticDistortion(object):
752
+ def __init__(self, distortion_params=None):
753
+ self.distortion_params = (
754
+ [[0.2, 0.4], [0.8, 1.6]] if distortion_params is None else distortion_params
755
+ )
756
+
757
+ @staticmethod
758
+ def elastic_distortion(coords, granularity, magnitude):
759
+ """
760
+ Apply elastic distortion on sparse coordinate space.
761
+ pointcloud: numpy array of (number of points, at least 3 spatial dims)
762
+ granularity: size of the noise grid (in same scale[m/cm] as the voxel grid)
763
+ magnitude: noise multiplier
764
+ """
765
+ blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3
766
+ blury = np.ones((1, 3, 1, 1)).astype("float32") / 3
767
+ blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3
768
+ coords_min = coords.min(0)
769
+
770
+ # Create Gaussian noise tensor of the size given by granularity.
771
+ noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3
772
+ noise = np.random.randn(*noise_dim, 3).astype(np.float32)
773
+
774
+ # Smoothing.
775
+ for _ in range(2):
776
+ noise = scipy.ndimage.filters.convolve(
777
+ noise, blurx, mode="constant", cval=0
778
+ )
779
+ noise = scipy.ndimage.filters.convolve(
780
+ noise, blury, mode="constant", cval=0
781
+ )
782
+ noise = scipy.ndimage.filters.convolve(
783
+ noise, blurz, mode="constant", cval=0
784
+ )
785
+
786
+ # Trilinear interpolate noise filters for each spatial dimensions.
787
+ ax = [
788
+ np.linspace(d_min, d_max, d)
789
+ for d_min, d_max, d in zip(
790
+ coords_min - granularity,
791
+ coords_min + granularity * (noise_dim - 2),
792
+ noise_dim,
793
+ )
794
+ ]
795
+ interp = scipy.interpolate.RegularGridInterpolator(
796
+ ax, noise, bounds_error=False, fill_value=0
797
+ )
798
+ coords += interp(coords) * magnitude
799
+ return coords
800
+
801
+ def __call__(self, data_dict):
802
+ if "coord" in data_dict.keys() and self.distortion_params is not None:
803
+ if random.random() < 0.95:
804
+ for granularity, magnitude in self.distortion_params:
805
+ data_dict["coord"] = self.elastic_distortion(
806
+ data_dict["coord"], granularity, magnitude
807
+ )
808
+ return data_dict
809
+
810
+
811
+ @TRANSFORMS.register_module()
812
+ class GridSample(object):
813
+ def __init__(
814
+ self,
815
+ grid_size=0.05,
816
+ hash_type="fnv",
817
+ mode="train",
818
+ return_inverse=False,
819
+ return_grid_coord=False,
820
+ return_min_coord=False,
821
+ return_displacement=False,
822
+ project_displacement=False,
823
+ ):
824
+ self.grid_size = grid_size
825
+ self.hash = self.fnv_hash_vec if hash_type == "fnv" else self.ravel_hash_vec
826
+ assert mode in ["train", "test"]
827
+ self.mode = mode
828
+ self.return_inverse = return_inverse
829
+ self.return_grid_coord = return_grid_coord
830
+ self.return_min_coord = return_min_coord
831
+ self.return_displacement = return_displacement
832
+ self.project_displacement = project_displacement
833
+
834
+ def __call__(self, data_dict):
835
+ assert "coord" in data_dict.keys()
836
+ scaled_coord = data_dict["coord"] / np.array(self.grid_size)
837
+ grid_coord = np.floor(scaled_coord).astype(int)
838
+ min_coord = grid_coord.min(0)
839
+ grid_coord -= min_coord
840
+ scaled_coord -= min_coord
841
+ min_coord = min_coord * np.array(self.grid_size)
842
+ key = self.hash(grid_coord)
843
+ idx_sort = np.argsort(key)
844
+ key_sort = key[idx_sort]
845
+ _, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True)
846
+ if self.mode == "train": # train mode
847
+ idx_select = (
848
+ np.cumsum(np.insert(count, 0, 0)[0:-1])
849
+ + np.random.randint(0, count.max(), count.size) % count
850
+ )
851
+ idx_unique = idx_sort[idx_select]
852
+ if "sampled_index" in data_dict:
853
+ # for ScanNet data efficient, we need to make sure labeled point is sampled.
854
+ idx_unique = np.unique(
855
+ np.append(idx_unique, data_dict["sampled_index"])
856
+ )
857
+ mask = np.zeros_like(data_dict["segment"]).astype(bool)
858
+ mask[data_dict["sampled_index"]] = True
859
+ data_dict["sampled_index"] = np.where(mask[idx_unique])[0]
860
+ data_dict = index_operator(data_dict, idx_unique)
861
+ if self.return_inverse:
862
+ data_dict["inverse"] = np.zeros_like(inverse)
863
+ data_dict["inverse"][idx_sort] = inverse
864
+ if self.return_grid_coord:
865
+ data_dict["grid_coord"] = grid_coord[idx_unique]
866
+ data_dict["index_valid_keys"].append("grid_coord")
867
+ if self.return_min_coord:
868
+ data_dict["min_coord"] = min_coord.reshape([1, 3])
869
+ if self.return_displacement:
870
+ displacement = (
871
+ scaled_coord - grid_coord - 0.5
872
+ ) # [0, 1] -> [-0.5, 0.5] displacement to center
873
+ if self.project_displacement:
874
+ displacement = np.sum(
875
+ displacement * data_dict["normal"], axis=-1, keepdims=True
876
+ )
877
+ data_dict["displacement"] = displacement[idx_unique]
878
+ data_dict["index_valid_keys"].append("displacement")
879
+ return data_dict
880
+
881
+ elif self.mode == "test": # test mode
882
+ data_part_list = []
883
+ for i in range(count.max()):
884
+ idx_select = np.cumsum(np.insert(count, 0, 0)[0:-1]) + i % count
885
+ idx_part = idx_sort[idx_select]
886
+ data_part = index_operator(data_dict, idx_part, duplicate=True)
887
+ data_part["index"] = idx_part
888
+ if self.return_inverse:
889
+ data_part["inverse"] = np.zeros_like(inverse)
890
+ data_part["inverse"][idx_sort] = inverse
891
+ if self.return_grid_coord:
892
+ data_part["grid_coord"] = grid_coord[idx_part]
893
+ data_dict["index_valid_keys"].append("grid_coord")
894
+ if self.return_min_coord:
895
+ data_part["min_coord"] = min_coord.reshape([1, 3])
896
+ if self.return_displacement:
897
+ displacement = (
898
+ scaled_coord - grid_coord - 0.5
899
+ ) # [0, 1] -> [-0.5, 0.5] displacement to center
900
+ if self.project_displacement:
901
+ displacement = np.sum(
902
+ displacement * data_dict["normal"], axis=-1, keepdims=True
903
+ )
904
+ data_dict["displacement"] = displacement[idx_part]
905
+ data_dict["index_valid_keys"].append("displacement")
906
+ data_part_list.append(data_part)
907
+ return data_part_list
908
+ else:
909
+ raise NotImplementedError
910
+
911
+ @staticmethod
912
+ def ravel_hash_vec(arr):
913
+ """
914
+ Ravel the coordinates after subtracting the min coordinates.
915
+ """
916
+ assert arr.ndim == 2
917
+ arr = arr.copy()
918
+ arr -= arr.min(0)
919
+ arr = arr.astype(np.uint64, copy=False)
920
+ arr_max = arr.max(0).astype(np.uint64) + 1
921
+
922
+ keys = np.zeros(arr.shape[0], dtype=np.uint64)
923
+ # Fortran style indexing
924
+ for j in range(arr.shape[1] - 1):
925
+ keys += arr[:, j]
926
+ keys *= arr_max[j + 1]
927
+ keys += arr[:, -1]
928
+ return keys
929
+
930
+ @staticmethod
931
+ def fnv_hash_vec(arr):
932
+ """
933
+ FNV64-1A
934
+ """
935
+ assert arr.ndim == 2
936
+ # Floor first for negative coordinates
937
+ arr = arr.copy()
938
+ arr = arr.astype(np.uint64, copy=False)
939
+ hashed_arr = np.uint64(14695981039346656037) * np.ones(
940
+ arr.shape[0], dtype=np.uint64
941
+ )
942
+ for j in range(arr.shape[1]):
943
+ hashed_arr *= np.uint64(1099511628211)
944
+ hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
945
+ return hashed_arr
946
+
947
+
948
+ @TRANSFORMS.register_module()
949
+ class SphereCrop(object):
950
+ def __init__(self, point_max=80000, sample_rate=None, mode="random"):
951
+ self.point_max = point_max
952
+ self.sample_rate = sample_rate
953
+ assert mode in ["random", "center", "all"]
954
+ self.mode = mode
955
+
956
+ def __call__(self, data_dict):
957
+ point_max = (
958
+ int(self.sample_rate * data_dict["coord"].shape[0])
959
+ if self.sample_rate is not None
960
+ else self.point_max
961
+ )
962
+
963
+ assert "coord" in data_dict.keys()
964
+ if data_dict["coord"].shape[0] > point_max:
965
+ if self.mode == "random":
966
+ center = data_dict["coord"][
967
+ np.random.randint(data_dict["coord"].shape[0])
968
+ ]
969
+ elif self.mode == "center":
970
+ center = data_dict["coord"][data_dict["coord"].shape[0] // 2]
971
+ else:
972
+ raise NotImplementedError
973
+ idx_crop = np.argsort(np.sum(np.square(data_dict["coord"] - center), 1))[
974
+ :point_max
975
+ ]
976
+ data_dict = index_operator(data_dict, idx_crop)
977
+ return data_dict
978
+
979
+
980
+ @TRANSFORMS.register_module()
981
+ class ShufflePoint(object):
982
+ def __call__(self, data_dict):
983
+ assert "coord" in data_dict.keys()
984
+ shuffle_index = np.arange(data_dict["coord"].shape[0])
985
+ np.random.shuffle(shuffle_index)
986
+ data_dict = index_operator(data_dict, shuffle_index)
987
+ return data_dict
988
+
989
+
990
+ @TRANSFORMS.register_module()
991
+ class CropBoundary(object):
992
+ def __call__(self, data_dict):
993
+ assert "segment" in data_dict
994
+ segment = data_dict["segment"].flatten()
995
+ mask = (segment != 0) * (segment != 1)
996
+ data_dict = index_operator(data_dict, mask)
997
+ return data_dict
998
+
999
+
1000
+ @TRANSFORMS.register_module()
1001
+ class ContrastiveViewsGenerator(object):
1002
+ def __init__(
1003
+ self,
1004
+ view_keys=("coord", "color", "normal", "origin_coord"),
1005
+ view_trans_cfg=None,
1006
+ ):
1007
+ self.view_keys = view_keys
1008
+ self.view_trans = Compose(view_trans_cfg)
1009
+
1010
+ def __call__(self, data_dict):
1011
+ view1_dict = dict()
1012
+ view2_dict = dict()
1013
+ for key in self.view_keys:
1014
+ view1_dict[key] = data_dict[key].copy()
1015
+ view2_dict[key] = data_dict[key].copy()
1016
+ view1_dict = self.view_trans(view1_dict)
1017
+ view2_dict = self.view_trans(view2_dict)
1018
+ for key, value in view1_dict.items():
1019
+ data_dict["view1_" + key] = value
1020
+ for key, value in view2_dict.items():
1021
+ data_dict["view2_" + key] = value
1022
+ return data_dict
1023
+
1024
+
1025
+ @TRANSFORMS.register_module()
1026
+ class MultiViewGenerator(object):
1027
+ def __init__(
1028
+ self,
1029
+ global_view_num=2,
1030
+ global_view_scale=(0.4, 1.0),
1031
+ local_view_num=4,
1032
+ local_view_scale=(0.1, 0.4),
1033
+ global_shared_transform=None,
1034
+ global_transform=None,
1035
+ local_transform=None,
1036
+ max_size=65536,
1037
+ center_height_scale=(0, 1),
1038
+ shared_global_view=False,
1039
+ view_keys=("coord", "origin_coord", "color", "normal"),
1040
+ ):
1041
+ self.global_view_num = global_view_num
1042
+ self.global_view_scale = global_view_scale
1043
+ self.local_view_num = local_view_num
1044
+ self.local_view_scale = local_view_scale
1045
+ self.global_shared_transform = Compose(global_shared_transform)
1046
+ self.global_transform = Compose(global_transform)
1047
+ self.local_transform = Compose(local_transform)
1048
+ self.max_size = max_size
1049
+ self.center_height_scale = center_height_scale
1050
+ self.shared_global_view = shared_global_view
1051
+ self.view_keys = view_keys
1052
+ assert "coord" in view_keys
1053
+
1054
+ def get_view(self, point, center, scale):
1055
+ coord = point["coord"]
1056
+ max_size = min(self.max_size, coord.shape[0])
1057
+ size = int(np.random.uniform(*scale) * max_size)
1058
+ index = np.argsort(np.sum(np.square(coord - center), axis=-1))[:size]
1059
+ view = dict(index=index)
1060
+ for key in point.keys():
1061
+ if key in self.view_keys:
1062
+ view[key] = point[key][index]
1063
+
1064
+ if "index_valid_keys" in point.keys():
1065
+ # inherit index_valid_keys from point
1066
+ view["index_valid_keys"] = point["index_valid_keys"]
1067
+ return view
1068
+
1069
+ def __call__(self, data_dict):
1070
+ coord = data_dict["coord"]
1071
+ point = self.global_shared_transform(copy.deepcopy(data_dict))
1072
+ z_min = coord[:, 2].min()
1073
+ z_max = coord[:, 2].max()
1074
+ z_min_ = z_min + (z_max - z_min) * self.center_height_scale[0]
1075
+ z_max_ = z_min + (z_max - z_min) * self.center_height_scale[1]
1076
+ center_mask = np.logical_and(coord[:, 2] >= z_min_, coord[:, 2] <= z_max_)
1077
+ # get major global view
1078
+ major_center = coord[np.random.choice(np.where(center_mask)[0])]
1079
+ major_view = self.get_view(point, major_center, self.global_view_scale)
1080
+ major_coord = major_view["coord"]
1081
+ # get global views: restrict the center of left global view within the major global view
1082
+ if not self.shared_global_view:
1083
+ global_views = [
1084
+ self.get_view(
1085
+ point=point,
1086
+ center=major_coord[np.random.randint(major_coord.shape[0])],
1087
+ scale=self.global_view_scale,
1088
+ )
1089
+ for _ in range(self.global_view_num - 1)
1090
+ ]
1091
+ else:
1092
+ global_views = [
1093
+ {key: value.copy() for key, value in major_view.items()}
1094
+ for _ in range(self.global_view_num - 1)
1095
+ ]
1096
+
1097
+ global_views = [major_view] + global_views
1098
+
1099
+ # get local views: restrict the center of local view within the major global view
1100
+ cover_mask = np.zeros_like(major_view["index"], dtype=bool)
1101
+ local_views = []
1102
+ for i in range(self.local_view_num):
1103
+ if sum(~cover_mask) == 0:
1104
+ # reset cover mask if all points are sampled
1105
+ cover_mask[:] = False
1106
+ local_view = self.get_view(
1107
+ point=data_dict,
1108
+ center=major_coord[np.random.choice(np.where(~cover_mask)[0])],
1109
+ scale=self.local_view_scale,
1110
+ )
1111
+ local_views.append(local_view)
1112
+ cover_mask[np.isin(major_view["index"], local_view["index"])] = True
1113
+
1114
+ # augmentation and concat
1115
+ view_dict = {}
1116
+ for global_view in global_views:
1117
+ global_view.pop("index")
1118
+ global_view = self.global_transform(global_view)
1119
+ for key in self.view_keys:
1120
+ if f"global_{key}" in view_dict.keys():
1121
+ view_dict[f"global_{key}"].append(global_view[key])
1122
+ else:
1123
+ view_dict[f"global_{key}"] = [global_view[key]]
1124
+ view_dict["global_offset"] = np.cumsum(
1125
+ [data.shape[0] for data in view_dict["global_coord"]]
1126
+ )
1127
+ for local_view in local_views:
1128
+ local_view.pop("index")
1129
+ local_view = self.local_transform(local_view)
1130
+ for key in self.view_keys:
1131
+ if f"local_{key}" in view_dict.keys():
1132
+ view_dict[f"local_{key}"].append(local_view[key])
1133
+ else:
1134
+ view_dict[f"local_{key}"] = [local_view[key]]
1135
+ view_dict["local_offset"] = np.cumsum(
1136
+ [data.shape[0] for data in view_dict["local_coord"]]
1137
+ )
1138
+ for key in view_dict.keys():
1139
+ if "offset" not in key:
1140
+ view_dict[key] = np.concatenate(view_dict[key], axis=0)
1141
+ data_dict.update(view_dict)
1142
+ return data_dict
1143
+
1144
+
1145
+ @TRANSFORMS.register_module()
1146
+ class InstanceParser(object):
1147
+ def __init__(self, segment_ignore_index=(-1, 0, 1), instance_ignore_index=-1):
1148
+ self.segment_ignore_index = segment_ignore_index
1149
+ self.instance_ignore_index = instance_ignore_index
1150
+
1151
+ def __call__(self, data_dict):
1152
+ coord = data_dict["coord"]
1153
+ segment = data_dict["segment"]
1154
+ instance = data_dict["instance"]
1155
+ mask = ~np.in1d(segment, self.segment_ignore_index)
1156
+ # mapping ignored instance to ignore index
1157
+ instance[~mask] = self.instance_ignore_index
1158
+ # reorder left instance
1159
+ unique, inverse = np.unique(instance[mask], return_inverse=True)
1160
+ instance_num = len(unique)
1161
+ instance[mask] = inverse
1162
+ # init instance information
1163
+ centroid = np.ones((coord.shape[0], 3)) * self.instance_ignore_index
1164
+ bbox = np.ones((instance_num, 8)) * self.instance_ignore_index
1165
+ vacancy = [
1166
+ index for index in self.segment_ignore_index if index >= 0
1167
+ ] # vacate class index
1168
+
1169
+ for instance_id in range(instance_num):
1170
+ mask_ = instance == instance_id
1171
+ coord_ = coord[mask_]
1172
+ bbox_min = coord_.min(0)
1173
+ bbox_max = coord_.max(0)
1174
+ bbox_centroid = coord_.mean(0)
1175
+ bbox_center = (bbox_max + bbox_min) / 2
1176
+ bbox_size = bbox_max - bbox_min
1177
+ bbox_theta = np.zeros(1, dtype=coord_.dtype)
1178
+ bbox_class = np.array([segment[mask_][0]], dtype=coord_.dtype)
1179
+ # shift class index to fill vacate class index caused by segment ignore index
1180
+ bbox_class -= np.greater(bbox_class, vacancy).sum()
1181
+
1182
+ centroid[mask_] = bbox_centroid
1183
+ bbox[instance_id] = np.concatenate(
1184
+ [bbox_center, bbox_size, bbox_theta, bbox_class]
1185
+ ) # 3 + 3 + 1 + 1 = 8
1186
+ data_dict["instance"] = instance
1187
+ data_dict["instance_centroid"] = centroid
1188
+ data_dict["bbox"] = bbox
1189
+ return data_dict
1190
+
1191
+
1192
+ class Compose(object):
1193
+ def __init__(self, cfg=None):
1194
+ self.cfg = cfg if cfg is not None else []
1195
+ self.transforms = []
1196
+ for t_cfg in self.cfg:
1197
+ self.transforms.append(TRANSFORMS.build(t_cfg))
1198
+
1199
+ def __call__(self, data_dict):
1200
+ for t in self.transforms:
1201
+ data_dict = t(data_dict)
1202
+ return data_dict
1203
+
1204
+
1205
+ def default(scale = 1.0, apply_z_positive = True, normalize_coord = False):
1206
+ config = [
1207
+ *([dict(type="NormalizeCoord")] if normalize_coord else []),
1208
+ dict(type="RandomScale", scale=[scale, scale]),
1209
+ *([dict(type="CenterShift", apply_z=True)] if apply_z_positive else []),
1210
+ dict(
1211
+ type="GridSample",
1212
+ grid_size=0.01,
1213
+ hash_type="fnv",
1214
+ mode="train",
1215
+ return_grid_coord=True,
1216
+ return_inverse=True,
1217
+ ),
1218
+ dict(type="NormalizeColor"),
1219
+ dict(type="ToTensor"),
1220
+ dict(
1221
+ type="Collect",
1222
+ keys=("coord", "grid_coord", "color", "inverse"),
1223
+ feat_keys=("coord", "color", "normal"),
1224
+ ),
1225
+ ]
1226
+ return Compose(config)
utonia/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ General utils
3
+
4
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
5
+ Please cite our work if the code is helpful to you.
6
+ """
7
+
8
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+
23
+ import os
24
+ import random
25
+ import numpy as np
26
+ import torch
27
+ import torch.backends.cudnn as cudnn
28
+ from datetime import datetime
29
+
30
+
31
+ @torch.no_grad()
32
+ def offset2bincount(offset):
33
+ return torch.diff(
34
+ offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long)
35
+ )
36
+
37
+
38
+ @torch.no_grad()
39
+ def bincount2offset(bincount):
40
+ return torch.cumsum(bincount, dim=0)
41
+
42
+
43
+ @torch.no_grad()
44
+ def offset2batch(offset):
45
+ bincount = offset2bincount(offset)
46
+ return torch.arange(
47
+ len(bincount), device=offset.device, dtype=torch.long
48
+ ).repeat_interleave(bincount)
49
+
50
+
51
+ @torch.no_grad()
52
+ def batch2offset(batch):
53
+ return torch.cumsum(batch.bincount(), dim=0).long()
54
+
55
+
56
+ def get_random_seed():
57
+ seed = (
58
+ os.getpid()
59
+ + int(datetime.now().strftime("%S%f"))
60
+ + int.from_bytes(os.urandom(2), "big")
61
+ )
62
+ return seed
63
+
64
+
65
+ def set_seed(seed=None):
66
+ if seed is None:
67
+ seed = get_random_seed()
68
+ random.seed(seed)
69
+ np.random.seed(seed)
70
+ torch.manual_seed(seed)
71
+ torch.cuda.manual_seed(seed)
72
+ torch.cuda.manual_seed_all(seed)
73
+ cudnn.benchmark = False
74
+ cudnn.deterministic = True
75
+ os.environ["PYTHONHASHSEED"] = str(seed)
vggt/__init__.py ADDED
File without changes
vggt/heads/camera_head.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from vggt.layers import Mlp
15
+ from vggt.layers.block import Block
16
+ from vggt.heads.head_act import activate_pose
17
+
18
+
19
+ class CameraHead(nn.Module):
20
+ """
21
+ CameraHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ if pose_encoding_type == "absT_quaR_FoV":
41
+ self.target_dim = 9
42
+ else:
43
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
+
45
+ self.trans_act = trans_act
46
+ self.quat_act = quat_act
47
+ self.fl_act = fl_act
48
+ self.trunk_depth = trunk_depth
49
+
50
+ # Build the trunk using a sequence of transformer blocks.
51
+ self.trunk = nn.Sequential(
52
+ *[
53
+ Block(
54
+ dim=dim_in,
55
+ num_heads=num_heads,
56
+ mlp_ratio=mlp_ratio,
57
+ init_values=init_values,
58
+ )
59
+ for _ in range(trunk_depth)
60
+ ]
61
+ )
62
+
63
+ # Normalizations for camera token and trunk output.
64
+ self.token_norm = nn.LayerNorm(dim_in)
65
+ self.trunk_norm = nn.LayerNorm(dim_in)
66
+
67
+ # Learnable empty camera pose token.
68
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
69
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
70
+
71
+ # Module for producing modulation parameters: shift, scale, and a gate.
72
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
73
+
74
+ # Adaptive layer normalization without affine parameters.
75
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
76
+ self.pose_branch = Mlp(
77
+ in_features=dim_in,
78
+ hidden_features=dim_in // 2,
79
+ out_features=self.target_dim,
80
+ drop=0,
81
+ )
82
+
83
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
84
+ """
85
+ Forward pass to predict camera parameters.
86
+
87
+ Args:
88
+ aggregated_tokens_list (list): List of token tensors from the network;
89
+ the last tensor is used for prediction.
90
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
91
+
92
+ Returns:
93
+ list: A list of predicted camera encodings (post-activation) from each iteration.
94
+ """
95
+ # Use tokens from the last block for camera prediction.
96
+ tokens = aggregated_tokens_list[-1]
97
+
98
+ # Extract the camera tokens
99
+ pose_tokens = tokens[:, :, 0]
100
+ pose_tokens = self.token_norm(pose_tokens)
101
+
102
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
103
+ return pred_pose_enc_list
104
+
105
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
106
+ """
107
+ Iteratively refine camera pose predictions.
108
+
109
+ Args:
110
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
111
+ num_iterations (int): Number of refinement iterations.
112
+
113
+ Returns:
114
+ list: List of activated camera encodings from each iteration.
115
+ """
116
+ B, S, C = pose_tokens.shape # S is expected to be 1.
117
+ pred_pose_enc = None
118
+ pred_pose_enc_list = []
119
+
120
+ for _ in range(num_iterations):
121
+ # Use a learned empty pose for the first iteration.
122
+ if pred_pose_enc is None:
123
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
124
+ else:
125
+ # Detach the previous prediction to avoid backprop through time.
126
+ pred_pose_enc = pred_pose_enc.detach()
127
+ module_input = self.embed_pose(pred_pose_enc)
128
+
129
+ # Generate modulation parameters and split them into shift, scale, and gate components.
130
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
131
+
132
+ # Adaptive layer normalization and modulation.
133
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
134
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
135
+
136
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
137
+ # Compute the delta update for the pose encoding.
138
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
139
+
140
+ if pred_pose_enc is None:
141
+ pred_pose_enc = pred_pose_enc_delta
142
+ else:
143
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
144
+
145
+ # Apply final activation functions for translation, quaternion, and field-of-view.
146
+ activated_pose = activate_pose(
147
+ pred_pose_enc,
148
+ trans_act=self.trans_act,
149
+ quat_act=self.quat_act,
150
+ fl_act=self.fl_act,
151
+ )
152
+ pred_pose_enc_list.append(activated_pose)
153
+
154
+ return pred_pose_enc_list
155
+
156
+
157
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Modulate the input tensor using scaling and shifting parameters.
160
+ """
161
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
162
+ return x * (1 + scale) + shift
vggt/heads/dpt_head.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [
71
+ nn.Conv2d(
72
+ in_channels=dim_in,
73
+ out_channels=oc,
74
+ kernel_size=1,
75
+ stride=1,
76
+ padding=0,
77
+ )
78
+ for oc in out_channels
79
+ ]
80
+ )
81
+
82
+ # Resize layers for upsampling feature maps.
83
+ self.resize_layers = nn.ModuleList(
84
+ [
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
87
+ ),
88
+ nn.ConvTranspose2d(
89
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
90
+ ),
91
+ nn.Identity(),
92
+ nn.Conv2d(
93
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
94
+ ),
95
+ ]
96
+ )
97
+
98
+ self.scratch = _make_scratch(
99
+ out_channels,
100
+ features,
101
+ expand=False,
102
+ )
103
+
104
+ # Attach additional modules to scratch.
105
+ self.scratch.stem_transpose = None
106
+ self.scratch.refinenet1 = _make_fusion_block(features)
107
+ self.scratch.refinenet2 = _make_fusion_block(features)
108
+ self.scratch.refinenet3 = _make_fusion_block(features)
109
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
110
+
111
+ head_features_1 = features
112
+ head_features_2 = 32
113
+
114
+ if feature_only:
115
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
116
+ else:
117
+ self.scratch.output_conv1 = nn.Conv2d(
118
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
119
+ )
120
+ conv2_in_channels = head_features_1 // 2
121
+
122
+ self.scratch.output_conv2 = nn.Sequential(
123
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
126
+ )
127
+
128
+ def forward(
129
+ self,
130
+ aggregated_tokens_list: List[torch.Tensor],
131
+ images: torch.Tensor,
132
+ patch_start_idx: int,
133
+ frames_chunk_size: int = 8,
134
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
135
+ """
136
+ Forward pass through the DPT head, supports processing by chunking frames.
137
+ Args:
138
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
139
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
140
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
141
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
142
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
143
+ If None or larger than S, all frames are processed at once. Default: 8.
144
+
145
+ Returns:
146
+ Tensor or Tuple[Tensor, Tensor]:
147
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
148
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
149
+ """
150
+ B, S, _, H, W = images.shape
151
+
152
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
153
+ if frames_chunk_size is None or frames_chunk_size >= S:
154
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
155
+
156
+ # Otherwise, process frames in chunks to manage memory usage
157
+ assert frames_chunk_size > 0
158
+
159
+ # Process frames in batches
160
+ all_preds = []
161
+ all_conf = []
162
+
163
+ for frames_start_idx in range(0, S, frames_chunk_size):
164
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
165
+
166
+ # Process batch of frames
167
+ if self.feature_only:
168
+ chunk_output = self._forward_impl(
169
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
170
+ )
171
+ all_preds.append(chunk_output)
172
+ else:
173
+ chunk_preds, chunk_conf = self._forward_impl(
174
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
175
+ )
176
+ all_preds.append(chunk_preds)
177
+ all_conf.append(chunk_conf)
178
+
179
+ # Concatenate results along the sequence dimension
180
+ if self.feature_only:
181
+ return torch.cat(all_preds, dim=1)
182
+ else:
183
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
184
+
185
+ def _forward_impl(
186
+ self,
187
+ aggregated_tokens_list: List[torch.Tensor],
188
+ images: torch.Tensor,
189
+ patch_start_idx: int,
190
+ frames_start_idx: int = None,
191
+ frames_end_idx: int = None,
192
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
193
+ """
194
+ Implementation of the forward pass through the DPT head.
195
+
196
+ This method processes a specific chunk of frames from the sequence.
197
+
198
+ Args:
199
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
200
+ images (Tensor): Input images with shape [B, S, 3, H, W].
201
+ patch_start_idx (int): Starting index for patch tokens.
202
+ frames_start_idx (int, optional): Starting index for frames to process.
203
+ frames_end_idx (int, optional): Ending index for frames to process.
204
+
205
+ Returns:
206
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
207
+ """
208
+ if frames_start_idx is not None and frames_end_idx is not None:
209
+ images = images[:, frames_start_idx:frames_end_idx]
210
+
211
+ B, S, _, H, W = images.shape
212
+
213
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
214
+
215
+ out = []
216
+ dpt_idx = 0
217
+
218
+ for layer_idx in self.intermediate_layer_idx:
219
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
220
+
221
+ # Select frames if processing a chunk
222
+ if frames_start_idx is not None and frames_end_idx is not None:
223
+ x = x[:, frames_start_idx:frames_end_idx]
224
+
225
+ x = x.view(B * S, -1, x.shape[-1])
226
+
227
+ x = self.norm(x)
228
+
229
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
230
+
231
+ x = self.projects[dpt_idx](x)
232
+ if self.pos_embed:
233
+ x = self._apply_pos_embed(x, W, H)
234
+ x = self.resize_layers[dpt_idx](x)
235
+
236
+ out.append(x)
237
+ dpt_idx += 1
238
+
239
+ # Fuse features from multiple layers.
240
+ out = self.scratch_forward(out)
241
+ # Interpolate fused output to match target image resolution.
242
+ out = custom_interpolate(
243
+ out,
244
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
245
+ mode="bilinear",
246
+ align_corners=True,
247
+ )
248
+
249
+ if self.pos_embed:
250
+ out = self._apply_pos_embed(out, W, H)
251
+
252
+ if self.feature_only:
253
+ return out.view(B, S, *out.shape[1:])
254
+
255
+ out = self.scratch.output_conv2(out)
256
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
257
+
258
+ preds = preds.view(B, S, *preds.shape[1:])
259
+ conf = conf.view(B, S, *conf.shape[1:])
260
+ return preds, conf
261
+
262
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
263
+ """
264
+ Apply positional embedding to tensor x.
265
+ """
266
+ patch_w = x.shape[-1]
267
+ patch_h = x.shape[-2]
268
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
269
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
270
+ pos_embed = pos_embed * ratio
271
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
272
+ return x + pos_embed
273
+
274
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
275
+ """
276
+ Forward pass through the fusion blocks.
277
+
278
+ Args:
279
+ features (List[Tensor]): List of feature maps from different layers.
280
+
281
+ Returns:
282
+ Tensor: Fused feature map.
283
+ """
284
+ layer_1, layer_2, layer_3, layer_4 = features
285
+
286
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
287
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
288
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
289
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
290
+
291
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
292
+ del layer_4_rn, layer_4
293
+
294
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
295
+ del layer_3_rn, layer_3
296
+
297
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
298
+ del layer_2_rn, layer_2
299
+
300
+ out = self.scratch.refinenet1(out, layer_1_rn)
301
+ del layer_1_rn, layer_1
302
+
303
+ out = self.scratch.output_conv1(out)
304
+ return out
305
+
306
+
307
+ ################################################################################
308
+ # Modules
309
+ ################################################################################
310
+
311
+
312
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
313
+ return FeatureFusionBlock(
314
+ features,
315
+ nn.ReLU(inplace=True),
316
+ deconv=False,
317
+ bn=False,
318
+ expand=False,
319
+ align_corners=True,
320
+ size=size,
321
+ has_residual=has_residual,
322
+ groups=groups,
323
+ )
324
+
325
+
326
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
327
+ scratch = nn.Module()
328
+ out_shape1 = out_shape
329
+ out_shape2 = out_shape
330
+ out_shape3 = out_shape
331
+ if len(in_shape) >= 4:
332
+ out_shape4 = out_shape
333
+
334
+ if expand:
335
+ out_shape1 = out_shape
336
+ out_shape2 = out_shape * 2
337
+ out_shape3 = out_shape * 4
338
+ if len(in_shape) >= 4:
339
+ out_shape4 = out_shape * 8
340
+
341
+ scratch.layer1_rn = nn.Conv2d(
342
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
343
+ )
344
+ scratch.layer2_rn = nn.Conv2d(
345
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
346
+ )
347
+ scratch.layer3_rn = nn.Conv2d(
348
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
349
+ )
350
+ if len(in_shape) >= 4:
351
+ scratch.layer4_rn = nn.Conv2d(
352
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
353
+ )
354
+ return scratch
355
+
356
+
357
+ class ResidualConvUnit(nn.Module):
358
+ """Residual convolution module."""
359
+
360
+ def __init__(self, features, activation, bn, groups=1):
361
+ """Init.
362
+
363
+ Args:
364
+ features (int): number of features
365
+ """
366
+ super().__init__()
367
+
368
+ self.bn = bn
369
+ self.groups = groups
370
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
371
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
372
+
373
+ self.norm1 = None
374
+ self.norm2 = None
375
+
376
+ self.activation = activation
377
+ self.skip_add = nn.quantized.FloatFunctional()
378
+
379
+ def forward(self, x):
380
+ """Forward pass.
381
+
382
+ Args:
383
+ x (tensor): input
384
+
385
+ Returns:
386
+ tensor: output
387
+ """
388
+
389
+ out = self.activation(x)
390
+ out = self.conv1(out)
391
+ if self.norm1 is not None:
392
+ out = self.norm1(out)
393
+
394
+ out = self.activation(out)
395
+ out = self.conv2(out)
396
+ if self.norm2 is not None:
397
+ out = self.norm2(out)
398
+
399
+ return self.skip_add.add(out, x)
400
+
401
+
402
+ class FeatureFusionBlock(nn.Module):
403
+ """Feature fusion block."""
404
+
405
+ def __init__(
406
+ self,
407
+ features,
408
+ activation,
409
+ deconv=False,
410
+ bn=False,
411
+ expand=False,
412
+ align_corners=True,
413
+ size=None,
414
+ has_residual=True,
415
+ groups=1,
416
+ ):
417
+ """Init.
418
+
419
+ Args:
420
+ features (int): number of features
421
+ """
422
+ super(FeatureFusionBlock, self).__init__()
423
+
424
+ self.deconv = deconv
425
+ self.align_corners = align_corners
426
+ self.groups = groups
427
+ self.expand = expand
428
+ out_features = features
429
+ if self.expand == True:
430
+ out_features = features // 2
431
+
432
+ self.out_conv = nn.Conv2d(
433
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
434
+ )
435
+
436
+ if has_residual:
437
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
438
+
439
+ self.has_residual = has_residual
440
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
441
+
442
+ self.skip_add = nn.quantized.FloatFunctional()
443
+ self.size = size
444
+
445
+ def forward(self, *xs, size=None):
446
+ """Forward pass.
447
+
448
+ Returns:
449
+ tensor: output
450
+ """
451
+ output = xs[0]
452
+
453
+ if self.has_residual:
454
+ res = self.resConfUnit1(xs[1])
455
+ output = self.skip_add.add(output, res)
456
+
457
+ output = self.resConfUnit2(output)
458
+
459
+ if (size is None) and (self.size is None):
460
+ modifier = {"scale_factor": 2}
461
+ elif size is None:
462
+ modifier = {"size": self.size}
463
+ else:
464
+ modifier = {"size": size}
465
+
466
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
467
+ output = self.out_conv(output)
468
+
469
+ return output
470
+
471
+
472
+ def custom_interpolate(
473
+ x: torch.Tensor,
474
+ size: Tuple[int, int] = None,
475
+ scale_factor: float = None,
476
+ mode: str = "bilinear",
477
+ align_corners: bool = True,
478
+ ) -> torch.Tensor:
479
+ """
480
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
481
+ """
482
+ if size is None:
483
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
484
+
485
+ INT_MAX = 1610612736
486
+
487
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
488
+
489
+ if input_elements > INT_MAX:
490
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
491
+ interpolated_chunks = [
492
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
493
+ ]
494
+ x = torch.cat(interpolated_chunks, dim=0)
495
+ return x.contiguous()
496
+ else:
497
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
vggt/heads/head_act.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
+ """
14
+ Activate pose parameters with specified activation functions.
15
+
16
+ Args:
17
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
+ trans_act: Activation type for translation component
19
+ quat_act: Activation type for quaternion component
20
+ fl_act: Activation type for focal length component
21
+
22
+ Returns:
23
+ Activated pose parameters tensor
24
+ """
25
+ T = pred_pose_enc[..., :3]
26
+ quat = pred_pose_enc[..., 3:7]
27
+ fl = pred_pose_enc[..., 7:] # or fov
28
+
29
+ T = base_pose_act(T, trans_act)
30
+ quat = base_pose_act(quat, quat_act)
31
+ fl = base_pose_act(fl, fl_act) # or fov
32
+
33
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
+
35
+ return pred_pose_enc
36
+
37
+
38
+ def base_pose_act(pose_enc, act_type="linear"):
39
+ """
40
+ Apply basic activation function to pose parameters.
41
+
42
+ Args:
43
+ pose_enc: Tensor containing encoded pose parameters
44
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
+
46
+ Returns:
47
+ Activated pose parameters
48
+ """
49
+ if act_type == "linear":
50
+ return pose_enc
51
+ elif act_type == "inv_log":
52
+ return inverse_log_transform(pose_enc)
53
+ elif act_type == "exp":
54
+ return torch.exp(pose_enc)
55
+ elif act_type == "relu":
56
+ return F.relu(pose_enc)
57
+ else:
58
+ raise ValueError(f"Unknown act_type: {act_type}")
59
+
60
+
61
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
+ """
63
+ Process network output to extract 3D points and confidence values.
64
+
65
+ Args:
66
+ out: Network output tensor (B, C, H, W)
67
+ activation: Activation type for 3D points
68
+ conf_activation: Activation type for confidence values
69
+
70
+ Returns:
71
+ Tuple of (3D points tensor, confidence tensor)
72
+ """
73
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
+
76
+ # Split into xyz (first C-1 channels) and confidence (last channel)
77
+ xyz = fmap[:, :, :, :-1]
78
+ conf = fmap[:, :, :, -1]
79
+
80
+ if activation == "norm_exp":
81
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
+ xyz_normed = xyz / d
83
+ pts3d = xyz_normed * torch.expm1(d)
84
+ elif activation == "norm":
85
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
+ elif activation == "exp":
87
+ pts3d = torch.exp(xyz)
88
+ elif activation == "relu":
89
+ pts3d = F.relu(xyz)
90
+ elif activation == "inv_log":
91
+ pts3d = inverse_log_transform(xyz)
92
+ elif activation == "xy_inv_log":
93
+ xy, z = xyz.split([2, 1], dim=-1)
94
+ z = inverse_log_transform(z)
95
+ pts3d = torch.cat([xy * z, z], dim=-1)
96
+ elif activation == "sigmoid":
97
+ pts3d = torch.sigmoid(xyz)
98
+ elif activation == "linear":
99
+ pts3d = xyz
100
+ else:
101
+ raise ValueError(f"Unknown activation: {activation}")
102
+
103
+ if conf_activation == "expp1":
104
+ conf_out = 1 + conf.exp()
105
+ elif conf_activation == "expp0":
106
+ conf_out = conf.exp()
107
+ elif conf_activation == "sigmoid":
108
+ conf_out = torch.sigmoid(conf)
109
+ else:
110
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
+
112
+ return pts3d, conf_out
113
+
114
+
115
+ def inverse_log_transform(y):
116
+ """
117
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
+
119
+ Args:
120
+ y: Input tensor
121
+
122
+ Returns:
123
+ Transformed tensor
124
+ """
125
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
vggt/heads/track_head.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+ from .dpt_head import DPTHead
9
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
10
+
11
+
12
+ class TrackHead(nn.Module):
13
+ """
14
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
15
+ The tracking is performed iteratively, refining predictions over multiple iterations.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dim_in,
21
+ patch_size=14,
22
+ features=128,
23
+ iters=4,
24
+ predict_conf=True,
25
+ stride=2,
26
+ corr_levels=7,
27
+ corr_radius=4,
28
+ hidden_size=384,
29
+ ):
30
+ """
31
+ Initialize the TrackHead module.
32
+
33
+ Args:
34
+ dim_in (int): Input dimension of tokens from the backbone.
35
+ patch_size (int): Size of image patches used in the vision transformer.
36
+ features (int): Number of feature channels in the feature extractor output.
37
+ iters (int): Number of refinement iterations for tracking predictions.
38
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
39
+ stride (int): Stride value for the tracker predictor.
40
+ corr_levels (int): Number of correlation pyramid levels
41
+ corr_radius (int): Radius for correlation computation, controlling the search area.
42
+ hidden_size (int): Size of hidden layers in the tracker network.
43
+ """
44
+ super().__init__()
45
+
46
+ self.patch_size = patch_size
47
+
48
+ # Feature extractor based on DPT architecture
49
+ # Processes tokens into feature maps for tracking
50
+ self.feature_extractor = DPTHead(
51
+ dim_in=dim_in,
52
+ patch_size=patch_size,
53
+ features=features,
54
+ feature_only=True, # Only output features, no activation
55
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
56
+ pos_embed=False,
57
+ )
58
+
59
+ # Tracker module that predicts point trajectories
60
+ # Takes feature maps and predicts coordinates and visibility
61
+ self.tracker = BaseTrackerPredictor(
62
+ latent_dim=features, # Match the output_dim of feature extractor
63
+ predict_conf=predict_conf,
64
+ stride=stride,
65
+ corr_levels=corr_levels,
66
+ corr_radius=corr_radius,
67
+ hidden_size=hidden_size,
68
+ )
69
+
70
+ self.iters = iters
71
+
72
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
73
+ """
74
+ Forward pass of the TrackHead.
75
+
76
+ Args:
77
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
78
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
79
+ B = batch size, S = sequence length.
80
+ patch_start_idx (int): Starting index for patch tokens.
81
+ query_points (torch.Tensor, optional): Initial query points to track.
82
+ If None, points are initialized by the tracker.
83
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
84
+
85
+ Returns:
86
+ tuple:
87
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
88
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
89
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
90
+ """
91
+ B, S, _, H, W = images.shape
92
+
93
+ # Extract features from tokens
94
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
95
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
96
+
97
+ # Use default iterations if not specified
98
+ if iters is None:
99
+ iters = self.iters
100
+
101
+ # Perform tracking using the extracted features
102
+ coord_preds, vis_scores, conf_scores = self.tracker(
103
+ query_points=query_points,
104
+ fmaps=feature_maps,
105
+ iters=iters,
106
+ )
107
+
108
+ return coord_preds, vis_scores, conf_scores
vggt/heads/track_modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
vggt/heads/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+
12
+ from .blocks import EfficientUpdateFormer, CorrBlock
13
+ from .utils import sample_features4d, get_2d_embedding, get_2d_sincos_pos_embed
14
+ from .modules import Mlp
15
+
16
+
17
+ class BaseTrackerPredictor(nn.Module):
18
+ def __init__(
19
+ self,
20
+ stride=1,
21
+ corr_levels=5,
22
+ corr_radius=4,
23
+ latent_dim=128,
24
+ hidden_size=384,
25
+ use_spaceatt=True,
26
+ depth=6,
27
+ max_scale=518,
28
+ predict_conf=True,
29
+ ):
30
+ super(BaseTrackerPredictor, self).__init__()
31
+ """
32
+ The base template to create a track predictor
33
+
34
+ Modified from https://github.com/facebookresearch/co-tracker/
35
+ and https://github.com/facebookresearch/vggsfm
36
+ """
37
+
38
+ self.stride = stride
39
+ self.latent_dim = latent_dim
40
+ self.corr_levels = corr_levels
41
+ self.corr_radius = corr_radius
42
+ self.hidden_size = hidden_size
43
+ self.max_scale = max_scale
44
+ self.predict_conf = predict_conf
45
+
46
+ self.flows_emb_dim = latent_dim // 2
47
+
48
+ self.corr_mlp = Mlp(
49
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
50
+ hidden_features=self.hidden_size,
51
+ out_features=self.latent_dim,
52
+ )
53
+
54
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
55
+
56
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
57
+
58
+ space_depth = depth if use_spaceatt else 0
59
+ time_depth = depth
60
+
61
+ self.updateformer = EfficientUpdateFormer(
62
+ space_depth=space_depth,
63
+ time_depth=time_depth,
64
+ input_dim=self.transformer_dim,
65
+ hidden_size=self.hidden_size,
66
+ output_dim=self.latent_dim + 2,
67
+ mlp_ratio=4.0,
68
+ add_space_attn=use_spaceatt,
69
+ )
70
+
71
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
72
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
73
+
74
+ # A linear layer to update track feats at each iteration
75
+ self.ffeat_updater = nn.Sequential(nn.Linear(self.latent_dim, self.latent_dim), nn.GELU())
76
+
77
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
78
+
79
+ if predict_conf:
80
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
81
+
82
+ def forward(self, query_points, fmaps=None, iters=6, return_feat=False, down_ratio=1, apply_sigmoid=True):
83
+ """
84
+ query_points: B x N x 2, the number of batches, tracks, and xy
85
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
86
+ note HH and WW is the size of feature maps instead of original images
87
+ """
88
+ B, N, D = query_points.shape
89
+ B, S, C, HH, WW = fmaps.shape
90
+
91
+ assert D == 2, "Input points must be 2D coordinates"
92
+
93
+ # apply a layernorm to fmaps here
94
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
95
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
96
+
97
+ # Scale the input query_points because we may downsample the images
98
+ # by down_ratio or self.stride
99
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
100
+ # its query_points should be query_points/4
101
+ if down_ratio > 1:
102
+ query_points = query_points / float(down_ratio)
103
+
104
+ query_points = query_points / float(self.stride)
105
+
106
+ # Init with coords as the query points
107
+ # It means the search will start from the position of query points at the reference frames
108
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
109
+
110
+ # Sample/extract the features of the query points in the query frame
111
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
112
+
113
+ # init track feats by query feats
114
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
115
+ # back up the init coords
116
+ coords_backup = coords.clone()
117
+
118
+ fcorr_fn = CorrBlock(fmaps, num_levels=self.corr_levels, radius=self.corr_radius)
119
+
120
+ coord_preds = []
121
+
122
+ # Iterative Refinement
123
+ for _ in range(iters):
124
+ # Detach the gradients from the last iteration
125
+ # (in my experience, not very important for performance)
126
+ coords = coords.detach()
127
+
128
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
129
+
130
+ corr_dim = fcorrs.shape[3]
131
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
132
+ fcorrs_ = self.corr_mlp(fcorrs_)
133
+
134
+ # Movement of current coords relative to query points
135
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
136
+
137
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
138
+
139
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
140
+ flows_emb = torch.cat([flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1)
141
+
142
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
143
+
144
+ # Concatenate them as the input for the transformers
145
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
146
+
147
+ # 2D positional embed
148
+ # TODO: this can be much simplified
149
+ pos_embed = get_2d_sincos_pos_embed(self.transformer_dim, grid_size=(HH, WW)).to(query_points.device)
150
+ sampled_pos_emb = sample_features4d(pos_embed.expand(B, -1, -1, -1), coords[:, 0])
151
+
152
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(1)
153
+
154
+ x = transformer_input + sampled_pos_emb
155
+
156
+ # Add the query ref token to the track feats
157
+ query_ref_token = torch.cat(
158
+ [self.query_ref_token[:, 0:1], self.query_ref_token[:, 1:2].expand(-1, S - 1, -1)], dim=1
159
+ )
160
+ x = x + query_ref_token.to(x.device).to(x.dtype)
161
+
162
+ # B, N, S, C
163
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
164
+
165
+ # Compute the delta coordinates and delta track features
166
+ delta, _ = self.updateformer(x)
167
+
168
+ # BN, S, C
169
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
170
+ delta_coords_ = delta[:, :, :2]
171
+ delta_feats_ = delta[:, :, 2:]
172
+
173
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
174
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
175
+
176
+ # Update the track features
177
+ track_feats_ = self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
178
+
179
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(0, 2, 1, 3) # BxSxNxC
180
+
181
+ # B x S x N x 2
182
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
183
+
184
+ # Force coord0 as query
185
+ # because we assume the query points should not be changed
186
+ coords[:, 0] = coords_backup[:, 0]
187
+
188
+ # The predicted tracks are in the original image scale
189
+ if down_ratio > 1:
190
+ coord_preds.append(coords * self.stride * down_ratio)
191
+ else:
192
+ coord_preds.append(coords * self.stride)
193
+
194
+ # B, S, N
195
+ vis_e = self.vis_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
196
+ if apply_sigmoid:
197
+ vis_e = torch.sigmoid(vis_e)
198
+
199
+ if self.predict_conf:
200
+ conf_e = self.conf_predictor(track_feats.reshape(B * S * N, self.latent_dim)).reshape(B, S, N)
201
+ if apply_sigmoid:
202
+ conf_e = torch.sigmoid(conf_e)
203
+ else:
204
+ conf_e = None
205
+
206
+ if return_feat:
207
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
208
+ else:
209
+ return coord_preds, vis_e, conf_e
vggt/heads/track_modules/blocks.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Modified from https://github.com/facebookresearch/co-tracker/
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from .utils import bilinear_sampler
16
+ from .modules import Mlp, AttnBlock, CrossAttnBlock, ResidualBlock
17
+
18
+
19
+ class EfficientUpdateFormer(nn.Module):
20
+ """
21
+ Transformer model that updates track estimates.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ space_depth=6,
27
+ time_depth=6,
28
+ input_dim=320,
29
+ hidden_size=384,
30
+ num_heads=8,
31
+ output_dim=130,
32
+ mlp_ratio=4.0,
33
+ add_space_attn=True,
34
+ num_virtual_tracks=64,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.out_channels = 2
39
+ self.num_heads = num_heads
40
+ self.hidden_size = hidden_size
41
+ self.add_space_attn = add_space_attn
42
+
43
+ # Add input LayerNorm before linear projection
44
+ self.input_norm = nn.LayerNorm(input_dim)
45
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
46
+
47
+ # Add output LayerNorm before final projection
48
+ self.output_norm = nn.LayerNorm(hidden_size)
49
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
50
+ self.num_virtual_tracks = num_virtual_tracks
51
+
52
+ if self.add_space_attn:
53
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
54
+ else:
55
+ self.virual_tracks = None
56
+
57
+ self.time_blocks = nn.ModuleList(
58
+ [
59
+ AttnBlock(
60
+ hidden_size,
61
+ num_heads,
62
+ mlp_ratio=mlp_ratio,
63
+ attn_class=nn.MultiheadAttention,
64
+ )
65
+ for _ in range(time_depth)
66
+ ]
67
+ )
68
+
69
+ if add_space_attn:
70
+ self.space_virtual_blocks = nn.ModuleList(
71
+ [
72
+ AttnBlock(
73
+ hidden_size,
74
+ num_heads,
75
+ mlp_ratio=mlp_ratio,
76
+ attn_class=nn.MultiheadAttention,
77
+ )
78
+ for _ in range(space_depth)
79
+ ]
80
+ )
81
+ self.space_point2virtual_blocks = nn.ModuleList(
82
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
83
+ )
84
+ self.space_virtual2point_blocks = nn.ModuleList(
85
+ [CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(space_depth)]
86
+ )
87
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
88
+ self.initialize_weights()
89
+
90
+ def initialize_weights(self):
91
+ def _basic_init(module):
92
+ if isinstance(module, nn.Linear):
93
+ torch.nn.init.xavier_uniform_(module.weight)
94
+ if module.bias is not None:
95
+ nn.init.constant_(module.bias, 0)
96
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
97
+
98
+ self.apply(_basic_init)
99
+
100
+ def forward(self, input_tensor, mask=None):
101
+ # Apply input LayerNorm
102
+ input_tensor = self.input_norm(input_tensor)
103
+ tokens = self.input_transform(input_tensor)
104
+
105
+ init_tokens = tokens
106
+
107
+ B, _, T, _ = tokens.shape
108
+
109
+ if self.add_space_attn:
110
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
111
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
112
+
113
+ _, N, _, _ = tokens.shape
114
+
115
+ j = 0
116
+ for i in range(len(self.time_blocks)):
117
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
118
+
119
+ time_tokens = self.time_blocks[i](time_tokens)
120
+
121
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
122
+ if self.add_space_attn and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0):
123
+ space_tokens = tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1) # B N T C -> (B T) N C
124
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
125
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
126
+
127
+ virtual_tokens = self.space_virtual2point_blocks[j](virtual_tokens, point_tokens, mask=mask)
128
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
129
+ point_tokens = self.space_point2virtual_blocks[j](point_tokens, virtual_tokens, mask=mask)
130
+
131
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
132
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
133
+ j += 1
134
+
135
+ if self.add_space_attn:
136
+ tokens = tokens[:, : N - self.num_virtual_tracks]
137
+
138
+ tokens = tokens + init_tokens
139
+
140
+ # Apply output LayerNorm before final projection
141
+ tokens = self.output_norm(tokens)
142
+ flow = self.flow_head(tokens)
143
+
144
+ return flow, None
145
+
146
+
147
+ class CorrBlock:
148
+ def __init__(self, fmaps, num_levels=4, radius=4, multiple_track_feats=False, padding_mode="zeros"):
149
+ """
150
+ Build a pyramid of feature maps from the input.
151
+
152
+ fmaps: Tensor (B, S, C, H, W)
153
+ num_levels: number of pyramid levels (each downsampled by factor 2)
154
+ radius: search radius for sampling correlation
155
+ multiple_track_feats: if True, split the target features per pyramid level
156
+ padding_mode: passed to grid_sample / bilinear_sampler
157
+ """
158
+ B, S, C, H, W = fmaps.shape
159
+ self.S, self.C, self.H, self.W = S, C, H, W
160
+ self.num_levels = num_levels
161
+ self.radius = radius
162
+ self.padding_mode = padding_mode
163
+ self.multiple_track_feats = multiple_track_feats
164
+
165
+ # Build pyramid: each level is half the spatial resolution of the previous
166
+ self.fmaps_pyramid = [fmaps] # level 0 is full resolution
167
+ current_fmaps = fmaps
168
+ for i in range(num_levels - 1):
169
+ B, S, C, H, W = current_fmaps.shape
170
+ # Merge batch & sequence dimensions
171
+ current_fmaps = current_fmaps.reshape(B * S, C, H, W)
172
+ # Avg pool down by factor 2
173
+ current_fmaps = F.avg_pool2d(current_fmaps, kernel_size=2, stride=2)
174
+ _, _, H_new, W_new = current_fmaps.shape
175
+ current_fmaps = current_fmaps.reshape(B, S, C, H_new, W_new)
176
+ self.fmaps_pyramid.append(current_fmaps)
177
+
178
+ # Precompute a delta grid (of shape (2r+1, 2r+1, 2)) for sampling.
179
+ # This grid is added to the (scaled) coordinate centroids.
180
+ r = self.radius
181
+ dx = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
182
+ dy = torch.linspace(-r, r, 2 * r + 1, device=fmaps.device, dtype=fmaps.dtype)
183
+ # delta: for every (dy,dx) displacement (i.e. Δx, Δy)
184
+ self.delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), dim=-1) # shape: (2r+1, 2r+1, 2)
185
+
186
+ def corr_sample(self, targets, coords):
187
+ """
188
+ Instead of storing the entire correlation pyramid, we compute each level's correlation
189
+ volume, sample it immediately, then discard it. This saves GPU memory.
190
+
191
+ Args:
192
+ targets: Tensor (B, S, N, C) — features for the current targets.
193
+ coords: Tensor (B, S, N, 2) — coordinates at full resolution.
194
+
195
+ Returns:
196
+ Tensor (B, S, N, L) where L = num_levels * (2*radius+1)**2 (concatenated sampled correlations)
197
+ """
198
+ B, S, N, C = targets.shape
199
+
200
+ # If you have multiple track features, split them per level.
201
+ if self.multiple_track_feats:
202
+ targets_split = torch.split(targets, C // self.num_levels, dim=-1)
203
+
204
+ out_pyramid = []
205
+ for i, fmaps in enumerate(self.fmaps_pyramid):
206
+ # Get current spatial resolution H, W for this pyramid level.
207
+ B, S, C, H, W = fmaps.shape
208
+ # Reshape feature maps for correlation computation:
209
+ # fmap2s: (B, S, C, H*W)
210
+ fmap2s = fmaps.view(B, S, C, H * W)
211
+ # Choose appropriate target features.
212
+ fmap1 = targets_split[i] if self.multiple_track_feats else targets # shape: (B, S, N, C)
213
+
214
+ # Compute correlation directly
215
+ corrs = compute_corr_level(fmap1, fmap2s, C)
216
+ corrs = corrs.view(B, S, N, H, W)
217
+
218
+ # Prepare sampling grid:
219
+ # Scale down the coordinates for the current level.
220
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / (2**i)
221
+ # Make sure our precomputed delta grid is on the same device/dtype.
222
+ delta_lvl = self.delta.to(coords.device).to(coords.dtype)
223
+ # Now the grid for grid_sample is:
224
+ # coords_lvl = centroid_lvl + delta_lvl (broadcasted over grid)
225
+ coords_lvl = centroid_lvl + delta_lvl.view(1, 2 * self.radius + 1, 2 * self.radius + 1, 2)
226
+
227
+ # Sample from the correlation volume using bilinear interpolation.
228
+ # We reshape corrs to (B * S * N, 1, H, W) so grid_sample acts over each target.
229
+ corrs_sampled = bilinear_sampler(
230
+ corrs.reshape(B * S * N, 1, H, W), coords_lvl, padding_mode=self.padding_mode
231
+ )
232
+ # The sampled output is (B * S * N, 1, 2r+1, 2r+1). Flatten the last two dims.
233
+ corrs_sampled = corrs_sampled.view(B, S, N, -1) # Now shape: (B, S, N, (2r+1)^2)
234
+ out_pyramid.append(corrs_sampled)
235
+
236
+ # Concatenate all levels along the last dimension.
237
+ out = torch.cat(out_pyramid, dim=-1).contiguous()
238
+ return out
239
+
240
+
241
+ def compute_corr_level(fmap1, fmap2s, C):
242
+ # fmap1: (B, S, N, C)
243
+ # fmap2s: (B, S, C, H*W)
244
+ corrs = torch.matmul(fmap1, fmap2s) # (B, S, N, H*W)
245
+ corrs = corrs.view(fmap1.shape[0], fmap1.shape[1], fmap1.shape[2], -1) # (B, S, N, H*W)
246
+ return corrs / math.sqrt(C)
vggt/heads/track_modules/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from functools import partial
12
+ from typing import Callable
13
+ import collections
14
+ from torch import Tensor
15
+ from itertools import repeat
16
+
17
+
18
+ # From PyTorch internals
19
+ def _ntuple(n):
20
+ def parse(x):
21
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
22
+ return tuple(x)
23
+ return tuple(repeat(x, n))
24
+
25
+ return parse
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ return val if exists(val) else d
34
+
35
+
36
+ to_2tuple = _ntuple(2)
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ """
41
+ ResidualBlock: construct a block of two conv layers with residual connections
42
+ """
43
+
44
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
45
+ super(ResidualBlock, self).__init__()
46
+
47
+ self.conv1 = nn.Conv2d(
48
+ in_planes,
49
+ planes,
50
+ kernel_size=kernel_size,
51
+ padding=1,
52
+ stride=stride,
53
+ padding_mode="zeros",
54
+ )
55
+ self.conv2 = nn.Conv2d(
56
+ planes,
57
+ planes,
58
+ kernel_size=kernel_size,
59
+ padding=1,
60
+ padding_mode="zeros",
61
+ )
62
+ self.relu = nn.ReLU(inplace=True)
63
+
64
+ num_groups = planes // 8
65
+
66
+ if norm_fn == "group":
67
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
68
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
69
+ if not stride == 1:
70
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
71
+
72
+ elif norm_fn == "batch":
73
+ self.norm1 = nn.BatchNorm2d(planes)
74
+ self.norm2 = nn.BatchNorm2d(planes)
75
+ if not stride == 1:
76
+ self.norm3 = nn.BatchNorm2d(planes)
77
+
78
+ elif norm_fn == "instance":
79
+ self.norm1 = nn.InstanceNorm2d(planes)
80
+ self.norm2 = nn.InstanceNorm2d(planes)
81
+ if not stride == 1:
82
+ self.norm3 = nn.InstanceNorm2d(planes)
83
+
84
+ elif norm_fn == "none":
85
+ self.norm1 = nn.Sequential()
86
+ self.norm2 = nn.Sequential()
87
+ if not stride == 1:
88
+ self.norm3 = nn.Sequential()
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ if stride == 1:
93
+ self.downsample = None
94
+ else:
95
+ self.downsample = nn.Sequential(
96
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
97
+ self.norm3,
98
+ )
99
+
100
+ def forward(self, x):
101
+ y = x
102
+ y = self.relu(self.norm1(self.conv1(y)))
103
+ y = self.relu(self.norm2(self.conv2(y)))
104
+
105
+ if self.downsample is not None:
106
+ x = self.downsample(x)
107
+
108
+ return self.relu(x + y)
109
+
110
+
111
+ class Mlp(nn.Module):
112
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
113
+
114
+ def __init__(
115
+ self,
116
+ in_features,
117
+ hidden_features=None,
118
+ out_features=None,
119
+ act_layer=nn.GELU,
120
+ norm_layer=None,
121
+ bias=True,
122
+ drop=0.0,
123
+ use_conv=False,
124
+ ):
125
+ super().__init__()
126
+ out_features = out_features or in_features
127
+ hidden_features = hidden_features or in_features
128
+ bias = to_2tuple(bias)
129
+ drop_probs = to_2tuple(drop)
130
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
131
+
132
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
133
+ self.act = act_layer()
134
+ self.drop1 = nn.Dropout(drop_probs[0])
135
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
136
+ self.drop2 = nn.Dropout(drop_probs[1])
137
+
138
+ def forward(self, x):
139
+ x = self.fc1(x)
140
+ x = self.act(x)
141
+ x = self.drop1(x)
142
+ x = self.fc2(x)
143
+ x = self.drop2(x)
144
+ return x
145
+
146
+
147
+ class AttnBlock(nn.Module):
148
+ def __init__(
149
+ self,
150
+ hidden_size,
151
+ num_heads,
152
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
153
+ mlp_ratio=4.0,
154
+ **block_kwargs
155
+ ):
156
+ """
157
+ Self attention block
158
+ """
159
+ super().__init__()
160
+
161
+ self.norm1 = nn.LayerNorm(hidden_size)
162
+ self.norm2 = nn.LayerNorm(hidden_size)
163
+
164
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
165
+
166
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
167
+
168
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
169
+
170
+ def forward(self, x, mask=None):
171
+ # Prepare the mask for PyTorch's attention (it expects a different format)
172
+ # attn_mask = mask if mask is not None else None
173
+ # Normalize before attention
174
+ x = self.norm1(x)
175
+
176
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
177
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
178
+
179
+ attn_output, _ = self.attn(x, x, x)
180
+
181
+ # Add & Norm
182
+ x = x + attn_output
183
+ x = x + self.mlp(self.norm2(x))
184
+ return x
185
+
186
+
187
+ class CrossAttnBlock(nn.Module):
188
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
189
+ """
190
+ Cross attention block
191
+ """
192
+ super().__init__()
193
+
194
+ self.norm1 = nn.LayerNorm(hidden_size)
195
+ self.norm_context = nn.LayerNorm(hidden_size)
196
+ self.norm2 = nn.LayerNorm(hidden_size)
197
+
198
+ self.cross_attn = nn.MultiheadAttention(
199
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
200
+ )
201
+
202
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
203
+
204
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
205
+
206
+ def forward(self, x, context, mask=None):
207
+ # Normalize inputs
208
+ x = self.norm1(x)
209
+ context = self.norm_context(context)
210
+
211
+ # Apply cross attention
212
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
213
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
214
+
215
+ # Add & Norm
216
+ x = x + attn_output
217
+ x = x + self.mlp(self.norm2(x))
218
+ return x
vggt/heads/track_modules/utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/facebookresearch/vggsfm
8
+ # and https://github.com/facebookresearch/co-tracker/tree/main
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+
18
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
19
+ """
20
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
21
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
22
+ Args:
23
+ - embed_dim: The embedding dimension.
24
+ - grid_size: The grid size.
25
+ Returns:
26
+ - pos_embed: The generated 2D positional embedding.
27
+ """
28
+ if isinstance(grid_size, tuple):
29
+ grid_size_h, grid_size_w = grid_size
30
+ else:
31
+ grid_size_h = grid_size_w = grid_size
32
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
33
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
34
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
35
+ grid = torch.stack(grid, dim=0)
36
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
37
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
38
+ if return_grid:
39
+ return (
40
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
41
+ grid,
42
+ )
43
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
44
+
45
+
46
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
49
+
50
+ Args:
51
+ - embed_dim: The embedding dimension.
52
+ - grid: The grid to generate the embedding from.
53
+
54
+ Returns:
55
+ - emb: The generated 2D positional embedding.
56
+ """
57
+ assert embed_dim % 2 == 0
58
+
59
+ # use half of dimensions to encode grid_h
60
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
61
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
62
+
63
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
64
+ return emb
65
+
66
+
67
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
68
+ """
69
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
70
+
71
+ Args:
72
+ - embed_dim: The embedding dimension.
73
+ - pos: The position to generate the embedding from.
74
+
75
+ Returns:
76
+ - emb: The generated 1D positional embedding.
77
+ """
78
+ assert embed_dim % 2 == 0
79
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = torch.sin(out) # (M, D/2)
87
+ emb_cos = torch.cos(out) # (M, D/2)
88
+
89
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
90
+ return emb[None].float()
91
+
92
+
93
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
94
+ """
95
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
96
+
97
+ Args:
98
+ - xy: The coordinates to generate the embedding from.
99
+ - C: The size of the embedding.
100
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
101
+
102
+ Returns:
103
+ - pe: The generated 2D positional embedding.
104
+ """
105
+ B, N, D = xy.shape
106
+ assert D == 2
107
+
108
+ x = xy[:, :, 0:1]
109
+ y = xy[:, :, 1:2]
110
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
111
+
112
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
113
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
114
+
115
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
116
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
117
+
118
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
119
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
120
+
121
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
122
+ if cat_coords:
123
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
124
+ return pe
125
+
126
+
127
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
128
+ r"""Sample a tensor using bilinear interpolation
129
+
130
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
131
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
132
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
133
+ convention.
134
+
135
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
136
+ :math:`B` is the batch size, :math:`C` is the number of channels,
137
+ :math:`H` is the height of the image, and :math:`W` is the width of the
138
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
139
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
140
+
141
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
142
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
143
+ that in this case the order of the components is slightly different
144
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
145
+
146
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
147
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
148
+ left-most image pixel :math:`W-1` to the center of the right-most
149
+ pixel.
150
+
151
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
152
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
153
+ the left-most pixel :math:`W` to the right edge of the right-most
154
+ pixel.
155
+
156
+ Similar conventions apply to the :math:`y` for the range
157
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
158
+ :math:`[0,T-1]` and :math:`[0,T]`.
159
+
160
+ Args:
161
+ input (Tensor): batch of input images.
162
+ coords (Tensor): batch of coordinates.
163
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
164
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
165
+
166
+ Returns:
167
+ Tensor: sampled points.
168
+ """
169
+ coords = coords.detach().clone()
170
+ ############################################################
171
+ # IMPORTANT:
172
+ coords = coords.to(input.device).to(input.dtype)
173
+ ############################################################
174
+
175
+ sizes = input.shape[2:]
176
+
177
+ assert len(sizes) in [2, 3]
178
+
179
+ if len(sizes) == 3:
180
+ # t x y -> x y t to match dimensions T H W in grid_sample
181
+ coords = coords[..., [1, 2, 0]]
182
+
183
+ if align_corners:
184
+ scale = torch.tensor(
185
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
186
+ )
187
+ else:
188
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
189
+
190
+ coords.mul_(scale) # coords = coords * scale
191
+ coords.sub_(1) # coords = coords - 1
192
+
193
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
194
+
195
+
196
+ def sample_features4d(input, coords):
197
+ r"""Sample spatial features
198
+
199
+ `sample_features4d(input, coords)` samples the spatial features
200
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
201
+
202
+ The field is sampled at coordinates :attr:`coords` using bilinear
203
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
204
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
205
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
206
+
207
+ The output tensor has one feature per point, and has shape :math:`(B,
208
+ R, C)`.
209
+
210
+ Args:
211
+ input (Tensor): spatial features.
212
+ coords (Tensor): points.
213
+
214
+ Returns:
215
+ Tensor: sampled features.
216
+ """
217
+
218
+ B, _, _, _ = input.shape
219
+
220
+ # B R 2 -> B R 1 2
221
+ coords = coords.unsqueeze(2)
222
+
223
+ # B C R 1
224
+ feats = bilinear_sampler(input, coords)
225
+
226
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
vggt/heads/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
12
+ """
13
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
14
+
15
+ Args:
16
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
17
+ embed_dim: Output channel dimension for embeddings
18
+
19
+ Returns:
20
+ Tensor of shape (H, W, embed_dim) with positional embeddings
21
+ """
22
+ H, W, grid_dim = pos_grid.shape
23
+ assert grid_dim == 2
24
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
25
+
26
+ # Process x and y coordinates separately
27
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
28
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
29
+
30
+ # Combine and reshape
31
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
32
+
33
+ return emb.view(H, W, embed_dim) # [H, W, D]
34
+
35
+
36
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
37
+ """
38
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - pos: The position to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 1D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
49
+ omega /= embed_dim / 2.0
50
+ omega = 1.0 / omega_0**omega # (D/2,)
51
+
52
+ pos = pos.reshape(-1) # (M,)
53
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
54
+
55
+ emb_sin = torch.sin(out) # (M, D/2)
56
+ emb_cos = torch.cos(out) # (M, D/2)
57
+
58
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
59
+ return emb.float()
60
+
61
+
62
+ # Inspired by https://github.com/microsoft/moge
63
+
64
+
65
+ def create_uv_grid(
66
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
67
+ ) -> torch.Tensor:
68
+ """
69
+ Create a normalized UV grid of shape (width, height, 2).
70
+
71
+ The grid spans horizontally and vertically according to an aspect ratio,
72
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
73
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
74
+
75
+ Args:
76
+ width (int): Number of points horizontally.
77
+ height (int): Number of points vertically.
78
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
79
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
80
+ device (torch.device, optional): Device on which the tensor is created.
81
+
82
+ Returns:
83
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
84
+ """
85
+ # Derive aspect ratio if not explicitly provided
86
+ if aspect_ratio is None:
87
+ aspect_ratio = float(width) / float(height)
88
+
89
+ # Compute normalized spans for X and Y
90
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
91
+ span_x = aspect_ratio / diag_factor
92
+ span_y = 1.0 / diag_factor
93
+
94
+ # Establish the linspace boundaries
95
+ left_x = -span_x * (width - 1) / width
96
+ right_x = span_x * (width - 1) / width
97
+ top_y = -span_y * (height - 1) / height
98
+ bottom_y = span_y * (height - 1) / height
99
+
100
+ # Generate 1D coordinates
101
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
102
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
103
+
104
+ # Create 2D meshgrid (width x height) and stack into UV
105
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
106
+ uv_grid = torch.stack((uu, vv), dim=-1)
107
+
108
+ return uv_grid
vggt/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
vggt/layers/attention.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+ XFORMERS_AVAILABLE = False
19
+
20
+
21
+ class Attention(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ num_heads: int = 8,
26
+ qkv_bias: bool = True,
27
+ proj_bias: bool = True,
28
+ attn_drop: float = 0.0,
29
+ proj_drop: float = 0.0,
30
+ norm_layer: nn.Module = nn.LayerNorm,
31
+ qk_norm: bool = False,
32
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
33
+ rope=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
37
+ self.num_heads = num_heads
38
+ self.head_dim = dim // num_heads
39
+ self.scale = self.head_dim**-0.5
40
+ self.fused_attn = fused_attn
41
+
42
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
43
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
44
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+ self.rope = rope
49
+
50
+ def forward(self, x: Tensor, pos=None) -> Tensor:
51
+ B, N, C = x.shape
52
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
53
+ q, k, v = qkv.unbind(0)
54
+ q, k = self.q_norm(q), self.k_norm(k)
55
+
56
+ if self.rope is not None:
57
+ q = self.rope(q, pos)
58
+ k = self.rope(k, pos)
59
+
60
+ if self.fused_attn:
61
+ x = F.scaled_dot_product_attention(
62
+ q,
63
+ k,
64
+ v,
65
+ dropout_p=self.attn_drop.p if self.training else 0.0,
66
+ )
67
+ else:
68
+ q = q * self.scale
69
+ attn = q @ k.transpose(-2, -1)
70
+ attn = attn.softmax(dim=-1)
71
+ attn = self.attn_drop(attn)
72
+ x = attn @ v
73
+
74
+ x = x.transpose(1, 2).reshape(B, N, C)
75
+ x = self.proj(x)
76
+ x = self.proj_drop(x)
77
+ return x
78
+
79
+
80
+ class MemEffAttention(Attention):
81
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
82
+ assert pos is None
83
+ if not XFORMERS_AVAILABLE:
84
+ if attn_bias is not None:
85
+ raise AssertionError("xFormers is required for using nested tensors")
86
+ return super().forward(x)
87
+
88
+ B, N, C = x.shape
89
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
90
+
91
+ q, k, v = unbind(qkv, 2)
92
+
93
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
94
+ x = x.reshape([B, N, C])
95
+
96
+ x = self.proj(x)
97
+ x = self.proj_drop(x)
98
+ return x
vggt/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ XFORMERS_AVAILABLE = False
25
+
26
+
27
+ class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ qkv_bias: bool = True,
34
+ proj_bias: bool = True,
35
+ ffn_bias: bool = True,
36
+ drop: float = 0.0,
37
+ attn_drop: float = 0.0,
38
+ init_values=None,
39
+ drop_path: float = 0.0,
40
+ act_layer: Callable[..., nn.Module] = nn.GELU,
41
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
42
+ attn_class: Callable[..., nn.Module] = Attention,
43
+ ffn_layer: Callable[..., nn.Module] = Mlp,
44
+ qk_norm: bool = False,
45
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
46
+ rope=None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self.norm1 = norm_layer(dim)
51
+
52
+ self.attn = attn_class(
53
+ dim,
54
+ num_heads=num_heads,
55
+ qkv_bias=qkv_bias,
56
+ proj_bias=proj_bias,
57
+ attn_drop=attn_drop,
58
+ proj_drop=drop,
59
+ qk_norm=qk_norm,
60
+ fused_attn=fused_attn,
61
+ rope=rope,
62
+ )
63
+
64
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
65
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
66
+
67
+ self.norm2 = norm_layer(dim)
68
+ mlp_hidden_dim = int(dim * mlp_ratio)
69
+ self.mlp = ffn_layer(
70
+ in_features=dim,
71
+ hidden_features=mlp_hidden_dim,
72
+ act_layer=act_layer,
73
+ drop=drop,
74
+ bias=ffn_bias,
75
+ )
76
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
77
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
78
+
79
+ self.sample_drop_ratio = drop_path
80
+
81
+ def forward(self, x: Tensor, pos=None) -> Tensor:
82
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
83
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
84
+
85
+ def ffn_residual_func(x: Tensor) -> Tensor:
86
+ return self.ls2(self.mlp(self.norm2(x)))
87
+
88
+ if self.training and self.sample_drop_ratio > 0.1:
89
+ # the overhead is compensated only for a drop path rate larger than 0.1
90
+ x = drop_add_residual_stochastic_depth(
91
+ x,
92
+ pos=pos,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x, pos=pos)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ pos=None,
115
+ ) -> Tensor:
116
+ # 1) extract subset using permutation
117
+ b, n, d = x.shape
118
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
119
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
120
+ x_subset = x[brange]
121
+
122
+ # 2) apply residual_func to get residual
123
+ if pos is not None:
124
+ # if necessary, apply rope to the subset
125
+ pos = pos[brange]
126
+ residual = residual_func(x_subset, pos=pos)
127
+ else:
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
vggt/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
vggt/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
vggt/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
vggt/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
vggt/layers/rope.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # Implementation of 2D Rotary Position Embeddings (RoPE).
8
+
9
+ # This module provides a clean implementation of 2D Rotary Position Embeddings,
10
+ # which extends the original RoPE concept to handle 2D spatial positions.
11
+
12
+ # Inspired by:
13
+ # https://github.com/meta-llama/codellama/blob/main/llama/model.py
14
+ # https://github.com/naver-ai/rope-vit
15
+
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from typing import Dict, Tuple
22
+
23
+
24
+ class PositionGetter:
25
+ """Generates and caches 2D spatial positions for patches in a grid.
26
+
27
+ This class efficiently manages the generation of spatial coordinates for patches
28
+ in a 2D grid, caching results to avoid redundant computations.
29
+
30
+ Attributes:
31
+ position_cache: Dictionary storing precomputed position tensors for different
32
+ grid dimensions.
33
+ """
34
+
35
+ def __init__(self):
36
+ """Initializes the position generator with an empty cache."""
37
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
38
+
39
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
40
+ """Generates spatial positions for a batch of patches.
41
+
42
+ Args:
43
+ batch_size: Number of samples in the batch.
44
+ height: Height of the grid in patches.
45
+ width: Width of the grid in patches.
46
+ device: Target device for the position tensor.
47
+
48
+ Returns:
49
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
50
+ for each position in the grid, repeated for each batch item.
51
+ """
52
+ if (height, width) not in self.position_cache:
53
+ y_coords = torch.arange(height, device=device)
54
+ x_coords = torch.arange(width, device=device)
55
+ positions = torch.cartesian_prod(y_coords, x_coords)
56
+ self.position_cache[height, width] = positions
57
+
58
+ cached_positions = self.position_cache[height, width]
59
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
60
+
61
+
62
+ class RotaryPositionEmbedding2D(nn.Module):
63
+ """2D Rotary Position Embedding implementation.
64
+
65
+ This module applies rotary position embeddings to input tokens based on their
66
+ 2D spatial positions. It handles the position-dependent rotation of features
67
+ separately for vertical and horizontal dimensions.
68
+
69
+ Args:
70
+ frequency: Base frequency for the position embeddings. Default: 100.0
71
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
72
+
73
+ Attributes:
74
+ base_frequency: Base frequency for computing position embeddings.
75
+ scaling_factor: Factor to scale the computed frequencies.
76
+ frequency_cache: Cache for storing precomputed frequency components.
77
+ """
78
+
79
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
80
+ """Initializes the 2D RoPE module."""
81
+ super().__init__()
82
+ self.base_frequency = frequency
83
+ self.scaling_factor = scaling_factor
84
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
85
+
86
+ def _compute_frequency_components(
87
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """Computes frequency components for rotary embeddings.
90
+
91
+ Args:
92
+ dim: Feature dimension (must be even).
93
+ seq_len: Maximum sequence length.
94
+ device: Target device for computations.
95
+ dtype: Data type for the computed tensors.
96
+
97
+ Returns:
98
+ Tuple of (cosine, sine) tensors for frequency components.
99
+ """
100
+ cache_key = (dim, seq_len, device, dtype)
101
+ if cache_key not in self.frequency_cache:
102
+ # Compute frequency bands
103
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
104
+ inv_freq = 1.0 / (self.base_frequency**exponents)
105
+
106
+ # Generate position-dependent frequencies
107
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
108
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
109
+
110
+ # Compute and cache frequency components
111
+ angles = angles.to(dtype)
112
+ angles = torch.cat((angles, angles), dim=-1)
113
+ cos_components = angles.cos().to(dtype)
114
+ sin_components = angles.sin().to(dtype)
115
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
116
+
117
+ return self.frequency_cache[cache_key]
118
+
119
+ @staticmethod
120
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
121
+ """Performs feature rotation by splitting and recombining feature dimensions.
122
+
123
+ Args:
124
+ x: Input tensor to rotate.
125
+
126
+ Returns:
127
+ Rotated feature tensor.
128
+ """
129
+ feature_dim = x.shape[-1]
130
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
131
+ return torch.cat((-x2, x1), dim=-1)
132
+
133
+ def _apply_1d_rope(
134
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
135
+ ) -> torch.Tensor:
136
+ """Applies 1D rotary position embeddings along one dimension.
137
+
138
+ Args:
139
+ tokens: Input token features.
140
+ positions: Position indices.
141
+ cos_comp: Cosine components for rotation.
142
+ sin_comp: Sine components for rotation.
143
+
144
+ Returns:
145
+ Tokens with applied rotary position embeddings.
146
+ """
147
+ # Embed positions with frequency components
148
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
149
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
150
+
151
+ # Apply rotation
152
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
153
+
154
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
155
+ """Applies 2D rotary position embeddings to input tokens.
156
+
157
+ Args:
158
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
159
+ The feature dimension (dim) must be divisible by 4.
160
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
161
+ the y and x coordinates for each token.
162
+
163
+ Returns:
164
+ Tensor of same shape as input with applied 2D rotary position embeddings.
165
+
166
+ Raises:
167
+ AssertionError: If input dimensions are invalid or positions are malformed.
168
+ """
169
+ # Validate inputs
170
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
171
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
172
+
173
+ # Compute feature dimension for each spatial direction
174
+ feature_dim = tokens.size(-1) // 2
175
+
176
+ # Get frequency components
177
+ max_position = int(positions.max()) + 1
178
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
179
+
180
+ # Split features for vertical and horizontal processing
181
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
182
+
183
+ # Apply RoPE separately for each dimension
184
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
185
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
186
+
187
+ # Combine processed features
188
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
vggt/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ # try:
39
+ # if XFORMERS_ENABLED:
40
+ # from xformers.ops import SwiGLU
41
+
42
+ # XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ # else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ # raise ImportError
47
+ # except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
vggt/layers/vision_transformer.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.checkpoint import checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
20
+
21
+ logger = logging.getLogger("dinov2")
22
+
23
+
24
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
25
+ if not depth_first and include_root:
26
+ fn(module=module, name=name)
27
+ for child_name, child_module in module.named_children():
28
+ child_name = ".".join((name, child_name)) if name else child_name
29
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
30
+ if depth_first and include_root:
31
+ fn(module=module, name=name)
32
+ return module
33
+
34
+
35
+ class BlockChunk(nn.ModuleList):
36
+ def forward(self, x):
37
+ for b in self:
38
+ x = b(x)
39
+ return x
40
+
41
+
42
+ class DinoVisionTransformer(nn.Module):
43
+ def __init__(
44
+ self,
45
+ img_size=224,
46
+ patch_size=16,
47
+ in_chans=3,
48
+ embed_dim=768,
49
+ depth=12,
50
+ num_heads=12,
51
+ mlp_ratio=4.0,
52
+ qkv_bias=True,
53
+ ffn_bias=True,
54
+ proj_bias=True,
55
+ drop_path_rate=0.0,
56
+ drop_path_uniform=False,
57
+ init_values=None, # for layerscale: None or 0 => no layerscale
58
+ embed_layer=PatchEmbed,
59
+ act_layer=nn.GELU,
60
+ block_fn=Block,
61
+ ffn_layer="mlp",
62
+ block_chunks=1,
63
+ num_register_tokens=0,
64
+ interpolate_antialias=False,
65
+ interpolate_offset=0.1,
66
+ qk_norm=False,
67
+ ):
68
+ """
69
+ Args:
70
+ img_size (int, tuple): input image size
71
+ patch_size (int, tuple): patch size
72
+ in_chans (int): number of input channels
73
+ embed_dim (int): embedding dimension
74
+ depth (int): depth of transformer
75
+ num_heads (int): number of attention heads
76
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
77
+ qkv_bias (bool): enable bias for qkv if True
78
+ proj_bias (bool): enable bias for proj in attn if True
79
+ ffn_bias (bool): enable bias for ffn if True
80
+ drop_path_rate (float): stochastic depth rate
81
+ drop_path_uniform (bool): apply uniform drop rate across blocks
82
+ weight_init (str): weight init scheme
83
+ init_values (float): layer-scale init values
84
+ embed_layer (nn.Module): patch embedding layer
85
+ act_layer (nn.Module): MLP activation layer
86
+ block_fn (nn.Module): transformer block class
87
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
88
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
89
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
90
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
91
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
92
+ """
93
+ super().__init__()
94
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
95
+
96
+ # tricky but makes it work
97
+ self.use_checkpoint = False
98
+ #
99
+
100
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
101
+ self.num_tokens = 1
102
+ self.n_blocks = depth
103
+ self.num_heads = num_heads
104
+ self.patch_size = patch_size
105
+ self.num_register_tokens = num_register_tokens
106
+ self.interpolate_antialias = interpolate_antialias
107
+ self.interpolate_offset = interpolate_offset
108
+
109
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
110
+ num_patches = self.patch_embed.num_patches
111
+
112
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
113
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
114
+ assert num_register_tokens >= 0
115
+ self.register_tokens = (
116
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
117
+ )
118
+
119
+ if drop_path_uniform is True:
120
+ dpr = [drop_path_rate] * depth
121
+ else:
122
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
123
+
124
+ if ffn_layer == "mlp":
125
+ logger.info("using MLP layer as FFN")
126
+ ffn_layer = Mlp
127
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
128
+ logger.info("using SwiGLU layer as FFN")
129
+ ffn_layer = SwiGLUFFNFused
130
+ elif ffn_layer == "identity":
131
+ logger.info("using Identity layer as FFN")
132
+
133
+ def f(*args, **kwargs):
134
+ return nn.Identity()
135
+
136
+ ffn_layer = f
137
+ else:
138
+ raise NotImplementedError
139
+
140
+ blocks_list = [
141
+ block_fn(
142
+ dim=embed_dim,
143
+ num_heads=num_heads,
144
+ mlp_ratio=mlp_ratio,
145
+ qkv_bias=qkv_bias,
146
+ proj_bias=proj_bias,
147
+ ffn_bias=ffn_bias,
148
+ drop_path=dpr[i],
149
+ norm_layer=norm_layer,
150
+ act_layer=act_layer,
151
+ ffn_layer=ffn_layer,
152
+ init_values=init_values,
153
+ qk_norm=qk_norm,
154
+ )
155
+ for i in range(depth)
156
+ ]
157
+ if block_chunks > 0:
158
+ self.chunked_blocks = True
159
+ chunked_blocks = []
160
+ chunksize = depth // block_chunks
161
+ for i in range(0, depth, chunksize):
162
+ # this is to keep the block index consistent if we chunk the block list
163
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
164
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
165
+ else:
166
+ self.chunked_blocks = False
167
+ self.blocks = nn.ModuleList(blocks_list)
168
+
169
+ self.norm = norm_layer(embed_dim)
170
+ self.head = nn.Identity()
171
+
172
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
173
+
174
+ self.init_weights()
175
+
176
+ def init_weights(self):
177
+ trunc_normal_(self.pos_embed, std=0.02)
178
+ nn.init.normal_(self.cls_token, std=1e-6)
179
+ if self.register_tokens is not None:
180
+ nn.init.normal_(self.register_tokens, std=1e-6)
181
+ named_apply(init_weights_vit_timm, self)
182
+
183
+ def interpolate_pos_encoding(self, x, w, h):
184
+ previous_dtype = x.dtype
185
+ npatch = x.shape[1] - 1
186
+ N = self.pos_embed.shape[1] - 1
187
+ if npatch == N and w == h:
188
+ return self.pos_embed
189
+ pos_embed = self.pos_embed.float()
190
+ class_pos_embed = pos_embed[:, 0]
191
+ patch_pos_embed = pos_embed[:, 1:]
192
+ dim = x.shape[-1]
193
+ w0 = w // self.patch_size
194
+ h0 = h // self.patch_size
195
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
196
+ assert N == M * M
197
+ kwargs = {}
198
+ if self.interpolate_offset:
199
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
200
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
201
+ sx = float(w0 + self.interpolate_offset) / M
202
+ sy = float(h0 + self.interpolate_offset) / M
203
+ kwargs["scale_factor"] = (sx, sy)
204
+ else:
205
+ # Simply specify an output size instead of a scale factor
206
+ kwargs["size"] = (w0, h0)
207
+ patch_pos_embed = nn.functional.interpolate(
208
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
209
+ mode="bicubic",
210
+ antialias=self.interpolate_antialias,
211
+ **kwargs,
212
+ )
213
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
214
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
215
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
216
+
217
+ def prepare_tokens_with_masks(self, x, masks=None):
218
+ B, nc, w, h = x.shape
219
+ x = self.patch_embed(x)
220
+ if masks is not None:
221
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
222
+
223
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
224
+ x = x + self.interpolate_pos_encoding(x, w, h)
225
+
226
+ if self.register_tokens is not None:
227
+ x = torch.cat(
228
+ (
229
+ x[:, :1],
230
+ self.register_tokens.expand(x.shape[0], -1, -1),
231
+ x[:, 1:],
232
+ ),
233
+ dim=1,
234
+ )
235
+
236
+ return x
237
+
238
+ def forward_features_list(self, x_list, masks_list):
239
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
240
+
241
+ for blk in self.blocks:
242
+ if self.use_checkpoint:
243
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
244
+ else:
245
+ x = blk(x)
246
+
247
+ all_x = x
248
+ output = []
249
+ for x, masks in zip(all_x, masks_list):
250
+ x_norm = self.norm(x)
251
+ output.append(
252
+ {
253
+ "x_norm_clstoken": x_norm[:, 0],
254
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
255
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
256
+ "x_prenorm": x,
257
+ "masks": masks,
258
+ }
259
+ )
260
+ return output
261
+
262
+ def forward_features(self, x, masks=None):
263
+ if isinstance(x, list):
264
+ return self.forward_features_list(x, masks)
265
+
266
+ x = self.prepare_tokens_with_masks(x, masks)
267
+
268
+ for blk in self.blocks:
269
+ if self.use_checkpoint:
270
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
271
+ else:
272
+ x = blk(x)
273
+
274
+ x_norm = self.norm(x)
275
+ return {
276
+ "x_norm_clstoken": x_norm[:, 0],
277
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
278
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
279
+ "x_prenorm": x,
280
+ "masks": masks,
281
+ }
282
+
283
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ # If n is an int, take the n last blocks. If it's a list, take them
286
+ output, total_block_len = [], len(self.blocks)
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for i, blk in enumerate(self.blocks):
289
+ x = blk(x)
290
+ if i in blocks_to_take:
291
+ output.append(x)
292
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
293
+ return output
294
+
295
+ def _get_intermediate_layers_chunked(self, x, n=1):
296
+ x = self.prepare_tokens_with_masks(x)
297
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
298
+ # If n is an int, take the n last blocks. If it's a list, take them
299
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
300
+ for block_chunk in self.blocks:
301
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
302
+ x = blk(x)
303
+ if i in blocks_to_take:
304
+ output.append(x)
305
+ i += 1
306
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
307
+ return output
308
+
309
+ def get_intermediate_layers(
310
+ self,
311
+ x: torch.Tensor,
312
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
313
+ reshape: bool = False,
314
+ return_class_token: bool = False,
315
+ norm=True,
316
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
317
+ if self.chunked_blocks:
318
+ outputs = self._get_intermediate_layers_chunked(x, n)
319
+ else:
320
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
321
+ if norm:
322
+ outputs = [self.norm(out) for out in outputs]
323
+ class_tokens = [out[:, 0] for out in outputs]
324
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
325
+ if reshape:
326
+ B, _, w, h = x.shape
327
+ outputs = [
328
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
329
+ for out in outputs
330
+ ]
331
+ if return_class_token:
332
+ return tuple(zip(outputs, class_tokens))
333
+ return tuple(outputs)
334
+
335
+ def forward(self, *args, is_training=True, **kwargs):
336
+ ret = self.forward_features(*args, **kwargs)
337
+ if is_training:
338
+ return ret
339
+ else:
340
+ return self.head(ret["x_norm_clstoken"])
341
+
342
+
343
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
344
+ """ViT weight initialization, original timm impl (for reproducibility)"""
345
+ if isinstance(module, nn.Linear):
346
+ trunc_normal_(module.weight, std=0.02)
347
+ if module.bias is not None:
348
+ nn.init.zeros_(module.bias)
349
+
350
+
351
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
352
+ model = DinoVisionTransformer(
353
+ patch_size=patch_size,
354
+ embed_dim=384,
355
+ depth=12,
356
+ num_heads=6,
357
+ mlp_ratio=4,
358
+ block_fn=partial(Block, attn_class=MemEffAttention),
359
+ num_register_tokens=num_register_tokens,
360
+ **kwargs,
361
+ )
362
+ return model
363
+
364
+
365
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
366
+ model = DinoVisionTransformer(
367
+ patch_size=patch_size,
368
+ embed_dim=768,
369
+ depth=12,
370
+ num_heads=12,
371
+ mlp_ratio=4,
372
+ block_fn=partial(Block, attn_class=MemEffAttention),
373
+ num_register_tokens=num_register_tokens,
374
+ **kwargs,
375
+ )
376
+ return model
377
+
378
+
379
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
380
+ model = DinoVisionTransformer(
381
+ patch_size=patch_size,
382
+ embed_dim=1024,
383
+ depth=24,
384
+ num_heads=16,
385
+ mlp_ratio=4,
386
+ block_fn=partial(Block, attn_class=MemEffAttention),
387
+ num_register_tokens=num_register_tokens,
388
+ **kwargs,
389
+ )
390
+ return model
391
+
392
+
393
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
394
+ """
395
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
396
+ """
397
+ model = DinoVisionTransformer(
398
+ patch_size=patch_size,
399
+ embed_dim=1536,
400
+ depth=40,
401
+ num_heads=24,
402
+ mlp_ratio=4,
403
+ block_fn=partial(Block, attn_class=MemEffAttention),
404
+ num_register_tokens=num_register_tokens,
405
+ **kwargs,
406
+ )
407
+ return model
vggt/models/aggregator.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional, Tuple, Union, List, Dict, Any
12
+
13
+ from vggt.layers import PatchEmbed
14
+ from vggt.layers.block import Block
15
+ from vggt.layers.rope import RotaryPositionEmbedding2D, PositionGetter
16
+ from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
21
+ _RESNET_STD = [0.229, 0.224, 0.225]
22
+
23
+
24
+ class Aggregator(nn.Module):
25
+ """
26
+ The Aggregator applies alternating-attention over input frames,
27
+ as described in VGGT: Visual Geometry Grounded Transformer.
28
+
29
+
30
+ Args:
31
+ img_size (int): Image size in pixels.
32
+ patch_size (int): Size of each patch for PatchEmbed.
33
+ embed_dim (int): Dimension of the token embeddings.
34
+ depth (int): Number of blocks.
35
+ num_heads (int): Number of attention heads.
36
+ mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
37
+ num_register_tokens (int): Number of register tokens.
38
+ block_fn (nn.Module): The block type used for attention (Block by default).
39
+ qkv_bias (bool): Whether to include bias in QKV projections.
40
+ proj_bias (bool): Whether to include bias in the output projection.
41
+ ffn_bias (bool): Whether to include bias in MLP layers.
42
+ patch_embed (str): Type of patch embed. e.g., "conv" or "dinov2_vitl14_reg".
43
+ aa_order (list[str]): The order of alternating attention, e.g. ["frame", "global"].
44
+ aa_block_size (int): How many blocks to group under each attention type before switching. If not necessary, set to 1.
45
+ qk_norm (bool): Whether to apply QK normalization.
46
+ rope_freq (int): Base frequency for rotary embedding. -1 to disable.
47
+ init_values (float): Init scale for layer scale.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ img_size=518,
53
+ patch_size=14,
54
+ embed_dim=1024,
55
+ depth=24,
56
+ num_heads=16,
57
+ mlp_ratio=4.0,
58
+ num_register_tokens=4,
59
+ block_fn=Block,
60
+ qkv_bias=True,
61
+ proj_bias=True,
62
+ ffn_bias=True,
63
+ patch_embed="dinov2_vitl14_reg",
64
+ aa_order=["frame", "global"],
65
+ aa_block_size=1,
66
+ qk_norm=True,
67
+ rope_freq=100,
68
+ init_values=0.01,
69
+ ):
70
+ super().__init__()
71
+
72
+ self.__build_patch_embed__(patch_embed, img_size, patch_size, num_register_tokens, embed_dim=embed_dim)
73
+
74
+ # Initialize rotary position embedding if frequency > 0
75
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
76
+ self.position_getter = PositionGetter() if self.rope is not None else None
77
+
78
+ self.frame_blocks = nn.ModuleList(
79
+ [
80
+ block_fn(
81
+ dim=embed_dim,
82
+ num_heads=num_heads,
83
+ mlp_ratio=mlp_ratio,
84
+ qkv_bias=qkv_bias,
85
+ proj_bias=proj_bias,
86
+ ffn_bias=ffn_bias,
87
+ init_values=init_values,
88
+ qk_norm=qk_norm,
89
+ rope=self.rope,
90
+ )
91
+ for _ in range(depth)
92
+ ]
93
+ )
94
+
95
+ self.global_blocks = nn.ModuleList(
96
+ [
97
+ block_fn(
98
+ dim=embed_dim,
99
+ num_heads=num_heads,
100
+ mlp_ratio=mlp_ratio,
101
+ qkv_bias=qkv_bias,
102
+ proj_bias=proj_bias,
103
+ ffn_bias=ffn_bias,
104
+ init_values=init_values,
105
+ qk_norm=qk_norm,
106
+ rope=self.rope,
107
+ )
108
+ for _ in range(depth)
109
+ ]
110
+ )
111
+
112
+ self.depth = depth
113
+ self.aa_order = aa_order
114
+ self.patch_size = patch_size
115
+ self.aa_block_size = aa_block_size
116
+
117
+ # Validate that depth is divisible by aa_block_size
118
+ if self.depth % self.aa_block_size != 0:
119
+ raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
120
+
121
+ self.aa_block_num = self.depth // self.aa_block_size
122
+
123
+ # Note: We have two camera tokens, one for the first frame and one for the rest
124
+ # The same applies for register tokens
125
+ self.camera_token = nn.Parameter(torch.randn(1, 2, 1, embed_dim))
126
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, embed_dim))
127
+
128
+ # The patch tokens start after the camera and register tokens
129
+ self.patch_start_idx = 1 + num_register_tokens
130
+
131
+ # Initialize parameters with small values
132
+ nn.init.normal_(self.camera_token, std=1e-6)
133
+ nn.init.normal_(self.register_token, std=1e-6)
134
+
135
+ # Register normalization constants as buffers
136
+ for name, value in (
137
+ ("_resnet_mean", _RESNET_MEAN),
138
+ ("_resnet_std", _RESNET_STD),
139
+ ):
140
+ self.register_buffer(
141
+ name,
142
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
143
+ persistent=False,
144
+ )
145
+
146
+ def __build_patch_embed__(
147
+ self,
148
+ patch_embed,
149
+ img_size,
150
+ patch_size,
151
+ num_register_tokens,
152
+ interpolate_antialias=True,
153
+ interpolate_offset=0.0,
154
+ block_chunks=0,
155
+ init_values=1.0,
156
+ embed_dim=1024,
157
+ ):
158
+ """
159
+ Build the patch embed layer. If 'conv', we use a
160
+ simple PatchEmbed conv layer. Otherwise, we use a vision transformer.
161
+ """
162
+
163
+ if "conv" in patch_embed:
164
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
165
+ else:
166
+ vit_models = {
167
+ "dinov2_vitl14_reg": vit_large,
168
+ "dinov2_vitb14_reg": vit_base,
169
+ "dinov2_vits14_reg": vit_small,
170
+ "dinov2_vitg2_reg": vit_giant2,
171
+ }
172
+
173
+ self.patch_embed = vit_models[patch_embed](
174
+ img_size=img_size,
175
+ patch_size=patch_size,
176
+ num_register_tokens=num_register_tokens,
177
+ interpolate_antialias=interpolate_antialias,
178
+ interpolate_offset=interpolate_offset,
179
+ block_chunks=block_chunks,
180
+ init_values=init_values,
181
+ )
182
+
183
+ # Disable gradient updates for mask token
184
+ if hasattr(self.patch_embed, "mask_token"):
185
+ self.patch_embed.mask_token.requires_grad_(False)
186
+
187
+ def forward(
188
+ self,
189
+ images: torch.Tensor,
190
+ ) -> Tuple[List[torch.Tensor], int]:
191
+ """
192
+ Args:
193
+ images (torch.Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
194
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
195
+
196
+ Returns:
197
+ (list[torch.Tensor], int):
198
+ The list of outputs from the attention blocks,
199
+ and the patch_start_idx indicating where patch tokens begin.
200
+ """
201
+ B, S, C_in, H, W = images.shape
202
+
203
+ if C_in != 3:
204
+ raise ValueError(f"Expected 3 input channels, got {C_in}")
205
+
206
+ # Normalize images and reshape for patch embed
207
+ images = (images - self._resnet_mean) / self._resnet_std
208
+
209
+ # Reshape to [B*S, C, H, W] for patch embedding
210
+ images = images.view(B * S, C_in, H, W)
211
+ patch_tokens = self.patch_embed(images)
212
+
213
+ if isinstance(patch_tokens, dict):
214
+ patch_tokens = patch_tokens["x_norm_patchtokens"]
215
+
216
+ _, P, C = patch_tokens.shape
217
+
218
+ # Expand camera and register tokens to match batch size and sequence length
219
+ camera_token = slice_expand_and_flatten(self.camera_token, B, S)
220
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
221
+
222
+ # Concatenate special tokens with patch tokens
223
+ tokens = torch.cat([camera_token, register_token, patch_tokens], dim=1)
224
+
225
+ pos = None
226
+ if self.rope is not None:
227
+ pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=images.device)
228
+
229
+ if self.patch_start_idx > 0:
230
+ # do not use position embedding for special tokens (camera and register tokens)
231
+ # so set pos to 0 for the special tokens
232
+ pos = pos + 1
233
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(images.device).to(pos.dtype)
234
+ pos = torch.cat([pos_special, pos], dim=1)
235
+
236
+ # update P because we added special tokens
237
+ _, P, C = tokens.shape
238
+
239
+ frame_idx = 0
240
+ global_idx = 0
241
+ output_list = []
242
+
243
+ for _ in range(self.aa_block_num):
244
+ for attn_type in self.aa_order:
245
+ if attn_type == "frame":
246
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
247
+ tokens, B, S, P, C, frame_idx, pos=pos
248
+ )
249
+ elif attn_type == "global":
250
+ tokens, global_idx, global_intermediates = self._process_global_attention(
251
+ tokens, B, S, P, C, global_idx, pos=pos
252
+ )
253
+ else:
254
+ raise ValueError(f"Unknown attention type: {attn_type}")
255
+
256
+ for i in range(len(frame_intermediates)):
257
+ # concat frame and global intermediates, [B x S x P x 2C]
258
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
259
+ output_list.append(concat_inter)
260
+
261
+ del concat_inter
262
+ del frame_intermediates
263
+ del global_intermediates
264
+ return output_list, self.patch_start_idx
265
+
266
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, pos=None):
267
+ """
268
+ Process frame attention blocks. We keep tokens in shape (B*S, P, C).
269
+ """
270
+ # If needed, reshape tokens or positions:
271
+ if tokens.shape != (B * S, P, C):
272
+ tokens = tokens.view(B, S, P, C).view(B * S, P, C)
273
+
274
+ if pos is not None and pos.shape != (B * S, P, 2):
275
+ pos = pos.view(B, S, P, 2).view(B * S, P, 2)
276
+
277
+ intermediates = []
278
+
279
+ # by default, self.aa_block_size=1, which processes one block at a time
280
+ for _ in range(self.aa_block_size):
281
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
282
+ frame_idx += 1
283
+ intermediates.append(tokens.view(B, S, P, C))
284
+
285
+ return tokens, frame_idx, intermediates
286
+
287
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, pos=None):
288
+ """
289
+ Process global attention blocks. We keep tokens in shape (B, S*P, C).
290
+ """
291
+ if tokens.shape != (B, S * P, C):
292
+ tokens = tokens.view(B, S, P, C).view(B, S * P, C)
293
+
294
+ if pos is not None and pos.shape != (B, S * P, 2):
295
+ pos = pos.view(B, S, P, 2).view(B, S * P, 2)
296
+
297
+ intermediates = []
298
+
299
+ # by default, self.aa_block_size=1, which processes one block at a time
300
+ for _ in range(self.aa_block_size):
301
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
302
+ global_idx += 1
303
+ intermediates.append(tokens.view(B, S, P, C))
304
+
305
+ return tokens, global_idx, intermediates
306
+
307
+
308
+ def slice_expand_and_flatten(token_tensor, B, S):
309
+ """
310
+ Processes specialized tokens with shape (1, 2, X, C) for multi-frame processing:
311
+ 1) Uses the first position (index=0) for the first frame only
312
+ 2) Uses the second position (index=1) for all remaining frames (S-1 frames)
313
+ 3) Expands both to match batch size B
314
+ 4) Concatenates to form (B, S, X, C) where each sequence has 1 first-position token
315
+ followed by (S-1) second-position tokens
316
+ 5) Flattens to (B*S, X, C) for processing
317
+
318
+ Returns:
319
+ torch.Tensor: Processed tokens with shape (B*S, X, C)
320
+ """
321
+
322
+ # Slice out the "query" tokens => shape (1, 1, ...)
323
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
324
+ # Slice out the "other" tokens => shape (1, S-1, ...)
325
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
326
+ # Concatenate => shape (B, S, ...)
327
+ combined = torch.cat([query, others], dim=1)
328
+
329
+ # Finally flatten => shape (B*S, ...)
330
+ combined = combined.view(B * S, *combined.shape[2:])
331
+ return combined
vggt/models/vggt.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
10
+
11
+ from vggt.models.aggregator import Aggregator
12
+ from vggt.heads.camera_head import CameraHead
13
+ from vggt.heads.dpt_head import DPTHead
14
+ from vggt.heads.track_head import TrackHead
15
+
16
+
17
+ class VGGT(nn.Module, PyTorchModelHubMixin):
18
+ def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
19
+ super().__init__()
20
+
21
+ self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
22
+ self.camera_head = CameraHead(dim_in=2 * embed_dim)
23
+ self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
24
+ self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
25
+ self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
26
+
27
+ def forward(
28
+ self,
29
+ images: torch.Tensor,
30
+ query_points: torch.Tensor = None,
31
+ ):
32
+ """
33
+ Forward pass of the VGGT model.
34
+
35
+ Args:
36
+ images (torch.Tensor): Input images with shape [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1].
37
+ B: batch size, S: sequence length, 3: RGB channels, H: height, W: width
38
+ query_points (torch.Tensor, optional): Query points for tracking, in pixel coordinates.
39
+ Shape: [N, 2] or [B, N, 2], where N is the number of query points.
40
+ Default: None
41
+ """
42
+ # If without batch dimension, add it
43
+ if len(images.shape) == 4:
44
+ images = images.unsqueeze(0)
45
+ if query_points is not None and len(query_points.shape) == 2:
46
+ query_points = query_points.unsqueeze(0)
47
+
48
+ aggregated_tokens_list, patch_start_idx = self.aggregator(images)
49
+
50
+ predictions = {}
51
+
52
+ with torch.cuda.amp.autocast(enabled=False):
53
+ if self.camera_head is not None:
54
+ pose_enc_list = self.camera_head(aggregated_tokens_list)
55
+ predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
56
+
57
+ if self.depth_head is not None:
58
+ depth, depth_conf = self.depth_head(
59
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
60
+ )
61
+ predictions["depth"] = depth
62
+ predictions["depth_conf"] = depth_conf
63
+
64
+ if self.point_head is not None:
65
+ pts3d, pts3d_conf = self.point_head(
66
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
67
+ )
68
+ predictions["world_points"] = pts3d
69
+ predictions["world_points_conf"] = pts3d_conf
70
+
71
+ if self.track_head is not None and query_points is not None:
72
+ track_list, vis, conf = self.track_head(
73
+ aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
74
+ )
75
+ predictions["track"] = track_list[-1] # track of the last iteration
76
+ predictions["vis"] = vis
77
+ predictions["conf"] = conf
78
+
79
+ predictions["images"] = images
80
+
81
+ return predictions
vggt/utils/geometry.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+ from typing import Tuple
11
+
12
+
13
+ def unproject_depth_map_to_point_map(
14
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
15
+ ) -> np.ndarray:
16
+ """
17
+ Unproject a batch of depth maps to 3D world coordinates.
18
+
19
+ Args:
20
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
21
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
22
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
23
+
24
+ Returns:
25
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
26
+ """
27
+ if isinstance(depth_map, torch.Tensor):
28
+ depth_map = depth_map.cpu().numpy()
29
+ if isinstance(extrinsics_cam, torch.Tensor):
30
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
31
+ if isinstance(intrinsics_cam, torch.Tensor):
32
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
33
+
34
+ world_points_list = []
35
+ for frame_idx in range(depth_map.shape[0]):
36
+ cur_world_points, _, _ = depth_to_world_coords_points(
37
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
38
+ )
39
+ world_points_list.append(cur_world_points)
40
+ world_points_array = np.stack(world_points_list, axis=0)
41
+
42
+ return world_points_array
43
+
44
+
45
+ def depth_to_world_coords_points(
46
+ depth_map: np.ndarray,
47
+ extrinsic: np.ndarray,
48
+ intrinsic: np.ndarray,
49
+ eps=1e-8,
50
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
51
+ """
52
+ Convert a depth map to world coordinates.
53
+
54
+ Args:
55
+ depth_map (np.ndarray): Depth map of shape (H, W).
56
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
57
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
58
+
59
+ Returns:
60
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
61
+ """
62
+ if depth_map is None:
63
+ return None, None, None
64
+
65
+ # Valid depth mask
66
+ point_mask = depth_map > eps
67
+
68
+ # Convert depth map to camera coordinates
69
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
70
+
71
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
72
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
73
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
74
+
75
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
76
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
77
+
78
+ # Apply the rotation and translation to the camera coordinates
79
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
80
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
81
+
82
+ return world_coords_points, cam_coords_points, point_mask
83
+
84
+
85
+ def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
86
+ """
87
+ Convert a depth map to camera coordinates.
88
+
89
+ Args:
90
+ depth_map (np.ndarray): Depth map of shape (H, W).
91
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
92
+
93
+ Returns:
94
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
95
+ """
96
+ H, W = depth_map.shape
97
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
98
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
99
+
100
+ # Intrinsic parameters
101
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
102
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
103
+
104
+ # Generate grid of pixel coordinates
105
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
106
+
107
+ # Unproject to camera coordinates
108
+ x_cam = (u - cu) * depth_map / fu
109
+ y_cam = (v - cv) * depth_map / fv
110
+ z_cam = depth_map
111
+
112
+ # Stack to form camera coordinates
113
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
114
+
115
+ return cam_coords
116
+
117
+
118
+ def closed_form_inverse_se3(se3, R=None, T=None):
119
+ """
120
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
121
+
122
+ If `R` and `T` are provided, they must correspond to the rotation and translation
123
+ components of `se3`. Otherwise, they will be extracted from `se3`.
124
+
125
+ Args:
126
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
127
+ R (optional): Nx3x3 array or tensor of rotation matrices.
128
+ T (optional): Nx3x1 array or tensor of translation vectors.
129
+
130
+ Returns:
131
+ Inverted SE3 matrices with the same type and device as `se3`.
132
+
133
+ Shapes:
134
+ se3: (N, 4, 4)
135
+ R: (N, 3, 3)
136
+ T: (N, 3, 1)
137
+ """
138
+ # Check if se3 is a numpy array or a torch tensor
139
+ is_numpy = isinstance(se3, np.ndarray)
140
+
141
+ # Validate shapes
142
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
143
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
144
+
145
+ # Extract R and T if not provided
146
+ if R is None:
147
+ R = se3[:, :3, :3] # (N,3,3)
148
+ if T is None:
149
+ T = se3[:, :3, 3:] # (N,3,1)
150
+
151
+ # Transpose R
152
+ if is_numpy:
153
+ # Compute the transpose of the rotation for NumPy
154
+ R_transposed = np.transpose(R, (0, 2, 1))
155
+ # -R^T t for NumPy
156
+ top_right = -np.matmul(R_transposed, T)
157
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
158
+ else:
159
+ R_transposed = R.transpose(1, 2) # (N,3,3)
160
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
161
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
162
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
163
+
164
+ inverted_matrix[:, :3, :3] = R_transposed
165
+ inverted_matrix[:, :3, 3:] = top_right
166
+
167
+ return inverted_matrix
vggt/utils/load_fn.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision import transforms as TF
10
+
11
+
12
+ def load_and_preprocess_images(image_path_list):
13
+ """
14
+ A quick start function to load and preprocess images for model input.
15
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
16
+
17
+ Args:
18
+ image_path_list (list): List of paths to image files
19
+
20
+ Returns:
21
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
22
+
23
+ Raises:
24
+ ValueError: If the input list is empty
25
+
26
+ Notes:
27
+ - Images with different dimensions will be padded with white (value=1.0)
28
+ - A warning is printed when images have different shapes
29
+ - The function ensures width=518px while maintaining aspect ratio
30
+ - Height is adjusted to be divisible by 14 for compatibility with model requirements
31
+ """
32
+ # Check for empty list
33
+ if len(image_path_list) == 0:
34
+ raise ValueError("At least 1 image is required")
35
+
36
+ images = []
37
+ shapes = set()
38
+ to_tensor = TF.ToTensor()
39
+
40
+ # First process all images and collect their shapes
41
+ for image_path in image_path_list:
42
+
43
+ # Open image
44
+ img = Image.open(image_path)
45
+
46
+ # If there's an alpha channel, blend onto white background:
47
+ if img.mode == 'RGBA':
48
+ # Create white background
49
+ background = Image.new('RGBA', img.size, (255, 255, 255, 255))
50
+ # Alpha composite onto the white background
51
+ img = Image.alpha_composite(background, img)
52
+
53
+ # Now convert to "RGB" (this step assigns white for transparent areas)
54
+ img = img.convert("RGB")
55
+
56
+ width, height = img.size
57
+ new_width = 518
58
+
59
+ # Calculate height maintaining aspect ratio, divisible by 14
60
+ new_height = round(height * (new_width / width) / 14) * 14
61
+
62
+ # Resize with new dimensions (width, height)
63
+
64
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
65
+ img = to_tensor(img) # Convert to tensor (0, 1)
66
+
67
+ # Center crop height if it's larger than 518
68
+
69
+ if new_height > 518:
70
+ start_y = (new_height - 518) // 2
71
+ img = img[:, start_y : start_y + 518, :]
72
+
73
+ shapes.add((img.shape[1], img.shape[2]))
74
+ images.append(img)
75
+
76
+ # Check if we have different shapes
77
+ # In theory our model can also work well with different shapes
78
+
79
+ if len(shapes) > 1:
80
+ print(f"Warning: Found images with different shapes: {shapes}")
81
+ # Find maximum dimensions
82
+ max_height = max(shape[0] for shape in shapes)
83
+ max_width = max(shape[1] for shape in shapes)
84
+
85
+ # Pad images if necessary
86
+ padded_images = []
87
+ for img in images:
88
+ h_padding = max_height - img.shape[1]
89
+ w_padding = max_width - img.shape[2]
90
+
91
+ if h_padding > 0 or w_padding > 0:
92
+ pad_top = h_padding // 2
93
+ pad_bottom = h_padding - pad_top
94
+ pad_left = w_padding // 2
95
+ pad_right = w_padding - pad_left
96
+
97
+ img = torch.nn.functional.pad(
98
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
99
+ )
100
+ padded_images.append(img)
101
+ images = padded_images
102
+
103
+ images = torch.stack(images) # concatenate images
104
+
105
+ # Ensure correct shape when single image
106
+ if len(image_path_list) == 1:
107
+ # Verify shape is (1, C, H, W)
108
+ if images.dim() == 3:
109
+ images = images.unsqueeze(0)
110
+
111
+ return images
vggt/utils/pose_enc.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from .rotation import quat_to_mat, mat_to_quat
9
+
10
+
11
+ def extri_intri_to_pose_encoding(
12
+ extrinsics,
13
+ intrinsics,
14
+ image_size_hw=None, # e.g., (256, 512)
15
+ pose_encoding_type="absT_quaR_FoV",
16
+ min_focal_length=0.1,
17
+ max_focal_length=10,
18
+ ):
19
+
20
+ # extrinsics: BxSx3x4
21
+ # intrinsics: BxSx3x3
22
+
23
+ if pose_encoding_type == "absT_quaR_FoV":
24
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
25
+ T = extrinsics[:, :, :3, 3] # BxSx3
26
+
27
+ quat = mat_to_quat(R)
28
+ # R_reverse = quat_to_mat(quat)
29
+ # Note the order of h and w here
30
+ H, W = image_size_hw
31
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
32
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
33
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
34
+ else:
35
+ raise NotImplementedError
36
+
37
+ return pose_encoding
38
+
39
+
40
+ def pose_encoding_to_extri_intri(
41
+ pose_encoding,
42
+ image_size_hw=None, # e.g., (256, 512)
43
+ min_focal_length=0.1,
44
+ max_focal_length=10,
45
+ pose_encoding_type="absT_quaR_FoV",
46
+ build_intrinsics=True,
47
+ ):
48
+
49
+ intrinsics = None
50
+
51
+ if pose_encoding_type == "absT_quaR_FoV":
52
+ T = pose_encoding[..., :3]
53
+ quat = pose_encoding[..., 3:7]
54
+ fov_h = pose_encoding[..., 7]
55
+ fov_w = pose_encoding[..., 8]
56
+
57
+ R = quat_to_mat(quat)
58
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
59
+
60
+ if build_intrinsics:
61
+ H, W = image_size_hw
62
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
63
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
64
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
65
+ intrinsics[..., 0, 0] = fx
66
+ intrinsics[..., 1, 1] = fy
67
+ intrinsics[..., 0, 2] = W / 2
68
+ intrinsics[..., 1, 2] = H / 2
69
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
70
+ else:
71
+ raise NotImplementedError
72
+
73
+ return extrinsics, intrinsics
vggt/utils/rotation.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
8
+
9
+ import torch
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Quaternion Order: XYZW or say ijkr, scalar-last
17
+
18
+ Convert rotations given as quaternions to rotation matrices.
19
+ Args:
20
+ quaternions: quaternions with real part last,
21
+ as tensor of shape (..., 4).
22
+
23
+ Returns:
24
+ Rotation matrices as tensor of shape (..., 3, 3).
25
+ """
26
+ i, j, k, r = torch.unbind(quaternions, -1)
27
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
28
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
29
+
30
+ o = torch.stack(
31
+ (
32
+ 1 - two_s * (j * j + k * k),
33
+ two_s * (i * j - k * r),
34
+ two_s * (i * k + j * r),
35
+ two_s * (i * j + k * r),
36
+ 1 - two_s * (i * i + k * k),
37
+ two_s * (j * k - i * r),
38
+ two_s * (i * k - j * r),
39
+ two_s * (j * k + i * r),
40
+ 1 - two_s * (i * i + j * j),
41
+ ),
42
+ -1,
43
+ )
44
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
45
+
46
+
47
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert rotations given as rotation matrices to quaternions.
50
+
51
+ Args:
52
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
53
+
54
+ Returns:
55
+ quaternions with real part last, as tensor of shape (..., 4).
56
+ Quaternion Order: XYZW or say ijkr, scalar-last
57
+ """
58
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
59
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
60
+
61
+ batch_dim = matrix.shape[:-2]
62
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
63
+
64
+ q_abs = _sqrt_positive_part(
65
+ torch.stack(
66
+ [
67
+ 1.0 + m00 + m11 + m22,
68
+ 1.0 + m00 - m11 - m22,
69
+ 1.0 - m00 + m11 - m22,
70
+ 1.0 - m00 - m11 + m22,
71
+ ],
72
+ dim=-1,
73
+ )
74
+ )
75
+
76
+ # we produce the desired quaternion multiplied by each of r, i, j, k
77
+ quat_by_rijk = torch.stack(
78
+ [
79
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
80
+ # `int`.
81
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
82
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
83
+ # `int`.
84
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
85
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
86
+ # `int`.
87
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
88
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
89
+ # `int`.
90
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
91
+ ],
92
+ dim=-2,
93
+ )
94
+
95
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
96
+ # the candidate won't be picked.
97
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
98
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
99
+
100
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
101
+ # forall i; we pick the best-conditioned one (with the largest denominator)
102
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
103
+
104
+ # Convert from rijk to ijkr
105
+ out = out[..., [1, 2, 3, 0]]
106
+
107
+ out = standardize_quaternion(out)
108
+
109
+ return out
110
+
111
+
112
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
113
+ """
114
+ Returns torch.sqrt(torch.max(0, x))
115
+ but with a zero subgradient where x is 0.
116
+ """
117
+ ret = torch.zeros_like(x)
118
+ positive_mask = x > 0
119
+ if torch.is_grad_enabled():
120
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
121
+ else:
122
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
123
+ return ret
124
+
125
+
126
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
127
+ """
128
+ Convert a unit quaternion to a standard form: one in which the real
129
+ part is non negative.
130
+
131
+ Args:
132
+ quaternions: Quaternions with real part last,
133
+ as tensor of shape (..., 4).
134
+
135
+ Returns:
136
+ Standardized quaternions as tensor of shape (..., 4).
137
+ """
138
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)