diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f6b1f326ca4ab7cf0c8798856f8fe0020ff82d58 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..36d011d729e50ca240e2550593ec4667a980651c --- /dev/null +++ b/.gitignore @@ -0,0 +1,152 @@ +uniflowmatch.egg-info/** +ufm_model_refine/** +ufm_model/** +/home/inf/UniFlowMatch/convert_old_ckpt.py +checkpoints/** + +# Byte-compiled / optimized / DLL files +__pycache__/ +**/__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Profiling data +.prof + +# Folder specific to your needs +**/tmp/ +**/outputs/skyseg.onnx +skyseg.onnx + +# pixi environments +.pixi +*.egg-info \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..3494f2451357fe75edbc48740d72c569c8512721 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,58 @@ +Attribution-NonCommercial 2.5 Generic +CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS LICENSE DOES NOT CREATE AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE INFORMATION PROVIDED, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM ITS USE. +License + +THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED. + +BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS. + +1. Definitions + +"Collective Work" means a work, such as a periodical issue, anthology or encyclopedia, in which the Work in its entirety in unmodified form, along with a number of other contributions, constituting separate and independent works in themselves, are assembled into a collective whole. A work that constitutes a Collective Work will not be considered a Derivative Work (as defined below) for the purposes of this License. +"Derivative Work" means a work based upon the Work or upon the Work and other pre-existing works, such as a translation, musical arrangement, dramatization, fictionalization, motion picture version, sound recording, art reproduction, abridgment, condensation, or any other form in which the Work may be recast, transformed, or adapted, except that a work that constitutes a Collective Work will not be considered a Derivative Work for the purpose of this License. For the avoidance of doubt, where the Work is a musical composition or sound recording, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered a Derivative Work for the purpose of this License. +"Licensor" means the individual or entity that offers the Work under the terms of this License. +"Original Author" means the individual or entity who created the Work. +"Work" means the copyrightable work of authorship offered under the terms of this License. +"You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation. +2. Fair Use Rights. Nothing in this license is intended to reduce, limit, or restrict any rights arising from fair use, first sale or other limitations on the exclusive rights of the copyright owner under copyright law or other applicable laws. + +3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below: + +to reproduce the Work, to incorporate the Work into one or more Collective Works, and to reproduce the Work as incorporated in the Collective Works; +to create and reproduce Derivative Works; +to distribute copies or phonorecords of, display publicly, perform publicly, and perform publicly by means of a digital audio transmission the Work including as incorporated in Collective Works; +to distribute copies or phonorecords of, display publicly, perform publicly, and perform publicly by means of a digital audio transmission Derivative Works; +The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. All rights not expressly granted by Licensor are hereby reserved, including but not limited to the rights set forth in Sections 4(d) and 4(e). + +4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions: + +You may distribute, publicly display, publicly perform, or publicly digitally perform the Work only under the terms of this License, and You must include a copy of, or the Uniform Resource Identifier for, this License with every copy or phonorecord of the Work You distribute, publicly display, publicly perform, or publicly digitally perform. You may not offer or impose any terms on the Work that alter or restrict the terms of this License or the recipients' exercise of the rights granted hereunder. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties. You may not distribute, publicly display, publicly perform, or publicly digitally perform the Work with any technological measures that control access or use of the Work in a manner inconsistent with the terms of this License Agreement. The above applies to the Work as incorporated in a Collective Work, but this does not require the Collective Work apart from the Work itself to be made subject to the terms of this License. If You create a Collective Work, upon notice from any Licensor You must, to the extent practicable, remove from the Collective Work any credit as required by clause 4(c), as requested. If You create a Derivative Work, upon notice from any Licensor You must, to the extent practicable, remove from the Derivative Work any credit as required by clause 4(c), as requested. +You may not exercise any of the rights granted to You in Section 3 above in any manner that is primarily intended for or directed toward commercial advantage or private monetary compensation. The exchange of the Work for other copyrighted works by means of digital file-sharing or otherwise shall not be considered to be intended for or directed toward commercial advantage or private monetary compensation, provided there is no payment of any monetary compensation in connection with the exchange of copyrighted works. +If you distribute, publicly display, publicly perform, or publicly digitally perform the Work or any Derivative Works or Collective Works, You must keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of Original Author (or pseudonym, if applicable) if supplied, and/or (ii) if the Original Author and/or Licensor designate another party or parties (e.g. a sponsor institute, publishing entity, journal) for attribution in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; the title of the Work if supplied; to the extent reasonably practicable, the Uniform Resource Identifier, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and in the case of a Derivative Work, a credit identifying the use of the Work in the Derivative Work (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). Such credit may be implemented in any reasonable manner; provided, however, that in the case of a Derivative Work or Collective Work, at a minimum such credit will appear where any other comparable authorship credit appears and in a manner at least as prominent as such other comparable authorship credit. +For the avoidance of doubt, where the Work is a musical composition: + +Performance Royalties Under Blanket Licenses . Licensor reserves the exclusive right to collect, whether individually or via a performance rights society (e.g. ASCAP, BMI, SESAC), royalties for the public performance or public digital performance (e.g. webcast) of the Work if that performance is primarily intended for or directed toward commercial advantage or private monetary compensation. +Mechanical Rights and Statutory Royalties . Licensor reserves the exclusive right to collect, whether individually or via a music rights agency or designated agent (e.g. Harry Fox Agency), royalties for any phonorecord You create from the Work ("cover version") and distribute, subject to the compulsory license created by 17 USC Section 115 of the US Copyright Act (or the equivalent in other jurisdictions), if Your distribution of such cover version is primarily intended for or directed toward commercial advantage or private monetary compensation. +Webcasting Rights and Statutory Royalties. For the avoidance of doubt, where the Work is a sound recording, Licensor reserves the exclusive right to collect, whether individually or via a performance-rights society (e.g. SoundExchange), royalties for the public digital performance (e.g. webcast) of the Work, subject to the compulsory license created by 17 USC Section 114 of the US Copyright Act (or the equivalent in other jurisdictions), if Your public digital performance is primarily intended for or directed toward commercial advantage or private monetary compensation. +5. Representations, Warranties and Disclaimer + +UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU. + +6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +7. Termination + +This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Derivative Works or Collective Works from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License. +Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above. +8. Miscellaneous + +Each time You distribute or publicly digitally perform the Work or a Collective Work, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License. +Each time You distribute or publicly digitally perform a Derivative Work, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License. +If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. +No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent. +This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You. +Creative Commons is not a party to this License, and makes no warranty whatsoever in connection with the Work. Creative Commons will not be liable to You or any party on any legal theory for any damages whatsoever, including without limitation any general, special, incidental or consequential damages arising in connection to this license. Notwithstanding the foregoing two (2) sentences, if Creative Commons has expressly identified itself as the Licensor hereunder, it shall have all rights and obligations of Licensor. + +Except for the limited purpose of indicating to the public that the Work is licensed under the CCPL, neither party will use the trademark "Creative Commons" or any related trademark or logo of Creative Commons without the prior written consent of Creative Commons. Any permitted use will be in compliance with Creative Commons' then-current trademark usage guidelines, as may be published on its website or otherwise made available upon request from time to time. + +Creative Commons may be contacted at https://creativecommons.org/ . \ No newline at end of file diff --git a/UniCeption/.gitignore b/UniCeption/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..299697e0092addb2014ab016f34ab070bb871d72 --- /dev/null +++ b/UniCeption/.gitignore @@ -0,0 +1,167 @@ +# Local Folders +checkpoints/ +local/ +reference_data/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/UniCeption/.pre-commit-config.yaml b/UniCeption/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..45c229ab04cde7b14f395667eb343083988db078 --- /dev/null +++ b/UniCeption/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +default_language_version: + python: python3 +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - repo: https://github.com/pre-commit/mirrors-isort + rev: 'v5.10.1' + hooks: + - id: isort + - repo: https://github.com/psf/black + rev: '23.3.0' + hooks: + - id: black diff --git a/UniCeption/.pylintrc b/UniCeption/.pylintrc new file mode 100644 index 0000000000000000000000000000000000000000..5c771e6244c02612f39871c932bbf15e516a4362 --- /dev/null +++ b/UniCeption/.pylintrc @@ -0,0 +1,399 @@ +# This Pylint rcfile contains a best-effort configuration to uphold the +# best-practices and style described in the Google Python style guide: +# https://google.github.io/styleguide/pyguide.html +# +# Its canonical open-source location is: +# https://google.github.io/styleguide/pylintrc + +[MAIN] + +# Files or directories to be skipped. They should be base names, not paths. +ignore=third_party + +# Files or directories matching the regex patterns are skipped. The regex +# matches against base names, not paths. +ignore-patterns= + +# Pickle collected data for later comparisons. +persistent=no + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Use multiple processes to speed up Pylint. +jobs=4 + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +#enable= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=R, + abstract-method, + apply-builtin, + arguments-differ, + attribute-defined-outside-init, + backtick, + bad-option-value, + basestring-builtin, + buffer-builtin, + c-extension-no-member, + consider-using-enumerate, + cmp-builtin, + cmp-method, + coerce-builtin, + coerce-method, + delslice-method, + div-method, + eq-without-hash, + execfile-builtin, + file-builtin, + filter-builtin-not-iterating, + fixme, + getslice-method, + global-statement, + hex-method, + idiv-method, + implicit-str-concat, + import-error, + import-self, + import-star-module-level, + input-builtin, + intern-builtin, + invalid-str-codec, + locally-disabled, + long-builtin, + long-suffix, + map-builtin-not-iterating, + misplaced-comparison-constant, + missing-function-docstring, + metaclass-assignment, + next-method-called, + next-method-defined, + no-absolute-import, + no-init, # added + no-member, + no-name-in-module, + no-self-use, + nonzero-method, + oct-method, + old-division, + old-ne-operator, + old-octal-literal, + old-raise-syntax, + parameter-unpacking, + print-statement, + raising-string, + range-builtin-not-iterating, + raw_input-builtin, + rdiv-method, + reduce-builtin, + relative-import, + reload-builtin, + round-builtin, + setslice-method, + signature-differs, + standarderror-builtin, + suppressed-message, + sys-max-int, + trailing-newlines, + unichr-builtin, + unicode-builtin, + unnecessary-pass, + unpacking-in-except, + useless-else-on-loop, + useless-suppression, + using-cmp-argument, + wrong-import-order, + xrange-builtin, + zip-builtin-not-iterating, + + +[REPORTS] + +# Set the output format. Available formats are text, parseable, colorized, msvs +# (visual studio) and html. You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + + +[BASIC] + +# Good variable names which should always be accepted, separated by a comma +good-names=main,_ + +# Bad variable names which should always be refused, separated by a comma +bad-names= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl + +# Regular expression matching correct function names +function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct constant names +const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct attribute names +attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ + +# Regular expression matching correct argument names +argument-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=^[a-z][a-z0-9_]*$ + +# Regular expression matching correct class names +class-rgx=^_?[A-Z][a-zA-Z0-9]*$ + +# Regular expression matching correct module names +module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ + +# Regular expression matching correct method names +method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=12 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + + +[FORMAT] + +# Maximum number of characters on a single line. +max-line-length=80 + +# TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt +# lines made too long by directives to pytype. + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=(?x)( + ^\s*(\#\ )??$| + ^\s*(from\s+\S+\s+)?import\s+.+$) + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=yes + +# Maximum number of lines in a module +max-module-lines=99999 + +# String used as indentation unit. The internal Google style guide mandates 2 +# spaces. Google's externaly-published style guide says 4, consistent with +# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google +# projects (like TensorFlow). +indent-string=' ' + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=TODO + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=yes + + +[VARIABLES] + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging,absl.logging,tensorflow.io.logging + + +[SIMILARITIES] + +# Minimum lines number of a similarity. +min-similarity-lines=4 + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[IMPORTS] + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=regsub, + TERMIOS, + Bastion, + rexec, + sets + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant, absl + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls, + class_ + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs diff --git a/UniCeption/LICENSE b/UniCeption/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6248c659a47c6158c93ef43e4c5ecf59f62cf15e --- /dev/null +++ b/UniCeption/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2024, AirLab Stacks + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/UniCeption/README.md b/UniCeption/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dd70229825272d3ccfbbdf0b57e799498e0e4b62 --- /dev/null +++ b/UniCeption/README.md @@ -0,0 +1,155 @@ +# UniCeption + +UniCeption houses modular building blocks for developing and training generalizable perception models for all things related to 3D, 4D, spatial AI and scene understanding. +It is designed to be flexible and extensible, allowing researchers to easily experiment with different architectures and configurations. + +Please refer to the [Developer Guidelines](#developer-guidelines) for contributing to the project. + +## Installation + +Clone the repository to your local machine by running the following command: + +```bash +git clone git@github.com:castacks/UniCeption.git +cd UniCeption +``` + +### Standard Installation + +Install the `uniception` package in development mode by running the following commands: + +```bash +# Please use Conda or Python Virtual Environment based on your preference +# For Conda Environment +conda create --name uniception python=3.12 +conda activate uniception +# For Python Virtual Environment +virtualenv uniception +source uniception/bin/activate + +# Install UniCeption with base dependencies (includes PyTorch) +pip install -e . + +# Optional: Install with XFormers support +pip install -e ".[xformers]" + +# Optional: Install with development tools +pip install -e ".[dev]" + +# Optional: Install all optional dependencies +pip install -e ".[all]" + +# Setup pre-commit hooks for development +pre-commit install +``` + +### Optional: CroCo RoPE Extension Installation + +To use CroCo models with the custom RoPE kernel: + +```bash +# Recommended: Use the console script +uniception-install-croco + +# Alternative: Set environment variable during installation +INSTALL_CROCO_ROPE=true pip install -e . + +# Manual compilation (if needed) +cd uniception/models/libs/croco/curope +python setup.py build_ext --inplace +cd ../../../../../ +``` + +### Installation Validation and Dependency Checking + +After installation, use these console scripts to validate your setup: + +```bash +# Validate installation and check dependencies +uniception-validate + +# Check which optional dependencies are available +uniception-check-deps +``` + +### Advanced Installation Options + +#### Docker Installation (No Internet Access) + +If you're working in a Docker container that already has Python dependencies installed but no internet access, you can install UniCeption in development mode without triggering network requests: + +```bash +# Install only the package structure without dependencies +pip install -e . --no-deps +``` + +**Note:** This command assumes your Docker image already contains all required dependencies (PyTorch, etc.). Use `uniception-validate` after installation to verify all dependencies are available. + +#### Offline Installation + +For environments without internet access: + +```bash +# 1. On a machine with internet access, prepare offline wheels +uniception-prepare-offline --output-dir offline_wheels --extras all + +# 2. Copy the offline_wheels directory to your offline environment +# 3. Run the offline installation +cd offline_wheels +INSTALL_CROCO_ROPE=true INSTALL_XFORMERS=true ./install_offline.sh +``` + +#### Downloading Checkpoints + +Download UniCeption format custom checkpoints: + +```bash +# Download all available checkpoints +uniception-download-checkpoints + +# Download specific folders only (e.g., encoders and prediction heads) +uniception-download-checkpoints --folders encoders prediction_heads + +# Specify custom destination +uniception-download-checkpoints --destination /path/to/checkpoints +``` + +**Available options:** +- `--folders`: Specify which folders to download. Choices: `encoders`, `info_sharing`, `prediction_heads`, `examples` (default: all folders) +- `--destination`: Custom destination folder for downloaded checkpoints (default: current directory) + +--- + +## Currently Supported Components + +### Encoders + +Please refer to the `uniception/models/encoders` directory for the supported encoders and documentation for adding new encoders. The supported encoders can be listed by running: + +```bash +python3 -m uniception.models.encoders.list +``` + +--- + +## Information Sharing Blocks + +Please refer to the `uniception/models/info_sharing` directory for the supported information sharing blocks. + +--- + +## Prediction Heads + +Please refer to the `uniception/models/prediction_heads` directory for the supported prediction heads. + +--- + +## Developer Guidelines + +Please follow these guidelines when contributing to UniCeption: +- **Code Style**: Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) for code style. +- **Documentation**: Add docstrings to all classes and methods. +- **Unit Tests**: Add necessary unit tests to the `tests` folder. +- **Linting**: Run `black` & `isort` on your code before committing. For example, you can run `black . && isort .`. + +Please create a pull request for any changes you make, and ensure that all tests pass before merging. diff --git a/UniCeption/examples/models/cosmos/autoencoding.py b/UniCeption/examples/models/cosmos/autoencoding.py new file mode 100644 index 0000000000000000000000000000000000000000..2b7f467205056e1b447545a5d0ce422f901a97c9 --- /dev/null +++ b/UniCeption/examples/models/cosmos/autoencoding.py @@ -0,0 +1,48 @@ +import os + +import cv2 +import torch +from matplotlib import pyplot as plt + +from uniception.models.encoders.base import ViTEncoderInput +from uniception.models.encoders.cosmos import CosmosEncoder +from uniception.models.prediction_heads.cosmos import CosmosSingleChannel + +base_path = os.path.dirname(os.path.abspath(__file__)) + +encoder = CosmosEncoder( + name="cosmos", + patch_size=8, + pretrained_checkpoint_path=os.path.join( + base_path, "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth" + ), +) + +decoder = CosmosSingleChannel( + patch_size=8, + pretrained_checkpoint_path=os.path.join(base_path, "../../../checkpoints/prediction_heads/cosmos/decoder_8.pth"), +) + +example_image = cv2.imread(os.path.join(base_path, "./example.png")) +example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB) +example_tensor = torch.tensor(example_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0 +example_tensor = example_tensor * 2.0 - 1.0 # Normalize to [-1, 1] according to the COSMOS Encoder + +encoded_latent = encoder(ViTEncoderInput("cosmos", example_tensor)).features + +decoded_image = decoder(encoded_latent) +decoded_image = (decoded_image + 1.0) / 2.0 # Denormalize to [0, 1] for visualization + +# plot the original and decoded images +plt.figure(figsize=(10, 5)) +plt.subplot(1, 2, 1) +plt.imshow(example_image) +plt.title("Original Image") +plt.axis("off") + +plt.subplot(1, 2, 2) +plt.imshow(decoded_image.squeeze().detach().permute(1, 2, 0).cpu().numpy()) +plt.title("Decoded Image") +plt.axis("off") + +plt.savefig(os.path.join(base_path, "example_decoded.png")) diff --git a/UniCeption/examples/models/cosmos/example.png b/UniCeption/examples/models/cosmos/example.png new file mode 100644 index 0000000000000000000000000000000000000000..bb207516833ed69d1b048865271a97611b28feeb --- /dev/null +++ b/UniCeption/examples/models/cosmos/example.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e6ee5528f76e5c0794e2708d688877b0f06f2139a11e883a3832ad57f19f89c +size 710967 diff --git a/UniCeption/examples/models/cosmos/example_decoded.png b/UniCeption/examples/models/cosmos/example_decoded.png new file mode 100644 index 0000000000000000000000000000000000000000..e2da9d0b8ab8cdaa29c7e33a853eed137539a1f3 --- /dev/null +++ b/UniCeption/examples/models/cosmos/example_decoded.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f948b50b602260352e14fca5f51999f01bd98b8e167dd1595451418380eaed21 +size 347571 diff --git a/UniCeption/examples/models/dust3r/convert_dust3r_weights_to_uniception.py b/UniCeption/examples/models/dust3r/convert_dust3r_weights_to_uniception.py new file mode 100644 index 0000000000000000000000000000000000000000..7c70b33b0a69c04bfc74f97d2812c3352a4771bc --- /dev/null +++ b/UniCeption/examples/models/dust3r/convert_dust3r_weights_to_uniception.py @@ -0,0 +1,331 @@ +""" +This file extracts the cross-attention transformer & prediction head weights from dust3r checkpoints into uniception format. + +Special Notice: dust3r have changed their released weights before/after CVPR, and +uniception uses the checkpoint BEFORE CVPR (they perform better). So please make sure you are not converting +the newly downloaded weights. Consult Yuchen and Nikhil on where to find the old weights. +""" + +import argparse +import os + +import torch +from torch import nn + +from uniception.models.info_sharing.cross_attention_transformer import MultiViewCrossAttentionTransformerIFR +from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor +from uniception.models.prediction_heads.linear import LinearFeature + + +def extract_cross_attention_weights(checkpoint_path, output_folder, output_filename): + "Extract the UniCeption format cross attention weights from the original CroCoV2/DUSt3R/MASt3R checkpoints." + # Load checkpoint + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + # Filter the relevant keys for the cross attention model and duplicate if necessary + filtered_checkpoint = checkpoint["model"] + filtered_checkpoint = {k: v for k, v in filtered_checkpoint.items() if "dec" in k} + duplicate_checkpoint = {} + if not any(k.startswith("dec_blocks2") for k in filtered_checkpoint): + print("Duplicating dec_blocks to dec_blocks2") + for key, value in filtered_checkpoint.items(): + if key.startswith("dec_blocks"): + duplicate_checkpoint[key.replace("dec_blocks", "dec_blocks2")] = value + filtered_checkpoint = {**filtered_checkpoint, **duplicate_checkpoint} + new_checkpoint = {} + for k, v in filtered_checkpoint.items(): + if "decoder_embed" in k: + new_key = k.replace("decoder_embed", "proj_embed") + new_checkpoint[new_key] = v + elif "dec_blocks." in k: + new_key = k.replace("dec_blocks.", "multi_view_branches.0.") + new_checkpoint[new_key] = v + elif "dec_blocks2." in k: + new_key = k.replace("dec_blocks2.", "multi_view_branches.1.") + new_checkpoint[new_key] = v + elif "dec_norm" in k: + new_key = k.replace("dec_norm", "norm") + new_checkpoint[new_key] = v + + # Init model + model = MultiViewCrossAttentionTransformerIFR( + name="MV-CAT-IFR", + input_embed_dim=1024, + num_views=2, + indices=[5, 8], + norm_intermediate=False, + ) + + # Load new checkpoint + print(model.load_state_dict(new_checkpoint)) + + # Save the checkpoint + save_checkpoint = {} + save_checkpoint["model"] = model.state_dict() + os.makedirs(os.path.join(output_folder, "cross_attn_transformer"), exist_ok=True) + save_path = os.path.join(output_folder, "cross_attn_transformer", output_filename) + torch.save(save_checkpoint, save_path) + + +def extract_dust3r_dpt_checkpoints(checkpoint_path, output_folder, output_filename): + "Extract the UniCeption format DPT head weights from the original DUSt3R checkpoint." + source_ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + for head in ["head1", "head2"]: + # Extract head weights from the checkpoint + dpt_head_weights = {k: v for k, v in source_ckpt["model"].items() if k.startswith(f"downstream_{head}")} + dpt_head_weights = {k.replace(f"downstream_{head}.dpt.", ""): v for k, v in dpt_head_weights.items()} + dpt_feature_weights = {k: v for k, v in dpt_head_weights.items() if not (k.startswith("head"))} + + # Construct the DPTFeature module and load the weights + dpt = DPTFeature( + patch_size=16, + hooks=[0, 1, 2, 3], + input_feature_dims=[1024, 768, 768, 768], + layer_dims=[96, 192, 384, 768], + feature_dim=256, + use_bn=False, + output_width_ratio=1, + ) + + dpt.load_state_dict(dpt_feature_weights, strict=True) + + # Construct the dpt processor module and load the weights + dpt_processor_weights = {k.replace("head.", ""): v for k, v in dpt_head_weights.items() if k.startswith("head")} + + # Replace the keys according to: + key_replace_dict = { + "0.weight": "conv1.weight", + "0.bias": "conv1.bias", + "2.weight": "conv2.0.weight", + "2.bias": "conv2.0.bias", + "4.weight": "conv2.2.weight", + "4.bias": "conv2.2.bias", + } + + dpt_processor_weights = {key_replace_dict.get(k, k): v for k, v in dpt_processor_weights.items()} + + dpt_reg_processor = DPTRegressionProcessor(input_feature_dim=256, output_dim=4, hidden_dims=[128, 128]) + + dpt_reg_processor.load_state_dict(dpt_processor_weights, strict=True) + + # Save the state_dicts of the DPTFeature and DPTRegressionProcessor + dpt_feature_path = os.path.join(output_folder, "dpt_feature_head", output_filename + f"_feature_{head}.pth") + dpt_reg_processor_path = os.path.join( + output_folder, "dpt_reg_processor", output_filename + f"_reg_processor{head[-1]}.pth" + ) + + os.makedirs(os.path.dirname(dpt_feature_path), exist_ok=True) + os.makedirs(os.path.dirname(dpt_reg_processor_path), exist_ok=True) + + torch.save({"model": dpt.state_dict()}, dpt_feature_path) + torch.save({"model": dpt_reg_processor.state_dict()}, dpt_reg_processor_path) + + +def extract_dust3r_linear_checkpoints(checkpoint_path, output_folder, output_filename): + "Extract the UniCeption format linear head weights from the original DUSt3R checkpoint." + test_linear_to_conv() + + source_ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + for head in ["head1", "head2"]: + linear_head_params = {k: v for k, v in source_ckpt["model"].items() if k.startswith(f"downstream_{head}")} + linear_head_params = {k.replace(f"downstream_{head}.proj.", ""): v for k, v in linear_head_params.items()} + + assert set(linear_head_params.keys()) == {"weight", "bias"} + + input_feature_dim = 768 + output_dim = 4 + patch_size = 16 + + linear = nn.Linear(input_feature_dim, output_dim * patch_size * patch_size, bias=True) + linear.load_state_dict(linear_head_params, strict=True) + + conv_layer = linear_to_conv2d(linear) + + linear_feature = LinearFeature(input_feature_dim, 4, patch_size) + linear_feature.linear.load_state_dict(conv_layer.state_dict(), strict=True) + + linear_feature_path = os.path.join( + output_folder, "linear_feature_head", output_filename + f"_feature_{head}.pth" + ) + os.makedirs(os.path.dirname(linear_feature_path), exist_ok=True) + torch.save({"model": linear_feature.state_dict()}, linear_feature_path) + + +def extract_mast3r_dpt_checkpoints(checkpoint_path, output_folder, output_filename): + "Extract the UniCeption format DPT head weights from the original MASt3R checkpoint." + source_ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + + for head in ["head1", "head2"]: + dpt_head = {k: v for k, v in source_ckpt["model"].items() if k.startswith(f"downstream_{head}")} + dpt_head = {k.replace(f"downstream_{head}.", ""): v for k, v in dpt_head.items()} + dpt_head = {k.replace("dpt.", ""): v for k, v in dpt_head.items()} + + dpt_feature_weights = { + k: v for k, v in dpt_head.items() if not (k.startswith("head") or k.startswith("head_local_features")) + } + + dpt = DPTFeature( + patch_size=16, + hooks=[0, 1, 2, 3], + input_feature_dims=[1024, 768, 768, 768], + layer_dims=[96, 192, 384, 768], + feature_dim=256, + use_bn=False, + output_width_ratio=1, + ) + + dpt.load_state_dict(dpt_feature_weights, strict=True) + + dpt_processor_weights = { + k.replace("head.", ""): v + for k, v in dpt_head.items() + if (k.startswith("head") and not k.startswith("head_local_features")) + } + + # Replace the keys according to: + key_replace_dict = { + "0.weight": "conv1.weight", + "0.bias": "conv1.bias", + "2.weight": "conv2.0.weight", + "2.bias": "conv2.0.bias", + "4.weight": "conv2.2.weight", + "4.bias": "conv2.2.bias", + } + + dpt_processor_weights = {key_replace_dict.get(k, k): v for k, v in dpt_processor_weights.items()} + + dpt_reg_processor = DPTRegressionProcessor(input_feature_dim=256, output_dim=4, hidden_dims=[128, 128]) + + dpt_reg_processor.load_state_dict(dpt_processor_weights, strict=True) + + # Save the state_dicts of the DPTFeature and DPTRegressionProcessor + dpt_feature_path = os.path.join(output_folder, "dpt_feature_head", output_filename + f"_feature_{head}.pth") + dpt_reg_processor_path = os.path.join( + output_folder, "dpt_reg_processor", output_filename + f"_reg_processor{head[-1]}.pth" + ) + + os.makedirs(os.path.dirname(dpt_feature_path), exist_ok=True) + os.makedirs(os.path.dirname(dpt_reg_processor_path), exist_ok=True) + + torch.save({"model": dpt.state_dict()}, dpt_feature_path) + torch.save({"model": dpt_reg_processor.state_dict()}, dpt_reg_processor_path) + + +def linear_to_conv2d(linear_layer): + """ + Converts a nn.Linear layer to an equivalent nn.Conv2d layer with a 1x1 kernel. + + Parameters: + - linear_layer (nn.Linear): The Linear layer to convert. + + Returns: + - conv_layer (nn.Conv2d): The equivalent Conv2d layer. + """ + # Extract in_features and out_features from the Linear layer + in_features = linear_layer.in_features + out_features = linear_layer.out_features + bias = linear_layer.bias is not None + + # Create a Conv2d layer with a 1x1 kernel + conv_layer = nn.Conv2d( + in_channels=in_features, out_channels=out_features, kernel_size=1, stride=1, padding=0, bias=bias + ) + + # Reshape Linear weights to match Conv2d weights + conv_weight = linear_layer.weight.data.view(out_features, in_features, 1, 1).clone() + conv_layer.weight.data = conv_weight + + # Copy bias if it exists + if bias: + conv_layer.bias.data = linear_layer.bias.data.clone() + + return conv_layer + + +def test_linear_to_conv(): + "Test the linear_to_conv2d function." + batch_size = 4 + height = 16 + width = 24 + in_channels = 3 + out_channels = 5 + + # Sample input tensor in BHWC format + x_linear = torch.randn(batch_size, height, width, in_channels) + + # Define Linear layer + linear_layer = nn.Linear(in_channels, out_channels) + output_linear = linear_layer(x_linear) + + # Transpose input tensor to BCHW format for Conv2d + x_conv = x_linear.permute(0, 3, 1, 2) + + # Define Conv2d layer + conv_layer = linear_to_conv2d(linear_layer) + + # Get Conv2d output and transpose back to BHWC format + output_conv = conv_layer(x_conv).permute(0, 2, 3, 1) + + # Verify that outputs are the same + assert torch.allclose(output_linear, output_conv, atol=1e-6) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Extract dust3r checkpoints to uniception format") + + parser.add_argument( + "-dcf", "--dust3r_checkpoints_folder", type=str, required=True, help="Path to the dust3r checkpoints folder" + ) + parser.add_argument("-of", "--output_folder", type=str, required=True, help="Path to the output folder") + + args = parser.parse_args() + + output_folder = args.output_folder + info_sharing_output_folder = os.path.join(output_folder, "info_sharing") + pred_head_output_folder = os.path.join(output_folder, "prediction_heads") + os.makedirs(output_folder, exist_ok=True) + os.makedirs(info_sharing_output_folder, exist_ok=True) + os.makedirs(pred_head_output_folder, exist_ok=True) + + # Extract croco checkpoint + print("Extracting CroCo checkpoint...") + croco_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "CroCo_V2_ViTLarge_BaseDecoder.pth") + extract_cross_attention_weights( + croco_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_CroCo.pth" + ) + + # Extract dust3r 224 linear checkpoint + print("Extracting DUSt3R 224 linear checkpoint...") + dust3r_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "DUSt3R_ViTLarge_BaseDecoder_224_linear.pth") + extract_cross_attention_weights( + dust3r_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_DUSt3R_224_linear.pth" + ) + extract_dust3r_linear_checkpoints(dust3r_ckpt_filepath, pred_head_output_folder, "DUSt3R_224_linear") + + # Extract dust3r 512 linear checkpoint + print("Extracting DUSt3R 512 linear checkpoint...") + dust3r_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "DUSt3R_ViTLarge_BaseDecoder_512_linear.pth") + extract_cross_attention_weights( + dust3r_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_DUSt3R_512_linear.pth" + ) + extract_dust3r_linear_checkpoints(dust3r_ckpt_filepath, pred_head_output_folder, "DUSt3R_512_linear") + + # Extract dust3r 512 dpt checkpoint + print("Extracting DUSt3R 512 dpt checkpoint...") + dust3r_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth") + extract_cross_attention_weights( + dust3r_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_DUSt3R_512_dpt.pth" + ) + extract_dust3r_dpt_checkpoints(dust3r_ckpt_filepath, pred_head_output_folder, "DUSt3R_512_dpt") + + # Extract mast3r 512 dpt checkpoint + print("Extracting MASt3R 512 dpt checkpoint...") + mast3r_ckpt_path = os.path.join( + args.dust3r_checkpoints_folder, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" + ) + extract_cross_attention_weights( + mast3r_ckpt_path, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_MASt3R_512_dpt.pth" + ) + extract_mast3r_dpt_checkpoints(mast3r_ckpt_path, pred_head_output_folder, "MASt3R_512_dpt") diff --git a/UniCeption/examples/models/dust3r/dust3r.py b/UniCeption/examples/models/dust3r/dust3r.py new file mode 100644 index 0000000000000000000000000000000000000000..bebadb990a22ac3da376133b51d41b08bc5b2607 --- /dev/null +++ b/UniCeption/examples/models/dust3r/dust3r.py @@ -0,0 +1,261 @@ +""" +Initalizing Pre-trained DUSt3R using UniCeption +""" + +import argparse +import os +from io import BytesIO + +import numpy as np +import requests +import rerun as rr +import torch +from PIL import Image + +from uniception.models.factory import DUSt3R +from uniception.utils.viz import script_add_rerun_args + + +def get_model_configurations_and_checkpoints(): + """ + Get different DUSt3R model configurations and paths to refactored checkpoints. + + Returns: + Tuple[List[str], dict]: A tuple containing the model configurations and paths to refactored checkpoints. + """ + # Initialize model configurations + model_configurations = ["dust3r_224_linear", "dust3r_512_linear", "dust3r_512_dpt", "dust3r_512_dpt_mast3r"] + + # Get paths to pretrained checkpoints + current_file_path = os.path.abspath(__file__) + relative_checkpoint_path = os.path.join(os.path.dirname(current_file_path), "../../../checkpoints") + + # Initialize model configurations + model_to_checkpoint_path = { + "dust3r_512_dpt": { + "encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth", + "info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_DUSt3R_512_dpt.pth", + "feature_head": [ + f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/DUSt3R_512_dpt_feature_head1.pth", + f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/DUSt3R_512_dpt_feature_head2.pth", + ], + "regressor": [ + f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/DUSt3R_512_dpt_reg_processor1.pth", + f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/DUSt3R_512_dpt_reg_processor2.pth", + ], + "ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", + }, + "dust3r_512_dpt_mast3r": { + "encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_512_MASt3R.pth", + "info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_MASt3R_512_dpt.pth", + "feature_head": [ + f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/MASt3R_512_dpt_feature_head1.pth", + f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/MASt3R_512_dpt_feature_head2.pth", + ], + "regressor": [ + f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/MASt3R_512_dpt_reg_processor1.pth", + f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/MASt3R_512_dpt_reg_processor2.pth", + ], + "ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_512_dpt_mast3r.pth", + }, + "dust3r_512_linear": { + "encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_512_DUSt3R_linear.pth", + "info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_DUSt3R_512_linear.pth", + "feature_head": [ + f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_512_linear_feature_head1.pth", + f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_512_linear_feature_head2.pth", + ], + "regressor": None, + "ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth", + }, + "dust3r_224_linear": { + "encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_224_DUSt3R_linear.pth", + "info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_DUSt3R_224_linear.pth", + "feature_head": [ + f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_224_linear_feature_head1.pth", + f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_224_linear_feature_head2.pth", + ], + "regressor": None, + "ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth", + }, + } + return model_configurations, model_to_checkpoint_path + + +def get_parser(): + "Argument parser for the script." + parser = argparse.ArgumentParser() + parser.add_argument("--viz", action="store_true") + + return parser + + +if __name__ == "__main__": + # Parse arguments + parser = get_parser() + script_add_rerun_args(parser) # Options: --addr + args = parser.parse_args() + + # Set up Rerun for visualization + if args.viz: + rr.script_setup(args, f"UniCeption_DUSt3R_Inference") + rr.set_time("stable_time", sequence=0) + + # the reference data are collected under this setting. + # may use (False, "high") to test the relative error at TF32 precision + torch.backends.cuda.matmul.allow_tf32 = False + torch.set_float32_matmul_precision("highest") + + # Get paths to pretrained checkpoints + current_file_path = os.path.abspath(__file__) + relative_checkpoint_path = os.path.join(os.path.dirname(current_file_path), "../../../checkpoints") + model_configurations, model_to_checkpoint_path = get_model_configurations_and_checkpoints() + + MODEL_TO_VERIFICATION_PATH = { + "dust3r_512_dpt": { + "head_output": os.path.join( + os.path.dirname(current_file_path), + "../../../reference_data/dust3r_pre_cvpr", + "DUSt3R_512_dpt", + "03_head_output.npz", + ) + }, + "dust3r_512_dpt_mast3r": { + "head_output": os.path.join( + os.path.dirname(current_file_path), + "../../../reference_data/dust3r_pre_cvpr", + "MASt3R_512_dpt", + "03_head_output.npz", + ) + }, + "dust3r_512_linear": { + "head_output": os.path.join( + os.path.dirname(current_file_path), + "../../../reference_data/dust3r_pre_cvpr", + "DUSt3R_512_linear", + "03_head_output.npz", + ) + }, + "dust3r_224_linear": { + "head_output": os.path.join( + os.path.dirname(current_file_path), + "../../../reference_data/dust3r_pre_cvpr", + "DUSt3R_224_linear", + "03_head_output.npz", + ) + }, + } + + # Test different DUSt3R models using UniCeption modules + for model_name in model_configurations: + dust3r_model = DUSt3R( + name=model_name, + img_size=(512, 512) if "512" in model_name else (224, 224), + patch_embed_cls="PatchEmbedDust3R", + pred_head_type="linear" if "linear" in model_name else "dpt", + pretrained_checkpoint_path=model_to_checkpoint_path[model_name]["ckpt_path"], + # pretrained_encoder_checkpoint_path=model_to_checkpoint_path[model_name]["encoder"], + # pretrained_info_sharing_checkpoint_path=model_to_checkpoint_path[model_name]["info_sharing"], + # pretrained_pred_head_checkpoint_paths=model_to_checkpoint_path[model_name]["feature_head"], + # pretrained_pred_head_regressor_checkpoint_paths=model_to_checkpoint_path[model_name]["regressor"], + # override_encoder_checkpoint_attributes=True, + ) + print("DUSt3R model initialized successfully!") + + # Initalize device + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + dust3r_model.to(device) + + # Initalize two example images + img0_url = ( + "https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau1.png" + ) + img1_url = ( + "https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau2.png" + ) + response = requests.get(img0_url) + img0 = Image.open(BytesIO(response.content)) + response = requests.get(img1_url) + img1 = Image.open(BytesIO(response.content)) + img0_tensor = torch.from_numpy(np.array(img0))[..., :3].permute(2, 0, 1).unsqueeze(0).float() / 255 + img1_tensor = torch.from_numpy(np.array(img1))[..., :3].permute(2, 0, 1).unsqueeze(0).float() / 255 + + # Normalize images according to DUSt3R's normalization + img0_tensor = (img0_tensor - 0.5) / 0.5 + img1_tensor = (img1_tensor - 0.5) / 0.5 + img_tensor = torch.cat((img0_tensor, img1_tensor), dim=0).to(device) + + # Run a forward pass + view1 = {"img": img_tensor, "instance": [0, 1], "data_norm_type": "dust3r"} + view2 = {"img": view1["img"][[1, 0]].clone().to(device), "instance": [1, 0], "data_norm_type": "dust3r"} + + res1, res2 = dust3r_model(view1, view2) + print("Forward pass completed successfully!") + + # Automatically test the results against the reference result from vanilla dust3r code if they exist + reference_output_path = MODEL_TO_VERIFICATION_PATH[model_name]["head_output"] + if os.path.exists(reference_output_path): + reference_output_data = np.load(reference_output_path) + + # Check against the reference output + check_dict = { + "head1_pts3d": ( + res1["pts3d"].detach().cpu().numpy(), + reference_output_data["head1_pts3d"], + ), + "head2_pts3d": ( + res2["pts3d_in_other_view"].detach().cpu().numpy(), + reference_output_data["head2_pts3d"], + ), + "head1_conf": ( + res1["conf"].detach().squeeze(-1).cpu().numpy(), + reference_output_data["head1_conf"], + ), + "head2_conf": ( + res2["conf"].detach().squeeze(-1).cpu().numpy(), + reference_output_data["head2_conf"], + ), + } + + compute_abs_and_rel_error = lambda x, y: (np.abs(x - y).max(), np.linalg.norm(x - y) / np.linalg.norm(x)) + + print(f"===== Checking for {model_name} model =====") + for key, (output, reference) in check_dict.items(): + abs_error, rel_error = compute_abs_and_rel_error(output, reference) + print(f"{key} abs_error: {abs_error}, rel_error: {rel_error}") + + assert abs_error < 1e-2 and rel_error < 1e-3, f"Error in {key} output" + + points1 = res1["pts3d"][0].detach().cpu().numpy() + points2 = res2["pts3d_in_other_view"][0].detach().cpu().numpy() + conf_mask1 = res1["conf"][0].squeeze(-1).detach().cpu().numpy() > 3.0 + conf_mask2 = res2["conf"][0].squeeze(-1).detach().cpu().numpy() > 3.0 + + if args.viz: + rr.log(f"{model_name}", rr.ViewCoordinates.RDF, static=True) + filtered_pts3d1 = points1[conf_mask1] + filtered_pts3d1_colors = np.array(img0)[..., :3][conf_mask1] / 255 + filtered_pts3d2 = points2[conf_mask2] + filtered_pts3d2_colors = np.array(img1)[..., :3][conf_mask2] / 255 + rr.log( + f"{model_name}/view1", + rr.Points3D( + positions=filtered_pts3d1.reshape(-1, 3), + colors=filtered_pts3d1_colors.reshape(-1, 3), + ), + ) + rr.log( + f"{model_name}/view2", + rr.Points3D( + positions=filtered_pts3d2.reshape(-1, 3), + colors=filtered_pts3d2_colors.reshape(-1, 3), + ), + ) + print( + "Visualizations logged to Rerun: rerun+http://127.0.0.1:/proxy." + "For example, to spawn viewer: rerun --connect rerun+http://127.0.0.1:/proxy" + "Replace with the actual port." + ) diff --git a/UniCeption/examples/models/dust3r/profile_dust3r.py b/UniCeption/examples/models/dust3r/profile_dust3r.py new file mode 100644 index 0000000000000000000000000000000000000000..0b3eb09f75e135d5a26d55a7c042bdce5ec75552 --- /dev/null +++ b/UniCeption/examples/models/dust3r/profile_dust3r.py @@ -0,0 +1,47 @@ +import torch +from dust3r import get_model_configurations_and_checkpoints + +from uniception.models.factory import DUSt3R +from uniception.utils.profile import benchmark_torch_function + +if __name__ == "__main__": + # Get model configurations and checkpoints + model_configurations, model_to_checkpoint_path = get_model_configurations_and_checkpoints() + + # Test different DUSt3R models using UniCeption modules + for model_name in model_configurations: + dust3r_model = DUSt3R( + name=model_name, + img_size=(512, 512) if "512" in model_name else (224, 224), + patch_embed_cls="PatchEmbedDust3R", + pred_head_type="linear" if "linear" in model_name else "dpt", + pretrained_checkpoint_path=model_to_checkpoint_path[model_name]["ckpt_path"], + ) + print(f"DUSt3R model ({model_name}) initialized successfully!") + + # Initialize device + device = "cuda:0" if torch.cuda.is_available() else "cpu" + dust3r_model.to(device) + print(f"Running on {device}") + + # Generate random input tensors + img_size = (512, 512) if "512" in model_name else (224, 224) + batch_sizes = [1, 2, 4, 8] + + for batch_size in batch_sizes: + # Prepare input views + view1_instances = range(batch_size) + view1_img_tensor = torch.randn(batch_size, 3, *img_size).to(device) + view1 = {"img": view1_img_tensor, "instance": view1_instances, "data_norm_type": "dust3r"} + view2_instances = range(batch_size) + view2_instances = [id + batch_size for id in view2_instances] + view2_img_tensor = torch.randn(batch_size, 3, *img_size).to(device) + view2 = {"img": view2_img_tensor, "instance": view2_instances, "data_norm_type": "dust3r"} + + # Benchmark the forward pass of the model + with torch.no_grad(): + with torch.autocast("cuda", enabled=True): + execution_time = benchmark_torch_function(dust3r_model, view1, view2) + print( + f"\033[92mForward pass for {model_name}, batch size : {batch_size} completed in {execution_time:.3f} milliseconds\033[0m" + ) diff --git a/UniCeption/pyproject.toml b/UniCeption/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..6c296aae8096bf4ed985a7303bafd6db7291ea83 --- /dev/null +++ b/UniCeption/pyproject.toml @@ -0,0 +1,21 @@ +[tool.black] +line-length = 120 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | cuda + | dist +)/ +''' + +[tool.isort] +profile = "black" +line_length = 120 diff --git a/UniCeption/scripts/check_dependencies.py b/UniCeption/scripts/check_dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ea73ed8a2f1f6fd2691d8bddab044d3dc79360 --- /dev/null +++ b/UniCeption/scripts/check_dependencies.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +""" +Console script to check UniCeption dependencies. +""" + +import sys +from pathlib import Path + +# Add the parent directory to the path to import uniception +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +def check_dependencies(): + """Check if optional dependencies are available.""" + try: + import torch + + print(f"PyTorch version: {torch.__version__}") + if torch.cuda.is_available(): + print(f"CUDA available: {torch.version.cuda}") + else: + print("CUDA not available") + except ImportError: + print("PyTorch not installed") + + try: + import xformers + + print(f"XFormers version: {xformers.__version__}") + except ImportError: + print("XFormers not installed") + + try: + from uniception.models.libs.croco.curope import cuRoPE2D + + print("CroCo RoPE extension available") + except ImportError: + print("CroCo RoPE extension not available") + + +def main(): + """Main entry point for the check dependencies script.""" + print("Checking UniCeption Dependencies...") + print("=" * 40) + check_dependencies() + + +if __name__ == "__main__": + main() diff --git a/UniCeption/scripts/download_checkpoints.py b/UniCeption/scripts/download_checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab974be3e51dbd4ff3fa14687dae9f3be6ca0bc --- /dev/null +++ b/UniCeption/scripts/download_checkpoints.py @@ -0,0 +1,48 @@ +"Download the UniCeption format checkpoints from the AirLab Data Server" + +import argparse +import os + +from minio import Minio +from minio.error import S3Error +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser(description="Download UniCeption format checkpoints from AirLab Data Server") + parser.add_argument( + "--folders", + nargs="+", + default=["encoders", "info_sharing", "prediction_heads", "examples"], + help="List of folders to download (default: all folders). Choices: encoders, info_sharing, prediction_heads, examples", + ) + parser.add_argument("--destination", type=str, default="./", help="Destination folder for downloaded checkpoints") + args = parser.parse_args() + + access_key = "bT79gQYtfhpxFIitlpns" + secret_key = "g7mSvUJ5k2a9mKv9IbhwXmUQjQX52MLwulhW9ONO" + client = Minio("airlab-share-02.andrew.cmu.edu:9000", access_key=access_key, secret_key=secret_key, secure=True) + + bucket_name = "uniception" + + def download_folder(folder_name, bucket_name, client, destination_folder): + folder_name = f"checkpoints/{folder_name}/" + objects = client.list_objects(bucket_name, prefix=folder_name, recursive=True) + for obj in tqdm(objects, desc=f"Downloading {folder_name}"): + destination_file = os.path.join(destination_folder, obj.object_name) + if not os.path.exists(destination_file): + os.makedirs(os.path.dirname(destination_file), exist_ok=True) + try: + client.fget_object(bucket_name, obj.object_name, destination_file) + print(f"Downloaded {obj.object_name} to {destination_file}") + except S3Error as e: + print(f"Error downloading {obj.object_name}: {e}") + else: + print(f"File {destination_file} already exists. Skipping...") + + for folder in args.folders: + download_folder(folder, bucket_name, client, args.destination) + + +if __name__ == "__main__": + main() diff --git a/UniCeption/scripts/install_croco_rope.py b/UniCeption/scripts/install_croco_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..8374763cb977ce37155588454723ecf33d533e21 --- /dev/null +++ b/UniCeption/scripts/install_croco_rope.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +""" +Console script to install CroCo RoPE extension. +""" + +import os +import subprocess +import sys +from pathlib import Path + + +def install_croco_rope(): + """Install CroCo RoPE extension.""" + try: + # Find the project root (where setup.py is located) + script_dir = Path(__file__).parent + project_root = script_dir.parent + curope_path = project_root / "uniception" / "models" / "libs" / "croco" / "curope" + + if curope_path.exists(): + print("Installing CroCo RoPE extension...") + original_cwd = os.getcwd() + try: + os.chdir(curope_path) + subprocess.check_call([sys.executable, "setup.py", "build_ext", "--inplace"]) + print("CroCo RoPE extension installed successfully!") + return True + except subprocess.CalledProcessError as e: + print(f"Warning: Failed to install CroCo RoPE extension: {e}") + print("You can install it later by running:") + print(f"cd {curope_path} && python setup.py build_ext --inplace") + return False + finally: + os.chdir(original_cwd) + else: + print("Warning: CroCo RoPE source code not found.") + print(f"Expected location: {curope_path}") + return False + except Exception as e: + print(f"Warning: Error during CroCo RoPE installation: {e}") + return False + + +def main(): + """Main entry point for the CroCo RoPE installation script.""" + print("UniCeption CroCo RoPE Extension Installer") + print("=" * 45) + + success = install_croco_rope() + + if success: + print("\n✓ CroCo RoPE extension installation completed successfully!") + sys.exit(0) + else: + print("\n⚠ CroCo RoPE extension installation failed or skipped.") + print("This is typically due to missing CUDA development tools.") + print("The extension is optional and UniCeption will work without it.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/UniCeption/scripts/prepare_offline_install.py b/UniCeption/scripts/prepare_offline_install.py new file mode 100644 index 0000000000000000000000000000000000000000..ff052443554f855428ceded39127a6ae14e57eaf --- /dev/null +++ b/UniCeption/scripts/prepare_offline_install.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +""" +Script to prepare dependencies for offline installation. + +This script downloads all necessary wheel files for offline installation +of UniCeption in environments without internet access. +""" + +import argparse +import os +import subprocess +import sys +from pathlib import Path + + +def download_wheels(output_dir: Path, extras: list = None): + """Download wheel files for offline installation.""" + output_dir.mkdir(parents=True, exist_ok=True) + + # Create temporary requirements files + temp_dir = output_dir / "temp" + temp_dir.mkdir(exist_ok=True) + + try: + # Create requirements files + create_requirements_files(temp_dir, extras) + + # Download base dependencies + base_cmd = [ + sys.executable, + "-m", + "pip", + "download", + "--dest", + str(output_dir), + "-r", + str(temp_dir / "requirements-base.txt"), + ] + + print(f"Downloading base dependencies to {output_dir}...") + subprocess.check_call(base_cmd) + + # Download optional dependencies if requested + if extras: + for extra in extras: + if extra == "all": + # Download all extras + for req_file in ["requirements-xformers.txt", "requirements-dev.txt"]: + if (temp_dir / req_file).exists(): + cmd = [ + sys.executable, + "-m", + "pip", + "download", + "--dest", + str(output_dir), + "-r", + str(temp_dir / req_file), + ] + print( + f"Downloading {req_file.replace('requirements-', '').replace('.txt', '')} dependencies..." + ) + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + print(f"Warning: Failed to download {extra} dependencies: {e}") + else: + req_file = temp_dir / f"requirements-{extra}.txt" + if req_file.exists(): + cmd = [sys.executable, "-m", "pip", "download", "--dest", str(output_dir), "-r", str(req_file)] + print(f"Downloading {extra} dependencies...") + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + print(f"Warning: Failed to download {extra} dependencies: {e}") + + # Create final offline installation files + create_offline_installation_files(output_dir) + + print("Download completed successfully!") + + except subprocess.CalledProcessError as e: + print(f"Error downloading wheels: {e}") + sys.exit(1) + finally: + # Clean up temporary files + import shutil + + if temp_dir.exists(): + shutil.rmtree(temp_dir) + + +def create_requirements_files(temp_dir: Path, extras: list = None): + """Create temporary requirements files for downloading.""" + + # Base requirements (including PyTorch) + base_reqs = [ + "numpy", + "torch", + "torchvision", + "torchaudio", + "timm", + "black", + "jaxtyping", + "matplotlib", + "Pillow", + "scikit-learn", + "einops", + "rerun-sdk", + "pre-commit", + "minio", + "pytest", + "isort", + ] + + # Write base requirements + with open(temp_dir / "requirements-base.txt", "w") as f: + for req in base_reqs: + f.write(f"{req}\n") + + # XFormers requirements + with open(temp_dir / "requirements-xformers.txt", "w") as f: + f.write("xformers\n") + + # Dev requirements + dev_reqs = [ + "black", + "isort", + "pre-commit", + "pytest", + ] + + with open(temp_dir / "requirements-dev.txt", "w") as f: + for req in dev_reqs: + f.write(f"{req}\n") + + +def create_offline_installation_files(output_dir: Path): + """Create requirements files and installation script for offline use.""" + + # Base requirements (including PyTorch) + base_reqs = [ + "numpy", + "torch", + "torchvision", + "torchaudio", + "timm", + "black", + "jaxtyping", + "matplotlib", + "Pillow", + "scikit-learn", + "einops", + "rerun-sdk", + "pre-commit", + "minio", + "pytest", + "isort", + ] + + # Write base requirements + with open(output_dir / "requirements-base.txt", "w") as f: + for req in base_reqs: + f.write(f"{req}\n") + + # XFormers requirements + with open(output_dir / "requirements-xformers.txt", "w") as f: + f.write("xformers\n") + + # Dev requirements + dev_reqs = [ + "black", + "isort", + "pre-commit", + "pytest", + ] + + with open(output_dir / "requirements-dev.txt", "w") as f: + for req in dev_reqs: + f.write(f"{req}\n") + + # Create installation script + install_script = output_dir / "install_offline.sh" + with open(install_script, "w") as f: + f.write( + """#!/bin/bash +# Offline installation script for UniCeption + +set -e + +echo "Installing UniCeption dependencies offline..." + +# Check if we're in the right directory +if [ ! -f "requirements-base.txt" ]; then + echo "Error: requirements-base.txt not found. Please run this script from the offline_wheels directory." + exit 1 +fi + +# Install base dependencies (includes PyTorch) +echo "Installing base dependencies (including PyTorch)..." +pip install --no-index --find-links . -r requirements-base.txt + +# Install XFormers if requested +if [ "$INSTALL_XFORMERS" = "true" ]; then + echo "Installing XFormers..." + pip install --no-index --find-links . -r requirements-xformers.txt +fi + +# Install dev dependencies if requested +if [ "$INSTALL_DEV" = "true" ]; then + echo "Installing development dependencies..." + pip install --no-index --find-links . -r requirements-dev.txt +fi + +# Navigate back to UniCeption directory and install the package +echo "Installing UniCeption package..." +cd .. +pip install --no-deps -e . + +# Install CroCo RoPE extension if requested +if [ "$INSTALL_CROCO_ROPE" = "true" ]; then + echo "Installing CroCo RoPE extension..." + cd uniception/models/libs/croco/curope + python setup.py build_ext --inplace + cd - +fi + +echo "Offline installation completed successfully!" +echo "" +echo "To verify installation, run:" +echo "python setup.py check_deps" +""" + ) + + # Make script executable + install_script.chmod(0o755) + + # Create Windows batch script as well + install_bat = output_dir / "install_offline.bat" + with open(install_bat, "w") as f: + f.write( + """@echo off +REM Offline installation script for UniCeption (Windows) + +echo Installing UniCeption dependencies offline... + +REM Check if we're in the right directory +if not exist "requirements-base.txt" ( + echo Error: requirements-base.txt not found. Please run this script from the offline_wheels directory. + exit /b 1 +) + +REM Install base dependencies (includes PyTorch) +echo Installing base dependencies (including PyTorch)... +pip install --no-index --find-links . -r requirements-base.txt + +REM Install XFormers if requested +if "%INSTALL_XFORMERS%"=="true" ( + echo Installing XFormers... + pip install --no-index --find-links . -r requirements-xformers.txt +) + +REM Install dev dependencies if requested +if "%INSTALL_DEV%"=="true" ( + echo Installing development dependencies... + pip install --no-index --find-links . -r requirements-dev.txt +) + +REM Navigate back to UniCeption directory and install the package +echo Installing UniCeption package... +cd .. +pip install --no-deps -e . + +REM Install CroCo RoPE extension if requested +if "%INSTALL_CROCO_ROPE%"=="true" ( + echo Installing CroCo RoPE extension... + cd uniception\\models\\libs\\croco\\curope + python setup.py build_ext --inplace + cd ..\\..\\..\\..\\.. +) + +echo Offline installation completed successfully! +echo. +echo To verify installation, run: +echo python setup.py check_deps +""" + ) + + # Create a README for offline installation + readme_file = output_dir / "README_OFFLINE.md" + with open(readme_file, "w") as f: + f.write( + """# UniCeption Offline Installation + +This directory contains all the necessary files for installing UniCeption without internet access. + +## Files Included + +- `requirements-base.txt` - Core dependencies (including PyTorch) +- `requirements-xformers.txt` - XFormers dependency +- `requirements-dev.txt` - Development dependencies +- `install_offline.sh` - Installation script for Unix/Linux/macOS +- `install_offline.bat` - Installation script for Windows +- `*.whl` files - Downloaded wheel packages + +## Installation Instructions + +### Unix/Linux/macOS + +```bash +# Set environment variables for optional components +export INSTALL_XFORMERS=true # Install XFormers +export INSTALL_DEV=true # Install development tools +export INSTALL_CROCO_ROPE=true # Compile CroCo RoPE extension + +# Run the installation script +./install_offline.sh +``` + +### Windows + +```cmd +REM Set environment variables for optional components +set INSTALL_XFORMERS=true +set INSTALL_DEV=true +set INSTALL_CROCO_ROPE=true + +REM Run the installation script +install_offline.bat +``` + +## Manual Installation + +If the scripts don't work, you can install manually: + +```bash +# Install base dependencies (includes PyTorch) +pip install --no-index --find-links . -r requirements-base.txt + +# Install optional dependencies as needed +pip install --no-index --find-links . -r requirements-xformers.txt +pip install --no-index --find-links . -r requirements-dev.txt + +# Install UniCeption package (from parent directory) +cd .. +pip install --no-deps -e . + +# Compile CroCo RoPE extension (optional) +cd uniception/models/libs/croco/curope +python setup.py build_ext --inplace +``` + +## Verification + +After installation, verify everything is working: + +```bash +cd .. # Go back to UniCeption root directory +python setup.py check_deps +``` + +## Notes + +- PyTorch, TorchVision, and TorchAudio are now included in the base requirements +- XFormers is optional and only needed for certain performance optimizations +- CroCo RoPE extension compilation requires a CUDA-enabled environment +""" + ) + + print(f"Created offline installation files in {output_dir}") + print("Files created:") + print(" - requirements-base.txt (includes PyTorch)") + print(" - requirements-xformers.txt") + print(" - requirements-dev.txt") + print(" - install_offline.sh (Unix/Linux/macOS)") + print(" - install_offline.bat (Windows)") + print(" - README_OFFLINE.md") + + +def create_offline_requirements(output_dir: Path): + """Create requirements files for offline installation.""" + # This function is now replaced by create_offline_installation_files + pass + + +def main(): + parser = argparse.ArgumentParser(description="Prepare UniCeption for offline installation") + parser.add_argument( + "--output-dir", type=Path, default="offline_wheels", help="Directory to store downloaded wheels" + ) + parser.add_argument("--extras", nargs="+", choices=["xformers", "dev", "all"], help="Extra dependencies to include") + + args = parser.parse_args() + + download_wheels(args.output_dir, args.extras) + + +if __name__ == "__main__": + main() diff --git a/UniCeption/scripts/validate_installation.py b/UniCeption/scripts/validate_installation.py new file mode 100644 index 0000000000000000000000000000000000000000..db2bcac581a88cc168d687ed3835786404a8c555 --- /dev/null +++ b/UniCeption/scripts/validate_installation.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +Validation script for UniCeption installation. + +This script validates that all components of UniCeption are correctly installed +and provides helpful diagnostics. +""" + +import importlib +import sys +from pathlib import Path + + +def check_package_installation(): + """Check if UniCeption package is properly installed.""" + try: + import uniception + + print("✓ UniCeption package is installed") + + # Check if we can import core modules + try: + from uniception.models.encoders import UniCeptionViTEncoderBase + + print("✓ Core encoder modules are available") + except ImportError as e: + print(f"✗ Failed to import core encoder modules: {e}") + + return True + except ImportError as e: + print(f"✗ UniCeption package not found: {e}") + return False + + +def check_dependencies(): + """Check optional dependencies.""" + dependencies = { + "torch": "PyTorch", + "torchvision": "TorchVision", + "torchaudio": "TorchAudio", + "xformers": "XFormers", + "timm": "Timm (PyTorch Image Models)", + "einops": "Einops", + "matplotlib": "Matplotlib", + "numpy": "NumPy", + "PIL": "Pillow", + } + + available = [] + missing = [] + + for module, name in dependencies.items(): + try: + mod = importlib.import_module(module) + version = getattr(mod, "__version__", "unknown") + available.append(f"✓ {name}: {version}") + except ImportError: + missing.append(f"✗ {name}: not installed") + + print("\nDependency Status:") + for dep in available: + print(f" {dep}") + + if missing: + print("\nMissing Dependencies:") + for dep in missing: + print(f" {dep}") + + return len(missing) == 0 + + +def check_cuda_support(): + """Check CUDA support.""" + try: + import torch + + if torch.cuda.is_available(): + print(f"\n✓ CUDA is available") + print(f" CUDA version: {torch.version.cuda}") + print(f" Available devices: {torch.cuda.device_count()}") + for i in range(torch.cuda.device_count()): + print(f" Device {i}: {torch.cuda.get_device_name(i)}") + else: + print(f"\n⚠ CUDA is not available (CPU-only mode)") + except ImportError: + print(f"\n⚠ PyTorch not installed - cannot check CUDA support") + + +def check_croco_rope(): + """Check CroCo RoPE extension.""" + try: + from uniception.models.libs.croco.curope import cuRoPE2D + + print("\n✓ CroCo RoPE extension is available") + return True + except ImportError: + print("\n✗ CroCo RoPE extension not available") + print(" To install: cd uniception/models/libs/croco/curope && python setup.py build_ext --inplace") + return False + + +def check_model_availability(): + """Check if models can be loaded.""" + try: + # Try to check if encoder modules are available + from uniception.models import encoders + + print(f"\n✓ Encoder module is available") + + # Try to run the encoder list command + try: + import subprocess + + result = subprocess.run( + [sys.executable, "-m", "uniception.models.encoders.list"], capture_output=True, text=True, timeout=10 + ) + + if result.returncode == 0: + lines = result.stdout.strip().split("\n") + encoder_count = len([line for line in lines if line.strip() and not line.startswith("Available")]) + print(f"✓ Available encoders: {encoder_count}") + return True + else: + print(f"⚠ Encoder listing returned non-zero exit code: {result.returncode}") + return False + + except subprocess.TimeoutExpired: + print(f"⚠ Encoder listing timed out") + return False + except Exception as e: + print(f"⚠ Could not run encoder listing: {e}") + return False + + except Exception as e: + print(f"\n✗ Failed to access encoder modules: {e}") + return False + + +def check_file_structure(): + """Check if the project file structure is correct.""" + base_path = Path(__file__).parent.parent + required_dirs = [ + "uniception", + "uniception/models", + "uniception/models/encoders", + "uniception/models/info_sharing", + "uniception/models/prediction_heads", + "scripts", + "tests", + ] + + missing_dirs = [] + for dir_path in required_dirs: + full_path = base_path / dir_path + if not full_path.exists(): + missing_dirs.append(dir_path) + + if missing_dirs: + print(f"\n✗ Missing directories:") + for dir_path in missing_dirs: + print(f" - {dir_path}") + return False + else: + print(f"\n✓ Project structure is correct") + return True + + +def main(): + """Run all validation checks.""" + print("UniCeption Installation Validation") + print("=" * 40) + + checks = [ + ("Package Installation", check_package_installation), + ("Dependencies", check_dependencies), + ("CUDA Support", check_cuda_support), + ("CroCo RoPE Extension", check_croco_rope), + ("Model Availability", check_model_availability), + ("File Structure", check_file_structure), + ] + + results = [] + for name, check_func in checks: + print(f"\nChecking {name}...") + try: + result = check_func() + results.append((name, result)) + except Exception as e: + print(f"✗ Error during {name} check: {e}") + results.append((name, False)) + + # Summary + print("\n" + "=" * 40) + print("Validation Summary:") + passed = 0 + for name, result in results: + status = "✓ PASS" if result else "✗ FAIL" + print(f" {name}: {status}") + if result: + passed += 1 + + print(f"\nOverall: {passed}/{len(results)} checks passed") + + if passed == len(results): + print("🎉 All checks passed! UniCeption is ready to use.") + return 0 + else: + print("⚠ Some checks failed. Please review the issues above.") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/UniCeption/setup.py b/UniCeption/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..d7f576c4d27fabedf4e2db3921f5f9a5699c0018 --- /dev/null +++ b/UniCeption/setup.py @@ -0,0 +1,188 @@ +"""Package installation setup.""" + +import os +import subprocess +import sys +from pathlib import Path + +from setuptools import find_packages, setup +from setuptools.command.develop import develop +from setuptools.command.install import install + + +def install_croco_rope(): + """Install CroCo RoPE extension.""" + try: + curope_path = Path(__file__).parent / "uniception" / "models" / "libs" / "croco" / "curope" + if curope_path.exists(): + print("Installing CroCo RoPE extension...") + original_cwd = os.getcwd() + try: + os.chdir(curope_path) + subprocess.check_call([sys.executable, "setup.py", "build_ext", "--inplace"]) + print("CroCo RoPE extension installed successfully!") + return True + except subprocess.CalledProcessError as e: + print(f"Warning: Failed to install CroCo RoPE extension: {e}") + print("You can install it later by running:") + print(f"cd {curope_path} && python setup.py build_ext --inplace") + return False + finally: + os.chdir(original_cwd) + else: + print("Warning: CroCo RoPE source code not found.") + return False + except Exception as e: + print(f"Warning: Error during CroCo RoPE installation: {e}") + return False + + +def check_dependencies(): + """Check if optional dependencies are available.""" + try: + import torch + + print(f"PyTorch version: {torch.__version__}") + if torch.cuda.is_available(): + print(f"CUDA available: {torch.version.cuda}") + else: + print("CUDA not available") + except ImportError: + print("PyTorch not installed") + + try: + import xformers + + print(f"XFormers version: {xformers.__version__}") + except ImportError: + print("XFormers not installed") + + try: + from uniception.models.libs.croco.curope import cuRoPE2D + + print("CroCo RoPE extension available") + except ImportError: + print("CroCo RoPE extension not available") + + +class CustomDevelopCommand(develop): + """Custom development installation command.""" + + def run(self): + develop.run(self) + # Only install CroCo RoPE if explicitly requested + if os.getenv("INSTALL_CROCO_ROPE", "false").lower() in ("true", "1", "yes"): + install_croco_rope() + + +class CustomInstallCommand(install): + """Custom installation command.""" + + def run(self): + install.run(self) + # Only install CroCo RoPE if explicitly requested + if os.getenv("INSTALL_CROCO_ROPE", "false").lower() in ("true", "1", "yes"): + install_croco_rope() + + +class CrocoInstallCommand(install): + """Install command that includes CroCo RoPE extension.""" + + def run(self): + install.run(self) + install_croco_rope() + + +class CheckDependenciesCommand(install): + """Command to check available dependencies.""" + + def run(self): + check_dependencies() + + +# Core dependencies (including PyTorch which is essential for UniCeption) +install_requires = [ + "numpy", + "torch", + "torchvision", + "torchaudio", + "timm", + "black", + "jaxtyping", + "matplotlib", + "Pillow", + "scikit-learn", + "einops", + "rerun-sdk", + "pre-commit", + "minio", + "pytest", + "isort", +] + +# Optional dependencies +extras_require = { + "xformers": [ + "xformers", # Will be installed from PyTorch wheel index + ], + "dev": [ + "black", + "isort", + "pre-commit", + "pytest", + ], + "minimal": [ + # Minimal dependencies for basic functionality without PyTorch + "numpy", + "matplotlib", + "Pillow", + "scikit-learn", + "einops", + ], +} + +# All optional dependencies combined (excluding minimal since it's subset of install_requires) +extras_require["all"] = list(set(extras_require["xformers"] + extras_require["dev"])) + +setup( + name="uniception", + version="0.1.0", + description="Generalizable Perception Stack for 3D, 4D, spatial AI and scene understanding", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + author="AirLab", + license="BSD Clause-3", + packages=find_packages(), + package_dir={"": "."}, + include_package_data=True, + python_requires=">=3.10", + install_requires=install_requires, + extras_require=extras_require, + cmdclass={ + "develop": CustomDevelopCommand, + "install": CustomInstallCommand, + "install_croco": CrocoInstallCommand, + "check_deps": CheckDependenciesCommand, + }, + entry_points={ + "console_scripts": [ + "uniception-download-checkpoints=scripts.download_checkpoints:main", + "uniception-validate=scripts.validate_installation:main", + "uniception-prepare-offline=scripts.prepare_offline_install:main", + "uniception-check-deps=scripts.check_dependencies:main", + "uniception-install-croco=scripts.install_croco_rope:main", + ], + }, + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + keywords="computer-vision, 3d-vision, spatial-ai, perception, deep-learning, pytorch", +) diff --git a/UniCeption/tests/models/encoders/conftest.py b/UniCeption/tests/models/encoders/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..a421574242826741901ef3b226b2e984a377511f --- /dev/null +++ b/UniCeption/tests/models/encoders/conftest.py @@ -0,0 +1,26 @@ +import pytest + + +def pytest_addoption(parser): + # Add custom command-line options + parser.addoption("--encoder-name", action="store", default=None, help="Specify the encoder name to test") + + parser.addoption( + "--device", + action="store", + default="cpu", + choices=["cpu", "gpu"], + help="Specify the device to use (default: cpu)", + ) + + +@pytest.fixture +def encoder_name(request): + # Access the value of the custom option for encoder name + return request.config.getoption("--encoder-name") + + +@pytest.fixture +def device(request): + # Access the value of the custom option for device + return request.config.getoption("--device") diff --git a/UniCeption/tests/models/encoders/test_encoders.py b/UniCeption/tests/models/encoders/test_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd783d80730d1fe376b93f37c90fb2fcf915ca6 --- /dev/null +++ b/UniCeption/tests/models/encoders/test_encoders.py @@ -0,0 +1,204 @@ +import os +import random +from functools import lru_cache +from typing import Tuple + +import numpy as np +import pytest +import requests +import torch +from PIL import Image + +from uniception.models.encoders import * +from uniception.models.encoders.image_normalizations import * + + +@pytest.fixture(scope="module") +def norm_types(): + return IMAGE_NORMALIZATION_DICT.keys() + + +@pytest.fixture(scope="module") +def encoders(): + return [ + "croco", + "dust3r_224", + "dust3r_512", + "dust3r_512_dpt", + "mast3r_512", + "dinov2_base", + "dinov2_large", + "dinov2_large_reg", + "dinov2_large_dav2", + "dinov2_giant", + "dinov2_giant_reg", + "radio_v2.5-b", + "radio_v2.5-l", + "e-radio_v2", + "naradio_v2.5-b", + "naradio_v2.5-l", + "cosmosx8", + "patch_embedder", + ] + + +@pytest.fixture(scope="module") +def encoder_configs(encoders): + # Adjust the number of configs to match the number of encoders + return [{}] * len(encoders) + + +@pytest.fixture +def device(request): + # Access the value of the custom option for device + device_str = request.config.getoption("--device") + if device_str == "gpu" and torch.cuda.is_available(): + device = torch.device("cuda") # Use the default CUDA device + else: + device = torch.device("cpu") + print(f"Using device: {device.type.upper()}") + return device + + +@pytest.fixture +def example_input(device): + @lru_cache(maxsize=3) + def _get_example_input( + image_size: Tuple[int, int], + image_norm_type: str = "dummy", + img_selection: int = 1, + return_viz_img: bool = False, + ) -> torch.Tensor: + url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png" + image = Image.open(requests.get(url, stream=True).raw) + image = image.resize(image_size) + image = image.convert("RGB") + + img = torch.from_numpy(np.array(image)) + viz_img = img.clone() + + # Normalize the image + image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type] + img_mean = image_normalization.mean + img_std = image_normalization.std + img = (img.float() / 255.0 - img_mean) / img_std + + # Convert to BCHW format + img = img.permute(2, 0, 1).unsqueeze(0).to(device) + + if return_viz_img: + return img, viz_img + else: + return img + + return _get_example_input + + +def inference_encoder(encoder, encoder_input): + # Encoder expects a ViTEncoderInput object + return encoder(encoder_input).features + + +def test_make_dummy_encoder(device): + print(f"Testing Init of Dummy Encoder on {device.type.upper()}") + encoder = _make_encoder_test("dummy").to(device) + + # Check if the encoder has parameters + try: + params = list(encoder.parameters()) + if not params: + print("Warning: The encoder has no parameters.") + else: + # Verify if the model is on the right device + assert params[0].is_cuda == (device.type == "cuda") + + except Exception as e: + print(f"Error: {e}") + assert False # Fail the test if any error occurs + + assert encoder is not None + + +def test_all_encoder_basics(encoders, encoder_configs, norm_types, example_input, encoder_name, device): + if encoder_name: + encoders = [encoder_name] # Override default encoders with the one specified + + for encoder_name, encoder_config in zip(encoders, encoder_configs): + print(f"Testing encoder: {encoder_name} on {device.type.upper()}") + + encoder = _make_encoder_test(encoder_name, **encoder_config).to(device) + _check_baseclass_attribute(encoder, norm_types) + _check_norm_check_function(encoder) + + if isinstance(encoder, UniCeptionViTEncoderBase): + _check_vit_encoder_attribute(encoder) + _test_vit_encoder_patch_size(encoder, example_input) + + +def _check_baseclass_attribute(encoder, norm_types): + assert hasattr(encoder, "name") + assert hasattr(encoder, "size") + assert hasattr(encoder, "data_norm_type") + + assert isinstance(encoder.name, str) + assert isinstance(encoder.size, str) or encoder.size is None + assert isinstance(encoder.data_norm_type, str) + + # Check if the data_norm_type is in the list of normalization types + assert encoder.data_norm_type in norm_types + + +def _check_norm_check_function(encoder): + assert hasattr(encoder, "_check_data_normalization_type") + + encoder_notm_type = encoder.data_norm_type + + try: + encoder._check_data_normalization_type(encoder_notm_type) + except AssertionError: + assert False + + try: + encoder._check_data_normalization_type("some_nonexistent_norm_type") + assert False + except AssertionError: + pass + + +def _check_vit_encoder_attribute(encoder): + assert hasattr(encoder, "patch_size") + assert isinstance(encoder.patch_size, int) + assert encoder.patch_size > 0 + + +def _test_vit_encoder_patch_size(encoder, example_input): + print(f"Testing {encoder.name} inference") + image_size = (14 * encoder.patch_size, 14 * encoder.patch_size) + + img = example_input(image_size, encoder.data_norm_type) + # Create an instance of ViTEncoderInput with correct attributes + encoder_input = ViTEncoderInput( + data_norm_type=encoder.data_norm_type, + image=img, + ) + + encoder_output = inference_encoder(encoder, encoder_input) + + assert isinstance(encoder_output, torch.Tensor) + assert encoder_output.shape[2] == 14 + assert encoder_output.shape[3] == 14 + + +@pytest.fixture(scope="session", autouse=True) +def seed_everything(): + seed = 42 + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + print(f"Seed set to: {seed} (type: {type(seed)})") + + # Turn XFormers off for testing on CPU + os.environ["XFORMERS_DISABLED"] = "1" diff --git a/UniCeption/tests/models/encoders/viz_image_encoders.py b/UniCeption/tests/models/encoders/viz_image_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..6213bf2d0ca0cb6127171e22fcba29256c31e9c1 --- /dev/null +++ b/UniCeption/tests/models/encoders/viz_image_encoders.py @@ -0,0 +1,294 @@ +""" +PCA Visualization of UniCeption Image Encoders +""" + +import os +import random +from functools import lru_cache +from typing import Tuple + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from matplotlib import pyplot as plt +from PIL import Image +from sklearn.decomposition import PCA + +from uniception.models.encoders import * +from uniception.models.encoders.image_normalizations import * + + +class TestEncoders: + def __init__(self, pca_save_folder, *args, **kwargs): + super(TestEncoders, self).__init__(*args, **kwargs) + + self.pca_save_folder = pca_save_folder + + self.norm_types = IMAGE_NORMALIZATION_DICT.keys() + + self.encoders = [ + "croco", + "dust3r_224", + "dust3r_512", + "dust3r_512_dpt", + "mast3r_512", + "dinov2_large", + "dinov2_large_reg", + "dinov2_large_dav2", + "dinov2_giant", + "dinov2_giant_reg", + "radio_v2.5-b", + "radio_v2.5-l", + "e-radio_v2", + ] + + self.encoder_configs = [{}] * len(self.encoders) + + def inference_encoder(self, encoder, input): + return encoder(input) + + def visualize_all_encoders(self): + for encoder, encoder_config in zip(self.encoders, self.encoder_configs): + encoder = _make_encoder_test(encoder, **encoder_config) + self._visualize_encoder_features_consistency(encoder, (224, 224)) + + def _visualize_encoder_features(self, encoder, image_size: Tuple[int, int]): + img, viz_img = self._get_example_input(image_size, encoder.data_norm_type, return_viz_img=True) + # input and output of the encoder + encoder_input: ViTEncoderInput = ViTEncoderInput( + data_norm_type=encoder.data_norm_type, + image=img, + ) + + encoder_output = self.inference_encoder(encoder, encoder_input) + encoder_output = encoder_output.features + + self.assertTrue(isinstance(encoder_output, torch.Tensor)) + + # visualize the features + pca_viz = get_pca_map(encoder_output.permute(0, 2, 3, 1), image_size, return_pca_stats=False) + + # plot the input image and the PCA features + fig, axs = plt.subplots(1, 2, figsize=(12, 6)) + axs[0].imshow(viz_img) + axs[0].set_title("Input Image") + axs[0].axis("off") + axs[1].imshow(pca_viz) + axs[1].set_title(f"PCA Features of {encoder.name}") + axs[1].axis("off") + plt.savefig(f"{self.pca_save_folder}/pca_{encoder.name}.png", bbox_inches="tight") + plt.close() + + def _visualize_encoder_features_consistency(self, encoder, image_size: Tuple[int, int]): + img0, viz_img0 = self._get_example_input( + image_size, encoder.data_norm_type, img_selection=1, return_viz_img=True + ) + img1, viz_img1 = self._get_example_input( + image_size, encoder.data_norm_type, img_selection=2, return_viz_img=True + ) + # input and output of the encoder + encoder_input0: ViTEncoderInput = ViTEncoderInput( + data_norm_type=encoder.data_norm_type, + image=img0, + ) + + encoder_input1: ViTEncoderInput = ViTEncoderInput( + data_norm_type=encoder.data_norm_type, + image=img1, + ) + + encoder_output0 = self.inference_encoder(encoder, encoder_input0) + encoder_output0 = encoder_output0.features + + encoder_output1 = self.inference_encoder(encoder, encoder_input1) + encoder_output1 = encoder_output1.features + + # get a common PCA codec + cat_feats = torch.cat([encoder_output0, encoder_output1], dim=3) + + pca_viz = get_pca_map(cat_feats.permute(0, 2, 3, 1), (image_size[0], image_size[1] * 2), return_pca_stats=True) + + # concatenate the input images along the width dimension + cat_imgs = torch.cat([viz_img0, viz_img1], dim=1) + + # plot the input image and the PCA features + fig, axs = plt.subplots(1, 2, figsize=(12, 6)) + axs[0].imshow(cat_imgs) + axs[0].set_title("Input Images") + axs[0].axis("off") + axs[1].imshow(pca_viz[0]) + axs[1].set_title(f"PCA Features of {encoder.name}") + axs[1].axis("off") + plt.savefig(f"{self.pca_save_folder}/multi_pca_{encoder.name}.png", bbox_inches="tight") + plt.close() + + @lru_cache(maxsize=3) + def _get_example_input( + self, + image_size: Tuple[int, int], + image_norm_type: str = "dummy", + img_selection: int = 1, + return_viz_img: bool = False, + ) -> torch.Tensor: + url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png" + image = Image.open(requests.get(url, stream=True).raw) + image = image.resize(image_size) + image = image.convert("RGB") + + img = torch.from_numpy(np.array(image)) + viz_img = img.clone() + + # Normalize the images + image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type] + + img_mean, img_std = image_normalization.mean, image_normalization.std + + img = (img.float() / 255.0 - img_mean) / img_std + + # convert to BCHW format + img = img.permute(2, 0, 1).unsqueeze(0) + + if return_viz_img: + return img, viz_img + else: + return img + + +def render_pca_as_rgb(features): + """ + Perform PCA on the given feature tensor and render the first 3 principal components as RGB. + + Args: + features (torch.Tensor): Feature tensor of shape (B, C, H, W). + + Returns: + np.ndarray: RGB image of shape (H, W, 3). + """ + # Ensure input is a 4D tensor + assert features.dim() == 4, "Input tensor must be 4D (B, C, H, W)" + + B, C, H, W = features.shape + + # Reshape the tensor to (B * H * W, C) + reshaped_features = features.permute(0, 2, 3, 1).contiguous().view(-1, C).cpu().numpy() + + # Perform PCA + pca = PCA(n_components=3) + principal_components = pca.fit_transform(reshaped_features) + + # Rescale the principal components to [0, 1] + principal_components = (principal_components - principal_components.min(axis=0)) / ( + principal_components.max(axis=0) - principal_components.min(axis=0) + ) + + # Reshape the principal components to (B, H, W, 3) + principal_components = principal_components.reshape(B, H, W, 3) + + # Convert the principal components to RGB image (take the first batch) + rgb_image = principal_components[0] + + return rgb_image + + +def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False): + # features: (N, C) + # m: a hyperparam controlling how many std dev outside for outliers + assert len(features.shape) == 2, "features should be (N, C)" + reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2] + colors = features @ reduction_mat + if remove_first_component: + colors_min = colors.min(dim=0).values + colors_max = colors.max(dim=0).values + tmp_colors = (colors - colors_min) / (colors_max - colors_min) + fg_mask = tmp_colors[..., 0] < 0.2 + reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2] + colors = features @ reduction_mat + else: + fg_mask = torch.ones_like(colors[:, 0]).bool() + d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values) + mdev = torch.median(d, dim=0).values + s = d / mdev + try: + rins = colors[fg_mask][s[:, 0] < m, 0] + gins = colors[fg_mask][s[:, 1] < m, 1] + bins = colors[fg_mask][s[:, 2] < m, 2] + rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) + rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) + except: + rins = colors + gins = colors + bins = colors + rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) + rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) + + return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat) + + +def get_pca_map( + feature_map: torch.Tensor, + img_size, + interpolation="bicubic", + return_pca_stats=False, + pca_stats=None, +): + """ + feature_map: (1, h, w, C) is the feature map of a single image. + """ + if feature_map.shape[0] != 1: + # make it (1, h, w, C) + feature_map = feature_map[None] + if pca_stats is None: + reduct_mat, color_min, color_max = get_robust_pca(feature_map.reshape(-1, feature_map.shape[-1])) + else: + reduct_mat, color_min, color_max = pca_stats + pca_color = feature_map @ reduct_mat + pca_color = (pca_color - color_min) / (color_max - color_min) + pca_color = pca_color.clamp(0, 1) + pca_color = F.interpolate( + pca_color.permute(0, 3, 1, 2), + size=img_size, + mode=interpolation, + ).permute(0, 2, 3, 1) + pca_color = pca_color.detach().cpu().numpy().squeeze(0) + if return_pca_stats: + return pca_color, (reduct_mat, color_min, color_max) + return pca_color + + +def seed_everything(seed=42): + """ + Set the `seed` value for torch and numpy seeds. Also turns on + deterministic execution for cudnn. + + Parameters: + - seed: A hashable seed value + """ + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + print(f"Seed set to: {seed} (type: {type(seed)})") + + +if __name__ == "__main__": + # Turn XFormers off for testing on CPU + os.environ["XFORMERS_DISABLED"] = "1" + + # Seed everything for consistent testing + seed_everything() + + # Create local directory for storing the PCA images + current_file_path = os.path.abspath(__file__) + relative_pca_image_folder = os.path.join(os.path.dirname(current_file_path), "../../../local/encoders/pca_images") + os.makedirs(relative_pca_image_folder, exist_ok=True) + + # Initialize the test class + test = TestEncoders(pca_save_folder=relative_pca_image_folder) + + # Visualize the PCA of all encoders + test.visualize_all_encoders() + + print(f"The PCA visualizations of all encoders are saved successfully to {relative_pca_image_folder}!") diff --git a/UniCeption/tests/models/info_sharing/viz_mulit_view_cross_attn_transformers.py b/UniCeption/tests/models/info_sharing/viz_mulit_view_cross_attn_transformers.py new file mode 100644 index 0000000000000000000000000000000000000000..7f1d52aed21c2c3633d1b89afc33a9a4f249dddc --- /dev/null +++ b/UniCeption/tests/models/info_sharing/viz_mulit_view_cross_attn_transformers.py @@ -0,0 +1,337 @@ +""" +PCA Visualization of UniCeption Image Encoders + Multi-View Cross Attention Transformers +""" + +import os +import random +from functools import lru_cache +from typing import Tuple + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from matplotlib import pyplot as plt +from PIL import Image +from sklearn.decomposition import PCA + +from uniception.models.encoders import * +from uniception.models.encoders.image_normalizations import * +from uniception.models.info_sharing.base import MultiViewTransformerInput +from uniception.models.info_sharing.cross_attention_transformer import MultiViewCrossAttentionTransformerIFR +from uniception.models.libs.croco.pos_embed import RoPE2D, get_2d_sincos_pos_embed + + +def _make_mv_cross_attention_transformer_test(model_str: str, **kwargs): + current_file_path = os.path.abspath(__file__) + relative_checkpoint_path = os.path.join( + os.path.dirname(current_file_path), "../../../checkpoints/info_sharing/cross_attn_transformer" + ) + rope = RoPE2D(float(100)) + if model_str == "croco": + return MultiViewCrossAttentionTransformerIFR( + name="croco_base_decoder", + input_embed_dim=1024, + num_views=2, + indices=[12 * 2 // 4, 12 * 3 // 4], + norm_intermediate=False, + custom_positional_encoding=rope, + pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_CroCo.pth", + **kwargs, + ) + elif model_str == "dust3r_224": + return MultiViewCrossAttentionTransformerIFR( + name="dust3r_224_base_decoder", + input_embed_dim=1024, + num_views=2, + indices=[12 * 2 // 4, 12 * 3 // 4], + norm_intermediate=False, + custom_positional_encoding=rope, + pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_DUSt3R_224_linear.pth", + **kwargs, + ) + elif model_str == "dust3r_512": + return MultiViewCrossAttentionTransformerIFR( + name="dust3r_512_base_decoder", + input_embed_dim=1024, + num_views=2, + indices=[12 * 2 // 4, 12 * 3 // 4], + norm_intermediate=False, + custom_positional_encoding=rope, + pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_DUSt3R_512_linear.pth", + **kwargs, + ) + elif model_str == "dust3r_512_dpt": + return MultiViewCrossAttentionTransformerIFR( + name="dust3r_512_dpt_base_decoder", + input_embed_dim=1024, + num_views=2, + indices=[12 * 2 // 4, 12 * 3 // 4], + norm_intermediate=False, + custom_positional_encoding=rope, + pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_DUSt3R_512_dpt.pth", + **kwargs, + ) + elif model_str == "mast3r_512": + return MultiViewCrossAttentionTransformerIFR( + name="mast3r_512_base_decoder", + input_embed_dim=1024, + num_views=2, + indices=[12 * 2 // 4, 12 * 3 // 4], + norm_intermediate=False, + custom_positional_encoding=rope, + pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_MASt3R.pth", + **kwargs, + ) + + +class TestMultiViewTransformers: + def __init__(self, pca_save_folder, *args, **kwargs): + super(TestMultiViewTransformers, self).__init__(*args, **kwargs) + + self.pca_save_folder = pca_save_folder + + self.norm_types = IMAGE_NORMALIZATION_DICT.keys() + + self.models = [ + "croco", + "dust3r_224", + "dust3r_512", + "dust3r_512_dpt", + "mast3r_512", + ] + + self.model_configs = [{}] * len(self.models) + + def inference_encoder(self, encoder, input): + return encoder(input) + + def inference_info_sharing(self, info_sharing, input): + return info_sharing(input) + + def visualize_all_models(self): + for model, model_config in zip(self.models, self.model_configs): + encoder = _make_encoder_test(model, **model_config) + info_sharing = _make_mv_cross_attention_transformer_test(model, **model_config) + self._visualize_model_features_consistency(encoder, info_sharing, (224, 224)) + + def _visualize_model_features_consistency(self, encoder, info_sharing, image_size: Tuple[int, int]): + img0, viz_img0 = self._get_example_input( + image_size, encoder.data_norm_type, img_selection=1, return_viz_img=True + ) + img1, viz_img1 = self._get_example_input( + image_size, encoder.data_norm_type, img_selection=2, return_viz_img=True + ) + # input and output of the encoder + encoder_input0: ViTEncoderInput = ViTEncoderInput( + data_norm_type=encoder.data_norm_type, + image=img0, + ) + + encoder_input1: ViTEncoderInput = ViTEncoderInput( + data_norm_type=encoder.data_norm_type, + image=img1, + ) + + encoder_output0 = self.inference_encoder(encoder, encoder_input0) + encoder_output0 = encoder_output0.features + + encoder_output1 = self.inference_encoder(encoder, encoder_input1) + encoder_output1 = encoder_output1.features + + # pass the encoder outputs to the info sharing model + multi_view_features = [encoder_output0, encoder_output1] + info_sharing_input = MultiViewTransformerInput(features=multi_view_features) + info_sharing_output = self.inference_info_sharing(info_sharing, info_sharing_input) + final_layer_multi_view_features = info_sharing_output[0].features + + # get a common PCA codec + cat_feats = torch.cat(final_layer_multi_view_features, dim=3) + + pca_viz = get_pca_map(cat_feats.permute(0, 2, 3, 1), (image_size[0], image_size[1] * 2), return_pca_stats=True) + + # concatenate the input images along the width dimension + cat_imgs = torch.cat([viz_img0, viz_img1], dim=1) + + # plot the input image and the PCA features + fig, axs = plt.subplots(1, 2, figsize=(12, 6)) + axs[0].imshow(cat_imgs) + axs[0].set_title("Input Images") + axs[0].axis("off") + axs[1].imshow(pca_viz[0]) + axs[1].set_title(f"PCA Features of {encoder.name} + Base Decoder") + axs[1].axis("off") + plt.savefig(f"{self.pca_save_folder}/multi_pca_{encoder.name}.png", bbox_inches="tight") + plt.close() + + @lru_cache(maxsize=3) + def _get_example_input( + self, + image_size: Tuple[int, int], + image_norm_type: str = "dummy", + img_selection: int = 1, + return_viz_img: bool = False, + ) -> torch.Tensor: + url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png" + image = Image.open(requests.get(url, stream=True).raw) + image = image.resize(image_size) + image = image.convert("RGB") + + img = torch.from_numpy(np.array(image)) + viz_img = img.clone() + + # Normalize the images + image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type] + + img_mean, img_std = image_normalization.mean, image_normalization.std + + img = (img.float() / 255.0 - img_mean) / img_std + + # convert to BCHW format + img = img.permute(2, 0, 1).unsqueeze(0) + + if return_viz_img: + return img, viz_img + else: + return img + + +def render_pca_as_rgb(features): + """ + Perform PCA on the given feature tensor and render the first 3 principal components as RGB. + + Args: + features (torch.Tensor): Feature tensor of shape (B, C, H, W). + + Returns: + np.ndarray: RGB image of shape (H, W, 3). + """ + # Ensure input is a 4D tensor + assert features.dim() == 4, "Input tensor must be 4D (B, C, H, W)" + + B, C, H, W = features.shape + + # Reshape the tensor to (B * H * W, C) + reshaped_features = features.permute(0, 2, 3, 1).contiguous().view(-1, C).cpu().numpy() + + # Perform PCA + pca = PCA(n_components=3) + principal_components = pca.fit_transform(reshaped_features) + + # Rescale the principal components to [0, 1] + principal_components = (principal_components - principal_components.min(axis=0)) / ( + principal_components.max(axis=0) - principal_components.min(axis=0) + ) + + # Reshape the principal components to (B, H, W, 3) + principal_components = principal_components.reshape(B, H, W, 3) + + # Convert the principal components to RGB image (take the first batch) + rgb_image = principal_components[0] + + return rgb_image + + +def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False): + # features: (N, C) + # m: a hyperparam controlling how many std dev outside for outliers + assert len(features.shape) == 2, "features should be (N, C)" + reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2] + colors = features @ reduction_mat + if remove_first_component: + colors_min = colors.min(dim=0).values + colors_max = colors.max(dim=0).values + tmp_colors = (colors - colors_min) / (colors_max - colors_min) + fg_mask = tmp_colors[..., 0] < 0.2 + reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2] + colors = features @ reduction_mat + else: + fg_mask = torch.ones_like(colors[:, 0]).bool() + d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values) + mdev = torch.median(d, dim=0).values + s = d / mdev + try: + rins = colors[fg_mask][s[:, 0] < m, 0] + gins = colors[fg_mask][s[:, 1] < m, 1] + bins = colors[fg_mask][s[:, 2] < m, 2] + rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) + rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) + except: + rins = colors + gins = colors + bins = colors + rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) + rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) + + return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat) + + +def get_pca_map( + feature_map: torch.Tensor, + img_size, + interpolation="bicubic", + return_pca_stats=False, + pca_stats=None, +): + """ + feature_map: (1, h, w, C) is the feature map of a single image. + """ + if feature_map.shape[0] != 1: + # make it (1, h, w, C) + feature_map = feature_map[None] + if pca_stats is None: + reduct_mat, color_min, color_max = get_robust_pca(feature_map.reshape(-1, feature_map.shape[-1])) + else: + reduct_mat, color_min, color_max = pca_stats + pca_color = feature_map @ reduct_mat + pca_color = (pca_color - color_min) / (color_max - color_min) + pca_color = pca_color.clamp(0, 1) + pca_color = F.interpolate( + pca_color.permute(0, 3, 1, 2), + size=img_size, + mode=interpolation, + ).permute(0, 2, 3, 1) + pca_color = pca_color.detach().cpu().numpy().squeeze(0) + if return_pca_stats: + return pca_color, (reduct_mat, color_min, color_max) + return pca_color + + +def seed_everything(seed=42): + """ + Set the `seed` value for torch and numpy seeds. Also turns on + deterministic execution for cudnn. + + Parameters: + - seed: A hashable seed value + """ + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + print(f"Seed set to: {seed} (type: {type(seed)})") + + +if __name__ == "__main__": + # Turn XFormers off for testing on CPU + os.environ["XFORMERS_DISABLED"] = "1" + + # Seed everything for consistent testing + seed_everything() + + # Create local directory for storing the PCA images + current_file_path = os.path.abspath(__file__) + relative_pca_image_folder = os.path.join( + os.path.dirname(current_file_path), "../../../local/info_sharing/pca_images" + ) + os.makedirs(relative_pca_image_folder, exist_ok=True) + + # Initialize the test class + test = TestMultiViewTransformers(pca_save_folder=relative_pca_image_folder) + + # Visualize the PCA of all models + test.visualize_all_models() + + print(f"The PCA visualizations of all models are saved successfully to {relative_pca_image_folder}!") diff --git a/UniCeption/uniception/__init__.py b/UniCeption/uniception/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UniCeption/uniception/models/encoders/README.md b/UniCeption/uniception/models/encoders/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1c58406498b969d0217165f586ef31141b81f3a7 --- /dev/null +++ b/UniCeption/uniception/models/encoders/README.md @@ -0,0 +1,129 @@ +# UniCeption Encoders + +## Currently Supported Encoders + +### UniCeptionViTEncoderBase: + +- `CroCoEncoder` + - `CroCoIntermediateFeatureReturner` +- `DINOv2Encoder` + - `DINOv2IntermediateFeatureReturner` +- `PatchEmbedder` +- `RADIOEncoder` + - `RADIOIntermediateFeatureReturner` + +# Developer Guidelines for UniCeption Encoders + +## Overview + +This folder contains the implementation of various UniCeption encoders. Each encoder must adhere to a specific structure and follow certain guidelines to ensure consistency and compatibility across different projects. + +## Directory Structure + +The encoders and other necessary dependencies/tests for encoders are organized as follows: +``` +uniception/ +├── models/ +│ ├── encoders/ +│ │ ├── __init__.py +│ │ ├── base.py +│ │ ├── croco.py +│ │ ├── dinov2.py +│ │ ├── radio.py +│ │ ├── image_normalizations.py +│ └── ... +│ └── libs/ +│ │ ├── external_dependency_folders/ +| | | ├── external_dependency_files +tests/ +├── models/ +│ ├── encoders/ +│ │ ├── test_encoders.py +│ │ ├── viz_image_encoders.py +│ │ └── ... +| └── ... +└── ... +``` + +## Adding a New Encoder + +To add a new encoder, follow these steps: + +1. **Create a New Encoder File**: + - Create a new file in the `encoders` directory, e.g., `new_encoder.py`. + - Define the new encoder class in this file, inheriting from `UniCeptionEncoderBase` or `UniCeptionViTEncoderBase`. + - Please look at the base class for the necessary attributes and methods to implement. + +2. **Define Input Data Normalization**: + - Add the corresponding normalization for the encoder to respective normalization files, for example, image normalizations should be added to `image_normalizations.py`. + - Ensure the normalization is added to the dictionaries present in the files, for example, `IMAGE_NORMALIZATION_DICT`. + +4. **Implement the Encoder Class**: + - Inherit from `UniCeptionEncoderBase` or `UniCeptionViTEncoderBase` or other UniCeption base classes. + - Implement the `forward` method. + - Ensure the encoder class has the necessary attributes and methods. + +4. **Update `__init__.py`**: + - Import the new encoder class in `__init__.py`. + - Add the new encoder to the encoder configuration dictionary `ENCODER_CONFIGS` so that it can be instantiated via the encoder factory. + - Update the `_make_encoder_test` function to include the new encoder. + +5. **Run Encoder Unit Tests**: + - Run `pytest -vs tests/models/encoders/test_encoders.py --encoder-name=""` to test the basic expected functionality of UniCeption encoders. + - Also, add your new encoder to the list in the encoders() in `tests/models/encoders/test_encoders.py` so that it can be tested along with all the existing encoders. + - Optionally, for image encoders, the unit tests in `tests/models/encoders/viz_image_encoders.py` save PCA visualizations of the encoder outputs to the `local/pca_images` directory. + +## Example Encoder Implementation + +Here is an example of how to implement a new encoder: + +```python +# new_encoder.py +import torch +from uniception.models.encoders.base import UniCeptionEncoderBase, EncoderInput, EncoderOutput + +class NewEncoder(UniCeptionEncoderBase): + def __init__(self, name: str, data_norm_type: str, *args, **kwargs): + super().__init__(name=name, data_norm_type=data_norm_type, *args, **kwargs) + # Initialize encoder-specific layers and parameters here + + def forward(self, encoder_input: EncoderInput) -> EncoderOutput: + self._check_data_normalization_type(encoder_input.data_norm_type) + # Implement the forward pass + return EncoderOutput() +``` + +## Example Normalization + +Add the normalization for the new encoder, for example, to `image_normalizations.py`: + +```python +# image_normalizations.py +IMAGE_NORMALIZATION_DICT = { + "dummy": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])), + "croco": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])), + "dust3r": ImageNormalization(mean=torch.tensor([0.5, 0.5, 0.5]), std=torch.tensor([0.5, 0.5, 0.5])), + "dinov2": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])), + "radio": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])), + "new_encoder": ImageNormalization(mean=torch.tensor([0.5, 0.5, 0.5]), std=torch.tensor([0.2, 0.2, 0.2])), +} +``` + +## Example Unit Testing + +Add the new encoder to the encoder factory in `__init__.py` and the encoder list in `tests/models/encoders/test_encoders.py`. Additional tests can also be added as required. + +Look at `tests/models/encoders/test_encoders.py` to see what tests are run. + +Additionally, if the new encoder is an image encoder, you can add to the encoder list in `tests/models/encoders/viz_image_encoders.py` for saving PCA visualizations of the encoder outputs to the `local/pca_images` directory. + +## Developer Guidelines + +Please follow these guidelines when contributing to the UniCeption encoders: +- **Consistency**: Ensure that the new encoder follows the structure and naming conventions of existing encoders. +- **Code Style**: Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) for code style. +- **Documentation**: Add docstrings to all classes and methods. +- **Unit Tests**: Add necessary unit tests for the encoder class. +- **Linting**: Run `black` on your code before committing. For example, you can run `black uniception`. + +## Happy Coding! diff --git a/UniCeption/uniception/models/encoders/__init__.py b/UniCeption/uniception/models/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6599fe4a4e11f7e0f9a218342c112c946657f17 --- /dev/null +++ b/UniCeption/uniception/models/encoders/__init__.py @@ -0,0 +1,235 @@ +""" +Encoder Factory for UniCeption +""" + +import os + +from uniception.models.encoders.base import ( + EncoderGlobalRepInput, + EncoderInput, + UniCeptionEncoderBase, + UniCeptionViTEncoderBase, + ViTEncoderInput, + ViTEncoderNonImageInput, + ViTEncoderOutput, +) +from uniception.models.encoders.cosmos import CosmosEncoder +from uniception.models.encoders.croco import CroCoEncoder, CroCoIntermediateFeatureReturner +from uniception.models.encoders.dense_rep_encoder import DenseRepresentationEncoder +from uniception.models.encoders.dinov2 import DINOv2Encoder, DINOv2IntermediateFeatureReturner +from uniception.models.encoders.global_rep_encoder import GlobalRepresentationEncoder +from uniception.models.encoders.naradio import NARADIOEncoder +from uniception.models.encoders.patch_embedder import PatchEmbedder +from uniception.models.encoders.radio import RADIOEncoder, RADIOIntermediateFeatureReturner + +# Define encoder configurations +ENCODER_CONFIGS = { + "croco": { + "class": CroCoEncoder, + "intermediate_feature_returner_class": CroCoIntermediateFeatureReturner, + "supported_models": ["CroCov2", "DUSt3R", "MASt3R"], + }, + "dense_rep_encoder": { + "class": DenseRepresentationEncoder, + "supported_models": ["Dense-Representation-Encoder"], + }, + "dinov2": { + "class": DINOv2Encoder, + "intermediate_feature_returner_class": DINOv2IntermediateFeatureReturner, + "supported_models": ["DINOv2", "DINOv2-Registers", "DINOv2-Depth-Anythingv2"], + }, + "global_rep_encoder": { + "class": GlobalRepresentationEncoder, + "supported_models": ["Global-Representation-Encoder"], + }, + "patch_embedder": { + "class": PatchEmbedder, + "supported_models": ["Patch-Embedder"], + }, + "radio": { + "class": RADIOEncoder, + "intermediate_feature_returner_class": RADIOIntermediateFeatureReturner, + "supported_models": ["RADIO", "E-RADIO"], + }, + "cosmos": { + "class": CosmosEncoder, + "supported_models": ["Cosmos-Tokenizer CI8x8", "Cosmos-Tokenizer CI16x16"], + }, + "naradio": { + "class": NARADIOEncoder, + "supported_models": ["RADIO"], + }, + # Add other encoders here +} + + +def encoder_factory(encoder_str: str, **kwargs) -> UniCeptionEncoderBase: + """ + Encoder factory for UniCeption. + Please use python3 -m uniception.models.encoders.list to see available encoders. + + Args: + encoder_str (str): Name of the encoder to create. + **kwargs: Additional keyword arguments to pass to the encoder constructor. + + Returns: + UniCeptionEncoderBase: An instance of the specified encoder. + """ + if encoder_str not in ENCODER_CONFIGS: + raise ValueError( + f"Unknown encoder: {encoder_str}. For valid encoder_str options, please use python3 -m uniception.models.encoders.list" + ) + + encoder_config = ENCODER_CONFIGS[encoder_str] + encoder_class = encoder_config["class"] + + return encoder_class(**kwargs) + + +def feature_returner_encoder_factory(encoder_str: str, **kwargs) -> UniCeptionEncoderBase: + """ + Factory for UniCeption Encoders with support for intermediate feature returning. + Please use python3 -m uniception.models.encoders.list to see available encoders. + + Args: + encoder_str (str): Name of the encoder to create. + **kwargs: Additional keyword arguments to pass to the encoder constructor. + + Returns: + UniCeptionEncoderBase: An instance of the specified encoder. + """ + if encoder_str not in ENCODER_CONFIGS: + raise ValueError( + f"Unknown encoder: {encoder_str}. For valid encoder_str options, please use python3 -m uniception.models.encoders.list" + ) + + encoder_config = ENCODER_CONFIGS[encoder_str] + encoder_class = encoder_config["intermediate_feature_returner_class"] + + return encoder_class(**kwargs) + + +def get_available_encoders() -> list: + """ + Get a list of available encoders in UniCeption. + + Returns: + list: A list of available encoder names. + """ + return list(ENCODER_CONFIGS.keys()) + + +def print_available_encoder_models(): + """ + Print the currently supported encoders in UniCeption. + """ + print("Currently Supported Encoders in UniCeption:\nFormat -> encoder_str: supported_models") + for encoder_name, config in ENCODER_CONFIGS.items(): + print(f"{encoder_name}: {', '.join(config['supported_models'])}") + + +def _make_encoder_test(encoder_str: str, **kwargs) -> UniCeptionEncoderBase: + "Function to create encoders for testing purposes." + current_file_path = os.path.abspath(__file__) + relative_checkpoint_path = os.path.join(os.path.dirname(current_file_path), "../../../checkpoints/encoders") + if encoder_str == "dummy": + return UniCeptionEncoderBase(name="dummy", data_norm_type="dummy") + elif encoder_str == "croco": + return CroCoEncoder( + name="croco", + data_norm_type="croco", + pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_224.pth", + patch_embed_cls="PatchEmbedCroCo", + ) + elif encoder_str == "dust3r_224": + return CroCoEncoder( + name="dust3r_224", + data_norm_type="dust3r", + pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_224_DUSt3R_linear.pth", + patch_embed_cls="PatchEmbedDust3R", + ) + elif encoder_str == "dust3r_512": + return CroCoEncoder( + name="dust3r_512", + data_norm_type="dust3r", + pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_512_DUSt3R_linear.pth", + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + ) + elif encoder_str == "dust3r_512_dpt": + return CroCoEncoder( + name="dust3r_512_dpt", + data_norm_type="dust3r", + pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_512_DUSt3R_dpt.pth", + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + ) + elif encoder_str == "mast3r_512": + return CroCoEncoder( + name="mast3r_512", + data_norm_type="dust3r", + pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_512_MASt3R.pth", + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + ) + elif "dinov2" in encoder_str: + size = encoder_str.split("_")[1] + size_single_cap_letter = size[0].upper() + if "reg" in encoder_str: + with_registers = True + pretrained_checkpoint_path = None + elif "dav2" in encoder_str: + with_registers = False + pretrained_checkpoint_path = ( + f"{relative_checkpoint_path}/DINOv2_ViT{size_single_cap_letter}_DepthAnythingV2.pth" + ) + else: + with_registers = False + pretrained_checkpoint_path = None + return DINOv2Encoder( + name=encoder_str.replace("_reg", ""), + size=size, + with_registers=with_registers, + pretrained_checkpoint_path=pretrained_checkpoint_path, + ) + elif "naradio" in encoder_str: + return NARADIOEncoder( + name=encoder_str, + model_version=encoder_str.replace("na", ""), + ) + elif "radio" in encoder_str: + if "e-radio" in encoder_str: + eradio_input_shape = (224, 224) + else: + eradio_input_shape = None + return RADIOEncoder( + name=encoder_str, + model_version=encoder_str, + eradio_input_shape=eradio_input_shape, + ) + elif "cosmos" in encoder_str: + patch_size = int(encoder_str.split("x")[-1]) + return CosmosEncoder( + name=encoder_str, + patch_size=patch_size, + pretrained_checkpoint_path=f"{relative_checkpoint_path}/Cosmos-Tokenizer-CI{patch_size}x{patch_size}/encoder.pth", + ) + elif "patch_embedder" in encoder_str: + return PatchEmbedder( + name=encoder_str, + ) + else: + raise ValueError(f"Unknown encoder: {encoder_str}") + + +__all__ = [ + "encoder_factory", + "get_available_encoders", + "print_available_encoder_models", + "_make_encoder_test", + "UniCeptionEncoderBase", + "UniCeptionViTEncoderBase", + "EncoderInput", + "ViTEncoderInput", + "ViTEncoderOutput", +] diff --git a/UniCeption/uniception/models/encoders/base.py b/UniCeption/uniception/models/encoders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..acd04b12bae3da631c38b17cefb7c3ded1e07e70 --- /dev/null +++ b/UniCeption/uniception/models/encoders/base.py @@ -0,0 +1,157 @@ +""" +Base Encoder Class for UniCeption +""" + +from dataclasses import dataclass +from typing import Optional + +import torch.nn as nn +from jaxtyping import Float +from torch import Tensor +from torch.utils.checkpoint import checkpoint + + +@dataclass +class EncoderInput: + "Data class for Encoder Input" + + data_norm_type: str + # Add other fields that are required by the specific implementation of the encoder. + + +@dataclass +class EncoderOutput: + "Data class for Encoder Output" + + pass + + +@dataclass +class EncoderGlobalRepInput: + "Data class for Encoder Global Representation Input" + + data: Float[Tensor, "batch channel"] + + +@dataclass +class EncoderGlobalRepOutput: + "Data class for Encoder Global Representation Output" + + features: Float[Tensor, "batch enc_embed_dim"] + + +class UniCeptionEncoderBase(nn.Module): + "Encoder Base Class for UniCeption" + + def __init__( + self, + name: str, + data_norm_type: str, + size: Optional[str] = None, + *args, + **kwargs, + ): + """ + Base class for all encoders in UniCeption. + """ + super().__init__(*args, **kwargs) + + self.name: str = name + self.size: Optional[str] = size + + self.data_norm_type: str = data_norm_type + + def forward( + self, + encoder_input: EncoderInput, + ) -> EncoderOutput: + """ + Forward interface for the UniCeption encoders. + + We expect the "data_norm_type" field to be present in the encoder_input to check for normalization type. + + Args: + encoder_input (EncoderInput): Input to the encoder. We expect the following fields: "data_norm_type: str". + This is also includes the other fields that are required by the specific implementation of the encoder. + + Returns: + EncoderOutput: Output of the encoder. + """ + + raise NotImplementedError + + def _check_data_normalization_type(self, data_norm_type: str): + """ + Check if the input normalization type matches the encoder's expected input data normalization type. + + Args: + data_norm_type (str): Data normalization type. + + Raises: + AssertionError: If the data normalization type does not match the encoder's expected input data normalization type. + """ + + assert ( + data_norm_type == self.data_norm_type + ), f"Input normalization type {data_norm_type} does not match the encoder's normalization type {self.data_norm_type}." + + +@dataclass +class ViTEncoderInput(EncoderInput): + "Data class for Vision Transformer Encoder Input" + + image: Float[Tensor, "batch channel height width"] + + +@dataclass +class ViTEncoderNonImageInput: + "Data class for Vision (2D-Grid) Transformer Encoder Non-Image Input" + + data: Float[Tensor, "batch channel height width"] + + +@dataclass +class ViTEncoderOutput(EncoderOutput): + "Data class for Vision Transformer Encoder Output" + + features: Float[Tensor, "batch enc_embed_dim feat_height feat_width"] + + +class UniCeptionViTEncoderBase(UniCeptionEncoderBase): + "Vision Transformer Encoder Base Class for UniCeption" + + def __init__( + self, + patch_size: int, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Base class for all Vision Transformer encoders in UniCeption. + """ + super().__init__(*args, **kwargs) + + self.patch_size = patch_size + self.gradient_checkpointing = gradient_checkpointing + + def wrap_module_with_gradient_checkpointing(self, module: nn.Module): + """ + Wrapper for Gradient Checkpointing + References: https://github.com/microsoft/MoGe + """ + + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +if __name__ == "__main__": + dummy_model = UniCeptionEncoderBase(name="name", data_norm_type="norm") + dummy_vit_model = UniCeptionViTEncoderBase(name="name", data_norm_type="norm", patch_size=16) + print("Dummy Base Encoders created successfully!") diff --git a/UniCeption/uniception/models/encoders/cosmos.py b/UniCeption/uniception/models/encoders/cosmos.py new file mode 100644 index 0000000000000000000000000000000000000000..3c06bbf711c65c33a0a15dd38f9ecd4c3f4d3761 --- /dev/null +++ b/UniCeption/uniception/models/encoders/cosmos.py @@ -0,0 +1,137 @@ +""" +Encoder Class for Cosmos +""" + +import torch + +from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput +from uniception.models.libs.cosmos_tokenizer.modules import ContinuousFormulation, EncoderType +from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs + + +class CosmosEncoder(UniCeptionViTEncoderBase): + "Uniception Cosmos Encoder" + + def __init__( + self, + name: str, + data_norm_type: str = "cosmos", + patch_size: int = 8, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Cosmos Encoder for extracting spatial features from images. + + Args: + name (str): Name of the encoder. + data_norm_type (str): Image normalization type. Default: "cosmos" + patch_size (int): Patch size for the encoder. Default: 8 + pretrained_checkpoint_path (str): Path to the pretrained checkpoint. Default: None + """ + # Init the base class + super().__init__(name=name, data_norm_type=data_norm_type, patch_size=patch_size, *args, **kwargs) + + # Init Cosmos Encoder sepecific attributes + tokenizer_config = TokenizerConfigs["CI"].value.copy() + tokenizer_config.update(dict(spatial_compression=self.patch_size)) + + z_factor = tokenizer_config["z_factor"] + z_channels = tokenizer_config["z_channels"] + latent_channels = tokenizer_config["latent_channels"] + encoder_name = kwargs.get("encoder", EncoderType.Default.name) + print(tokenizer_config) + del tokenizer_config["z_factor"] + del tokenizer_config["z_channels"] + del tokenizer_config["latent_channels"] + self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **tokenizer_config) + self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1) + formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) + self.distribution = ContinuousFormulation[formulation_name].value() + + # Load the pretrained checkpoint + if pretrained_checkpoint_path is not None: + print(f"Loading custom pretrained Cosmos checkpoint from {pretrained_checkpoint_path}") + ckpt = torch.load(pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: + """Encodes an image into a latent embedding or code. + + Args: + input_tensor: The input tensor Bx3xHxW layout, range [-1..1]. + Returns: + For continuous image (CI) tokenizer, the tuple contains: + - The latent embedding, Bx16x(h)x(w), where the compression + rate is (H/h x W/w), and channel dimension of 16. + For discrete image (DI) tokenizer, the tuple contains: + - The indices, Bx(h)x(w), from a codebook of size 64K, which + corresponds to FSQ levels of (8,8,8,5,5,5). + - The discrete code, Bx6x(h)x(w), where the compression rate is + again (H/h x W/w), and channel dimension of 6. + """ + x = self.encoder(input_tensor) + x = self.quant_conv(x) + output_latent = self.distribution(x) + + if isinstance(output_latent, torch.Tensor): + return output_latent + return output_latent[:-1] + + def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput: + """ + Cosmos Encoder Forward Pass + + Args: + encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. + + Returns: + ViTEncoderOutput: Output data from the encoder. + """ + # Check image normalization type + self._check_data_normalization_type(encoder_input.data_norm_type) + + # Check the dtype and shape of the input image + assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor" + assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)" + batch_size, channels, height, width = encoder_input.image.shape + assert channels == 3, "Input must have 3 channels" + assert ( + height % self.patch_size == 0 and width % self.patch_size == 0 + ), f"Input shape must be divisible by patch size: {self.patch_size}" + + # Extract the features from the DINOv2 model + features = self.encode(encoder_input.image)[0].contiguous() + + return ViTEncoderOutput(features=features) + + +if __name__ == "__main__": + + # initialize different variants of the Cosmos Encoder, untrained + for is_continuous in [True]: + for patch_size in [8, 16]: + encoder = CosmosEncoder(name="cosmos", patch_size=patch_size) + + # # initialize from trained checkpoint, with/without jit inference capability + PRETRAINED_JIT_CHECKPOINTS = { + ("CI", 8): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth", + ("CI", 16): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI16x16/encoder.pth", + } + + for patch_size in [8, 16]: + + encoder = CosmosEncoder( + name="cosmos", + patch_size=patch_size, + pretrained_checkpoint_path=PRETRAINED_JIT_CHECKPOINTS[("CI", patch_size)], + ) + + # example inference + dummy_image = torch.randn(1, 3, 256, 256).cuda() + + encoder_input = ViTEncoderInput(data_norm_type="cosmos", image=dummy_image) + + encoder = encoder.cuda() + encoder_output = encoder(encoder_input) diff --git a/UniCeption/uniception/models/encoders/croco.py b/UniCeption/uniception/models/encoders/croco.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca5fc11105377a94a1d8938ba78e4d728c185bd --- /dev/null +++ b/UniCeption/uniception/models/encoders/croco.py @@ -0,0 +1,457 @@ +""" +Encoder Class for CroCo & DUSt3R +""" + +from functools import partial +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput +from uniception.models.libs.croco.blocks import Block +from uniception.models.libs.croco.patch_embed import get_patch_embed +from uniception.models.libs.croco.pos_embed import RoPE2D +from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices + + +class CroCoEncoder(UniCeptionViTEncoderBase): + "UniCeption CroCov2 Encoder" + + def __init__( + self, + name: str, + data_norm_type: str, + patch_embed_cls: str = "PatchEmbedDust3R", + img_size: Union[int, Tuple[int, int]] = (224, 224), + patch_size: int = 16, + enc_embed_dim: int = 1024, + enc_depth: int = 24, + enc_num_heads: int = 16, + mlp_ratio: int = 4, + norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6), + pos_embed: str = "RoPE100", + pretrained_checkpoint_path: str = None, + override_checkpoint_attributes: bool = False, + *args, + **kwargs, + ): + """ + References: https://github.com/naver/dust3r, https://github.com/naver/croco + + Args: + name (str): Name of the encoder. + data_norm_type (str): Input data normalization type. + patch_embed_cls (str, optional): The class to use for patch embedding. + Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed']. + img_size (int, optional): The size of the input image. Defaults to 224. + patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16. + enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768. + enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12. + enc_num_heads (int, optional): The number of encoder heads. Defaults to 12. + mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4. + norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6. + pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['RoPEfreq']. + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None. + """ + # Init the base class + super().__init__( + name=name, + data_norm_type=data_norm_type, + patch_size=patch_size, + *args, + **kwargs, + ) + + # Init the CroCo Encoder specific attributes + self.patch_embed_cls = patch_embed_cls + self.img_size = img_size + self.enc_embed_dim = enc_embed_dim + self.enc_depth = enc_depth + self.enc_num_heads = enc_num_heads + self.mlp_ratio = mlp_ratio + self.norm_layer = norm_layer + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.override_checkpoint_attributes = override_checkpoint_attributes + + # Init the positional embedding + self.pos_embed = pos_embed + if pos_embed.startswith("RoPE"): # eg RoPE100 + self.enc_pos_embed = None # nothing to add in the encoder with RoPE + self.dec_pos_embed = None # nothing to add in the decoder with RoPE + if RoPE2D is None: + raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") + freq = float(pos_embed[len("RoPE") :]) + self.rope = RoPE2D(freq=freq) + else: + raise NotImplementedError("Unknown pos_embed " + pos_embed) + + # Init the patch embedding + self._set_patch_embed(img_size, patch_size, enc_embed_dim) + + # Init the encoder + self._set_encoder(enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, self.rope) + + # Initialize random weights + self.initialize_weights() + + # Load the pretrained CroCo checkpoint if provided + if pretrained_checkpoint_path: + print(f"Loading pretrained CroCo checkpoint from {pretrained_checkpoint_path}") + ckpt = torch.load(pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + if not override_checkpoint_attributes: + ckpt_data_norm_type = ckpt["data_norm_type"] + ckpt_patch_embed_cls = ckpt["patch_embed_cls"] + assert ( + data_norm_type == ckpt_data_norm_type + ), f"Data normalization type {data_norm_type} does not match the checkpoint {ckpt_data_norm_type}." + assert ( + patch_embed_cls == ckpt_patch_embed_cls + ), f"Patch embedding class {patch_embed_cls} does not match the checkpoint {ckpt_patch_embed_cls}." + else: + print("No pretrained checkpoint provided. Randomly initializing the CroCo encoder.") + + def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768): + "Set the patch embedding scheme" + self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim) + + def _set_encoder(self, enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, rope): + "Set the encoder" + self.enc_blocks = nn.ModuleList( + [ + Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=rope) + for _ in range(enc_depth) + ] + ) + self.enc_norm = norm_layer(enc_embed_dim) + + def initialize_weights(self): + "Initialize the weights of the patch embedding and the transformer encoder" + # Patch embedding + self.patch_embed._init_weights() + # Linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + "Initialize the transformer encoder weights" + if isinstance(m, nn.Linear): + # We use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput: + """ + CroCov2 Encoder Forward Pass + + Args: + encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. + + Returns: + ViTEncoderOutput: Output data from the encoder. + """ + # Check image normalization type + self._check_data_normalization_type(encoder_input.data_norm_type) + + # Get the true shape of the image for landscape/portrait mode check in patch embedding + batch_size, _, height, width = encoder_input.image.shape + if hasattr(encoder_input, "true_shape"): + true_shape = encoder_input.true_shape + else: + true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1) + + # Embed the image into patches + features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape) + + # Now apply the transformer encoder and normalization + for blk in self.enc_blocks: + features = blk(features, pos) + features = self.enc_norm(features) + + # Resize the features to the expected shape + # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) + features = features.permute(0, 2, 1) + features = features.reshape( + -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size + ).contiguous() + + return ViTEncoderOutput(features=features) + + +class CroCoIntermediateFeatureReturner(CroCoEncoder, IntermediateFeatureReturner): + "Intermediate Feature Returner for UniCeption CroCo Encoder" + + def __init__( + self, + name: str, + data_norm_type: str, + patch_embed_cls: str = "PatchEmbedDust3R", + img_size: Union[int, Tuple[int, int]] = (224, 224), + patch_size: int = 16, + enc_embed_dim: int = 1024, + enc_depth: int = 24, + enc_num_heads: int = 16, + mlp_ratio: int = 4, + norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6), + pos_embed: str = "RoPE100", + pretrained_checkpoint_path: str = None, + indices: Optional[Union[int, List[int]]] = None, + norm_intermediate: bool = True, + stop_early: bool = False, + intermediates_only: bool = True, + *args, + **kwargs, + ): + """ + Intermediate Feature Returner for the CroCo Encoder. + + Args: + name (str): Name of the encoder. + data_norm_type (str): Input data normalization type. + patch_embed_cls (str, optional): The class to use for patch embedding. + Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed']. + img_size (int, optional): The size of the input image. Defaults to 224. + patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16. + enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768. + enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12. + enc_num_heads (int, optional): The number of encoder heads. Defaults to 12. + mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4. + norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6. + pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['cosine', 'RoPE100']. + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None. + indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to None. Options: + - None: Return all intermediate layers. + - int: Return the last n layers. + - List[int]: Return the intermediate layers at the specified indices. + norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True. + stop_early (bool, optional): Whether to stop early. Defaults to False. + intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True. + """ + # Init the base classes + CroCoEncoder.__init__( + self, + name=name, + data_norm_type=data_norm_type, + patch_embed_cls=patch_embed_cls, + img_size=img_size, + patch_size=patch_size, + enc_embed_dim=enc_embed_dim, + enc_depth=enc_depth, + enc_num_heads=enc_num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + pos_embed=pos_embed, + pretrained_checkpoint_path=pretrained_checkpoint_path, + *args, + **kwargs, + ) + IntermediateFeatureReturner.__init__( + self, + indices=indices, + norm_intermediate=norm_intermediate, + stop_early=stop_early, + intermediates_only=intermediates_only, + ) + + def forward( + self, encoder_input: ViTEncoderInput + ) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: + """ + CroCov2 Encoder Forward Pass with Intermediate Feature Return + + Args: + encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. + + Returns: + Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder. + If `intermediates_only` is True, returns a list of intermediate features. + Otherwise, returns a tuple with the final features and a list of intermediate features. + """ + # Check image normalization type + self._check_data_normalization_type(encoder_input.data_norm_type) + + # Get the true shape of the image for landscape/portrait mode check in patch embedding + batch_size, _, height, width = encoder_input.image.shape + if hasattr(encoder_input, "true_shape"): + true_shape = encoder_input.true_shape + else: + true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1) + + # Embed the image into patches + features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape) + + # Get indices of the intermediate features to return + intermediate_features = [] + take_indices, max_index = feature_take_indices(len(self.enc_blocks), self.indices) + + # Get the blocks based on early stopping + if torch.jit.is_scripting() or not self.stop_early: # can't slice blocks in torchscript + blocks = self.enc_blocks + else: + blocks = self.enc_blocks[: max_index + 1] + + # Now apply the transformer encoder and normalization + for blk_idx, blk in enumerate(blocks): + features = blk(features, pos) + if blk_idx in take_indices: + # Normalize intermediates with final norm layer if enabled + intermediate_features.append(self.enc_norm(features) if self.norm_intermediate else features) + + # Reshape the intermediate features and convert to ViTEncoderOutput class + intermediate_features = [ + intermediate.permute(0, 2, 1) + .reshape(-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size) + .contiguous() + for intermediate in intermediate_features + ] + intermediate_features = [ViTEncoderOutput(features=intermediate) for intermediate in intermediate_features] + + # Return only the intermediate features if enabled + if self.intermediates_only: + return intermediate_features + + # Normalize and reshape the final features + features = self.enc_norm(features) + # Resize the features to the expected shape + # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) + features = features.permute(0, 2, 1) + features = features.reshape( + -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size + ).contiguous() + final_features = ViTEncoderOutput(features=features) + + return final_features, intermediate_features + + +if __name__ == "__main__": + # Init the pre-trained CroCo Encoder + pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224.pth" + croco_encoder = CroCoEncoder( + name="croco", + data_norm_type="croco", + pretrained_checkpoint_path=pretrained_checkpoint_path, + patch_embed_cls="PatchEmbedCroCo", + ) + + # Init the pre-trained DUSt3R CroCo Encoder + pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224_DUSt3R_linear.pth" + dust3r_encoder = CroCoEncoder( + name="dust3r_224", + data_norm_type="dust3r", + pretrained_checkpoint_path=pretrained_checkpoint_path, + patch_embed_cls="PatchEmbedDust3R", + ) + + # Init the pre-trained DUSt3R 512 linear CroCo Encoder + pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_linear.pth" + dust3r_encoder_512 = CroCoEncoder( + name="dust3r_512", + data_norm_type="dust3r", + pretrained_checkpoint_path=pretrained_checkpoint_path, + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + ) + + # Init the pre-trained DUSt3R 512 DPT CroCo Encoder + pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth" + dust3r_encoder_512_dpt = CroCoEncoder( + name="dust3r_512_dpt", + data_norm_type="dust3r", + pretrained_checkpoint_path=pretrained_checkpoint_path, + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + ) + + # Init the MASt3R 512 CroCo Encoder + pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_MASt3R.pth" + mast3r_encoder_512 = CroCoEncoder( + name="mast3r_512", + data_norm_type="dust3r", + pretrained_checkpoint_path=pretrained_checkpoint_path, + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + ) + + print("All CroCo & DUSt3R Encoders have been initialized successfully!") + + # Intermediate Feature Returner Tests + print("Running Intermediate Feature Returner Tests...") + pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth" + + # Run the intermediate feature returner with last-n index + dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner( + name="dust3r_512_dpt", + data_norm_type="dust3r", + pretrained_checkpoint_path=pretrained_checkpoint_path, + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + indices=6, # Last 6 layers + ) + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r") + output = dust3r_intermediate_feature_returner(dummy_input) + assert isinstance(output, list), "Output must be a list of intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices" + + # Run the intermediate feature returner with specific indices + dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner( + name="dust3r_512_dpt", + data_norm_type="dust3r", + pretrained_checkpoint_path=pretrained_checkpoint_path, + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + indices=[0, 2, 4, 6], # Specific layers + ) + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r") + output = dust3r_intermediate_feature_returner(dummy_input) + assert isinstance(output, list), "Output must be a list of intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices" + + # Test the normalizing of intermediate features + dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner( + name="dust3r_512_dpt", + data_norm_type="dust3r", + pretrained_checkpoint_path=pretrained_checkpoint_path, + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + indices=[-1], + norm_intermediate=False, + intermediates_only=False, + ) + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r") + output = dust3r_intermediate_feature_returner(dummy_input) + assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features" + assert isinstance(output[1], list), "Second element of output must be a list of intermediate features" + assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + if not isinstance(dust3r_intermediate_feature_returner.enc_norm, torch.nn.Identity): + assert not torch.equal( + output[0].features, output[1][0].features + ), "Final features and intermediate features must be different" + + dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner( + name="dust3r_512_dpt", + data_norm_type="dust3r", + pretrained_checkpoint_path=pretrained_checkpoint_path, + patch_embed_cls="ManyAR_PatchEmbed", + img_size=(512, 512), + indices=[-1], + norm_intermediate=True, + intermediates_only=False, + ) + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r") + output = dust3r_intermediate_feature_returner(dummy_input) + assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features" + assert isinstance(output[1], list), "Second element of output must be a list of intermediate features" + assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert torch.equal( + output[0].features, output[1][0].features + ), "Final features and intermediate features must be same" + + print("All Intermediate Feature Returner Tests have passed successfully!") diff --git a/UniCeption/uniception/models/encoders/dense_rep_encoder.py b/UniCeption/uniception/models/encoders/dense_rep_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..30d4f2eb751be3be3a61914bc2d2884732958775 --- /dev/null +++ b/UniCeption/uniception/models/encoders/dense_rep_encoder.py @@ -0,0 +1,344 @@ +""" +Encoder class for Dense Representation Encoder +""" + +import math +from functools import partial +from typing import Callable, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ + +from uniception.models.encoders.base import ( + UniCeptionViTEncoderBase, + ViTEncoderInput, + ViTEncoderNonImageInput, + ViTEncoderOutput, +) + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class ResidualBlock(nn.Module): + "Redidual block for Dense Representation Encoder" + + def __init__(self, in_channels: int, out_channels: int, act_layer: Type[nn.Module] = nn.GELU): + super(ResidualBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.act = act_layer() + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.shortcut = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x): + identity = self.shortcut(x) + out = self.conv1(x) + out = self.act(out) + out = self.conv2(out) + out += identity + + return self.act(out) + + +class DenseRepresentationEncoder(UniCeptionViTEncoderBase): + "UniCeption Dense Representation Encoder" + + def __init__( + self, + name: str, + in_chans: int = 3, + enc_embed_dim: int = 1024, + apply_pe: bool = True, + input_size_for_pe: Union[int, Tuple[int, int]] = 518, + patch_size: int = 14, + intermediate_dims: List[int] = [588, 768, 1024], + data_norm_type: str = "dense_rep_encoder", + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6), + post_pe_norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6), + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Dense Representation Encoder for extracting patch-wise features from a spatial input of size (B, C, H, W). + Uses a convolution based patchify followed by some residual blocks. + Also applies positional encoding with interpolation to the patch-wise features if required. + + Args: + in_chans (int): Number of input channels. + enc_embed_dim (int): Embedding dimension of the encoder. + apply_pe (bool): Whether to apply positional encoding. + input_size_for_pe (Union[int, Tuple[int, int]]): Input size for positional encoding. + patch_size (int): Patch size of the encoder. + intermediate_dims (List[int]): Intermediate dimensions of the encoder. + data_norm_type (str): Data normalization type. (Used for checking if the input images are normalized correctly.) + act_layer (Type[nn.Module]): Activation layer. + norm_layer (Optional[Callable]): Normalization layer. + post_pe_norm_layer (Optional[Callable]): Normalization layer after positional encoding. + interpolate_antialias (bool): Whether to apply antialiasing in interpolation. + interpolate_offset (float): Offset for interpolation. + pretrained_checkpoint_path (str): Path to pretrained checkpoint. + """ + # Init the base class + super().__init__( + name=name, + data_norm_type=data_norm_type, + patch_size=patch_size, + *args, + **kwargs, + ) + + # Init the specific attributes + self.in_chans = in_chans + self.enc_embed_dim = enc_embed_dim + self.intermediate_dims = intermediate_dims + self.apply_pe = apply_pe + + # Initialize the encoder with a pixel unshuffle and conv projection to patchify the input + self.unshuffle = nn.PixelUnshuffle(self.patch_size) + self.conv_in = nn.Conv2d(self.in_chans * (self.patch_size**2), self.intermediate_dims[0], 3, 1, 1) + + # Add residual blocks + layers = [] + for intermediate_idx in range(len(self.intermediate_dims) - 1): + layers.append( + ResidualBlock( + in_channels=self.intermediate_dims[intermediate_idx], + out_channels=self.intermediate_dims[intermediate_idx + 1], + act_layer=act_layer, + ) + ) + + # Final projection to match encoder embeddings dim + layers.append( + nn.Conv2d( + in_channels=self.intermediate_dims[-1], + out_channels=self.enc_embed_dim, + kernel_size=1, + stride=1, + padding=0, + ) + ) + self.encoder = nn.Sequential(*layers) + + # Init norm layer after encoder if required + self.norm_layer = norm_layer(enc_embed_dim) if norm_layer else nn.Identity() + if isinstance(self.norm_layer, nn.LayerNorm): + nn.init.constant_(self.norm_layer.bias, 0) + nn.init.constant_(self.norm_layer.weight, 1.0) + + if self.apply_pe: + # Init the patch resolution details required for positional encoding + patch_HW = make_2tuple(patch_size) + self.input_size_for_pe = make_2tuple(input_size_for_pe) + self.patches_resolution = ( + self.input_size_for_pe[0] // patch_HW[0], + self.input_size_for_pe[1] // patch_HW[1], + ) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + + # Init the sinusodial positional encodings + self.register_buffer( + "pos_embed", + self._get_sinusoid_encoding_table(self.num_patches, self.enc_embed_dim, 70007), + ) + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + # Init the norm layer after positional encoding if required + self.post_pe_norm = post_pe_norm_layer(enc_embed_dim) if post_pe_norm_layer else nn.Identity() + if isinstance(self.post_pe_norm, nn.LayerNorm): + nn.init.constant_(self.post_pe_norm.bias, 0) + nn.init.constant_(self.post_pe_norm.weight, 1.0) + + # Load the pretrained checkpoint if provided + self.pretrained_checkpoint_path = pretrained_checkpoint_path + if self.pretrained_checkpoint_path: + print( + f"Loading custom pretrained Dense Representation Encoder checkpoint from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def _get_sinusoid_encoding_table(self, n_position, d_hid, base): + "Sinusoid position encoding table" + + def get_position_angle_vec(position): + return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) + + return torch.FloatTensor(sinusoid_table) + + def interpolate_pos_encoding(self, features, height, width): + """ + Interpolate the positional encoding to the expected size. + + Args: + features (torch.Tensor): Input tensor of shape (B, N, C). + height (int, float): Height of the input tensor. + width (int, float): Width of the input tensor. + + Returns: + torch.Tensor: Interpolated positional encoding tensor of shape (1, N, C). + """ + previous_dtype = features.dtype + npatch = features.shape[1] + N = self.pos_embed.unsqueeze(0).shape[1] + if npatch == N and height == width: + return self.pos_embed.unsqueeze(0) + patch_pos_embed = self.pos_embed.unsqueeze(0).float() + dim = features.shape[-1] + height0 = height // self.patch_size + width0 = width // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sh = float(height0 + self.interpolate_offset) / M + sw = float(width0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sh, sw) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (height0, width0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (height0, width0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return patch_pos_embed.to(previous_dtype) + + def forward(self, encoder_input: Union[ViTEncoderInput, ViTEncoderNonImageInput]) -> ViTEncoderOutput: + """ + Dense Representation Encoder Forward Pass + + Args: + encoder_input (Union[ViTEncoderInput, ViTEncoderNonImageInput]): Input data for the encoder. + If input type is ViTEncoderInput, input data must contain image normalization type and normalized image tensor. + If input type is ViTEncoderNonImageInput, input data must contain a tensor of size (B, C, H, W). + + Returns: + ViTEncoderOutput: Output data from the encoder. + """ + # Get the input data and verify normalization if the input type is ViTEncoderInput + if isinstance(encoder_input, ViTEncoderInput): + self._check_data_normalization_type(encoder_input.data_norm_type) + input_data = encoder_input.image + elif isinstance(encoder_input, ViTEncoderNonImageInput): + input_data = encoder_input.data + else: + raise ValueError("Unsupported input type for Dense Representation Encoder.") + + # Check the dtype and shape of the input + assert isinstance(input_data, torch.Tensor), "Input must be a torch.Tensor" + assert input_data.ndim == 4, "Input must be of shape (B, C, H, W)" + assert input_data.shape[1] == self.in_chans, f"Input channels must be {self.in_chans}" + batch_size, channels, height, width = input_data.shape + assert ( + height % self.patch_size == 0 and width % self.patch_size == 0 + ), f"Input shape must be divisible by patch size: {self.patch_size}" + + # Encode the dense representation + features = self.unshuffle(input_data) + features = self.conv_in(features) + features = self.encoder(features) + features = features.flatten(2).transpose( + 1, 2 + ) # (B, E, H / Patch_Size, W / Patch_Size) -> (B, H / Patch_Size * W / Patch_Size, E) + features = self.norm_layer(features) # Normalize the features after patch encoding + + # Apply positional encoding if required + if self.apply_pe: + features = features + self.interpolate_pos_encoding( + features, height, width + ) # (B, H / Patch_Size * W / Patch_Size, E) + features = self.post_pe_norm(features) # Normalize the features after positional encoding + + # Resize the features to the expected shape + # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) + features = features.permute(0, 2, 1) + features = features.reshape( + -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size + ).contiguous() + + return ViTEncoderOutput(features=features) + + +if __name__ == "__main__": + # Init Dense Representation Encoder for images as input + patch_embedder = DenseRepresentationEncoder( + name="dense_rep_encoder", + data_norm_type="dense_rep_encoder", + input_size_for_pe=518, + patch_size=14, + in_chans=3, + enc_embed_dim=1024, + apply_pe=False, + ) + + # Test dummy image input + dummy_image = torch.randn(1, 3, 518, 518) + patch_embedder_output = patch_embedder(ViTEncoderInput(data_norm_type="dense_rep_encoder", image=dummy_image)) + assert patch_embedder_output.features.shape == ( + 1, + 1024, + 37, + 37, + ), "Output features must have shape (1, 1024, 37, 37)" + + # Init Dense Representation Encoder for non-image data as input + patch_embedder = DenseRepresentationEncoder( + name="dense_rep_encoder", + data_norm_type="dense_rep_encoder", + input_size_for_pe=518, + patch_size=14, + in_chans=6, + enc_embed_dim=1024, + ) + + # Init Dense Representation Encoder for single channel input + patch_embedder = DenseRepresentationEncoder( + name="dense_rep_encoder", + data_norm_type="dense_rep_encoder", + input_size_for_pe=518, + patch_size=14, + in_chans=1, + enc_embed_dim=1024, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + apply_pe=True, + ) + + # Test dummy non-image input + dummy_image = torch.randn(1, 1, 980, 980) + patch_embedder_output = patch_embedder(ViTEncoderNonImageInput(data=dummy_image)) + assert patch_embedder_output.features.shape == ( + 1, + 1024, + 70, + 70, + ), "Output features must have shape (1, 1024, 70, 70)" + + print("All variants of Dense Representation Encoder have been initialized successfully!") diff --git a/UniCeption/uniception/models/encoders/dinov2.py b/UniCeption/uniception/models/encoders/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..3002ef7131d76dc4021a7e27bee389f10d71170c --- /dev/null +++ b/UniCeption/uniception/models/encoders/dinov2.py @@ -0,0 +1,333 @@ +""" +Encoder Class for DINOv2 +""" + +from typing import List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput +from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner + + +class DINOv2Encoder(UniCeptionViTEncoderBase): + "UniCeption DINOv2 Encoder" + + def __init__( + self, + name: str, + data_norm_type: str = "dinov2", + patch_size: int = 14, + size: str = "large", + with_registers: bool = False, + pretrained_checkpoint_path: str = None, + torch_hub_force_reload: bool = False, + gradient_checkpointing: bool = False, + keep_first_n_layers: Optional[int] = None, + use_pytorch_sdpa=True, + *args, + **kwargs, + ): + """ + DINOv2 Encoder for extracting spatial features from images. + + Args: + name (str): Name of the encoder. + data_norm_type (str): Image normalization type. Default: "dinov2" + patch_size (int): Patch size for the encoder. Default: 14 + size (str): Size variant of the DINOv2 model. Options: ["small", "base", "large", "giant"]. Default: "large" + with_registers (bool): Whether to use the DINOv2 model with registers. Default: False + pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of DINOv2. Default: None + torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False + gradient_checkpointing (bool): Whether to use gradient checkpointing to save GPU memory during backward call. Default: False + keep_first_n_layers (Optional[int]): If specified, only the first n layers of the model will be kept. Default: None + use_pytorch_sdpa (bool): Whether to use PyTorch native SDPA for attention layers. Default: True + """ + # Init the base class + name = name if not with_registers else f"{name}_reg" + super().__init__( + name=name, + data_norm_type=data_norm_type, + patch_size=patch_size, + gradient_checkpointing=gradient_checkpointing, + *args, + **kwargs, + ) + + # Init the DINOv2 Encoder specific attributes + self.version = size + self.with_registers = with_registers + self.enc_embed_dim = {"small": 384, "base": 768, "large": 1024, "giant": 1536}[self.version] + + # Define DINOv2 model factory + DINO_MODELS = { + # No registers + False: { + "small": "dinov2_vits14", + "base": "dinov2_vitb14", + "large": "dinov2_vitl14", + "giant": "dinov2_vitg14", + }, + # With registers + True: { + "small": "dinov2_vits14_reg", + "base": "dinov2_vitb14_reg", + "large": "dinov2_vitl14_reg", + "giant": "dinov2_vitg14_reg", + }, + } + + # Load the pretrained DINOv2 model from torch hub + print(f"Loading pretrained {DINO_MODELS[self.with_registers][self.version]} from torch hub") + try: # Requires internet access + self.model = torch.hub.load( + "facebookresearch/dinov2", + DINO_MODELS[self.with_registers][self.version], + force_reload=torch_hub_force_reload, + ) + except: # Load from cache + self.model = torch.hub.load("facebookresearch/dinov2", DINO_MODELS[self.with_registers][self.version]) + + del ( + self.model.mask_token + ) # This parameter is unused in producing patch features, and will lead to unused parameters + + # Keep only the first n layers of the model if keep_first_n_layers is specified + if keep_first_n_layers is not None: + self.model.blocks = nn.ModuleList(self.model.blocks[:keep_first_n_layers]) + + # Use Native Torch SDPA for attention layers if specified (instead of DINOv2's XFormers) + if use_pytorch_sdpa: + self.enable_pytorch_native_sdpa() + + # Wrap the transformer blocks with support for gradient checkpointing if required + if self.gradient_checkpointing: + for i in range(len(self.model.blocks)): + self.model.blocks[i] = self.wrap_module_with_gradient_checkpointing(self.model.blocks[i]) + + # Load the custom pretrained checkpoint if provided + if pretrained_checkpoint_path: + print(f"Loading custom pretrained DINOv2 checkpoint from {pretrained_checkpoint_path}") + ckpt = torch.load(pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def enable_pytorch_native_sdpa(self): + "Enable PyTorch native SDPA for attention layers" + for i in range(len(self.model.blocks)): + self.model.blocks[i].attn = self.wrap_dinov2_attention_with_sdpa(self.model.blocks[i].attn) + + def wrap_dinov2_attention_with_sdpa(self, module: nn.Module): + "Wrap DINOv2 attention module with PyTorch native SDPA" + assert torch.__version__ >= "2.0", "SDPA requires PyTorch 2.0 or later" + + class _AttentionWrapper(module.__class__): + "SDPA Attention Wrapper Class" + + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + ) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + module.__class__ = _AttentionWrapper + return module + + def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput: + """ + DINOv2 Encoder Forward Pass + + Args: + encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. + + Returns: + ViTEncoderOutput: Output data from the encoder. + """ + # Check image normalization type + self._check_data_normalization_type(encoder_input.data_norm_type) + + # Check the dtype and shape of the input image + assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor" + assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)" + batch_size, channels, height, width = encoder_input.image.shape + assert channels == 3, "Input must have 3 channels" + assert ( + height % self.patch_size == 0 and width % self.patch_size == 0 + ), f"Input shape must be divisible by patch size: {self.patch_size}" + + # Extract the features from the DINOv2 model + features = self.model.forward_features(encoder_input.image)["x_norm_patchtokens"] + + # Resize the features to the expected shape + # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) + features = features.permute(0, 2, 1) + features = features.reshape( + -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size + ).contiguous() + + return ViTEncoderOutput(features=features) + + +class DINOv2IntermediateFeatureReturner(DINOv2Encoder, IntermediateFeatureReturner): + "Intermediate Feature Returner for UniCeption DINOv2 Encoder" + + def __init__( + self, + name: str, + data_norm_type: str = "dinov2", + patch_size: int = 14, + size: str = "large", + with_registers: bool = False, + pretrained_checkpoint_path: str = None, + indices: Optional[Union[int, List[int]]] = 1, + keep_first_n_layers: Optional[int] = None, + norm_intermediate: bool = True, + *args, + **kwargs, + ): + """ + DINOv2 Encoder for extracting spatial features from images. + + Args: + name (str): Name of the encoder. + data_norm_type (str): Image normalization type. Default: "dinov2" + patch_size (int): Patch size for the encoder. Default: 14 + size (str): Size variant of the DINOv2 model. Options: ["small", "base", "large", "giant"] + with_registers (bool): Whether to use the DINOv2 model with registers. + pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of DINOv2. + indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to 1. Options: + - int: Return the last n layers. + - List[int]: Return the intermediate layers at the specified indices. + keep_first_n_layers (Optional[int], optional): If specified, only the first n layers of the model will be kept. Defaults to None. + norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True. + """ + # Init the base classes + DINOv2Encoder.__init__( + self, + name=name, + data_norm_type=data_norm_type, + patch_size=patch_size, + size=size, + with_registers=with_registers, + keep_first_n_layers=keep_first_n_layers, + pretrained_checkpoint_path=pretrained_checkpoint_path, + *args, + **kwargs, + ) + IntermediateFeatureReturner.__init__( + self, + indices=indices, + norm_intermediate=norm_intermediate, + ) + + def forward(self, encoder_input: ViTEncoderInput) -> List[ViTEncoderOutput]: + """ + DINOv2 Encoder Forward Pass with Intermediate Feature Return + + Args: + encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. + + Returns: + List[ViTEncoderOutput]: Output data from the encoder. Returns a list of intermediate features. + """ + # Check image normalization type + self._check_data_normalization_type(encoder_input.data_norm_type) + + # Check the dtype and shape of the input image + assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor" + assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)" + batch_size, channels, height, width = encoder_input.image.shape + assert channels == 3, "Input must have 3 channels" + assert ( + height % self.patch_size == 0 and width % self.patch_size == 0 + ), f"Input shape must be divisible by patch size: {self.patch_size}" + + if self.indices is None: + self.indices = range(len(self.model.blocks)) + + # Extract the intermediate features from the DINOv2 model + intermediate_features = self.model.get_intermediate_layers( + encoder_input.image, n=self.indices, reshape=True, norm=self.norm_intermediate + ) + + # Convert the intermediate features to a list of ViTEncoderOutput + intermediate_features = [ViTEncoderOutput(features=features) for features in intermediate_features] + + return intermediate_features + + +if __name__ == "__main__": + # Init different variants of DINOv2 + for size in ["small", "base", "large", "giant"]: + for with_registers in [False, True]: + name = f"dinov2_{size}" + dinov2_encoder = DINOv2Encoder(name=name, size=size, with_registers=with_registers) + + # Init the custom pretrained DINOv2 encoders + for size in ["small", "base", "large"]: + pretrained_checkpoints_dict = { + "small": "../../../checkpoints/encoders/DINOv2_ViTS_DepthAnythingV2.pth", + "base": "../../../checkpoints/encoders/DINOv2_ViTB_DepthAnythingV2.pth", + "large": "../../../checkpoints/encoders/DINOv2_ViTL_DepthAnythingV2.pth", + } + name = f"dinov2_dav2_{size}" + dinov2_encoder = DINOv2Encoder( + name=name, size=size, with_registers=False, pretrained_checkpoint_path=pretrained_checkpoints_dict[size] + ) + + print("All DINOv2 Encoders have been initialized successfully!") + + # Intermediate Feature Returner Tests + print("Running Intermediate Feature Returner Tests...") + + # Run the intermediate feature returner with last-n index + dinov2_intermediate_feature_returner = DINOv2IntermediateFeatureReturner( + name="dinov2_base", size="base", indices=6 + ) # Last 6 layers + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dinov2") + output = dinov2_intermediate_feature_returner(dummy_input) + assert isinstance(output, list), "Output must be a list of intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices" + + # Run the intermediate feature returner with specific indices + dinov2_intermediate_feature_returner = DINOv2IntermediateFeatureReturner( + name="dinov2_base", size="base", indices=[0, 2, 4, 6] + ) # Specific layers + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dinov2") + output = dinov2_intermediate_feature_returner(dummy_input) + assert isinstance(output, list), "Output must be a list of intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices" + + print("All Intermediate Feature Returner Tests have passed successfully!") + + from uniception.models.encoders.utils import profile_encoder + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Profile the DINOv2 Encoder + dinov2_encoder = DINOv2Encoder( + name="dinov2_large", size="large", use_pytorch_sdpa=True, gradient_checkpointing=True, keep_first_n_layers=12 + ).cuda() + dummy_input = ViTEncoderInput(image=torch.randn(24, 3, 560, 420).cuda(), data_norm_type="dinov2") + + class Profiler: + @profile_encoder(num_warmup=3, num_runs=20, autocast_precision="bfloat16", use_compile=True, dynamic=False) + def run_fn(self): + output = dinov2_encoder(dummy_input) + return output + + profiler = Profiler() + profiler.run_fn() diff --git a/UniCeption/uniception/models/encoders/global_rep_encoder.py b/UniCeption/uniception/models/encoders/global_rep_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..48c52851bb7448a7f1e504136380147d4323ce4d --- /dev/null +++ b/UniCeption/uniception/models/encoders/global_rep_encoder.py @@ -0,0 +1,115 @@ +""" +Encoder class for Global Representation Encoder +""" + +from functools import partial +from typing import Callable, List, Optional, Type, Union + +import torch +import torch.nn as nn + +from uniception.models.encoders.base import EncoderGlobalRepInput, EncoderGlobalRepOutput + + +class GlobalRepresentationEncoder(nn.Module): + "UniCeption Global Representation Encoder" + + def __init__( + self, + name: str, + in_chans: int = 3, + enc_embed_dim: int = 1024, + intermediate_dims: List[int] = [128, 256, 512], + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6), + pretrained_checkpoint_path: Optional[str] = None, + *args, + **kwargs, + ): + """ + Global Representation Encoder for projecting a global representation to a desired latent dimension. + + Args: + name (str): Name of the Encoder. + in_chans (int): Number of input channels. + enc_embed_dim (int): Embedding dimension of the encoder. + intermediate_dims (List[int]): List of intermediate dimensions of the encoder. + act_layer (Type[nn.Module]): Activation layer to use in the encoder. + norm_layer (Union[Type[nn.Module], Callable[..., nn.Module]]): Final normalization layer to use in the encoder. + pretrained_checkpoint_path (Optional[str]): Path to pretrained checkpoint. (default: None) + """ + super().__init__(*args, **kwargs) + + # Initialize the attributes + self.name = name + self.in_chans = in_chans + self.enc_embed_dim = enc_embed_dim + self.intermediate_dims = intermediate_dims + self.pretrained_checkpoint_path = pretrained_checkpoint_path + + # Init the activation layer + self.act_layer = act_layer() + + # Initialize the encoder + self.encoder = nn.Sequential( + nn.Linear(self.in_chans, self.intermediate_dims[0]), + self.act_layer, + ) + for intermediate_idx in range(1, len(self.intermediate_dims)): + self.encoder = nn.Sequential( + self.encoder, + nn.Linear(self.intermediate_dims[intermediate_idx - 1], self.intermediate_dims[intermediate_idx]), + self.act_layer, + ) + self.encoder = nn.Sequential( + self.encoder, + nn.Linear(self.intermediate_dims[-1], self.enc_embed_dim), + ) + + # Init weights of the final norm layer + self.norm_layer = norm_layer(enc_embed_dim) if norm_layer else nn.Identity() + if isinstance(self.norm_layer, nn.LayerNorm): + nn.init.constant_(self.norm_layer.bias, 0) + nn.init.constant_(self.norm_layer.weight, 1.0) + + # Load pretrained weights if provided + if self.pretrained_checkpoint_path is not None: + print( + f"Loading pretrained Global Representation Encoder checkpoint from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, encoder_input: EncoderGlobalRepInput) -> EncoderGlobalRepOutput: + """ + Global Representation Encoder Forward Pass + + Args: + encoder_input (EncoderGlobalRepInput): Input data for the encoder. + The provided data must contain a tensor of size (B, C). + + Returns: + EncoderGlobalRepOutput: Output features from the encoder. + """ + # Get the input data and verify the shape of the input + input_data = encoder_input.data + assert input_data.ndim == 2, "Input data must have shape (B, C)" + assert input_data.shape[1] == self.in_chans, f"Input data must have {self.in_chans} channels" + + # Encode the global representation + features = self.encoder(input_data) + + # Normalize the output + features = self.norm_layer(features) + + return EncoderGlobalRepOutput(features=features) + + +if __name__ == "__main__": + dummy_model = GlobalRepresentationEncoder( + name="dummy", in_chans=3, enc_embed_dim=1024, intermediate_dims=[128, 256, 512] + ) + dummy_input = EncoderGlobalRepInput(data=torch.randn(4, 3)) + dummy_output = dummy_model(dummy_input) + assert dummy_output.features.shape == (4, 1024), "Output features must have shape (B, 1024)" + print("Global Representation Encoder has been initialized successfully!") diff --git a/UniCeption/uniception/models/encoders/image_normalizations.py b/UniCeption/uniception/models/encoders/image_normalizations.py new file mode 100644 index 0000000000000000000000000000000000000000..35e4c0db0ec5d4a5cd7b1c438c23611e6984ca14 --- /dev/null +++ b/UniCeption/uniception/models/encoders/image_normalizations.py @@ -0,0 +1,35 @@ +""" +Image normalizations for the different UniCeption image encoders. +Image encoders defined in UniCeption must have their corresponding image normalization defined here. +""" + +from dataclasses import dataclass + +import torch + + +@dataclass +class ImageNormalization: + mean: torch.Tensor + std: torch.Tensor + + +IMAGE_NORMALIZATION_DICT = { + "dummy": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])), + "croco": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])), + "dust3r": ImageNormalization(mean=torch.tensor([0.5, 0.5, 0.5]), std=torch.tensor([0.5, 0.5, 0.5])), + "dinov2": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])), + "identity": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])), + "patch_embedder": ImageNormalization( + mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]) + ), + "radio": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])), + "sea_raft": ImageNormalization( + mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0]) / 255 + ), # Sea-RAFT uses 0-255 in FP32 + "unimatch": ImageNormalization( + mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0]) / 255 + ), # UniMatch uses 0-255 in FP32 + "roma": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])), + "cosmos": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([0.5, 0.5, 0.5])), +} diff --git a/UniCeption/uniception/models/encoders/list.py b/UniCeption/uniception/models/encoders/list.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4ab13ef18ebe7ec8422837f0ab81165b0469f6 --- /dev/null +++ b/UniCeption/uniception/models/encoders/list.py @@ -0,0 +1,10 @@ +""" +List available UniCeption encoders. +""" + +import argparse + +from uniception.models.encoders import print_available_encoder_models + +if __name__ == "__main__": + print_available_encoder_models() diff --git a/UniCeption/uniception/models/encoders/naradio.py b/UniCeption/uniception/models/encoders/naradio.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddaa88de27f1387c22e5cea19194dc5877f0fa5 --- /dev/null +++ b/UniCeption/uniception/models/encoders/naradio.py @@ -0,0 +1,502 @@ +""" +Encoder Class for NARADIO (RayFronts) +""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention.flex_attention import flex_attention + +from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput +from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner + + +class GaussKernelAttn(nn.Module): + """Implementation of Gaussian Kernel based Attention using FlexAttention""" + + def __init__( + self, + orig_attn, + gauss_std: float, + dim: int, + qk_norm: bool = False, + num_prefix_tokens: int = 8, + patch_size: int = 16, + ) -> None: + super().__init__() + num_heads = orig_attn.num_heads + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.addition_cache = dict() + self.input_resolution = None # to be set when calling forward + self.gauss_std = gauss_std + self.patch_size = patch_size + + self.qkv = orig_attn.qkv + self.q_norm = orig_attn.q_norm if qk_norm else nn.Identity() + self.k_norm = orig_attn.k_norm if qk_norm else nn.Identity() + self.attn_drop = orig_attn.attn_drop + self.proj = orig_attn.proj + self.proj_drop = orig_attn.proj_drop + self.num_prefix_tokens = num_prefix_tokens + + @staticmethod + def gaussian_window(dim1, dim2, std=7.0): + constant = 1 / (std * math.sqrt(2)) + ks = list() + for dim in [dim1, dim2]: + start = -(dim - 1) / 2.0 + k = torch.linspace(start=start * constant, end=(start + (dim - 1)) * constant, steps=dim, dtype=torch.float) + ks.append(k) + dist_square_to_mu = (torch.stack(torch.meshgrid(*ks, indexing="ij")) ** 2).sum(0) + + return torch.exp(-dist_square_to_mu) + + @staticmethod + def get_attention_addition(dim1, dim2, window, num_prefix_tokens=8): + m = torch.einsum("ij,kl->ijkl", torch.eye(dim1), torch.eye(dim2)) + m = m.permute((0, 3, 1, 2)).contiguous() + out = F.conv2d(m.view(-1, dim1, dim2).unsqueeze(1), window.unsqueeze(0).unsqueeze(1), padding="same").squeeze(1) + + out = out.view(dim1 * dim2, dim1 * dim2) + if num_prefix_tokens > 0: + v_adjusted = torch.vstack([torch.zeros((num_prefix_tokens, dim1 * dim2)), out]) + out = torch.hstack([torch.zeros((dim1 * dim2 + num_prefix_tokens, num_prefix_tokens)), v_adjusted]) + + return out + + def prepare_gaussian_addition(self, n_patches, device): + """Prepare the Gaussian addition matrix for the current input""" + # Check if we have a cached addition matrix for these dimensions + if n_patches not in self.addition_cache: + window_size = [side * 2 - 1 for side in n_patches] + window = self.gaussian_window(*window_size, std=self.gauss_std) + addition = self.get_attention_addition(*n_patches, window, self.num_prefix_tokens).to(device) + + # Cache the addition matrix + self.addition_cache[n_patches] = addition + + # Return the cached addition matrix + return self.addition_cache[n_patches] + + def gauss_score_mod(self, score, b, h, q_idx, kv_idx, addition): + """Score modification function for FlexAttention""" + # Adding the precomputed Gaussian pattern to the attention score + return score + addition[q_idx, kv_idx] + + def set_input_resolution(self, input_resolution: Tuple[int, int]): + """Set the input resolution for the Gaussian attention window""" + self.input_resolution = input_resolution + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + assert self.input_resolution is not None, "input_resolution must be set before forward pass" + h, w = self.input_resolution + n_patches = (w // self.patch_size, h // self.patch_size) + + qkv = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + q, k = self.q_norm(q), self.k_norm(k) + + q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + addition = self.prepare_gaussian_addition(n_patches, device=x.device) + + # Create a score_mod function with the current addition matrix + score_mod = lambda score, b, h, q_idx, kv_idx: self.gauss_score_mod(score, b, h, q_idx, kv_idx, addition) + + # Use FlexAttention + attn_output = flex_attention(q, k, v, score_mod=score_mod) + + # Reshape output and apply projection + attn_output = attn_output.transpose(1, 2).reshape(B, N, C) + attn_output = self.proj(attn_output) + attn_output = self.proj_drop(attn_output) + + return attn_output + + +class NARADIOEncoder(UniCeptionViTEncoderBase): + """ + UniCeption NARADIO (RayFronts) Encoder based on NACLIP & RADIO + + The model modifies the attention of the last layer of RADIO following NACLIP, + thereby improving the spatial patch features. + """ + + def __init__( + self, + name: str, + data_norm_type: str = "radio", + patch_size: int = 16, + model_version: str = "radio_v2.5-l", + gauss_std: float = 7.0, + pretrained_checkpoint_path: str = None, + eradio_input_shape: Optional[tuple] = None, + torch_hub_force_reload: bool = False, + keep_first_n_layers: Optional[int] = None, + *args, + **kwargs, + ): + """ + NARADIO Encoder for extracting spatial features from images. + + Args: + name (str): Name of the encoder. + data_norm_type (str): Image normalization type. Default: "radio" + patch_size (int): Patch size for the encoder. Default: 16 + model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l" + gauss_std: Standard deviation of the gaussian kernel. Default: 7.0 + pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO. Default: None + eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None + torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False + keep_first_n_layers (Optional[int]): Number of layers to keep from the pretrained model. Default: None + """ + # Init the base class + super().__init__( + name=name, + data_norm_type=data_norm_type, + patch_size=patch_size, + *args, + **kwargs, + ) + + # Init the RADIO Encoder specific attributes + self.model_version = model_version + self.enc_embed_dim = { + "radio_v2.5-b": 768, + "radio_v2.5-l": 1024, + "radio_v2.5-h": 1280, + "radio_v2.5-g": 1536, + "e-radio_v2": 1536, + }[self.model_version] + + if self.model_version == "radio_v2.5-g": + assert patch_size == 14, "Patch size must be 14 for RADIO v2.5-g" + else: + assert patch_size == 16, "Patch size must be 16 for all other versions of RADIO" + + # Load the pretrained RADIO model from torch hub + print(f"Loading pretrained {self.model_version} from torch hub") + try: # Requires internet access + self.model = torch.hub.load( + "NVlabs/RADIO", + "radio_model", + version=self.model_version, + progress=True, + skip_validation=True, + force_reload=torch_hub_force_reload, + ) + except: # Load from cache + self.model = torch.hub.load( + "NVlabs/RADIO", + "radio_model", + version=self.model_version, + progress=True, + skip_validation=True, + ) + + # Delete the excess blocks if keep_first_n_layers is specified + if keep_first_n_layers is not None: + assert keep_first_n_layers < len( + self.model.model.blocks + ), "keep_first_n_layers must be less than the number of blocks" + print(f"Keeping only the first {keep_first_n_layers} layers of the model") + self.model.model.blocks = torch.nn.ModuleList(self.model.model.blocks[:keep_first_n_layers]) + + # Set the optimal window size for E-RADIO models + if "e-radio" in self.model_version: + assert eradio_input_shape is not None, "Input shape (height, width) must be provided for E-RADIO models" + self.model.model.set_optimal_window_size(eradio_input_shape) + + # Load the custom pretrained checkpoint if provided + if pretrained_checkpoint_path is not None: + print(f"Loading custom pretrained NARADIO checkpoint from {pretrained_checkpoint_path}") + ckpt = torch.load(pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + # Replace the attention of the last ViT block with the Gaussian Kernel based attention + self.model.model.blocks[-1] = GaussKernelAttn( + self.model.model.blocks[-1].attn, + gauss_std, + dim=self.enc_embed_dim, + num_prefix_tokens=self.model.num_summary_tokens, + patch_size=self.patch_size, + ) + + def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput: + """ + NARADIO Encoder Forward Pass + + Args: + encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. + + Returns: + ViTEncoderOutput: Output data from the encoder. + """ + # Check image normalization type + self._check_data_normalization_type(encoder_input.data_norm_type) + + # Check the dtype and shape of the input image + assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor" + assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)" + batch_size, channels, height, width = encoder_input.image.shape + assert channels == 3, "Input must have 3 channels" + assert ( + height % self.patch_size == 0 and width % self.patch_size == 0 + ), f"Input shape must be divisible by patch size: {self.patch_size}" + + # Set input resolution for Gaussian attention + self.model.model.blocks[-1].set_input_resolution((height, width)) + + # Forward pass throught the RADIO encoder + summary, features = self.model(encoder_input.image) + + # Resize the features to the expected shape + # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) + features = features.permute(0, 2, 1) + features = features.reshape( + -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size + ).contiguous() + + return ViTEncoderOutput(features=features) + + +class NARADIOIntermediateFeatureReturner(NARADIOEncoder, IntermediateFeatureReturner): + "Intermediate Feature Returner for UniCeption NARADIO Encoder" + + def __init__( + self, + name: str, + data_norm_type: str = "radio", + patch_size: int = 16, + model_version: str = "radio_v2.5-l", + gauss_std: float = 7.0, + pretrained_checkpoint_path: str = None, + eradio_input_shape: Optional[tuple] = None, + indices: Union[int, List[int]] = [-1], + norm_intermediate: bool = True, + stop_early: bool = False, + intermediates_only: bool = True, + feature_adaptor: Optional[str] = None, + keep_first_n_layers: Optional[int] = None, + *args, + **kwargs, + ): + """ + Intermediate Feature Returner for the NARADIO Encoder. + + Args: + name (str): Name of the encoder. + data_norm_type (str): Image normalization type. Default: "radio" + patch_size (int): Patch size for the encoder. Default: 16 + model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l" + gauss_std (float): Standard deviation of the gaussian kernel. Default: 7.0 + pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO. + eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None + indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to [-1]. Options: + - int: Return the last n layers. + - List[int]: Return the intermediate layers at the specified indices. + norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True. + stop_early (bool, optional): Whether to stop early. Defaults to False. + intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True. + feature_adaptor (Optional[str], optional): Feature adaptor to use. Defaults to None. Currently supported: "dino_v2". + keep_first_n_layers (Optional[int], optional): Number of layers to keep from the pretrained model. Defaults to None. + """ + # Init the base classes + NARADIOEncoder.__init__( + self, + name=name, + data_norm_type=data_norm_type, + patch_size=patch_size, + model_version=model_version, + gauss_std=gauss_std, + pretrained_checkpoint_path=pretrained_checkpoint_path, + eradio_input_shape=eradio_input_shape, + keep_first_n_layers=keep_first_n_layers, + *args, + **kwargs, + ) + IntermediateFeatureReturner.__init__( + self, + indices=indices, + norm_intermediate=norm_intermediate, + stop_early=stop_early, + intermediates_only=intermediates_only, + ) + + # Convert indices to absolute indices if indices is None + if self.indices is None: + self.indices = list(range(len(self.model.model.blocks))) + + self.feature_adaptor = feature_adaptor + if self.feature_adaptor is None: + pass + elif self.feature_adaptor == "dino_v2": + # Initialize a dummy radio encoder with the adaptor setting + dummy_model = torch.hub.load( + "NVlabs/RADIO", + "radio_model", + version=self.model_version, + progress=True, + skip_validation=True, + adaptor_names="dino_v2", + ) + + # Extract its feature converter weights + self.spatial_feature_converter = dummy_model.adaptors["dino_v2"].feat_mlp + + # Update the embedding dimension because the features have been projected + self.enc_embed_dim = self.spatial_feature_converter.final[-1].out_features + + del dummy_model + else: + raise ValueError("Unsupported feature adaptor. Supported: dino_v2") + + def forward( + self, encoder_input: ViTEncoderInput + ) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: + """ + NARADIO Encoder Forward Pass with Intermediate Feature Return + + Args: + encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. + + Returns: + Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder. + If `intermediates_only` is True, returns a list of intermediate features. + Otherwise, returns a tuple with the final features and a list of intermediate features. + """ + # Check image normalization type + self._check_data_normalization_type(encoder_input.data_norm_type) + + # Check the dtype and shape of the input image + assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor" + assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)" + batch_size, channels, height, width = encoder_input.image.shape + assert channels == 3, "Input must have 3 channels" + assert ( + height % self.patch_size == 0 and width % self.patch_size == 0 + ), f"Input shape must be divisible by patch size: {self.patch_size}" + + # Set input resolution for Gaussian attention + self.model.model.blocks[-1].set_input_resolution((height, width)) + + # Extract the final features and intermediate features accordingly + model_outputs = self.model.forward_intermediates( + encoder_input.image, + indices=self.indices, + return_prefix_tokens=False, + norm=self.norm_intermediate, + stop_early=self.stop_early, + output_fmt="NLC", + intermediates_only=self.intermediates_only, + ) + + # Extract the final features and intermediate features accordingly + final_features, intermediate_features = None, None + if self.intermediates_only: + intermediate_features = model_outputs + else: + final_features = model_outputs[0].features.contiguous() + intermediate_features = model_outputs[1] + + # Optionally convert the features using the feature adaptor + Hp, Wp = height // self.patch_size, width // self.patch_size + + # Convert final features + if final_features is not None: + if self.feature_adaptor is not None: + final_features = self.spatial_feature_converter(final_features) + + # Convert to BCHW and package + final_features = final_features.view(batch_size, Hp, Wp, -1).permute(0, 3, 1, 2) + final_features = ViTEncoderOutput(features=final_features) + + # Convert intermediate features + if intermediate_features is not None: + num_intermediate = len(intermediate_features) + all_intermediate_feats_tensor = torch.cat(intermediate_features, dim=0) + if self.feature_adaptor is not None: + all_intermediate_feats_tensor = self.spatial_feature_converter(all_intermediate_feats_tensor) + # Convert to BCHW + all_intermediate_feats_tensor = all_intermediate_feats_tensor.view( + num_intermediate * batch_size, Hp, Wp, -1 + ).permute(0, 3, 1, 2) + all_intermediate_feats = torch.chunk(all_intermediate_feats_tensor, num_intermediate, dim=0) + intermediate_features = [ViTEncoderOutput(features=x) for x in all_intermediate_feats] + + # Return the final features and intermediate features accordingly + if self.intermediates_only: + return intermediate_features + else: + return final_features, intermediate_features + + +if __name__ == "__main__": + # Init different versions of the RADIO Encoder + for model_version in ["radio_v2.5-b", "radio_v2.5-l"]: + naradio_encoder = NARADIOEncoder(name="NARADIOv2.5", model_version=model_version) + + print("All NARADIO Encoders have been initialized successfully!") + + # Intermediate Feature Returner Tests + print("Running Intermediate Feature Returner Tests...") + + # Run the intermediate feature returner with last-n index + naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner( + name="NARADIOv2.5", model_version="radio_v2.5-b", indices=6 + ) # Last 6 layers + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio") + output = naradio_intermediate_feature_returner(dummy_input) + assert isinstance(output, list), "Output must be a list of intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices" + + # Run the intermediate feature returner with specific indices + naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner( + name="NARADIOv2.5", model_version="radio_v2.5-b", indices=[0, 2, 4, 6] + ) # Specific layers + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio") + output = naradio_intermediate_feature_returner(dummy_input) + assert isinstance(output, list), "Output must be a list of intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices" + + # Test the normalizing of intermediate features + naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner( + name="NARADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=False, intermediates_only=False + ) # Do not normalize + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio") + output = naradio_intermediate_feature_returner(dummy_input) + assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features" + assert isinstance(output[1], list), "Second element of output must be a list of intermediate features" + assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + if not isinstance(naradio_intermediate_feature_returner.model.model.norm, torch.nn.Identity): + assert not torch.equal( + output[0].features, output[1][0].features + ), "Final features and intermediate features must be different" + + naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner( + name="NARADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=True, intermediates_only=False + ) + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio") + output = naradio_intermediate_feature_returner(dummy_input) + assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features" + assert isinstance(output[1], list), "Second element of output must be a list of intermediate features" + assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert torch.equal( + output[0].features, output[1][0].features + ), "Final features and intermediate features must be same" + + print("All Intermediate Feature Returner Tests have passed successfully!") diff --git a/UniCeption/uniception/models/encoders/patch_embedder.py b/UniCeption/uniception/models/encoders/patch_embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d7425bd4116b35934934ee47c414fde3f4950b --- /dev/null +++ b/UniCeption/uniception/models/encoders/patch_embedder.py @@ -0,0 +1,235 @@ +""" +Encoder class for Patch Embedder +""" + +import math +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ + +from uniception.models.encoders.base import ( + UniCeptionViTEncoderBase, + ViTEncoderInput, + ViTEncoderNonImageInput, + ViTEncoderOutput, +) + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbedder(UniCeptionViTEncoderBase): + "UniCeption Patch Embedder" + + def __init__( + self, + name: str, + data_norm_type: str = "patch_embedder", + input_size: Union[int, Tuple[int, int]] = 518, + patch_size: int = 14, + in_chans: int = 3, + enc_embed_dim: int = 1024, + norm_layer: Optional[Callable] = None, + post_pe_norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6), + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Patch Encoder for extracting patch-wise features from a spatial input of size (B, C, H, W). + Learnable positional encoding is also applied to the patch-wise features. + """ + # Init the base class + super().__init__( + name=name, + data_norm_type=data_norm_type, + patch_size=patch_size, + *args, + **kwargs, + ) + + # Init the Patch Embedder specific attributes + patch_HW = make_2tuple(patch_size) + self.input_size = make_2tuple(input_size) + self.patches_resolution = (self.input_size[0] // patch_HW[0], self.input_size[1] // patch_HW[1]) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + self.in_chans = in_chans + self.enc_embed_dim = enc_embed_dim + + # Init the Patch Embedder layers + self.proj = nn.Conv2d(in_chans, enc_embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(enc_embed_dim) if norm_layer else nn.Identity() + + # Init the learnable positional encodings + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.enc_embed_dim)) + trunc_normal_(self.pos_embed, std=0.02) + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + # Init the norm layer after positional encoding + self.post_pe_norm = post_pe_norm_layer(enc_embed_dim) if post_pe_norm_layer else nn.Identity() + + # Load the pretrained checkpoint if provided + self.pretrained_checkpoint_path = pretrained_checkpoint_path + if self.pretrained_checkpoint_path: + print(f"Loading custom pretrained Patch Embedder checkpoint from {self.pretrained_checkpoint_path} ...") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def interpolate_pos_encoding(self, features, height, width): + """ + Interpolate the positional encoding to the expected size. + + Args: + features (torch.Tensor): Input tensor of shape (B, N, C). + height (int, float): Height of the input tensor. + width (int, float): Width of the input tensor. + + Returns: + torch.Tensor: Interpolated positional encoding tensor of shape (1, N, C). + """ + previous_dtype = features.dtype + npatch = features.shape[1] + N = self.pos_embed.shape[1] + if npatch == N and height == width: + return self.pos_embed + patch_pos_embed = self.pos_embed.float() + dim = features.shape[-1] + height0 = height // self.patch_size + width0 = width // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sh = float(height0 + self.interpolate_offset) / M + sw = float(width0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sh, sw) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (height0, width0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (height0, width0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return patch_pos_embed.to(previous_dtype) + + def forward(self, encoder_input: Union[ViTEncoderInput, ViTEncoderNonImageInput]) -> ViTEncoderOutput: + """ + Patch Embedder Forward Pass + + Args: + encoder_input (Union[ViTEncoderInput, ViTEncoderNonImageInput]): Input data for the encoder. + If input type is ViTEncoderInput, input data must contain image normalization type and normalized image tensor. + If input type is ViTEncoderNonImageInput, input data must contain a tensor of size (B, C, H, W). + + Returns: + ViTEncoderOutput: Output data from the encoder. + """ + # Get the input data and verify normalization if the input type is ViTEncoderInput + if isinstance(encoder_input, ViTEncoderInput): + self._check_data_normalization_type(encoder_input.data_norm_type) + input_data = encoder_input.image + elif isinstance(encoder_input, ViTEncoderNonImageInput): + input_data = encoder_input.data + else: + raise ValueError("Unsupported input type for Patch Embedder.") + + # Check the dtype and shape of the input + assert isinstance(input_data, torch.Tensor), "Input must be a torch.Tensor" + assert input_data.ndim == 4, "Input must be of shape (B, C, H, W)" + batch_size, channels, height, width = input_data.shape + assert ( + height % self.patch_size == 0 and width % self.patch_size == 0 + ), f"Input shape must be divisible by patch size: {self.patch_size}" + + # Patchify the input data and project into expected latent space + features = self.proj(input_data) # (B, C, H, W) -> (B, E, H / Patch_Size, W / Patch_Size) + features = features.flatten(2).transpose( + 1, 2 + ) # (B, E, H / Patch_Size, W / Patch_Size) -> (B, H / Patch_Size * W / Patch_Size, E) + features = self.norm(features) # Normalize the features after patch embedding + features = features + self.interpolate_pos_encoding( + features, height, width + ) # (B, H / Patch_Size * W / Patch_Size, E) + features = self.post_pe_norm(features) # Normalize the features after positional encoding + + # Resize the features to the expected shape + # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) + features = features.permute(0, 2, 1) + features = features.reshape( + -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size + ).contiguous() + + return ViTEncoderOutput(features=features) + + +if __name__ == "__main__": + # Init Patch Embedder for images as input + patch_embedder = PatchEmbedder( + name="patch_embedder", + data_norm_type="patch_embedder", + input_size=518, + patch_size=14, + in_chans=3, + enc_embed_dim=1024, + ) + + # Test dummy image input + dummy_image = torch.randn(1, 3, 518, 518) + patch_embedder_output = patch_embedder(ViTEncoderInput(data_norm_type="patch_embedder", image=dummy_image)) + assert patch_embedder_output.features.shape == ( + 1, + 1024, + 37, + 37, + ), "Output features must have shape (1, 1024, 37, 37)" + + # Init Patch Embedder for non-image data as input + patch_embedder = PatchEmbedder( + name="patch_embedder", + data_norm_type="patch_embedder", + input_size=518, + patch_size=14, + in_chans=6, + enc_embed_dim=1024, + ) + + # Init Patch Embedder for single channel input + patch_embedder = PatchEmbedder( + name="patch_embedder", + data_norm_type="patch_embedder", + input_size=518, + patch_size=14, + in_chans=1, + enc_embed_dim=1024, + ) + + # Test dummy non-image input + dummy_image = torch.randn(1, 1, 518, 518) + patch_embedder_output = patch_embedder(ViTEncoderNonImageInput(data=dummy_image)) + assert patch_embedder_output.features.shape == ( + 1, + 1024, + 37, + 37, + ), "Output features must have shape (1, 1024, 37, 37)" + + print("All variants of Patch Embedder have been initialized successfully!") diff --git a/UniCeption/uniception/models/encoders/radio.py b/UniCeption/uniception/models/encoders/radio.py new file mode 100644 index 0000000000000000000000000000000000000000..59cc3de409931b1ae299597fbe6288d5248937cc --- /dev/null +++ b/UniCeption/uniception/models/encoders/radio.py @@ -0,0 +1,367 @@ +""" +Encoder Class for RADIO (Nvidia) +""" + +from typing import List, Optional, Tuple, Union + +import torch + +from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput +from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner + + +class RADIOEncoder(UniCeptionViTEncoderBase): + "UniCeption RADIO Encoder" + + def __init__( + self, + name: str, + data_norm_type: str = "radio", + patch_size: int = 16, + model_version: str = "radio_v2.5-l", + pretrained_checkpoint_path: str = None, + eradio_input_shape: Optional[tuple] = None, + torch_hub_force_reload: bool = False, + keep_first_n_layers: Optional[int] = None, + *args, + **kwargs, + ): + """ + RADIO Encoder for extracting spatial features from images. + + Args: + name (str): Name of the encoder. + data_norm_type (str): Image normalization type. Default: "radio" + patch_size (int): Patch size for the encoder. Default: 16 + model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l" + pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO. Default: None + eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None + torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False + keep_first_n_layers (Optional[int]): Number of layers to keep from the pretrained model. Default: None + """ + # Init the base class + super().__init__( + name=name, + data_norm_type=data_norm_type, + patch_size=patch_size, + *args, + **kwargs, + ) + + # Init the RADIO Encoder specific attributes + self.model_version = model_version + self.enc_embed_dim = { + "radio_v2.5-b": 768, + "radio_v2.5-l": 1024, + "radio_v2.5-h": 1280, + "radio_v2.5-g": 1536, + "e-radio_v2": 1536, + }[self.model_version] + + if self.model_version == "radio_v2.5-g": + assert patch_size == 14, "Patch size must be 14 for RADIO v2.5-g" + else: + assert patch_size == 16, "Patch size must be 16 for all other versions of RADIO" + + # Load the pretrained RADIO model from torch hub + print(f"Loading pretrained {self.model_version} from torch hub") + try: # Requires internet access + self.model = torch.hub.load( + "NVlabs/RADIO", + "radio_model", + version=self.model_version, + progress=True, + skip_validation=True, + force_reload=torch_hub_force_reload, + ) + except: # Load from cache + self.model = torch.hub.load( + "NVlabs/RADIO", + "radio_model", + version=self.model_version, + progress=True, + skip_validation=True, + ) + + # Delete the excess blocks if keep_first_n_layers is specified + if keep_first_n_layers is not None: + assert keep_first_n_layers < len( + self.model.model.blocks + ), "keep_first_n_layers must be less than the number of blocks" + print(f"Keeping only the first {keep_first_n_layers} layers of the model") + self.model.model.blocks = torch.nn.ModuleList(self.model.model.blocks[:keep_first_n_layers]) + + # Set the optimal window size for E-RADIO models + if "e-radio" in self.model_version: + assert eradio_input_shape is not None, "Input shape (height, width) must be provided for E-RADIO models" + self.model.model.set_optimal_window_size(eradio_input_shape) + + # Load the custom pretrained checkpoint if provided + if pretrained_checkpoint_path is not None: + print(f"Loading custom pretrained RADIO checkpoint from {pretrained_checkpoint_path}") + ckpt = torch.load(pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput: + """ + RADIO Encoder Forward Pass + + Args: + encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. + + Returns: + ViTEncoderOutput: Output data from the encoder. + """ + # Check image normalization type + self._check_data_normalization_type(encoder_input.data_norm_type) + + # Check the dtype and shape of the input image + assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor" + assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)" + batch_size, channels, height, width = encoder_input.image.shape + assert channels == 3, "Input must have 3 channels" + assert ( + height % self.patch_size == 0 and width % self.patch_size == 0 + ), f"Input shape must be divisible by patch size: {self.patch_size}" + + # Forward pass throught the RADIO encoder + summary, features = self.model(encoder_input.image) + + # Resize the features to the expected shape + # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size) + features = features.permute(0, 2, 1) + features = features.reshape( + -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size + ).contiguous() + + return ViTEncoderOutput(features=features) + + +class RADIOIntermediateFeatureReturner(RADIOEncoder, IntermediateFeatureReturner): + "Intermediate Feature Returner for UniCeption RADIO Encoder" + + def __init__( + self, + name: str, + data_norm_type: str = "radio", + patch_size: int = 16, + model_version: str = "radio_v2.5-l", + pretrained_checkpoint_path: str = None, + eradio_input_shape: Optional[tuple] = None, + indices: Union[int, List[int]] = [-1], + norm_intermediate: bool = True, + stop_early: bool = False, + intermediates_only: bool = True, + feature_adaptor: Optional[str] = None, + keep_first_n_layers: Optional[int] = None, + *args, + **kwargs, + ): + """ + Intermediate Feature Returner for the RADIO Encoder. + + Args: + name (str): Name of the encoder. + data_norm_type (str): Image normalization type. Default: "radio" + patch_size (int): Patch size for the encoder. Default: 16 + model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l" + pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO. + eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None + indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to [-1]. Options: + - int: Return the last n layers. + - List[int]: Return the intermediate layers at the specified indices. + norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True. + stop_early (bool, optional): Whether to stop early. Defaults to False. + intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True. + feature_adaptor (Optional[str], optional): Feature adaptor to use. Defaults to None. Currently supported: "dino_v2". + keep_first_n_layers (Optional[int], optional): Number of layers to keep from the pretrained model. Defaults to None. + """ + # Init the base classes + RADIOEncoder.__init__( + self, + name=name, + data_norm_type=data_norm_type, + patch_size=patch_size, + model_version=model_version, + pretrained_checkpoint_path=pretrained_checkpoint_path, + eradio_input_shape=eradio_input_shape, + keep_first_n_layers=keep_first_n_layers, + *args, + **kwargs, + ) + IntermediateFeatureReturner.__init__( + self, + indices=indices, + norm_intermediate=norm_intermediate, + stop_early=stop_early, + intermediates_only=intermediates_only, + ) + + # Convert indices to absolute indices if indices is None + if self.indices is None: + self.indices = list(range(len(self.model.model.blocks))) + + self.feature_adaptor = feature_adaptor + if self.feature_adaptor is None: + pass + elif self.feature_adaptor == "dino_v2": + # Initialize a dummy radio encoder with the adaptor setting + dummy_model = torch.hub.load( + "NVlabs/RADIO", + "radio_model", + version=self.model_version, + progress=True, + skip_validation=True, + adaptor_names="dino_v2", + ) + + # Extract its feature converter weights + self.spatial_feature_converter = dummy_model.adaptors["dino_v2"].feat_mlp + + # Update the embedding dimension because the features have been projected + self.enc_embed_dim = self.spatial_feature_converter.final[-1].out_features + + del dummy_model + else: + raise ValueError("Unsupported feature adaptor. Supported: dino_v2") + + def forward( + self, encoder_input: ViTEncoderInput + ) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: + """ + RADIO Encoder Forward Pass with Intermediate Feature Return + + Args: + encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor. + + Returns: + Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder. + If `intermediates_only` is True, returns a list of intermediate features. + Otherwise, returns a tuple with the final features and a list of intermediate features. + """ + # Check image normalization type + self._check_data_normalization_type(encoder_input.data_norm_type) + + # Check the dtype and shape of the input image + assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor" + assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)" + batch_size, channels, height, width = encoder_input.image.shape + assert channels == 3, "Input must have 3 channels" + assert ( + height % self.patch_size == 0 and width % self.patch_size == 0 + ), f"Input shape must be divisible by patch size: {self.patch_size}" + + # Extract the final features and intermediate features accordingly + model_outputs = self.model.forward_intermediates( + encoder_input.image, + indices=self.indices, + return_prefix_tokens=False, + norm=self.norm_intermediate, + stop_early=self.stop_early, + output_fmt="NLC", + intermediates_only=self.intermediates_only, + ) + + # Extract the final features and intermediate features accordingly + final_features, intermediate_features = None, None + if self.intermediates_only: + intermediate_features = model_outputs + else: + final_features = model_outputs[0].features.contiguous() + intermediate_features = model_outputs[1] + + # Optionally convert the features using the feature adaptor + Hp, Wp = height // self.patch_size, width // self.patch_size + + # Convert final features + if final_features is not None: + if self.feature_adaptor is not None: + final_features = self.spatial_feature_converter(final_features) + + # Convert to BCHW and package + final_features = final_features.view(batch_size, Hp, Wp, -1).permute(0, 3, 1, 2) + final_features = ViTEncoderOutput(features=final_features) + + # Convert intermediate features + if intermediate_features is not None: + num_intermediate = len(intermediate_features) + all_intermediate_feats_tensor = torch.cat(intermediate_features, dim=0) + if self.feature_adaptor is not None: + all_intermediate_feats_tensor = self.spatial_feature_converter(all_intermediate_feats_tensor) + # Convert to BCHW + all_intermediate_feats_tensor = all_intermediate_feats_tensor.view( + num_intermediate * batch_size, Hp, Wp, -1 + ).permute(0, 3, 1, 2) + all_intermediate_feats = torch.chunk(all_intermediate_feats_tensor, num_intermediate, dim=0) + intermediate_features = [ViTEncoderOutput(features=x) for x in all_intermediate_feats] + + # Return the final features and intermediate features accordingly + if self.intermediates_only: + return intermediate_features + else: + return final_features, intermediate_features + + +if __name__ == "__main__": + # Init different versions of the RADIO Encoder + for model_version in ["radio_v2.5-b", "radio_v2.5-l"]: + radio_encoder = RADIOEncoder(name="RADIOv2.5", model_version=model_version) + + # Init the E-RADIO Encoder + eradio_input_shape = (512, 512) + eradio_encoder = RADIOEncoder(name="E-RADIO", model_version="e-radio_v2", eradio_input_shape=eradio_input_shape) + + print("All RADIO Encoders have been initialized successfully!") + + # Intermediate Feature Returner Tests + print("Running Intermediate Feature Returner Tests...") + + # Run the intermediate feature returner with last-n index + radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner( + name="RADIOv2.5", model_version="radio_v2.5-b", indices=6 + ) # Last 6 layers + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio") + output = radio_intermediate_feature_returner(dummy_input) + assert isinstance(output, list), "Output must be a list of intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices" + + # Run the intermediate feature returner with specific indices + radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner( + name="RADIOv2.5", model_version="radio_v2.5-b", indices=[0, 2, 4, 6] + ) # Specific layers + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio") + output = radio_intermediate_feature_returner(dummy_input) + assert isinstance(output, list), "Output must be a list of intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices" + + # Test the normalizing of intermediate features + radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner( + name="RADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=False, intermediates_only=False + ) # Do not normalize + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio") + output = radio_intermediate_feature_returner(dummy_input) + assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features" + assert isinstance(output[1], list), "Second element of output must be a list of intermediate features" + assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + if not isinstance(radio_intermediate_feature_returner.model.model.norm, torch.nn.Identity): + assert not torch.equal( + output[0].features, output[1][0].features + ), "Final features and intermediate features must be different" + + radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner( + name="RADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=True, intermediates_only=False + ) + dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio") + output = radio_intermediate_feature_returner(dummy_input) + assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features" + assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features" + assert isinstance(output[1], list), "Second element of output must be a list of intermediate features" + assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput" + assert torch.equal( + output[0].features, output[1][0].features + ), "Final features and intermediate features must be same" + + print("All Intermediate Feature Returner Tests have passed successfully!") diff --git a/UniCeption/uniception/models/encoders/utils.py b/UniCeption/uniception/models/encoders/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af2e67518e02e34a59742e88d55c2011aaa100ed --- /dev/null +++ b/UniCeption/uniception/models/encoders/utils.py @@ -0,0 +1,86 @@ +""" +Utility functions for UniCeption Encoders. +""" + +import functools + +import numpy as np +import torch + + +def profile_encoder(num_warmup=3, num_runs=20, autocast_precision="float16", use_compile=False, dynamic=True): + def decorator(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + device = "cuda" + autocast_dtype = getattr(torch, autocast_precision) + + # Compile the model if requested + if use_compile: + compiled_func = torch.compile(func, dynamic=dynamic, mode="max-autotune") + else: + compiled_func = func + + with torch.autocast("cuda", dtype=autocast_dtype): + # Warm-up runs + for _ in range(num_warmup): + output = compiled_func(self, *args, **kwargs) + if isinstance(output, torch.Tensor): + output.sum().backward() + else: + output.features.sum().backward() + torch.cuda.synchronize() + + # Clear memory cache + torch.cuda.empty_cache() + + # Lists to store results + forward_times, backward_times, memory_usages = [], [], [] + + for _ in range(num_runs): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.reset_peak_memory_stats() + memory_before = torch.cuda.max_memory_allocated(device) + + # Forward pass + start_event.record() + output = compiled_func(self, *args, **kwargs) + end_event.record() + torch.cuda.synchronize() + forward_times.append(start_event.elapsed_time(end_event)) + + # Backward pass + start_event.record() + if isinstance(output, torch.Tensor): + output.sum().backward() + else: + output.features.sum().backward() + end_event.record() + torch.cuda.synchronize() + backward_times.append(start_event.elapsed_time(end_event)) + + memory_after = torch.cuda.max_memory_allocated(device) + memory_usages.append((memory_after - memory_before) / 1e6) # Convert to MB + + # Compute mean and standard deviation + fwd_mean, fwd_std = np.mean(forward_times), np.std(forward_times) + bwd_mean, bwd_std = np.mean(backward_times), np.std(backward_times) + mem_mean, mem_std = np.mean(memory_usages), np.std(memory_usages) + + compile_status = ( + "with torch.compile (dynamic=True)" + if use_compile and dynamic + else "with torch.compile (dynamic=False)" if use_compile else "without torch.compile" + ) + print(f"Profiling results {compile_status}:") + print(f"Forward Pass Time: {fwd_mean:.2f} ± {fwd_std:.2f} ms") + print(f"Backward Pass Time: {bwd_mean:.2f} ± {bwd_std:.2f} ms") + print(f"Peak GPU Memory Usage: {mem_mean:.2f} ± {mem_std:.2f} MB") + + return output + + return wrapper + + return decorator diff --git a/UniCeption/uniception/models/factory/__init__.py b/UniCeption/uniception/models/factory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..477ee9468bff13a3430759acd0ddc3282016fdb1 --- /dev/null +++ b/UniCeption/uniception/models/factory/__init__.py @@ -0,0 +1,3 @@ +from uniception.models.factory.dust3r import DUSt3R + +__all__ = ["DUSt3R"] diff --git a/UniCeption/uniception/models/factory/dust3r.py b/UniCeption/uniception/models/factory/dust3r.py new file mode 100644 index 0000000000000000000000000000000000000000..8f1cde4d1f450c0364eb651c5f36ef929e69e503 --- /dev/null +++ b/UniCeption/uniception/models/factory/dust3r.py @@ -0,0 +1,332 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from uniception.models.encoders import ViTEncoderInput +from uniception.models.encoders.croco import CroCoEncoder +from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT +from uniception.models.info_sharing.base import MultiViewTransformerInput +from uniception.models.info_sharing.cross_attention_transformer import ( + MultiViewCrossAttentionTransformer, + MultiViewCrossAttentionTransformerIFR, +) +from uniception.models.libs.croco.pos_embed import RoPE2D, get_2d_sincos_pos_embed +from uniception.models.prediction_heads.adaptors import PointMapWithConfidenceAdaptor +from uniception.models.prediction_heads.base import AdaptorInput, PredictionHeadInput, PredictionHeadLayeredInput +from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor +from uniception.models.prediction_heads.linear import LinearFeature + + +def is_symmetrized(gt1, gt2): + "Function to check if input pairs are symmetrized, i.e., (a, b) and (b, a) always exist in the input" + x = gt1["instance"] + y = gt2["instance"] + if len(x) == len(y) and len(x) == 1: + return False # special case of batchsize 1 + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) + return ok + + +def interleave(tensor1, tensor2): + "Interleave two tensors along the first dimension (used to avoid redundant encoding for symmetrized pairs)" + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +class DUSt3R(nn.Module): + "DUSt3R defined with UniCeption Modules" + + def __init__( + self, + name: str, + data_norm_type: str = "dust3r", + img_size: tuple = (224, 224), + patch_embed_cls: str = "PatchEmbedDust3R", + pred_head_type: str = "linear", + pred_head_output_dim: int = 4, + pred_head_feature_dim: int = 256, + depth_mode: Tuple[str, float, float] = ("exp", -float("inf"), float("inf")), + conf_mode: Tuple[str, float, float] = ("exp", 1, float("inf")), + pos_embed: str = "RoPE100", + pretrained_checkpoint_path: str = None, + pretrained_encoder_checkpoint_path: str = None, + pretrained_info_sharing_checkpoint_path: str = None, + pretrained_pred_head_checkpoint_paths: List[str] = [None, None], + pretrained_pred_head_regressor_checkpoint_paths: List[str] = [None, None], + override_encoder_checkpoint_attributes: bool = False, + *args, + **kwargs, + ): + """ + Two-view model containing siamese encoders followed by a two-view cross-attention transformer and respective downstream heads. + The goal is to output scene representation directly, both images in view1's frame (hence the asymmetry). + + Args: + name (str): Name of the model. + data_norm_type (str): Type of data normalization. (default: "dust3r") + img_size (tuple): Size of input images. (default: (224, 224)) + patch_embed_cls (str): Class for patch embedding. (default: "PatchEmbedDust3R"). Options: + - "PatchEmbedDust3R" + - "ManyAR_PatchEmbed" + pred_head_type (str): Type of prediction head. (default: "linear"). Options: + - "linear" + - "dpt" + pred_head_output_dim (int): Output dimension of prediction head. (default: 4) + pred_head_feature_dim (int): Feature dimension of prediction head. (default: 256) + depth_mode (Tuple[str, float, float]): Depth mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', -inf, inf)) + conf_mode (Tuple[str, float, float]): Confidence mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', 1, inf)) + pos_embed (str): Position embedding type. (default: 'RoPE100') + landscape_only (bool): Run downstream head only in landscape orientation. (default: True) + pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None) + pretrained_encoder_checkpoint_path (str): Path to pretrained encoder checkpoint. (default: None) + pretrained_info_sharing_checkpoint_path (str): Path to pretrained info_sharing checkpoint. (default: None) + pretrained_pred_head_checkpoint_paths (List[str]): Paths to pretrained prediction head checkpoints. (default: None) + pretrained_pred_head_regressor_checkpoint_paths (List[str]): Paths to pretrained prediction head regressor checkpoints. (default: None) + override_encoder_checkpoint_attributes (bool): Whether to override encoder checkpoint attributes. (default: False) + """ + super().__init__(*args, **kwargs) + + # Initalize the attributes + self.name = name + self.data_norm_type = data_norm_type + self.img_size = img_size + self.patch_embed_cls = patch_embed_cls + self.pred_head_type = pred_head_type + self.pred_head_output_dim = pred_head_output_dim + self.depth_mode = depth_mode + self.conf_mode = conf_mode + self.pos_embed = pos_embed + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.pretrained_encoder_checkpoint_path = pretrained_encoder_checkpoint_path + self.pretrained_info_sharing_checkpoint_path = pretrained_info_sharing_checkpoint_path + self.pretrained_pred_head_checkpoint_paths = pretrained_pred_head_checkpoint_paths + self.pretrained_pred_head_regressor_checkpoint_paths = pretrained_pred_head_regressor_checkpoint_paths + self.override_encoder_checkpoint_attributes = override_encoder_checkpoint_attributes + + # Initialize RoPE for the CroCo Encoder & Two-View Cross Attention Transformer + freq = float(pos_embed[len("RoPE") :]) + self.rope = RoPE2D(freq=freq) + + # Initialize Encoder + self.encoder = CroCoEncoder( + name=name, + data_norm_type=data_norm_type, + patch_embed_cls=patch_embed_cls, + img_size=img_size, + pretrained_checkpoint_path=pretrained_encoder_checkpoint_path, + override_checkpoint_attributes=override_encoder_checkpoint_attributes, + ) + + # Initialize Multi-View Cross Attention Transformer + if self.pred_head_type == "linear": + # Returns only normalized last layer features + self.info_sharing = MultiViewCrossAttentionTransformer( + name="base_info_sharing", + input_embed_dim=self.encoder.enc_embed_dim, + num_views=2, + custom_positional_encoding=self.rope, + pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path, + ) + elif self.pred_head_type == "dpt": + # Returns intermediate features and normalized last layer features + self.info_sharing = MultiViewCrossAttentionTransformerIFR( + name="base_info_sharing", + input_embed_dim=self.encoder.enc_embed_dim, + num_views=2, + indices=[5, 8], + norm_intermediate=False, + custom_positional_encoding=self.rope, + pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path, + ) + else: + raise ValueError(f"Invalid prediction head type: {pred_head_type}. Must be 'linear' or 'dpt'.") + + # Initialize Prediction Heads + if pred_head_type == "linear": + # Initialize Prediction Head 1 + self.head1 = LinearFeature( + input_feature_dim=self.info_sharing.dim, + output_dim=pred_head_output_dim, + patch_size=self.encoder.patch_size, + pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0], + ) + # Initialize Prediction Head 2 + self.head2 = LinearFeature( + input_feature_dim=self.info_sharing.dim, + output_dim=pred_head_output_dim, + patch_size=self.encoder.patch_size, + pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1], + ) + elif pred_head_type == "dpt": + # Initialze Predction Head 1 + self.dpt_feature_head1 = DPTFeature( + patch_size=self.encoder.patch_size, + hooks=[0, 1, 2, 3], + input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3, + feature_dim=pred_head_feature_dim, + pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0], + ) + self.dpt_regressor_head1 = DPTRegressionProcessor( + input_feature_dim=pred_head_feature_dim, + output_dim=pred_head_output_dim, + pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[0], + ) + self.head1 = nn.Sequential(self.dpt_feature_head1, self.dpt_regressor_head1) + # Initialize Prediction Head 2 + self.dpt_feature_head2 = DPTFeature( + patch_size=self.encoder.patch_size, + hooks=[0, 1, 2, 3], + input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3, + feature_dim=pred_head_feature_dim, + pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1], + ) + self.dpt_regressor_head2 = DPTRegressionProcessor( + input_feature_dim=pred_head_feature_dim, + output_dim=pred_head_output_dim, + pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[1], + ) + self.head2 = nn.Sequential(self.dpt_feature_head2, self.dpt_regressor_head2) + + # Initialize Final Output Adaptor + self.adaptor = PointMapWithConfidenceAdaptor( + name="pointmap", + pointmap_mode=depth_mode[0], + pointmap_vmin=depth_mode[1], + pointmap_vmax=depth_mode[2], + confidence_type=conf_mode[0], + confidence_vmin=conf_mode[1], + confidence_vmax=conf_mode[2], + ) + + # Load pretrained weights + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained DUSt3R weights from {self.pretrained_checkpoint_path} ...") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def _encode_image_pairs(self, img1, img2, data_norm_type): + "Encode two different batches of images (each batch can have different image shape)" + if img1.shape[-2:] == img2.shape[-2:]: + encoder_input = ViTEncoderInput(image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type) + encoder_output = self.encoder(encoder_input) + out, out2 = encoder_output.features.chunk(2, dim=0) + else: + encoder_input = ViTEncoderInput(image=img1, data_norm_type=data_norm_type) + out = self.encoder(encoder_input) + out = out.features + encoder_input2 = ViTEncoderInput(image=img2) + out2 = self.encoder(encoder_input2) + out2 = out2.features + + return out, out2 + + def _encode_symmetrized(self, view1, view2): + "Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input" + img1 = view1["img"] + img2 = view2["img"] + if is_symmetrized(view1, view2): + # Computing half of forward pass' + feat1, feat2 = self._encode_image_pairs(img1[::2], img2[::2], data_norm_type=view1["data_norm_type"]) + feat1, feat2 = interleave(feat1, feat2) + else: + feat1, feat2 = self._encode_image_pairs(img1, img2, data_norm_type=view1["data_norm_type"]) + + return feat1, feat2 + + def _downstream_head(self, head_num, decout, img_shape): + "Run the respective prediction heads" + head = getattr(self, f"head{head_num}") + if self.pred_head_type == "linear": + head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) + elif self.pred_head_type == "dpt": + head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) + + return head(head_input) + + def forward(self, view1, view2): + """ + Forward pass for DUSt3R performing the following operations: + 1. Encodes the two input views (images). + 2. Combines the encoded features using a two-view cross-attention transformer. + 3. Passes the combined features through the respective prediction heads. + 4. Returns the processed final outputs for both views. + + Args: + view1 (dict): Dictionary containing the first view's images and instance information. + "img" is a required key and value is a tensor of shape (B, C, H, W). + view2 (dict): Dictionary containing the second view's images and instance information. + "img" is a required key and value is a tensor of shape (B, C, H, W). + + Returns: + Tuple[dict, dict]: A tuple containing the final outputs for both views. + """ + # Get input shapes + _, _, height1, width1 = view1["img"].shape + _, _, height2, width2 = view2["img"].shape + shape1 = (int(height1), int(width1)) + shape2 = (int(height2), int(width2)) + + # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) + feat1, feat2 = self._encode_symmetrized(view1, view2) + + # Combine all images into view-centric representation + info_sharing_input = MultiViewTransformerInput(features=[feat1, feat2]) + if self.pred_head_type == "linear": + final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input) + elif self.pred_head_type == "dpt": + final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( + info_sharing_input + ) + + if self.pred_head_type == "linear": + # Define feature dictionary for linear head + info_sharing_outputs = { + "1": final_info_sharing_multi_view_feat.features[0].float(), + "2": final_info_sharing_multi_view_feat.features[1].float(), + } + elif self.pred_head_type == "dpt": + # Define feature dictionary for DPT head + info_sharing_outputs = { + "1": [ + feat1.float(), + intermediate_info_sharing_multi_view_feat[0].features[0].float(), + intermediate_info_sharing_multi_view_feat[1].features[0].float(), + final_info_sharing_multi_view_feat.features[0].float(), + ], + "2": [ + feat2.float(), + intermediate_info_sharing_multi_view_feat[0].features[1].float(), + intermediate_info_sharing_multi_view_feat[1].features[1].float(), + final_info_sharing_multi_view_feat.features[1].float(), + ], + } + + # Downstream task prediction + with torch.autocast("cuda", enabled=False): + # Prediction heads + head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) + head_output2 = self._downstream_head(2, info_sharing_outputs, shape2) + + # Post-process outputs + final_output1 = self.adaptor( + AdaptorInput(adaptor_feature=head_output1.decoded_channels, output_shape_hw=shape1) + ) + final_output2 = self.adaptor( + AdaptorInput(adaptor_feature=head_output2.decoded_channels, output_shape_hw=shape2) + ) + + # Convert outputs to dictionary + res1 = { + "pts3d": final_output1.value.permute(0, 2, 3, 1).contiguous(), + "conf": final_output1.confidence.permute(0, 2, 3, 1).contiguous(), + } + res2 = { + "pts3d_in_other_view": final_output2.value.permute(0, 2, 3, 1).contiguous(), + "conf": final_output2.confidence.permute(0, 2, 3, 1).contiguous(), + } + + return res1, res2 diff --git a/UniCeption/uniception/models/info_sharing/README.md b/UniCeption/uniception/models/info_sharing/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0c164a0eb9927dc59cb667293539f94c5f7af271 --- /dev/null +++ b/UniCeption/uniception/models/info_sharing/README.md @@ -0,0 +1,18 @@ +# UniCeption Information Sharing Blocks + +## Currently Supported Information Sharing Architectures + +### UniCeptionInfoSharingBase: + +- `MultiViewCrossAttentionTransformer` + - `MultiViewCrossAttentionTransformerIFR` +- `MultiViewGlobalAttentionTransformer` + - `MultiViewGlobalAttentionTransformerIFR` +- `MultiViewAlternatingAttentionTransformer` + - `MultiViewAlternatingAttentionTransformerIFR` + +## Developer Guidelines + +Please follow the main UniCeption developer guidelines described in `README.md` when contributing to the information sharing blocks. Make sure to test your different implementations and add necessary unit tests. + +## Happy Coding! diff --git a/UniCeption/uniception/models/info_sharing/__init__.py b/UniCeption/uniception/models/info_sharing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d962f767f7bb38f32af68014404fea4ebe9b4b39 --- /dev/null +++ b/UniCeption/uniception/models/info_sharing/__init__.py @@ -0,0 +1,35 @@ +from uniception.models.info_sharing.alternating_attention_transformer import ( + MultiViewAlternatingAttentionTransformer, + MultiViewAlternatingAttentionTransformerIFR, +) +from uniception.models.info_sharing.cross_attention_transformer import ( + MultiViewCrossAttentionTransformer, + MultiViewCrossAttentionTransformerIFR, + MultiViewTransformerInput, +) +from uniception.models.info_sharing.diff_cross_attention_transformer import ( + DifferentialMultiViewCrossAttentionTransformer, + DifferentialMultiViewCrossAttentionTransformerIFR, +) +from uniception.models.info_sharing.global_attention_transformer import ( + MultiViewGlobalAttentionTransformer, + MultiViewGlobalAttentionTransformerIFR, +) + +INFO_SHARING_CLASSES = { + "cross_attention": (MultiViewCrossAttentionTransformer, MultiViewCrossAttentionTransformerIFR), + "diff_cross_attention": ( + DifferentialMultiViewCrossAttentionTransformer, + DifferentialMultiViewCrossAttentionTransformerIFR, + ), + "alternating_attention": ( + MultiViewAlternatingAttentionTransformer, + MultiViewAlternatingAttentionTransformerIFR, + ), + "global_attention": ( + MultiViewGlobalAttentionTransformer, + MultiViewGlobalAttentionTransformerIFR, + ), +} + +__all__ = ["INFO_SHARING_CLASSES", "MultiViewTransformerInput"] diff --git a/UniCeption/uniception/models/info_sharing/alternating_attention_transformer.py b/UniCeption/uniception/models/info_sharing/alternating_attention_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f6768042c22a3f1c1c643fb4938170d22bdaf774 --- /dev/null +++ b/UniCeption/uniception/models/info_sharing/alternating_attention_transformer.py @@ -0,0 +1,944 @@ +""" +UniCeption Alternating-Attention Transformer for Information Sharing +""" + +from functools import partial +from typing import Callable, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn + +from uniception.models.info_sharing.base import ( + MultiViewTransformerInput, + MultiViewTransformerOutput, + UniCeptionInfoSharingBase, +) +from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices +from uniception.models.utils.positional_encoding import PositionGetter +from uniception.models.utils.transformer_blocks import Mlp, SelfAttentionBlock + + +class MultiViewAlternatingAttentionTransformer(UniCeptionInfoSharingBase): + "UniCeption Multi-View Alternating-Attention Transformer for information sharing across image features from different views." + + def __init__( + self, + name: str, + input_embed_dim: int, + use_pe_for_non_reference_views: bool = False, + max_num_views_for_pe: int = 1000, + use_rand_idx_pe_for_non_reference_views: bool = True, + size: Optional[str] = None, + depth: int = 12, + dim: int = 768, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6), + mlp_layer: Type[nn.Module] = Mlp, + custom_positional_encoding: Optional[Callable] = None, + pretrained_checkpoint_path: Optional[str] = None, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Initialize the Multi-View Alternating-Attention Transformer for information sharing across image features from different views. + Alternates between global and frame-level attention. + + Args: + input_embed_dim (int): Dimension of input embeddings. + use_pe_for_non_reference_views (bool): Whether to use view positional encoding for input non-referenec views. (default: False) + max_num_views_for_pe (int): Maximum number of views for positional encoding. (default: 1000) + use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views. (default: True) + size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None) + depth (int): Number of transformer layers. (default: 12, base size) + dim (int): Dimension of the transformer. (default: 768, base size) + num_heads (int): Number of attention heads. (default: 12, base size) + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: True) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None) + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False) + """ + # Initialize the base class + super().__init__(name=name, size=size, *args, **kwargs) + + # Initialize the specific attributes of the transformer + self.input_embed_dim = input_embed_dim + self.use_pe_for_non_reference_views = use_pe_for_non_reference_views + self.max_num_views_for_pe = max_num_views_for_pe + self.use_rand_idx_pe_for_non_reference_views = use_rand_idx_pe_for_non_reference_views + self.depth = depth + self.dim = dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.proj_drop = proj_drop + self.attn_drop = attn_drop + self.init_values = init_values + self.drop_path = drop_path + self.act_layer = act_layer + self.norm_layer = norm_layer + self.mlp_layer = mlp_layer + self.custom_positional_encoding = custom_positional_encoding + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.gradient_checkpointing = gradient_checkpointing + + # Initialize the projection layer for input embeddings + if self.input_embed_dim != self.dim: + self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True) + else: + self.proj_embed = nn.Identity() + + # Initialize the self-attention blocks which ingest all views at once + self.self_attention_blocks = nn.ModuleList( + [ + SelfAttentionBlock( + dim=self.dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_norm=self.qk_norm, + proj_drop=self.proj_drop, + attn_drop=self.attn_drop, + init_values=self.init_values, + drop_path=self.drop_path, + act_layer=self.act_layer, + norm_layer=self.norm_layer, + mlp_layer=self.mlp_layer, + custom_positional_encoding=self.custom_positional_encoding, + ) + for _ in range(self.depth) + ] + ) + + # Initialize the final normalization layer + self.norm = self.norm_layer(self.dim) + + # Initialize the position getter for patch positions if required + if self.custom_positional_encoding is not None: + self.position_getter = PositionGetter() + + if self.use_pe_for_non_reference_views: + # Initialize the positional encoding table for the different views + self.register_buffer( + "view_pos_table", + self._get_sinusoid_encoding_table(self.max_num_views_for_pe, self.dim, 10000), + ) + else: + # Initialize the positional encoding table for the reference view + self.register_buffer( + "view_pos_table", + self._get_sinusoid_encoding_table(1, self.dim, 10000), + ) + + # Initialize random weights + self.initialize_weights() + + # Apply gradient checkpointing if enabled + if self.gradient_checkpointing: + for i, block in enumerate(self.self_attention_blocks): + self.self_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block) + + # Load pretrained weights if provided + if self.pretrained_checkpoint_path is not None: + print( + f"Loading pretrained multi-view Alternating-Attention transformer weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def _get_sinusoid_encoding_table(self, n_position, d_hid, base): + "Sinusoid position encoding table" + + def get_position_angle_vec(position): + return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) + + return torch.FloatTensor(sinusoid_table) + + def initialize_weights(self): + "Initialize weights of the transformer." + # Linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + "Initialize the transformer linear and layer norm weights." + if isinstance(m, nn.Linear): + # We use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + model_input: MultiViewTransformerInput, + ) -> MultiViewTransformerOutput: + """ + Forward interface for the Multi-View Alternating-Attention Transformer. + + Args: + model_input (MultiViewTransformerInput): Input to the model. + Expects the features to be a list of size (batch, input_embed_dim, height, width), + where each entry corresponds to a different view. + Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token) + which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens). + + Returns: + MultiViewTransformerOutput: Output of the model post information sharing. + """ + # Check that the number of views matches the input and the features are of expected shape + if self.use_pe_for_non_reference_views: + assert ( + len(model_input.features) <= self.max_num_views_for_pe + ), f"Expected less than {self.max_num_views_for_pe} views, got {len(model_input.features)}" + assert all( + view_features.shape[1] == self.input_embed_dim for view_features in model_input.features + ), f"All views must have input dimension {self.input_embed_dim}" + assert all( + view_features.ndim == 4 for view_features in model_input.features + ), "All views must have 4 dimensions (N, C, H, W)" + + # Initialize the multi-view features from the model input and number of views for current input + multi_view_features = model_input.features + num_of_views = len(multi_view_features) + batch_size, _, height, width = multi_view_features[0].shape + num_of_tokens_per_view = height * width + + # Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape) + multi_view_features = torch.stack(multi_view_features, dim=1) + + # Resize the multi-view features from NVCHW to NLC, where L = V * H * W + multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C) + multi_view_features = multi_view_features.reshape( + batch_size, num_of_views * height * width, self.input_embed_dim + ).contiguous() + + # Process additional input tokens if provided + if model_input.additional_input_tokens is not None: + + additional_tokens = model_input.additional_input_tokens + assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)" + assert ( + additional_tokens.shape[1] == self.input_embed_dim + ), f"Additional tokens must have input dimension {self.input_embed_dim}" + assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens" + + # Reshape to channel-last format for transformer processing + additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C) + + # Concatenate the additional tokens to the multi-view features + multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1) + + # Project input features to the transformer dimension + multi_view_features = self.proj_embed(multi_view_features) + + # Create patch positions for each view if custom positional encoding is used + if self.custom_positional_encoding is not None: + multi_view_positions = [ + self.position_getter(batch_size, height, width, multi_view_features.device) + ] * num_of_views # List of length V, where each tensor is (N, H * W, C) + multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C) + else: + multi_view_positions = [None] * num_of_views + + # Add None positions for additional tokens if they exist + if model_input.additional_input_tokens is not None: + + additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1] + multi_view_positions = multi_view_positions + additional_tokens_positions + + # Add positional encoding for reference view (idx 0) + ref_view_pe = self.view_pos_table[0].clone().detach() + ref_view_pe = ref_view_pe.reshape((1, 1, self.dim)) + ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1) + ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :] + ref_view_features = ref_view_features + ref_view_pe + + if self.use_pe_for_non_reference_views: + # Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled) + if self.use_rand_idx_pe_for_non_reference_views: + non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views_for_pe, size=(num_of_views - 1,)) + else: + non_ref_view_pe_indices = torch.arange(1, num_of_views) + non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach() + non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim)) + non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1) + non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1) + non_ref_view_features = multi_view_features[ + :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, : + ] + non_ref_view_features = non_ref_view_features + non_ref_view_pe + else: + non_ref_view_features = multi_view_features[ + :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, : + ] + + # Concatenate the reference and non-reference view features + # Handle additional tokens (no view-based positional encoding for them) + if model_input.additional_input_tokens is not None: + + additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1) + else: + multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1) + + # Loop over the depth of the transformer + for depth_idx in range(self.depth): + if depth_idx % 2 == 0: + # Apply the self-attention block and update the multi-view features + # Global attention across all views + multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions) + else: + # Handle additional tokens separately for frame-level attention + additional_features = None + additional_positions = None + if model_input.additional_input_tokens is not None: + + # Extract additional token features + additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + # Keep only view features for frame-level attention + multi_view_features = multi_view_features[:, : num_of_views * num_of_tokens_per_view, :] + + # Handle positions for additional tokens if custom positional encoding is used + if self.custom_positional_encoding is not None: + additional_positions = multi_view_positions[:, num_of_views * num_of_tokens_per_view :, :] + multi_view_positions = multi_view_positions[:, : num_of_views * num_of_tokens_per_view, :] + + # Reshape the multi-view features from (N, V * H * W, C) to (N * V, H * W, C) + multi_view_features = multi_view_features.reshape( + batch_size * num_of_views, num_of_tokens_per_view, self.dim + ).contiguous() # (N * V, H * W, C) + if multi_view_positions[0] is not None: + multi_view_positions = multi_view_positions.reshape( + batch_size * num_of_views, num_of_tokens_per_view, 2 + ).contiguous() # (N * V, H * W, C) + + # Apply the self-attention block and update the multi-view features + # Frame-level attention within each view + multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions) + + # Reshape the multi-view features from (N * V, H * W, C) back to (N, V * H * W, C) + multi_view_features = multi_view_features.reshape( + batch_size, num_of_views * num_of_tokens_per_view, self.dim + ).contiguous() # (N, V * H * W, C) + if multi_view_positions[0] is not None: + multi_view_positions = multi_view_positions.reshape( + batch_size, num_of_views * num_of_tokens_per_view, 2 + ).contiguous() # (N, V * H * W, C) + + # Reattach additional tokens if they exist + if additional_features is not None: + multi_view_features = torch.cat([multi_view_features, additional_features], dim=1) + # Reattach positions for additional tokens if they exist + if additional_positions is not None: + multi_view_positions = torch.cat([multi_view_positions, additional_positions], dim=1) + + # Normalize the output features + output_multi_view_features = self.norm(multi_view_features) + + # Extract only the view features (excluding additional tokens) + view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :] + + # Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W) + view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C) + view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W) + + # Split the output multi-view features into separate views + view_features = view_features.split(1, dim=1) + view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features] + + # Extract and return additional token features if provided + if model_input.additional_input_tokens is not None: + + additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T) + return MultiViewTransformerOutput( + features=view_features, additional_token_features=additional_token_features + ) + else: + return MultiViewTransformerOutput(features=view_features) + + +class MultiViewAlternatingAttentionTransformerIFR( + MultiViewAlternatingAttentionTransformer, IntermediateFeatureReturner +): + "Intermediate Feature Returner for UniCeption Multi-View Alternating-Attention Transformer" + + def __init__( + self, + name: str, + input_embed_dim: int, + use_pe_for_non_reference_views: bool = False, + max_num_views_for_pe: int = 1000, + use_rand_idx_pe_for_non_reference_views: bool = True, + size: Optional[str] = None, + depth: int = 12, + dim: int = 768, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + mlp_layer: nn.Module = Mlp, + custom_positional_encoding: Callable = None, + pretrained_checkpoint_path: str = None, + indices: Optional[Union[int, List[int]]] = None, + norm_intermediate: bool = True, + intermediates_only: bool = False, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Initialize the Multi-View Alternating-Attention Transformer for information sharing across image features from different views. + Extends the base class to return intermediate features. + + Args: + input_embed_dim (int): Dimension of input embeddings. + use_pe_for_non_reference_views (bool): Whether to use view positional encoding for input non-referenec views. (default: False) + max_num_views_for_pe (int): Maximum number of views for positional encoding. (default: 1000) + use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views. (default: True) + use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views. + size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None) + depth (int): Number of transformer layers. (default: 12, base size) + dim (int): Dimension of the transformer. (default: 768, base size) + num_heads (int): Number of attention heads. (default: 12, base size) + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None) + indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options: + - None: Return all intermediate layers. + - int: Return the last n layers. + - List[int]: Return the intermediate layers at the specified indices. + norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True) + intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False) + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False) + """ + # Init the base classes + MultiViewAlternatingAttentionTransformer.__init__( + self, + name=name, + input_embed_dim=input_embed_dim, + use_pe_for_non_reference_views=use_pe_for_non_reference_views, + max_num_views_for_pe=max_num_views_for_pe, + use_rand_idx_pe_for_non_reference_views=use_rand_idx_pe_for_non_reference_views, + size=size, + depth=depth, + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_drop=proj_drop, + attn_drop=attn_drop, + init_values=init_values, + drop_path=drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + mlp_layer=mlp_layer, + custom_positional_encoding=custom_positional_encoding, + pretrained_checkpoint_path=pretrained_checkpoint_path, + gradient_checkpointing=gradient_checkpointing, + *args, + **kwargs, + ) + IntermediateFeatureReturner.__init__( + self, + indices=indices, + norm_intermediate=norm_intermediate, + intermediates_only=intermediates_only, + ) + + def forward( + self, + model_input: MultiViewTransformerInput, + ) -> Union[ + List[MultiViewTransformerOutput], + Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]], + ]: + """ + Forward interface for the Multi-View Alternating-Attention Transformer with Intermediate Feature Return. + + Args: + model_input (MultiViewTransformerInput): Input to the model. + Expects the features to be a list of size (batch, input_embed_dim, height, width), + where each entry corresponds to a different view. + Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token) + which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens). + + Returns: + Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]: + Output of the model post information sharing. + If intermediates_only is True, returns a list of intermediate outputs. + If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs. + """ + # Check that the number of views matches the input and the features are of expected shape + if self.use_pe_for_non_reference_views: + assert ( + len(model_input.features) <= self.max_num_views_for_pe + ), f"Expected less than {self.max_num_views_for_pe} views, got {len(model_input.features)}" + assert all( + view_features.shape[1] == self.input_embed_dim for view_features in model_input.features + ), f"All views must have input dimension {self.input_embed_dim}" + assert all( + view_features.ndim == 4 for view_features in model_input.features + ), "All views must have 4 dimensions (N, C, H, W)" + + # Get the indices of the intermediate features to return + intermediate_multi_view_features = [] + take_indices, _ = feature_take_indices(self.depth, self.indices) + + # Initialize the multi-view features from the model input and number of views for current input + multi_view_features = model_input.features + num_of_views = len(multi_view_features) + batch_size, _, height, width = multi_view_features[0].shape + num_of_tokens_per_view = height * width + + # Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape) + multi_view_features = torch.stack(multi_view_features, dim=1) + + # Resize the multi-view features from NVCHW to NLC, where L = V * H * W + multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C) + multi_view_features = multi_view_features.reshape( + batch_size, num_of_views * height * width, self.input_embed_dim + ).contiguous() + + # Process additional input tokens if provided + if model_input.additional_input_tokens is not None: + + additional_tokens = model_input.additional_input_tokens + assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)" + assert ( + additional_tokens.shape[1] == self.input_embed_dim + ), f"Additional tokens must have input dimension {self.input_embed_dim}" + assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens" + + # Reshape to channel-last format for transformer processing + additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C) + + # Concatenate the additional tokens to the multi-view features + multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1) + + # Project input features to the transformer dimension + multi_view_features = self.proj_embed(multi_view_features) + + # Create patch positions for each view if custom positional encoding is used + if self.custom_positional_encoding is not None: + multi_view_positions = [ + self.position_getter(batch_size, height, width, multi_view_features.device) + ] * num_of_views # List of length V, where each tensor is (N, H * W, C) + multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C) + else: + multi_view_positions = [None] * num_of_views + + # Add None positions for additional tokens if they exist + if model_input.additional_input_tokens is not None: + + additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1] + multi_view_positions = multi_view_positions + additional_tokens_positions + + # Add positional encoding for reference view (idx 0) + ref_view_pe = self.view_pos_table[0].clone().detach() + ref_view_pe = ref_view_pe.reshape((1, 1, self.dim)) + ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1) + ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :] + ref_view_features = ref_view_features + ref_view_pe + + if self.use_pe_for_non_reference_views: + # Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled) + if self.use_rand_idx_pe_for_non_reference_views: + non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views_for_pe, size=(num_of_views - 1,)) + else: + non_ref_view_pe_indices = torch.arange(1, num_of_views) + non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach() + non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim)) + non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1) + non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1) + non_ref_view_features = multi_view_features[ + :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, : + ] + non_ref_view_features = non_ref_view_features + non_ref_view_pe + else: + non_ref_view_features = multi_view_features[ + :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, : + ] + + # Concatenate the reference and non-reference view features + # Handle additional tokens (no view-based positional encoding for them) + if model_input.additional_input_tokens is not None: + + additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1) + else: + multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1) + + # Loop over the depth of the transformer + for depth_idx in range(self.depth): + if depth_idx % 2 == 0: + # Apply the self-attention block and update the multi-view features + # Global attention across all views + multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions) + else: + # Handle additional tokens separately for frame-level attention + additional_features = None + additional_positions = None + if model_input.additional_input_tokens is not None: + + # Extract additional token features + additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + # Keep only view features for frame-level attention + multi_view_features = multi_view_features[:, : num_of_views * num_of_tokens_per_view, :] + + # Handle positions for additional tokens if custom positional encoding is used + if self.custom_positional_encoding is not None: + additional_positions = multi_view_positions[:, num_of_views * num_of_tokens_per_view :, :] + multi_view_positions = multi_view_positions[:, : num_of_views * num_of_tokens_per_view, :] + + # Reshape the multi-view features from (N, V * H * W, C) to (N * V, H * W, C) + multi_view_features = multi_view_features.reshape( + batch_size * num_of_views, num_of_tokens_per_view, self.dim + ).contiguous() # (N * V, H * W, C) + if multi_view_positions[0] is not None: + multi_view_positions = multi_view_positions.reshape( + batch_size * num_of_views, num_of_tokens_per_view, 2 + ).contiguous() # (N * V, H * W, C) + + # Apply the self-attention block and update the multi-view features + # Frame-level attention within each view + multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions) + + # Reshape the multi-view features from (N * V, H * W, C) back to (N, V * H * W, C) + multi_view_features = multi_view_features.reshape( + batch_size, num_of_views * num_of_tokens_per_view, self.dim + ).contiguous() # (N, V * H * W, C) + if multi_view_positions[0] is not None: + multi_view_positions = multi_view_positions.reshape( + batch_size, num_of_views * num_of_tokens_per_view, 2 + ).contiguous() # (N, V * H * W, C) + + # Reattach additional tokens if they exist + if additional_features is not None: + multi_view_features = torch.cat([multi_view_features, additional_features], dim=1) + # Reattach positions for additional tokens if they exist + if additional_positions is not None: + multi_view_positions = torch.cat([multi_view_positions, additional_positions], dim=1) + if depth_idx in take_indices: + # Normalize the intermediate features with final norm layer if enabled + intermediate_multi_view_features.append( + self.norm(multi_view_features) if self.norm_intermediate else multi_view_features + ) + + # Reshape the intermediate features and convert to MultiViewTransformerOutput class + for idx in range(len(intermediate_multi_view_features)): + # Get the current intermediate features + current_features = intermediate_multi_view_features[idx] + + # Extract additional token features if provided + additional_token_features = None + if model_input.additional_input_tokens is not None: + + additional_token_features = current_features[:, num_of_views * num_of_tokens_per_view :, :] + additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T) + # Only keep the view features for reshaping + current_features = current_features[:, : num_of_views * num_of_tokens_per_view, :] + + # Reshape the intermediate multi-view features (N, V * H * W, C) back to (N, V, C, H, W) + current_features = current_features.reshape( + batch_size, num_of_views, height, width, self.dim + ) # (N, V, H, W, C) + current_features = current_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W) + + # Split the intermediate multi-view features into separate views + current_features = current_features.split(1, dim=1) + current_features = [ + intermediate_view_features.squeeze(dim=1) for intermediate_view_features in current_features + ] + + intermediate_multi_view_features[idx] = MultiViewTransformerOutput( + features=current_features, additional_token_features=additional_token_features + ) + + # Return only the intermediate features if enabled + if self.intermediates_only: + return intermediate_multi_view_features + + # Normalize the output features + output_multi_view_features = self.norm(multi_view_features) + + # Extract view features (excluding additional tokens) + additional_token_features = None + if model_input.additional_input_tokens is not None: + + additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T) + view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :] + else: + view_features = output_multi_view_features + + # Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W) + view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C) + view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W) + + # Split the output multi-view features into separate views + view_features = view_features.split(1, dim=1) + view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features] + + output_multi_view_features = MultiViewTransformerOutput( + features=view_features, additional_token_features=additional_token_features + ) + + return output_multi_view_features, intermediate_multi_view_features + + +def dummy_positional_encoding(x, xpos): + "Dummy function for positional encoding of tokens" + x = x + xpos = xpos + return x + + +def test_reshape_for_frame_attention(): + "Test the reshape function for frame-level attention in the Alternating Attention Transformer" + batch_size = 2 + num_of_views = 3 + height = width = 2 + dim = 4 + num_of_tokens_per_view = height * width + + # Create tensor with recognizable pattern + x = torch.zeros(batch_size, num_of_views * num_of_tokens_per_view, dim) + for b in range(batch_size): + for v in range(num_of_views): + for h in range(height): + for w in range(width): + token_idx = v * num_of_tokens_per_view + h * width + w + x[b, token_idx] = torch.tensor([b, v, h, w]) + + # Apply reshape + reshaped = x.reshape(batch_size * num_of_views, num_of_tokens_per_view, dim).contiguous() + + # Verify shape + assert reshaped.shape == (batch_size * num_of_views, num_of_tokens_per_view, dim) + + # Verify content (check a few values) + for b in range(batch_size): + for v in range(num_of_views): + for h in range(height): + for w in range(width): + batch_view_idx = b * num_of_views + v + token_idx = h * width + w + expected = torch.tensor([b, v, h, w]) + assert torch.all(reshaped[batch_view_idx, token_idx] == expected) + + # Verify reshape back works + back_to_original = reshaped.reshape(batch_size, num_of_views * num_of_tokens_per_view, dim) + assert torch.all(x == back_to_original) + + print("Reshape test passed!") + + +if __name__ == "__main__": + # Unit test the reshape logic used for frame-level attention + test_reshape_for_frame_attention() + + # Init multi-view alternating-attention transformer with no custom positional encoding and run a forward pass + for num_views in [2, 3, 4]: + print(f"Testing MultiViewAlternatingAttentionTransformer with {num_views} views ...") + # No positional encoding for non-reference views + model = MultiViewAlternatingAttentionTransformer( + name="MV-AAT", + input_embed_dim=1024, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + # Sequential idx based positional encoding + model = MultiViewAlternatingAttentionTransformer( + name="MV-AAT", + input_embed_dim=1024, + use_pe_for_non_reference_views=True, + max_num_views_for_pe=1000, + use_rand_idx_pe_for_non_reference_views=False, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + # Random idx based positional encoding + model = MultiViewAlternatingAttentionTransformer( + name="MV-AAT", + input_embed_dim=1024, + use_pe_for_non_reference_views=True, + max_num_views_for_pe=1000, + use_rand_idx_pe_for_non_reference_views=True, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + + # Init multi-view alternating-attention transformer with custom positional encoding and run a forward pass + for num_views in [2, 3, 4]: + print( + f"Testing MultiViewAlternatingAttentionTransformer with {num_views} views and custom positional encoding ..." + ) + model = MultiViewAlternatingAttentionTransformer( + name="MV-AAT", + input_embed_dim=1024, + custom_positional_encoding=dummy_positional_encoding, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + + print("All multi-view alternating-attention transformers initialized and tested successfully!") + + # Intermediate Feature Returner Tests + print("Running Intermediate Feature Returner Tests ...") + + # Run the intermediate feature returner with last-n index + model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR( + name="MV-AAT-IFR", + input_embed_dim=1024, + indices=6, # Last 6 layers + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert len(output[1]) == 6 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert len(output[1][0].features) == 2 + + # Run the intermediate feature returner with specific indices + model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR( + name="MV-AAT-IFR", + input_embed_dim=1024, + indices=[0, 2, 4, 6], # Specific indices + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert len(output[1]) == 4 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert len(output[1][0].features) == 2 + + # Test the normalizing of intermediate features + model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR( + name="MV-AAT-IFR", + input_embed_dim=1024, + indices=[-1], # Last layer + norm_intermediate=False, # Disable normalization + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + for view_idx in range(2): + assert not torch.equal( + output[0].features[view_idx], output[1][-1].features[view_idx] + ), "Final features and intermediate features (last layer) must be different." + + model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR( + name="MV-AAT-IFR", + input_embed_dim=1024, + indices=[-1], # Last layer + norm_intermediate=True, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + for view_idx in range(2): + assert torch.equal( + output[0].features[view_idx], output[1][-1].features[view_idx] + ), "Final features and intermediate features (last layer) must be same." + + print("All Intermediate Feature Returner Tests passed!") + + # Test additonal input tokens for MultiViewAlternatingAttentionTransformer + print("Testing MultiViewAlternatingAttentionTransformer with additional input tokens ...") + model = MultiViewAlternatingAttentionTransformer( + name="MV-AAT", + input_embed_dim=1024, + ) + num_views = 2 + num_additional_tokens = 5 + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + additional_tokens = torch.rand(1, 1024, num_additional_tokens) + model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + assert model_output.additional_token_features is not None + assert model_output.additional_token_features.shape == (1, model.dim, num_additional_tokens) + + # Test additonal input tokens for MultiViewAlternatingAttentionTransformerIFR + print("Testing MultiViewAlternatingAttentionTransformerIFR with additional input tokens ...") + model_ifr = MultiViewAlternatingAttentionTransformerIFR( + name="MV-AAT-IFR", + input_embed_dim=1024, + indices=[0, 2, 4], + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + additional_tokens = torch.rand(1, 1024, num_additional_tokens) + model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens) + output = model_ifr(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert output[0].additional_token_features is not None + assert output[0].additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens) + assert len(output[1]) == 3 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert all(intermediate.additional_token_features is not None for intermediate in output[1]) + assert all( + intermediate.additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens) + for intermediate in output[1] + ) + + print("All tests using additional input tokens passed!") diff --git a/UniCeption/uniception/models/info_sharing/base.py b/UniCeption/uniception/models/info_sharing/base.py new file mode 100644 index 0000000000000000000000000000000000000000..99f8f84c13c1affc6130567582ea0c8889c77ee4 --- /dev/null +++ b/UniCeption/uniception/models/info_sharing/base.py @@ -0,0 +1,116 @@ +""" +Base Information Sharing Class for UniCeption +""" + +from dataclasses import dataclass +from typing import List, Optional + +import torch.nn as nn +from jaxtyping import Float +from torch import Tensor +from torch.utils.checkpoint import checkpoint + + +@dataclass +class InfoSharingInput: + pass + + +@dataclass +class InfoSharingOutput: + pass + + +class UniCeptionInfoSharingBase(nn.Module): + "Information Sharing Base Class for UniCeption" + + def __init__( + self, + name: str, + size: Optional[str] = None, + *args, + **kwargs, + ): + """ + Base class for all models in UniCeption. + """ + super().__init__(*args, **kwargs) + + self.name: str = name + self.size: Optional[str] = size + + def forward( + self, + model_input: InfoSharingInput, + ) -> InfoSharingOutput: + """ + Forward interface for the UniCeption information sharing models. + + Args: + model_input (InfoSharingInput): Input to the model. + This is also includes the other fields that are required by the specific implementation of the model. + + Returns: + InfoSharingOutput: Output of the model. + """ + + raise NotImplementedError + + def wrap_module_with_gradient_checkpointing(self, module: nn.Module): + """ + Wrapper for Gradient Checkpointing + """ + + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +@dataclass +class MultiViewTransformerInput(InfoSharingInput): + """ + Input class for Multi-View Transformer. + """ + + features: List[Float[Tensor, "batch input_embed_dim feat_height feat_width"]] + additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None + + +@dataclass +class MultiViewTransformerOutput(InfoSharingOutput): + """ + Output class for Multi-View Transformer. + """ + + features: List[Float[Tensor, "batch transformer_embed_dim feat_height feat_width"]] + additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None + + +@dataclass +class MultiSetTransformerInput(InfoSharingInput): + """ + Input class for Multi-Set Transformer. + """ + + features: List[Float[Tensor, "batch input_embed_dim num_tokens"]] + additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None + + +@dataclass +class MultiSetTransformerOutput(InfoSharingOutput): + """ + Output class for Multi-Set Transformer. + """ + + features: List[Float[Tensor, "batch transformer_embed_dim num_tokens"]] + additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None + + +if __name__ == "__main__": + dummy_model = UniCeptionInfoSharingBase(name="dummy") + print("Dummy Base InfoSharing model created successfully!") diff --git a/UniCeption/uniception/models/info_sharing/cross_attention_transformer.py b/UniCeption/uniception/models/info_sharing/cross_attention_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..32de983927a983064df9e44e3334d4668b8eeb7a --- /dev/null +++ b/UniCeption/uniception/models/info_sharing/cross_attention_transformer.py @@ -0,0 +1,582 @@ +""" +UniCeption Cross-Attention Transformer for Information Sharing +""" + +from copy import deepcopy +from functools import partial +from typing import Callable, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn + +from uniception.models.info_sharing.base import ( + MultiViewTransformerInput, + MultiViewTransformerOutput, + UniCeptionInfoSharingBase, +) +from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices +from uniception.models.utils.positional_encoding import PositionGetter +from uniception.models.utils.transformer_blocks import CrossAttentionBlock, Mlp + + +class MultiViewCrossAttentionTransformer(UniCeptionInfoSharingBase): + "UniCeption Multi-View Cross-Attention Transformer for information sharing across image features from different views." + + def __init__( + self, + name: str, + input_embed_dim: int, + num_views: int, + size: Optional[str] = None, + depth: int = 12, + dim: int = 768, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6), + mlp_layer: Type[nn.Module] = Mlp, + custom_positional_encoding: Optional[Callable] = None, + norm_cross_tokens: bool = True, + pretrained_checkpoint_path: Optional[str] = None, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views. + Creates a cross-attention transformer with multiple branches for each view. + + Args: + input_embed_dim (int): Dimension of input embeddings. + num_views (int): Number of views (input feature sets). + size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None) + depth (int): Number of transformer layers. (default: 12, base size) + dim (int): Dimension of the transformer. (default: 768, base size) + num_heads (int): Number of attention heads. (default: 12, base size) + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: True) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + norm_cross_tokens (bool): Whether to normalize cross tokens (default: True) + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None) + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False) + """ + # Initialize the base class + super().__init__(name=name, size=size, *args, **kwargs) + + # Initialize the specific attributes of the transformer + self.input_embed_dim = input_embed_dim + self.num_views = num_views + self.depth = depth + self.dim = dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.proj_drop = proj_drop + self.attn_drop = attn_drop + self.init_values = init_values + self.drop_path = drop_path + self.act_layer = act_layer + self.norm_layer = norm_layer + self.mlp_layer = mlp_layer + self.custom_positional_encoding = custom_positional_encoding + self.norm_cross_tokens = norm_cross_tokens + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.gradient_checkpointing = gradient_checkpointing + + # Initialize the projection layer for input embeddings + if self.input_embed_dim != self.dim: + self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True) + else: + self.proj_embed = nn.Identity() + + # Initialize the cross-attention blocks for a single view + cross_attention_blocks = nn.ModuleList( + [ + CrossAttentionBlock( + dim=self.dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_norm=self.qk_norm, + proj_drop=self.proj_drop, + attn_drop=self.attn_drop, + init_values=self.init_values, + drop_path=self.drop_path, + act_layer=self.act_layer, + norm_layer=self.norm_layer, + mlp_layer=self.mlp_layer, + custom_positional_encoding=self.custom_positional_encoding, + norm_cross_tokens=self.norm_cross_tokens, + ) + for _ in range(self.depth) + ] + ) + + # Copy the cross-attention blocks for all other views + self.multi_view_branches = nn.ModuleList([cross_attention_blocks]) + for _ in range(1, self.num_views): + self.multi_view_branches.append(deepcopy(cross_attention_blocks)) + + # Initialize the final normalization layer + self.norm = self.norm_layer(self.dim) + + # Initialize the position getter for patch positions if required + if self.custom_positional_encoding is not None: + self.position_getter = PositionGetter() + + # Initialize random weights + self.initialize_weights() + + # Apply gradient checkpointing if enabled + if self.gradient_checkpointing: + for i, block in enumerate(self.cross_attention_blocks): + self.cross_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block) + + # Load pretrained weights if provided + if self.pretrained_checkpoint_path is not None: + print( + f"Loading pretrained multi-view cross-attention transformer weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def initialize_weights(self): + "Initialize weights of the transformer." + # Linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + "Initialize the transformer linear and layer norm weights." + if isinstance(m, nn.Linear): + # We use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + model_input: MultiViewTransformerInput, + ) -> MultiViewTransformerOutput: + """ + Forward interface for the Multi-View Cross-Attention Transformer. + + Args: + model_input (MultiViewTransformerInput): Input to the model. + Expects the features to be a list of size (batch, input_embed_dim, height, width), + where each entry corresponds to a different view. + + Returns: + MultiViewTransformerOutput: Output of the model post information sharing. + """ + # Check that the number of views matches the input and the features are of expected shape + assert ( + len(model_input.features) == self.num_views + ), f"Expected {self.num_views} views, got {len(model_input.features)}" + assert all( + view_features.shape[1] == self.input_embed_dim for view_features in model_input.features + ), f"All views must have input dimension {self.input_embed_dim}" + assert all( + view_features.ndim == 4 for view_features in model_input.features + ), "All views must have 4 dimensions (N, C, H, W)" + + # Initialize the multi-view features from the model input + multi_view_features = model_input.features + + # Resize the multi-view features from NCHW to NLC + batch_size, _, height, width = multi_view_features[0].shape + multi_view_features = [ + view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous() + for view_features in multi_view_features + ] + + # Create patch positions for each view if custom positional encoding is used + if self.custom_positional_encoding is not None: + multi_view_positions = [ + self.position_getter(batch_size, height, width, view_features.device) + for view_features in multi_view_features + ] + else: + multi_view_positions = [None] * self.num_views + + # Project input features to the transformer dimension + multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features] + + # Pass through each view's cross-attention blocks + # Loop over the depth of the transformer + for depth_idx in range(self.depth): + updated_multi_view_features = [] + # Loop over each view + for view_idx, view_features in enumerate(multi_view_features): + # Get all the other views + other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx] + # Concatenate all the tokens from the other views + other_views_features = torch.cat(other_views_features, dim=1) + # Get the positions for the current view + view_positions = multi_view_positions[view_idx] + # Get the positions for all other views + other_views_positions = ( + torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1) + if view_positions is not None + else None + ) + # Apply the cross-attention block and update the multi-view features + updated_view_features = self.multi_view_branches[view_idx][depth_idx]( + view_features, other_views_features, view_positions, other_views_positions + ) + # Keep track of the updated view features + updated_multi_view_features.append(updated_view_features) + # Update the multi-view features for the next depth + multi_view_features = updated_multi_view_features + + # Normalize the output features + output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features] + + # Resize the output multi-view features back to NCHW + output_multi_view_features = [ + view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous() + for view_features in output_multi_view_features + ] + + return MultiViewTransformerOutput(features=output_multi_view_features) + + +class MultiViewCrossAttentionTransformerIFR(MultiViewCrossAttentionTransformer, IntermediateFeatureReturner): + "Intermediate Feature Returner for UniCeption Multi-View Cross-Attention Transformer" + + def __init__( + self, + name: str, + input_embed_dim: int, + num_views: int, + size: Optional[str] = None, + depth: int = 12, + dim: int = 768, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + mlp_layer: nn.Module = Mlp, + custom_positional_encoding: Callable = None, + norm_cross_tokens: bool = True, + pretrained_checkpoint_path: str = None, + indices: Optional[Union[int, List[int]]] = None, + norm_intermediate: bool = True, + intermediates_only: bool = False, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views. + Creates a cross-attention transformer with multiple branches for each view. + Extends the base class to return intermediate features. + + Args: + input_embed_dim (int): Dimension of input embeddings. + num_views (int): Number of views (input feature sets). + size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None) + depth (int): Number of transformer layers. (default: 12, base size) + dim (int): Dimension of the transformer. (default: 768, base size) + num_heads (int): Number of attention heads. (default: 12, base size) + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: True) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + norm_cross_tokens (bool): Whether to normalize cross tokens (default: True) + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None) + indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options: + - None: Return all intermediate layers. + - int: Return the last n layers. + - List[int]: Return the intermediate layers at the specified indices. + norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True) + intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False) + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False) + """ + # Init the base classes + MultiViewCrossAttentionTransformer.__init__( + self, + name=name, + input_embed_dim=input_embed_dim, + num_views=num_views, + size=size, + depth=depth, + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_drop=proj_drop, + attn_drop=attn_drop, + init_values=init_values, + drop_path=drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + mlp_layer=mlp_layer, + custom_positional_encoding=custom_positional_encoding, + norm_cross_tokens=norm_cross_tokens, + pretrained_checkpoint_path=pretrained_checkpoint_path, + gradient_checkpointing=gradient_checkpointing, + *args, + **kwargs, + ) + IntermediateFeatureReturner.__init__( + self, + indices=indices, + norm_intermediate=norm_intermediate, + intermediates_only=intermediates_only, + ) + + def forward( + self, + model_input: MultiViewTransformerInput, + ) -> Union[ + List[MultiViewTransformerOutput], + Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]], + ]: + """ + Forward interface for the Multi-View Cross-Attention Transformer with Intermediate Feature Return. + + Args: + model_input (MultiViewTransformerInput): Input to the model. + Expects the features to be a list of size (batch, input_embed_dim, height, width), + where each entry corresponds to a different view. + + Returns: + Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]: + Output of the model post information sharing. + If intermediates_only is True, returns a list of intermediate outputs. + If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs. + """ + # Check that the number of views matches the input and the features are of expected shape + assert ( + len(model_input.features) == self.num_views + ), f"Expected {self.num_views} views, got {len(model_input.features)}" + assert all( + view_features.shape[1] == self.input_embed_dim for view_features in model_input.features + ), f"All views must have input dimension {self.input_embed_dim}" + assert all( + view_features.ndim == 4 for view_features in model_input.features + ), "All views must have 4 dimensions (N, C, H, W)" + + # Get the indices of the intermediate features to return + intermediate_multi_view_features = [] + take_indices, _ = feature_take_indices(self.depth, self.indices) + + # Initialize the multi-view features from the model input + multi_view_features = model_input.features + + # Resize the multi-view features from NCHW to NLC + batch_size, _, height, width = multi_view_features[0].shape + multi_view_features = [ + view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous() + for view_features in multi_view_features + ] + + # Create patch positions for each view if custom positional encoding is used + if self.custom_positional_encoding is not None: + multi_view_positions = [ + self.position_getter(batch_size, height, width, view_features.device) + for view_features in multi_view_features + ] + else: + multi_view_positions = [None] * self.num_views + + # Project input features to the transformer dimension + multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features] + + # Pass through each view's cross-attention blocks + # Loop over the depth of the transformer + for depth_idx in range(self.depth): + updated_multi_view_features = [] + # Loop over each view + for view_idx, view_features in enumerate(multi_view_features): + # Get all the other views + other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx] + # Concatenate all the tokens from the other views + other_views_features = torch.cat(other_views_features, dim=1) + # Get the positions for the current view + view_positions = multi_view_positions[view_idx] + # Get the positions for all other views + other_views_positions = ( + torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1) + if view_positions is not None + else None + ) + # Apply the cross-attention block and update the multi-view features + updated_view_features = self.multi_view_branches[view_idx][depth_idx]( + view_features, other_views_features, view_positions, other_views_positions + ) + # Keep track of the updated view features + updated_multi_view_features.append(updated_view_features) + # Update the multi-view features for the next depth + multi_view_features = updated_multi_view_features + # Append the intermediate features if required + if depth_idx in take_indices: + # Normalize the intermediate features with final norm layer if enabled + intermediate_multi_view_features.append( + [self.norm(view_features) for view_features in multi_view_features] + if self.norm_intermediate + else multi_view_features + ) + + # Reshape the intermediate features and convert to MultiViewTransformerOutput class + for idx in range(len(intermediate_multi_view_features)): + intermediate_multi_view_features[idx] = [ + view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous() + for view_features in intermediate_multi_view_features[idx] + ] + intermediate_multi_view_features[idx] = MultiViewTransformerOutput( + features=intermediate_multi_view_features[idx] + ) + + # Return only the intermediate features if enabled + if self.intermediates_only: + return intermediate_multi_view_features + + # Normalize the output features + output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features] + + # Resize the output multi-view features back to NCHW + output_multi_view_features = [ + view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous() + for view_features in output_multi_view_features + ] + + output_multi_view_features = MultiViewTransformerOutput(features=output_multi_view_features) + + return output_multi_view_features, intermediate_multi_view_features + + +def dummy_positional_encoding(x, xpos): + "Dummy function for positional encoding of tokens" + x = x + xpos = xpos + return x + + +if __name__ == "__main__": + # Init multi-view cross-attention transformer with no custom positional encoding and run a forward pass + for num_views in [2, 3, 4]: + print(f"Testing MultiViewCrossAttentionTransformer with {num_views} views ...") + model = MultiViewCrossAttentionTransformer(name="MV-CAT", input_embed_dim=1024, num_views=num_views) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + + # Init multi-view cross-attention transformer with custom positional encoding and run a forward pass + for num_views in [2, 3, 4]: + print(f"Testing MultiViewCrossAttentionTransformer with {num_views} views and custom positional encoding ...") + model = MultiViewCrossAttentionTransformer( + name="MV-CAT", + input_embed_dim=1024, + num_views=num_views, + custom_positional_encoding=dummy_positional_encoding, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + + print("All multi-view cross-attention transformers initialized and tested successfully!") + + # Intermediate Feature Returner Tests + print("Running Intermediate Feature Returner Tests ...") + + # Run the intermediate feature returner with last-n index + model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR( + name="MV-CAT-IFR", + input_embed_dim=1024, + num_views=2, + indices=6, # Last 6 layers + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert len(output[1]) == 6 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert len(output[1][0].features) == 2 + + # Run the intermediate feature returner with specific indices + model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR( + name="MV-CAT-IFR", + input_embed_dim=1024, + num_views=2, + indices=[0, 2, 4, 6], # Specific indices + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert len(output[1]) == 4 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert len(output[1][0].features) == 2 + + # Test the normalizing of intermediate features + model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR( + name="MV-CAT-IFR", + input_embed_dim=1024, + num_views=2, + indices=[-1], # Last layer + norm_intermediate=False, # Disable normalization + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + for view_idx in range(2): + assert not torch.equal( + output[0].features[view_idx], output[1][-1].features[view_idx] + ), "Final features and intermediate features (last layer) must be different." + + model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR( + name="MV-CAT-IFR", + input_embed_dim=1024, + num_views=2, + indices=[-1], # Last layer + norm_intermediate=True, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + for view_idx in range(2): + assert torch.equal( + output[0].features[view_idx], output[1][-1].features[view_idx] + ), "Final features and intermediate features (last layer) must be same." + + print("All Intermediate Feature Returner Tests passed!") diff --git a/UniCeption/uniception/models/info_sharing/diff_cross_attention_transformer.py b/UniCeption/uniception/models/info_sharing/diff_cross_attention_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d6e4d55f91123de08d812944ba837a9e0e83b653 --- /dev/null +++ b/UniCeption/uniception/models/info_sharing/diff_cross_attention_transformer.py @@ -0,0 +1,588 @@ +""" +UniCeption Cross-Attention Transformer for Information Sharing +""" + +from copy import deepcopy +from functools import partial +from typing import Callable, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn + +from uniception.models.info_sharing.base import UniCeptionInfoSharingBase +from uniception.models.info_sharing.cross_attention_transformer import ( + MultiViewTransformerInput, + MultiViewTransformerOutput, + PositionGetter, +) +from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices +from uniception.models.utils.transformer_blocks import DiffCrossAttentionBlock, Mlp + + +class DifferentialMultiViewCrossAttentionTransformer(UniCeptionInfoSharingBase): + "UniCeption Multi-View Cross-Attention Transformer for information sharing across image features from different views." + + def __init__( + self, + name: str, + input_embed_dim: int, + num_views: int, + size: Optional[str] = None, + depth: int = 12, + dim: int = 768, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6), + mlp_layer: Type[nn.Module] = Mlp, + custom_positional_encoding: Optional[Callable] = None, + norm_cross_tokens: bool = True, + pretrained_checkpoint_path: Optional[str] = None, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views. + Creates a cross-attention transformer with multiple branches for each view. + + Args: + input_embed_dim (int): Dimension of input embeddings. + num_views (int): Number of views (input feature sets). + depth (int): Number of transformer layers. (default: 12, base size) + dim (int): Dimension of the transformer. (default: 768, base size) + num_heads (int): Number of attention heads. (default: 12, base size) + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + norm_cross_tokens (bool): Whether to normalize cross tokens (default: True) + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None) + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False) + """ + # Initialize the base class + super().__init__(name=name, size=size, *args, **kwargs) + + # Initialize the specific attributes of the transformer + self.input_embed_dim = input_embed_dim + self.num_views = num_views + self.depth = depth + self.dim = dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.proj_drop = proj_drop + self.attn_drop = attn_drop + self.init_values = init_values + self.drop_path = drop_path + self.act_layer = act_layer + self.norm_layer = norm_layer + self.mlp_layer = mlp_layer + self.custom_positional_encoding = custom_positional_encoding + self.norm_cross_tokens = norm_cross_tokens + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.gradient_checkpointing = gradient_checkpointing + + # Initialize the projection layer for input embeddings + if self.input_embed_dim != self.dim: + self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True) + else: + self.proj_embed = nn.Identity() + + # Initialize the cross-attention blocks for a single view + assert num_heads % 2 == 0, "Number of heads must be divisible by 2 for differential cross-attention." + cross_attention_blocks = nn.ModuleList( + [ + DiffCrossAttentionBlock( + depth=i, + dim=self.dim, + num_heads=self.num_heads // 2, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_norm=self.qk_norm, + proj_drop=self.proj_drop, + attn_drop=self.attn_drop, + init_values=self.init_values, + drop_path=self.drop_path, + act_layer=self.act_layer, + norm_layer=self.norm_layer, + mlp_layer=self.mlp_layer, + custom_positional_encoding=self.custom_positional_encoding, + norm_cross_tokens=self.norm_cross_tokens, + ) + for i in range(self.depth) + ] + ) + + # Copy the cross-attention blocks for all other views + self.multi_view_branches = nn.ModuleList([cross_attention_blocks]) + for _ in range(1, self.num_views): + self.multi_view_branches.append(deepcopy(cross_attention_blocks)) + + # Initialize the final normalization layer + self.norm = self.norm_layer(self.dim) + + # Initialize the position getter for patch positions if required + if self.custom_positional_encoding is not None: + self.position_getter = PositionGetter() + + # Initialize random weights + self.initialize_weights() + + # Apply gradient checkpointing if enabled + if self.gradient_checkpointing: + for i, block in enumerate(self.cross_attention_blocks): + self.cross_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block) + + # Load pretrained weights if provided + if self.pretrained_checkpoint_path is not None: + print( + f"Loading pretrained multi-view cross-attention transformer weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def initialize_weights(self): + "Initialize weights of the transformer." + # Linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + "Initialize the transformer linear and layer norm weights." + if isinstance(m, nn.Linear): + # We use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + model_input: MultiViewTransformerInput, + ) -> MultiViewTransformerOutput: + """ + Forward interface for the Multi-View Cross-Attention Transformer. + + Args: + model_input (MultiViewTransformerInput): Input to the model. + Expects the features to be a list of size (batch, input_embed_dim, height, width), + where each entry corresponds to a different view. + + Returns: + MultiViewTransformerOutput: Output of the model post information sharing. + """ + # Check that the number of views matches the input and the features are of expected shape + assert ( + len(model_input.features) == self.num_views + ), f"Expected {self.num_views} views, got {len(model_input.features)}" + assert all( + view_features.shape[1] == self.input_embed_dim for view_features in model_input.features + ), f"All views must have input dimension {self.input_embed_dim}" + assert all( + view_features.ndim == 4 for view_features in model_input.features + ), "All views must have 4 dimensions (N, C, H, W)" + + # Initialize the multi-view features from the model input + multi_view_features = model_input.features + + # Resize the multi-view features from NCHW to NLC + batch_size, _, height, width = multi_view_features[0].shape + multi_view_features = [ + view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous() + for view_features in multi_view_features + ] + + # Create patch positions for each view if custom positional encoding is used + if self.custom_positional_encoding is not None: + multi_view_positions = [ + self.position_getter(batch_size, height, width, view_features.device) + for view_features in multi_view_features + ] + else: + multi_view_positions = [None] * self.num_views + + # Project input features to the transformer dimension + multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features] + + # Pass through each view's cross-attention blocks + # Loop over the depth of the transformer + for depth_idx in range(self.depth): + updated_multi_view_features = [] + # Loop over each view + for view_idx, view_features in enumerate(multi_view_features): + # Get all the other views + other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx] + # Concatenate all the tokens from the other views + other_views_features = torch.cat(other_views_features, dim=1) + # Get the positions for the current view + view_positions = multi_view_positions[view_idx] + # Get the positions for all other views + other_views_positions = ( + torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1) + if view_positions is not None + else None + ) + # Apply the cross-attention block and update the multi-view features + updated_view_features = self.multi_view_branches[view_idx][depth_idx]( + view_features, other_views_features, view_positions, other_views_positions + ) + # Keep track of the updated view features + updated_multi_view_features.append(updated_view_features) + # Update the multi-view features for the next depth + multi_view_features = updated_multi_view_features + + # Normalize the output features + output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features] + + # Resize the output multi-view features back to NCHW + output_multi_view_features = [ + view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous() + for view_features in output_multi_view_features + ] + + return MultiViewTransformerOutput(features=output_multi_view_features) + + +class DifferentialMultiViewCrossAttentionTransformerIFR( + DifferentialMultiViewCrossAttentionTransformer, IntermediateFeatureReturner +): + "Intermediate Feature Returner for UniCeption Multi-View Cross-Attention Transformer" + + def __init__( + self, + name: str, + input_embed_dim: int, + num_views: int, + size: Optional[str] = None, + depth: int = 12, + dim: int = 768, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + mlp_layer: nn.Module = Mlp, + custom_positional_encoding: Callable = None, + norm_cross_tokens: bool = True, + pretrained_checkpoint_path: str = None, + indices: Optional[Union[int, List[int]]] = None, + norm_intermediate: bool = True, + intermediates_only: bool = False, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views. + Creates a cross-attention transformer with multiple branches for each view. + Extends the base class to return intermediate features. + + Args: + input_embed_dim (int): Dimension of input embeddings. + num_views (int): Number of views (input feature sets). + depth (int): Number of transformer layers. (default: 12, base size) + dim (int): Dimension of the transformer. (default: 768, base size) + num_heads (int): Number of attention heads. (default: 12, base size) + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + norm_cross_tokens (bool): Whether to normalize cross tokens (default: True) + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None) + indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options: + - None: Return all intermediate layers. + - int: Return the last n layers. + - List[int]: Return the intermediate layers at the specified indices. + norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True) + intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False) + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False) + """ + # Init the base classes + DifferentialMultiViewCrossAttentionTransformer.__init__( + self, + name=name, + input_embed_dim=input_embed_dim, + num_views=num_views, + size=size, + depth=depth, + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_drop=proj_drop, + attn_drop=attn_drop, + init_values=init_values, + drop_path=drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + mlp_layer=mlp_layer, + custom_positional_encoding=custom_positional_encoding, + norm_cross_tokens=norm_cross_tokens, + pretrained_checkpoint_path=pretrained_checkpoint_path, + gradient_checkpointing=gradient_checkpointing, + *args, + **kwargs, + ) + IntermediateFeatureReturner.__init__( + self, + indices=indices, + norm_intermediate=norm_intermediate, + intermediates_only=intermediates_only, + ) + + def forward( + self, + model_input: MultiViewTransformerInput, + ) -> Union[ + List[MultiViewTransformerOutput], + Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]], + ]: + """ + Forward interface for the Multi-View Cross-Attention Transformer with Intermediate Feature Return. + + Args: + model_input (MultiViewTransformerInput): Input to the model. + Expects the features to be a list of size (batch, input_embed_dim, height, width), + where each entry corresponds to a different view. + + Returns: + Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]: + Output of the model post information sharing. + If intermediates_only is True, returns a list of intermediate outputs. + If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs. + """ + # Check that the number of views matches the input and the features are of expected shape + assert ( + len(model_input.features) == self.num_views + ), f"Expected {self.num_views} views, got {len(model_input.features)}" + assert all( + view_features.shape[1] == self.input_embed_dim for view_features in model_input.features + ), f"All views must have input dimension {self.input_embed_dim}" + assert all( + view_features.ndim == 4 for view_features in model_input.features + ), "All views must have 4 dimensions (N, C, H, W)" + + # Get the indices of the intermediate features to return + intermediate_multi_view_features = [] + take_indices, _ = feature_take_indices(self.depth, self.indices) + + # Initialize the multi-view features from the model input + multi_view_features = model_input.features + + # Resize the multi-view features from NCHW to NLC + batch_size, _, height, width = multi_view_features[0].shape + multi_view_features = [ + view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous() + for view_features in multi_view_features + ] + + # Create patch positions for each view if custom positional encoding is used + if self.custom_positional_encoding is not None: + multi_view_positions = [ + self.position_getter(batch_size, height, width, view_features.device) + for view_features in multi_view_features + ] + else: + multi_view_positions = [None] * self.num_views + + # Project input features to the transformer dimension + multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features] + + # Pass through each view's cross-attention blocks + # Loop over the depth of the transformer + for depth_idx in range(self.depth): + updated_multi_view_features = [] + # Loop over each view + for view_idx, view_features in enumerate(multi_view_features): + # Get all the other views + other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx] + # Concatenate all the tokens from the other views + other_views_features = torch.cat(other_views_features, dim=1) + # Get the positions for the current view + view_positions = multi_view_positions[view_idx] + # Get the positions for all other views + other_views_positions = ( + torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1) + if view_positions is not None + else None + ) + # Apply the cross-attention block and update the multi-view features + updated_view_features = self.multi_view_branches[view_idx][depth_idx]( + view_features, other_views_features, view_positions, other_views_positions + ) + # Keep track of the updated view features + updated_multi_view_features.append(updated_view_features) + # Update the multi-view features for the next depth + multi_view_features = updated_multi_view_features + # Append the intermediate features if required + if depth_idx in take_indices: + # Normalize the intermediate features with final norm layer if enabled + intermediate_multi_view_features.append( + [self.norm(view_features) for view_features in multi_view_features] + if self.norm_intermediate + else multi_view_features + ) + + # Reshape the intermediate features and convert to MultiViewTransformerOutput class + for idx in range(len(intermediate_multi_view_features)): + intermediate_multi_view_features[idx] = [ + view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous() + for view_features in intermediate_multi_view_features[idx] + ] + intermediate_multi_view_features[idx] = MultiViewTransformerOutput( + features=intermediate_multi_view_features[idx] + ) + + # Return only the intermediate features if enabled + if self.intermediates_only: + return intermediate_multi_view_features + + # Normalize the output features + output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features] + + # Resize the output multi-view features back to NCHW + output_multi_view_features = [ + view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous() + for view_features in output_multi_view_features + ] + + output_multi_view_features = MultiViewTransformerOutput(features=output_multi_view_features) + + return output_multi_view_features, intermediate_multi_view_features + + +def dummy_positional_encoding(x, xpos): + "Dummy function for positional encoding of tokens" + x = x + xpos = xpos + return x + + +if __name__ == "__main__": + # Init multi-view cross-attention transformer with no custom positional encoding and run a forward pass + for num_views in [2, 3, 4]: + print(f"Testing MultiViewCrossAttentionTransformer with {num_views} views ...") + model = DifferentialMultiViewCrossAttentionTransformer( + name="MV-DCAT", input_embed_dim=1024, num_views=num_views + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + + # Init multi-view cross-attention transformer with custom positional encoding and run a forward pass + for num_views in [2, 3, 4]: + print( + f"Testing Differential MultiViewCrossAttentionTransformer with {num_views} views and custom positional encoding ..." + ) + model = DifferentialMultiViewCrossAttentionTransformer( + name="MV-DCAT", + input_embed_dim=1024, + num_views=num_views, + custom_positional_encoding=dummy_positional_encoding, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + + print("All multi-view cross-attention transformers initialized and tested successfully!") + + # Intermediate Feature Returner Tests + print("Running Intermediate Feature Returner Tests ...") + + # Run the intermediate feature returner with last-n index + model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR( + name="MV-DCAT-IFR", + input_embed_dim=1024, + num_views=2, + indices=6, # Last 6 layers + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert len(output[1]) == 6 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert len(output[1][0].features) == 2 + + # Run the intermediate feature returner with specific indices + model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR( + name="MV-DCAT-IFR", + input_embed_dim=1024, + num_views=2, + indices=[0, 2, 4, 6], # Specific indices + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert len(output[1]) == 4 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert len(output[1][0].features) == 2 + + # Test the normalizing of intermediate features + model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR( + name="MV-DCAT-IFR", + input_embed_dim=1024, + num_views=2, + indices=[-1], # Last layer + norm_intermediate=False, # Disable normalization + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + for view_idx in range(2): + assert not torch.equal( + output[0].features[view_idx], output[1][-1].features[view_idx] + ), "Final features and intermediate features (last layer) must be different." + + model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR( + name="MV-DCAT-IFR", + input_embed_dim=1024, + num_views=2, + indices=[-1], # Last layer + norm_intermediate=True, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + for view_idx in range(2): + assert torch.equal( + output[0].features[view_idx], output[1][-1].features[view_idx] + ), "Final features and intermediate features (last layer) must be same." + + print("All Intermediate Feature Returner Tests passed!") diff --git a/UniCeption/uniception/models/info_sharing/global_attention_transformer.py b/UniCeption/uniception/models/info_sharing/global_attention_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..00c77eb0e4759ac30c5c9412fd8bf5f6c3297111 --- /dev/null +++ b/UniCeption/uniception/models/info_sharing/global_attention_transformer.py @@ -0,0 +1,1107 @@ +""" +UniCeption Global-Attention Transformer for Information Sharing +""" + +from functools import partial +from typing import Callable, List, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn + +from uniception.models.info_sharing.base import ( + MultiSetTransformerInput, + MultiSetTransformerOutput, + MultiViewTransformerInput, + MultiViewTransformerOutput, + UniCeptionInfoSharingBase, +) +from uniception.models.libs.croco.pos_embed import RoPE2D +from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices +from uniception.models.utils.positional_encoding import PositionGetter +from uniception.models.utils.transformer_blocks import Mlp, SelfAttentionBlock + + +class MultiViewGlobalAttentionTransformer(UniCeptionInfoSharingBase): + "UniCeption Multi-View Global-Attention Transformer for information sharing across image features from different views." + + def __init__( + self, + name: str, + input_embed_dim: int, + max_num_views: int, + use_rand_idx_pe_for_non_reference_views: bool, + size: Optional[str] = None, + depth: int = 12, + dim: int = 768, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6), + mlp_layer: Type[nn.Module] = Mlp, + custom_positional_encoding: Optional[Union[str, Callable]] = None, + pretrained_checkpoint_path: Optional[str] = None, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Initialize the Multi-View Global-Attention Transformer for information sharing across image features from different views. + + Args: + input_embed_dim (int): Dimension of input embeddings. + max_num_views (int): Maximum number of views for positional encoding. + use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views. + size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None) + depth (int): Number of transformer layers. (default: 12, base size) + dim (int): Dimension of the transformer. (default: 768, base size) + num_heads (int): Number of attention heads. (default: 12, base size) + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: True) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None) + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False) + """ + # Initialize the base class + super().__init__(name=name, size=size, *args, **kwargs) + + # Initialize the specific attributes of the transformer + self.input_embed_dim = input_embed_dim + self.max_num_views = max_num_views + self.use_rand_idx_pe_for_non_reference_views = use_rand_idx_pe_for_non_reference_views + self.depth = depth + self.dim = dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.proj_drop = proj_drop + self.attn_drop = attn_drop + self.init_values = init_values + self.drop_path = drop_path + self.act_layer = act_layer + self.norm_layer = norm_layer + self.mlp_layer = mlp_layer + self.custom_positional_encoding = custom_positional_encoding + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.gradient_checkpointing = gradient_checkpointing + + # Initialize the projection layer for input embeddings + if self.input_embed_dim != self.dim: + self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True) + else: + self.proj_embed = nn.Identity() + + # Initialize custom position encodings + if self.custom_positional_encoding is not None and isinstance(self.custom_positional_encoding, str): + if self.custom_positional_encoding == "rope": + self.rope = RoPE2D(freq=100.0, F0=1.0) + self.custom_positional_encoding = self.rope + else: + raise ValueError(f"Unknown custom positional encoding: {self.custom_positional_encoding}") + + # Initialize the self-attention blocks which ingest all views at once + self.self_attention_blocks = nn.ModuleList( + [ + SelfAttentionBlock( + dim=self.dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_norm=self.qk_norm, + proj_drop=self.proj_drop, + attn_drop=self.attn_drop, + init_values=self.init_values, + drop_path=self.drop_path, + act_layer=self.act_layer, + norm_layer=self.norm_layer, + mlp_layer=self.mlp_layer, + custom_positional_encoding=self.custom_positional_encoding, + ) + for _ in range(self.depth) + ] + ) + + # Initialize the final normalization layer + self.norm = self.norm_layer(self.dim) + + # Initialize the position getter for patch positions if required + if self.custom_positional_encoding is not None: + self.position_getter = PositionGetter() + + # Initialize the positional encoding table for the different views + self.register_buffer( + "view_pos_table", + self._get_sinusoid_encoding_table(self.max_num_views, self.dim, 10000), + ) + + # Initialize random weights + self.initialize_weights() + + # Load pretrained weights if provided + if self.pretrained_checkpoint_path is not None: + print( + f"Loading pretrained multi-view global-attention transformer weights from {self.pretrained_checkpoint_path} ..." + ) + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + # Apply gradient checkpointing if enabled + if self.gradient_checkpointing: + for i, block in enumerate(self.self_attention_blocks): + self.self_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block) + + def _get_sinusoid_encoding_table(self, n_position, d_hid, base): + "Sinusoid position encoding table" + + def get_position_angle_vec(position): + return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) + + return torch.FloatTensor(sinusoid_table) + + def initialize_weights(self): + "Initialize weights of the transformer." + # Linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + "Initialize the transformer linear and layer norm weights." + if isinstance(m, nn.Linear): + # We use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + model_input: MultiViewTransformerInput, + ) -> MultiViewTransformerOutput: + """ + Forward interface for the Multi-View Global-Attention Transformer. + + Args: + model_input (MultiViewTransformerInput): Input to the model. + Expects the features to be a list of size (batch, input_embed_dim, height, width), + where each entry corresponds to a different view. + Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token) + which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens). + + Returns: + MultiViewTransformerOutput: Output of the model post information sharing. + """ + # Check that the number of views matches the input and the features are of expected shape + assert ( + len(model_input.features) <= self.max_num_views + ), f"Expected less than {self.max_num_views} views, got {len(model_input.features)}" + assert all( + curr_view_features.shape[1] == self.input_embed_dim for curr_view_features in model_input.features + ), f"All views must have input dimension {self.input_embed_dim}" + assert all( + curr_view_features.ndim == 4 for curr_view_features in model_input.features + ), "All views must have 4 dimensions (N, C, H, W)" + + # Initialize the multi-view features from the model input and number of views for current input + multi_view_features = model_input.features + num_of_views = len(multi_view_features) + batch_size, _, height, width = multi_view_features[0].shape + num_of_tokens_per_view = height * width + + # Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape) + multi_view_features = torch.stack(multi_view_features, dim=1) + + # Resize the multi-view features from NVCHW to NLC, where L = V * H * W + multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C) + multi_view_features = multi_view_features.reshape( + batch_size, num_of_views * height * width, self.input_embed_dim + ).contiguous() + + # Process additional input tokens if provided + if model_input.additional_input_tokens is not None: + additional_tokens = model_input.additional_input_tokens + assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)" + assert ( + additional_tokens.shape[1] == self.input_embed_dim + ), f"Additional tokens must have input dimension {self.input_embed_dim}" + assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens" + + # Reshape to channel-last format for transformer processing + additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C) + + # Concatenate the additional tokens to the multi-view features + multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1) + + # Project input features to the transformer dimension + multi_view_features = self.proj_embed(multi_view_features) + + # Create patch positions for each view if custom positional encoding is used + if self.custom_positional_encoding is not None: + multi_view_positions = [ + self.position_getter(batch_size, height, width, multi_view_features.device) + ] * num_of_views # List of length V, where each tensor is (N, H * W, C) + multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C) + else: + multi_view_positions = [None] * num_of_views + + # Add None positions for additional tokens if they exist + if model_input.additional_input_tokens is not None: + additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1] + multi_view_positions = multi_view_positions + additional_tokens_positions + + # Add positional encoding for reference view (idx 0) + ref_view_pe = self.view_pos_table[0].clone().detach() + ref_view_pe = ref_view_pe.reshape((1, 1, self.dim)) + ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1) + ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :] + ref_view_features = ref_view_features + ref_view_pe + + # Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled) + if self.use_rand_idx_pe_for_non_reference_views: + non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views, size=(num_of_views - 1,)) + else: + non_ref_view_pe_indices = torch.arange(1, num_of_views) + non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach() + non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim)) + non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1) + non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1) + non_ref_view_features = multi_view_features[ + :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, : + ] + non_ref_view_features = non_ref_view_features + non_ref_view_pe + + # Concatenate the reference and non-reference view features + # Handle additional tokens (no view-based positional encoding for them) + if model_input.additional_input_tokens is not None: + additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1) + else: + multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1) + + # Loop over the depth of the transformer + for depth_idx in range(self.depth): + # Apply the self-attention block and update the multi-view features + multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions) + + # Normalize the output features + output_multi_view_features = self.norm(multi_view_features) + + # Extract only the view features (excluding additional tokens) + view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :] + + # Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W) + view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C) + view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W) + + # Split the output multi-view features into separate views + view_features = view_features.split(1, dim=1) + view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features] + + # Extract and return additional token features if provided + if model_input.additional_input_tokens is not None: + additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T) + return MultiViewTransformerOutput( + features=view_features, additional_token_features=additional_token_features + ) + else: + return MultiViewTransformerOutput(features=view_features) + + +class MultiViewGlobalAttentionTransformerIFR(MultiViewGlobalAttentionTransformer, IntermediateFeatureReturner): + "Intermediate Feature Returner for UniCeption Multi-View Global-Attention Transformer" + + def __init__( + self, + name: str, + input_embed_dim: int, + max_num_views: int, + use_rand_idx_pe_for_non_reference_views: bool, + size: Optional[str] = None, + depth: int = 12, + dim: int = 768, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + mlp_layer: nn.Module = Mlp, + custom_positional_encoding: Callable = None, + pretrained_checkpoint_path: str = None, + indices: Optional[Union[int, List[int]]] = None, + norm_intermediate: bool = True, + intermediates_only: bool = False, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Initialize the Multi-View Global-Attention Transformer for information sharing across image features from different views. + Extends the base class to return intermediate features. + + Args: + input_embed_dim (int): Dimension of input embeddings. + max_num_views (int): Maximum number of views for positional encoding. + use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views. + size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None) + depth (int): Number of transformer layers. (default: 12, base size) + dim (int): Dimension of the transformer. (default: 768, base size) + num_heads (int): Number of attention heads. (default: 12, base size) + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None) + indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options: + - None: Return all intermediate layers. + - int: Return the last n layers. + - List[int]: Return the intermediate layers at the specified indices. + norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True) + intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False) + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False) + """ + # Init the base classes + MultiViewGlobalAttentionTransformer.__init__( + self, + name=name, + input_embed_dim=input_embed_dim, + max_num_views=max_num_views, + use_rand_idx_pe_for_non_reference_views=use_rand_idx_pe_for_non_reference_views, + size=size, + depth=depth, + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_drop=proj_drop, + attn_drop=attn_drop, + init_values=init_values, + drop_path=drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + mlp_layer=mlp_layer, + custom_positional_encoding=custom_positional_encoding, + pretrained_checkpoint_path=pretrained_checkpoint_path, + gradient_checkpointing=gradient_checkpointing, + *args, + **kwargs, + ) + IntermediateFeatureReturner.__init__( + self, + indices=indices, + norm_intermediate=norm_intermediate, + intermediates_only=intermediates_only, + ) + + def forward( + self, + model_input: MultiViewTransformerInput, + ) -> Union[ + List[MultiViewTransformerOutput], + Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]], + ]: + """ + Forward interface for the Multi-View Global-Attention Transformer with Intermediate Feature Return. + + Args: + model_input (MultiViewTransformerInput): Input to the model. + Expects the features to be a list of size (batch, input_embed_dim, height, width), + where each entry corresponds to a different view. + Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token) + which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens). + + Returns: + Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]: + Output of the model post information sharing. + If intermediates_only is True, returns a list of intermediate outputs. + If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs. + """ + # Check that the number of views matches the input and the features are of expected shape + assert ( + len(model_input.features) <= self.max_num_views + ), f"Expected {self.num_views} views, got {len(model_input.features)}" + assert all( + curr_view_features.shape[1] == self.input_embed_dim for curr_view_features in model_input.features + ), f"All views must have input dimension {self.input_embed_dim}" + assert all( + curr_view_features.ndim == 4 for curr_view_features in model_input.features + ), "All views must have 4 dimensions (N, C, H, W)" + + # Get the indices of the intermediate features to return + intermediate_multi_view_features = [] + take_indices, _ = feature_take_indices(self.depth, self.indices) + + # Initialize the multi-view features from the model input and number of views for current input + multi_view_features = model_input.features + num_of_views = len(multi_view_features) + batch_size, _, height, width = multi_view_features[0].shape + num_of_tokens_per_view = height * width + + # Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape) + multi_view_features = torch.stack(multi_view_features, dim=1) + + # Resize the multi-view features from NVCHW to NLC, where L = V * H * W + multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C) + multi_view_features = multi_view_features.reshape( + batch_size, num_of_views * height * width, self.input_embed_dim + ).contiguous() + + # Process additional input tokens if provided + if model_input.additional_input_tokens is not None: + additional_tokens = model_input.additional_input_tokens + assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)" + assert ( + additional_tokens.shape[1] == self.input_embed_dim + ), f"Additional tokens must have input dimension {self.input_embed_dim}" + assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens" + + # Reshape to channel-last format for transformer processing + additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C) + + # Concatenate the additional tokens to the multi-view features + multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1) + + # Project input features to the transformer dimension + multi_view_features = self.proj_embed(multi_view_features) + + # Create patch positions for each view if custom positional encoding is used + if self.custom_positional_encoding is not None: + multi_view_positions = [ + self.position_getter(batch_size, height, width, multi_view_features.device) + ] * num_of_views # List of length V, where each tensor is (N, H * W, C) + multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C) + else: + multi_view_positions = [None] * num_of_views + + # Add None positions for additional tokens if they exist + if model_input.additional_input_tokens is not None: + additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1] + multi_view_positions = multi_view_positions + additional_tokens_positions + + # Add positional encoding for reference view (idx 0) + ref_view_pe = self.view_pos_table[0].clone().detach() + ref_view_pe = ref_view_pe.reshape((1, 1, self.dim)) + ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1) + ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :] + ref_view_features = ref_view_features + ref_view_pe + + # Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled) + if self.use_rand_idx_pe_for_non_reference_views: + non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views, size=(num_of_views - 1,)) + else: + non_ref_view_pe_indices = torch.arange(1, num_of_views) + non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach() + non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim)) + non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1) + non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1) + non_ref_view_features = multi_view_features[ + :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, : + ] + non_ref_view_features = non_ref_view_features + non_ref_view_pe + + # Concatenate the reference and non-reference view features + # Handle additional tokens (no view-based positional encoding for them) + if model_input.additional_input_tokens is not None: + additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1) + else: + multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1) + + # Loop over the depth of the transformer + for depth_idx in range(self.depth): + # Apply the self-attention block and update the multi-view features + multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions) + if depth_idx in take_indices: + # Normalize the intermediate features with final norm layer if enabled + intermediate_multi_view_features.append( + self.norm(multi_view_features) if self.norm_intermediate else multi_view_features + ) + + # Reshape the intermediate features and convert to MultiViewTransformerOutput class + for idx in range(len(intermediate_multi_view_features)): + # Get the current intermediate features + current_features = intermediate_multi_view_features[idx] + + # Extract additional token features if provided + additional_token_features = None + if model_input.additional_input_tokens is not None: + additional_token_features = current_features[:, num_of_views * num_of_tokens_per_view :, :] + additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T) + # Only keep the view features for reshaping + current_features = current_features[:, : num_of_views * num_of_tokens_per_view, :] + + # Reshape the intermediate multi-view features (N, V * H * W, C) back to (N, V, C, H, W) + current_features = current_features.reshape( + batch_size, num_of_views, height, width, self.dim + ) # (N, V, H, W, C) + current_features = current_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W) + + # Split the intermediate multi-view features into separate views + current_features = current_features.split(1, dim=1) + current_features = [ + intermediate_view_features.squeeze(dim=1) for intermediate_view_features in current_features + ] + + intermediate_multi_view_features[idx] = MultiViewTransformerOutput( + features=current_features, additional_token_features=additional_token_features + ) + + # Return only the intermediate features if enabled + if self.intermediates_only: + return intermediate_multi_view_features + + # Normalize the output features + output_multi_view_features = self.norm(multi_view_features) + + # Extract view features (excluding additional tokens) + additional_token_features = None + if model_input.additional_input_tokens is not None: + additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :] + additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T) + view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :] + else: + view_features = output_multi_view_features + + # Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W) + view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C) + view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W) + + # Split the output multi-view features into separate views + view_features = view_features.split(1, dim=1) + view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features] + + output_multi_view_features = MultiViewTransformerOutput( + features=view_features, additional_token_features=additional_token_features + ) + + return output_multi_view_features, intermediate_multi_view_features + + +class GlobalAttentionTransformer(UniCeptionInfoSharingBase): + "UniCeption Global-Attention Transformer for information sharing across different set of features." + + def __init__( + self, + name: str, + input_embed_dim: int, + max_num_sets: int, + use_rand_idx_pe_for_non_reference_sets: bool, + size: Optional[str] = None, + depth: int = 12, + dim: int = 768, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6), + mlp_layer: Type[nn.Module] = Mlp, + pretrained_checkpoint_path: Optional[str] = None, + gradient_checkpointing: bool = False, + *args, + **kwargs, + ): + """ + Initialize the Global-Attention Transformer for information sharing across features from different sets. + + Args: + input_embed_dim (int): Dimension of input embeddings. + max_num_sets (int): Maximum number of sets for positional encoding. + use_rand_idx_pe_for_non_reference_sets (bool): Whether to use random index positional encoding for non-reference sets. + size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None) + depth (int): Number of transformer layers. (default: 12, base size) + dim (int): Dimension of the transformer. (default: 768, base size) + num_heads (int): Number of attention heads. (default: 12, base size) + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: True) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None) + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False) + """ + # Initialize the base class + super().__init__(name=name, size=size, *args, **kwargs) + + # Initialize the specific attributes of the transformer + self.input_embed_dim = input_embed_dim + self.max_num_sets = max_num_sets + self.use_rand_idx_pe_for_non_reference_sets = use_rand_idx_pe_for_non_reference_sets + self.depth = depth + self.dim = dim + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + self.proj_drop = proj_drop + self.attn_drop = attn_drop + self.init_values = init_values + self.drop_path = drop_path + self.act_layer = act_layer + self.norm_layer = norm_layer + self.mlp_layer = mlp_layer + self.pretrained_checkpoint_path = pretrained_checkpoint_path + self.gradient_checkpointing = gradient_checkpointing + + # Initialize the projection layer for input embeddings + if self.input_embed_dim != self.dim: + self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True) + else: + self.proj_embed = nn.Identity() + + # Initialize the self-attention blocks which ingest all sets at once + self.self_attention_blocks = nn.ModuleList( + [ + SelfAttentionBlock( + dim=self.dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_norm=self.qk_norm, + proj_drop=self.proj_drop, + attn_drop=self.attn_drop, + init_values=self.init_values, + drop_path=self.drop_path, + act_layer=self.act_layer, + norm_layer=self.norm_layer, + mlp_layer=self.mlp_layer, + ) + for _ in range(self.depth) + ] + ) + + # Initialize the final normalization layer + self.norm = self.norm_layer(self.dim) + + # Initialize the positional encoding table for the different sets + self.register_buffer( + "set_pos_table", + self._get_sinusoid_encoding_table(self.max_num_sets, self.dim, 10000), + ) + + # Initialize random weights + self.initialize_weights() + + # Load pretrained weights if provided + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained global-attention transformer weights from {self.pretrained_checkpoint_path} ...") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + # Apply gradient checkpointing if enabled + if self.gradient_checkpointing: + for i, block in enumerate(self.self_attention_blocks): + self.self_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block) + + def _get_sinusoid_encoding_table(self, n_position, d_hid, base): + "Sinusoid position encoding table" + + def get_position_angle_vec(position): + return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) + + return torch.FloatTensor(sinusoid_table) + + def initialize_weights(self): + "Initialize weights of the transformer." + # Linears and layer norms + self.apply(self._init_weights) + + def _init_weights(self, m): + "Initialize the transformer linear and layer norm weights." + if isinstance(m, nn.Linear): + # We use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + model_input: MultiSetTransformerInput, + ) -> MultiSetTransformerOutput: + """ + Forward interface for the Multi-Set Global-Attention Transformer. + + Args: + model_input (MultiSetTransformerInput): Input to the model. + Expects the features to be a list of size (batch, input_embed_dim, num_tokens), + where each entry corresponds to a different set of tokens and + the number of tokens can be different for each set. + Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token) + which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens). + + Returns: + MultiSetTransformerOutput: Output of the model post information sharing. + """ + # Check that the number of sets matches the input and the features are of expected shape + assert ( + len(model_input.features) <= self.max_num_sets + ), f"Expected less than {self.max_num_sets} sets, got {len(model_input.features)}" + assert all( + set_features.shape[1] == self.input_embed_dim for set_features in model_input.features + ), f"All sets must have input dimension {self.input_embed_dim}" + assert all( + set_features.ndim == 3 for set_features in model_input.features + ), "All sets must have 3 dimensions (N, C, T)" + + # Initialize the multi-set features from the model input and number of sets for current input + multi_set_features = model_input.features + num_of_sets = len(multi_set_features) + batch_size, _, _ = multi_set_features[0].shape + num_of_tokens_per_set = [set_features.shape[2] for set_features in multi_set_features] + + # Permute the multi-set features from (N, C, T) to (N, T, C) + multi_set_features = [set_features.permute(0, 2, 1).contiguous() for set_features in multi_set_features] + + # Stack the multi-set features along the number of tokens dimension + multi_set_features = torch.cat(multi_set_features, dim=1) + + # Process additional input tokens if provided + if model_input.additional_input_tokens is not None: + additional_tokens = model_input.additional_input_tokens + assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)" + assert ( + additional_tokens.shape[1] == self.input_embed_dim + ), f"Additional tokens must have input dimension {self.input_embed_dim}" + assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens" + + # Reshape to channel-last format for transformer processing + additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C) + + # Concatenate the additional tokens to the multi-set features + multi_set_features = torch.cat([multi_set_features, additional_tokens], dim=1) + + # Project input features to the transformer dimension + multi_set_features = self.proj_embed(multi_set_features) + + # Create dummy patch positions for each set + multi_set_positions = [None] * num_of_sets + + # Add positional encoding for reference set (idx 0) + ref_set_pe = self.set_pos_table[0].clone().detach() + ref_set_pe = ref_set_pe.reshape((1, 1, self.dim)) + ref_set_pe = ref_set_pe.repeat(batch_size, num_of_tokens_per_set[0], 1) + ref_set_features = multi_set_features[:, : num_of_tokens_per_set[0], :] + ref_set_features = ref_set_features + ref_set_pe + + # Add positional encoding for non-reference sets (sequential indices starting from idx 1 or random indices which are uniformly sampled) + if self.use_rand_idx_pe_for_non_reference_sets: + non_ref_set_pe_indices = torch.randint(low=1, high=self.max_num_sets, size=(num_of_sets - 1,)) + else: + non_ref_set_pe_indices = torch.arange(1, num_of_sets) + non_ref_set_pe_list = [] + for non_ref_set_idx in range(1, num_of_sets): + non_ref_set_pe_for_idx = self.set_pos_table[non_ref_set_pe_indices[non_ref_set_idx - 1]].clone().detach() + non_ref_set_pe_for_idx = non_ref_set_pe_for_idx.reshape((1, 1, self.dim)) + non_ref_set_pe_for_idx = non_ref_set_pe_for_idx.repeat( + batch_size, num_of_tokens_per_set[non_ref_set_idx], 1 + ) + non_ref_set_pe_list.append(non_ref_set_pe_for_idx) + non_ref_set_pe = torch.cat(non_ref_set_pe_list, dim=1) + non_ref_set_features = multi_set_features[:, num_of_tokens_per_set[0] : sum(num_of_tokens_per_set), :] + non_ref_set_features = non_ref_set_features + non_ref_set_pe + + # Concatenate the reference and non-reference set features + # Handle additional tokens (no set-based positional encoding for them) + if model_input.additional_input_tokens is not None: + additional_features = multi_set_features[:, sum(num_of_tokens_per_set) :, :] + multi_set_features = torch.cat([ref_set_features, non_ref_set_features, additional_features], dim=1) + else: + multi_set_features = torch.cat([ref_set_features, non_ref_set_features], dim=1) + + # Add None positions for additional tokens if they exist + if model_input.additional_input_tokens is not None: + additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[2] + multi_set_positions = multi_set_positions + additional_tokens_positions + + # Loop over the depth of the transformer + for depth_idx in range(self.depth): + # Apply the self-attention block and update the multi-set features + multi_set_features = self.self_attention_blocks[depth_idx](multi_set_features, multi_set_positions) + + # Normalize the output features + output_multi_set_features = self.norm(multi_set_features) + + # Extract additional token features if provided + additional_token_features = None + if model_input.additional_input_tokens is not None: + additional_token_features = output_multi_set_features[:, sum(num_of_tokens_per_set) :, :] + additional_token_features = additional_token_features.permute( + 0, 2, 1 + ).contiguous() # (N, T, C) -> (N, C, T) + # Only keep the set features for reshaping + output_multi_set_features = output_multi_set_features[:, : sum(num_of_tokens_per_set), :] + + # Reshape the output multi-set features from (N, T, C) to (N, C, T) + output_multi_set_features = output_multi_set_features.permute(0, 2, 1).contiguous() + + # Split the output multi-set features into separate sets using the list of number of tokens per set + output_multi_set_features = torch.split(output_multi_set_features, num_of_tokens_per_set, dim=2) + + # Return the output multi-set features with additional token features if provided + return MultiSetTransformerOutput( + features=output_multi_set_features, additional_token_features=additional_token_features + ) + + +def dummy_positional_encoding(x, xpos): + "Dummy function for positional encoding of tokens" + x = x + xpos = xpos + return x + + +if __name__ == "__main__": + # Init multi-view global-attention transformer with no custom positional encoding and run a forward pass + for num_views in [2, 3, 4]: + print(f"Testing MultiViewGlobalAttentionTransformer with {num_views} views ...") + # Sequential idx based positional encoding + model = MultiViewGlobalAttentionTransformer( + name="MV-GAT", input_embed_dim=1024, max_num_views=1000, use_rand_idx_pe_for_non_reference_views=False + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + # Random idx based positional encoding + model = MultiViewGlobalAttentionTransformer( + name="MV-GAT", input_embed_dim=1024, max_num_views=1000, use_rand_idx_pe_for_non_reference_views=True + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + + # Init multi-view global-attention transformer with custom positional encoding and run a forward pass + for num_views in [2, 3, 4]: + print(f"Testing MultiViewGlobalAttentionTransformer with {num_views} views and custom positional encoding ...") + model = MultiViewGlobalAttentionTransformer( + name="MV-GAT", + input_embed_dim=1024, + max_num_views=1000, + use_rand_idx_pe_for_non_reference_views=True, + custom_positional_encoding=dummy_positional_encoding, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + model_input = MultiViewTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + + print("All multi-view global-attention transformers initialized and tested successfully!") + + # Intermediate Feature Returner Tests + print("Running Intermediate Feature Returner Tests ...") + + # Run the intermediate feature returner with last-n index + model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR( + name="MV-GAT-IFR", + input_embed_dim=1024, + max_num_views=1000, + use_rand_idx_pe_for_non_reference_views=True, + indices=6, # Last 6 layers + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert len(output[1]) == 6 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert len(output[1][0].features) == 2 + + # Run the intermediate feature returner with specific indices + model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR( + name="MV-GAT-IFR", + input_embed_dim=1024, + max_num_views=1000, + use_rand_idx_pe_for_non_reference_views=True, + indices=[0, 2, 4, 6], # Specific indices + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert len(output[1]) == 4 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert len(output[1][0].features) == 2 + + # Test the normalizing of intermediate features + model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR( + name="MV-GAT-IFR", + input_embed_dim=1024, + max_num_views=1000, + use_rand_idx_pe_for_non_reference_views=True, + indices=[-1], # Last layer + norm_intermediate=False, # Disable normalization + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + for view_idx in range(2): + assert not torch.equal( + output[0].features[view_idx], output[1][-1].features[view_idx] + ), "Final features and intermediate features (last layer) must be different." + + model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR( + name="MV-GAT-IFR", + input_embed_dim=1024, + max_num_views=1000, + use_rand_idx_pe_for_non_reference_views=True, + indices=[-1], # Last layer + norm_intermediate=True, + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)] + model_input = MultiViewTransformerInput(features=model_input) + output = model_intermediate_feature_returner(model_input) + for view_idx in range(2): + assert torch.equal( + output[0].features[view_idx], output[1][-1].features[view_idx] + ), "Final features and intermediate features (last layer) must be same." + + print("All Intermediate Feature Returner Tests passed!") + + # Init multi-set global-attention transformer and run a forward pass with different number of sets and set token sizes + import random + + model = GlobalAttentionTransformer( + name="GAT", input_embed_dim=1024, max_num_sets=3, use_rand_idx_pe_for_non_reference_sets=False + ) + for num_sets in [2, 3]: + print(f"Testing GlobalAttentionTransformer with {num_sets} sets ...") + model_input = [torch.rand(1, 1024, random.randint(256, 513)) for _ in range(num_sets)] + model_input = MultiSetTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_sets + for feat, rand_input in zip(model_output.features, model_input.features): + assert feat.shape[2] == rand_input.shape[2] + assert feat.shape[1] == model.dim + assert feat.shape[0] == rand_input.shape[0] + # Random idx based positional encoding + model = GlobalAttentionTransformer( + name="GAT", input_embed_dim=1024, max_num_sets=1000, use_rand_idx_pe_for_non_reference_sets=True + ) + for num_sets in [2, 3, 4]: + print(f"Testing GlobalAttentionTransformer with {num_sets} sets ...") + model_input = [torch.rand(1, 1024, random.randint(256, 513)) for _ in range(num_sets)] + model_input = MultiSetTransformerInput(features=model_input) + model_output = model(model_input) + assert len(model_output.features) == num_sets + for feat, rand_input in zip(model_output.features, model_input.features): + assert feat.shape[2] == rand_input.shape[2] + assert feat.shape[1] == model.dim + assert feat.shape[0] == rand_input.shape[0] + + print("All Global Attention Transformer Tests passed!") + + # Test additional input tokens for MultiViewGlobalAttentionTransformer + print("Testing MultiViewGlobalAttentionTransformer with additional input tokens...") + model = MultiViewGlobalAttentionTransformer( + name="MV-GAT", input_embed_dim=1024, max_num_views=1000, use_rand_idx_pe_for_non_reference_views=False + ) + num_views = 2 + num_additional_tokens = 5 + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + additional_tokens = torch.rand(1, 1024, num_additional_tokens) + model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens) + model_output = model(model_input) + assert len(model_output.features) == num_views + assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features) + assert model_output.additional_token_features is not None + assert model_output.additional_token_features.shape == (1, model.dim, num_additional_tokens) + + # Test additional input tokens for MultiViewGlobalAttentionTransformerIFR + print("Testing MultiViewGlobalAttentionTransformerIFR with additional input tokens...") + model_ifr = MultiViewGlobalAttentionTransformerIFR( + name="MV-GAT-IFR", + input_embed_dim=1024, + max_num_views=1000, + use_rand_idx_pe_for_non_reference_views=True, + indices=[0, 2, 4], + ) + model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)] + additional_tokens = torch.rand(1, 1024, num_additional_tokens) + model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens) + output = model_ifr(model_input) + assert isinstance(output, tuple) + assert isinstance(output[0], MultiViewTransformerOutput) + assert output[0].additional_token_features is not None + assert output[0].additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens) + assert len(output[1]) == 3 + assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1]) + assert all(intermediate.additional_token_features is not None for intermediate in output[1]) + assert all( + intermediate.additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens) + for intermediate in output[1] + ) + + # Test additional input tokens for GlobalAttentionTransformer + print("Testing GlobalAttentionTransformer with additional input tokens...") + model = GlobalAttentionTransformer( + name="GAT", input_embed_dim=1024, max_num_sets=1000, use_rand_idx_pe_for_non_reference_sets=False + ) + num_sets = 3 + num_additional_tokens = 8 + model_input = [torch.rand(1, 1024, random.randint(256, 513)) for _ in range(num_sets)] + additional_tokens = torch.rand(1, 1024, num_additional_tokens) + model_input = MultiSetTransformerInput(features=model_input, additional_input_tokens=additional_tokens) + model_output = model(model_input) + assert len(model_output.features) == num_sets + for feat, rand_input in zip(model_output.features, model_input.features): + assert feat.shape[2] == rand_input.shape[2] + assert feat.shape[1] == model.dim + assert feat.shape[0] == rand_input.shape[0] + assert model_output.additional_token_features is not None + assert model_output.additional_token_features.shape == (1, model.dim, num_additional_tokens) + + print("All tests using additional input tokens passed!") diff --git a/UniCeption/uniception/models/libs/__init__.py b/UniCeption/uniception/models/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/__init__.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a08b2c2049a3dc0e8777b93e88b824544ef480c5 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/image_cli.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/image_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..dc065761414d4e1a6963b9d2645c456287cc7553 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/image_cli.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A CLI to run ImageTokenizer on plain images based on torch.jit. + +Usage: + python3 -m cosmos_tokenizer.image_cli \ + --image_pattern 'path/to/input/folder/*.jpg' \ + --output_dir ./reconstructions \ + --checkpoint_enc ./pretrained_ckpts/CosmosCI_f8x8/encoder.jit \ + --checkpoint_dec ./pretrained_ckpts/CosmosCI_f8x8/decoder.jit + + Optionally, you can run the model in pure PyTorch mode: + python3 -m cosmos_tokenizer.image_cli \ + --image_pattern 'path/to/input/folder/*.jpg' \ + --mode torch \ + --tokenizer_type CI \ + --spatial_compression 8 \ + --checkpoint_enc ./pretrained_ckpts/CosmosCI_f8x8/encoder.jit \ + --checkpoint_dec ./pretrained_ckpts/CosmosCI_f8x8/decoder.jit +""" + +import os +import sys +from argparse import ArgumentParser, Namespace +from typing import Any + +import numpy as np + +from uniception.models.libs.cosmos_tokenizer.image_lib import ImageTokenizer +from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs +from uniception.models.libs.cosmos_tokenizer.utils import ( + get_filepaths, + get_output_filepath, + read_image, + resize_image, + write_image, +) + + +def _parse_args() -> tuple[Namespace, dict[str, Any]]: + parser = ArgumentParser(description="A CLI for running ImageTokenizer on plain images.") + parser.add_argument( + "--image_pattern", + type=str, + default="path/to/images/*.jpg", + help="Glob pattern.", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="JIT full Autoencoder model filepath.", + ) + parser.add_argument( + "--checkpoint_enc", + type=str, + default=None, + help="JIT Encoder model filepath.", + ) + parser.add_argument( + "--checkpoint_dec", + type=str, + default=None, + help="JIT Decoder model filepath.", + ) + parser.add_argument( + "--tokenizer_type", + type=str, + choices=["CI", "DI"], + help="Specifies the tokenizer type.", + ) + parser.add_argument( + "--spatial_compression", + type=int, + choices=[8, 16], + default=8, + help="The spatial compression factor.", + ) + parser.add_argument( + "--mode", + type=str, + choices=["torch", "jit"], + default="jit", + help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", + ) + parser.add_argument( + "--short_size", + type=int, + default=None, + help="The size to resample inputs. None, by default.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Sets the precision. Default bfloat16.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for invoking the model.", + ) + parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") + parser.add_argument( + "--save_input", + action="store_true", + help="If on, the input image will be be outputed too.", + ) + args = parser.parse_args() + return args + + +args = _parse_args() +if args.mode == "torch" and args.tokenizer_type not in ["CI", "DI"]: + sys.exit(1) + + +def _run_eval() -> None: + """Invokes the evaluation pipeline.""" + + if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: + return + + if args.mode == "torch": + tokenizer_config = TokenizerConfigs[args.tokenizer_type].value + tokenizer_config.update(dict(spatial_compression=args.spatial_compression)) + else: + tokenizer_config = None + + autoencoder = ImageTokenizer( + checkpoint=args.checkpoint, + checkpoint_enc=args.checkpoint_enc, + checkpoint_dec=args.checkpoint_dec, + tokenizer_config=tokenizer_config, + device=args.device, + dtype=args.dtype, + ) + + filepaths = get_filepaths(args.image_pattern) + + for filepath in filepaths: + image = read_image(filepath) + image = resize_image(image, short_size=args.short_size) + batch_image = np.expand_dims(image, axis=0) + + output_image = autoencoder(batch_image)[0] + + output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) + write_image(output_filepath, output_image) + + if args.save_input: + ext = os.path.splitext(output_filepath)[-1] + input_filepath = output_filepath.replace(ext, "_input" + ext) + write_image(input_filepath, image) + + +def main() -> None: + _run_eval() + + +if __name__ == "__main__": + main() diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/image_lib.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/image_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..03556d5390489ed9bc9afb117e6bba68e5a1b436 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/image_lib.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A library for image tokenizers inference.""" + +from typing import Any + +import numpy as np +import torch + +from uniception.models.libs.cosmos_tokenizer.utils import ( + load_decoder_model, + load_encoder_model, + load_model, + numpy2tensor, + pad_image_batch, + tensor2numpy, + unpad_image_batch, +) + + +class ImageTokenizer(torch.nn.Module): + def __init__( + self, + checkpoint: str = None, + checkpoint_enc: str = None, + checkpoint_dec: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", + dtype: str = "bfloat16", + ) -> None: + super().__init__() + self._device = device + self._dtype = getattr(torch, dtype) + self._full_model = ( + load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None + ) + self._enc_model = ( + load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) + if checkpoint_enc is not None + else None + ) + self._dec_model = ( + load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) + if checkpoint_dec is not None + else None + ) + + @torch.no_grad() + def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Reconstrcuts a batch of image tensors after embedding into a latent. + + Args: + input_tensor: The input image Bx3xHxW layout, range [-1..1]. + Returns: + The reconstructed tensor, layout Bx3xHxW, range [-1..1]. + """ + if self._full_model is not None: + output_tensor = self._full_model(input_tensor) + output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor + else: + output_latent = self.encode(input_tensor)[0] + output_tensor = self.decode(output_latent) + return output_tensor + + @torch.no_grad() + def decode(self, input_latent: torch.Tensor) -> torch.Tensor: + """Decodes an image from a provided latent embedding. + + Args: + input_latent: The continuous latent Bx16xhxw for CI, + or the discrete indices Bxhxw for DI. + Returns: + The output tensor in Bx3xHxW, range [-1..1]. + """ + return self._dec_model(input_latent) + + @torch.no_grad() + def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: + """Encodes an image into a latent embedding or code. + + Args: + input_tensor: The input tensor Bx3xHxW layout, range [-1..1]. + Returns: + For continuous image (CI) tokenizer, the tuple contains: + - The latent embedding, Bx16x(h)x(w), where the compression + rate is (H/h x W/w), and channel dimension of 16. + For discrete image (DI) tokenizer, the tuple contains: + - The indices, Bx(h)x(w), from a codebook of size 64K, which + corresponds to FSQ levels of (8,8,8,5,5,5). + - The discrete code, Bx6x(h)x(w), where the compression rate is + again (H/h x W/w), and channel dimension of 6. + """ + output_latent = self._enc_model(input_tensor) + if isinstance(output_latent, torch.Tensor): + return output_latent + return output_latent[:-1] + + @torch.no_grad() + def forward(self, image: np.ndarray) -> np.ndarray: + """Reconstructs an image using a pre-trained tokenizer. + + Args: + image: The input image BxHxWxC layout, range [0..255]. + Returns: + The reconstructed image in range [0..255], layout BxHxWxC. + """ + padded_input_image, crop_region = pad_image_batch(image) + input_tensor = numpy2tensor(padded_input_image, dtype=self._dtype, device=self._device) + output_tensor = self.autoencode(input_tensor) + padded_output_image = tensor2numpy(output_tensor) + return unpad_image_batch(padded_output_image, crop_region) diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/__init__.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3756f4260b2117b413d61c5157c30867ae5b4cb1 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/__init__.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + +from uniception.models.libs.cosmos_tokenizer.modules.distributions import GaussianDistribution, IdentityDistribution +from uniception.models.libs.cosmos_tokenizer.modules.layers2d import Decoder, Encoder +from uniception.models.libs.cosmos_tokenizer.modules.layers3d import ( + DecoderBase, + DecoderFactorized, + EncoderBase, + EncoderFactorized, +) +from uniception.models.libs.cosmos_tokenizer.modules.quantizers import ( + FSQuantizer, + LFQuantizer, + ResidualFSQuantizer, + VectorQuantizer, +) + + +class EncoderType(Enum): + Default = Encoder + + +class DecoderType(Enum): + Default = Decoder + + +class Encoder3DType(Enum): + BASE = EncoderBase + FACTORIZED = EncoderFactorized + + +class Decoder3DType(Enum): + BASE = DecoderBase + FACTORIZED = DecoderFactorized + + +class ContinuousFormulation(Enum): + VAE = GaussianDistribution + AE = IdentityDistribution + + +class DiscreteQuantizer(Enum): + VQ = VectorQuantizer + LFQ = LFQuantizer + FSQ = FSQuantizer + RESFSQ = ResidualFSQuantizer diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/distributions.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ac656af54b82e2ec8d56b9e4eed8b40d5e3146 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/distributions.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The distribution modes to use for continuous image tokenizers.""" + +import torch + + +class IdentityDistribution(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, parameters): + return parameters, (torch.tensor([0.0]), torch.tensor([0.0])) + + +class GaussianDistribution(torch.nn.Module): + def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0): + super().__init__() + self.min_logvar = min_logvar + self.max_logvar = max_logvar + + def sample(self, mean, logvar): + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + def forward(self, parameters): + mean, logvar = torch.chunk(parameters, 2, dim=1) + logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar) + return self.sample(mean, logvar), (mean, logvar) diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers2d.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers2d.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca22da290298a7a80ee57d314ab4fa68370aaa2 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers2d.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The model definition for Continuous 2D layers + +Adapted from: https://github.com/CompVis/stable-diffusion/blob/ +21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py + +[Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors] +https://github.com/CompVis/stable-diffusion/blob/ +21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/LICENSE +""" + +import math + +import numpy as np + +# pytorch_diffusion + derived encoder decoder +import torch +import torch.nn as nn +import torch.nn.functional as F + +from uniception.models.libs.cosmos_tokenizer.modules.patching import Patcher, UnPatcher +from uniception.models.libs.cosmos_tokenizer.modules.utils import Normalize, nonlinearity + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) + return self.conv(x) + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + return self.conv(x) + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = Normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nin_shortcut = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + + self.norm = Normalize(in_channels) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # TODO (freda): Consider reusing implementations in Attn `imaginaire`, + # since than one is gonna be based on TransformerEngine's attn op, + # w/c could ease CP implementations. + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size + + # calculate the number of downsample operations + self.num_downsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert ( + self.num_downsamples <= self.num_resolutions + ), f"we can only downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level < self.num_downsamples: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level < self.num_downsamples: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: int, + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size + + # calculate the number of upsample operations + self.num_upsamples = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert self.num_upsamples <= self.num_resolutions, f"we can only upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level >= (self.num_resolutions - self.num_upsamples): + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level >= (self.num_resolutions - self.num_upsamples): + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher(h) + return h diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers3d.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers3d.py new file mode 100644 index 0000000000000000000000000000000000000000..b46ae1e85d5597c6c05a261bf34eeff400ea96ef --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/layers3d.py @@ -0,0 +1,965 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The model definition for 3D layers + +Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889 + +[MIT License Copyright (c) 2023 Phil Wang] +https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/LICENSE +""" +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from uniception.models.libs.cosmos_tokenizer.modules.patching import Patcher, Patcher3D, UnPatcher, UnPatcher3D +from uniception.models.libs.cosmos_tokenizer.modules.utils import ( + CausalNormalize, + batch2space, + batch2time, + cast_tuple, + is_odd, + nonlinearity, + replication_pad, + space2batch, + time2batch, +) + +_LEGACY_NUM_GROUPS = 32 + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in: int = 1, + chan_out: int = 1, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + pad_mode: str = "constant", + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + time_stride = kwargs.pop("time_stride", 1) + time_dilation = kwargs.pop("time_dilation", 1) + padding = kwargs.pop("padding", 1) + + self.pad_mode = pad_mode + time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride) + self.time_pad = time_pad + + self.spatial_pad = (padding, padding, padding, padding) + + stride = (time_stride, stride, stride) + dilation = (time_dilation, dilation, dilation) + self.conv3d = nn.Conv3d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs, + ) + + def _replication_pad(self, x: torch.Tensor) -> torch.Tensor: + x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1) + x = torch.cat([x_prev, x], dim=2) + padding = self.spatial_pad + (0, 0) + return F.pad(x, padding, mode=self.pad_mode, value=0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._replication_pad(x) + return self.conv3d(x) + + +class CausalUpsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + # TODO(freda): Check if this causes temporal inconsistency. + # Shoule reverse the order of the following two ops, + # better perf and better temporal smoothness. + x = self.conv(x) + return x[..., int(time_factor - 1) :, :, :] + + +class CausalDownsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d( + in_channels, + in_channels, + kernel_size=3, + stride=2, + time_stride=2, + padding=0, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x = replication_pad(x) + x = self.conv(x) + return x + + +class CausalHybridUpsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_up: bool = True, + temporal_up: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.conv1 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(3, 1, 1), + stride=1, + time_stride=1, + padding=0, + ) + self.conv2 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(1, 3, 3), + stride=1, + time_stride=1, + padding=1, + ) + self.conv3 = CausalConv3d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + time_stride=1, + padding=0, + ) + self.spatial_up = spatial_up + self.temporal_up = temporal_up + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_up and not self.temporal_up: + return x + + # hybrid upsample temporally. + if self.temporal_up: + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = x[..., int(time_factor - 1) :, :, :] + x = self.conv1(x) + x + + # hybrid upsample spatially. + if self.spatial_up: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + x = self.conv2(x) + x + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalHybridDownsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_down: bool = True, + temporal_down: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.conv1 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(1, 3, 3), + stride=2, + time_stride=1, + padding=0, + ) + self.conv2 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(3, 1, 1), + stride=1, + time_stride=2, + padding=0, + ) + self.conv3 = CausalConv3d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + time_stride=1, + padding=0, + ) + self.spatial_down = spatial_down + self.temporal_down = temporal_down + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_down and not self.temporal_down: + return x + + # hybrid downsample spatially. + if self.spatial_down: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x1 = self.conv1(x) + x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + x = x1 + x2 + + # hybrid downsample temporally. + if self.temporal_down: + x = replication_pad(x) + x1 = self.conv2(x) + x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1)) + x = x1 + x2 + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalResnetBlock3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=num_groups) + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalResnetBlockFactorized3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=1) + self.conv1 = nn.Sequential( + CausalConv3d( + in_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Sequential( + CausalConv3d( + out_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size = time2batch(q) + k, batch_size = time2batch(k) + v, batch_size = time2batch(v) + + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = batch2time(h_, batch_size) + h_ = self.proj_out(h_) + return x + h_ + + +class CausalTemporalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size, height = space2batch(q) + k, _, _ = space2batch(k) + v, _, _ = space2batch(v) + + bhw, c, t = q.shape + q = q.permute(0, 2, 1) # (bhw, t, c) + k = k.permute(0, 2, 1) # (bhw, t, c) + v = v.permute(0, 2, 1) # (bhw, t, c) + + w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t) + w_ = w_ * (int(c) ** (-0.5)) + + # Apply causal mask + mask = torch.tril(torch.ones_like(w_)) + w_ = w_.masked_fill(mask == 0, float("-inf")) + w_ = F.softmax(w_, dim=2) + + # attend to values + h_ = torch.bmm(w_, v) # (bhw, t, c) + h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t) + + h_ = batch2space(h_, batch_size, height) + h_ = self.proj_out(h_) + return x + h_ + + +class EncoderBase(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher = Patcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size + + # downsampling + self.conv_in = CausalConv3d(in_channels, channels, kernel_size=3, stride=1, padding=1) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = CausalDownsample3d(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d(block_in, z_channels, kernel_size=3, stride=1, padding=1) + + def patcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.patcher(x) + x = batch2time(x, batch_size) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + else: + # temporal downsample (last level) + time_factor = 1 + 1 * (hs[-1].shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + hs[-1] = replication_pad(hs[-1]) + hs.append( + F.avg_pool3d( + hs[-1], + kernel_size=[time_factor, 1, 1], + stride=[2, 1, 1], + ) + ) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderBase(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher = UnPatcher(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = CausalUpsample3d(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def unpatcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.unpatcher(x) + x = batch2time(x, batch_size) + + return x + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + else: + # temporal upsample (last level) + time_factor = 1.0 + 1.0 * (h.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + h = h.repeat_interleave(int(time_factor), dim=2) + h = h[..., int(time_factor - 1) :, :, :] + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h + + +class EncoderFactorized(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 16, + temporal_compression: int = 8, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + in_channels = in_channels * patch_size * patch_size * patch_size + + # calculate the number of downsample operations + self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert ( + self.num_spatial_downs <= self.num_resolutions + ), f"Spatially downsample {self.num_resolutions} times at most" + + self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_downs <= self.num_resolutions + ), f"Temporally downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = nn.Sequential( + CausalConv3d( + in_channels, + channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + spatial_down = i_level < self.num_spatial_downs + temporal_down = i_level < self.num_temporal_downs + down.downsample = CausalHybridDownsample3d( + block_in, + spatial_down=spatial_down, + temporal_down=temporal_down, + ) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d( + z_channels, + z_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderFactorized(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 16, + temporal_compression: int = 8, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange")) + out_ch = out_channels * patch_size * patch_size * patch_size + + # calculate the number of upsample operations + self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size)) + assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most" + self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size)) + assert ( + self.num_temporal_ups <= self.num_resolutions + ), f"Temporally upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Sequential( + CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + legacy_mode = ignore_kwargs.get("legacy_mode", False) + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + # The layer index for temporal/spatial downsampling performed + # in the encoder should correspond to the layer index in + # reverse order where upsampling is performed in the decoder. + # If you've a pre-trained model, you can simply finetune. + i_level_reverse = self.num_resolutions - i_level - 1 + if legacy_mode: + temporal_up = i_level_reverse < self.num_temporal_ups + else: + temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1 + spatial_up = temporal_up or ( + i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups + ) + up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/patching.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/patching.py new file mode 100644 index 0000000000000000000000000000000000000000..ed62a80cf6e0449abeb18bdabc80bbee8ab36ff6 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/patching.py @@ -0,0 +1,310 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The patcher and unpatcher implementation for 2D and 3D data. + +The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions. +One on the rows and one on the columns. +For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2. +We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component. +For H component, we can use a 1D convolution with kernel [1, -1] and stride 2. +Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all + as we need to support downsampling for more than 2x. +For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be. + [3, 256, 256] -> [12, 128, 128] -> [48, 64, 64] +""" + +import torch +import torch.nn.functional as F +from einops import rearrange + +_WAVELETS = { + "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]), + "rearrange": torch.tensor([1.0, 1.0]), +} +_PERSISTENT = False + + +class Patcher(torch.nn.Module): + """A module to convert image tensors into patches using torch operations. + + The main difference from `class Patching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Patching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._haar(x) + elif self.patch_method == "rearrange": + return self._arrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) + xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) + xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) + xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) + + out = torch.cat([xll, xlh, xhl, xhh], dim=1) + if rescale: + out = out / 2 + return out + + def _haar(self, x): + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + x = rearrange( + x, + "b c (h p1) (w p2) -> b (c p1 p2) h w", + p1=self.patch_size, + p2=self.patch_size, + ).contiguous() + return x + + +class Patcher3D(Patcher): + """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + self.register_buffer( + "patch_size_buffer", + patch_size * torch.ones([1], dtype=torch.int32), + persistent=_PERSISTENT, + ) + + def _dwt(self, x, wavelet, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + # Handles temporal axis. + x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + # Handles spatial axes. + xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1) + if rescale: + out = out / (2 * torch.sqrt(torch.tensor(2.0))) + return out + + def _haar(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + for _ in self.range: + x = self._dwt(x, "haar", rescale=True) + return x + + def _arrange(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + x = rearrange( + x, + "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ).contiguous() + return x + + +class UnPatcher(torch.nn.Module): + """A module to convert patches into image tensorsusing torch operations. + + The main difference from `class Unpatching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Unpatching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._ihaar(x) + elif self.patch_method == "rearrange": + return self._iarrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 4 + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) + + # Inverse transform. + yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0)) + y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2)) + + if rescale: + y = y * 2 + return y + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.patch_size, + p2=self.patch_size, + ) + return x + + +class UnPatcher3D(UnPatcher): + """A 3D inverse discrete wavelet transform for video wavelet decompositions.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets + n = h.shape[0] + + g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors. + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1) + hl = hl.to(dtype=dtype) + hh = hh.to(dtype=dtype) + + xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + + # Height height transposed convolutions. + xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + # Handles width transposed convolutions. + xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + # Handles time axis transposed convolutions. + x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + if rescale: + x = x * (2 * torch.sqrt(torch.tensor(2.0))) + return x + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + x = x[:, :, self.patch_size - 1 :, ...] + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ) + x = x[:, :, self.patch_size - 1 :, ...] + return x diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/quantizers.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/quantizers.py new file mode 100644 index 0000000000000000000000000000000000000000..b5698b9b4b19e942b4fb750f0178f860118e64fc --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/quantizers.py @@ -0,0 +1,510 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Quantizers for discrete image and video tokenization.""" + +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import reduce + +from uniception.models.libs.cosmos_tokenizer.modules.utils import ( + default, + entropy, + pack_one, + rearrange, + round_ste, + unpack_one, +) + + +class ResidualFSQuantizer(nn.Module): + """Residual Finite Scalar Quantization + + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, levels: list[int], num_quantizers: int, **ignore_kwargs): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.float32) + self.layers = nn.ModuleList([FSQuantizer(levels=levels) for _ in range(num_quantizers)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + indices_stack = [] + residual = x + quantized_out = 0 + loss_out = 0 + for i, layer in enumerate(self.layers): + quant_indices, z, loss = layer(residual) + indices_stack.append(quant_indices) + residual = residual - z.detach() + quantized_out = quantized_out + z + loss_out = loss_out + loss + self.residual = residual + indices = torch.stack(indices_stack, dim=1) + return indices, quantized_out.to(self.dtype), loss_out.to(self.dtype) + + def indices_to_codes(self, indices_stack: torch.Tensor) -> torch.Tensor: + quantized_out = 0 + for layer, indices in zip(self.layers, indices_stack.transpose(0, 1)): + quantized_out += layer.indices_to_codes(indices) + return quantized_out + + +class FSQuantizer(nn.Module): + """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 + + Code adapted from Jax version in Appendix A.1. + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/finite_scalar_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + levels: list[int], + dim: Optional[int] = None, + num_codebooks=1, + keep_num_codebooks_dim: Optional[bool] = None, + scale: Optional[float] = None, + **ignore_kwargs, + ): + super().__init__() + self.dtype = ignore_kwargs.get("dtype", torch.bfloat16) + _levels = torch.tensor(levels, dtype=torch.int32) + self.register_buffer("_levels", _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32) + self.register_buffer("_basis", _basis, persistent=False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.codebook_size = self._levels.prod().item() + + implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False) + self.register_buffer("implicit_codebook", implicit_codebook, persistent=False) + + def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z: torch.Tensor) -> torch.Tensor: + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat).float() + return (zhat * self._basis).sum(dim=-1).to(torch.int32) + + def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor: + """Inverse of `codes_to_indices`.""" + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + codes = self._scale_and_shift_inverse(codes_non_centered) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + if project_out: + codes = self.project_out(codes) + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes.to(self.dtype) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + is_img_or_video = z.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + codes = self.quantize(z) + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if is_img_or_video: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + indices = unpack_one(indices, ps, "b * c") + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True)) + else: + dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + return (indices, out.to(self.dtype), dummy_loss) + + +class VectorQuantizer(nn.Module): + """Improved version over VectorQuantizer. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + + Adapted from: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/ + taming/modules/vqvae/quantize.py + + [Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer] + https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/License.txt + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + beta: float = 0.25, + remap: str = None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + use_norm=False, + **ignore_kwargs, + ): + super().__init__() + self.n_e = num_embeddings + self.e_dim = embedding_dim + self.beta = beta + self.legacy = legacy + self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = num_embeddings + + self.sane_index_shape = sane_index_shape + self.dtype = ignore_kwargs.get("dtype", torch.float32) + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits is False, "Only for interface compatible with Gumbel" + assert return_logits is False, "Only for interface compatible with Gumbel" + z = rearrange(z, "b c h w -> b h w c").contiguous() + z_flattened = z.view(-1, self.e_dim) + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", + z_flattened, + rearrange(self.embedding.weight, "n d -> d n"), + ) + ) + + encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + encodings = torch.zeros(encoding_indices.shape[0], self.n_e, device=z.device) + encodings.scatter_(1, encoding_indices, 1) + z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape) + min_encodings = None + + z_q, z = self.norm(z_q), self.norm(z) + + # compute loss for embedding + commit_loss = torch.mean((z_q - z.detach()) ** 2, dim=[1, 2, 3], keepdim=True) + emb_loss = torch.mean((z_q.detach() - z) ** 2, dim=[1, 2, 3], keepdim=True) + if not self.legacy: + loss = self.beta * emb_loss + commit_loss + else: + loss = emb_loss + self.beta * commit_loss + + # preserve gradients + z_q = z + (z_q - z).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + # reshape back to match original input shape + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = encoding_indices.squeeze(1).reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(encoding_indices.squeeze(1)) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + # TODO: return (indices, z_q, loss) + return ( + z_q, + loss, + ( + encoding_indices.squeeze(1), + min_encodings, + commit_loss.mean().detach(), + self.beta * emb_loss.mean().detach(), + perplexity.mean().detach(), + ), + ) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class LFQuantizer(nn.Module): + """Lookup-Free Quantization + + Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/ + vector_quantize_pytorch/lookup_free_quantization.py + [Copyright (c) 2020 Phil Wang] + https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE + """ + + def __init__( + self, + *, + codebook_size: int, + codebook_dim: int, + embed_dim: Optional[int] = None, # if None, use codebook_dim + entropy_loss_weight=0.1, + commitment_loss_weight=0.25, + default_temp: float = 0.01, + entropy_loss: bool = False, + **ignore_kwargs, + ): + """Lookup-Free Quantization + + Args: + codebook_size (int): The number of entries in the codebook. + codebook_dim (int): The number of bits in each code. + embed_dim (Optional[int], optional): The dimension of the input embedding. Defaults to None. + entropy_loss_weight (float, optional): Whether to use entropy loss. Defaults to 0.1. + commitment_loss_weight (float, optional): Weight for commitment loss. Defaults to 0.25. + default_temp (float, optional): The temprature to use. Defaults to 0.01. + entropy_loss (bool, optional): Flag for entropy loss. Defaults to False. + """ + super().__init__() + self.entropy_loss = entropy_loss + self.codebook_dim = codebook_dim + self.default_temp = default_temp + self.entrop_loss_weight = entropy_loss_weight + self.commitment_loss_weight = commitment_loss_weight + embed_dim = embed_dim or codebook_dim + + has_projections = embed_dim != codebook_dim + self.project_in = nn.Linear(embed_dim, codebook_dim) if has_projections else nn.Identity() + self.project_out = nn.Linear(codebook_dim, embed_dim) if has_projections else nn.Identity() + + self.dtype = ignore_kwargs.get("dtype", torch.float32) + + if entropy_loss: + assert 2**codebook_dim == codebook_size, "codebook size must be 2 ** codebook_dim" + self.codebook_size = codebook_size + + self.register_buffer( + "mask", + 2 ** torch.arange(codebook_dim - 1, -1, -1), + persistent=False, + ) + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + all_codes = torch.arange(codebook_size) + bits = ((all_codes[..., None].int() & self.mask) != 0).float() + codebook = 2 * bits - 1.0 + + self.register_buffer("codebook", codebook, persistent=False) # [codebook_size, codebook_dim] + + def forward(self, z: torch.Tensor, temp: float = None) -> torch.Tensor: + temp = temp or self.default_temp + + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + z = self.project_in(z) + + # split out number of codebooks + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + # quantization + original_input = z + + codebook_value = torch.ones_like(z) + z_q = torch.where(z > 0, codebook_value, -codebook_value) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # commit loss + commit_loss = ((original_input - z_q.detach()) ** 2).mean(dim=[1, 2, 3]) + + z_q = rearrange(z_q, "b n c d -> b n (c d)") + z_q = self.project_out(z_q) + + # reshape + z_q = unpack_one(z_q, ps, "b * d") + z_q = rearrange(z_q, "b ... d -> b d ...") + + loss = self.commitment_loss_weight * commit_loss + + # entropy loss (eq-5) + if self.entropy_loss: + # indices + indices = reduce((z > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") + indices = unpack_one(indices, ps, "b * c") + indices = rearrange(indices, "... 1 -> ...") + + distance = -2 * torch.einsum( + "... i d, j d -> ... i j", + original_input, + self.codebook.to(original_input.dtype), + ) + prob = (-distance / temp).softmax(dim=-1) + per_sample_entropy = entropy(prob).mean(dim=[1, 2]) + avg_prob = reduce(prob, "... c d -> c d", "mean") + codebook_entropy = entropy(avg_prob).mean() + entropy_aux_loss = per_sample_entropy - codebook_entropy + + loss += self.entrop_loss_weight * entropy_aux_loss + + # TODO: return (indices, z_q, loss) + return ( + z_q, + loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), + ( + indices, + self.commitment_loss_weight * commit_loss.mean().detach(), + self.entrop_loss_weight * entropy_aux_loss.mean().detach(), + self.entrop_loss_weight * per_sample_entropy.mean().detach(), + self.entrop_loss_weight * codebook_entropy.mean().detach(), + ), + ) + else: + return ( + z_q, + loss.unsqueeze(1).unsqueeze(1).unsqueeze(1), + self.commitment_loss_weight * commit_loss.mean().detach(), + ) + + +class InvQuantizerJit(nn.Module): + """Use for decoder_jit to trace quantizer in discrete tokenizer""" + + def __init__(self, quantizer): + super().__init__() + self.quantizer = quantizer + + def forward(self, indices: torch.Tensor): + codes = self.quantizer.indices_to_codes(indices) + return codes.to(self.quantizer.dtype) diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/utils.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7b1554498264f2ce99805ee878834eabb7a760 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/modules/utils.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utilities for the networks module.""" + +from typing import Any + +import torch +from einops import pack, rearrange, unpack + + +def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size = x.shape[0] + return rearrange(x, "b c t h w -> (b t) c h w"), batch_size + + +def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: + return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + + +def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size, height = x.shape[0], x.shape[-2] + return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height + + +def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: + return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) + + +def cast_tuple(t: Any, length: int = 1) -> Any: + return t if isinstance(t, tuple) else ((t,) * length) + + +def replication_pad(x): + return torch.cat([x[:, :, :1, ...], x], dim=2) + + +def divisible_by(num: int, den: int) -> bool: + return (num % den) == 0 + + +def is_odd(n: int) -> bool: + return not divisible_by(n, 2) + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class CausalNormalize(torch.nn.Module): + def __init__(self, in_channels, num_groups=1): + super().__init__() + self.norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True, + ) + self.num_groups = num_groups + + def forward(self, x): + # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. + # All new models should use num_groups=1, otherwise causality is not guaranteed. + if self.num_groups == 1: + x, batch_size = time2batch(x) + return batch2time(self.norm(x), batch_size) + return self.norm(x) + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def round_ste(z: torch.Tensor) -> torch.Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() + + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/__init__.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5cff9e207ca8e27d78d09458a0d918a847841b06 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/__init__.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +from uniception.models.libs.cosmos_tokenizer.networks.configs import continuous_image as continuous_image_dict +from uniception.models.libs.cosmos_tokenizer.networks.configs import continuous_video as continuous_video_dict +from uniception.models.libs.cosmos_tokenizer.networks.configs import discrete_image as discrete_image_dict +from uniception.models.libs.cosmos_tokenizer.networks.configs import discrete_video as discrete_video_dict +from uniception.models.libs.cosmos_tokenizer.networks.continuous_image import ContinuousImageTokenizer +from uniception.models.libs.cosmos_tokenizer.networks.continuous_video import CausalContinuousVideoTokenizer +from uniception.models.libs.cosmos_tokenizer.networks.discrete_image import DiscreteImageTokenizer +from uniception.models.libs.cosmos_tokenizer.networks.discrete_video import CausalDiscreteVideoTokenizer + + +class TokenizerConfigs(Enum): + CI = continuous_image_dict + DI = discrete_image_dict + CV = continuous_video_dict + DV = discrete_video_dict + + +class TokenizerModels(Enum): + CI = ContinuousImageTokenizer + DI = DiscreteImageTokenizer + CV = CausalContinuousVideoTokenizer + DV = CausalDiscreteVideoTokenizer diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/configs.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..820fbfbb7a9202e9029e2f9575a17b4040bafc9d --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/configs.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The default image and video tokenizer configs.""" + +from uniception.models.libs.cosmos_tokenizer.modules import ( + ContinuousFormulation, + Decoder3DType, + DecoderType, + DiscreteQuantizer, + Encoder3DType, + EncoderType, +) + +continuous_image = dict( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The output latent dimension (channels). + latent_channels=16, + # The encoder output channels just before sampling. + # Which is also the decoder's input channels. + z_channels=16, + # A factor over the z_channels, to get the total channels the encoder should output. + # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels. + z_factor=1, + name="CI", + # What formulation to use, either "AE" or "VAE". + # Chose VAE here, since the pre-trained ckpt were of a VAE formulation. + formulation=ContinuousFormulation.AE.name, + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) + +discrete_image = dict( + # The attention resolution for res blocks. + attn_resolutions=[32], + # The base number of channels. + channels=128, + # The channel multipler for each resolution. + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + # The spatial compression ratio. + spatial_compression=16, + # The number of layers in each res block. + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + # The encoder output channels just before sampling. + z_channels=256, + # A factor over the z_channels, to get the total channels the encoder should output. + # for discrete tokenization, often we directly use the vector, so z_factor=1. + z_factor=1, + # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ. + quantizer=DiscreteQuantizer.FSQ.name, + # The embedding dimension post-quantization, which is also the input channels of the decoder. + # Which is also the output + embedding_dim=6, + # The number of levels to use for fine-scalar quantization. + levels=[8, 8, 8, 5, 5, 5], + # The number of quantizers to use for residual fine-scalar quantization. + num_quantizers=4, + name="DI", + # Specify type of encoder ["Default", "LiteVAE"] + encoder=EncoderType.Default.name, + # Specify type of decoder ["Default"] + decoder=DecoderType.Default.name, +) + +continuous_video = dict( + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + latent_channels=16, + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=8, + temporal_compression=8, + formulation=ContinuousFormulation.AE.name, + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="CV", +) + +discrete_video = dict( + attn_resolutions=[32], + channels=128, + channels_mult=[2, 4, 4], + dropout=0.0, + in_channels=3, + num_res_blocks=2, + out_channels=3, + resolution=1024, + patch_size=4, + patch_method="haar", + z_channels=16, + z_factor=1, + num_groups=1, + legacy_mode=False, + spatial_compression=16, + temporal_compression=8, + quantizer=DiscreteQuantizer.FSQ.name, + embedding_dim=6, + levels=[8, 8, 8, 5, 5, 5], + encoder=Encoder3DType.FACTORIZED.name, + decoder=Decoder3DType.FACTORIZED.name, + name="DV", +) diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_image.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_image.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ec8f9eb418660b315b103881e2173e74241c12 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_image.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The continuous image tokenizer with VAE or AE formulation for 2D data.""" + +from collections import OrderedDict, namedtuple + +import torch +from torch import nn + +from uniception.models.libs.cosmos_tokenizer.modules import ContinuousFormulation, DecoderType, EncoderType + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) + + +class ContinuousImageTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, latent_channels: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "ContinuousImageTokenizer") + self.latent_channels = latent_channels + + encoder_name = kwargs.get("encoder", EncoderType.Default.name) + self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", DecoderType.Default.name) + self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1) + self.post_quant_conv = torch.nn.Conv2d(latent_channels, z_channels, 1) + + formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) + self.distribution = ContinuousFormulation[formulation_name].value() + + num_parameters = sum(param.numel() for param in self.parameters()) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("distribution", self.distribution), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + return self.distribution(moments) + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input) -> dict[str, torch.Tensor] | NetworkEval: + latent, posteriors = self.encode(input) + dec = self.decode(latent) + if self.training: + return dict(reconstructions=dec, posteriors=posteriors, latent=latent) + return NetworkEval(reconstructions=dec, posteriors=posteriors, latent=latent) diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_video.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_video.py new file mode 100644 index 0000000000000000000000000000000000000000..3a085a4241f433a8c09885e1a98dfe29bea52b68 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/continuous_video.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The causal continuous video tokenizer with VAE or AE formulation for 3D data..""" +from collections import OrderedDict, namedtuple + +from torch import nn + +from uniception.models.libs.cosmos_tokenizer.modules import ContinuousFormulation, Decoder3DType, Encoder3DType +from uniception.models.libs.cosmos_tokenizer.modules.layers3d import CausalConv3d + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) + + +class CausalContinuousVideoTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, latent_channels: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalContinuousVideoTokenizer") + self.latent_channels = latent_channels + + encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) + self.encoder = Encoder3DType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) + if kwargs.get("temporal_compression", 4) == 4: + kwargs["channels_mult"] = [2, 4] + decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) + self.decoder = Decoder3DType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = CausalConv3d( + z_factor * z_channels, + z_factor * latent_channels, + kernel_size=1, + padding=0, + ) + self.post_quant_conv = CausalConv3d(latent_channels, z_channels, kernel_size=1, padding=0) + + formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) + self.distribution = ContinuousFormulation[formulation_name].value() + + num_parameters = sum(param.numel() for param in self.parameters()) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("distribution", self.distribution), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + return self.distribution(moments) + + def decode(self, z): + z = self.post_quant_conv(z) + return self.decoder(z) + + def forward(self, input): + latent, posteriors = self.encode(input) + reconstructions = self.decode(latent) + if self.training: + return dict( + reconstructions=reconstructions, + posteriors=posteriors, + latent=latent, + ) + return NetworkEval( + reconstructions=reconstructions, + posteriors=posteriors, + latent=latent, + ) diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_image.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_image.py new file mode 100644 index 0000000000000000000000000000000000000000..7e75e22a1ff214d32b96a65e72ba94a57a04cde5 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_image.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The network definition for discrete image tokenization with VQ, LFQ, FSQ or ResidualFSQ.""" +from collections import OrderedDict, namedtuple + +import torch +from torch import nn + +from uniception.models.libs.cosmos_tokenizer.modules import DecoderType, DiscreteQuantizer, EncoderType +from uniception.models.libs.cosmos_tokenizer.modules.quantizers import InvQuantizerJit + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class DiscreteImageTokenizer(nn.Module): + def __init__(self, z_channels: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "DiscreteImageTokenizer") + self.embedding_dim = embedding_dim + + encoder_name = kwargs.get("encoder", EncoderType.Default.name) + self.encoder = EncoderType[encoder_name].value(z_channels=z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", DecoderType.Default.name) + self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) + self.quant_conv = nn.Conv2d(z_channels, embedding_dim, 1) + self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1) + + quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) + if quantizer_name == DiscreteQuantizer.VQ.name: + assert "num_embeddings" in kwargs, f"`num_embeddings` must be provided for {quantizer_name}." + kwargs.update(dict(embedding_dim=embedding_dim)) + elif quantizer_name == DiscreteQuantizer.LFQ.name: + assert "codebook_size" in kwargs, f"`codebook_size` must be provided for {quantizer_name}." + assert "codebook_dim" in kwargs, f"`codebook_dim` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.FSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.RESFSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}.name." + assert "num_quantizers" in kwargs, f"`num_quantizers` must be provided for {quantizer_name}." + self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) + + num_parameters = sum(param.numel() for param in self.parameters()) + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(DiscreteImageTokenizer, self).to(*args, **kwargs) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("quantizer", self.quantizer), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("inv_quant", InvQuantizerJit(self.quantizer)), + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def decode_code(self, code_b): + quant_b = self.quantizer.indices_to_codes(code_b) + quant_b = self.post_quant_conv(quant_b) + return self.decoder(quant_b) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) + return NetworkEval( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_video.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_video.py new file mode 100644 index 0000000000000000000000000000000000000000..c9268fe7df42c7ed9b9abb18597d55218f56f45e --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/networks/discrete_video.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The network definition for discrete video tokenizer with VQ, LFQ, FSQ or ResidualFSQ.""" +from collections import OrderedDict, namedtuple + +import torch +from torch import nn + +from uniception.models.libs.cosmos_tokenizer.modules import Decoder3DType, DiscreteQuantizer, Encoder3DType +from uniception.models.libs.cosmos_tokenizer.modules.layers3d import CausalConv3d +from uniception.models.libs.cosmos_tokenizer.modules.quantizers import InvQuantizerJit + +NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) + + +class CausalDiscreteVideoTokenizer(nn.Module): + def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") + self.embedding_dim = embedding_dim + + encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) + self.encoder = Encoder3DType[encoder_name].value(z_channels=z_factor * z_channels, **kwargs) + + decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) + self.decoder = Decoder3DType[decoder_name].value(z_channels=z_channels, **kwargs) + + self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0) + self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0) + + quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) + if quantizer_name == DiscreteQuantizer.VQ.name: + assert "num_embeddings" in kwargs, f"`num_embeddings` must be provided for {quantizer_name}." + kwargs.update(dict(embedding_dim=embedding_dim)) + elif quantizer_name == DiscreteQuantizer.LFQ.name: + assert "codebook_size" in kwargs, f"`codebook_size` must be provided for {quantizer_name}." + assert "codebook_dim" in kwargs, f"`codebook_dim` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.FSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." + elif quantizer_name == DiscreteQuantizer.RESFSQ.name: + assert "levels" in kwargs, f"`levels` must be provided for {quantizer_name}." + assert "num_quantizers" in kwargs, f"`num_quantizers` must be provided for {quantizer_name}." + self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) + + num_parameters = sum(param.numel() for param in self.parameters()) + + def to(self, *args, **kwargs): + setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) + return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) + + def encoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("encoder", self.encoder), + ("quant_conv", self.quant_conv), + ("quantizer", self.quantizer), + ] + ) + ) + + def decoder_jit(self): + return nn.Sequential( + OrderedDict( + [ + ("inv_quant", InvQuantizerJit(self.quantizer)), + ("post_quant_conv", self.post_quant_conv), + ("decoder", self.decoder), + ] + ) + ) + + def last_decoder_layer(self): + return self.decoder.conv_out + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return self.quantizer(h) + + def decode(self, quant): + quant = self.post_quant_conv(quant) + return self.decoder(quant) + + def decode_code(self, code_b): + quant_b = self.quantizer.indices_to_codes(code_b) + quant_b = self.post_quant_conv(quant_b) + return self.decoder(quant_b) + + def forward(self, input): + quant_info, quant_codes, quant_loss = self.encode(input) + reconstructions = self.decode(quant_codes) + if self.training: + return dict( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) + return NetworkEval( + reconstructions=reconstructions, + quant_loss=quant_loss, + quant_info=quant_info, + ) diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/utils.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4fb572118377a27274b38a80e89bf0fc78e000b --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/utils.py @@ -0,0 +1,402 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for the inference libraries.""" + +import os +from glob import glob +from typing import Any + +import mediapy as media +import numpy as np +import torch +from PIL import Image + +from uniception.models.libs.cosmos_tokenizer.networks import TokenizerModels + +_DTYPE, _DEVICE = torch.bfloat16, "cuda" +_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) +_SPATIAL_ALIGN = 16 +_TEMPORAL_ALIGN = 8 + + +def load_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + full_model.load_state_dict(ckpts.state_dict(), strict=False) + return full_model.eval().to(device) + + +def load_encoder_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + encoder_model = full_model.encoder_jit() + encoder_model.load_state_dict(ckpts.state_dict(), strict=False) + return encoder_model.eval().to(device) + + +def load_decoder_model( + jit_filepath: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", +) -> torch.nn.Module | torch.jit.ScriptModule: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + if tokenizer_config is None: + return load_jit_model(jit_filepath, device) + full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device) + decoder_model = full_model.decoder_jit() + decoder_model.load_state_dict(ckpts.state_dict(), strict=False) + return decoder_model.eval().to(device) + + +def _load_pytorch_model( + jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda" +) -> torch.nn.Module: + """Loads a torch.nn.Module from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + tokenizer_name = tokenizer_config["name"] + model = TokenizerModels[tokenizer_name].value(**tokenizer_config) + ckpts = torch.jit.load(jit_filepath) + return model, ckpts + + +def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule: + """Loads a torch.jit.ScriptModule from a filepath. + + Args: + jit_filepath: The filepath to the JIT-compiled model. + device: The device to load the model onto, default=cuda. + Returns: + The JIT compiled model loaded to device and on eval mode. + """ + model = torch.jit.load(jit_filepath) + return model.eval().to(device) + + +def save_jit_model( + model: torch.jit.ScriptModule | torch.jit.RecursiveScriptModule = None, + jit_filepath: str = None, +) -> None: + """Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file. + + Args: + model: JIT compiled model loaded onto `config.checkpoint.jit.device`. + jit_filepath: The filepath to the JIT-compiled model. + """ + torch.jit.save(model, jit_filepath) + + +def get_filepaths(input_pattern) -> list[str]: + """Returns a list of filepaths from a pattern.""" + filepaths = sorted(glob(str(input_pattern))) + return list(set(filepaths)) + + +def get_output_filepath(filepath: str, output_dir: str = None) -> str: + """Returns the output filepath for the given input filepath.""" + output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions" + output_filepath = f"{output_dir}/{os.path.basename(filepath)}" + os.makedirs(output_dir, exist_ok=True) + return output_filepath + + +def read_image(filepath: str) -> np.ndarray: + """Reads an image from a filepath. + + Args: + filepath: The filepath to the image. + + Returns: + The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype. + """ + image = media.read_image(filepath) + # convert the grey scale image to RGB + # since our tokenizers always assume 3-channel RGB image + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + # convert RGBA to RGB + if image.shape[-1] == 4: + image = image[..., :3] + return image + + +def read_video(filepath: str) -> np.ndarray: + """Reads a video from a filepath. + + Args: + filepath: The filepath to the video. + Returns: + The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype. + """ + video = media.read_video(filepath) + # convert the grey scale frame to RGB + # since our tokenizers always assume 3-channel video + if video.ndim == 3: + video = np.stack([video] * 3, axis=-1) + # convert RGBA to RGB + if video.shape[-1] == 4: + video = video[..., :3] + return video + + +def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray: + """Resizes an image to have the short side of `short_size`. + + Args: + image: The image to resize, layout HxWxC, of any range. + short_size: The size of the short side. + Returns: + The resized image. + """ + if short_size is None: + return image + height, width = image.shape[-3:-1] + if height <= width: + height_new, width_new = short_size, int(width * short_size / height + 0.5) + width_new = width_new if width_new % 2 == 0 else width_new + 1 + else: + height_new, width_new = ( + int(height * short_size / width + 0.5), + short_size, + ) + height_new = height_new if height_new % 2 == 0 else height_new + 1 + return media.resize_image(image, shape=(height_new, width_new)) + + +def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray: + """Resizes a video to have the short side of `short_size`. + + Args: + video: The video to resize, layout TxHxWxC, of any range. + short_size: The size of the short side. + Returns: + The resized video. + """ + if short_size is None: + return video + height, width = video.shape[-3:-1] + if height <= width: + height_new, width_new = short_size, int(width * short_size / height + 0.5) + width_new = width_new if width_new % 2 == 0 else width_new + 1 + else: + height_new, width_new = ( + int(height * short_size / width + 0.5), + short_size, + ) + height_new = height_new if height_new % 2 == 0 else height_new + 1 + return media.resize_video(video, shape=(height_new, width_new)) + + +def write_image(filepath: str, image: np.ndarray): + """Writes an image to a filepath.""" + return media.write_image(filepath, image) + + +def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None: + """Writes a video to a filepath.""" + return media.write_video(filepath, video, fps=fps) + + +def numpy2tensor( + input_image: np.ndarray, + dtype: torch.dtype = _DTYPE, + device: str = _DEVICE, + range_min: int = -1, +) -> torch.Tensor: + """Converts image(dtype=np.uint8) to `dtype` in range [0..255]. + + Args: + input_image: A batch of images in range [0..255], BxHxWx3 layout. + Returns: + A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype. + """ + ndim = input_image.ndim + indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1] + image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F + if range_min == -1: + image = 2.0 * image - 1.0 + return torch.from_numpy(image).to(dtype).to(device) + + +def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray: + """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255]. + + Args: + input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1]. + Returns: + A numpy image of layout BxHxWx3, range [0..255], uint8 dtype. + """ + if range_min == -1: + input_tensor = (input_tensor.float() + 1.0) / 2.0 + ndim = input_tensor.ndim + output_image = input_tensor.clamp(0, 1).cpu().numpy() + output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,)) + return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8) + + +def pad_image_batch(batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN) -> tuple[np.ndarray, list[int]]: + """Pads a batch of images to be divisible by `spatial_align`. + + Args: + batch: The batch of images to pad, layout BxHxWx3, in any range. + align: The alignment to pad to. + Returns: + The padded batch and the crop region. + """ + height, width = batch.shape[1:3] + align = spatial_align + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + crop_region = [ + height_to_pad >> 1, + width_to_pad >> 1, + height + (height_to_pad >> 1), + width + (width_to_pad >> 1), + ] + batch = np.pad( + batch, + ( + (0, 0), + (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), + (0, 0), + ), + mode="constant", + ) + return batch, crop_region + + +def pad_video_batch( + batch: np.ndarray, + temporal_align: int = _TEMPORAL_ALIGN, + spatial_align: int = _SPATIAL_ALIGN, +) -> tuple[np.ndarray, list[int]]: + """Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`. + + Zero pad spatially. Reflection pad temporally to handle causality better. + Args: + batch: The batch of videos to pad., layout BxFxHxWx3, in any range. + align: The alignment to pad to. + Returns: + The padded batch and the crop region. + """ + num_frames, height, width = batch.shape[-4:-1] + align = spatial_align + height_to_pad = (align - height % align) if height % align != 0 else 0 + width_to_pad = (align - width % align) if width % align != 0 else 0 + + align = temporal_align + frames_to_pad = (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0 + + crop_region = [ + frames_to_pad >> 1, + height_to_pad >> 1, + width_to_pad >> 1, + num_frames + (frames_to_pad >> 1), + height + (height_to_pad >> 1), + width + (width_to_pad >> 1), + ] + batch = np.pad( + batch, + ( + (0, 0), + (0, 0), + (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)), + (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), + (0, 0), + ), + mode="constant", + ) + batch = np.pad( + batch, + ( + (0, 0), + (frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)), + (0, 0), + (0, 0), + (0, 0), + ), + mode="edge", + ) + return batch, crop_region + + +def unpad_video_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: + """Unpads video with `crop_region`. + + Args: + batch: A batch of numpy videos, layout BxFxHxWxC. + crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices. + + Returns: + np.ndarray: Cropped numpy video, layout BxFxHxWxC. + """ + assert len(crop_region) == 6, "crop_region should be len of 6." + f1, y1, x1, f2, y2, x2 = crop_region + return batch[..., f1:f2, y1:y2, x1:x2, :] + + +def unpad_image_batch(batch: np.ndarray, crop_region: list[int]) -> np.ndarray: + """Unpads image with `crop_region`. + + Args: + batch: A batch of numpy images, layout BxHxWxC. + crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices. + + Returns: + np.ndarray: Cropped numpy image, layout BxHxWxC. + """ + assert len(crop_region) == 4, "crop_region should be len of 4." + y1, x1, y2, x2 = crop_region + return batch[..., y1:y2, x1:x2, :] diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/video_cli.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/video_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3b83b02834c1f76c8f1a8e5232c69e4fd943f2 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/video_cli.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A CLI to run CausalVideoTokenizer on plain videos based on torch.jit. + +Usage: + python3 -m cosmos_tokenizer.video_cli \ + --video_pattern 'path/to/video/samples/*.mp4' \ + --output_dir ./reconstructions \ + --checkpoint_enc ./pretrained_ckpts/CosmosCV_f4x8x8/encoder.jit \ + --checkpoint_dec ./pretrained_ckpts/CosmosCV_f4x8x8/decoder.jit + + Optionally, you can run the model in pure PyTorch mode: + python3 -m cosmos_tokenizer.video_cli \ + --video_pattern 'path/to/video/samples/*.mp4' \ + --mode=torch \ + --tokenizer_type=CV \ + --temporal_compression=4 \ + --spatial_compression=8 \ + --checkpoint_enc ./pretrained_ckpts/CosmosCV_f4x8x8/encoder.jit \ + --checkpoint_dec ./pretrained_ckpts/CosmosCV_f4x8x8/decoder.jit +""" + +import os +import sys +from argparse import ArgumentParser, Namespace +from typing import Any + +import numpy as np + +from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs +from uniception.models.libs.cosmos_tokenizer.utils import ( + get_filepaths, + get_output_filepath, + read_video, + resize_video, + write_video, +) +from uniception.models.libs.cosmos_tokenizer.video_lib import CausalVideoTokenizer + + +def _parse_args() -> tuple[Namespace, dict[str, Any]]: + parser = ArgumentParser(description="A CLI for CausalVideoTokenizer.") + parser.add_argument( + "--video_pattern", + type=str, + default="path/to/videos/*.mp4", + help="Glob pattern.", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="JIT full Autoencoder model filepath.", + ) + parser.add_argument( + "--checkpoint_enc", + type=str, + default=None, + help="JIT Encoder model filepath.", + ) + parser.add_argument( + "--checkpoint_dec", + type=str, + default=None, + help="JIT Decoder model filepath.", + ) + parser.add_argument( + "--tokenizer_type", + type=str, + choices=["CV", "DV"], + help="Specifies the tokenizer type.", + ) + parser.add_argument( + "--spatial_compression", + type=int, + choices=[8, 16], + default=8, + help="The spatial compression factor.", + ) + parser.add_argument( + "--temporal_compression", + type=int, + choices=[4, 8], + default=4, + help="The temporal compression factor.", + ) + parser.add_argument( + "--mode", + type=str, + choices=["torch", "jit"], + default="jit", + help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", + ) + parser.add_argument( + "--short_size", + type=int, + default=None, + help="The size to resample inputs. None, by default.", + ) + parser.add_argument( + "--temporal_window", + type=int, + default=17, + help="The temporal window to operate at a time.", + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Sets the precision, default bfloat16.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device for invoking the model.", + ) + parser.add_argument("--output_dir", type=str, default=None, help="Output directory.") + parser.add_argument( + "--output_fps", + type=float, + default=24.0, + help="Output frames-per-second (FPS).", + ) + parser.add_argument( + "--save_input", + action="store_true", + help="If on, the input video will be be outputted too.", + ) + + args = parser.parse_args() + return args + + +args = _parse_args() +if args.mode == "torch" and args.tokenizer_type not in ["CV", "DV"]: + sys.exit(1) + + +def _run_eval() -> None: + """Invokes JIT-compiled CausalVideoTokenizer on an input video.""" + + if args.checkpoint_enc is None and args.checkpoint_dec is None and args.checkpoint is None: + return + + if args.mode == "torch": + tokenizer_config = TokenizerConfigs[args.tokenizer_type].value + tokenizer_config.update(dict(spatial_compression=args.spatial_compression)) + tokenizer_config.update(dict(temporal_compression=args.temporal_compression)) + else: + tokenizer_config = None + + autoencoder = CausalVideoTokenizer( + checkpoint=args.checkpoint, + checkpoint_enc=args.checkpoint_enc, + checkpoint_dec=args.checkpoint_dec, + tokenizer_config=tokenizer_config, + device=args.device, + dtype=args.dtype, + ) + + filepaths = get_filepaths(args.video_pattern) + + for filepath in filepaths: + video = read_video(filepath) + video = resize_video(video, short_size=args.short_size) + + batch_video = video[np.newaxis, ...] + output_video = autoencoder(batch_video, temporal_window=args.temporal_window)[0] + output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) + write_video(output_filepath, output_video, fps=args.output_fps) + if args.save_input: + ext = os.path.splitext(output_filepath)[-1] + input_filepath = output_filepath.replace(ext, "_input" + ext) + write_video(input_filepath, video, fps=args.output_fps) + + +def main() -> None: + _run_eval() + + +if __name__ == "__main__": + main() diff --git a/UniCeption/uniception/models/libs/cosmos_tokenizer/video_lib.py b/UniCeption/uniception/models/libs/cosmos_tokenizer/video_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..83572346088cc0b25affa76964b43b5bbf904325 --- /dev/null +++ b/UniCeption/uniception/models/libs/cosmos_tokenizer/video_lib.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A library for Causal Video Tokenizer inference.""" + +from typing import Any + +import numpy as np +import torch +from tqdm import tqdm + +from uniception.models.libs.cosmos_tokenizer.utils import ( + load_decoder_model, + load_encoder_model, + load_model, + numpy2tensor, + pad_video_batch, + tensor2numpy, + unpad_video_batch, +) + + +class CausalVideoTokenizer(torch.nn.Module): + def __init__( + self, + checkpoint: str = None, + checkpoint_enc: str = None, + checkpoint_dec: str = None, + tokenizer_config: dict[str, Any] = None, + device: str = "cuda", + dtype: str = "bfloat16", + ) -> None: + super().__init__() + self._device = device + self._dtype = getattr(torch, dtype) + self._full_model = ( + load_model(checkpoint, tokenizer_config, device).to(self._dtype) if checkpoint is not None else None + ) + self._enc_model = ( + load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) + if checkpoint_enc is not None + else None + ) + self._dec_model = ( + load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) + if checkpoint_dec is not None + else None + ) + + @torch.no_grad() + def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: + """Reconstrcuts a batch of video tensors after embedding into a latent. + + Args: + video: The input video Bx3xTxHxW layout, range [-1..1]. + Returns: + The reconstructed video, layout Bx3xTxHxW, range [-1..1]. + """ + if self._full_model is not None: + output_tensor = self._full_model(input_tensor) + output_tensor = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor + else: + output_latent = self.encode(input_tensor)[0] + output_tensor = self.decode(output_latent) + return output_tensor + + @torch.no_grad() + def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: + """Encodes a numpy video into a CausalVideo latent or code. + + Args: + input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1]. + Returns: + For causal continuous video (CV) tokenizer, the tuple contains: + - The latent embedding, Bx16x(t)x(h)x(w), where the compression + rate is (T/t x H/h x W/w), and channel dimension of 16. + For causal discrete video (DV) tokenizer, the tuple contains: + 1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which + is formed by FSQ levels of (8,8,8,5,5,5). + 2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate + is again (T/t x H/h x W/w), and channel dimension of 6. + """ + assert input_tensor.ndim == 5, "input video should be of 5D." + + output_latent = self._enc_model(input_tensor) + if isinstance(output_latent, torch.Tensor): + return output_latent + return output_latent[:-1] + + @torch.no_grad() + def decode(self, input_latent: torch.Tensor) -> torch.Tensor: + """Encodes a numpy video into a CausalVideo latent. + + Args: + input_latent: The continuous latent Bx16xtxhxw for CV, + or the discrete indices Bxtxhxw for DV. + Returns: + The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1]. + """ + assert input_latent.ndim >= 4, "input latent should be of 5D for continuous and 4D for discrete." + return self._dec_model(input_latent) + + def forward( + self, + video: np.ndarray, + temporal_window: int = 17, + ) -> np.ndarray: + """Reconstructs video using a pre-trained CausalTokenizer autoencoder. + Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer + in a sliding manner with a `temporal_window` size. + + Args: + video: The input video BxTxHxWx3 layout, range [0..255]. + temporal_window: The length of the temporal window to process, default=25. + Returns: + The reconstructed video in range [0..255], layout BxTxHxWx3. + """ + assert video.ndim == 5, "input video should be of 5D." + num_frames = video.shape[1] # can be of any length. + output_video_list = [] + for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)): + # Input video for the current window. + start, end = idx * temporal_window, (idx + 1) * temporal_window + input_video = video[:, start:end, ...] + + # Spatio-temporally pad input_video so it's evenly divisible. + padded_input_video, crop_region = pad_video_batch(input_video) + input_tensor = numpy2tensor(padded_input_video, dtype=self._dtype, device=self._device) + output_tensor = self.autoencode(input_tensor) + padded_output_video = tensor2numpy(output_tensor) + output_video = unpad_video_batch(padded_output_video, crop_region) + + output_video_list.append(output_video) + return np.concatenate(output_video_list, axis=1) diff --git a/UniCeption/uniception/models/libs/croco/__init__.py b/UniCeption/uniception/models/libs/croco/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/UniCeption/uniception/models/libs/croco/blocks.py b/UniCeption/uniception/models/libs/croco/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..77046d827faaa160a4410277f1b1626b057f7f97 --- /dev/null +++ b/UniCeption/uniception/models/libs/croco/blocks.py @@ -0,0 +1,249 @@ +# -------------------------------------------------------- +# Main encoder/decoder blocks for CroCo and DUSt3R +# Adopted from CroCoV2 (Naver Corporation, CC BY-NC-SA 4.0 (non-commercial use only)) +# -------------------------------------------------------- +# References: +# timm +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py +import torch +import torch.nn as nn + +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 +# Use torch.nn.functional.scaled_dot_product_attention instead of the naive PyTorch implementation +from uniception.models.utils.config import use_fused_attn + +use_torch_attn = use_fused_attn() + +import collections.abc +from itertools import repeat + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, torch_attn=use_torch_attn + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + self.torch_attn = torch_attn + self.dropout_p = attn_drop + + def forward(self, x, xpos): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + q, k, v = [qkv[:, :, i] for i in range(3)] + # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + if not self.torch_attn: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + else: + # https://pytorch.org/docs/2.4/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention + x = nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=(self.dropout_p if self.training else 0.0), scale=self.scale + ) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + rope=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, xpos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, torch_attn=use_torch_attn + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.projq = nn.Linear(dim, dim, bias=qkv_bias) + self.projk = nn.Linear(dim, dim, bias=qkv_bias) + self.projv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.rope = rope + self.torch_attn = torch_attn + self.dropout_p = attn_drop + + def forward(self, query, key, value, qpos, kpos): + B, Nq, C = query.shape + Nk = key.shape[1] + Nv = value.shape[1] + + q = self.projq(query).reshape(B, Nq, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.projk(key).reshape(B, Nk, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.projv(value).reshape(B, Nv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.rope is not None: + q = self.rope(q, qpos) + k = self.rope(k, kpos) + + if not self.torch_attn: + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + else: + # https://pytorch.org/docs/2.4/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention + x = nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=(self.dropout_p if self.training else 0.0), scale=self.scale + ) + + x = x.transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class DecoderBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_mem=True, + rope=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop + ) + self.cross_attn = CrossAttention( + dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.norm3 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.norm_y = norm_layer(dim) if norm_mem else nn.Identity() + + def forward(self, x, y, xpos, ypos): + x = x + self.drop_path(self.attn(self.norm1(x), xpos)) + y_ = self.norm_y(y) + x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos)) + x = x + self.drop_path(self.mlp(self.norm3(x))) + return x, y diff --git a/UniCeption/uniception/models/libs/croco/curope/__init__.py b/UniCeption/uniception/models/libs/croco/curope/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25e3d48a162760260826080f6366838e83e26878 --- /dev/null +++ b/UniCeption/uniception/models/libs/croco/curope/__init__.py @@ -0,0 +1,4 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from .curope2d import cuRoPE2D diff --git a/UniCeption/uniception/models/libs/croco/curope/curope.cpp b/UniCeption/uniception/models/libs/croco/curope/curope.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8fc67ca3e01b666c56f96280a12089fa4ec2e2a7 --- /dev/null +++ b/UniCeption/uniception/models/libs/croco/curope/curope.cpp @@ -0,0 +1,69 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include + +// forward declaration +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); + +void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) +{ + const int B = tokens.size(0); + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3) / 4; + + auto tok = tokens.accessor(); + auto pos = positions.accessor(); + + for (int b = 0; b < B; b++) { + for (int x = 0; x < 2; x++) { // y and then x (2d) + for (int n = 0; n < N; n++) { + + // grab the token position + const int p = pos[b][n][x]; + + for (int h = 0; h < H; h++) { + for (int d = 0; d < D; d++) { + // grab the two values + float u = tok[b][n][h][d+0+x*2*D]; + float v = tok[b][n][h][d+D+x*2*D]; + + // grab the cos,sin + const float inv_freq = fwd * p / powf(base, d/float(D)); + float c = cosf(inv_freq); + float s = sinf(inv_freq); + + // write the result + tok[b][n][h][d+0+x*2*D] = u*c - v*s; + tok[b][n][h][d+D+x*2*D] = v*c + u*s; + } + } + } + } + } +} + +void rope_2d( torch::Tensor tokens, // B,N,H,D + const torch::Tensor positions, // B,N,2 + const float base, + const float fwd ) +{ + TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); + TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); + TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); + TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); + TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); + TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); + + if (tokens.is_cuda()) + rope_2d_cuda( tokens, positions, base, fwd ); + else + rope_2d_cpu( tokens, positions, base, fwd ); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); +} diff --git a/UniCeption/uniception/models/libs/croco/curope/curope2d.py b/UniCeption/uniception/models/libs/croco/curope/curope2d.py new file mode 100644 index 0000000000000000000000000000000000000000..b7272b8f03977ab41204afda489df5dd920dad79 --- /dev/null +++ b/UniCeption/uniception/models/libs/croco/curope/curope2d.py @@ -0,0 +1,39 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +import torch + +try: + import curope as _kernels # run `python setup.py install` +except ModuleNotFoundError: + from . import curope as _kernels # run `python setup.py build_ext --inplace` + + +class cuRoPE2D_func(torch.autograd.Function): + @staticmethod + def forward(ctx, tokens, positions, base, F0=1): + ctx.save_for_backward(positions) + ctx.saved_base = base + ctx.saved_F0 = F0 + # tokens = tokens.clone() # uncomment this if inplace doesn't work + _kernels.rope_2d(tokens, positions, base, F0) + ctx.mark_dirty(tokens) + return tokens + + @staticmethod + def backward(ctx, grad_res): + positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 + _kernels.rope_2d(grad_res, positions, base, -F0) + ctx.mark_dirty(grad_res) + return grad_res, None, None, None + + +class cuRoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + + def forward(self, tokens, positions): + cuRoPE2D_func.apply(tokens.transpose(1, 2), positions, self.base, self.F0) + return tokens diff --git a/UniCeption/uniception/models/libs/croco/curope/kernels.cu b/UniCeption/uniception/models/libs/croco/curope/kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..01604c01a0e832d63bd77ef8c684145fb5d3d11b --- /dev/null +++ b/UniCeption/uniception/models/libs/croco/curope/kernels.cu @@ -0,0 +1,108 @@ +/* + Copyright (C) 2022-present Naver Corporation. All rights reserved. + Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +*/ + +#include +#include +#include +#include + +#define CHECK_CUDA(tensor) {\ + TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ + TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } +void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} + + +template < typename scalar_t > +__global__ void rope_2d_cuda_kernel( + //scalar_t* __restrict__ tokens, + torch::PackedTensorAccessor32 tokens, + const int64_t* __restrict__ pos, + const float base, + const float fwd ) + // const int N, const int H, const int D ) +{ + // tokens shape = (B, N, H, D) + const int N = tokens.size(1); + const int H = tokens.size(2); + const int D = tokens.size(3); + + // each block update a single token, for all heads + // each thread takes care of a single output + extern __shared__ float shared[]; + float* shared_inv_freq = shared + D; + + const int b = blockIdx.x / N; + const int n = blockIdx.x % N; + + const int Q = D / 4; + // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] + // u_Y v_Y u_X v_X + + // shared memory: first, compute inv_freq + if (threadIdx.x < Q) + shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); + __syncthreads(); + + // start of X or Y part + const int X = threadIdx.x < D/2 ? 0 : 1; + const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X + + // grab the cos,sin appropriate for me + const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; + const float cos = cosf(freq); + const float sin = sinf(freq); + /* + float* shared_cos_sin = shared + D + D/4; + if ((threadIdx.x % (D/2)) < Q) + shared_cos_sin[m+0] = cosf(freq); + else + shared_cos_sin[m+Q] = sinf(freq); + __syncthreads(); + const float cos = shared_cos_sin[m+0]; + const float sin = shared_cos_sin[m+Q]; + */ + + for (int h = 0; h < H; h++) + { + // then, load all the token for this head in shared memory + shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; + __syncthreads(); + + const float u = shared[m]; + const float v = shared[m+Q]; + + // write output + if ((threadIdx.x % (D/2)) < Q) + tokens[b][n][h][threadIdx.x] = u*cos - v*sin; + else + tokens[b][n][h][threadIdx.x] = v*cos + u*sin; + } +} + +void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) +{ + const int B = tokens.size(0); // batch size + const int N = tokens.size(1); // sequence length + const int H = tokens.size(2); // number of heads + const int D = tokens.size(3); // dimension per head + + TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); + TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); + TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); + TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); + + // one block for each layer, one thread per local-max + const int THREADS_PER_BLOCK = D; + const int N_BLOCKS = B * N; // each block takes care of H*D values + const int SHARED_MEM = sizeof(float) * (D + D/4); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { + rope_2d_cuda_kernel <<>> ( + //tokens.data_ptr(), + tokens.packed_accessor32(), + pos.data_ptr(), + base, fwd); //, N, H, D ); + })); +} diff --git a/UniCeption/uniception/models/libs/croco/curope/setup.py b/UniCeption/uniception/models/libs/croco/curope/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..15ed238c07046c08d1aa1fa95e4a419d67cb26ec --- /dev/null +++ b/UniCeption/uniception/models/libs/croco/curope/setup.py @@ -0,0 +1,33 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). + +from setuptools import setup +from torch import cuda +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +# compile for all possible CUDA architectures +all_cuda_archs = cuda.get_gencode_flags().replace("compute=", "arch=").split() +# alternatively, you can list cuda archs that you want, eg: +# all_cuda_archs = [ +# '-gencode', 'arch=compute_70,code=sm_70', +# '-gencode', 'arch=compute_75,code=sm_75', +# '-gencode', 'arch=compute_80,code=sm_80', +# '-gencode', 'arch=compute_86,code=sm_86' +# ] + +setup( + name="curope", + ext_modules=[ + CUDAExtension( + name="curope", + sources=[ + "curope.cpp", + "kernels.cu", + ], + extra_compile_args=dict( + nvcc=["-O3", "--ptxas-options=-v", "--use_fast_math"] + all_cuda_archs, cxx=["-O3"] + ), + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/UniCeption/uniception/models/libs/croco/dpt_block.py b/UniCeption/uniception/models/libs/croco/dpt_block.py new file mode 100644 index 0000000000000000000000000000000000000000..66c77c9e8a0f62404fdc83bdec6a35e619e534d3 --- /dev/null +++ b/UniCeption/uniception/models/libs/croco/dpt_block.py @@ -0,0 +1,530 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# -------------------------------------------------------- +# DPT head for ViTs +# -------------------------------------------------------- +# References: +# https://github.com/isl-org/DPT +# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + + +def make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + scratch.layer_rn = nn.ModuleList( + [ + scratch.layer1_rn, + scratch.layer2_rn, + scratch.layer3_rn, + scratch.layer4_rn, + ] + ) + + return scratch + + +class SineActivation(nn.Module): + def __init__(self, dim=None, on_channels=False): + super().__init__() + + self.dim = dim + self.on_channels = on_channels + + def forward(self, x): + return torch.sin(x) + + +class GaussianActivation(nn.Module): + def __init__(self, dim=None, on_channels=False): + super().__init__() + self.dim = dim + self.on_channels = on_channels + + def forward(self, x): + return torch.exp(-(x**2)) + + +class XCosineXActivation(nn.Module): + def __init__(self, dim=None, on_channels=False): + super().__init__() + self.dim = dim + self.on_channels = on_channels + + def forward(self, x): + return x * torch.cos(x) + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=not self.bn, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + width_ratio=1, + ): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + self.width_ratio = width_ratio + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + if self.width_ratio != 1: + res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode="bilinear") + + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + if self.width_ratio != 1: + # and output.shape[3] < self.width_ratio * output.shape[2] + # size=(image.shape[]) + if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio: + shape = 3 * output.shape[3] + else: + shape = int(self.width_ratio * 2 * output.shape[2]) + output = F.interpolate(output, size=(2 * output.shape[2], shape), mode="bilinear") + else: + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + output = self.out_conv(output) # This causes inconsistent gradient stride. don't know why. + return output + + +def make_nonlinearity(nonlinearity, dim=None, on_channels=False): + if nonlinearity == "relu": + return nn.ReLU(False) + elif nonlinearity == "sine": + return SineActivation(dim=dim, on_channels=on_channels) + elif nonlinearity == "gaussian": + return GaussianActivation(dim=dim, on_channels=on_channels) + elif nonlinearity == "tanh": + return nn.Tanh() + elif nonlinearity == "sigmoid": + return nn.Sigmoid() + elif nonlinearity == "gelu": + return nn.GELU() + elif nonlinearity == "xcosx": + return XCosineXActivation(dim=dim, on_channels=on_channels) + else: + raise ValueError(f"Unknown nonlinearity: {nonlinearity}") + + +def make_fusion_block(features, use_bn, width_ratio=1, nonlinearity="relu"): + + nonlinear_layer = make_nonlinearity(nonlinearity, features, on_channels=True) + + return FeatureFusionBlock_custom( + features, + nonlinear_layer, + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + width_ratio=width_ratio, + ) + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + + +class DPTOutputAdapter(nn.Module): + """DPT output adapter. + + :param num_cahnnels: Number of output channels + :param stride_level: tride level compared to the full-sized image. + E.g. 4 for 1/4th the size of the image. + :param patch_size_full: Int or tuple of the patch size over the full image size. + Patch size for smaller inputs will be computed accordingly. + :param hooks: Index of intermediate layers + :param layer_dims: Dimension of intermediate layers + :param feature_dim: Feature dimension + :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression + :param use_bn: If set to True, activates batch norm + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + + def __init__( + self, + num_channels: int = 1, + stride_level: int = 1, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ("rgb",), + hooks: List[int] = [2, 5, 8, 11], + layer_dims: List[int] = [96, 192, 384, 768], + feature_dim: int = 256, + last_dim: int = 32, + use_bn: bool = False, + dim_tokens_enc: Optional[int] = None, + head_type: str = "regression", + output_width_ratio=1, + nonlinearity="relu", + **kwargs, + ): + super().__init__() + self.num_channels = num_channels + self.stride_level = stride_level + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None + self.head_type = head_type + + # Actual patch height and width, taking into account stride of input + self.P_H = max(1, self.patch_size[0] // stride_level) + self.P_W = max(1, self.patch_size[1] // stride_level) + + self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + + if self.head_type == "regression": + # The "DPTDepthModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1), + make_nonlinearity(nonlinearity, dim=last_dim), + nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0), + ) + elif self.head_type == "semseg": + # The "DPTSegmentationModel" head + self.head = nn.Sequential( + nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(), + make_nonlinearity(nonlinearity, dim=feature_dim), + nn.Dropout(0.1, False), + nn.Conv2d(feature_dim, self.num_channels, kernel_size=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + ) + else: + raise ValueError('DPT head_type must be "regression" or "semseg".') + + if self.dim_tokens_enc is not None: + self.init(dim_tokens_enc=dim_tokens_enc) + + def init(self, dim_tokens_enc=768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + Should be called when setting up MultiMAE. + + :param dim_tokens_enc: Dimension of tokens coming from encoder + """ + # print(dim_tokens_enc) + + # Set up activation postprocessing layers + if isinstance(dim_tokens_enc, int): + dim_tokens_enc = 4 * [dim_tokens_enc] + + self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc] + + self.act_1_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[0], + out_channels=self.layer_dims[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[0], + out_channels=self.layer_dims[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + self.act_2_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[1], + out_channels=self.layer_dims[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + self.act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[2], + out_channels=self.layer_dims[2], + kernel_size=1, + stride=1, + padding=0, + ) + ) + + self.act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.dim_tokens_enc[3], + out_channels=self.layer_dims[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[3], + out_channels=self.layer_dims[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + self.act_postprocess = nn.ModuleList( + [self.act_1_postprocess, self.act_2_postprocess, self.act_3_postprocess, self.act_4_postprocess] + ) + + def adapt_tokens(self, encoder_tokens): + # Adapt tokens + x = [] + x.append(encoder_tokens[:, :]) + x = torch.cat(x, dim=-1) + return x + + def forward(self, encoder_tokens: List[torch.Tensor], image_size): + # input_info: Dict): + assert self.dim_tokens_enc is not None, "Need to call init(dim_tokens_enc) function first" + H, W = image_size + + # Number of patches in height and width + N_H = H // (self.stride_level * self.P_H) + N_W = W // (self.stride_level * self.P_W) + + # Hook decoder onto 4 layers from specified ViT layers + layers = [encoder_tokens[hook] for hook in self.hooks] + + # Extract only task-relevant tokens and ignore global tokens. + layers = [self.adapt_tokens(l) for l in layers] + + # Reshape tokens to spatial representation + layers = [rearrange(l, "b (nh nw) c -> b c nh nw", nh=N_H, nw=N_W) for l in layers] + + layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # Project layers to chosen feature dim + layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3]) + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + path_1 = self.scratch.refinenet1(path_2, layers[0]) + + # Output head + out = self.head(path_1) + + return out diff --git a/UniCeption/uniception/models/libs/croco/patch_embed.py b/UniCeption/uniception/models/libs/croco/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..544df1e18354b6293bae3afc32f5c6ae1b0e8941 --- /dev/null +++ b/UniCeption/uniception/models/libs/croco/patch_embed.py @@ -0,0 +1,127 @@ +# -------------------------------------------------------- +# Patch Embedding for CroCo and DUSt3R +# Adopted from DUSt3R (Naver Corporation, CC BY-NC-SA 4.0 (non-commercial use only)) +# -------------------------------------------------------- +import torch +import torch.nn as nn + +from uniception.models.libs.croco.blocks import to_2tuple + +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + + +def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): + assert patch_embed_cls in ["PatchEmbedCroCo", "PatchEmbedDust3R", "ManyAR_PatchEmbed"] + patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) + return patch_embed + + +class PositionGetter(object): + """Return positions of patches""" + + def __init__(self): + self.cache_positions = {} + + def __call__(self, b, h, w, device): + if not (h, w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone() + return pos + + +class PatchEmbedCroCo(nn.Module): + """Just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + self.position_getter = PositionGetter() + + def forward(self, x, **kw): + B, C, H, W = x.shape + torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + def _init_weights(self): + w = self.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + +class PatchEmbedDust3R(PatchEmbedCroCo): + def forward(self, x, **kw): + B, C, H, W = x.shape + assert ( + H % self.patch_size[0] == 0 + ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + x = self.proj(x) + pos = self.position_getter(B, x.size(2), x.size(3), x.device) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, pos + + +class ManyAR_PatchEmbed(PatchEmbedCroCo): + """Handle images with non-square aspect ratio. + All images in the same batch have the same aspect ratio. + true_shape = [(height, width) ...] indicates the actual shape of each image. + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + self.embed_dim = embed_dim + super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) + + def forward(self, img, true_shape): + B, C, H, W = img.shape + assert W >= H, f"img should be in landscape mode, but got {W=} {H=}" + assert ( + H % self.patch_size[0] == 0 + ), f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." + assert ( + W % self.patch_size[1] == 0 + ), f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." + assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" + + # size expressed in tokens + W //= self.patch_size[0] + H //= self.patch_size[1] + n_tokens = H * W + + height, width = true_shape.T + is_landscape = width >= height + is_portrait = ~is_landscape + + # allocate result + x = img.new_zeros((B, n_tokens, self.embed_dim)) + pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) + + # linear projection, transposed if necessary + x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() + x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() + + pos[is_landscape] = self.position_getter(1, H, W, pos.device) + pos[is_portrait] = self.position_getter(1, W, H, pos.device) + + x = self.norm(x) + return x, pos diff --git a/UniCeption/uniception/models/libs/croco/pos_embed.py b/UniCeption/uniception/models/libs/croco/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..cb51c9a30ef0a4b778a3be89465090a25ec91ef7 --- /dev/null +++ b/UniCeption/uniception/models/libs/croco/pos_embed.py @@ -0,0 +1,155 @@ +# Copyright (C) 2022-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- +import numpy as np +import torch + +torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12 + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if n_cls_token > 0: + pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +# ---------------------------------------------------------- +# RoPE2D: RoPE implementation in 2D +# ---------------------------------------------------------- +try: + from uniception.models.libs.croco.curope import cuRoPE2D + + RoPE2D = cuRoPE2D +except: + + class RoPE2D(torch.nn.Module): + def __init__(self, freq=100.0, F0=1.0): + super().__init__() + self.base = freq + self.F0 = F0 + self.cache = {} + + def get_cos_sin(self, D, seq_len, device, dtype): + if (D, seq_len, device, dtype) not in self.cache: + inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) + freqs = torch.cat((freqs, freqs), dim=-1) + cos = freqs.cos() # (Seq, Dim) + sin = freqs.sin() + self.cache[D, seq_len, device, dtype] = (cos, sin) + return self.cache[D, seq_len, device, dtype] + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope1d(self, tokens, pos1d, cos, sin): + assert pos1d.ndim == 2 + cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] + sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] + return (tokens * cos) + (self.rotate_half(tokens) * sin) + + def forward(self, tokens, positions): + """ + input: + * tokens: batch_size x nheads x ntokens x dim + * positions: batch_size x ntokens x 2 (y and x position of each token) + output: + * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) + """ + assert tokens.size(3) % 2 == 0, "number of dimensions should be a multiple of two" + D = tokens.size(3) // 2 + assert positions.ndim == 3 and positions.shape[-1] == 2 # Batch, Seq, 2 + cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype) + # split features into two along the feature dimension, and apply rope1d on each half + y, x = tokens.chunk(2, dim=-1) + y = self.apply_rope1d(y, positions[:, :, 0], cos, sin) + x = self.apply_rope1d(x, positions[:, :, 1], cos, sin) + tokens = torch.cat((y, x), dim=-1) + return tokens diff --git a/UniCeption/uniception/models/prediction_heads/README.md b/UniCeption/uniception/models/prediction_heads/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e0873d00afa287e18cc9e179852db1955fdde5a0 --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/README.md @@ -0,0 +1,29 @@ +# UniCeption Prediction Heads + +## Currently Implemented Pathways + +``` +IntermediateFeatureReturner +├── DPTFeature +│ ├── DPTRegressionProcessor +│ │ ├── FlowAdaptor +│ │ ├── DepthAdaptor +| │ ├── PointMapAdaptor +│ │ ├── ConfidenceAdaptor +│ │ ├── ValueWithConfidenceAdaptor +│ │ └── FlowWithConfidenceAdaptor +│ │ └── PointMapWithConfidenceAdaptor +│ └── DPTSegmentationProcessor +│ └── MaskAdaptor +└── LinearFeature +│ └── ...(all adaptors) +└── PoseHead +``` + +The diagram outlines how implemented classes are designed to interact with each other. + +## Developer Guidelines + +Please follow the main UniCeption developer guidelines described in `README.md` when contributing to the prediction heads. Make sure to test your different implementations and add necessary unit tests. + +## Happy Coding! diff --git a/UniCeption/uniception/models/prediction_heads/__init__.py b/UniCeption/uniception/models/prediction_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9551edc7b27b21f37db2802cc3ee635033fa9e --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/__init__.py @@ -0,0 +1,18 @@ +""" +Init UniCeption Prediction Heads +""" + +from uniception.models.prediction_heads.base import ( + AdaptorInput, + AdaptorOutput, + Covariance2DAdaptorOutput, + MaskAdaptorOutput, + PredictionHeadInput, + PredictionHeadOutput, + RegressionAdaptorOutput, + RegressionWithConfidenceAdaptorOutput, + RegressionWithConfidenceAndMaskAdaptorOutput, + RegressionWithMaskAdaptorOutput, + UniCeptionAdaptorBase, + UniCeptionPredictionHeadBase, +) diff --git a/UniCeption/uniception/models/prediction_heads/adaptors.py b/UniCeption/uniception/models/prediction_heads/adaptors.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0c0b95fd296aebf5051939f036a0b6a0df4786 --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/adaptors.py @@ -0,0 +1,1765 @@ +""" +Adaptors for the UniCeption Prediction Heads. +""" + +from functools import lru_cache +from math import isfinite +from typing import List, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + +from uniception.models.prediction_heads import ( + AdaptorInput, + AdaptorOutput, + Covariance2DAdaptorOutput, + MaskAdaptorOutput, + RegressionAdaptorOutput, + RegressionWithConfidenceAdaptorOutput, + RegressionWithConfidenceAndMaskAdaptorOutput, + RegressionWithMaskAdaptorOutput, + UniCeptionAdaptorBase, +) + + +class FlowAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + flow_mean: Union[Tuple[float, float], List[float]], + flow_std: Union[Tuple[float, float], List[float]], + base_shape: Tuple[int, int], + scale_strategy: str, + output_normalized_coordinate: bool = False, + *args, + **kwargs, + ): + """ + Adaptor for the Flow head in UniCeption. + + Args: + name (str): Name of the adaptor. + flow_mean (torch.Tensor): (2,) Mean of the flow. + flow_std (torch.Tensor): (2,) Standard deviation of the flow. + base_shape (Tuple[int, int]): Base shape of the flow mean and std. + scale_strategy (str): Strategy for scaling the flow, either + - none: No scaling, network will be unnormalized with the given mean and std for all input shapes + - scale_width: scale the output for "none" by actual width divided by base width for both X and Y + - scale_height: scale the output for "none" by actual height divided by base height for both X and Y + - scale_both: scale the output for "none" by actual dimension / base dimension individually for X and Y + output_normalized_coordinate (bool): If True, will subtract the (X, Y) coordinate of the output pixel from input x after it is being scaled to pixel coordinates. + In other words, the network will predict the pixel position that the source pixel will land on the target image, rather than the flow. + """ + super().__init__(name, required_channels=2, *args, **kwargs) + + self.name: str = name + + flow_mean = list(flow_mean) + flow_std = list(flow_std) + + # Handle the case where flow_mean and flow_std are passed as tuples + if isinstance(flow_mean, tuple) or isinstance(flow_mean, list): + flow_mean = torch.tensor(flow_mean, dtype=torch.float32) + assert flow_mean.shape == (2,), f"Flow mean must be a 2D tensor, got {flow_mean.shape}" + + if isinstance(flow_std, tuple) or isinstance(flow_std, list): + flow_std = torch.tensor(flow_std, dtype=torch.float32) + assert flow_std.shape == (2,), f"Flow std must be a 2D tensor, got {flow_std.shape}" + + self.register_buffer("flow_mean", flow_mean.view(1, 2, 1, 1)) + self.register_buffer("flow_std", flow_std.view(1, 2, 1, 1)) + + self.base_shape = list(base_shape) + self.scale_strategy = scale_strategy + self.output_normalized_coordinate = output_normalized_coordinate + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the FlowAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + + Returns: + AdaptorOutput: Output of the adaptor. + """ + + x = adaptor_input.adaptor_feature + + # Check the number of channels to avoid passing BHWC features + _, C, _, _ = x.shape + assert C == 2, f"FlowAdaptor requires BCHW format with 2 channels, got {C} channels" + + output_shape = adaptor_input.output_shape_hw + + if not self.output_normalized_coordinate: + x_scale, y_scale = self._get_xy_scale(output_shape) + + # Scale the flow by stored mean, std and scaling factors + flow_mean = self.flow_mean * torch.tensor([x_scale, y_scale], dtype=torch.float32, device=x.device).view( + 1, 2, 1, 1 + ) + flow_std = self.flow_std * torch.tensor([x_scale, y_scale], dtype=torch.float32, device=x.device).view( + 1, 2, 1, 1 + ) + + # Unnormalize the flow + x = x * flow_std + flow_mean + else: + # Optionally subtract the coordinate bias + wh_normalizer = torch.tensor( + adaptor_input.output_shape_hw[::-1], dtype=torch.float32, device=x.device + ).view(1, 2, 1, 1) + + x = 0.5 * (x + 1) * wh_normalizer + 0.5 + + coords = self._get_coordinate_bias(output_shape, x.device) + x = x - coords + + return RegressionAdaptorOutput(value=x) + + def _get_xy_scale(self, output_shape: Tuple[int, int]): + """ + Get the scaling factor for the X and Y dimensions. + + Args: + output_shape (Tuple[int, int]): HW Shape of the output. + + Returns: + Tuple[float, float]: Scaling factors for X and Y dimensions. + """ + if self.scale_strategy == "none": + return 1.0, 1.0 + elif self.scale_strategy == "scale_width": + return output_shape[1] / self.base_shape[1], output_shape[1] / self.base_shape[1] + elif self.scale_strategy == "scale_height": + return output_shape[0] / self.base_shape[0], output_shape[0] / self.base_shape[0] + elif self.scale_strategy == "scale_both": + return output_shape[1] / self.base_shape[1], output_shape[0] / self.base_shape[0] + else: + raise ValueError(f"Invalid scaling strategy: {self.scale_strategy}") + + @lru_cache(maxsize=10) + def _get_coordinate_bias(self, output_shape: Tuple[int, int], device: str): + """ + Get the (X, Y) coordinate image for the given output shape. + + Args: + output_shape (Tuple[int, int]): HW Shape of the output. + device: device to store the tensor on + + Returns: + torch.Tensor: (2, H, W) tensor with X and Y coordinates, at device. This coordinate value will + include 0.5 px offset - i.e. the center of the top-left pixel is (0.5, 0.5). + """ + + H, W = output_shape + + coords = torch.stack( + torch.meshgrid( + torch.arange(0, W, device=device, dtype=torch.float32) + 0.5, + torch.arange(0, H, device=device, dtype=torch.float32) + 0.5, + indexing="xy", + ), + dim=0, + ) + + return coords + + +class ScaleAdaptor(UniCeptionAdaptorBase): + def __init__(self, name: str, mode: str, vmin: float = 0, vmax: float = np.inf, *args, **kwargs): + """ + Adaptor for scale prediction in UniCeption. + + Args: + name (str): Name of the adaptor. + mode (str): Mode of the scale prediction, either "linear", "square" or "exp". Scales the predicted scaling factor accordingly. + vmin (float): Minimum value of the scale prediction after scaling. + vmax (float): Maximum value of the scale prediction after scaling. + """ + super().__init__(name, required_channels=1, *args, **kwargs) + + self.mode = mode + self.vmin = vmin + self.vmax = vmax + + self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the ScaleAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x 1 x ...) + Returns: + AdaptorOutput: Output of the adaptor. + """ + predicted_scale_factor = adaptor_input.adaptor_feature + output_scale_factor = None + + if self.mode == "linear": + output_scale_factor = predicted_scale_factor + elif self.mode == "square": + output_scale_factor = predicted_scale_factor.square() + elif self.mode == "exp": + output_scale_factor = torch.exp(predicted_scale_factor) + + if not self.no_bounds: + output_scale_factor = output_scale_factor.clip(self.vmin, self.vmax) + + return AdaptorOutput(value=output_scale_factor) + + +class DepthAdaptor(UniCeptionAdaptorBase): + def __init__(self, name: str, mode: str, vmin: float = 0, vmax: float = np.inf, *args, **kwargs): + """ + Adaptor for the Depth head in UniCeption. + + Args: + name (str): Name of the adaptor. + mode (str): Mode of the depth, either "linear", "square" or "exp". Scales the depth accordingly. + vmin (float): Minimum value of the depth after scaling. + vmax (float): Maximum value of the depth after scaling. + """ + super().__init__(name, required_channels=1, *args, **kwargs) + + self.mode = mode + self.vmin = vmin + self.vmax = vmax + + self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the DepthAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + Returns: + AdaptorOutput: Output of the adaptor. + """ + x = adaptor_input.adaptor_feature + output_depth = None + + if self.mode == "linear": + output_depth = x + elif self.mode == "square": + output_depth = x**2 + elif self.mode == "exp": + output_depth = torch.exp(x) + else: + raise ValueError(f"Invalid mode: {self.mode}") + + if not self.no_bounds: + output_depth = output_depth.clip(self.vmin, self.vmax) + + return RegressionAdaptorOutput(value=output_depth) + + +class PointMapAdaptor(UniCeptionAdaptorBase): + def __init__(self, name: str, mode: str, vmin: float = -np.inf, vmax: float = np.inf, *args, **kwargs): + """ + Adaptor for the PointMap head in UniCeption. + + Args: + name (str): Name of the adaptor. + mode (str): Mode of the point map, either "linear", "square" or "exp". Scales the distance of the points to the world origin accordingly. + vmin (float): Minimum value of the point map after scaling. + vmax (float): Maximum value of the point map after scaling. + """ + super().__init__(name, required_channels=3, *args, **kwargs) + + self.mode = mode + self.vmin = vmin + self.vmax = vmax + + self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the PointMapAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + Returns: + AdaptorOutput: Output of the adaptor. + """ + xyz = adaptor_input.adaptor_feature + output_xyz = None + + if self.mode != "linear": + if self.mode == "square": + # Compute distance to world origin + d = xyz.norm(dim=1, keepdim=True) + output_xyz = xyz / d.clip(min=1e-8) + # Scale the distance to world origin based on mode + output_xyz = output_xyz * d.square() + elif self.mode == "exp": + # Compute distance to world origin + d = xyz.norm(dim=1, keepdim=True) + output_xyz = xyz / d.clip(min=1e-8) + # Scale the distance to world origin based on mode + output_xyz = output_xyz * torch.expm1(d) + elif self.mode == "z_exp": + xy, z = xyz.split([2, 1], dim=1) + z = torch.exp(z) + output_xyz = torch.cat([xy * z, z], dim=1) + else: + raise ValueError(f"Invalid mode: {self.mode}") + else: + output_xyz = xyz + + if not self.no_bounds: + output_xyz = output_xyz.clip(self.vmin, self.vmax) + + return RegressionAdaptorOutput(value=output_xyz) + + +class RayOriginsAdaptor(UniCeptionAdaptorBase): + def __init__(self, name: str, mode: str, vmin: float = -np.inf, vmax: float = np.inf, *args, **kwargs): + """ + Adaptor for the RayOrigins head in UniCeption. + + Args: + name (str): Name of the adaptor. + mode (str): Mode of the ray origins, either "linear", "square" or "exp". Scales the distance of the ray origins to the world origin accordingly. + vmin (float): Minimum value of the ray origins after scaling. + vmax (float): Maximum value of the ray origins after scaling. + """ + super().__init__(name, required_channels=3, *args, **kwargs) + + self.mode = mode + self.vmin = vmin + self.vmax = vmax + + self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the RayOriginsAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + Returns: + AdaptorOutput: Output of the adaptor. + """ + ray_origins = adaptor_input.adaptor_feature + output_ray_origins = None + + if self.mode != "linear": + # Compute distance to world origin + d = ray_origins.norm(dim=1, keepdim=True) + output_ray_origins = ray_origins / d.clip(min=1e-8) + # Scale the distance to world origin based on mode + if self.mode == "square": + output_ray_origins = output_ray_origins * d.square() + elif self.mode == "exp": + output_ray_origins = output_ray_origins * torch.expm1(d) + else: + raise ValueError(f"Invalid mode: {self.mode}") + else: + output_ray_origins = ray_origins + + if not self.no_bounds: + output_ray_origins = output_ray_origins.clip(self.vmin, self.vmax) + + return RegressionAdaptorOutput(value=output_ray_origins) + + +class RayDirectionsAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + mode: str, + normalize_to_unit_sphere: bool, + normalize_to_unit_image_plane: bool, + vmin: float = -np.inf, + vmax: float = np.inf, + clamp_min_of_z_dir: bool = False, + z_dir_min: float = 1, + *args, + **kwargs, + ): + """ + Adaptor for the RayDirections head in UniCeption. + + Args: + name (str): Name of the adaptor. + mode (str): Mode of the ray directions. Scales the directions accordingly. Currently only supports "linear". + normalize_to_unit_sphere (bool): If True, will normalize the ray directions to unit vectors. + normalize_to_unit_image_plane (bool): If True, will normalize the ray directions so that the z component is 1. + vmin (float): Minimum value of the ray directions after scaling & before any sort of normalization. (default: -inf) + vmax (float): Maximum value of the ray directions after scaling & before any sort of normalization. (default: inf) + clamp_min_of_z_dir (bool): If True, will clamp the z component of the ray directions before normalization. (default: False) + z_dir_min (float): If clamp_min_of_z_dir is True, this minimum value is used for clamping. (default: 1) + """ + super().__init__(name, required_channels=3, *args, **kwargs) + + self.mode = mode + self.normalize_to_unit_sphere = normalize_to_unit_sphere + self.normalize_to_unit_image_plane = normalize_to_unit_image_plane + self.vmin = vmin + self.vmax = vmax + self.clamp_min_of_z_dir = clamp_min_of_z_dir + self.z_dir_min = z_dir_min + + self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the RayDirectionsAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + Returns: + AdaptorOutput: Output of the adaptor. + """ + ray_directions = adaptor_input.adaptor_feature + + if self.mode == "linear": + output_ray_directions = ray_directions + else: + raise ValueError(f"Invalid mode: {self.mode}") + + if not self.no_bounds: + output_ray_directions = output_ray_directions.clip(self.vmin, self.vmax) + + if self.clamp_min_of_z_dir: + # Clamp the z component of ray directions + output_ray_directions_xy = output_ray_directions[:, :2] + clamped_output_ray_directions_z = torch.clamp(output_ray_directions[:, 2:3], min=self.z_dir_min) + output_ray_directions = torch.cat((output_ray_directions_xy, clamped_output_ray_directions_z), dim=1) + + if self.normalize_to_unit_sphere: + # Normalize the ray directions to unit vectors + output_ray_dirs_norm = output_ray_directions.norm(dim=1, keepdim=True).clip(min=1e-8) + output_ray_directions = output_ray_directions / output_ray_dirs_norm + elif self.normalize_to_unit_image_plane: + # Normalize the ray directions so that the z component is 1 + output_ray_directions_z = output_ray_directions[:, 2:3] + output_ray_directions = output_ray_directions / output_ray_directions_z + + return RegressionAdaptorOutput(value=output_ray_directions) + + +class RayDirectionsPlusDepthAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + # Ray directions adaptor + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayDirections + Depth head in UniCeption. + """ + super().__init__(name, required_channels=4, *args, **kwargs) + + self.ray_directions_adaptor = RayDirectionsAdaptor( + name, + ray_directions_mode, + ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin, + ray_directions_vmax, + ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min, + ) + self.depth_adaptor = DepthAdaptor(name, depth_mode, depth_vmin, depth_vmax) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the RayMapPlusDepthAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + Returns: + AdaptorOutput: Output of the adaptor. + """ + ray_directions, ray_depths = torch.split(adaptor_input.adaptor_feature, [3, 1], dim=1) + ray_directions_adaptor_input = AdaptorInput( + adaptor_feature=ray_directions, output_shape_hw=adaptor_input.output_shape_hw + ) + depth_adaptor_input = AdaptorInput(adaptor_feature=ray_depths, output_shape_hw=adaptor_input.output_shape_hw) + output_ray_directions = self.ray_directions_adaptor(ray_directions_adaptor_input) + output_depth = self.depth_adaptor(depth_adaptor_input) + output = torch.cat([output_ray_directions.value, output_depth.value], dim=1) + + return RegressionAdaptorOutput(value=output) + + +class CamTranslationAdaptor(UniCeptionAdaptorBase): + def __init__(self, name: str, mode: str, vmin: float = -np.inf, vmax: float = np.inf, *args, **kwargs): + """ + Adaptor for the Camera Translation or Pose head in UniCeption. + + Args: + name (str): Name of the adaptor. + mode (str): Mode of the camera translation, either "linear", "square" or "exp". Scales the distance of the camera to the world origin accordingly. + vmin (float): Minimum value of the camera translation after scaling. + vmax (float): Maximum value of the camera translation after scaling. + """ + super().__init__(name, required_channels=3, *args, **kwargs) + + self.mode = mode + self.vmin = vmin + self.vmax = vmax + + self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the CamTranslationAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C ...) + Returns: + AdaptorOutput: Output of the adaptor. + """ + cam_trans = adaptor_input.adaptor_feature + output_cam_trans = None + + if self.mode != "linear": + # Compute distance to world origin + d = cam_trans.norm(dim=1, keepdim=True) + output_cam_trans = cam_trans / d.clip(min=1e-8) + # Scale the distance to world origin based on mode + if self.mode == "square": + output_cam_trans = output_cam_trans * d.square() + elif self.mode == "exp": + output_cam_trans = output_cam_trans * torch.expm1(d) + else: + raise ValueError(f"Invalid mode: {self.mode}") + else: + output_cam_trans = cam_trans + + if not self.no_bounds: + output_cam_trans = output_cam_trans.clip(self.vmin, self.vmax) + + return AdaptorOutput(value=output_cam_trans) + + +class QuaternionsAdaptor(UniCeptionAdaptorBase): + def __init__( + self, name: str, mode: str, normalize: bool, vmin: float = -np.inf, vmax: float = np.inf, *args, **kwargs + ): + """ + Adaptor for the Quaternions or Pose head in UniCeption. + Notation of the quaternions: (x, y, z, w) + + Args: + name (str): Name of the adaptor. + mode (str): Mode of the quaternions. Scales the quaternions accordingly before normalization. Currently only supports "linear". + normalize (bool): If True, will normalize the quaternions to unit quaternions. + vmin (float): Minimum value of the quaternions after scaling & before normalization to unit quaternions if required. + vmax (float): Maximum value of the quaternions after scaling & before normalization to unit quaternions if required. + """ + super().__init__(name, required_channels=4, *args, **kwargs) + + self.mode = mode + self.normalize = normalize + self.vmin = vmin + self.vmax = vmax + + self.no_bounds = (vmin == -float("inf")) and (vmax == float("inf")) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the QuaternionsAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C ...) + Returns: + AdaptorOutput: Output of the adaptor. + """ + quaternions = adaptor_input.adaptor_feature + + if self.mode == "linear": + output_quaternions = quaternions + else: + raise ValueError(f"Invalid mode: {self.mode}") + + if not self.no_bounds: + output_quaternions = output_quaternions.clip(self.vmin, self.vmax) + + if self.normalize: + # Normalize the quaternions to unit quaternions + output_quats_norm = output_quaternions.norm(dim=1, keepdim=True).clip(min=1e-8) + output_quaternions = output_quaternions / output_quats_norm + + return AdaptorOutput(value=output_quaternions) + + +class CamTranslationPlusQuatsAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + # Cam translation adaptor + cam_trans_mode: str, + cam_trans_vmin: float, + cam_trans_vmax: float, + # Quaternions adaptor + quaternions_mode: str, + quaternions_normalize: bool, + quaternions_vmin: float, + quaternions_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the Camera Translation + Quaternions head in UniCeption. + """ + super().__init__(name, required_channels=7, *args, **kwargs) + + self.cam_trans_adaptor = CamTranslationAdaptor(name, cam_trans_mode, cam_trans_vmin, cam_trans_vmax) + self.quaternions_adaptor = QuaternionsAdaptor( + name, quaternions_mode, quaternions_normalize, quaternions_vmin, quaternions_vmax + ) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the CamTranslationPlusQuatsAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C ...) + Returns: + AdaptorOutput: Output of the adaptor. + """ + cam_trans, quaternions = torch.split(adaptor_input.adaptor_feature, [3, 4], dim=1) + cam_trans_adaptor_input = AdaptorInput(adaptor_feature=cam_trans, output_shape_hw=adaptor_input.output_shape_hw) + quaternions_adaptor_input = AdaptorInput( + adaptor_feature=quaternions, output_shape_hw=adaptor_input.output_shape_hw + ) + output_cam_trans = self.cam_trans_adaptor(cam_trans_adaptor_input) + output_quaternions = self.quaternions_adaptor(quaternions_adaptor_input) + output = torch.cat([output_cam_trans.value, output_quaternions.value], dim=1) + + return AdaptorOutput(value=output) + + +class RayMapAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + # Ray origins adaptor + ray_origins_mode: str, + ray_origins_vmin: float, + ray_origins_vmax: float, + # Ray directions adaptor + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayMap (RayOrigins + RayDirections) head in UniCeption. + """ + super().__init__(name, required_channels=6, *args, **kwargs) + + self.ray_origins_adaptor = RayOriginsAdaptor(name, ray_origins_mode, ray_origins_vmin, ray_origins_vmax) + self.ray_directions_adaptor = RayDirectionsAdaptor( + name, + ray_directions_mode, + ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin, + ray_directions_vmax, + ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min, + ) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the RayMapAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + Returns: + AdaptorOutput: Output of the adaptor. + """ + ray_origins, ray_directions = torch.split(adaptor_input.adaptor_feature, 3, dim=1) + ray_origins_adaptor_input = AdaptorInput( + adaptor_feature=ray_origins, output_shape_hw=adaptor_input.output_shape_hw + ) + ray_directions_adaptor_input = AdaptorInput( + adaptor_feature=ray_directions, output_shape_hw=adaptor_input.output_shape_hw + ) + output_ray_origins = self.ray_origins_adaptor(ray_origins_adaptor_input) + output_ray_directions = self.ray_directions_adaptor(ray_directions_adaptor_input) + output_rays = torch.cat([output_ray_origins.value, output_ray_directions.value], dim=1) + + return RegressionAdaptorOutput(value=output_rays) + + +class RayMapPlusDepthAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + # Ray origins adaptor + ray_origins_mode: str, + ray_origins_vmin: float, + ray_origins_vmax: float, + # Ray directions adaptor + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayMap (RayOrigins + RayDirections) + Depth head in UniCeption. + """ + super().__init__(name, required_channels=7, *args, **kwargs) + + self.ray_origins_adaptor = RayOriginsAdaptor(name, ray_origins_mode, ray_origins_vmin, ray_origins_vmax) + self.ray_directions_adaptor = RayDirectionsAdaptor( + name, + ray_directions_mode, + ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin, + ray_directions_vmax, + ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min, + ) + self.depth_adaptor = DepthAdaptor(name, depth_mode, depth_vmin, depth_vmax) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the RayMapPlusDepthAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + Returns: + AdaptorOutput: Output of the adaptor. + """ + ray_origins, ray_directions, ray_depths = torch.split(adaptor_input.adaptor_feature, [3, 3, 1], dim=1) + ray_origins_adaptor_input = AdaptorInput( + adaptor_feature=ray_origins, output_shape_hw=adaptor_input.output_shape_hw + ) + ray_directions_adaptor_input = AdaptorInput( + adaptor_feature=ray_directions, output_shape_hw=adaptor_input.output_shape_hw + ) + depth_adaptor_input = AdaptorInput(adaptor_feature=ray_depths, output_shape_hw=adaptor_input.output_shape_hw) + output_ray_origins = self.ray_origins_adaptor(ray_origins_adaptor_input) + output_ray_directions = self.ray_directions_adaptor(ray_directions_adaptor_input) + output_depth = self.depth_adaptor(depth_adaptor_input) + output = torch.cat([output_ray_origins.value, output_ray_directions.value, output_depth.value], dim=1) + + return RegressionAdaptorOutput(value=output) + + +class RayMapPlusDepthPlusQuatsAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + # Ray origins adaptor + ray_origins_mode: str, + ray_origins_vmin: float, + ray_origins_vmax: float, + # Ray directions adaptor + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + # Quaternions adaptor + quaternions_mode: str, + quaternions_normalize: bool, + quaternions_vmin: float, + quaternions_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayMap (RayOrigins + RayDirections) + Depth + Quaternions head in UniCeption. + """ + super().__init__(name, required_channels=11, *args, **kwargs) + + self.ray_origins_adaptor = RayOriginsAdaptor(name, ray_origins_mode, ray_origins_vmin, ray_origins_vmax) + self.ray_directions_adaptor = RayDirectionsAdaptor( + name, + ray_directions_mode, + ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin, + ray_directions_vmax, + ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min, + ) + self.depth_adaptor = DepthAdaptor(name, depth_mode, depth_vmin, depth_vmax) + self.quaternions_adaptor = QuaternionsAdaptor( + name, quaternions_mode, quaternions_normalize, quaternions_vmin, quaternions_vmax + ) + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the RayMapPlusDepthPlusQuatsAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + Returns: + AdaptorOutput: Output of the adaptor. + """ + ray_origins, ray_directions, ray_depths, ray_quaternions = torch.split( + adaptor_input.adaptor_feature, [3, 3, 1, 4], dim=1 + ) + ray_origins_adaptor_input = AdaptorInput( + adaptor_feature=ray_origins, output_shape_hw=adaptor_input.output_shape_hw + ) + ray_directions_adaptor_input = AdaptorInput( + adaptor_feature=ray_directions, output_shape_hw=adaptor_input.output_shape_hw + ) + depth_adaptor_input = AdaptorInput(adaptor_feature=ray_depths, output_shape_hw=adaptor_input.output_shape_hw) + quaternions_adaptor_input = AdaptorInput( + adaptor_feature=ray_quaternions, output_shape_hw=adaptor_input.output_shape_hw + ) + output_ray_origins = self.ray_origins_adaptor(ray_origins_adaptor_input) + output_ray_directions = self.ray_directions_adaptor(ray_directions_adaptor_input) + output_ray_depths = self.depth_adaptor(depth_adaptor_input) + output_ray_quaternions = self.quaternions_adaptor(quaternions_adaptor_input) + output = torch.cat( + [ + output_ray_origins.value, + output_ray_directions.value, + output_ray_depths.value, + output_ray_quaternions.value, + ], + dim=1, + ) + + return RegressionAdaptorOutput(value=output) + + +class ConfidenceAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + confidence_type: str, + vmin: float, + vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the Confidence head in UniCeption. + + Args: + name (str): Name of the adaptor. + confidence_type (str): Type of the confidence, either + - exp: Exponential confidence + - sigmoid: Sigmoid confidence + vmin (float): Minimum value of the confidence. + vmax (float): Maximum value of the confidence. + """ + super().__init__(name, required_channels=1, *args, **kwargs) + + self.confidence_type = confidence_type + self.vmin = vmin + self.vmax = vmax + + assert vmin < vmax, "vmin must be less than vmax" + + if confidence_type == "sigmoid": + assert isfinite(vmin) and isfinite(vmax), "vmin and vmax must be finite for sigmoid confidence" + assert vmin >= 0 + + def forward(self, adaptor_input: AdaptorInput): + """ + Forward pass for the ConfidenceAdaptor. + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. (B x C x H x W) + Returns: + AdaptorOutput: Output of the adaptor. + """ + + x = adaptor_input.adaptor_feature + + if self.confidence_type == "exp": + confidence = self.vmin + x.exp().clip(max=self.vmax - self.vmin) + + return RegressionAdaptorOutput(value=confidence) + + elif self.confidence_type == "sigmoid": + confidence = torch.sigmoid(x) + + confidence = confidence * (self.vmax - self.vmin) + self.vmin + + return RegressionAdaptorOutput(value=confidence) + + elif self.confidence_type == "softmax": + B, C, H, W = x.shape + confidence = torch.nn.functional.softmax(x.reshape(B, C, -1), dim=-1).reshape(B, C, H, W) * (H * W) + + return RegressionAdaptorOutput(value=confidence) + + +class Covariance2DAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + parametrization: str = "exp_tanh", + *args, + **kwargs, + ): + """ + Adaptor for the Covariance2D head in UniCeption. + """ + super().__init__(name, required_channels=3, *args, **kwargs) + self.parametrization = parametrization + + def forward(self, adaptor_input: AdaptorInput): + x = adaptor_input.adaptor_feature + + if self.parametrization == "exp_tanh": + c1, c2, s = torch.split(x, 1, dim=1) + + diag_exponent = (c1 + c2) / 2 + tanh_s = s.tanh() + + cov = torch.cat([c1.exp(), c2.exp(), tanh_s * torch.exp(diag_exponent)], dim=1) + + log_det = c1 + c2 + torch.log(1 - torch.square(tanh_s) + 1e-8) + + inv_coeff = 1 / (1 - torch.square(tanh_s) + 1e-8) + inv_cov = inv_coeff * torch.cat( + [torch.exp(-c1), torch.exp(-c2), -tanh_s * torch.exp(-diag_exponent)], dim=1 + ) + + else: + raise ValueError(f"Invalid parametrization: {self.parametrization}") + + return Covariance2DAdaptorOutput(covariance=cov, log_det=log_det, inv_covariance=inv_cov) + + +class MaskAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + *args, + **kwargs, + ): + """ + Adaptor for the Mask head in UniCeption. + """ + super().__init__(name, required_channels=1, *args, **kwargs) + + def forward(self, adaptor_input: AdaptorInput): + x = adaptor_input.adaptor_feature + + mask = torch.sigmoid(x) + + return MaskAdaptorOutput(logits=x, mask=mask) + + +class ValueWithConfidenceAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + value_adaptor: UniCeptionAdaptorBase, + confidence_adaptor: UniCeptionAdaptorBase, + *args, + **kwargs, + ): + """ + Adaptor for the Value with Confidence head in UniCeption. + + Args: + name (str): Name of the adaptor. + value_adaptor (UniCeptionAdaptorBase): Adaptor for the value. + confidence_adaptor (UniCeptionAdaptorBase): Adaptor for the confidence. + """ + + super().__init__( + name, + required_channels=value_adaptor.required_channels + confidence_adaptor.required_channels, + *args, + **kwargs, + ) + + self.value_adaptor = value_adaptor + self.confidence_adaptor = confidence_adaptor + + def forward(self, adaptor_input: AdaptorInput): + value_input, confidence_input = torch.split( + adaptor_input.adaptor_feature, + [self.value_adaptor.required_channels, self.confidence_adaptor.required_channels], + dim=1, + ) + value_adaptor_input = AdaptorInput(adaptor_feature=value_input, output_shape_hw=adaptor_input.output_shape_hw) + confidence_adaptor_input = AdaptorInput( + adaptor_feature=confidence_input, output_shape_hw=adaptor_input.output_shape_hw + ) + value_output = self.value_adaptor(value_adaptor_input) + confidence_output = self.confidence_adaptor(confidence_adaptor_input) + + return RegressionWithConfidenceAdaptorOutput(value=value_output.value, confidence=confidence_output.value) + + +class FlowWithConfidenceAdaptor(ValueWithConfidenceAdaptor): + def __init__( + self, + name: str, + # Flow adaptor + flow_mean: torch.Tensor, + flow_std: torch.Tensor, + base_shape: Tuple[int, int], + scale_strategy: str, + output_normalized_coordinate: bool, + # Confidence adaptor + confidence_type: str, + vmin: float, + vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the Flow with Confidence head in UniCeption. + """ + flow_adaptor = FlowAdaptor( + name=f"{name}", + flow_mean=flow_mean, + flow_std=flow_std, + base_shape=base_shape, + scale_strategy=scale_strategy, + output_normalized_coordinate=output_normalized_coordinate, + ) + + confidence_adaptor = ConfidenceAdaptor( + name=f"{name}_confidence", confidence_type=confidence_type, vmin=vmin, vmax=vmax + ) + + super().__init__(name, value_adaptor=flow_adaptor, confidence_adaptor=confidence_adaptor, *args, **kwargs) + + +class PointMapWithConfidenceAdaptor(ValueWithConfidenceAdaptor): + def __init__( + self, + name: str, + # Pointmap adaptor + pointmap_mode: str, + pointmap_vmin: float, + pointmap_vmax: float, + # Confidence adaptor + confidence_type: str, + confidence_vmin: float, + confidence_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the PointMap with Confidence head in UniCeption. + """ + pointmap_adaptor = PointMapAdaptor(name=f"{name}", mode=pointmap_mode, vmin=pointmap_vmin, vmax=pointmap_vmax) + + confidence_adaptor = ConfidenceAdaptor( + name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax + ) + + super().__init__(name, value_adaptor=pointmap_adaptor, confidence_adaptor=confidence_adaptor, *args, **kwargs) + + +class RayDirectionsPlusDepthwithConfidenceAdaptor(ValueWithConfidenceAdaptor): + def __init__( + self, + name: str, + # Ray directions adaptor + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + # Confidence adaptor + confidence_type: str, + confidence_vmin: float, + confidence_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayDirections + Depth with Confidence head in UniCeption. + """ + ray_directions_plus_depth_adaptor = RayDirectionsPlusDepthAdaptor( + name=f"{name}", + ray_directions_mode=ray_directions_mode, + ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin=ray_directions_vmin, + ray_directions_vmax=ray_directions_vmax, + ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min=ray_directions_z_dir_min, + depth_mode=depth_mode, + depth_vmin=depth_vmin, + depth_vmax=depth_vmax, + ) + + confidence_adaptor = ConfidenceAdaptor( + name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax + ) + + super().__init__( + name, + value_adaptor=ray_directions_plus_depth_adaptor, + confidence_adaptor=confidence_adaptor, + *args, + **kwargs, + ) + + +class RayMapPlusDepthwithConfidenceAdaptor(ValueWithConfidenceAdaptor): + def __init__( + self, + name: str, + # RayMap adaptor + ray_origins_mode: str, + ray_origins_vmin: float, + ray_origins_vmax: float, + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + # Confidence adaptor + confidence_type: str, + confidence_vmin: float, + confidence_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayMap (RayOrigins + RayDirections) + Depth with Confidence head in UniCeption. + """ + raymap_plus_depth_adaptor = RayMapPlusDepthAdaptor( + name=f"{name}", + ray_origins_mode=ray_origins_mode, + ray_origins_vmin=ray_origins_vmin, + ray_origins_vmax=ray_origins_vmax, + ray_directions_mode=ray_directions_mode, + ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin=ray_directions_vmin, + ray_directions_vmax=ray_directions_vmax, + ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min=ray_directions_z_dir_min, + depth_mode=depth_mode, + depth_vmin=depth_vmin, + depth_vmax=depth_vmax, + ) + + confidence_adaptor = ConfidenceAdaptor( + name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax + ) + + super().__init__( + name, value_adaptor=raymap_plus_depth_adaptor, confidence_adaptor=confidence_adaptor, *args, **kwargs + ) + + +class RayMapPlusDepthPlusQuatswithConfidenceAdaptor(ValueWithConfidenceAdaptor): + def __init__( + self, + name: str, + # RayMap adaptor + ray_origins_mode: str, + ray_origins_vmin: float, + ray_origins_vmax: float, + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + # Quaternions adaptor + quaternions_mode: str, + quaternions_normalize: bool, + quaternions_vmin: float, + quaternions_vmax: float, + # Confidence adaptor + confidence_type: str, + confidence_vmin: float, + confidence_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayMap (RayOrigins + RayDirections) + Depth + Quaternions with Confidence head in UniCeption. + """ + raymap_plus_depth_plus_quats_adaptor = RayMapPlusDepthPlusQuatsAdaptor( + name=f"{name}", + ray_origins_mode=ray_origins_mode, + ray_origins_vmin=ray_origins_vmin, + ray_origins_vmax=ray_origins_vmax, + ray_directions_mode=ray_directions_mode, + ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin=ray_directions_vmin, + ray_directions_vmax=ray_directions_vmax, + ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min=ray_directions_z_dir_min, + depth_mode=depth_mode, + depth_vmin=depth_vmin, + depth_vmax=depth_vmax, + quaternions_mode=quaternions_mode, + quaternions_normalize=quaternions_normalize, + quaternions_vmin=quaternions_vmin, + quaternions_vmax=quaternions_vmax, + ) + + confidence_adaptor = ConfidenceAdaptor( + name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax + ) + + super().__init__( + name, + value_adaptor=raymap_plus_depth_plus_quats_adaptor, + confidence_adaptor=confidence_adaptor, + *args, + **kwargs, + ) + + +class ValueWithMaskAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + value_adaptor: UniCeptionAdaptorBase, + mask_adaptor: UniCeptionAdaptorBase, + *args, + **kwargs, + ): + """ + Adaptor for the Value with Mask head in UniCeption. + + Args: + name (str): Name of the adaptor. + value_adaptor (UniCeptionAdaptorBase): Adaptor for the value. + mask_adaptor (UniCeptionAdaptorBase): Adaptor for the mask. + """ + + super().__init__( + name, + required_channels=value_adaptor.required_channels + mask_adaptor.required_channels, + *args, + **kwargs, + ) + + self.value_adaptor = value_adaptor + self.mask_adaptor = mask_adaptor + + def forward(self, adaptor_input: AdaptorInput): + value_input, mask_input = torch.split( + adaptor_input.adaptor_feature, + [self.value_adaptor.required_channels, self.mask_adaptor.required_channels], + dim=1, + ) + value_adaptor_input = AdaptorInput(adaptor_feature=value_input, output_shape_hw=adaptor_input.output_shape_hw) + mask_adaptor_input = AdaptorInput(adaptor_feature=mask_input, output_shape_hw=adaptor_input.output_shape_hw) + value_output = self.value_adaptor(value_adaptor_input) + mask_output = self.mask_adaptor(mask_adaptor_input) + + return RegressionWithMaskAdaptorOutput( + value=value_output.value, mask=mask_output.mask, logits=mask_output.logits + ) + + +class PointMapWithMaskAdaptor(ValueWithMaskAdaptor): + def __init__( + self, + name: str, + # Pointmap adaptor + pointmap_mode: str, + pointmap_vmin: float, + pointmap_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the PointMap with Confidence head in UniCeption. + """ + pointmap_adaptor = PointMapAdaptor(name=f"{name}", mode=pointmap_mode, vmin=pointmap_vmin, vmax=pointmap_vmax) + + mask_adaptor = MaskAdaptor(name=f"{name}_mask") + + super().__init__(name, value_adaptor=pointmap_adaptor, mask_adaptor=mask_adaptor, *args, **kwargs) + + +class RayDirectionsPlusDepthwithMaskAdaptor(ValueWithMaskAdaptor): + def __init__( + self, + name: str, + # Ray directions adaptor + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayDirections + Depth with Mask head in UniCeption. + """ + ray_directions_plus_depth_adaptor = RayDirectionsPlusDepthAdaptor( + name=f"{name}", + ray_directions_mode=ray_directions_mode, + ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin=ray_directions_vmin, + ray_directions_vmax=ray_directions_vmax, + ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min=ray_directions_z_dir_min, + depth_mode=depth_mode, + depth_vmin=depth_vmin, + depth_vmax=depth_vmax, + ) + + mask_adaptor = MaskAdaptor(name=f"{name}_mask") + + super().__init__( + name, value_adaptor=ray_directions_plus_depth_adaptor, mask_adaptor=mask_adaptor, *args, **kwargs + ) + + +class RayMapPlusDepthwithMaskAdaptor(ValueWithMaskAdaptor): + def __init__( + self, + name: str, + # RayMap adaptor + ray_origins_mode: str, + ray_origins_vmin: float, + ray_origins_vmax: float, + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayMap (RayOrigins + RayDirections) + Depth with Mask head in UniCeption. + """ + raymap_plus_depth_adaptor = RayMapPlusDepthAdaptor( + name=f"{name}", + ray_origins_mode=ray_origins_mode, + ray_origins_vmin=ray_origins_vmin, + ray_origins_vmax=ray_origins_vmax, + ray_directions_mode=ray_directions_mode, + ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin=ray_directions_vmin, + ray_directions_vmax=ray_directions_vmax, + ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min=ray_directions_z_dir_min, + depth_mode=depth_mode, + depth_vmin=depth_vmin, + depth_vmax=depth_vmax, + ) + + mask_adaptor = MaskAdaptor(name=f"{name}_mask") + + super().__init__(name, value_adaptor=raymap_plus_depth_adaptor, mask_adaptor=mask_adaptor, *args, **kwargs) + + +class RayMapPlusDepthPlusQuatswithMaskAdaptor(ValueWithMaskAdaptor): + def __init__( + self, + name: str, + # RayMap adaptor + ray_origins_mode: str, + ray_origins_vmin: float, + ray_origins_vmax: float, + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + # Quaternions adaptor + quaternions_mode: str, + quaternions_normalize: bool, + quaternions_vmin: float, + quaternions_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayMap (RayOrigins + RayDirections) + Depth + Quaternions with Mask head in UniCeption. + """ + raymap_plus_depth_plus_quats_adaptor = RayMapPlusDepthPlusQuatsAdaptor( + name=f"{name}", + ray_origins_mode=ray_origins_mode, + ray_origins_vmin=ray_origins_vmin, + ray_origins_vmax=ray_origins_vmax, + ray_directions_mode=ray_directions_mode, + ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin=ray_directions_vmin, + ray_directions_vmax=ray_directions_vmax, + ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min=ray_directions_z_dir_min, + depth_mode=depth_mode, + depth_vmin=depth_vmin, + depth_vmax=depth_vmax, + quaternions_mode=quaternions_mode, + quaternions_normalize=quaternions_normalize, + quaternions_vmin=quaternions_vmin, + quaternions_vmax=quaternions_vmax, + ) + + mask_adaptor = MaskAdaptor(name=f"{name}_mask") + + super().__init__( + name, value_adaptor=raymap_plus_depth_plus_quats_adaptor, mask_adaptor=mask_adaptor, *args, **kwargs + ) + + +class ValueWithConfidenceAndMaskAdaptor(UniCeptionAdaptorBase): + def __init__( + self, + name: str, + value_adaptor: UniCeptionAdaptorBase, + confidence_adaptor: UniCeptionAdaptorBase, + mask_adaptor: UniCeptionAdaptorBase, + *args, + **kwargs, + ): + """ + Adaptor for the Value with Confidence & Mask head in UniCeption. + + Args: + name (str): Name of the adaptor. + value_adaptor (UniCeptionAdaptorBase): Adaptor for the value. + mask_adaptor (UniCeptionAdaptorBase): Adaptor for the mask. + """ + + super().__init__( + name, + required_channels=value_adaptor.required_channels + + confidence_adaptor.required_channels + + mask_adaptor.required_channels, + *args, + **kwargs, + ) + + self.value_adaptor = value_adaptor + self.confidence_adaptor = confidence_adaptor + self.mask_adaptor = mask_adaptor + + def forward(self, adaptor_input: AdaptorInput): + value_input, confidence_input, mask_input = torch.split( + adaptor_input.adaptor_feature, + [ + self.value_adaptor.required_channels, + self.confidence_adaptor.required_channels, + self.mask_adaptor.required_channels, + ], + dim=1, + ) + value_adaptor_input = AdaptorInput(adaptor_feature=value_input, output_shape_hw=adaptor_input.output_shape_hw) + confidence_adaptor_input = AdaptorInput( + adaptor_feature=confidence_input, output_shape_hw=adaptor_input.output_shape_hw + ) + mask_adaptor_input = AdaptorInput(adaptor_feature=mask_input, output_shape_hw=adaptor_input.output_shape_hw) + value_output = self.value_adaptor(value_adaptor_input) + confidence_output = self.confidence_adaptor(confidence_adaptor_input) + mask_output = self.mask_adaptor(mask_adaptor_input) + + return RegressionWithConfidenceAndMaskAdaptorOutput( + value=value_output.value, + confidence=confidence_output.value, + mask=mask_output.mask, + logits=mask_output.logits, + ) + + +class PointMapWithConfidenceAndMaskAdaptor(ValueWithConfidenceAndMaskAdaptor): + def __init__( + self, + name: str, + # PointMap adaptor + pointmap_mode: str, + pointmap_vmin: float, + pointmap_vmax: float, + # Confidence adaptor + confidence_type: str, + confidence_vmin: float, + confidence_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the PointMap with Confidence & Mask head in UniCeption. + """ + pointmap_adaptor = PointMapAdaptor(name=f"{name}", mode=pointmap_mode, vmin=pointmap_vmin, vmax=pointmap_vmax) + + confidence_adaptor = ConfidenceAdaptor( + name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax + ) + + mask_adaptor = MaskAdaptor(name=f"{name}_mask") + + super().__init__( + name, + value_adaptor=pointmap_adaptor, + confidence_adaptor=confidence_adaptor, + mask_adaptor=mask_adaptor, + *args, + **kwargs, + ) + + +class RayDirectionsPlusDepthwithConfidenceAndMaskAdaptor(ValueWithConfidenceAndMaskAdaptor): + def __init__( + self, + name: str, + # Ray directions adaptor + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + # Confidence adaptor + confidence_type: str, + confidence_vmin: float, + confidence_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayDirections + Depth with Confidence & Mask head in UniCeption. + """ + ray_directions_plus_depth_adaptor = RayDirectionsPlusDepthAdaptor( + name=f"{name}", + ray_directions_mode=ray_directions_mode, + ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin=ray_directions_vmin, + ray_directions_vmax=ray_directions_vmax, + ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min=ray_directions_z_dir_min, + depth_mode=depth_mode, + depth_vmin=depth_vmin, + depth_vmax=depth_vmax, + ) + + confidence_adaptor = ConfidenceAdaptor( + name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax + ) + + mask_adaptor = MaskAdaptor(name=f"{name}_mask") + + super().__init__( + name, + value_adaptor=ray_directions_plus_depth_adaptor, + confidence_adaptor=confidence_adaptor, + mask_adaptor=mask_adaptor, + *args, + **kwargs, + ) + + +class RayMapPlusDepthwithConfidenceAndMaskAdaptor(ValueWithConfidenceAndMaskAdaptor): + def __init__( + self, + name: str, + # RayMap adaptor + ray_origins_mode: str, + ray_origins_vmin: float, + ray_origins_vmax: float, + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + # Confidence adaptor + confidence_type: str, + confidence_vmin: float, + confidence_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayMap (RayOrigins + RayDirections) + Depth with Confidence & Mask head in UniCeption. + """ + raymap_plus_depth_adaptor = RayMapPlusDepthAdaptor( + name=f"{name}", + ray_origins_mode=ray_origins_mode, + ray_origins_vmin=ray_origins_vmin, + ray_origins_vmax=ray_origins_vmax, + ray_directions_mode=ray_directions_mode, + ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin=ray_directions_vmin, + ray_directions_vmax=ray_directions_vmax, + ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min=ray_directions_z_dir_min, + depth_mode=depth_mode, + depth_vmin=depth_vmin, + depth_vmax=depth_vmax, + ) + + confidence_adaptor = ConfidenceAdaptor( + name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax + ) + + mask_adaptor = MaskAdaptor(name=f"{name}_mask") + + super().__init__( + name, + value_adaptor=raymap_plus_depth_adaptor, + confidence_adaptor=confidence_adaptor, + mask_adaptor=mask_adaptor, + *args, + **kwargs, + ) + + +class RayMapPlusDepthPlusQuatswithConfidenceAndMaskAdaptor(ValueWithConfidenceAndMaskAdaptor): + def __init__( + self, + name: str, + # RayMap adaptor + ray_origins_mode: str, + ray_origins_vmin: float, + ray_origins_vmax: float, + ray_directions_mode: str, + ray_directions_normalize_to_unit_sphere: bool, + ray_directions_normalize_to_unit_image_plane: bool, + ray_directions_vmin: float, + ray_directions_vmax: float, + ray_directions_clamp_min_of_z_dir: bool, + ray_directions_z_dir_min: float, + # Depth adaptor + depth_mode: str, + depth_vmin: float, + depth_vmax: float, + # Quaternions adaptor + quaternions_mode: str, + quaternions_normalize: bool, + quaternions_vmin: float, + quaternions_vmax: float, + # Confidence adaptor + confidence_type: str, + confidence_vmin: float, + confidence_vmax: float, + *args, + **kwargs, + ): + """ + Adaptor for the RayMap (RayOrigins + RayDirections) + Depth + Quaternions with Confidence & Mask head in UniCeption. + """ + raymap_plus_depth_plus_quats_adaptor = RayMapPlusDepthPlusQuatsAdaptor( + name=f"{name}", + ray_origins_mode=ray_origins_mode, + ray_origins_vmin=ray_origins_vmin, + ray_origins_vmax=ray_origins_vmax, + ray_directions_mode=ray_directions_mode, + ray_directions_normalize_to_unit_sphere=ray_directions_normalize_to_unit_sphere, + ray_directions_normalize_to_unit_image_plane=ray_directions_normalize_to_unit_image_plane, + ray_directions_vmin=ray_directions_vmin, + ray_directions_vmax=ray_directions_vmax, + ray_directions_clamp_min_of_z_dir=ray_directions_clamp_min_of_z_dir, + ray_directions_z_dir_min=ray_directions_z_dir_min, + depth_mode=depth_mode, + depth_vmin=depth_vmin, + depth_vmax=depth_vmax, + quaternions_mode=quaternions_mode, + quaternions_normalize=quaternions_normalize, + quaternions_vmin=quaternions_vmin, + quaternions_vmax=quaternions_vmax, + ) + + confidence_adaptor = ConfidenceAdaptor( + name=f"{name}_confidence", confidence_type=confidence_type, vmin=confidence_vmin, vmax=confidence_vmax + ) + + mask_adaptor = MaskAdaptor(name=f"{name}_mask") + + super().__init__( + name, + value_adaptor=raymap_plus_depth_plus_quats_adaptor, + confidence_adaptor=confidence_adaptor, + mask_adaptor=mask_adaptor, + *args, + **kwargs, + ) diff --git a/UniCeption/uniception/models/prediction_heads/base.py b/UniCeption/uniception/models/prediction_heads/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d09a1e4cd9a435f50f8cbac48e952f1883c2bef0 --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/base.py @@ -0,0 +1,210 @@ +""" +Base Prediction Head Class for UniCeption +""" + +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from jaxtyping import Float +from torch import Tensor + + +@dataclass +class PredictionHeadInput: + last_feature: Float[Tensor, "batch_size feat_dim feat_height feat_width"] + + +@dataclass +class PredictionHeadLayeredInput: + list_features: List[Float[Tensor, "batch_size feat_dim feat_height feat_width"]] + target_output_shape: Tuple[int, int] + + +@dataclass +class PredictionHeadTokenInput: + last_feature: Float[Tensor, "batch_size feat_dim num_tokens"] + + +@dataclass +class PixelTaskOutput: + """ + PixelTaskOutput have dense pixel-wise output in BCHW format, + with the same spatial resolution as the input image. + """ + + decoded_channels: Float[Tensor, "batch_size output_channels height width"] + + +@dataclass +class SummaryTaskOutput: + """ + SummaryTaskOutput have a single latent output for each image in BC format. + """ + + decoded_channels: Float[Tensor, "batch_size output_channels"] + + +@dataclass +class AdaptorInput: + adaptor_feature: Float[Tensor, "batch_size sliced_channels height width"] + output_shape_hw: Tuple[int, int] + + +@dataclass +class AdaptorOutput: + value: Float[Tensor, "batch_size sliced_channels ..."] + + +@dataclass +class PredictionHeadOutput: + adaptor_output: Dict[str, AdaptorOutput] + + +@dataclass +class MaskAdaptorOutput: + logits: Float[Tensor, "batch_size 1 height width"] + mask: Float[Tensor, "batch_size 1 height width"] + + +@dataclass +class Covariance2DAdaptorOutput: + covariance: Float[Tensor, "batch_size 3 height width"] # the 3 channels are s_x^2, s_y^2, and rho_xy + log_det: Float[Tensor, "batch_size 1 height width"] # log determinant of the covariance matrix + inv_covariance: Float[ + Tensor, "batch_size 3 height width" + ] # the channels are [0,0], [1,1], and [0,1] of the inverse covariance matrix + + +@dataclass +class RegressionAdaptorOutput: + value: Float[Tensor, "batch_size sliced_channels height width"] + + +@dataclass +class RegressionWithConfidenceAdaptorOutput: + value: Float[Tensor, "batch_size sliced_channels height width"] + confidence: Float[Tensor, "batch_size 1 height width"] + + +@dataclass +class RegressionWithMaskAdaptorOutput: + value: Float[Tensor, "batch_size sliced_channels height width"] + logits: Float[Tensor, "batch_size 1 height width"] + mask: Float[Tensor, "batch_size 1 height width"] + + +@dataclass +class RegressionWithConfidenceAndMaskAdaptorOutput: + value: Float[Tensor, "batch_size sliced_channels height width"] + confidence: Float[Tensor, "batch_size 1 height width"] + logits: Float[Tensor, "batch_size 1 height width"] + mask: Float[Tensor, "batch_size 1 height width"] + + +class UniCeptionPredictionHeadBase(nn.Module): + def __init__( + self, + name: str, + *args, + **kwargs, + ): + """ + Base class for all prediction heads in UniCeption. + """ + super().__init__(*args, **kwargs) + + self.name: str = name + + def forward( + self, + head_input: PredictionHeadInput, + ) -> PredictionHeadOutput: + """ + Forward interface for the UniCeption prediction heads. + + + Args: + head_input (PredictionHeadInput): Input to the prediction head. + + Returns: + head_output (PredictionHeadOutput): Output of the prediction head. + """ + + raise NotImplementedError + + +class UniCeptionAdaptorBase(nn.Module): + def __init__( + self, + name: str, + required_channels: int, + *args, + **kwargs, + ): + """ + Base class for all adaptors in UniCeption. + """ + super().__init__(*args, **kwargs) + + self.name: str = name + self.required_channels: int = required_channels + + def forward( + self, + adaptor_input: AdaptorInput, + ) -> AdaptorOutput: + """ + Forward interface for the UniCeption adaptors. + + + Args: + adaptor_input (AdaptorInput): Input to the adaptor. + + Returns: + adaptor_output (AdaptorOutput): Output of the adaptor. + """ + + raise NotImplementedError + + +class AdaptorMap(nn.Module): + def __init__(self, *adaptors: UniCeptionAdaptorBase): + """ + AdaptorMap slices the input tensor and passes it to the corresponding adaptors. + + Args: + *adaptors (List[UniCeptionAdaptorBase]): List of adaptors in the Adaptor + """ + + super().__init__() + self.adaptors = nn.ModuleDict({adaptor.name: adaptor for adaptor in adaptors}) + + self.required_channels = sum([adaptor.required_channels for adaptor in adaptors]) + + def forward( + self, + adaptor_input: AdaptorInput, + ) -> Dict[str, AdaptorOutput]: + """ + Run the input through the adaptors and return the output. + + Args: + adaptor_input (AdaptorInput): Input to the adaptors. + + Returns: + Dict[str, AdaptorOutput]: Output of the adaptors, from adaptor name to AdaptorOutput. + """ + + # split adaptor input into chunks + adaptor_features = torch.split( + adaptor_input.decoded_channels, [adaptor.required_channels for adaptor in self.adaptors.values()], dim=1 + ) + + result = { + adaptor_name: adaptor(AdaptorInput(adaptor_features[i], adaptor_features[i].shape[2:])) + for i, (adaptor_name, adaptor) in enumerate(self.adaptors.items()) + } + + return result diff --git a/UniCeption/uniception/models/prediction_heads/cosmos.py b/UniCeption/uniception/models/prediction_heads/cosmos.py new file mode 100644 index 0000000000000000000000000000000000000000..dbc43a591ff5bd3b1bff8d425c2a8b060fde4912 --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/cosmos.py @@ -0,0 +1,211 @@ +""" +Cosmos Decoder head implementation +Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width); +""" + +import torch +import torch.nn as nn + +from uniception.models.libs.cosmos_tokenizer.modules import DecoderType +from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs +from uniception.models.prediction_heads.adaptors import ( + Covariance2DAdaptor, + FlowAdaptor, + FlowWithConfidenceAdaptor, + MaskAdaptor, +) +from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput + +COSMOS_LATENT_CHANNELS = 16 + +CLASSNAME_TO_ADAPTOR_CLASS = { + "FlowAdaptor": FlowAdaptor, + "FlowWithConfidenceAdaptor": FlowWithConfidenceAdaptor, + "Covariance2DAdaptor": Covariance2DAdaptor, + "MaskAdaptor": MaskAdaptor, +} + + +class CosmosSingleChannel(nn.Module): + """ + This class implements a single cosmos decoder. This decoder takes features and produce + a single channel output in the range of [-1, 1] (not strictly enforced). + """ + + def __init__( + self, + patch_size: int, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Initialize the linear feature mapping. + + Args: + input_feature_dim : int, the input feature dimension + output_dim : int, the output feature dimension + patch_size : int, the patch size + """ + + super().__init__(*args, **kwargs) + + self.patch_size = patch_size + + assert self.patch_size in [8, 16], f"Invalid patch size: {self.patch_size}" + + # Init Cosmos Encoder sepecific attributes + tokenizer_config = TokenizerConfigs["CI"].value.copy() + tokenizer_config.update(dict(spatial_compression=self.patch_size)) + + z_channels = tokenizer_config["z_channels"] + latent_channels = tokenizer_config["latent_channels"] + del tokenizer_config["z_channels"] + del tokenizer_config["latent_channels"] + + decoder_name = tokenizer_config.get("decoder", DecoderType.Default.name) + self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **tokenizer_config) + + self.post_quant_conv = torch.nn.Conv2d(latent_channels, z_channels, 1) + + if pretrained_checkpoint_path is not None: + print(f"Loading pretrained cosmos decoder from {pretrained_checkpoint_path}") + ckpt = torch.load(pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, x: torch.Tensor): + """ + Forward interface for the linear feature mapping. + + Args: + x : torch.Tensor, the input features + + Returns: + torch.Tensor, the output of the linear feature mapping + """ + + x = self.post_quant_conv(x) + x = self.decoder(x) + + return x + + +class CosmosFeature(nn.Module): + """ + This class implements a linear mapping from the low resolution patch features + to pixel-wise features. + """ + + def __init__( + self, + input_feature_dim: int, + output_dim: int, + patch_size: int, + skip_linear: bool = False, + single_channel_ckpt: str = None, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Initialize the linear feature mapping. + + Args: + input_feature_dim : int, the input feature dimension + output_dim : int, the output feature dimension + patch_size : int, the patch size + pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None) + """ + + super().__init__(*args, **kwargs) + + self.input_feature_dim = input_feature_dim + self.output_dim = output_dim + self.patch_size = patch_size + self.skip_linear = skip_linear + self.pretrained_checkpoint_path = pretrained_checkpoint_path + + assert self.patch_size in [8, 16], f"Invalid patch size: {self.patch_size}" + + if not self.skip_linear: + self.linear = nn.Conv2d( + in_channels=self.input_feature_dim, + out_channels=self.output_dim * COSMOS_LATENT_CHANNELS, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + self.cosmos_decoders = nn.ModuleList( + [ + CosmosSingleChannel( + patch_size=self.patch_size, + pretrained_checkpoint_path=single_channel_ckpt, + *args, + **kwargs, + ) + for _ in range(self.output_dim) + ] + ) + + self.output_scaling = nn.Parameter(torch.ones(1, self.output_dim, 1, 1)) + self.output_bias = nn.Parameter(torch.zeros(1, self.output_dim, 1, 1)) + + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, feature_input: PredictionHeadInput): + """ + Forward interface for the linear feature mapping. + + Args: + feature_input : PredictionHeadInput, the input features + - last_feature : torch.Tensor, the last feature tensor + + Returns: + PixelTaskOutput, the output of the linear feature mapping + - decoded_channels : torch.Tensor, the decoded channels + + """ + + x = feature_input.last_feature + + assert ( + x.shape[1] == self.input_feature_dim + ), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}" + + if not self.skip_linear: + x = self.linear(x) + + x_split = list(torch.split(x, COSMOS_LATENT_CHANNELS, dim=1)) + + output = [None] * self.output_dim + for i, decoder in enumerate(self.cosmos_decoders): + output[i] = torch.mean(decoder(x_split[i]), dim=1, keepdim=True) + + # Concatenate the decoded channels + x = torch.cat(output, dim=1) + + # a linear scaling layer to map cosmos output [-1, 1] to arbitrary range + x = x * self.output_scaling + self.output_bias + + return PixelTaskOutput(decoded_channels=x), x_split + + +if __name__ == "__main__": + + x_single_channel = torch.randn(1, 16, 8, 8) + + # Test CosmosSingleChannel + cosmos_single_channel = CosmosSingleChannel(patch_size=8) + cosmos_single_channel(x_single_channel) + + # Test CosmosFeature + cosmos_feature = CosmosFeature(input_feature_dim=1024, output_dim=2, patch_size=8) + x_feature = torch.randn(1, 1024, 8, 8) + + output = cosmos_feature(PredictionHeadInput(last_feature=x_feature)) + print(output.decoded_channels.shape) diff --git a/UniCeption/uniception/models/prediction_heads/dpt.py b/UniCeption/uniception/models/prediction_heads/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..ab65646b77c760cfd8225d2a93f12179db54de4a --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/dpt.py @@ -0,0 +1,676 @@ +""" +DPT head implementation +Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width); +The DPT head implementation is based on DUSt3R and CroCoV2 +References: https://github.com/naver/dust3r +""" + +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +from jaxtyping import Float +from torch import Tensor +from torch.utils.checkpoint import checkpoint + +from uniception.models.libs.croco.dpt_block import make_fusion_block, make_nonlinearity, make_scratch, pair +from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadLayeredInput + + +@dataclass +class DPTFeatureInput: + features_upsampled_8x: Float[Tensor, "batch_size dpt_output_feat_dim feat_height_8x feat_width_8x"] + target_output_shape: Tuple[int, int] + + +# -------------------------------------------------------- DPT Feature -------------------------------------------------------- + + +class DPTFeature(nn.Module): + """ + DPT head implementation based on DUSt3R and CroCoV2 + + Behavior: + In forward, it will take in a list of Feature Tensors in BCHW (B, C, H//P, W//P)format, + and return a upsampled feature tensor of shape (B, C, 8*(H//P), 8*(W//P)). This module + should be used together with DPT[*]Processor to upsample the feature and + interpolate when P is not 2^n to match the image shape exactly. + """ + + def __init__( + self, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ("rgb",), + hooks: List[int] = [2, 5, 8, 11], + input_feature_dims: Optional[Union[int, List[int]]] = 768, + layer_dims: List[int] = [96, 192, 384, 768], + feature_dim: int = 256, + use_bn: bool = False, + output_width_ratio=1, + pretrained_checkpoint_path: str = None, + checkpoint_gradient: bool = False, + nonlinearity: str = "relu", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.checkpoint_gradient = checkpoint_gradient + + if isinstance(input_feature_dims, int): + input_feature_dims = 4 * [input_feature_dims] + else: + input_feature_dims = input_feature_dims + assert isinstance(input_feature_dims, List) and len(input_feature_dims) == 4 + + self.input_feature_dims = input_feature_dims + + self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio, nonlinearity=nonlinearity) + self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio, nonlinearity=nonlinearity) + self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio, nonlinearity=nonlinearity) + self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio, nonlinearity=nonlinearity) + + # delete resconfunit1 in refinement 4 because it is not used, and will cause error in DDP. + del self.scratch.refinenet4.resConfUnit1 + + if self.input_feature_dims is not None: + self.init(input_feature_dims=input_feature_dims) + + self.pretrained_checkpoint_path = pretrained_checkpoint_path + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained DPT dense feature head from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def init(self, input_feature_dims: Union[int, List[int]] = 768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + + Args: + input_feature_dims: Dimension of tokens coming from encoder + """ + # Set up activation postprocessing layers + if isinstance(input_feature_dims, int): + input_feature_dims = 4 * [input_feature_dims] + + self.input_feature_dims = [dt * len(self.main_tasks) for dt in input_feature_dims] + + act_1_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.input_feature_dims[0], + out_channels=self.layer_dims[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[0], + out_channels=self.layer_dims[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + act_2_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.input_feature_dims[1], + out_channels=self.layer_dims[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.input_feature_dims[2], + out_channels=self.layer_dims[2], + kernel_size=1, + stride=1, + padding=0, + ) + ) + + act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.input_feature_dims[3], + out_channels=self.layer_dims[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[3], + out_channels=self.layer_dims[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + act_postprocess = [act_1_postprocess, act_2_postprocess, act_3_postprocess, act_4_postprocess] + + self.input_process = nn.ModuleList( + [nn.Sequential(act_, layer_rn_) for act_, layer_rn_ in zip(act_postprocess, self.scratch.layer_rn)] + ) + + def forward(self, dpt_input: PredictionHeadLayeredInput) -> DPTFeatureInput: + """ + DPT Feature forward pass from 4 layers in the transformer to 8x sampled feature output. + + Args: + dpt_input (PredictionHeadLayeredInput): Input to the DPT feature head + - list_features: List of 4 BCHW Tensors representing the features from 4 layers of the transformer + + Returns: + DPTFeatureInput: Output of the DPT feature head + - features_upsampled_8x: BCHW Tensor representing the 8x upsampled feature. + """ + + assert self.input_feature_dims is not None, "Need to call init(input_feature_dims) function first" + + layered_feats = dpt_input.list_features + + # check input dimensions + for hook_idx, hook in enumerate(self.hooks): + assert ( + layered_feats[hook].shape[1] == self.input_feature_dims[hook_idx] + ), f"Input feature dimension mismatch at hook {hook}. Expected BCHW" + + if not self.checkpoint_gradient: + # Hook decoder onto 4 layers from specified ViT layers + layers = [layered_feats[hook] for hook in self.hooks] + + # layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # # Project layers to chosen feature dim + # layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + layers = [self.input_process[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[3])[:, :, : layers[2].shape[2], : layers[2].shape[3]] + path_3 = self.scratch.refinenet3(path_4, layers[2]) + path_2 = self.scratch.refinenet2(path_3, layers[1]) + feature_upsampled_8x = self.scratch.refinenet1(path_2, layers[0]) + else: + # Hook decoder onto 4 layers from specified ViT layers + layers = [layered_feats[hook] for hook in self.hooks] + + layers = [checkpoint(self.input_process[idx], l, use_reentrant=False) for idx, l in enumerate(layers)] + + path_4 = checkpoint(self.scratch.refinenet4, layers[3], use_reentrant=False)[ + :, :, : layers[2].shape[2], : layers[2].shape[3] + ] + path_3 = checkpoint(self.scratch.refinenet3, path_4, layers[2], use_reentrant=False) + path_2 = checkpoint(self.scratch.refinenet2, path_3, layers[1], use_reentrant=False) + feature_upsampled_8x = checkpoint(self.scratch.refinenet1, path_2, layers[0], use_reentrant=False) + + return DPTFeatureInput( + features_upsampled_8x=feature_upsampled_8x, target_output_shape=dpt_input.target_output_shape + ) + + +# -------------------------------------------------------- DPT Processors -------------------------------------------------------- + + +class DPTRegressionProcessor(nn.Module): + def __init__( + self, + input_feature_dim: int, + output_dim: int, + hidden_dims: Optional[List[int]] = None, # when not given, use input_feature_dim//2 + pretrained_checkpoint_path: str = None, + checkpoint_gradient: bool = False, + nonlinearity: str = "relu", + *args, + **kwargs, + ): + """ + DPT regression processor, takes 8x upsampled feature from DPT and furture upsamples to target shape + + It will interpolate the feature to match the target shape exactly, handling patch size not 2^n + + Args: + input_feature_dim: Dimension of input feature + output_dim: Dimension of output regression + hidden_dims: [h1, h2] List of 2 hidden dimensions for intermediate. default is [input_feature_dim//2] * 2 + pretrained_checkpoint_path: Path to pretrained checkpoint (default: None) + """ + + super().__init__(*args, **kwargs) + + if hidden_dims is None: + hidden_dims = [input_feature_dim // 2] * 2 + else: + assert isinstance(hidden_dims, List) and len(hidden_dims) == 2 + + self.checkpoint_gradient = checkpoint_gradient + + self.conv1 = nn.Conv2d(input_feature_dim, hidden_dims[0], kernel_size=3, stride=1, padding=1) + # interpolate is dependent on target output size + self.conv2 = nn.Sequential( + nn.Conv2d(hidden_dims[0], hidden_dims[1], kernel_size=3, stride=1, padding=1), + make_nonlinearity(nonlinearity), + nn.Conv2d(hidden_dims[1], output_dim, kernel_size=1, stride=1, padding=0), + ) + + self.pretrained_checkpoint_path = pretrained_checkpoint_path + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained DPT regression processor from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, dpt_processor_input: DPTFeatureInput): + """ + DPT regression processor, process DPT output into channels to be adapted into regression output. + + Args: + dpt_processor_input (DPTFeatureInput): Input to the processor + - features_upsampled_8x: BCHW Tensor representing the upsampled feature + - target_output_shape: Tuple of (H, W) representing the target output shape + + Returns: + PixelTaskOutput: Output of the processor + - decoded_channels: BCHW Tensor representing the regression output + """ + + x = dpt_processor_input.features_upsampled_8x + output_shape = dpt_processor_input.target_output_shape + + if not self.checkpoint_gradient: + x = self.conv1(x) + x = F.interpolate(x, size=output_shape, mode="bilinear", align_corners=True) + x = self.conv2(x) + else: + x = self.conv1(x) + x = F.interpolate(x, size=output_shape, mode="bilinear", align_corners=True) + x = checkpoint(self.conv2, x, use_reentrant=False) + + return PixelTaskOutput(decoded_channels=x) + + +class DPTSegmentationProcessor(nn.Module): + def __init__( + self, + input_feature_dim: int, + output_dim: int, + hidden_dim: Optional[int] = None, # when not given, use input_feature_dim + use_bn: bool = False, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + DPT segmentation processor, takes 8x upsampled feature from DPT and furture upsamples to target shape. + This version differs slightly from the regression processor. + + It will interpolate the feature to match the target shape exactly, handling patch size not 2^n + + Args: + input_feature_dim: Dimension of input feature + output_dim: Dimension of output regression + hidden_dim: h1 Hidden dimension for intermediate. default is input_feature_dim + use_bn: Whether to use batch normalization, default is False + pretrained_checkpoint_path: Path to pretrained checkpoint (default: None) + """ + + super().__init__(*args, **kwargs) + + if hidden_dim is None: + hidden_dim = input_feature_dim + + self.conv = nn.Sequential( + nn.Conv2d(input_feature_dim, hidden_dim, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(hidden_dim) if use_bn else nn.Identity(), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(hidden_dim, output_dim, kernel_size=1), + ) + + self.pretrained_checkpoint_path = pretrained_checkpoint_path + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained DPT segmentation processor from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, dpt_processor_input: DPTFeatureInput): + """ + Forward pass for the DPT segmentation processor, process DPT output into channels + to be adapted into segmentation mask. + + Args: + dpt_processor_input (DPTFeatureInput): Input to the processor + - features_upsampled_8x: BCHW Tensor representing the upsampled feature + - target_output_shape: Tuple of (H, W) representing the target output shape + + Returns: + PixelTaskOutput: Output of the processor + - decoded_channels: BCHW Tensor representing the segmentation mask + """ + + x = dpt_processor_input.features_upsampled_8x + output_shape = dpt_processor_input.target_output_shape + + x = self.conv(x) + x = F.interpolate(x, size=output_shape, mode="bilinear", align_corners=True) + + return PixelTaskOutput(decoded_channels=x) + + +# ---------------------------------------- DPT Feature 2x upsample ---------------------------------------- +class DPTFeatureDoubleUpsampling(nn.Module): + """ + DPT head implementation based on DUSt3R and CroCoV2 + + Behavior: + In forward, it will take in a list of Feature Tensors in BCHW (B, C, H//P, W//P)format, + and return a upsampled feature tensor of shape (B, C, 8*(H//P), 8*(W//P)). This module + should be used together with DPT[*]Processor to upsample the feature and + interpolate when P is not 2^n to match the image shape exactly. + """ + + def __init__( + self, + patch_size: Union[int, Tuple[int, int]] = 16, + main_tasks: Iterable[str] = ("rgb",), + hooks: List[int] = [0, 1], + input_feature_dims: Optional[Union[int, List[int]]] = 768, + layer_dims: List[int] = [384, 768], + feature_dim: int = 256, + use_bn: bool = False, + output_width_ratio=1, + pretrained_checkpoint_path: str = None, + checkpoint_gradient: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.patch_size = pair(patch_size) + self.main_tasks = main_tasks + self.hooks = hooks + self.layer_dims = layer_dims + self.feature_dim = feature_dim + self.checkpoint_gradient = checkpoint_gradient + + if isinstance(input_feature_dims, int): + input_feature_dims = 2 * [input_feature_dims] + else: + input_feature_dims = input_feature_dims + assert isinstance(input_feature_dims, List) and len(input_feature_dims) == 2 + + self.input_feature_dims = input_feature_dims + + self.scratch = self.make_scratch_2(layer_dims, feature_dim, groups=1, expand=False) + + self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio) + + # delete resconfunit1 in refinement 4 because it is not used, and will cause error in DDP. + del self.scratch.refinenet4.resConfUnit1 + + if self.input_feature_dims is not None: + self.init(input_feature_dims=input_feature_dims) + + self.pretrained_checkpoint_path = pretrained_checkpoint_path + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained DPT dense feature head from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def make_scratch_2(self, in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer3_rn = nn.Conv2d( + in_shape[0], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[1], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + scratch.layer_rn = nn.ModuleList( + [ + scratch.layer3_rn, + scratch.layer4_rn, + ] + ) + + return scratch + + def init(self, input_feature_dims: Union[int, List[int]] = 768): + """ + Initialize parts of decoder that are dependent on dimension of encoder tokens. + + Args: + input_feature_dims: Dimension of tokens coming from encoder + """ + # Set up activation postprocessing layers + if isinstance(input_feature_dims, int): + input_feature_dims = 2 * [input_feature_dims] + + self.input_feature_dims = [dt * len(self.main_tasks) for dt in input_feature_dims] + + act_3_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.input_feature_dims[0], + out_channels=self.layer_dims[0], + kernel_size=1, + stride=1, + padding=0, + ) + ) + + act_4_postprocess = nn.Sequential( + nn.Conv2d( + in_channels=self.input_feature_dims[1], + out_channels=self.layer_dims[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=self.layer_dims[1], + out_channels=self.layer_dims[1], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + act_postprocess = [act_3_postprocess, act_4_postprocess] + + self.input_process = nn.ModuleList( + [nn.Sequential(act_, layer_rn_) for act_, layer_rn_ in zip(act_postprocess, self.scratch.layer_rn)] + ) + + def forward(self, dpt_input: PredictionHeadLayeredInput) -> DPTFeatureInput: + """ + DPT Feature forward pass from 4 layers in the transformer to 8x sampled feature output. + + Args: + dpt_input (PredictionHeadLayeredInput): Input to the DPT feature head + - list_features: List of 4 BCHW Tensors representing the features from 4 layers of the transformer + + Returns: + DPTFeatureInput: Output of the DPT feature head + - features_upsampled_8x: BCHW Tensor representing the 8x upsampled feature. + """ + + assert self.input_feature_dims is not None, "Need to call init(input_feature_dims) function first" + + layered_feats = dpt_input.list_features + + # check input dimensions + for hook_idx, hook in enumerate(self.hooks): + assert ( + layered_feats[hook].shape[1] == self.input_feature_dims[hook_idx] + ), f"Input feature dimension mismatch at hook {hook}. Expected BCHW" + + if not self.checkpoint_gradient: + # Hook decoder onto 4 layers from specified ViT layers + layers = [layered_feats[hook] for hook in self.hooks] + + # layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] + # # Project layers to chosen feature dim + # layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] + layers = [self.input_process[idx](l) for idx, l in enumerate(layers)] + + # Fuse layers using refinement stages + path_4 = self.scratch.refinenet4(layers[1])[:, :, : layers[0].shape[2], : layers[0].shape[3]] + feature_upsampled_2x = self.scratch.refinenet3(path_4, layers[0]) + else: + # Hook decoder onto 4 layers from specified ViT layers + layers = [layered_feats[hook] for hook in self.hooks] + + layers = [checkpoint(self.input_process[idx], l, use_reentrant=False) for idx, l in enumerate(layers)] + + path_4 = checkpoint(self.scratch.refinenet4, layers[1], use_reentrant=False)[ + :, :, : layers[0].shape[2], : layers[0].shape[3] + ] + feature_upsampled_2x = checkpoint(self.scratch.refinenet3, path_4, layers[0], use_reentrant=False) + + return DPTFeatureInput( + features_upsampled_8x=feature_upsampled_2x, target_output_shape=dpt_input.target_output_shape + ) + + +if __name__ == "__main__": + import numpy as np + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Ensure the model is on GPU + num_runs = 20 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Instantiate the model and move to GPU + dpt_feature_output = DPTFeature( + patch_size=16, + main_tasks=("rgb",), + hooks=[2, 5, 8, 11], + input_feature_dims=[1024, 768, 768, 768], + layer_dims=[96, 192, 384, 768], + feature_dim=256, + use_bn=False, + output_width_ratio=1, + checkpoint_gradient=True, + ).to(device) + + postprocess = DPTRegressionProcessor(input_feature_dim=256, output_dim=3, checkpoint_gradient=True).to(device) + + # Define input shape + image_shape = (560, 420) + batch_size = 12 + patch_size = 14 + + patch_num = (image_shape[0] // patch_size, image_shape[1] // patch_size) + + input_feats = [None for _ in range(12)] + + input_feats[2] = torch.randn(batch_size, 1024, *patch_num, device=device, requires_grad=True) + input_feats[5] = torch.randn(batch_size, 768, *patch_num, device=device, requires_grad=True) + input_feats[8] = torch.randn(batch_size, 768, *patch_num, device=device, requires_grad=True) + input_feats[11] = torch.randn(batch_size, 768, *patch_num, device=device, requires_grad=True) + + # Warm-up to stabilize GPU performance + for _ in range(3): + output = dpt_feature_output( + PredictionHeadLayeredInput(list_features=input_feats, target_output_shape=image_shape) + ) + output2 = postprocess(output) + torch.cuda.synchronize() + + # Clear memory cache + torch.cuda.empty_cache() + + # Lists to store results + forward_times = [] + backward_times = [] + memory_usages = [] + + for _ in range(num_runs): + # Start measuring time + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Reset memory stats + torch.cuda.reset_peak_memory_stats() + memory_before = torch.cuda.max_memory_allocated(device) + + # Forward pass + start_event.record() + output = dpt_feature_output( + PredictionHeadLayeredInput(list_features=input_feats, target_output_shape=image_shape) + ) + output2 = postprocess(output) + end_event.record() + torch.cuda.synchronize() + forward_time = start_event.elapsed_time(end_event) # Time in milliseconds + + # Backward pass + start_event.record() + output = dpt_feature_output( + PredictionHeadLayeredInput(list_features=input_feats, target_output_shape=image_shape) + ) + output2 = postprocess(output) + output2.decoded_channels.sum().backward() + end_event.record() + torch.cuda.synchronize() + backward_time = start_event.elapsed_time(end_event) + + # Memory usage + memory_after = torch.cuda.max_memory_allocated(device) + peak_memory = memory_after - memory_before + + forward_times.append(forward_time) + backward_times.append(backward_time) + memory_usages.append(peak_memory / 1e6) # Convert to MB + + # Compute mean and standard deviation + fwd_mean, fwd_std = np.mean(forward_times), np.std(forward_times) + bwd_mean, bwd_std = np.mean(backward_times), np.std(backward_times) + mem_mean, mem_std = np.mean(memory_usages), np.std(memory_usages) + + print(f"Forward Pass Time: {fwd_mean:.2f} ± {fwd_std:.2f} ms") + print(f"Backward Pass Time: {bwd_mean:.2f} ± {bwd_std:.2f} ms") + print(f"Peak GPU Memory Usage: {mem_mean:.2f} ± {mem_std:.2f} MB") diff --git a/UniCeption/uniception/models/prediction_heads/global_head.py b/UniCeption/uniception/models/prediction_heads/global_head.py new file mode 100644 index 0000000000000000000000000000000000000000..dff36f37d24c0dc6af4a721bfbbcbb46cca8b6dc --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/global_head.py @@ -0,0 +1,142 @@ +""" +Global quantity prediction head implementation +Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width) +""" + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from uniception.models.prediction_heads.base import PredictionHeadInput, SummaryTaskOutput +from uniception.models.prediction_heads.pose_head import ResConvBlock + + +class GlobalHead(nn.Module): + """ + Glboal quantity regression head implementation + """ + + def __init__( + self, + patch_size: int, + input_feature_dim: int, + num_resconv_block: int = 2, + output_representation_dim: int = 1, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Initialize the global head. + + Args: + patch_size : int, the patch size of the transformer used to generate the input features + input_feature_dim : int, the input feature dimension + num_resconv_block : int, the number of residual convolution blocks + output_representation_dim : int, the dimension of the output representation + pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None) + """ + super().__init__() + self.patch_size = patch_size + self.input_feature_dim = input_feature_dim + self.num_resconv_block = num_resconv_block + self.output_representation_dim = output_representation_dim + self.pretrained_checkpoint_path = pretrained_checkpoint_path + + # Initialize the hidden dimension of the global head based on the patch size + self.output_dim = 4 * (self.patch_size**2) + + # Initialize the projection layer for the hidden dimension of the global head + self.proj = nn.Conv2d( + in_channels=self.input_feature_dim, + out_channels=self.output_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + # Initialize sequential layers of the global head + self.res_conv = nn.ModuleList( + [copy.deepcopy(ResConvBlock(self.output_dim, self.output_dim)) for _ in range(self.num_resconv_block)] + ) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.more_mlps = nn.Sequential( + nn.Linear(self.output_dim, self.output_dim), + nn.ReLU(), + nn.Linear(self.output_dim, self.output_dim), + nn.ReLU(), + ) + self.fc_output = nn.Linear(self.output_dim, self.output_representation_dim) + + # Load the pretrained checkpoint if provided + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained global head from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, feature_input: PredictionHeadInput): + """ + Forward interface for the global quantity prediction head. + The head requires an adapter on the final output. + + Args: + feature_input : PredictionHeadInput, the input features + - last_feature : torch.Tensor, the last feature tensor + + Returns: + SummaryTaskOutput, the output of the global head + - decoded_channels : torch.Tensor, the decoded channels + """ + # Get the patch-level features from the input + feat = feature_input.last_feature # (B, C, H, W) + + # Check the input dimensions + assert ( + feat.shape[1] == self.input_feature_dim + ), f"Input feature dimension {feat.shape[1]} does not match expected dimension {self.input_feature_dim}" + + # Apply the projection layer to the patch-level features + feat = self.proj(feat) # (B, PC, H, W) + + # Apply the residual convolution blocks to the projected features + for i in range(self.num_resconv_block): + feat = self.res_conv[i](feat) + + # Apply the average pooling layer to the residual convolution output + feat = self.avgpool(feat) # (B, PC, 1, 1) + + # Flatten the average pooled features + feat = feat.view(feat.size(0), -1) # (B, PC) + + # Apply the more MLPs to the flattened features + feat = self.more_mlps(feat) # (B, PC) + + # Apply the final linear layers to the more MLPs output + output_feat = self.fc_output(feat) # (B, self.output_representation_dim) + + return SummaryTaskOutput(decoded_channels=output_feat) + + +if __name__ == "__main__": + # Init an example global head + global_head = GlobalHead( + patch_size=14, + input_feature_dim=1024, + num_resconv_block=2, + output_representation_dim=1, + pretrained_checkpoint_path=None, + ) + + # Create a dummy input tensor with shape (B, C, H, W) + dummy_input = torch.randn(4, 1024, 14, 14) # Example input + + # Run dummy forward pass + output = global_head(PredictionHeadInput(last_feature=dummy_input)) + + # Check the output shape + assert output.decoded_channels.shape == (4, 1), "Output shape mismatch" + + print("Global head test passed!") diff --git a/UniCeption/uniception/models/prediction_heads/linear.py b/UniCeption/uniception/models/prediction_heads/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..31ba601648b7d371d732fe48e8cd3f8ee4823f9d --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/linear.py @@ -0,0 +1,95 @@ +""" +Linear head implementation +Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width); +The linear head implementation is based on DUSt3R and CroCoV2 +References: https://github.com/naver/dust3r +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput + + +class LinearFeature(nn.Module): + """ + This class implements a linear mapping from the low resolution patch features + to pixel-wise features. + """ + + def __init__( + self, + input_feature_dim: int, + output_dim: int, + patch_size: int, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Initialize the linear feature mapping. + + Args: + input_feature_dim : int, the input feature dimension + output_dim : int, the output feature dimension + patch_size : int, the patch size + pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None) + """ + + super().__init__(*args, **kwargs) + + self.input_feature_dim = input_feature_dim + self.output_dim = output_dim + self.patch_size = patch_size + self.pretrained_checkpoint_path = pretrained_checkpoint_path + + self.linear = nn.Conv2d( + in_channels=self.input_feature_dim, + out_channels=self.output_dim * (self.patch_size**2), + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, feature_input: PredictionHeadInput): + """ + Forward interface for the linear feature mapping. + + Args: + feature_input : PredictionHeadInput, the input features + - last_feature : torch.Tensor, the last feature tensor + + Returns: + PixelTaskOutput, the output of the linear feature mapping + - decoded_channels : torch.Tensor, the decoded channels + + """ + + x = feature_input.last_feature + + assert ( + x.shape[1] == self.input_feature_dim + ), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}" + + x = self.linear(x) + x = F.pixel_shuffle(x, self.patch_size) + + return PixelTaskOutput(decoded_channels=x) + + +if __name__ == "__main__": + # Init an example linear feature head + linear_prediction_head = LinearFeature(input_feature_dim=768, output_dim=4, patch_size=16) + + # Create a dummy input tensor with shape (B, C, H, W) + dummy_input = torch.randn(1, 768, 14, 14) # Example input + + # Run dummy forward pass + output = linear_prediction_head(PredictionHeadInput(last_feature=dummy_input)) diff --git a/UniCeption/uniception/models/prediction_heads/mlp_feature.py b/UniCeption/uniception/models/prediction_heads/mlp_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..7b23a41d0e162fbac558ffb10c58acc6bb776ac2 --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/mlp_feature.py @@ -0,0 +1,114 @@ +""" +Linear head with MLP implementation +Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width) +""" + +from typing import Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadInput +from uniception.models.utils.transformer_blocks import Mlp + + +class MLPFeature(nn.Module): + """ + This class implements a linear mapping from the low resolution patch features + to pixel-wise features with an additional intermediate MLP layer. + """ + + def __init__( + self, + input_feature_dim: Union[int, str], + patch_size: int, + output_dim: int, + mlp_ratio: int = 4, + act_layer=nn.GELU, + bias=True, + drop=0.0, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Initialize the linear feature mapping. + + Args: + input_feature_dim : int, the input feature dimension + output_dim : int, the output feature dimension + patch_size : int, the patch size + pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None) + """ + + super().__init__(*args, **kwargs) + + if isinstance(input_feature_dim, str): + input_feature_dim = eval(input_feature_dim) + + self.input_feature_dim = input_feature_dim + self.output_dim = output_dim + self.patch_size = patch_size + self.pretrained_checkpoint_path = pretrained_checkpoint_path + + self.mlp = Mlp( + in_features=self.input_feature_dim, + hidden_features=int(mlp_ratio * self.input_feature_dim), + act_layer=act_layer, + drop=drop, + bias=bias, + ) + + self.linear = nn.Conv2d( + in_channels=self.input_feature_dim, + out_channels=self.output_dim * (self.patch_size**2), + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained linear dense feature head from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, feature_input: PredictionHeadInput): + """ + Forward interface for the linear feature mapping. + + Args: + feature_input : PredictionHeadInput, the input features + - last_feature : torch.Tensor, the last feature tensor + + Returns: + PixelTaskOutput, the output of the linear feature mapping + - decoded_channels : torch.Tensor, the decoded channels + + """ + + x = feature_input.last_feature + + assert ( + x.shape[1] == self.input_feature_dim + ), f"Input feature dimension mismatch: {x.shape[1]} != {self.input_feature_dim}" + + x = self.mlp(x.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous() + x = self.linear(x) + x = F.pixel_shuffle(x, self.patch_size) + + return PixelTaskOutput(decoded_channels=x) + + +if __name__ == "__main__": + # Init an example linear feature head + linear_prediction_head = MLPFeature( + input_feature_dim=768, mlp_ratio=4, act_layer=nn.GELU, output_dim=4, patch_size=16 + ) + + # Create a dummy input tensor with shape (B, C, H, W) + dummy_input = torch.randn(1, 768, 14, 14) # Example input + + # Run dummy forward pass + output = linear_prediction_head(PredictionHeadInput(last_feature=dummy_input)) diff --git a/UniCeption/uniception/models/prediction_heads/mlp_head.py b/UniCeption/uniception/models/prediction_heads/mlp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..a64583f1c85c9b038e4bc236e9b03ab6c5302063 --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/mlp_head.py @@ -0,0 +1,114 @@ +""" +MLP head implementation +Downstream heads that coverts a batch of tokens to target representation. +Assumes inputs of size BC (B: batch, C: Channels) +""" + +import torch +import torch.nn as nn + +from uniception.models.prediction_heads.base import PredictionHeadTokenInput, SummaryTaskOutput + + +class MLPHead(nn.Module): + """ + MLP head implementation to convert tokens to target representation + """ + + def __init__( + self, + input_feature_dim: int, + output_dim: int, + num_mlp_layers: int = 2, + hidden_dim: int = 196, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Initialize the MLP head. + + Args: + input_feature_dim (int): Input feature dimension. + num_mlp_layers (int): Number of MLP layers. + pretrained_checkpoint_path (str): Path to a pretrained checkpoint. + """ + super().__init__() + self.input_feature_dim = input_feature_dim + self.num_mlp_layers = num_mlp_layers + self.hidden_dim = hidden_dim + + # Initialize the input projection layer for the hidden dimension of the mlp head + self.proj = nn.Linear(self.input_feature_dim, hidden_dim) + + # Initialize the MLP layers + self.mlp = nn.ModuleList() + for _ in range(self.num_mlp_layers): + self.mlp.append(nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU())) + + # Initialize the output projection layer for the target representation + self.output_proj = nn.Linear(self.hidden_dim, output_dim) + + # Load the pretrained checkpoint if provided + if pretrained_checkpoint_path: + print(f"Loading pretrained mlp head from {pretrained_checkpoint_path}") + ckpt = torch.load(pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, feature_input: PredictionHeadTokenInput): + """ + Forward interface for the mlp head. + Adapter can be used on output to achieve different types of scaling (linear, log, exp, etc). + + Args: + feature_input : PredictionHeadTokenInput, the input feature tokens + - last_feature : torch.Tensor, the last feature tensor + + Returns: + SummaryTaskOutput, the output of the mlp head + - decoded_channels : torch.Tensor, the decoded channels + """ + # Get the token features + feat = feature_input.last_feature # (B, C, T) + + # Check the input dimensions + assert feat.ndim == 3, f"Input feature tensor must have 3 dimensions (B, C, T), got {feat.ndim}" + assert ( + feat.shape[1] == self.input_feature_dim + ), f"Input feature dimension {feat.shape[1]} does not match expected dimension {self.input_feature_dim}" + + # Apply the projection layer + feat = feat.permute(0, 2, 1) # (B, T, C) + feat = self.proj(feat) # (B, hidden_dim) + + # Apply the MLP layers + for layer in self.mlp: + feat = layer(feat) + + # Apply the output projection layer + output = self.output_proj(feat) + output = output.permute(0, 2, 1) # (B, C, T) + + return SummaryTaskOutput(decoded_channels=output) + + +if __name__ == "__main__": + # Init an example MLP head + mlp_head = MLPHead( + input_feature_dim=768, + output_dim=1, + num_mlp_layers=2, + hidden_dim=196, + pretrained_checkpoint_path=None, + ) + + # Create a dummy input tensor with shape (B, C, T) + dummy_input = torch.randn(4, 768, 3) # Example batch of 4 with 768 features + + # Run dummy forward pass + output = mlp_head(PredictionHeadTokenInput(last_feature=dummy_input)) + + # Check the output shape + assert output.decoded_channels.shape == (4, 1, 3), "Output shape mismatch" + + print("MLP head test passed!") diff --git a/UniCeption/uniception/models/prediction_heads/moge_conv.py b/UniCeption/uniception/models/prediction_heads/moge_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7b9d9931b417f9edc114bacd1912edf5228acd --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/moge_conv.py @@ -0,0 +1,342 @@ +""" +MoGe Conv Decoder Implementation +References: https://github.com/microsoft/MoGe/blob/main/moge/model/v1.py +""" + +from typing import List, Literal, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint + +from uniception.models.prediction_heads.base import PixelTaskOutput, PredictionHeadLayeredInput + + +class ResidualConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + hidden_channels: Optional[int] = None, + padding_mode: str = "replicate", + activation: Literal["relu", "leaky_relu", "silu", "elu"] = "relu", + norm: Literal["group_norm", "layer_norm"] = "group_norm", + ): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation == "relu": + activation_cls = lambda: nn.ReLU(inplace=True) + elif activation == "leaky_relu": + activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) + elif activation == "silu": + activation_cls = lambda: nn.SiLU(inplace=True) + elif activation == "elu": + activation_cls = lambda: nn.ELU(inplace=True) + else: + raise ValueError(f"Unsupported activation function: {activation}") + + self.layers = nn.Sequential( + nn.GroupNorm(1, in_channels), + activation_cls(), + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode), + nn.GroupNorm(hidden_channels // 32 if norm == "group_norm" else 1, hidden_channels), + activation_cls(), + nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode), + ) + + self.skip_connection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +def normalized_view_plane_uv( + width: int, + height: int, + aspect_ratio: Optional[float] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, +) -> torch.Tensor: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio**2) ** 0.5 + span_y = 1 / (1 + aspect_ratio**2) ** 0.5 + + u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device) + v = torch.linspace( + -span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device + ) + u, v = torch.meshgrid(u, v, indexing="xy") + uv = torch.stack([u, v], dim=-1) + return uv + + +class MoGeConvFeature(nn.Module): + def __init__( + self, + patch_size: int, + # MoGe parameters + num_features: int, + input_feature_dims: Union[int, List[int]], + dim_out: List[int], + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 64], + dim_times_res_block_hidden: int = 2, + num_res_blocks: int = 2, + res_block_norm: Literal["group_norm", "layer_norm"] = "group_norm", + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + # UniCeption parameters + pretrained_checkpoint_path: Optional[str] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.patch_size = patch_size + if isinstance(input_feature_dims, int): + input_feature_dims = [input_feature_dims] * num_features + self.input_feature_dims = input_feature_dims + + self.projects = nn.ModuleList( + [ + nn.Conv2d( + in_channels=self.input_feature_dims[i], + out_channels=dim_proj, + kernel_size=1, + stride=1, + padding=0, + ) + for i in range(num_features) + ] + ) + + self.upsample_blocks = nn.ModuleList( + [ + nn.Sequential( + self._make_upsampler(in_ch + 2, out_ch), + *( + ResidualConvBlock( + out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm + ) + for _ in range(num_res_blocks) + ), + ) + for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) + ] + ) + + self.output_block = nn.ModuleList( + [ + self._make_output_block( + dim_upsample[-1] + 2, + dim_out_, + dim_times_res_block_hidden, + last_res_blocks, + last_conv_channels, + last_conv_size, + res_block_norm, + ) + for dim_out_ in dim_out + ] + ) + + self.pretrained_checkpoint_path = pretrained_checkpoint_path + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained DPT dense feature head from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def _make_upsampler(self, in_channels: int, out_channels: int): + upsampler = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode="replicate"), + ) + upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] + return upsampler + + def _make_output_block( + self, + dim_in: int, + dim_out: int, + dim_times_res_block_hidden: int, + last_res_blocks: int, + last_conv_channels: int, + last_conv_size: int, + res_block_norm: Literal["group_norm", "layer_norm"], + ): + return nn.Sequential( + nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode="replicate"), + *( + ResidualConvBlock( + last_conv_channels, + last_conv_channels, + dim_times_res_block_hidden * last_conv_channels, + activation="relu", + norm=res_block_norm, + ) + for _ in range(last_res_blocks) + ), + nn.ReLU(inplace=True), + nn.Conv2d( + last_conv_channels, + dim_out, + kernel_size=last_conv_size, + stride=1, + padding=last_conv_size // 2, + padding_mode="replicate", + ), + ) + + # @torch.compile(fullgraph=True, options={}, dynamic=True) + def forward(self, head_input: PredictionHeadLayeredInput) -> PixelTaskOutput: + img_h, img_w = head_input.target_output_shape + patch_h, patch_w = img_h // self.patch_size, img_w // self.patch_size + + # Process the hidden states + x: torch.Tensor = torch.stack( + [proj(feat.contiguous()) for proj, feat in zip(self.projects, head_input.list_features)], dim=1 + ).sum(dim=1) + + # Upsample stage + # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) + for i, block in enumerate(self.upsample_blocks): + # UV coordinates is for awareness of image aspect ratio + uv = normalized_view_plane_uv( + width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + for layer in block: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + + # (patch_h * 8, patch_w * 8) -> (img_h, img_w) + x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + uv = normalized_view_plane_uv( + width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device + ) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + + if isinstance(self.output_block, nn.ModuleList): + output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block] + else: + raise NotImplementedError() + + return PixelTaskOutput(decoded_channels=torch.cat(output, dim=1)) + + +if __name__ == "__main__": + import time + + import numpy as np + import torch.cuda.profiler as profiler + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # Ensure the model is on GPU + num_runs = 20 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Instantiate the model and move to GPU + head = MoGeConvFeature( + patch_size=14, + num_features=4, + input_feature_dims=[1024, 768, 768, 768], + dim_out=[2, 1], + dim_proj=512, + dim_upsample=[256, 128, 64], + dim_times_res_block_hidden=2, + num_res_blocks=2, + res_block_norm="group_norm", + last_res_blocks=0, + last_conv_channels=32, + last_conv_size=1, + pretrained_checkpoint_path=None, + ).to(device) + + # Define input shape + image_shape = (560, 420) + batch_size = 10 + patch_size = 14 + patch_num = (image_shape[0] // patch_size, image_shape[1] // patch_size) + + # Generate input features and move to GPU + input_feats = [ + torch.randn(batch_size, dim, *patch_num, device=device, requires_grad=True) for dim in [1024, 768, 768, 768] + ] + + # Wrap input into PredictionHeadLayeredInput + model_input = PredictionHeadLayeredInput(list_features=input_feats, target_output_shape=image_shape) + + with torch.autocast("cuda", dtype=torch.float16): + # Warm-up to stabilize GPU performance + for _ in range(3): + output = head(model_input) + output.decoded_channels.sum().backward() + torch.cuda.synchronize() + + # Clear memory cache + torch.cuda.empty_cache() + + # Lists to store results + forward_times = [] + backward_times = [] + memory_usages = [] + + for _ in range(num_runs): + # Start measuring time + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # Reset memory stats + torch.cuda.reset_peak_memory_stats() + memory_before = torch.cuda.max_memory_allocated(device) + + # Forward pass + start_event.record() + output = head(model_input) + end_event.record() + torch.cuda.synchronize() + forward_time = start_event.elapsed_time(end_event) # Time in milliseconds + + # Backward pass + start_event.record() + output.decoded_channels.sum().backward() + end_event.record() + torch.cuda.synchronize() + backward_time = start_event.elapsed_time(end_event) + + # Memory usage + memory_after = torch.cuda.max_memory_allocated(device) + peak_memory = memory_after - memory_before + + forward_times.append(forward_time) + backward_times.append(backward_time) + memory_usages.append(peak_memory / 1e6) # Convert to MB + + # Compute mean and standard deviation + fwd_mean, fwd_std = np.mean(forward_times), np.std(forward_times) + bwd_mean, bwd_std = np.mean(backward_times), np.std(backward_times) + mem_mean, mem_std = np.mean(memory_usages), np.std(memory_usages) + + print(f"Forward Pass Time: {fwd_mean:.2f} ± {fwd_std:.2f} ms") + print(f"Backward Pass Time: {bwd_mean:.2f} ± {bwd_std:.2f} ms") + print(f"Peak GPU Memory Usage: {mem_mean:.2f} ± {mem_std:.2f} MB") diff --git a/UniCeption/uniception/models/prediction_heads/pose_head.py b/UniCeption/uniception/models/prediction_heads/pose_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1b52444bd9d427741eb735595a92852026988c0d --- /dev/null +++ b/UniCeption/uniception/models/prediction_heads/pose_head.py @@ -0,0 +1,181 @@ +""" +Pose head implementation +Downstream heads assume inputs of size BCHW (B: batch, C: channels, H: height, W: width); +The Pose head implementation is based on Reloc3r and MaRePo +References: +https://github.com/ffrivera0/reloc3r/blob/main/reloc3r/pose_head.py +""" + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from uniception.models.prediction_heads.base import PredictionHeadInput, SummaryTaskOutput + + +class ResConvBlock(nn.Module): + """ + 1x1 convolution residual block implementation based on Reloc3r & MaRePo + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + *args, + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.head_skip = ( + nn.Identity() + if self.in_channels == self.out_channels + else nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + ) + self.res_conv1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) + self.res_conv2 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + self.res_conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0) + + def forward(self, res): + x = F.relu(self.res_conv1(res)) + x = F.relu(self.res_conv2(x)) + x = F.relu(self.res_conv3(x)) + res = self.head_skip(res) + x + return res + + +class PoseHead(nn.Module): + """ + Pose regression head implementation based on Reloc3r & MaRePo + """ + + def __init__( + self, + patch_size: int, + input_feature_dim: int, + num_resconv_block: int = 2, + rot_representation_dim: int = 4, + pretrained_checkpoint_path: str = None, + *args, + **kwargs, + ): + """ + Initialize the pose head. + + Args: + patch_size : int, the patch size of the transformer used to generate the input features + input_feature_dim : int, the input feature dimension + num_resconv_block : int, the number of residual convolution blocks + rot_representation_dim : int, the dimension of the rotation representation + pretrained_checkpoint_path : str, path to pretrained checkpoint (default: None) + """ + super().__init__() + self.patch_size = patch_size + self.input_feature_dim = input_feature_dim + self.num_resconv_block = num_resconv_block + self.rot_representation_dim = rot_representation_dim + self.pretrained_checkpoint_path = pretrained_checkpoint_path + + # Initialize the hidden dimension of the pose head based on the patch size + self.output_dim = 4 * (self.patch_size**2) + + # Initialize the projection layer for the hidden dimension of the pose head + self.proj = nn.Conv2d( + in_channels=self.input_feature_dim, + out_channels=self.output_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + # Initialize sequential layers of the pose head + self.res_conv = nn.ModuleList( + [copy.deepcopy(ResConvBlock(self.output_dim, self.output_dim)) for _ in range(self.num_resconv_block)] + ) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.more_mlps = nn.Sequential( + nn.Linear(self.output_dim, self.output_dim), + nn.ReLU(), + nn.Linear(self.output_dim, self.output_dim), + nn.ReLU(), + ) + self.fc_t = nn.Linear(self.output_dim, 3) + self.fc_rot = nn.Linear(self.output_dim, self.rot_representation_dim) + + # Load the pretrained checkpoint if provided + if self.pretrained_checkpoint_path is not None: + print(f"Loading pretrained pose head from {self.pretrained_checkpoint_path}") + ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False) + print(self.load_state_dict(ckpt["model"])) + + def forward(self, feature_input: PredictionHeadInput): + """ + Forward interface for the pose head. + The pose head requires an adapter on the final output to get the pose. + + Args: + feature_input : PredictionHeadInput, the input features + - last_feature : torch.Tensor, the last feature tensor + + Returns: + SummaryTaskOutput, the output of the pose head + - decoded_channels : torch.Tensor, the decoded channels + """ + # Get the patch-level features from the input + feat = feature_input.last_feature # (B, C, H, W) + + # Check the input dimensions + assert ( + feat.shape[1] == self.input_feature_dim + ), f"Input feature dimension {feat.shape[1]} does not match expected dimension {self.input_feature_dim}" + + # Apply the projection layer to the patch-level features + feat = self.proj(feat) # (B, PC, H, W) + + # Apply the residual convolution blocks to the projected features + for i in range(self.num_resconv_block): + feat = self.res_conv[i](feat) + + # Apply the average pooling layer to the residual convolution output + feat = self.avgpool(feat) # (B, PC, 1, 1) + + # Flatten the average pooled features + feat = feat.view(feat.size(0), -1) # (B, PC) + + # Apply the more MLPs to the flattened features + feat = self.more_mlps(feat) # (B, PC) + + # Apply the final linear layers to the more MLPs output + feat_t = self.fc_t(feat) # (B, 3) + feat_rot = self.fc_rot(feat) # (B, self.rot_representation_dim) + + # Concatenate the translation and rotation features + output_feat = torch.cat([feat_t, feat_rot], dim=1) # (B, 3 + self.rot_representation_dim + + return SummaryTaskOutput(decoded_channels=output_feat) + + +if __name__ == "__main__": + # Init an example pose head + pose_head = PoseHead( + patch_size=16, + input_feature_dim=768, + num_resconv_block=2, + rot_representation_dim=4, + pretrained_checkpoint_path=None, + ) + + # Create a dummy input tensor with shape (B, C, H, W) + dummy_input = torch.randn(1, 768, 14, 14) # Example input + + # Run dummy forward pass + output = pose_head(PredictionHeadInput(last_feature=dummy_input)) + + # Check the output shape + assert output.decoded_channels.shape == (1, 7), "Output shape mismatch" + + print("Pose head test passed!") diff --git a/UniCeption/uniception/models/utils/config.py b/UniCeption/uniception/models/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..5756680f3bf64e7751daae2cabaf85e29140b8b9 --- /dev/null +++ b/UniCeption/uniception/models/utils/config.py @@ -0,0 +1,34 @@ +""" +Model Utils Config +""" + +import os +import warnings + +import torch + +__all__ = ["use_fused_attn", "set_fused_attn"] + +# Use torch.scaled_dot_product_attention where possible +_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention") +if "UNICEPTION_FUSED_ATTN" in os.environ: + _USE_FUSED_ATTN = int(os.environ["UNICEPTION_FUSED_ATTN"]) +else: + _USE_FUSED_ATTN = 1 # 0 == off, 1 == on + + +def use_fused_attn() -> bool: + "Return whether to use torch.nn.functional.scaled_dot_product_attention" + return _USE_FUSED_ATTN > 0 + + +def set_fused_attn(enable: bool = True): + "Set whether to use torch.nn.functional.scaled_dot_product_attention" + global _USE_FUSED_ATTN + if not _HAS_FUSED_ATTN: + warnings.warn("This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.") + return + if enable: + _USE_FUSED_ATTN = 1 + else: + _USE_FUSED_ATTN = 0 diff --git a/UniCeption/uniception/models/utils/intermediate_feature_return.py b/UniCeption/uniception/models/utils/intermediate_feature_return.py new file mode 100644 index 0000000000000000000000000000000000000000..29a8b523b7d7fc255c69d065e47d2b24b4cceeee --- /dev/null +++ b/UniCeption/uniception/models/utils/intermediate_feature_return.py @@ -0,0 +1,97 @@ +""" +Utils for Intermediate Feature Returner +References: +HuggingFace PyTorch Image Models (Timm) +""" + +from typing import List, Optional, Tuple, Union + +import torch + +try: + from torch import _assert +except ImportError: + + def _assert(condition: bool, message: str): + assert condition, message + + +class IntermediateFeatureReturner: + "Intermediate Feature Returner Class" + + def __init__( + self, + indices: Optional[Union[int, List[int]]] = None, + norm_intermediate: bool = True, + stop_early: bool = False, + intermediates_only: bool = True, + ): + """ + Init class for returning intermediate features from the encoder. + + Args: + indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to None. Options: + - None: Return all intermediate layers. + - int: Return the last n layers. + - List[int]: Return the intermediate layers at the specified indices. + norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True. + stop_early (bool, optional): Whether to stop early. Defaults to False. + intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True. + """ + self.indices = indices + self.norm_intermediate = norm_intermediate + self.stop_early = stop_early + self.intermediates_only = intermediates_only + + +def feature_take_indices( + num_features: int, + indices: Optional[Union[int, List[int]]] = None, + as_set: bool = False, +) -> Tuple[List[int], int]: + """Determine the absolute feature indices to 'take' from. + + Note: This function can be called in forwar() so must be torchscript compatible, + which requires some incomplete typing and workaround hacks. + + Args: + num_features: total number of features to select from + indices: indices to select, + None -> select all + int -> select last n + list/tuple of int -> return specified (-ve indices specify from end) + as_set: return as a set + + Returns: + List (or set) of absolute (from beginning) indices, Maximum index + """ + if indices is None: + indices = num_features # all features if None + + if isinstance(indices, int): + # convert int -> last n indices + _assert(0 < indices <= num_features, f"last-n ({indices}) is out of range (1 to {num_features})") + take_indices = [num_features - indices + i for i in range(indices)] + else: + take_indices: List[int] = [] + for i in indices: + idx = num_features + i if i < 0 else i + _assert(0 <= idx < num_features, f"feature index {idx} is out of range (0 to {num_features - 1})") + take_indices.append(idx) + + if not torch.jit.is_scripting() and as_set: + return set(take_indices), max(take_indices) + + return take_indices, max(take_indices) + + +class FeatureWrapper: + def __init__(self, tensor): + self.tensor = tensor + self.features = getattr(tensor, "features", tensor) + + def __getattr__(self, attr): + return getattr(self.tensor, attr) + + def __getitem__(self, idx): + return self.tensor[idx] diff --git a/UniCeption/uniception/models/utils/positional_encoding.py b/UniCeption/uniception/models/utils/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..1474d9a03220dbfac8ab0cd9fa77a13002e1ad44 --- /dev/null +++ b/UniCeption/uniception/models/utils/positional_encoding.py @@ -0,0 +1,23 @@ +""" +Helper function for positional encoding in UniCeption +""" + +import torch + + +class PositionGetter(object): + "Helper class to return positions of patches." + + def __init__(self): + "Initialize the position getter." + self.cache_positions = {} + + def __call__(self, b, h, w, device): + "Get the positions for a given batch size, height, and width. Uses caching." + if not (h, w) in self.cache_positions: + x = torch.arange(w, device=device) + y = torch.arange(h, device=device) + self.cache_positions[h, w] = torch.cartesian_prod(y, x) # (h, w, 2) + pos = self.cache_positions[h, w].view(1, h * w, 2).expand(b, -1, 2).clone() + + return pos diff --git a/UniCeption/uniception/models/utils/transformer_blocks.py b/UniCeption/uniception/models/utils/transformer_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..0983f556b0496bcd273b4574fb0acacfafe95fb2 --- /dev/null +++ b/UniCeption/uniception/models/utils/transformer_blocks.py @@ -0,0 +1,964 @@ +""" +Utils for Common Transformer Blocks used in UniCeption +References: +HuggingFace PyTorch Image Models (Timm) +CroCoV2 +""" + +import collections.abc +import math +from itertools import repeat +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.jit import Final + +from uniception.models.utils.config import use_fused_attn + +torch.backends.cuda.matmul.allow_tf32 = True + + +def _ntuple(n): + "Helper function to create n-tuple." + + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.0): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + + return x + + +class Attention(nn.Module): + "Self-Attention Layer" + + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + latent_attn_dim: Optional[int] = None, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + custom_positional_encoding: Callable = None, + ): + """ + Initialize the Attention layer. + + Args: + dim (int): Dimension of input features + latent_attn_dim (int): Dimension of latent attention features (default: None) + num_heads (int): Number of attention heads (default: 8) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + attn_drop (float): Dropout rate for attention weights (default: 0.) + proj_drop (float): Dropout rate for output (default: 0.) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + """ + super().__init__() + + if latent_attn_dim is not None: + assert latent_attn_dim % num_heads == 0, "latent_attn_dim should be divisible by num_heads" + self.latent_attn_dim = latent_attn_dim + self.latent_attn = True + else: + self.latent_attn = False + assert dim % num_heads == 0, "dim should be divisible by num_heads" + + self.num_heads = num_heads + self.head_dim = dim // num_heads if not self.latent_attn else latent_attn_dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = use_fused_attn() + + self.qkv = ( + nn.Linear(dim, dim * 3, bias=qkv_bias) + if not self.latent_attn + else nn.Linear(dim, latent_attn_dim * 3, bias=qkv_bias) + ) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) if not self.latent_attn else nn.Linear(latent_attn_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.custom_positional_encoding = custom_positional_encoding + + def forward(self, x: torch.Tensor, xpos: torch.Tensor = None) -> torch.Tensor: + """ + Forward pass of the Attention layer. + + Args: + x (torch.Tensor): Input features + xpos (torch.Tensor): Positions of tokens (required when using custom positional encoding) + + Returns: + torch.Tensor: Output features of same shape as input + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q, k = self.q_norm(q), self.k_norm(k) + + if self.custom_positional_encoding is not None: + assert ( + xpos is not None + ), "Positions of tokens (xpos) are a required input when using custom positional encoding" + q = self.custom_positional_encoding(q, xpos) + k = self.custom_positional_encoding(k, xpos) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=(self.attn_drop.p if self.training else 0.0), scale=self.scale + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttention(nn.Module): + "Cross-Attention Layer" + + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + custom_positional_encoding: Callable = None, + ): + """ + Initialize the Cross-Attention layer. + + Args: + dim (int): Dimension of input features + num_heads (int): Number of attention heads (default: 8) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + attn_drop (float): Dropout rate for attention weights (default: 0.) + proj_drop (float): Dropout rate for output (default: 0.) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + """ + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = use_fused_attn() + + self.projq = nn.Linear(dim, dim, bias=qkv_bias) + self.projk = nn.Linear(dim, dim, bias=qkv_bias) + self.projv = nn.Linear(dim, dim, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.custom_positional_encoding = custom_positional_encoding + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + qpos: torch.Tensor = None, + kpos: torch.Tensor = None, + ) -> torch.Tensor: + """ + Forward pass of the Cross-Attention layer. + + Args: + query (torch.Tensor): Query features + key (torch.Tensor): Key features + value (torch.Tensor): Value features + qpos (torch.Tensor): Positions of queries (required when using custom positional encoding) + kpos (torch.Tensor): Positions of keys (required when using custom positional encoding) + + Returns: + torch.Tensor: Output features of same shape as input + """ + B, Nq, C = query.shape + Nk = key.shape[1] + Nv = value.shape[1] + + q = self.projq(query).reshape(B, Nq, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = self.projk(key).reshape(B, Nk, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + v = self.projv(value).reshape(B, Nv, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + q, k = self.q_norm(q), self.k_norm(k) + + if self.custom_positional_encoding is not None: + assert ( + qpos is not None + ), "Positions of queries (qpos) are a required input when using custom positional encoding" + assert ( + kpos is not None + ), "Positions of keys (kpos) are a required input when using custom positional encoding" + q = self.custom_positional_encoding(q, qpos) + k = self.custom_positional_encoding(k, kpos) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=(self.attn_drop.p if self.training else 0.0), scale=self.scale + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + "Layer Scale Layer" + + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ): + """ + Initialize the Layer Scale layer + + Args: + dim (int): Dimension of input features + init_values (float): Initial value for LayerScale gamma (default: 1e-5) + inplace (bool): Whether to perform inplace operations (default: False) + """ + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + "Forward pass of the Layer Scale layer" + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class SelfAttentionBlock(nn.Module): + "Self-Attention Block" + + def __init__( + self, + dim: int, + num_heads: int, + latent_attn_dim: Optional[int] = None, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + custom_positional_encoding: Callable = None, + ): + """ + Initialize the Self-Attention Block. + + Args: + dim (int): Dimension of input features + num_heads (int): Number of attention heads + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + latent_attn_dim=latent_attn_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + custom_positional_encoding=custom_positional_encoding, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.custom_positional_encoding = custom_positional_encoding + + def forward(self, x: torch.Tensor, xpos: torch.Tensor = None) -> torch.Tensor: + """ + Forward pass of the Self-Attention Block. + + Args: + x (torch.Tensor): Input features + xpos (torch.Tensor): Positions of tokens (required when using custom positional encoding) + + Returns: + torch.Tensor: Output features of same shape as input + """ + if self.custom_positional_encoding is not None: + assert ( + xpos is not None + ), "Positions of tokens (xpos) are a required input when using custom positional encoding" + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), xpos))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class CrossAttentionBlock(nn.Module): + "Cross-Attention Block" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + custom_positional_encoding: Callable = None, + norm_cross_tokens: bool = True, + ): + """ + Initialize the Cross-Attention Block. + + Args: + dim (int): Dimension of input features + num_heads (int): Number of attention heads + mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + proj_drop (float): Dropout rate for output (default: 0.) + attn_drop (float): Dropout rate for attention weights (default: 0.) + init_values (float): Initial value for LayerScale gamma (default: None) + drop_path (float): Dropout rate for stochastic depth (default: 0.) + act_layer (nn.Module): Activation layer (default: nn.GELU) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + mlp_layer (nn.Module): MLP layer (default: Mlp) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + norm_cross_tokens (bool): Whether to normalize cross tokens (default: True) + + Returns: + torch.Tensor: Output features of same shape as input + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + custom_positional_encoding=custom_positional_encoding, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm_y = norm_layer(dim) if norm_cross_tokens else nn.Identity() + self.custom_positional_encoding = custom_positional_encoding + self.norm2 = norm_layer(dim) + self.cross_attn = CrossAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + custom_positional_encoding=custom_positional_encoding, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm3 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls3 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path3 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + xpos: torch.Tensor = None, + ypos: torch.Tensor = None, + ) -> torch.Tensor: + """ + Forward pass of the Cross-Attention Block. + + Args: + x (torch.Tensor): Input features + y (torch.Tensor): Cross features + xpos (torch.Tensor): Positions of tokens (required when using custom positional encoding) + ypos (torch.Tensor): Positions of cross tokens (required when using custom positional encoding) + + Returns: + torch.Tensor: Output features of same shape as input + """ + if self.custom_positional_encoding is not None: + assert ( + xpos is not None + ), "Positions of tokens (xpos) are a required input when using custom positional encoding" + assert ( + ypos is not None + ), "Positions of cross tokens (ypos) are a required input when using custom positional encoding" + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), xpos))) + y_ = self.norm_y(y) + x = x + self.drop_path2(self.ls2(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))) + x = x + self.drop_path3(self.ls3(self.mlp(self.norm3(x)))) + return x + + +def dummy_positional_encoding(x, xpos): + "Dummy function for positional encoding of tokens" + x = x + xpos = xpos + return x + + +# copied from DiffTrsformer +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False): + super().__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter("weight", None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + + def extra_repr(self) -> str: + return f"dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}" + + +def lambda_init_fn(depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) # copied from DiffTrsformer + + +class DiffAttention(nn.Module): + "Differential Self-Attention Layer" + + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + depth: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + custom_positional_encoding: Callable = None, + ): + """ + Initialize the DiffAttention layer. + + Args: + dim (int): Dimension of input features + depth (int): Depth of the current layer, used in lambda initialization (default: 0) + num_heads (int): Number of attention heads (default: 8) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + attn_drop (float): Dropout rate for attention weights (default: 0.) + proj_drop (float): Dropout rate for output (default: 0.) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + """ + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads // 2 + self.scale = self.head_dim**-0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.custom_positional_encoding = custom_positional_encoding + + # DiffTransformer specific + self.lambda_init = lambda_init_fn(depth) + self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + + self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) + + def forward(self, x: torch.Tensor, xpos: torch.Tensor = None) -> torch.Tensor: + """ + Forward pass of the Attention layer. + + Args: + x (torch.Tensor): Input features + xpos (torch.Tensor): Positions of tokens (required when using custom positional encoding) + + Returns: + torch.Tensor: Output features of same shape as input + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim * 2) + q, k, v = torch.chunk(qkv, 3, dim=2) # B, N, Nh, Dh + + q = q.view(B, N, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = k.view(B, N, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3) + v = v.view(B, N, self.num_heads, 2 * self.head_dim).permute(0, 2, 1, 3) + + q, k = self.q_norm(q), self.k_norm(k) + + if self.custom_positional_encoding is not None: + assert ( + xpos is not None + ), "Positions of tokens (xpos) are a required input when using custom positional encoding" + q = self.custom_positional_encoding(q, xpos) + k = self.custom_positional_encoding(k, xpos) + + q1, q2 = q.chunk(2, dim=1) # split heads dimension into two + k1, k2 = k.chunk(2, dim=1) # split heads dimension into two + + if self.fused_attn: + attn1 = F.scaled_dot_product_attention( + q1, k1, v, dropout_p=(self.attn_drop.p if self.training else 0.0), scale=self.scale + ) + attn2 = F.scaled_dot_product_attention( + q2, k2, v, dropout_p=(self.attn_drop.p if self.training else 0.0), scale=self.scale + ) + else: + q1 = q1 * self.scale + attn = q1 @ k1.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn1 = attn @ v + + q2 = q2 * self.scale + attn = q2 @ k2.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn2 = attn @ v + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + attn = attn1 - lambda_full * attn2 + + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + attn = attn.reshape(B, N, self.num_heads * 2 * self.head_dim) + + x = self.proj(attn) + x = self.proj_drop(x) + return x + + +class DiffCrossAttention(nn.Module): + "Differential Cross-Attention Layer, following https://arxiv.org/abs/2410.05258" + + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + depth: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + custom_positional_encoding: Callable = None, + ): + """ + Initialize the Cross-Attention layer. + + Args: + dim (int): Dimension of input features + depth (int): Depth of the current layer, used in lambda initialization (default: 0) + num_heads (int): Number of attention heads (default: 8) + qkv_bias (bool): Whether to include bias in qkv projection (default: False) + qk_norm (bool): Whether to normalize q and k (default: False) + attn_drop (float): Dropout rate for attention weights (default: 0.) + proj_drop (float): Dropout rate for output (default: 0.) + norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm) + custom_positional_encoding (Callable): Custom positional encoding function (default: None) + """ + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads // 2 + self.scale = self.head_dim**-0.5 + self.fused_attn = use_fused_attn() + + self.projq = nn.Linear(dim, dim, bias=qkv_bias) + self.projk = nn.Linear(dim, dim, bias=qkv_bias) + self.projv = nn.Linear(dim, dim, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + # DiffTransformer specific + self.lambda_init = lambda_init_fn(depth) + self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)) + + self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True) + + self.custom_positional_encoding = custom_positional_encoding + + def lambda_init_fn(self, depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) # copied from DiffTrsformer + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + qpos: torch.Tensor = None, + kpos: torch.Tensor = None, + ) -> torch.Tensor: + """ + Forward pass of the Cross-Attention layer. + + Args: + query (torch.Tensor): Query features + key (torch.Tensor): Key features + value (torch.Tensor): Value features + qpos (torch.Tensor): Positions of queries (required when using custom positional encoding) + kpos (torch.Tensor): Positions of keys (required when using custom positional encoding) + + Returns: + torch.Tensor: Output features of same shape as input + """ + B, Nq, C = query.shape + Nk = key.shape[1] + Nv = value.shape[1] + + q = self.projq(query).reshape(B, Nq, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = self.projk(key).reshape(B, Nk, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3) + v = self.projv(value).reshape(B, Nv, self.num_heads, 2 * self.head_dim).permute(0, 2, 1, 3) + q, k = self.q_norm(q), self.k_norm(k) + + if self.custom_positional_encoding is not None: + assert ( + qpos is not None + ), "Positions of queries (qpos) are a required input when using custom positional encoding" + assert ( + kpos is not None + ), "Positions of keys (kpos) are a required input when using custom positional encoding" + q = self.custom_positional_encoding(q, qpos) + k = self.custom_positional_encoding(k, kpos) + + q1, q2 = q.chunk(2, dim=1) # split heads dimension into two + k1, k2 = k.chunk(2, dim=1) # split heads dimension into two + + if self.fused_attn: + attn1 = F.scaled_dot_product_attention( + q1, k1, v, dropout_p=(self.attn_drop.p if self.training else 0.0), scale=self.scale + ) + attn2 = F.scaled_dot_product_attention( + q2, k2, v, dropout_p=(self.attn_drop.p if self.training else 0.0), scale=self.scale + ) + else: + q1 = q1 * self.scale + attn = q1 @ k1.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn1 = attn @ v + + q2 = q2 * self.scale + attn = q2 @ k2.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn2 = attn @ v + + attn1 = attn1.transpose(1, 2) # B, Nq, Nh, Dh + attn2 = attn2.transpose(1, 2) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + attn = attn1 - lambda_full * attn2 + + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + attn = attn.reshape(B, Nq, self.num_heads * 2 * self.head_dim) + + x = self.proj(attn) + x = self.proj_drop(x) + return x + + +class DiffSelfAttentionBlock(SelfAttentionBlock): + "Differential Self-Attention Block" + + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + custom_positional_encoding: Callable = None, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_drop=proj_drop, + attn_drop=attn_drop, + init_values=init_values, + drop_path=drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + mlp_layer=mlp_layer, + custom_positional_encoding=custom_positional_encoding, + ) + + self.attn = DiffAttention( + dim, + depth, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + custom_positional_encoding=custom_positional_encoding, + ) + + +class DiffCrossAttentionBlock(CrossAttentionBlock): + "Differential Cross-Attention Block" + + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + custom_positional_encoding: Callable = None, + norm_cross_tokens: bool = True, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_drop=proj_drop, + attn_drop=attn_drop, + init_values=init_values, + drop_path=drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + mlp_layer=mlp_layer, + custom_positional_encoding=custom_positional_encoding, + norm_cross_tokens=norm_cross_tokens, + ) + + self.cross_attn = DiffCrossAttention( + dim, + depth, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + custom_positional_encoding=custom_positional_encoding, + ) + + +if __name__ == "__main__": + # Init Attention & CrossAttention classes + self_attn = Attention(dim=768, custom_positional_encoding=dummy_positional_encoding) + cross_attn = CrossAttention(dim=768, custom_positional_encoding=dummy_positional_encoding) + + # Perform dummy inference with the Attention classes + dummy_input = torch.randn((1, 256, 768)) + dummy_x = torch.arange(16) + dummy_y = torch.arange(16) + dummy_xpos = torch.cartesian_prod(dummy_y, dummy_x).view(1, 256, 2).expand(1, -1, 2).clone() + self_attn_output = self_attn(dummy_input, dummy_xpos) + cross_attn_output = cross_attn(dummy_input, dummy_input, dummy_input, dummy_xpos, dummy_xpos) + print("Init of Attention & CrossAttention classes is successful!") + + # Init SelfAttentionBlock & CrossAttentionBlock + self_attn_block = SelfAttentionBlock(dim=768, num_heads=16, custom_positional_encoding=dummy_positional_encoding) + cross_attn_block = CrossAttentionBlock(dim=768, num_heads=16, custom_positional_encoding=dummy_positional_encoding) + + # Perform dummy inference with the Attention blocks + self_attn_block_output = self_attn_block(dummy_input, dummy_xpos) + cross_attn_block_output = cross_attn_block(dummy_input, dummy_input, dummy_xpos, dummy_xpos) + print("Init of SelfAttentionBlock & CrossAttentionBlock is successful!") + + # Init DiffAttention & DiffCrossAttention classes + diff_self_attn = DiffAttention(dim=768, depth=0, custom_positional_encoding=dummy_positional_encoding) + diff_cross_attn = DiffCrossAttention(dim=768, depth=0, custom_positional_encoding=dummy_positional_encoding) + + # Perform dummy inference with the DiffAttention classes + diff_self_attn_output = diff_self_attn(dummy_input, dummy_xpos) + diff_cross_attn_output = diff_cross_attn(dummy_input, dummy_input, dummy_input, dummy_xpos, dummy_xpos) + print("Init of DiffAttention & DiffCrossAttention classes is successful!") + + # Init DiffSelfAttentionBlock & DiffCrossAttentionBlock + diff_self_attn_block = DiffSelfAttentionBlock( + dim=768, depth=0, num_heads=8, custom_positional_encoding=dummy_positional_encoding + ) + diff_cross_attn_block = DiffCrossAttentionBlock( + dim=768, depth=0, num_heads=8, custom_positional_encoding=dummy_positional_encoding + ) + + # Perform dummy inference with the DiffAttention blocks + diff_self_attn_block_output = diff_self_attn_block(dummy_input, dummy_xpos) + diff_cross_attn_block_output = diff_cross_attn_block(dummy_input, dummy_input, dummy_xpos, dummy_xpos) + + print("Init of DiffSelfAttentionBlock & DiffCrossAttentionBlock is successful!") diff --git a/UniCeption/uniception/utils/profile.py b/UniCeption/uniception/utils/profile.py new file mode 100644 index 0000000000000000000000000000000000000000..300eb760f7ab92bb8f8f3d9ce557ac9b839ff53c --- /dev/null +++ b/UniCeption/uniception/utils/profile.py @@ -0,0 +1,13 @@ +import torch.utils.benchmark as benchmark + + +def benchmark_torch_function(f, *args, **kwargs): + t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) + return t0.blocked_autorange().mean * 1e3 # Milliseconds + + +def benchmark_torch_function_with_result(f, *args, **kwargs): + result = f(*args, **kwargs) + t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) + time_in_ms = t0.blocked_autorange().mean * 1e3 # Milliseconds + return time_in_ms, result diff --git a/UniCeption/uniception/utils/viz.py b/UniCeption/uniception/utils/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..2e7d934ef1840f460f04b60c82c14d9f7c3ca1bc --- /dev/null +++ b/UniCeption/uniception/utils/viz.py @@ -0,0 +1,99 @@ +""" +Utilitary functions for visualizations +""" + +from argparse import ArgumentParser, Namespace +from distutils.util import strtobool + + +def str2bool(v): + return bool(strtobool(v)) + + +def script_add_rerun_args(parser: ArgumentParser) -> None: + """ + Add common Rerun script arguments to `parser`. + + Change Log from https://github.com/rerun-io/rerun/blob/29eb8954b08e59ff96943dc0677f46f7ea4ea734/rerun_py/rerun_sdk/rerun/script_helpers.py#L65: + - Added default portforwarding url for ease of use + - Update parser types + + Parameters + ---------- + parser : ArgumentParser + The parser to add arguments to. + + Returns + ------- + None + """ + parser.add_argument("--headless", type=str2bool, nargs="?", const=True, default=True, help="Don't show GUI") + parser.add_argument( + "--connect", + dest="connect", + type=str2bool, + nargs="?", + const=True, + default=True, + help="Connect to an external viewer", + ) + parser.add_argument( + "--serve", + dest="serve", + type=str2bool, + nargs="?", + const=True, + default=False, + help="Serve a web viewer (WARNING: experimental feature)", + ) + parser.add_argument( + "--url", + type=str, + default="rerun+http://127.0.0.1:/proxy", + help="Connect to this HTTP(S) URL. Replace with the actual port number.", + ) + parser.add_argument("--save", type=str, default=None, help="Save data to a .rrd file at this path") + parser.add_argument( + "-o", + "--stdout", + dest="stdout", + action="store_true", + help="Log data to standard output, to be piped into a Rerun Viewer", + ) + + +def init_rerun_args( + headless=True, connect=True, serve=False, url="rerun+http://127.0.0.1:/proxy", save=None, stdout=False +) -> Namespace: + """ + Initialize common Rerun script arguments. + + Parameters + ---------- + headless : bool, optional + Don't show GUI, by default True + connect : bool, optional + Connect to an external viewer, by default True + serve : bool, optional + Serve a web viewer (WARNING: experimental feature), by default False + url : str, optional + Connect to this HTTP(S) URL, by default "rerun+http://127.0.0.1:/proxy". Replace with the actual port number. + save : str, optional + Save data to a .rrd file at this path, by default None + stdout : bool, optional + Log data to standard output, to be piped into a Rerun Viewer, by default False + + Returns + ------- + Namespace + The parsed arguments. + """ + rerun_args = Namespace() + rerun_args.headless = headless + rerun_args.connect = connect + rerun_args.serve = serve + rerun_args.url = url + rerun_args.save = save + rerun_args.stdout = stdout + + return rerun_args diff --git a/app.py b/app.py index 04cc31aa8d0e06aeaac3b59bb361ed71d831e43f..01675f867ba01d4b48dff19662bae239668b5a0c 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,207 @@ +import cv2 +import flow_vis import gradio as gr +import numpy as np +import torch +from PIL import Image -def greet(name): - return "Hello " + name + "!!" +import sys +sys.path.append("uniflowmatch/") +sys.path.append("UniCeption/uniception") -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() + +from uniflowmatch.models.ufm import ( + UniFlowMatchClassificationRefinement, + UniFlowMatchConfidence, +) +from uniflowmatch.utils.viz import warp_image_with_flow + +# Global model variable +model = None +USE_REFINEMENT_MODEL = False + + +def initialize_model(use_refinement: bool = False): + """Initialize the model - call this once at startup""" + global model, USE_REFINEMENT_MODEL + USE_REFINEMENT_MODEL = use_refinement + + try: + if USE_REFINEMENT_MODEL: + print("Loading UFM Refinement model from infinity1096/UFM-Refine...") + model = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine") + else: + print("Loading UFM Base model from infinity1096/UFM-Base...") + model = UniFlowMatchConfidence.from_pretrained("infinity1096/UFM-Base") + + # Set model to evaluation mode + if hasattr(model, "eval"): + model.eval() + + print("Model loaded successfully!") + return True + except Exception as e: + print(f"Error loading model: {e}") + return False + + +def process_images(source_image, target_image, model_type_choice): + """ + Process two uploaded images and return visualizations + """ + if source_image is None or target_image is None: + return None, None, None, "Please upload both images." + + # Reinitialize model if type has changed + current_refinement = model_type_choice == "Refinement Model" + if current_refinement != USE_REFINEMENT_MODEL: + print(f"Switching to {model_type_choice}...") + initialize_model(current_refinement) + + if model is None: + return None, None, None, "Model not loaded. Please restart the application." + + try: + # Convert PIL images to numpy arrays + source_np = np.array(source_image) + target_np = np.array(target_image) + + # Ensure images are RGB + if len(source_np.shape) == 3 and source_np.shape[2] == 3: + source_rgb = source_np + else: + source_rgb = cv2.cvtColor(source_np, cv2.COLOR_BGR2RGB) + + if len(target_np.shape) == 3 and target_np.shape[2] == 3: + target_rgb = target_np + else: + target_rgb = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB) + + print(f"Processing images with shapes: Source {source_rgb.shape}, Target {target_rgb.shape}") + + # === Predict Correspondences === + with torch.no_grad(): + result = model.predict_correspondences_batched( + source_image=torch.from_numpy(source_rgb), + target_image=torch.from_numpy(target_rgb), + ) + + # Extract results based on your model's output structure + flow_output = result.flow.flow_output[0].cpu().numpy() + covisibility = result.covisibility.mask[0].cpu().numpy() + + print(f"Flow output shape: {flow_output.shape}") + print(f"Covisibility shape: {covisibility.shape}") + + # === Create Visualizations === + + # 1. Flow visualization + flow_vis_image = flow_vis.flow_to_color(flow_output.transpose(1, 2, 0)) + flow_pil = Image.fromarray(flow_vis_image.astype(np.uint8)) + + # 2. Covisibility visualization - direct gray image + covisibility_gray = (covisibility * 255).astype(np.uint8) + covisibility_pil = Image.fromarray(covisibility_gray, mode="L") + + # 3. Warped image using actual warp function + warped_image = warp_image_with_flow(source_rgb, None, target_rgb, flow_output.transpose(1, 2, 0)) + warped_image = covisibility[..., None] * warped_image + (1 - covisibility[..., None]) * 255 * np.ones_like( + warped_image + ) + warped_image = (warped_image / 255.0).clip(0, 1) + warped_pil = Image.fromarray((warped_image * 255).astype(np.uint8)) + + status_msg = f"Processing completed with {model_type_choice}" + + return flow_pil, covisibility_pil, warped_pil, status_msg + + except Exception as e: + error_msg = f"Error processing images: {str(e)}" + print(error_msg) + return None, None, None, error_msg + + +def create_demo(): + """Create the Gradio interface""" + + with gr.Blocks(title="UniFlowMatch Demo") as demo: + gr.Markdown("# UniFlowMatch Demo") + gr.Markdown("Upload two images to see optical flow visualization") + + # Input section + with gr.Row(): + source_input = gr.Image(label="Source Image", type="pil") + target_input = gr.Image(label="Target Image", type="pil") + + # Model selection + model_type = gr.Radio(choices=["Base Model", "Refinement Model"], value="Base Model", label="Model Type") + + # Process button + process_btn = gr.Button("Process Images") + + # Status + status_output = gr.Textbox(label="Status", interactive=False) + + # Output section + with gr.Row(): + flow_output = gr.Image(label="Flow Visualization") + covisibility_output = gr.Image(label="Covisibility Mask") + warped_output = gr.Image(label="Warped Source Image") + + # Example images + gr.Examples( + examples=[ + ["examples/image_pairs/fire_academy_0.png", "examples/image_pairs/fire_academy_1.png"], + ["examples/image_pairs/scene_0.png", "examples/image_pairs/scene_1.png"], + ["examples/image_pairs/bike_0.png", "examples/image_pairs/bike_1.png"], + ["examples/image_pairs/cook_0.png", "examples/image_pairs/cook_1.png"], + ["examples/image_pairs/building_0.png", "examples/image_pairs/building_1.png"], + ], + inputs=[source_input, target_input], + label="Example Image Pairs", + ) + + # Event handlers + process_btn.click( + fn=process_images, + inputs=[source_input, target_input, model_type], + outputs=[flow_output, covisibility_output, warped_output, status_output], + ) + + # Auto-process when both images are uploaded + def auto_process(source, target, model_choice): + if source is not None and target is not None: + return process_images(source, target, model_choice) + return None, None, None, "Upload both images to start processing." + + for input_component in [source_input, target_input, model_type]: + input_component.change( + fn=auto_process, + inputs=[source_input, target_input, model_type], + outputs=[flow_output, covisibility_output, warped_output, status_output], + ) + + return demo + + +if __name__ == "__main__": + # Initialize model + print("Initializing UniFlowMatch model...") + model_loaded = initialize_model(use_refinement=False) # Start with base model + + if not model_loaded: + print("Error: Model failed to load. Please check your model installation and HuggingFace access.") + print("Make sure you have:") + print("1. Installed uniflowmatch package") + print("2. Have internet access for downloading pretrained models") + print("3. All required dependencies installed") + exit(1) + + # Create and launch demo + demo = create_demo() + demo.launch( + share=True, # Set to True to create a public link + server_name="0.0.0.0", # Allow external connections + server_port=7860, # Default Gradio port + show_error=True, + ) diff --git a/examples/example_ufm_output.png b/examples/example_ufm_output.png new file mode 100644 index 0000000000000000000000000000000000000000..bfd92a19ee4f51a89e44a56e3031ac7c8bab8f37 --- /dev/null +++ b/examples/example_ufm_output.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f7bd999d1f0bbec4bf8cf8aa0d712bdda52a4174ab603de4eda557f492dffa5 +size 455843 diff --git a/examples/image_pairs/bike_0.png b/examples/image_pairs/bike_0.png new file mode 100644 index 0000000000000000000000000000000000000000..2cf22b20fd9c76532a317368ae43a8abe41e2326 --- /dev/null +++ b/examples/image_pairs/bike_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7a6b85ecd43e1752faa4b131b948f0d46148cb020898ff06e40eeed68da0b09 +size 1551541 diff --git a/examples/image_pairs/bike_1.png b/examples/image_pairs/bike_1.png new file mode 100644 index 0000000000000000000000000000000000000000..cd10f8b7a3f2f07fc0b3b72c280b62be0cdb4302 --- /dev/null +++ b/examples/image_pairs/bike_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a987fa3f65de09b5847c70a195459ebb63c552dfb747e98aa19a26819048555 +size 987667 diff --git a/examples/image_pairs/building_0.png b/examples/image_pairs/building_0.png new file mode 100644 index 0000000000000000000000000000000000000000..79e8cfcc7499ef97b584831bf166dc78d83be302 --- /dev/null +++ b/examples/image_pairs/building_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a95ad3deaf3b5fcf7dc1591052e50aceaccb0fd2c842decb29f45b0c7c3d70a9 +size 1599006 diff --git a/examples/image_pairs/building_1.png b/examples/image_pairs/building_1.png new file mode 100644 index 0000000000000000000000000000000000000000..18804a1e565841f3890701c1d7d3ff25284f6837 --- /dev/null +++ b/examples/image_pairs/building_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdff780719d400fe5427780ad06cf1a23d2584c3e4fb3e3ca0cfa7756007bcc4 +size 1407132 diff --git a/examples/image_pairs/cook_0.png b/examples/image_pairs/cook_0.png new file mode 100644 index 0000000000000000000000000000000000000000..dfd116e20ac02318601687b459a6e56722a148be --- /dev/null +++ b/examples/image_pairs/cook_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1cc34812cc5241ca71031dc923d2993d9bb7f7e58f6a516c908a7dab52f7adc +size 927821 diff --git a/examples/image_pairs/cook_1.png b/examples/image_pairs/cook_1.png new file mode 100644 index 0000000000000000000000000000000000000000..14b34559ca4d1b1a32a8514d6bff82cc68746859 --- /dev/null +++ b/examples/image_pairs/cook_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd8ac578ea367dfb2c2b999a0f4b183706cad8c5e48476244256c66b8d08f55d +size 1529183 diff --git a/examples/image_pairs/fire_academy_0.png b/examples/image_pairs/fire_academy_0.png new file mode 100644 index 0000000000000000000000000000000000000000..46eafc60acae4298e7d63fab2bd050da30c2e75d --- /dev/null +++ b/examples/image_pairs/fire_academy_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84b15ace1cc5812f5bfda0156831ccc7279757eefb1a88f500c24e70dc33c3ae +size 1167965 diff --git a/examples/image_pairs/fire_academy_1.png b/examples/image_pairs/fire_academy_1.png new file mode 100644 index 0000000000000000000000000000000000000000..50e2f8d95aa0a1cc515f7357ec3808b10cc268d7 --- /dev/null +++ b/examples/image_pairs/fire_academy_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa613448fd6bcbfee4c5ce3ebdf5af50a54f3013916e2db47471f2e45a8bbe7c +size 1274757 diff --git a/examples/image_pairs/scene_0.png b/examples/image_pairs/scene_0.png new file mode 100644 index 0000000000000000000000000000000000000000..aac5b862a1aef0b214514b896cc24d882a1c83c7 --- /dev/null +++ b/examples/image_pairs/scene_0.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:056d12106883bbcb6c1c7cdc0534b83d754e74211a45e4cdbc088e3d41055704 +size 1405585 diff --git a/examples/image_pairs/scene_1.png b/examples/image_pairs/scene_1.png new file mode 100644 index 0000000000000000000000000000000000000000..1078c2721aa315ad7adcd026e304b3a0f5e91561 --- /dev/null +++ b/examples/image_pairs/scene_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e257524443d9ea86e6db2c128a35fb00fcb4cc321782926d10705bfc339bfe80 +size 1426869 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..5bc7bbdbaead627a8a15f8939e9395c3007e0dfb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "uniflowmatch" +version = "0.1.0" +description = "Your project description" +authors = [{ name = "Yuchen Zhang", email = "yuchenz7@andrew.cmu.edu" }] +dependencies = [ + "torch", + "torchvision", + "torchaudio", + "numpy", + "matplotlib", + "opencv-python", + "flow_vis", + "huggingface_hub", + "einops" +] + +[project.optional-dependencies] +dev = [ + "black", + "pre-commit" +] + +[tool.black] +line-length = 120 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | cuda + | dist +)/ +''' + +[tool.isort] +profile = "black" +line_length = 120 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9852966ecd8fa85c6dd3ac2a52ddf8a5789e6e06 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +torch +torchvision +torchaudio +numpy +matplotlib +opencv-python +flow_vis +huggingface_hub +einops +gradio \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..e379710f7d7ebba15886ea281365a879b841eff2 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +"""Package installation setup.""" + +from setuptools import setup + +setup( + name="uniflowmatch", + version="0.0.0", + description="UniFlowMatch Project", + author="AirLab", + license="BSD Clause-3", + packages=["uniception", "uniflowmatch"], # Directly specify the package + package_dir={ + "uniception": "UniCeption/uniception", # Map uniception package + "uniflowmatch": "uniflowmatch", # Map uniflowmatch package + }, + include_package_data=True, +) diff --git a/uniflowmatch/models/base.py b/uniflowmatch/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..308bdb45e47cf149a9491f2665ae3c1235d0ac10 --- /dev/null +++ b/uniflowmatch/models/base.py @@ -0,0 +1,334 @@ +""" +Base class of the UniFlowMatch training system. +""" + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch + + +@dataclass +class UFMFlowFieldOutput: + """ + Output interface of the flow field prediction network. + """ + + flow_output: torch.Tensor + flow_covariance: Optional[torch.Tensor] = None + flow_covariance_inv: Optional[torch.Tensor] = None + flow_covariance_log_det: Optional[torch.Tensor] = None + + +@dataclass +class UFMMaskFieldOutput: + """ + Output interface of the mask prediction network. + """ + + mask: torch.Tensor + logits: torch.Tensor + + +@dataclass +class UFMClassificationRefinementOutput: + """ + Output interface of the classification refinement network. + """ + + # the flow output of the regression step, with shape [B, 2, H, W]. + # it is the initial flow output, which is used to get the first local feature maps for the residual. + regression_flow_output: torch.Tensor + + # residual is the output of the refinement step, with shape [B, 2, H, W]. + # it is added to the initial flow output to get the final flow output. + residual: torch.Tensor + + # log_softmax is + # the logarithm of + # the softmax of + # similarity of the pixel's feature + # to that of its neighborhood of the flow prediction + # in the other image. + # it have shape [B, H, W, P, P], the similarity of pixel at [b, h, w] to its neighborhood [P, P] centered at regression_flow_output[b, h, w] + log_softmax: torch.Tensor + + feature_map_0: torch.Tensor + feature_map_1: torch.Tensor + + +@dataclass +class UFMOutputInterface: + """ + Output interface of the UniFlowMatch training system. + """ + + flow: Optional[UFMFlowFieldOutput] = None + + # Refinement output (for training and visualization) + classification_refinement: Optional[UFMClassificationRefinementOutput] = None + + # auxiliary ouputs + covisibility: Optional[UFMMaskFieldOutput] = None + + +from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT + +from uniflowmatch.utils.flow_resizing import ( + AutomaticShapeSelection, + ResizeToFixedManipulation, + unmap_predicted_channels, + unmap_predicted_flow, +) + + +class UniFlowMatchModelsBase(torch.nn.Module): + def __init__(self, inference_resolution: Optional[Union[List[Tuple[int, int]], Tuple[int, int]]] = None): + super().__init__() + + if inference_resolution is None: + inference_resolution = [(560, 420)] + + if isinstance(inference_resolution[0], int): # Handle the case for single resolution + inference_resolution = [inference_resolution] + + self.inference_resolution = inference_resolution + + self.image_scaler = AutomaticShapeSelection( + *[ResizeToFixedManipulation((resolution[1], resolution[0])) for resolution in inference_resolution], + strategy="closest_aspect", # will inference on the trained aspect ratio that is closest to the input image 1 + ) + + def forward(self, view1, view2) -> UFMOutputInterface: + """ + Forward interface of correspondence prediction networks. + + Args: + - view1 (Dict[str, Any]): Input view 1 + - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type + - instance (List[int]): List of instance indices, or id of the input image + - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT + - view2 (Dict[str, Any]): Input view 2 + - (same structure as view1) + Returns: + - Dict[str, Any]: Output results + - flow [Required] (Dict[str, torch.Tensor]): Flow output + - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW + - [Optional] flow_covariance + - [Optional] flow_covariance_inv + - [Optional] flow_covariance_log_det + - occlusion [Optional] (Dict[str, torch.Tensor]): Occlusion output + - [Optional] mask + - [Optional] logits + """ + raise NotImplementedError("Implement this method in derived classes") + + def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: + """ + Get parameter groups for optimizer. This methods guides the optimizer + to apply correct learning rate to different parts of the model. + + Returns: + - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer + """ + + raise NotImplementedError("Implement this method in derived classes") + + def predict_correspondences_batched( + self, + source_image: torch.Tensor, + target_image: torch.Tensor, + data_norm_type: Optional[str] = None, + ) -> UFMOutputInterface: + """ + Predict correspondences between source and target images. + + This method generates random correspondences for demonstration purposes. + + Args: + source_image (torch.Tensor): Tensor of shape BCHW/BHWC/CHW/HWC, dtype of uint8 or float32 The source image tensor. + target_image (torch.Tensor): Tensor of shape BCHW/BHWC/CHW/HWC, dtype of uint8 or float32 The target image tensor. + + Returns: + UFMOutputInterface: + - flow + - flow_output (torch.Tensor): Tensor of shape (B, 2, H, W) representing the flow output in the original image space. + - occlusion + - mask (torch.Tensor): Tensor of shape (B, H, W) representing the covisibility in range [0, 1]. 1 = fully covisible, 0 = fully occluded or out of range. + """ + + assert isinstance(source_image, torch.Tensor) and isinstance( + target_image, torch.Tensor + ), "source_image and target_image must be torch.Tensors" + assert source_image.dim() in [3, 4], "source_image must have dimensions 3 or 4" + assert target_image.dim() in [3, 4], "target_image must have dimensions 3 or 4" + + batched = source_image.dim() == 4 + + if not batched: + # add batch dimension + source_image = source_image.unsqueeze(0) + target_image = target_image.unsqueeze(0) + + # check the channel + if source_image.shape[1] == 3 and target_image.shape[1] == 3: + pass # do nothing because the image is in BCHW format + elif source_image.shape[-1] == 3 and target_image.shape[-1] == 3: + # convert to BCHW + source_image = source_image.permute(0, 3, 1, 2) + target_image = target_image.permute(0, 3, 1, 2) + else: + raise ValueError("source_image and target_image must have 3 channels in either BCHW or BHWC format") + + required_data_norm_type = self.encoder.data_norm_type + + image_device = source_image.device + + if source_image.dtype == torch.float32: + assert data_norm_type is not None, "data_norm_type must be provided for float32 images" + assert ( + data_norm_type in IMAGE_NORMALIZATION_DICT + ), f"data_norm_type must be one of {list(IMAGE_NORMALIZATION_DICT.keys())}" + + if data_norm_type != required_data_norm_type: + # apply transformation to the correct from the old normalization + prev_mean = ( + IMAGE_NORMALIZATION_DICT[data_norm_type].mean.view(1, 3, 1, 1).to(image_device, non_blocking=True) + ) + prev_std = ( + IMAGE_NORMALIZATION_DICT[data_norm_type].std.view(1, 3, 1, 1).to(image_device, non_blocking=True) + ) + mean = ( + IMAGE_NORMALIZATION_DICT[required_data_norm_type] + .mean.view(1, 3, 1, 1) + .to(image_device, non_blocking=True) + ) + std = ( + IMAGE_NORMALIZATION_DICT[required_data_norm_type] + .std.view(1, 3, 1, 1) + .to(image_device, non_blocking=True) + ) + + source_image = source_image * (prev_std / std) + (prev_mean - mean) / std + target_image = target_image * (prev_std / std) + (prev_mean - mean) / std + + elif source_image.dtype == torch.uint8: + # convert into float32 and apply normalization + mean = ( + IMAGE_NORMALIZATION_DICT[required_data_norm_type] + .mean.view(1, 3, 1, 1) + .to(image_device, non_blocking=True) + ) + std = ( + IMAGE_NORMALIZATION_DICT[required_data_norm_type] + .std.view(1, 3, 1, 1) + .to(image_device, non_blocking=True) + ) + + source_image = (source_image.float() / 255.0 - mean) / std + target_image = (target_image.float() / 255.0 - mean) / std + else: + raise ValueError("source_image and target_image must be of type torch.float32 or torch.uint8") + + # Now all the inputs are normalized according to the model's encoder and organized in BCHW format + return self._predict_correspondences_batched(source_image, target_image) + + def _predict_correspondences_batched( + self, + source_image: torch.Tensor, + target_image: torch.Tensor, + ) -> UFMOutputInterface: + assert isinstance(source_image, torch.Tensor), "source_image must be a torch.Tensor" + assert isinstance(target_image, torch.Tensor), "target_image must be a torch.Tensor" + + assert source_image.dim() == 4, "source_image must be of shape (B, 3, H, W)" + assert target_image.dim() == 4, "target_image must be of shape (B, 3, H, W)" + assert source_image.shape[1] == 3, "source_image must be of shape (B, 3, H, W)" + assert target_image.shape[1] == 3, "target_image must be of shape (B, 3, H, W)" + + assert source_image.dtype == torch.float32, "source_image must be of dtype torch.float32" + assert target_image.dtype == torch.float32, "target_image must be of dtype torch.float32" + + source_shape_hw = source_image.shape[2:] + target_shape_hw = target_image.shape[2:] + + # Scale images to one of the model's trained resolution. + ( + scaled_img0, # The scaled source image + scaled_img1, # The scaled target image + img0_region_source, # Where in the source image is captured in the scaled image + img1_region_source, # Where in the target image is captured in the scaled image + img0_region_representation, # Region in the source image is captured in this region in the scaled image + img1_region_representation, # same as above, but for the target image + ) = self.image_scaler(source_image.permute(0, 2, 3, 1), target_image.permute(0, 2, 3, 1)) + + scaled_img0 = scaled_img0.permute(0, 3, 1, 2) + scaled_img1 = scaled_img1.permute(0, 3, 1, 2) + + # Run a forward pass + view1 = {"img": scaled_img0, "symmetrized": False, "data_norm_type": self.encoder.data_norm_type} + view2 = {"img": scaled_img1, "symmetrized": False, "data_norm_type": self.encoder.data_norm_type} + + with torch.no_grad(): + with torch.autocast("cuda", torch.bfloat16): + result = self(view1, view2) + + rescaled_ufm_result = UFMOutputInterface() + + # rescale flow + flow_output = result.flow.flow_output + flow_unmapped, flow_unmap_validity = unmap_predicted_flow( + flow_output, + img0_region_representation, + img1_region_representation, + img0_region_source, + img1_region_source, + source_shape_hw, + target_shape_hw, + ) + + rescaled_ufm_result.flow = UFMFlowFieldOutput( + flow_output=flow_unmapped, + ) + + # rescale covariance if it exists + if result.flow.flow_covariance is not None: + flow_covariance = result.flow.flow_covariance + flow_covariance_unmapped, _ = unmap_predicted_channels( + flow_covariance, + img0_region_representation, + img0_region_source, + source_shape_hw, + ) + + # scale covariance in the correct way + w_pred = scaled_img0.shape[3] + h_pred = scaled_img0.shape[2] + + w_final = source_shape_hw[1] + h_final = source_shape_hw[0] + + w_ratio, h_ratio = w_final / w_pred, h_final / h_pred + + flow_covariance_unmapped *= ( + torch.tensor([w_ratio**2, h_ratio**2, w_ratio * h_ratio]) + .view(1, 3, 1, 1) + .to(flow_covariance_unmapped.device) + ) + + rescaled_ufm_result.flow.flow_covariance = flow_covariance_unmapped + + # rescale occlusion if it exists + if result.covisibility is not None: + occlusion_mask = result.covisibility.mask + covisibility_unmapped, _ = unmap_predicted_channels( + occlusion_mask, + img0_region_representation, + img0_region_source, + source_shape_hw, + ) + + covisibility_unmapped = covisibility_unmapped.squeeze(1) + rescaled_ufm_result.covisibility = UFMMaskFieldOutput(mask=covisibility_unmapped, logits=None) + + return rescaled_ufm_result diff --git a/uniflowmatch/models/ufm.py b/uniflowmatch/models/ufm.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6e54e77d80173f20cda6127897af60505e57ec --- /dev/null +++ b/uniflowmatch/models/ufm.py @@ -0,0 +1,1166 @@ +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +from huggingface_hub import PyTorchModelHubMixin +from torch import nn + +# Only enable flash attention backend +from uniception.models.encoders import ViTEncoderInput, feature_returner_encoder_factory +from uniception.models.info_sharing import INFO_SHARING_CLASSES, MultiViewTransformerInput +from uniception.models.prediction_heads.adaptors import ( + ConfidenceAdaptor, + Covariance2DAdaptor, + FlowAdaptor, + FlowWithConfidenceAdaptor, + MaskAdaptor, +) +from uniception.models.prediction_heads.base import AdaptorMap, PredictionHeadInput, PredictionHeadLayeredInput +from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor +from uniception.models.prediction_heads.mlp_feature import MLPFeature +from uniception.models.prediction_heads.moge_conv import MoGeConvFeature + +from uniflowmatch.models.base import ( + UFMClassificationRefinementOutput, + UFMFlowFieldOutput, + UFMMaskFieldOutput, + UFMOutputInterface, + UniFlowMatchModelsBase, +) +from uniflowmatch.models.unet_encoder import UNet +from uniflowmatch.models.utils import get_meshgrid_torch + +CLASSNAME_TO_ADAPTOR_CLASS = { + "FlowWithConfidenceAdaptor": FlowWithConfidenceAdaptor, + "FlowAdaptor": FlowAdaptor, + "MaskAdaptor": MaskAdaptor, + "Covariance2DAdaptor": Covariance2DAdaptor, + "ConfidenceAdaptor": ConfidenceAdaptor, +} + + +# dust3r data structure for reducing passing duplicate images through the encoder +def is_symmetrized(gt1, gt2): + "Function to check if input pairs are symmetrized, i.e., (a, b) and (b, a) always exist in the input" + x = gt1["instance"] + y = gt2["instance"] + if len(x) == len(y) and len(x) == 1: + return False # special case of batchsize 1 + ok = True + for i in range(0, len(x), 2): + ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i]) + + return ok + + +def interleave(tensor1, tensor2): + "Interleave two tensors along the first dimension (used to avoid redundant encoding for symmetrized pairs)" + res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) + res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) + return res1, res2 + + +def modify_state_dict(original_state_dict, mappings): + """ + Modify state dict keys according to replacement mappings + + Args: + original_state_dict: Loaded checkpoint state dict + mappings: Dictionary of {old_key_substr: new_key_substr_or_None} + + Returns: + Modified state dictionary with updated keys + """ + new_state_dict = {} + + for k, v in original_state_dict.items(): + new_key = None + skip = False + + # Check for all possible replacements + for replace_key, replace_value in mappings.items(): + if replace_key in k: + if replace_value is None: + skip = True + break # Skip this key entirely + else: + new_key = k.replace(replace_key, replace_value) + break # Only apply first matching replacement + + if skip: + continue + + new_state_dict[new_key if new_key is not None else k] = v + + return new_state_dict + + +class UniFlowMatch(UniFlowMatchModelsBase, PyTorchModelHubMixin): + """ + UniFlowMatch model. + """ + + def __init__( + self, + # Encoder configurations + encoder_str: str, + encoder_kwargs: Dict[str, Any], + # Info sharing & output head structure configurations + info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported + # Information sharing configurations + info_sharing_str: str = "global_attention", + info_sharing_kwargs: Dict[str, Any] = {}, + # skip-connections between encoder and info-sharing + encoder_skip_connection: Optional[List[int]] = None, + info_sharing_skip_connection: Optional[List[int]] = None, + # Prediction Heads & Adaptors + head_type: str = "dpt", + feature_head_kwargs: Dict[str, Any] = {}, + adaptors_kwargs: Dict[str, Any] = {}, + # Load Pretrained Weights + pretrained_checkpoint_path: Optional[str] = None, + # Inference Settings + inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH + *args, + **kwargs, + ): + """ + Initialize the UniFlowMarch Model + + - encoder_str (str): Encoder string + - encoder_kwargs (Dict[str, Any]): Encoder configurations + + - info_sharing_and_head_structure (str): Info sharing and head structure configurations + - "dual+single": Dual view info sharing and single view prediction head + + - info_sharing_str (str): Info sharing method + - "global_attention_transformer": Global attention transformer + - info_sharing_kwargs (Dict[str, Any]): Info sharing configurations + + """ + UniFlowMatchModelsBase.__init__(self, inference_resolution=inference_resolution, *args, **kwargs) + + PyTorchModelHubMixin.__init__(self) + + # assertion on architectures + assert info_sharing_and_head_structure == "dual+single", "Only dual+single is supported now" + + # initialize the skip-connections + self.encoder_skip_connection = encoder_skip_connection + self.info_sharing_skip_connection = info_sharing_skip_connection + + # initialize encoder + self.encoder: nn.Module = feature_returner_encoder_factory(encoder_str, **encoder_kwargs) + + # initialize info-sharing module + assert head_type != "linear", "Linear head is not supported, because it have major disadvantage to DPTs" + self.head_type = head_type + + self.info_sharing: nn.Module = INFO_SHARING_CLASSES[info_sharing_str][1](**info_sharing_kwargs) + + self.head1: nn.Module = self._initialize_prediction_heads(head_type, feature_head_kwargs, adaptors_kwargs) + + # load pretrained weights + if pretrained_checkpoint_path is not None: + ckpt = torch.load(pretrained_checkpoint_path, map_location="cpu") + + if "state_dict" in ckpt: + # we are loading from training checkpoint directly. + model_state_dict = ckpt["state_dict"] + model_state_dict = { + k[6:]: v for k, v in model_state_dict.items() if k.startswith("model.") + } # remove "model." prefix + + model_state_dict = modify_state_dict( + model_state_dict, {"feature_matching_proj": None, "encoder.model.mask_token": None} + ) + + self.load_state_dict(model_state_dict, strict=True) + else: + model_state_dict = ckpt["model"] + + load_result = self.load_state_dict(model_state_dict, strict=False) + assert len(load_result.missing_keys) == 0, f"Missing keys: {load_result.missing_keys}" + + @classmethod + def from_pretrained_ckpt(cls, pretrained_model_name_or_path, strict=True, **kw): + if os.path.isfile(pretrained_model_name_or_path): + ckpt = torch.load(pretrained_model_name_or_path, map_location="cpu") + + # remove base_pretrained_checkpoint_path from the model args + if "base_pretrained_checkpoint_path" in ckpt["model_args"]: + ckpt["model_args"].pop("base_pretrained_checkpoint_path") + + # convert old model args into new definition + if "img_size" in ckpt["model_args"]: + # we are loading from a old benchmark checkpoint + print("Converting from a old benchmark checkpoint") + model_args = { + # Encoder args + "encoder_str": ckpt["model_args"]["encoder_str"], + "encoder_kwargs": ckpt["model_args"]["encoder_kwargs"], + # Info-sharing args + "info_sharing_and_head_structure": "dual+single", + "info_sharing_str": ckpt["model_args"]["info_sharing_type"], + "info_sharing_kwargs": { + "name": "info_sharing", + "input_embed_dim": ckpt["model_args"]["input_embed_dim"], + "num_views": 2, + "use_rand_idx_pe_for_non_reference_views": False, + "depth": ckpt["model_args"]["num_layers"], + "dim": ckpt["model_args"]["transformer_dim"], + "num_heads": ckpt["model_args"]["num_heads"], + "mlp_ratio": ckpt["model_args"]["mlp_ratio"], + "qkv_bias": ckpt["model_args"]["qkv_bias"], + "qk_norm": ckpt["model_args"]["qk_norm"], + "custom_positional_encoding": ckpt["model_args"]["position_encoding"], + "norm_intermediate": ckpt["model_args"]["normalize_intermediate"], + "indices": ckpt["model_args"]["returned_intermediate_layers"], + }, + # flow head args + "head_type": "dpt", + "feature_head_kwargs": ckpt["model_args"]["feature_head_kwargs"], + "adaptors_kwargs": ckpt["model_args"]["adaptors_kwargs"], + } + + if "covocc_feature_head_kwargs" in ckpt["model_args"]: + # if the model has a covocc head, we need to convert it to the new format + model_args["uncertainty_head_type"] = "dpt" + model_args["uncertainty_head_kwargs"] = { + "dpt_feature": ckpt["model_args"]["covocc_feature_head_kwargs"]["dpt_feature"], + "dpt_processor": ckpt["model_args"]["covocc_feature_head_kwargs"]["dpt_regr_processor"], + } + model_args["uncertainty_adaptors_kwargs"] = { + "flow_cov": ckpt["model_args"]["covocc_adaptors_kwargs"]["flow_cov"] + } + + ckpt["model_args"] = model_args + + # Update the old weights into the current format + ckpt["model"] = modify_state_dict( + ckpt["model"], + { + "covocc_head.dpt_feature": "uncertainty_head.0.0", + "covocc_head.dpt_regr_processor": "uncertainty_head.0.1", + "covocc_head.dpt_segm_processor": None, + "feature_matching_proj": None, + "encoder.model.mask_token": None, + }, + ) + + # remove the ket "pretrained_backbone_checkpoint_path" from the model args + if "pretrained_backbone_checkpoint_path" in ckpt["model_args"]: + ckpt["model_args"].pop("pretrained_backbone_checkpoint_path") + + model = cls(**ckpt["model_args"]) + model.load_state_dict(ckpt["model"], strict=strict) + return model + else: + raise ValueError(f"Pretrained model {pretrained_model_name_or_path} not found.") + + def _initialize_prediction_heads( + self, head_type: str, feature_head_kwargs: Dict[str, Any], adaptors_kwargs: Dict[str, Any] + ): + """ + Initialize prediction heads and adaptors + + Args: + - head_type (str): Head type, either "dpt" or "linear" + - feature_head_kwargs (Dict[str, Any]): Feature head configurations + - adaptors_kwargs (Dict[str, Any]): Adaptors configurations + + Returns: + - nn.Module: output head + adaptors + """ + feature_processor: nn.Module + if head_type == "dpt": + feature_processor = nn.Sequential( + DPTFeature(**feature_head_kwargs["dpt_feature"]), + DPTRegressionProcessor(**feature_head_kwargs["dpt_processor"]), + ) + elif head_type == "moge_conv": + feature_processor = MoGeConvFeature(**feature_head_kwargs) + else: + raise ValueError(f"Head type {head_type} not supported.") + + adaptors = self._initialize_adaptors(adaptors_kwargs) + + return nn.Sequential(feature_processor, AdaptorMap(*adaptors.values())) + + def _initialize_adaptors(self, adaptors_kwargs: Dict[str, Any]): + """ + Initialize a dict of adaptors + + Args: + - adaptors_kwargs (Dict[str, Any]): Adaptors configurations + + Returns: + - Dict[str, nn.Module]: dict of adaptors, from adaptor's name to the adaptor + """ + return { + name: CLASSNAME_TO_ADAPTOR_CLASS[configs["class"]](**configs["kwargs"]) + for name, configs in adaptors_kwargs.items() + } + + def _encode_image_pairs(self, img1, img2, data_norm_type): + "Encode two different batches of images (each batch can have different image shape)" + if img1.shape[-2:] == img2.shape[-2:]: + encoder_input = ViTEncoderInput(image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type) + encoder_output = self.encoder(encoder_input) + out_list, out2_list = [], [] + + for encoder_output_ in encoder_output: + out, out2 = encoder_output_.features.chunk(2, dim=0) + out_list.append(out) + out2_list.append(out2) + else: + raise NotImplementedError("Unequal Image sizes are not supported now") + + return out_list, out2_list + + def _encode_symmetrized(self, view1, view2, symmetrized=False): + "Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input" + img1 = view1["img"] + img2 = view2["img"] + + feat1_list, feat2_list = [], [] + + if symmetrized: + # Computing half of forward pass' + # modified in conjunction with UFM for not copying the images again. + # used to be: feat1, feat2 = self._encode_image_pairs(img1[::2], img2[::2], data_norm_type=view1["data_norm_type"]) + # be very carefult with this!!! + feat1_list_, feat2_list_ = self._encode_image_pairs( + img1[::2], img2[::2], data_norm_type=view1["data_norm_type"] + ) + + for feat1, feat2 in zip(feat1_list_, feat2_list_): + feat1, feat2 = interleave(feat1, feat2) + feat1_list.append(feat1) + feat2_list.append(feat2) + else: + feat1_list, feat2_list = self._encode_image_pairs(img1, img2, data_norm_type=view1["data_norm_type"]) + + return feat1_list, feat2_list + + def forward(self, view1, view2) -> UFMOutputInterface: + """ + Forward interface of correspondence prediction networks. + + Args: + - view1 (Dict[str, Any]): Input view 1 + - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type + - instance (List[int]): List of instance indices, or id of the input image + - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT + - view2 (Dict[str, Any]): Input view 2 + - (same structure as view1) + Returns: + - Dict[str, Any]: Output results + - flow [Required] (Dict[str, torch.Tensor]): Flow output + - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW + - [Optional] flow_covariance + - [Optional] flow_covariance_inv + - [Optional] flow_covariance_log_det + - covisibility [Optional] (Dict[str, torch.Tensor]): Covisibility output + - [Optional] mask + - [Optional] logits + """ + + # Get input shapes + _, _, height1, width1 = view1["img"].shape + _, _, height2, width2 = view2["img"].shape + shape1 = (int(height1), int(width1)) + shape2 = (int(height2), int(width2)) + + # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) + feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"]) + + # Pass the features through the info_sharing + info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]]) + + final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( + info_sharing_input + ) + + info_sharing_outputs = { + "1": [ + feat1_list[-1].float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(), + final_info_sharing_multi_view_feat.features[0].float().contiguous(), + ], + "2": [ + feat2_list[-1].float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(), + final_info_sharing_multi_view_feat.features[1].float().contiguous(), + ], + } + + result = UFMOutputInterface() + + # The prediction need precision, so we disable any autocasting here + with torch.autocast("cuda", torch.float32): + # run the collected info_sharing features through the prediction heads + head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) + + if "flow" in head_output1: + # output is flow only + result.flow = UFMFlowFieldOutput(flow_output=head_output1["flow"].value) + + if "flow_cov" in head_output1: + result.flow.flow_covariance = head_output1["flow_cov"].covariance + result.flow.flow_covariance_inv = head_output1["flow_cov"].inv_covariance + result.flow.flow_covariance_log_det = head_output1["flow_cov"].log_det + + if "non_occluded_mask" in head_output1: + result.covisibility = UFMMaskFieldOutput( + mask=head_output1["non_occluded_mask"].mask, + logits=head_output1["non_occluded_mask"].logits, + ) + + return result + + def _downstream_head(self, head_num, decout, img_shape): + "Run the respective prediction heads" + # if self.info_sharing_and_head_structure == "dual+single": + + head = getattr(self, f"head{head_num}") + if self.head_type == "linear": + head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) + elif self.head_type in ["dpt", "moge_conv"]: + head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) + + return head(head_input) + + def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: + """ + Get parameter groups for optimizer. This methods guides the optimizer + to apply correct learning rate to different parts of the model. + + Returns: + - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer + """ + + return { + "encoder": torch.nn.ParameterList(self.encoder.parameters()), + "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), + "output_head": torch.nn.ParameterList(self.head1.parameters()), + } + + +class UniFlowMatchConfidence(UniFlowMatch, PyTorchModelHubMixin): + """ + UniFlowMatch model with uncertainty estimation. + """ + + def __init__( + self, + # Encoder configurations + encoder_str: str, + encoder_kwargs: Dict[str, Any], + # Info sharing & output head structure configurations + info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported + # Information sharing configurations + info_sharing_str: str = "global_attention", + info_sharing_kwargs: Dict[str, Any] = {}, + # Prediction Heads & Adaptors + head_type: str = "dpt", + feature_head_kwargs: Dict[str, Any] = {}, + adaptors_kwargs: Dict[str, Any] = {}, + # Uncertainty Heads & Adaptors + detach_uncertainty_head: bool = True, + uncertainty_head_type: str = "dpt", + uncertainty_head_kwargs: Dict[str, Any] = {}, + uncertainty_adaptors_kwargs: Dict[str, Any] = {}, + # Load Pretrained Weights + pretrained_backbone_checkpoint_path: Optional[str] = None, + pretrained_checkpoint_path: Optional[str] = None, + # Inference Settings + inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH + *args, + **kwargs, + ): + UniFlowMatch.__init__( + self, + encoder_str=encoder_str, + encoder_kwargs=encoder_kwargs, + info_sharing_and_head_structure=info_sharing_and_head_structure, + info_sharing_str=info_sharing_str, + info_sharing_kwargs=info_sharing_kwargs, + head_type=head_type, + feature_head_kwargs=feature_head_kwargs, + adaptors_kwargs=adaptors_kwargs, + pretrained_checkpoint_path=pretrained_backbone_checkpoint_path, + inference_resolution=inference_resolution, + *args, + **kwargs, + ) + + PyTorchModelHubMixin.__init__(self) + + # initialize uncertainty heads + assert uncertainty_head_type == "dpt", "Only DPT is supported for uncertainty head now" + + self.uncertainty_head = self._initialize_prediction_heads( + uncertainty_head_type, uncertainty_head_kwargs, uncertainty_adaptors_kwargs + ) + self.uncertainty_adaptors = self._initialize_adaptors(uncertainty_adaptors_kwargs) + + assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now" + + self.detach_uncertainty_head = detach_uncertainty_head + + def forward(self, view1, view2) -> UFMOutputInterface: + """ + Forward interface of correspondence prediction networks. + + Args: + - view1 (Dict[str, Any]): Input view 1 + - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type + - instance (List[int]): List of instance indices, or id of the input image + - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT + - view2 (Dict[str, Any]): Input view 2 + - (same structure as view1) + Returns: + - Dict[str, Any]: Output results + - flow [Required] (Dict[str, torch.Tensor]): Flow output + - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW + - [Optional] flow_covariance + - [Optional] flow_covariance_inv + - [Optional] flow_covariance_log_det + - covisibility [Optional] (Dict[str, torch.Tensor]): Covisibiltiy output + - [Optional] mask + - [Optional] logits + """ + + # Get input shapes + _, _, height1, width1 = view1["img"].shape + _, _, height2, width2 = view2["img"].shape + shape1 = (int(height1), int(width1)) + shape2 = (int(height2), int(width2)) + + # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) + feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"]) + + # Pass the features through the info_sharing + info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]]) + + final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( + info_sharing_input + ) + + info_sharing_outputs = { + "1": [ + feat1_list[-1].float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(), + final_info_sharing_multi_view_feat.features[0].float().contiguous(), + ], + "2": [ + feat2_list[-1].float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(), + final_info_sharing_multi_view_feat.features[1].float().contiguous(), + ], + } + + info_sharing_outputs_detached = { + "1": [ + feat1_list[-1].detach().float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[0].detach().float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[0].detach().float().contiguous(), + final_info_sharing_multi_view_feat.features[0].detach().float().contiguous(), + ], + "2": [ + feat2_list[-1].detach().float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[1].detach().float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[1].detach().float().contiguous(), + final_info_sharing_multi_view_feat.features[1].detach().float().contiguous(), + ], + } + + result = UFMOutputInterface() + + # The prediction need precision, so we disable any autocasting here + with torch.autocast("cuda", torch.float32): + # run the collected info_sharing features through the prediction heads + head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) + head_output_uncertainty = self._downstream_head( + "uncertainty", + info_sharing_outputs_detached if self.detach_uncertainty_head else info_sharing_outputs, + shape1, + ) + + result.flow = UFMFlowFieldOutput( + flow_output=head_output1["flow"].value, + ) + + if "flow_cov" in head_output_uncertainty: + result.flow.flow_covariance = head_output_uncertainty["flow_cov"].covariance + result.flow.flow_covariance_inv = head_output_uncertainty["flow_cov"].inv_covariance + result.flow.flow_covariance_log_det = head_output_uncertainty["flow_cov"].log_det + + if "keypoint_confidence" in head_output_uncertainty: + result.keypoint_confidence = head_output_uncertainty["keypoint_confidence"].value.squeeze(1) + + if "non_occluded_mask" in head_output_uncertainty: + result.covisibility = UFMMaskFieldOutput( + mask=head_output_uncertainty["non_occluded_mask"].mask, + logits=head_output_uncertainty["non_occluded_mask"].logits, + ) + + return result + + def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: + """ + Get parameter groups for optimizer. This methods guides the optimizer + to apply correct learning rate to different parts of the model. + + Returns: + - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer + """ + + return { + "encoder": torch.nn.ParameterList(self.encoder.parameters()), + "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), + "output_head": torch.nn.ParameterList(self.head1.parameters()), + "uncertainty_head": torch.nn.ParameterList(self.uncertainty_head.parameters()), + } + + def _downstream_head(self, head_num, decout, img_shape): + "Run the respective prediction heads" + # if self.info_sharing_and_head_structure == "dual+single": + + head = getattr(self, f"head{head_num}") if head_num != "uncertainty" else self.uncertainty_head + + head_num = head_num if head_num != "uncertainty" else 1 # uncertainty head is always from branch 1 + + if self.head_type == "linear": + head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) + elif self.head_type in ["dpt", "moge_conv"]: + head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) + + return head(head_input) + + +class UniFlowMatchClassificationRefinement(UniFlowMatch, PyTorchModelHubMixin): + """ + The variant of UniFlowMatch with local classification for refinement. + """ + + def __init__( + self, + # Encoder configurations + encoder_str: str, + encoder_kwargs: Dict[str, Any], + # Info sharing & output head structure configurations + info_sharing_and_head_structure: str = "dual+single", # only dual+single is supported + # Information sharing configurations + info_sharing_str: str = "global_attention", + info_sharing_kwargs: Dict[str, Any] = {}, + # Prediction Heads & Adaptors + head_type: str = "dpt", + feature_head_kwargs: Dict[str, Any] = {}, + adaptors_kwargs: Dict[str, Any] = {}, + # Uncertainty Heads & Adaptors + detach_uncertainty_head: bool = True, + uncertainty_head_type: str = "dpt", + uncertainty_head_kwargs: Dict[str, Any] = {}, + uncertainty_adaptors_kwargs: Dict[str, Any] = {}, + # Classification Heads & Adaptors + temperature: float = 4, + use_unet_feature: bool = False, + classification_head_type: str = "patch_mlp", + classification_head_kwargs: Dict[str, Any] = {}, + feature_combine_method: str = "conv", + # Refinement Range + refinement_range: int = 5, + # Load Pretrained Weights + pretrained_backbone_checkpoint_path: Optional[str] = None, + pretrained_checkpoint_path: Optional[str] = None, + # Inference Settings + inference_resolution: Optional[Tuple[int, int]] = (560, 420), # WH + *args, + **kwargs, + ): + UniFlowMatch.__init__( + self, + encoder_str=encoder_str, + encoder_kwargs=encoder_kwargs, + info_sharing_and_head_structure=info_sharing_and_head_structure, + info_sharing_str=info_sharing_str, + info_sharing_kwargs=info_sharing_kwargs, + head_type=head_type, + feature_head_kwargs=feature_head_kwargs, + adaptors_kwargs=adaptors_kwargs, + pretrained_checkpoint_path=pretrained_backbone_checkpoint_path, + inference_resolution=inference_resolution, + *args, + **kwargs, + ) + + PyTorchModelHubMixin.__init__(self) + + # initialize uncertainty heads + assert classification_head_type == "patch_mlp", "Only DPT is supported for uncertainty head now" + self.classification_head_type = classification_head_type + + self.classification_head = self._initialize_classification_head(classification_head_kwargs) + + self.refinement_range = refinement_range + self.temperature = temperature + + assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now" + + self.use_unet_feature = use_unet_feature + + self.feature_combine_method = feature_combine_method + + # Unet experiment + if self.use_unet_feature: + self.unet_feature = UNet(in_channels=3, out_channels=16, features=[64, 128, 256, 512]) + + self.conv1 = nn.Conv2d(32, 32, kernel_size=1, stride=1, padding=0) + + if self.feature_combine_method == "conv": + self.conv2 = nn.Conv2d(32, 16, kernel_size=1, stride=1, padding=0) + elif self.feature_combine_method == "modulate": + self.conv2 = nn.Conv2d(16, 16, kernel_size=1, stride=1, padding=0) + + default_attention_bias = torch.zeros(self.refinement_range * self.refinement_range) + self.classification_bias = nn.Parameter(default_attention_bias) + + # initialize uncertainty heads + if len(uncertainty_head_kwargs) > 0: + assert uncertainty_head_type == "dpt", "Only DPT is supported for uncertainty head now" + + self.uncertainty_head = self._initialize_prediction_heads( + uncertainty_head_type, uncertainty_head_kwargs, uncertainty_adaptors_kwargs + ) + self.uncertainty_adaptors = self._initialize_adaptors(uncertainty_adaptors_kwargs) + + assert pretrained_checkpoint_path is None, "Pretrained weights are not supported for now" + + self.detach_uncertainty_head = detach_uncertainty_head + + def forward(self, view1, view2) -> UFMOutputInterface: + """ + Forward interface of correspondence prediction networks. + + Args: + - view1 (Dict[str, Any]): Input view 1 + - img (torch.Tensor): BCHW image tensor normalized according to encoder's data_norm_type + - instance (List[int]): List of instance indices, or id of the input image + - data_norm_type (str): Data normalization type, see uniception.models.encoders.IMAGE_NORMALIZATION_DICT + - view2 (Dict[str, Any]): Input view 2 + - (same structure as view1) + Returns: + - Dict[str, Any]: Output results + - flow [Required] (Dict[str, torch.Tensor]): Flow output + - [Required] flow_output (torch.Tensor): Flow output tensor, BCHW + - [Optional] flow_covariance + - [Optional] flow_covariance_inv + - [Optional] flow_covariance_log_det + - covisibility [Optional] (Dict[str, torch.Tensor]): Covisibility output + - [Optional] mask + - [Optional] logits + - classification [Optional]: Probability and targets of the classification head + """ + + # Get input shapes + _, _, height1, width1 = view1["img"].shape + _, _, height2, width2 = view2["img"].shape + shape1 = (int(height1), int(width1)) + shape2 = (int(height2), int(width2)) + + # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width) + feat1_list, feat2_list = self._encode_symmetrized(view1, view2, view1["symmetrized"]) + + # Pass the features through the info_sharing + info_sharing_input = MultiViewTransformerInput(features=[feat1_list[-1], feat2_list[-1]]) + + final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing( + info_sharing_input + ) + + info_sharing_outputs = { + "1": [ + feat1_list[-1].float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[0].float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[0].float().contiguous(), + final_info_sharing_multi_view_feat.features[0].float().contiguous(), + ], + "2": [ + feat2_list[-1].float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[1].float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[1].float().contiguous(), + final_info_sharing_multi_view_feat.features[1].float().contiguous(), + ], + } + + info_sharing_outputs_detached = { + "1": [ + feat1_list[-1].detach().float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[0].detach().float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[0].detach().float().contiguous(), + final_info_sharing_multi_view_feat.features[0].detach().float().contiguous(), + ], + "2": [ + feat2_list[-1].detach().float().contiguous(), + intermediate_info_sharing_multi_view_feat[0].features[1].detach().float().contiguous(), + intermediate_info_sharing_multi_view_feat[1].features[1].detach().float().contiguous(), + final_info_sharing_multi_view_feat.features[1].detach().float().contiguous(), + ], + } + + # optionally inference for U-Net Features + if self.use_unet_feature: + unet_feat1 = self.unet_feature(view1["img"]) + unet_feat2 = self.unet_feature(view2["img"]) + + result = UFMOutputInterface() + # The prediction need precision, so we disable any autocasting here + with torch.autocast("cuda", torch.float32): + # run the collected info_sharing features through the prediction heads + head_output1 = self._downstream_head(1, info_sharing_outputs, shape1) + flow_prediction = head_output1["flow"].value + + if hasattr(self, "uncertainty_head"): + # run the uncertainty head + head_output_uncertainty = self._downstream_head( + "uncertainty", + info_sharing_outputs_detached if self.detach_uncertainty_head else info_sharing_outputs, + shape1, + ) + + if "flow_cov" in head_output_uncertainty: + result.flow.flow_covariance = head_output_uncertainty["flow_cov"].covariance + result.flow.flow_covariance_inv = head_output_uncertainty["flow_cov"].inv_covariance + result.flow.flow_covariance_log_det = head_output_uncertainty["flow_cov"].log_det + + if "keypoint_confidence" in head_output_uncertainty: + result.keypoint_confidence = head_output_uncertainty["keypoint_confidence"].value.squeeze(1) + + if "non_occluded_mask" in head_output_uncertainty: + result.covisibility = UFMMaskFieldOutput( + mask=head_output_uncertainty["non_occluded_mask"].mask, + logits=head_output_uncertainty["non_occluded_mask"].logits, + ) + + # we run the classification head in the autocast environment bacause it is not regression + if self.classification_head_type == "patch_mlp": + # concatenate the last encoder feature with final info_sharing feature + + # use the first encoder feature, because it captures more low-level information, which is needed + # for refinement of the regressed flow. + classification_feat_1 = torch.cat( + [feat1_list[0].float().contiguous(), info_sharing_outputs["1"][-1]], dim=1 + ) + classification_feat_2 = torch.cat( + [feat2_list[0].float().contiguous(), info_sharing_outputs["2"][-1]], dim=1 + ) + + classification_input = PredictionHeadInput( + torch.cat([classification_feat_1, classification_feat_2], dim=0) + ) + + classification_features = self.classification_head(classification_input).decoded_channels + + if self.use_unet_feature: + + if self.feature_combine_method == "conv": + combined_features = torch.cat( + [classification_features, torch.cat([unet_feat1, unet_feat2], dim=0)], dim=1 + ) + + combined_features = self.conv1(combined_features) + combined_features = nn.functional.relu(combined_features) + combined_features = self.conv2(combined_features) + elif self.feature_combine_method == "modulate": + + combined_features = classification_features * torch.tanh( + torch.cat([unet_feat1, unet_feat2], dim=0) + ) + combined_features = self.conv2(combined_features) + + classification_features = combined_features + + classification_features0, classification_features1 = classification_features.chunk(2, dim=0) + + # refine the flow prediction with features from the classification head + for i in range(1): + residual, log_softmax_attention = self.classification_refinement( + flow_prediction, classification_features + ) + flow_prediction = flow_prediction + residual + + # Fill in the result + # WARNING: based on how the residual is computed, flow_prediction will have gradient cancelled by mathematics, + # so there will be no supervision to the flow prediction at all. We need to use specialized loss function to + # supervise the regression_flow_output. + result.flow = UFMFlowFieldOutput( + flow_output=flow_prediction, + ) + + result.classification_refinement = UFMClassificationRefinementOutput( + regression_flow_output=flow_prediction, + residual=residual, + log_softmax=log_softmax_attention, + feature_map_0=classification_features0, + feature_map_1=classification_features1, + ) + + return result + + # @torch.compile() + def classification_refinement(self, flow_prediction, classification_features) -> Dict[str, Any]: + """ + Use correlation between self feature and features around a local patch of the initial flow prediction + to refine the flow prediction. + + """ + + classification_features1, classification_features2 = classification_features.chunk(2, dim=0) + + neighborhood_features, neighborhood_flow_residual = self.obtain_neighborhood_features( + flow_estimation=flow_prediction, other_features=classification_features2, local_patch=self.refinement_range + ) + + residual, log_softmax_attention = self.compute_refinement_attention( + classification_features1, neighborhood_features, neighborhood_flow_residual + ) + + return residual, log_softmax_attention + + def compute_refinement_attention(self, classification_features1, neighborhood_features, neighborhood_flow_residual): + """ + Compute the attention for the refinement, with special processing + to fit + """ + + B, C, H, W = classification_features1.shape + P = self.refinement_range + + # reshape Q to B, H, W, 1, 1, C + classification_features1 = classification_features1.permute(0, 2, 3, 1).reshape(B * H * W, 1, C) + + # reshape K to B, H, W, 1, P^2, C + assert neighborhood_features.shape[0] == B + assert neighborhood_features.shape[1] == H + assert neighborhood_features.shape[2] == W + assert neighborhood_features.shape[3] == P + assert neighborhood_features.shape[4] == P + assert neighborhood_features.shape[5] == C + + neighborhood_features = neighborhood_features.reshape(B * H * W, P * P, C) + + # reshape V to B, H, W, 1, P^2, 2 + neighborhood_flow_residual = neighborhood_flow_residual.reshape(-1, P * P, 2) + + # compute the attention + attention_score = ( + torch.matmul(classification_features1, neighborhood_features.permute(0, 2, 1)) / self.temperature + ) + attention_score = attention_score + self.classification_bias + + attention = torch.nn.functional.softmax(attention_score, dim=-1) + log_softmax_attention = torch.nn.functional.log_softmax(attention_score, dim=-1) + + # compute the weighted sum + residual = torch.matmul(attention, neighborhood_flow_residual) + + # reshape the residual to B, H, W, 2, then B, 2, H, W + residual = residual.reshape(B, H, W, 2).permute(0, 3, 1, 2) + + return residual, log_softmax_attention.reshape(B, H, W, P, P) + + def _downstream_head(self, head_num, decout, img_shape): + "Run the respective prediction heads" + # if self.info_sharing_and_head_structure == "dual+single": + + head = getattr(self, f"head{head_num}") if head_num != "uncertainty" else self.uncertainty_head + + head_num = head_num if head_num != "uncertainty" else 1 # uncertainty head is always from branch 1 + + if self.head_type == "linear": + head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"]) + elif self.head_type in ["dpt", "moge_conv"]: + head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape) + + return head(head_input) + + def obtain_neighborhood_features( + self, flow_estimation: torch.Tensor, other_features: torch.Tensor, local_patch: int = 5 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Query the other features according to flow estimation. + """ + + assert local_patch % 2 == 1, "local_patch should be odd number" + + P = local_patch + R = (P - 1) // 2 + B, C, H, W = other_features.shape + + device = other_features.device + + # expected_output = torch.zeros(B, H, W, P, P, C, device=other_features.device, dtype=torch.float32) + + neighborhood_grid_ij: torch.Tensor + + i_local, j_local = torch.meshgrid( + torch.arange(-R, R + 1, device=device), torch.arange(-R, R + 1, device=device), indexing="ij" + ) + ij_local = torch.stack((i_local, j_local), dim=0) # 2, P, P tensor + + # compute the indices of the fetch + base_grid_xy = get_meshgrid_torch(W=W, H=H, device=device).permute(2, 0, 1).reshape(1, 2, H, W) + + target_coordinate_xy_float = flow_estimation + base_grid_xy + target_coordinate_xy = target_coordinate_xy_float.view(B, 2, H, W, 1, 1) + target_coordinate_ij = target_coordinate_xy[:, [1, 0], ...] + + # compute the neighborhood grid + neighborhood_grid_ij = target_coordinate_ij + ij_local.view(1, 2, 1, 1, P, P) + + grid_for_sample = neighborhood_grid_ij[:, [1, 0], ...].permute(0, 2, 3, 4, 5, 1).reshape(B, H, W * P * P, 2) + grid_for_sample = (grid_for_sample + 0.5) / torch.tensor([W, H], device=device).view(1, 1, 1, 2) + grid_for_sample = grid_for_sample * 2 - 1 + + expected_output = torch.nn.functional.grid_sample( + other_features, grid=grid_for_sample, mode="bicubic", padding_mode="zeros", align_corners=False + ).view(B, C, H, W, P, P) + + # transform BCHWPP to BHWPPC + expected_output = expected_output.permute(0, 2, 3, 4, 5, 1) + + neighborhood_grid_xy_residual = ij_local[[1, 0], ...].view(1, 2, 1, 1, P, P).to(device).float() + neighborhood_grid_xy_residual = neighborhood_grid_xy_residual.permute(0, 2, 3, 4, 5, 1).float() + + return expected_output, neighborhood_grid_xy_residual + + def _initialize_classification_head(self, classification_head_kwargs: Dict[str, Any]): + """ + Initialize classification head + + Args: + - classification_head_kwargs (Dict[str, Any]): Classification head configurations + + Returns: + - nn.Module: Classification head + """ + + if self.classification_head_type == "patch_mlp": + return MLPFeature(**classification_head_kwargs) + else: + raise ValueError(f"Classification head type {self.classification_head_type} not supported.") + + def get_parameter_groups(self) -> Dict[str, torch.nn.ParameterList]: + """ + Get parameter groups for optimizer. This methods guides the optimizer + to apply correct learning rate to different parts of the model. + + Returns: + - Dict[str, torch.nn.ParameterList]: Parameter groups for optimizer + """ + + if self.use_unet_feature: + params_dict = { + "encoder": torch.nn.ParameterList(self.encoder.parameters()), + "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), + "output_head": torch.nn.ParameterList(self.head1.parameters()), + "classification_head": torch.nn.ParameterList(self.classification_head.parameters()), + "unet_feature": torch.nn.ParameterList( + list(self.unet_feature.parameters()) + + list(self.conv1.parameters()) + + list(self.conv2.parameters()) + + [self.classification_bias] + ), + } + else: + params_dict = { + "encoder": torch.nn.ParameterList(self.encoder.parameters()), + "info_sharing": torch.nn.ParameterList(self.info_sharing.parameters()), + "output_head": torch.nn.ParameterList(self.head1.parameters()), + "classification_head": torch.nn.ParameterList(self.classification_head.parameters()), + } + + if hasattr(self, "uncertainty_head"): + params_dict["uncertainty_head"] = torch.nn.ParameterList(self.uncertainty_head.parameters()) + + return params_dict + + +if __name__ == "__main__": + import cv2 + import flow_vis + import matplotlib.pyplot as plt + import numpy as np + import torch + + from uniflowmatch.utils.geometry import get_meshgrid_torch + from uniflowmatch.utils.viz import warp_image_with_flow + + USE_REFINEMENT_MODEL = False + + if USE_REFINEMENT_MODEL: + model = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine") + else: + model = UniFlowMatchConfidence.from_pretrained("infinity1096/UFM-Base") + + # === Load and Prepare Images === + source_path = "examples/image_pairs/fire_academy_0.png" + target_path = "examples/image_pairs/fire_academy_1.png" + + source_image = cv2.imread(source_path) + target_image = cv2.imread(target_path) + + source_image = cv2.cvtColor(source_image, cv2.COLOR_BGR2RGB) + target_image = cv2.cvtColor(target_image, cv2.COLOR_BGR2RGB) + + # === Predict Correspondences === + result = model.predict_correspondences_batched( + source_image=torch.from_numpy(source_image), + target_image=torch.from_numpy(target_image), + ) + + flow_output = result.flow.flow_output[0].cpu().numpy() + covisibility = result.covisibility.mask[0].cpu().numpy() + + # === Visualize Results === + fig, axs = plt.subplots(2, 3, figsize=(15, 5)) + + axs[0, 0].imshow(source_image) + axs[0, 0].set_title("Source Image") + + axs[0, 1].imshow(target_image) + axs[0, 1].set_title("Target Image") + + # Warp the image using flow + warped_image = warp_image_with_flow(source_image, None, target_image, flow_output.transpose(1, 2, 0)) + warped_image = covisibility[..., None] * warped_image + (1 - covisibility[..., None]) * 255 * np.ones_like( + warped_image + ) + warped_image /= 255.0 + + axs[0, 2].imshow(warped_image) + axs[0, 2].set_title("Warped Image") + + # Flow visualization + flow_vis_image = flow_vis.flow_to_color(flow_output.transpose(1, 2, 0)) + axs[1, 0].imshow(flow_vis_image) + axs[1, 0].set_title("Flow Output (Valid at covisible region)") + + # Covisibility mask + axs[1, 1].imshow(covisibility > 0.5, cmap="gray", vmin=0, vmax=1) + axs[1, 1].set_title("Covisibility Mask (Thresholded by 0.5)") + + heatmap = axs[1, 2].imshow(covisibility, cmap="gray", vmin=0, vmax=1) + axs[1, 2].set_title("Covisibility Mask") + plt.colorbar(heatmap, ax=axs[1, 2]) + + plt.tight_layout() + plt.savefig("ufm_output.png") + plt.show() + print("Saved ufm_output.png") \ No newline at end of file diff --git a/uniflowmatch/models/unet_encoder.py b/uniflowmatch/models/unet_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b242f28f7059bf9b7f5fde42e7ba2c3b9114c768 --- /dev/null +++ b/uniflowmatch/models/unet_encoder.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(Conv2d => ReLU) * 2 with padding""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), # preserve spatial + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), # preserve spatial + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.conv(x) + + +class UNet(nn.Module): + def __init__(self, in_channels, out_channels, features=[64, 128, 256, 512]): + super().__init__() + self.downs = nn.ModuleList() + self.ups = nn.ModuleList() + + # Downsampling part + for feature in features: + self.downs.append(DoubleConv(in_channels, feature)) + in_channels = feature + + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + # Bottleneck + self.bottleneck = DoubleConv(features[-1], features[-1] * 2) + + # Upsampling part + for feature in reversed(features): + self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)) + self.ups.append(DoubleConv(feature * 2, feature)) + + # Final output + self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) + + def forward(self, x): + skip_connections = [] + + # Encoder + for down in self.downs: + x = down(x) + skip_connections.append(x) + x = self.pool(x) + + x = self.bottleneck(x) + + # Decoder + skip_connections = skip_connections[::-1] + for idx in range(0, len(self.ups), 2): + x = self.ups[idx](x) # ConvTranspose2d + skip_connection = skip_connections[idx // 2] + if x.shape != skip_connection.shape: + x = F.interpolate(x, size=skip_connection.shape[2:]) # Fix mismatched shapes + x = torch.cat((skip_connection, x), dim=1) + x = self.ups[idx + 1](x) + + return self.final_conv(x) + + +if __name__ == "__main__": + model = UNet(in_channels=3, out_channels=16) + x = torch.randn(1, 3, 256, 256) + out = model(x) + print(out.shape) diff --git a/uniflowmatch/models/utils.py b/uniflowmatch/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11c3869654aab9adcdd3ca1b8a883b81a204e371 --- /dev/null +++ b/uniflowmatch/models/utils.py @@ -0,0 +1,12 @@ +from functools import lru_cache + +import torch + + +@lru_cache(maxsize=10) +def get_meshgrid_torch(W, H, device): + u, v = torch.meshgrid(torch.arange(W, device=device).float(), torch.arange(H, device=device).float(), indexing="xy") + + uv = torch.stack((u, v), dim=-1) + + return uv diff --git a/uniflowmatch/utils/flow_resizing.py b/uniflowmatch/utils/flow_resizing.py new file mode 100644 index 0000000000000000000000000000000000000000..30d113b58c1568280bdfdcbd886c726b92a950b7 --- /dev/null +++ b/uniflowmatch/utils/flow_resizing.py @@ -0,0 +1,1087 @@ +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F + + +class ImagePairsManipulationBase: + def __init__(self): + pass + + def __call__( + self, + img0: torch.Tensor, + img1: torch.Tensor, + img0_region_source: torch.Tensor, + img1_region_source: torch.Tensor, + img0_region_representation: torch.Tensor, + img1_region_representation: torch.Tensor, + ): + """ + Apply resizing, cropping, and padding to image pairs while recording correspondence information. + + Args: + - img0: Tensor of shape (B, H, W, C), dtype uint8 representing the first set of images. + - img1: Tensor of shape (B, H, W, C), dtype uint8 representing the second set of images. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + Returns: + - img0: Tensor of image0 after manipulation. + - img1: Tensor of image1 after manipulation. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + """ + raise NotImplementedError + + def output_shape(self, H: int, W: int) -> Tuple[int, int]: + """ + Compute the output shape of the image after the resize operation. + + Args: + - H: Height of the first image. + - W: Width of the first image. + + Returns: + Tuple of (H1, W1, H2, W2) representing the output shape of the images if the manipulation is applied. + """ + + raise NotImplementedError + + def output_shape_pairs(self, H1: int, W1: int, H2: int, W2: int) -> Tuple[int, int, int, int]: + """ + Compute the output shape of the image after the resize operation. + """ + + output1 = self.output_shape(H1, W1) + output2 = self.output_shape(H2, W2) + + return output1[0], output1[1], output2[0], output2[1] + + def check_input(self, H: int, W: int) -> bool: + """ + Check whether the input shapes are correct for the current manipulation. + + Args: + - H: Height of the first image. + - W: Width of the first image. + + Returns: + Whether the manipualtion can run on the given input shapes. + """ + raise NotImplementedError + + def check_input_pairs(self, H1: int, W1: int, H2: int, W2: int) -> bool: + return self.check_input(H1, W1) and self.check_input(H2, W2) + + +class ResizeHorizontalAxisManipulation(ImagePairsManipulationBase): + def __init__(self, horizontal_axis: int): + self.horizontal_axis = horizontal_axis + + def output_shape(self, H: int, W: int) -> Tuple[int, int]: + """ + Compute the output shape of the image after the resize operation. + """ + resize_ratio = self.horizontal_axis / W + + return (int(H * resize_ratio), self.horizontal_axis) + + def check_input(self, H: int, W: int) -> bool: + return True + + def __call__( + self, + img0: torch.Tensor, + img1: torch.Tensor, + img0_region_source: torch.Tensor, + img1_region_source: torch.Tensor, + img0_region_representation: torch.Tensor, + img1_region_representation: torch.Tensor, + ): + """ + Apply resizing, cropping, and padding to image pairs while recording correspondence information. + + Args: + - img0: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the first set of images. + - img1: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the second set of images. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + Returns: + - img0: Tensor of image0 after manipulation. + - img1: Tensor of image1 after manipulation. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + """ + # assert img0.shape == img1.shape, "Image shapes must match" + + _, h0, w0, _ = img0.shape + _, h1, w1, _ = img1.shape + + target_h0, target_w0, target_h1, target_w1 = self.output_shape_pairs(h0, w0, h1, w1) + + assert img0.dtype == img1.dtype, "Image types must match" + is_uint8 = img0.dtype == torch.uint8 + + img0_resized = F.interpolate( + img0.permute(0, 3, 1, 2).float(), size=(target_h0, target_w0), mode="bilinear", align_corners=False + ).permute(0, 2, 3, 1) + img1_resized = F.interpolate( + img1.permute(0, 3, 1, 2).float(), size=(target_h1, target_w1), mode="bilinear", align_corners=False + ).permute(0, 2, 3, 1) + + if is_uint8: + img0_resized = img0_resized.to(torch.uint8) + img1_resized = img1_resized.to(torch.uint8) + + h_mult0 = target_h0 / h0 + w_mult0 = target_w0 / w0 + + multplier0 = torch.tensor([h_mult0, h_mult0, w_mult0, w_mult0]) + + h_mult1 = target_h1 / h1 + w_mult1 = target_w1 / w1 + + multplier1 = torch.tensor([h_mult1, h_mult1, w_mult1, w_mult1]) + + # source region is unchanged + # target region is scaled + img0_region_representation = multplier0 * img0_region_representation + img1_region_representation = multplier1 * img1_region_representation + + return ( + img0_resized, + img1_resized, + img0_region_source, + img1_region_source, + img0_region_representation, + img1_region_representation, + ) + + +class ResizeVerticalAxisManipulation(ImagePairsManipulationBase): + def __init__(self, vertical_axis: int): + self.vertical_axis = vertical_axis + + def output_shape(self, H: int, W: int) -> Tuple[int, int]: + """ + Compute the output shape of the image after the resize operation. + """ + + resize_ratio = self.vertical_axis / H + + return (self.vertical_axis, int(W * resize_ratio)) + + def check_input(self, H: int, W: int) -> bool: + return True + + def __call__( + self, + img0: torch.Tensor, + img1: torch.Tensor, + img0_region_source: torch.Tensor, + img1_region_source: torch.Tensor, + img0_region_representation: torch.Tensor, + img1_region_representation: torch.Tensor, + ): + """ + Apply resizing, cropping, and padding to image pairs while recording correspondence information. + + Args: + - img0: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the first set of images. + - img1: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the second set of images. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + Returns: + - img0: Tensor of image0 after manipulation. + - img1: Tensor of image1 after manipulation. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + """ + # assert img0.shape == img1.shape, "Image shapes must match" + + _, h0, w0, _ = img0.shape + _, h1, w1, _ = img1.shape + + target_h0, target_w0, target_h1, target_w1 = self.output_shape_pairs(h0, w0, h1, w1) + + assert img0.dtype == img1.dtype, "Image types must match" + is_uint8 = img0.dtype == torch.uint8 + + img0_resized = F.interpolate( + img0.permute(0, 3, 1, 2).float(), size=(target_h0, target_w0), mode="bilinear", align_corners=False + ).permute(0, 2, 3, 1) + img1_resized = F.interpolate( + img1.permute(0, 3, 1, 2).float(), size=(target_h1, target_w1), mode="bilinear", align_corners=False + ).permute(0, 2, 3, 1) + + if is_uint8: + img0_resized = img0_resized.to(torch.uint8) + img1_resized = img1_resized.to(torch.uint8) + + h_mult0 = target_h0 / h0 + w_mult0 = target_w0 / w0 + + multplier0 = torch.tensor([h_mult0, h_mult0, w_mult0, w_mult0]) + + h_mult1 = target_h1 / h1 + w_mult1 = target_w1 / w1 + + multplier1 = torch.tensor([h_mult1, h_mult1, w_mult1, w_mult1]) + + # source region is unchanged + # target region is scaled + img0_region_representation = multplier0 * img0_region_representation + img1_region_representation = multplier1 * img1_region_representation + + return ( + img0_resized, + img1_resized, + img0_region_source, + img1_region_source, + img0_region_representation, + img1_region_representation, + ) + + +class ResizeToFixedManipulation(ImagePairsManipulationBase): + def __init__(self, target_shape: Tuple[int, int]): + self.target_shape = target_shape + + def output_shape(self, H: int, W: int) -> Tuple[int, int]: + """ + Compute the output shape of the image after the resize operation. + """ + + return self.target_shape + + def check_input(self, H: int, W: int) -> bool: + return True + + def __call__( + self, + img0: torch.Tensor, + img1: torch.Tensor, + img0_region_source: torch.Tensor, + img1_region_source: torch.Tensor, + img0_region_representation: torch.Tensor, + img1_region_representation: torch.Tensor, + ): + """ + Apply resizing, cropping, and padding to image pairs while recording correspondence information. + + Args: + - img0: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the first set of images. + - img1: Tensor of shape (B, H, W, C), dtype float32 or uint8 representing the second set of images. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + Returns: + - img0: Tensor of image0 after manipulation. + - img1: Tensor of image1 after manipulation. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + """ + # assert img0.shape == img1.shape, "Image shapes must match" + + _, h0, w0, _ = img0.shape + _, h1, w1, _ = img1.shape + + target_h0, target_w0, target_h1, target_w1 = self.output_shape_pairs(h0, w0, h1, w1) + + assert img0.dtype == img1.dtype, "Image types must match" + is_uint8 = img0.dtype == torch.uint8 + + img0_resized = F.interpolate( + img0.permute(0, 3, 1, 2).float(), + size=(target_h0, target_w0), + mode="bilinear", + align_corners=False, + antialias=True, + ).permute(0, 2, 3, 1) + img1_resized = F.interpolate( + img1.permute(0, 3, 1, 2).float(), + size=(target_h1, target_w1), + mode="bilinear", + align_corners=False, + antialias=True, + ).permute(0, 2, 3, 1) + + if is_uint8: + img0_resized = img0_resized.to(torch.uint8) + img1_resized = img1_resized.to(torch.uint8) + + h_mult0 = target_h0 / h0 + w_mult0 = target_w0 / w0 + + multplier0 = torch.tensor([h_mult0, h_mult0, w_mult0, w_mult0]) + + h_mult1 = target_h1 / h1 + w_mult1 = target_w1 / w1 + + multplier1 = torch.tensor([h_mult1, h_mult1, w_mult1, w_mult1]) + + # source region is unchanged + # target region is scaled + img0_region_representation = (multplier0 * img0_region_representation).to(torch.int64) + img1_region_representation = (multplier1 * img1_region_representation).to(torch.int64) + + return ( + img0_resized, + img1_resized, + img0_region_source, + img1_region_source, + img0_region_representation, + img1_region_representation, + ) + + +def scale_axis( + source_low: float, + source_high: float, + reference_low: float, + reference_high: float, + reference_low_new: float, + reference_high_new: float, +): + reference_length = reference_high - reference_low + coordinate_relative_low = (reference_low_new - reference_low) / reference_length + coordinate_relative_high = (reference_high_new - reference_low) / reference_length + + source_length = source_high - source_low + source_low_new = source_low + coordinate_relative_low * source_length + source_high_new = source_low + coordinate_relative_high * source_length + + return source_low_new, source_high_new + + +class CenterCropManipulation(ImagePairsManipulationBase): + def __init__(self, target_size: Tuple[int, int]): + self.target_size = target_size + + def output_shape(self, H: int, W: int) -> Tuple[int, int]: + """ + Compute the output shape of the image after the resize operation. + """ + + return self.target_size + + def check_input(self, H: int, W: int) -> bool: + return H >= self.target_size[0] and W >= self.target_size[1] + + def __call__( + self, + img0: torch.Tensor, + img1: torch.Tensor, + img0_region_source: torch.Tensor, + img1_region_source: torch.Tensor, + img0_region_representation: torch.Tensor, + img1_region_representation: torch.Tensor, + ): + """ + Apply resizing, cropping, and padding to image pairs while recording correspondence information. + + Args: + - img0: Tensor of shape (B, H, W, C), dtype uint8 representing the first set of images. + - img1: Tensor of shape (B, H, W, C), dtype uint8 representing the second set of images. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + Returns: + - img0: Tensor of image0 after manipulation. + - img1: Tensor of image1 after manipulation. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + """ + + B0, H0, W0, C0 = img0.shape + B1, H1, W1, C1 = img1.shape + + target_h, target_w = self.target_size + + assert H0 >= target_h and W0 >= target_w, "Image shapes must be larger than the target size." + assert H1 >= target_h and W1 >= target_w, "Image shapes must be larger than the target size." + + crop_top_0 = (H0 - target_h) // 2 + crop_bottom_0 = H0 - target_h - crop_top_0 + crop_left_0 = (W0 - target_w) // 2 + crop_right_0 = W0 - target_w - crop_left_0 + + crop_top_1 = (H1 - target_h) // 2 + crop_bottom_1 = H1 - target_h - crop_top_1 + crop_left_1 = (W1 - target_w) // 2 + crop_right_1 = W1 - target_w - crop_left_1 + + # apply the crops + img0_cropped = img0[:, crop_top_0 : H0 - crop_bottom_0, crop_left_0 : W0 - crop_right_0, :] + img1_cropped = img1[:, crop_top_1 : H1 - crop_bottom_1, crop_left_1 : W1 - crop_right_1, :] + + # update the representation region accurately. This is complex as we may or may not crop out the valid regions. + remaining_region_0 = torch.tensor( + [ + max(img0_region_representation[0], crop_top_0), + min(img0_region_representation[1], H0 - crop_bottom_0), + max(img0_region_representation[2], crop_left_0), + min(img0_region_representation[3], W0 - crop_right_0), + ] + ) + + remaining_region_1 = torch.tensor( + [ + max(img1_region_representation[0], crop_top_1), + min(img1_region_representation[1], H1 - crop_bottom_1), + max(img1_region_representation[2], crop_left_1), + min(img1_region_representation[3], W1 - crop_right_1), + ] + ) + + # shift the remaining region as the cropped region is removed + img0_region_representation_new = remaining_region_0 - torch.tensor( + [crop_top_0, crop_top_0, crop_left_0, crop_left_0] + ) + img1_region_representation_new = remaining_region_1 - torch.tensor( + [crop_top_1, crop_top_1, crop_left_1, crop_left_1] + ) + + img0_region_representation_new = img0_region_representation_new.to(torch.int64) + img1_region_representation_new = img1_region_representation_new.to(torch.int64) + + # the valid region may or may not be cropped out, so we need to adjust the source region as well + img0_region_source[0], img0_region_source[1] = scale_axis( + img0_region_source[0], + img0_region_source[1], + img0_region_representation[0], + img0_region_representation[1], + remaining_region_0[0], + remaining_region_0[1], + ) + + img0_region_source[2], img0_region_source[3] = scale_axis( + img0_region_source[2], + img0_region_source[3], + img0_region_representation[2], + img0_region_representation[3], + remaining_region_0[2], + remaining_region_0[3], + ) + + img1_region_source[0], img1_region_source[1] = scale_axis( + img1_region_source[0], + img1_region_source[1], + img1_region_representation[0], + img1_region_representation[1], + remaining_region_1[0], + remaining_region_1[1], + ) + + img1_region_source[2], img1_region_source[3] = scale_axis( + img1_region_source[2], + img1_region_source[3], + img1_region_representation[2], + img1_region_representation[3], + remaining_region_1[2], + remaining_region_1[3], + ) + + return ( + img0_cropped, + img1_cropped, + img0_region_source, + img1_region_source, + img0_region_representation_new, + img1_region_representation_new, + ) + + +class ImagePairsManipulationComposite(ImagePairsManipulationBase): + def __init__(self, *manipulations: List[ImagePairsManipulationBase]): + self.manipulations = manipulations + + def output_shape(self, H: int, W: int) -> Tuple[int, int]: + """ + Compute the output shape of the image after the resize operation. + """ + + output_shape = (H, W) + for manipulation in self.manipulations: + output_shape = manipulation.output_shape(*output_shape) + + return output_shape + + def output_shape_pairs(self, H1: int, W1: int, H2: int, W2: int) -> Tuple[int, int, int, int]: + """ + Compute the output shape of the image after the resize operation. + """ + + output_shape = (H1, W1, H2, W2) + for manipulation in self.manipulations: + output_shape = manipulation.output_shape_pairs(*output_shape) + + return output_shape + + def check_input(self, H: int, W: int) -> bool: + current_shape = (H, W) + for manipulation in self.manipulations: + if not manipulation.check_input(*current_shape): + return False + + current_shape = manipulation.output_shape(*current_shape) + + return True + + def check_input_pairs(self, H1: int, W1: int, H2: int, W2: int) -> bool: + current_shape = (H1, W1, H2, W2) + for manipulation in self.manipulations: + if not manipulation.check_input_pairs(*current_shape): + return False + + current_shape = manipulation.output_shape_pairs(*current_shape) + + return True + + def __call__( + self, + img0: torch.Tensor, + img1: torch.Tensor, + img0_region_source: torch.Tensor, + img1_region_source: torch.Tensor, + img0_region_representation: torch.Tensor, + img1_region_representation: torch.Tensor, + ): # -> tuple[Tensor | Any, Tensor | Any, Tensor | Any, Tensor | ...: + """ + Apply resizing, cropping, and padding to image pairs while recording correspondence information. + + Args: + - img0: Tensor of shape (B, H, W, C), dtype uint8 representing the first set of images. + - img1: Tensor of shape (B, H, W, C), dtype uint8 representing the second set of images. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + Returns: + - img0: Tensor of image0 after manipulation. + - img1: Tensor of image1 after manipulation. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + """ + + for manipulation in self.manipulations: + ( + img0, + img1, + img0_region_source, + img1_region_source, + img0_region_representation, + img1_region_representation, + ) = manipulation( + img0, + img1, + img0_region_source, + img1_region_source, + img0_region_representation, + img1_region_representation, + ) + + return ( + img0, + img1, + img0_region_source, + img1_region_source, + img0_region_representation, + img1_region_representation, + ) + + +class AutomaticShapeSelection(ImagePairsManipulationBase): + def __init__(self, *manipulations: ImagePairsManipulationBase, strategy="closest_aspect"): + self.manipulations = manipulations + + if strategy == "closest_aspect": + self.strategy = self._closest_aspect_strategy + else: + raise ValueError("Unknown strategy") + + def output_shape(self, H: int, W: int) -> Tuple[int, int]: + """ + Compute the output shape of the image after the resize operation. + """ + + output_shape, augmentor = self.strategy(H, W) + + if output_shape is None: + raise ValueError("No valid shape found for the given resolution.") + + return output_shape + + def output_shape_pairs(self, H1: int, W1: int, H2: int, W2: int) -> Tuple[int, int, int, int]: + """ + Compute the output shape of the image after the resize operation. + """ + + output_shape, augmentor = self.strategy(H1, W1, H2, W2) + + if output_shape is None: + raise ValueError("No valid shape found for the given resolution.") + + return output_shape + + def check_input(self, H: int, W: int) -> bool: + output_shape, augmentor = self.strategy(H, W) + + if output_shape is None: + return False + + return True + + def check_input_pairs(self, H1: int, W1: int, H2: int, W2: int) -> bool: + output_shape, augmentor = self.strategy(H1, W1, H2, W2) + + if output_shape is None: + return False + + return True + + def _closest_aspect_strategy(self, H: int, W: int, *shape_img1): + # for all caididate sizes, first check if they can run at the given resolution + if shape_img1 is None: + runnable_sizes = [ + (manipulator.output_shape(H, W, *shape_img1), manipulator) + for manipulator in self.manipulations + if manipulator.check_input(H, W, *shape_img1) + ] + else: + runnable_sizes = [ + (manipulator.output_shape_pairs(H, W, *shape_img1), manipulator) + for manipulator in self.manipulations + if manipulator.check_input_pairs(H, W, *shape_img1) + ] + + if len(runnable_sizes) == 0: + return None, None + + # if there are runnable sizes, then select the one that is closest to the given resolution + if shape_img1 is None: + closest_size, closest_augmentor = min(runnable_sizes, key=lambda x: abs(x[0][0] / x[0][1] - H / W)) + else: + closest_size, closest_augmentor = min( + runnable_sizes, + key=lambda x: abs(x[0][0] / x[0][1] - H / W) + abs(x[0][2] / x[0][3] - shape_img1[0] / shape_img1[1]), + ) + + return closest_size, closest_augmentor + + def __call__( + self, + img0: torch.Tensor, + img1: torch.Tensor, + img0_region_source: Optional[torch.Tensor] = None, + img1_region_source: Optional[torch.Tensor] = None, + img0_region_representation: Optional[torch.Tensor] = None, + img1_region_representation: Optional[torch.Tensor] = None, + ): + """ + Apply resizing, cropping, and padding to image pairs while recording correspondence information. + + Args: + - img0: Tensor of shape (B, H, W, C), dtype uint8 representing the first set of images. + - img1: Tensor of shape (B, H, W, C), dtype uint8 representing the second set of images. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + Returns: + - img0: Tensor of image0 after manipulation. + - img1: Tensor of image1 after manipulation. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + """ + + H0, W0 = img0.shape[1], img0.shape[2] + H1, W1 = img1.shape[1], img1.shape[2] + + output_shape, augmentor = self.strategy(H0, W0, H1, W1) + + if output_shape is None: + raise ValueError("No valid shape found for the given resolution.") + + if img0_region_source is None: + assert img1_region_source is None + assert img0_region_representation is None + assert img1_region_representation is None + + img0_region_source = torch.tensor([0, H0, 0, W0]) + img1_region_source = torch.tensor([0, H1, 0, W1]) + img0_region_representation = torch.tensor([0, H0, 0, W0]) + img1_region_representation = torch.tensor([0, H1, 0, W1]) + + return augmentor( + img0, img1, img0_region_source, img1_region_source, img0_region_representation, img1_region_representation + ) + + +# unmap the predicted flow to match the input. Flow is unique semantically as its value changes +# depending on the source and target region. +def unmap_predicted_flow( + flow: torch.Tensor, + img0_region_representation: torch.Tensor, + img1_region_representation: torch.Tensor, + img0_region_source: torch.Tensor, + img1_region_source: torch.Tensor, + img0_source_shape: Tuple[int, int], + img1_source_shape: Tuple[int, int], +): + """ + Unmap the predicted flow to the original image space. + + Args: + - flow: Tensor of shape (B, 2, H, W) representing the predicted flow between the two regions. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + Returns: + - flow: Tensor of shape (B, 2, H, W) representing the predicted flow in the original image space. + """ + + B, C, H, W = flow.shape + + # Step 1: Zero the start of flow representing mapping in model's output space + # the flow end is the source coordinates + the flow + flow_roi = flow[ + ..., + img0_region_representation[0] : img0_region_representation[1], + img0_region_representation[2] : img0_region_representation[3], + ] + + source_offset = torch.tensor([img0_region_source[2], img0_region_source[0]]).to(flow.device) + + target_offset = torch.tensor([img1_region_source[2], img1_region_source[0]]).to(flow.device) + + flow_valid2valid = flow_roi # + (source_offset - target_offset).view(1, 2, 1, 1) + + # Step 2: Represent the flow as pairs of source and target coordinates + source_coordinates = ( + torch.stack( + torch.meshgrid( + torch.arange(0, flow_valid2valid.shape[3]) + 0.5, + torch.arange(0, flow_valid2valid.shape[2]) + 0.5, + indexing="xy", + ), + dim=-1, + ) + .permute(2, 0, 1) + .unsqueeze(0) + .to(flow.device) + ) + + # Step 3: Scale the flow to the source space. Notice that here we can actually assume + # valid representation space have the same shape. + # So it looks like both source and target coordinates are scaled according to the source representation. + + # now we scale the valid2valid flow from representation space to source space + source_valid_shape = torch.tensor( + [img0_region_source[1] - img0_region_source[0], img0_region_source[3] - img0_region_source[2]] + ) + + target_valid_shape = torch.tensor( + [img1_region_source[1] - img1_region_source[0], img1_region_source[3] - img1_region_source[2]] + ) + + # upscale source and target coordinates to the source space + source_coordinates_valid = F.interpolate( + source_coordinates.float(), size=source_valid_shape.tolist(), mode="bilinear", align_corners=False + ) + + # This is equivalently we define "target_coordinates = source_coordinates + flow_valid2valid" and apply the scaling. + # since we have a flow component, we can only do nearest interpolation, but this will cause ~0.5 pixel error + # because we are interpoling the source_coordinates also linearly. + + target_coordinates_valid = ( + F.interpolate(flow_valid2valid.float(), size=source_valid_shape.tolist(), mode="nearest") + + source_coordinates_valid + ) + + # print("Change me to nearest interpolation") + + # apply different scaling to the flow: representation for source maps to source_valid_shape in source space + source_coordinates_valid *= ( + torch.tensor( + [ + source_valid_shape[1] / (img0_region_representation[3] - img0_region_representation[2]), + source_valid_shape[0] / (img0_region_representation[1] - img0_region_representation[0]), + ] + ) + .view(1, 2, 1, 1) + .to(flow.device) + ) + + # target coordinates are scaled to the target source space, which may be different from the source space + target_coordinates_valid *= ( + torch.tensor( + [ + target_valid_shape[1] / (img0_region_representation[3] - img0_region_representation[2]), + target_valid_shape[0] / (img0_region_representation[1] - img0_region_representation[0]), + ] + ) + .view(1, 2, 1, 1) + .to(flow.device) + ) + + # Step 4: Offset the flow from valid source space to the original source space + source_coordinates_valid += ( + torch.tensor([img0_region_source[2], img0_region_source[0]]).view(1, 2, 1, 1).to(flow.device) + ) + + target_coordinates_valid += ( + torch.tensor([img1_region_source[2], img1_region_source[0]]).view(1, 2, 1, 1).to(flow.device) + ) + + # now we can compute the flow in the source space + flow_source = target_coordinates_valid - source_coordinates_valid + + # Step5: Embed the flow in its original space + flow_output = torch.zeros((B, 2, img0_source_shape[0], img0_source_shape[1]), dtype=flow.dtype, device=flow.device) + + flow_output[ + ..., img0_region_source[0] : img0_region_source[1], img0_region_source[2] : img0_region_source[3] + ] = flow_source + + flow_valid = torch.zeros((B, img0_source_shape[0], img0_source_shape[1]), dtype=torch.bool, device=flow.device) + flow_valid[..., img0_region_source[0] : img0_region_source[1], img0_region_source[2] : img0_region_source[3]] = True + + return flow_output, flow_valid + + +# unmap predicted source - target point pairs. +def unmap_predicted_pairs( + source_points: torch.Tensor, + target_points: torch.Tensor, + img0_region_representation: torch.Tensor, + img1_region_representation: torch.Tensor, + img0_region_source: torch.Tensor, + img1_region_source: torch.Tensor, + img0_source_shape: Tuple[int, int], + img1_source_shape: Tuple[int, int], +): + """ + Unmap the predicted flow to the original image space. + + Args: + - source_points: Tensor of shape (B, N, 2) representing the predicted source points. + - target_points: Tensor of shape (B, N, 2) representing the predicted target points. + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img1_region_representation: Tensor of size 4 representing the region of img1 in current representation space corresponding to the source region. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img1_region_source: Tensor of size 4 representing the region of img1 that is the source of the correspondence. + Returns: + - flow: Tensor of shape (B, 2, H, W) representing the predicted flow in the original image space. + """ + + # 1. scale source points & target points from representation space to source space + img0_region_source_shape = torch.tensor( + [img0_region_source[1] - img0_region_source[0], img0_region_source[3] - img0_region_source[2]] + ) + + img1_region_source_shape = torch.tensor( + [img1_region_source[1] - img1_region_source[0], img1_region_source[3] - img1_region_source[2]] + ) + + source_points[:, :, 0], _ = scale_axis( + img0_region_source[2], + img0_region_source[3], + img0_region_representation[2], + img0_region_representation[3], + source_points[:, :, 0], + 0.0, + ) + + source_points[:, :, 1], _ = scale_axis( + img0_region_source[0], + img0_region_source[1], + img0_region_representation[0], + img0_region_representation[1], + source_points[:, :, 1], + 0.0, + ) + + target_points[:, :, 0], _ = scale_axis( + img1_region_source[2], + img1_region_source[3], + img1_region_representation[2], + img1_region_representation[3], + target_points[:, :, 0], + 0.0, + ) + + target_points[:, :, 1], _ = scale_axis( + img1_region_source[0], + img1_region_source[1], + img1_region_representation[0], + img1_region_representation[1], + target_points[:, :, 1], + 0.0, + ) + + return source_points, target_points + + +# unmap normal channels like confidence, depth, etc. +# much simpler than the flow case +def unmap_predicted_channels( + channel: torch.Tensor, + img0_region_representation: torch.Tensor, + img0_region_source: torch.Tensor, + img0_source_shape: Tuple[int, int], +): + """ + Unmap the predicted flow to the original image space. + + Args: + - channel: Tensor of shape (B, C, H, W) representing the predicted values in img0 representation space + - img0_region_representation: Tensor of size 4 representing the region of img0 in current representation space corresponding to the source region. + - img0_region_source: Tensor of size 4 representing the region of img0 that is the source of the correspondence. + - img0_source_shape: Tuple of size 2 representing the shape of the original image. + Returns: + - channel: Tensor of shape (B, C, H, W) representing the predicted flow in the original image space. + - channel_valid: Tensor of shape (B, H, W) representing the valid region of the channel in the original image space. + """ + + B, C, H, W = channel.shape + + # Step 1: Zero the start of flow representing mapping in model's output space + # the flow end is the source coordinates + the flow + channel_roi = channel[ + ..., + img0_region_representation[0] : img0_region_representation[1], + img0_region_representation[2] : img0_region_representation[3], + ] + + # upscale the channel roi into source space roi + img0_valid_shape = torch.tensor( + [img0_region_source[1] - img0_region_source[0], img0_region_source[3] - img0_region_source[2]] + ) + + channel_source_roi = F.interpolate( + channel_roi, + size=img0_valid_shape.tolist(), + mode="nearest", + # align_corners=False + ) + + channel_output = torch.zeros( + (B, C, img0_source_shape[0], img0_source_shape[1]), dtype=channel.dtype, device=channel.device + ) + channel_output[ + ..., img0_region_source[0] : img0_region_source[1], img0_region_source[2] : img0_region_source[3] + ] = channel_source_roi + + channel_valid = torch.zeros( + (B, img0_source_shape[0], img0_source_shape[1]), dtype=torch.bool, device=channel.device + ) + channel_valid[ + ..., img0_region_source[0] : img0_region_source[1], img0_region_source[2] : img0_region_source[3] + ] = True + + return channel_output, channel_valid + + +if __name__ == "__main__": + import sys + + import matplotlib.pyplot as plt + import numpy as np + import torch + import torch.nn.functional as F + + # make a example test image that have flow in only one pixel from (25%, 25%) to (50%, 75%) of the image. + img0 = torch.zeros((1, 145, 256, 3), dtype=torch.uint8) # one below and one above the aspect (288, 512) + img1 = torch.zeros((1, 135, 256, 3), dtype=torch.uint8) + + source_pt = img0.shape[1] * 0.25, img0.shape[2] * 0.25 + target_pt = img1.shape[1] * 0.5, img1.shape[2] * 0.75 + + img0[0, int(source_pt[0]), int(source_pt[1]), :] = 255 + img1[0, int(target_pt[0]), int(target_pt[1]), :] = 255 + + flow_gt = torch.zeros((1, 2, 145, 256)) + flow_gt[0, :, int(source_pt[0]), int(source_pt[1])] = torch.tensor( + [target_pt[1] - source_pt[1], target_pt[0] - source_pt[0]] + ) + + H0, W0 = img0.shape[1], img0.shape[2] + H1, W1 = img1.shape[1], img1.shape[2] + + manipulation = AutomaticShapeSelection( + ImagePairsManipulationComposite(ResizeHorizontalAxisManipulation(512), CenterCropManipulation((288, 512))), + ImagePairsManipulationComposite(ResizeHorizontalAxisManipulation(512), CenterCropManipulation((200, 512))), + ) + + ( + img0_resized, + img1_resized, + img0_region_source, + img1_region_source, + img0_region_representation, + img1_region_representation, + ) = manipulation(img0, img1) + + fig, axs = plt.subplots(2, 3) + + axs[0, 0].imshow(img0[0].numpy()) + axs[0, 1].imshow(img0_resized[0].numpy()) + + axs[1, 0].imshow(img1[0].numpy()) + axs[1, 1].imshow(img1_resized[0].numpy()) + + print(img0_region_source) + print(img1_region_source) + print(img0_region_representation) + print(img1_region_representation) + + flow_pred = torch.zeros((1, 2, 288, 512)) + flow_pred[0, :, 28, 128] = torch.tensor([256, 72]) + + # unmap the flow + flow_unmapped = unmap_predicted_flow( + flow_pred, + img0_region_representation, + img1_region_representation, + img0_region_source, + img1_region_source, + (H0, W0), + (H1, W1), + ) + + flow_unmapped, flow_validity = flow_unmapped + flow_unmapped = flow_unmapped[0] + flow_validity = flow_validity[0] + + import flow_vis + + flow_rgb = flow_vis.flow_to_color(flow_unmapped.permute(1, 2, 0).numpy()) + axs[0, 2].imshow(flow_validity) + + plt.figure() + plt.imshow(flow_rgb) + plt.show() diff --git a/uniflowmatch/utils/geometry.py b/uniflowmatch/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0d507173a07e4f089a6e215dbad04f04e49680 --- /dev/null +++ b/uniflowmatch/utils/geometry.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python3 +# -------------------------------------------------------- +# Utils for geometric calculations +# Adopted from AnyMap(Nilhil Keetha) +# Includes functions from DUSt3R (Naver Corporation, CC BY-NC-SA 4.0 (non-commercial use only)) & GradSLAM (MIT License) +# -------------------------------------------------------- +from functools import lru_cache + +import einops as ein +import numpy as np +import torch + + +def depthmap_to_camera_frame(depthmap, intrinsics): + """ + Convert depth image to a pointcloud in camera frame. + Args: + - depthmap: HxW torch tensor + - camera_intrinsics: 3x3 torch tensor + Returns: + pointmap in camera frame (HxWx3 tensor), and a mask specifying valid pixels. + """ + height, width = depthmap.shape + device = depthmap.device + fx = intrinsics[0, 0] + fy = intrinsics[1, 1] + cx = intrinsics[0, 2] + cy = intrinsics[1, 2] + + # Compute 3D point in camera frame associated with each pixel + x_grid, y_grid = torch.meshgrid( + torch.arange(width).to(device).float(), torch.arange(height).to(device).float(), indexing="xy" + ) + depth_z = depthmap + xx = (x_grid - cx) * depth_z / fx + yy = (y_grid - cy) * depth_z / fy + pts3d_cam = torch.stack((xx, yy, depth_z), dim=-1) + + # Compute mask of valid non-zero depth pixels + valid_mask = depthmap > 0.0 + + return pts3d_cam, valid_mask + + +def depthmap_to_world_frame(depthmap, intrinsics, camera_pose=None): + """ + Convert depth image to a pointcloud in world frame. + + Args: + - depthmap: HxW torch tensor + - camera_intrinsics: 3x3 torch tensor + - camera_pose: 4x4 torch tensor + + Returns: + pointmap in world frame (HxWx3 tensor), and a mask specifying valid pixels. + """ + pts3d_cam, valid_mask = depthmap_to_camera_frame(depthmap, intrinsics) + + if camera_pose is not None: + pts3d_cam_homo = torch.cat([pts3d_cam, torch.ones_like(pts3d_cam[..., :1])], dim=-1) + pts3d_world = ein.einsum(camera_pose, pts3d_cam_homo, "i k, h w k -> h w i") + pts3d_world = pts3d_world[..., :3] + + return pts3d_world, valid_mask + + +def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw): + """Output a (H,W,2) array of int32 + with output[j,i,0] = i + origin[0] + output[j,i,1] = j + origin[1] + """ + if device is None: + # numpy + arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones + else: + # torch + arange = lambda *a, **kw: torch.arange(*a, device=device, **kw) + meshgrid, stack = torch.meshgrid, torch.stack + ones = lambda *a: torch.ones(*a, device=device) + + tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)] + grid = meshgrid(tw, th, indexing="xy") + if homogeneous: + grid = grid + (ones((H, W)),) + if unsqueeze is not None: + grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze)) + if cat_dim is not None: + grid = stack(grid, cat_dim) + + return grid + + +def geotrf(Trf, pts, ncol=None, norm=False): + """Apply a geometric transformation to a list of 3-D points. + + H: 3x3 or 4x4 projection matrix (typically a Homography) + p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) + + ncol: int. number of columns of the result (2 or 3) + norm: float. if != 0, the resut is projected on the z=norm plane. + + Returns an array of projected 2d points. + """ + assert Trf.ndim >= 2 + if isinstance(Trf, np.ndarray): + pts = np.asarray(pts) + elif isinstance(Trf, torch.Tensor): + pts = torch.as_tensor(pts, dtype=Trf.dtype) + + # adapt shape if necessary + output_reshape = pts.shape[:-1] + ncol = ncol or pts.shape[-1] + + # optimized code + if isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and Trf.ndim == 3 and pts.ndim == 4: + d = pts.shape[3] + if Trf.shape[-1] == d: + pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) + elif Trf.shape[-1] == d + 1: + pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] + else: + raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}") + else: + if Trf.ndim >= 3: + n = Trf.ndim - 2 + assert Trf.shape[:n] == pts.shape[:n], "batch size does not match" + Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) + + if pts.ndim > Trf.ndim: + # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) + pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) + elif pts.ndim == 2: + # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) + pts = pts[:, None, :] + + if pts.shape[-1] + 1 == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] + elif pts.shape[-1] == Trf.shape[-1]: + Trf = Trf.swapaxes(-1, -2) # transpose Trf + pts = pts @ Trf + else: + pts = Trf @ pts.T + if pts.ndim >= 2: + pts = pts.swapaxes(-1, -2) + + if norm: + pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG + if norm != 1: + pts *= norm + + res = pts[..., :ncol].reshape(*output_reshape, ncol) + + return res + + +def inv(mat): + """Invert a torch or numpy matrix""" + if isinstance(mat, torch.Tensor): + return torch.linalg.inv(mat) + if isinstance(mat, np.ndarray): + return np.linalg.inv(mat) + raise ValueError(f"bad matrix type = {type(mat)}") + + +def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_): + """ + Args: + - depthmap (BxHxW array): + - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W] + Returns: + pointmap of absolute coordinates (BxHxWx3 array) + """ + + if len(depth.shape) == 4: + B, H, W, n = depth.shape + else: + B, H, W = depth.shape + n = None + + if len(pseudo_focal.shape) == 3: # [B,H,W] + pseudo_focalx = pseudo_focaly = pseudo_focal + elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W] + pseudo_focalx = pseudo_focal[:, 0] + if pseudo_focal.shape[1] == 2: + pseudo_focaly = pseudo_focal[:, 1] + else: + pseudo_focaly = pseudo_focalx + else: + raise NotImplementedError("Error, unknown input focal shape format.") + + assert pseudo_focalx.shape == depth.shape[:3] + assert pseudo_focaly.shape == depth.shape[:3] + grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None] + + # set principal point + if pp is None: + grid_x = grid_x - (W - 1) / 2 + grid_y = grid_y - (H - 1) / 2 + else: + grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None] + grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None] + + if n is None: + pts3d = torch.empty((B, H, W, 3), device=depth.device) + pts3d[..., 0] = depth * grid_x / pseudo_focalx + pts3d[..., 1] = depth * grid_y / pseudo_focaly + pts3d[..., 2] = depth + else: + pts3d = torch.empty((B, H, W, 3, n), device=depth.device) + pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None] + pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None] + pts3d[..., 2, :] = depth + return pts3d + + +@lru_cache(maxsize=10) +def get_meshgrid(W, H): + u, v = np.meshgrid(np.arange(W), np.arange(H)) + return u, v + + +@lru_cache(maxsize=10) +def get_meshgrid_torch(W, H, device): + u, v = torch.meshgrid(torch.arange(W, device=device).float(), torch.arange(H, device=device).float(), indexing="xy") + + uv = torch.stack((u, v), dim=-1) + + return uv + + +def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + u, v = get_meshgrid(W, H) + + X_cam = np.zeros((H, W, 3), dtype=np.float32) + + X_cam[..., 0] = (u - cu) * depthmap / fu + X_cam[..., 1] = (v - cv) * depthmap / fv + X_cam[..., 2] = depthmap + + # Mask for valid coordinates + valid_mask = depthmap > 0.0 + + return X_cam, valid_mask + + +def z_depthmap_to_norm_depthmap(z_depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - z_depthmap (HxW array) + - camera_intrinsics: a 3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = z_depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + rays = np.ones((H, W, 3), dtype=np.float32) + + u, v = get_meshgrid(W, H) + + rays[..., 0] = (u - cu) / fu + rays[..., 1] = (v - cv) / fv + + ray_norm = np.linalg.norm(rays, axis=-1) + + return z_depthmap * ray_norm + + +def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw): + """ + Args: + - depthmap (HxW array): + - camera_intrinsics: a 3x3 matrix + - camera_pose: a 4x3 or 4x4 cam2world matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.""" + X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics) + + X_world = X_cam # default + if camera_pose is not None: + # R_cam2world = np.float32(camera_params["R_cam2world"]) + # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze() + R_cam2world = camera_pose[:3, :3] + t_cam2world = camera_pose[:3, 3] + + # Express in absolute coordinates (invalid depth values) + # X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :] + X_world = X_cam @ (R_cam2world.T) + t_cam2world[None, None, :] + + return X_world, valid_mask + + +def global_points_to_local(pts, camera_pose): + """ + Args: + - pts: points in world coordinate + - camera_pose: camera to world transformation + """ + + world_to_camera = np.linalg.inv(camera_pose) + R_world2cam = world_to_camera[:3, :3] + t_world2cam = world_to_camera[:3, 3] + + pts_local = np.einsum("ik, vuk -> vui", R_world2cam, pts) + t_world2cam[None, None, :] + + return pts_local + + +def project_points_to_pixels(pts_camera, camera_intrinsics, pseudo_focal=None): + """ + Args: + - pts_camera (HxWx3 array): points in camera coordinates + - camera_intrinsics: a 3x3 matrix + Returns: + pixel coordinates (HxWx2 array), and a mask specifying valid pixels. + """ + camera_intrinsics = np.float32(camera_intrinsics) + H, W = pts_camera.shape[:2] + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert camera_intrinsics[0, 1] == 0.0 + assert camera_intrinsics[1, 0] == 0.0 + if pseudo_focal is None: + fu = camera_intrinsics[0, 0] + fv = camera_intrinsics[1, 1] + else: + assert pseudo_focal.shape == (H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[0, 2] + cv = camera_intrinsics[1, 2] + + x, y, z = pts_camera[..., 0], pts_camera[..., 1], pts_camera[..., 2] + + uv = np.zeros((H, W, 2), dtype=np.float32) + + uv[..., 0] = fu * x / z + cu + uv[..., 1] = fv * y / z + cv + + # Mask for valid coordinates + valid_mask = ( + (z > 0.0) & (uv[..., 0] >= -0.5) & (uv[..., 0] < W - 0.5) & (uv[..., 1] >= -0.5) & (uv[..., 1] < H - 0.5) + ) + # valid_mask = (z > 0.0) & (uv[..., 0] >= 0) & (uv[..., 0] < W) & (uv[..., 1] >= 0) & (uv[..., 1] < H) + + return uv, valid_mask + + +def project_points_to_pixels_batched(pts_camera, camera_intrinsics, pseudo_focal=None): + """ + Args: + - pts_camera (BxHxWx3 torch.Tensor): points in camera coordinates + - camera_intrinsics: a Bx3x3 torch.Tensor + Returns: + pixel coordinates (BxHxWx2 torch.Tensor), and a mask (BxHxW) specifying valid pixels. + """ + camera_intrinsics = camera_intrinsics + B, H, W, C = pts_camera.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert (camera_intrinsics[..., 0, 1] == 0.0).all() + assert (camera_intrinsics[..., 1, 0] == 0.0).all() + if pseudo_focal is None: + fu = camera_intrinsics[..., 0, 0] + fv = camera_intrinsics[..., 1, 1] + else: + assert pseudo_focal.shape == (B, H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[..., 0, 2] + cv = camera_intrinsics[..., 1, 2] + + x, y, z = pts_camera[..., 0], pts_camera[..., 1], pts_camera[..., 2] + + uv = torch.zeros((B, H, W, 2), dtype=pts_camera.dtype, device=pts_camera.device) + + uv[..., 0] = fu.view(B, 1, 1) * x / z + cu.view(B, 1, 1) + uv[..., 1] = fv.view(B, 1, 1) * y / z + cv.view(B, 1, 1) + + # Mask for valid coordinates + valid_mask = ( + (z > 0.0) & (uv[..., 0] >= -0.5) & (uv[..., 0] < W - 0.5) & (uv[..., 1] >= -0.5) & (uv[..., 1] < H - 0.5) + ) + # valid_mask = (z > 0.0) & (uv[..., 0] >= 0) & (uv[..., 0] < W) & (uv[..., 1] >= 0) & (uv[..., 1] < H) + + return uv, valid_mask + + +def z_depthmap_to_norm_depthmap_batched(z_depthmap, camera_intrinsics, pseudo_focal=None): + """ + Args: + - z_depthmap (BxHxW array) + - camera_intrinsics: a Bx3x3 matrix + Returns: + pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels. + """ + + B, H, W = z_depthmap.shape + + # Compute 3D ray associated with each pixel + # Strong assumption: there are no skew terms + assert (camera_intrinsics[..., 0, 1] == 0.0).all() + assert (camera_intrinsics[..., 1, 0] == 0.0).all() + if pseudo_focal is None: + fu = camera_intrinsics[..., 0, 0] + fv = camera_intrinsics[..., 1, 1] + else: + assert pseudo_focal.shape == (B, H, W) + fu = fv = pseudo_focal + cu = camera_intrinsics[..., 0, 2] + cv = camera_intrinsics[..., 1, 2] + + rays = torch.ones((B, H, W, 3), dtype=z_depthmap.dtype, device=z_depthmap.device) + + uv = get_meshgrid_torch(W, H, device=z_depthmap.device) + + rays[..., 0] = (uv[..., 0].view(1, H, W) - cu.view(B, 1, 1)) / fu.view(B, 1, 1) + rays[..., 1] = (uv[..., 1].view(1, H, W) - cv.view(B, 1, 1)) / fv.view(B, 1, 1) + + ray_norm = torch.linalg.norm(rays, axis=-1) + + return z_depthmap * ray_norm + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] -= 0.5 + K[1, 2] -= 0.5 + + return K + + +def opencv_to_colmap_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[0, 2] += 0.5 + K[1, 2] += 0.5 + + return K + + +@torch.no_grad() +def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5): + # set invalid points to NaN + _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1) + _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None + _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1 + + # compute median depth overall (ignoring nans) + if quantile == 0.5: + shift_z = torch.nanmedian(_z, dim=-1).values + else: + shift_z = torch.nanquantile(_z, quantile, dim=-1) + + return shift_z # (B,) + + +@torch.no_grad() +def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True): + # set invalid points to NaN + _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3) + _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None + _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1 + + # compute median center + _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3) + if z_only: + _center[..., :2] = 0 # do not center X and Y + + # compute median norm + _norm = ((_pts - _center) if center else _pts).norm(dim=-1) + scale = torch.nanmedian(_norm, dim=1).values + + return _center[:, None, :, :], scale[:, None, None, None] + + +def find_reciprocal_matches(P1, P2): + """ + returns 3 values: + 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match + 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1 + 3 - reciprocal_in_P2.sum(): the number of matches + """ + tree1 = KDTree(P1) + tree2 = KDTree(P2) + + _, nn1_in_P2 = tree2.query(P1, workers=8) + _, nn2_in_P1 = tree1.query(P2, workers=8) + + reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)) + reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)) + assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum() + + return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum() + + +def rotate_vector_with_quaternion( + v: torch.Tensor, quat: torch.Tensor, scalar_first: bool = False, skip_norm=False +) -> torch.Tensor: + """ + Rotate a 3D vector by a quaternion. + + Args: + v (torch.Tensor): A tensor of shape (..., 3) representing the vectors to rotate. + quat (torch.Tensor): A tensor of shape (..., 4) representing the quaternions. + The last dimension is [w, x, y, z] if scalar_first is True, + or [x, y, z, w] if scalar_first is False. + scalar_first (bool): If True, assumes the quaternion is in the format [w, x, y, z]. + Otherwise, assumes the format [x, y, z, w]. + + Returns: + torch.Tensor: A tensor of shape (..., 3) representing the rotated vectors. + """ + if scalar_first: + w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] + else: + x, y, z, w = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] + + # Normalize the quaternion to ensure a valid rotation + if not skip_norm: + norm_quat = torch.sqrt(w**2 + x**2 + y**2 + z**2 + 1e-8) + w, x, y, z = w / norm_quat, x / norm_quat, y / norm_quat, z / norm_quat + + # Vector part of the quaternion + q_vec = torch.stack([x, y, z], dim=-1) # Shape (..., 3) + + # Cross product q_vec x v + t = 2 * torch.cross(q_vec, v, dim=-1) # Intermediate vector, shape (..., 3) + + # Ensure proper broadcasting of w + v_rotated = v + w.unsqueeze(-1) * t + torch.cross(q_vec, t, dim=-1) + + return v_rotated + + +def quaternion_to_rot_matrix(quat: torch.Tensor, scalar_first: bool = False) -> torch.Tensor: + if scalar_first: + w, x, y, z = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] + else: + x, y, z, w = quat[..., 0], quat[..., 1], quat[..., 2], quat[..., 3] + + norm_quat = torch.sqrt(w**2 + x**2 + y**2 + z**2 + 1e-8) + w, x, y, z = w / norm_quat, x / norm_quat, y / norm_quat, z / norm_quat + + xx, yy, zz = x * x, y * y, z * z + xy, xz, yz = x * y, x * z, y * z + wx, wy, wz = w * x, w * y, w * z + + rot_matrix_shape = quat.shape[:-1] + (3, 3) + rot_matrix = torch.empty(rot_matrix_shape, device=quat.device) + + rot_matrix[..., 0, 0] = 1 - 2 * (yy + zz) + rot_matrix[..., 0, 1] = 2 * (xy - wz) + rot_matrix[..., 0, 2] = 2 * (xz + wy) + + rot_matrix[..., 1, 0] = 2 * (xy + wz) + rot_matrix[..., 1, 1] = 1 - 2 * (xx + zz) + rot_matrix[..., 1, 2] = 2 * (yz - wx) + + rot_matrix[..., 2, 0] = 2 * (xz - wy) + rot_matrix[..., 2, 1] = 2 * (yz + wx) + rot_matrix[..., 2, 2] = 1 - 2 * (xx + yy) + + return rot_matrix diff --git a/uniflowmatch/utils/viz.py b/uniflowmatch/utils/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf442d08d7035322f7e6984ec53c0e695db6b17 --- /dev/null +++ b/uniflowmatch/utils/viz.py @@ -0,0 +1,93 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + + +def warp_image_with_flow(source_image, source_mask, target_image, flow) -> np.ndarray: + """ + Warp the target to source image using the given flow vectors. + Flow vectors indicate the displacement from source to target. + + Args: + source_image: np.ndarray of shape (H, W, 3), normalized to [0, 1] + target_image: np.ndarray of shape (H, W, 3), normalized to [0, 1] + flow: np.ndarray of shape (H, W, 2) + source_mask: non_occluded mask represented in source image. + + Returns: + warped_image: target_image warped according to flow into frame of source image + np.ndarray of shape (H, W, 3), normalized to [0, 1] + + """ + # assert source_image.shape[-1] == 3 + # assert target_image.shape[-1] == 3 + + assert flow.shape[-1] == 2 + + # Get the shape of the source image + height, width = source_image.shape[:2] + target_height, target_width = target_image.shape[:2] + + # Create mesh grid + x, y = np.meshgrid(np.arange(width), np.arange(height)) + + # Apply flow displacements + flow_x, flow_y = flow[..., 0], flow[..., 1] + x_new = np.clip(x + flow_x, 0, target_width - 1) + 0.5 + y_new = np.clip(y + flow_y, 0, target_height - 1) + 0.5 + + x_new = (x_new / target_image.shape[1]) * 2 - 1 + y_new = (y_new / target_image.shape[0]) * 2 - 1 + + warped_image = F.grid_sample( + torch.from_numpy(target_image).permute(2, 0, 1)[None, ...].float(), + torch.from_numpy(np.stack([x_new, y_new], axis=-1)).float()[None, ...], + mode="bilinear", + align_corners=False, + ) + + warped_image = warped_image[0].permute(1, 2, 0).numpy() + + if source_mask is not None: + warped_image = warped_image * (source_mask > 0.5) + + return warped_image + + +def visualize_flow(flow, flow_scale): + """ + Visualize optical flow with direction modulating color and magnitude modulating saturation in HSV color space. + + Args: + flow (np.ndarray): Flow array of shape (H, W, 2), where the first dimension + represents (flow_x, flow_y). + flow_scale (float): The scaling factor for the magnitude of the flow. + + Returns: + np.ndarray: An RGB image visualizing the flow. + """ + # Convert CHW to HWC + flow_hwc = flow + + # Compute the magnitude and angle of the flow + magnitude = np.sqrt(np.square(flow_hwc[..., 0]) + np.square(flow_hwc[..., 1])) + angle = np.arctan2(flow_hwc[..., 1], flow_hwc[..., 0]) # Angle in radians (-pi, pi) + + # Normalize the magnitude with the provided flow scale + magnitude = magnitude / flow_scale + magnitude = np.clip(magnitude, 0, 1) # Clip values to [0, 1] for saturation + + # Convert angle from radians to degrees (used for color hue in HSV) + angle_deg = np.degrees(angle) % 360 # Convert angle to [0, 360] degrees + + # Create an HSV image: hue is based on angle, saturation on magnitude, and value is always 1 + hsv_image = np.zeros((flow_hwc.shape[0], flow_hwc.shape[1], 3), dtype=np.uint8) + hsv_image[..., 0] = angle_deg / 2 # OpenCV expects hue in range [0, 180] + hsv_image[..., 1] = magnitude * 255 # Saturation in range [0, 255] + hsv_image[..., 2] = 255 # Value always max (brightest) + + # Convert HSV image to RGB using OpenCV + rgb_image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR) + + return rgb_image