Commit
·
c20d7cc
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +166 -0
- .pre-commit-config.yaml +23 -0
- .python-version +1 -0
- ACKNOWLEDGEMENTS +214 -0
- CODE_OF_CONDUCT.md +70 -0
- CONTRIBUTING.md +11 -0
- LICENSE +47 -0
- LICENSE_MODEL +88 -0
- README.md +95 -0
- pyproject.toml +69 -0
- requirements.in +1 -0
- requirements.txt +172 -0
- src/sharp/__init__.py +4 -0
- src/sharp/cli/__init__.py +19 -0
- src/sharp/cli/predict.py +206 -0
- src/sharp/cli/render.py +120 -0
- src/sharp/models/__init__.py +79 -0
- src/sharp/models/alignment.py +126 -0
- src/sharp/models/blocks.py +210 -0
- src/sharp/models/composer.py +251 -0
- src/sharp/models/decoders/__init__.py +22 -0
- src/sharp/models/decoders/base_decoder.py +21 -0
- src/sharp/models/decoders/monodepth_decoder.py +37 -0
- src/sharp/models/decoders/multires_conv_decoder.py +116 -0
- src/sharp/models/decoders/unet_decoder.py +113 -0
- src/sharp/models/encoders/__init__.py +24 -0
- src/sharp/models/encoders/base_encoder.py +25 -0
- src/sharp/models/encoders/monodepth_encoder.py +123 -0
- src/sharp/models/encoders/spn_encoder.py +369 -0
- src/sharp/models/encoders/unet_encoder.py +117 -0
- src/sharp/models/encoders/vit_encoder.py +111 -0
- src/sharp/models/gaussian_decoder.py +267 -0
- src/sharp/models/heads.py +53 -0
- src/sharp/models/initializer.py +297 -0
- src/sharp/models/monodepth.py +268 -0
- src/sharp/models/normalizers.py +80 -0
- src/sharp/models/params.py +203 -0
- src/sharp/models/predictor.py +201 -0
- src/sharp/models/presets/__init__.py +23 -0
- src/sharp/models/presets/monodepth.py +21 -0
- src/sharp/models/presets/vit.py +58 -0
- src/sharp/utils/__init__.py +5 -0
- src/sharp/utils/camera.py +386 -0
- src/sharp/utils/color_space.py +88 -0
- src/sharp/utils/gaussians.py +480 -0
- src/sharp/utils/gsplat.py +191 -0
- src/sharp/utils/io.py +213 -0
- src/sharp/utils/linalg.py +104 -0
- src/sharp/utils/logging.py +45 -0
- src/sharp/utils/math.py +183 -0
.gitignore
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 110 |
+
.pdm.toml
|
| 111 |
+
.pdm-python
|
| 112 |
+
.pdm-build/
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Environments
|
| 125 |
+
.env
|
| 126 |
+
.venv
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
|
| 133 |
+
# Spyder project settings
|
| 134 |
+
.spyderproject
|
| 135 |
+
.spyproject
|
| 136 |
+
|
| 137 |
+
# Rope project settings
|
| 138 |
+
.ropeproject
|
| 139 |
+
|
| 140 |
+
# mkdocs documentation
|
| 141 |
+
/site
|
| 142 |
+
|
| 143 |
+
# mypy
|
| 144 |
+
.mypy_cache/
|
| 145 |
+
.dmypy.json
|
| 146 |
+
dmypy.json
|
| 147 |
+
|
| 148 |
+
# Pyre type checker
|
| 149 |
+
.pyre/
|
| 150 |
+
|
| 151 |
+
# pytype static type analyzer
|
| 152 |
+
.pytype/
|
| 153 |
+
|
| 154 |
+
# Cython debug symbols
|
| 155 |
+
cython_debug/
|
| 156 |
+
|
| 157 |
+
# PyCharm
|
| 158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
#.idea/
|
| 163 |
+
|
| 164 |
+
.DS_STORE
|
| 165 |
+
*.pt
|
| 166 |
+
.aider*
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exclude: |
|
| 2 |
+
(?x)(
|
| 3 |
+
^src/sharp/external
|
| 4 |
+
)
|
| 5 |
+
repos:
|
| 6 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 7 |
+
rev: v4.5.0
|
| 8 |
+
hooks:
|
| 9 |
+
- id: trailing-whitespace
|
| 10 |
+
- id: end-of-file-fixer
|
| 11 |
+
# - id: no-commit-to-branch
|
| 12 |
+
# args: ['--branch', 'main']
|
| 13 |
+
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
| 14 |
+
rev: v0.1.7
|
| 15 |
+
hooks:
|
| 16 |
+
- id: ruff
|
| 17 |
+
args: [--fix, --exit-non-zero-on-fix]
|
| 18 |
+
- id: ruff-format
|
| 19 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
| 20 |
+
rev: v1.7.1
|
| 21 |
+
hooks:
|
| 22 |
+
- id: mypy
|
| 23 |
+
additional_dependencies: [ types-PyYAML ]
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.13
|
ACKNOWLEDGEMENTS
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Acknowledgements
|
| 2 |
+
Portions of this Software may utilize the following copyrighted
|
| 3 |
+
material, the use of which is hereby acknowledged.
|
| 4 |
+
|
| 5 |
+
---------------------------------------------------------------------------------
|
| 6 |
+
|
| 7 |
+
TIMM - Pytorch Image Models library
|
| 8 |
+
|
| 9 |
+
https://github.com/huggingface/pytorch-image-models
|
| 10 |
+
|
| 11 |
+
Apache License
|
| 12 |
+
Version 2.0, January 2004
|
| 13 |
+
http://www.apache.org/licenses/
|
| 14 |
+
|
| 15 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 16 |
+
|
| 17 |
+
1. Definitions.
|
| 18 |
+
|
| 19 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 20 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 21 |
+
|
| 22 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 23 |
+
the copyright owner that is granting the License.
|
| 24 |
+
|
| 25 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 26 |
+
other entities that control, are controlled by, or are under common
|
| 27 |
+
control with that entity. For the purposes of this definition,
|
| 28 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 29 |
+
direction or management of such entity, whether by contract or
|
| 30 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 31 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 32 |
+
|
| 33 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 34 |
+
exercising permissions granted by this License.
|
| 35 |
+
|
| 36 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 37 |
+
including but not limited to software source code, documentation
|
| 38 |
+
source, and configuration files.
|
| 39 |
+
|
| 40 |
+
"Object" form shall mean any form resulting from mechanical
|
| 41 |
+
transformation or translation of a Source form, including but
|
| 42 |
+
not limited to compiled object code, generated documentation,
|
| 43 |
+
and conversions to other media types.
|
| 44 |
+
|
| 45 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 46 |
+
Object form, made available under the License, as indicated by a
|
| 47 |
+
copyright notice that is included in or attached to the work
|
| 48 |
+
(an example is provided in the Appendix below).
|
| 49 |
+
|
| 50 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 51 |
+
form, that is based on (or derived from) the Work and for which the
|
| 52 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 53 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 54 |
+
of this License, Derivative Works shall not include works that remain
|
| 55 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 56 |
+
the Work and Derivative Works thereof.
|
| 57 |
+
|
| 58 |
+
"Contribution" shall mean any work of authorship, including
|
| 59 |
+
the original version of the Work and any modifications or additions
|
| 60 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 61 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 62 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 63 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 64 |
+
means any form of electronic, verbal, or written communication sent
|
| 65 |
+
to the Licensor or its representatives, including but not limited to
|
| 66 |
+
communication on electronic mailing lists, source code control systems,
|
| 67 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 68 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 69 |
+
excluding communication that is conspicuously marked or otherwise
|
| 70 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 71 |
+
|
| 72 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 73 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 74 |
+
subsequently incorporated within the Work.
|
| 75 |
+
|
| 76 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 77 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 78 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 79 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 80 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 81 |
+
Work and such Derivative Works in Source or Object form.
|
| 82 |
+
|
| 83 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 84 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 85 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 86 |
+
(except as stated in this section) patent license to make, have made,
|
| 87 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 88 |
+
where such license applies only to those patent claims licensable
|
| 89 |
+
by such Contributor that are necessarily infringed by their
|
| 90 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 91 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 92 |
+
institute patent litigation against any entity (including a
|
| 93 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 94 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 95 |
+
or contributory patent infringement, then any patent licenses
|
| 96 |
+
granted to You under this License for that Work shall terminate
|
| 97 |
+
as of the date such litigation is filed.
|
| 98 |
+
|
| 99 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 100 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 101 |
+
modifications, and in Source or Object form, provided that You
|
| 102 |
+
meet the following conditions:
|
| 103 |
+
|
| 104 |
+
(a) You must give any other recipients of the Work or
|
| 105 |
+
Derivative Works a copy of this License; and
|
| 106 |
+
|
| 107 |
+
(b) You must cause any modified files to carry prominent notices
|
| 108 |
+
stating that You changed the files; and
|
| 109 |
+
|
| 110 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 111 |
+
that You distribute, all copyright, patent, trademark, and
|
| 112 |
+
attribution notices from the Source form of the Work,
|
| 113 |
+
excluding those notices that do not pertain to any part of
|
| 114 |
+
the Derivative Works; and
|
| 115 |
+
|
| 116 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 117 |
+
distribution, then any Derivative Works that You distribute must
|
| 118 |
+
include a readable copy of the attribution notices contained
|
| 119 |
+
within such NOTICE file, excluding those notices that do not
|
| 120 |
+
pertain to any part of the Derivative Works, in at least one
|
| 121 |
+
of the following places: within a NOTICE text file distributed
|
| 122 |
+
as part of the Derivative Works; within the Source form or
|
| 123 |
+
documentation, if provided along with the Derivative Works; or,
|
| 124 |
+
within a display generated by the Derivative Works, if and
|
| 125 |
+
wherever such third-party notices normally appear. The contents
|
| 126 |
+
of the NOTICE file are for informational purposes only and
|
| 127 |
+
do not modify the License. You may add Your own attribution
|
| 128 |
+
notices within Derivative Works that You distribute, alongside
|
| 129 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 130 |
+
that such additional attribution notices cannot be construed
|
| 131 |
+
as modifying the License.
|
| 132 |
+
|
| 133 |
+
You may add Your own copyright statement to Your modifications and
|
| 134 |
+
may provide additional or different license terms and conditions
|
| 135 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 136 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 137 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 138 |
+
the conditions stated in this License.
|
| 139 |
+
|
| 140 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 141 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 142 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 143 |
+
this License, without any additional terms or conditions.
|
| 144 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 145 |
+
the terms of any separate license agreement you may have executed
|
| 146 |
+
with Licensor regarding such Contributions.
|
| 147 |
+
|
| 148 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 149 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 150 |
+
except as required for reasonable and customary use in describing the
|
| 151 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 152 |
+
|
| 153 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 154 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 155 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 156 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 157 |
+
implied, including, without limitation, any warranties or conditions
|
| 158 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 159 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 160 |
+
appropriateness of using or redistributing the Work and assume any
|
| 161 |
+
risks associated with Your exercise of permissions under this License.
|
| 162 |
+
|
| 163 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 164 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 165 |
+
unless required by applicable law (such as deliberate and grossly
|
| 166 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 167 |
+
liable to You for damages, including any direct, indirect, special,
|
| 168 |
+
incidental, or consequential damages of any character arising as a
|
| 169 |
+
result of this License or out of the use or inability to use the
|
| 170 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 171 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 172 |
+
other commercial damages or losses), even if such Contributor
|
| 173 |
+
has been advised of the possibility of such damages.
|
| 174 |
+
|
| 175 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 176 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 177 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 178 |
+
or other liability obligations and/or rights consistent with this
|
| 179 |
+
License. However, in accepting such obligations, You may act only
|
| 180 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 181 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 182 |
+
defend, and hold each Contributor harmless for any liability
|
| 183 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 184 |
+
of your accepting any such warranty or additional liability.
|
| 185 |
+
|
| 186 |
+
END OF TERMS AND CONDITIONS
|
| 187 |
+
|
| 188 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 189 |
+
|
| 190 |
+
To apply the Apache License to your work, attach the following
|
| 191 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
| 192 |
+
replaced with your own identifying information. (Don't include
|
| 193 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 194 |
+
comment syntax for the file format. We also recommend that a
|
| 195 |
+
file or class name and description of purpose be included on the
|
| 196 |
+
same "printed page" as the copyright notice for easier
|
| 197 |
+
identification within third-party archives.
|
| 198 |
+
|
| 199 |
+
Copyright 2019 Ross Wightman
|
| 200 |
+
|
| 201 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 202 |
+
you may not use this file except in compliance with the License.
|
| 203 |
+
You may obtain a copy of the License at
|
| 204 |
+
|
| 205 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 206 |
+
|
| 207 |
+
Unless required by applicable law or agreed to in writing, software
|
| 208 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 209 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 210 |
+
See the License for the specific language governing permissions and
|
| 211 |
+
limitations under the License.
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
-------
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to making participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
## Enforcement
|
| 56 |
+
|
| 57 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 58 |
+
reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
|
| 59 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 60 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 61 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 62 |
+
Further details of specific enforcement policies may be posted separately.
|
| 63 |
+
|
| 64 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 65 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 66 |
+
members of the project's leadership.
|
| 67 |
+
|
| 68 |
+
## Attribution
|
| 69 |
+
|
| 70 |
+
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contribution Guide
|
| 2 |
+
|
| 3 |
+
Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
|
| 4 |
+
|
| 5 |
+
While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
|
| 6 |
+
|
| 7 |
+
## Before you get started
|
| 8 |
+
|
| 9 |
+
By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
|
| 10 |
+
|
| 11 |
+
We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
|
LICENSE
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 2 |
+
|
| 3 |
+
Disclaimer: IMPORTANT: This Apple software is supplied to you by Apple
|
| 4 |
+
Inc. ("Apple") in consideration of your agreement to the following
|
| 5 |
+
terms, and your use, installation, modification or redistribution of
|
| 6 |
+
this Apple software constitutes acceptance of these terms. If you do
|
| 7 |
+
not agree with these terms, please do not use, install, modify or
|
| 8 |
+
redistribute this Apple software.
|
| 9 |
+
|
| 10 |
+
In consideration of your agreement to abide by the following terms, and
|
| 11 |
+
subject to these terms, Apple grants you a personal, non-exclusive
|
| 12 |
+
license, under Apple's copyrights in this original Apple software (the
|
| 13 |
+
"Apple Software"), to use, reproduce, modify and redistribute the Apple
|
| 14 |
+
Software, with or without modifications, in source and/or binary forms;
|
| 15 |
+
provided that if you redistribute the Apple Software in its entirety and
|
| 16 |
+
without modifications, you must retain this notice and the following
|
| 17 |
+
text and disclaimers in all such redistributions of the Apple Software.
|
| 18 |
+
Neither the name, trademarks, service marks or logos of Apple Inc. may
|
| 19 |
+
be used to endorse or promote products derived from the Apple Software
|
| 20 |
+
without specific prior written permission from Apple. Except as
|
| 21 |
+
expressly stated in this notice, no other rights or licenses, express or
|
| 22 |
+
implied, are granted by Apple herein, including but not limited to any
|
| 23 |
+
patent rights that may be infringed by your derivative works or by other
|
| 24 |
+
works in which the Apple Software may be incorporated.
|
| 25 |
+
|
| 26 |
+
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
|
| 27 |
+
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
|
| 28 |
+
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
|
| 29 |
+
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
|
| 30 |
+
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
|
| 31 |
+
|
| 32 |
+
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
|
| 33 |
+
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
| 34 |
+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
| 35 |
+
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
|
| 36 |
+
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
|
| 37 |
+
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
|
| 38 |
+
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
|
| 39 |
+
POSSIBILITY OF SUCH DAMAGE.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
-------------------------------------------------------------------------------
|
| 43 |
+
SOFTWARE DISTRIBUTED IN THIS REPOSITORY:
|
| 44 |
+
|
| 45 |
+
This software includes a number of subcomponents with separate
|
| 46 |
+
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
|
| 47 |
+
-------------------------------------------------------------------------------
|
LICENSE_MODEL
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Disclaimer: IMPORTANT: This Apple Machine Learning Research Model is
|
| 2 |
+
specifically developed and released by Apple Inc. ("Apple") for the sole purpose
|
| 3 |
+
of scientific research of artificial intelligence and machine-learning
|
| 4 |
+
technology. “Apple Machine Learning Research Model” means the model, including
|
| 5 |
+
but not limited to algorithms, formulas, trained model weights, parameters,
|
| 6 |
+
configurations, checkpoints, and any related materials (including
|
| 7 |
+
documentation).
|
| 8 |
+
|
| 9 |
+
This Apple Machine Learning Research Model is provided to You by
|
| 10 |
+
Apple in consideration of your agreement to the following terms, and your use,
|
| 11 |
+
modification, creation of Model Derivatives, and or redistribution of the Apple
|
| 12 |
+
Machine Learning Research Model constitutes acceptance of this Agreement. If You
|
| 13 |
+
do not agree with these terms, please do not use, modify, create Model
|
| 14 |
+
Derivatives of, or distribute this Apple Machine Learning Research Model or
|
| 15 |
+
Model Derivatives.
|
| 16 |
+
|
| 17 |
+
* License Scope: In consideration of your agreement to abide by the following
|
| 18 |
+
terms, and subject to these terms, Apple hereby grants you a personal,
|
| 19 |
+
non-exclusive, worldwide, non-transferable, royalty-free, revocable, and
|
| 20 |
+
limited license, to use, copy, modify, distribute, and create Model
|
| 21 |
+
Derivatives (defined below) of the Apple Machine Learning Research Model
|
| 22 |
+
exclusively for Research Purposes. You agree that any Model Derivatives You
|
| 23 |
+
may create or that may be created for You will be limited to Research Purposes
|
| 24 |
+
as well. “Research Purposes” means non-commercial scientific research and
|
| 25 |
+
academic development activities, such as experimentation, analysis, testing
|
| 26 |
+
conducted by You with the sole intent to advance scientific knowledge and
|
| 27 |
+
research. “Research Purposes” does not include any commercial exploitation,
|
| 28 |
+
product development or use in any commercial product or service.
|
| 29 |
+
|
| 30 |
+
* Distribution of Apple Machine Learning Research Model and Model Derivatives:
|
| 31 |
+
If you choose to redistribute Apple Machine Learning Research Model or its
|
| 32 |
+
Model Derivatives, you must provide a copy of this Agreement to such third
|
| 33 |
+
party, and ensure that the following attribution notice be provided: “Apple
|
| 34 |
+
Machine Learning Research Model is licensed under the Apple Machine Learning
|
| 35 |
+
Research Model License Agreement.” Additionally, all Model Derivatives must
|
| 36 |
+
clearly be identified as such, including disclosure of modifications and
|
| 37 |
+
changes made to the Apple Machine Learning Research Model. The name,
|
| 38 |
+
trademarks, service marks or logos of Apple may not be used to endorse or
|
| 39 |
+
promote Model Derivatives or the relationship between You and Apple. “Model
|
| 40 |
+
Derivatives” means any models or any other artifacts created by modifications,
|
| 41 |
+
improvements, adaptations, alterations to the architecture, algorithm or
|
| 42 |
+
training processes of the Apple Machine Learning Research Model, or by any
|
| 43 |
+
retraining, fine-tuning of the Apple Machine Learning Research Model.
|
| 44 |
+
|
| 45 |
+
* No Other License: Except as expressly stated in this notice, no other rights
|
| 46 |
+
or licenses, express or implied, are granted by Apple herein, including but
|
| 47 |
+
not limited to any patent, trademark, and similar intellectual property rights
|
| 48 |
+
worldwide that may be infringed by the Apple Machine Learning Research Model,
|
| 49 |
+
the Model Derivatives or by other works in which the Apple Machine Learning
|
| 50 |
+
Research Model may be incorporated.
|
| 51 |
+
|
| 52 |
+
* Compliance with Laws: Your use of Apple Machine Learning Research Model must
|
| 53 |
+
be in compliance with all applicable laws and regulations.
|
| 54 |
+
|
| 55 |
+
* Term and Termination: The term of this Agreement will begin upon your
|
| 56 |
+
acceptance of this Agreement or use of the Apple Machine Learning Research
|
| 57 |
+
Model and will continue until terminated in accordance with the following
|
| 58 |
+
terms. Apple may terminate this Agreement at any time if You are in breach of
|
| 59 |
+
any term or condition of this Agreement. Upon termination of this Agreement,
|
| 60 |
+
You must cease to use all Apple Machine Learning Research Models and Model
|
| 61 |
+
Derivatives and permanently delete any copy thereof. Sections 3, 6 and 7 will
|
| 62 |
+
survive termination.
|
| 63 |
+
|
| 64 |
+
* Disclaimer and Limitation of Liability: This Apple Machine Learning Research
|
| 65 |
+
Model and any outputs generated by the Apple Machine Learning Research Model
|
| 66 |
+
are provided on an “AS IS” basis. APPLE MAKES NO WARRANTIES, EXPRESS OR
|
| 67 |
+
IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF
|
| 68 |
+
NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE,
|
| 69 |
+
REGARDING THE APPLE MACHINE LEARNING RESEARCH MODEL OR OUTPUTS GENERATED BY
|
| 70 |
+
THE APPLE MACHINE LEARNING RESEARCH MODEL. You are solely responsible for
|
| 71 |
+
determining the appropriateness of using or redistributing the Apple Machine
|
| 72 |
+
Learning Research Model and any outputs of the Apple Machine Learning Research
|
| 73 |
+
Model and assume any risks associated with Your use of the Apple Machine
|
| 74 |
+
Learning Research Model and any output and results. IN NO EVENT SHALL APPLE BE
|
| 75 |
+
LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
|
| 76 |
+
IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF
|
| 77 |
+
THE APPLE MACHINE LEARNING RESEARCH MODEL AND ANY OUTPUTS OF THE APPLE MACHINE
|
| 78 |
+
LEARNING RESEARCH MODEL, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT,
|
| 79 |
+
TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS
|
| 80 |
+
BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 81 |
+
|
| 82 |
+
* Governing Law: This Agreement will be governed by and construed under the laws
|
| 83 |
+
of the State of California without regard to its choice of law principles. The
|
| 84 |
+
Convention on Contracts for the International Sale of Goods shall not apply to
|
| 85 |
+
the Agreement except that the arbitration clause and any arbitration hereunder
|
| 86 |
+
shall be governed by the Federal Arbitration Act, Chapters 1 and 2.
|
| 87 |
+
|
| 88 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
README.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Sharp Monocular View Synthesis in Less Than a Second
|
| 2 |
+
|
| 3 |
+
[](https://apple.github.io/ml-sharp/)
|
| 4 |
+
[](https://arxiv.org/abs/2512.10685)
|
| 5 |
+
|
| 6 |
+
This software project accompanies the research paper: _Sharp Monocular View Synthesis in Less Than a Second_
|
| 7 |
+
by _Lars Mescheder, Wei Dong, Shiwei Li, Xuyang Bai, Marcel Santos, Peiyun Hu, Bruno Lecouat, Mingmin Zhen, Amaël Delaunoy,
|
| 8 |
+
Tian Fang, Yanghai Tsin, Stephan Richter and Vladlen Koltun_.
|
| 9 |
+
|
| 10 |
+

|
| 11 |
+
|
| 12 |
+
We present SHARP, an approach to photorealistic view synthesis from a single image. Given a single photograph, SHARP regresses the parameters of a 3D Gaussian representation of the depicted scene. This is done in less than a second on a standard GPU via a single feedforward pass through a neural network. The 3D Gaussian representation produced by SHARP can then be rendered in real time, yielding high-resolution photorealistic images for nearby views. The representation is metric, with absolute scale, supporting metric camera movements. Experimental results demonstrate that SHARP delivers robust zero-shot generalization across datasets. It sets a new state of the art on multiple datasets, reducing LPIPS by 25–34% and DISTS by 21–43% versus the best prior model, while lowering the synthesis time by three orders of magnitude.
|
| 13 |
+
|
| 14 |
+
## Getting started
|
| 15 |
+
|
| 16 |
+
We recommend to first create a python environment:
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
conda create -n sharp python=3.13
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
Afterwards, you can install the project using
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
pip install -r requirements.txt
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
To test the installation, run
|
| 29 |
+
|
| 30 |
+
```
|
| 31 |
+
sharp --help
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Using the CLI
|
| 35 |
+
|
| 36 |
+
To run prediction:
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
sharp predict -i /path/to/input/images -o /path/to/output/gaussians
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
The model checkpoint will be downloaded automatically on first run and cached locally at `~/.cache/torch/hub/checkpoints/`.
|
| 43 |
+
|
| 44 |
+
Alternatively, you can download the model directly:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
wget https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
To use a manually downloaded checkpoint, specify it with the `-c` flag:
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
sharp predict -i /path/to/input/images -o /path/to/output/gaussians -c sharp_2572gikvuh.pt
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
The results will be 3D gaussian splats (3DGS) in the output folder. The 3DGS `.ply` files are compatible to various public 3DGS renderers. We follow the OpenCV coordinate convention (x right, y down, z forward). The 3DGS scene center is roughly at (0, 0, +z). When dealing with 3rdparty renderers, please scale and rotate to re-center the scene accordingly.
|
| 57 |
+
|
| 58 |
+
### Rendering trajectories (CUDA GPU only)
|
| 59 |
+
|
| 60 |
+
Additionally you can render videos with a camera trajectory. While the gaussians prediction works for all CPU, CUDA, and MPS, rendering videos via the `--render` option currently requires a CUDA GPU. The gsplat renderer takes a while to initialize at the first launch.
|
| 61 |
+
|
| 62 |
+
```
|
| 63 |
+
sharp predict -i /path/to/input/images -o /path/to/output/gaussians --render
|
| 64 |
+
|
| 65 |
+
# Or from the intermediate gaussians:
|
| 66 |
+
sharp render -i /path/to/output/gaussians -o /path/to/output/renderings
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Evaluation
|
| 70 |
+
|
| 71 |
+
Please refer to the paper for both quantitative and qualitative evaluations.
|
| 72 |
+
Additionally, please check out this [qualitative examples page](https://apple.github.io/ml-sharp/) containing several video comparisons against related work.
|
| 73 |
+
|
| 74 |
+
## Citation
|
| 75 |
+
|
| 76 |
+
If you find our work useful, please cite the following paper:
|
| 77 |
+
|
| 78 |
+
```bibtex
|
| 79 |
+
@inproceedings{Sharp2025:arxiv,
|
| 80 |
+
title = {Sharp Monocular View Synthesis in Less Than a Second},
|
| 81 |
+
author = {Lars Mescheder and Wei Dong and Shiwei Li and Xuyang Bai and Marcel Santos and Peiyun Hu and Bruno Lecouat and Mingmin Zhen and Ama\"{e}l Delaunoyand Tian Fang and Yanghai Tsin and Stephan R. Richter and Vladlen Koltun},
|
| 82 |
+
journal = {arXiv preprint arXiv:2512.10685},
|
| 83 |
+
year = {2025},
|
| 84 |
+
url = {https://arxiv.org/abs/2512.10685},
|
| 85 |
+
}
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
## Acknowledgements
|
| 89 |
+
|
| 90 |
+
Our codebase is built using multiple opensource contributions, please see [ACKNOWLEDGEMENTS](ACKNOWLEDGEMENTS) for more details.
|
| 91 |
+
|
| 92 |
+
## License
|
| 93 |
+
|
| 94 |
+
Please check out the repository [LICENSE](LICENSE) before using the provided code and
|
| 95 |
+
[LICENSE_MODEL](LICENSE_MODEL) for the released models.
|
pyproject.toml
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sharp"
|
| 3 |
+
version = "0.1"
|
| 4 |
+
description = "Inference/Network/Model code for SHARP view synthesis model."
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"click",
|
| 8 |
+
"gsplat",
|
| 9 |
+
"imageio[ffmpeg]",
|
| 10 |
+
"matplotlib",
|
| 11 |
+
"pillow-heif",
|
| 12 |
+
"plyfile",
|
| 13 |
+
"scipy",
|
| 14 |
+
"timm",
|
| 15 |
+
"torch",
|
| 16 |
+
"torchvision",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.scripts]
|
| 20 |
+
sharp = "sharp.cli:main_cli"
|
| 21 |
+
|
| 22 |
+
[project.urls]
|
| 23 |
+
Homepage = "https://github.com/apple/ml-sharp"
|
| 24 |
+
Repository = "https://github.com/apple/ml-sharp"
|
| 25 |
+
|
| 26 |
+
[build-system]
|
| 27 |
+
requires = ["setuptools", "setuptools-scm"]
|
| 28 |
+
build-backend = "setuptools.build_meta"
|
| 29 |
+
|
| 30 |
+
[tool.setuptools.packages.find]
|
| 31 |
+
where = ["src"]
|
| 32 |
+
|
| 33 |
+
[tool.pyright]
|
| 34 |
+
include = ["src"]
|
| 35 |
+
exclude = [
|
| 36 |
+
"**/node_modules",
|
| 37 |
+
"**/__pycache__",
|
| 38 |
+
]
|
| 39 |
+
pythonVersion = "3.13"
|
| 40 |
+
|
| 41 |
+
[tool.pytest.ini_options]
|
| 42 |
+
minversion = "6.0"
|
| 43 |
+
addopts = "-ra -q"
|
| 44 |
+
testpaths = [
|
| 45 |
+
"tests"
|
| 46 |
+
]
|
| 47 |
+
filterwarnings = [
|
| 48 |
+
"ignore::DeprecationWarning"
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
[tool.lint.per-file-ignores]
|
| 52 |
+
"__init__.py" = ["F401", "D100", "D104"]
|
| 53 |
+
|
| 54 |
+
[tool.ruff]
|
| 55 |
+
line-length = 100
|
| 56 |
+
lint.select = ["E", "F", "D", "I"]
|
| 57 |
+
lint.ignore = ["D100", "D105",
|
| 58 |
+
# Imperative mood of docstring.
|
| 59 |
+
"D401",
|
| 60 |
+
]
|
| 61 |
+
extend-exclude = [
|
| 62 |
+
"*external*",
|
| 63 |
+
"third_party",
|
| 64 |
+
]
|
| 65 |
+
src = ["sharp"]
|
| 66 |
+
target-version = "py39"
|
| 67 |
+
|
| 68 |
+
[tool.ruff.lint.pydocstyle]
|
| 69 |
+
convention = "google"
|
requirements.in
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
-e .
|
requirements.txt
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile requirements.in -o requirements.txt --universal
|
| 3 |
+
-e .
|
| 4 |
+
# via -r requirements.in
|
| 5 |
+
certifi==2025.8.3
|
| 6 |
+
# via requests
|
| 7 |
+
charset-normalizer==3.4.3
|
| 8 |
+
# via requests
|
| 9 |
+
click==8.3.0
|
| 10 |
+
# via sharp
|
| 11 |
+
colorama==0.4.6 ; sys_platform == 'win32'
|
| 12 |
+
# via
|
| 13 |
+
# click
|
| 14 |
+
# tqdm
|
| 15 |
+
contourpy==1.3.3
|
| 16 |
+
# via matplotlib
|
| 17 |
+
cycler==0.12.1
|
| 18 |
+
# via matplotlib
|
| 19 |
+
filelock==3.19.1
|
| 20 |
+
# via
|
| 21 |
+
# huggingface-hub
|
| 22 |
+
# torch
|
| 23 |
+
fonttools==4.61.0
|
| 24 |
+
# via matplotlib
|
| 25 |
+
fsspec==2025.9.0
|
| 26 |
+
# via
|
| 27 |
+
# huggingface-hub
|
| 28 |
+
# torch
|
| 29 |
+
gsplat==1.5.3
|
| 30 |
+
# via sharp
|
| 31 |
+
hf-xet==1.1.10 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
| 32 |
+
# via huggingface-hub
|
| 33 |
+
huggingface-hub==0.35.3
|
| 34 |
+
# via timm
|
| 35 |
+
idna==3.10
|
| 36 |
+
# via requests
|
| 37 |
+
imageio==2.37.0
|
| 38 |
+
# via sharp
|
| 39 |
+
imageio-ffmpeg==0.6.0
|
| 40 |
+
# via imageio
|
| 41 |
+
jaxtyping==0.3.3
|
| 42 |
+
# via gsplat
|
| 43 |
+
jinja2==3.1.6
|
| 44 |
+
# via torch
|
| 45 |
+
kiwisolver==1.4.9
|
| 46 |
+
# via matplotlib
|
| 47 |
+
markdown-it-py==4.0.0
|
| 48 |
+
# via rich
|
| 49 |
+
markupsafe==3.0.3
|
| 50 |
+
# via jinja2
|
| 51 |
+
matplotlib==3.10.6
|
| 52 |
+
# via sharp
|
| 53 |
+
mdurl==0.1.2
|
| 54 |
+
# via markdown-it-py
|
| 55 |
+
mpmath==1.3.0
|
| 56 |
+
# via sympy
|
| 57 |
+
networkx==3.5
|
| 58 |
+
# via torch
|
| 59 |
+
ninja==1.13.0
|
| 60 |
+
# via gsplat
|
| 61 |
+
numpy==2.3.3
|
| 62 |
+
# via
|
| 63 |
+
# contourpy
|
| 64 |
+
# gsplat
|
| 65 |
+
# imageio
|
| 66 |
+
# matplotlib
|
| 67 |
+
# plyfile
|
| 68 |
+
# scipy
|
| 69 |
+
# torchvision
|
| 70 |
+
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 71 |
+
# via
|
| 72 |
+
# nvidia-cudnn-cu12
|
| 73 |
+
# nvidia-cusolver-cu12
|
| 74 |
+
# torch
|
| 75 |
+
nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 76 |
+
# via torch
|
| 77 |
+
nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 78 |
+
# via torch
|
| 79 |
+
nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 80 |
+
# via torch
|
| 81 |
+
nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 82 |
+
# via torch
|
| 83 |
+
nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 84 |
+
# via torch
|
| 85 |
+
nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 86 |
+
# via torch
|
| 87 |
+
nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 88 |
+
# via torch
|
| 89 |
+
nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 90 |
+
# via torch
|
| 91 |
+
nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 92 |
+
# via
|
| 93 |
+
# nvidia-cusolver-cu12
|
| 94 |
+
# torch
|
| 95 |
+
nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 96 |
+
# via torch
|
| 97 |
+
nvidia-nccl-cu12==2.27.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 98 |
+
# via torch
|
| 99 |
+
nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 100 |
+
# via
|
| 101 |
+
# nvidia-cufft-cu12
|
| 102 |
+
# nvidia-cusolver-cu12
|
| 103 |
+
# nvidia-cusparse-cu12
|
| 104 |
+
# torch
|
| 105 |
+
nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 106 |
+
# via torch
|
| 107 |
+
packaging==25.0
|
| 108 |
+
# via
|
| 109 |
+
# huggingface-hub
|
| 110 |
+
# matplotlib
|
| 111 |
+
pillow==11.3.0
|
| 112 |
+
# via
|
| 113 |
+
# imageio
|
| 114 |
+
# matplotlib
|
| 115 |
+
# pillow-heif
|
| 116 |
+
# torchvision
|
| 117 |
+
pillow-heif==1.1.1
|
| 118 |
+
# via sharp
|
| 119 |
+
plyfile==1.1.2
|
| 120 |
+
# via sharp
|
| 121 |
+
psutil==7.1.0
|
| 122 |
+
# via imageio
|
| 123 |
+
pygments==2.19.2
|
| 124 |
+
# via rich
|
| 125 |
+
pyparsing==3.2.5
|
| 126 |
+
# via matplotlib
|
| 127 |
+
python-dateutil==2.9.0.post0
|
| 128 |
+
# via matplotlib
|
| 129 |
+
pyyaml==6.0.3
|
| 130 |
+
# via
|
| 131 |
+
# huggingface-hub
|
| 132 |
+
# timm
|
| 133 |
+
requests==2.32.5
|
| 134 |
+
# via huggingface-hub
|
| 135 |
+
rich==14.1.0
|
| 136 |
+
# via gsplat
|
| 137 |
+
safetensors==0.6.2
|
| 138 |
+
# via timm
|
| 139 |
+
scipy==1.16.2
|
| 140 |
+
# via sharp
|
| 141 |
+
setuptools==80.9.0
|
| 142 |
+
# via
|
| 143 |
+
# torch
|
| 144 |
+
# triton
|
| 145 |
+
six==1.17.0
|
| 146 |
+
# via python-dateutil
|
| 147 |
+
sympy==1.14.0
|
| 148 |
+
# via torch
|
| 149 |
+
timm==1.0.20
|
| 150 |
+
# via sharp
|
| 151 |
+
torch==2.8.0
|
| 152 |
+
# via
|
| 153 |
+
# gsplat
|
| 154 |
+
# sharp
|
| 155 |
+
# timm
|
| 156 |
+
# torchvision
|
| 157 |
+
torchvision==0.23.0
|
| 158 |
+
# via
|
| 159 |
+
# sharp
|
| 160 |
+
# timm
|
| 161 |
+
tqdm==4.67.1
|
| 162 |
+
# via huggingface-hub
|
| 163 |
+
triton==3.4.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
| 164 |
+
# via torch
|
| 165 |
+
typing-extensions==4.15.0
|
| 166 |
+
# via
|
| 167 |
+
# huggingface-hub
|
| 168 |
+
# torch
|
| 169 |
+
urllib3==2.6.0
|
| 170 |
+
# via requests
|
| 171 |
+
wadler-lindig==0.1.7
|
| 172 |
+
# via jaxtyping
|
src/sharp/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""For licensing see accompanying LICENSE file.
|
| 2 |
+
|
| 3 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 4 |
+
"""
|
src/sharp/cli/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Command-line-interface to run SHARP model.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import click
|
| 8 |
+
|
| 9 |
+
from . import predict, render
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@click.group()
|
| 13 |
+
def main_cli():
|
| 14 |
+
"""Run inference for SHARP model."""
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
main_cli.add_command(predict.predict_cli, "predict")
|
| 19 |
+
main_cli.add_command(render.render_cli, "render")
|
src/sharp/cli/predict.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains `sharp predict` CLI implementation.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import click
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import torch.utils.data
|
| 17 |
+
|
| 18 |
+
from sharp.models import (
|
| 19 |
+
PredictorParams,
|
| 20 |
+
RGBGaussianPredictor,
|
| 21 |
+
create_predictor,
|
| 22 |
+
)
|
| 23 |
+
from sharp.utils import io
|
| 24 |
+
from sharp.utils import logging as logging_utils
|
| 25 |
+
from sharp.utils.gaussians import (
|
| 26 |
+
Gaussians3D,
|
| 27 |
+
SceneMetaData,
|
| 28 |
+
save_ply,
|
| 29 |
+
unproject_gaussians,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from .render import render_gaussians
|
| 33 |
+
|
| 34 |
+
LOGGER = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
DEFAULT_MODEL_URL = "https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@click.command()
|
| 40 |
+
@click.option(
|
| 41 |
+
"-i",
|
| 42 |
+
"--input-path",
|
| 43 |
+
type=click.Path(path_type=Path, exists=True),
|
| 44 |
+
help="Path to an image or containing a list of images.",
|
| 45 |
+
required=True,
|
| 46 |
+
)
|
| 47 |
+
@click.option(
|
| 48 |
+
"-o",
|
| 49 |
+
"--output-path",
|
| 50 |
+
type=click.Path(path_type=Path, file_okay=False),
|
| 51 |
+
help="Path to save the predicted Gaussians and renderings.",
|
| 52 |
+
required=True,
|
| 53 |
+
)
|
| 54 |
+
@click.option(
|
| 55 |
+
"-c",
|
| 56 |
+
"--checkpoint-path",
|
| 57 |
+
type=click.Path(path_type=Path, dir_okay=False),
|
| 58 |
+
default=None,
|
| 59 |
+
help="Path to the .pt checkpoint. If not provided, downloads the default model automatically.",
|
| 60 |
+
required=False,
|
| 61 |
+
)
|
| 62 |
+
@click.option(
|
| 63 |
+
"--render/--no-render",
|
| 64 |
+
"with_rendering",
|
| 65 |
+
is_flag=True,
|
| 66 |
+
default=False,
|
| 67 |
+
help="Whether to render trajectory for checkpoint.",
|
| 68 |
+
)
|
| 69 |
+
@click.option(
|
| 70 |
+
"--device",
|
| 71 |
+
type=str,
|
| 72 |
+
default="default",
|
| 73 |
+
help="Device to run on. ['cpu', 'mps', 'cuda']",
|
| 74 |
+
)
|
| 75 |
+
@click.option("-v", "--verbose", is_flag=True, help="Activate debug logs.")
|
| 76 |
+
def predict_cli(
|
| 77 |
+
input_path: Path,
|
| 78 |
+
output_path: Path,
|
| 79 |
+
checkpoint_path: Path,
|
| 80 |
+
with_rendering: bool,
|
| 81 |
+
device: str,
|
| 82 |
+
verbose: bool,
|
| 83 |
+
):
|
| 84 |
+
"""Predict Gaussians from input images."""
|
| 85 |
+
logging_utils.configure(logging.DEBUG if verbose else logging.INFO)
|
| 86 |
+
|
| 87 |
+
extensions = io.get_supported_image_extensions()
|
| 88 |
+
|
| 89 |
+
image_paths = []
|
| 90 |
+
if input_path.is_file():
|
| 91 |
+
if input_path.suffix in extensions:
|
| 92 |
+
image_paths = [input_path]
|
| 93 |
+
else:
|
| 94 |
+
for ext in extensions:
|
| 95 |
+
image_paths.extend(list(input_path.glob(f"**/*{ext}")))
|
| 96 |
+
|
| 97 |
+
if len(image_paths) == 0:
|
| 98 |
+
LOGGER.info("No valid images found. Input was %s.", input_path)
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
LOGGER.info("Processing %d valid image files.", len(image_paths))
|
| 102 |
+
|
| 103 |
+
if device == "default":
|
| 104 |
+
if torch.cuda.is_available():
|
| 105 |
+
device = "cuda"
|
| 106 |
+
elif torch.mps.is_available():
|
| 107 |
+
device = "mps"
|
| 108 |
+
else:
|
| 109 |
+
device = "cpu"
|
| 110 |
+
LOGGER.info("Using device %s", device)
|
| 111 |
+
|
| 112 |
+
if with_rendering and device != "cuda":
|
| 113 |
+
LOGGER.warning("Can only run rendering with gsplat on CUDA. Rendering is disabled.")
|
| 114 |
+
with_rendering = False
|
| 115 |
+
|
| 116 |
+
# Load or download checkpoint
|
| 117 |
+
if checkpoint_path is None:
|
| 118 |
+
LOGGER.info("No checkpoint provided. Downloading default model from %s", DEFAULT_MODEL_URL)
|
| 119 |
+
state_dict = torch.hub.load_state_dict_from_url(DEFAULT_MODEL_URL, progress=True)
|
| 120 |
+
else:
|
| 121 |
+
LOGGER.info("Loading checkpoint from %s", checkpoint_path)
|
| 122 |
+
state_dict = torch.load(checkpoint_path, weights_only=True)
|
| 123 |
+
|
| 124 |
+
gaussian_predictor = create_predictor(PredictorParams())
|
| 125 |
+
gaussian_predictor.load_state_dict(state_dict)
|
| 126 |
+
gaussian_predictor.eval()
|
| 127 |
+
gaussian_predictor.to(device)
|
| 128 |
+
|
| 129 |
+
output_path.mkdir(exist_ok=True, parents=True)
|
| 130 |
+
|
| 131 |
+
for image_path in image_paths:
|
| 132 |
+
LOGGER.info("Processing %s", image_path)
|
| 133 |
+
image, _, f_px = io.load_rgb(image_path)
|
| 134 |
+
height, width = image.shape[:2]
|
| 135 |
+
intrinsics = torch.tensor(
|
| 136 |
+
[
|
| 137 |
+
[f_px, 0, (width - 1) / 2.0, 0],
|
| 138 |
+
[0, f_px, (height - 1) / 2.0, 0],
|
| 139 |
+
[0, 0, 1, 0],
|
| 140 |
+
[0, 0, 0, 1],
|
| 141 |
+
],
|
| 142 |
+
device=device,
|
| 143 |
+
dtype=torch.float32,
|
| 144 |
+
)
|
| 145 |
+
gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device))
|
| 146 |
+
|
| 147 |
+
LOGGER.info("Saving 3DGS to %s", output_path)
|
| 148 |
+
save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply")
|
| 149 |
+
|
| 150 |
+
if with_rendering:
|
| 151 |
+
output_video_path = (output_path / image_path.stem).with_suffix(".mp4")
|
| 152 |
+
LOGGER.info("Rendering trajectory to %s", output_video_path)
|
| 153 |
+
|
| 154 |
+
metadata = SceneMetaData(intrinsics[0, 0].item(), (width, height), "linearRGB")
|
| 155 |
+
render_gaussians(gaussians, metadata, output_video_path)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@torch.no_grad()
|
| 159 |
+
def predict_image(
|
| 160 |
+
predictor: RGBGaussianPredictor,
|
| 161 |
+
image: np.ndarray,
|
| 162 |
+
f_px: float,
|
| 163 |
+
device: torch.device,
|
| 164 |
+
) -> Gaussians3D:
|
| 165 |
+
"""Predict Gaussians from an image."""
|
| 166 |
+
internal_shape = (1536, 1536)
|
| 167 |
+
|
| 168 |
+
LOGGER.info("Running preprocessing.")
|
| 169 |
+
image_pt = torch.from_numpy(image.copy()).float().to(device).permute(2, 0, 1) / 255.0
|
| 170 |
+
_, height, width = image_pt.shape
|
| 171 |
+
disparity_factor = torch.tensor([f_px / width]).float().to(device)
|
| 172 |
+
|
| 173 |
+
image_resized_pt = F.interpolate(
|
| 174 |
+
image_pt[None],
|
| 175 |
+
size=(internal_shape[1], internal_shape[0]),
|
| 176 |
+
mode="bilinear",
|
| 177 |
+
align_corners=True,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Predict Gaussians in the NDC space.
|
| 181 |
+
LOGGER.info("Running inference.")
|
| 182 |
+
gaussians_ndc = predictor(image_resized_pt, disparity_factor)
|
| 183 |
+
|
| 184 |
+
LOGGER.info("Running postprocessing.")
|
| 185 |
+
intrinsics = (
|
| 186 |
+
torch.tensor(
|
| 187 |
+
[
|
| 188 |
+
[f_px, 0, width / 2, 0],
|
| 189 |
+
[0, f_px, height / 2, 0],
|
| 190 |
+
[0, 0, 1, 0],
|
| 191 |
+
[0, 0, 0, 1],
|
| 192 |
+
]
|
| 193 |
+
)
|
| 194 |
+
.float()
|
| 195 |
+
.to(device)
|
| 196 |
+
)
|
| 197 |
+
intrinsics_resized = intrinsics.clone()
|
| 198 |
+
intrinsics_resized[0] *= internal_shape[0] / width
|
| 199 |
+
intrinsics_resized[1] *= internal_shape[1] / height
|
| 200 |
+
|
| 201 |
+
# Convert Gaussians to metrics space.
|
| 202 |
+
gaussians = unproject_gaussians(
|
| 203 |
+
gaussians_ndc, torch.eye(4).to(device), intrinsics_resized, internal_shape
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
return gaussians
|
src/sharp/cli/render.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains `sharp render` CLI implementation.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import click
|
| 13 |
+
import torch
|
| 14 |
+
import torch.utils.data
|
| 15 |
+
|
| 16 |
+
from sharp.utils import camera, gsplat, io
|
| 17 |
+
from sharp.utils import logging as logging_utils
|
| 18 |
+
from sharp.utils.gaussians import Gaussians3D, SceneMetaData, load_ply
|
| 19 |
+
|
| 20 |
+
LOGGER = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@click.command()
|
| 24 |
+
@click.option(
|
| 25 |
+
"-i",
|
| 26 |
+
"--input-path",
|
| 27 |
+
type=click.Path(exists=True, path_type=Path),
|
| 28 |
+
help="Path to the ply or a list of plys.",
|
| 29 |
+
required=True,
|
| 30 |
+
)
|
| 31 |
+
@click.option(
|
| 32 |
+
"-o",
|
| 33 |
+
"--output-path",
|
| 34 |
+
type=click.Path(path_type=Path, file_okay=False),
|
| 35 |
+
help="Path to save the rendered videos.",
|
| 36 |
+
required=True,
|
| 37 |
+
)
|
| 38 |
+
@click.option("-v", "--verbose", is_flag=True, help="Activate debug logs.")
|
| 39 |
+
def render_cli(input_path: Path, output_path: Path, verbose: bool):
|
| 40 |
+
"""Predict Gaussians from input images."""
|
| 41 |
+
logging_utils.configure(logging.DEBUG if verbose else logging.INFO)
|
| 42 |
+
|
| 43 |
+
if not torch.cuda.is_available():
|
| 44 |
+
LOGGER.error("Rendering a checkpoint requires CUDA.")
|
| 45 |
+
exit(1)
|
| 46 |
+
|
| 47 |
+
output_path.mkdir(exist_ok=True, parents=True)
|
| 48 |
+
|
| 49 |
+
params = camera.TrajectoryParams()
|
| 50 |
+
|
| 51 |
+
if input_path.suffix == ".ply":
|
| 52 |
+
scene_paths = [input_path]
|
| 53 |
+
elif input_path.is_dir():
|
| 54 |
+
scene_paths = list(input_path.glob("*.ply"))
|
| 55 |
+
else:
|
| 56 |
+
LOGGER.error("Input path must be either directory or single PLY file.")
|
| 57 |
+
exit(1)
|
| 58 |
+
|
| 59 |
+
for scene_path in scene_paths:
|
| 60 |
+
LOGGER.info("Rendering %s", scene_path)
|
| 61 |
+
gaussians, metadata = load_ply(scene_path)
|
| 62 |
+
render_gaussians(
|
| 63 |
+
gaussians=gaussians,
|
| 64 |
+
metadata=metadata,
|
| 65 |
+
params=params,
|
| 66 |
+
output_path=(output_path / scene_path.stem).with_suffix(".mp4"),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def render_gaussians(
|
| 71 |
+
gaussians: Gaussians3D,
|
| 72 |
+
metadata: SceneMetaData,
|
| 73 |
+
output_path: Path,
|
| 74 |
+
params: camera.TrajectoryParams | None = None,
|
| 75 |
+
) -> None:
|
| 76 |
+
"""Render a single gaussian checkpoint file."""
|
| 77 |
+
(width, height) = metadata.resolution_px
|
| 78 |
+
f_px = metadata.focal_length_px
|
| 79 |
+
|
| 80 |
+
if params is None:
|
| 81 |
+
params = camera.TrajectoryParams()
|
| 82 |
+
|
| 83 |
+
if not torch.cuda.is_available():
|
| 84 |
+
raise RuntimeError("Rendering a checkpoint requires CUDA.")
|
| 85 |
+
|
| 86 |
+
device = torch.device("cuda")
|
| 87 |
+
|
| 88 |
+
intrinsics = torch.tensor(
|
| 89 |
+
[
|
| 90 |
+
[f_px, 0, (width - 1) / 2., 0],
|
| 91 |
+
[0, f_px, (height - 1) / 2., 0],
|
| 92 |
+
[0, 0, 1, 0],
|
| 93 |
+
[0, 0, 0, 1],
|
| 94 |
+
],
|
| 95 |
+
device=device,
|
| 96 |
+
dtype=torch.float32,
|
| 97 |
+
)
|
| 98 |
+
camera_model = camera.create_camera_model(
|
| 99 |
+
gaussians, intrinsics, resolution_px=metadata.resolution_px
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
trajectory = camera.create_eye_trajectory(
|
| 103 |
+
gaussians, params, resolution_px=metadata.resolution_px, f_px=f_px
|
| 104 |
+
)
|
| 105 |
+
renderer = gsplat.GSplatRenderer(color_space=metadata.color_space)
|
| 106 |
+
video_writer = io.VideoWriter(output_path)
|
| 107 |
+
|
| 108 |
+
for _, eye_position in enumerate(trajectory):
|
| 109 |
+
camera_info = camera_model.compute(eye_position)
|
| 110 |
+
rendering_output = renderer(
|
| 111 |
+
gaussians.to(device),
|
| 112 |
+
extrinsics=camera_info.extrinsics[None].to(device),
|
| 113 |
+
intrinsics=camera_info.intrinsics[None].to(device),
|
| 114 |
+
image_width=camera_info.width,
|
| 115 |
+
image_height=camera_info.height,
|
| 116 |
+
)
|
| 117 |
+
color = (rendering_output.color[0].permute(1, 2, 0) * 255.0).to(dtype=torch.uint8)
|
| 118 |
+
depth = rendering_output.depth[0]
|
| 119 |
+
video_writer.add_frame(color, depth)
|
| 120 |
+
video_writer.close()
|
src/sharp/models/__init__.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains different Gaussian predictors.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from sharp.models.monodepth import (
|
| 10 |
+
create_monodepth_adaptor,
|
| 11 |
+
create_monodepth_dpt,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from .alignment import create_alignment
|
| 15 |
+
from .composer import GaussianComposer
|
| 16 |
+
from .gaussian_decoder import create_gaussian_decoder
|
| 17 |
+
from .heads import DirectPredictionHead
|
| 18 |
+
from .initializer import create_initializer
|
| 19 |
+
from .params import PredictorParams
|
| 20 |
+
from .predictor import RGBGaussianPredictor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def create_predictor(params: PredictorParams) -> RGBGaussianPredictor:
|
| 24 |
+
"""Create gaussian predictor model specified by name."""
|
| 25 |
+
if params.gaussian_decoder.stride < params.initializer.stride:
|
| 26 |
+
raise ValueError(
|
| 27 |
+
"We donot expected gaussian_decoder has higher resolution than initializer."
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
scale_factor = params.gaussian_decoder.stride // params.initializer.stride
|
| 31 |
+
gaussian_composer = GaussianComposer(
|
| 32 |
+
delta_factor=params.delta_factor,
|
| 33 |
+
min_scale=params.min_scale,
|
| 34 |
+
max_scale=params.max_scale,
|
| 35 |
+
color_activation_type=params.color_activation_type,
|
| 36 |
+
opacity_activation_type=params.opacity_activation_type,
|
| 37 |
+
color_space=params.color_space,
|
| 38 |
+
scale_factor=scale_factor,
|
| 39 |
+
base_scale_on_predicted_mean=params.base_scale_on_predicted_mean,
|
| 40 |
+
)
|
| 41 |
+
if params.num_monodepth_layers > 1 and params.initializer.num_layers != 2:
|
| 42 |
+
raise KeyError("We only support num_layers = 2 when num_monodepth_layers > 1.")
|
| 43 |
+
|
| 44 |
+
monodepth_model = create_monodepth_dpt(params.monodepth)
|
| 45 |
+
monodepth_adaptor = create_monodepth_adaptor(
|
| 46 |
+
monodepth_model,
|
| 47 |
+
params.monodepth_adaptor,
|
| 48 |
+
params.num_monodepth_layers,
|
| 49 |
+
params.sorting_monodepth,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
if params.num_monodepth_layers == 2:
|
| 53 |
+
monodepth_adaptor.replicate_head(params.num_monodepth_layers)
|
| 54 |
+
|
| 55 |
+
gaussian_decoder = create_gaussian_decoder(
|
| 56 |
+
params.gaussian_decoder,
|
| 57 |
+
dims_depth_features=monodepth_adaptor.get_feature_dims(),
|
| 58 |
+
)
|
| 59 |
+
initializer = create_initializer(
|
| 60 |
+
params.initializer,
|
| 61 |
+
)
|
| 62 |
+
prediction_head = DirectPredictionHead(
|
| 63 |
+
feature_dim=gaussian_decoder.dim_out, num_layers=initializer.num_layers
|
| 64 |
+
)
|
| 65 |
+
decoder_dim = monodepth_model.decoder.dims_decoder[-1]
|
| 66 |
+
return RGBGaussianPredictor(
|
| 67 |
+
init_model=initializer,
|
| 68 |
+
feature_model=gaussian_decoder,
|
| 69 |
+
prediction_head=prediction_head,
|
| 70 |
+
monodepth_model=monodepth_adaptor,
|
| 71 |
+
gaussian_composer=gaussian_composer,
|
| 72 |
+
scale_map_estimator=create_alignment(params.depth_alignment, depth_decoder_dim=decoder_dim),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
__all__ = [
|
| 77 |
+
"PredictorParams",
|
| 78 |
+
"create_predictor",
|
| 79 |
+
]
|
src/sharp/models/alignment.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains modules for different types of alignment.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
from sharp.models.decoders import UNetDecoder
|
| 16 |
+
from sharp.models.encoders import UNetEncoder
|
| 17 |
+
from sharp.utils import math as math_utils
|
| 18 |
+
|
| 19 |
+
from .params import AlignmentParams
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def create_alignment(
|
| 23 |
+
params: AlignmentParams, depth_decoder_dim: int | None = None
|
| 24 |
+
) -> nn.Module | None:
|
| 25 |
+
"""Create depth alignment."""
|
| 26 |
+
if depth_decoder_dim is None:
|
| 27 |
+
raise ValueError("Requires depth_decoder_dim for LearnedAlignment.")
|
| 28 |
+
alignment = LearnedAlignment(
|
| 29 |
+
depth_decoder_features=params.depth_decoder_features,
|
| 30 |
+
depth_decoder_dim=depth_decoder_dim,
|
| 31 |
+
steps=params.steps,
|
| 32 |
+
stride=params.stride,
|
| 33 |
+
base_width=params.base_width,
|
| 34 |
+
activation_type=params.activation_type,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if params.frozen:
|
| 38 |
+
alignment.requires_grad_(False)
|
| 39 |
+
|
| 40 |
+
return alignment
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class LearnedAlignment(nn.Module):
|
| 44 |
+
"""Aligns tensors using a UNet."""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
steps: int = 4,
|
| 49 |
+
stride: int = 8,
|
| 50 |
+
base_width: int = 16,
|
| 51 |
+
depth_decoder_features: bool = False,
|
| 52 |
+
depth_decoder_dim: int = 256,
|
| 53 |
+
activation_type: math_utils.ActivationType = "exp",
|
| 54 |
+
) -> None:
|
| 55 |
+
"""Initialize LearnedAlignment.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
steps: Number of steps in the UNet.
|
| 59 |
+
stride: Effective downsampling of the alignment module.
|
| 60 |
+
base_width: Base width of the UNet.
|
| 61 |
+
depth_decoder_features: Whether to use depth decoder features.
|
| 62 |
+
depth_decoder_dim: Dimension of the depth decoder features.
|
| 63 |
+
activation_type: Activation type for the alignment output.
|
| 64 |
+
"""
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.activation = math_utils.create_activation_pair(activation_type)
|
| 67 |
+
bias_value = self.activation.inverse(torch.tensor(1.0))
|
| 68 |
+
|
| 69 |
+
self.depth_decoder_features = depth_decoder_features
|
| 70 |
+
if depth_decoder_features:
|
| 71 |
+
dim_in = 2 + depth_decoder_dim
|
| 72 |
+
else:
|
| 73 |
+
dim_in = 2
|
| 74 |
+
|
| 75 |
+
def is_power_of_two(n: int) -> bool:
|
| 76 |
+
"""Check if a number is a power of two."""
|
| 77 |
+
if n <= 0:
|
| 78 |
+
return False
|
| 79 |
+
return (n & (n - 1)) == 0
|
| 80 |
+
|
| 81 |
+
if not is_power_of_two(stride):
|
| 82 |
+
raise ValueError(f"Stride {stride} is not a power of two.")
|
| 83 |
+
|
| 84 |
+
steps_decoder = steps - int(math.log2(stride))
|
| 85 |
+
if steps_decoder < 1:
|
| 86 |
+
raise ValueError(f"{steps_decoder} must be greater or equal to 1.")
|
| 87 |
+
widths = [min(base_width << i, 1024) for i in range(steps + 1)]
|
| 88 |
+
self.encoder = UNetEncoder(dim_in=dim_in, width=widths, steps=steps, norm_num_groups=4)
|
| 89 |
+
self.decoder = UNetDecoder(
|
| 90 |
+
dim_out=widths[0], width=widths, steps=steps_decoder, norm_num_groups=4
|
| 91 |
+
)
|
| 92 |
+
self.conv_out = nn.Conv2d(widths[0], 1, 1, bias=True)
|
| 93 |
+
nn.init.zeros_(self.conv_out.weight)
|
| 94 |
+
nn.init.constant_(self.conv_out.bias, bias_value)
|
| 95 |
+
|
| 96 |
+
def forward(
|
| 97 |
+
self,
|
| 98 |
+
tensor_src: torch.Tensor,
|
| 99 |
+
tensor_tgt: torch.Tensor,
|
| 100 |
+
depth_decoder_features: torch.Tensor | None = None,
|
| 101 |
+
) -> torch.Tensor:
|
| 102 |
+
"""Compute alignment map."""
|
| 103 |
+
# Since the tensors are usually given by depth which is >= 1.0, we invert
|
| 104 |
+
# the tensors to have them in a reasonable range.
|
| 105 |
+
tensor_src = 1.0 / tensor_src.clamp(min=1e-4)
|
| 106 |
+
tensor_tgt = 1.0 / tensor_tgt.clamp(min=1e-4)
|
| 107 |
+
tensor_input = torch.cat([tensor_src, tensor_tgt], dim=1)
|
| 108 |
+
if self.depth_decoder_features:
|
| 109 |
+
height, width = tensor_src.shape[-2:]
|
| 110 |
+
upsampled_encodings = F.interpolate(
|
| 111 |
+
depth_decoder_features,
|
| 112 |
+
size=(height, width),
|
| 113 |
+
mode="bilinear",
|
| 114 |
+
)
|
| 115 |
+
tensor_input = torch.cat([tensor_input, upsampled_encodings], dim=1)
|
| 116 |
+
features = self.encoder(tensor_input)
|
| 117 |
+
output = self.conv_out(self.decoder(features))
|
| 118 |
+
alignment_map_lowres = self.activation.forward(output)
|
| 119 |
+
if alignment_map_lowres.shape[-2:] != tensor_src.shape[-2]:
|
| 120 |
+
alignment_map = F.interpolate(
|
| 121 |
+
alignment_map_lowres,
|
| 122 |
+
size=tensor_src.shape[-2:],
|
| 123 |
+
mode="bilinear",
|
| 124 |
+
align_corners=False,
|
| 125 |
+
)
|
| 126 |
+
return alignment_map
|
src/sharp/models/blocks.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains reusable network components.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Literal
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
NormLayerName = Literal["noop", "batch_norm", "group_norm", "instance_norm"]
|
| 15 |
+
UpsamplingMode = Literal["transposed_conv", "nearest", "bilinear"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def norm_layer_2d(num_features: int, norm_type: NormLayerName, num_groups: int = 8) -> nn.Module:
|
| 19 |
+
"""Create normalization layer."""
|
| 20 |
+
if norm_type == "noop":
|
| 21 |
+
return nn.Identity()
|
| 22 |
+
elif norm_type == "batch_norm":
|
| 23 |
+
return nn.BatchNorm2d(num_features=num_features)
|
| 24 |
+
elif norm_type == "group_norm":
|
| 25 |
+
return nn.GroupNorm(num_channels=num_features, num_groups=num_groups)
|
| 26 |
+
elif norm_type == "instance_norm":
|
| 27 |
+
return nn.InstanceNorm2d(num_features=num_features)
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"Invalid normalization layer type: {norm_type}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def upsampling_layer(upsampling_mode: UpsamplingMode, scale_factor: int, dim_in: int) -> nn.Module:
|
| 33 |
+
"""Create upsampling layer."""
|
| 34 |
+
if upsampling_mode == "transposed_conv":
|
| 35 |
+
return nn.ConvTranspose2d(
|
| 36 |
+
in_channels=dim_in,
|
| 37 |
+
out_channels=dim_in,
|
| 38 |
+
kernel_size=scale_factor,
|
| 39 |
+
stride=scale_factor,
|
| 40 |
+
padding=0,
|
| 41 |
+
bias=False,
|
| 42 |
+
)
|
| 43 |
+
elif upsampling_mode in ("nearest", "bilinear"):
|
| 44 |
+
return nn.Upsample(scale_factor=scale_factor, mode=upsampling_mode)
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f"Invalid upsampling mode {upsampling_mode}.")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ResidualBlock(nn.Module):
|
| 50 |
+
"""Generic implementation of residual blocks.
|
| 51 |
+
|
| 52 |
+
This implements a generic residual block from
|
| 53 |
+
|
| 54 |
+
He et al. - Identity Mappings in Deep Residual Networks (2016),
|
| 55 |
+
https://arxiv.org/abs/1603.05027
|
| 56 |
+
|
| 57 |
+
which can be further customized via factory functions.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
|
| 61 |
+
"""Initialize ResidualBlock."""
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.residual = residual
|
| 64 |
+
self.shortcut = shortcut
|
| 65 |
+
|
| 66 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
"""Apply residual block."""
|
| 68 |
+
delta_x = self.residual(x)
|
| 69 |
+
|
| 70 |
+
if self.shortcut is not None:
|
| 71 |
+
x = self.shortcut(x)
|
| 72 |
+
|
| 73 |
+
return x + delta_x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def residual_block_2d(
|
| 77 |
+
dim_in: int,
|
| 78 |
+
dim_out: int,
|
| 79 |
+
dim_hidden: int | None = None,
|
| 80 |
+
actvn: nn.Module | None = None,
|
| 81 |
+
norm_type: NormLayerName = "noop",
|
| 82 |
+
norm_num_groups: int = 8,
|
| 83 |
+
dilation: int = 1,
|
| 84 |
+
kernel_size: int = 3,
|
| 85 |
+
):
|
| 86 |
+
"""Create a simple 2D residual block."""
|
| 87 |
+
if actvn is None:
|
| 88 |
+
actvn = nn.ReLU()
|
| 89 |
+
|
| 90 |
+
if dim_hidden is None:
|
| 91 |
+
dim_hidden = dim_out // 2
|
| 92 |
+
|
| 93 |
+
# Padding to maintain output size
|
| 94 |
+
# See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
| 95 |
+
padding = (dilation * (kernel_size - 1)) // 2
|
| 96 |
+
|
| 97 |
+
def _create_block(dim_in: int, dim_out: int) -> list[nn.Module]:
|
| 98 |
+
layers = [
|
| 99 |
+
norm_layer_2d(dim_in, norm_type, num_groups=norm_num_groups),
|
| 100 |
+
actvn,
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
layers.append(
|
| 104 |
+
nn.Conv2d(
|
| 105 |
+
dim_in,
|
| 106 |
+
dim_out,
|
| 107 |
+
kernel_size=kernel_size,
|
| 108 |
+
stride=1,
|
| 109 |
+
dilation=dilation,
|
| 110 |
+
padding=padding,
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
return layers
|
| 114 |
+
|
| 115 |
+
residual = nn.Sequential(
|
| 116 |
+
*_create_block(dim_in, dim_hidden),
|
| 117 |
+
*_create_block(dim_hidden, dim_out),
|
| 118 |
+
)
|
| 119 |
+
shortcut = None
|
| 120 |
+
|
| 121 |
+
if dim_in != dim_out:
|
| 122 |
+
shortcut = nn.Conv2d(dim_in, dim_out, 1)
|
| 123 |
+
|
| 124 |
+
return ResidualBlock(residual, shortcut)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class FeatureFusionBlock2d(nn.Module):
|
| 128 |
+
"""Feature fusion for DPT."""
|
| 129 |
+
|
| 130 |
+
# We use the name "deconv" for backward compatibility. However, "deconv" can also
|
| 131 |
+
# refer to some other upsampling layer or a no-op.
|
| 132 |
+
deconv: nn.Module
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
dim_in: int,
|
| 137 |
+
dim_out: int | None = None,
|
| 138 |
+
upsampling_mode: UpsamplingMode | None = None,
|
| 139 |
+
batch_norm: bool = False,
|
| 140 |
+
):
|
| 141 |
+
"""Initialize feature fusion block.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
dim_in: Dimensions of input.
|
| 145 |
+
dim_out: Dimensions of output.
|
| 146 |
+
batch_norm: Whether to use batch normalization in resnet blocks.
|
| 147 |
+
upsampling_mode: What mode to use for upsampling. None if no upsampling
|
| 148 |
+
is required.
|
| 149 |
+
"""
|
| 150 |
+
super().__init__()
|
| 151 |
+
if dim_out is None:
|
| 152 |
+
dim_out = dim_in
|
| 153 |
+
self.resnet1 = self._residual_block(dim_in, batch_norm)
|
| 154 |
+
self.resnet2 = self._residual_block(dim_in, batch_norm)
|
| 155 |
+
|
| 156 |
+
if upsampling_mode is not None:
|
| 157 |
+
self.deconv = upsampling_layer(upsampling_mode, scale_factor=2, dim_in=dim_in)
|
| 158 |
+
else:
|
| 159 |
+
self.deconv = nn.Sequential()
|
| 160 |
+
|
| 161 |
+
self.out_conv = nn.Conv2d(
|
| 162 |
+
dim_in,
|
| 163 |
+
dim_out,
|
| 164 |
+
kernel_size=1,
|
| 165 |
+
stride=1,
|
| 166 |
+
padding=0,
|
| 167 |
+
bias=True,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 171 |
+
|
| 172 |
+
def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
|
| 173 |
+
"""Process and fuse input features."""
|
| 174 |
+
x = x0
|
| 175 |
+
|
| 176 |
+
if x1 is not None:
|
| 177 |
+
res = self.resnet1(x1)
|
| 178 |
+
x = self.skip_add.add(x, res)
|
| 179 |
+
|
| 180 |
+
x = self.resnet2(x)
|
| 181 |
+
x = self.deconv(x)
|
| 182 |
+
x = self.out_conv(x)
|
| 183 |
+
|
| 184 |
+
return x
|
| 185 |
+
|
| 186 |
+
@staticmethod
|
| 187 |
+
def _residual_block(num_features: int, batch_norm: bool):
|
| 188 |
+
"""Create a residual block."""
|
| 189 |
+
|
| 190 |
+
def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
|
| 191 |
+
layers = [
|
| 192 |
+
nn.ReLU(False),
|
| 193 |
+
nn.Conv2d(
|
| 194 |
+
num_features,
|
| 195 |
+
num_features,
|
| 196 |
+
kernel_size=3,
|
| 197 |
+
stride=1,
|
| 198 |
+
padding=1,
|
| 199 |
+
bias=not batch_norm,
|
| 200 |
+
),
|
| 201 |
+
]
|
| 202 |
+
if batch_norm:
|
| 203 |
+
layers.append(nn.BatchNorm2d(dim))
|
| 204 |
+
return layers
|
| 205 |
+
|
| 206 |
+
residual = nn.Sequential(
|
| 207 |
+
*_create_block(dim=num_features, batch_norm=batch_norm),
|
| 208 |
+
*_create_block(dim=num_features, batch_norm=batch_norm),
|
| 209 |
+
)
|
| 210 |
+
return ResidualBlock(residual)
|
src/sharp/models/composer.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Defines module to compose final Gaussians from base values and delta values.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
|
| 13 |
+
from sharp.models.initializer import GaussianBaseValues
|
| 14 |
+
from sharp.utils import math as math_utils
|
| 15 |
+
from sharp.utils.color_space import ColorSpace, sRGB2linearRGB
|
| 16 |
+
from sharp.utils.gaussians import Gaussians3D
|
| 17 |
+
|
| 18 |
+
from .params import DeltaFactor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_scale_activation_constant(max_scale: float, min_scale: float) -> tuple[float, float]:
|
| 22 |
+
"""Return constants for scale activation function."""
|
| 23 |
+
# To ensure for delta = 0, the value of scale_factor is 1 and the gradient is 1.
|
| 24 |
+
constant_a = (max_scale - min_scale) / (1 - min_scale) / (max_scale - 1)
|
| 25 |
+
constant_b = math_utils.inverse_sigmoid(
|
| 26 |
+
torch.tensor((1.0 - min_scale) / (max_scale - min_scale))
|
| 27 |
+
).item()
|
| 28 |
+
return constant_a, constant_b
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class GaussianComposer(nn.Module):
|
| 32 |
+
"""Converts base values and deltas into Gaussians."""
|
| 33 |
+
|
| 34 |
+
color_activation_type: math_utils.ActivationType
|
| 35 |
+
opacity_activation_type: math_utils.ActivationType
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
delta_factor: DeltaFactor,
|
| 40 |
+
min_scale: float,
|
| 41 |
+
max_scale: float,
|
| 42 |
+
color_activation_type: math_utils.ActivationType,
|
| 43 |
+
opacity_activation_type: math_utils.ActivationType,
|
| 44 |
+
color_space: ColorSpace,
|
| 45 |
+
base_scale_on_predicted_mean: bool,
|
| 46 |
+
scale_factor: int = 1,
|
| 47 |
+
) -> None:
|
| 48 |
+
"""Initialize GaussianComposer.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
delta_factor: Multiply delta offsets by this factor.
|
| 52 |
+
min_scale: The minimal scale factor for gaussian scale activation.
|
| 53 |
+
max_scale: The maximal scale factor for gaussian scale activation.
|
| 54 |
+
color_activation_type: Which activation function to use for colors.
|
| 55 |
+
opacity_activation_type: Which activation function to use for opacities.
|
| 56 |
+
color_space: Which color space is used in training.
|
| 57 |
+
scale_factor: The scale factor to upsample the delta_values before composition.
|
| 58 |
+
base_scale_on_predicted_mean: Whether to account z offsets for estimating base scale.
|
| 59 |
+
"""
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.delta_factor = delta_factor
|
| 62 |
+
self.max_scale = max_scale
|
| 63 |
+
self.min_scale = min_scale
|
| 64 |
+
self.color_activation_type = color_activation_type
|
| 65 |
+
self.opacity_activation_type = opacity_activation_type
|
| 66 |
+
self.color_space = color_space
|
| 67 |
+
self.scale_factor = scale_factor
|
| 68 |
+
self.base_scale_on_predicted_mean = base_scale_on_predicted_mean
|
| 69 |
+
|
| 70 |
+
def upsample_delta_value(self, delta: torch.Tensor, scale_factor: int = 1):
|
| 71 |
+
"""Upsample the delta value.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
delta: The delta values predicted by gaussian predictor.
|
| 75 |
+
scale_factor: The scale factor to upsample the delta_values.
|
| 76 |
+
"""
|
| 77 |
+
(
|
| 78 |
+
batch_size,
|
| 79 |
+
num_channels,
|
| 80 |
+
num_layers,
|
| 81 |
+
image_height,
|
| 82 |
+
image_width,
|
| 83 |
+
) = delta.shape
|
| 84 |
+
new_height = image_height * scale_factor
|
| 85 |
+
new_width = image_width * scale_factor
|
| 86 |
+
upsampled_delta = F.interpolate(
|
| 87 |
+
delta.view(batch_size, num_channels * num_layers, image_height, image_width),
|
| 88 |
+
scale_factor=scale_factor,
|
| 89 |
+
).view(batch_size, num_channels, num_layers, new_height, new_width)
|
| 90 |
+
return upsampled_delta
|
| 91 |
+
|
| 92 |
+
def forward(
|
| 93 |
+
self,
|
| 94 |
+
delta: torch.Tensor,
|
| 95 |
+
base_values: GaussianBaseValues,
|
| 96 |
+
global_scale: torch.Tensor | None = None,
|
| 97 |
+
flatten_output: bool = True,
|
| 98 |
+
) -> Gaussians3D:
|
| 99 |
+
"""Combine predicted delta values with base gaussian values and apply activation function.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
delta: The delta values predicted by gaussian predictor.
|
| 103 |
+
base_values: The gaussian base values.
|
| 104 |
+
global_scale: Global scale of Gaussians.
|
| 105 |
+
flatten_output: Flatten the gaussian parameters.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
The computed 3D Gaussians.
|
| 109 |
+
"""
|
| 110 |
+
# Upsample the delta if delta and base_values have different strides.
|
| 111 |
+
scale_factor = self.scale_factor
|
| 112 |
+
# For triplane head, the delta has already been upsampled.
|
| 113 |
+
actual_scale_factor = base_values.mean_x_ndc.shape[-1] // delta.shape[-1]
|
| 114 |
+
if scale_factor != 1 and actual_scale_factor != 1:
|
| 115 |
+
delta = self.upsample_delta_value(delta, scale_factor)
|
| 116 |
+
|
| 117 |
+
mean_vectors = self._forward_mean(base_values, delta)
|
| 118 |
+
|
| 119 |
+
# Account for the change in base scale due to z offsets.
|
| 120 |
+
base_scales = (
|
| 121 |
+
(base_values.scales * base_values.mean_inverse_z_ndc * mean_vectors[:, 2:3, ...])
|
| 122 |
+
if self.base_scale_on_predicted_mean
|
| 123 |
+
else base_values.scales
|
| 124 |
+
)
|
| 125 |
+
singular_values = self._scale_activation(
|
| 126 |
+
base_scales,
|
| 127 |
+
delta[:, 3:6],
|
| 128 |
+
self.min_scale,
|
| 129 |
+
self.max_scale,
|
| 130 |
+
)
|
| 131 |
+
quaternions = self._quaternion_activation(base_values.quaternions, delta[:, 6:10])
|
| 132 |
+
colors = self._color_activation(base_values.colors, delta[:, 10:13])
|
| 133 |
+
opacities = self._opacity_activation(base_values.opacities, delta[:, 13])
|
| 134 |
+
|
| 135 |
+
if flatten_output:
|
| 136 |
+
# [B, C, N, H, W] -> [B, N, H, W, C].
|
| 137 |
+
# NOTE: opacities is [B, N, H, W] so it doesn't need to permute.
|
| 138 |
+
mean_vectors = mean_vectors.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 139 |
+
singular_values = singular_values.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 140 |
+
quaternions = quaternions.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 141 |
+
colors = colors.permute(0, 2, 3, 4, 1).flatten(1, 3)
|
| 142 |
+
opacities = opacities.flatten(1, 3)
|
| 143 |
+
|
| 144 |
+
# Apply global scaling to convert Gaussians to metric space.
|
| 145 |
+
if global_scale is not None:
|
| 146 |
+
mean_vectors = global_scale[:, None, None] * mean_vectors
|
| 147 |
+
singular_values = global_scale[:, None, None] * singular_values
|
| 148 |
+
|
| 149 |
+
return Gaussians3D(
|
| 150 |
+
mean_vectors=mean_vectors,
|
| 151 |
+
singular_values=singular_values,
|
| 152 |
+
quaternions=quaternions,
|
| 153 |
+
colors=colors,
|
| 154 |
+
opacities=opacities,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def _forward_mean(self, base_values: GaussianBaseValues, delta: torch.Tensor) -> torch.Tensor:
|
| 158 |
+
# Concatenate base vectors and apply mean activation.
|
| 159 |
+
delta_factor = torch.tensor(
|
| 160 |
+
[self.delta_factor.xy, self.delta_factor.xy, self.delta_factor.z],
|
| 161 |
+
device=delta.device,
|
| 162 |
+
)[None, :, None, None, None]
|
| 163 |
+
|
| 164 |
+
dtype = base_values.mean_x_ndc.dtype
|
| 165 |
+
device = base_values.mean_x_ndc.device
|
| 166 |
+
target_shape = (1, 3, 1, 1, 1)
|
| 167 |
+
mean_x_mask = torch.tensor([1.0, 0.0, 0.0], dtype=dtype, device=device).reshape(
|
| 168 |
+
target_shape
|
| 169 |
+
)
|
| 170 |
+
mean_y_mask = torch.tensor([0.0, 1.0, 0.0], dtype=dtype, device=device).reshape(
|
| 171 |
+
target_shape
|
| 172 |
+
)
|
| 173 |
+
mean_z_mask = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device).reshape(
|
| 174 |
+
target_shape
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
mean_vectors_ndc = (
|
| 178 |
+
base_values.mean_x_ndc.repeat(target_shape) * mean_x_mask
|
| 179 |
+
+ base_values.mean_y_ndc.repeat(target_shape) * mean_y_mask
|
| 180 |
+
+ base_values.mean_inverse_z_ndc.repeat(target_shape) * mean_z_mask
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
mean_vectors = self._mean_activation(mean_vectors_ndc, delta_factor * delta[:, :3])
|
| 184 |
+
return mean_vectors
|
| 185 |
+
|
| 186 |
+
def _mean_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor:
|
| 187 |
+
"""Mean activation function.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
base: Tensor of shape [B, 3, H, W], where first two feature dimensions
|
| 191 |
+
(x,y) are in normalized device coordinates (NDC) where (-1, -1) is
|
| 192 |
+
the top, while the third dimension is inverse depth.
|
| 193 |
+
learned_delta: Tensor of shape [B, 3, H, W] with predicted delta values.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Returns: The final mean vector after combining base and delta and applying nonlinearies.
|
| 197 |
+
"""
|
| 198 |
+
xx = base[:, 0:1] + learned_delta[:, 0:1]
|
| 199 |
+
yy = base[:, 1:2] + learned_delta[:, 1:2]
|
| 200 |
+
|
| 201 |
+
a = base[:, 2:3]
|
| 202 |
+
b = learned_delta[:, 2:3]
|
| 203 |
+
|
| 204 |
+
# Original formula:
|
| 205 |
+
inverse_zz = F.softplus(math_utils.inverse_softplus(a) + b)
|
| 206 |
+
zz = 1.0 / (inverse_zz + 1e-3)
|
| 207 |
+
|
| 208 |
+
mean_vectors = torch.cat([zz * xx, zz * yy, zz], dim=1)
|
| 209 |
+
return mean_vectors
|
| 210 |
+
|
| 211 |
+
def _scale_activation(
|
| 212 |
+
self,
|
| 213 |
+
base: torch.Tensor,
|
| 214 |
+
learned_delta: torch.Tensor,
|
| 215 |
+
min_scale: float,
|
| 216 |
+
max_scale: float,
|
| 217 |
+
) -> torch.Tensor:
|
| 218 |
+
constant_a, constant_b = _get_scale_activation_constant(max_scale, min_scale)
|
| 219 |
+
scale_factor = (max_scale - min_scale) * torch.sigmoid(
|
| 220 |
+
constant_a * self.delta_factor.scale * learned_delta + constant_b
|
| 221 |
+
) + min_scale
|
| 222 |
+
return base * scale_factor
|
| 223 |
+
|
| 224 |
+
def _quaternion_activation(
|
| 225 |
+
self, base: torch.Tensor, learned_delta: torch.Tensor
|
| 226 |
+
) -> torch.Tensor:
|
| 227 |
+
# No need to normalize the quaternions, since this is also done in rendering.
|
| 228 |
+
return base + self.delta_factor.quaternion * learned_delta
|
| 229 |
+
|
| 230 |
+
def _color_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor:
|
| 231 |
+
# For certain activation functions we need to clamp the base value to
|
| 232 |
+
# a supported range.
|
| 233 |
+
if self.color_activation_type == "sigmoid":
|
| 234 |
+
base = torch.clamp(base, min=0.01, max=0.99)
|
| 235 |
+
elif self.color_activation_type in ("exp", "softplus"):
|
| 236 |
+
base = torch.clamp(base, min=0.01)
|
| 237 |
+
|
| 238 |
+
activation = math_utils.create_activation_pair(self.color_activation_type)
|
| 239 |
+
colors: torch.Tensor = activation.forward(
|
| 240 |
+
activation.inverse(base) + self.delta_factor.color * learned_delta
|
| 241 |
+
)
|
| 242 |
+
# Convert gaussian color to linear if linearRGB colorspace is specified.
|
| 243 |
+
if self.color_space == "linearRGB":
|
| 244 |
+
colors = sRGB2linearRGB(colors)
|
| 245 |
+
return colors
|
| 246 |
+
|
| 247 |
+
def _opacity_activation(self, base: torch.Tensor, learned_delta: torch.Tensor) -> torch.Tensor:
|
| 248 |
+
activation = math_utils.create_activation_pair(self.opacity_activation_type)
|
| 249 |
+
return activation.forward(
|
| 250 |
+
activation.inverse(base) + self.delta_factor.opacity * learned_delta
|
| 251 |
+
)
|
src/sharp/models/decoders/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains different decoders for Gaussian predictor.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from .base_decoder import BaseDecoder
|
| 10 |
+
from .monodepth_decoder import (
|
| 11 |
+
create_monodepth_decoder,
|
| 12 |
+
)
|
| 13 |
+
from .multires_conv_decoder import MultiresConvDecoder, UpsamplingMode
|
| 14 |
+
from .unet_decoder import UNetDecoder
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
"BaseDecoder",
|
| 18 |
+
"UNetDecoder",
|
| 19 |
+
"MultiresConvDecoder",
|
| 20 |
+
"UpsamplingMode",
|
| 21 |
+
"create_monodepth_decoder",
|
| 22 |
+
]
|
src/sharp/models/decoders/base_decoder.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains the base class for decoders.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import abc
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseDecoder(nn.Module, abc.ABC):
|
| 15 |
+
"""Base decoder class."""
|
| 16 |
+
|
| 17 |
+
dim_out: int
|
| 18 |
+
|
| 19 |
+
@abc.abstractmethod
|
| 20 |
+
def forward(self, encodings: List[torch.Tensor]) -> torch.Tensor:
|
| 21 |
+
"""Decode (multi-resolution) encodings."""
|
src/sharp/models/decoders/monodepth_decoder.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains factory function for loading/creating monodepth decoder.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from sharp.models.presets import (
|
| 11 |
+
MONODEPTH_ENCODER_DIMS_MAP,
|
| 12 |
+
ViTPreset,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from .multires_conv_decoder import MultiresConvDecoder
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_monodepth_decoder(
|
| 19 |
+
patch_encoder_preset: ViTPreset,
|
| 20 |
+
dims_decoder=None,
|
| 21 |
+
) -> MultiresConvDecoder:
|
| 22 |
+
"""Create DepthDensePredictionTransformer model.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
patch_encoder_preset: The preset patch encoder architecture in SPN.
|
| 26 |
+
dims_decoder: The decoder architecture.
|
| 27 |
+
"""
|
| 28 |
+
dims_encoder = MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset]
|
| 29 |
+
if dims_decoder is None:
|
| 30 |
+
dims_decoder = dims_encoder[0]
|
| 31 |
+
if isinstance(dims_decoder, int):
|
| 32 |
+
dims_decoder = [dims_decoder]
|
| 33 |
+
decoder = MultiresConvDecoder(
|
| 34 |
+
dims_encoder=[dims_decoder[0]] + list(dims_encoder), dims_decoder=dims_decoder
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return decoder
|
src/sharp/models/decoders/multires_conv_decoder.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains multi-res convolutional decoder.
|
| 2 |
+
|
| 3 |
+
Implements the decoder for Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
|
| 4 |
+
|
| 5 |
+
For licensing see accompanying LICENSE file.
|
| 6 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import Iterable
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
from sharp.models.blocks import FeatureFusionBlock2d, UpsamplingMode
|
| 17 |
+
from sharp.utils.training import checkpoint_wrapper
|
| 18 |
+
|
| 19 |
+
from .base_decoder import BaseDecoder
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MultiresConvDecoder(BaseDecoder):
|
| 23 |
+
"""Decoder for multi-resolution encodings."""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
dims_encoder: Iterable[int],
|
| 28 |
+
dims_decoder: Iterable[int] | int,
|
| 29 |
+
grad_checkpointing: bool = False,
|
| 30 |
+
upsampling_mode: UpsamplingMode = "transposed_conv",
|
| 31 |
+
):
|
| 32 |
+
"""Initialize multiresolution convolutional decoder.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
dims_encoder: Expected dims at each level from the encoder.
|
| 36 |
+
dims_decoder: Dim of decoder features.
|
| 37 |
+
grad_checkpointing: Whether to checkpoint gradient during training.
|
| 38 |
+
upsampling_mode: What method to use for upsampling.
|
| 39 |
+
"""
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.dims_encoder = list(dims_encoder)
|
| 42 |
+
|
| 43 |
+
if isinstance(dims_decoder, int):
|
| 44 |
+
self.dims_decoder = [dims_decoder] * len(self.dims_encoder)
|
| 45 |
+
else:
|
| 46 |
+
self.dims_decoder = list(dims_decoder)
|
| 47 |
+
|
| 48 |
+
if len(self.dims_decoder) != len(self.dims_encoder):
|
| 49 |
+
raise ValueError("Received dims_encoder and dims_decoder of different sizes.")
|
| 50 |
+
|
| 51 |
+
self.dim_out = self.dims_decoder[0]
|
| 52 |
+
|
| 53 |
+
num_encoders = len(self.dims_encoder)
|
| 54 |
+
|
| 55 |
+
# At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
|
| 56 |
+
# when the dimensions mismatch. Otherwise we do not do anything, which is
|
| 57 |
+
# the default behavior of monodepth.
|
| 58 |
+
conv0 = (
|
| 59 |
+
nn.Conv2d(self.dims_encoder[0], self.dims_decoder[0], kernel_size=1, bias=False)
|
| 60 |
+
if self.dims_encoder[0] != self.dims_decoder[0]
|
| 61 |
+
else nn.Identity()
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
convs = [conv0]
|
| 65 |
+
for i in range(1, num_encoders):
|
| 66 |
+
convs.append(
|
| 67 |
+
nn.Conv2d(
|
| 68 |
+
self.dims_encoder[i],
|
| 69 |
+
self.dims_decoder[i],
|
| 70 |
+
kernel_size=3,
|
| 71 |
+
stride=1,
|
| 72 |
+
padding=1,
|
| 73 |
+
bias=False,
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
self.convs = nn.ModuleList(convs)
|
| 77 |
+
|
| 78 |
+
fusions = []
|
| 79 |
+
for i in range(num_encoders):
|
| 80 |
+
fusions.append(
|
| 81 |
+
FeatureFusionBlock2d(
|
| 82 |
+
dim_in=self.dims_decoder[i],
|
| 83 |
+
dim_out=self.dims_decoder[i - 1] if i != 0 else self.dim_out,
|
| 84 |
+
upsampling_mode=upsampling_mode if i != 0 else None,
|
| 85 |
+
batch_norm=False,
|
| 86 |
+
)
|
| 87 |
+
)
|
| 88 |
+
self.fusions = nn.ModuleList(fusions)
|
| 89 |
+
|
| 90 |
+
self.grad_checkpointing = grad_checkpointing
|
| 91 |
+
|
| 92 |
+
@torch.jit.ignore
|
| 93 |
+
def set_grad_checkpointing(self, is_enabled=True):
|
| 94 |
+
"""Enable grad checkpointing."""
|
| 95 |
+
self.grad_checkpointing = is_enabled
|
| 96 |
+
|
| 97 |
+
def forward(self, encodings: list[torch.Tensor]) -> torch.Tensor:
|
| 98 |
+
"""Decode the multi-resolution encodings."""
|
| 99 |
+
num_levels = len(encodings)
|
| 100 |
+
num_encoders = len(self.dims_encoder)
|
| 101 |
+
|
| 102 |
+
if num_levels != num_encoders:
|
| 103 |
+
raise ValueError(
|
| 104 |
+
f"Encoder output levels={num_levels} at runtime "
|
| 105 |
+
f"mismatch with expected levels={num_encoders}."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Project features of different encoder dims to the same decoder dim.
|
| 109 |
+
# Fuse features from the lowest resolution (num_levels-1)
|
| 110 |
+
# to the highest (0).
|
| 111 |
+
features = self.convs[-1](encodings[-1])
|
| 112 |
+
features = checkpoint_wrapper(self, self.fusions[-1], features)
|
| 113 |
+
for i in range(num_levels - 2, -1, -1):
|
| 114 |
+
features_i = self.convs[i](encodings[i])
|
| 115 |
+
features = checkpoint_wrapper(self, self.fusions[i], features, features_i)
|
| 116 |
+
return features
|
src/sharp/models/decoders/unet_decoder.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains the UNet decoder.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from sharp.models.blocks import (
|
| 15 |
+
NormLayerName,
|
| 16 |
+
norm_layer_2d,
|
| 17 |
+
residual_block_2d,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from .base_decoder import BaseDecoder
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class UNetDecoder(BaseDecoder):
|
| 24 |
+
"""Decoder of UNet model."""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
dim_out: int,
|
| 29 |
+
width: List[int] | int,
|
| 30 |
+
steps: int = 5,
|
| 31 |
+
norm_type: NormLayerName = "group_norm",
|
| 32 |
+
norm_num_groups=8,
|
| 33 |
+
blocks_per_layer=2,
|
| 34 |
+
) -> None:
|
| 35 |
+
"""Initialize UNet Decoder.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
dim_out: The number of output channels.
|
| 39 |
+
width: Width of last input feature map from encoder
|
| 40 |
+
or the width list of all input feature maps from encoder.
|
| 41 |
+
steps: The number of upsampling steps.
|
| 42 |
+
norm_type: Which kind of normalization layer to use.
|
| 43 |
+
norm_num_groups: How many groups to use for group norm (if relevant).
|
| 44 |
+
blocks_per_layer: How many blocks per layer to use.
|
| 45 |
+
"""
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
if blocks_per_layer < 1:
|
| 49 |
+
raise ValueError("blocks_per_layer must be greater or equal to one.")
|
| 50 |
+
|
| 51 |
+
self.dim_out = dim_out
|
| 52 |
+
|
| 53 |
+
self.convs_up = nn.ModuleList()
|
| 54 |
+
|
| 55 |
+
self.output_dims: list[int]
|
| 56 |
+
# If only one number is specified, we assume each layer will double the channel dimension.
|
| 57 |
+
if isinstance(width, int):
|
| 58 |
+
self.input_dims = [width >> i for i in range(0, steps + 1)]
|
| 59 |
+
else:
|
| 60 |
+
self.input_dims = width[::-1][: steps + 1]
|
| 61 |
+
|
| 62 |
+
for i_step in range(steps):
|
| 63 |
+
input_width = self.input_dims[i_step]
|
| 64 |
+
current_width = self.input_dims[i_step + 1]
|
| 65 |
+
convs_up_i = nn.Sequential(
|
| 66 |
+
nn.Upsample(scale_factor=2),
|
| 67 |
+
residual_block_2d(
|
| 68 |
+
input_width * (1 if i_step == 0 else 2),
|
| 69 |
+
current_width,
|
| 70 |
+
norm_type=norm_type,
|
| 71 |
+
norm_num_groups=norm_num_groups,
|
| 72 |
+
),
|
| 73 |
+
*[
|
| 74 |
+
residual_block_2d(
|
| 75 |
+
current_width,
|
| 76 |
+
current_width,
|
| 77 |
+
norm_type=norm_type,
|
| 78 |
+
norm_num_groups=norm_num_groups,
|
| 79 |
+
)
|
| 80 |
+
for _ in range(blocks_per_layer - 1)
|
| 81 |
+
],
|
| 82 |
+
)
|
| 83 |
+
self.convs_up.append(convs_up_i)
|
| 84 |
+
input_width = 2 * current_width
|
| 85 |
+
current_width //= 2
|
| 86 |
+
|
| 87 |
+
last_width = self.input_dims[-1]
|
| 88 |
+
self.conv_out = nn.Sequential(
|
| 89 |
+
norm_layer_2d(last_width * 2, norm_type, num_groups=norm_num_groups),
|
| 90 |
+
nn.ReLU(),
|
| 91 |
+
nn.Conv2d(last_width * 2, dim_out, 1),
|
| 92 |
+
norm_layer_2d(dim_out, norm_type, num_groups=norm_num_groups),
|
| 93 |
+
nn.ReLU(),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, features: list[torch.Tensor]) -> torch.Tensor:
|
| 97 |
+
"""Apply UNet to image.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
features: The input multi-level feature map from encoder.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
The output feature map.
|
| 104 |
+
"""
|
| 105 |
+
i_feature_layer = len(features) - 1
|
| 106 |
+
out = self.convs_up[0](features[i_feature_layer])
|
| 107 |
+
i_feature_layer -= 1
|
| 108 |
+
for conv_up in self.convs_up[1:]: # type: ignore
|
| 109 |
+
out = conv_up(torch.cat([out, features[i_feature_layer]], dim=1))
|
| 110 |
+
i_feature_layer -= 1
|
| 111 |
+
out = self.conv_out(torch.cat([out, features[i_feature_layer]], dim=1))
|
| 112 |
+
|
| 113 |
+
return out
|
src/sharp/models/encoders/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains different encoders for Gaussian predictor.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from sharp.models.encoders.base_encoder import BaseEncoder
|
| 8 |
+
|
| 9 |
+
from .monodepth_encoder import (
|
| 10 |
+
MonodepthFeatureEncoder,
|
| 11 |
+
create_monodepth_encoder,
|
| 12 |
+
)
|
| 13 |
+
from .spn_encoder import SlidingPyramidNetwork
|
| 14 |
+
from .unet_encoder import UNetEncoder
|
| 15 |
+
from .vit_encoder import create_vit
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"create_vit",
|
| 19 |
+
"BaseEncoder",
|
| 20 |
+
"UNetEncoder",
|
| 21 |
+
"SlidingPyramidNetwork",
|
| 22 |
+
"MonodepthFeatureEncoder",
|
| 23 |
+
"create_monodepth_encoder",
|
| 24 |
+
]
|
src/sharp/models/encoders/base_encoder.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains the base class for encoders.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import abc
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaseEncoder(nn.Module, abc.ABC):
|
| 14 |
+
"""Base encoder class."""
|
| 15 |
+
|
| 16 |
+
dim_in: int
|
| 17 |
+
output_dims: list[int]
|
| 18 |
+
|
| 19 |
+
@abc.abstractmethod
|
| 20 |
+
def forward(self, image: torch.Tensor) -> list[torch.Tensor]:
|
| 21 |
+
"""Encode input image into multi-resolution encodings."""
|
| 22 |
+
|
| 23 |
+
def internal_resolution(self) -> int:
|
| 24 |
+
"""Internal resolution of the encoder."""
|
| 25 |
+
return 1536
|
src/sharp/models/encoders/monodepth_encoder.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains Dense Transformer Prediction architecture.
|
| 2 |
+
|
| 3 |
+
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
|
| 4 |
+
|
| 5 |
+
For licensing see accompanying LICENSE file.
|
| 6 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from sharp.models.presets import (
|
| 15 |
+
MONODEPTH_ENCODER_DIMS_MAP,
|
| 16 |
+
MONODEPTH_HOOK_IDS_MAP,
|
| 17 |
+
ViTPreset,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from .base_encoder import BaseEncoder
|
| 21 |
+
from .spn_encoder import SlidingPyramidNetwork
|
| 22 |
+
from .vit_encoder import create_vit
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create_monodepth_encoder(
|
| 26 |
+
patch_encoder_preset: ViTPreset,
|
| 27 |
+
image_encoder_preset: ViTPreset,
|
| 28 |
+
use_patch_overlap: bool = True,
|
| 29 |
+
last_encoder: int = 256,
|
| 30 |
+
) -> SlidingPyramidNetwork:
|
| 31 |
+
"""Creates DepthDensePredictionTransformer model.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
patch_encoder_preset: The preset patch encoder architecture in SPN.
|
| 35 |
+
image_encoder_preset: The preset image encoder architecture in SPN.
|
| 36 |
+
use_patch_overlap: Whether to use overlap between patches in SPN.
|
| 37 |
+
last_encoder: last number of encoder features.
|
| 38 |
+
"""
|
| 39 |
+
dims_encoder = [last_encoder] + MONODEPTH_ENCODER_DIMS_MAP[patch_encoder_preset]
|
| 40 |
+
patch_encoder_block_ids = MONODEPTH_HOOK_IDS_MAP[patch_encoder_preset]
|
| 41 |
+
|
| 42 |
+
patch_encoder = create_vit(
|
| 43 |
+
preset=patch_encoder_preset,
|
| 44 |
+
intermediate_features_ids=patch_encoder_block_ids,
|
| 45 |
+
# We always need to output intermediate features for assembly.
|
| 46 |
+
)
|
| 47 |
+
image_encoder = create_vit(
|
| 48 |
+
preset=image_encoder_preset,
|
| 49 |
+
intermediate_features_ids=None,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
encoder = SlidingPyramidNetwork(
|
| 53 |
+
dims_encoder=dims_encoder,
|
| 54 |
+
patch_encoder=patch_encoder,
|
| 55 |
+
image_encoder=image_encoder,
|
| 56 |
+
use_patch_overlap=use_patch_overlap,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return encoder
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ProjectionModule(nn.Module):
|
| 63 |
+
"""Apply projection of features."""
|
| 64 |
+
|
| 65 |
+
def __init__(self, dims_in: list[int], dims_out: list[int]) -> None:
|
| 66 |
+
"""Initialize projection module."""
|
| 67 |
+
super().__init__()
|
| 68 |
+
if len(dims_in) != len(dims_out):
|
| 69 |
+
raise ValueError("Length of dims_in must be same as length of dims_out.")
|
| 70 |
+
self.convs = nn.ModuleList(
|
| 71 |
+
[nn.Conv2d(dim_in, dim_out, 1) for dim_in, dim_out in zip(dims_in, dims_out)]
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(self, encodings: list[torch.Tensor]) -> list[torch.Tensor]:
|
| 75 |
+
"""Apply projection module."""
|
| 76 |
+
if len(encodings) != len(self.convs):
|
| 77 |
+
raise ValueError("Number of encodings must be equal to number of projections.")
|
| 78 |
+
return [conv(encoding) for conv, encoding in zip(self.convs, encodings)]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class MonodepthFeatureEncoder(BaseEncoder):
|
| 82 |
+
"""A wrapper around monodepth network to extract features."""
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
monodepth_encoder: SlidingPyramidNetwork,
|
| 87 |
+
output_dims: list[int] | None = None,
|
| 88 |
+
freeze_projection: bool = False,
|
| 89 |
+
) -> None:
|
| 90 |
+
"""Initialize MonodepthFeatureExtractor."""
|
| 91 |
+
super().__init__()
|
| 92 |
+
|
| 93 |
+
self.encoder = monodepth_encoder
|
| 94 |
+
|
| 95 |
+
# The monodepth network returns two feature maps for the first entry in
|
| 96 |
+
# backbone.encoder.dims_encoder.
|
| 97 |
+
monodepth_dims = self.encoder.dims_encoder
|
| 98 |
+
monodepth_dims = monodepth_dims
|
| 99 |
+
|
| 100 |
+
if output_dims is not None:
|
| 101 |
+
if not len(output_dims) == len(monodepth_dims):
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"When set, number of output dimensions must be equal to output "
|
| 104 |
+
f"dimensions of monodepth model {len(monodepth_dims)}."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.projection = ProjectionModule(monodepth_dims, output_dims)
|
| 108 |
+
self.output_dims = output_dims
|
| 109 |
+
else:
|
| 110 |
+
self.projection = nn.Identity()
|
| 111 |
+
self.output_dims = monodepth_dims
|
| 112 |
+
|
| 113 |
+
if freeze_projection:
|
| 114 |
+
self.projection.requires_grad_(False)
|
| 115 |
+
|
| 116 |
+
def forward(self, input_features: torch.Tensor) -> list[torch.Tensor]:
|
| 117 |
+
"""Extract multi-resolution features."""
|
| 118 |
+
encodings = self.encoder(input_features[:, :3].contiguous())
|
| 119 |
+
return self.projection(encodings)
|
| 120 |
+
|
| 121 |
+
def internal_resolution(self) -> int:
|
| 122 |
+
"""Internal resolution of the encoder."""
|
| 123 |
+
return self.encoder.internal_resolution()
|
src/sharp/models/encoders/spn_encoder.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains Sliding Pyramid Network architecture.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
from typing import Iterable
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.fx
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from sharp.utils.training import checkpoint_wrapper
|
| 18 |
+
|
| 19 |
+
from .base_encoder import BaseEncoder
|
| 20 |
+
from .vit_encoder import TimmViT
|
| 21 |
+
|
| 22 |
+
# torch.fx.wrap is used here to mark functions as leaf nodes during symbolic tracing
|
| 23 |
+
# ensuring they are not traced but seen as atomic operation. In short, symbolic tracing
|
| 24 |
+
# struggles with native python functions and conditional flows.
|
| 25 |
+
non_traceable_ops = ("len", "int")
|
| 26 |
+
for op in non_traceable_ops:
|
| 27 |
+
torch.fx.wrap(op)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SlidingPyramidNetwork(BaseEncoder):
|
| 31 |
+
"""Sliding Pyramid Network.
|
| 32 |
+
|
| 33 |
+
An encoder aimed at creating multi-resolution encodings from Vision Transformers.
|
| 34 |
+
|
| 35 |
+
Reference: Bochkovskii et al. - "Depth pro: Sharp monocular metric depth in less
|
| 36 |
+
than a second." (ICLR 2024)
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
dims_encoder: Iterable[int],
|
| 42 |
+
patch_encoder: TimmViT,
|
| 43 |
+
image_encoder: TimmViT,
|
| 44 |
+
use_patch_overlap: bool = True,
|
| 45 |
+
):
|
| 46 |
+
"""Initialize Sliding Pyramid Network.
|
| 47 |
+
|
| 48 |
+
The framework
|
| 49 |
+
1. creates an image pyramid,
|
| 50 |
+
2. generates overlapping patches with a sliding window at each pyramid level,
|
| 51 |
+
3. creates batched encodings via vision transformer backbones,
|
| 52 |
+
4. produces multi-resolution encodings.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
dims_encoder: Dimensions of the encoder at different layers.
|
| 56 |
+
patch_encoder: Backbone used for highres part of the pyramid.
|
| 57 |
+
image_encoder: Backbone used for lowres part of the pyramid.
|
| 58 |
+
use_patch_overlap: Whether to use overlap between patches in SPN.
|
| 59 |
+
"""
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
self.dim_in = patch_encoder.dim_in
|
| 63 |
+
|
| 64 |
+
self.dims_encoder = list(dims_encoder)
|
| 65 |
+
self.patch_encoder = patch_encoder
|
| 66 |
+
self.image_encoder = image_encoder
|
| 67 |
+
|
| 68 |
+
base_embed_dim = patch_encoder.embed_dim
|
| 69 |
+
lowres_embed_dim = image_encoder.embed_dim
|
| 70 |
+
self.patch_size = patch_encoder.internal_resolution()
|
| 71 |
+
|
| 72 |
+
self.grad_checkpointing = False
|
| 73 |
+
self.use_patch_overlap = use_patch_overlap
|
| 74 |
+
|
| 75 |
+
# Retrieve intermediate feature ids registered in create_monodepth_encoder.
|
| 76 |
+
self.patch_intermediate_features_ids = patch_encoder.intermediate_features_ids
|
| 77 |
+
if (
|
| 78 |
+
not isinstance(self.patch_intermediate_features_ids, list)
|
| 79 |
+
or not len(self.patch_intermediate_features_ids) == 4
|
| 80 |
+
):
|
| 81 |
+
raise ValueError("Patch intermediate feature ids must be a 4-item list.")
|
| 82 |
+
|
| 83 |
+
self.image_intermediate_features_ids = image_encoder.intermediate_features_ids
|
| 84 |
+
|
| 85 |
+
def _create_project_upsample_block(
|
| 86 |
+
dim_in: int,
|
| 87 |
+
dim_out: int,
|
| 88 |
+
upsample_layers: int,
|
| 89 |
+
dim_intermediate=None,
|
| 90 |
+
) -> nn.Module:
|
| 91 |
+
if dim_intermediate is None:
|
| 92 |
+
dim_intermediate = dim_out
|
| 93 |
+
# Projection.
|
| 94 |
+
blocks = [
|
| 95 |
+
nn.Conv2d(
|
| 96 |
+
in_channels=dim_in,
|
| 97 |
+
out_channels=dim_intermediate,
|
| 98 |
+
kernel_size=1,
|
| 99 |
+
stride=1,
|
| 100 |
+
padding=0,
|
| 101 |
+
bias=False,
|
| 102 |
+
)
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
# Upsampling.
|
| 106 |
+
blocks += [
|
| 107 |
+
nn.ConvTranspose2d(
|
| 108 |
+
in_channels=dim_intermediate if i == 0 else dim_out,
|
| 109 |
+
out_channels=dim_out,
|
| 110 |
+
kernel_size=2,
|
| 111 |
+
stride=2,
|
| 112 |
+
padding=0,
|
| 113 |
+
bias=False,
|
| 114 |
+
)
|
| 115 |
+
for i in range(upsample_layers)
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
return nn.Sequential(*blocks)
|
| 119 |
+
|
| 120 |
+
self.upsample_latent0 = _create_project_upsample_block(
|
| 121 |
+
dim_in=base_embed_dim,
|
| 122 |
+
dim_out=self.dims_encoder[0],
|
| 123 |
+
upsample_layers=3,
|
| 124 |
+
dim_intermediate=self.dims_encoder[1],
|
| 125 |
+
)
|
| 126 |
+
self.upsample_latent1 = _create_project_upsample_block(
|
| 127 |
+
dim_in=base_embed_dim, dim_out=self.dims_encoder[1], upsample_layers=2
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.upsample0 = _create_project_upsample_block(
|
| 131 |
+
dim_in=base_embed_dim, dim_out=self.dims_encoder[2], upsample_layers=1
|
| 132 |
+
)
|
| 133 |
+
self.upsample1 = _create_project_upsample_block(
|
| 134 |
+
dim_in=base_embed_dim, dim_out=self.dims_encoder[3], upsample_layers=1
|
| 135 |
+
)
|
| 136 |
+
self.upsample2 = _create_project_upsample_block(
|
| 137 |
+
dim_in=base_embed_dim, dim_out=self.dims_encoder[4], upsample_layers=1
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.upsample_lowres = nn.ConvTranspose2d(
|
| 141 |
+
in_channels=lowres_embed_dim,
|
| 142 |
+
out_channels=self.dims_encoder[4],
|
| 143 |
+
kernel_size=2,
|
| 144 |
+
stride=2,
|
| 145 |
+
padding=0,
|
| 146 |
+
bias=True,
|
| 147 |
+
)
|
| 148 |
+
self.fuse_lowres = nn.Conv2d(
|
| 149 |
+
in_channels=(self.dims_encoder[4] + self.dims_encoder[4]),
|
| 150 |
+
out_channels=self.dims_encoder[4],
|
| 151 |
+
kernel_size=1,
|
| 152 |
+
stride=1,
|
| 153 |
+
padding=0,
|
| 154 |
+
bias=True,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def internal_resolution(self) -> int:
|
| 158 |
+
"""Return the full image size of the SPN network."""
|
| 159 |
+
return self.patch_size * 4
|
| 160 |
+
|
| 161 |
+
@torch.jit.ignore
|
| 162 |
+
def set_grad_checkpointing(self, is_enabled=True):
|
| 163 |
+
"""Enable grad checkpointing."""
|
| 164 |
+
self.grad_checkpointing = is_enabled
|
| 165 |
+
self.patch_encoder.set_grad_checkpointing(is_enabled)
|
| 166 |
+
self.image_encoder.set_grad_checkpointing(is_enabled)
|
| 167 |
+
|
| 168 |
+
@torch.jit.ignore
|
| 169 |
+
def set_requires_grad_(self, patch_encoder: bool, image_encoder: bool):
|
| 170 |
+
"""Set requires grad for separate components."""
|
| 171 |
+
self.patch_encoder.requires_grad_(patch_encoder)
|
| 172 |
+
self.image_encoder.requires_grad_(image_encoder)
|
| 173 |
+
|
| 174 |
+
# Always freeze the unused TimmViT head to exclude it from the calculation of
|
| 175 |
+
# trainable parameters.
|
| 176 |
+
self.patch_encoder.head.requires_grad_(False)
|
| 177 |
+
self.image_encoder.head.requires_grad_(False)
|
| 178 |
+
|
| 179 |
+
# These upsamplers only affect patch encoder's feature maps.
|
| 180 |
+
self.upsample_latent0.requires_grad_(patch_encoder)
|
| 181 |
+
self.upsample_latent1.requires_grad_(patch_encoder)
|
| 182 |
+
self.upsample0.requires_grad_(patch_encoder)
|
| 183 |
+
self.upsample1.requires_grad_(patch_encoder)
|
| 184 |
+
self.upsample2.requires_grad_(patch_encoder)
|
| 185 |
+
|
| 186 |
+
# This upsampler affects only image encoder's feature map.
|
| 187 |
+
self.upsample_lowres.requires_grad_(image_encoder)
|
| 188 |
+
|
| 189 |
+
# This fuser affects both image and patch encoders.
|
| 190 |
+
self.fuse_lowres.requires_grad_(image_encoder or patch_encoder)
|
| 191 |
+
|
| 192 |
+
def _create_pyramid(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 193 |
+
"""Creates a 3-level image pyramid."""
|
| 194 |
+
# Original resolution: 1536 by default.
|
| 195 |
+
x0 = x
|
| 196 |
+
|
| 197 |
+
# Middle resolution: 768 by default.
|
| 198 |
+
x1 = F.interpolate(x, size=None, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 199 |
+
|
| 200 |
+
# Low resolution: 384 by default, corresponding to the backbone resolution.
|
| 201 |
+
x2 = F.interpolate(x, size=None, scale_factor=0.25, mode="bilinear", align_corners=False)
|
| 202 |
+
|
| 203 |
+
return x0, x1, x2
|
| 204 |
+
|
| 205 |
+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
| 206 |
+
"""Encode input at multiple resolutions."""
|
| 207 |
+
batch_size = x.shape[0]
|
| 208 |
+
|
| 209 |
+
# Step 0: create a 3-level image pyramid.
|
| 210 |
+
x0, x1, x2 = self._create_pyramid(x)
|
| 211 |
+
|
| 212 |
+
if self.use_patch_overlap:
|
| 213 |
+
# Step 1: split to create batched overlapped mini-images at the ViT
|
| 214 |
+
# resolution.
|
| 215 |
+
# 5x5 @ 384x384 at the highest resolution (1536x1536).
|
| 216 |
+
x0_patches = split(x0, overlap_ratio=0.25, patch_size=self.patch_size)
|
| 217 |
+
# 3x3 @ 384x384 at the middle resolution (768x768).
|
| 218 |
+
x1_patches = split(x1, overlap_ratio=0.5, patch_size=self.patch_size)
|
| 219 |
+
# 1x1 # 384x384 at the lowest resolution (384x384).
|
| 220 |
+
x2_patches = x2
|
| 221 |
+
padding = 3
|
| 222 |
+
else:
|
| 223 |
+
# Step 1: split to create batched overlapped mini-images at the ViT
|
| 224 |
+
# resolution.
|
| 225 |
+
# 4x4 @ 384x384 at the highest resolution (1536x1536).
|
| 226 |
+
x0_patches = split(x0, overlap_ratio=0.0, patch_size=self.patch_size)
|
| 227 |
+
# 2x2 @ 384x384 at the middle resolution (768x768).
|
| 228 |
+
x1_patches = split(x1, overlap_ratio=0.0, patch_size=self.patch_size)
|
| 229 |
+
# 1x1 # 384x384 at the lowest resolution (384x384).
|
| 230 |
+
x2_patches = x2
|
| 231 |
+
padding = 0
|
| 232 |
+
x0_tile_size = x0_patches.shape[0]
|
| 233 |
+
|
| 234 |
+
# Concatenate all the sliding window patches and form a batch of size
|
| 235 |
+
# (35=5x5+3x3+1x1) or (21=4x4+2x2+1x1).
|
| 236 |
+
x_pyramid_patches = torch.cat(
|
| 237 |
+
(x0_patches, x1_patches, x2_patches),
|
| 238 |
+
dim=0,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Run the ViT model and get the result of large batch size.
|
| 242 |
+
#
|
| 243 |
+
# For the retrieval of intermediate features forward hooks are more concise,
|
| 244 |
+
# but they are not well compatible with symbolic tracing because attributes
|
| 245 |
+
# of submodules can be lost during tracing. Therefore, forward hooks may not
|
| 246 |
+
# be preserved during graph transformation, leading to unexpected behavior.
|
| 247 |
+
# To avoid such issues it is safer not to use them because they are not
|
| 248 |
+
# essential here.
|
| 249 |
+
x_pyramid_encodings, patch_intermediate_features = self.patch_encoder(x_pyramid_patches)
|
| 250 |
+
|
| 251 |
+
# Step 3: merging.
|
| 252 |
+
# Merge highres latent encoding.
|
| 253 |
+
# NOTE: list type check has completed in init.
|
| 254 |
+
x_latent0_encodings = self.patch_encoder.reshape_feature(
|
| 255 |
+
patch_intermediate_features[self.patch_intermediate_features_ids[0]] # type:ignore[index]
|
| 256 |
+
)
|
| 257 |
+
x_latent0_features = merge(
|
| 258 |
+
x_latent0_encodings[: batch_size * x0_tile_size],
|
| 259 |
+
batch_size=batch_size,
|
| 260 |
+
padding=padding,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
x_latent1_encodings = self.patch_encoder.reshape_feature(
|
| 264 |
+
patch_intermediate_features[self.patch_intermediate_features_ids[1]] # type:ignore[index]
|
| 265 |
+
)
|
| 266 |
+
x_latent1_features = merge(
|
| 267 |
+
x_latent1_encodings[: batch_size * x0_tile_size],
|
| 268 |
+
batch_size=batch_size,
|
| 269 |
+
padding=padding,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1.
|
| 273 |
+
x0_encodings, x1_encodings, x2_encodings = torch.split(
|
| 274 |
+
x_pyramid_encodings,
|
| 275 |
+
[len(x0_patches), len(x1_patches), len(x2_patches)],
|
| 276 |
+
dim=0,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps.
|
| 280 |
+
x0_features = merge(x0_encodings, batch_size=batch_size, padding=padding)
|
| 281 |
+
|
| 282 |
+
# 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps.
|
| 283 |
+
x1_features = merge(x1_encodings, batch_size=batch_size, padding=2 * padding)
|
| 284 |
+
|
| 285 |
+
# 24x24 feature maps.
|
| 286 |
+
x2_features = x2_encodings
|
| 287 |
+
|
| 288 |
+
# Apply the image encoder.
|
| 289 |
+
x_lowres_features, image_intermediate_features = self.image_encoder(x2_patches)
|
| 290 |
+
|
| 291 |
+
# Upsample feature maps.
|
| 292 |
+
x_latent0_features = checkpoint_wrapper(self, self.upsample_latent0, x_latent0_features)
|
| 293 |
+
x_latent1_features = checkpoint_wrapper(self, self.upsample_latent1, x_latent1_features)
|
| 294 |
+
|
| 295 |
+
x0_features = checkpoint_wrapper(self, self.upsample0, x0_features)
|
| 296 |
+
x1_features = checkpoint_wrapper(self, self.upsample1, x1_features)
|
| 297 |
+
x2_features = checkpoint_wrapper(self, self.upsample2, x2_features)
|
| 298 |
+
|
| 299 |
+
x_lowres_features = checkpoint_wrapper(self, self.upsample_lowres, x_lowres_features)
|
| 300 |
+
x_lowres_features = checkpoint_wrapper(
|
| 301 |
+
self, self.fuse_lowres, torch.cat((x2_features, x_lowres_features), dim=1)
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
output = [
|
| 305 |
+
x_latent0_features,
|
| 306 |
+
x_latent1_features,
|
| 307 |
+
x0_features,
|
| 308 |
+
x1_features,
|
| 309 |
+
x_lowres_features,
|
| 310 |
+
]
|
| 311 |
+
|
| 312 |
+
return output
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# It seems that torch.fx.wrap can only be applied to functions, not methods.
|
| 316 |
+
# Hence, split and merge were converted into functions to be marked as atomic
|
| 317 |
+
# operations for symbolic tracing.
|
| 318 |
+
@torch.fx.wrap
|
| 319 |
+
def split(image: torch.Tensor, overlap_ratio: float = 0.25, patch_size: int = 384) -> torch.Tensor:
|
| 320 |
+
"""Split the input into small patches with sliding window."""
|
| 321 |
+
patch_stride = int(patch_size * (1 - overlap_ratio))
|
| 322 |
+
|
| 323 |
+
image_size = image.shape[-1]
|
| 324 |
+
steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1
|
| 325 |
+
|
| 326 |
+
x_patch_list = []
|
| 327 |
+
for j in range(steps):
|
| 328 |
+
j0 = j * patch_stride
|
| 329 |
+
j1 = j0 + patch_size
|
| 330 |
+
|
| 331 |
+
for i in range(steps):
|
| 332 |
+
i0 = i * patch_stride
|
| 333 |
+
i1 = i0 + patch_size
|
| 334 |
+
x_patch_list.append(image[..., j0:j1, i0:i1])
|
| 335 |
+
|
| 336 |
+
return torch.cat(x_patch_list, dim=0)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# Decorator marking function as an atomic operator for symbolic tracing.
|
| 340 |
+
@torch.fx.wrap
|
| 341 |
+
def merge(image_patches: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor:
|
| 342 |
+
"""Merge the patched input into a image with sliding window."""
|
| 343 |
+
steps = int(math.sqrt(image_patches.shape[0] // batch_size))
|
| 344 |
+
|
| 345 |
+
idx = 0
|
| 346 |
+
|
| 347 |
+
output_list = []
|
| 348 |
+
for j in range(steps):
|
| 349 |
+
output_row_list = []
|
| 350 |
+
for i in range(steps):
|
| 351 |
+
output = image_patches[batch_size * idx : batch_size * (idx + 1)]
|
| 352 |
+
|
| 353 |
+
if padding != 0:
|
| 354 |
+
if j != 0:
|
| 355 |
+
output = output[..., padding:, :]
|
| 356 |
+
if i != 0:
|
| 357 |
+
output = output[..., :, padding:]
|
| 358 |
+
if j != steps - 1:
|
| 359 |
+
output = output[..., :-padding, :]
|
| 360 |
+
if i != steps - 1:
|
| 361 |
+
output = output[..., :, :-padding]
|
| 362 |
+
|
| 363 |
+
output_row_list.append(output)
|
| 364 |
+
idx += 1
|
| 365 |
+
|
| 366 |
+
output_row = torch.cat(output_row_list, dim=-1)
|
| 367 |
+
output_list.append(output_row)
|
| 368 |
+
output = torch.cat(output_list, dim=-2)
|
| 369 |
+
return output
|
src/sharp/models/encoders/unet_encoder.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains backbone models for feature extraction from RGBD input.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from sharp.models.blocks import (
|
| 15 |
+
NormLayerName,
|
| 16 |
+
norm_layer_2d,
|
| 17 |
+
residual_block_2d,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from .base_encoder import BaseEncoder
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class UNetEncoder(BaseEncoder):
|
| 24 |
+
"""Encoder of UNet model."""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
dim_in: int,
|
| 29 |
+
width: List[int] | int,
|
| 30 |
+
steps: int = 6,
|
| 31 |
+
norm_type: NormLayerName = "group_norm",
|
| 32 |
+
norm_num_groups=8,
|
| 33 |
+
blocks_per_layer=2,
|
| 34 |
+
) -> None:
|
| 35 |
+
"""Initialize UNet Encoder.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
dim_in: The number of input channels.
|
| 39 |
+
width: Width multiplicator of intermediate layers or the width list of all layers.
|
| 40 |
+
steps: The number of downsampling steps.
|
| 41 |
+
norm_type: Which kind of normalization layer to use.
|
| 42 |
+
norm_num_groups: How many groups to use for group norm (if relevant).
|
| 43 |
+
blocks_per_layer: How many residual blocks per layer to use.
|
| 44 |
+
"""
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
if blocks_per_layer < 1:
|
| 48 |
+
raise ValueError("blocks_per_layer must be greater or equal to one.")
|
| 49 |
+
|
| 50 |
+
self.dim_in = dim_in
|
| 51 |
+
self.width = width
|
| 52 |
+
self.num_steps = steps
|
| 53 |
+
|
| 54 |
+
self.convs_down = nn.ModuleList()
|
| 55 |
+
|
| 56 |
+
self.output_dims: list[int]
|
| 57 |
+
# If only one number is specified, we assume each layer will double the channel dimension.
|
| 58 |
+
if isinstance(width, int):
|
| 59 |
+
self.output_dims = [width << i for i in range(0, steps + 1)]
|
| 60 |
+
else:
|
| 61 |
+
if len(width) != (steps + 1):
|
| 62 |
+
raise ValueError("Length of width should match the steps for UNetEncoder.")
|
| 63 |
+
self.output_dims = width
|
| 64 |
+
|
| 65 |
+
self.conv_in = nn.Sequential(
|
| 66 |
+
nn.Conv2d(self.dim_in, self.output_dims[0], 3, stride=1, padding=1),
|
| 67 |
+
norm_layer_2d(self.output_dims[0], norm_type, num_groups=norm_num_groups),
|
| 68 |
+
nn.ReLU(),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
for i_step in range(steps):
|
| 72 |
+
input_width = self.output_dims[i_step]
|
| 73 |
+
current_width = self.output_dims[i_step + 1]
|
| 74 |
+
convs_down_i = nn.Sequential(
|
| 75 |
+
nn.AvgPool2d(2, stride=2),
|
| 76 |
+
residual_block_2d(
|
| 77 |
+
input_width,
|
| 78 |
+
current_width,
|
| 79 |
+
norm_type=norm_type,
|
| 80 |
+
norm_num_groups=norm_num_groups,
|
| 81 |
+
),
|
| 82 |
+
*[
|
| 83 |
+
residual_block_2d(
|
| 84 |
+
current_width,
|
| 85 |
+
current_width,
|
| 86 |
+
norm_type=norm_type,
|
| 87 |
+
norm_num_groups=norm_num_groups,
|
| 88 |
+
)
|
| 89 |
+
for _ in range(blocks_per_layer - 1)
|
| 90 |
+
],
|
| 91 |
+
)
|
| 92 |
+
self.convs_down.append(convs_down_i)
|
| 93 |
+
|
| 94 |
+
def forward(self, input: torch.Tensor) -> list[torch.Tensor]:
|
| 95 |
+
"""Apply UNet Encoder to image.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
input: The input image.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
The output multi-level feature map from encoder.
|
| 102 |
+
"""
|
| 103 |
+
features = []
|
| 104 |
+
|
| 105 |
+
feat_i = self.conv_in(input)
|
| 106 |
+
features.append(feat_i)
|
| 107 |
+
|
| 108 |
+
for conv_down in self.convs_down:
|
| 109 |
+
feat_i = conv_down(feat_i)
|
| 110 |
+
features.append(feat_i)
|
| 111 |
+
|
| 112 |
+
return features
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def out_width(self) -> int:
|
| 116 |
+
"""Compute the output width for UNet decoder."""
|
| 117 |
+
return self.output_dims[-1]
|
src/sharp/models/encoders/vit_encoder.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains factory functions to build and load ViT.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import timm
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from sharp.models.presets.vit import VIT_CONFIG_DICT, ViTConfig, ViTPreset
|
| 15 |
+
|
| 16 |
+
LOGGER = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TimmViT(timm.models.VisionTransformer):
|
| 20 |
+
"""Contains TIMM implementation for Vanilla ViT."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, config: ViTConfig):
|
| 23 |
+
"""Initialize ViT from TIMM implementation."""
|
| 24 |
+
# Handle mlp layers.
|
| 25 |
+
mlp_layer = timm.layers.GluMlp if config.mlp_mode == "glu" else timm.layers.Mlp
|
| 26 |
+
|
| 27 |
+
super().__init__(
|
| 28 |
+
in_chans=config.in_chans,
|
| 29 |
+
embed_dim=config.embed_dim,
|
| 30 |
+
depth=config.depth,
|
| 31 |
+
num_heads=config.num_heads,
|
| 32 |
+
init_values=config.init_values,
|
| 33 |
+
img_size=config.img_size,
|
| 34 |
+
patch_size=config.patch_size,
|
| 35 |
+
num_classes=config.num_classes,
|
| 36 |
+
mlp_ratio=config.mlp_ratio,
|
| 37 |
+
qkv_bias=config.qkv_bias,
|
| 38 |
+
global_pool=config.global_pool,
|
| 39 |
+
mlp_layer=mlp_layer,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Required for extracting intermediate features.
|
| 43 |
+
self.dim_in = config.in_chans
|
| 44 |
+
self.intermediate_features_ids = config.intermediate_features_ids
|
| 45 |
+
|
| 46 |
+
def reshape_feature(self, embeddings: torch.Tensor):
|
| 47 |
+
"""Discard class token and reshape 1D feature map to a 2D grid."""
|
| 48 |
+
batch_size, seq_len, channel = embeddings.shape
|
| 49 |
+
|
| 50 |
+
height, width = self.patch_embed.grid_size
|
| 51 |
+
|
| 52 |
+
# Remove class token.
|
| 53 |
+
if self.num_prefix_tokens:
|
| 54 |
+
embeddings = embeddings[:, self.num_prefix_tokens :, :]
|
| 55 |
+
|
| 56 |
+
# Shape: (batch, height, width, dim) -> (batch, dim, height, width)
|
| 57 |
+
embeddings = embeddings.reshape(batch_size, height, width, channel).permute(0, 3, 1, 2)
|
| 58 |
+
return embeddings
|
| 59 |
+
|
| 60 |
+
def forward(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, dict[int, torch.Tensor]]:
|
| 61 |
+
"""Override forwarding with intermediate features.
|
| 62 |
+
|
| 63 |
+
Adapted from timm ViT.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Output features and list of features from intermediate layers (patch encoder only).
|
| 67 |
+
"""
|
| 68 |
+
intermediate_features = {}
|
| 69 |
+
|
| 70 |
+
x = self.patch_embed(input_tensor)
|
| 71 |
+
batch_size, seq_len, _ = x.shape
|
| 72 |
+
|
| 73 |
+
x = self._pos_embed(x)
|
| 74 |
+
x = self.patch_drop(x)
|
| 75 |
+
x = self.norm_pre(x)
|
| 76 |
+
|
| 77 |
+
for idx, block in enumerate(self.blocks):
|
| 78 |
+
x = block(x)
|
| 79 |
+
if self.intermediate_features_ids is not None and idx in self.intermediate_features_ids:
|
| 80 |
+
intermediate_features[idx] = x
|
| 81 |
+
x = self.norm(x)
|
| 82 |
+
|
| 83 |
+
x = self.reshape_feature(x)
|
| 84 |
+
return x, intermediate_features
|
| 85 |
+
|
| 86 |
+
def internal_resolution(self) -> int:
|
| 87 |
+
"""Return the internal image size of the network."""
|
| 88 |
+
if isinstance(self.patch_embed.img_size, tuple):
|
| 89 |
+
return self.patch_embed.img_size[0]
|
| 90 |
+
else:
|
| 91 |
+
return self.patch_embed.img_size
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def create_vit(
|
| 95 |
+
config: ViTConfig | None = None,
|
| 96 |
+
preset: ViTPreset | None = "dinov2l16_384",
|
| 97 |
+
intermediate_features_ids: list[int] | None = None,
|
| 98 |
+
) -> TimmViT:
|
| 99 |
+
"""Factory function for creating a ViT model."""
|
| 100 |
+
if config is not None:
|
| 101 |
+
LOGGER.info("Using user-defined config.")
|
| 102 |
+
else:
|
| 103 |
+
if preset is None:
|
| 104 |
+
raise ValueError("User-defined config and preset cannot be both None.")
|
| 105 |
+
LOGGER.info("Using preset ViT %s.", preset)
|
| 106 |
+
config = VIT_CONFIG_DICT[preset]
|
| 107 |
+
|
| 108 |
+
config.intermediate_features_ids = intermediate_features_ids
|
| 109 |
+
model = TimmViT(config)
|
| 110 |
+
LOGGER.debug(model)
|
| 111 |
+
return model
|
src/sharp/models/gaussian_decoder.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains Dense Transformer Prediction architecture.
|
| 2 |
+
|
| 3 |
+
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
|
| 4 |
+
|
| 5 |
+
For licensing see accompanying LICENSE file.
|
| 6 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import NamedTuple
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
from sharp.models.blocks import (
|
| 17 |
+
FeatureFusionBlock2d,
|
| 18 |
+
NormLayerName,
|
| 19 |
+
residual_block_2d,
|
| 20 |
+
)
|
| 21 |
+
from sharp.models.decoders import BaseDecoder, MultiresConvDecoder
|
| 22 |
+
from sharp.models.params import DPTImageEncoderType, GaussianDecoderParams
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create_gaussian_decoder(
|
| 26 |
+
params: GaussianDecoderParams, dims_depth_features: list[int]
|
| 27 |
+
) -> GaussianDensePredictionTransformer:
|
| 28 |
+
"""Create gaussian_decoder model specified by gaussian_decoder_name."""
|
| 29 |
+
decoder = MultiresConvDecoder(
|
| 30 |
+
dims_depth_features,
|
| 31 |
+
params.dims_decoder,
|
| 32 |
+
grad_checkpointing=params.grad_checkpointing,
|
| 33 |
+
upsampling_mode=params.upsampling_mode,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return GaussianDensePredictionTransformer(
|
| 37 |
+
decoder=decoder,
|
| 38 |
+
dim_in=params.dim_in,
|
| 39 |
+
dim_out=params.dim_out,
|
| 40 |
+
stride_out=params.stride,
|
| 41 |
+
norm_type=params.norm_type,
|
| 42 |
+
norm_num_groups=params.norm_num_groups,
|
| 43 |
+
use_depth_input=params.use_depth_input,
|
| 44 |
+
grad_checkpointing=params.grad_checkpointing,
|
| 45 |
+
image_encoder_type=params.image_encoder_type,
|
| 46 |
+
image_encoder_params=params,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _create_project_upsample_block(
|
| 51 |
+
dim_in: int,
|
| 52 |
+
dim_out: int,
|
| 53 |
+
upsample_layers: int,
|
| 54 |
+
dim_intermediate: int | None = None,
|
| 55 |
+
) -> nn.Module:
|
| 56 |
+
if dim_intermediate is None:
|
| 57 |
+
dim_intermediate = dim_out
|
| 58 |
+
# Projection.
|
| 59 |
+
blocks = [
|
| 60 |
+
nn.Conv2d(
|
| 61 |
+
in_channels=dim_in,
|
| 62 |
+
out_channels=dim_intermediate,
|
| 63 |
+
kernel_size=1,
|
| 64 |
+
stride=1,
|
| 65 |
+
padding=0,
|
| 66 |
+
bias=False,
|
| 67 |
+
)
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
# Upsampling.
|
| 71 |
+
blocks += [
|
| 72 |
+
nn.ConvTranspose2d(
|
| 73 |
+
in_channels=dim_intermediate if i == 0 else dim_out,
|
| 74 |
+
out_channels=dim_out,
|
| 75 |
+
kernel_size=2,
|
| 76 |
+
stride=2,
|
| 77 |
+
padding=0,
|
| 78 |
+
bias=False,
|
| 79 |
+
)
|
| 80 |
+
for i in range(upsample_layers)
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
return nn.Sequential(*blocks)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class ImageFeatures(NamedTuple):
|
| 87 |
+
"""Image feature extracted from decoder."""
|
| 88 |
+
|
| 89 |
+
texture_features: torch.Tensor
|
| 90 |
+
geometry_features: torch.Tensor
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class SkipConvBackbone(nn.Module):
|
| 94 |
+
"""A wrapper around a conv layer that behaves like a BaseBackbone."""
|
| 95 |
+
|
| 96 |
+
def __init__(self, dim_in: int, dim_out: int, kernel_size: int, stride_out: int):
|
| 97 |
+
"""Initialize SkipConvBackbone."""
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.stride_out = stride_out
|
| 100 |
+
if stride_out == 1 and kernel_size != 1:
|
| 101 |
+
raise ValueError("We only support kernel_size = 1 if stride_out is 1.")
|
| 102 |
+
padding: int = (kernel_size - 1) // 2
|
| 103 |
+
self.conv = nn.Conv2d(
|
| 104 |
+
dim_in, dim_out, kernel_size=kernel_size, stride=stride_out, padding=padding
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def forward(
|
| 108 |
+
self,
|
| 109 |
+
input_features: torch.Tensor,
|
| 110 |
+
encodings: list[torch.Tensor] | None = None,
|
| 111 |
+
) -> ImageFeatures:
|
| 112 |
+
"""Apply SkipConvBackbone to image."""
|
| 113 |
+
output = self.conv(input_features)
|
| 114 |
+
return ImageFeatures(
|
| 115 |
+
texture_features=output,
|
| 116 |
+
geometry_features=output,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def stride(self) -> int:
|
| 121 |
+
"""Effective downsampling stride."""
|
| 122 |
+
return self.stride_out
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class GaussianDensePredictionTransformer(nn.Module):
|
| 126 |
+
"""Dense Prediction Transformer for Gaussian.
|
| 127 |
+
|
| 128 |
+
Reuse monodepth decoded features for processing.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
norm_type: NormLayerName
|
| 132 |
+
|
| 133 |
+
def __init__(
|
| 134 |
+
self,
|
| 135 |
+
decoder: BaseDecoder,
|
| 136 |
+
dim_in: int,
|
| 137 |
+
dim_out: int,
|
| 138 |
+
stride_out: int,
|
| 139 |
+
image_encoder_params: GaussianDecoderParams,
|
| 140 |
+
image_encoder_type: DPTImageEncoderType = "skip_conv",
|
| 141 |
+
norm_type: NormLayerName = "group_norm",
|
| 142 |
+
norm_num_groups: int = 8,
|
| 143 |
+
use_depth_input: bool = True,
|
| 144 |
+
grad_checkpointing: bool = False,
|
| 145 |
+
):
|
| 146 |
+
"""Initialize Dense Prediction Transformer for Gaussian.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
decoder: Decoder to decode features.
|
| 150 |
+
monodepth_decoder: Optional monodepth decoder to fuse monodepth decoded features.
|
| 151 |
+
dim_in: Input dimension.
|
| 152 |
+
dim_out: Final output dimension.
|
| 153 |
+
stride_out: Stride of output feature map.
|
| 154 |
+
image_encoder_params: The backbone parameters to configurate the image encoder.
|
| 155 |
+
image_encoder_type: Type of image encoder to use.
|
| 156 |
+
encoder: Encoder to generate features using monodepth model.
|
| 157 |
+
norm_type: Type of norm layers.
|
| 158 |
+
norm_num_groups: Num groups for norm layers.
|
| 159 |
+
use_depth_input: Whether to use depth input.
|
| 160 |
+
grad_checkpointing: Whether to use gradient checkpointing.
|
| 161 |
+
"""
|
| 162 |
+
super().__init__()
|
| 163 |
+
|
| 164 |
+
self.decoder = decoder
|
| 165 |
+
self.dim_in = dim_in
|
| 166 |
+
self.dim_out = dim_out
|
| 167 |
+
self.stride_out = stride_out
|
| 168 |
+
self.norm_type = norm_type
|
| 169 |
+
self.norm_num_groups = norm_num_groups
|
| 170 |
+
self.use_depth_input = use_depth_input
|
| 171 |
+
self.grad_checkpointing = grad_checkpointing
|
| 172 |
+
self.image_encoder_type = image_encoder_type
|
| 173 |
+
|
| 174 |
+
# Adopt an image encoder to lift dimension to monodepth feature and
|
| 175 |
+
# resize to be the same resolution as the decoder output.
|
| 176 |
+
dim_in = self.dim_in if use_depth_input else self.dim_in - 1
|
| 177 |
+
image_encoder_params.dim_in = dim_in
|
| 178 |
+
image_encoder_params.dim_out = decoder.dim_out
|
| 179 |
+
self.image_encoder = self._create_image_encoder(image_encoder_params, stride_out)
|
| 180 |
+
|
| 181 |
+
self.fusion = FeatureFusionBlock2d(decoder.dim_out)
|
| 182 |
+
|
| 183 |
+
if stride_out == 1:
|
| 184 |
+
self.upsample = _create_project_upsample_block(
|
| 185 |
+
decoder.dim_out,
|
| 186 |
+
decoder.dim_out,
|
| 187 |
+
upsample_layers=1,
|
| 188 |
+
)
|
| 189 |
+
elif stride_out == 2:
|
| 190 |
+
self.upsample = nn.Identity()
|
| 191 |
+
else:
|
| 192 |
+
raise ValueError("We only support stride is 1 or 2 for DPT backbone.")
|
| 193 |
+
|
| 194 |
+
self.texture_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)
|
| 195 |
+
self.geometry_head = self._create_head(dim_decoder=decoder.dim_out, dim_out=self.dim_out)
|
| 196 |
+
|
| 197 |
+
def _create_head(self, dim_decoder: int, dim_out: int) -> nn.Module:
|
| 198 |
+
return nn.Sequential(
|
| 199 |
+
residual_block_2d(
|
| 200 |
+
dim_in=dim_decoder,
|
| 201 |
+
dim_out=dim_decoder,
|
| 202 |
+
dim_hidden=dim_decoder // 2,
|
| 203 |
+
norm_type=self.norm_type,
|
| 204 |
+
norm_num_groups=self.norm_num_groups,
|
| 205 |
+
),
|
| 206 |
+
residual_block_2d(
|
| 207 |
+
dim_in=dim_decoder,
|
| 208 |
+
dim_hidden=dim_decoder // 2,
|
| 209 |
+
dim_out=dim_decoder,
|
| 210 |
+
norm_type=self.norm_type,
|
| 211 |
+
norm_num_groups=self.norm_num_groups,
|
| 212 |
+
),
|
| 213 |
+
nn.ReLU(),
|
| 214 |
+
nn.Conv2d(dim_decoder, dim_out, kernel_size=1, stride=1),
|
| 215 |
+
nn.ReLU(),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def _create_image_encoder(
|
| 219 |
+
self, image_encoder_params: GaussianDecoderParams, stride_out: int
|
| 220 |
+
) -> nn.Module:
|
| 221 |
+
"""Create image encoder and return based on parameters."""
|
| 222 |
+
if self.image_encoder_type == "skip_conv":
|
| 223 |
+
# Use kernel_size = 1 only if stride_out is 1.
|
| 224 |
+
return SkipConvBackbone(
|
| 225 |
+
image_encoder_params.dim_in,
|
| 226 |
+
image_encoder_params.dim_out,
|
| 227 |
+
kernel_size=3 if stride_out != 1 else 1,
|
| 228 |
+
stride_out=stride_out,
|
| 229 |
+
)
|
| 230 |
+
elif self.image_encoder_type == "skip_conv_kernel2":
|
| 231 |
+
return SkipConvBackbone(
|
| 232 |
+
image_encoder_params.dim_in,
|
| 233 |
+
image_encoder_params.dim_out,
|
| 234 |
+
kernel_size=stride_out,
|
| 235 |
+
stride_out=stride_out,
|
| 236 |
+
)
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"Unsupported image encoder type: {self.image_encoder_type}")
|
| 239 |
+
|
| 240 |
+
def forward(self, input_features: torch.Tensor, encodings: list[torch.Tensor]) -> ImageFeatures:
|
| 241 |
+
"""Run monodepth and fuse features with input image to predict Gaussians.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
input_features: The input features to use.
|
| 245 |
+
encodings: Feature encodings (e.g. from monodepth network).
|
| 246 |
+
"""
|
| 247 |
+
features = self.decoder(encodings).contiguous()
|
| 248 |
+
features = self.upsample(features)
|
| 249 |
+
|
| 250 |
+
if self.use_depth_input:
|
| 251 |
+
skip_features = self.image_encoder(input_features).texture_features
|
| 252 |
+
else:
|
| 253 |
+
skip_features = self.image_encoder(input_features[:, :3].contiguous())
|
| 254 |
+
features = self.fusion(features, skip_features)
|
| 255 |
+
|
| 256 |
+
texture_features = self.texture_head(features)
|
| 257 |
+
geometry_features = self.geometry_head(features)
|
| 258 |
+
|
| 259 |
+
return ImageFeatures(
|
| 260 |
+
texture_features=texture_features, # type: ignore
|
| 261 |
+
geometry_features=geometry_features, # type: ignore
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
@property
|
| 265 |
+
def stride(self) -> int:
|
| 266 |
+
"""Internal stride of GaussianDensePredictionTransformer."""
|
| 267 |
+
return self.stride_out
|
src/sharp/models/heads.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains decoder head for direct prediction of delta values.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from .gaussian_decoder import ImageFeatures
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DirectPredictionHead(nn.Module):
|
| 16 |
+
"""Decodes features into delta values using convolutions."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, feature_dim: int, num_layers: int) -> None:
|
| 19 |
+
"""Initialize DirectGaussianPredictor.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
feature_dim: Number of input features.
|
| 23 |
+
num_layers: The number of layers of Gaussians to predict.
|
| 24 |
+
"""
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.num_layers = num_layers
|
| 27 |
+
|
| 28 |
+
# 14 is 3 means, 3 scales, 4 quaternions, 3 colors and 1 opacity
|
| 29 |
+
self.geometry_prediction_head = nn.Conv2d(feature_dim, 3 * num_layers, 1)
|
| 30 |
+
self.geometry_prediction_head.weight.data.zero_()
|
| 31 |
+
assert self.geometry_prediction_head.bias is not None
|
| 32 |
+
self.geometry_prediction_head.bias.data.zero_()
|
| 33 |
+
|
| 34 |
+
self.texture_prediction_head = nn.Conv2d(feature_dim, (14 - 3) * num_layers, 1)
|
| 35 |
+
self.texture_prediction_head.weight.data.zero_()
|
| 36 |
+
assert self.texture_prediction_head.bias is not None
|
| 37 |
+
self.texture_prediction_head.bias.data.zero_()
|
| 38 |
+
|
| 39 |
+
def forward(self, image_features: ImageFeatures) -> torch.Tensor:
|
| 40 |
+
"""Predict deltas for 3D Gaussians.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
image_features: Image features from decoder.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
The predicted deltas for Gaussian attributes.
|
| 47 |
+
"""
|
| 48 |
+
delta_values_geometry = self.geometry_prediction_head(image_features.geometry_features)
|
| 49 |
+
delta_values_texture = self.texture_prediction_head(image_features.texture_features)
|
| 50 |
+
delta_values_geometry = delta_values_geometry.unflatten(1, (3, self.num_layers))
|
| 51 |
+
delta_values_texture = delta_values_texture.unflatten(1, (14 - 3, self.num_layers))
|
| 52 |
+
delta_values = torch.cat([delta_values_geometry, delta_values_texture], dim=1)
|
| 53 |
+
return delta_values
|
src/sharp/models/initializer.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains modules to initialize Gaussians from RGBD.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import NamedTuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from .params import ColorInitOption, DepthInitOption, InitializerParams
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def create_initializer(params: InitializerParams) -> nn.Module:
|
| 18 |
+
"""Create inpainter."""
|
| 19 |
+
return MultiLayerInitializer(
|
| 20 |
+
num_layers=params.num_layers,
|
| 21 |
+
stride=params.stride,
|
| 22 |
+
base_depth=params.base_depth,
|
| 23 |
+
scale_factor=params.scale_factor,
|
| 24 |
+
disparity_factor=params.disparity_factor,
|
| 25 |
+
color_option=params.color_option,
|
| 26 |
+
first_layer_depth_option=params.first_layer_depth_option,
|
| 27 |
+
rest_layer_depth_option=params.rest_layer_depth_option,
|
| 28 |
+
normalize_depth=params.normalize_depth,
|
| 29 |
+
feature_input_stop_grad=params.feature_input_stop_grad,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class GaussianBaseValues(NamedTuple):
|
| 34 |
+
"""Base values for gaussian predictor.
|
| 35 |
+
|
| 36 |
+
We predict x and y in normalized device coordinates (NDC) where (-1, -1) is the top
|
| 37 |
+
left corner and (1, 1) the bottom right corner. The last component of
|
| 38 |
+
mean_vectors_ndc is inverse depth.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
mean_x_ndc: torch.Tensor
|
| 42 |
+
mean_y_ndc: torch.Tensor
|
| 43 |
+
mean_inverse_z_ndc: torch.Tensor
|
| 44 |
+
|
| 45 |
+
scales: torch.Tensor
|
| 46 |
+
quaternions: torch.Tensor
|
| 47 |
+
colors: torch.Tensor
|
| 48 |
+
opacities: torch.Tensor
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class InitializerOutput(NamedTuple):
|
| 52 |
+
"""Output of initializer."""
|
| 53 |
+
|
| 54 |
+
# Gaussian base values.
|
| 55 |
+
gaussian_base_values: GaussianBaseValues
|
| 56 |
+
|
| 57 |
+
# Feature input to the Gaussian predictor.
|
| 58 |
+
feature_input: torch.Tensor
|
| 59 |
+
|
| 60 |
+
# Global scale to unscale output.
|
| 61 |
+
global_scale: torch.Tensor | None = None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class MultiLayerInitializer(nn.Module):
|
| 65 |
+
"""Initialize Gaussians with multilayer representation.
|
| 66 |
+
|
| 67 |
+
The returned tensors have the shape
|
| 68 |
+
|
| 69 |
+
batch_size x dim x num_layers x height x width
|
| 70 |
+
|
| 71 |
+
where dim indicates the dimensionality of the property.
|
| 72 |
+
Some of the dimensions might be set to 1 for efficiency reasons.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
num_layers: int,
|
| 78 |
+
stride: int,
|
| 79 |
+
base_depth: float,
|
| 80 |
+
scale_factor: float,
|
| 81 |
+
disparity_factor: float,
|
| 82 |
+
color_option: ColorInitOption = "first_layer",
|
| 83 |
+
first_layer_depth_option: DepthInitOption = "surface_min",
|
| 84 |
+
rest_layer_depth_option: DepthInitOption = "surface_min",
|
| 85 |
+
normalize_depth: bool = True,
|
| 86 |
+
feature_input_stop_grad: bool = True,
|
| 87 |
+
) -> None:
|
| 88 |
+
"""Initialize MultilayerInitializer.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
stride: The downsample rate of output feature map.
|
| 92 |
+
base_depth: The depth of the first layer (after the foreground
|
| 93 |
+
layer if use_depth=True).
|
| 94 |
+
scale_factor: Multiply scale of Gaussians by this factor.
|
| 95 |
+
disparity_factor: Factor to convert inverse depth to disparity.
|
| 96 |
+
num_layers: How many layers of Gaussians to predict.
|
| 97 |
+
color_option: Which color option to initialize the multi-layer gaussians.
|
| 98 |
+
first_layer_depth_option: Which depth option to initialize the first layer of gaussians.
|
| 99 |
+
rest_layer_depth_option: Which depth option to initialize the rest layers of gaussians.
|
| 100 |
+
normalize_depth: # Whether to normalize depth to [DepthTransformParam.depth_min,
|
| 101 |
+
DepthTransformParam.depth_max).
|
| 102 |
+
feature_input_stop_grad: Whether to not propagate gradients through feature inputs.
|
| 103 |
+
"""
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.num_layers = num_layers
|
| 106 |
+
self.stride = stride
|
| 107 |
+
self.base_depth = base_depth
|
| 108 |
+
self.scale_factor = scale_factor
|
| 109 |
+
self.disparity_factor = disparity_factor
|
| 110 |
+
self.color_option = color_option
|
| 111 |
+
self.first_layer_depth_option = first_layer_depth_option
|
| 112 |
+
self.rest_layer_depth_option = rest_layer_depth_option
|
| 113 |
+
self.normalize_depth = normalize_depth
|
| 114 |
+
self.feature_input_stop_grad = feature_input_stop_grad
|
| 115 |
+
|
| 116 |
+
def prepare_feature_input(self, image: torch.Tensor, depth: torch.Tensor) -> torch.Tensor:
|
| 117 |
+
"""Prepare the feature input to the Guassian predictor."""
|
| 118 |
+
if self.feature_input_stop_grad:
|
| 119 |
+
image = image.detach()
|
| 120 |
+
depth = depth.detach()
|
| 121 |
+
|
| 122 |
+
normalized_disparity = self.disparity_factor / depth
|
| 123 |
+
features_in = torch.cat([image, normalized_disparity], dim=1)
|
| 124 |
+
features_in = 2.0 * features_in - 1.0
|
| 125 |
+
return features_in
|
| 126 |
+
|
| 127 |
+
def forward(self, image: torch.Tensor, depth: torch.Tensor) -> InitializerOutput:
|
| 128 |
+
"""Construct Gaussian base values and prepare feature input.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
image: The image to process.
|
| 132 |
+
depth: The corresponding depth map from the monodepth network.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
The base value for Gaussians.
|
| 136 |
+
"""
|
| 137 |
+
image = image.contiguous()
|
| 138 |
+
depth = depth.contiguous()
|
| 139 |
+
device = depth.device
|
| 140 |
+
batch_size, _, image_height, image_width = depth.shape
|
| 141 |
+
base_height, base_width = (
|
| 142 |
+
image_height // self.stride,
|
| 143 |
+
image_width // self.stride,
|
| 144 |
+
)
|
| 145 |
+
# global_scale is the inverse of the depth_factor, which is used to rescale
|
| 146 |
+
# the depth such that it is numerically stable for training.
|
| 147 |
+
global_scale: torch.Tensor | None = None
|
| 148 |
+
if self.normalize_depth:
|
| 149 |
+
depth, depth_factor = _rescale_depth(depth)
|
| 150 |
+
global_scale = 1.0 / depth_factor
|
| 151 |
+
|
| 152 |
+
def _create_disparity_layers(num_layers: int = 1) -> torch.Tensor:
|
| 153 |
+
"""Create multiple disparity layers."""
|
| 154 |
+
disparity = torch.linspace(1.0 / self.base_depth, 0.0, num_layers + 1, device=device)
|
| 155 |
+
return disparity[None, None, :-1, None, None].repeat(
|
| 156 |
+
batch_size, 1, 1, base_height, base_width
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def _create_surface_layer(
|
| 160 |
+
depth: torch.Tensor,
|
| 161 |
+
depth_pooling_mode: str,
|
| 162 |
+
) -> torch.Tensor:
|
| 163 |
+
"""Create multiple surface layers."""
|
| 164 |
+
disparity = 1.0 / depth
|
| 165 |
+
if depth_pooling_mode == "min":
|
| 166 |
+
disparity = torch.max_pool2d(disparity, self.stride, self.stride)
|
| 167 |
+
elif depth_pooling_mode == "max":
|
| 168 |
+
disparity = -torch.max_pool2d(-disparity, self.stride, self.stride)
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(f"Invalid depth pooling mode {depth_pooling_mode}.")
|
| 171 |
+
|
| 172 |
+
return disparity[:, :, None, :, :]
|
| 173 |
+
|
| 174 |
+
# Input disparity dimensions:
|
| 175 |
+
# (batch_size, num_channels in (1, 2), height, width)
|
| 176 |
+
|
| 177 |
+
# Output disparity dimensions:
|
| 178 |
+
# (batch_size, num_channels=1, num_layers in (1, 2), height, width)
|
| 179 |
+
if self.first_layer_depth_option == "surface_min":
|
| 180 |
+
first_disparity = _create_surface_layer(depth[:, 0:1], "min")
|
| 181 |
+
elif self.first_layer_depth_option == "surface_max":
|
| 182 |
+
first_disparity = _create_surface_layer(depth[:, 0:1], "max")
|
| 183 |
+
elif self.first_layer_depth_option in ("base_depth", "linear_disparity"):
|
| 184 |
+
first_disparity = _create_disparity_layers()
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError(f"Unknown depth init option: {self.first_layer_depth_option}.")
|
| 187 |
+
|
| 188 |
+
if self.num_layers == 1:
|
| 189 |
+
disparity = first_disparity
|
| 190 |
+
else: # Fill in the rest layers.
|
| 191 |
+
following_depth = depth if depth.shape[1] == 1 else depth[:, 1:]
|
| 192 |
+
if self.rest_layer_depth_option == "surface_min":
|
| 193 |
+
following_disparity = _create_surface_layer(following_depth, "min")
|
| 194 |
+
elif self.rest_layer_depth_option == "surface_max":
|
| 195 |
+
following_disparity = _create_surface_layer(following_depth, "max")
|
| 196 |
+
elif self.rest_layer_depth_option == "base_depth":
|
| 197 |
+
following_disparity = torch.cat(
|
| 198 |
+
[_create_disparity_layers() for i in range(self.num_layers - 1)],
|
| 199 |
+
dim=2,
|
| 200 |
+
)
|
| 201 |
+
elif self.rest_layer_depth_option == "linear_disparity":
|
| 202 |
+
following_disparity = _create_disparity_layers(self.num_layers - 1)
|
| 203 |
+
else:
|
| 204 |
+
raise ValueError(f"Unknown depth init option: {self.rest_layer_depth_option}.")
|
| 205 |
+
|
| 206 |
+
disparity = torch.cat([first_disparity, following_disparity], dim=2)
|
| 207 |
+
|
| 208 |
+
# Prepare base values.
|
| 209 |
+
base_x_ndc, base_y_ndc = _create_base_xy(depth, self.stride, self.num_layers)
|
| 210 |
+
disparity_scale_factor = 2 * self.scale_factor * self.stride / float(image_width)
|
| 211 |
+
base_scales = _create_base_scale(disparity, disparity_scale_factor)
|
| 212 |
+
|
| 213 |
+
base_quaternions = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device)
|
| 214 |
+
base_quaternions = base_quaternions[None, :, None, None, None]
|
| 215 |
+
|
| 216 |
+
# Initializing the opacitiy this way ensures that the initial transmittance
|
| 217 |
+
# is approximately
|
| 218 |
+
#
|
| 219 |
+
# 1 / e ~= (1 - 1 / self.num_layers)**self.num_layers
|
| 220 |
+
#
|
| 221 |
+
# and hence independent of the number of layers.
|
| 222 |
+
#
|
| 223 |
+
base_opacities = torch.tensor([min(1.0 / self.num_layers, 0.5)], device=device)
|
| 224 |
+
base_colors = torch.empty(
|
| 225 |
+
batch_size, 3, self.num_layers, base_height, base_width, device=device
|
| 226 |
+
).fill_(0.5)
|
| 227 |
+
# Dimensions: (batch_size, num_channels, num_layers, height, width)
|
| 228 |
+
if self.color_option == "none":
|
| 229 |
+
pass
|
| 230 |
+
elif self.color_option == "first_layer":
|
| 231 |
+
base_colors[:, :, 0] = torch.nn.functional.avg_pool2d(image, self.stride, self.stride)
|
| 232 |
+
elif self.color_option == "all_layers":
|
| 233 |
+
temp = torch.nn.functional.avg_pool2d(image, self.stride, self.stride)
|
| 234 |
+
base_colors = temp[:, :, None, :, :].repeat(1, 1, self.num_layers, 1, 1)
|
| 235 |
+
else:
|
| 236 |
+
raise ValueError(f"Unknown color init option: {self.color_option}.")
|
| 237 |
+
|
| 238 |
+
features_in = self.prepare_feature_input(image, depth)
|
| 239 |
+
base_gaussians = GaussianBaseValues(
|
| 240 |
+
mean_x_ndc=base_x_ndc,
|
| 241 |
+
mean_y_ndc=base_y_ndc,
|
| 242 |
+
mean_inverse_z_ndc=disparity,
|
| 243 |
+
scales=base_scales,
|
| 244 |
+
quaternions=base_quaternions,
|
| 245 |
+
colors=base_colors,
|
| 246 |
+
opacities=base_opacities,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
return InitializerOutput(
|
| 250 |
+
gaussian_base_values=base_gaussians,
|
| 251 |
+
feature_input=features_in,
|
| 252 |
+
global_scale=global_scale,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _create_base_xy(
|
| 257 |
+
depth: torch.Tensor, stride: int, num_layers: int
|
| 258 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 259 |
+
"""Create base x and y coordinates for the gaussians in NDC space."""
|
| 260 |
+
device = depth.device
|
| 261 |
+
batch_size, _, image_height, image_width = depth.shape
|
| 262 |
+
xx = torch.arange(0.5 * stride, image_width, stride, device=device)
|
| 263 |
+
yy = torch.arange(0.5 * stride, image_height, stride, device=device)
|
| 264 |
+
xx = 2 * xx / image_width - 1.0
|
| 265 |
+
yy = 2 * yy / image_height - 1.0
|
| 266 |
+
|
| 267 |
+
xx, yy = torch.meshgrid(xx, yy, indexing="xy")
|
| 268 |
+
base_x_ndc = xx[None, None, None].repeat(batch_size, 1, num_layers, 1, 1)
|
| 269 |
+
base_y_ndc = yy[None, None, None].repeat(batch_size, 1, num_layers, 1, 1)
|
| 270 |
+
|
| 271 |
+
return base_x_ndc, base_y_ndc
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _create_base_scale(disparity: torch.Tensor, disparity_scale_factor: float) -> torch.Tensor:
|
| 275 |
+
"""Create base scale for the gaussians."""
|
| 276 |
+
inverse_disparity = torch.ones_like(disparity) / disparity
|
| 277 |
+
base_scales = inverse_disparity * disparity_scale_factor
|
| 278 |
+
return base_scales
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _rescale_depth(
|
| 282 |
+
depth: torch.Tensor, depth_min: float = 1.0, depth_max: float = 1e2
|
| 283 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 284 |
+
"""Rescale a depth image tensor.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
depth: The depth tensor to transform.
|
| 288 |
+
depth_min: The min depth to scale depth to.
|
| 289 |
+
depth_max: The max clamp depth after scaling.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
The rescaled depth and rescale factor.
|
| 293 |
+
"""
|
| 294 |
+
current_depth_min = depth.flatten(depth.ndim - 3).min(dim=-1).values
|
| 295 |
+
depth_factor = depth_min / (current_depth_min + 1e-6)
|
| 296 |
+
depth = (depth * depth_factor[..., None, None, None]).clamp(max=depth_max)
|
| 297 |
+
return depth, depth_factor
|
src/sharp/models/monodepth.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains Dense Transformer Prediction architecture.
|
| 2 |
+
|
| 3 |
+
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
|
| 4 |
+
|
| 5 |
+
For licensing see accompanying LICENSE file.
|
| 6 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import copy
|
| 12 |
+
from typing import NamedTuple, Tuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from sharp.models import normalizers
|
| 18 |
+
from sharp.models.decoders import MultiresConvDecoder, create_monodepth_decoder
|
| 19 |
+
from sharp.models.encoders import (
|
| 20 |
+
SlidingPyramidNetwork,
|
| 21 |
+
create_monodepth_encoder,
|
| 22 |
+
)
|
| 23 |
+
from sharp.utils import module_surgery
|
| 24 |
+
|
| 25 |
+
from .params import MonodepthAdaptorParams, MonodepthParams
|
| 26 |
+
|
| 27 |
+
DimsDecoder = Tuple[int, int, int, int, int]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MonodepthDensePredictionTransformer(nn.Module):
|
| 31 |
+
"""Dense Prediction Transformer for monodepth.
|
| 32 |
+
|
| 33 |
+
Attach the disparity prediction head for monodepth prediction.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
encoder: SlidingPyramidNetwork,
|
| 39 |
+
decoder: MultiresConvDecoder,
|
| 40 |
+
last_dims: tuple[int, int],
|
| 41 |
+
):
|
| 42 |
+
"""Initialize Dense Prediction Transformer.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
encoder: The SlidingPyramidTransformer backbone.
|
| 46 |
+
decoder: The MultiresConvDecoder decoder.
|
| 47 |
+
last_dims: The dimension for the last convolution layers.
|
| 48 |
+
"""
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
self.normalizer = normalizers.AffineRangeNormalizer(
|
| 52 |
+
input_range=(0, 1), output_range=(-1, 1)
|
| 53 |
+
)
|
| 54 |
+
self.encoder = encoder
|
| 55 |
+
self.decoder = decoder
|
| 56 |
+
|
| 57 |
+
dim_decoder = decoder.dim_out
|
| 58 |
+
self.head = nn.Sequential(
|
| 59 |
+
nn.Conv2d(dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1),
|
| 60 |
+
nn.ConvTranspose2d(
|
| 61 |
+
in_channels=dim_decoder // 2,
|
| 62 |
+
out_channels=dim_decoder // 2,
|
| 63 |
+
kernel_size=2,
|
| 64 |
+
stride=2,
|
| 65 |
+
padding=0,
|
| 66 |
+
bias=True,
|
| 67 |
+
),
|
| 68 |
+
nn.Conv2d(
|
| 69 |
+
dim_decoder // 2,
|
| 70 |
+
last_dims[0],
|
| 71 |
+
kernel_size=3,
|
| 72 |
+
stride=1,
|
| 73 |
+
padding=1,
|
| 74 |
+
),
|
| 75 |
+
nn.ReLU(True),
|
| 76 |
+
nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0),
|
| 77 |
+
nn.ReLU(),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Set the final convoultion layer's bias to be 0.
|
| 81 |
+
self.head[4].bias.data.fill_(0)
|
| 82 |
+
|
| 83 |
+
self.grad_checkpointing = False
|
| 84 |
+
|
| 85 |
+
@torch.jit.ignore
|
| 86 |
+
def set_grad_checkpointing(self, is_enabled=True):
|
| 87 |
+
"""Enable grad checkpointing."""
|
| 88 |
+
self.grad_checkpointing = is_enabled
|
| 89 |
+
self.encoder.set_grad_checkpointing(self.grad_checkpointing)
|
| 90 |
+
self.decoder.set_grad_checkpointing(self.grad_checkpointing)
|
| 91 |
+
|
| 92 |
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
"""Decode by projection and fusion of multi-resolution encodings."""
|
| 94 |
+
encodings = self.encoder(self.normalizer(image))
|
| 95 |
+
num_encoder_features = len(self.encoder.dims_encoder)
|
| 96 |
+
features = self.decoder(encodings[:num_encoder_features])
|
| 97 |
+
disparity = self.head(features)
|
| 98 |
+
return disparity
|
| 99 |
+
|
| 100 |
+
def internal_resolution(self) -> int:
|
| 101 |
+
"""Return the internal image size of the network."""
|
| 102 |
+
return self.encoder.internal_resolution()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def create_monodepth_dpt(
|
| 106 |
+
params: MonodepthParams | None = None,
|
| 107 |
+
) -> MonodepthDensePredictionTransformer:
|
| 108 |
+
"""Creates DepthDensePredictionTransformer model.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
params: Parameters of monodepth network.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
The configured monodepth DPT.
|
| 115 |
+
"""
|
| 116 |
+
if params is None:
|
| 117 |
+
params = MonodepthParams()
|
| 118 |
+
encoder: SlidingPyramidNetwork = create_monodepth_encoder(
|
| 119 |
+
params.patch_encoder_preset,
|
| 120 |
+
params.image_encoder_preset,
|
| 121 |
+
use_patch_overlap=params.use_patch_overlap,
|
| 122 |
+
last_encoder=params.dims_decoder[0],
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
decoder: MultiresConvDecoder = create_monodepth_decoder(
|
| 126 |
+
params.patch_encoder_preset, params.dims_decoder
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
monodepth_model = MonodepthDensePredictionTransformer(
|
| 130 |
+
encoder=encoder, decoder=decoder, last_dims=(32, 1)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# By default, we don't train the monodepth model.
|
| 134 |
+
# However, we allow to selectively unfreeze parts of the network.
|
| 135 |
+
monodepth_model.requires_grad_(False)
|
| 136 |
+
|
| 137 |
+
monodepth_model.encoder.set_requires_grad_(
|
| 138 |
+
patch_encoder=params.unfreeze_patch_encoder,
|
| 139 |
+
image_encoder=params.unfreeze_image_encoder,
|
| 140 |
+
)
|
| 141 |
+
monodepth_model.decoder.requires_grad_(params.unfreeze_decoder)
|
| 142 |
+
monodepth_model.head.requires_grad_(params.unfreeze_head)
|
| 143 |
+
|
| 144 |
+
if not params.unfreeze_norm_layers:
|
| 145 |
+
module_surgery.freeze_norm_layer(monodepth_model)
|
| 146 |
+
|
| 147 |
+
monodepth_model.set_grad_checkpointing(params.grad_checkpointing)
|
| 148 |
+
|
| 149 |
+
return monodepth_model
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class MonodepthOutput(NamedTuple):
|
| 153 |
+
"""Output of the monodepth model."""
|
| 154 |
+
|
| 155 |
+
# Disparity output from the monodepth model.
|
| 156 |
+
disparity: torch.Tensor
|
| 157 |
+
# Multi-level features from monodepth encoder.
|
| 158 |
+
encoder_features: list[torch.Tensor]
|
| 159 |
+
# Single-level feature from monodepth decoder.
|
| 160 |
+
decoder_features: torch.Tensor
|
| 161 |
+
# List of monodepth features to be used in gaussian predictor.
|
| 162 |
+
output_features: list[torch.Tensor]
|
| 163 |
+
# List of intermediate encoder features to be used in distillation.
|
| 164 |
+
intermediate_features: list[torch.Tensor] = []
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class MonodepthWithEncodingAdaptor(nn.Module):
|
| 168 |
+
"""Monodepth model with feature maps."""
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
monodepth_predictor: MonodepthDensePredictionTransformer,
|
| 173 |
+
return_encoder_features: bool,
|
| 174 |
+
return_decoder_features: bool,
|
| 175 |
+
num_monodepth_layers: int,
|
| 176 |
+
sorting_monodepth: bool,
|
| 177 |
+
):
|
| 178 |
+
"""Initialize MonodepthWithEncodingAdaptor.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
monodepth_predictor: The monodepth model.
|
| 182 |
+
return_encoder_features: Whether to return encoder features from monodepth model.
|
| 183 |
+
return_decoder_features: Whether to return decoder features from monodepth model.
|
| 184 |
+
num_monodepth_layers: How many layers the monodepth model predicts.
|
| 185 |
+
sorting_monodepth: Whether to sort the monodepth output (for two layer monodepth).
|
| 186 |
+
"""
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.monodepth_predictor = monodepth_predictor
|
| 189 |
+
self.return_encoder_features = return_encoder_features
|
| 190 |
+
self.return_decoder_features = return_decoder_features
|
| 191 |
+
self.num_monodepth_layers = num_monodepth_layers
|
| 192 |
+
self.sorting_monodepth = sorting_monodepth
|
| 193 |
+
|
| 194 |
+
def forward(self, image: torch.Tensor) -> MonodepthOutput:
|
| 195 |
+
"""Process image and return disparity and feature maps."""
|
| 196 |
+
inputs = self.monodepth_predictor.normalizer(image)
|
| 197 |
+
encoder_output = self.monodepth_predictor.encoder(inputs)
|
| 198 |
+
|
| 199 |
+
num_encoder_features = len(self.monodepth_predictor.encoder.dims_encoder)
|
| 200 |
+
|
| 201 |
+
# NOTE: whether intermediate features are empty have already been decided
|
| 202 |
+
# in monodepth_predictor during create_monodepth_dpt.
|
| 203 |
+
encoder_features = encoder_output[:num_encoder_features]
|
| 204 |
+
intermediate_features = encoder_output[num_encoder_features:]
|
| 205 |
+
decoder_features = self.monodepth_predictor.decoder(encoder_features)
|
| 206 |
+
disparity = self.monodepth_predictor.head(decoder_features)
|
| 207 |
+
|
| 208 |
+
# We cannot use disparity.shape[1], otherwise the tracer will fail.
|
| 209 |
+
if self.num_monodepth_layers == 2 and self.sorting_monodepth:
|
| 210 |
+
first_layer_disparity = disparity.max(dim=1, keepdims=True).values
|
| 211 |
+
second_layer_disparity = disparity.min(dim=1, keepdims=True).values
|
| 212 |
+
disparity = torch.cat([first_layer_disparity, second_layer_disparity], dim=1)
|
| 213 |
+
|
| 214 |
+
output_features = []
|
| 215 |
+
if self.return_encoder_features:
|
| 216 |
+
output_features.extend(encoder_features)
|
| 217 |
+
|
| 218 |
+
if self.return_decoder_features:
|
| 219 |
+
output_features.append(decoder_features)
|
| 220 |
+
|
| 221 |
+
return MonodepthOutput(
|
| 222 |
+
disparity=disparity,
|
| 223 |
+
encoder_features=encoder_features,
|
| 224 |
+
decoder_features=decoder_features,
|
| 225 |
+
output_features=output_features,
|
| 226 |
+
intermediate_features=intermediate_features,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def get_feature_dims(self) -> list[int]:
|
| 230 |
+
"""Return dimensions of output feature maps."""
|
| 231 |
+
dims = []
|
| 232 |
+
if self.return_encoder_features:
|
| 233 |
+
dims.extend(self.monodepth_predictor.encoder.dims_encoder)
|
| 234 |
+
|
| 235 |
+
if self.return_decoder_features:
|
| 236 |
+
dims.append(self.monodepth_predictor.decoder.dim_out)
|
| 237 |
+
|
| 238 |
+
return dims
|
| 239 |
+
|
| 240 |
+
def internal_resolution(self) -> int:
|
| 241 |
+
"""Return the internal image size of the network."""
|
| 242 |
+
return self.monodepth_predictor.internal_resolution()
|
| 243 |
+
|
| 244 |
+
def replicate_head(self, num_repeat: int):
|
| 245 |
+
"""Replicate the last convolution layer (head[4] in DPT) for multi layer depth."""
|
| 246 |
+
conv_last = copy.deepcopy(self.monodepth_predictor.head[4])
|
| 247 |
+
self.monodepth_predictor.head[4].out_channels = num_repeat
|
| 248 |
+
self.monodepth_predictor.head[4].weight = nn.Parameter(
|
| 249 |
+
conv_last.weight.repeat(num_repeat, 1, 1, 1)
|
| 250 |
+
)
|
| 251 |
+
self.monodepth_predictor.head[4].bias = nn.Parameter(conv_last.bias.repeat(num_repeat))
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def create_monodepth_adaptor(
|
| 255 |
+
monodepth_predictor: MonodepthDensePredictionTransformer,
|
| 256 |
+
params: MonodepthAdaptorParams,
|
| 257 |
+
num_monodepth_layers: int,
|
| 258 |
+
sorting_monodepth: bool,
|
| 259 |
+
) -> MonodepthWithEncodingAdaptor:
|
| 260 |
+
"""Create an adaptor that returns both disparity and features."""
|
| 261 |
+
adaptor = MonodepthWithEncodingAdaptor(
|
| 262 |
+
monodepth_predictor=monodepth_predictor,
|
| 263 |
+
return_encoder_features=params.encoder_features,
|
| 264 |
+
return_decoder_features=params.decoder_features,
|
| 265 |
+
num_monodepth_layers=num_monodepth_layers,
|
| 266 |
+
sorting_monodepth=sorting_monodepth,
|
| 267 |
+
)
|
| 268 |
+
return adaptor
|
src/sharp/models/normalizers.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains an implementation of image normalizers for perceptual loss.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Sequence, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MeanStdNormalizer(nn.Module):
|
| 16 |
+
"""Normalizing image input by mean and std."""
|
| 17 |
+
|
| 18 |
+
mean: torch.Tensor
|
| 19 |
+
std_inv: torch.Tensor
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
mean: Union[Sequence[float], torch.Tensor],
|
| 24 |
+
std: Union[Sequence[float], torch.Tensor],
|
| 25 |
+
):
|
| 26 |
+
"""Initialize MeanStdNormalizer."""
|
| 27 |
+
super(MeanStdNormalizer, self).__init__()
|
| 28 |
+
if not isinstance(mean, torch.Tensor):
|
| 29 |
+
mean = torch.as_tensor(mean).view(-1, 1, 1)
|
| 30 |
+
if not isinstance(std, torch.Tensor):
|
| 31 |
+
std = torch.as_tensor(std).view(-1, 1, 1)
|
| 32 |
+
self.register_buffer("mean", mean)
|
| 33 |
+
# We use inverse std to use a multiplication which is better supported by the hardware
|
| 34 |
+
self.register_buffer("std_inv", 1.0 / std)
|
| 35 |
+
|
| 36 |
+
def forward(self, image: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
"""Apply mean and std normalization over input image."""
|
| 38 |
+
return (image - self.mean) * self.std_inv
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class AffineRangeNormalizer(nn.Module):
|
| 42 |
+
"""Perform linear mapping to map input_range to output_range.
|
| 43 |
+
|
| 44 |
+
Output_range defaults to (0, 1).
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
input_range: tuple[float, float],
|
| 50 |
+
output_range: tuple[float, float] = (0, 1),
|
| 51 |
+
):
|
| 52 |
+
"""Initialize AffineRangeNormalizer."""
|
| 53 |
+
super().__init__()
|
| 54 |
+
input_min, input_max = input_range
|
| 55 |
+
output_min, output_max = output_range
|
| 56 |
+
if input_max <= input_min:
|
| 57 |
+
raise ValueError(f"Invalid input_range: {input_range}")
|
| 58 |
+
if output_max <= output_min:
|
| 59 |
+
raise ValueError(f"Invalid output_range: {output_range}")
|
| 60 |
+
|
| 61 |
+
self.scale = (output_max - output_min) / (input_max - input_min)
|
| 62 |
+
self.bias = output_min - input_min * self.scale
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 65 |
+
"""Apply affine range normalization over input image."""
|
| 66 |
+
if self.scale != 1.0:
|
| 67 |
+
x = x * self.scale
|
| 68 |
+
|
| 69 |
+
if self.bias != 0.0:
|
| 70 |
+
x = x + self.bias
|
| 71 |
+
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class MobileNetNormalizer(AffineRangeNormalizer):
|
| 76 |
+
"""Image normalization in mobilenet."""
|
| 77 |
+
|
| 78 |
+
def __init__(self, input_range: tuple[float, float] = (0, 1)):
|
| 79 |
+
"""Initialize MobileNetNormalizer."""
|
| 80 |
+
super().__init__(input_range=input_range, output_range=(-1, 1))
|
src/sharp/models/params.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains params for backbone.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import dataclasses
|
| 8 |
+
from typing import Literal
|
| 9 |
+
|
| 10 |
+
import sharp.utils.math as math_utils
|
| 11 |
+
from sharp.models.blocks import NormLayerName, UpsamplingMode
|
| 12 |
+
from sharp.models.presets import ViTPreset
|
| 13 |
+
from sharp.utils.color_space import ColorSpace
|
| 14 |
+
|
| 15 |
+
DimsDecoder = tuple[int, int, int, int, int]
|
| 16 |
+
DPTImageEncoderType = Literal["skip_conv", "skip_conv_kernel2"]
|
| 17 |
+
|
| 18 |
+
ColorInitOption = Literal[
|
| 19 |
+
"none", # Initialize as gray.
|
| 20 |
+
"first_layer", # Initialize the first layer with input image, other layers with gray.
|
| 21 |
+
"all_layers", # Initialize all layers with input image.
|
| 22 |
+
]
|
| 23 |
+
DepthInitOption = Literal[
|
| 24 |
+
# Initialize the layer of gaussian on surface using min pooling of input depth.
|
| 25 |
+
"surface_min",
|
| 26 |
+
# Initialize the layer of gaussian on surface using max pooling of input depth
|
| 27 |
+
"surface_max",
|
| 28 |
+
# Initialize the layer of gaussian on plane using base_depth depth.
|
| 29 |
+
"base_depth",
|
| 30 |
+
# Initialize the layer of gaussian on plane based on base_depth and index of layer.
|
| 31 |
+
"linear_disparity",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclasses.dataclass
|
| 36 |
+
class AlignmentParams:
|
| 37 |
+
"""Parameters for depth alignment."""
|
| 38 |
+
|
| 39 |
+
kernel_size: int = 16
|
| 40 |
+
stride: int = 1
|
| 41 |
+
frozen: bool = False
|
| 42 |
+
|
| 43 |
+
# The following parameters are only used for LearnedAlignment.
|
| 44 |
+
# Number of steps in the UNet for LearnedAlignment.
|
| 45 |
+
steps: int = 4
|
| 46 |
+
# Activation type for LearnedAlignment.
|
| 47 |
+
activation_type: math_utils.ActivationType = "exp"
|
| 48 |
+
# Whether to use depth decoder features for LearnedAlignment.
|
| 49 |
+
depth_decoder_features: bool = False
|
| 50 |
+
# Base width of the UNet for LearnedAlignment.
|
| 51 |
+
base_width: int = 16
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclasses.dataclass
|
| 55 |
+
class DeltaFactor:
|
| 56 |
+
"""Factors to multiply deltas with before activation.
|
| 57 |
+
|
| 58 |
+
These factors effectively selectively reduce the learning rate.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
xy: float = 0.001
|
| 62 |
+
z: float = 0.001
|
| 63 |
+
color: float = 0.1 # We recommend 0.1 for linearRGB and 1.0 for sRGB.
|
| 64 |
+
opacity: float = 1.0
|
| 65 |
+
scale: float = 1.0
|
| 66 |
+
quaternion: float = 1.0
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclasses.dataclass
|
| 70 |
+
class InitializerParams:
|
| 71 |
+
"""Parameters for initializer."""
|
| 72 |
+
|
| 73 |
+
# Common parameters.
|
| 74 |
+
# Multiply scale of Gaussians by this factor.
|
| 75 |
+
scale_factor: float = 1.0
|
| 76 |
+
# Factor to convert inverse depth to disparity.
|
| 77 |
+
disparity_factor: float = 1.0
|
| 78 |
+
# Stride of the initializer.
|
| 79 |
+
stride: int = 2
|
| 80 |
+
|
| 81 |
+
# Parameters that only affect MultiLayerInitializer.
|
| 82 |
+
# How many layers of Gaussians to predict (only available for MultiLayerInitializer).
|
| 83 |
+
num_layers: int = 2
|
| 84 |
+
# Which option to use for depth initialization.
|
| 85 |
+
first_layer_depth_option: DepthInitOption = "surface_min"
|
| 86 |
+
rest_layer_depth_option: DepthInitOption = "surface_min"
|
| 87 |
+
# Which option to use for color initialization.
|
| 88 |
+
color_option: ColorInitOption = "all_layers"
|
| 89 |
+
# Which depth value to use for depth layers.
|
| 90 |
+
base_depth: float = 10.0
|
| 91 |
+
# Deactivate gradient for feature inputs.
|
| 92 |
+
feature_input_stop_grad: bool = False
|
| 93 |
+
# Whether to normalize depth to [DepthTransformParam.depth_min,
|
| 94 |
+
# DepthTransformParam.depth_max).
|
| 95 |
+
normalize_depth: bool = True
|
| 96 |
+
|
| 97 |
+
# Output only the inpainted layer. In this case, num_layers = 1.
|
| 98 |
+
output_inpainted_layer_only: bool = False
|
| 99 |
+
# Whether to set the uninpainted region to zero opacities.
|
| 100 |
+
set_uninpainted_opacity_to_zero: bool = False
|
| 101 |
+
# Whether to concatenate the inpainting mask to the feature input.
|
| 102 |
+
concat_inpainting_mask: bool = False
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclasses.dataclass
|
| 106 |
+
class MonodepthParams:
|
| 107 |
+
"""Parameters for monodepth network."""
|
| 108 |
+
|
| 109 |
+
patch_encoder_preset: ViTPreset = "dinov2l16_384"
|
| 110 |
+
image_encoder_preset: ViTPreset = "dinov2l16_384"
|
| 111 |
+
|
| 112 |
+
checkpoint_uri: str | None = None
|
| 113 |
+
unfreeze_patch_encoder: bool = False
|
| 114 |
+
unfreeze_image_encoder: bool = False
|
| 115 |
+
unfreeze_decoder: bool = False
|
| 116 |
+
unfreeze_head: bool = False
|
| 117 |
+
unfreeze_norm_layers: bool = False
|
| 118 |
+
grad_checkpointing: bool = False
|
| 119 |
+
use_patch_overlap: bool = True
|
| 120 |
+
dims_decoder: DimsDecoder = (256, 256, 256, 256, 256)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@dataclasses.dataclass
|
| 124 |
+
class MonodepthAdaptorParams:
|
| 125 |
+
"""Parameters for monodepth network feature adaptor."""
|
| 126 |
+
|
| 127 |
+
encoder_features: bool = True
|
| 128 |
+
decoder_features: bool = False
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@dataclasses.dataclass
|
| 132 |
+
class GaussianDecoderParams:
|
| 133 |
+
"""Parameters for backbone with default values."""
|
| 134 |
+
|
| 135 |
+
dim_in: int = 5
|
| 136 |
+
dim_out: int = 32
|
| 137 |
+
# Which normalization to use in backbone.
|
| 138 |
+
norm_type: NormLayerName = "group_norm"
|
| 139 |
+
# How many groups to use for group normalization.
|
| 140 |
+
norm_num_groups: int = 8
|
| 141 |
+
# Stride of backbone.
|
| 142 |
+
stride: int = 2
|
| 143 |
+
|
| 144 |
+
patch_encoder_preset: ViTPreset = "dinov2l16_384"
|
| 145 |
+
image_encoder_preset: ViTPreset = "dinov2l16_384"
|
| 146 |
+
|
| 147 |
+
# Dimensionality of feature maps for DPT decoder.
|
| 148 |
+
dims_decoder: DimsDecoder = (128, 128, 128, 128, 128)
|
| 149 |
+
|
| 150 |
+
# Whether to use depth as input.
|
| 151 |
+
use_depth_input: bool = True
|
| 152 |
+
|
| 153 |
+
# Whether to enable gradient checkpointing for the backbone
|
| 154 |
+
grad_checkpointing: bool = False
|
| 155 |
+
|
| 156 |
+
# What mode to use for upsampling in decoder.
|
| 157 |
+
upsampling_mode: UpsamplingMode = "transposed_conv"
|
| 158 |
+
|
| 159 |
+
# The type of image encoder.
|
| 160 |
+
image_encoder_type: DPTImageEncoderType = "skip_conv_kernel2"
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dataclasses.dataclass
|
| 164 |
+
class PredictorParams:
|
| 165 |
+
"""Parameters for predictors with default values."""
|
| 166 |
+
|
| 167 |
+
# Parameters for submodules.
|
| 168 |
+
initializer: InitializerParams = dataclasses.field(default_factory=InitializerParams)
|
| 169 |
+
monodepth: MonodepthParams = dataclasses.field(default_factory=MonodepthParams)
|
| 170 |
+
monodepth_adaptor: MonodepthAdaptorParams = dataclasses.field(
|
| 171 |
+
default_factory=MonodepthAdaptorParams
|
| 172 |
+
)
|
| 173 |
+
gaussian_decoder: GaussianDecoderParams = dataclasses.field(
|
| 174 |
+
default_factory=GaussianDecoderParams
|
| 175 |
+
)
|
| 176 |
+
# How to align depth map (only relevant for RGBGaussianPredictor).
|
| 177 |
+
depth_alignment: AlignmentParams = dataclasses.field(default_factory=AlignmentParams)
|
| 178 |
+
|
| 179 |
+
# Selectively reduce learning rate for different properties.
|
| 180 |
+
delta_factor: DeltaFactor = dataclasses.field(default_factory=DeltaFactor)
|
| 181 |
+
# The maximum scale of Gaussians relative to initial scale.
|
| 182 |
+
max_scale: float = 10.0
|
| 183 |
+
# The minimum scale of Gaussians relative to initial scale.
|
| 184 |
+
min_scale: float = 0.0
|
| 185 |
+
# Which normalization to use in prediction head.
|
| 186 |
+
norm_type: NormLayerName = "group_norm"
|
| 187 |
+
# How many groups to use for group normalization.
|
| 188 |
+
norm_num_groups: int = 8
|
| 189 |
+
# Whether to use predicted mean to sample triplane features.
|
| 190 |
+
use_predicted_mean: bool = False
|
| 191 |
+
# Which activation function to use for colors / opacities.
|
| 192 |
+
color_activation_type: math_utils.ActivationType = "sigmoid"
|
| 193 |
+
opacity_activation_type: math_utils.ActivationType = "sigmoid"
|
| 194 |
+
# Colorspace of the renderer ("linearRGB" or "sRGB").
|
| 195 |
+
color_space: ColorSpace = "linearRGB"
|
| 196 |
+
# A small value to avoid ill-conditioned splats
|
| 197 |
+
low_pass_filter_eps: float = 1e-2
|
| 198 |
+
# How many layer of depth does monodepth model predict.
|
| 199 |
+
num_monodepth_layers: int = 2
|
| 200 |
+
# Whether to sort the monodepth output (for two layer monodepth).
|
| 201 |
+
sorting_monodepth: bool = False
|
| 202 |
+
# Whether to account the z offsets for estimating base scale.
|
| 203 |
+
base_scale_on_predicted_mean: bool = True
|
src/sharp/models/predictor.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains definition of RGB-only gaussian predictor.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
from sharp.models.monodepth import MonodepthWithEncodingAdaptor
|
| 15 |
+
from sharp.utils.gaussians import Gaussians3D
|
| 16 |
+
|
| 17 |
+
from .composer import GaussianComposer
|
| 18 |
+
|
| 19 |
+
LOGGER = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DepthAlignment(nn.Module):
|
| 23 |
+
"""Depth alignment in a dedicated nn.Module.
|
| 24 |
+
|
| 25 |
+
Wrap scale_map_estimator to perform the conditional logic in a separated torch
|
| 26 |
+
module outside the forward of RGBGaussianPredictor. This module can be then
|
| 27 |
+
excluded during symbolic tracing.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, scale_map_estimator: nn.Module | None):
|
| 31 |
+
"""Initialize DepthAlignmentWrapper.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
scale_map_estimator: Module to align monodepth to ground truth depth.
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.scale_map_estimator = scale_map_estimator
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
monodepth: torch.Tensor,
|
| 42 |
+
depth: torch.Tensor,
|
| 43 |
+
depth_decoder_features: torch.Tensor | None = None,
|
| 44 |
+
):
|
| 45 |
+
"""Optionally align monodepth to ground truth with a local scale map.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
monodepth: The monodepth model with intermediate features to use.
|
| 49 |
+
depth: Ground truth depth to align predicted depth to.
|
| 50 |
+
depth_decoder_features: The (optional) monodepth decoder features.
|
| 51 |
+
"""
|
| 52 |
+
if depth is not None and self.scale_map_estimator is not None:
|
| 53 |
+
depth_alignment_map = self.scale_map_estimator(
|
| 54 |
+
monodepth[:, 0:1], depth, depth_decoder_features
|
| 55 |
+
)
|
| 56 |
+
monodepth = depth_alignment_map * monodepth
|
| 57 |
+
else:
|
| 58 |
+
# Some losses rely on the presence of an alignment map.
|
| 59 |
+
# We ensure that they can be computed by creating a fake alignment map.
|
| 60 |
+
depth_alignment_map = torch.ones_like(monodepth)
|
| 61 |
+
return monodepth, depth_alignment_map
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class RGBGaussianPredictor(nn.Module):
|
| 65 |
+
"""Predicts 3D Gaussians from images."""
|
| 66 |
+
|
| 67 |
+
feature_model: nn.Module
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
init_model: nn.Module,
|
| 72 |
+
monodepth_model: MonodepthWithEncodingAdaptor,
|
| 73 |
+
feature_model: nn.Module,
|
| 74 |
+
prediction_head: nn.Module,
|
| 75 |
+
gaussian_composer: GaussianComposer,
|
| 76 |
+
scale_map_estimator: nn.Module | None,
|
| 77 |
+
) -> None:
|
| 78 |
+
"""Initialize RGBGaussianPredictor.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
init_model: A model mapping image and depth to base values.
|
| 82 |
+
monodepth_model: The monodepth model with intermediate features to use.
|
| 83 |
+
feature_model: The image2image model to predict Gaussians from.
|
| 84 |
+
prediction_head: Head to decode image features.
|
| 85 |
+
gaussian_composer: Module to compose final prediction from deltas and
|
| 86 |
+
base values.
|
| 87 |
+
scale_map_estimator: Module to align monodepth to ground truth depth.
|
| 88 |
+
|
| 89 |
+
Note:
|
| 90 |
+
----
|
| 91 |
+
when monodepth_model is trainable, using local depth alignment can
|
| 92 |
+
result in the monodepth model losing its ability to predict shapes. It is
|
| 93 |
+
hence recommend to deactivate the corresponding flag.
|
| 94 |
+
"""
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.init_model = init_model
|
| 97 |
+
self.feature_model = feature_model
|
| 98 |
+
self.monodepth_model = monodepth_model
|
| 99 |
+
self.prediction_head = prediction_head
|
| 100 |
+
self.gaussian_composer = gaussian_composer
|
| 101 |
+
self.depth_alignment = DepthAlignment(scale_map_estimator)
|
| 102 |
+
|
| 103 |
+
def forward(
|
| 104 |
+
self,
|
| 105 |
+
image: torch.Tensor,
|
| 106 |
+
disparity_factor: torch.Tensor,
|
| 107 |
+
depth: torch.Tensor | None = None,
|
| 108 |
+
) -> Gaussians3D:
|
| 109 |
+
"""Predict 3D Gaussians.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
image: The image to process.
|
| 113 |
+
disparity_factor: Factor to convert depth to disparities.
|
| 114 |
+
depth: Ground truth depth to align predicted depth to.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
The predicted 3D Gaussians.
|
| 118 |
+
|
| 119 |
+
Note:
|
| 120 |
+
----
|
| 121 |
+
During training, it is recommended to feed an additional ground truth depth
|
| 122 |
+
map to the network to align the predicted depth to. During inference, it is
|
| 123 |
+
recommended to use depth_gt=None and use monodepth_disparity output from the
|
| 124 |
+
model instead to compute depth.
|
| 125 |
+
"""
|
| 126 |
+
# Estimate depth and align to ground truth (if available).
|
| 127 |
+
monodepth_output = self.monodepth_model(image)
|
| 128 |
+
monodepth_disparity = monodepth_output.disparity
|
| 129 |
+
|
| 130 |
+
disparity_factor = disparity_factor[:, None, None, None]
|
| 131 |
+
monodepth = disparity_factor / monodepth_disparity.clamp(min=1e-4, max=1e4)
|
| 132 |
+
|
| 133 |
+
# In the model we apply additional alignment to provided ground truth depth
|
| 134 |
+
# as well as additional normalization.
|
| 135 |
+
#
|
| 136 |
+
# The overall graph looks as follows:
|
| 137 |
+
#
|
| 138 |
+
# monodepth depth # Both monodepth and depth are metric here.
|
| 139 |
+
# | |
|
| 140 |
+
# +------+-------+
|
| 141 |
+
# |
|
| 142 |
+
# +-------+--------+ # Optionally align monodepth to ground truth
|
| 143 |
+
# |depth_alignement| # with a local scale map.
|
| 144 |
+
# +-------+--------+
|
| 145 |
+
# |
|
| 146 |
+
# v
|
| 147 |
+
# monodepth (aligned) # Monodepth is now aligned to ground truth.
|
| 148 |
+
# |
|
| 149 |
+
# +-----+----+ # Normalize depth and compute base gaussians.
|
| 150 |
+
# |init_model| # in these normalized coordinates.
|
| 151 |
+
# +-----+----+
|
| 152 |
+
# |
|
| 153 |
+
# v
|
| 154 |
+
# +------ init_output # Init_output consists of features, base
|
| 155 |
+
# | | # gaussians and a global scale.
|
| 156 |
+
# | +------+-----+
|
| 157 |
+
# | |main network| # Compute delta values to base gaussians.
|
| 158 |
+
# | +------+-----+
|
| 159 |
+
# | |
|
| 160 |
+
# | V
|
| 161 |
+
# | delta_values # The delta values are computed with normalized depth.
|
| 162 |
+
# | |
|
| 163 |
+
# | +-------+---------+
|
| 164 |
+
# +--> |gaussian_composer| # Add delta to base values and unscale gaussians.
|
| 165 |
+
# +-------+---------+
|
| 166 |
+
# |
|
| 167 |
+
# v
|
| 168 |
+
# gaussians # The final Gaussians are metric again.
|
| 169 |
+
#
|
| 170 |
+
|
| 171 |
+
# The logic to decide whether to align monodepth to the ground truth is wrapped
|
| 172 |
+
# in a submodule 'DepthAlignement' to facilitate the symbolic tracing of the
|
| 173 |
+
# predictor. This way, the depth alignment submodule containing the conditional
|
| 174 |
+
# logic can be excluded during the tracing and the graph of the predictors is
|
| 175 |
+
# static.
|
| 176 |
+
monodepth, _ = self.depth_alignment(
|
| 177 |
+
monodepth,
|
| 178 |
+
depth,
|
| 179 |
+
monodepth_output.decoder_features,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
init_output = self.init_model(image, monodepth)
|
| 183 |
+
image_features = self.feature_model(
|
| 184 |
+
init_output.feature_input, encodings=monodepth_output.output_features
|
| 185 |
+
)
|
| 186 |
+
delta_values = self.prediction_head(image_features)
|
| 187 |
+
gaussians = self.gaussian_composer(
|
| 188 |
+
delta=delta_values,
|
| 189 |
+
base_values=init_output.gaussian_base_values,
|
| 190 |
+
global_scale=init_output.global_scale,
|
| 191 |
+
)
|
| 192 |
+
return gaussians
|
| 193 |
+
|
| 194 |
+
def internal_resolution(self) -> int:
|
| 195 |
+
"""Internal resolution."""
|
| 196 |
+
return self.monodepth_model.internal_resolution()
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def output_resolution(self) -> int:
|
| 200 |
+
"""Output resolution of Gaussians."""
|
| 201 |
+
return self.internal_resolution() // 2
|
src/sharp/models/presets/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains presets for pretrained neural networks.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .monodepth import (
|
| 8 |
+
MONODEPTH_ENCODER_DIMS_MAP,
|
| 9 |
+
MONODEPTH_HOOK_IDS_MAP,
|
| 10 |
+
)
|
| 11 |
+
from .vit import (
|
| 12 |
+
VIT_CONFIG_DICT,
|
| 13 |
+
ViTConfig,
|
| 14 |
+
ViTPreset,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"ViTConfig",
|
| 19 |
+
"ViTPreset",
|
| 20 |
+
"VIT_CONFIG_DICT",
|
| 21 |
+
"MONODEPTH_ENCODER_DIMS_MAP",
|
| 22 |
+
"MONODEPTH_HOOK_IDS_MAP",
|
| 23 |
+
]
|
src/sharp/models/presets/monodepth.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains preset for monodepth modules.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from .vit import ViTPreset
|
| 10 |
+
|
| 11 |
+
# Map the decoder configuration with the number of output channels
|
| 12 |
+
# for each tensor from the decoder output.
|
| 13 |
+
MONODEPTH_ENCODER_DIMS_MAP: dict[ViTPreset, list[int]] = {
|
| 14 |
+
# For publication
|
| 15 |
+
"dinov2l16_384": [256, 512, 1024, 1024],
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
MONODEPTH_HOOK_IDS_MAP: dict[ViTPreset, list[int]] = {
|
| 19 |
+
# For publication
|
| 20 |
+
"dinov2l16_384": [5, 11, 17, 23],
|
| 21 |
+
}
|
src/sharp/models/presets/vit.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains preset for ViT modules.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import dataclasses
|
| 10 |
+
from typing import Literal
|
| 11 |
+
|
| 12 |
+
ViTPreset = Literal["dinov2l16_384",]
|
| 13 |
+
|
| 14 |
+
MLPMode = Literal["vanilla", "glu"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclasses.dataclass
|
| 18 |
+
class ViTConfig:
|
| 19 |
+
"""Configuration for ViT."""
|
| 20 |
+
|
| 21 |
+
in_chans: int
|
| 22 |
+
embed_dim: int
|
| 23 |
+
depth: int
|
| 24 |
+
num_heads: int
|
| 25 |
+
init_values: float
|
| 26 |
+
|
| 27 |
+
img_size: int = 384
|
| 28 |
+
patch_size: int = 16
|
| 29 |
+
|
| 30 |
+
num_classes: int = 21841
|
| 31 |
+
mlp_ratio: float = 4.0
|
| 32 |
+
drop_rate: float = 0.0
|
| 33 |
+
attn_drop_rate: float = 0.0
|
| 34 |
+
drop_path_rate: float = 0.0
|
| 35 |
+
qkv_bias: bool = True
|
| 36 |
+
global_pool: str = "avg"
|
| 37 |
+
|
| 38 |
+
# Properties for timm_vit.
|
| 39 |
+
mlp_mode: MLPMode = "vanilla"
|
| 40 |
+
|
| 41 |
+
# Properties for SPN.
|
| 42 |
+
intermediate_features_ids: list[int] | None = None
|
| 43 |
+
|
| 44 |
+
def asdict(self):
|
| 45 |
+
"""Convenience method to convert the class to a dict."""
|
| 46 |
+
return dataclasses.asdict(self)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
VIT_CONFIG_DICT: dict[ViTPreset, ViTConfig] = {
|
| 50 |
+
"dinov2l16_384": ViTConfig(
|
| 51 |
+
in_chans=3,
|
| 52 |
+
embed_dim=1024,
|
| 53 |
+
depth=24,
|
| 54 |
+
num_heads=16,
|
| 55 |
+
init_values=1e-5,
|
| 56 |
+
global_pool="",
|
| 57 |
+
),
|
| 58 |
+
}
|
src/sharp/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains utils packages.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
src/sharp/utils/camera.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains utility functionality to render different modalities.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import dataclasses
|
| 10 |
+
from typing import Literal, NamedTuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from .gaussians import Gaussians3D
|
| 16 |
+
from .linalg import eyes
|
| 17 |
+
|
| 18 |
+
TrajetoryType = Literal["swipe", "shake", "rotate", "rotate_forward"]
|
| 19 |
+
LookAtMode = Literal["point", "ahead"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclasses.dataclass
|
| 23 |
+
class CameraInfo:
|
| 24 |
+
"""Camera info for a pinhole camera."""
|
| 25 |
+
|
| 26 |
+
intrinsics: torch.Tensor
|
| 27 |
+
extrinsics: torch.Tensor
|
| 28 |
+
width: int
|
| 29 |
+
height: int
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class FocusRange(NamedTuple):
|
| 33 |
+
"""Parametrizes a range of depth / disparity values."""
|
| 34 |
+
|
| 35 |
+
min: float
|
| 36 |
+
focus: float
|
| 37 |
+
max: float
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclasses.dataclass
|
| 41 |
+
class TrajectoryParams:
|
| 42 |
+
"""Parameters for trajectory."""
|
| 43 |
+
|
| 44 |
+
type: TrajetoryType = "rotate_forward"
|
| 45 |
+
lookat_mode: LookAtMode = "point"
|
| 46 |
+
max_disparity: float = 0.08
|
| 47 |
+
max_zoom: float = 0.15
|
| 48 |
+
distance_m: float = 0.0
|
| 49 |
+
num_steps: int = 60
|
| 50 |
+
num_repeats: int = 1
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def compute_max_offset(
|
| 54 |
+
scene: Gaussians3D,
|
| 55 |
+
params: TrajectoryParams,
|
| 56 |
+
resolution_px: tuple[int, int],
|
| 57 |
+
f_px: float,
|
| 58 |
+
) -> np.ndarray:
|
| 59 |
+
"""Compute the maximum offset for camera along X/Y/Z axis."""
|
| 60 |
+
scene_points = scene.mean_vectors
|
| 61 |
+
extrinsics = torch.eye(4).to(scene_points.device)
|
| 62 |
+
min_depth, _, _ = _compute_depth_quantiles(scene_points, extrinsics)
|
| 63 |
+
|
| 64 |
+
r_px = resolution_px
|
| 65 |
+
diagonal = np.sqrt((r_px[0] / f_px) ** 2 + (r_px[1] / f_px) ** 2)
|
| 66 |
+
max_lateral_offset_m = params.max_disparity * diagonal * min_depth
|
| 67 |
+
|
| 68 |
+
max_medial_offset_m = params.max_zoom * min_depth
|
| 69 |
+
max_offset_xyz_m = np.array([max_lateral_offset_m, max_lateral_offset_m, max_medial_offset_m])
|
| 70 |
+
|
| 71 |
+
return max_offset_xyz_m
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def create_eye_trajectory(
|
| 75 |
+
scene: Gaussians3D,
|
| 76 |
+
params: TrajectoryParams,
|
| 77 |
+
resolution_px: tuple[int, int],
|
| 78 |
+
f_px: float,
|
| 79 |
+
) -> list[torch.Tensor]:
|
| 80 |
+
"""Create eye trajectory for trajectory type."""
|
| 81 |
+
max_offset_xyz_m = compute_max_offset(
|
| 82 |
+
scene,
|
| 83 |
+
params,
|
| 84 |
+
resolution_px,
|
| 85 |
+
f_px,
|
| 86 |
+
)
|
| 87 |
+
# We place the eye trajectory at z=distance plane (default=0),
|
| 88 |
+
# assuming portal plane is placed at z=natural_distance.
|
| 89 |
+
if params.type == "swipe":
|
| 90 |
+
return create_eye_trajectory_swipe(
|
| 91 |
+
max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats
|
| 92 |
+
)
|
| 93 |
+
elif params.type == "shake":
|
| 94 |
+
return create_eye_trajectory_shake(
|
| 95 |
+
max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats
|
| 96 |
+
)
|
| 97 |
+
elif params.type == "rotate":
|
| 98 |
+
return create_eye_trajectory_rotate(
|
| 99 |
+
max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats
|
| 100 |
+
)
|
| 101 |
+
elif params.type == "rotate_forward":
|
| 102 |
+
return create_eye_trajectory_rotate_forward(
|
| 103 |
+
max_offset_xyz_m, params.distance_m, params.num_steps, params.num_repeats
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"Invalid trajectory type {params.type}.")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def create_eye_trajectory_swipe(
|
| 110 |
+
offset_xyz_m: np.ndarray,
|
| 111 |
+
distance_m: float,
|
| 112 |
+
num_steps: int,
|
| 113 |
+
num_repeats: int,
|
| 114 |
+
) -> list[torch.Tensor]:
|
| 115 |
+
"""Create a left to right swipe trajectory."""
|
| 116 |
+
offset_x_m, _, _ = offset_xyz_m
|
| 117 |
+
eye_positions = [
|
| 118 |
+
torch.tensor([x, 0, distance_m], dtype=torch.float32)
|
| 119 |
+
for x in np.linspace(-offset_x_m, offset_x_m, num_steps)
|
| 120 |
+
]
|
| 121 |
+
return eye_positions * num_repeats
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def create_eye_trajectory_shake(
|
| 125 |
+
offset_xyz_m: np.ndarray,
|
| 126 |
+
distance_m: float,
|
| 127 |
+
num_steps: int,
|
| 128 |
+
num_repeats: int,
|
| 129 |
+
) -> list[torch.Tensor]:
|
| 130 |
+
"""Create a left right shake followed by an up down shake trajectory."""
|
| 131 |
+
num_steps_total = num_steps * num_repeats
|
| 132 |
+
num_steps_horizontal = num_steps_total // 2
|
| 133 |
+
num_steps_vertical = num_steps_total - num_steps_horizontal
|
| 134 |
+
|
| 135 |
+
offset_x_m, offset_y_m, _ = offset_xyz_m
|
| 136 |
+
eye_positions: list[torch.Tensor] = []
|
| 137 |
+
eye_positions.extend(
|
| 138 |
+
torch.tensor(
|
| 139 |
+
[offset_x_m * np.sin(2 * np.pi * t), 0.0, distance_m],
|
| 140 |
+
dtype=torch.float32,
|
| 141 |
+
)
|
| 142 |
+
for t in np.linspace(0, num_repeats, num_steps_horizontal)
|
| 143 |
+
)
|
| 144 |
+
eye_positions.extend(
|
| 145 |
+
torch.tensor(
|
| 146 |
+
[0.0, offset_y_m * np.sin(2 * np.pi * t), distance_m],
|
| 147 |
+
dtype=torch.float32,
|
| 148 |
+
)
|
| 149 |
+
for t in np.linspace(0, num_repeats, num_steps_vertical)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return eye_positions
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def create_eye_trajectory_rotate(
|
| 156 |
+
offset_xyz_m: np.ndarray,
|
| 157 |
+
distance_m: float,
|
| 158 |
+
num_steps: int,
|
| 159 |
+
num_repeats: int,
|
| 160 |
+
) -> list[torch.Tensor]:
|
| 161 |
+
"""Create a rotating trajectory."""
|
| 162 |
+
num_steps_total = num_steps * num_repeats
|
| 163 |
+
offset_x_m, offset_y_m, _ = offset_xyz_m
|
| 164 |
+
eye_positions = [
|
| 165 |
+
torch.tensor(
|
| 166 |
+
[
|
| 167 |
+
offset_x_m * np.sin(2 * np.pi * t),
|
| 168 |
+
offset_y_m * np.cos(2 * np.pi * t),
|
| 169 |
+
distance_m,
|
| 170 |
+
],
|
| 171 |
+
dtype=torch.float32,
|
| 172 |
+
)
|
| 173 |
+
for t in np.linspace(0, num_repeats, num_steps_total)
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
return eye_positions
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def create_eye_trajectory_rotate_forward(
|
| 180 |
+
offset_xyz_m: np.ndarray,
|
| 181 |
+
distance_m: float,
|
| 182 |
+
num_steps: int,
|
| 183 |
+
num_repeats: int,
|
| 184 |
+
) -> list[torch.Tensor]:
|
| 185 |
+
"""Create a rotating trajectory."""
|
| 186 |
+
num_steps_total = num_steps * num_repeats
|
| 187 |
+
offset_x_m, _, offset_z_m = offset_xyz_m
|
| 188 |
+
eye_positions = [
|
| 189 |
+
torch.tensor(
|
| 190 |
+
[
|
| 191 |
+
offset_x_m * np.sin(2 * np.pi * t),
|
| 192 |
+
0.0,
|
| 193 |
+
distance_m + offset_z_m * (1.0 - np.cos(2 * np.pi * t)) / 2,
|
| 194 |
+
],
|
| 195 |
+
dtype=torch.float32,
|
| 196 |
+
)
|
| 197 |
+
for t in np.linspace(0, num_repeats, num_steps_total)
|
| 198 |
+
]
|
| 199 |
+
|
| 200 |
+
return eye_positions
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def create_camera_model(
|
| 204 |
+
scene: Gaussians3D,
|
| 205 |
+
intrinsics: torch.Tensor,
|
| 206 |
+
resolution_px: tuple[int, int],
|
| 207 |
+
lookat_mode: LookAtMode = "point",
|
| 208 |
+
) -> PinholeCameraModel:
|
| 209 |
+
"""Create camera model to simulate general pinhole camera."""
|
| 210 |
+
screen_extrinsics = torch.eye(4)
|
| 211 |
+
screen_intrinsics = intrinsics.clone()
|
| 212 |
+
|
| 213 |
+
image_width, image_height = resolution_px
|
| 214 |
+
screen_resolution_px = get_screen_resolution_px_from_input(
|
| 215 |
+
width=image_width, height=image_height
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
screen_intrinsics[0] *= screen_resolution_px[0] / image_width
|
| 219 |
+
screen_intrinsics[1] *= screen_resolution_px[1] / image_height
|
| 220 |
+
|
| 221 |
+
camera_model = PinholeCameraModel(
|
| 222 |
+
scene,
|
| 223 |
+
screen_extrinsics=screen_extrinsics,
|
| 224 |
+
screen_intrinsics=screen_intrinsics,
|
| 225 |
+
screen_resolution_px=screen_resolution_px,
|
| 226 |
+
focus_depth_quantile=0.1,
|
| 227 |
+
min_depth_focus=2.0,
|
| 228 |
+
lookat_mode=lookat_mode,
|
| 229 |
+
)
|
| 230 |
+
return camera_model
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def create_camera_matrix(
|
| 234 |
+
position: torch.Tensor,
|
| 235 |
+
look_at_position: torch.Tensor | None = None,
|
| 236 |
+
world_up: torch.Tensor | None = None,
|
| 237 |
+
inverse: bool = False,
|
| 238 |
+
) -> torch.Tensor:
|
| 239 |
+
"""Create camera matrix from vectors."""
|
| 240 |
+
device = position.device
|
| 241 |
+
|
| 242 |
+
if look_at_position is None:
|
| 243 |
+
look_at_position = torch.zeros(3, device=device)
|
| 244 |
+
if world_up is None:
|
| 245 |
+
world_up = torch.tensor([0.0, 0.0, 1.0], device=device)
|
| 246 |
+
|
| 247 |
+
position, look_at_position, world_up = torch.broadcast_tensors(
|
| 248 |
+
position, look_at_position, world_up
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
camera_front = look_at_position - position
|
| 252 |
+
camera_front = camera_front / camera_front.norm(dim=-1, keepdim=True)
|
| 253 |
+
|
| 254 |
+
camera_right = torch.cross(camera_front, world_up, dim=-1)
|
| 255 |
+
camera_right = camera_right / camera_right.norm(dim=-1, keepdim=True)
|
| 256 |
+
|
| 257 |
+
camera_down = torch.cross(camera_front, camera_right, dim=-1)
|
| 258 |
+
rotation_matrix = torch.stack([camera_right, camera_down, camera_front], dim=-1)
|
| 259 |
+
|
| 260 |
+
matrix = eyes(dim=4, shape=position.shape[:-1], device=device)
|
| 261 |
+
if inverse:
|
| 262 |
+
matrix[..., :3, :3] = rotation_matrix.transpose(-1, -2)
|
| 263 |
+
matrix[..., :3, 3:4] = -rotation_matrix.transpose(-1, -2) @ position[..., None]
|
| 264 |
+
else:
|
| 265 |
+
matrix[..., :3, :3] = rotation_matrix
|
| 266 |
+
matrix[..., :3, 3] = position
|
| 267 |
+
|
| 268 |
+
return matrix
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class PinholeCameraModel:
|
| 272 |
+
"""Camera model that focuses on point."""
|
| 273 |
+
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
scene: Gaussians3D,
|
| 277 |
+
screen_extrinsics: torch.Tensor,
|
| 278 |
+
screen_intrinsics: torch.Tensor,
|
| 279 |
+
screen_resolution_px: tuple[int, int],
|
| 280 |
+
focus_depth_quantile: float = 0.1,
|
| 281 |
+
min_depth_focus: float = 2.0,
|
| 282 |
+
lookat_point: tuple[float, float, float] | None = None,
|
| 283 |
+
lookat_mode: LookAtMode = "point",
|
| 284 |
+
) -> None:
|
| 285 |
+
"""Initialize GeneralPinholeCameraModel.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
scene: The scene to display.
|
| 289 |
+
screen_extrinsics: Extrinsics of the default position.
|
| 290 |
+
screen_intrinsics: Intrinsics to use for rendering.
|
| 291 |
+
screen_resolution_px: Width and height to render.
|
| 292 |
+
focus_depth_quantile: Where inside the depth range to focus on.
|
| 293 |
+
min_depth_focus: Depth to focus at.
|
| 294 |
+
lookat_point: a point that the camera's Z axis directs towards.
|
| 295 |
+
lookat_mode: "point" to look at a fixed point,
|
| 296 |
+
"ahead" to look straight ahead.
|
| 297 |
+
"""
|
| 298 |
+
self.scene = scene
|
| 299 |
+
self.screen_extrinsics = screen_extrinsics
|
| 300 |
+
self.screen_intrinsics = screen_intrinsics
|
| 301 |
+
self.screen_resolution_px = screen_resolution_px
|
| 302 |
+
|
| 303 |
+
self.focus_depth_quantile = focus_depth_quantile
|
| 304 |
+
self.min_depth_focus = min_depth_focus
|
| 305 |
+
self.lookat_point = lookat_point
|
| 306 |
+
self.lookat_mode = lookat_mode
|
| 307 |
+
|
| 308 |
+
scene_points = scene.mean_vectors
|
| 309 |
+
if scene_points.ndim == 3:
|
| 310 |
+
scene_points = scene_points[0]
|
| 311 |
+
elif scene_points.ndim != 2:
|
| 312 |
+
raise ValueError("Unsupported dimensionality of scene points.")
|
| 313 |
+
self._scene_points = scene_points.cpu()
|
| 314 |
+
|
| 315 |
+
self.depth_quantiles = _compute_depth_quantiles(
|
| 316 |
+
self._scene_points,
|
| 317 |
+
self.screen_extrinsics,
|
| 318 |
+
q_focus=self.focus_depth_quantile,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
def compute(self, eye_pos: torch.Tensor) -> CameraInfo:
|
| 322 |
+
"""Compute camera for eye position."""
|
| 323 |
+
extrinsics = self.screen_extrinsics.clone()
|
| 324 |
+
|
| 325 |
+
origin = eye_pos if self.lookat_mode == "ahead" else torch.zeros(3)
|
| 326 |
+
|
| 327 |
+
if self.lookat_point is None:
|
| 328 |
+
depth_focus = max(self.min_depth_focus, self.depth_quantiles.focus)
|
| 329 |
+
look_at_position = origin + torch.tensor([0.0, 0.0, depth_focus])
|
| 330 |
+
else:
|
| 331 |
+
look_at_position = origin + torch.tensor([*self.lookat_point])
|
| 332 |
+
|
| 333 |
+
world_up = torch.tensor([0.0, -1.0, 0.0])
|
| 334 |
+
extrinsics_modifier = create_camera_matrix(
|
| 335 |
+
eye_pos, look_at_position, world_up, inverse=True
|
| 336 |
+
)
|
| 337 |
+
extrinsics = extrinsics_modifier @ self.screen_extrinsics
|
| 338 |
+
|
| 339 |
+
camera_info = CameraInfo(
|
| 340 |
+
intrinsics=self.screen_intrinsics,
|
| 341 |
+
extrinsics=extrinsics,
|
| 342 |
+
width=self.screen_resolution_px[0],
|
| 343 |
+
height=self.screen_resolution_px[1],
|
| 344 |
+
)
|
| 345 |
+
return camera_info
|
| 346 |
+
|
| 347 |
+
def set_screen_extrinsics(self, new_value: torch.Tensor) -> None:
|
| 348 |
+
"""Modify the default extrinsics."""
|
| 349 |
+
self.screen_extrinsics = new_value
|
| 350 |
+
self.depth_quantiles = _compute_depth_quantiles(self._scene_points, self.screen_extrinsics)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def get_screen_resolution_px_from_input(width: int, height: int) -> tuple[int, int]:
|
| 354 |
+
"""Get resolution for metadata dictionary."""
|
| 355 |
+
resolution_px = (width, height)
|
| 356 |
+
# halve the dimensions for super large image
|
| 357 |
+
if resolution_px[1] > 3000:
|
| 358 |
+
resolution_px = (resolution_px[0] // 2, resolution_px[1] // 2)
|
| 359 |
+
# for mp4 compatibility, enforce dimensions to even number,
|
| 360 |
+
# otherwise could not be played in browser
|
| 361 |
+
if resolution_px[0] % 2 != 0:
|
| 362 |
+
resolution_px = (resolution_px[0] + 1, resolution_px[1])
|
| 363 |
+
if resolution_px[1] % 2 != 0:
|
| 364 |
+
resolution_px = (resolution_px[0], resolution_px[1] + 1)
|
| 365 |
+
return resolution_px
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _compute_depth_quantiles(
|
| 369 |
+
points: torch.Tensor,
|
| 370 |
+
extrinsics: torch.Tensor,
|
| 371 |
+
q_near: float = 0.001,
|
| 372 |
+
q_focus: float = 0.1,
|
| 373 |
+
q_far: float = 0.999,
|
| 374 |
+
) -> FocusRange:
|
| 375 |
+
"""Compute disparity quantiles for scene and extrinsics id."""
|
| 376 |
+
points_local = points @ extrinsics[:3, :3].T + extrinsics[:3, 3]
|
| 377 |
+
depth_values = points_local[..., 2].flatten()
|
| 378 |
+
depth_values = depth_values[depth_values > 0]
|
| 379 |
+
q_values = torch.tensor([q_near, q_focus, q_far])
|
| 380 |
+
depth_quantiles_pt = torch.quantile(depth_values.cpu(), q_values)
|
| 381 |
+
depth_quantiles = FocusRange(
|
| 382 |
+
min=float(depth_quantiles_pt[0]),
|
| 383 |
+
focus=float(depth_quantiles_pt[1]),
|
| 384 |
+
max=float(depth_quantiles_pt[2]),
|
| 385 |
+
)
|
| 386 |
+
return depth_quantiles
|
src/sharp/utils/color_space.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains color space utility functions.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Literal
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from sharp.utils.robust import robust_where
|
| 15 |
+
|
| 16 |
+
LOGGER = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
ColorSpace = Literal["sRGB", "linearRGB"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def encode_color_space(color_space: ColorSpace) -> int:
|
| 22 |
+
"""Encode color space to integer."""
|
| 23 |
+
return 0 if color_space == "sRGB" else 1
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def decode_color_space(color_space_index: int) -> ColorSpace:
|
| 27 |
+
"""Decode color space index to color space."""
|
| 28 |
+
return "sRGB" if color_space_index == 0 else "linearRGB"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def sRGB2linearRGB(sRGB: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
"""SRGB to linearRGB conversion function.
|
| 33 |
+
|
| 34 |
+
Reference:
|
| 35 |
+
https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
| 36 |
+
Section 7.7.7
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
sRGB: Input image tensor in sRGB space.
|
| 40 |
+
"""
|
| 41 |
+
# We need to use robust_where to clamp the second branch.
|
| 42 |
+
# Otherwise, torch.where will lead to NaN in the backward pass, see
|
| 43 |
+
# https://github.com/pytorch/pytorch/issues/68425
|
| 44 |
+
THRESHOLD = 0.04045
|
| 45 |
+
|
| 46 |
+
def branch_true_func(x):
|
| 47 |
+
return x / 12.92
|
| 48 |
+
|
| 49 |
+
def branch_false_func(x):
|
| 50 |
+
return ((x + 0.055) / 1.055) ** 2.4
|
| 51 |
+
|
| 52 |
+
return robust_where(
|
| 53 |
+
sRGB <= THRESHOLD,
|
| 54 |
+
sRGB,
|
| 55 |
+
branch_true_func,
|
| 56 |
+
branch_false_func,
|
| 57 |
+
branch_false_safe_value=THRESHOLD,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def linearRGB2sRGB(linearRGB: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
"""LinearRGB to sRGB conversion function.
|
| 63 |
+
|
| 64 |
+
Reference:
|
| 65 |
+
https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
| 66 |
+
Section 7.7.7
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
linearRGB: Input image tensor in linearRGB space.
|
| 70 |
+
"""
|
| 71 |
+
# We need to use robust_where to clamp the second branch.
|
| 72 |
+
# Otherwise, torch.where will lead to NaN in the backward pass, see
|
| 73 |
+
# https://github.com/pytorch/pytorch/issues/68425
|
| 74 |
+
THRESHOLD = 0.0031308
|
| 75 |
+
|
| 76 |
+
def branch_true_func(x):
|
| 77 |
+
return x * 12.92
|
| 78 |
+
|
| 79 |
+
def branch_false_func(x):
|
| 80 |
+
return 1.055 * (x ** (1 / 2.4)) - 0.055
|
| 81 |
+
|
| 82 |
+
return robust_where(
|
| 83 |
+
linearRGB <= THRESHOLD,
|
| 84 |
+
linearRGB,
|
| 85 |
+
branch_true_func,
|
| 86 |
+
branch_false_func,
|
| 87 |
+
branch_false_safe_value=THRESHOLD,
|
| 88 |
+
)
|
src/sharp/utils/gaussians.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains basic data structures and functionality for 3D Gaussians.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Literal, NamedTuple
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
from plyfile import PlyData, PlyElement
|
| 16 |
+
|
| 17 |
+
from sharp.utils import color_space as cs_utils
|
| 18 |
+
from sharp.utils import linalg
|
| 19 |
+
|
| 20 |
+
LOGGER = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
BackgroundColor = Literal["black", "white", "random_color", "random_pixel"]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Gaussians3D(NamedTuple):
|
| 27 |
+
"""Represents a collection of 3D Gaussians."""
|
| 28 |
+
|
| 29 |
+
mean_vectors: torch.Tensor
|
| 30 |
+
singular_values: torch.Tensor
|
| 31 |
+
quaternions: torch.Tensor
|
| 32 |
+
colors: torch.Tensor
|
| 33 |
+
opacities: torch.Tensor
|
| 34 |
+
|
| 35 |
+
def to(self, device: torch.device) -> Gaussians3D:
|
| 36 |
+
"""Move Gaussians to device."""
|
| 37 |
+
return Gaussians3D(
|
| 38 |
+
mean_vectors=self.mean_vectors.to(device),
|
| 39 |
+
singular_values=self.singular_values.to(device),
|
| 40 |
+
quaternions=self.quaternions.to(device),
|
| 41 |
+
colors=self.colors.to(device),
|
| 42 |
+
opacities=self.opacities.to(device),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SceneMetaData(NamedTuple):
|
| 47 |
+
"""Meta data about Gaussian scene."""
|
| 48 |
+
|
| 49 |
+
focal_length_px: float
|
| 50 |
+
resolution_px: tuple[int, int]
|
| 51 |
+
color_space: cs_utils.ColorSpace
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_unprojection_matrix(
|
| 55 |
+
extrinsics: torch.Tensor,
|
| 56 |
+
intrinsics: torch.Tensor,
|
| 57 |
+
image_shape: tuple[int, int],
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
"""Compute unprojection matrix to transform Gaussians to Euclidean space.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
extrinsics: The 4x4 extrinsics matrix of the camera view.
|
| 63 |
+
intrinsics: The 4x4 intrinsics matrix of the camera view.
|
| 64 |
+
image_shape: The (width, height) of the input image.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
A 4x4 matrix to transform Gaussians from NDC space to Euclidean space.
|
| 68 |
+
"""
|
| 69 |
+
device = intrinsics.device
|
| 70 |
+
image_width, image_height = image_shape
|
| 71 |
+
# This matrix converts OpenCV pixel coordinates to NDC coordinates where
|
| 72 |
+
# (-1, 1) denotes the top left and (1, 1) the bottom right of the image.
|
| 73 |
+
#
|
| 74 |
+
# Note that premultiplying the intrinsics with ndc_matrix typically yields a matrix
|
| 75 |
+
# that simply scales the x-axis by 2 * focal_length / image_width and the y-axis by
|
| 76 |
+
# 2 * focal_length / image_height.
|
| 77 |
+
ndc_matrix = torch.tensor(
|
| 78 |
+
[
|
| 79 |
+
[2.0 / image_width, 0.0, -1.0, 0.0],
|
| 80 |
+
[0.0, 2.0 / image_height, -1.0, 0.0],
|
| 81 |
+
[0.0, 0.0, 1.0, 0.0],
|
| 82 |
+
[0.0, 0.0, 0.0, 1.0],
|
| 83 |
+
],
|
| 84 |
+
device=device,
|
| 85 |
+
)
|
| 86 |
+
return torch.linalg.inv(ndc_matrix @ intrinsics @ extrinsics)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def unproject_gaussians(
|
| 90 |
+
gaussians_ndc: Gaussians3D,
|
| 91 |
+
extrinsics: torch.Tensor,
|
| 92 |
+
intrinsics: torch.Tensor,
|
| 93 |
+
image_shape: tuple[int, int],
|
| 94 |
+
) -> Gaussians3D:
|
| 95 |
+
"""Unproject Gaussians from NDC space to world coordinates."""
|
| 96 |
+
unprojection_matrix = get_unprojection_matrix(extrinsics, intrinsics, image_shape)
|
| 97 |
+
gaussians = apply_transform(gaussians_ndc, unprojection_matrix[:3])
|
| 98 |
+
return gaussians
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def apply_transform(gaussians: Gaussians3D, transform: torch.Tensor) -> Gaussians3D:
|
| 102 |
+
"""Apply an affine transformation to 3D Gaussians.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
gaussians: The Gaussians to transform.
|
| 106 |
+
transform: An affine transform with shape 3x4.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
The transformed Gaussians.
|
| 110 |
+
|
| 111 |
+
Note: This operation is not differentiable.
|
| 112 |
+
"""
|
| 113 |
+
transform_linear = transform[..., :3, :3]
|
| 114 |
+
transform_offset = transform[..., :3, 3]
|
| 115 |
+
|
| 116 |
+
mean_vectors = gaussians.mean_vectors @ transform_linear.T + transform_offset
|
| 117 |
+
covariance_matrices = compose_covariance_matrices(
|
| 118 |
+
gaussians.quaternions, gaussians.singular_values
|
| 119 |
+
)
|
| 120 |
+
covariance_matrices = (
|
| 121 |
+
transform_linear @ covariance_matrices @ transform_linear.transpose(-1, -2)
|
| 122 |
+
)
|
| 123 |
+
quaternions, singular_values = decompose_covariance_matrices(covariance_matrices)
|
| 124 |
+
|
| 125 |
+
return Gaussians3D(
|
| 126 |
+
mean_vectors=mean_vectors,
|
| 127 |
+
singular_values=singular_values,
|
| 128 |
+
quaternions=quaternions,
|
| 129 |
+
colors=gaussians.colors,
|
| 130 |
+
opacities=gaussians.opacities,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def decompose_covariance_matrices(
|
| 135 |
+
covariance_matrices: torch.Tensor,
|
| 136 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 137 |
+
"""Decompose 3D covariance matrices into quaternions and singular values.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
covariance_matrices: The covariance matrices to decompose.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Quaternion and singular values corresponding to the orientation and scales of
|
| 144 |
+
the diagonalized matrix.
|
| 145 |
+
|
| 146 |
+
Note: This operation is not differentiable.
|
| 147 |
+
"""
|
| 148 |
+
device = covariance_matrices.device
|
| 149 |
+
dtype = covariance_matrices.dtype
|
| 150 |
+
|
| 151 |
+
# We convert to fp64 to avoid numerical errors.
|
| 152 |
+
covariance_matrices = covariance_matrices.detach().cpu().to(torch.float64)
|
| 153 |
+
rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices)
|
| 154 |
+
|
| 155 |
+
# NOTE: in SVD, it is possible that U and VT are both reflections.
|
| 156 |
+
# We need to correct them.
|
| 157 |
+
batch_idx, gaussian_idx = torch.where(torch.linalg.det(rotations) < 0)
|
| 158 |
+
num_reflections = len(gaussian_idx)
|
| 159 |
+
if num_reflections > 0:
|
| 160 |
+
LOGGER.warning(
|
| 161 |
+
"Received %d reflection matrices from SVD. Flipping them to rotations.",
|
| 162 |
+
num_reflections,
|
| 163 |
+
)
|
| 164 |
+
# Flip the last column of reflection and make it a rotation.
|
| 165 |
+
rotations[batch_idx, gaussian_idx, :, -1] *= -1
|
| 166 |
+
quaternions = linalg.quaternions_from_rotation_matrices(rotations)
|
| 167 |
+
quaternions = quaternions.to(dtype=dtype, device=device)
|
| 168 |
+
singular_values = singular_values_2.sqrt().to(dtype=dtype, device=device)
|
| 169 |
+
return quaternions, singular_values
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def compose_covariance_matrices(
|
| 173 |
+
quaternions: torch.Tensor, singular_values: torch.Tensor
|
| 174 |
+
) -> torch.Tensor:
|
| 175 |
+
"""Compose 3D covariance matrices into quaternions and singular values.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
quaternions: The quaternions describing the principal basis.
|
| 179 |
+
singular_values: The scales of the diagonalized matrix.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
The 3x3 covariances matrices.
|
| 183 |
+
"""
|
| 184 |
+
device = quaternions.device
|
| 185 |
+
rotations = linalg.rotation_matrices_from_quaternions(quaternions)
|
| 186 |
+
diagonal_matrix = torch.eye(3, device=device) * singular_values[..., :, None]
|
| 187 |
+
return rotations @ diagonal_matrix.square() @ rotations.transpose(-1, -2)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def convert_spherical_harmonics_to_rgb(sh0: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
"""Convert degree-0 spherical harmonics to RGB.
|
| 192 |
+
|
| 193 |
+
Reference:
|
| 194 |
+
https://en.wikipedia.org/wiki/Table_of_spherical_harmonics
|
| 195 |
+
"""
|
| 196 |
+
coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi))
|
| 197 |
+
return sh0 * coeff_degree0 + 0.5
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def convert_rgb_to_spherical_harmonics(rgb: torch.Tensor) -> torch.Tensor:
|
| 201 |
+
"""Convert RGB to degree-0 spherical harmonics.
|
| 202 |
+
|
| 203 |
+
Reference:
|
| 204 |
+
https://en.wikipedia.org/wiki/Table_of_spherical_harmonics
|
| 205 |
+
"""
|
| 206 |
+
coeff_degree0 = np.sqrt(1.0 / (4.0 * np.pi))
|
| 207 |
+
return (rgb - 0.5) / coeff_degree0
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def load_ply(path: Path) -> tuple[Gaussians3D, SceneMetaData]:
|
| 211 |
+
"""Loads a ply from a file."""
|
| 212 |
+
plydata = PlyData.read(path)
|
| 213 |
+
|
| 214 |
+
vertices = next(filter(lambda x: x.name == "vertex", plydata.elements))
|
| 215 |
+
|
| 216 |
+
properties = ["x", "y", "z"]
|
| 217 |
+
properties.extend([f"f_dc_{i}" for i in range(3)])
|
| 218 |
+
properties.extend([f"scale_{i}" for i in range(3)])
|
| 219 |
+
properties.extend([f"rot_{i}" for i in range(3)])
|
| 220 |
+
|
| 221 |
+
for prop in properties:
|
| 222 |
+
if prop not in vertices:
|
| 223 |
+
raise KeyError(f"Incompatible ply file: property {prop} not found in ply elements.")
|
| 224 |
+
mean_vectors = np.stack(
|
| 225 |
+
(
|
| 226 |
+
np.asarray(vertices["x"]),
|
| 227 |
+
np.asarray(vertices["y"]),
|
| 228 |
+
np.asarray(vertices["z"]),
|
| 229 |
+
),
|
| 230 |
+
axis=1,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
scale_logits = np.stack(
|
| 234 |
+
(
|
| 235 |
+
np.asarray(vertices["scale_0"]),
|
| 236 |
+
np.asarray(vertices["scale_1"]),
|
| 237 |
+
np.asarray(vertices["scale_2"]),
|
| 238 |
+
),
|
| 239 |
+
axis=1,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
quaternions = np.stack(
|
| 243 |
+
(
|
| 244 |
+
np.asarray(vertices["rot_0"]),
|
| 245 |
+
np.asarray(vertices["rot_1"]),
|
| 246 |
+
np.asarray(vertices["rot_2"]),
|
| 247 |
+
np.asarray(vertices["rot_3"]),
|
| 248 |
+
),
|
| 249 |
+
axis=1,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
spherical_harmonics_deg0 = np.stack(
|
| 253 |
+
(
|
| 254 |
+
np.asarray(vertices["f_dc_0"]),
|
| 255 |
+
np.asarray(vertices["f_dc_1"]),
|
| 256 |
+
np.asarray(vertices["f_dc_2"]),
|
| 257 |
+
),
|
| 258 |
+
axis=1,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
colors = convert_spherical_harmonics_to_rgb(spherical_harmonics_deg0)
|
| 262 |
+
|
| 263 |
+
opacity_logits = np.asarray(vertices["opacity"])[..., None]
|
| 264 |
+
|
| 265 |
+
supplement_elements = [element for element in plydata.elements if element.name != "vertex"]
|
| 266 |
+
supplement_data: dict[str, Any] = {}
|
| 267 |
+
supplement_keys = ["extrinsic", "intrinsic", "color_space", "image_size"]
|
| 268 |
+
|
| 269 |
+
for element in supplement_elements:
|
| 270 |
+
for key in supplement_keys:
|
| 271 |
+
if key not in supplement_data and key in element:
|
| 272 |
+
supplement_data[key] = np.asarray(element[key])
|
| 273 |
+
|
| 274 |
+
# Parse intrinsics and image_size.
|
| 275 |
+
if "intrinsic" in supplement_data:
|
| 276 |
+
intrinsics_data = supplement_data["intrinsic"]
|
| 277 |
+
|
| 278 |
+
# Legacy: image_size is contained in intrinsic element.
|
| 279 |
+
if "image_size" not in supplement_data:
|
| 280 |
+
if len(intrinsics_data) != 4:
|
| 281 |
+
raise ValueError(
|
| 282 |
+
"Expect legacy intrinsics with len=4 containing image size, "
|
| 283 |
+
f"but received len={len(intrinsics_data)}"
|
| 284 |
+
)
|
| 285 |
+
focal_length_px = (intrinsics_data[0], intrinsics_data[1])
|
| 286 |
+
width = int(intrinsics_data[2])
|
| 287 |
+
height = int(intrinsics_data[3])
|
| 288 |
+
|
| 289 |
+
else:
|
| 290 |
+
if len(intrinsics_data) != 9:
|
| 291 |
+
raise ValueError(
|
| 292 |
+
"Expect 9 elements in intrinsics, " f"but received {len(intrinsics_data)}."
|
| 293 |
+
)
|
| 294 |
+
intrinsics_matrix = intrinsics_data.reshape((3, 3))
|
| 295 |
+
focal_length_px = (intrinsics_matrix[0, 0], intrinsics_matrix[1, 1])
|
| 296 |
+
|
| 297 |
+
image_size_data = supplement_data["image_size"]
|
| 298 |
+
width = image_size_data[0]
|
| 299 |
+
height = image_size_data[1]
|
| 300 |
+
|
| 301 |
+
# Default to VGA resolution: focal length = 512, image size = (640, 480).
|
| 302 |
+
else:
|
| 303 |
+
focal_length_px = (512, 512)
|
| 304 |
+
width = 640
|
| 305 |
+
height = 480
|
| 306 |
+
|
| 307 |
+
# Parse extrinsics.
|
| 308 |
+
extrinsics_data = supplement_data.get("extrinsic", np.eye(4).flatten())
|
| 309 |
+
extrinsics_matrix = np.eye(4)
|
| 310 |
+
|
| 311 |
+
# Legacy: extrinsics store 12 elements.
|
| 312 |
+
if len(extrinsics_data) == 12:
|
| 313 |
+
extrinsics_matrix[:3] = extrinsics_data.reshape((3, 4))
|
| 314 |
+
extrinsics_matrix[:3, :3] = extrinsics_matrix[:3, :3].copy().T
|
| 315 |
+
elif len(extrinsics_data) == 16:
|
| 316 |
+
extrinsics_matrix[:] = extrinsics_data.reshape((4, 4))
|
| 317 |
+
else:
|
| 318 |
+
raise ValueError(f"Unrecognized extrinsics matrix shape {len(extrinsics_data)}")
|
| 319 |
+
|
| 320 |
+
# Parse color space.
|
| 321 |
+
color_space_index = supplement_data.get("color_space", 1)
|
| 322 |
+
color_space = cs_utils.decode_color_space(color_space_index)
|
| 323 |
+
if color_space == "sRGB":
|
| 324 |
+
colors = cs_utils.sRGB2linearRGB(colors)
|
| 325 |
+
|
| 326 |
+
mean_vectors = torch.from_numpy(mean_vectors).view(1, -1, 3).float()
|
| 327 |
+
quaternions = torch.from_numpy(quaternions).view(1, -1, 4).float()
|
| 328 |
+
singular_values = torch.exp(torch.from_numpy(scale_logits).view(1, -1, 3)).float()
|
| 329 |
+
opacities = torch.sigmoid(torch.from_numpy(opacity_logits).view(1, -1)).float()
|
| 330 |
+
colors = torch.from_numpy(colors).view(1, -1, 3).float()
|
| 331 |
+
|
| 332 |
+
gaussians = Gaussians3D(
|
| 333 |
+
mean_vectors=mean_vectors,
|
| 334 |
+
quaternions=quaternions,
|
| 335 |
+
singular_values=singular_values,
|
| 336 |
+
opacities=opacities,
|
| 337 |
+
colors=colors,
|
| 338 |
+
)
|
| 339 |
+
metadata = SceneMetaData(focal_length_px[0], (width, height), color_space)
|
| 340 |
+
return gaussians, metadata
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@torch.no_grad()
|
| 344 |
+
def save_ply(
|
| 345 |
+
gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path
|
| 346 |
+
) -> PlyData:
|
| 347 |
+
"""Save a predicted Gaussian3D to a ply file."""
|
| 348 |
+
|
| 349 |
+
def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor:
|
| 350 |
+
return torch.log(tensor / (1.0 - tensor))
|
| 351 |
+
|
| 352 |
+
xyz = gaussians.mean_vectors.flatten(0, 1)
|
| 353 |
+
scale_logits = torch.log(gaussians.singular_values).flatten(0, 1)
|
| 354 |
+
quaternions = gaussians.quaternions.flatten(0, 1)
|
| 355 |
+
|
| 356 |
+
# SHARP takes an image, convert it to sRGB color space as input,
|
| 357 |
+
# and predicts linearRGB Gaussians as output.
|
| 358 |
+
# The SHARP renderer would blend linearRGB Gaussians and convert rendered images and videos
|
| 359 |
+
# back to sRGB for the best display quality.
|
| 360 |
+
#
|
| 361 |
+
# However, public renderers do not have such linear2sRGB conversions after rendering.
|
| 362 |
+
# If they render linearRGB Gaussians as-is, the output would be dark without Gamma correction.
|
| 363 |
+
#
|
| 364 |
+
# To make it compatible to public renderers, we force convert linearRGB to sRGB during export.
|
| 365 |
+
# - The SHARP renderer will still handle conversions properly.
|
| 366 |
+
# - Public renderers will be mostly working fine when regarding sRGB images as linearRGB images,
|
| 367 |
+
# although for the best performance, it is recommended to apply the conversions.
|
| 368 |
+
colors = convert_rgb_to_spherical_harmonics(
|
| 369 |
+
cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1))
|
| 370 |
+
)
|
| 371 |
+
color_space_index = cs_utils.encode_color_space("sRGB")
|
| 372 |
+
|
| 373 |
+
# Store opacity logits.
|
| 374 |
+
opacity_logits = _inverse_sigmoid(gaussians.opacities).flatten(0, 1).unsqueeze(-1)
|
| 375 |
+
|
| 376 |
+
attributes = torch.cat(
|
| 377 |
+
(
|
| 378 |
+
xyz,
|
| 379 |
+
colors,
|
| 380 |
+
opacity_logits,
|
| 381 |
+
scale_logits,
|
| 382 |
+
quaternions,
|
| 383 |
+
),
|
| 384 |
+
dim=1,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
dtype_full = [
|
| 388 |
+
(attribute, "f4")
|
| 389 |
+
for attribute in ["x", "y", "z"]
|
| 390 |
+
+ [f"f_dc_{i}" for i in range(3)]
|
| 391 |
+
+ ["opacity"]
|
| 392 |
+
+ [f"scale_{i}" for i in range(3)]
|
| 393 |
+
+ [f"rot_{i}" for i in range(4)]
|
| 394 |
+
]
|
| 395 |
+
|
| 396 |
+
num_gaussians = len(xyz)
|
| 397 |
+
elements = np.empty(num_gaussians, dtype=dtype_full)
|
| 398 |
+
elements[:] = list(map(tuple, attributes.detach().cpu().numpy()))
|
| 399 |
+
vertex_elements = PlyElement.describe(elements, "vertex")
|
| 400 |
+
|
| 401 |
+
# Load image-wise metadata.
|
| 402 |
+
image_height, image_width = image_shape
|
| 403 |
+
|
| 404 |
+
# Export image size.
|
| 405 |
+
dtype_image_size = [("image_size", "u4")]
|
| 406 |
+
image_size_array = np.empty(2, dtype=dtype_image_size)
|
| 407 |
+
image_size_array[:] = np.array([image_width, image_height])
|
| 408 |
+
image_size_element = PlyElement.describe(image_size_array, "image_size")
|
| 409 |
+
|
| 410 |
+
# Export intrinsics.
|
| 411 |
+
dtype_intrinsic = [("intrinsic", "f4")]
|
| 412 |
+
intrinsic_array = np.empty(9, dtype=dtype_intrinsic)
|
| 413 |
+
intrinsic = np.array(
|
| 414 |
+
[
|
| 415 |
+
f_px,
|
| 416 |
+
0,
|
| 417 |
+
image_width * 0.5,
|
| 418 |
+
0,
|
| 419 |
+
f_px,
|
| 420 |
+
image_height * 0.5,
|
| 421 |
+
0,
|
| 422 |
+
0,
|
| 423 |
+
1,
|
| 424 |
+
]
|
| 425 |
+
)
|
| 426 |
+
intrinsic_array[:] = intrinsic.flatten()
|
| 427 |
+
intrinsic_element = PlyElement.describe(intrinsic_array, "intrinsic")
|
| 428 |
+
|
| 429 |
+
# Export dummy extrinsics.
|
| 430 |
+
dtype_extrinsic = [("extrinsic", "f4")]
|
| 431 |
+
extrinsic_array = np.empty(16, dtype=dtype_extrinsic)
|
| 432 |
+
extrinsic_array[:] = np.eye(4).flatten()
|
| 433 |
+
extrinsic_element = PlyElement.describe(extrinsic_array, "extrinsic")
|
| 434 |
+
|
| 435 |
+
# Export number of frames and particles per frame.
|
| 436 |
+
dtype_frames = [("frame", "i4")]
|
| 437 |
+
frame_array = np.empty(2, dtype=dtype_frames)
|
| 438 |
+
frame_array[:] = np.array([1, num_gaussians], dtype=np.int32)
|
| 439 |
+
frame_element = PlyElement.describe(frame_array, "frame")
|
| 440 |
+
|
| 441 |
+
# Export disparity ranges for transform.
|
| 442 |
+
dtype_disparity = [("disparity", "f4")]
|
| 443 |
+
disparity_array = np.empty(2, dtype=dtype_disparity)
|
| 444 |
+
|
| 445 |
+
disparity = 1.0 / gaussians.mean_vectors[0, ..., -1]
|
| 446 |
+
quantiles = (
|
| 447 |
+
torch.quantile(disparity, q=torch.tensor([0.1, 0.9], device=disparity.device))
|
| 448 |
+
.float()
|
| 449 |
+
.cpu()
|
| 450 |
+
.numpy()
|
| 451 |
+
)
|
| 452 |
+
disparity_array[:] = quantiles
|
| 453 |
+
disparity_element = PlyElement.describe(disparity_array, "disparity")
|
| 454 |
+
|
| 455 |
+
# Export colorspace.
|
| 456 |
+
dtype_color_space = [("color_space", "u1")]
|
| 457 |
+
color_space_array = np.empty(1, dtype=dtype_color_space)
|
| 458 |
+
color_space_array[:] = np.array([color_space_index]).flatten()
|
| 459 |
+
color_space_element = PlyElement.describe(color_space_array, "color_space")
|
| 460 |
+
|
| 461 |
+
dtype_version = [("version", "u1")]
|
| 462 |
+
version_array = np.empty(3, dtype=dtype_version)
|
| 463 |
+
version_array[:] = np.array([1, 5, 0], dtype=np.uint8).flatten()
|
| 464 |
+
version_element = PlyElement.describe(version_array, "version")
|
| 465 |
+
|
| 466 |
+
plydata = PlyData(
|
| 467 |
+
[
|
| 468 |
+
vertex_elements,
|
| 469 |
+
extrinsic_element,
|
| 470 |
+
intrinsic_element,
|
| 471 |
+
image_size_element,
|
| 472 |
+
frame_element,
|
| 473 |
+
disparity_element,
|
| 474 |
+
color_space_element,
|
| 475 |
+
version_element,
|
| 476 |
+
]
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
plydata.write(path)
|
| 480 |
+
return plydata
|
src/sharp/utils/gsplat.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains utility code for gsplat renderer.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import NamedTuple
|
| 11 |
+
|
| 12 |
+
import gsplat
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
from sharp.utils import color_space as cs_utils
|
| 17 |
+
from sharp.utils import io, vis
|
| 18 |
+
from sharp.utils.gaussians import BackgroundColor, Gaussians3D
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RenderingOutputs(NamedTuple):
|
| 22 |
+
"""Outputs of 3D Gaussians renderer."""
|
| 23 |
+
|
| 24 |
+
color: torch.Tensor
|
| 25 |
+
depth: torch.Tensor
|
| 26 |
+
alpha: torch.Tensor
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def write_renderings(rendering: RenderingOutputs, output_folder: Path, filename: str):
|
| 30 |
+
"""Write rendered color/depth/alpha to files."""
|
| 31 |
+
batch_size = len(rendering.color)
|
| 32 |
+
if batch_size != 1:
|
| 33 |
+
raise RuntimeError("We only support saving rendering of batch size = 1")
|
| 34 |
+
|
| 35 |
+
def _save_image_tensor(tensor: torch.Tensor, suffix: str):
|
| 36 |
+
np_array = tensor.permute(1, 2, 0).numpy()
|
| 37 |
+
io.save_image(np_array, (output_folder / filename).with_suffix(suffix))
|
| 38 |
+
|
| 39 |
+
color = (rendering.color[0].cpu() * 255.0).to(dtype=torch.uint8)
|
| 40 |
+
colorized_depth = vis.colorize_depth(rendering.depth[0], val_max=100.0)
|
| 41 |
+
colorized_alpha = vis.colorize_alpha(rendering.alpha[0])
|
| 42 |
+
|
| 43 |
+
_save_image_tensor(color, ".color.png")
|
| 44 |
+
_save_image_tensor(colorized_depth, ".depth.png")
|
| 45 |
+
_save_image_tensor(colorized_alpha, ".alpha.png")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class GSplatRenderer(nn.Module):
|
| 49 |
+
"""Module to render 3D Gaussians to images using gsplat."""
|
| 50 |
+
|
| 51 |
+
color_space: cs_utils.ColorSpace
|
| 52 |
+
background_color: BackgroundColor
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
color_space: cs_utils.ColorSpace = "sRGB",
|
| 57 |
+
background_color: BackgroundColor = "black",
|
| 58 |
+
low_pass_filter_eps: float = 0.0,
|
| 59 |
+
) -> None:
|
| 60 |
+
"""Initialize gsplat renderer.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
color_space: The color space to use for rendering.
|
| 64 |
+
background_color: The background color to use for rendering.
|
| 65 |
+
low_pass_filter_eps: The epsilon value for the low pass filter.
|
| 66 |
+
"""
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.color_space = color_space
|
| 69 |
+
self.background_color = background_color
|
| 70 |
+
self.low_pass_filter_eps = low_pass_filter_eps
|
| 71 |
+
|
| 72 |
+
def forward(
|
| 73 |
+
self,
|
| 74 |
+
gaussians: Gaussians3D,
|
| 75 |
+
extrinsics: torch.Tensor,
|
| 76 |
+
intrinsics: torch.Tensor,
|
| 77 |
+
image_width: int,
|
| 78 |
+
image_height: int,
|
| 79 |
+
) -> RenderingOutputs:
|
| 80 |
+
"""Predict images from gaussians.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
gaussians: The Gaussians to render.
|
| 84 |
+
extrinsics: The extrinsics of the camera to render to in OpenCV format.
|
| 85 |
+
intrinsics: The intriniscs of the camera to render to in OpenCV format.
|
| 86 |
+
image_width: The desired output image width.
|
| 87 |
+
image_height: The desired output image height.
|
| 88 |
+
"""
|
| 89 |
+
batch_size = len(gaussians.mean_vectors)
|
| 90 |
+
outputs_list: list[RenderingOutputs] = []
|
| 91 |
+
|
| 92 |
+
for ib in range(batch_size):
|
| 93 |
+
colors, alphas, meta = gsplat.rendering.rasterization(
|
| 94 |
+
means=gaussians.mean_vectors[ib],
|
| 95 |
+
quats=gaussians.quaternions[ib],
|
| 96 |
+
scales=gaussians.singular_values[ib],
|
| 97 |
+
opacities=gaussians.opacities[ib],
|
| 98 |
+
colors=gaussians.colors[ib],
|
| 99 |
+
viewmats=extrinsics[ib : ib + 1],
|
| 100 |
+
Ks=intrinsics[ib : ib + 1, :3, :3],
|
| 101 |
+
width=image_width,
|
| 102 |
+
height=image_height,
|
| 103 |
+
render_mode="RGB+D",
|
| 104 |
+
rasterize_mode="classic",
|
| 105 |
+
absgrad=False,
|
| 106 |
+
packed=False,
|
| 107 |
+
eps2d=self.low_pass_filter_eps,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
rendered_color = colors[..., 0:3].permute([0, 3, 1, 2])
|
| 111 |
+
rendered_depth_unnormalized = colors[..., 3:4].permute([0, 3, 1, 2])
|
| 112 |
+
rendered_alpha = alphas.permute([0, 3, 1, 2])
|
| 113 |
+
|
| 114 |
+
# Compose with background color.
|
| 115 |
+
rendered_color = self.compose_with_background(
|
| 116 |
+
rendered_color, rendered_alpha, self.background_color
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Colorspace conversion.
|
| 120 |
+
if self.color_space == "sRGB":
|
| 121 |
+
pass
|
| 122 |
+
elif self.color_space == "linearRGB":
|
| 123 |
+
rendered_color = cs_utils.linearRGB2sRGB(rendered_color)
|
| 124 |
+
else:
|
| 125 |
+
ValueError("Unsupported ColorSpace type.")
|
| 126 |
+
|
| 127 |
+
# splats: (B, N, 10)
|
| 128 |
+
cov2d = self._conics_to_covars2d(meta["conics"])
|
| 129 |
+
# Set the cov2d of invisible splats to 1 to avoid nan in condition number calculation..
|
| 130 |
+
splats_visible_mask = meta["depths"] > 1e-2
|
| 131 |
+
cov2d[~splats_visible_mask][..., 0, 0] = 1
|
| 132 |
+
cov2d[~splats_visible_mask][..., 1, 1] = 1
|
| 133 |
+
cov2d[~splats_visible_mask][..., 0, 1] = 0
|
| 134 |
+
|
| 135 |
+
# Normalize the depth by alpha.
|
| 136 |
+
rendered_depth = rendered_depth_unnormalized / torch.clip(rendered_alpha, min=1e-8)
|
| 137 |
+
|
| 138 |
+
outputs = RenderingOutputs(
|
| 139 |
+
color=rendered_color,
|
| 140 |
+
depth=rendered_depth,
|
| 141 |
+
alpha=rendered_alpha,
|
| 142 |
+
)
|
| 143 |
+
outputs_list.append(outputs)
|
| 144 |
+
|
| 145 |
+
return RenderingOutputs(
|
| 146 |
+
color=torch.cat([item.color for item in outputs_list], dim=0).contiguous(),
|
| 147 |
+
depth=torch.cat([item.depth for item in outputs_list], dim=0).contiguous(),
|
| 148 |
+
alpha=torch.cat([item.alpha for item in outputs_list], dim=0).contiguous(),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
@staticmethod
|
| 152 |
+
def compose_with_background(
|
| 153 |
+
rendered_rgb: torch.Tensor,
|
| 154 |
+
rendered_alpha: torch.Tensor,
|
| 155 |
+
background_color: BackgroundColor,
|
| 156 |
+
) -> torch.Tensor:
|
| 157 |
+
"""Compose rendered RGB with background color."""
|
| 158 |
+
if background_color == "black":
|
| 159 |
+
return rendered_rgb
|
| 160 |
+
elif background_color == "white":
|
| 161 |
+
return rendered_rgb + (1.0 - rendered_alpha)
|
| 162 |
+
elif background_color == "random_color":
|
| 163 |
+
return (
|
| 164 |
+
rendered_rgb
|
| 165 |
+
+ (1.0 - rendered_alpha)
|
| 166 |
+
* torch.rand(3, dtype=rendered_rgb.dtype, device=rendered_rgb.device)[
|
| 167 |
+
None, :, None, None
|
| 168 |
+
]
|
| 169 |
+
)
|
| 170 |
+
elif background_color == "random_pixel":
|
| 171 |
+
return rendered_rgb + (1.0 - rendered_alpha) * torch.rand_like(rendered_rgb)
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError("Unsupported BackgroundColor type.")
|
| 174 |
+
|
| 175 |
+
@staticmethod
|
| 176 |
+
def _conics_to_covars2d(conics: torch.Tensor, eps=1e-8) -> torch.Tensor:
|
| 177 |
+
"""Convert conics to covariance matrices."""
|
| 178 |
+
a = conics[..., 0]
|
| 179 |
+
b = conics[..., 1]
|
| 180 |
+
c = conics[..., 2]
|
| 181 |
+
# Reconstruct determinant.
|
| 182 |
+
det = 1 / (a * c - b**2 + eps)
|
| 183 |
+
det = det.clamp(min=eps)
|
| 184 |
+
# Reconstruct covars2d.
|
| 185 |
+
covars2d = torch.zeros(*conics.shape[:-1], 2, 2, device=conics.device)
|
| 186 |
+
covars2d[..., 1, 1] = a * det
|
| 187 |
+
covars2d[..., 0, 0] = c * det
|
| 188 |
+
covars2d[..., 0, 1] = -b * det
|
| 189 |
+
covars2d[..., 1, 0] = -b * det
|
| 190 |
+
covars2d = torch.nan_to_num(covars2d, nan=0.0, posinf=0.0, neginf=0.0)
|
| 191 |
+
return covars2d
|
src/sharp/utils/io.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains image IO.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import io
|
| 10 |
+
import logging
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import IO, Any, Protocol
|
| 13 |
+
|
| 14 |
+
import imageio.v2 as iio
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pillow_heif
|
| 17 |
+
import torch
|
| 18 |
+
from PIL import ExifTags, Image, TiffTags
|
| 19 |
+
|
| 20 |
+
from .vis import METRIC_DEPTH_MAX_CLAMP_METER, colorize_depth
|
| 21 |
+
|
| 22 |
+
LOGGER = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# NOTE: unused, kept for reference.
|
| 26 |
+
Image.MAX_IMAGE_PIXELS = 200000000
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_rgb(
|
| 30 |
+
path: Path, auto_rotate: bool = True, remove_alpha: bool = True
|
| 31 |
+
) -> tuple[np.ndarray, list[bytes] | None, float]:
|
| 32 |
+
"""Load an RGB image."""
|
| 33 |
+
LOGGER.debug(f"Loading image {path} ...")
|
| 34 |
+
|
| 35 |
+
if path.suffix.lower() in [".heic"]:
|
| 36 |
+
heif_file = pillow_heif.open_heif(path, convert_hdr_to_8bit=True)
|
| 37 |
+
img_pil = heif_file.to_pillow()
|
| 38 |
+
else:
|
| 39 |
+
img_pil = Image.open(path)
|
| 40 |
+
|
| 41 |
+
img_exif = extract_exif(img_pil)
|
| 42 |
+
icc_profile = img_pil.info.get("icc_profile", None)
|
| 43 |
+
|
| 44 |
+
# Rotate the image.
|
| 45 |
+
if auto_rotate:
|
| 46 |
+
exif_orientation = img_exif.get("Orientation", 1)
|
| 47 |
+
if exif_orientation == 3:
|
| 48 |
+
img_pil = img_pil.transpose(Image.ROTATE_180)
|
| 49 |
+
elif exif_orientation == 6:
|
| 50 |
+
img_pil = img_pil.transpose(Image.ROTATE_270)
|
| 51 |
+
elif exif_orientation == 8:
|
| 52 |
+
img_pil = img_pil.transpose(Image.ROTATE_90)
|
| 53 |
+
elif exif_orientation != 1:
|
| 54 |
+
LOGGER.warning(f"Ignoring image orientation {exif_orientation}.")
|
| 55 |
+
|
| 56 |
+
# Extract the focal length.
|
| 57 |
+
f_35mm = img_exif.get("FocalLengthIn35mmFilm", img_exif.get("FocalLenIn35mmFilm", None))
|
| 58 |
+
if f_35mm is None or f_35mm < 1:
|
| 59 |
+
f_35mm = img_exif.get("FocalLength", None)
|
| 60 |
+
if f_35mm is None:
|
| 61 |
+
LOGGER.warn(f"Did not find focallength in exif data of {path} - Setting to 30mm.")
|
| 62 |
+
f_35mm = 30.0
|
| 63 |
+
if f_35mm < 10.0:
|
| 64 |
+
LOGGER.info("Found focal length below 10mm, assuming it's not for 35mm.")
|
| 65 |
+
# This is a very crude approximation.
|
| 66 |
+
f_35mm *= 8.4
|
| 67 |
+
|
| 68 |
+
img = np.asarray(img_pil)
|
| 69 |
+
# Convert to RGB if single channel.
|
| 70 |
+
if img.ndim < 3 or img.shape[2] == 1:
|
| 71 |
+
img = np.dstack((img, img, img))
|
| 72 |
+
|
| 73 |
+
if remove_alpha:
|
| 74 |
+
img = img[:, :, :3]
|
| 75 |
+
|
| 76 |
+
LOGGER.debug(f"\tHxW: {img.shape[0]}x{img.shape[1]}")
|
| 77 |
+
LOGGER.debug(f"\tfocal length @ 35mm film: {f_35mm}mm")
|
| 78 |
+
f_px = convert_focallength(img.shape[1], img.shape[0], f_35mm)
|
| 79 |
+
LOGGER.debug(f"\tfocal length: {f_px:.2f}px")
|
| 80 |
+
|
| 81 |
+
return img, icc_profile, f_px
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def extract_exif(img_pil: Image.Image) -> dict[str, Any]:
|
| 85 |
+
"""Return exif information as a dictionary."""
|
| 86 |
+
# Get full exif description from get_ifd(0x8769):
|
| 87 |
+
# cf https://pillow.readthedocs.io/en/stable/releasenotes/8.2.0.html#image-getexif-exif-and-gps-ifd # noqa
|
| 88 |
+
img_exif = img_pil.getexif().get_ifd(0x8769)
|
| 89 |
+
exif_dict = {ExifTags.TAGS[k]: v for k, v in img_exif.items() if k in ExifTags.TAGS}
|
| 90 |
+
|
| 91 |
+
# https://pillow.readthedocs.io/en/stable/_modules/PIL/TiffTags.html# # noqa
|
| 92 |
+
tiff_tags = img_pil.getexif()
|
| 93 |
+
tiff_dict = {TiffTags.TAGS_V2[k].name: v for k, v in tiff_tags.items() if k in TiffTags.TAGS_V2}
|
| 94 |
+
return {**exif_dict, **tiff_dict}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def convert_focallength(width: float, height: float, f_mm: float = 30) -> float:
|
| 98 |
+
"""Converts a focal length given in mm to pixels."""
|
| 99 |
+
return f_mm * np.sqrt(width**2.0 + height**2.0) / np.sqrt(36**2 + 24**2)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def save_image(
|
| 103 |
+
image: np.ndarray,
|
| 104 |
+
output_path: Path,
|
| 105 |
+
icc_profile: list[bytes] | None = None,
|
| 106 |
+
jpeg_quality: int = 92,
|
| 107 |
+
) -> None:
|
| 108 |
+
"""Save image to given path."""
|
| 109 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 110 |
+
|
| 111 |
+
extensions_to_format = Image.registered_extensions()
|
| 112 |
+
try:
|
| 113 |
+
format = extensions_to_format[output_path.suffix.lower()]
|
| 114 |
+
except KeyError:
|
| 115 |
+
raise ValueError(f"Unsupported output format {output_path.suffix}.")
|
| 116 |
+
|
| 117 |
+
with output_path.open("wb") as file_handle:
|
| 118 |
+
write_image(
|
| 119 |
+
image,
|
| 120 |
+
file_handle,
|
| 121 |
+
format,
|
| 122 |
+
icc_profile=icc_profile,
|
| 123 |
+
jpeg_quality=jpeg_quality,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def write_image(
|
| 128 |
+
image: np.ndarray,
|
| 129 |
+
output_io: IO[bytes],
|
| 130 |
+
format="jpg",
|
| 131 |
+
icc_profile: list[bytes] | None = None,
|
| 132 |
+
jpeg_quality: int = 92,
|
| 133 |
+
):
|
| 134 |
+
"""Write image to binary stream."""
|
| 135 |
+
pil_config = {}
|
| 136 |
+
if format == "JPEG":
|
| 137 |
+
pil_config["quality"] = jpeg_quality
|
| 138 |
+
|
| 139 |
+
image_pil = Image.fromarray(image)
|
| 140 |
+
|
| 141 |
+
# Workaround to error [io.UnsupportedOperation: seek].
|
| 142 |
+
if format == "TIFF":
|
| 143 |
+
bytes_io = io.BytesIO()
|
| 144 |
+
image_pil.save(bytes_io, format="TIFF")
|
| 145 |
+
bytes_io.seek(0)
|
| 146 |
+
output_io.write(bytes_io.read())
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
image_pil.save(output_io, format, icc_profile=icc_profile, **pil_config)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_supported_image_extensions(with_heic: bool = True) -> list[str]:
|
| 153 |
+
"""Return supported image extensions."""
|
| 154 |
+
exts = Image.registered_extensions()
|
| 155 |
+
supported_extensions = {ex for ex, f in exts.items() if f in Image.OPEN}
|
| 156 |
+
if with_heic:
|
| 157 |
+
supported_extensions.add(".heic")
|
| 158 |
+
|
| 159 |
+
supported_extensions_upper = {ex.upper() for ex in supported_extensions}
|
| 160 |
+
return list(supported_extensions | supported_extensions_upper)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_supported_video_extensions():
|
| 164 |
+
"""Return supported video extensions."""
|
| 165 |
+
supported_extensions = {".mp4", ".mov"}
|
| 166 |
+
supported_extensions_upper = {ext.upper() for ext in supported_extensions}
|
| 167 |
+
return list(supported_extensions | supported_extensions_upper)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class OutputWriter(Protocol):
|
| 171 |
+
"""Protocol for writing output to disk."""
|
| 172 |
+
|
| 173 |
+
def add_frame(self, image: torch.Tensor, depth: torch.Tensor) -> None:
|
| 174 |
+
"""Add a single frame to output."""
|
| 175 |
+
...
|
| 176 |
+
|
| 177 |
+
def close(self) -> None:
|
| 178 |
+
"""Finish writing."""
|
| 179 |
+
...
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class VideoWriter(OutputWriter):
|
| 183 |
+
"""Output writer for video output."""
|
| 184 |
+
|
| 185 |
+
def __init__(self, output_path: Path, fps: float = 30.0, render_depth: bool = True) -> None:
|
| 186 |
+
"""Initialize VideoWriter."""
|
| 187 |
+
output_path.parent.mkdir(exist_ok=True, parents=True)
|
| 188 |
+
self.output_path = output_path
|
| 189 |
+
self.image_writer = iio.get_writer(output_path, fps=fps)
|
| 190 |
+
|
| 191 |
+
self.max_depth_estimate = None
|
| 192 |
+
if render_depth:
|
| 193 |
+
self.depth_writer = iio.get_writer(output_path.with_suffix(".depth.mp4"), fps=fps)
|
| 194 |
+
|
| 195 |
+
def add_frame(self, image: torch.Tensor, depth: torch.Tensor) -> None:
|
| 196 |
+
"""Add a single frame to output."""
|
| 197 |
+
image_np = image.detach().cpu().numpy()
|
| 198 |
+
self.image_writer.append_data(image_np)
|
| 199 |
+
|
| 200 |
+
if self.depth_writer is not None:
|
| 201 |
+
if self.max_depth_estimate is None:
|
| 202 |
+
self.max_depth_estimate = depth.max().item()
|
| 203 |
+
|
| 204 |
+
colored_depth_pt = colorize_depth(
|
| 205 |
+
depth,
|
| 206 |
+
min(self.max_depth_estimate, METRIC_DEPTH_MAX_CLAMP_METER), # type: ignore[call-overload]
|
| 207 |
+
)
|
| 208 |
+
colored_depth_np = colored_depth_pt.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
| 209 |
+
self.depth_writer.append_data(colored_depth_np)
|
| 210 |
+
|
| 211 |
+
def close(self):
|
| 212 |
+
"""Finish writing."""
|
| 213 |
+
self.image_writer.close()
|
src/sharp/utils/linalg.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains linear algebra related utility functions.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from scipy.spatial.transform import Rotation
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def rotation_matrices_from_quaternions(quaternions: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
"""Convert batch of quaternions into rotations matrices.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
quaternions: The quaternions convert to matrices.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
The rotations matrices corresponding to the (normalized) quaternions.
|
| 22 |
+
"""
|
| 23 |
+
device = quaternions.device
|
| 24 |
+
shape = quaternions.shape[:-1]
|
| 25 |
+
|
| 26 |
+
quaternions = quaternions / torch.linalg.norm(quaternions, dim=-1, keepdim=True)
|
| 27 |
+
real_part = quaternions[..., 0]
|
| 28 |
+
vector_part = quaternions[..., 1:]
|
| 29 |
+
|
| 30 |
+
vector_cross = get_cross_product_matrix(vector_part)
|
| 31 |
+
real_part = real_part[..., None, None]
|
| 32 |
+
|
| 33 |
+
matrix_outer = vector_part[..., :, None] * vector_part[..., None, :]
|
| 34 |
+
matrix_diag = real_part.square() * eyes(3, shape=shape, device=device)
|
| 35 |
+
matrix_cross_1 = 2 * real_part * vector_cross
|
| 36 |
+
matrix_cross_2 = vector_cross @ vector_cross
|
| 37 |
+
|
| 38 |
+
return matrix_outer + matrix_diag + matrix_cross_1 + matrix_cross_2
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def quaternions_from_rotation_matrices(matrices: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""Convert batch of rotation matrices to quaternions.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
matrices: The matrices to convert to quaternions.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
The quaternions corresponding to the rotation matrices.
|
| 49 |
+
|
| 50 |
+
Note: this operation is not differentiable and will be performed on the CPU.
|
| 51 |
+
"""
|
| 52 |
+
if not matrices.shape[-2:] == (3, 3):
|
| 53 |
+
raise ValueError(f"matrices have invalid shape {matrices.shape}")
|
| 54 |
+
matrices_np = matrices.detach().cpu().numpy()
|
| 55 |
+
quaternions_np = Rotation.from_matrix(matrices_np.reshape(-1, 3, 3)).as_quat()
|
| 56 |
+
# We use a convention where the w component is at the start of the quaternion.
|
| 57 |
+
quaternions_np = quaternions_np[:, [3, 0, 1, 2]]
|
| 58 |
+
quaternions_np = quaternions_np.reshape(matrices_np.shape[:-2] + (4,))
|
| 59 |
+
return torch.as_tensor(quaternions_np, device=matrices.device, dtype=matrices.dtype)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_cross_product_matrix(vectors: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
"""Generate cross product matrix for vector exterior product."""
|
| 64 |
+
if not vectors.shape[-1] == 3:
|
| 65 |
+
raise ValueError("Only 3-dimensional vectors are supported")
|
| 66 |
+
device = vectors.device
|
| 67 |
+
shape = vectors.shape[:-1]
|
| 68 |
+
unit_basis = eyes(3, shape=shape, device=device)
|
| 69 |
+
# We compute the matrix by multiplying each column of unit_basis with the
|
| 70 |
+
# corresponding vector.
|
| 71 |
+
return torch.cross(vectors[..., :, None], unit_basis, dim=-2)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def eyes(
|
| 75 |
+
dim: int, shape: tuple[int, ...], device: torch.device | str | None = None
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
"""Create batch of identity matrices."""
|
| 78 |
+
return torch.eye(dim, device=device).broadcast_to(shape + (dim, dim)).clone()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def quaternion_product(q1, q2):
|
| 82 |
+
"""Compute dot product between two quaternions."""
|
| 83 |
+
real_1 = q1[..., :1]
|
| 84 |
+
real_2 = q2[..., :1]
|
| 85 |
+
vector_1 = q1[..., 1:]
|
| 86 |
+
vector_2 = q2[..., 1:]
|
| 87 |
+
|
| 88 |
+
real_out = real_1 * real_2 - (vector_1 * vector_2).sum(dim=-1, keepdim=True)
|
| 89 |
+
vector_out = real_1 * vector_2 + real_2 * vector_1 + torch.cross(vector_1, vector_2)
|
| 90 |
+
return torch.concatenate([real_out, vector_out], dim=-1)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def quaternion_conj(q):
|
| 94 |
+
"""Get conjugate of a quaternion."""
|
| 95 |
+
real = q[..., :1]
|
| 96 |
+
vector = q[..., 1:]
|
| 97 |
+
return torch.concatenate([real, -vector], dim=-1)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def project(u: torch.Tensor, basis: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""Project tensor u to unit basis a."""
|
| 102 |
+
unit_u = F.normalize(u, dim=-1)
|
| 103 |
+
inner_prod = (unit_u * basis).sum(dim=-1, keepdim=True)
|
| 104 |
+
return inner_prod * u
|
src/sharp/utils/logging.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains logging related utility functions.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def configure(log_level: int, log_path: Path | None = None, prefix: str | None = None) -> None:
|
| 15 |
+
"""Configure logger globally.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
log_level: The desired verbosity level.
|
| 19 |
+
log_path: The path to write logs to.
|
| 20 |
+
prefix: The prefix of the logger.
|
| 21 |
+
"""
|
| 22 |
+
logger = logging.getLogger(prefix)
|
| 23 |
+
|
| 24 |
+
# Reset logger to initial state (e.g. to avoid side effects from imports).
|
| 25 |
+
for handler in logger.handlers:
|
| 26 |
+
logger.removeHandler(handler)
|
| 27 |
+
|
| 28 |
+
for filter in logger.filters:
|
| 29 |
+
logger.removeFilter(filter)
|
| 30 |
+
|
| 31 |
+
# Set level.
|
| 32 |
+
logger.setLevel(log_level)
|
| 33 |
+
|
| 34 |
+
formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s")
|
| 35 |
+
|
| 36 |
+
# Set up console handler.
|
| 37 |
+
stdout_handler = logging.StreamHandler(sys.stdout)
|
| 38 |
+
stdout_handler.setFormatter(formatter)
|
| 39 |
+
logger.addHandler(stdout_handler)
|
| 40 |
+
|
| 41 |
+
# Set up file handler.
|
| 42 |
+
if log_path is not None:
|
| 43 |
+
file_handler = logging.FileHandler(log_path, mode="w")
|
| 44 |
+
file_handler.setFormatter(formatter)
|
| 45 |
+
logger.addHandler(file_handler)
|
src/sharp/utils/math.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Contains utility math functions.
|
| 2 |
+
|
| 3 |
+
For licensing see accompanying LICENSE file.
|
| 4 |
+
Copyright (C) 2025 Apple Inc. All Rights Reserved.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any, Callable, Literal, NamedTuple, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import autograd
|
| 13 |
+
|
| 14 |
+
ActivationType = Literal[
|
| 15 |
+
"linear",
|
| 16 |
+
"exp",
|
| 17 |
+
"sigmoid",
|
| 18 |
+
"softplus",
|
| 19 |
+
"relu_with_pushback",
|
| 20 |
+
"hard_sigmoid_with_pushback",
|
| 21 |
+
]
|
| 22 |
+
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ActivationPair(NamedTuple):
|
| 26 |
+
"""A pair of forward and inverse activation functions."""
|
| 27 |
+
|
| 28 |
+
forward: ActivationFunction
|
| 29 |
+
inverse: ActivationFunction
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_activation_pair(activation_type: ActivationType) -> ActivationPair:
|
| 33 |
+
"""Create activation function and corresponding inverse function.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
activation_type: The activation type to create.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
The corresponding activation functions and the corresponding inverse function.
|
| 40 |
+
"""
|
| 41 |
+
if activation_type == "linear":
|
| 42 |
+
return ActivationPair(lambda x: x, lambda x: x)
|
| 43 |
+
elif activation_type == "exp":
|
| 44 |
+
return ActivationPair(torch.exp, torch.log)
|
| 45 |
+
elif activation_type == "sigmoid":
|
| 46 |
+
return ActivationPair(torch.sigmoid, inverse_sigmoid)
|
| 47 |
+
elif activation_type == "softplus":
|
| 48 |
+
return ActivationPair(torch.nn.functional.softplus, inverse_softplus)
|
| 49 |
+
elif activation_type == "relu_with_pushback":
|
| 50 |
+
return ActivationPair(relu_with_pushback, lambda x: x)
|
| 51 |
+
elif activation_type == "hard_sigmoid_with_pushback":
|
| 52 |
+
return ActivationPair(hard_sigmoid_with_pushback, lambda x: 6.0 * x - 3.0)
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError(f"Unsupported activation function: {activation_type}.")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
"""Compute inverse sigmoid."""
|
| 59 |
+
return torch.log(tensor / (1.0 - tensor))
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def inverse_softplus(tensor: torch.Tensor, eps: float = 1e-06) -> torch.Tensor:
|
| 63 |
+
"""Compute inverse softplus."""
|
| 64 |
+
tensor = tensor.clamp_min(eps)
|
| 65 |
+
sigmoid = torch.sigmoid(-tensor)
|
| 66 |
+
exp = sigmoid / (1.0 - sigmoid)
|
| 67 |
+
return tensor + torch.log(-exp + 1.0)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# The first value describes the threshold from where clamping will be applied, while
|
| 71 |
+
# the second value describes the value to clamp with.
|
| 72 |
+
SoftClampRange = Tuple[Union[torch.Tensor, float], Union[torch.Tensor, float]]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def softclamp(
|
| 76 |
+
tensor: torch.Tensor,
|
| 77 |
+
min: SoftClampRange | None = None,
|
| 78 |
+
max: SoftClampRange | None = None,
|
| 79 |
+
) -> torch.Tensor:
|
| 80 |
+
"""Clamp tensor to min/max in differentiable way.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
tensor: The tensor to clamp.
|
| 84 |
+
min: Pair of threshold to start clamping and value to clamp to.
|
| 85 |
+
The first value should be larger than the second.
|
| 86 |
+
max: Pair of threshold to start clamping and value to clamp to.
|
| 87 |
+
The first value should be smaller than the second.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
The clamped tensor.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def normalize(clamp_range: SoftClampRange) -> torch.Tensor:
|
| 94 |
+
value0, value1 = clamp_range
|
| 95 |
+
return value0 + (value1 - value0) * torch.tanh((tensor - value0) / (value1 - value0))
|
| 96 |
+
|
| 97 |
+
tensor_clamped = tensor
|
| 98 |
+
if min is not None:
|
| 99 |
+
tensor_clamped = torch.maximum(tensor_clamped, normalize(min))
|
| 100 |
+
if max is not None:
|
| 101 |
+
tensor_clamped = torch.minimum(tensor_clamped, normalize(max))
|
| 102 |
+
|
| 103 |
+
return tensor_clamped
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class ClampWithPushback(autograd.Function):
|
| 107 |
+
"""Implementation of clamp_with_pushback function."""
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def forward(
|
| 111 |
+
ctx: Any,
|
| 112 |
+
tensor: torch.Tensor,
|
| 113 |
+
min: float | None,
|
| 114 |
+
max: float | None,
|
| 115 |
+
pushback: float,
|
| 116 |
+
) -> torch.Tensor:
|
| 117 |
+
"""Apply clamp."""
|
| 118 |
+
if min is not None and max is not None and min >= max:
|
| 119 |
+
raise ValueError("Only min < max is supported.")
|
| 120 |
+
|
| 121 |
+
ctx.save_for_backward(tensor)
|
| 122 |
+
ctx.min = min
|
| 123 |
+
ctx.max = max
|
| 124 |
+
ctx.pushback = pushback
|
| 125 |
+
return torch.clamp(tensor, min=min, max=max)
|
| 126 |
+
|
| 127 |
+
@staticmethod
|
| 128 |
+
def backward( # type: ignore[override] # Deal with buggy torch annotations.
|
| 129 |
+
ctx: Any, grad_in: torch.Tensor
|
| 130 |
+
) -> tuple[torch.Tensor, None, None, None]:
|
| 131 |
+
"""Compute gradient of clamp with pushback."""
|
| 132 |
+
grad_out = grad_in.clone()
|
| 133 |
+
(tensor,) = ctx.saved_tensors
|
| 134 |
+
|
| 135 |
+
if ctx.min is not None:
|
| 136 |
+
mask_min = tensor < ctx.min
|
| 137 |
+
grad_out[mask_min] = -ctx.pushback
|
| 138 |
+
|
| 139 |
+
if ctx.max is not None:
|
| 140 |
+
mask_max = tensor > ctx.max
|
| 141 |
+
grad_out[mask_max] = ctx.pushback
|
| 142 |
+
|
| 143 |
+
return grad_out, None, None, None
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def clamp_with_pushback(
|
| 147 |
+
tensor: torch.Tensor,
|
| 148 |
+
min: float | None = None,
|
| 149 |
+
max: float | None = None,
|
| 150 |
+
pushback: float = 1e-2,
|
| 151 |
+
) -> torch.Tensor:
|
| 152 |
+
"""Variant of clamp function which avoid the vanishing gradient problem.
|
| 153 |
+
|
| 154 |
+
This function is equivalent to adding a regularizer of the form
|
| 155 |
+
|
| 156 |
+
pushback * sum_i (
|
| 157 |
+
relu(min - preactivation_i) + relu(preactivation_i - max)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
to the full loss function, which pushes clamped values back.
|
| 161 |
+
|
| 162 |
+
When used in minimization problems, pushback should be greater than
|
| 163 |
+
zero. In maximization problems, pushback should be smaller than zero.
|
| 164 |
+
"""
|
| 165 |
+
output = ClampWithPushback.apply(tensor, min, max, pushback)
|
| 166 |
+
assert isinstance(output, torch.Tensor)
|
| 167 |
+
return output
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def hard_sigmoid_with_pushback(x: torch.Tensor, slope: float = 1.0 / 6.0) -> torch.Tensor:
|
| 171 |
+
"""Apply hard sigmoid with pushback.
|
| 172 |
+
|
| 173 |
+
For compatibility reasons, we follow the default PyTorch implementation with a
|
| 174 |
+
default slope of 1/6:
|
| 175 |
+
|
| 176 |
+
https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html
|
| 177 |
+
"""
|
| 178 |
+
return clamp_with_pushback(slope * x + 0.5, min=0.0, max=1.0)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def relu_with_pushback(x: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
"""Compute relu with pushback."""
|
| 183 |
+
return clamp_with_pushback(x, min=0.0)
|