amael-apple commited on
Commit
c20d7cc
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +166 -0
  2. .pre-commit-config.yaml +23 -0
  3. .python-version +1 -0
  4. ACKNOWLEDGEMENTS +214 -0
  5. CODE_OF_CONDUCT.md +70 -0
  6. CONTRIBUTING.md +11 -0
  7. LICENSE +47 -0
  8. LICENSE_MODEL +88 -0
  9. README.md +95 -0
  10. pyproject.toml +69 -0
  11. requirements.in +1 -0
  12. requirements.txt +172 -0
  13. src/sharp/__init__.py +4 -0
  14. src/sharp/cli/__init__.py +19 -0
  15. src/sharp/cli/predict.py +206 -0
  16. src/sharp/cli/render.py +120 -0
  17. src/sharp/models/__init__.py +79 -0
  18. src/sharp/models/alignment.py +126 -0
  19. src/sharp/models/blocks.py +210 -0
  20. src/sharp/models/composer.py +251 -0
  21. src/sharp/models/decoders/__init__.py +22 -0
  22. src/sharp/models/decoders/base_decoder.py +21 -0
  23. src/sharp/models/decoders/monodepth_decoder.py +37 -0
  24. src/sharp/models/decoders/multires_conv_decoder.py +116 -0
  25. src/sharp/models/decoders/unet_decoder.py +113 -0
  26. src/sharp/models/encoders/__init__.py +24 -0
  27. src/sharp/models/encoders/base_encoder.py +25 -0
  28. src/sharp/models/encoders/monodepth_encoder.py +123 -0
  29. src/sharp/models/encoders/spn_encoder.py +369 -0
  30. src/sharp/models/encoders/unet_encoder.py +117 -0
  31. src/sharp/models/encoders/vit_encoder.py +111 -0
  32. src/sharp/models/gaussian_decoder.py +267 -0
  33. src/sharp/models/heads.py +53 -0
  34. src/sharp/models/initializer.py +297 -0
  35. src/sharp/models/monodepth.py +268 -0
  36. src/sharp/models/normalizers.py +80 -0
  37. src/sharp/models/params.py +203 -0
  38. src/sharp/models/predictor.py +201 -0
  39. src/sharp/models/presets/__init__.py +23 -0
  40. src/sharp/models/presets/monodepth.py +21 -0
  41. src/sharp/models/presets/vit.py +58 -0
  42. src/sharp/utils/__init__.py +5 -0
  43. src/sharp/utils/camera.py +386 -0
  44. src/sharp/utils/color_space.py +88 -0
  45. src/sharp/utils/gaussians.py +480 -0
  46. src/sharp/utils/gsplat.py +191 -0
  47. src/sharp/utils/io.py +213 -0
  48. src/sharp/utils/linalg.py +104 -0
  49. src/sharp/utils/logging.py +45 -0
  50. src/sharp/utils/math.py +183 -0
.gitignore ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ .DS_STORE
165
+ *.pt
166
+ .aider*
.pre-commit-config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: |
2
+ (?x)(
3
+ ^src/sharp/external
4
+ )
5
+ repos:
6
+ - repo: https://github.com/pre-commit/pre-commit-hooks
7
+ rev: v4.5.0
8
+ hooks:
9
+ - id: trailing-whitespace
10
+ - id: end-of-file-fixer
11
+ # - id: no-commit-to-branch
12
+ # args: ['--branch', 'main']
13
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
14
+ rev: v0.1.7
15
+ hooks:
16
+ - id: ruff
17
+ args: [--fix, --exit-non-zero-on-fix]
18
+ - id: ruff-format
19
+ - repo: https://github.com/pre-commit/mirrors-mypy
20
+ rev: v1.7.1
21
+ hooks:
22
+ - id: mypy
23
+ additional_dependencies: [ types-PyYAML ]
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.13
ACKNOWLEDGEMENTS ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Acknowledgements
2
+ Portions of this Software may utilize the following copyrighted
3
+ material, the use of which is hereby acknowledged.
4
+
5
+ ---------------------------------------------------------------------------------
6
+
7
+ TIMM - Pytorch Image Models library
8
+
9
+ https://github.com/huggingface/pytorch-image-models
10
+
11
+ Apache License
12
+ Version 2.0, January 2004
13
+ http://www.apache.org/licenses/
14
+
15
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
16
+
17
+ 1. Definitions.
18
+
19
+ "License" shall mean the terms and conditions for use, reproduction,
20
+ and distribution as defined by Sections 1 through 9 of this document.
21
+
22
+ "Licensor" shall mean the copyright owner or entity authorized by
23
+ the copyright owner that is granting the License.
24
+
25
+ "Legal Entity" shall mean the union of the acting entity and all
26
+ other entities that control, are controlled by, or are under common
27
+ control with that entity. For the purposes of this definition,
28
+ "control" means (i) the power, direct or indirect, to cause the
29
+ direction or management of such entity, whether by contract or
30
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
31
+ outstanding shares, or (iii) beneficial ownership of such entity.
32
+
33
+ "You" (or "Your") shall mean an individual or Legal Entity
34
+ exercising permissions granted by this License.
35
+
36
+ "Source" form shall mean the preferred form for making modifications,
37
+ including but not limited to software source code, documentation
38
+ source, and configuration files.
39
+
40
+ "Object" form shall mean any form resulting from mechanical
41
+ transformation or translation of a Source form, including but
42
+ not limited to compiled object code, generated documentation,
43
+ and conversions to other media types.
44
+
45
+ "Work" shall mean the work of authorship, whether in Source or
46
+ Object form, made available under the License, as indicated by a
47
+ copyright notice that is included in or attached to the work
48
+ (an example is provided in the Appendix below).
49
+
50
+ "Derivative Works" shall mean any work, whether in Source or Object
51
+ form, that is based on (or derived from) the Work and for which the
52
+ editorial revisions, annotations, elaborations, or other modifications
53
+ represent, as a whole, an original work of authorship. For the purposes
54
+ of this License, Derivative Works shall not include works that remain
55
+ separable from, or merely link (or bind by name) to the interfaces of,
56
+ the Work and Derivative Works thereof.
57
+
58
+ "Contribution" shall mean any work of authorship, including
59
+ the original version of the Work and any modifications or additions
60
+ to that Work or Derivative Works thereof, that is intentionally
61
+ submitted to Licensor for inclusion in the Work by the copyright owner
62
+ or by an individual or Legal Entity authorized to submit on behalf of
63
+ the copyright owner. For the purposes of this definition, "submitted"
64
+ means any form of electronic, verbal, or written communication sent
65
+ to the Licensor or its representatives, including but not limited to
66
+ communication on electronic mailing lists, source code control systems,
67
+ and issue tracking systems that are managed by, or on behalf of, the
68
+ Licensor for the purpose of discussing and improving the Work, but
69
+ excluding communication that is conspicuously marked or otherwise
70
+ designated in writing by the copyright owner as "Not a Contribution."
71
+
72
+ "Contributor" shall mean Licensor and any individual or Legal Entity
73
+ on behalf of whom a Contribution has been received by Licensor and
74
+ subsequently incorporated within the Work.
75
+
76
+ 2. Grant of Copyright License. Subject to the terms and conditions of
77
+ this License, each Contributor hereby grants to You a perpetual,
78
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
79
+ copyright license to reproduce, prepare Derivative Works of,
80
+ publicly display, publicly perform, sublicense, and distribute the
81
+ Work and such Derivative Works in Source or Object form.
82
+
83
+ 3. Grant of Patent License. Subject to the terms and conditions of
84
+ this License, each Contributor hereby grants to You a perpetual,
85
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
86
+ (except as stated in this section) patent license to make, have made,
87
+ use, offer to sell, sell, import, and otherwise transfer the Work,
88
+ where such license applies only to those patent claims licensable
89
+ by such Contributor that are necessarily infringed by their
90
+ Contribution(s) alone or by combination of their Contribution(s)
91
+ with the Work to which such Contribution(s) was submitted. If You
92
+ institute patent litigation against any entity (including a
93
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
94
+ or a Contribution incorporated within the Work constitutes direct
95
+ or contributory patent infringement, then any patent licenses
96
+ granted to You under this License for that Work shall terminate
97
+ as of the date such litigation is filed.
98
+
99
+ 4. Redistribution. You may reproduce and distribute copies of the
100
+ Work or Derivative Works thereof in any medium, with or without
101
+ modifications, and in Source or Object form, provided that You
102
+ meet the following conditions:
103
+
104
+ (a) You must give any other recipients of the Work or
105
+ Derivative Works a copy of this License; and
106
+
107
+ (b) You must cause any modified files to carry prominent notices
108
+ stating that You changed the files; and
109
+
110
+ (c) You must retain, in the Source form of any Derivative Works
111
+ that You distribute, all copyright, patent, trademark, and
112
+ attribution notices from the Source form of the Work,
113
+ excluding those notices that do not pertain to any part of
114
+ the Derivative Works; and
115
+
116
+ (d) If the Work includes a "NOTICE" text file as part of its
117
+ distribution, then any Derivative Works that You distribute must
118
+ include a readable copy of the attribution notices contained
119
+ within such NOTICE file, excluding those notices that do not
120
+ pertain to any part of the Derivative Works, in at least one
121
+ of the following places: within a NOTICE text file distributed
122
+ as part of the Derivative Works; within the Source form or
123
+ documentation, if provided along with the Derivative Works; or,
124
+ within a display generated by the Derivative Works, if and
125
+ wherever such third-party notices normally appear. The contents
126
+ of the NOTICE file are for informational purposes only and
127
+ do not modify the License. You may add Your own attribution
128
+ notices within Derivative Works that You distribute, alongside
129
+ or as an addendum to the NOTICE text from the Work, provided
130
+ that such additional attribution notices cannot be construed
131
+ as modifying the License.
132
+
133
+ You may add Your own copyright statement to Your modifications and
134
+ may provide additional or different license terms and conditions
135
+ for use, reproduction, or distribution of Your modifications, or
136
+ for any such Derivative Works as a whole, provided Your use,
137
+ reproduction, and distribution of the Work otherwise complies with
138
+ the conditions stated in this License.
139
+
140
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
141
+ any Contribution intentionally submitted for inclusion in the Work
142
+ by You to the Licensor shall be under the terms and conditions of
143
+ this License, without any additional terms or conditions.
144
+ Notwithstanding the above, nothing herein shall supersede or modify
145
+ the terms of any separate license agreement you may have executed
146
+ with Licensor regarding such Contributions.
147
+
148
+ 6. Trademarks. This License does not grant permission to use the trade
149
+ names, trademarks, service marks, or product names of the Licensor,
150
+ except as required for reasonable and customary use in describing the
151
+ origin of the Work and reproducing the content of the NOTICE file.
152
+
153
+ 7. Disclaimer of Warranty. Unless required by applicable law or
154
+ agreed to in writing, Licensor provides the Work (and each
155
+ Contributor provides its Contributions) on an "AS IS" BASIS,
156
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
157
+ implied, including, without limitation, any warranties or conditions
158
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
159
+ PARTICULAR PURPOSE. You are solely responsible for determining the
160
+ appropriateness of using or redistributing the Work and assume any
161
+ risks associated with Your exercise of permissions under this License.
162
+
163
+ 8. Limitation of Liability. In no event and under no legal theory,
164
+ whether in tort (including negligence), contract, or otherwise,
165
+ unless required by applicable law (such as deliberate and grossly
166
+ negligent acts) or agreed to in writing, shall any Contributor be
167
+ liable to You for damages, including any direct, indirect, special,
168
+ incidental, or consequential damages of any character arising as a
169
+ result of this License or out of the use or inability to use the
170
+ Work (including but not limited to damages for loss of goodwill,
171
+ work stoppage, computer failure or malfunction, or any and all
172
+ other commercial damages or losses), even if such Contributor
173
+ has been advised of the possibility of such damages.
174
+
175
+ 9. Accepting Warranty or Additional Liability. While redistributing
176
+ the Work or Derivative Works thereof, You may choose to offer,
177
+ and charge a fee for, acceptance of support, warranty, indemnity,
178
+ or other liability obligations and/or rights consistent with this
179
+ License. However, in accepting such obligations, You may act only
180
+ on Your own behalf and on Your sole responsibility, not on behalf
181
+ of any other Contributor, and only if You agree to indemnify,
182
+ defend, and hold each Contributor harmless for any liability
183
+ incurred by, or claims asserted against, such Contributor by reason
184
+ of your accepting any such warranty or additional liability.
185
+
186
+ END OF TERMS AND CONDITIONS
187
+
188
+ APPENDIX: How to apply the Apache License to your work.
189
+
190
+ To apply the Apache License to your work, attach the following
191
+ boilerplate notice, with the fields enclosed by brackets "{}"
192
+ replaced with your own identifying information. (Don't include
193
+ the brackets!) The text should be enclosed in the appropriate
194
+ comment syntax for the file format. We also recommend that a
195
+ file or class name and description of purpose be included on the
196
+ same "printed page" as the copyright notice for easier
197
+ identification within third-party archives.
198
+
199
+ Copyright 2019 Ross Wightman
200
+
201
+ Licensed under the Apache License, Version 2.0 (the "License");
202
+ you may not use this file except in compliance with the License.
203
+ You may obtain a copy of the License at
204
+
205
+ http://www.apache.org/licenses/LICENSE-2.0
206
+
207
+ Unless required by applicable law or agreed to in writing, software
208
+ distributed under the License is distributed on an "AS IS" BASIS,
209
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
210
+ See the License for the specific language governing permissions and
211
+ limitations under the License.
212
+
213
+
214
+ -------
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to making participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
CONTRIBUTING.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contribution Guide
2
+
3
+ Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
4
+
5
+ While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6
+
7
+ ## Before you get started
8
+
9
+ By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10
+
11
+ We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
LICENSE ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
2
+
3
+ Disclaimer: IMPORTANT: This Apple software is supplied to you by Apple
4
+ Inc. ("Apple") in consideration of your agreement to the following
5
+ terms, and your use, installation, modification or redistribution of
6
+ this Apple software constitutes acceptance of these terms. If you do
7
+ not agree with these terms, please do not use, install, modify or
8
+ redistribute this Apple software.
9
+
10
+ In consideration of your agreement to abide by the following terms, and
11
+ subject to these terms, Apple grants you a personal, non-exclusive
12
+ license, under Apple's copyrights in this original Apple software (the
13
+ "Apple Software"), to use, reproduce, modify and redistribute the Apple
14
+ Software, with or without modifications, in source and/or binary forms;
15
+ provided that if you redistribute the Apple Software in its entirety and
16
+ without modifications, you must retain this notice and the following
17
+ text and disclaimers in all such redistributions of the Apple Software.
18
+ Neither the name, trademarks, service marks or logos of Apple Inc. may
19
+ be used to endorse or promote products derived from the Apple Software
20
+ without specific prior written permission from Apple. Except as
21
+ expressly stated in this notice, no other rights or licenses, express or
22
+ implied, are granted by Apple herein, including but not limited to any
23
+ patent rights that may be infringed by your derivative works or by other
24
+ works in which the Apple Software may be incorporated.
25
+
26
+ The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27
+ MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28
+ THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29
+ FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30
+ OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31
+
32
+ IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33
+ OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35
+ INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36
+ MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37
+ AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38
+ STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39
+ POSSIBILITY OF SUCH DAMAGE.
40
+
41
+
42
+ -------------------------------------------------------------------------------
43
+ SOFTWARE DISTRIBUTED IN THIS REPOSITORY:
44
+
45
+ This software includes a number of subcomponents with separate
46
+ copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
47
+ -------------------------------------------------------------------------------
LICENSE_MODEL ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Disclaimer: IMPORTANT: This Apple Machine Learning Research Model is
2
+ specifically developed and released by Apple Inc. ("Apple") for the sole purpose
3
+ of scientific research of artificial intelligence and machine-learning
4
+ technology. “Apple Machine Learning Research Model” means the model, including
5
+ but not limited to algorithms, formulas, trained model weights, parameters,
6
+ configurations, checkpoints, and any related materials (including
7
+ documentation).
8
+
9
+ This Apple Machine Learning Research Model is provided to You by
10
+ Apple in consideration of your agreement to the following terms, and your use,
11
+ modification, creation of Model Derivatives, and or redistribution of the Apple
12
+ Machine Learning Research Model constitutes acceptance of this Agreement. If You
13
+ do not agree with these terms, please do not use, modify, create Model
14
+ Derivatives of, or distribute this Apple Machine Learning Research Model or
15
+ Model Derivatives.
16
+
17
+ * License Scope: In consideration of your agreement to abide by the following
18
+ terms, and subject to these terms, Apple hereby grants you a personal,
19
+ non-exclusive, worldwide, non-transferable, royalty-free, revocable, and
20
+ limited license, to use, copy, modify, distribute, and create Model
21
+ Derivatives (defined below) of the Apple Machine Learning Research Model
22
+ exclusively for Research Purposes. You agree that any Model Derivatives You
23
+ may create or that may be created for You will be limited to Research Purposes
24
+ as well. “Research Purposes” means non-commercial scientific research and
25
+ academic development activities, such as experimentation, analysis, testing
26
+ conducted by You with the sole intent to advance scientific knowledge and
27
+ research. “Research Purposes” does not include any commercial exploitation,
28
+ product development or use in any commercial product or service.
29
+
30
+ * Distribution of Apple Machine Learning Research Model and Model Derivatives:
31
+ If you choose to redistribute Apple Machine Learning Research Model or its
32
+ Model Derivatives, you must provide a copy of this Agreement to such third
33
+ party, and ensure that the following attribution notice be provided: “Apple
34
+ Machine Learning Research Model is licensed under the Apple Machine Learning
35
+ Research Model License Agreement.” Additionally, all Model Derivatives must
36
+ clearly be identified as such, including disclosure of modifications and
37
+ changes made to the Apple Machine Learning Research Model. The name,
38
+ trademarks, service marks or logos of Apple may not be used to endorse or
39
+ promote Model Derivatives or the relationship between You and Apple. “Model
40
+ Derivatives” means any models or any other artifacts created by modifications,
41
+ improvements, adaptations, alterations to the architecture, algorithm or
42
+ training processes of the Apple Machine Learning Research Model, or by any
43
+ retraining, fine-tuning of the Apple Machine Learning Research Model.
44
+
45
+ * No Other License: Except as expressly stated in this notice, no other rights
46
+ or licenses, express or implied, are granted by Apple herein, including but
47
+ not limited to any patent, trademark, and similar intellectual property rights
48
+ worldwide that may be infringed by the Apple Machine Learning Research Model,
49
+ the Model Derivatives or by other works in which the Apple Machine Learning
50
+ Research Model may be incorporated.
51
+
52
+ * Compliance with Laws: Your use of Apple Machine Learning Research Model must
53
+ be in compliance with all applicable laws and regulations.
54
+
55
+ * Term and Termination: The term of this Agreement will begin upon your
56
+ acceptance of this Agreement or use of the Apple Machine Learning Research
57
+ Model and will continue until terminated in accordance with the following
58
+ terms. Apple may terminate this Agreement at any time if You are in breach of
59
+ any term or condition of this Agreement. Upon termination of this Agreement,
60
+ You must cease to use all Apple Machine Learning Research Models and Model
61
+ Derivatives and permanently delete any copy thereof. Sections 3, 6 and 7 will
62
+ survive termination.
63
+
64
+ * Disclaimer and Limitation of Liability: This Apple Machine Learning Research
65
+ Model and any outputs generated by the Apple Machine Learning Research Model
66
+ are provided on an “AS IS” basis. APPLE MAKES NO WARRANTIES, EXPRESS OR
67
+ IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF
68
+ NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE,
69
+ REGARDING THE APPLE MACHINE LEARNING RESEARCH MODEL OR OUTPUTS GENERATED BY
70
+ THE APPLE MACHINE LEARNING RESEARCH MODEL. You are solely responsible for
71
+ determining the appropriateness of using or redistributing the Apple Machine
72
+ Learning Research Model and any outputs of the Apple Machine Learning Research
73
+ Model and assume any risks associated with Your use of the Apple Machine
74
+ Learning Research Model and any output and results. IN NO EVENT SHALL APPLE BE
75
+ LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
76
+ IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF
77
+ THE APPLE MACHINE LEARNING RESEARCH MODEL AND ANY OUTPUTS OF THE APPLE MACHINE
78
+ LEARNING RESEARCH MODEL, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT,
79
+ TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS
80
+ BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
81
+
82
+ * Governing Law: This Agreement will be governed by and construed under the laws
83
+ of the State of California without regard to its choice of law principles. The
84
+ Convention on Contracts for the International Sale of Goods shall not apply to
85
+ the Agreement except that the arbitration clause and any arbitration hereunder
86
+ shall be governed by the Federal Arbitration Act, Chapters 1 and 2.
87
+
88
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sharp Monocular View Synthesis in Less Than a Second
2
+
3
+ [![Project Page](https://img.shields.io/badge/Project-Page-green)](https://apple.github.io/ml-sharp/)
4
+ [![arXiv](https://img.shields.io/badge/arXiv-2512.10685-b31b1b.svg)](https://arxiv.org/abs/2512.10685)
5
+
6
+ This software project accompanies the research paper: _Sharp Monocular View Synthesis in Less Than a Second_
7
+ by _Lars Mescheder, Wei Dong, Shiwei Li, Xuyang Bai, Marcel Santos, Peiyun Hu, Bruno Lecouat, Mingmin Zhen, Amaël Delaunoy,
8
+ Tian Fang, Yanghai Tsin, Stephan Richter and Vladlen Koltun_.
9
+
10
+ ![](data/teaser.jpg)
11
+
12
+ We present SHARP, an approach to photorealistic view synthesis from a single image. Given a single photograph, SHARP regresses the parameters of a 3D Gaussian representation of the depicted scene. This is done in less than a second on a standard GPU via a single feedforward pass through a neural network. The 3D Gaussian representation produced by SHARP can then be rendered in real time, yielding high-resolution photorealistic images for nearby views. The representation is metric, with absolute scale, supporting metric camera movements. Experimental results demonstrate that SHARP delivers robust zero-shot generalization across datasets. It sets a new state of the art on multiple datasets, reducing LPIPS by 25–34% and DISTS by 21–43% versus the best prior model, while lowering the synthesis time by three orders of magnitude.
13
+
14
+ ## Getting started
15
+
16
+ We recommend to first create a python environment:
17
+
18
+ ```
19
+ conda create -n sharp python=3.13
20
+ ```
21
+
22
+ Afterwards, you can install the project using
23
+
24
+ ```
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ To test the installation, run
29
+
30
+ ```
31
+ sharp --help
32
+ ```
33
+
34
+ ## Using the CLI
35
+
36
+ To run prediction:
37
+
38
+ ```
39
+ sharp predict -i /path/to/input/images -o /path/to/output/gaussians
40
+ ```
41
+
42
+ The model checkpoint will be downloaded automatically on first run and cached locally at `~/.cache/torch/hub/checkpoints/`.
43
+
44
+ Alternatively, you can download the model directly:
45
+
46
+ ```
47
+ wget https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt
48
+ ```
49
+
50
+ To use a manually downloaded checkpoint, specify it with the `-c` flag:
51
+
52
+ ```
53
+ sharp predict -i /path/to/input/images -o /path/to/output/gaussians -c sharp_2572gikvuh.pt
54
+ ```
55
+
56
+ The results will be 3D gaussian splats (3DGS) in the output folder. The 3DGS `.ply` files are compatible to various public 3DGS renderers. We follow the OpenCV coordinate convention (x right, y down, z forward). The 3DGS scene center is roughly at (0, 0, +z). When dealing with 3rdparty renderers, please scale and rotate to re-center the scene accordingly.
57
+
58
+ ### Rendering trajectories (CUDA GPU only)
59
+
60
+ Additionally you can render videos with a camera trajectory. While the gaussians prediction works for all CPU, CUDA, and MPS, rendering videos via the `--render` option currently requires a CUDA GPU. The gsplat renderer takes a while to initialize at the first launch.
61
+
62
+ ```
63
+ sharp predict -i /path/to/input/images -o /path/to/output/gaussians --render
64
+
65
+ # Or from the intermediate gaussians:
66
+ sharp render -i /path/to/output/gaussians -o /path/to/output/renderings
67
+ ```
68
+
69
+ ## Evaluation
70
+
71
+ Please refer to the paper for both quantitative and qualitative evaluations.
72
+ Additionally, please check out this [qualitative examples page](https://apple.github.io/ml-sharp/) containing several video comparisons against related work.
73
+
74
+ ## Citation
75
+
76
+ If you find our work useful, please cite the following paper:
77
+
78
+ ```bibtex
79
+ @inproceedings{Sharp2025:arxiv,
80
+ title = {Sharp Monocular View Synthesis in Less Than a Second},
81
+ author = {Lars Mescheder and Wei Dong and Shiwei Li and Xuyang Bai and Marcel Santos and Peiyun Hu and Bruno Lecouat and Mingmin Zhen and Ama\"{e}l Delaunoyand Tian Fang and Yanghai Tsin and Stephan R. Richter and Vladlen Koltun},
82
+ journal = {arXiv preprint arXiv:2512.10685},
83
+ year = {2025},
84
+ url = {https://arxiv.org/abs/2512.10685},
85
+ }
86
+ ```
87
+
88
+ ## Acknowledgements
89
+
90
+ Our codebase is built using multiple opensource contributions, please see [ACKNOWLEDGEMENTS](ACKNOWLEDGEMENTS) for more details.
91
+
92
+ ## License
93
+
94
+ Please check out the repository [LICENSE](LICENSE) before using the provided code and
95
+ [LICENSE_MODEL](LICENSE_MODEL) for the released models.
pyproject.toml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sharp"
3
+ version = "0.1"
4
+ description = "Inference/Network/Model code for SHARP view synthesis model."
5
+ readme = "README.md"
6
+ dependencies = [
7
+ "click",
8
+ "gsplat",
9
+ "imageio[ffmpeg]",
10
+ "matplotlib",
11
+ "pillow-heif",
12
+ "plyfile",
13
+ "scipy",
14
+ "timm",
15
+ "torch",
16
+ "torchvision",
17
+ ]
18
+
19
+ [project.scripts]
20
+ sharp = "sharp.cli:main_cli"
21
+
22
+ [project.urls]
23
+ Homepage = "https://github.com/apple/ml-sharp"
24
+ Repository = "https://github.com/apple/ml-sharp"
25
+
26
+ [build-system]
27
+ requires = ["setuptools", "setuptools-scm"]
28
+ build-backend = "setuptools.build_meta"
29
+
30
+ [tool.setuptools.packages.find]
31
+ where = ["src"]
32
+
33
+ [tool.pyright]
34
+ include = ["src"]
35
+ exclude = [
36
+ "**/node_modules",
37
+ "**/__pycache__",
38
+ ]
39
+ pythonVersion = "3.13"
40
+
41
+ [tool.pytest.ini_options]
42
+ minversion = "6.0"
43
+ addopts = "-ra -q"
44
+ testpaths = [
45
+ "tests"
46
+ ]
47
+ filterwarnings = [
48
+ "ignore::DeprecationWarning"
49
+ ]
50
+
51
+ [tool.lint.per-file-ignores]
52
+ "__init__.py" = ["F401", "D100", "D104"]
53
+
54
+ [tool.ruff]
55
+ line-length = 100
56
+ lint.select = ["E", "F", "D", "I"]
57
+ lint.ignore = ["D100", "D105",
58
+ # Imperative mood of docstring.
59
+ "D401",
60
+ ]
61
+ extend-exclude = [
62
+ "*external*",
63
+ "third_party",
64
+ ]
65
+ src = ["sharp"]
66
+ target-version = "py39"
67
+
68
+ [tool.ruff.lint.pydocstyle]
69
+ convention = "google"
requirements.in ADDED
@@ -0,0 +1 @@
 
 
1
+ -e .
requirements.txt ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile requirements.in -o requirements.txt --universal
3
+ -e .
4
+ # via -r requirements.in
5
+ certifi==2025.8.3
6
+ # via requests
7
+ charset-normalizer==3.4.3
8
+ # via requests
9
+ click==8.3.0
10
+ # via sharp
11
+ colorama==0.4.6 ; sys_platform == 'win32'
12
+ # via
13
+ # click
14
+ # tqdm
15
+ contourpy==1.3.3
16
+ # via matplotlib
17
+ cycler==0.12.1
18
+ # via matplotlib
19
+ filelock==3.19.1
20
+ # via
21
+ # huggingface-hub
22
+ # torch
23
+ fonttools==4.61.0
24
+ # via matplotlib
25
+ fsspec==2025.9.0
26
+ # via
27
+ # huggingface-hub
28
+ # torch
29
+ gsplat==1.5.3
30
+ # via sharp
31
+ hf-xet==1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
32
+ # via huggingface-hub
33
+ huggingface-hub==0.35.3
34
+ # via timm
35
+ idna==3.10
36
+ # via requests
37
+ imageio==2.37.0
38
+ # via sharp
39
+ imageio-ffmpeg==0.6.0
40
+ # via imageio
41
+ jaxtyping==0.3.3
42
+ # via gsplat
43
+ jinja2==3.1.6
44
+ # via torch
45
+ kiwisolver==1.4.9
46
+ # via matplotlib
47
+ markdown-it-py==4.0.0
48
+ # via rich
49
+ markupsafe==3.0.3
50
+ # via jinja2
51
+ matplotlib==3.10.6
52
+ # via sharp
53
+ mdurl==0.1.2
54
+ # via markdown-it-py
55
+ mpmath==1.3.0
56
+ # via sympy
57
+ networkx==3.5
58
+ # via torch
59
+ ninja==1.13.0
60
+ # via gsplat
61
+ numpy==2.3.3
62
+ # via
63
+ # contourpy
64
+ # gsplat
65
+ # imageio
66
+ # matplotlib
67
+ # plyfile
68
+ # scipy
69
+ # torchvision
70
+ nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
71
+ # via
72
+ # nvidia-cudnn-cu12
73
+ # nvidia-cusolver-cu12
74
+ # torch
75
+ nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
76
+ # via torch
77
+ nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
78
+ # via torch
79
+ nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
80
+ # via torch
81
+ nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
82
+ # via torch
83
+ nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
84
+ # via torch
85
+ nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
86
+ # via torch
87
+ nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
88
+ # via torch
89
+ nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
90
+ # via torch
91
+ nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
92
+ # via
93
+ # nvidia-cusolver-cu12
94
+ # torch
95
+ nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
96
+ # via torch
97
+ nvidia-nccl-cu12==2.27.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
98
+ # via torch
99
+ nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
100
+ # via
101
+ # nvidia-cufft-cu12
102
+ # nvidia-cusolver-cu12
103
+ # nvidia-cusparse-cu12
104
+ # torch
105
+ nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
106
+ # via torch
107
+ packaging==25.0
108
+ # via
109
+ # huggingface-hub
110
+ # matplotlib
111
+ pillow==11.3.0
112
+ # via
113
+ # imageio
114
+ # matplotlib
115
+ # pillow-heif
116
+ # torchvision
117
+ pillow-heif==1.1.1
118
+ # via sharp
119
+ plyfile==1.1.2
120
+ # via sharp
121
+ psutil==7.1.0
122
+ # via imageio
123
+ pygments==2.19.2
124
+ # via rich
125
+ pyparsing==3.2.5
126
+ # via matplotlib
127
+ python-dateutil==2.9.0.post0
128
+ # via matplotlib
129
+ pyyaml==6.0.3
130
+ # via
131
+ # huggingface-hub
132
+ # timm
133
+ requests==2.32.5
134
+ # via huggingface-hub
135
+ rich==14.1.0
136
+ # via gsplat
137
+ safetensors==0.6.2
138
+ # via timm
139
+ scipy==1.16.2
140
+ # via sharp
141
+ setuptools==80.9.0
142
+ # via
143
+ # torch
144
+ # triton
145
+ six==1.17.0
146
+ # via python-dateutil
147
+ sympy==1.14.0
148
+ # via torch
149
+ timm==1.0.20
150
+ # via sharp
151
+ torch==2.8.0
152
+ # via
153
+ # gsplat
154
+ # sharp
155
+ # timm
156
+ # torchvision
157
+ torchvision==0.23.0
158
+ # via
159
+ # sharp
160
+ # timm
161
+ tqdm==4.67.1
162
+ # via huggingface-hub
163
+ triton==3.4.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
164
+ # via torch
165
+ typing-extensions==4.15.0
166
+ # via
167
+ # huggingface-hub
168
+ # torch
169
+ urllib3==2.6.0
170
+ # via requests
171
+ wadler-lindig==0.1.7
172
+ # via jaxtyping
src/sharp/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """For licensing see accompanying LICENSE file.
2
+
3
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
4
+ """
src/sharp/cli/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Command-line-interface to run SHARP model.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ import click
8
+
9
+ from . import predict, render
10
+
11
+
12
+ @click.group()
13
+ def main_cli():
14
+ """Run inference for SHARP model."""
15
+ pass
16
+
17
+
18
+ main_cli.add_command(predict.predict_cli, "predict")
19
+ main_cli.add_command(render.render_cli, "render")
src/sharp/cli/predict.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains `sharp predict` CLI implementation.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ import click
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torch.utils.data
17
+
18
+ from sharp.models import (
19
+ PredictorParams,
20
+ RGBGaussianPredictor,
21
+ create_predictor,
22
+ )
23
+ from sharp.utils import io
24
+ from sharp.utils import logging as logging_utils
25
+ from sharp.utils.gaussians import (
26
+ Gaussians3D,
27
+ SceneMetaData,
28
+ save_ply,
29
+ unproject_gaussians,
30
+ )
31
+
32
+ from .render import render_gaussians
33
+
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+ DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
37
+
38
+
39
+ @click.command()
40
+ @click.option(
41
+ "-i",
42
+ "--input-path",
43
+ type=click.Path(path_type=Path, exists=True),
44
+ help="Path to an image or containing a list of images.",
45
+ required=True,
46
+ )
47
+ @click.option(
48
+ "-o",
49
+ "--output-path",
50
+ type=click.Path(path_type=Path, file_okay=False),
51
+ help="Path to save the predicted Gaussians and renderings.",
52
+ required=True,
53
+ )
54
+ @click.option(
55
+ "-c",
56
+ "--checkpoint-path",
57
+ type=click.Path(path_type=Path, dir_okay=False),
58
+ default=None,
59
+ help="Path to the .pt checkpoint. If not provided, downloads the default model automatically.",
60
+ required=False,
61
+ )
62
+ @click.option(
63
+ "--render/--no-render",
64
+ "with_rendering",
65
+ is_flag=True,
66
+ default=False,
67
+ help="Whether to render trajectory for checkpoint.",
68
+ )
69
+ @click.option(
70
+ "--device",
71
+ type=str,
72
+ default="default",
73
+ help="Device to run on. ['cpu', 'mps', 'cuda']",
74
+ )
75
+ @click.option("-v", "--verbose", is_flag=True, help="Activate debug logs.")
76
+ def predict_cli(
77
+ input_path: Path,
78
+ output_path: Path,
79
+ checkpoint_path: Path,
80
+ with_rendering: bool,
81
+ device: str,
82
+ verbose: bool,
83
+ ):
84
+ """Predict Gaussians from input images."""
85
+ logging_utils.configure(logging.DEBUG if verbose else logging.INFO)
86
+
87
+ extensions = io.get_supported_image_extensions()
88
+
89
+ image_paths = []
90
+ if input_path.is_file():
91
+ if input_path.suffix in extensions:
92
+ image_paths = [input_path]
93
+ else:
94
+ for ext in extensions:
95
+ image_paths.extend(list(input_path.glob(f"**/*{ext}")))
96
+
97
+ if len(image_paths) == 0:
98
+ LOGGER.info("No valid images found. Input was %s.", input_path)
99
+ return
100
+
101
+ LOGGER.info("Processing %d valid image files.", len(image_paths))
102
+
103
+ if device == "default":
104
+ if torch.cuda.is_available():
105
+ device = "cuda"
106
+ elif torch.mps.is_available():
107
+ device = "mps"
108
+ else:
109
+ device = "cpu"
110
+ LOGGER.info("Using device %s", device)
111
+
112
+ if with_rendering and device != "cuda":
113
+ LOGGER.warning("Can only run rendering with gsplat on CUDA. Rendering is disabled.")
114
+ with_rendering = False
115
+
116
+ # Load or download checkpoint
117
+ if checkpoint_path is None:
118
+ LOGGER.info("No checkpoint provided. Downloading default model from %s", DEFAULT_MODEL_URL)
119
+ state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
120
+ else:
121
+ LOGGER.info("Loading checkpoint from %s", checkpoint_path)
122
+ state_dict = torch.load(checkpoint_path, weights_only=True)
123
+
124
+ gaussian_predictor = create_predictor(PredictorParams())
125
+ gaussian_predictor.load_state_dict(state_dict)
126
+ gaussian_predictor.eval()
127
+ gaussian_predictor.to(device)
128
+
129
+ output_path.mkdir(exist_ok=True, parents=True)
130
+
131
+ for image_path in image_paths:
132
+ LOGGER.info("Processing %s", image_path)
133
+ image, _, f_px = io.load_rgb(image_path)
134
+ height, width = image.shape[:2]
135
+ intrinsics = torch.tensor(
136
+ [
137
+ [f_px, 0, (width - 1) / 2.0, 0],
138
+ [0, f_px, (height - 1) / 2.0, 0],
139
+ [0, 0, 1, 0],
140
+ [0, 0, 0, 1],
141
+ ],
142
+ device=device,
143
+ dtype=torch.float32,
144
+ )
145
+ gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device))
146
+
147
+ LOGGER.info("Saving 3DGS to %s", output_path)
148
+ save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply")
149
+
150
+ if with_rendering:
151
+ output_video_path = (output_path / image_path.stem).with_suffix(".mp4")
152
+ LOGGER.info("Rendering trajectory to %s", output_video_path)
153
+
154
+ metadata = SceneMetaData(intrinsics[0, 0].item(), (width, height), "linearRGB")
155
+ render_gaussians(gaussians, metadata, output_video_path)
156
+
157
+
158
+ @torch.no_grad()
159
+ def predict_image(
160
+ predictor: RGBGaussianPredictor,
161
+ image: np.ndarray,
162
+ f_px: float,
163
+ device: torch.device,
164
+ ) -> Gaussians3D:
165
+ """Predict Gaussians from an image."""
166
+ internal_shape = (1536, 1536)
167
+
168
+ LOGGER.info("Running preprocessing.")
169
+ image_pt = torch.from_numpy(image.copy()).float().to(device).permute(2, 0, 1) / 255.0
170
+ _, height, width = image_pt.shape
171
+ disparity_factor = torch.tensor([f_px / width]).float().to(device)
172
+
173
+ image_resized_pt = F.interpolate(
174
+ image_pt[None],
175
+ size=(internal_shape[1], internal_shape[0]),
176
+ mode="bilinear",
177
+ align_corners=True,
178
+ )
179
+
180
+ # Predict Gaussians in the NDC space.
181
+ LOGGER.info("Running inference.")
182
+ gaussians_ndc = predictor(image_resized_pt, disparity_factor)
183
+
184
+ LOGGER.info("Running postprocessing.")
185
+ intrinsics = (
186
+ torch.tensor(
187
+ [
188
+ [f_px, 0, width / 2, 0],
189
+ [0, f_px, height / 2, 0],
190
+ [0, 0, 1, 0],
191
+ [0, 0, 0, 1],
192
+ ]
193
+ )
194
+ .float()
195
+ .to(device)
196
+ )
197
+ intrinsics_resized = intrinsics.clone()
198
+ intrinsics_resized[0] *= internal_shape[0] / width
199
+ intrinsics_resized[1] *= internal_shape[1] / height
200
+
201
+ # Convert Gaussians to metrics space.
202
+ gaussians = unproject_gaussians(
203
+ gaussians_ndc, torch.eye(4).to(device), intrinsics_resized, internal_shape
204
+ )
205
+
206
+ return gaussians
src/sharp/cli/render.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains `sharp render` CLI implementation.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ import click
13
+ import torch
14
+ import torch.utils.data
15
+
16
+ from sharp.utils import camera, gsplat, io
17
+ from sharp.utils import logging as logging_utils
18
+ from sharp.utils.gaussians import Gaussians3D, SceneMetaData, load_ply
19
+
20
+ LOGGER = logging.getLogger(__name__)
21
+
22
+
23
+ @click.command()
24
+ @click.option(
25
+ "-i",
26
+ "--input-path",
27
+ type=click.Path(exists=True, path_type=Path),
28
+ help="Path to the ply or a list of plys.",
29
+ required=True,
30
+ )
31
+ @click.option(
32
+ "-o",
33
+ "--output-path",
34
+ type=click.Path(path_type=Path, file_okay=False),
35
+ help="Path to save the rendered videos.",
36
+ required=True,
37
+ )
38
+ @click.option("-v", "--verbose", is_flag=True, help="Activate debug logs.")
39
+ def render_cli(input_path: Path, output_path: Path, verbose: bool):
40
+ """Predict Gaussians from input images."""
41
+ logging_utils.configure(logging.DEBUG if verbose else logging.INFO)
42
+
43
+ if not torch.cuda.is_available():
44
+ LOGGER.error("Rendering a checkpoint requires CUDA.")
45
+ exit(1)
46
+
47
+ output_path.mkdir(exist_ok=True, parents=True)
48
+
49
+ params = camera.TrajectoryParams()
50
+
51
+ if input_path.suffix == ".ply":
52
+ scene_paths = [input_path]
53
+ elif input_path.is_dir():
54
+ scene_paths = list(input_path.glob("*.ply"))
55
+ else:
56
+ LOGGER.error("Input path must be either directory or single PLY file.")
57
+ exit(1)
58
+
59
+ for scene_path in scene_paths:
60
+ LOGGER.info("Rendering %s", scene_path)
61
+ gaussians, metadata = load_ply(scene_path)
62
+ render_gaussians(
63
+ gaussians=gaussians,
64
+ metadata=metadata,
65
+ params=params,
66
+ output_path=(output_path / scene_path.stem).with_suffix(".mp4"),
67
+ )
68
+
69
+
70
+ def render_gaussians(
71
+ gaussians: Gaussians3D,
72
+ metadata: SceneMetaData,
73
+ output_path: Path,
74
+ params: camera.TrajectoryParams | None = None,
75
+ ) -> None:
76
+ """Render a single gaussian checkpoint file."""
77
+ (width, height) = metadata.resolution_px
78
+ f_px = metadata.focal_length_px
79
+
80
+ if params is None:
81
+ params = camera.TrajectoryParams()
82
+
83
+ if not torch.cuda.is_available():
84
+ raise RuntimeError("Rendering a checkpoint requires CUDA.")
85
+
86
+ device = torch.device("cuda")
87
+
88
+ intrinsics = torch.tensor(
89
+ [
90
+ [f_px, 0, (width - 1) / 2., 0],
91
+ [0, f_px, (height - 1) / 2., 0],
92
+ [0, 0, 1, 0],
93
+ [0, 0, 0, 1],
94
+ ],
95
+ device=device,
96
+ dtype=torch.float32,
97
+ )
98
+ camera_model = camera.create_camera_model(
99
+ gaussians, intrinsics, resolution_px=metadata.resolution_px
100
+ )
101
+
102
+ trajectory = camera.create_eye_trajectory(
103
+ gaussians, params, resolution_px=metadata.resolution_px, f_px=f_px
104
+ )
105
+ renderer = gsplat.GSplatRenderer(color_space=metadata.color_space)
106
+ video_writer = io.VideoWriter(output_path)
107
+
108
+ for _, eye_position in enumerate(trajectory):
109
+ camera_info = camera_model.compute(eye_position)
110
+ rendering_output = renderer(
111
+ gaussians.to(device),
112
+ extrinsics=camera_info.extrinsics[None].to(device),
113
+ intrinsics=camera_info.intrinsics[None].to(device),
114
+ image_width=camera_info.width,
115
+ image_height=camera_info.height,
116
+ )
117
+ color = (rendering_output.color[0].permute(1, 2, 0) * 255.0).to(dtype=torch.uint8)
118
+ depth = rendering_output.depth[0]
119
+ video_writer.add_frame(color, depth)
120
+ video_writer.close()
src/sharp/models/__init__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains different Gaussian predictors.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from sharp.models.monodepth import (
10
+ create_monodepth_adaptor,
11
+ create_monodepth_dpt,
12
+ )
13
+
14
+ from .alignment import create_alignment
15
+ from .composer import GaussianComposer
16
+ from .gaussian_decoder import create_gaussian_decoder
17
+ from .heads import DirectPredictionHead
18
+ from .initializer import create_initializer
19
+ from .params import PredictorParams
20
+ from .predictor import RGBGaussianPredictor
21
+
22
+
23
+ def create_predictor(params: PredictorParams) -> RGBGaussianPredictor:
24
+ """Create gaussian predictor model specified by name."""
25
+ if params.gaussian_decoder.stride < params.initializer.stride:
26
+ raise ValueError(
27
+ "We donot expected gaussian_decoder has higher resolution than initializer."
28
+ )
29
+
30
+ scale_factor = params.gaussian_decoder.stride // params.initializer.stride
31
+ gaussian_composer = GaussianComposer(
32
+ delta_factor=params.delta_factor,
33
+ min_scale=params.min_scale,
34
+ max_scale=params.max_scale,
35
+ color_activation_type=params.color_activation_type,
36
+ opacity_activation_type=params.opacity_activation_type,
37
+ color_space=params.color_space,
38
+ scale_factor=scale_factor,
39
+ base_scale_on_predicted_mean=params.base_scale_on_predicted_mean,
40
+ )
41
+ if params.num_monodepth_layers > 1 and params.initializer.num_layers != 2:
42
+ raise KeyError("We only support num_layers = 2 when num_monodepth_layers > 1.")
43
+
44
+ monodepth_model = create_monodepth_dpt(params.monodepth)
45
+ monodepth_adaptor = create_monodepth_adaptor(
46
+ monodepth_model,
47
+ params.monodepth_adaptor,
48
+ params.num_monodepth_layers,
49
+ params.sorting_monodepth,
50
+ )
51
+
52
+ if params.num_monodepth_layers == 2:
53
+ monodepth_adaptor.replicate_head(params.num_monodepth_layers)
54
+
55
+ gaussian_decoder = create_gaussian_decoder(
56
+ params.gaussian_decoder,
57
+ dims_depth_features=monodepth_adaptor.get_feature_dims(),
58
+ )
59
+ initializer = create_initializer(
60
+ params.initializer,
61
+ )
62
+ prediction_head = DirectPredictionHead(
63
+ feature_dim=gaussian_decoder.dim_out, num_layers=initializer.num_layers
64
+ )
65
+ decoder_dim = monodepth_model.decoder.dims_decoder[-1]
66
+ return RGBGaussianPredictor(
67
+ init_model=initializer,
68
+ feature_model=gaussian_decoder,
69
+ prediction_head=prediction_head,
70
+ monodepth_model=monodepth_adaptor,
71
+ gaussian_composer=gaussian_composer,
72
+ scale_map_estimator=create_alignment(params.depth_alignment, depth_decoder_dim=decoder_dim),
73
+ )
74
+
75
+
76
+ __all__ = [
77
+ "PredictorParams",
78
+ "create_predictor",
79
+ ]
src/sharp/models/alignment.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains modules for different types of alignment.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+
15
+ from sharp.models.decoders import UNetDecoder
16
+ from sharp.models.encoders import UNetEncoder
17
+ from sharp.utils import math as math_utils
18
+
19
+ from .params import AlignmentParams
20
+
21
+
22
+ def create_alignment(
23
+ params: AlignmentParams, depth_decoder_dim: int | None = None
24
+ ) -> nn.Module | None:
25
+ """Create depth alignment."""
26
+ if depth_decoder_dim is None:
27
+ raise ValueError("Requires depth_decoder_dim for LearnedAlignment.")
28
+ alignment = LearnedAlignment(
29
+ depth_decoder_features=params.depth_decoder_features,
30
+ depth_decoder_dim=depth_decoder_dim,
31
+ steps=params.steps,
32
+ stride=params.stride,
33
+ base_width=params.base_width,
34
+ activation_type=params.activation_type,
35
+ )
36
+
37
+ if params.frozen:
38
+ alignment.requires_grad_(False)
39
+
40
+ return alignment
41
+
42
+
43
+ class LearnedAlignment(nn.Module):
44
+ """Aligns tensors using a UNet."""
45
+
46
+ def __init__(
47
+ self,
48
+ steps: int = 4,
49
+ stride: int = 8,
50
+ base_width: int = 16,
51
+ depth_decoder_features: bool = False,
52
+ depth_decoder_dim: int = 256,
53
+ activation_type: math_utils.ActivationType = "exp",
54
+ ) -> None:
55
+ """Initialize LearnedAlignment.
56
+
57
+ Args:
58
+ steps: Number of steps in the UNet.
59
+ stride: Effective downsampling of the alignment module.
60
+ base_width: Base width of the UNet.
61
+ depth_decoder_features: Whether to use depth decoder features.
62
+ depth_decoder_dim: Dimension of the depth decoder features.
63
+ activation_type: Activation type for the alignment output.
64
+ """
65
+ super().__init__()
66
+ self.activation = math_utils.create_activation_pair(activation_type)
67
+ bias_value = self.activation.inverse(torch.tensor(1.0))
68
+
69
+ self.depth_decoder_features = depth_decoder_features
70
+ if depth_decoder_features:
71
+ dim_in = 2 + depth_decoder_dim
72
+ else:
73
+ dim_in = 2
74
+
75
+ def is_power_of_two(n: int) -> bool:
76
+ """Check if a number is a power of two."""
77
+ if n <= 0:
78
+ return False
79
+ return (n & (n - 1)) == 0
80
+
81
+ if not is_power_of_two(stride):
82
+ raise ValueError(f"Stride {stride} is not a power of two.")
83
+
84
+ steps_decoder = steps - int(math.log2(stride))
85
+ if steps_decoder < 1:
86
+ raise ValueError(f"{steps_decoder} must be greater or equal to 1.")
87
+ widths = [min(base_width << i, 1024) for i in range(steps + 1)]
88
+ self.encoder = UNetEncoder(dim_in=dim_in, width=widths, steps=steps, norm_num_groups=4)
89
+ self.decoder = UNetDecoder(
90
+ dim_out=widths[0], width=widths, steps=steps_decoder, norm_num_groups=4
91
+ )
92
+ self.conv_out = nn.Conv2d(widths[0], 1, 1, bias=True)
93
+ nn.init.zeros_(self.conv_out.weight)
94
+ nn.init.constant_(self.conv_out.bias, bias_value)
95
+
96
+ def forward(
97
+ self,
98
+ tensor_src: torch.Tensor,
99
+ tensor_tgt: torch.Tensor,
100
+ depth_decoder_features: torch.Tensor | None = None,
101
+ ) -> torch.Tensor:
102
+ """Compute alignment map."""
103
+ # Since the tensors are usually given by depth which is >= 1.0, we invert
104
+ # the tensors to have them in a reasonable range.
105
+ tensor_src = 1.0 / tensor_src.clamp(min=1e-4)
106
+ tensor_tgt = 1.0 / tensor_tgt.clamp(min=1e-4)
107
+ tensor_input = torch.cat([tensor_src, tensor_tgt], dim=1)
108
+ if self.depth_decoder_features:
109
+ height, width = tensor_src.shape[-2:]
110
+ upsampled_encodings = F.interpolate(
111
+ depth_decoder_features,
112
+ size=(height, width),
113
+ mode="bilinear",
114
+ )
115
+ tensor_input = torch.cat([tensor_input, upsampled_encodings], dim=1)
116
+ features = self.encoder(tensor_input)
117
+ output = self.conv_out(self.decoder(features))
118
+ alignment_map_lowres = self.activation.forward(output)
119
+ if alignment_map_lowres.shape[-2:] != tensor_src.shape[-2]:
120
+ alignment_map = F.interpolate(
121
+ alignment_map_lowres,
122
+ size=tensor_src.shape[-2:],
123
+ mode="bilinear",
124
+ align_corners=False,
125
+ )
126
+ return alignment_map
src/sharp/models/blocks.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains reusable network components.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Literal
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ NormLayerName = Literal["noop", "batch_norm", "group_norm", "instance_norm"]
15
+ UpsamplingMode = Literal["transposed_conv", "nearest", "bilinear"]
16
+
17
+
18
+ def norm_layer_2d(num_features: int, norm_type: NormLayerName, num_groups: int = 8) -> nn.Module:
19
+ """Create normalization layer."""
20
+ if norm_type == "noop":
21
+ return nn.Identity()
22
+ elif norm_type == "batch_norm":
23
+ return nn.BatchNorm2d(num_features=num_features)
24
+ elif norm_type == "group_norm":
25
+ return nn.GroupNorm(num_channels=num_features, num_groups=num_groups)
26
+ elif norm_type == "instance_norm":
27
+ return nn.InstanceNorm2d(num_features=num_features)
28
+ else:
29
+ raise ValueError(f"Invalid normalization layer type: {norm_type}")
30
+
31
+
32
+ def upsampling_layer(upsampling_mode: UpsamplingMode, scale_factor: int, dim_in: int) -> nn.Module:
33
+ """Create upsampling layer."""
34
+ if upsampling_mode == "transposed_conv":
35
+ return nn.ConvTranspose2d(
36
+ in_channels=dim_in,
37
+ out_channels=dim_in,
38
+ kernel_size=scale_factor,
39
+ stride=scale_factor,
40
+ padding=0,
41
+ bias=False,
42
+ )
43
+ elif upsampling_mode in ("nearest", "bilinear"):
44
+ return nn.Upsample(scale_factor=scale_factor, mode=upsampling_mode)
45
+ else:
46
+ raise ValueError(f"Invalid upsampling mode {upsampling_mode}.")
47
+
48
+
49
+ class ResidualBlock(nn.Module):
50
+ """Generic implementation of residual blocks.
51
+
52
+ This implements a generic residual block from
53
+
54
+ He et al. - Identity Mappings in Deep Residual Networks (2016),
55
+ https://arxiv.org/abs/1603.05027
56
+
57
+ which can be further customized via factory functions.
58
+ """
59
+
60
+ def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
61
+ """Initialize ResidualBlock."""
62
+ super().__init__()
63
+ self.residual = residual
64
+ self.shortcut = shortcut
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ """Apply residual block."""
68
+ delta_x = self.residual(x)
69
+
70
+ if self.shortcut is not None:
71
+ x = self.shortcut(x)
72
+
73
+ return x + delta_x
74
+
75
+
76
+ def residual_block_2d(
77
+ dim_in: int,
78
+ dim_out: int,
79
+ dim_hidden: int | None = None,
80
+ actvn: nn.Module | None = None,
81
+ norm_type: NormLayerName = "noop",
82
+ norm_num_groups: int = 8,
83
+ dilation: int = 1,
84
+ kernel_size: int = 3,
85
+ ):
86
+ """Create a simple 2D residual block."""
87
+ if actvn is None:
88
+ actvn = nn.ReLU()
89
+
90
+ if dim_hidden is None:
91
+ dim_hidden = dim_out // 2
92
+
93
+ # Padding to maintain output size
94
+ # See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
95
+ padding = (dilation * (kernel_size - 1)) // 2
96
+
97
+ def _create_block(dim_in: int, dim_out: int) -> list[nn.Module]:
98
+ layers = [
99
+ norm_layer_2d(dim_in, norm_type, num_groups=norm_num_groups),
100
+ actvn,
101
+ ]
102
+
103
+ layers.append(
104
+ nn.Conv2d(
105
+ dim_in,
106
+ dim_out,
107
+ kernel_size=kernel_size,
108
+ stride=1,
109
+ dilation=dilation,
110
+ padding=padding,
111
+ )
112
+ )
113
+ return layers
114
+
115
+ residual = nn.Sequential(
116
+ *_create_block(dim_in, dim_hidden),
117
+ *_create_block(dim_hidden, dim_out),
118
+ )
119
+ shortcut = None
120
+
121
+ if dim_in != dim_out:
122
+ shortcut = nn.Conv2d(dim_in, dim_out, 1)
123
+
124
+ return ResidualBlock(residual, shortcut)
125
+
126
+
127
+ class FeatureFusionBlock2d(nn.Module):
128
+ """Feature fusion for DPT."""
129
+
130
+ # We use the name "deconv" for backward compatibility. However, "deconv" can also
131
+ # refer to some other upsampling layer or a no-op.
132
+ deconv: nn.Module
133
+
134
+ def __init__(
135
+ self,
136
+ dim_in: int,
137
+ dim_out: int | None = None,
138
+ upsampling_mode: UpsamplingMode | None = None,
139
+ batch_norm: bool = False,
140
+ ):
141
+ """Initialize feature fusion block.
142
+
143
+ Args:
144
+ dim_in: Dimensions of input.
145
+ dim_out: Dimensions of output.
146
+ batch_norm: Whether to use batch normalization in resnet blocks.
147
+ upsampling_mode: What mode to use for upsampling. None if no upsampling
148
+ is required.
149
+ """
150
+ super().__init__()
151
+ if dim_out is None:
152
+ dim_out = dim_in
153
+ self.resnet1 = self._residual_block(dim_in, batch_norm)
154
+ self.resnet2 = self._residual_block(dim_in, batch_norm)
155
+
156
+ if upsampling_mode is not None:
157
+ self.deconv = upsampling_layer(upsampling_mode, scale_factor=2, dim_in=dim_in)
158
+ else:
159
+ self.deconv = nn.Sequential()
160
+
161
+ self.out_conv = nn.Conv2d(
162
+ dim_in,
163
+ dim_out,
164
+ kernel_size=1,
165
+ stride=1,
166
+ padding=0,
167
+ bias=True,
168
+ )
169
+
170
+ self.skip_add = nn.quantized.FloatFunctional()
171
+
172
+ def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
173
+ """Process and fuse input features."""
174
+ x = x0
175
+
176
+ if x1 is not None:
177
+ res = self.resnet1(x1)
178
+ x = self.skip_add.add(x, res)
179
+
180
+ x = self.resnet2(x)
181
+ x = self.deconv(x)
182
+ x = self.out_conv(x)
183
+
184
+ return x
185
+
186
+ @staticmethod
187
+ def _residual_block(num_features: int, batch_norm: bool):
188
+ """Create a residual block."""
189
+
190
+ def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
191
+ layers = [
192
+ nn.ReLU(False),
193
+ nn.Conv2d(
194
+ num_features,
195
+ num_features,
196
+ kernel_size=3,
197
+ stride=1,
198
+ padding=1,
199
+ bias=not batch_norm,
200
+ ),
201
+ ]
202
+ if batch_norm:
203
+ layers.append(nn.BatchNorm2d(dim))
204
+ return layers
205
+
206
+ residual = nn.Sequential(
207
+ *_create_block(dim=num_features, batch_norm=batch_norm),
208
+ *_create_block(dim=num_features, batch_norm=batch_norm),
209
+ )
210
+ return ResidualBlock(residual)
src/sharp/models/composer.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Defines module to compose final Gaussians from base values and delta values.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+
13
+ from sharp.models.initializer import GaussianBaseValues
14
+ from sharp.utils import math as math_utils
15
+ from sharp.utils.color_space import ColorSpace, sRGB2linearRGB
16
+ from sharp.utils.gaussians import Gaussians3D
17
+
18
+ from .params import DeltaFactor
19
+
20
+
21
+ def _get_scale_activation_constant(max_scale: float, min_scale: float) -> tuple[float, float]:
22
+ """Return constants for scale activation function."""
23
+ # To ensure for delta = 0, the value of scale_factor is 1 and the gradient is 1.
24
+ constant_a = (max_scale - min_scale) / (1 - min_scale) / (max_scale - 1)
25
+ constant_b = math_utils.inverse_sigmoid(
26
+ torch.tensor((1.0 - min_scale) / (max_scale - min_scale))
27
+ ).item()
28
+ return constant_a, constant_b
29
+
30
+
31
+ class GaussianComposer(nn.Module):
32
+ """Converts base values and deltas into Gaussians."""
33
+
34
+ color_activation_type: math_utils.ActivationType
35
+ opacity_activation_type: math_utils.ActivationType
36
+
37
+ def __init__(
38
+ self,
39
+ delta_factor: DeltaFactor,
40
+ min_scale: float,
41
+ max_scale: float,
42
+ color_activation_type: math_utils.ActivationType,
43
+ opacity_activation_type: math_utils.ActivationType,
44
+ color_space: ColorSpace,
45
+ base_scale_on_predicted_mean: bool,
46
+ scale_factor: int = 1,
47
+ ) -> None:
48
+ """Initialize GaussianComposer.
49
+
50
+ Args:
51
+ delta_factor: Multiply delta offsets by this factor.
52
+ min_scale: The minimal scale factor for gaussian scale activation.
53
+ max_scale: The maximal scale factor for gaussian scale activation.
54
+ color_activation_type: Which activation function to use for colors.
55
+ opacity_activation_type: Which activation function to use for opacities.
56
+ color_space: Which color space is used in training.
57
+ scale_factor: The scale factor to upsample the delta_values before composition.
58
+ base_scale_on_predicted_mean: Whether to account z offsets for estimating base scale.
59
+ """
60
+ super().__init__()
61
+ self.delta_factor = delta_factor
62
+ self.max_scale = max_scale
63
+ self.min_scale = min_scale
64
+ self.color_activation_type = color_activation_type
65
+ self.opacity_activation_type = opacity_activation_type
66
+ self.color_space = color_space
67
+ self.scale_factor = scale_factor
68
+ self.base_scale_on_predicted_mean = base_scale_on_predicted_mean
69
+
70
+ def upsample_delta_value(self, delta: torch.Tensor, scale_factor: int = 1):
71
+ """Upsample the delta value.
72
+
73
+ Args:
74
+ delta: The delta values predicted by gaussian predictor.
75
+ scale_factor: The scale factor to upsample the delta_values.
76
+ """
77
+ (
78
+ batch_size,
79
+ num_channels,
80
+ num_layers,
81
+ image_height,
82
+ image_width,
83
+ ) = delta.shape
84
+ new_height = image_height * scale_factor
85
+ new_width = image_width * scale_factor
86
+ upsampled_delta = F.interpolate(
87
+ delta.view(batch_size, num_channels * num_layers, image_height, image_width),
88
+ scale_factor=scale_factor,
89
+ ).view(batch_size, num_channels, num_layers, new_height, new_width)
90
+ return upsampled_delta
91
+
92
+ def forward(
93
+ self,
94
+ delta: torch.Tensor,
95
+ base_values: GaussianBaseValues,
96
+ global_scale: torch.Tensor | None = None,
97
+ flatten_output: bool = True,
98
+ ) -> Gaussians3D:
99
+ """Combine predicted delta values with base gaussian values and apply activation function.
100
+
101
+ Args:
102
+ delta: The delta values predicted by gaussian predictor.
103
+ base_values: The gaussian base values.
104
+ global_scale: Global scale of Gaussians.
105
+ flatten_output: Flatten the gaussian parameters.
106
+
107
+ Returns:
108
+ The computed 3D Gaussians.
109
+ """
110
+ # Upsample the delta if delta and base_values have different strides.
111
+ scale_factor = self.scale_factor
112
+ # For triplane head, the delta has already been upsampled.
113
+ actual_scale_factor = base_values.mean_x_ndc.shape[-1] // delta.shape[-1]
114
+ if scale_factor != 1 and actual_scale_factor != 1:
115
+ delta = self.upsample_delta_value(delta, scale_factor)
116
+
117
+ mean_vectors = self._forward_mean(base_values, delta)
118
+
119
+ # Account for the change in base scale due to z offsets.
120
+ base_scales = (
121
+ (base_values.scales * base_values.mean_inverse_z_ndc * mean_vectors[:, 2:3, ...])
122
+ if self.base_scale_on_predicted_mean
123
+ else base_values.scales
124
+ )
125
+ singular_values = self._scale_activation(
126
+ base_scales,
127
+ delta[:, 3:6],
128
+ self.min_scale,
129
+ self.max_scale,
130
+ )
131
+ quaternions = self._quaternion_activation(base_values.quaternions, delta[:, 6:10])
132
+ colors = self._color_activation(base_values.colors, delta[:, 10:13])
133
+ opacities = self._opacity_activation(base_values.opacities, delta[:, 13])
134
+
135
+ if flatten_output:
136
+ # [B, C, N, H, W] -> [B, N, H, W, C].
137
+ # NOTE: opacities is [B, N, H, W] so it doesn't need to permute.
138
+ mean_vectors = mean_vectors.permute(0, 2, 3, 4, 1).flatten(1, 3)
139
+ singular_values = singular_values.permute(0, 2, 3, 4, 1).flatten(1, 3)
140
+ quaternions = quaternions.permute(0, 2, 3, 4, 1).flatten(1, 3)
141
+ colors = colors.permute(0, 2, 3, 4, 1).flatten(1, 3)
142
+ opacities = opacities.flatten(1, 3)
143
+
144
+ # Apply global scaling to convert Gaussians to metric space.
145
+ if global_scale is not None:
146
+ mean_vectors = global_scale[:, None, None] * mean_vectors
147
+ singular_values = global_scale[:, None, None] * singular_values
148
+
149
+ return Gaussians3D(
150
+ mean_vectors=mean_vectors,
151
+ singular_values=singular_values,
152
+ quaternions=quaternions,
153
+ colors=colors,
154
+ opacities=opacities,
155
+ )
156
+
157
+ def _forward_mean(self, base_values: GaussianBaseValues, delta: torch.Tensor) -> torch.Tensor:
158
+ # Concatenate base vectors and apply mean activation.
159
+ delta_factor = torch.tensor(
160
+ [self.delta_factor.xy, self.delta_factor.xy, self.delta_factor.z],
161
+ device=delta.device,
162
+ )[None, :, None, None, None]
163
+
164
+ dtype = base_values.mean_x_ndc.dtype
165
+ device = base_values.mean_x_ndc.device
166
+ target_shape = (1, 3, 1, 1, 1)
167
+ mean_x_mask = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device).reshape(
168
+ target_shape
169
+ )
170
+ mean_y_mask = torch.tensor([0.0, 1.0, 0.0], dtype=dtype, device=device).reshape(
171
+ target_shape
172
+ )
173
+ mean_z_mask = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device).reshape(
174
+ target_shape
175
+ )
176
+
177
+ mean_vectors_ndc = (
178
+ base_values.mean_x_ndc.repeat(target_shape) * mean_x_mask
179
+ + base_values.mean_y_ndc.repeat(target_shape) * mean_y_mask
180
+ + base_values.mean_inverse_z_ndc.repeat(target_shape) * mean_z_mask
181
+ )
182
+
183
+ mean_vectors = self._mean_activation(mean_vectors_ndc, delta_factor * delta[:, :3])
184
+ return mean_vectors
185
+
186
+ def _mean_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor:
187
+ """Mean activation function.
188
+
189
+ Args:
190
+ base: Tensor of shape [B, 3, H, W], where first two feature dimensions
191
+ (x,y) are in normalized device coordinates (NDC) where (-1, -1) is
192
+ the top, while the third dimension is inverse depth.
193
+ learned_delta: Tensor of shape [B, 3, H, W] with predicted delta values.
194
+
195
+ Returns:
196
+ Returns: The final mean vector after combining base and delta and applying nonlinearies.
197
+ """
198
+ xx = base[:, 0:1] + learned_delta[:, 0:1]
199
+ yy = base[:, 1:2] + learned_delta[:, 1:2]
200
+
201
+ a = base[:, 2:3]
202
+ b = learned_delta[:, 2:3]
203
+
204
+ # Original formula:
205
+ inverse_zz = F.softplus(math_utils.inverse_softplus(a) + b)
206
+ zz = 1.0 / (inverse_zz + 1e-3)
207
+
208
+ mean_vectors = torch.cat([zz * xx, zz * yy, zz], dim=1)
209
+ return mean_vectors
210
+
211
+ def _scale_activation(
212
+ self,
213
+ base: torch.Tensor,
214
+ learned_delta: torch.Tensor,
215
+ min_scale: float,
216
+ max_scale: float,
217
+ ) -> torch.Tensor:
218
+ constant_a, constant_b = _get_scale_activation_constant(max_scale, min_scale)
219
+ scale_factor = (max_scale - min_scale) * torch.sigmoid(
220
+ constant_a * self.delta_factor.scale * learned_delta + constant_b
221
+ ) + min_scale
222
+ return base * scale_factor
223
+
224
+ def _quaternion_activation(
225
+ self, base: torch.Tensor, learned_delta: torch.Tensor
226
+ ) -> torch.Tensor:
227
+ # No need to normalize the quaternions, since this is also done in rendering.
228
+ return base + self.delta_factor.quaternion * learned_delta
229
+
230
+ def _color_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor:
231
+ # For certain activation functions we need to clamp the base value to
232
+ # a supported range.
233
+ if self.color_activation_type == "sigmoid":
234
+ base = torch.clamp(base, min=0.01, max=0.99)
235
+ elif self.color_activation_type in ("exp", "softplus"):
236
+ base = torch.clamp(base, min=0.01)
237
+
238
+ activation = math_utils.create_activation_pair(self.color_activation_type)
239
+ colors: torch.Tensor = activation.forward(
240
+ activation.inverse(base) + self.delta_factor.color * learned_delta
241
+ )
242
+ # Convert gaussian color to linear if linearRGB colorspace is specified.
243
+ if self.color_space == "linearRGB":
244
+ colors = sRGB2linearRGB(colors)
245
+ return colors
246
+
247
+ def _opacity_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor:
248
+ activation = math_utils.create_activation_pair(self.opacity_activation_type)
249
+ return activation.forward(
250
+ activation.inverse(base) + self.delta_factor.opacity * learned_delta
251
+ )
src/sharp/models/decoders/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains different decoders for Gaussian predictor.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from .base_decoder import BaseDecoder
10
+ from .monodepth_decoder import (
11
+ create_monodepth_decoder,
12
+ )
13
+ from .multires_conv_decoder import MultiresConvDecoder, UpsamplingMode
14
+ from .unet_decoder import UNetDecoder
15
+
16
+ __all__ = [
17
+ "BaseDecoder",
18
+ "UNetDecoder",
19
+ "MultiresConvDecoder",
20
+ "UpsamplingMode",
21
+ "create_monodepth_decoder",
22
+ ]
src/sharp/models/decoders/base_decoder.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains the base class for decoders.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ import abc
8
+ from typing import List
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ class BaseDecoder(nn.Module, abc.ABC):
15
+ """Base decoder class."""
16
+
17
+ dim_out: int
18
+
19
+ @abc.abstractmethod
20
+ def forward(self, encodings: List[torch.Tensor]) -> torch.Tensor:
21
+ """Decode (multi-resolution) encodings."""
src/sharp/models/decoders/monodepth_decoder.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains factory function for loading/creating monodepth decoder.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+
8
+ from __future__ import annotations
9
+
10
+ from sharp.models.presets import (
11
+ MONODEPTH_ENCODER_DIMS_MAP,
12
+ ViTPreset,
13
+ )
14
+
15
+ from .multires_conv_decoder import MultiresConvDecoder
16
+
17
+
18
+ def create_monodepth_decoder(
19
+ patch_encoder_preset: ViTPreset,
20
+ dims_decoder=None,
21
+ ) -> MultiresConvDecoder:
22
+ """Create DepthDensePredictionTransformer model.
23
+
24
+ Args:
25
+ patch_encoder_preset: The preset patch encoder architecture in SPN.
26
+ dims_decoder: The decoder architecture.
27
+ """
28
+ dims_encoder = MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset]
29
+ if dims_decoder is None:
30
+ dims_decoder = dims_encoder[0]
31
+ if isinstance(dims_decoder, int):
32
+ dims_decoder = [dims_decoder]
33
+ decoder = MultiresConvDecoder(
34
+ dims_encoder=[dims_decoder[0]] + list(dims_encoder), dims_decoder=dims_decoder
35
+ )
36
+
37
+ return decoder
src/sharp/models/decoders/multires_conv_decoder.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains multi-res convolutional decoder.
2
+
3
+ Implements the decoder for Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
4
+
5
+ For licensing see accompanying LICENSE file.
6
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Iterable
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from sharp.models.blocks import FeatureFusionBlock2d, UpsamplingMode
17
+ from sharp.utils.training import checkpoint_wrapper
18
+
19
+ from .base_decoder import BaseDecoder
20
+
21
+
22
+ class MultiresConvDecoder(BaseDecoder):
23
+ """Decoder for multi-resolution encodings."""
24
+
25
+ def __init__(
26
+ self,
27
+ dims_encoder: Iterable[int],
28
+ dims_decoder: Iterable[int] | int,
29
+ grad_checkpointing: bool = False,
30
+ upsampling_mode: UpsamplingMode = "transposed_conv",
31
+ ):
32
+ """Initialize multiresolution convolutional decoder.
33
+
34
+ Args:
35
+ dims_encoder: Expected dims at each level from the encoder.
36
+ dims_decoder: Dim of decoder features.
37
+ grad_checkpointing: Whether to checkpoint gradient during training.
38
+ upsampling_mode: What method to use for upsampling.
39
+ """
40
+ super().__init__()
41
+ self.dims_encoder = list(dims_encoder)
42
+
43
+ if isinstance(dims_decoder, int):
44
+ self.dims_decoder = [dims_decoder] * len(self.dims_encoder)
45
+ else:
46
+ self.dims_decoder = list(dims_decoder)
47
+
48
+ if len(self.dims_decoder) != len(self.dims_encoder):
49
+ raise ValueError("Received dims_encoder and dims_decoder of different sizes.")
50
+
51
+ self.dim_out = self.dims_decoder[0]
52
+
53
+ num_encoders = len(self.dims_encoder)
54
+
55
+ # At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
56
+ # when the dimensions mismatch. Otherwise we do not do anything, which is
57
+ # the default behavior of monodepth.
58
+ conv0 = (
59
+ nn.Conv2d(self.dims_encoder[0], self.dims_decoder[0], kernel_size=1, bias=False)
60
+ if self.dims_encoder[0] != self.dims_decoder[0]
61
+ else nn.Identity()
62
+ )
63
+
64
+ convs = [conv0]
65
+ for i in range(1, num_encoders):
66
+ convs.append(
67
+ nn.Conv2d(
68
+ self.dims_encoder[i],
69
+ self.dims_decoder[i],
70
+ kernel_size=3,
71
+ stride=1,
72
+ padding=1,
73
+ bias=False,
74
+ )
75
+ )
76
+ self.convs = nn.ModuleList(convs)
77
+
78
+ fusions = []
79
+ for i in range(num_encoders):
80
+ fusions.append(
81
+ FeatureFusionBlock2d(
82
+ dim_in=self.dims_decoder[i],
83
+ dim_out=self.dims_decoder[i - 1] if i != 0 else self.dim_out,
84
+ upsampling_mode=upsampling_mode if i != 0 else None,
85
+ batch_norm=False,
86
+ )
87
+ )
88
+ self.fusions = nn.ModuleList(fusions)
89
+
90
+ self.grad_checkpointing = grad_checkpointing
91
+
92
+ @torch.jit.ignore
93
+ def set_grad_checkpointing(self, is_enabled=True):
94
+ """Enable grad checkpointing."""
95
+ self.grad_checkpointing = is_enabled
96
+
97
+ def forward(self, encodings: list[torch.Tensor]) -> torch.Tensor:
98
+ """Decode the multi-resolution encodings."""
99
+ num_levels = len(encodings)
100
+ num_encoders = len(self.dims_encoder)
101
+
102
+ if num_levels != num_encoders:
103
+ raise ValueError(
104
+ f"Encoder output levels={num_levels} at runtime "
105
+ f"mismatch with expected levels={num_encoders}."
106
+ )
107
+
108
+ # Project features of different encoder dims to the same decoder dim.
109
+ # Fuse features from the lowest resolution (num_levels-1)
110
+ # to the highest (0).
111
+ features = self.convs[-1](encodings[-1])
112
+ features = checkpoint_wrapper(self, self.fusions[-1], features)
113
+ for i in range(num_levels - 2, -1, -1):
114
+ features_i = self.convs[i](encodings[i])
115
+ features = checkpoint_wrapper(self, self.fusions[i], features, features_i)
116
+ return features
src/sharp/models/decoders/unet_decoder.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains the UNet decoder.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import List
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from sharp.models.blocks import (
15
+ NormLayerName,
16
+ norm_layer_2d,
17
+ residual_block_2d,
18
+ )
19
+
20
+ from .base_decoder import BaseDecoder
21
+
22
+
23
+ class UNetDecoder(BaseDecoder):
24
+ """Decoder of UNet model."""
25
+
26
+ def __init__(
27
+ self,
28
+ dim_out: int,
29
+ width: List[int] | int,
30
+ steps: int = 5,
31
+ norm_type: NormLayerName = "group_norm",
32
+ norm_num_groups=8,
33
+ blocks_per_layer=2,
34
+ ) -> None:
35
+ """Initialize UNet Decoder.
36
+
37
+ Args:
38
+ dim_out: The number of output channels.
39
+ width: Width of last input feature map from encoder
40
+ or the width list of all input feature maps from encoder.
41
+ steps: The number of upsampling steps.
42
+ norm_type: Which kind of normalization layer to use.
43
+ norm_num_groups: How many groups to use for group norm (if relevant).
44
+ blocks_per_layer: How many blocks per layer to use.
45
+ """
46
+ super().__init__()
47
+
48
+ if blocks_per_layer < 1:
49
+ raise ValueError("blocks_per_layer must be greater or equal to one.")
50
+
51
+ self.dim_out = dim_out
52
+
53
+ self.convs_up = nn.ModuleList()
54
+
55
+ self.output_dims: list[int]
56
+ # If only one number is specified, we assume each layer will double the channel dimension.
57
+ if isinstance(width, int):
58
+ self.input_dims = [width >> i for i in range(0, steps + 1)]
59
+ else:
60
+ self.input_dims = width[::-1][: steps + 1]
61
+
62
+ for i_step in range(steps):
63
+ input_width = self.input_dims[i_step]
64
+ current_width = self.input_dims[i_step + 1]
65
+ convs_up_i = nn.Sequential(
66
+ nn.Upsample(scale_factor=2),
67
+ residual_block_2d(
68
+ input_width * (1 if i_step == 0 else 2),
69
+ current_width,
70
+ norm_type=norm_type,
71
+ norm_num_groups=norm_num_groups,
72
+ ),
73
+ *[
74
+ residual_block_2d(
75
+ current_width,
76
+ current_width,
77
+ norm_type=norm_type,
78
+ norm_num_groups=norm_num_groups,
79
+ )
80
+ for _ in range(blocks_per_layer - 1)
81
+ ],
82
+ )
83
+ self.convs_up.append(convs_up_i)
84
+ input_width = 2 * current_width
85
+ current_width //= 2
86
+
87
+ last_width = self.input_dims[-1]
88
+ self.conv_out = nn.Sequential(
89
+ norm_layer_2d(last_width * 2, norm_type, num_groups=norm_num_groups),
90
+ nn.ReLU(),
91
+ nn.Conv2d(last_width * 2, dim_out, 1),
92
+ norm_layer_2d(dim_out, norm_type, num_groups=norm_num_groups),
93
+ nn.ReLU(),
94
+ )
95
+
96
+ def forward(self, features: list[torch.Tensor]) -> torch.Tensor:
97
+ """Apply UNet to image.
98
+
99
+ Args:
100
+ features: The input multi-level feature map from encoder.
101
+
102
+ Returns:
103
+ The output feature map.
104
+ """
105
+ i_feature_layer = len(features) - 1
106
+ out = self.convs_up[0](features[i_feature_layer])
107
+ i_feature_layer -= 1
108
+ for conv_up in self.convs_up[1:]: # type: ignore
109
+ out = conv_up(torch.cat([out, features[i_feature_layer]], dim=1))
110
+ i_feature_layer -= 1
111
+ out = self.conv_out(torch.cat([out, features[i_feature_layer]], dim=1))
112
+
113
+ return out
src/sharp/models/encoders/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains different encoders for Gaussian predictor.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from sharp.models.encoders.base_encoder import BaseEncoder
8
+
9
+ from .monodepth_encoder import (
10
+ MonodepthFeatureEncoder,
11
+ create_monodepth_encoder,
12
+ )
13
+ from .spn_encoder import SlidingPyramidNetwork
14
+ from .unet_encoder import UNetEncoder
15
+ from .vit_encoder import create_vit
16
+
17
+ __all__ = [
18
+ "create_vit",
19
+ "BaseEncoder",
20
+ "UNetEncoder",
21
+ "SlidingPyramidNetwork",
22
+ "MonodepthFeatureEncoder",
23
+ "create_monodepth_encoder",
24
+ ]
src/sharp/models/encoders/base_encoder.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains the base class for encoders.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ import abc
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class BaseEncoder(nn.Module, abc.ABC):
14
+ """Base encoder class."""
15
+
16
+ dim_in: int
17
+ output_dims: list[int]
18
+
19
+ @abc.abstractmethod
20
+ def forward(self, image: torch.Tensor) -> list[torch.Tensor]:
21
+ """Encode input image into multi-resolution encodings."""
22
+
23
+ def internal_resolution(self) -> int:
24
+ """Internal resolution of the encoder."""
25
+ return 1536
src/sharp/models/encoders/monodepth_encoder.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains Dense Transformer Prediction architecture.
2
+
3
+ Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
4
+
5
+ For licensing see accompanying LICENSE file.
6
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from sharp.models.presets import (
15
+ MONODEPTH_ENCODER_DIMS_MAP,
16
+ MONODEPTH_HOOK_IDS_MAP,
17
+ ViTPreset,
18
+ )
19
+
20
+ from .base_encoder import BaseEncoder
21
+ from .spn_encoder import SlidingPyramidNetwork
22
+ from .vit_encoder import create_vit
23
+
24
+
25
+ def create_monodepth_encoder(
26
+ patch_encoder_preset: ViTPreset,
27
+ image_encoder_preset: ViTPreset,
28
+ use_patch_overlap: bool = True,
29
+ last_encoder: int = 256,
30
+ ) -> SlidingPyramidNetwork:
31
+ """Creates DepthDensePredictionTransformer model.
32
+
33
+ Args:
34
+ patch_encoder_preset: The preset patch encoder architecture in SPN.
35
+ image_encoder_preset: The preset image encoder architecture in SPN.
36
+ use_patch_overlap: Whether to use overlap between patches in SPN.
37
+ last_encoder: last number of encoder features.
38
+ """
39
+ dims_encoder = [last_encoder] + MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset]
40
+ patch_encoder_block_ids = MONODEPTH_HOOK_IDS_MAP[patch_encoder_preset]
41
+
42
+ patch_encoder = create_vit(
43
+ preset=patch_encoder_preset,
44
+ intermediate_features_ids=patch_encoder_block_ids,
45
+ # We always need to output intermediate features for assembly.
46
+ )
47
+ image_encoder = create_vit(
48
+ preset=image_encoder_preset,
49
+ intermediate_features_ids=None,
50
+ )
51
+
52
+ encoder = SlidingPyramidNetwork(
53
+ dims_encoder=dims_encoder,
54
+ patch_encoder=patch_encoder,
55
+ image_encoder=image_encoder,
56
+ use_patch_overlap=use_patch_overlap,
57
+ )
58
+
59
+ return encoder
60
+
61
+
62
+ class ProjectionModule(nn.Module):
63
+ """Apply projection of features."""
64
+
65
+ def __init__(self, dims_in: list[int], dims_out: list[int]) -> None:
66
+ """Initialize projection module."""
67
+ super().__init__()
68
+ if len(dims_in) != len(dims_out):
69
+ raise ValueError("Length of dims_in must be same as length of dims_out.")
70
+ self.convs = nn.ModuleList(
71
+ [nn.Conv2d(dim_in, dim_out, 1) for dim_in, dim_out in zip(dims_in, dims_out)]
72
+ )
73
+
74
+ def forward(self, encodings: list[torch.Tensor]) -> list[torch.Tensor]:
75
+ """Apply projection module."""
76
+ if len(encodings) != len(self.convs):
77
+ raise ValueError("Number of encodings must be equal to number of projections.")
78
+ return [conv(encoding) for conv, encoding in zip(self.convs, encodings)]
79
+
80
+
81
+ class MonodepthFeatureEncoder(BaseEncoder):
82
+ """A wrapper around monodepth network to extract features."""
83
+
84
+ def __init__(
85
+ self,
86
+ monodepth_encoder: SlidingPyramidNetwork,
87
+ output_dims: list[int] | None = None,
88
+ freeze_projection: bool = False,
89
+ ) -> None:
90
+ """Initialize MonodepthFeatureExtractor."""
91
+ super().__init__()
92
+
93
+ self.encoder = monodepth_encoder
94
+
95
+ # The monodepth network returns two feature maps for the first entry in
96
+ # backbone.encoder.dims_encoder.
97
+ monodepth_dims = self.encoder.dims_encoder
98
+ monodepth_dims = monodepth_dims
99
+
100
+ if output_dims is not None:
101
+ if not len(output_dims) == len(monodepth_dims):
102
+ raise ValueError(
103
+ "When set, number of output dimensions must be equal to output "
104
+ f"dimensions of monodepth model {len(monodepth_dims)}."
105
+ )
106
+
107
+ self.projection = ProjectionModule(monodepth_dims, output_dims)
108
+ self.output_dims = output_dims
109
+ else:
110
+ self.projection = nn.Identity()
111
+ self.output_dims = monodepth_dims
112
+
113
+ if freeze_projection:
114
+ self.projection.requires_grad_(False)
115
+
116
+ def forward(self, input_features: torch.Tensor) -> list[torch.Tensor]:
117
+ """Extract multi-resolution features."""
118
+ encodings = self.encoder(input_features[:, :3].contiguous())
119
+ return self.projection(encodings)
120
+
121
+ def internal_resolution(self) -> int:
122
+ """Internal resolution of the encoder."""
123
+ return self.encoder.internal_resolution()
src/sharp/models/encoders/spn_encoder.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains Sliding Pyramid Network architecture.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+ from typing import Iterable
11
+
12
+ import torch
13
+ import torch.fx
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from sharp.utils.training import checkpoint_wrapper
18
+
19
+ from .base_encoder import BaseEncoder
20
+ from .vit_encoder import TimmViT
21
+
22
+ # torch.fx.wrap is used here to mark functions as leaf nodes during symbolic tracing
23
+ # ensuring they are not traced but seen as atomic operation. In short, symbolic tracing
24
+ # struggles with native python functions and conditional flows.
25
+ non_traceable_ops = ("len", "int")
26
+ for op in non_traceable_ops:
27
+ torch.fx.wrap(op)
28
+
29
+
30
+ class SlidingPyramidNetwork(BaseEncoder):
31
+ """Sliding Pyramid Network.
32
+
33
+ An encoder aimed at creating multi-resolution encodings from Vision Transformers.
34
+
35
+ Reference: Bochkovskii et al. - "Depth pro: Sharp monocular metric depth in less
36
+ than a second." (ICLR 2024)
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ dims_encoder: Iterable[int],
42
+ patch_encoder: TimmViT,
43
+ image_encoder: TimmViT,
44
+ use_patch_overlap: bool = True,
45
+ ):
46
+ """Initialize Sliding Pyramid Network.
47
+
48
+ The framework
49
+ 1. creates an image pyramid,
50
+ 2. generates overlapping patches with a sliding window at each pyramid level,
51
+ 3. creates batched encodings via vision transformer backbones,
52
+ 4. produces multi-resolution encodings.
53
+
54
+ Args:
55
+ dims_encoder: Dimensions of the encoder at different layers.
56
+ patch_encoder: Backbone used for highres part of the pyramid.
57
+ image_encoder: Backbone used for lowres part of the pyramid.
58
+ use_patch_overlap: Whether to use overlap between patches in SPN.
59
+ """
60
+ super().__init__()
61
+
62
+ self.dim_in = patch_encoder.dim_in
63
+
64
+ self.dims_encoder = list(dims_encoder)
65
+ self.patch_encoder = patch_encoder
66
+ self.image_encoder = image_encoder
67
+
68
+ base_embed_dim = patch_encoder.embed_dim
69
+ lowres_embed_dim = image_encoder.embed_dim
70
+ self.patch_size = patch_encoder.internal_resolution()
71
+
72
+ self.grad_checkpointing = False
73
+ self.use_patch_overlap = use_patch_overlap
74
+
75
+ # Retrieve intermediate feature ids registered in create_monodepth_encoder.
76
+ self.patch_intermediate_features_ids = patch_encoder.intermediate_features_ids
77
+ if (
78
+ not isinstance(self.patch_intermediate_features_ids, list)
79
+ or not len(self.patch_intermediate_features_ids) == 4
80
+ ):
81
+ raise ValueError("Patch intermediate feature ids must be a 4-item list.")
82
+
83
+ self.image_intermediate_features_ids = image_encoder.intermediate_features_ids
84
+
85
+ def _create_project_upsample_block(
86
+ dim_in: int,
87
+ dim_out: int,
88
+ upsample_layers: int,
89
+ dim_intermediate=None,
90
+ ) -> nn.Module:
91
+ if dim_intermediate is None:
92
+ dim_intermediate = dim_out
93
+ # Projection.
94
+ blocks = [
95
+ nn.Conv2d(
96
+ in_channels=dim_in,
97
+ out_channels=dim_intermediate,
98
+ kernel_size=1,
99
+ stride=1,
100
+ padding=0,
101
+ bias=False,
102
+ )
103
+ ]
104
+
105
+ # Upsampling.
106
+ blocks += [
107
+ nn.ConvTranspose2d(
108
+ in_channels=dim_intermediate if i == 0 else dim_out,
109
+ out_channels=dim_out,
110
+ kernel_size=2,
111
+ stride=2,
112
+ padding=0,
113
+ bias=False,
114
+ )
115
+ for i in range(upsample_layers)
116
+ ]
117
+
118
+ return nn.Sequential(*blocks)
119
+
120
+ self.upsample_latent0 = _create_project_upsample_block(
121
+ dim_in=base_embed_dim,
122
+ dim_out=self.dims_encoder[0],
123
+ upsample_layers=3,
124
+ dim_intermediate=self.dims_encoder[1],
125
+ )
126
+ self.upsample_latent1 = _create_project_upsample_block(
127
+ dim_in=base_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=2
128
+ )
129
+
130
+ self.upsample0 = _create_project_upsample_block(
131
+ dim_in=base_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1
132
+ )
133
+ self.upsample1 = _create_project_upsample_block(
134
+ dim_in=base_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1
135
+ )
136
+ self.upsample2 = _create_project_upsample_block(
137
+ dim_in=base_embed_dim, dim_out=self.dims_encoder[4], upsample_layers=1
138
+ )
139
+
140
+ self.upsample_lowres = nn.ConvTranspose2d(
141
+ in_channels=lowres_embed_dim,
142
+ out_channels=self.dims_encoder[4],
143
+ kernel_size=2,
144
+ stride=2,
145
+ padding=0,
146
+ bias=True,
147
+ )
148
+ self.fuse_lowres = nn.Conv2d(
149
+ in_channels=(self.dims_encoder[4] + self.dims_encoder[4]),
150
+ out_channels=self.dims_encoder[4],
151
+ kernel_size=1,
152
+ stride=1,
153
+ padding=0,
154
+ bias=True,
155
+ )
156
+
157
+ def internal_resolution(self) -> int:
158
+ """Return the full image size of the SPN network."""
159
+ return self.patch_size * 4
160
+
161
+ @torch.jit.ignore
162
+ def set_grad_checkpointing(self, is_enabled=True):
163
+ """Enable grad checkpointing."""
164
+ self.grad_checkpointing = is_enabled
165
+ self.patch_encoder.set_grad_checkpointing(is_enabled)
166
+ self.image_encoder.set_grad_checkpointing(is_enabled)
167
+
168
+ @torch.jit.ignore
169
+ def set_requires_grad_(self, patch_encoder: bool, image_encoder: bool):
170
+ """Set requires grad for separate components."""
171
+ self.patch_encoder.requires_grad_(patch_encoder)
172
+ self.image_encoder.requires_grad_(image_encoder)
173
+
174
+ # Always freeze the unused TimmViT head to exclude it from the calculation of
175
+ # trainable parameters.
176
+ self.patch_encoder.head.requires_grad_(False)
177
+ self.image_encoder.head.requires_grad_(False)
178
+
179
+ # These upsamplers only affect patch encoder's feature maps.
180
+ self.upsample_latent0.requires_grad_(patch_encoder)
181
+ self.upsample_latent1.requires_grad_(patch_encoder)
182
+ self.upsample0.requires_grad_(patch_encoder)
183
+ self.upsample1.requires_grad_(patch_encoder)
184
+ self.upsample2.requires_grad_(patch_encoder)
185
+
186
+ # This upsampler affects only image encoder's feature map.
187
+ self.upsample_lowres.requires_grad_(image_encoder)
188
+
189
+ # This fuser affects both image and patch encoders.
190
+ self.fuse_lowres.requires_grad_(image_encoder or patch_encoder)
191
+
192
+ def _create_pyramid(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
193
+ """Creates a 3-level image pyramid."""
194
+ # Original resolution: 1536 by default.
195
+ x0 = x
196
+
197
+ # Middle resolution: 768 by default.
198
+ x1 = F.interpolate(x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False)
199
+
200
+ # Low resolution: 384 by default, corresponding to the backbone resolution.
201
+ x2 = F.interpolate(x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False)
202
+
203
+ return x0, x1, x2
204
+
205
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
206
+ """Encode input at multiple resolutions."""
207
+ batch_size = x.shape[0]
208
+
209
+ # Step 0: create a 3-level image pyramid.
210
+ x0, x1, x2 = self._create_pyramid(x)
211
+
212
+ if self.use_patch_overlap:
213
+ # Step 1: split to create batched overlapped mini-images at the ViT
214
+ # resolution.
215
+ # 5x5 @ 384x384 at the highest resolution (1536x1536).
216
+ x0_patches = split(x0, overlap_ratio=0.25, patch_size=self.patch_size)
217
+ # 3x3 @ 384x384 at the middle resolution (768x768).
218
+ x1_patches = split(x1, overlap_ratio=0.5, patch_size=self.patch_size)
219
+ # 1x1 # 384x384 at the lowest resolution (384x384).
220
+ x2_patches = x2
221
+ padding = 3
222
+ else:
223
+ # Step 1: split to create batched overlapped mini-images at the ViT
224
+ # resolution.
225
+ # 4x4 @ 384x384 at the highest resolution (1536x1536).
226
+ x0_patches = split(x0, overlap_ratio=0.0, patch_size=self.patch_size)
227
+ # 2x2 @ 384x384 at the middle resolution (768x768).
228
+ x1_patches = split(x1, overlap_ratio=0.0, patch_size=self.patch_size)
229
+ # 1x1 # 384x384 at the lowest resolution (384x384).
230
+ x2_patches = x2
231
+ padding = 0
232
+ x0_tile_size = x0_patches.shape[0]
233
+
234
+ # Concatenate all the sliding window patches and form a batch of size
235
+ # (35=5x5+3x3+1x1) or (21=4x4+2x2+1x1).
236
+ x_pyramid_patches = torch.cat(
237
+ (x0_patches, x1_patches, x2_patches),
238
+ dim=0,
239
+ )
240
+
241
+ # Run the ViT model and get the result of large batch size.
242
+ #
243
+ # For the retrieval of intermediate features forward hooks are more concise,
244
+ # but they are not well compatible with symbolic tracing because attributes
245
+ # of submodules can be lost during tracing. Therefore, forward hooks may not
246
+ # be preserved during graph transformation, leading to unexpected behavior.
247
+ # To avoid such issues it is safer not to use them because they are not
248
+ # essential here.
249
+ x_pyramid_encodings, patch_intermediate_features = self.patch_encoder(x_pyramid_patches)
250
+
251
+ # Step 3: merging.
252
+ # Merge highres latent encoding.
253
+ # NOTE: list type check has completed in init.
254
+ x_latent0_encodings = self.patch_encoder.reshape_feature(
255
+ patch_intermediate_features[self.patch_intermediate_features_ids[0]] # type:ignore[index]
256
+ )
257
+ x_latent0_features = merge(
258
+ x_latent0_encodings[: batch_size * x0_tile_size],
259
+ batch_size=batch_size,
260
+ padding=padding,
261
+ )
262
+
263
+ x_latent1_encodings = self.patch_encoder.reshape_feature(
264
+ patch_intermediate_features[self.patch_intermediate_features_ids[1]] # type:ignore[index]
265
+ )
266
+ x_latent1_features = merge(
267
+ x_latent1_encodings[: batch_size * x0_tile_size],
268
+ batch_size=batch_size,
269
+ padding=padding,
270
+ )
271
+
272
+ # Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1.
273
+ x0_encodings, x1_encodings, x2_encodings = torch.split(
274
+ x_pyramid_encodings,
275
+ [len(x0_patches), len(x1_patches), len(x2_patches)],
276
+ dim=0,
277
+ )
278
+
279
+ # 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps.
280
+ x0_features = merge(x0_encodings, batch_size=batch_size, padding=padding)
281
+
282
+ # 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps.
283
+ x1_features = merge(x1_encodings, batch_size=batch_size, padding=2 * padding)
284
+
285
+ # 24x24 feature maps.
286
+ x2_features = x2_encodings
287
+
288
+ # Apply the image encoder.
289
+ x_lowres_features, image_intermediate_features = self.image_encoder(x2_patches)
290
+
291
+ # Upsample feature maps.
292
+ x_latent0_features = checkpoint_wrapper(self, self.upsample_latent0, x_latent0_features)
293
+ x_latent1_features = checkpoint_wrapper(self, self.upsample_latent1, x_latent1_features)
294
+
295
+ x0_features = checkpoint_wrapper(self, self.upsample0, x0_features)
296
+ x1_features = checkpoint_wrapper(self, self.upsample1, x1_features)
297
+ x2_features = checkpoint_wrapper(self, self.upsample2, x2_features)
298
+
299
+ x_lowres_features = checkpoint_wrapper(self, self.upsample_lowres, x_lowres_features)
300
+ x_lowres_features = checkpoint_wrapper(
301
+ self, self.fuse_lowres, torch.cat((x2_features, x_lowres_features), dim=1)
302
+ )
303
+
304
+ output = [
305
+ x_latent0_features,
306
+ x_latent1_features,
307
+ x0_features,
308
+ x1_features,
309
+ x_lowres_features,
310
+ ]
311
+
312
+ return output
313
+
314
+
315
+ # It seems that torch.fx.wrap can only be applied to functions, not methods.
316
+ # Hence, split and merge were converted into functions to be marked as atomic
317
+ # operations for symbolic tracing.
318
+ @torch.fx.wrap
319
+ def split(image: torch.Tensor, overlap_ratio: float = 0.25, patch_size: int = 384) -> torch.Tensor:
320
+ """Split the input into small patches with sliding window."""
321
+ patch_stride = int(patch_size * (1 - overlap_ratio))
322
+
323
+ image_size = image.shape[-1]
324
+ steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1
325
+
326
+ x_patch_list = []
327
+ for j in range(steps):
328
+ j0 = j * patch_stride
329
+ j1 = j0 + patch_size
330
+
331
+ for i in range(steps):
332
+ i0 = i * patch_stride
333
+ i1 = i0 + patch_size
334
+ x_patch_list.append(image[..., j0:j1, i0:i1])
335
+
336
+ return torch.cat(x_patch_list, dim=0)
337
+
338
+
339
+ # Decorator marking function as an atomic operator for symbolic tracing.
340
+ @torch.fx.wrap
341
+ def merge(image_patches: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor:
342
+ """Merge the patched input into a image with sliding window."""
343
+ steps = int(math.sqrt(image_patches.shape[0] // batch_size))
344
+
345
+ idx = 0
346
+
347
+ output_list = []
348
+ for j in range(steps):
349
+ output_row_list = []
350
+ for i in range(steps):
351
+ output = image_patches[batch_size * idx : batch_size * (idx + 1)]
352
+
353
+ if padding != 0:
354
+ if j != 0:
355
+ output = output[..., padding:, :]
356
+ if i != 0:
357
+ output = output[..., :, padding:]
358
+ if j != steps - 1:
359
+ output = output[..., :-padding, :]
360
+ if i != steps - 1:
361
+ output = output[..., :, :-padding]
362
+
363
+ output_row_list.append(output)
364
+ idx += 1
365
+
366
+ output_row = torch.cat(output_row_list, dim=-1)
367
+ output_list.append(output_row)
368
+ output = torch.cat(output_list, dim=-2)
369
+ return output
src/sharp/models/encoders/unet_encoder.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains backbone models for feature extraction from RGBD input.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import List
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from sharp.models.blocks import (
15
+ NormLayerName,
16
+ norm_layer_2d,
17
+ residual_block_2d,
18
+ )
19
+
20
+ from .base_encoder import BaseEncoder
21
+
22
+
23
+ class UNetEncoder(BaseEncoder):
24
+ """Encoder of UNet model."""
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int,
29
+ width: List[int] | int,
30
+ steps: int = 6,
31
+ norm_type: NormLayerName = "group_norm",
32
+ norm_num_groups=8,
33
+ blocks_per_layer=2,
34
+ ) -> None:
35
+ """Initialize UNet Encoder.
36
+
37
+ Args:
38
+ dim_in: The number of input channels.
39
+ width: Width multiplicator of intermediate layers or the width list of all layers.
40
+ steps: The number of downsampling steps.
41
+ norm_type: Which kind of normalization layer to use.
42
+ norm_num_groups: How many groups to use for group norm (if relevant).
43
+ blocks_per_layer: How many residual blocks per layer to use.
44
+ """
45
+ super().__init__()
46
+
47
+ if blocks_per_layer < 1:
48
+ raise ValueError("blocks_per_layer must be greater or equal to one.")
49
+
50
+ self.dim_in = dim_in
51
+ self.width = width
52
+ self.num_steps = steps
53
+
54
+ self.convs_down = nn.ModuleList()
55
+
56
+ self.output_dims: list[int]
57
+ # If only one number is specified, we assume each layer will double the channel dimension.
58
+ if isinstance(width, int):
59
+ self.output_dims = [width << i for i in range(0, steps + 1)]
60
+ else:
61
+ if len(width) != (steps + 1):
62
+ raise ValueError("Length of width should match the steps for UNetEncoder.")
63
+ self.output_dims = width
64
+
65
+ self.conv_in = nn.Sequential(
66
+ nn.Conv2d(self.dim_in, self.output_dims[0], 3, stride=1, padding=1),
67
+ norm_layer_2d(self.output_dims[0], norm_type, num_groups=norm_num_groups),
68
+ nn.ReLU(),
69
+ )
70
+
71
+ for i_step in range(steps):
72
+ input_width = self.output_dims[i_step]
73
+ current_width = self.output_dims[i_step + 1]
74
+ convs_down_i = nn.Sequential(
75
+ nn.AvgPool2d(2, stride=2),
76
+ residual_block_2d(
77
+ input_width,
78
+ current_width,
79
+ norm_type=norm_type,
80
+ norm_num_groups=norm_num_groups,
81
+ ),
82
+ *[
83
+ residual_block_2d(
84
+ current_width,
85
+ current_width,
86
+ norm_type=norm_type,
87
+ norm_num_groups=norm_num_groups,
88
+ )
89
+ for _ in range(blocks_per_layer - 1)
90
+ ],
91
+ )
92
+ self.convs_down.append(convs_down_i)
93
+
94
+ def forward(self, input: torch.Tensor) -> list[torch.Tensor]:
95
+ """Apply UNet Encoder to image.
96
+
97
+ Args:
98
+ input: The input image.
99
+
100
+ Returns:
101
+ The output multi-level feature map from encoder.
102
+ """
103
+ features = []
104
+
105
+ feat_i = self.conv_in(input)
106
+ features.append(feat_i)
107
+
108
+ for conv_down in self.convs_down:
109
+ feat_i = conv_down(feat_i)
110
+ features.append(feat_i)
111
+
112
+ return features
113
+
114
+ @property
115
+ def out_width(self) -> int:
116
+ """Compute the output width for UNet decoder."""
117
+ return self.output_dims[-1]
src/sharp/models/encoders/vit_encoder.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains factory functions to build and load ViT.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+
11
+ import timm
12
+ import torch
13
+
14
+ from sharp.models.presets.vit import VIT_CONFIG_DICT, ViTConfig, ViTPreset
15
+
16
+ LOGGER = logging.getLogger(__name__)
17
+
18
+
19
+ class TimmViT(timm.models.VisionTransformer):
20
+ """Contains TIMM implementation for Vanilla ViT."""
21
+
22
+ def __init__(self, config: ViTConfig):
23
+ """Initialize ViT from TIMM implementation."""
24
+ # Handle mlp layers.
25
+ mlp_layer = timm.layers.GluMlp if config.mlp_mode == "glu" else timm.layers.Mlp
26
+
27
+ super().__init__(
28
+ in_chans=config.in_chans,
29
+ embed_dim=config.embed_dim,
30
+ depth=config.depth,
31
+ num_heads=config.num_heads,
32
+ init_values=config.init_values,
33
+ img_size=config.img_size,
34
+ patch_size=config.patch_size,
35
+ num_classes=config.num_classes,
36
+ mlp_ratio=config.mlp_ratio,
37
+ qkv_bias=config.qkv_bias,
38
+ global_pool=config.global_pool,
39
+ mlp_layer=mlp_layer,
40
+ )
41
+
42
+ # Required for extracting intermediate features.
43
+ self.dim_in = config.in_chans
44
+ self.intermediate_features_ids = config.intermediate_features_ids
45
+
46
+ def reshape_feature(self, embeddings: torch.Tensor):
47
+ """Discard class token and reshape 1D feature map to a 2D grid."""
48
+ batch_size, seq_len, channel = embeddings.shape
49
+
50
+ height, width = self.patch_embed.grid_size
51
+
52
+ # Remove class token.
53
+ if self.num_prefix_tokens:
54
+ embeddings = embeddings[:, self.num_prefix_tokens :, :]
55
+
56
+ # Shape: (batch, height, width, dim) -> (batch, dim, height, width)
57
+ embeddings = embeddings.reshape(batch_size, height, width, channel).permute(0, 3, 1, 2)
58
+ return embeddings
59
+
60
+ def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, dict[int, torch.Tensor]]:
61
+ """Override forwarding with intermediate features.
62
+
63
+ Adapted from timm ViT.
64
+
65
+ Returns:
66
+ Output features and list of features from intermediate layers (patch encoder only).
67
+ """
68
+ intermediate_features = {}
69
+
70
+ x = self.patch_embed(input_tensor)
71
+ batch_size, seq_len, _ = x.shape
72
+
73
+ x = self._pos_embed(x)
74
+ x = self.patch_drop(x)
75
+ x = self.norm_pre(x)
76
+
77
+ for idx, block in enumerate(self.blocks):
78
+ x = block(x)
79
+ if self.intermediate_features_ids is not None and idx in self.intermediate_features_ids:
80
+ intermediate_features[idx] = x
81
+ x = self.norm(x)
82
+
83
+ x = self.reshape_feature(x)
84
+ return x, intermediate_features
85
+
86
+ def internal_resolution(self) -> int:
87
+ """Return the internal image size of the network."""
88
+ if isinstance(self.patch_embed.img_size, tuple):
89
+ return self.patch_embed.img_size[0]
90
+ else:
91
+ return self.patch_embed.img_size
92
+
93
+
94
+ def create_vit(
95
+ config: ViTConfig | None = None,
96
+ preset: ViTPreset | None = "dinov2l16_384",
97
+ intermediate_features_ids: list[int] | None = None,
98
+ ) -> TimmViT:
99
+ """Factory function for creating a ViT model."""
100
+ if config is not None:
101
+ LOGGER.info("Using user-defined config.")
102
+ else:
103
+ if preset is None:
104
+ raise ValueError("User-defined config and preset cannot be both None.")
105
+ LOGGER.info("Using preset ViT %s.", preset)
106
+ config = VIT_CONFIG_DICT[preset]
107
+
108
+ config.intermediate_features_ids = intermediate_features_ids
109
+ model = TimmViT(config)
110
+ LOGGER.debug(model)
111
+ return model
src/sharp/models/gaussian_decoder.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains Dense Transformer Prediction architecture.
2
+
3
+ Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
4
+
5
+ For licensing see accompanying LICENSE file.
6
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import NamedTuple
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from sharp.models.blocks import (
17
+ FeatureFusionBlock2d,
18
+ NormLayerName,
19
+ residual_block_2d,
20
+ )
21
+ from sharp.models.decoders import BaseDecoder, MultiresConvDecoder
22
+ from sharp.models.params import DPTImageEncoderType, GaussianDecoderParams
23
+
24
+
25
+ def create_gaussian_decoder(
26
+ params: GaussianDecoderParams, dims_depth_features: list[int]
27
+ ) -> GaussianDensePredictionTransformer:
28
+ """Create gaussian_decoder model specified by gaussian_decoder_name."""
29
+ decoder = MultiresConvDecoder(
30
+ dims_depth_features,
31
+ params.dims_decoder,
32
+ grad_checkpointing=params.grad_checkpointing,
33
+ upsampling_mode=params.upsampling_mode,
34
+ )
35
+
36
+ return GaussianDensePredictionTransformer(
37
+ decoder=decoder,
38
+ dim_in=params.dim_in,
39
+ dim_out=params.dim_out,
40
+ stride_out=params.stride,
41
+ norm_type=params.norm_type,
42
+ norm_num_groups=params.norm_num_groups,
43
+ use_depth_input=params.use_depth_input,
44
+ grad_checkpointing=params.grad_checkpointing,
45
+ image_encoder_type=params.image_encoder_type,
46
+ image_encoder_params=params,
47
+ )
48
+
49
+
50
+ def _create_project_upsample_block(
51
+ dim_in: int,
52
+ dim_out: int,
53
+ upsample_layers: int,
54
+ dim_intermediate: int | None = None,
55
+ ) -> nn.Module:
56
+ if dim_intermediate is None:
57
+ dim_intermediate = dim_out
58
+ # Projection.
59
+ blocks = [
60
+ nn.Conv2d(
61
+ in_channels=dim_in,
62
+ out_channels=dim_intermediate,
63
+ kernel_size=1,
64
+ stride=1,
65
+ padding=0,
66
+ bias=False,
67
+ )
68
+ ]
69
+
70
+ # Upsampling.
71
+ blocks += [
72
+ nn.ConvTranspose2d(
73
+ in_channels=dim_intermediate if i == 0 else dim_out,
74
+ out_channels=dim_out,
75
+ kernel_size=2,
76
+ stride=2,
77
+ padding=0,
78
+ bias=False,
79
+ )
80
+ for i in range(upsample_layers)
81
+ ]
82
+
83
+ return nn.Sequential(*blocks)
84
+
85
+
86
+ class ImageFeatures(NamedTuple):
87
+ """Image feature extracted from decoder."""
88
+
89
+ texture_features: torch.Tensor
90
+ geometry_features: torch.Tensor
91
+
92
+
93
+ class SkipConvBackbone(nn.Module):
94
+ """A wrapper around a conv layer that behaves like a BaseBackbone."""
95
+
96
+ def __init__(self, dim_in: int, dim_out: int, kernel_size: int, stride_out: int):
97
+ """Initialize SkipConvBackbone."""
98
+ super().__init__()
99
+ self.stride_out = stride_out
100
+ if stride_out == 1 and kernel_size != 1:
101
+ raise ValueError("We only support kernel_size = 1 if stride_out is 1.")
102
+ padding: int = (kernel_size - 1) // 2
103
+ self.conv = nn.Conv2d(
104
+ dim_in, dim_out, kernel_size=kernel_size, stride=stride_out, padding=padding
105
+ )
106
+
107
+ def forward(
108
+ self,
109
+ input_features: torch.Tensor,
110
+ encodings: list[torch.Tensor] | None = None,
111
+ ) -> ImageFeatures:
112
+ """Apply SkipConvBackbone to image."""
113
+ output = self.conv(input_features)
114
+ return ImageFeatures(
115
+ texture_features=output,
116
+ geometry_features=output,
117
+ )
118
+
119
+ @property
120
+ def stride(self) -> int:
121
+ """Effective downsampling stride."""
122
+ return self.stride_out
123
+
124
+
125
+ class GaussianDensePredictionTransformer(nn.Module):
126
+ """Dense Prediction Transformer for Gaussian.
127
+
128
+ Reuse monodepth decoded features for processing.
129
+ """
130
+
131
+ norm_type: NormLayerName
132
+
133
+ def __init__(
134
+ self,
135
+ decoder: BaseDecoder,
136
+ dim_in: int,
137
+ dim_out: int,
138
+ stride_out: int,
139
+ image_encoder_params: GaussianDecoderParams,
140
+ image_encoder_type: DPTImageEncoderType = "skip_conv",
141
+ norm_type: NormLayerName = "group_norm",
142
+ norm_num_groups: int = 8,
143
+ use_depth_input: bool = True,
144
+ grad_checkpointing: bool = False,
145
+ ):
146
+ """Initialize Dense Prediction Transformer for Gaussian.
147
+
148
+ Args:
149
+ decoder: Decoder to decode features.
150
+ monodepth_decoder: Optional monodepth decoder to fuse monodepth decoded features.
151
+ dim_in: Input dimension.
152
+ dim_out: Final output dimension.
153
+ stride_out: Stride of output feature map.
154
+ image_encoder_params: The backbone parameters to configurate the image encoder.
155
+ image_encoder_type: Type of image encoder to use.
156
+ encoder: Encoder to generate features using monodepth model.
157
+ norm_type: Type of norm layers.
158
+ norm_num_groups: Num groups for norm layers.
159
+ use_depth_input: Whether to use depth input.
160
+ grad_checkpointing: Whether to use gradient checkpointing.
161
+ """
162
+ super().__init__()
163
+
164
+ self.decoder = decoder
165
+ self.dim_in = dim_in
166
+ self.dim_out = dim_out
167
+ self.stride_out = stride_out
168
+ self.norm_type = norm_type
169
+ self.norm_num_groups = norm_num_groups
170
+ self.use_depth_input = use_depth_input
171
+ self.grad_checkpointing = grad_checkpointing
172
+ self.image_encoder_type = image_encoder_type
173
+
174
+ # Adopt an image encoder to lift dimension to monodepth feature and
175
+ # resize to be the same resolution as the decoder output.
176
+ dim_in = self.dim_in if use_depth_input else self.dim_in - 1
177
+ image_encoder_params.dim_in = dim_in
178
+ image_encoder_params.dim_out = decoder.dim_out
179
+ self.image_encoder = self._create_image_encoder(image_encoder_params, stride_out)
180
+
181
+ self.fusion = FeatureFusionBlock2d(decoder.dim_out)
182
+
183
+ if stride_out == 1:
184
+ self.upsample = _create_project_upsample_block(
185
+ decoder.dim_out,
186
+ decoder.dim_out,
187
+ upsample_layers=1,
188
+ )
189
+ elif stride_out == 2:
190
+ self.upsample = nn.Identity()
191
+ else:
192
+ raise ValueError("We only support stride is 1 or 2 for DPT backbone.")
193
+
194
+ self.texture_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)
195
+ self.geometry_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)
196
+
197
+ def _create_head(self, dim_decoder: int, dim_out: int) -> nn.Module:
198
+ return nn.Sequential(
199
+ residual_block_2d(
200
+ dim_in=dim_decoder,
201
+ dim_out=dim_decoder,
202
+ dim_hidden=dim_decoder // 2,
203
+ norm_type=self.norm_type,
204
+ norm_num_groups=self.norm_num_groups,
205
+ ),
206
+ residual_block_2d(
207
+ dim_in=dim_decoder,
208
+ dim_hidden=dim_decoder // 2,
209
+ dim_out=dim_decoder,
210
+ norm_type=self.norm_type,
211
+ norm_num_groups=self.norm_num_groups,
212
+ ),
213
+ nn.ReLU(),
214
+ nn.Conv2d(dim_decoder, dim_out, kernel_size=1, stride=1),
215
+ nn.ReLU(),
216
+ )
217
+
218
+ def _create_image_encoder(
219
+ self, image_encoder_params: GaussianDecoderParams, stride_out: int
220
+ ) -> nn.Module:
221
+ """Create image encoder and return based on parameters."""
222
+ if self.image_encoder_type == "skip_conv":
223
+ # Use kernel_size = 1 only if stride_out is 1.
224
+ return SkipConvBackbone(
225
+ image_encoder_params.dim_in,
226
+ image_encoder_params.dim_out,
227
+ kernel_size=3 if stride_out != 1 else 1,
228
+ stride_out=stride_out,
229
+ )
230
+ elif self.image_encoder_type == "skip_conv_kernel2":
231
+ return SkipConvBackbone(
232
+ image_encoder_params.dim_in,
233
+ image_encoder_params.dim_out,
234
+ kernel_size=stride_out,
235
+ stride_out=stride_out,
236
+ )
237
+ else:
238
+ raise ValueError(f"Unsupported image encoder type: {self.image_encoder_type}")
239
+
240
+ def forward(self, input_features: torch.Tensor, encodings: list[torch.Tensor]) -> ImageFeatures:
241
+ """Run monodepth and fuse features with input image to predict Gaussians.
242
+
243
+ Args:
244
+ input_features: The input features to use.
245
+ encodings: Feature encodings (e.g. from monodepth network).
246
+ """
247
+ features = self.decoder(encodings).contiguous()
248
+ features = self.upsample(features)
249
+
250
+ if self.use_depth_input:
251
+ skip_features = self.image_encoder(input_features).texture_features
252
+ else:
253
+ skip_features = self.image_encoder(input_features[:, :3].contiguous())
254
+ features = self.fusion(features, skip_features)
255
+
256
+ texture_features = self.texture_head(features)
257
+ geometry_features = self.geometry_head(features)
258
+
259
+ return ImageFeatures(
260
+ texture_features=texture_features, # type: ignore
261
+ geometry_features=geometry_features, # type: ignore
262
+ )
263
+
264
+ @property
265
+ def stride(self) -> int:
266
+ """Internal stride of GaussianDensePredictionTransformer."""
267
+ return self.stride_out
src/sharp/models/heads.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains decoder head for direct prediction of delta values.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from .gaussian_decoder import ImageFeatures
13
+
14
+
15
+ class DirectPredictionHead(nn.Module):
16
+ """Decodes features into delta values using convolutions."""
17
+
18
+ def __init__(self, feature_dim: int, num_layers: int) -> None:
19
+ """Initialize DirectGaussianPredictor.
20
+
21
+ Args:
22
+ feature_dim: Number of input features.
23
+ num_layers: The number of layers of Gaussians to predict.
24
+ """
25
+ super().__init__()
26
+ self.num_layers = num_layers
27
+
28
+ # 14 is 3 means, 3 scales, 4 quaternions, 3 colors and 1 opacity
29
+ self.geometry_prediction_head = nn.Conv2d(feature_dim, 3 * num_layers, 1)
30
+ self.geometry_prediction_head.weight.data.zero_()
31
+ assert self.geometry_prediction_head.bias is not None
32
+ self.geometry_prediction_head.bias.data.zero_()
33
+
34
+ self.texture_prediction_head = nn.Conv2d(feature_dim, (14 - 3) * num_layers, 1)
35
+ self.texture_prediction_head.weight.data.zero_()
36
+ assert self.texture_prediction_head.bias is not None
37
+ self.texture_prediction_head.bias.data.zero_()
38
+
39
+ def forward(self, image_features: ImageFeatures) -> torch.Tensor:
40
+ """Predict deltas for 3D Gaussians.
41
+
42
+ Args:
43
+ image_features: Image features from decoder.
44
+
45
+ Returns:
46
+ The predicted deltas for Gaussian attributes.
47
+ """
48
+ delta_values_geometry = self.geometry_prediction_head(image_features.geometry_features)
49
+ delta_values_texture = self.texture_prediction_head(image_features.texture_features)
50
+ delta_values_geometry = delta_values_geometry.unflatten(1, (3, self.num_layers))
51
+ delta_values_texture = delta_values_texture.unflatten(1, (14 - 3, self.num_layers))
52
+ delta_values = torch.cat([delta_values_geometry, delta_values_texture], dim=1)
53
+ return delta_values
src/sharp/models/initializer.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains modules to initialize Gaussians from RGBD.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import NamedTuple
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from .params import ColorInitOption, DepthInitOption, InitializerParams
15
+
16
+
17
+ def create_initializer(params: InitializerParams) -> nn.Module:
18
+ """Create inpainter."""
19
+ return MultiLayerInitializer(
20
+ num_layers=params.num_layers,
21
+ stride=params.stride,
22
+ base_depth=params.base_depth,
23
+ scale_factor=params.scale_factor,
24
+ disparity_factor=params.disparity_factor,
25
+ color_option=params.color_option,
26
+ first_layer_depth_option=params.first_layer_depth_option,
27
+ rest_layer_depth_option=params.rest_layer_depth_option,
28
+ normalize_depth=params.normalize_depth,
29
+ feature_input_stop_grad=params.feature_input_stop_grad,
30
+ )
31
+
32
+
33
+ class GaussianBaseValues(NamedTuple):
34
+ """Base values for gaussian predictor.
35
+
36
+ We predict x and y in normalized device coordinates (NDC) where (-1, -1) is the top
37
+ left corner and (1, 1) the bottom right corner. The last component of
38
+ mean_vectors_ndc is inverse depth.
39
+ """
40
+
41
+ mean_x_ndc: torch.Tensor
42
+ mean_y_ndc: torch.Tensor
43
+ mean_inverse_z_ndc: torch.Tensor
44
+
45
+ scales: torch.Tensor
46
+ quaternions: torch.Tensor
47
+ colors: torch.Tensor
48
+ opacities: torch.Tensor
49
+
50
+
51
+ class InitializerOutput(NamedTuple):
52
+ """Output of initializer."""
53
+
54
+ # Gaussian base values.
55
+ gaussian_base_values: GaussianBaseValues
56
+
57
+ # Feature input to the Gaussian predictor.
58
+ feature_input: torch.Tensor
59
+
60
+ # Global scale to unscale output.
61
+ global_scale: torch.Tensor | None = None
62
+
63
+
64
+ class MultiLayerInitializer(nn.Module):
65
+ """Initialize Gaussians with multilayer representation.
66
+
67
+ The returned tensors have the shape
68
+
69
+ batch_size x dim x num_layers x height x width
70
+
71
+ where dim indicates the dimensionality of the property.
72
+ Some of the dimensions might be set to 1 for efficiency reasons.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ num_layers: int,
78
+ stride: int,
79
+ base_depth: float,
80
+ scale_factor: float,
81
+ disparity_factor: float,
82
+ color_option: ColorInitOption = "first_layer",
83
+ first_layer_depth_option: DepthInitOption = "surface_min",
84
+ rest_layer_depth_option: DepthInitOption = "surface_min",
85
+ normalize_depth: bool = True,
86
+ feature_input_stop_grad: bool = True,
87
+ ) -> None:
88
+ """Initialize MultilayerInitializer.
89
+
90
+ Args:
91
+ stride: The downsample rate of output feature map.
92
+ base_depth: The depth of the first layer (after the foreground
93
+ layer if use_depth=True).
94
+ scale_factor: Multiply scale of Gaussians by this factor.
95
+ disparity_factor: Factor to convert inverse depth to disparity.
96
+ num_layers: How many layers of Gaussians to predict.
97
+ color_option: Which color option to initialize the multi-layer gaussians.
98
+ first_layer_depth_option: Which depth option to initialize the first layer of gaussians.
99
+ rest_layer_depth_option: Which depth option to initialize the rest layers of gaussians.
100
+ normalize_depth: # Whether to normalize depth to [DepthTransformParam.depth_min,
101
+ DepthTransformParam.depth_max).
102
+ feature_input_stop_grad: Whether to not propagate gradients through feature inputs.
103
+ """
104
+ super().__init__()
105
+ self.num_layers = num_layers
106
+ self.stride = stride
107
+ self.base_depth = base_depth
108
+ self.scale_factor = scale_factor
109
+ self.disparity_factor = disparity_factor
110
+ self.color_option = color_option
111
+ self.first_layer_depth_option = first_layer_depth_option
112
+ self.rest_layer_depth_option = rest_layer_depth_option
113
+ self.normalize_depth = normalize_depth
114
+ self.feature_input_stop_grad = feature_input_stop_grad
115
+
116
+ def prepare_feature_input(self, image: torch.Tensor, depth: torch.Tensor) -> torch.Tensor:
117
+ """Prepare the feature input to the Guassian predictor."""
118
+ if self.feature_input_stop_grad:
119
+ image = image.detach()
120
+ depth = depth.detach()
121
+
122
+ normalized_disparity = self.disparity_factor / depth
123
+ features_in = torch.cat([image, normalized_disparity], dim=1)
124
+ features_in = 2.0 * features_in - 1.0
125
+ return features_in
126
+
127
+ def forward(self, image: torch.Tensor, depth: torch.Tensor) -> InitializerOutput:
128
+ """Construct Gaussian base values and prepare feature input.
129
+
130
+ Args:
131
+ image: The image to process.
132
+ depth: The corresponding depth map from the monodepth network.
133
+
134
+ Returns:
135
+ The base value for Gaussians.
136
+ """
137
+ image = image.contiguous()
138
+ depth = depth.contiguous()
139
+ device = depth.device
140
+ batch_size, _, image_height, image_width = depth.shape
141
+ base_height, base_width = (
142
+ image_height // self.stride,
143
+ image_width // self.stride,
144
+ )
145
+ # global_scale is the inverse of the depth_factor, which is used to rescale
146
+ # the depth such that it is numerically stable for training.
147
+ global_scale: torch.Tensor | None = None
148
+ if self.normalize_depth:
149
+ depth, depth_factor = _rescale_depth(depth)
150
+ global_scale = 1.0 / depth_factor
151
+
152
+ def _create_disparity_layers(num_layers: int = 1) -> torch.Tensor:
153
+ """Create multiple disparity layers."""
154
+ disparity = torch.linspace(1.0 / self.base_depth, 0.0, num_layers + 1, device=device)
155
+ return disparity[None, None, :-1, None, None].repeat(
156
+ batch_size, 1, 1, base_height, base_width
157
+ )
158
+
159
+ def _create_surface_layer(
160
+ depth: torch.Tensor,
161
+ depth_pooling_mode: str,
162
+ ) -> torch.Tensor:
163
+ """Create multiple surface layers."""
164
+ disparity = 1.0 / depth
165
+ if depth_pooling_mode == "min":
166
+ disparity = torch.max_pool2d(disparity, self.stride, self.stride)
167
+ elif depth_pooling_mode == "max":
168
+ disparity = -torch.max_pool2d(-disparity, self.stride, self.stride)
169
+ else:
170
+ raise ValueError(f"Invalid depth pooling mode {depth_pooling_mode}.")
171
+
172
+ return disparity[:, :, None, :, :]
173
+
174
+ # Input disparity dimensions:
175
+ # (batch_size, num_channels in (1, 2), height, width)
176
+
177
+ # Output disparity dimensions:
178
+ # (batch_size, num_channels=1, num_layers in (1, 2), height, width)
179
+ if self.first_layer_depth_option == "surface_min":
180
+ first_disparity = _create_surface_layer(depth[:, 0:1], "min")
181
+ elif self.first_layer_depth_option == "surface_max":
182
+ first_disparity = _create_surface_layer(depth[:, 0:1], "max")
183
+ elif self.first_layer_depth_option in ("base_depth", "linear_disparity"):
184
+ first_disparity = _create_disparity_layers()
185
+ else:
186
+ raise ValueError(f"Unknown depth init option: {self.first_layer_depth_option}.")
187
+
188
+ if self.num_layers == 1:
189
+ disparity = first_disparity
190
+ else: # Fill in the rest layers.
191
+ following_depth = depth if depth.shape[1] == 1 else depth[:, 1:]
192
+ if self.rest_layer_depth_option == "surface_min":
193
+ following_disparity = _create_surface_layer(following_depth, "min")
194
+ elif self.rest_layer_depth_option == "surface_max":
195
+ following_disparity = _create_surface_layer(following_depth, "max")
196
+ elif self.rest_layer_depth_option == "base_depth":
197
+ following_disparity = torch.cat(
198
+ [_create_disparity_layers() for i in range(self.num_layers - 1)],
199
+ dim=2,
200
+ )
201
+ elif self.rest_layer_depth_option == "linear_disparity":
202
+ following_disparity = _create_disparity_layers(self.num_layers - 1)
203
+ else:
204
+ raise ValueError(f"Unknown depth init option: {self.rest_layer_depth_option}.")
205
+
206
+ disparity = torch.cat([first_disparity, following_disparity], dim=2)
207
+
208
+ # Prepare base values.
209
+ base_x_ndc, base_y_ndc = _create_base_xy(depth, self.stride, self.num_layers)
210
+ disparity_scale_factor = 2 * self.scale_factor * self.stride / float(image_width)
211
+ base_scales = _create_base_scale(disparity, disparity_scale_factor)
212
+
213
+ base_quaternions = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device)
214
+ base_quaternions = base_quaternions[None, :, None, None, None]
215
+
216
+ # Initializing the opacitiy this way ensures that the initial transmittance
217
+ # is approximately
218
+ #
219
+ # 1 / e ~= (1 - 1 / self.num_layers)**self.num_layers
220
+ #
221
+ # and hence independent of the number of layers.
222
+ #
223
+ base_opacities = torch.tensor([min(1.0 / self.num_layers, 0.5)], device=device)
224
+ base_colors = torch.empty(
225
+ batch_size, 3, self.num_layers, base_height, base_width, device=device
226
+ ).fill_(0.5)
227
+ # Dimensions: (batch_size, num_channels, num_layers, height, width)
228
+ if self.color_option == "none":
229
+ pass
230
+ elif self.color_option == "first_layer":
231
+ base_colors[:, :, 0] = torch.nn.functional.avg_pool2d(image, self.stride, self.stride)
232
+ elif self.color_option == "all_layers":
233
+ temp = torch.nn.functional.avg_pool2d(image, self.stride, self.stride)
234
+ base_colors = temp[:, :, None, :, :].repeat(1, 1, self.num_layers, 1, 1)
235
+ else:
236
+ raise ValueError(f"Unknown color init option: {self.color_option}.")
237
+
238
+ features_in = self.prepare_feature_input(image, depth)
239
+ base_gaussians = GaussianBaseValues(
240
+ mean_x_ndc=base_x_ndc,
241
+ mean_y_ndc=base_y_ndc,
242
+ mean_inverse_z_ndc=disparity,
243
+ scales=base_scales,
244
+ quaternions=base_quaternions,
245
+ colors=base_colors,
246
+ opacities=base_opacities,
247
+ )
248
+
249
+ return InitializerOutput(
250
+ gaussian_base_values=base_gaussians,
251
+ feature_input=features_in,
252
+ global_scale=global_scale,
253
+ )
254
+
255
+
256
+ def _create_base_xy(
257
+ depth: torch.Tensor, stride: int, num_layers: int
258
+ ) -> tuple[torch.Tensor, torch.Tensor]:
259
+ """Create base x and y coordinates for the gaussians in NDC space."""
260
+ device = depth.device
261
+ batch_size, _, image_height, image_width = depth.shape
262
+ xx = torch.arange(0.5 * stride, image_width, stride, device=device)
263
+ yy = torch.arange(0.5 * stride, image_height, stride, device=device)
264
+ xx = 2 * xx / image_width - 1.0
265
+ yy = 2 * yy / image_height - 1.0
266
+
267
+ xx, yy = torch.meshgrid(xx, yy, indexing="xy")
268
+ base_x_ndc = xx[None, None, None].repeat(batch_size, 1, num_layers, 1, 1)
269
+ base_y_ndc = yy[None, None, None].repeat(batch_size, 1, num_layers, 1, 1)
270
+
271
+ return base_x_ndc, base_y_ndc
272
+
273
+
274
+ def _create_base_scale(disparity: torch.Tensor, disparity_scale_factor: float) -> torch.Tensor:
275
+ """Create base scale for the gaussians."""
276
+ inverse_disparity = torch.ones_like(disparity) / disparity
277
+ base_scales = inverse_disparity * disparity_scale_factor
278
+ return base_scales
279
+
280
+
281
+ def _rescale_depth(
282
+ depth: torch.Tensor, depth_min: float = 1.0, depth_max: float = 1e2
283
+ ) -> tuple[torch.Tensor, torch.Tensor]:
284
+ """Rescale a depth image tensor.
285
+
286
+ Args:
287
+ depth: The depth tensor to transform.
288
+ depth_min: The min depth to scale depth to.
289
+ depth_max: The max clamp depth after scaling.
290
+
291
+ Returns:
292
+ The rescaled depth and rescale factor.
293
+ """
294
+ current_depth_min = depth.flatten(depth.ndim - 3).min(dim=-1).values
295
+ depth_factor = depth_min / (current_depth_min + 1e-6)
296
+ depth = (depth * depth_factor[..., None, None, None]).clamp(max=depth_max)
297
+ return depth, depth_factor
src/sharp/models/monodepth.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains Dense Transformer Prediction architecture.
2
+
3
+ Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
4
+
5
+ For licensing see accompanying LICENSE file.
6
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import copy
12
+ from typing import NamedTuple, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from sharp.models import normalizers
18
+ from sharp.models.decoders import MultiresConvDecoder, create_monodepth_decoder
19
+ from sharp.models.encoders import (
20
+ SlidingPyramidNetwork,
21
+ create_monodepth_encoder,
22
+ )
23
+ from sharp.utils import module_surgery
24
+
25
+ from .params import MonodepthAdaptorParams, MonodepthParams
26
+
27
+ DimsDecoder = Tuple[int, int, int, int, int]
28
+
29
+
30
+ class MonodepthDensePredictionTransformer(nn.Module):
31
+ """Dense Prediction Transformer for monodepth.
32
+
33
+ Attach the disparity prediction head for monodepth prediction.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ encoder: SlidingPyramidNetwork,
39
+ decoder: MultiresConvDecoder,
40
+ last_dims: tuple[int, int],
41
+ ):
42
+ """Initialize Dense Prediction Transformer.
43
+
44
+ Args:
45
+ encoder: The SlidingPyramidTransformer backbone.
46
+ decoder: The MultiresConvDecoder decoder.
47
+ last_dims: The dimension for the last convolution layers.
48
+ """
49
+ super().__init__()
50
+
51
+ self.normalizer = normalizers.AffineRangeNormalizer(
52
+ input_range=(0, 1), output_range=(-1, 1)
53
+ )
54
+ self.encoder = encoder
55
+ self.decoder = decoder
56
+
57
+ dim_decoder = decoder.dim_out
58
+ self.head = nn.Sequential(
59
+ nn.Conv2d(dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1),
60
+ nn.ConvTranspose2d(
61
+ in_channels=dim_decoder // 2,
62
+ out_channels=dim_decoder // 2,
63
+ kernel_size=2,
64
+ stride=2,
65
+ padding=0,
66
+ bias=True,
67
+ ),
68
+ nn.Conv2d(
69
+ dim_decoder // 2,
70
+ last_dims[0],
71
+ kernel_size=3,
72
+ stride=1,
73
+ padding=1,
74
+ ),
75
+ nn.ReLU(True),
76
+ nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0),
77
+ nn.ReLU(),
78
+ )
79
+
80
+ # Set the final convoultion layer's bias to be 0.
81
+ self.head[4].bias.data.fill_(0)
82
+
83
+ self.grad_checkpointing = False
84
+
85
+ @torch.jit.ignore
86
+ def set_grad_checkpointing(self, is_enabled=True):
87
+ """Enable grad checkpointing."""
88
+ self.grad_checkpointing = is_enabled
89
+ self.encoder.set_grad_checkpointing(self.grad_checkpointing)
90
+ self.decoder.set_grad_checkpointing(self.grad_checkpointing)
91
+
92
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
93
+ """Decode by projection and fusion of multi-resolution encodings."""
94
+ encodings = self.encoder(self.normalizer(image))
95
+ num_encoder_features = len(self.encoder.dims_encoder)
96
+ features = self.decoder(encodings[:num_encoder_features])
97
+ disparity = self.head(features)
98
+ return disparity
99
+
100
+ def internal_resolution(self) -> int:
101
+ """Return the internal image size of the network."""
102
+ return self.encoder.internal_resolution()
103
+
104
+
105
+ def create_monodepth_dpt(
106
+ params: MonodepthParams | None = None,
107
+ ) -> MonodepthDensePredictionTransformer:
108
+ """Creates DepthDensePredictionTransformer model.
109
+
110
+ Args:
111
+ params: Parameters of monodepth network.
112
+
113
+ Returns:
114
+ The configured monodepth DPT.
115
+ """
116
+ if params is None:
117
+ params = MonodepthParams()
118
+ encoder: SlidingPyramidNetwork = create_monodepth_encoder(
119
+ params.patch_encoder_preset,
120
+ params.image_encoder_preset,
121
+ use_patch_overlap=params.use_patch_overlap,
122
+ last_encoder=params.dims_decoder[0],
123
+ )
124
+
125
+ decoder: MultiresConvDecoder = create_monodepth_decoder(
126
+ params.patch_encoder_preset, params.dims_decoder
127
+ )
128
+
129
+ monodepth_model = MonodepthDensePredictionTransformer(
130
+ encoder=encoder, decoder=decoder, last_dims=(32, 1)
131
+ )
132
+
133
+ # By default, we don't train the monodepth model.
134
+ # However, we allow to selectively unfreeze parts of the network.
135
+ monodepth_model.requires_grad_(False)
136
+
137
+ monodepth_model.encoder.set_requires_grad_(
138
+ patch_encoder=params.unfreeze_patch_encoder,
139
+ image_encoder=params.unfreeze_image_encoder,
140
+ )
141
+ monodepth_model.decoder.requires_grad_(params.unfreeze_decoder)
142
+ monodepth_model.head.requires_grad_(params.unfreeze_head)
143
+
144
+ if not params.unfreeze_norm_layers:
145
+ module_surgery.freeze_norm_layer(monodepth_model)
146
+
147
+ monodepth_model.set_grad_checkpointing(params.grad_checkpointing)
148
+
149
+ return monodepth_model
150
+
151
+
152
+ class MonodepthOutput(NamedTuple):
153
+ """Output of the monodepth model."""
154
+
155
+ # Disparity output from the monodepth model.
156
+ disparity: torch.Tensor
157
+ # Multi-level features from monodepth encoder.
158
+ encoder_features: list[torch.Tensor]
159
+ # Single-level feature from monodepth decoder.
160
+ decoder_features: torch.Tensor
161
+ # List of monodepth features to be used in gaussian predictor.
162
+ output_features: list[torch.Tensor]
163
+ # List of intermediate encoder features to be used in distillation.
164
+ intermediate_features: list[torch.Tensor] = []
165
+
166
+
167
+ class MonodepthWithEncodingAdaptor(nn.Module):
168
+ """Monodepth model with feature maps."""
169
+
170
+ def __init__(
171
+ self,
172
+ monodepth_predictor: MonodepthDensePredictionTransformer,
173
+ return_encoder_features: bool,
174
+ return_decoder_features: bool,
175
+ num_monodepth_layers: int,
176
+ sorting_monodepth: bool,
177
+ ):
178
+ """Initialize MonodepthWithEncodingAdaptor.
179
+
180
+ Args:
181
+ monodepth_predictor: The monodepth model.
182
+ return_encoder_features: Whether to return encoder features from monodepth model.
183
+ return_decoder_features: Whether to return decoder features from monodepth model.
184
+ num_monodepth_layers: How many layers the monodepth model predicts.
185
+ sorting_monodepth: Whether to sort the monodepth output (for two layer monodepth).
186
+ """
187
+ super().__init__()
188
+ self.monodepth_predictor = monodepth_predictor
189
+ self.return_encoder_features = return_encoder_features
190
+ self.return_decoder_features = return_decoder_features
191
+ self.num_monodepth_layers = num_monodepth_layers
192
+ self.sorting_monodepth = sorting_monodepth
193
+
194
+ def forward(self, image: torch.Tensor) -> MonodepthOutput:
195
+ """Process image and return disparity and feature maps."""
196
+ inputs = self.monodepth_predictor.normalizer(image)
197
+ encoder_output = self.monodepth_predictor.encoder(inputs)
198
+
199
+ num_encoder_features = len(self.monodepth_predictor.encoder.dims_encoder)
200
+
201
+ # NOTE: whether intermediate features are empty have already been decided
202
+ # in monodepth_predictor during create_monodepth_dpt.
203
+ encoder_features = encoder_output[:num_encoder_features]
204
+ intermediate_features = encoder_output[num_encoder_features:]
205
+ decoder_features = self.monodepth_predictor.decoder(encoder_features)
206
+ disparity = self.monodepth_predictor.head(decoder_features)
207
+
208
+ # We cannot use disparity.shape[1], otherwise the tracer will fail.
209
+ if self.num_monodepth_layers == 2 and self.sorting_monodepth:
210
+ first_layer_disparity = disparity.max(dim=1, keepdims=True).values
211
+ second_layer_disparity = disparity.min(dim=1, keepdims=True).values
212
+ disparity = torch.cat([first_layer_disparity, second_layer_disparity], dim=1)
213
+
214
+ output_features = []
215
+ if self.return_encoder_features:
216
+ output_features.extend(encoder_features)
217
+
218
+ if self.return_decoder_features:
219
+ output_features.append(decoder_features)
220
+
221
+ return MonodepthOutput(
222
+ disparity=disparity,
223
+ encoder_features=encoder_features,
224
+ decoder_features=decoder_features,
225
+ output_features=output_features,
226
+ intermediate_features=intermediate_features,
227
+ )
228
+
229
+ def get_feature_dims(self) -> list[int]:
230
+ """Return dimensions of output feature maps."""
231
+ dims = []
232
+ if self.return_encoder_features:
233
+ dims.extend(self.monodepth_predictor.encoder.dims_encoder)
234
+
235
+ if self.return_decoder_features:
236
+ dims.append(self.monodepth_predictor.decoder.dim_out)
237
+
238
+ return dims
239
+
240
+ def internal_resolution(self) -> int:
241
+ """Return the internal image size of the network."""
242
+ return self.monodepth_predictor.internal_resolution()
243
+
244
+ def replicate_head(self, num_repeat: int):
245
+ """Replicate the last convolution layer (head[4] in DPT) for multi layer depth."""
246
+ conv_last = copy.deepcopy(self.monodepth_predictor.head[4])
247
+ self.monodepth_predictor.head[4].out_channels = num_repeat
248
+ self.monodepth_predictor.head[4].weight = nn.Parameter(
249
+ conv_last.weight.repeat(num_repeat, 1, 1, 1)
250
+ )
251
+ self.monodepth_predictor.head[4].bias = nn.Parameter(conv_last.bias.repeat(num_repeat))
252
+
253
+
254
+ def create_monodepth_adaptor(
255
+ monodepth_predictor: MonodepthDensePredictionTransformer,
256
+ params: MonodepthAdaptorParams,
257
+ num_monodepth_layers: int,
258
+ sorting_monodepth: bool,
259
+ ) -> MonodepthWithEncodingAdaptor:
260
+ """Create an adaptor that returns both disparity and features."""
261
+ adaptor = MonodepthWithEncodingAdaptor(
262
+ monodepth_predictor=monodepth_predictor,
263
+ return_encoder_features=params.encoder_features,
264
+ return_decoder_features=params.decoder_features,
265
+ num_monodepth_layers=num_monodepth_layers,
266
+ sorting_monodepth=sorting_monodepth,
267
+ )
268
+ return adaptor
src/sharp/models/normalizers.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains an implementation of image normalizers for perceptual loss.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Sequence, Union
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ class MeanStdNormalizer(nn.Module):
16
+ """Normalizing image input by mean and std."""
17
+
18
+ mean: torch.Tensor
19
+ std_inv: torch.Tensor
20
+
21
+ def __init__(
22
+ self,
23
+ mean: Union[Sequence[float], torch.Tensor],
24
+ std: Union[Sequence[float], torch.Tensor],
25
+ ):
26
+ """Initialize MeanStdNormalizer."""
27
+ super(MeanStdNormalizer, self).__init__()
28
+ if not isinstance(mean, torch.Tensor):
29
+ mean = torch.as_tensor(mean).view(-1, 1, 1)
30
+ if not isinstance(std, torch.Tensor):
31
+ std = torch.as_tensor(std).view(-1, 1, 1)
32
+ self.register_buffer("mean", mean)
33
+ # We use inverse std to use a multiplication which is better supported by the hardware
34
+ self.register_buffer("std_inv", 1.0 / std)
35
+
36
+ def forward(self, image: torch.Tensor) -> torch.Tensor:
37
+ """Apply mean and std normalization over input image."""
38
+ return (image - self.mean) * self.std_inv
39
+
40
+
41
+ class AffineRangeNormalizer(nn.Module):
42
+ """Perform linear mapping to map input_range to output_range.
43
+
44
+ Output_range defaults to (0, 1).
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ input_range: tuple[float, float],
50
+ output_range: tuple[float, float] = (0, 1),
51
+ ):
52
+ """Initialize AffineRangeNormalizer."""
53
+ super().__init__()
54
+ input_min, input_max = input_range
55
+ output_min, output_max = output_range
56
+ if input_max <= input_min:
57
+ raise ValueError(f"Invalid input_range: {input_range}")
58
+ if output_max <= output_min:
59
+ raise ValueError(f"Invalid output_range: {output_range}")
60
+
61
+ self.scale = (output_max - output_min) / (input_max - input_min)
62
+ self.bias = output_min - input_min * self.scale
63
+
64
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
65
+ """Apply affine range normalization over input image."""
66
+ if self.scale != 1.0:
67
+ x = x * self.scale
68
+
69
+ if self.bias != 0.0:
70
+ x = x + self.bias
71
+
72
+ return x
73
+
74
+
75
+ class MobileNetNormalizer(AffineRangeNormalizer):
76
+ """Image normalization in mobilenet."""
77
+
78
+ def __init__(self, input_range: tuple[float, float] = (0, 1)):
79
+ """Initialize MobileNetNormalizer."""
80
+ super().__init__(input_range=input_range, output_range=(-1, 1))
src/sharp/models/params.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains params for backbone.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ import dataclasses
8
+ from typing import Literal
9
+
10
+ import sharp.utils.math as math_utils
11
+ from sharp.models.blocks import NormLayerName, UpsamplingMode
12
+ from sharp.models.presets import ViTPreset
13
+ from sharp.utils.color_space import ColorSpace
14
+
15
+ DimsDecoder = tuple[int, int, int, int, int]
16
+ DPTImageEncoderType = Literal["skip_conv", "skip_conv_kernel2"]
17
+
18
+ ColorInitOption = Literal[
19
+ "none", # Initialize as gray.
20
+ "first_layer", # Initialize the first layer with input image, other layers with gray.
21
+ "all_layers", # Initialize all layers with input image.
22
+ ]
23
+ DepthInitOption = Literal[
24
+ # Initialize the layer of gaussian on surface using min pooling of input depth.
25
+ "surface_min",
26
+ # Initialize the layer of gaussian on surface using max pooling of input depth
27
+ "surface_max",
28
+ # Initialize the layer of gaussian on plane using base_depth depth.
29
+ "base_depth",
30
+ # Initialize the layer of gaussian on plane based on base_depth and index of layer.
31
+ "linear_disparity",
32
+ ]
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class AlignmentParams:
37
+ """Parameters for depth alignment."""
38
+
39
+ kernel_size: int = 16
40
+ stride: int = 1
41
+ frozen: bool = False
42
+
43
+ # The following parameters are only used for LearnedAlignment.
44
+ # Number of steps in the UNet for LearnedAlignment.
45
+ steps: int = 4
46
+ # Activation type for LearnedAlignment.
47
+ activation_type: math_utils.ActivationType = "exp"
48
+ # Whether to use depth decoder features for LearnedAlignment.
49
+ depth_decoder_features: bool = False
50
+ # Base width of the UNet for LearnedAlignment.
51
+ base_width: int = 16
52
+
53
+
54
+ @dataclasses.dataclass
55
+ class DeltaFactor:
56
+ """Factors to multiply deltas with before activation.
57
+
58
+ These factors effectively selectively reduce the learning rate.
59
+ """
60
+
61
+ xy: float = 0.001
62
+ z: float = 0.001
63
+ color: float = 0.1 # We recommend 0.1 for linearRGB and 1.0 for sRGB.
64
+ opacity: float = 1.0
65
+ scale: float = 1.0
66
+ quaternion: float = 1.0
67
+
68
+
69
+ @dataclasses.dataclass
70
+ class InitializerParams:
71
+ """Parameters for initializer."""
72
+
73
+ # Common parameters.
74
+ # Multiply scale of Gaussians by this factor.
75
+ scale_factor: float = 1.0
76
+ # Factor to convert inverse depth to disparity.
77
+ disparity_factor: float = 1.0
78
+ # Stride of the initializer.
79
+ stride: int = 2
80
+
81
+ # Parameters that only affect MultiLayerInitializer.
82
+ # How many layers of Gaussians to predict (only available for MultiLayerInitializer).
83
+ num_layers: int = 2
84
+ # Which option to use for depth initialization.
85
+ first_layer_depth_option: DepthInitOption = "surface_min"
86
+ rest_layer_depth_option: DepthInitOption = "surface_min"
87
+ # Which option to use for color initialization.
88
+ color_option: ColorInitOption = "all_layers"
89
+ # Which depth value to use for depth layers.
90
+ base_depth: float = 10.0
91
+ # Deactivate gradient for feature inputs.
92
+ feature_input_stop_grad: bool = False
93
+ # Whether to normalize depth to [DepthTransformParam.depth_min,
94
+ # DepthTransformParam.depth_max).
95
+ normalize_depth: bool = True
96
+
97
+ # Output only the inpainted layer. In this case, num_layers = 1.
98
+ output_inpainted_layer_only: bool = False
99
+ # Whether to set the uninpainted region to zero opacities.
100
+ set_uninpainted_opacity_to_zero: bool = False
101
+ # Whether to concatenate the inpainting mask to the feature input.
102
+ concat_inpainting_mask: bool = False
103
+
104
+
105
+ @dataclasses.dataclass
106
+ class MonodepthParams:
107
+ """Parameters for monodepth network."""
108
+
109
+ patch_encoder_preset: ViTPreset = "dinov2l16_384"
110
+ image_encoder_preset: ViTPreset = "dinov2l16_384"
111
+
112
+ checkpoint_uri: str | None = None
113
+ unfreeze_patch_encoder: bool = False
114
+ unfreeze_image_encoder: bool = False
115
+ unfreeze_decoder: bool = False
116
+ unfreeze_head: bool = False
117
+ unfreeze_norm_layers: bool = False
118
+ grad_checkpointing: bool = False
119
+ use_patch_overlap: bool = True
120
+ dims_decoder: DimsDecoder = (256, 256, 256, 256, 256)
121
+
122
+
123
+ @dataclasses.dataclass
124
+ class MonodepthAdaptorParams:
125
+ """Parameters for monodepth network feature adaptor."""
126
+
127
+ encoder_features: bool = True
128
+ decoder_features: bool = False
129
+
130
+
131
+ @dataclasses.dataclass
132
+ class GaussianDecoderParams:
133
+ """Parameters for backbone with default values."""
134
+
135
+ dim_in: int = 5
136
+ dim_out: int = 32
137
+ # Which normalization to use in backbone.
138
+ norm_type: NormLayerName = "group_norm"
139
+ # How many groups to use for group normalization.
140
+ norm_num_groups: int = 8
141
+ # Stride of backbone.
142
+ stride: int = 2
143
+
144
+ patch_encoder_preset: ViTPreset = "dinov2l16_384"
145
+ image_encoder_preset: ViTPreset = "dinov2l16_384"
146
+
147
+ # Dimensionality of feature maps for DPT decoder.
148
+ dims_decoder: DimsDecoder = (128, 128, 128, 128, 128)
149
+
150
+ # Whether to use depth as input.
151
+ use_depth_input: bool = True
152
+
153
+ # Whether to enable gradient checkpointing for the backbone
154
+ grad_checkpointing: bool = False
155
+
156
+ # What mode to use for upsampling in decoder.
157
+ upsampling_mode: UpsamplingMode = "transposed_conv"
158
+
159
+ # The type of image encoder.
160
+ image_encoder_type: DPTImageEncoderType = "skip_conv_kernel2"
161
+
162
+
163
+ @dataclasses.dataclass
164
+ class PredictorParams:
165
+ """Parameters for predictors with default values."""
166
+
167
+ # Parameters for submodules.
168
+ initializer: InitializerParams = dataclasses.field(default_factory=InitializerParams)
169
+ monodepth: MonodepthParams = dataclasses.field(default_factory=MonodepthParams)
170
+ monodepth_adaptor: MonodepthAdaptorParams = dataclasses.field(
171
+ default_factory=MonodepthAdaptorParams
172
+ )
173
+ gaussian_decoder: GaussianDecoderParams = dataclasses.field(
174
+ default_factory=GaussianDecoderParams
175
+ )
176
+ # How to align depth map (only relevant for RGBGaussianPredictor).
177
+ depth_alignment: AlignmentParams = dataclasses.field(default_factory=AlignmentParams)
178
+
179
+ # Selectively reduce learning rate for different properties.
180
+ delta_factor: DeltaFactor = dataclasses.field(default_factory=DeltaFactor)
181
+ # The maximum scale of Gaussians relative to initial scale.
182
+ max_scale: float = 10.0
183
+ # The minimum scale of Gaussians relative to initial scale.
184
+ min_scale: float = 0.0
185
+ # Which normalization to use in prediction head.
186
+ norm_type: NormLayerName = "group_norm"
187
+ # How many groups to use for group normalization.
188
+ norm_num_groups: int = 8
189
+ # Whether to use predicted mean to sample triplane features.
190
+ use_predicted_mean: bool = False
191
+ # Which activation function to use for colors / opacities.
192
+ color_activation_type: math_utils.ActivationType = "sigmoid"
193
+ opacity_activation_type: math_utils.ActivationType = "sigmoid"
194
+ # Colorspace of the renderer ("linearRGB" or "sRGB").
195
+ color_space: ColorSpace = "linearRGB"
196
+ # A small value to avoid ill-conditioned splats
197
+ low_pass_filter_eps: float = 1e-2
198
+ # How many layer of depth does monodepth model predict.
199
+ num_monodepth_layers: int = 2
200
+ # Whether to sort the monodepth output (for two layer monodepth).
201
+ sorting_monodepth: bool = False
202
+ # Whether to account the z offsets for estimating base scale.
203
+ base_scale_on_predicted_mean: bool = True
src/sharp/models/predictor.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains definition of RGB-only gaussian predictor.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from sharp.models.monodepth import MonodepthWithEncodingAdaptor
15
+ from sharp.utils.gaussians import Gaussians3D
16
+
17
+ from .composer import GaussianComposer
18
+
19
+ LOGGER = logging.getLogger(__name__)
20
+
21
+
22
+ class DepthAlignment(nn.Module):
23
+ """Depth alignment in a dedicated nn.Module.
24
+
25
+ Wrap scale_map_estimator to perform the conditional logic in a separated torch
26
+ module outside the forward of RGBGaussianPredictor. This module can be then
27
+ excluded during symbolic tracing.
28
+ """
29
+
30
+ def __init__(self, scale_map_estimator: nn.Module | None):
31
+ """Initialize DepthAlignmentWrapper.
32
+
33
+ Args:
34
+ scale_map_estimator: Module to align monodepth to ground truth depth.
35
+ """
36
+ super().__init__()
37
+ self.scale_map_estimator = scale_map_estimator
38
+
39
+ def forward(
40
+ self,
41
+ monodepth: torch.Tensor,
42
+ depth: torch.Tensor,
43
+ depth_decoder_features: torch.Tensor | None = None,
44
+ ):
45
+ """Optionally align monodepth to ground truth with a local scale map.
46
+
47
+ Args:
48
+ monodepth: The monodepth model with intermediate features to use.
49
+ depth: Ground truth depth to align predicted depth to.
50
+ depth_decoder_features: The (optional) monodepth decoder features.
51
+ """
52
+ if depth is not None and self.scale_map_estimator is not None:
53
+ depth_alignment_map = self.scale_map_estimator(
54
+ monodepth[:, 0:1], depth, depth_decoder_features
55
+ )
56
+ monodepth = depth_alignment_map * monodepth
57
+ else:
58
+ # Some losses rely on the presence of an alignment map.
59
+ # We ensure that they can be computed by creating a fake alignment map.
60
+ depth_alignment_map = torch.ones_like(monodepth)
61
+ return monodepth, depth_alignment_map
62
+
63
+
64
+ class RGBGaussianPredictor(nn.Module):
65
+ """Predicts 3D Gaussians from images."""
66
+
67
+ feature_model: nn.Module
68
+
69
+ def __init__(
70
+ self,
71
+ init_model: nn.Module,
72
+ monodepth_model: MonodepthWithEncodingAdaptor,
73
+ feature_model: nn.Module,
74
+ prediction_head: nn.Module,
75
+ gaussian_composer: GaussianComposer,
76
+ scale_map_estimator: nn.Module | None,
77
+ ) -> None:
78
+ """Initialize RGBGaussianPredictor.
79
+
80
+ Args:
81
+ init_model: A model mapping image and depth to base values.
82
+ monodepth_model: The monodepth model with intermediate features to use.
83
+ feature_model: The image2image model to predict Gaussians from.
84
+ prediction_head: Head to decode image features.
85
+ gaussian_composer: Module to compose final prediction from deltas and
86
+ base values.
87
+ scale_map_estimator: Module to align monodepth to ground truth depth.
88
+
89
+ Note:
90
+ ----
91
+ when monodepth_model is trainable, using local depth alignment can
92
+ result in the monodepth model losing its ability to predict shapes. It is
93
+ hence recommend to deactivate the corresponding flag.
94
+ """
95
+ super().__init__()
96
+ self.init_model = init_model
97
+ self.feature_model = feature_model
98
+ self.monodepth_model = monodepth_model
99
+ self.prediction_head = prediction_head
100
+ self.gaussian_composer = gaussian_composer
101
+ self.depth_alignment = DepthAlignment(scale_map_estimator)
102
+
103
+ def forward(
104
+ self,
105
+ image: torch.Tensor,
106
+ disparity_factor: torch.Tensor,
107
+ depth: torch.Tensor | None = None,
108
+ ) -> Gaussians3D:
109
+ """Predict 3D Gaussians.
110
+
111
+ Args:
112
+ image: The image to process.
113
+ disparity_factor: Factor to convert depth to disparities.
114
+ depth: Ground truth depth to align predicted depth to.
115
+
116
+ Returns:
117
+ The predicted 3D Gaussians.
118
+
119
+ Note:
120
+ ----
121
+ During training, it is recommended to feed an additional ground truth depth
122
+ map to the network to align the predicted depth to. During inference, it is
123
+ recommended to use depth_gt=None and use monodepth_disparity output from the
124
+ model instead to compute depth.
125
+ """
126
+ # Estimate depth and align to ground truth (if available).
127
+ monodepth_output = self.monodepth_model(image)
128
+ monodepth_disparity = monodepth_output.disparity
129
+
130
+ disparity_factor = disparity_factor[:, None, None, None]
131
+ monodepth = disparity_factor / monodepth_disparity.clamp(min=1e-4, max=1e4)
132
+
133
+ # In the model we apply additional alignment to provided ground truth depth
134
+ # as well as additional normalization.
135
+ #
136
+ # The overall graph looks as follows:
137
+ #
138
+ # monodepth depth # Both monodepth and depth are metric here.
139
+ # | |
140
+ # +------+-------+
141
+ # |
142
+ # +-------+--------+ # Optionally align monodepth to ground truth
143
+ # |depth_alignement| # with a local scale map.
144
+ # +-------+--------+
145
+ # |
146
+ # v
147
+ # monodepth (aligned) # Monodepth is now aligned to ground truth.
148
+ # |
149
+ # +-----+----+ # Normalize depth and compute base gaussians.
150
+ # |init_model| # in these normalized coordinates.
151
+ # +-----+----+
152
+ # |
153
+ # v
154
+ # +------ init_output # Init_output consists of features, base
155
+ # | | # gaussians and a global scale.
156
+ # | +------+-----+
157
+ # | |main network| # Compute delta values to base gaussians.
158
+ # | +------+-----+
159
+ # | |
160
+ # | V
161
+ # | delta_values # The delta values are computed with normalized depth.
162
+ # | |
163
+ # | +-------+---------+
164
+ # +--> |gaussian_composer| # Add delta to base values and unscale gaussians.
165
+ # +-------+---------+
166
+ # |
167
+ # v
168
+ # gaussians # The final Gaussians are metric again.
169
+ #
170
+
171
+ # The logic to decide whether to align monodepth to the ground truth is wrapped
172
+ # in a submodule 'DepthAlignement' to facilitate the symbolic tracing of the
173
+ # predictor. This way, the depth alignment submodule containing the conditional
174
+ # logic can be excluded during the tracing and the graph of the predictors is
175
+ # static.
176
+ monodepth, _ = self.depth_alignment(
177
+ monodepth,
178
+ depth,
179
+ monodepth_output.decoder_features,
180
+ )
181
+
182
+ init_output = self.init_model(image, monodepth)
183
+ image_features = self.feature_model(
184
+ init_output.feature_input, encodings=monodepth_output.output_features
185
+ )
186
+ delta_values = self.prediction_head(image_features)
187
+ gaussians = self.gaussian_composer(
188
+ delta=delta_values,
189
+ base_values=init_output.gaussian_base_values,
190
+ global_scale=init_output.global_scale,
191
+ )
192
+ return gaussians
193
+
194
+ def internal_resolution(self) -> int:
195
+ """Internal resolution."""
196
+ return self.monodepth_model.internal_resolution()
197
+
198
+ @property
199
+ def output_resolution(self) -> int:
200
+ """Output resolution of Gaussians."""
201
+ return self.internal_resolution() // 2
src/sharp/models/presets/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains presets for pretrained neural networks.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from .monodepth import (
8
+ MONODEPTH_ENCODER_DIMS_MAP,
9
+ MONODEPTH_HOOK_IDS_MAP,
10
+ )
11
+ from .vit import (
12
+ VIT_CONFIG_DICT,
13
+ ViTConfig,
14
+ ViTPreset,
15
+ )
16
+
17
+ __all__ = [
18
+ "ViTConfig",
19
+ "ViTPreset",
20
+ "VIT_CONFIG_DICT",
21
+ "MONODEPTH_ENCODER_DIMS_MAP",
22
+ "MONODEPTH_HOOK_IDS_MAP",
23
+ ]
src/sharp/models/presets/monodepth.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains preset for monodepth modules.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from .vit import ViTPreset
10
+
11
+ # Map the decoder configuration with the number of output channels
12
+ # for each tensor from the decoder output.
13
+ MONODEPTH_ENCODER_DIMS_MAP: dict[ViTPreset, list[int]] = {
14
+ # For publication
15
+ "dinov2l16_384": [256, 512, 1024, 1024],
16
+ }
17
+
18
+ MONODEPTH_HOOK_IDS_MAP: dict[ViTPreset, list[int]] = {
19
+ # For publication
20
+ "dinov2l16_384": [5, 11, 17, 23],
21
+ }
src/sharp/models/presets/vit.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains preset for ViT modules.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import dataclasses
10
+ from typing import Literal
11
+
12
+ ViTPreset = Literal["dinov2l16_384",]
13
+
14
+ MLPMode = Literal["vanilla", "glu"]
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class ViTConfig:
19
+ """Configuration for ViT."""
20
+
21
+ in_chans: int
22
+ embed_dim: int
23
+ depth: int
24
+ num_heads: int
25
+ init_values: float
26
+
27
+ img_size: int = 384
28
+ patch_size: int = 16
29
+
30
+ num_classes: int = 21841
31
+ mlp_ratio: float = 4.0
32
+ drop_rate: float = 0.0
33
+ attn_drop_rate: float = 0.0
34
+ drop_path_rate: float = 0.0
35
+ qkv_bias: bool = True
36
+ global_pool: str = "avg"
37
+
38
+ # Properties for timm_vit.
39
+ mlp_mode: MLPMode = "vanilla"
40
+
41
+ # Properties for SPN.
42
+ intermediate_features_ids: list[int] | None = None
43
+
44
+ def asdict(self):
45
+ """Convenience method to convert the class to a dict."""
46
+ return dataclasses.asdict(self)
47
+
48
+
49
+ VIT_CONFIG_DICT: dict[ViTPreset, ViTConfig] = {
50
+ "dinov2l16_384": ViTConfig(
51
+ in_chans=3,
52
+ embed_dim=1024,
53
+ depth=24,
54
+ num_heads=16,
55
+ init_values=1e-5,
56
+ global_pool="",
57
+ ),
58
+ }
src/sharp/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Contains utils packages.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
src/sharp/utils/camera.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains utility functionality to render different modalities.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import dataclasses
10
+ from typing import Literal, NamedTuple
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ from .gaussians import Gaussians3D
16
+ from .linalg import eyes
17
+
18
+ TrajetoryType = Literal["swipe", "shake", "rotate", "rotate_forward"]
19
+ LookAtMode = Literal["point", "ahead"]
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class CameraInfo:
24
+ """Camera info for a pinhole camera."""
25
+
26
+ intrinsics: torch.Tensor
27
+ extrinsics: torch.Tensor
28
+ width: int
29
+ height: int
30
+
31
+
32
+ class FocusRange(NamedTuple):
33
+ """Parametrizes a range of depth / disparity values."""
34
+
35
+ min: float
36
+ focus: float
37
+ max: float
38
+
39
+
40
+ @dataclasses.dataclass
41
+ class TrajectoryParams:
42
+ """Parameters for trajectory."""
43
+
44
+ type: TrajetoryType = "rotate_forward"
45
+ lookat_mode: LookAtMode = "point"
46
+ max_disparity: float = 0.08
47
+ max_zoom: float = 0.15
48
+ distance_m: float = 0.0
49
+ num_steps: int = 60
50
+ num_repeats: int = 1
51
+
52
+
53
+ def compute_max_offset(
54
+ scene: Gaussians3D,
55
+ params: TrajectoryParams,
56
+ resolution_px: tuple[int, int],
57
+ f_px: float,
58
+ ) -> np.ndarray:
59
+ """Compute the maximum offset for camera along X/Y/Z axis."""
60
+ scene_points = scene.mean_vectors
61
+ extrinsics = torch.eye(4).to(scene_points.device)
62
+ min_depth, _, _ = _compute_depth_quantiles(scene_points, extrinsics)
63
+
64
+ r_px = resolution_px
65
+ diagonal = np.sqrt((r_px[0] / f_px) ** 2 + (r_px[1] / f_px) ** 2)
66
+ max_lateral_offset_m = params.max_disparity * diagonal * min_depth
67
+
68
+ max_medial_offset_m = params.max_zoom * min_depth
69
+ max_offset_xyz_m = np.array([max_lateral_offset_m, max_lateral_offset_m, max_medial_offset_m])
70
+
71
+ return max_offset_xyz_m
72
+
73
+
74
+ def create_eye_trajectory(
75
+ scene: Gaussians3D,
76
+ params: TrajectoryParams,
77
+ resolution_px: tuple[int, int],
78
+ f_px: float,
79
+ ) -> list[torch.Tensor]:
80
+ """Create eye trajectory for trajectory type."""
81
+ max_offset_xyz_m = compute_max_offset(
82
+ scene,
83
+ params,
84
+ resolution_px,
85
+ f_px,
86
+ )
87
+ # We place the eye trajectory at z=distance plane (default=0),
88
+ # assuming portal plane is placed at z=natural_distance.
89
+ if params.type == "swipe":
90
+ return create_eye_trajectory_swipe(
91
+ max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats
92
+ )
93
+ elif params.type == "shake":
94
+ return create_eye_trajectory_shake(
95
+ max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats
96
+ )
97
+ elif params.type == "rotate":
98
+ return create_eye_trajectory_rotate(
99
+ max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats
100
+ )
101
+ elif params.type == "rotate_forward":
102
+ return create_eye_trajectory_rotate_forward(
103
+ max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats
104
+ )
105
+ else:
106
+ raise ValueError(f"Invalid trajectory type {params.type}.")
107
+
108
+
109
+ def create_eye_trajectory_swipe(
110
+ offset_xyz_m: np.ndarray,
111
+ distance_m: float,
112
+ num_steps: int,
113
+ num_repeats: int,
114
+ ) -> list[torch.Tensor]:
115
+ """Create a left to right swipe trajectory."""
116
+ offset_x_m, _, _ = offset_xyz_m
117
+ eye_positions = [
118
+ torch.tensor([x, 0, distance_m], dtype=torch.float32)
119
+ for x in np.linspace(-offset_x_m, offset_x_m, num_steps)
120
+ ]
121
+ return eye_positions * num_repeats
122
+
123
+
124
+ def create_eye_trajectory_shake(
125
+ offset_xyz_m: np.ndarray,
126
+ distance_m: float,
127
+ num_steps: int,
128
+ num_repeats: int,
129
+ ) -> list[torch.Tensor]:
130
+ """Create a left right shake followed by an up down shake trajectory."""
131
+ num_steps_total = num_steps * num_repeats
132
+ num_steps_horizontal = num_steps_total // 2
133
+ num_steps_vertical = num_steps_total - num_steps_horizontal
134
+
135
+ offset_x_m, offset_y_m, _ = offset_xyz_m
136
+ eye_positions: list[torch.Tensor] = []
137
+ eye_positions.extend(
138
+ torch.tensor(
139
+ [offset_x_m * np.sin(2 * np.pi * t), 0.0, distance_m],
140
+ dtype=torch.float32,
141
+ )
142
+ for t in np.linspace(0, num_repeats, num_steps_horizontal)
143
+ )
144
+ eye_positions.extend(
145
+ torch.tensor(
146
+ [0.0, offset_y_m * np.sin(2 * np.pi * t), distance_m],
147
+ dtype=torch.float32,
148
+ )
149
+ for t in np.linspace(0, num_repeats, num_steps_vertical)
150
+ )
151
+
152
+ return eye_positions
153
+
154
+
155
+ def create_eye_trajectory_rotate(
156
+ offset_xyz_m: np.ndarray,
157
+ distance_m: float,
158
+ num_steps: int,
159
+ num_repeats: int,
160
+ ) -> list[torch.Tensor]:
161
+ """Create a rotating trajectory."""
162
+ num_steps_total = num_steps * num_repeats
163
+ offset_x_m, offset_y_m, _ = offset_xyz_m
164
+ eye_positions = [
165
+ torch.tensor(
166
+ [
167
+ offset_x_m * np.sin(2 * np.pi * t),
168
+ offset_y_m * np.cos(2 * np.pi * t),
169
+ distance_m,
170
+ ],
171
+ dtype=torch.float32,
172
+ )
173
+ for t in np.linspace(0, num_repeats, num_steps_total)
174
+ ]
175
+
176
+ return eye_positions
177
+
178
+
179
+ def create_eye_trajectory_rotate_forward(
180
+ offset_xyz_m: np.ndarray,
181
+ distance_m: float,
182
+ num_steps: int,
183
+ num_repeats: int,
184
+ ) -> list[torch.Tensor]:
185
+ """Create a rotating trajectory."""
186
+ num_steps_total = num_steps * num_repeats
187
+ offset_x_m, _, offset_z_m = offset_xyz_m
188
+ eye_positions = [
189
+ torch.tensor(
190
+ [
191
+ offset_x_m * np.sin(2 * np.pi * t),
192
+ 0.0,
193
+ distance_m + offset_z_m * (1.0 - np.cos(2 * np.pi * t)) / 2,
194
+ ],
195
+ dtype=torch.float32,
196
+ )
197
+ for t in np.linspace(0, num_repeats, num_steps_total)
198
+ ]
199
+
200
+ return eye_positions
201
+
202
+
203
+ def create_camera_model(
204
+ scene: Gaussians3D,
205
+ intrinsics: torch.Tensor,
206
+ resolution_px: tuple[int, int],
207
+ lookat_mode: LookAtMode = "point",
208
+ ) -> PinholeCameraModel:
209
+ """Create camera model to simulate general pinhole camera."""
210
+ screen_extrinsics = torch.eye(4)
211
+ screen_intrinsics = intrinsics.clone()
212
+
213
+ image_width, image_height = resolution_px
214
+ screen_resolution_px = get_screen_resolution_px_from_input(
215
+ width=image_width, height=image_height
216
+ )
217
+
218
+ screen_intrinsics[0] *= screen_resolution_px[0] / image_width
219
+ screen_intrinsics[1] *= screen_resolution_px[1] / image_height
220
+
221
+ camera_model = PinholeCameraModel(
222
+ scene,
223
+ screen_extrinsics=screen_extrinsics,
224
+ screen_intrinsics=screen_intrinsics,
225
+ screen_resolution_px=screen_resolution_px,
226
+ focus_depth_quantile=0.1,
227
+ min_depth_focus=2.0,
228
+ lookat_mode=lookat_mode,
229
+ )
230
+ return camera_model
231
+
232
+
233
+ def create_camera_matrix(
234
+ position: torch.Tensor,
235
+ look_at_position: torch.Tensor | None = None,
236
+ world_up: torch.Tensor | None = None,
237
+ inverse: bool = False,
238
+ ) -> torch.Tensor:
239
+ """Create camera matrix from vectors."""
240
+ device = position.device
241
+
242
+ if look_at_position is None:
243
+ look_at_position = torch.zeros(3, device=device)
244
+ if world_up is None:
245
+ world_up = torch.tensor([0.0, 0.0, 1.0], device=device)
246
+
247
+ position, look_at_position, world_up = torch.broadcast_tensors(
248
+ position, look_at_position, world_up
249
+ )
250
+
251
+ camera_front = look_at_position - position
252
+ camera_front = camera_front / camera_front.norm(dim=-1, keepdim=True)
253
+
254
+ camera_right = torch.cross(camera_front, world_up, dim=-1)
255
+ camera_right = camera_right / camera_right.norm(dim=-1, keepdim=True)
256
+
257
+ camera_down = torch.cross(camera_front, camera_right, dim=-1)
258
+ rotation_matrix = torch.stack([camera_right, camera_down, camera_front], dim=-1)
259
+
260
+ matrix = eyes(dim=4, shape=position.shape[:-1], device=device)
261
+ if inverse:
262
+ matrix[..., :3, :3] = rotation_matrix.transpose(-1, -2)
263
+ matrix[..., :3, 3:4] = -rotation_matrix.transpose(-1, -2) @ position[..., None]
264
+ else:
265
+ matrix[..., :3, :3] = rotation_matrix
266
+ matrix[..., :3, 3] = position
267
+
268
+ return matrix
269
+
270
+
271
+ class PinholeCameraModel:
272
+ """Camera model that focuses on point."""
273
+
274
+ def __init__(
275
+ self,
276
+ scene: Gaussians3D,
277
+ screen_extrinsics: torch.Tensor,
278
+ screen_intrinsics: torch.Tensor,
279
+ screen_resolution_px: tuple[int, int],
280
+ focus_depth_quantile: float = 0.1,
281
+ min_depth_focus: float = 2.0,
282
+ lookat_point: tuple[float, float, float] | None = None,
283
+ lookat_mode: LookAtMode = "point",
284
+ ) -> None:
285
+ """Initialize GeneralPinholeCameraModel.
286
+
287
+ Args:
288
+ scene: The scene to display.
289
+ screen_extrinsics: Extrinsics of the default position.
290
+ screen_intrinsics: Intrinsics to use for rendering.
291
+ screen_resolution_px: Width and height to render.
292
+ focus_depth_quantile: Where inside the depth range to focus on.
293
+ min_depth_focus: Depth to focus at.
294
+ lookat_point: a point that the camera's Z axis directs towards.
295
+ lookat_mode: "point" to look at a fixed point,
296
+ "ahead" to look straight ahead.
297
+ """
298
+ self.scene = scene
299
+ self.screen_extrinsics = screen_extrinsics
300
+ self.screen_intrinsics = screen_intrinsics
301
+ self.screen_resolution_px = screen_resolution_px
302
+
303
+ self.focus_depth_quantile = focus_depth_quantile
304
+ self.min_depth_focus = min_depth_focus
305
+ self.lookat_point = lookat_point
306
+ self.lookat_mode = lookat_mode
307
+
308
+ scene_points = scene.mean_vectors
309
+ if scene_points.ndim == 3:
310
+ scene_points = scene_points[0]
311
+ elif scene_points.ndim != 2:
312
+ raise ValueError("Unsupported dimensionality of scene points.")
313
+ self._scene_points = scene_points.cpu()
314
+
315
+ self.depth_quantiles = _compute_depth_quantiles(
316
+ self._scene_points,
317
+ self.screen_extrinsics,
318
+ q_focus=self.focus_depth_quantile,
319
+ )
320
+
321
+ def compute(self, eye_pos: torch.Tensor) -> CameraInfo:
322
+ """Compute camera for eye position."""
323
+ extrinsics = self.screen_extrinsics.clone()
324
+
325
+ origin = eye_pos if self.lookat_mode == "ahead" else torch.zeros(3)
326
+
327
+ if self.lookat_point is None:
328
+ depth_focus = max(self.min_depth_focus, self.depth_quantiles.focus)
329
+ look_at_position = origin + torch.tensor([0.0, 0.0, depth_focus])
330
+ else:
331
+ look_at_position = origin + torch.tensor([*self.lookat_point])
332
+
333
+ world_up = torch.tensor([0.0, -1.0, 0.0])
334
+ extrinsics_modifier = create_camera_matrix(
335
+ eye_pos, look_at_position, world_up, inverse=True
336
+ )
337
+ extrinsics = extrinsics_modifier @ self.screen_extrinsics
338
+
339
+ camera_info = CameraInfo(
340
+ intrinsics=self.screen_intrinsics,
341
+ extrinsics=extrinsics,
342
+ width=self.screen_resolution_px[0],
343
+ height=self.screen_resolution_px[1],
344
+ )
345
+ return camera_info
346
+
347
+ def set_screen_extrinsics(self, new_value: torch.Tensor) -> None:
348
+ """Modify the default extrinsics."""
349
+ self.screen_extrinsics = new_value
350
+ self.depth_quantiles = _compute_depth_quantiles(self._scene_points, self.screen_extrinsics)
351
+
352
+
353
+ def get_screen_resolution_px_from_input(width: int, height: int) -> tuple[int, int]:
354
+ """Get resolution for metadata dictionary."""
355
+ resolution_px = (width, height)
356
+ # halve the dimensions for super large image
357
+ if resolution_px[1] > 3000:
358
+ resolution_px = (resolution_px[0] // 2, resolution_px[1] // 2)
359
+ # for mp4 compatibility, enforce dimensions to even number,
360
+ # otherwise could not be played in browser
361
+ if resolution_px[0] % 2 != 0:
362
+ resolution_px = (resolution_px[0] + 1, resolution_px[1])
363
+ if resolution_px[1] % 2 != 0:
364
+ resolution_px = (resolution_px[0], resolution_px[1] + 1)
365
+ return resolution_px
366
+
367
+
368
+ def _compute_depth_quantiles(
369
+ points: torch.Tensor,
370
+ extrinsics: torch.Tensor,
371
+ q_near: float = 0.001,
372
+ q_focus: float = 0.1,
373
+ q_far: float = 0.999,
374
+ ) -> FocusRange:
375
+ """Compute disparity quantiles for scene and extrinsics id."""
376
+ points_local = points @ extrinsics[:3, :3].T + extrinsics[:3, 3]
377
+ depth_values = points_local[..., 2].flatten()
378
+ depth_values = depth_values[depth_values > 0]
379
+ q_values = torch.tensor([q_near, q_focus, q_far])
380
+ depth_quantiles_pt = torch.quantile(depth_values.cpu(), q_values)
381
+ depth_quantiles = FocusRange(
382
+ min=float(depth_quantiles_pt[0]),
383
+ focus=float(depth_quantiles_pt[1]),
384
+ max=float(depth_quantiles_pt[2]),
385
+ )
386
+ return depth_quantiles
src/sharp/utils/color_space.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains color space utility functions.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from typing import Literal
11
+
12
+ import torch
13
+
14
+ from sharp.utils.robust import robust_where
15
+
16
+ LOGGER = logging.getLogger(__name__)
17
+
18
+ ColorSpace = Literal["sRGB", "linearRGB"]
19
+
20
+
21
+ def encode_color_space(color_space: ColorSpace) -> int:
22
+ """Encode color space to integer."""
23
+ return 0 if color_space == "sRGB" else 1
24
+
25
+
26
+ def decode_color_space(color_space_index: int) -> ColorSpace:
27
+ """Decode color space index to color space."""
28
+ return "sRGB" if color_space_index == 0 else "linearRGB"
29
+
30
+
31
+ def sRGB2linearRGB(sRGB: torch.Tensor) -> torch.Tensor:
32
+ """SRGB to linearRGB conversion function.
33
+
34
+ Reference:
35
+ https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
36
+ Section 7.7.7
37
+
38
+ Args:
39
+ sRGB: Input image tensor in sRGB space.
40
+ """
41
+ # We need to use robust_where to clamp the second branch.
42
+ # Otherwise, torch.where will lead to NaN in the backward pass, see
43
+ # https://github.com/pytorch/pytorch/issues/68425
44
+ THRESHOLD = 0.04045
45
+
46
+ def branch_true_func(x):
47
+ return x / 12.92
48
+
49
+ def branch_false_func(x):
50
+ return ((x + 0.055) / 1.055) ** 2.4
51
+
52
+ return robust_where(
53
+ sRGB <= THRESHOLD,
54
+ sRGB,
55
+ branch_true_func,
56
+ branch_false_func,
57
+ branch_false_safe_value=THRESHOLD,
58
+ )
59
+
60
+
61
+ def linearRGB2sRGB(linearRGB: torch.Tensor) -> torch.Tensor:
62
+ """LinearRGB to sRGB conversion function.
63
+
64
+ Reference:
65
+ https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
66
+ Section 7.7.7
67
+
68
+ Args:
69
+ linearRGB: Input image tensor in linearRGB space.
70
+ """
71
+ # We need to use robust_where to clamp the second branch.
72
+ # Otherwise, torch.where will lead to NaN in the backward pass, see
73
+ # https://github.com/pytorch/pytorch/issues/68425
74
+ THRESHOLD = 0.0031308
75
+
76
+ def branch_true_func(x):
77
+ return x * 12.92
78
+
79
+ def branch_false_func(x):
80
+ return 1.055 * (x ** (1 / 2.4)) - 0.055
81
+
82
+ return robust_where(
83
+ linearRGB <= THRESHOLD,
84
+ linearRGB,
85
+ branch_true_func,
86
+ branch_false_func,
87
+ branch_false_safe_value=THRESHOLD,
88
+ )
src/sharp/utils/gaussians.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains basic data structures and functionality for 3D Gaussians.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Any, Literal, NamedTuple
12
+
13
+ import numpy as np
14
+ import torch
15
+ from plyfile import PlyData, PlyElement
16
+
17
+ from sharp.utils import color_space as cs_utils
18
+ from sharp.utils import linalg
19
+
20
+ LOGGER = logging.getLogger(__name__)
21
+
22
+
23
+ BackgroundColor = Literal["black", "white", "random_color", "random_pixel"]
24
+
25
+
26
+ class Gaussians3D(NamedTuple):
27
+ """Represents a collection of 3D Gaussians."""
28
+
29
+ mean_vectors: torch.Tensor
30
+ singular_values: torch.Tensor
31
+ quaternions: torch.Tensor
32
+ colors: torch.Tensor
33
+ opacities: torch.Tensor
34
+
35
+ def to(self, device: torch.device) -> Gaussians3D:
36
+ """Move Gaussians to device."""
37
+ return Gaussians3D(
38
+ mean_vectors=self.mean_vectors.to(device),
39
+ singular_values=self.singular_values.to(device),
40
+ quaternions=self.quaternions.to(device),
41
+ colors=self.colors.to(device),
42
+ opacities=self.opacities.to(device),
43
+ )
44
+
45
+
46
+ class SceneMetaData(NamedTuple):
47
+ """Meta data about Gaussian scene."""
48
+
49
+ focal_length_px: float
50
+ resolution_px: tuple[int, int]
51
+ color_space: cs_utils.ColorSpace
52
+
53
+
54
+ def get_unprojection_matrix(
55
+ extrinsics: torch.Tensor,
56
+ intrinsics: torch.Tensor,
57
+ image_shape: tuple[int, int],
58
+ ) -> torch.Tensor:
59
+ """Compute unprojection matrix to transform Gaussians to Euclidean space.
60
+
61
+ Args:
62
+ extrinsics: The 4x4 extrinsics matrix of the camera view.
63
+ intrinsics: The 4x4 intrinsics matrix of the camera view.
64
+ image_shape: The (width, height) of the input image.
65
+
66
+ Returns:
67
+ A 4x4 matrix to transform Gaussians from NDC space to Euclidean space.
68
+ """
69
+ device = intrinsics.device
70
+ image_width, image_height = image_shape
71
+ # This matrix converts OpenCV pixel coordinates to NDC coordinates where
72
+ # (-1, 1) denotes the top left and (1, 1) the bottom right of the image.
73
+ #
74
+ # Note that premultiplying the intrinsics with ndc_matrix typically yields a matrix
75
+ # that simply scales the x-axis by 2 * focal_length / image_width and the y-axis by
76
+ # 2 * focal_length / image_height.
77
+ ndc_matrix = torch.tensor(
78
+ [
79
+ [2.0 / image_width, 0.0, -1.0, 0.0],
80
+ [0.0, 2.0 / image_height, -1.0, 0.0],
81
+ [0.0, 0.0, 1.0, 0.0],
82
+ [0.0, 0.0, 0.0, 1.0],
83
+ ],
84
+ device=device,
85
+ )
86
+ return torch.linalg.inv(ndc_matrix @ intrinsics @ extrinsics)
87
+
88
+
89
+ def unproject_gaussians(
90
+ gaussians_ndc: Gaussians3D,
91
+ extrinsics: torch.Tensor,
92
+ intrinsics: torch.Tensor,
93
+ image_shape: tuple[int, int],
94
+ ) -> Gaussians3D:
95
+ """Unproject Gaussians from NDC space to world coordinates."""
96
+ unprojection_matrix = get_unprojection_matrix(extrinsics, intrinsics, image_shape)
97
+ gaussians = apply_transform(gaussians_ndc, unprojection_matrix[:3])
98
+ return gaussians
99
+
100
+
101
+ def apply_transform(gaussians: Gaussians3D, transform: torch.Tensor) -> Gaussians3D:
102
+ """Apply an affine transformation to 3D Gaussians.
103
+
104
+ Args:
105
+ gaussians: The Gaussians to transform.
106
+ transform: An affine transform with shape 3x4.
107
+
108
+ Returns:
109
+ The transformed Gaussians.
110
+
111
+ Note: This operation is not differentiable.
112
+ """
113
+ transform_linear = transform[..., :3, :3]
114
+ transform_offset = transform[..., :3, 3]
115
+
116
+ mean_vectors = gaussians.mean_vectors @ transform_linear.T + transform_offset
117
+ covariance_matrices = compose_covariance_matrices(
118
+ gaussians.quaternions, gaussians.singular_values
119
+ )
120
+ covariance_matrices = (
121
+ transform_linear @ covariance_matrices @ transform_linear.transpose(-1, -2)
122
+ )
123
+ quaternions, singular_values = decompose_covariance_matrices(covariance_matrices)
124
+
125
+ return Gaussians3D(
126
+ mean_vectors=mean_vectors,
127
+ singular_values=singular_values,
128
+ quaternions=quaternions,
129
+ colors=gaussians.colors,
130
+ opacities=gaussians.opacities,
131
+ )
132
+
133
+
134
+ def decompose_covariance_matrices(
135
+ covariance_matrices: torch.Tensor,
136
+ ) -> tuple[torch.Tensor, torch.Tensor]:
137
+ """Decompose 3D covariance matrices into quaternions and singular values.
138
+
139
+ Args:
140
+ covariance_matrices: The covariance matrices to decompose.
141
+
142
+ Returns:
143
+ Quaternion and singular values corresponding to the orientation and scales of
144
+ the diagonalized matrix.
145
+
146
+ Note: This operation is not differentiable.
147
+ """
148
+ device = covariance_matrices.device
149
+ dtype = covariance_matrices.dtype
150
+
151
+ # We convert to fp64 to avoid numerical errors.
152
+ covariance_matrices = covariance_matrices.detach().cpu().to(torch.float64)
153
+ rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices)
154
+
155
+ # NOTE: in SVD, it is possible that U and VT are both reflections.
156
+ # We need to correct them.
157
+ batch_idx, gaussian_idx = torch.where(torch.linalg.det(rotations) < 0)
158
+ num_reflections = len(gaussian_idx)
159
+ if num_reflections > 0:
160
+ LOGGER.warning(
161
+ "Received %d reflection matrices from SVD. Flipping them to rotations.",
162
+ num_reflections,
163
+ )
164
+ # Flip the last column of reflection and make it a rotation.
165
+ rotations[batch_idx, gaussian_idx, :, -1] *= -1
166
+ quaternions = linalg.quaternions_from_rotation_matrices(rotations)
167
+ quaternions = quaternions.to(dtype=dtype, device=device)
168
+ singular_values = singular_values_2.sqrt().to(dtype=dtype, device=device)
169
+ return quaternions, singular_values
170
+
171
+
172
+ def compose_covariance_matrices(
173
+ quaternions: torch.Tensor, singular_values: torch.Tensor
174
+ ) -> torch.Tensor:
175
+ """Compose 3D covariance matrices into quaternions and singular values.
176
+
177
+ Args:
178
+ quaternions: The quaternions describing the principal basis.
179
+ singular_values: The scales of the diagonalized matrix.
180
+
181
+ Returns:
182
+ The 3x3 covariances matrices.
183
+ """
184
+ device = quaternions.device
185
+ rotations = linalg.rotation_matrices_from_quaternions(quaternions)
186
+ diagonal_matrix = torch.eye(3, device=device) * singular_values[..., :, None]
187
+ return rotations @ diagonal_matrix.square() @ rotations.transpose(-1, -2)
188
+
189
+
190
+ def convert_spherical_harmonics_to_rgb(sh0: torch.Tensor) -> torch.Tensor:
191
+ """Convert degree-0 spherical harmonics to RGB.
192
+
193
+ Reference:
194
+ https://en.wikipedia.org/wiki/Table_of_spherical_harmonics
195
+ """
196
+ coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi))
197
+ return sh0 * coeff_degree0 + 0.5
198
+
199
+
200
+ def convert_rgb_to_spherical_harmonics(rgb: torch.Tensor) -> torch.Tensor:
201
+ """Convert RGB to degree-0 spherical harmonics.
202
+
203
+ Reference:
204
+ https://en.wikipedia.org/wiki/Table_of_spherical_harmonics
205
+ """
206
+ coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi))
207
+ return (rgb - 0.5) / coeff_degree0
208
+
209
+
210
+ def load_ply(path: Path) -> tuple[Gaussians3D, SceneMetaData]:
211
+ """Loads a ply from a file."""
212
+ plydata = PlyData.read(path)
213
+
214
+ vertices = next(filter(lambda x: x.name == "vertex", plydata.elements))
215
+
216
+ properties = ["x", "y", "z"]
217
+ properties.extend([f"f_dc_{i}" for i in range(3)])
218
+ properties.extend([f"scale_{i}" for i in range(3)])
219
+ properties.extend([f"rot_{i}" for i in range(3)])
220
+
221
+ for prop in properties:
222
+ if prop not in vertices:
223
+ raise KeyError(f"Incompatible ply file: property {prop} not found in ply elements.")
224
+ mean_vectors = np.stack(
225
+ (
226
+ np.asarray(vertices["x"]),
227
+ np.asarray(vertices["y"]),
228
+ np.asarray(vertices["z"]),
229
+ ),
230
+ axis=1,
231
+ )
232
+
233
+ scale_logits = np.stack(
234
+ (
235
+ np.asarray(vertices["scale_0"]),
236
+ np.asarray(vertices["scale_1"]),
237
+ np.asarray(vertices["scale_2"]),
238
+ ),
239
+ axis=1,
240
+ )
241
+
242
+ quaternions = np.stack(
243
+ (
244
+ np.asarray(vertices["rot_0"]),
245
+ np.asarray(vertices["rot_1"]),
246
+ np.asarray(vertices["rot_2"]),
247
+ np.asarray(vertices["rot_3"]),
248
+ ),
249
+ axis=1,
250
+ )
251
+
252
+ spherical_harmonics_deg0 = np.stack(
253
+ (
254
+ np.asarray(vertices["f_dc_0"]),
255
+ np.asarray(vertices["f_dc_1"]),
256
+ np.asarray(vertices["f_dc_2"]),
257
+ ),
258
+ axis=1,
259
+ )
260
+
261
+ colors = convert_spherical_harmonics_to_rgb(spherical_harmonics_deg0)
262
+
263
+ opacity_logits = np.asarray(vertices["opacity"])[..., None]
264
+
265
+ supplement_elements = [element for element in plydata.elements if element.name != "vertex"]
266
+ supplement_data: dict[str, Any] = {}
267
+ supplement_keys = ["extrinsic", "intrinsic", "color_space", "image_size"]
268
+
269
+ for element in supplement_elements:
270
+ for key in supplement_keys:
271
+ if key not in supplement_data and key in element:
272
+ supplement_data[key] = np.asarray(element[key])
273
+
274
+ # Parse intrinsics and image_size.
275
+ if "intrinsic" in supplement_data:
276
+ intrinsics_data = supplement_data["intrinsic"]
277
+
278
+ # Legacy: image_size is contained in intrinsic element.
279
+ if "image_size" not in supplement_data:
280
+ if len(intrinsics_data) != 4:
281
+ raise ValueError(
282
+ "Expect legacy intrinsics with len=4 containing image size, "
283
+ f"but received len={len(intrinsics_data)}"
284
+ )
285
+ focal_length_px = (intrinsics_data[0], intrinsics_data[1])
286
+ width = int(intrinsics_data[2])
287
+ height = int(intrinsics_data[3])
288
+
289
+ else:
290
+ if len(intrinsics_data) != 9:
291
+ raise ValueError(
292
+ "Expect 9 elements in intrinsics, " f"but received {len(intrinsics_data)}."
293
+ )
294
+ intrinsics_matrix = intrinsics_data.reshape((3, 3))
295
+ focal_length_px = (intrinsics_matrix[0, 0], intrinsics_matrix[1, 1])
296
+
297
+ image_size_data = supplement_data["image_size"]
298
+ width = image_size_data[0]
299
+ height = image_size_data[1]
300
+
301
+ # Default to VGA resolution: focal length = 512, image size = (640, 480).
302
+ else:
303
+ focal_length_px = (512, 512)
304
+ width = 640
305
+ height = 480
306
+
307
+ # Parse extrinsics.
308
+ extrinsics_data = supplement_data.get("extrinsic", np.eye(4).flatten())
309
+ extrinsics_matrix = np.eye(4)
310
+
311
+ # Legacy: extrinsics store 12 elements.
312
+ if len(extrinsics_data) == 12:
313
+ extrinsics_matrix[:3] = extrinsics_data.reshape((3, 4))
314
+ extrinsics_matrix[:3, :3] = extrinsics_matrix[:3, :3].copy().T
315
+ elif len(extrinsics_data) == 16:
316
+ extrinsics_matrix[:] = extrinsics_data.reshape((4, 4))
317
+ else:
318
+ raise ValueError(f"Unrecognized extrinsics matrix shape {len(extrinsics_data)}")
319
+
320
+ # Parse color space.
321
+ color_space_index = supplement_data.get("color_space", 1)
322
+ color_space = cs_utils.decode_color_space(color_space_index)
323
+ if color_space == "sRGB":
324
+ colors = cs_utils.sRGB2linearRGB(colors)
325
+
326
+ mean_vectors = torch.from_numpy(mean_vectors).view(1, -1, 3).float()
327
+ quaternions = torch.from_numpy(quaternions).view(1, -1, 4).float()
328
+ singular_values = torch.exp(torch.from_numpy(scale_logits).view(1, -1, 3)).float()
329
+ opacities = torch.sigmoid(torch.from_numpy(opacity_logits).view(1, -1)).float()
330
+ colors = torch.from_numpy(colors).view(1, -1, 3).float()
331
+
332
+ gaussians = Gaussians3D(
333
+ mean_vectors=mean_vectors,
334
+ quaternions=quaternions,
335
+ singular_values=singular_values,
336
+ opacities=opacities,
337
+ colors=colors,
338
+ )
339
+ metadata = SceneMetaData(focal_length_px[0], (width, height), color_space)
340
+ return gaussians, metadata
341
+
342
+
343
+ @torch.no_grad()
344
+ def save_ply(
345
+ gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path
346
+ ) -> PlyData:
347
+ """Save a predicted Gaussian3D to a ply file."""
348
+
349
+ def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor:
350
+ return torch.log(tensor / (1.0 - tensor))
351
+
352
+ xyz = gaussians.mean_vectors.flatten(0, 1)
353
+ scale_logits = torch.log(gaussians.singular_values).flatten(0, 1)
354
+ quaternions = gaussians.quaternions.flatten(0, 1)
355
+
356
+ # SHARP takes an image, convert it to sRGB color space as input,
357
+ # and predicts linearRGB Gaussians as output.
358
+ # The SHARP renderer would blend linearRGB Gaussians and convert rendered images and videos
359
+ # back to sRGB for the best display quality.
360
+ #
361
+ # However, public renderers do not have such linear2sRGB conversions after rendering.
362
+ # If they render linearRGB Gaussians as-is, the output would be dark without Gamma correction.
363
+ #
364
+ # To make it compatible to public renderers, we force convert linearRGB to sRGB during export.
365
+ # - The SHARP renderer will still handle conversions properly.
366
+ # - Public renderers will be mostly working fine when regarding sRGB images as linearRGB images,
367
+ # although for the best performance, it is recommended to apply the conversions.
368
+ colors = convert_rgb_to_spherical_harmonics(
369
+ cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1))
370
+ )
371
+ color_space_index = cs_utils.encode_color_space("sRGB")
372
+
373
+ # Store opacity logits.
374
+ opacity_logits = _inverse_sigmoid(gaussians.opacities).flatten(0, 1).unsqueeze(-1)
375
+
376
+ attributes = torch.cat(
377
+ (
378
+ xyz,
379
+ colors,
380
+ opacity_logits,
381
+ scale_logits,
382
+ quaternions,
383
+ ),
384
+ dim=1,
385
+ )
386
+
387
+ dtype_full = [
388
+ (attribute, "f4")
389
+ for attribute in ["x", "y", "z"]
390
+ + [f"f_dc_{i}" for i in range(3)]
391
+ + ["opacity"]
392
+ + [f"scale_{i}" for i in range(3)]
393
+ + [f"rot_{i}" for i in range(4)]
394
+ ]
395
+
396
+ num_gaussians = len(xyz)
397
+ elements = np.empty(num_gaussians, dtype=dtype_full)
398
+ elements[:] = list(map(tuple, attributes.detach().cpu().numpy()))
399
+ vertex_elements = PlyElement.describe(elements, "vertex")
400
+
401
+ # Load image-wise metadata.
402
+ image_height, image_width = image_shape
403
+
404
+ # Export image size.
405
+ dtype_image_size = [("image_size", "u4")]
406
+ image_size_array = np.empty(2, dtype=dtype_image_size)
407
+ image_size_array[:] = np.array([image_width, image_height])
408
+ image_size_element = PlyElement.describe(image_size_array, "image_size")
409
+
410
+ # Export intrinsics.
411
+ dtype_intrinsic = [("intrinsic", "f4")]
412
+ intrinsic_array = np.empty(9, dtype=dtype_intrinsic)
413
+ intrinsic = np.array(
414
+ [
415
+ f_px,
416
+ 0,
417
+ image_width * 0.5,
418
+ 0,
419
+ f_px,
420
+ image_height * 0.5,
421
+ 0,
422
+ 0,
423
+ 1,
424
+ ]
425
+ )
426
+ intrinsic_array[:] = intrinsic.flatten()
427
+ intrinsic_element = PlyElement.describe(intrinsic_array, "intrinsic")
428
+
429
+ # Export dummy extrinsics.
430
+ dtype_extrinsic = [("extrinsic", "f4")]
431
+ extrinsic_array = np.empty(16, dtype=dtype_extrinsic)
432
+ extrinsic_array[:] = np.eye(4).flatten()
433
+ extrinsic_element = PlyElement.describe(extrinsic_array, "extrinsic")
434
+
435
+ # Export number of frames and particles per frame.
436
+ dtype_frames = [("frame", "i4")]
437
+ frame_array = np.empty(2, dtype=dtype_frames)
438
+ frame_array[:] = np.array([1, num_gaussians], dtype=np.int32)
439
+ frame_element = PlyElement.describe(frame_array, "frame")
440
+
441
+ # Export disparity ranges for transform.
442
+ dtype_disparity = [("disparity", "f4")]
443
+ disparity_array = np.empty(2, dtype=dtype_disparity)
444
+
445
+ disparity = 1.0 / gaussians.mean_vectors[0, ..., -1]
446
+ quantiles = (
447
+ torch.quantile(disparity, q=torch.tensor([0.1, 0.9], device=disparity.device))
448
+ .float()
449
+ .cpu()
450
+ .numpy()
451
+ )
452
+ disparity_array[:] = quantiles
453
+ disparity_element = PlyElement.describe(disparity_array, "disparity")
454
+
455
+ # Export colorspace.
456
+ dtype_color_space = [("color_space", "u1")]
457
+ color_space_array = np.empty(1, dtype=dtype_color_space)
458
+ color_space_array[:] = np.array([color_space_index]).flatten()
459
+ color_space_element = PlyElement.describe(color_space_array, "color_space")
460
+
461
+ dtype_version = [("version", "u1")]
462
+ version_array = np.empty(3, dtype=dtype_version)
463
+ version_array[:] = np.array([1, 5, 0], dtype=np.uint8).flatten()
464
+ version_element = PlyElement.describe(version_array, "version")
465
+
466
+ plydata = PlyData(
467
+ [
468
+ vertex_elements,
469
+ extrinsic_element,
470
+ intrinsic_element,
471
+ image_size_element,
472
+ frame_element,
473
+ disparity_element,
474
+ color_space_element,
475
+ version_element,
476
+ ]
477
+ )
478
+
479
+ plydata.write(path)
480
+ return plydata
src/sharp/utils/gsplat.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains utility code for gsplat renderer.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from pathlib import Path
10
+ from typing import NamedTuple
11
+
12
+ import gsplat
13
+ import torch
14
+ from torch import nn
15
+
16
+ from sharp.utils import color_space as cs_utils
17
+ from sharp.utils import io, vis
18
+ from sharp.utils.gaussians import BackgroundColor, Gaussians3D
19
+
20
+
21
+ class RenderingOutputs(NamedTuple):
22
+ """Outputs of 3D Gaussians renderer."""
23
+
24
+ color: torch.Tensor
25
+ depth: torch.Tensor
26
+ alpha: torch.Tensor
27
+
28
+
29
+ def write_renderings(rendering: RenderingOutputs, output_folder: Path, filename: str):
30
+ """Write rendered color/depth/alpha to files."""
31
+ batch_size = len(rendering.color)
32
+ if batch_size != 1:
33
+ raise RuntimeError("We only support saving rendering of batch size = 1")
34
+
35
+ def _save_image_tensor(tensor: torch.Tensor, suffix: str):
36
+ np_array = tensor.permute(1, 2, 0).numpy()
37
+ io.save_image(np_array, (output_folder / filename).with_suffix(suffix))
38
+
39
+ color = (rendering.color[0].cpu() * 255.0).to(dtype=torch.uint8)
40
+ colorized_depth = vis.colorize_depth(rendering.depth[0], val_max=100.0)
41
+ colorized_alpha = vis.colorize_alpha(rendering.alpha[0])
42
+
43
+ _save_image_tensor(color, ".color.png")
44
+ _save_image_tensor(colorized_depth, ".depth.png")
45
+ _save_image_tensor(colorized_alpha, ".alpha.png")
46
+
47
+
48
+ class GSplatRenderer(nn.Module):
49
+ """Module to render 3D Gaussians to images using gsplat."""
50
+
51
+ color_space: cs_utils.ColorSpace
52
+ background_color: BackgroundColor
53
+
54
+ def __init__(
55
+ self,
56
+ color_space: cs_utils.ColorSpace = "sRGB",
57
+ background_color: BackgroundColor = "black",
58
+ low_pass_filter_eps: float = 0.0,
59
+ ) -> None:
60
+ """Initialize gsplat renderer.
61
+
62
+ Args:
63
+ color_space: The color space to use for rendering.
64
+ background_color: The background color to use for rendering.
65
+ low_pass_filter_eps: The epsilon value for the low pass filter.
66
+ """
67
+ super().__init__()
68
+ self.color_space = color_space
69
+ self.background_color = background_color
70
+ self.low_pass_filter_eps = low_pass_filter_eps
71
+
72
+ def forward(
73
+ self,
74
+ gaussians: Gaussians3D,
75
+ extrinsics: torch.Tensor,
76
+ intrinsics: torch.Tensor,
77
+ image_width: int,
78
+ image_height: int,
79
+ ) -> RenderingOutputs:
80
+ """Predict images from gaussians.
81
+
82
+ Args:
83
+ gaussians: The Gaussians to render.
84
+ extrinsics: The extrinsics of the camera to render to in OpenCV format.
85
+ intrinsics: The intriniscs of the camera to render to in OpenCV format.
86
+ image_width: The desired output image width.
87
+ image_height: The desired output image height.
88
+ """
89
+ batch_size = len(gaussians.mean_vectors)
90
+ outputs_list: list[RenderingOutputs] = []
91
+
92
+ for ib in range(batch_size):
93
+ colors, alphas, meta = gsplat.rendering.rasterization(
94
+ means=gaussians.mean_vectors[ib],
95
+ quats=gaussians.quaternions[ib],
96
+ scales=gaussians.singular_values[ib],
97
+ opacities=gaussians.opacities[ib],
98
+ colors=gaussians.colors[ib],
99
+ viewmats=extrinsics[ib : ib + 1],
100
+ Ks=intrinsics[ib : ib + 1, :3, :3],
101
+ width=image_width,
102
+ height=image_height,
103
+ render_mode="RGB+D",
104
+ rasterize_mode="classic",
105
+ absgrad=False,
106
+ packed=False,
107
+ eps2d=self.low_pass_filter_eps,
108
+ )
109
+
110
+ rendered_color = colors[..., 0:3].permute([0, 3, 1, 2])
111
+ rendered_depth_unnormalized = colors[..., 3:4].permute([0, 3, 1, 2])
112
+ rendered_alpha = alphas.permute([0, 3, 1, 2])
113
+
114
+ # Compose with background color.
115
+ rendered_color = self.compose_with_background(
116
+ rendered_color, rendered_alpha, self.background_color
117
+ )
118
+
119
+ # Colorspace conversion.
120
+ if self.color_space == "sRGB":
121
+ pass
122
+ elif self.color_space == "linearRGB":
123
+ rendered_color = cs_utils.linearRGB2sRGB(rendered_color)
124
+ else:
125
+ ValueError("Unsupported ColorSpace type.")
126
+
127
+ # splats: (B, N, 10)
128
+ cov2d = self._conics_to_covars2d(meta["conics"])
129
+ # Set the cov2d of invisible splats to 1 to avoid nan in condition number calculation..
130
+ splats_visible_mask = meta["depths"] > 1e-2
131
+ cov2d[~splats_visible_mask][..., 0, 0] = 1
132
+ cov2d[~splats_visible_mask][..., 1, 1] = 1
133
+ cov2d[~splats_visible_mask][..., 0, 1] = 0
134
+
135
+ # Normalize the depth by alpha.
136
+ rendered_depth = rendered_depth_unnormalized / torch.clip(rendered_alpha, min=1e-8)
137
+
138
+ outputs = RenderingOutputs(
139
+ color=rendered_color,
140
+ depth=rendered_depth,
141
+ alpha=rendered_alpha,
142
+ )
143
+ outputs_list.append(outputs)
144
+
145
+ return RenderingOutputs(
146
+ color=torch.cat([item.color for item in outputs_list], dim=0).contiguous(),
147
+ depth=torch.cat([item.depth for item in outputs_list], dim=0).contiguous(),
148
+ alpha=torch.cat([item.alpha for item in outputs_list], dim=0).contiguous(),
149
+ )
150
+
151
+ @staticmethod
152
+ def compose_with_background(
153
+ rendered_rgb: torch.Tensor,
154
+ rendered_alpha: torch.Tensor,
155
+ background_color: BackgroundColor,
156
+ ) -> torch.Tensor:
157
+ """Compose rendered RGB with background color."""
158
+ if background_color == "black":
159
+ return rendered_rgb
160
+ elif background_color == "white":
161
+ return rendered_rgb + (1.0 - rendered_alpha)
162
+ elif background_color == "random_color":
163
+ return (
164
+ rendered_rgb
165
+ + (1.0 - rendered_alpha)
166
+ * torch.rand(3, dtype=rendered_rgb.dtype, device=rendered_rgb.device)[
167
+ None, :, None, None
168
+ ]
169
+ )
170
+ elif background_color == "random_pixel":
171
+ return rendered_rgb + (1.0 - rendered_alpha) * torch.rand_like(rendered_rgb)
172
+ else:
173
+ raise ValueError("Unsupported BackgroundColor type.")
174
+
175
+ @staticmethod
176
+ def _conics_to_covars2d(conics: torch.Tensor, eps=1e-8) -> torch.Tensor:
177
+ """Convert conics to covariance matrices."""
178
+ a = conics[..., 0]
179
+ b = conics[..., 1]
180
+ c = conics[..., 2]
181
+ # Reconstruct determinant.
182
+ det = 1 / (a * c - b**2 + eps)
183
+ det = det.clamp(min=eps)
184
+ # Reconstruct covars2d.
185
+ covars2d = torch.zeros(*conics.shape[:-1], 2, 2, device=conics.device)
186
+ covars2d[..., 1, 1] = a * det
187
+ covars2d[..., 0, 0] = c * det
188
+ covars2d[..., 0, 1] = -b * det
189
+ covars2d[..., 1, 0] = -b * det
190
+ covars2d = torch.nan_to_num(covars2d, nan=0.0, posinf=0.0, neginf=0.0)
191
+ return covars2d
src/sharp/utils/io.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains image IO.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import io
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import IO, Any, Protocol
13
+
14
+ import imageio.v2 as iio
15
+ import numpy as np
16
+ import pillow_heif
17
+ import torch
18
+ from PIL import ExifTags, Image, TiffTags
19
+
20
+ from .vis import METRIC_DEPTH_MAX_CLAMP_METER, colorize_depth
21
+
22
+ LOGGER = logging.getLogger(__name__)
23
+
24
+
25
+ # NOTE: unused, kept for reference.
26
+ Image.MAX_IMAGE_PIXELS = 200000000
27
+
28
+
29
+ def load_rgb(
30
+ path: Path, auto_rotate: bool = True, remove_alpha: bool = True
31
+ ) -> tuple[np.ndarray, list[bytes] | None, float]:
32
+ """Load an RGB image."""
33
+ LOGGER.debug(f"Loading image {path} ...")
34
+
35
+ if path.suffix.lower() in [".heic"]:
36
+ heif_file = pillow_heif.open_heif(path, convert_hdr_to_8bit=True)
37
+ img_pil = heif_file.to_pillow()
38
+ else:
39
+ img_pil = Image.open(path)
40
+
41
+ img_exif = extract_exif(img_pil)
42
+ icc_profile = img_pil.info.get("icc_profile", None)
43
+
44
+ # Rotate the image.
45
+ if auto_rotate:
46
+ exif_orientation = img_exif.get("Orientation", 1)
47
+ if exif_orientation == 3:
48
+ img_pil = img_pil.transpose(Image.ROTATE_180)
49
+ elif exif_orientation == 6:
50
+ img_pil = img_pil.transpose(Image.ROTATE_270)
51
+ elif exif_orientation == 8:
52
+ img_pil = img_pil.transpose(Image.ROTATE_90)
53
+ elif exif_orientation != 1:
54
+ LOGGER.warning(f"Ignoring image orientation {exif_orientation}.")
55
+
56
+ # Extract the focal length.
57
+ f_35mm = img_exif.get("FocalLengthIn35mmFilm", img_exif.get("FocalLenIn35mmFilm", None))
58
+ if f_35mm is None or f_35mm < 1:
59
+ f_35mm = img_exif.get("FocalLength", None)
60
+ if f_35mm is None:
61
+ LOGGER.warn(f"Did not find focallength in exif data of {path} - Setting to 30mm.")
62
+ f_35mm = 30.0
63
+ if f_35mm < 10.0:
64
+ LOGGER.info("Found focal length below 10mm, assuming it's not for 35mm.")
65
+ # This is a very crude approximation.
66
+ f_35mm *= 8.4
67
+
68
+ img = np.asarray(img_pil)
69
+ # Convert to RGB if single channel.
70
+ if img.ndim < 3 or img.shape[2] == 1:
71
+ img = np.dstack((img, img, img))
72
+
73
+ if remove_alpha:
74
+ img = img[:, :, :3]
75
+
76
+ LOGGER.debug(f"\tHxW: {img.shape[0]}x{img.shape[1]}")
77
+ LOGGER.debug(f"\tfocal length @ 35mm film: {f_35mm}mm")
78
+ f_px = convert_focallength(img.shape[1], img.shape[0], f_35mm)
79
+ LOGGER.debug(f"\tfocal length: {f_px:.2f}px")
80
+
81
+ return img, icc_profile, f_px
82
+
83
+
84
+ def extract_exif(img_pil: Image.Image) -> dict[str, Any]:
85
+ """Return exif information as a dictionary."""
86
+ # Get full exif description from get_ifd(0x8769):
87
+ # cf https://pillow.readthedocs.io/en/stable/releasenotes/8.2.0.html#image-getexif-exif-and-gps-ifd # noqa
88
+ img_exif = img_pil.getexif().get_ifd(0x8769)
89
+ exif_dict = {ExifTags.TAGS[k]: v for k, v in img_exif.items() if k in ExifTags.TAGS}
90
+
91
+ # https://pillow.readthedocs.io/en/stable/_modules/PIL/TiffTags.html# # noqa
92
+ tiff_tags = img_pil.getexif()
93
+ tiff_dict = {TiffTags.TAGS_V2[k].name: v for k, v in tiff_tags.items() if k in TiffTags.TAGS_V2}
94
+ return {**exif_dict, **tiff_dict}
95
+
96
+
97
+ def convert_focallength(width: float, height: float, f_mm: float = 30) -> float:
98
+ """Converts a focal length given in mm to pixels."""
99
+ return f_mm * np.sqrt(width**2.0 + height**2.0) / np.sqrt(36**2 + 24**2)
100
+
101
+
102
+ def save_image(
103
+ image: np.ndarray,
104
+ output_path: Path,
105
+ icc_profile: list[bytes] | None = None,
106
+ jpeg_quality: int = 92,
107
+ ) -> None:
108
+ """Save image to given path."""
109
+ output_path.parent.mkdir(parents=True, exist_ok=True)
110
+
111
+ extensions_to_format = Image.registered_extensions()
112
+ try:
113
+ format = extensions_to_format[output_path.suffix.lower()]
114
+ except KeyError:
115
+ raise ValueError(f"Unsupported output format {output_path.suffix}.")
116
+
117
+ with output_path.open("wb") as file_handle:
118
+ write_image(
119
+ image,
120
+ file_handle,
121
+ format,
122
+ icc_profile=icc_profile,
123
+ jpeg_quality=jpeg_quality,
124
+ )
125
+
126
+
127
+ def write_image(
128
+ image: np.ndarray,
129
+ output_io: IO[bytes],
130
+ format="jpg",
131
+ icc_profile: list[bytes] | None = None,
132
+ jpeg_quality: int = 92,
133
+ ):
134
+ """Write image to binary stream."""
135
+ pil_config = {}
136
+ if format == "JPEG":
137
+ pil_config["quality"] = jpeg_quality
138
+
139
+ image_pil = Image.fromarray(image)
140
+
141
+ # Workaround to error [io.UnsupportedOperation: seek].
142
+ if format == "TIFF":
143
+ bytes_io = io.BytesIO()
144
+ image_pil.save(bytes_io, format="TIFF")
145
+ bytes_io.seek(0)
146
+ output_io.write(bytes_io.read())
147
+ return
148
+
149
+ image_pil.save(output_io, format, icc_profile=icc_profile, **pil_config)
150
+
151
+
152
+ def get_supported_image_extensions(with_heic: bool = True) -> list[str]:
153
+ """Return supported image extensions."""
154
+ exts = Image.registered_extensions()
155
+ supported_extensions = {ex for ex, f in exts.items() if f in Image.OPEN}
156
+ if with_heic:
157
+ supported_extensions.add(".heic")
158
+
159
+ supported_extensions_upper = {ex.upper() for ex in supported_extensions}
160
+ return list(supported_extensions | supported_extensions_upper)
161
+
162
+
163
+ def get_supported_video_extensions():
164
+ """Return supported video extensions."""
165
+ supported_extensions = {".mp4", ".mov"}
166
+ supported_extensions_upper = {ext.upper() for ext in supported_extensions}
167
+ return list(supported_extensions | supported_extensions_upper)
168
+
169
+
170
+ class OutputWriter(Protocol):
171
+ """Protocol for writing output to disk."""
172
+
173
+ def add_frame(self, image: torch.Tensor, depth: torch.Tensor) -> None:
174
+ """Add a single frame to output."""
175
+ ...
176
+
177
+ def close(self) -> None:
178
+ """Finish writing."""
179
+ ...
180
+
181
+
182
+ class VideoWriter(OutputWriter):
183
+ """Output writer for video output."""
184
+
185
+ def __init__(self, output_path: Path, fps: float = 30.0, render_depth: bool = True) -> None:
186
+ """Initialize VideoWriter."""
187
+ output_path.parent.mkdir(exist_ok=True, parents=True)
188
+ self.output_path = output_path
189
+ self.image_writer = iio.get_writer(output_path, fps=fps)
190
+
191
+ self.max_depth_estimate = None
192
+ if render_depth:
193
+ self.depth_writer = iio.get_writer(output_path.with_suffix(".depth.mp4"), fps=fps)
194
+
195
+ def add_frame(self, image: torch.Tensor, depth: torch.Tensor) -> None:
196
+ """Add a single frame to output."""
197
+ image_np = image.detach().cpu().numpy()
198
+ self.image_writer.append_data(image_np)
199
+
200
+ if self.depth_writer is not None:
201
+ if self.max_depth_estimate is None:
202
+ self.max_depth_estimate = depth.max().item()
203
+
204
+ colored_depth_pt = colorize_depth(
205
+ depth,
206
+ min(self.max_depth_estimate, METRIC_DEPTH_MAX_CLAMP_METER), # type: ignore[call-overload]
207
+ )
208
+ colored_depth_np = colored_depth_pt.squeeze(0).permute(1, 2, 0).cpu().numpy()
209
+ self.depth_writer.append_data(colored_depth_np)
210
+
211
+ def close(self):
212
+ """Finish writing."""
213
+ self.image_writer.close()
src/sharp/utils/linalg.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains linear algebra related utility functions.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from scipy.spatial.transform import Rotation
12
+
13
+
14
+ def rotation_matrices_from_quaternions(quaternions: torch.Tensor) -> torch.Tensor:
15
+ """Convert batch of quaternions into rotations matrices.
16
+
17
+ Args:
18
+ quaternions: The quaternions convert to matrices.
19
+
20
+ Returns:
21
+ The rotations matrices corresponding to the (normalized) quaternions.
22
+ """
23
+ device = quaternions.device
24
+ shape = quaternions.shape[:-1]
25
+
26
+ quaternions = quaternions / torch.linalg.norm(quaternions, dim=-1, keepdim=True)
27
+ real_part = quaternions[..., 0]
28
+ vector_part = quaternions[..., 1:]
29
+
30
+ vector_cross = get_cross_product_matrix(vector_part)
31
+ real_part = real_part[..., None, None]
32
+
33
+ matrix_outer = vector_part[..., :, None] * vector_part[..., None, :]
34
+ matrix_diag = real_part.square() * eyes(3, shape=shape, device=device)
35
+ matrix_cross_1 = 2 * real_part * vector_cross
36
+ matrix_cross_2 = vector_cross @ vector_cross
37
+
38
+ return matrix_outer + matrix_diag + matrix_cross_1 + matrix_cross_2
39
+
40
+
41
+ def quaternions_from_rotation_matrices(matrices: torch.Tensor) -> torch.Tensor:
42
+ """Convert batch of rotation matrices to quaternions.
43
+
44
+ Args:
45
+ matrices: The matrices to convert to quaternions.
46
+
47
+ Returns:
48
+ The quaternions corresponding to the rotation matrices.
49
+
50
+ Note: this operation is not differentiable and will be performed on the CPU.
51
+ """
52
+ if not matrices.shape[-2:] == (3, 3):
53
+ raise ValueError(f"matrices have invalid shape {matrices.shape}")
54
+ matrices_np = matrices.detach().cpu().numpy()
55
+ quaternions_np = Rotation.from_matrix(matrices_np.reshape(-1, 3, 3)).as_quat()
56
+ # We use a convention where the w component is at the start of the quaternion.
57
+ quaternions_np = quaternions_np[:, [3, 0, 1, 2]]
58
+ quaternions_np = quaternions_np.reshape(matrices_np.shape[:-2] + (4,))
59
+ return torch.as_tensor(quaternions_np, device=matrices.device, dtype=matrices.dtype)
60
+
61
+
62
+ def get_cross_product_matrix(vectors: torch.Tensor) -> torch.Tensor:
63
+ """Generate cross product matrix for vector exterior product."""
64
+ if not vectors.shape[-1] == 3:
65
+ raise ValueError("Only 3-dimensional vectors are supported")
66
+ device = vectors.device
67
+ shape = vectors.shape[:-1]
68
+ unit_basis = eyes(3, shape=shape, device=device)
69
+ # We compute the matrix by multiplying each column of unit_basis with the
70
+ # corresponding vector.
71
+ return torch.cross(vectors[..., :, None], unit_basis, dim=-2)
72
+
73
+
74
+ def eyes(
75
+ dim: int, shape: tuple[int, ...], device: torch.device | str | None = None
76
+ ) -> torch.Tensor:
77
+ """Create batch of identity matrices."""
78
+ return torch.eye(dim, device=device).broadcast_to(shape + (dim, dim)).clone()
79
+
80
+
81
+ def quaternion_product(q1, q2):
82
+ """Compute dot product between two quaternions."""
83
+ real_1 = q1[..., :1]
84
+ real_2 = q2[..., :1]
85
+ vector_1 = q1[..., 1:]
86
+ vector_2 = q2[..., 1:]
87
+
88
+ real_out = real_1 * real_2 - (vector_1 * vector_2).sum(dim=-1, keepdim=True)
89
+ vector_out = real_1 * vector_2 + real_2 * vector_1 + torch.cross(vector_1, vector_2)
90
+ return torch.concatenate([real_out, vector_out], dim=-1)
91
+
92
+
93
+ def quaternion_conj(q):
94
+ """Get conjugate of a quaternion."""
95
+ real = q[..., :1]
96
+ vector = q[..., 1:]
97
+ return torch.concatenate([real, -vector], dim=-1)
98
+
99
+
100
+ def project(u: torch.Tensor, basis: torch.Tensor) -> torch.Tensor:
101
+ """Project tensor u to unit basis a."""
102
+ unit_u = F.normalize(u, dim=-1)
103
+ inner_prod = (unit_u * basis).sum(dim=-1, keepdim=True)
104
+ return inner_prod * u
src/sharp/utils/logging.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains logging related utility functions.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import sys
11
+ from pathlib import Path
12
+
13
+
14
+ def configure(log_level: int, log_path: Path | None = None, prefix: str | None = None) -> None:
15
+ """Configure logger globally.
16
+
17
+ Args:
18
+ log_level: The desired verbosity level.
19
+ log_path: The path to write logs to.
20
+ prefix: The prefix of the logger.
21
+ """
22
+ logger = logging.getLogger(prefix)
23
+
24
+ # Reset logger to initial state (e.g. to avoid side effects from imports).
25
+ for handler in logger.handlers:
26
+ logger.removeHandler(handler)
27
+
28
+ for filter in logger.filters:
29
+ logger.removeFilter(filter)
30
+
31
+ # Set level.
32
+ logger.setLevel(log_level)
33
+
34
+ formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")
35
+
36
+ # Set up console handler.
37
+ stdout_handler = logging.StreamHandler(sys.stdout)
38
+ stdout_handler.setFormatter(formatter)
39
+ logger.addHandler(stdout_handler)
40
+
41
+ # Set up file handler.
42
+ if log_path is not None:
43
+ file_handler = logging.FileHandler(log_path, mode="w")
44
+ file_handler.setFormatter(formatter)
45
+ logger.addHandler(file_handler)
src/sharp/utils/math.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains utility math functions.
2
+
3
+ For licensing see accompanying LICENSE file.
4
+ Copyright (C) 2025 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, Callable, Literal, NamedTuple, Tuple, Union
10
+
11
+ import torch
12
+ from torch import autograd
13
+
14
+ ActivationType = Literal[
15
+ "linear",
16
+ "exp",
17
+ "sigmoid",
18
+ "softplus",
19
+ "relu_with_pushback",
20
+ "hard_sigmoid_with_pushback",
21
+ ]
22
+ ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
23
+
24
+
25
+ class ActivationPair(NamedTuple):
26
+ """A pair of forward and inverse activation functions."""
27
+
28
+ forward: ActivationFunction
29
+ inverse: ActivationFunction
30
+
31
+
32
+ def create_activation_pair(activation_type: ActivationType) -> ActivationPair:
33
+ """Create activation function and corresponding inverse function.
34
+
35
+ Args:
36
+ activation_type: The activation type to create.
37
+
38
+ Returns:
39
+ The corresponding activation functions and the corresponding inverse function.
40
+ """
41
+ if activation_type == "linear":
42
+ return ActivationPair(lambda x: x, lambda x: x)
43
+ elif activation_type == "exp":
44
+ return ActivationPair(torch.exp, torch.log)
45
+ elif activation_type == "sigmoid":
46
+ return ActivationPair(torch.sigmoid, inverse_sigmoid)
47
+ elif activation_type == "softplus":
48
+ return ActivationPair(torch.nn.functional.softplus, inverse_softplus)
49
+ elif activation_type == "relu_with_pushback":
50
+ return ActivationPair(relu_with_pushback, lambda x: x)
51
+ elif activation_type == "hard_sigmoid_with_pushback":
52
+ return ActivationPair(hard_sigmoid_with_pushback, lambda x: 6.0 * x - 3.0)
53
+ else:
54
+ raise ValueError(f"Unsupported activation function: {activation_type}.")
55
+
56
+
57
+ def inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor:
58
+ """Compute inverse sigmoid."""
59
+ return torch.log(tensor / (1.0 - tensor))
60
+
61
+
62
+ def inverse_softplus(tensor: torch.Tensor, eps: float = 1e-06) -> torch.Tensor:
63
+ """Compute inverse softplus."""
64
+ tensor = tensor.clamp_min(eps)
65
+ sigmoid = torch.sigmoid(-tensor)
66
+ exp = sigmoid / (1.0 - sigmoid)
67
+ return tensor + torch.log(-exp + 1.0)
68
+
69
+
70
+ # The first value describes the threshold from where clamping will be applied, while
71
+ # the second value describes the value to clamp with.
72
+ SoftClampRange = Tuple[Union[torch.Tensor, float], Union[torch.Tensor, float]]
73
+
74
+
75
+ def softclamp(
76
+ tensor: torch.Tensor,
77
+ min: SoftClampRange | None = None,
78
+ max: SoftClampRange | None = None,
79
+ ) -> torch.Tensor:
80
+ """Clamp tensor to min/max in differentiable way.
81
+
82
+ Args:
83
+ tensor: The tensor to clamp.
84
+ min: Pair of threshold to start clamping and value to clamp to.
85
+ The first value should be larger than the second.
86
+ max: Pair of threshold to start clamping and value to clamp to.
87
+ The first value should be smaller than the second.
88
+
89
+ Returns:
90
+ The clamped tensor.
91
+ """
92
+
93
+ def normalize(clamp_range: SoftClampRange) -> torch.Tensor:
94
+ value0, value1 = clamp_range
95
+ return value0 + (value1 - value0) * torch.tanh((tensor - value0) / (value1 - value0))
96
+
97
+ tensor_clamped = tensor
98
+ if min is not None:
99
+ tensor_clamped = torch.maximum(tensor_clamped, normalize(min))
100
+ if max is not None:
101
+ tensor_clamped = torch.minimum(tensor_clamped, normalize(max))
102
+
103
+ return tensor_clamped
104
+
105
+
106
+ class ClampWithPushback(autograd.Function):
107
+ """Implementation of clamp_with_pushback function."""
108
+
109
+ @staticmethod
110
+ def forward(
111
+ ctx: Any,
112
+ tensor: torch.Tensor,
113
+ min: float | None,
114
+ max: float | None,
115
+ pushback: float,
116
+ ) -> torch.Tensor:
117
+ """Apply clamp."""
118
+ if min is not None and max is not None and min >= max:
119
+ raise ValueError("Only min < max is supported.")
120
+
121
+ ctx.save_for_backward(tensor)
122
+ ctx.min = min
123
+ ctx.max = max
124
+ ctx.pushback = pushback
125
+ return torch.clamp(tensor, min=min, max=max)
126
+
127
+ @staticmethod
128
+ def backward( # type: ignore[override] # Deal with buggy torch annotations.
129
+ ctx: Any, grad_in: torch.Tensor
130
+ ) -> tuple[torch.Tensor, None, None, None]:
131
+ """Compute gradient of clamp with pushback."""
132
+ grad_out = grad_in.clone()
133
+ (tensor,) = ctx.saved_tensors
134
+
135
+ if ctx.min is not None:
136
+ mask_min = tensor < ctx.min
137
+ grad_out[mask_min] = -ctx.pushback
138
+
139
+ if ctx.max is not None:
140
+ mask_max = tensor > ctx.max
141
+ grad_out[mask_max] = ctx.pushback
142
+
143
+ return grad_out, None, None, None
144
+
145
+
146
+ def clamp_with_pushback(
147
+ tensor: torch.Tensor,
148
+ min: float | None = None,
149
+ max: float | None = None,
150
+ pushback: float = 1e-2,
151
+ ) -> torch.Tensor:
152
+ """Variant of clamp function which avoid the vanishing gradient problem.
153
+
154
+ This function is equivalent to adding a regularizer of the form
155
+
156
+ pushback * sum_i (
157
+ relu(min - preactivation_i) + relu(preactivation_i - max)
158
+ )
159
+
160
+ to the full loss function, which pushes clamped values back.
161
+
162
+ When used in minimization problems, pushback should be greater than
163
+ zero. In maximization problems, pushback should be smaller than zero.
164
+ """
165
+ output = ClampWithPushback.apply(tensor, min, max, pushback)
166
+ assert isinstance(output, torch.Tensor)
167
+ return output
168
+
169
+
170
+ def hard_sigmoid_with_pushback(x: torch.Tensor, slope: float = 1.0 / 6.0) -> torch.Tensor:
171
+ """Apply hard sigmoid with pushback.
172
+
173
+ For compatibility reasons, we follow the default PyTorch implementation with a
174
+ default slope of 1/6:
175
+
176
+ https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
177
+ """
178
+ return clamp_with_pushback(slope * x + 0.5, min=0.0, max=1.0)
179
+
180
+
181
+ def relu_with_pushback(x: torch.Tensor) -> torch.Tensor:
182
+ """Compute relu with pushback."""
183
+ return clamp_with_pushback(x, min=0.0)