diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..b84b68b4ba418e3eab445abab6217e40e69d34cb 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,49 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/LingBot-VLA.pdf filter=lfs diff=lfs merge=lfs -text
+assets/PaliGemmaPI.png filter=lfs diff=lfs merge=lfs -text
+assets/QwenPI.png filter=lfs diff=lfs merge=lfs -text
+assets/QwenPI_PaliGemmaPI.png filter=lfs diff=lfs merge=lfs -text
+assets/Teaser.png filter=lfs diff=lfs merge=lfs -text
+assets/exp-gm-100.png filter=lfs diff=lfs merge=lfs -text
+assets/exp-robotwin.png filter=lfs diff=lfs merge=lfs -text
+assets/scale_ps.png filter=lfs diff=lfs merge=lfs -text
+assets/scale_sr.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/assets/normal_comaprison.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/assets/overview_simplified.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/assets/panorama_pipeline.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/01_HouseIndoor.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/02_Office.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/03_Traffic.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/05_Mountain.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/06_MaitreyaBuddha.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/07_Breads.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/08_CatGirl.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/09_Restaurant.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/10_MedievalVillage.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/MoGe/example_images/panorama/Braunschweig_Panoram.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/assets/attention/fig-attention-vis.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/assets/dataset/diversity_figure.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-divided.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-full.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_grasp/fig-grasp-demo.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-dynamic-tracking.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-scene-tracking-crop.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/assets/teaser/teaser-crop.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/raw_depth.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/rgb.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/raw_depth.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/rgb.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/raw_depth.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/rgb.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/raw_depth.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/rgb.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/raw_depth.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/rgb.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/raw_depth.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/rgb.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/raw_depth.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/raw_depth.png filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/rgb.jpg filter=lfs diff=lfs merge=lfs -text
+lingbotvla/models/vla/vision_models/lingbot-depth/tech-report.pdf filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..cb37d28c5edcdb340b886895089299cc2a45f03e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,222 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[codz]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py.cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# UV
+# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+#uv.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+#poetry.toml
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
+# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
+#pdm.lock
+#pdm.toml
+.pdm-python
+.pdm-build/
+
+# pixi
+# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
+#pixi.lock
+# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
+# in the .venv directory. It is recommended not to include this directory in version control.
+.pixi
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.envrc
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+# Abstra
+# Abstra is an AI-powered process automation framework.
+# Ignore directories containing user credentials, local state, and settings.
+# Learn more at https://abstra.io/docs
+.abstra/
+
+# Visual Studio Code
+# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
+# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
+# and can be added to the global gitignore or merged into this file. However, if you prefer,
+# you could uncomment the following to ignore the entire vscode folder
+# .vscode/
+
+# Ruff stuff:
+.ruff_cache/
+
+# PyPI configuration file
+.pypirc
+
+# Cursor
+# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
+# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
+# refer to https://docs.cursor.com/context/ignore-files
+.cursorignore
+.cursorindexingignore
+
+# Marimo
+marimo/_static/
+marimo/_lsp/
+__marimo__/
+
+# log
+*log.txt
+ossutil_output/
+.sumi/
+env.sh
+pids_qwenpi.txt
+run.sh
+start_multi_eval.sh
+trash/
+eval.sh
+
+# xwc
+output/
+wandb/
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..2175b72b29bd9e708a6f0fc15c1fd48c6c3e0951
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,6 @@
+[submodule "lingbotvla/models/vla/vision_models/lingbot-depth"]
+ path = lingbotvla/models/vla/vision_models/lingbot-depth
+ url = https://github.com/Robbyant/lingbot-depth
+[submodule "lingbotvla/models/vla/vision_models/MoGe"]
+ path = lingbotvla/models/vla/vision_models/MoGe
+ url = https://github.com/microsoft/MoGe.git
diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000000000000000000000000000000000000..73ccbfa8510d81d046ef66b29ac75c5555b65f46
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,88 @@
+{
+ // Use IntelliSense to learn about possible attributes.
+ // Hover to view descriptions of existing attributes.
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "deploy lingbotvla (模块方式)",
+ "type": "debugpy",
+ "request": "launch",
+ "module": "deploy.lingbot_robotwin_policy",
+ "console": "integratedTerminal",
+ "cwd": "${workspaceFolder}",
+ "justMyCode": false,
+ "args": [
+ "--model_path",
+ "output/ori_4/checkpoints/global_step_12850/hf_ckpt",
+ "--use_length",
+ "50",
+ "--chunk_ret",
+ "true",
+ "--debug_infer_once"
+ ],
+ "env": {
+ "CUDA_VISIBLE_DEVICES": "0",
+ "QWEN25_PATH": "/group/ossdphi_algo_scratch_11/weicxu/huggingface_cache/hub/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/66285546d2b821cf421d4f5eb2576359d3770cd3"
+ }
+ },
+ {
+ "name": "example_call_robotwin_server",
+ "type": "debugpy",
+ "request": "launch",
+ "module": "deploy.example_call_robotwin_server",
+ "console": "integratedTerminal",
+ "cwd": "${workspaceFolder}",
+ "justMyCode": false,
+ "args": [
+ "--host",
+ "127.0.0.1",
+ "--port",
+ "8006"
+ ],
+ "env": {
+ "CUDA_VISIBLE_DEVICES": "0"
+ }
+ },
+ {
+ "name": "train lingbotvla",
+ "type": "debugpy",
+ "request": "launch",
+ "program": "${file}",
+ "console": "integratedTerminal",
+ "justMyCode": false,
+ "args": [
+ "configs/vla/robotwin_load20000h.yaml",
+ "--model.model_path",
+ "robbyant/lingbot-vla-4b",
+ "--data.train_path",
+ "mixed_robotwin_5tasks_repo_0.1.0",
+ "--train.output_dir",
+ "output/",
+ "--model.tokenizer_path",
+ "Qwen/Qwen2.5-VL-3B-Instruct",
+ "--train.micro_batch_size",
+ "1",
+ "--train.global_batch_size",
+ "1",
+ "--train.enable_full_shard",
+ "true",
+ "--train.use_compile",
+ "false",
+ "--train.enable_fp32",
+ "false",
+ "--train.freeze_vision_encoder",
+ "true",
+ ],
+ "env": {
+ "CUDA_VISIBLE_DEVICES": "2",
+ "LOCAL_RANK": "0",
+ "RANK": "0",
+ "WORLD_SIZE": "1",
+ "MASTER_ADDR": "localhost",
+ "MASTER_PORT": "29500",
+ "PYDEVD_USE_SYS_MONITORING": "0"
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/LEGAL.md b/LEGAL.md
new file mode 100644
index 0000000000000000000000000000000000000000..f96892081dd58b22ee2199adffd7b188b79e7e7f
--- /dev/null
+++ b/LEGAL.md
@@ -0,0 +1,7 @@
+Legal Disclaimer
+
+Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
+
+法律免责声明
+
+关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..7068f8196e7bbdc9c4904b7e26da166c5c934e75
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [2026] [Robbyant Team]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..cd932772290256b68b29b1f744410ff7ac467145
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,21 @@
+.PHONY: build commit quality style test
+
+check_dirs := tasks tests lingbot docs setup.py
+
+build:
+ python3 setup.py sdist bdist_wheel
+
+commit:
+ pre-commit install
+ pre-commit run --all-files
+
+quality:
+ ruff check $(check_dirs)
+ ruff format --check $(check_dirs)
+
+style:
+ ruff check $(check_dirs) --fix
+ ruff format $(check_dirs)
+
+test:
+ pytest tests/
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3ed9271d7e96739c2302db6243081ee596dbd1a3
--- /dev/null
+++ b/README.md
@@ -0,0 +1,330 @@
+
LingBot-VLA: A Pragmatic VLA Foundation Model
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## 🥳 We are excited to introduce **LingBot-VLA**, a pragmatic Vision-Language-Action foundation model.
+
+**LingBot-VLA** has focused on **Pragmatic**:
+- **Large-scale Pre-training Data**: 20,000 hours of real-world
+data from 9 popular dual-arm robot configurations.
+
+
+
+
+
+- **Strong Performance**: Achieve clear superiority over competitors on simulation and real-world benchmarks.
+- **Training Efficiency**: Represent a 1.5 ∼ 2.8× (depending on the relied VLM base model) speedup over existing VLA-oriented codebases.
+
+## 🚀 News
+- **[2026-01-27]** LingBot-VLA Technical Report is available on Arxiv.
+- **[2026-01-27]** Weights and code released!
+
+
+---
+
+
+## 🛠️ Installation
+Requirements
+ - Python 3.12.3
+ - Pytorch 2.8.0
+ - CUDA 12.8
+
+```bash
+# Install Lerobot
+pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu128
+GIT_LFS_SKIP_SMUDGE=1 git clone https://github.com/huggingface/lerobot.git
+cd lerobot
+git checkout 0cf864870cf29f4738d3ade893e6fd13fbd7cdb5
+pip install -e .
+# Install flash attention
+pip install /path/to/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
+
+# Clone the repository
+git clone https://github.com/robbyant/lingbot-vla.git
+cd lingbot-vla/
+git submodule update --remote --recursive
+pip install -e .
+pip install -r requirements.txt
+# Install LingBot-Depth dependency
+cd ./lingbotvla/models/vla/vision_models/lingbot-depth/
+pip install -e . --no-deps
+cd ../MoGe
+pip install -e .
+```
+
+---
+
+## 📦 Model Download
+We release LingBot-VLA pre-trained weights in two configurations: depth-free version and a depth-distillated version.
+- **Pretrained Checkpoints for Post-Training with and without depth**
+
+| Model Name | Huggingface | ModelScope | Description |
+| :--- | :---: | :---: | :---: |
+| LingBot-VLA-4B | [🤗 lingbot-vla-4b](https://huggingface.co/robbyant/lingbot-vla-4b) | [🤖 lingbot-vla-4b](https://modelscope.cn/models/Robbyant/lingbot-vla-4b) | LingBot-VLA *w/o* Depth|
+| LingBot-VLA-4B-Depth | [🤗 lingbot-vla-4b-depth](https://huggingface.co/robbyant/lingbot-vla-4b-depth) | [🤖 lingbot-vla-4b-depth](https://modelscope.cn/models/Robbyant/lingbot-vla-4b-depth) | LingBot-VLA *w/* Depth |
+
+
+
+
+To train LingBot with our codebase, weights from [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct), [MoGe-2-vitb-normal](https://huggingface.co/Ruicheng/moge-2-vitb-normal), and [LingBot-Depth](https://huggingface.co/robbyant/lingbot-depth-pretrain-vitl-14) also need to be prepared.
+- **Run Command**:
+```bash
+python3 scripts/download_hf_model.py --repo_id robbyant/lingbot-vla-4b --local_dir lingbot-vla-4b
+```
+---
+
+## 💻 Post-Training Example
+
+- **Data Preparation**:
+Please follow [RoboTwin2.0 Preparation](experiment/robotwin/README.md)
+
+- **Training Configuration**:
+We provide the mixed post-training configuration in five RoboTwin 2.0 tasks ("open_microwave" "click_bell" "stack_blocks_three" "place_shoe" "put_object_cabinet").
+
+Click to expand full YAML configuration
+
+```yaml
+model:
+ model_path: "path/to/lingbot_vla_checkpoint" # Path to pre-trained VLA foundation model (w/o or w depth)
+ tokenizer_path: "path/to/Qwen2.5-VL-3B-Instruct"
+ post_training: true # Enable post-training/fine-tuning mode
+ adanorm_time: true
+ old_adanorm: true
+
+data:
+ datasets_type: vla
+ data_name: robotwin_5_new
+ train_path: "path/to/lerobot_merged_data" # merged data from 5 robotwin2.0 tasks
+ num_workers: 8
+ norm_type: bounds_99_woclip
+ norm_stats_file: assets/norm_stats/robotwin_50.json # file of normalization statistics
+
+train:
+ output_dir: "path/to/output"
+ loss_type: L1_fm # we apply L1 flow-matching loss in robotwin2.0 finetuning
+ data_parallel_mode: fsdp2 # Use Fully Sharded Data Parallel (PyTorch FSDP2)
+ enable_full_shard: false # Don't apply reshare after forward in FSDP2
+ module_fsdp_enable: true
+ use_compile: true # Acceleration via torch.compile
+ use_wandb: false
+ rmpad: false
+ rmpad_with_pos_ids: false
+ ulysses_parallel_size: 1
+ freeze_vision_encoder: false # ViT need to be optimized
+ tokenizer_max_length: 24 # token numbers of task prompt
+ action_dim: 14 # Target robot action space dimension
+ max_action_dim: 75 # action dim in LingBot-VLA
+ max_state_dim: 75 # state dim in LingBot-VLA
+ lr: 1.0e-4
+ lr_decay_style: constant
+ num_train_epochs: 69 # finetuning 20k step
+ micro_batch_size: 32
+ global_batch_size: 256
+ max_steps: 220000
+ ckpt_manager: dcp
+ save_steps: 220000
+ save_epochs: 69
+ enable_fp32: true
+ enable_resume: true # resume training automatically
+ # ===========================================================================
+ # Depth Injection Parameters
+ # (Required only for LingBot-VLA with Depth. Ignore if not using depth)
+ # ===========================================================================
+ align_params:
+ mode: 'query' # Query-based distillation
+ num_task_tokens: 8 # Number of learnable task-specific tokens
+ use_image_tokens: True
+ use_task_tokens: False
+ use_text_tokens: False
+ use_contrastive: True
+ contrastive_loss_weight: 0.3
+ depth_loss_weight: 0.002
+ llm: # VLM Projection Settings
+ dim_out: 2048
+ image_token_size: 8
+ image_input_size: 224
+ depth:
+ model_type: MoRGBD
+ moge_path: /"path/to/moGe-2-vitb-normal"
+ morgbd_path: "path/to/LingBot-Depth"
+ num_layers: 1
+ num_heads: 4
+ dim_head: 32
+ ff_mult: 1
+ num_backbone_tokens: 256
+ token_size: 16
+ dim_out: 1024
+ input_size: 224
+ visual_steps: 10000
+ visual_dir: "path/to/output/images" # visualization path of depth distillation
+```
+
+
+- **Run Command**:
+```bash
+# without detph
+bash train.sh tasks/vla/train_lingbotvla.py ./configs/vla/robotwin_load20000h.yaml --model.model_path /path/to/LingBot-VLA --data.train_path path/to/mixed_robotwin_5tasks --train.output_dir /path/to/lingbot_robotwin5tasks/ --model.tokenizer_path /path/to/Qwen2.5-VL-3B-Instruct --train.micro_batch_size ${your_batch_size} --train.global_batch_size ${your_batch_size * your_gpu_num}
+
+# with depth
+bash train.sh tasks/vla/train_lingbotvla.py ./configs/vla/robotwin_load20000h_depth.yaml --model.model_path /path/to/LingBot-VLA-Depth --data.train_path /path/to/mixed_robotwin_5tasks --train.output_dir /path/to/lingbot_depth_robotwin5tasks --model.tokenizer_path /path/to/Qwen2.5-VL-3B-Instruct --model.moge_path /path/to/moge2-vitb-normal.pt --model.morgbd_path /path/to/LingBot-Depth-Pretrained --train.micro_batch_size ${your_batch_size} --train.global_batch_size ${your_batch_size * your_gpu_num}
+```
+
+- **Evaluation**
+```bash
+# robotwin2.0
+export QWEN25_PATH=path_to_Qwen2.5-VL-3B-Instruct
+python -m deploy.lingbot_robotwin_policy \
+ --model_path path_to_your_model \
+ --use_length 50 \
+ --port port
+```
+
+- **Customized Post-training**:
+To construct post-training in specified downstream tasks, we have provided an example and please refer to [Custom](lingbotvla/data/vla_data/README.md) for details.
+---
+
+## 🏗️ Efficiency
+
+
+
+We evaluate the training efficiency of our codebase against established baselines for both Qwen2.5-VL-3B-π and PaliGemma-3B-pt-224-π models. The results demonstrate that our codebase
+achieved the fastest training speeds in both model settings. The above figures detail the training throughput across configurations of 8, 16, 32, 128, and 256 GPUs, alongside the theoretical linear scaling limit.
+
+> **📢 Note on Throughput Metrics:**
+> All throughput values (e.g., 261 samples/sec) represent the **total aggregate throughput across all GPUs**, not per-GPU performance.
+>
(Updated: Previously mislabeled as per-GPU in earlier versions. We apologize for the confusion.)
+
+---
+
+## 📊 Performance
+
+Our LingBot-VLA achieves state-of-the-art results on real-world and simulation benchmarks:
+- **GM-100 across 3 robot platforms**
+
+
+
+
+ | Platform |
+ WALL-OSS |
+ GR00T N1.6 |
+ π0.5 |
+ Ours w/o depth |
+ Ours w/ depth |
+
+
+ | SR | PS |
+ SR | PS |
+ SR | PS |
+ SR | PS |
+ SR | PS |
+
+
+
+
+ | Agibot G1 |
+ 2.99% | 8.75% | 5.23% | 12.63% | 7.77% | 21.98% | 12.82% | 30.04% | 11.98% | 30.47% |
+
+
+ | AgileX |
+ 2.26% | 8.16% | 3.26% | 10.52% | 17.20% | 34.82% | 15.50% | 36.31% | 18.93% | 40.36% |
+
+
+ | Galaxea R1Pro |
+ 6.89% | 14.13% | 14.29% | 24.83% | 14.10% | 26.14% | 18.89% | 34.71% | 20.98% | 35.40% |
+
+
+ | Average |
+ 4.05% | 10.35% | 7.59% | 15.99% | 13.02% | 27.65% | 15.74% | 33.69% | 17.30% | 35.41% |
+
+
+
+
+
+- **RoboTwin 2.0 (Clean and Randomized)**
+
+
+
+
+ | Simulation Tasks |
+ π0.5 |
+ Ours w/o depth |
+ Ours w/ depth |
+
+
+ | Clean |
+ Rand. |
+ Clean |
+ Rand. |
+ Clean |
+ Rand. |
+
+
+
+
+ | Average SR |
+ 82.74% |
+ 76.76% |
+ 86.50% |
+ 85.34% |
+ 88.56% |
+ 86.68% |
+
+
+
+
+
+
+📢 We have released our checkpoints of LingBot-VLA-Posttrain-Robotwin:
+| Model Name | Huggingface | ModelScope | Description |
+| :--- | :---: | :---: | :---: |
+| LingBot-VLA-4B-Posttrain-Robotwin | [🤗 lingbot-vla-4b-posttrain-robotwin](https://huggingface.co/robbyant/lingbot-vla-4b-posttrain-robotwin) | [🤖 lingbot-vla-4b-posttrain-robotwin](https://modelscope.cn/models/Robbyant/lingbot-vla-4b-posttrain-robotwin) | LingBot-VLA-Posttrain-Robotwin *w/o* Depth|
+| LingBot-VLA-4B-Depth-Posttrain-Robotwin | [🤗 lingbot-vla-4b-depth-posttrain-robotwin](https://huggingface.co/robbyant/lingbot-vla-4b-depth-posttrain-robotwin) | [🤖 lingbot-vla-4b-depth-posttrain-robotwin](https://modelscope.cn/models/Robbyant/lingbot-vla-4b-depth-posttrain-robotwin) | LingBot-VLA-Posttrain-Robotwin *w/* Depth |
+
+We also provided [evaluation code](deploy/lingbot_robotwin_policy_rep.py) for the community to reproduce the performance of LingBot-VLA on Robotwin 2.0:
+```bash
+export QWEN25_PATH=path_to_Qwen2.5-VL-3B-Instruct
+python -m deploy.lingbot_robotwin_policy_rep \
+ --model_path Path_to_LingBot-VLA-Posttrain-Robotwin \
+ --use_length 50 \
+ --port port
+```
+
+
+
+
+
+
+---
+
+## 📝 Citation
+
+If you find our work useful in your research, feel free to give us a cite.
+
+```bibtex
+@article{wu2026pragmatic,
+ title={A Pragmatic VLA Foundation Model},
+ author={Wei Wu and Fan Lu and Yunnan Wang and Shuai Yang and Shi Liu and Fangjing Wang and Shuailei Ma and He Sun and Yong Wang and Zhenqi Qiu and Houlong Xiong and Ziyu Wang and Shuai Zhou and Yiyu Ren and Kejia Zhang and Hui Yu and Jingmei Zhao and Qian Zhu and Ran Cheng and Yong-Lu Li and Yongtao Huang and Xing Zhu and Yujun Shen and Kecheng Zheng},
+ journal={arXiv preprint arXiv:2601.18692v1},
+ year={2026}
+}
+```
+
+---
+
+## 📄 License Agreement
+This project is licensed under the [Apache-2.0 License](LICENSE).
+
+## 😊 Acknowledgement
+We would like to express our sincere gratitude to the developers of [VeOmni](https://arxiv.org/abs/2508.02317) and [LeRobot](https://github.com/huggingface/lerobot#). This project benefits significantly from their outstanding work and contributions to the open-source community.
diff --git a/assets/LingBot-VLA.pdf b/assets/LingBot-VLA.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..e7c4d609dfdb0c522c37b47066b713308eab7cd7
--- /dev/null
+++ b/assets/LingBot-VLA.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b0a361d6084d74afc0bc9fcdee5051375b701a8e41013460107a46902bd0426
+size 10000817
diff --git a/assets/PaliGemmaPI.png b/assets/PaliGemmaPI.png
new file mode 100644
index 0000000000000000000000000000000000000000..38767cca1036bed3ab467ca7dea7e391d9cfa8b9
--- /dev/null
+++ b/assets/PaliGemmaPI.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e691d3ffcabb56307bd58397b04b575e03186b6e6f98aa86cd0a00f6327659b8
+size 458344
diff --git a/assets/QwenPI.png b/assets/QwenPI.png
new file mode 100644
index 0000000000000000000000000000000000000000..056c634d2450b29a73ce36b0ac0e3dfb493d91dc
--- /dev/null
+++ b/assets/QwenPI.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f327696f64edd947a3f4b6ce4d81d88420bc8ca756fc80b4db937228d571f150
+size 441610
diff --git a/assets/QwenPI_PaliGemmaPI.png b/assets/QwenPI_PaliGemmaPI.png
new file mode 100644
index 0000000000000000000000000000000000000000..3a1692e66ba5c14d76a1972c9810d3d94d773456
--- /dev/null
+++ b/assets/QwenPI_PaliGemmaPI.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4ce326329047abdf297f713ae303693db983de4849f3ad5f32a92c3ca310658d
+size 208888
diff --git a/assets/Teaser.png b/assets/Teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..b8b3a9b9f26597d1cdebb8307044bfc27d95ebac
--- /dev/null
+++ b/assets/Teaser.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7081c4c6c8586c21ade32fbfe7547f0841b201c46302ab495c9537cfc982ab54
+size 9139384
diff --git a/assets/exp-gm-100.png b/assets/exp-gm-100.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ea59061f2c13411f6d32a8c8dbe7695a6e34d1d
--- /dev/null
+++ b/assets/exp-gm-100.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9afddc707eb74534e0c1e3903eed0ee6a2ea24df883f7eb1b2fc8d0c5862068d
+size 515527
diff --git a/assets/exp-robotwin.png b/assets/exp-robotwin.png
new file mode 100644
index 0000000000000000000000000000000000000000..a09e8e3af3a036547691c36b686e64bf915401d1
--- /dev/null
+++ b/assets/exp-robotwin.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1d61317bee06123a946302d358ff14f11cc01640cfb820f31630cbf612373ecc
+size 396348
diff --git a/assets/norm_stats/libero.json b/assets/norm_stats/libero.json
new file mode 100644
index 0000000000000000000000000000000000000000..23bc6f126d158e6f762ade069abbf6447cbae684
--- /dev/null
+++ b/assets/norm_stats/libero.json
@@ -0,0 +1,280 @@
+{
+ "norm_stats": {
+ "state": {
+ "mean": [
+ -0.04617275670170784,
+ 0.034034404903650284,
+ 0.7647115588188171,
+ 2.971421480178833,
+ -0.2198116034269333,
+ -0.1260652393102646,
+ 0.02694438025355339,
+ -0.0272101741284132,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "std": [
+ 0.1049584373831749,
+ 0.15187117457389832,
+ 0.3785041272640228,
+ 0.3451951742172241,
+ 0.910057544708252,
+ 0.3253032863140106,
+ 0.014151589013636112,
+ 0.014038060791790485,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "q01": [
+ -0.4003246918797493,
+ -0.268838057410717,
+ 0.03963126605004072,
+ 1.5141939243793487,
+ -2.7199491125106814,
+ -1.0708919448852539,
+ 0.0017206525699933989,
+ -0.04004273633235134,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "q99": [
+ 0.1335429027736188,
+ 0.3378903574764729,
+ 1.2657122139371932,
+ 3.2784227243721484,
+ 2.4147262454509733,
+ 0.5962245464324951,
+ 0.04029089962062426,
+ -0.001789628425752747,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ]
+ },
+ "actions": {
+ "mean": [
+ 0.06667574495077133,
+ 0.06483978033065796,
+ -0.80384361743927,
+ -2.970874071121216,
+ 0.22662578523159027,
+ 0.11959122866392136,
+ -0.036161474883556366,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "std": [
+ 0.32812511920928955,
+ 0.4197826683521271,
+ 0.6153613924980164,
+ 0.35168182849884033,
+ 0.9132273197174072,
+ 0.3432939946651459,
+ 0.9993459582328796,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "q01": [
+ -0.7088336983919143,
+ -0.8786727856397629,
+ -2.097322083187103,
+ -3.3041505486488343,
+ -2.4138620029449465,
+ -0.6111064100980759,
+ -1.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "q99": [
+ 1.0219826289415357,
+ 1.0526966882944104,
+ 0.7265835452556608,
+ -1.491220802116394,
+ 2.7264903316497806,
+ 1.1191907620668413,
+ 0.9996,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/assets/norm_stats/robotwin_50.json b/assets/norm_stats/robotwin_50.json
new file mode 100644
index 0000000000000000000000000000000000000000..832b84f37f3894f287122a5169e02bd664e1137f
--- /dev/null
+++ b/assets/norm_stats/robotwin_50.json
@@ -0,0 +1,229 @@
+{
+ "norm_stats": {
+ "action.arm.position": {
+ "mean": [
+ -0.22649447619915009,
+ 1.0910465717315674,
+ 0.8046976923942566,
+ -0.3529793620109558,
+ 0.056382808834314346,
+ -0.04518803581595421,
+ 0.23444592952728271,
+ 1.1117788553237915,
+ 0.8302268385887146,
+ -0.3584558367729187,
+ -0.010058438405394554,
+ 0.010835078544914722
+ ],
+ "std": [
+ 0.36951732635498047,
+ 0.9946224689483643,
+ 0.7907869219779968,
+ 0.663685142993927,
+ 0.24930860102176666,
+ 0.5646992921829224,
+ 0.32377511262893677,
+ 1.0205038785934448,
+ 0.8121177554130554,
+ 0.7205489277839661,
+ 0.25676125288009644,
+ 0.6210611462593079
+ ],
+ "q01": [
+ -0.9676963651657111,
+ -0.0003164021181873977,
+ -0.0008187678098678652,
+ -1.5952941972732544,
+ -0.4444093635320664,
+ -2.2108209049224854,
+ -0.13648582720756508,
+ -0.0025135905981064077,
+ -0.0016476722434163094,
+ -1.7023667912483216,
+ -1.0292453282356262,
+ -1.6702169750213622
+ ],
+ "q99": [
+ 0.17045696868896432,
+ 2.5792064671580563,
+ 2.4791862522006034,
+ 1.263499072647095,
+ 1.2283580561399456,
+ 1.4622943069458012,
+ 1.096450059175491,
+ 2.605947977209091,
+ 2.5039097490906714,
+ 1.3104696589708325,
+ 1.074876550579071,
+ 2.104229341125489
+ ],
+ "q02": [
+ -0.9234203773498537,
+ -0.0003164021181873977,
+ -0.0008187678098678652,
+ -1.509812859249115,
+ -0.32799621334075924,
+ -1.656348336791992,
+ -0.05942733430862468,
+ -0.0025135905981064077,
+ -0.0016476722434163094,
+ -1.6187864029407502,
+ -0.8712951603889465,
+ -1.5470734649658198
+ ],
+ "q98": [
+ 0.11836757125854458,
+ 2.4944407171577216,
+ 2.3239549394726753,
+ 1.0776700769424439,
+ 1.0128444806575776,
+ 1.2158620544433596,
+ 0.945415413093567,
+ 2.5296102081775667,
+ 2.3580759009346366,
+ 1.2048114322423933,
+ 0.6983346325874327,
+ 1.7523907409667974
+ ]
+ },
+ "action.effector.position": {
+ "mean": [
+ 0.6722026467323303,
+ 0.6737783551216125
+ ],
+ "std": [
+ 0.45274168252944946,
+ 0.45141810178756714
+ ],
+ "q01": [
+ -1e-10,
+ -1e-10
+ ],
+ "q99": [
+ 0.99980000009996,
+ 0.99980000009996
+ ],
+ "q02": [
+ -1e-10,
+ -1e-10
+ ],
+ "q98": [
+ 0.99980000009996,
+ 0.99980000009996
+ ]
+ },
+ "observation.state.arm.position": {
+ "mean": [
+ -0.22545991837978363,
+ 1.0864390134811401,
+ 0.8012449741363525,
+ -0.3515830338001251,
+ 0.05604754388332367,
+ -0.0445503294467926,
+ 0.23296862840652466,
+ 1.1059207916259766,
+ 0.8258985280990601,
+ -0.3568105697631836,
+ -0.00992637686431408,
+ 0.010328034870326519
+ ],
+ "std": [
+ 0.3688313364982605,
+ 0.9950565099716187,
+ 0.7906551957130432,
+ 0.6622100472450256,
+ 0.24865445494651794,
+ 0.5626452565193176,
+ 0.32314980030059814,
+ 1.0208053588867188,
+ 0.8119285702705383,
+ 0.718558132648468,
+ 0.25572913885116577,
+ 0.6181830763816833
+ ],
+ "q01": [
+ -0.9676963651657111,
+ -0.0003164021181873977,
+ -0.0008187678098678652,
+ -1.5938075653076171,
+ -0.44261839199066166,
+ -2.198074409103393,
+ -0.13494465734958627,
+ -0.0025135905981064077,
+ -0.0016476722434163094,
+ -1.7015782970190048,
+ -1.0292453282356262,
+ -1.6682623161315915
+ ],
+ "q99": [
+ 0.17045696868896432,
+ 2.5792064671580563,
+ 2.4782622562915084,
+ 1.2545792808532719,
+ 1.2247761130571364,
+ 1.458045475006104,
+ 1.0856618701696394,
+ 2.6036578441381453,
+ 2.502444082275033,
+ 1.3057386935949324,
+ 1.0699406078338622,
+ 2.0983653644561766
+ ],
+ "q02": [
+ -0.9234203773498537,
+ -0.0003164021181873977,
+ -0.0008187678098678652,
+ -1.5083262272834776,
+ -0.32799621334075924,
+ -1.6499750888824458,
+ -0.05942733430862468,
+ -0.0025135905981064077,
+ -0.0016476722434163094,
+ -1.6172094144821167,
+ -0.8684746216773986,
+ -1.5470734649658198
+ ],
+ "q98": [
+ 0.11836757125854458,
+ 2.4944407171577216,
+ 2.320258955836296,
+ 1.0754401289939883,
+ 1.0116504996299742,
+ 1.2137376384735115,
+ 0.945415413093567,
+ 2.528846830487251,
+ 2.3551445673033595,
+ 1.2016574553251265,
+ 0.6969243632316591,
+ 1.746526764297485
+ ]
+ },
+ "observation.state.effector.position": {
+ "mean": [
+ 0.6734354496002197,
+ 0.6749846339225769
+ ],
+ "std": [
+ 0.4522727429866791,
+ 0.45095184445381165
+ ],
+ "q01": [
+ -1e-10,
+ -1e-10
+ ],
+ "q99": [
+ 0.99980000009996,
+ 0.99980000009996
+ ],
+ "q02": [
+ -1e-10,
+ -1e-10
+ ],
+ "q98": [
+ 0.99980000009996,
+ 0.99980000009996
+ ]
+ }
+ },
+ "count": 532992
+}
\ No newline at end of file
diff --git a/assets/norm_stats/robotwin_5_customized.json b/assets/norm_stats/robotwin_5_customized.json
new file mode 100644
index 0000000000000000000000000000000000000000..6cc9f79b21460f0b448a9dfa9f12b2f2e4a6d7bf
--- /dev/null
+++ b/assets/norm_stats/robotwin_5_customized.json
@@ -0,0 +1,201 @@
+{
+ "norm_stats": {
+ "action": {
+ "mean": [
+ -0.32207754254341125,
+ 1.406205654144287,
+ 1.1087545156478882,
+ -0.6245313882827759,
+ -0.027720848098397255,
+ -0.035565875470638275,
+ 0.4717631936073303,
+ 0.25276312232017517,
+ 0.8104884624481201,
+ 0.5522242188453674,
+ -0.1358797252178192,
+ 0.13210205733776093,
+ -0.13196010887622833,
+ 0.7805091738700867
+ ],
+ "std": [
+ 0.2855374813079834,
+ 0.9229381084442139,
+ 0.8118345737457275,
+ 0.49564430117607117,
+ 0.16244904696941376,
+ 0.5517618656158447,
+ 0.4883338212966919,
+ 0.40702372789382935,
+ 1.036325216293335,
+ 0.7480976581573486,
+ 0.7034134268760681,
+ 0.3450477123260498,
+ 0.7341580390930176,
+ 0.4033139646053314
+ ],
+ "q01": [
+ -0.8213654638230801,
+ -5.257390398583084e-7,
+ -0.00002296771708643064,
+ -1.6557389229632915,
+ -0.6564541918039322,
+ -1.1997157670021057,
+ 0.0,
+ -0.0013322193384173175,
+ 0.0,
+ -0.0000281171942333458,
+ -1.4858032744407654,
+ -0.013652276556193832,
+ -1.5582030366897581,
+ 0.0
+ ],
+ "q99": [
+ 0.01988644998967637,
+ 2.618066892673189,
+ 2.8887816588023267,
+ -0.00009503023102874764,
+ 0.39941834962368006,
+ 1.3274614672660827,
+ 0.9998,
+ 1.2499000839233396,
+ 2.403721238327026,
+ 2.223998639903084,
+ 1.3482957191944123,
+ 1.2036741195514797,
+ 2.3008846492767336,
+ 0.9998
+ ],
+ "q02": [
+ -0.8116190195694566,
+ -5.257390398583084e-7,
+ -0.00002296771708643064,
+ -1.5653808554142714,
+ -0.5909986785650253,
+ -0.9318809885978698,
+ 0.0,
+ -0.0013322193384173175,
+ 0.0,
+ -0.0000281171942333458,
+ -1.400590261220932,
+ -0.005905654035508634,
+ -1.5582030366897581,
+ 0.0
+ ],
+ "q98": [
+ 0.01988644998967637,
+ 2.509362170317786,
+ 2.6153081541584893,
+ -0.00009503023102874764,
+ 0.34549802929162987,
+ 1.2313367155075077,
+ 0.9998,
+ 1.2416952819347378,
+ 2.374588215923309,
+ 2.1395174845976728,
+ 1.328065291595459,
+ 1.1956508319407702,
+ 2.172924092388153,
+ 0.9998
+ ]
+ },
+ "observation.state": {
+ "mean": [
+ -0.320831835269928,
+ 1.401549220085144,
+ 1.1045918464660645,
+ -0.6217827796936035,
+ -0.0279570072889328,
+ -0.03499468415975571,
+ 0.4726906716823578,
+ 0.2512069344520569,
+ 0.8065828680992126,
+ 0.5495453476905823,
+ -0.13533149659633636,
+ 0.13129419088363647,
+ -0.1315813809633255,
+ 0.7816013693809509
+ ],
+ "std": [
+ 0.28554511070251465,
+ 0.924691379070282,
+ 0.8124904036521912,
+ 0.49545007944107056,
+ 0.16213101148605347,
+ 0.5504377484321594,
+ 0.4883865714073181,
+ 0.40611740946769714,
+ 1.035233497619629,
+ 0.7470027208328247,
+ 0.7013660073280334,
+ 0.3439686894416809,
+ 0.7313857674598694,
+ 0.4025507867336273
+ ],
+ "q01": [
+ -0.8213654638230801,
+ -5.257390398583084e-7,
+ -0.00002296771708643064,
+ -1.6557389229632915,
+ -0.6564541918039322,
+ -1.1997157670021057,
+ 0.0,
+ -0.0013322193384173175,
+ 0.0,
+ -0.0000281171942333458,
+ -1.483351101398468,
+ -0.013652276556193832,
+ -1.5582030366897581,
+ 0.0
+ ],
+ "q99": [
+ 0.01988644998967637,
+ 2.6186390227908487,
+ 2.889423615385998,
+ -0.00009503023102874764,
+ 0.39780878782272344,
+ 1.3274614672660827,
+ 0.9998,
+ 1.2499000839233396,
+ 2.404215018367767,
+ 2.2201366442319794,
+ 1.347682675933838,
+ 1.2036741195514797,
+ 2.3008846492767336,
+ 0.9998
+ ],
+ "q02": [
+ -0.8116190195694566,
+ -5.257390398583084e-7,
+ -0.00002296771708643064,
+ -1.5653808554142714,
+ -0.5909986785650253,
+ -0.9318809885978698,
+ 0.0,
+ -0.0013322193384173175,
+ 0.0,
+ -0.0000281171942333458,
+ -1.3981380881786347,
+ -0.005905654035508634,
+ -1.5582030366897581,
+ 0.0
+ ],
+ "q98": [
+ 0.01988644998967637,
+ 2.509362170317786,
+ 2.61595011074216,
+ -0.00009503023102874764,
+ 0.3452297689914703,
+ 1.2313367155075077,
+ 0.9998,
+ 1.2416952819347378,
+ 2.374588215923309,
+ 2.1380692362210083,
+ 1.328065291595459,
+ 1.1956508319407702,
+ 2.1450514958381657,
+ 0.9998
+ ]
+ }
+ },
+ "count": 74240
+}
\ No newline at end of file
diff --git a/assets/norm_stats/robotwin_all_new.json b/assets/norm_stats/robotwin_all_new.json
new file mode 100644
index 0000000000000000000000000000000000000000..eb48cb8f1bd411fe56ba6be41a23753086a03e56
--- /dev/null
+++ b/assets/norm_stats/robotwin_all_new.json
@@ -0,0 +1,229 @@
+{
+ "norm_stats": {
+ "action.arm.position": {
+ "mean": [
+ -0.2260681688785553,
+ 1.090435266494751,
+ 0.8042582273483276,
+ -0.3527189791202545,
+ 0.056556474417448044,
+ -0.04530515521764755,
+ 0.2346765249967575,
+ 1.112542748451233,
+ 0.8304542303085327,
+ -0.357768177986145,
+ -0.01014612801373005,
+ 0.010991317220032215
+ ],
+ "std": [
+ 0.3691432774066925,
+ 0.994762122631073,
+ 0.7908730506896973,
+ 0.6637247800827026,
+ 0.24963052570819855,
+ 0.5638052821159363,
+ 0.32393988966941833,
+ 1.0204970836639404,
+ 0.8119731545448303,
+ 0.7209287285804749,
+ 0.25776439905166626,
+ 0.6208906769752502
+ ],
+ "q01": [
+ -0.9676963651657111,
+ -0.0003164021181873977,
+ -0.0026667596280574857,
+ -1.596037513256073,
+ -0.4467973255872727,
+ -2.20232324104309,
+ -0.13648582720756508,
+ -0.0017502129077910933,
+ -0.0023805056512355804,
+ -1.703943779706955,
+ -1.0264247895240783,
+ -1.6682623161315915
+ ],
+ "q99": [
+ 0.17045696868896432,
+ 2.5760957974332737,
+ 2.4727182808369395,
+ 1.259782492733002,
+ 1.2253731035709379,
+ 1.4495478111267097,
+ 1.0841207003116606,
+ 2.6036578441381453,
+ 2.4987799152359367,
+ 1.3104696589708325,
+ 1.0692354731559752,
+ 2.104229341125489
+ ],
+ "q02": [
+ -0.9260248472213748,
+ -0.0003164021181873977,
+ -0.0026667596280574857,
+ -1.5090695432662964,
+ -0.3291901943683624,
+ -1.6520995048522948,
+ -0.05942733430862468,
+ -0.0017502129077910933,
+ -0.0023805056512355804,
+ -1.6187864029407502,
+ -0.8741156991004944,
+ -1.5490281238555905
+ ],
+ "q98": [
+ 0.1157631013870235,
+ 2.4936630497265257,
+ 2.3193349599272013,
+ 1.0769267609596254,
+ 1.0140384616851805,
+ 1.2073643905639653,
+ 0.9469565829515458,
+ 2.528083452796936,
+ 2.3551445673033595,
+ 1.2071769149303435,
+ 0.6969243632316591,
+ 1.7504360820770266
+ ]
+ },
+ "action.effector.position": {
+ "mean": [
+ 0.6723259687423706,
+ 0.6735112071037292
+ ],
+ "std": [
+ 0.4526418447494507,
+ 0.4514695405960083
+ ],
+ "q01": [
+ 0.0,
+ 0.0
+ ],
+ "q99": [
+ 0.9998,
+ 0.9998
+ ],
+ "q02": [
+ 0.0,
+ 0.0
+ ],
+ "q98": [
+ 0.9998,
+ 0.9998
+ ]
+ },
+ "observation.state.arm.position": {
+ "mean": [
+ -0.22502799332141876,
+ 1.0857956409454346,
+ 0.8007810711860657,
+ -0.3513113558292389,
+ 0.05622035637497902,
+ -0.044659487903118134,
+ 0.23319771885871887,
+ 1.106688141822815,
+ 0.82613205909729,
+ -0.3561287522315979,
+ -0.010010534897446632,
+ 0.010481182485818863
+ ],
+ "std": [
+ 0.3684558570384979,
+ 0.9951919317245483,
+ 0.7907320857048035,
+ 0.6622379422187805,
+ 0.24897389113903046,
+ 0.5617504119873047,
+ 0.32331398129463196,
+ 1.0208075046539307,
+ 0.8117841482162476,
+ 0.718940019607544,
+ 0.25672635436058044,
+ 0.6180205345153809
+ ],
+ "q01": [
+ -0.9676963651657111,
+ -0.0003164021181873977,
+ -0.0026667596280574857,
+ -1.5938075653076171,
+ -0.4462003350734711,
+ -2.195949993133545,
+ -0.13648582720756508,
+ -0.0017502129077910933,
+ -0.0023805056512355804,
+ -1.703943779706955,
+ -1.0257196548461915,
+ -1.6663076572418207
+ ],
+ "q99": [
+ 0.16785249881744324,
+ 2.5760957974332737,
+ 2.47087028901875,
+ 1.2516060169219974,
+ 1.22238815100193,
+ 1.4495478111267097,
+ 1.073332511305809,
+ 2.602131088757515,
+ 2.494382914789021,
+ 1.3104696589708325,
+ 1.0657097997665406,
+ 2.102274682235718
+ ],
+ "q02": [
+ -0.9234203773498537,
+ -0.0003164021181873977,
+ -0.0026667596280574857,
+ -1.5060962793350219,
+ -0.3291901943683624,
+ -1.6436018409728996,
+ -0.05788616445064587,
+ -0.0017502129077910933,
+ -0.0023805056512355804,
+ -1.6164209202528,
+ -0.8698848910331727,
+ -1.5490281238555905
+ ],
+ "q98": [
+ 0.1157631013870235,
+ 2.4928853822953303,
+ 2.3174869681090113,
+ 1.0754401289939883,
+ 1.0122474901437757,
+ 1.2031155586242681,
+ 0.945415413093567,
+ 2.527320075106621,
+ 2.3522132336720825,
+ 1.202445949554443,
+ 0.694808959197998,
+ 1.7484814231872559
+ ]
+ },
+ "observation.state.effector.position": {
+ "mean": [
+ 0.6735715866088867,
+ 0.6747165322303772
+ ],
+ "std": [
+ 0.4521658420562744,
+ 0.4510030150413513
+ ],
+ "q01": [
+ 0.0,
+ 0.0
+ ],
+ "q99": [
+ 0.9998,
+ 0.9998
+ ],
+ "q02": [
+ 0.0,
+ 0.0
+ ],
+ "q98": [
+ 0.9998,
+ 0.9998
+ ]
+ }
+ },
+ "count": 535680
+}
\ No newline at end of file
diff --git a/assets/scale_ps.png b/assets/scale_ps.png
new file mode 100644
index 0000000000000000000000000000000000000000..c60567db19924e370fa371c839e29c5f7f583db4
--- /dev/null
+++ b/assets/scale_ps.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b23143996c78b30f658b9a81e0d46c96c2231d9dd2646775b0c057773a1fce14
+size 480748
diff --git a/assets/scale_sr.png b/assets/scale_sr.png
new file mode 100644
index 0000000000000000000000000000000000000000..d809e89cf07b513dc53e3146641336fda1dff45f
--- /dev/null
+++ b/assets/scale_sr.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3becc2bb6d5355f672dc110a4578277c3eac1cf53f3cba726e5e6277b8d9c413
+size 465811
diff --git a/configs/norm/robotwin_5.yaml b/configs/norm/robotwin_5.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ba3618edb2e392c10f4130e1f53d0cfa926a8b33
--- /dev/null
+++ b/configs/norm/robotwin_5.yaml
@@ -0,0 +1,12 @@
+model:
+ model_path: /path/to/LingBot-VLA-Depth
+ tokenizer_path: /path/to/Qwen2.5-VL-3B-Instruct/
+
+data:
+ datasets_type: vla
+ train_path: /path/to/mixed_robotwin_5tasks
+ norm_path: assets/norm_stats/robotwin_5_custom.json
+
+train:
+ global_batch_size: 512
+ output_dir: output/norm
\ No newline at end of file
diff --git a/configs/vla/robotwin_load20000h.yaml b/configs/vla/robotwin_load20000h.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d0f1c7df3a8b328933bd5649808fc8f5ec951879
--- /dev/null
+++ b/configs/vla/robotwin_load20000h.yaml
@@ -0,0 +1,42 @@
+model:
+ model_path: /path/to/LingBot-VLA
+ tokenizer_path: /path/to/Qwen2.5-VL-3B-Instruct/
+ post_training: true
+ adanorm_time: true
+ old_adanorm: true
+
+data:
+ datasets_type: vla
+ data_name: robotwin_5_new
+ train_path: /path/to/mixed_robotwin_5tasks
+ num_workers: 8
+ norm_type: bounds_99_woclip
+ norm_stats_file: assets/norm_stats/robotwin_50.json
+
+train:
+ output_dir: /path/to/lingbot_robotwin5tasks/
+ loss_type: L1_fm
+ data_parallel_mode: fsdp2
+ enable_full_shard: false
+ module_fsdp_enable: true
+ use_compile: true
+ use_wandb: false
+ rmpad: false
+ rmpad_with_pos_ids: false
+ ulysses_parallel_size: 1
+ freeze_vision_encoder: false
+ tokenizer_max_length: 24
+ action_dim: 14
+ max_action_dim: 75
+ max_state_dim: 75
+ lr: 1.0e-4
+ lr_decay_style: constant
+ num_train_epochs: 69
+ micro_batch_size: 32
+ global_batch_size: 256
+ max_steps: 220000
+ ckpt_manager: dcp
+ save_steps: 220000
+ save_epochs: 69
+ enable_fp32: true
+ enable_resume: true
\ No newline at end of file
diff --git a/configs/vla/robotwin_load20000h_depth.yaml b/configs/vla/robotwin_load20000h_depth.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9dc939fd9c131ff6bb522c30a39eeeb3fb64c38f
--- /dev/null
+++ b/configs/vla/robotwin_load20000h_depth.yaml
@@ -0,0 +1,68 @@
+model:
+ model_path: /path/to/LingBot-VLA-Depth
+ tokenizer_path: /path/to/Qwen2.5-VL-3B-Instruct/
+ post_training: true
+ adanorm_time: true
+ old_adanorm: true
+ moge_path: /path/to/moge2-vitb-normal
+ morgbd_path: /path/to/LingBot-Depth-Pretrained
+
+data:
+ datasets_type: vla
+ data_name: robotwin_5_new
+ train_path: /path/to/mixed_robotwin_5tasks
+ num_workers: 8
+ norm_type: bounds_99_woclip
+ norm_stats_file: assets/norm_stats/robotwin_50.json
+
+train:
+ output_dir: /path/to/lingbot_depth_robotwin5tasks/
+ loss_type: L1_fm
+ data_parallel_mode: fsdp2
+ enable_full_shard: false
+ module_fsdp_enable: true
+ use_compile: true
+ use_wandb: false
+ rmpad: false
+ rmpad_with_pos_ids: false
+ ulysses_parallel_size: 1
+ freeze_vision_encoder: false
+ tokenizer_max_length: 24
+ action_dim: 14
+ max_action_dim: 75
+ max_state_dim: 75
+ lr: 1.0e-4
+ lr_decay_style: constant
+ num_train_epochs: 69
+ micro_batch_size: 32
+ global_batch_size: 256
+ max_steps: 220000
+ ckpt_manager: dcp
+ save_steps: 220000
+ save_epochs: 69
+ enable_fp32: true
+ enable_resume: true
+ align_params:
+ mode: 'query'
+ num_task_tokens: 8
+ use_image_tokens: True
+ use_task_tokens: False
+ use_text_tokens: False
+ use_contrastive: True
+ contrastive_loss_weight: 0.3
+ depth_loss_weight: 0.004
+ llm:
+ dim_out: 2048
+ image_token_size: 8
+ image_input_size: 224
+ depth:
+ model_type: MoRGBD
+ num_layers: 1
+ num_heads: 4
+ dim_head: 32
+ ff_mult: 1
+ num_backbone_tokens: 256
+ token_size: 16
+ dim_out: 1024
+ input_size: 224
+ visual_steps: 10000
\ No newline at end of file
diff --git a/deploy/__init__.py b/deploy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/deploy/image_tools.py b/deploy/image_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a971b9d5f6b1495fd6cdea202ffa607d8b34bf0
--- /dev/null
+++ b/deploy/image_tools.py
@@ -0,0 +1,58 @@
+import numpy as np
+from PIL import Image
+
+
+def convert_to_uint8(img: np.ndarray) -> np.ndarray:
+ """Converts an image to uint8 if it is a float image.
+
+ This is important for reducing the size of the image when sending it over the network.
+ """
+ if np.issubdtype(img.dtype, np.floating):
+ img = (255 * img).astype(np.uint8)
+ return img
+
+
+def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
+ """Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
+
+ Args:
+ images: A batch of images in [..., height, width, channel] format.
+ height: The target height of the image.
+ width: The target width of the image.
+ method: The interpolation method to use. Default is bilinear.
+
+ Returns:
+ The resized images in [..., height, width, channel].
+ """
+ # If the images are already the correct size, return them as is.
+ if images.shape[-3:-1] == (height, width):
+ return images
+
+ original_shape = images.shape
+
+ images = images.reshape(-1, *original_shape[-3:])
+ resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
+ return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
+
+
+def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
+ """Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
+ width without distortion by padding with zeros.
+
+ Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
+ """
+ cur_width, cur_height = image.size
+ if cur_width == width and cur_height == height:
+ return image # No need to resize if the image is already the correct size.
+
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+ resized_image = image.resize((resized_width, resized_height), resample=method)
+
+ zero_image = Image.new(resized_image.mode, (width, height), 0)
+ pad_height = max(0, int((height - resized_height) / 2))
+ pad_width = max(0, int((width - resized_width) / 2))
+ zero_image.paste(resized_image, (pad_width, pad_height))
+ assert zero_image.size == (width, height)
+ return zero_image
diff --git a/deploy/lingbot_robotwin_policy.py b/deploy/lingbot_robotwin_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b2c8d84f0fd7b4562e2829d27c2d1ac12ccb08
--- /dev/null
+++ b/deploy/lingbot_robotwin_policy.py
@@ -0,0 +1,506 @@
+import json
+import os
+import time
+import random
+import numpy as np
+from collections import deque
+import torchvision
+import yaml
+from types import SimpleNamespace
+from packaging.version import Version
+from typing import Callable, Dict, List, Optional, Type, Union, Tuple, Any, Sequence
+from glob import glob
+from tqdm import tqdm
+from safetensors import safe_open
+from safetensors.torch import load_file
+from pathlib import Path
+from PIL import Image
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+import transformers
+from transformers.models.auto.tokenization_auto import AutoTokenizer
+from transformers import (
+ AutoConfig,
+ PretrainedConfig,
+ PreTrainedModel,
+ AutoProcessor,
+)
+
+from lerobot.configs.policies import PreTrainedConfig
+from lingbotvla.models.vla.pi0.modeling_pi0 import PI0Policy
+from lingbotvla.models.vla.pi0.modeling_lingbot_vla import LingbotVlaPolicy
+from lingbotvla.data.vla_data.transform import Normalizer, prepare_images, prepare_language, prepare_state
+from lingbotvla.models import build_processor
+
+
+def set_seed_everywhere(seed: int):
+ """Sets the random seed for Python, NumPy, and PyTorch functions."""
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ os.environ["PYTHONHASHSEED"] = str(seed)
+
+set_seed_everywhere(42)
+
+BASE_MODEL_PATH = {
+ 'pi0': os.environ.get('PALIGEMMA_PATH', './paligemma-3b-pt-224/'),
+ 'lingbotvla': os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/'),
+}
+
+def load_model_weights(policy, path_to_pi_model, strict=True):
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
+ merged_weights = {}
+
+ for file_path in tqdm(all_safetensors):
+ with safe_open(file_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ merged_weights[key] = f.get_tensor(key)
+ policy.load_state_dict(merged_weights, strict=strict)
+
+
+def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
+ crop_scale = 0.9
+ side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, 1.0))) # side length scale
+ out_size = (224, 224)
+
+ # Convert input to PIL Image
+ if isinstance(image, np.ndarray):
+ arr = image
+ if arr.dtype.kind == "f":
+ # If floats likely in [0,1], map to [0,255]
+ if arr.max() <= 1.0 and arr.min() >= 0.0:
+ arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8)
+ else:
+ arr = np.clip(arr, 0.0, 255.0).astype(np.uint8)
+ elif arr.dtype == np.uint16:
+ # Map 16-bit to 8-bit
+ arr = (arr / 257).astype(np.uint8)
+ elif arr.dtype != np.uint8:
+ arr = arr.astype(np.uint8)
+ pil = Image.fromarray(arr)
+ elif isinstance(image, Image.Image):
+ pil = image
+ else:
+ raise TypeError("image must be a numpy array or PIL.Image.Image")
+
+ # Force RGB for consistent output
+ pil = pil.convert("RGB")
+ W, H = pil.size
+
+ # Compute centered crop box (integer pixels)
+ crop_w = max(1, int(round(W * side_scale)))
+ crop_h = max(1, int(round(H * side_scale)))
+ left = (W - crop_w) // 2
+ top = (H - crop_h) // 2
+ right = left + crop_w
+ bottom = top + crop_h
+
+ cropped = pil.crop((left, top, right, bottom))
+ resized = cropped.resize(out_size, resample=Image.BILINEAR)
+ return resized
+
+def resize_with_pad(img, width, height, pad_value=-1):
+ # assume no-op when width height fits already
+ if img.ndim != 4:
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
+
+ # channel last to channel first if necessary
+ if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3):
+ img = img.permute(0, 3, 1, 2)
+
+ cur_height, cur_width = img.shape[2:]
+
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+ resized_img = F.interpolate(
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ )
+
+ pad_height = max(0, int(height - resized_height))
+ pad_width = max(0, int(width - resized_width))
+
+ # pad on left and top of image
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
+ return padded_img
+
+class PolicyPreprocessMixin:
+
+ @torch.no_grad
+ def select_action(
+ self, observation: dict[str, Tensor], use_bf16: bool = False, vlm_causal: bool = False, noise: Tensor | None = None
+ ):
+ self.eval()
+ device = 'cuda'
+ if use_bf16:
+ dtype = torch.bfloat16
+ else:
+ dtype = torch.float32
+ s1 = time.time()
+
+ if len(observation['images'].shape) == 4:
+ observation['images'] = observation['images'].unsqueeze(0)
+ observation['img_masks'] = observation['img_masks'].unsqueeze(0)
+
+ if 'expert_imgs' in observation:
+ actions = self.model.sample_actions(
+ observation['images'].to(dtype=dtype, device=device),
+ observation['img_masks'].to(device=device),
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
+ observation['lang_masks'].unsqueeze(0).to(device=device),
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
+ observation['expert_imgs'].to(dtype=dtype, device=device),
+ vlm_causal = vlm_causal
+ )
+ else:
+ actions = self.model.sample_actions(
+ observation['images'].to(dtype=dtype, device=device),
+ observation['img_masks'].to(device=device),
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
+ observation['lang_masks'].unsqueeze(0).to(device=device),
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
+ vlm_causal = vlm_causal
+ )
+ delta_time = time.time() - s1
+ print(f'sample_actions cost {delta_time} s')
+ observation['action'] = actions.squeeze(0)[:, :14].to(dtype=torch.float32, device='cpu')
+ if use_bf16:
+ observation['state'] = observation['state'].to(dtype=torch.float32)
+ data = self.normalizer.unnormalize(observation)
+ return data
+
+class LingBotVlaInferencePolicy(PolicyPreprocessMixin, LingbotVlaPolicy):
+ pass # Only combine necessary functions
+
+class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy):
+ pass # Only combine necessary functions
+
+
+def merge_qwen_config(policy_config, qwen_config):
+ if hasattr(qwen_config, 'to_dict'):
+ config_dict = qwen_config.to_dict()
+ else:
+ config_dict = qwen_config
+
+ text_keys = {
+ "hidden_size",
+ "intermediate_size",
+ "num_hidden_layers",
+ "num_attention_heads",
+ "num_key_value_heads",
+ "rms_norm_eps",
+ "rope_theta",
+ "vocab_size",
+ "max_position_embeddings",
+ "hidden_act",
+ "tie_word_embeddings",
+ "tokenizer_path",
+ }
+
+ for key in text_keys:
+ if key in config_dict:
+ setattr(policy_config, key, config_dict[key])
+ print(f"✅ Merged LLM: {key} = {config_dict[key]}")
+
+ if "vision_config" in config_dict:
+ policy_config.vision_config = qwen_config.vision_config
+ else:
+ print("⚠️ Warning: 'vision_config' not found in qwen_config!")
+
+ return policy_config
+
+
+class QwenPiServer:
+ '''
+ policy wrapper to support action ensemble or chunk execution
+ '''
+ def __init__(
+ self,
+ path_to_pi_model="",
+ adaptive_ensemble_alpha=0.1,
+ action_ensemble_horizon=8,
+ use_length=1, # to control the execution length of the action chunk, -1 denotes using action ensemble
+ chunk_ret=False,
+ use_bf16=True,
+ use_fp32=False,
+ ) -> None:
+ assert not (use_bf16 and use_fp32), 'Bfloat16 or Float32!!!'
+ self.adaptive_ensemble_alpha = adaptive_ensemble_alpha
+ self.use_length = use_length
+ self.chunk_ret = chunk_ret
+
+ self.task_description = None
+
+ self.vla = self.load_vla(path_to_pi_model)
+ self.vla = self.vla.cuda().eval()
+ if use_bf16:
+ self.vla = self.vla.to(torch.bfloat16)
+ elif use_fp32:
+ self.vla.model.float()
+ self.global_step = 0
+ self.last_action_chunk = None
+ self.use_bf16 = use_bf16
+ self.use_fp32 = use_fp32
+
+ def load_vla(self, path_to_pi_model) -> LingbotVlaPolicy:
+ # load model
+
+ print(f"loading model from: {path_to_pi_model}")
+ config = PreTrainedConfig.from_pretrained(path_to_pi_model)
+
+ # load training config
+ training_config_path = Path(path_to_pi_model).parent.parent.parent/'lingbotvla_cli.yaml'
+ with open(training_config_path, 'r') as f:
+ training_config = yaml.safe_load(f)
+ f.close()
+
+ # update model config according to training config
+ training_model_config = training_config['model']
+ training_model_config.update(training_config['train'])
+ for k, v in training_model_config.items():
+ v = getattr(config, k, training_model_config[k])
+ setattr(config, k, v)
+
+ # Set attention_implementation to 'eager' to speed up evaluation.
+ config.attention_implementation = 'eager'
+
+ # set base model according to training config
+ training_base_model = training_config['model']['tokenizer_path']
+ if 'paligemma' in training_base_model:
+ model_name = 'pi0'
+ config.vocab_size = 257152 # set vocab size for paligamma
+ elif 'qwen2' in training_base_model.lower():
+ model_name = 'lingbotvla'
+ else:
+ raise ValueError(f"Unsupported base model of {path_to_pi_model}")
+ base_model_path = BASE_MODEL_PATH[model_name]
+ config.tokenizer_path = base_model_path
+ self.model_name = model_name
+
+ qwen_config = AutoConfig.from_pretrained(base_model_path)
+ config = merge_qwen_config(config, qwen_config)
+
+ if 'vocab_size' in training_config['model'] and training_config['model']['vocab_size'] != 0:
+ config.vocab_size = training_config['model']['vocab_size']
+ # load processors
+ self.processor = build_processor(base_model_path)
+ self.language_tokenizer = self.processor.tokenizer
+ self.image_processor = self.processor.image_processor
+ data_config = SimpleNamespace(**training_config['data'])
+
+ print('Initializing model ... ')
+
+ if 'paligemma' in training_base_model:
+ policy = PI0InfernecePolicy(config, tokenizer_path=base_model_path)
+ else:
+ policy = LingBotVlaInferencePolicy(config, tokenizer_path=base_model_path)
+
+ load_model_weights(policy, path_to_pi_model, strict=True)
+
+ policy.feature_transform = None
+ self.data_config = data_config
+ self.config = config
+ self.joint_max_dim = training_config['train']['max_action_dim']
+ self.action_dim = training_config['train']['action_dim']
+ self.chunk_size = training_config['train']['chunk_size']
+ policy.action_dim = self.action_dim
+ policy.chunk_size = self.chunk_size
+ self.norm_stats_file = data_config.norm_stats_file
+ if 'align_params' in training_config['train']:
+ self.use_depth_align = True
+ else: self.use_depth_align = False
+ with open(self.norm_stats_file) as f:
+ self.norm_stats = json.load(f)
+ policy.normalizer = Normalizer(
+ norm_stats=self.norm_stats['norm_stats'],
+ from_file=True,
+ data_type='robotwin',
+ norm_type={
+ "observation.images.cam_high": "identity",
+ "observation.images.cam_left_wrist": "identity",
+ "observation.images.cam_right_wrist": "identity",
+ "observation.state": self.data_config.norm_type,
+ "action": self.data_config.norm_type,
+ },
+ )
+
+ print('Model initialized ... ')
+
+ return policy
+
+ def reset(self, robo_name, path_to_pi_model = None) -> None:
+
+ if path_to_pi_model is not None:
+ self.vla = self.load_vla(path_to_pi_model)
+ self.vla = self.vla.cuda().eval()
+ if self.use_bf16:
+ self.vla = self.vla.to(torch.bfloat16)
+ elif self.use_fp32:
+ self.vla.model.float()
+
+ self.global_step = 0
+ self.last_action_chunk = None
+
+ if getattr(self.data_config, 'norm_type', None) is None:
+ self.data_config.norm_type = 'meanstd'
+ if getattr(self.config, 'vlm_causal', None) is None:
+ self.config.vlm_causal = False
+ if getattr(self.config, 'qwenvl_bos', None) is None:
+ self.config.qwenvl_bos = False
+
+ # if update ckpt path
+ if path_to_pi_model is not None:
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
+ merged_weights = {}
+
+ for file_path in tqdm(all_safetensors):
+ with safe_open(file_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ merged_weights[key] = f.get_tensor(key)
+
+ self.vla.load_state_dict(merged_weights, strict=True)
+
+ def resize_image(self, observation):
+ for image_feature in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
+ assert image_feature in observation
+ assert len(observation[image_feature].shape)==3 and observation[image_feature].shape[-1] == 3
+ image = observation[image_feature]
+ img_pil = Image.fromarray(image)
+ image_size = getattr(self.data_config, 'img_size', 224)
+ img_pil = img_pil.resize((image_size, image_size), Image.BILINEAR)
+
+ # img_resized shape: C*H*W
+ img_resized = np.transpose(np.array(img_pil), (2,0,1)) # (3,224,224)
+ observation[image_feature] = img_resized / 255.
+
+ def infer(self, observation, center_crop=True):
+ """Generates an action with the VLA policy."""
+
+ # (If trained with image augmentations) Center crop image and then resize back up to original size.
+ # IMPORTANT: Let's say crop scale == 0.9. To get the new height and width (post-crop), multiply
+ # the original height and width by sqrt(0.9) -- not 0.9!
+ if 'reset' in observation and observation['reset']:
+ self.reset(robo_name=observation['robo_name'], path_to_pi_model=observation['path_to_pi_model'] if 'path_to_pi_model' in observation else None)
+ return dict(action = None)
+
+ self.resize_image(observation)
+ for k, v in observation.items():
+ if isinstance(v, np.ndarray):
+ observation[k] = torch.from_numpy(v)
+
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
+ joint_max_dim = getattr(self, 'joint_max_dim')
+ action_dim = getattr(self, 'action_dim')
+ chunk_size = getattr(self, 'chunk_size')
+ normalized_observation = self.vla.normalizer.normalize(observation)
+ base_image = (normalized_observation["observation.images.cam_high"] * 255).to(torch.uint8)
+ left_wrist_image = (normalized_observation["observation.images.cam_left_wrist"] * 255).to(
+ torch.uint8
+ )
+ right_wrist_image = (normalized_observation["observation.images.cam_right_wrist"] * 255).to(
+ torch.uint8
+ )
+ obs_dict = {
+ "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
+ "state": normalized_observation["observation.state"].to(torch.float32),
+ "prompt": [observation["task"]],
+ }
+ state = prepare_state(self.config, obs_dict)
+ lang_tokens, lang_masks = prepare_language(self.config, self.language_tokenizer, obs_dict)
+ images, img_masks, _ = prepare_images(self.config, self.image_processor, obs_dict)
+ observation = {
+ 'images': images,
+ 'img_masks': img_masks,
+ 'state': state,
+ 'lang_tokens': lang_tokens,
+ 'lang_masks': lang_masks,
+ }
+
+ if self.use_bf16:
+ observation['state'] = observation['state'].to(torch.bfloat16)
+
+ org_actions = ['action']
+ assert len(org_actions)==1, "Only support single action feature"
+ if self.chunk_ret:
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]].float().cpu().numpy()
+ action = action[:self.use_length, :self.action_dim]
+ else:
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]]
+ self.last_action_chunk = action.float().cpu().numpy()
+
+ if self.use_length > 0:
+ action = self.last_action_chunk[self.global_step % self.use_length]
+ action = action[:, :self.action_dim]
+ print(f"on server step: {self.global_step}")
+ self.global_step+=1
+
+ return dict(action = action)
+
+
+import argparse
+from .websocket_policy_server import WebsocketPolicyServer
+
+def main():
+ parser = argparse.ArgumentParser(description="启动 QwenPi WebSocket 策略服务器")
+
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ )
+
+ parser.add_argument(
+ "--use_length",
+ type=int,
+ default=50,
+ help="used length of action chunk"
+ )
+
+ parser.add_argument(
+ "--chunk_ret",
+ type=bool,
+ default=True,
+ help=" True: The returned action tensor includes the horizon dimension. This allows the model to output a sequence of actions for each horizon step. False: The horizon dimension is omitted. The model selects and returns the next step autonomously based on its policy."
+ )
+
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8006,
+ help="port of WebSocket"
+ )
+
+ parser.add_argument(
+ "--debug_infer_once",
+ action="store_true",
+ help="Run one infer with dummy observation then exit (for debugging infer() without WebSocket client)",
+ )
+
+ args = parser.parse_args()
+
+ model = QwenPiServer(args.model_path, use_length=args.use_length, chunk_ret=args.chunk_ret)
+ if args.debug_infer_once:
+ # 调试用:不启动 WebSocket,只跑一次 infer,可在 infer / select_action 里下断点
+ dummy_obs = {
+ "observation.images.cam_high": np.zeros((224, 224, 3), dtype=np.uint8),
+ "observation.images.cam_left_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
+ "observation.images.cam_right_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
+ "observation.state": np.zeros(model.action_dim, dtype=np.float32),
+ "task": "dummy task for debug",
+ "reset": False,
+ }
+ out = model.infer(dummy_obs)
+ print("debug_infer_once result keys:", out.keys())
+ return
+ model_server = WebsocketPolicyServer(model, port=args.port)
+ model_server.serve_forever()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/deploy/lingbot_robotwin_policy_rep.py b/deploy/lingbot_robotwin_policy_rep.py
new file mode 100644
index 0000000000000000000000000000000000000000..057aded44fa675746fe83f3d79385f92380691d7
--- /dev/null
+++ b/deploy/lingbot_robotwin_policy_rep.py
@@ -0,0 +1,491 @@
+import json
+import os
+import time
+import random
+import numpy as np
+from collections import deque
+import torchvision
+import yaml
+from types import SimpleNamespace
+from packaging.version import Version
+from typing import Callable, Dict, List, Optional, Type, Union, Tuple, Any, Sequence
+from glob import glob
+from tqdm import tqdm
+from safetensors import safe_open
+from safetensors.torch import load_file
+from pathlib import Path
+from PIL import Image
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+import transformers
+from transformers.models.auto.tokenization_auto import AutoTokenizer
+from transformers import (
+ AutoConfig,
+ PretrainedConfig,
+ PreTrainedModel,
+ AutoProcessor,
+)
+
+from lerobot.configs.policies import PreTrainedConfig
+from lingbotvla.models.vla.pi0.modeling_pi0 import PI0Policy
+from lingbotvla.models.vla.pi0.modeling_lingbot_vla import LingbotVlaPolicy
+from lingbotvla.data.vla_data.transform import Normalizer, prepare_images, prepare_language, prepare_state
+from lingbotvla.models import build_processor
+
+
+def set_seed_everywhere(seed: int):
+ """Sets the random seed for Python, NumPy, and PyTorch functions."""
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ os.environ["PYTHONHASHSEED"] = str(seed)
+
+set_seed_everywhere(42)
+
+BASE_MODEL_PATH = {
+ 'pi0': os.environ.get('PALIGEMMA_PATH', './paligemma-3b-pt-224/'),
+ 'lingbotvla': os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/'),
+}
+
+def load_model_weights(policy, path_to_pi_model, strict=True):
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
+ merged_weights = {}
+
+ for file_path in tqdm(all_safetensors):
+ with safe_open(file_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ merged_weights[key] = f.get_tensor(key)
+ policy.load_state_dict(merged_weights, strict=strict)
+
+
+def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
+ crop_scale = 0.9
+ side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, 1.0))) # side length scale
+ out_size = (224, 224)
+
+ # Convert input to PIL Image
+ if isinstance(image, np.ndarray):
+ arr = image
+ if arr.dtype.kind == "f":
+ # If floats likely in [0,1], map to [0,255]
+ if arr.max() <= 1.0 and arr.min() >= 0.0:
+ arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8)
+ else:
+ arr = np.clip(arr, 0.0, 255.0).astype(np.uint8)
+ elif arr.dtype == np.uint16:
+ # Map 16-bit to 8-bit
+ arr = (arr / 257).astype(np.uint8)
+ elif arr.dtype != np.uint8:
+ arr = arr.astype(np.uint8)
+ pil = Image.fromarray(arr)
+ elif isinstance(image, Image.Image):
+ pil = image
+ else:
+ raise TypeError("image must be a numpy array or PIL.Image.Image")
+
+ # Force RGB for consistent output
+ pil = pil.convert("RGB")
+ W, H = pil.size
+
+ # Compute centered crop box (integer pixels)
+ crop_w = max(1, int(round(W * side_scale)))
+ crop_h = max(1, int(round(H * side_scale)))
+ left = (W - crop_w) // 2
+ top = (H - crop_h) // 2
+ right = left + crop_w
+ bottom = top + crop_h
+
+ cropped = pil.crop((left, top, right, bottom))
+ resized = cropped.resize(out_size, resample=Image.BILINEAR)
+ return resized
+
+def resize_with_pad(img, width, height, pad_value=-1):
+ # assume no-op when width height fits already
+ if img.ndim != 4:
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
+
+ # channel last to channel first if necessary
+ if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3):
+ img = img.permute(0, 3, 1, 2)
+
+ cur_height, cur_width = img.shape[2:]
+
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+ resized_img = F.interpolate(
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ )
+
+ pad_height = max(0, int(height - resized_height))
+ pad_width = max(0, int(width - resized_width))
+
+ # pad on left and top of image
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
+ return padded_img
+
+class PolicyPreprocessMixin:
+
+ @torch.no_grad
+ def select_action(
+ self, observation: dict[str, Tensor], use_bf16: bool = False, vlm_causal: bool = False, noise: Tensor | None = None
+ ):
+ self.eval()
+ device = 'cuda'
+ if use_bf16:
+ dtype = torch.bfloat16
+ else:
+ dtype = torch.float32
+ s1 = time.time()
+
+ if len(observation['images'].shape) == 4:
+ observation['images'] = observation['images'].unsqueeze(0)
+ observation['img_masks'] = observation['img_masks'].unsqueeze(0)
+ state_indices = list(range(12)) + list(range(73, 75)) + list(range(12, 14)) + list(range(14, 73))
+ observation['state'] = observation['state'][state_indices]
+ if 'expert_imgs' in observation:
+ actions = self.model.sample_actions(
+ observation['images'].to(dtype=dtype, device=device),
+ observation['img_masks'].to(device=device),
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
+ observation['lang_masks'].unsqueeze(0).to(device=device),
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
+ observation['expert_imgs'].to(dtype=dtype, device=device),
+ vlm_causal = vlm_causal
+ )
+ else:
+ actions = self.model.sample_actions(
+ observation['images'].to(dtype=dtype, device=device),
+ observation['img_masks'].to(device=device),
+ observation['lang_tokens'].unsqueeze(0).to(device=device),
+ observation['lang_masks'].unsqueeze(0).to(device=device),
+ observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
+ vlm_causal = vlm_causal
+ )
+ action_indices = list(range(6)) + [14] + list(range(6, 12)) + [15]
+ actions = actions[:, :, action_indices]
+ delta_time = time.time() - s1
+ print(f'sample_actions cost {delta_time} s')
+ observation['action'] = actions.squeeze(0)[:, :14].to(dtype=torch.float32, device='cpu')
+ if use_bf16:
+ observation['state'] = observation['state'].to(dtype=torch.float32)
+ data = self.normalizer.unnormalize(observation)
+ return data
+
+class LingBotVlaInferencePolicy(PolicyPreprocessMixin, LingbotVlaPolicy):
+ pass # Only combine necessary functions
+
+class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy):
+ pass # Only combine necessary functions
+
+
+def merge_qwen_config(policy_config, qwen_config):
+ if hasattr(qwen_config, 'to_dict'):
+ config_dict = qwen_config.to_dict()
+ else:
+ config_dict = qwen_config
+
+ text_keys = {
+ "hidden_size",
+ "intermediate_size",
+ "num_hidden_layers",
+ "num_attention_heads",
+ "num_key_value_heads",
+ "rms_norm_eps",
+ "rope_theta",
+ "vocab_size",
+ "max_position_embeddings",
+ "hidden_act",
+ "tie_word_embeddings",
+ "tokenizer_path",
+ }
+
+ for key in text_keys:
+ if key in config_dict:
+ setattr(policy_config, key, config_dict[key])
+ print(f"✅ Merged LLM: {key} = {config_dict[key]}")
+
+ if "vision_config" in config_dict:
+ policy_config.vision_config = qwen_config.vision_config
+ else:
+ print("⚠️ Warning: 'vision_config' not found in qwen_config!")
+
+ return policy_config
+
+
+class QwenPiServer:
+ '''
+ policy wrapper to support action ensemble or chunk execution
+ '''
+ def __init__(
+ self,
+ path_to_pi_model="",
+ adaptive_ensemble_alpha=0.1,
+ action_ensemble_horizon=8,
+ use_length=1, # to control the execution length of the action chunk, -1 denotes using action ensemble
+ chunk_ret=False,
+ use_bf16=True,
+ use_fp32=False,
+ ) -> None:
+ assert not (use_bf16 and use_fp32), 'Bfloat16 or Float32!!!'
+ self.adaptive_ensemble_alpha = adaptive_ensemble_alpha
+ self.use_length = use_length
+ self.chunk_ret = chunk_ret
+
+ self.task_description = None
+
+ self.vla = self.load_vla(path_to_pi_model)
+ self.vla = self.vla.cuda().eval()
+ if use_bf16:
+ self.vla = self.vla.to(torch.bfloat16)
+ elif use_fp32:
+ self.vla.model.float()
+ self.global_step = 0
+ self.last_action_chunk = None
+ self.use_bf16 = use_bf16
+ self.use_fp32 = use_fp32
+
+ def load_vla(self, path_to_pi_model) -> LingbotVlaPolicy:
+ # load model
+ print(f"loading model from: {path_to_pi_model}")
+ config = PreTrainedConfig.from_pretrained(path_to_pi_model)
+
+ # load training config
+ training_config_path = Path(path_to_pi_model)/'lingbotvla_cli.yaml'
+ with open(training_config_path, 'r') as f:
+ training_config = yaml.safe_load(f)
+ f.close()
+
+ # update model config according to training config
+ training_model_config = training_config['model']
+ training_model_config.update(training_config['train'])
+ for k, v in training_model_config.items():
+ v = getattr(config, k, training_model_config[k])
+ setattr(config, k, v)
+
+ # Set attention_implementation to 'eager' to speed up evaluation.
+ config.attention_implementation = 'eager'
+
+ # set base model according to training config
+ training_base_model = os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/')
+ if 'paligemma' in training_base_model:
+ model_name = 'pi0'
+ config.vocab_size = 257152 # set vocab size for paligamma
+ elif 'qwen2' in training_base_model.lower():
+ model_name = 'lingbotvla'
+ else:
+ raise ValueError(f"Unsupported base model of {path_to_pi_model}")
+ base_model_path = BASE_MODEL_PATH[model_name]
+ config.tokenizer_path = base_model_path
+ self.model_name = model_name
+
+ qwen_config = AutoConfig.from_pretrained(base_model_path)
+ config = merge_qwen_config(config, qwen_config)
+
+ if 'vocab_size' in training_config['model'] and training_config['model']['vocab_size'] != 0:
+ config.vocab_size = training_config['model']['vocab_size']
+ # load processors
+ self.processor = build_processor(base_model_path)
+ self.language_tokenizer = self.processor.tokenizer
+ self.image_processor = self.processor.image_processor
+ data_config = SimpleNamespace(**training_config['data'])
+
+ print('Initializing model ... ')
+
+ if 'paligemma' in training_base_model:
+ policy = PI0InfernecePolicy(config, tokenizer_path=base_model_path)
+ else:
+ policy = LingBotVlaInferencePolicy(config, tokenizer_path=base_model_path, eval=True)
+
+ load_model_weights(policy, path_to_pi_model, strict=True)
+
+ policy.feature_transform = None
+ self.data_config = data_config
+ self.config = config
+ self.joint_max_dim = training_config['train']['max_action_dim']
+ self.action_dim = training_config['train']['action_dim']
+ self.chunk_size = training_config['train']['chunk_size']
+ policy.action_dim = self.action_dim
+ policy.chunk_size = self.chunk_size
+ self.norm_stats_file = 'assets/norm_stats/robotwin_all_new.json'
+ if 'align_params' in training_config['train']:
+ self.use_depth_align = True
+ else: self.use_depth_align = False
+ with open(self.norm_stats_file) as f:
+ self.norm_stats = json.load(f)
+ policy.normalizer = Normalizer(
+ norm_stats=self.norm_stats['norm_stats'],
+ from_file=True,
+ data_type='robotwin_rep',
+ norm_type={
+ "observation.images.cam_high": "identity",
+ "observation.images.cam_left_wrist": "identity",
+ "observation.images.cam_right_wrist": "identity",
+ "observation.state": self.data_config.norm_type,
+ "action": self.data_config.norm_type,
+ },
+ )
+
+ print('Model initialized ... ')
+
+ return policy
+
+ def reset(self, robo_name, path_to_pi_model = None) -> None:
+
+ if path_to_pi_model is not None:
+ self.vla = self.load_vla(path_to_pi_model)
+ self.vla = self.vla.cuda().eval()
+ if self.use_bf16:
+ self.vla = self.vla.to(torch.bfloat16)
+ elif self.use_fp32:
+ self.vla.model.float()
+
+ self.global_step = 0
+ self.last_action_chunk = None
+
+ if getattr(self.data_config, 'norm_type', None) is None:
+ self.data_config.norm_type = 'meanstd'
+ if getattr(self.config, 'vlm_causal', None) is None:
+ self.config.vlm_causal = False
+ if getattr(self.config, 'qwenvl_bos', None) is None:
+ self.config.qwenvl_bos = False
+
+ # if update ckpt path
+ if path_to_pi_model is not None:
+ all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
+ merged_weights = {}
+
+ for file_path in tqdm(all_safetensors):
+ with safe_open(file_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ merged_weights[key] = f.get_tensor(key)
+
+ self.vla.load_state_dict(merged_weights, strict=True)
+
+ def resize_image(self, observation):
+ for image_feature in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
+ assert image_feature in observation
+ assert len(observation[image_feature].shape)==3 and observation[image_feature].shape[-1] == 3
+ image = observation[image_feature]
+ img_pil = Image.fromarray(image)
+ image_size = getattr(self.data_config, 'img_size', 224)
+ img_pil = img_pil.resize((image_size, image_size), Image.BILINEAR)
+
+ # img_resized shape: C*H*W
+ img_resized = np.transpose(np.array(img_pil), (2,0,1)) # (3,224,224)
+ observation[image_feature] = img_resized / 255.
+
+ def infer(self, observation, center_crop=True):
+ """Generates an action with the VLA policy."""
+
+ # (If trained with image augmentations) Center crop image and then resize back up to original size.
+ # IMPORTANT: Let's say crop scale == 0.9. To get the new height and width (post-crop), multiply
+ # the original height and width by sqrt(0.9) -- not 0.9!
+ if 'reset' in observation and observation['reset']:
+ self.reset(robo_name=observation['robo_name'], path_to_pi_model=observation['path_to_pi_model'] if 'path_to_pi_model' in observation else None)
+ return dict(action = None)
+
+ self.resize_image(observation)
+ for k, v in observation.items():
+ if isinstance(v, np.ndarray):
+ observation[k] = torch.from_numpy(v)
+
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
+ joint_max_dim = getattr(self, 'joint_max_dim')
+ action_dim = getattr(self, 'action_dim')
+ chunk_size = getattr(self, 'chunk_size')
+ indices = list(range(6)) + list(range(7, 13)) + [6] + [13]
+ observation["observation.state"] = observation["observation.state"][indices]
+ normalized_observation = self.vla.normalizer.normalize(observation)
+ base_image = (normalized_observation["observation.images.cam_high"] * 255).to(torch.uint8)
+ left_wrist_image = (normalized_observation["observation.images.cam_left_wrist"] * 255).to(
+ torch.uint8
+ )
+ right_wrist_image = (normalized_observation["observation.images.cam_right_wrist"] * 255).to(
+ torch.uint8
+ )
+ obs_dict = {
+ "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
+ "state": normalized_observation["observation.state"].to(torch.float32),
+ "prompt": [observation["task"]],
+ }
+ state = prepare_state(self.config, obs_dict)
+ lang_tokens, lang_masks = prepare_language(self.config, self.language_tokenizer, obs_dict)
+ images, img_masks, _ = prepare_images(self.config, self.image_processor, obs_dict)
+ observation = {
+ 'images': images,
+ 'img_masks': img_masks,
+ 'state': state,
+ 'lang_tokens': lang_tokens,
+ 'lang_masks': lang_masks,
+ }
+
+ if self.use_bf16:
+ observation['state'] = observation['state'].to(torch.bfloat16)
+
+ org_actions = ['action']
+ assert len(org_actions)==1, "Only support single action feature"
+ if self.chunk_ret:
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]].float().cpu().numpy()
+ action = action[:self.use_length, :self.action_dim]
+ else:
+ if self.use_length == -1 or self.global_step % self.use_length == 0:
+ action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]]
+ self.last_action_chunk = action.float().cpu().numpy()
+
+ if self.use_length > 0:
+ action = self.last_action_chunk[self.global_step % self.use_length]
+ action = action[:, :self.action_dim]
+ print(f"on server step: {self.global_step}")
+ self.global_step+=1
+
+ return dict(action = action)
+
+
+import argparse
+from .websocket_policy_server import WebsocketPolicyServer
+
+def main():
+ parser = argparse.ArgumentParser(description="启动 QwenPi WebSocket 策略服务器")
+
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ )
+
+ parser.add_argument(
+ "--use_length",
+ type=int,
+ default=50,
+ help="used length of action chunk"
+ )
+
+ parser.add_argument(
+ "--chunk_ret",
+ type=bool,
+ default=True,
+ help=" True: The returned action tensor includes the horizon dimension. This allows the model to output a sequence of actions for each horizon step. False: The horizon dimension is omitted. The model selects and returns the next step autonomously based on its policy."
+ )
+
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8006,
+ help="port of WebSocket"
+ )
+
+ args = parser.parse_args()
+
+ model = QwenPiServer(args.model_path, use_length=args.use_length, chunk_ret = args.chunk_ret)
+ model_server = WebsocketPolicyServer(model, port=args.port)
+ model_server.serve_forever()
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/deploy/msgpack_numpy.py b/deploy/msgpack_numpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..007f755edf54565579376b077eec7f7f715e1b96
--- /dev/null
+++ b/deploy/msgpack_numpy.py
@@ -0,0 +1,57 @@
+"""Adds NumPy array support to msgpack.
+
+msgpack is good for (de)serializing data over a network for multiple reasons:
+- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
+- msgpack is widely used and has good cross-language support
+- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
+ languages like Python and JavaScript
+- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
+ than pickle for serializing large arrays using the below strategy
+
+The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
+that it falls back to pickle for object arrays.
+"""
+
+import functools
+
+import msgpack
+import numpy as np
+
+
+def pack_array(obj):
+ if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
+ raise ValueError(f"Unsupported dtype: {obj.dtype}")
+
+ if isinstance(obj, np.ndarray):
+ return {
+ b"__ndarray__": True,
+ b"data": obj.tobytes(),
+ b"dtype": obj.dtype.str,
+ b"shape": obj.shape,
+ }
+
+ if isinstance(obj, np.generic):
+ return {
+ b"__npgeneric__": True,
+ b"data": obj.item(),
+ b"dtype": obj.dtype.str,
+ }
+
+ return obj
+
+
+def unpack_array(obj):
+ if b"__ndarray__" in obj:
+ return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
+
+ if b"__npgeneric__" in obj:
+ return np.dtype(obj[b"dtype"]).type(obj[b"data"])
+
+ return obj
+
+
+Packer = functools.partial(msgpack.Packer, default=pack_array)
+packb = functools.partial(msgpack.packb, default=pack_array)
+
+Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
+unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
diff --git a/deploy/websocket_client_policy.py b/deploy/websocket_client_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d6b614b4a9402ab7862033c50c894d9d47a88d6
--- /dev/null
+++ b/deploy/websocket_client_policy.py
@@ -0,0 +1,88 @@
+import logging
+import time
+from typing import Dict, Optional, Tuple
+
+from typing_extensions import override
+import websockets.sync.client
+from .msgpack_numpy import Packer, unpackb
+
+
+class WebsocketClientPolicy:
+ """Implements the Policy interface by communicating with a server over websocket.
+
+ See WebsocketPolicyServer for a corresponding server implementation.
+ """
+
+ def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
+ self._uri = f"ws://{host}"
+ if port is not None:
+ self._uri += f":{port}"
+ self._packer = Packer()
+ self._api_key = api_key
+ self._ws, self._server_metadata = self._wait_for_server()
+
+ def get_server_metadata(self) -> Dict:
+ return self._server_metadata
+
+ def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
+ logging.info(f"Waiting for server at {self._uri}...")
+ while True:
+ try:
+ headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
+ conn = websockets.sync.client.connect(
+ self._uri, compression=None, max_size=None, additional_headers=headers
+ )
+ metadata = unpackb(conn.recv())
+ return conn, metadata
+ except ConnectionRefusedError:
+ logging.info("Still waiting for server...")
+ time.sleep(5)
+
+ @override
+ def infer(self, obs: Dict) -> Dict: # noqa: UP006
+ data = self._packer.pack(obs)
+ self._ws.send(data)
+ response = self._ws.recv()
+ if isinstance(response, str):
+ # we're expecting bytes; if the server sends a string, it's an error.
+ raise RuntimeError(f"Error in inference server:\n{response}")
+ return unpackb(response)
+
+ @override
+ def reset(self, robo_name: str) -> None:
+ self.infer(dict(reset=True, robo_name=robo_name))
+
+if __name__ == "__main__":
+ policy_on_device = WebsocketClientPolicy(port=8000)
+ import torch
+ import numpy as np
+ from PIL import Image
+ from .image_tools import convert_to_uint8
+ device = torch.device("cuda")
+
+ base_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
+ left_wrist_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
+ state = np.random.rand(1,8).astype(np.float32)
+ prompt = ["do something"]
+
+ # observation = {
+ # "image": {
+ # "base_0_rgb": torch.from_numpy(base_0_rgb).to(device)[None],
+ # "left_wrist_0_rgb": torch.from_numpy(left_wrist_0_rgb).to(device)[None],
+ # },
+ # "state": torch.from_numpy(state).to(device)[None],
+ # "prompt": prompt,
+ # }
+
+ observation = {
+ "image": {
+ "base_0_rgb": convert_to_uint8(base_0_rgb),
+ "left_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
+ "right_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
+ },
+ "state": state,
+ "prompt": prompt,
+ }
+
+ policy_on_device.infer(observation)
+ from IPython import embed;embed()
diff --git a/deploy/websocket_policy_server.py b/deploy/websocket_policy_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..1db3dbfe6f8dde9e52aec6a79bca666c191089fe
--- /dev/null
+++ b/deploy/websocket_policy_server.py
@@ -0,0 +1,89 @@
+import asyncio
+import http
+import logging
+import time
+import traceback
+
+from .msgpack_numpy import Packer, unpackb
+import websockets.asyncio.server as _server
+import websockets.frames
+
+logger = logging.getLogger(__name__)
+
+
+class WebsocketPolicyServer:
+ """Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
+
+ Currently only implements the `load` and `infer` methods.
+ """
+
+ def __init__(
+ self,
+ policy,
+ host: str = "0.0.0.0",
+ port: int | None = None,
+ metadata: dict | None = None,
+ ) -> None:
+ self._policy = policy
+ self._host = host
+ self._port = port
+ self._metadata = metadata or {}
+ logging.getLogger("websockets.server").setLevel(logging.INFO)
+
+ def serve_forever(self) -> None:
+ asyncio.run(self.run())
+
+ async def run(self):
+ async with _server.serve(
+ self._handler,
+ self._host,
+ self._port,
+ compression=None,
+ max_size=None,
+ process_request=_health_check,
+ ) as server:
+ await server.serve_forever()
+
+ async def _handler(self, websocket: _server.ServerConnection):
+ logger.info(f"Connection from {websocket.remote_address} opened")
+ packer = Packer()
+
+ await websocket.send(packer.pack(self._metadata))
+
+ prev_total_time = None
+ while True:
+ try:
+ start_time = time.monotonic()
+ obs = unpackb(await websocket.recv())
+
+ infer_time = time.monotonic()
+ action = self._policy.infer(obs)
+ infer_time = time.monotonic() - infer_time
+
+ action["server_timing"] = {
+ "infer_ms": infer_time * 1000,
+ }
+ if prev_total_time is not None:
+ # We can only record the last total time since we also want to include the send time.
+ action["server_timing"]["prev_total_ms"] = prev_total_time * 1000
+
+ await websocket.send(packer.pack(action))
+ prev_total_time = time.monotonic() - start_time
+
+ except websockets.ConnectionClosed:
+ logger.info(f"Connection from {websocket.remote_address} closed")
+ break
+ except Exception:
+ await websocket.send(traceback.format_exc())
+ await websocket.close(
+ code=websockets.frames.CloseCode.INTERNAL_ERROR,
+ reason="Internal server error. Traceback included in previous frame.",
+ )
+ raise
+
+
+def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
+ if request.path == "/healthz":
+ return connection.respond(http.HTTPStatus.OK, "OK\n")
+ # Continue with the normal request handling.
+ return None
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..52590fe52fa4d8a382d2614c10632bf5af554f3a
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,34 @@
+# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)
+# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
+FROM nvcr.io/nvidia/pytorch:24.08-py3
+
+# Define environments
+ENV MAX_JOBS=32
+ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
+ENV DEBIAN_FRONTEND=noninteractive
+ENV NODE_OPTIONS=""
+
+
+# Install systemctl and tini
+RUN apt-get update && \
+apt-get install -y -o Dpkg::Options::="--force-confdef" systemd tini && \
+apt-get clean || { echo "Installation failed"; exit 1; }
+
+RUN apt-get install -y tzdata \
+ && ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
+ && dpkg-reconfigure -f noninteractive tzdata
+
+# Change pip source
+RUN python -m pip install --upgrade pip
+
+# Install torch-2.5.1 + vllm-0.7.3
+RUN pip install --no-cache-dir vllm==0.7.3 torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 tensordict torchdata \
+ transformers>=4.49.0 accelerate datasets peft hf-transfer diffusers \
+ codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb ninja liger-kernel \
+ pytest yapf py-spy pyext pre-commit ruff packaging
+
+# Install flux
+RUN pip install --no-cache-dir byte-flux
+
+# Install flash-attn and triton
+RUN pip install --no-cache-dir flash-attn triton>=3.1.0
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..afbff72cd0aa3df3da6abacd06f0e4cb1c9f25fb
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line.
+SPHINXOPTS =
+SPHINXBUILD = sphinx-build
+SPHINXPROJ = LingBotVLA
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0011c5efe7387c17ac21befb3fd95d722ac9e934
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,19 @@
+# LingBotVLA documents
+
+## Build the docs
+
+```bash
+# Install dependencies.
+pip install -r requirements-docs.txt
+
+# Build the docs.
+make clean
+make html
+```
+
+## Open the docs with your browser
+
+```bash
+python -m http.server -d _build/html/
+```
+Launch your browser and open localhost:8000.
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 0000000000000000000000000000000000000000..e74d1762b3552640132b6d741b39be61c533ebac
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,66 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+# import os
+# import sys
+# sys.path.insert(0, os.path.abspath('.'))
+
+
+# -- Project information -----------------------------------------------------
+
+project = "LingBotVLA"
+# pylint: disable=W0622
+copyright = "2026 Robbyant Team, based on VeOmni by ByteDance Seed Foundation MLSys Team"
+
+# -- General configuration ---------------------------------------------------
+# The master toctree document.
+master_doc = "index"
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "recommonmark",
+ "sphinx.ext.autosectionlabel",
+]
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+source_suffix = [".rst", "rest", ".md"]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = "en"
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = "sphinx_rtd_theme"
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ["_static"]
diff --git a/docs/config/config.md b/docs/config/config.md
new file mode 100644
index 0000000000000000000000000000000000000000..466094b0cb526bce6f96a20b8228853af122396e
--- /dev/null
+++ b/docs/config/config.md
@@ -0,0 +1,96 @@
+## Config arguments Explanation
+### Model configuration arguments
+| Name | Type | Description | Default Value |
+| --- | --- | --- | --- |
+| model.config_path | str | Path to the model huggingface configuration, like `config.json` | model.model_path |
+| model.model_path | str | Path to the model parameter file. If empty, random initialization will be performed | None |
+| model.tokenizer_path | str | Path to the tokenizer | model.model_path |
+| model.encoders | dict | Configuration file for multi-modal encoders | {} |
+| model.decoders | dict | Configuration file for multi-modal decoders | {} |
+| model.input_encoder | str: {"encoder", "decoder"} | Use the encoder of the encoder or decoder to encode the input image | encoder |
+| model.output_encoder | str: {"encoder", "decoder"} | Use the encoder of the encoder or decoder to encode the output image | decoder |
+| model.encode_target | bool | Used to encode the training data for the diffusion model | False |
+
+### Data configuration arguments
+
+| Name | Type | Description | Default Value |
+| --- | --- | --- | --- |
+| data.train_path | str | Path of training dataset | Required |
+| data.train_size | int | Total number of tokens in the training set | 10,000,000 |
+| data.data_type | str: {"plaintext", "conversation"} | Dataset type. | conversation |
+| data.dataloader_type | str: {"native"} | Use the pytorch dataloader or | native |
+| data.datasets_type | str: {"mapping", "iterable"} | Dataset type. `IterativeDataset` or `MappingDataset`, or your custom datsets | mapping |
+| data.text_keys | str: {"content_split", "messages"} | The key corresponding to the text samples in the data dictionary. Generally, it is "content_split" for pretraining and "messages" for SFT. | content_split |
+| data.image_keys | str | The key corresponding to the image samples in the data dictionary. Generally, it is "images". | images |
+| data.chat_template | str | Name of the chat template. | default |
+| data.max_seq_len | int | Maximum training length. | 2048 |
+| data.num_workers | int | Number of multi-process loaders for the dataloader. | 4 |
+| data.drop_last | bool | Whether to discard the remaining data at the end. | True |
+| data.pin_memory | bool | Whether to pin the data in the CPU memory. | True |
+| data.prefetch_factor | int | Number of samples preprocessed by the dataloader. | 2 |
+
+#### Training configuration arguments
+| Name | Type | Description | Default Value |
+| --- | --- | --- | --- |
+| train.output_dir | str | Path to save the model. | Required |
+| train.lr | float | Maximum learning rate. | 5e - 5 |
+| train.lr_min | float | Minimum learning rate. | 1e - 7 |
+| train.weight_decay | float | Weight decay coefficient. | 0 |
+| train.optimizer | str: {"adamw", "anyprecision_adamw"} | Name of the optimizer. | adamw |
+| train.max_grad_norm | float | Gradient clipping norm. | 1.0 |
+| train.micro_batch_size | int | Number of samples processed simultaneously on each GPU. | 1 |
+| train.global_batch_size | int | Global batch size, which must be a multiple of the number of GPUs. | train.micro_batch_size * n_gpus |
+| train.num_train_epochs | int | Number of training epochs. | 1 |
+| train.rmpad | bool | Whether to use rmpad training based on cu_seqlens. | False |
+| train.rmpad_with_pos_ids | bool | Whether to use rmpad training based on position_ids. | False |
+| train.dyn_bsz_margin | int | Number of pad tokens in the dynamic batch. | 0 |
+| train.dyn_bsz_runtime | str: {"main", "worker"} | Running process of the dynamic batch. | worker |
+| train.bsz_warmup_ratio | float | Proportion of batch size warmup in the total number of steps. | 0 |
+| train.lr_warmup_ratio | float | Proportion of learning rate warmup in the total number of steps. | 0 |
+| train.lr_decay_style | str: {"constant", "linear", "cosine"} | Name of the learning rate scheduler. | cosine |
+| train.lr_decay_ratio | float | Proportion of learning rate decay in the total number of steps | 1.0 |
+| train.use_doptim | bool | Whether to use the distributed optimizer during Vescale training(no use for torch fsdp) | False |
+| train.enable_mixed_precision | bool | Whether to enable mixed precision training (higher memory usage but more stable) | True |
+| train.enable_gradient_checkpointing | bool | Whether to enable gradient checkpointing to reduce memory usage. | True |
+| train.enable_reentrant | bool | Whether to enable reentrant in gradient checkpointing. | True |
+| train.enable_full_shard | bool | Whether to use full sharding FSDP (equivalent to ZeRO3). | True |
+| train.enable_fsdp_offload | bool | Whether to enable FSDP CPU offloading (only supported for FSDP1). | False |
+| train.enable_activation_offload | bool | Whether to enable activation value CPU offloading. | False |
+| train.activation_gpu_limit | float | Size of the activation values retained on the GPU (in GB). | 0.0 |
+| train.enable_manual_eager | bool | Whether to use manual eager during Vescale training. | False |
+| train.init_device: meta | str | "cpu", "cuda", "meta", init device for model initialization. use "meta" or cpu for large model(>30B) | cuda |
+| train.enable_full_determinism | bool | Whether to enable deterministic mode (for bitwise alignment). | False |
+| train.empty_cache_steps | int | Number of steps between two cache clearings. -1 means not enabled. | 500 |
+| train.data_parallel_mode | str: {"ddp", "fsdp1", "fsdp2"} | Data parallel algorithm. | ddp |
+| train.tensor_parallel_size | int | Tensor parallel size (currently only supported for vescale training). | 1 |
+| train.pipeline_parallel_size | int | Pipeline parallel size (currently not supported). | 1 |
+| train.ulysses_parallel_size | int | Ulysses sequence parallel size (currently only supported for P6dense and Qwen2VL). | 1 |
+| train.context_parallel_size | int | Ring sequence parallel size (currently not supported) | 1 |
+| train.expert_parallel_size | int | Expert parallel size (currently only supported DeepseekMOE) | 1 |
+| train.load_checkpoint_path | str | Path to the omnistore checkpoint for resuming training. | None |
+| train.save_steps | int | Number of steps between two checkpoint saves. 0 means invalid. | 0 |
+| train.save_epochs | int | Number of epochs between two checkpoint saves. 0 means invalid. | 1 |
+| train.save_hf_weights | bool | Whether to save the model weights in the huggingface format. It is recommended to set it to False for models > 30B to prevent NCCL timeout. You can convert it after training. | True |
+| train.seed | int | Random seed. | 42 |
+| train.use_wandb | bool | Whether to enable byted wandb experiment logging. | True |
+| train.wandb_project | str | Name of the wandb experiment project. | LingBotVLA |
+| train.wandb_name | str | Name of the wandb experiment. | None |
+| train.enable_profiling | bool | Whether to use torch profiling. | False |
+| train.profile_start_step | int | Starting step of profiling. | 1 |
+| train.profile_end_step | int | Ending step of profiling. | 2 |
+| train.profile_trace_dir | str | Path to save the profiling results. | ./trace |
+| train.profile_record_shapes | bool | Whether to record the shapes of the input tensors. | True |
+| train.profile_profile_memory | bool | Whether to record the memory usage. | True |
+| train.profile_with_stack | bool | Whether to record the stack information. | True |
+| train.max_steps | int | Number of steps per training epoch (only used for debugging). | None |
+
+### Inference configuration arguments
+| Name | Type | Description | Default Value |
+| --- | --- | --- | --- |
+| infer.model_path | str | Path to the model parameter file. | Required |
+| infer.tokenizer_path | str | Path to the tokenizer. | model.model_path |
+| infer.seed | int | Random seed. | 42 |
+| infer.do_sample | bool | Whether to enable sampling. | True |
+| infer.temperature | float | Sampling temperature. | 1.0 |
+| infer.top_p | float | Sampling Top P value. | 1.0 |
+| infer.max_tokens | int | Maximum number of tokens generated each time. | 1024 |
diff --git a/docs/examples/qwen2vl.rst b/docs/examples/qwen2vl.rst
new file mode 100644
index 0000000000000000000000000000000000000000..154f50057c4208138dcff01c86e2734132dba37b
--- /dev/null
+++ b/docs/examples/qwen2vl.rst
@@ -0,0 +1,2 @@
+Qwen2VL example
+=========================
diff --git a/docs/examples/qwen3_moe.md b/docs/examples/qwen3_moe.md
new file mode 100644
index 0000000000000000000000000000000000000000..a544ecb83ae74cf8333af0e598de3651e089ac0e
--- /dev/null
+++ b/docs/examples/qwen3_moe.md
@@ -0,0 +1,125 @@
+Qwen3 MoE training guide
+
+1. Download qwen3 moe model
+
+```shell
+python3 scripts/download_hf_model.py \
+ --repo_id Qwen/Qwen3-30B-A3B \
+ --local_dir .
+```
+
+2. Merge qwen3 moe model experts to support GroupGemm optimize
+``` shell
+python3 scripts/moe_ckpt_merge/moe_merge.py --raw_hf_path Qwen3-30B-A3B --merge_hf_path Qwen3-30B-A3B-merge
+```
+
+Most of the MoE models in Transformers referenced the open-source implementation of Mixtral MoE. In this implementation, MoE experts are divided into multiple blocks instead of being combined into a single `nn.Parameters`. Additionally, there are cpu-block operators like `torch.where()` and for loop, which are not very friendly for integrating MoE fusion operators.
+
+Origin [Qwen3MoeMLP](https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L200C1-L213C25) code
+```python
+class Qwen3MoeMLP(nn.Module):
+ def __init__(self, config, intermediate_size=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+class Qwen3MoeSparseMoeBlock(nn.Module):
+ def __init__(self, config):
+
+ ...
+
+ self.experts = nn.ModuleList(
+ [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ ...
+
+ final_hidden_states = torch.zeros(
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ for expert_idx in expert_hitted:
+ expert_layer = self.experts[expert_idx]
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
+
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
+
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
+ return final_hidden_states, router_logits
+
+```
+
+- Combine Qwen3MoeMLP to Qwen3MoeExperts, then use fused moe operator
+
+```python
+class Qwen3MoeExperts(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.hidden_dim = config.hidden_size
+ self.intermediate_size = config.moe_intermediate_size
+ self.gate_proj = torch.nn.Parameter(
+ torch.empty(self.num_experts, self.intermediate_size, self.hidden_dim),
+ requires_grad=True,
+ )
+ self.up_proj = torch.nn.Parameter(
+ torch.empty(self.num_experts, self.intermediate_size, self.hidden_dim),
+ requires_grad=True,
+ )
+ self.down_proj = torch.nn.Parameter(
+ torch.empty(self.num_experts, self.hidden_dim, self.intermediate_size),
+ requires_grad=True,
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states, expert_idx=None, cumsum=None):
+ gate_proj_out = torch.matmul(hidden_states, self.gate_proj[expert_idx].transpose(0, 1))
+ up_proj_out = torch.matmul(hidden_states, self.up_proj[expert_idx].transpose(0, 1))
+
+ out = self.act_fn(gate_proj_out) * up_proj_out
+ out = torch.matmul(out, self.down_proj[expert_idx].transpose(0, 1))
+ return out
+
+
+class Qwen3MoeSparseFusedMoeBlock(nn.Module):
+ def __init__(self, config):
+
+ ...
+
+ self.experts = Qwen3MoeExperts(config)
+
+ def forward(self, hidden_states, expert_idx=None, routing_weights=None, selected_experts=None) -> torch.Tensor:
+
+ ...
+
+ out = fused_moe_forward(
+ module=self,
+ num_experts=self.num_experts,
+ routing_weights=routing_weights,
+ selected_experts=selected_experts,
+ hidden_states=hidden_states,
+ fc1_1_weight=self.gate_proj,
+ fc1_2_weight=self.up_proj,
+ fc2_weight=self.down_proj,
+ )
+ return out
+
+```
+
+3. Train qwen3 moe model
+```
+bash train.sh tasks/train_torch.py configs/pretrain/qwen3-moe.yaml
+```
diff --git a/docs/index.rst b/docs/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..54feaeeab99cb308c49c33eecb1c1a0966e9fef7
--- /dev/null
+++ b/docs/index.rst
@@ -0,0 +1,2 @@
+Welcome to LingBotVLA
+=========================
diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9822d54d8e49a18257c3bc4043eca64c2dd8a359
--- /dev/null
+++ b/docs/requirements-docs.txt
@@ -0,0 +1,9 @@
+# markdown suport
+recommonmark
+# markdown table suport
+sphinx-markdown-tables
+
+# theme default rtd
+
+# crate-docs-theme
+sphinx-rtd-theme
diff --git a/docs/start/start.rst b/docs/start/start.rst
new file mode 100644
index 0000000000000000000000000000000000000000..1e5e90680b631b76ff56c0e5fb31b2cd17750fca
--- /dev/null
+++ b/docs/start/start.rst
@@ -0,0 +1,2 @@
+Getting Started
+=========================
diff --git a/experiment/libero/README.md b/experiment/libero/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d2c355193699b84f1a907260e5cf16b2092c5867
--- /dev/null
+++ b/experiment/libero/README.md
@@ -0,0 +1,18 @@
+# Install official LIBERO
+
+```bash
+git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git libero # (here)
+cd libero
+pip install -e .
+
+cd experiment/libero/libero
+pip install -r req.txt
+```
+
+If can not import xxx from libero.libero please add the libero (here) path to the PYTHONPATH variable.
+
+The results will be save to /project_root/Libero
+
+- release_ensemble/ stores the log files (This directory can be changed by --local_log_dir variable)
+- rollouts stores the videos
+
diff --git a/experiment/libero/libero/libero_utils.py b/experiment/libero/libero/libero_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c71a1bc6ad24060450a7b859ef92345d9ab343a
--- /dev/null
+++ b/experiment/libero/libero/libero_utils.py
@@ -0,0 +1,112 @@
+"""Utils for evaluating policies in LIBERO simulation environments."""
+
+import math
+import os
+
+import imageio
+import numpy as np
+import tensorflow as tf
+from libero.libero import get_libero_path
+from libero.libero.envs import OffScreenRenderEnv
+
+from experiment.libero.robot_utils import (
+ DATE,
+ DATE_TIME,
+)
+
+
+def get_libero_env(task, model_family, resolution=256):
+ """Initializes and returns the LIBERO environment, along with the task description."""
+ task_description = task.language
+ task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
+ env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
+ env = OffScreenRenderEnv(**env_args)
+ env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
+ return env, task_description
+
+
+def get_libero_dummy_action(model_family: str):
+ """Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
+ return [0, 0, 0, 0, 0, 0, -1]
+
+
+def resize_image(img, resize_size):
+ """
+ Takes numpy array corresponding to a single image and returns resized image as numpy array.
+
+ NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow
+ the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training.
+ """
+ assert isinstance(resize_size, tuple)
+ # Resize to image size expected by model
+ with tf.device('/CPU:0'):
+ img = tf.image.encode_jpeg(img) # Encode as JPEG, as done in RLDS dataset builder
+ img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Immediately decode back
+ img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True)
+ img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)
+ img = img.numpy()
+ return img
+
+
+def get_libero_image(obs, resize_size):
+ """Extracts image from observations and preprocesses it."""
+ assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
+ if isinstance(resize_size, int):
+ resize_size = (resize_size, resize_size)
+ img = obs["agentview_image"]
+ img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
+ img = resize_image(img, resize_size)
+ return img
+
+
+def get_libero_wrist_image(obs, resize_size):
+ """Extracts wrist camera image from observations and preprocesses it."""
+ assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
+ if isinstance(resize_size, int):
+ resize_size = (resize_size, resize_size)
+ img = obs["robot0_eye_in_hand_image"]
+ img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
+ img = resize_image(img, resize_size)
+ return img
+
+def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, ckpt_index=None, task_suite_name=None, task_id=None):
+ """Saves an MP4 replay of an episode."""
+ rollout_dir = f"./Libero/rollouts/{ckpt_index}/{task_suite_name}-task{task_id}-{DATE_TIME}-{ckpt_index}"
+ os.makedirs(rollout_dir, exist_ok=True)
+ processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
+ mp4_path = f"{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4"
+ video_writer = imageio.get_writer(mp4_path, fps=30)
+ for img in rollout_images:
+ video_writer.append_data(img)
+ video_writer.close()
+ print(f"Saved rollout MP4 at path {mp4_path}")
+ if log_file is not None:
+ log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
+ return mp4_path
+
+
+def quat2axisangle(quat):
+ """
+ Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
+
+ Converts quaternion to axis-angle format.
+ Returns a unit vector direction scaled by its angle in radians.
+
+ Args:
+ quat (np.array): (x,y,z,w) vec4 float angles
+
+ Returns:
+ np.array: (ax,ay,az) axis-angle exponential coordinates
+ """
+ # clip quaternion
+ if quat[3] > 1.0:
+ quat[3] = 1.0
+ elif quat[3] < -1.0:
+ quat[3] = -1.0
+
+ den = np.sqrt(1.0 - quat[3] * quat[3])
+ if math.isclose(den, 0.0):
+ # This is (close to) a zero degree rotation, immediately return
+ return np.zeros(3)
+
+ return (quat[:3] * 2.0 * math.acos(quat[3])) / den
diff --git a/experiment/libero/libero/req.txt b/experiment/libero/libero/req.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ca0c0467efa33b91ad76d66976da480c1e5f572f
--- /dev/null
+++ b/experiment/libero/libero/req.txt
@@ -0,0 +1,6 @@
+imageio[ffmpeg]
+robosuite==1.4.1
+bddl
+easydict
+cloudpickle
+gym
diff --git a/experiment/libero/libero/run_libero_eval.py b/experiment/libero/libero/run_libero_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..4078655cc89a618ef25cbab005e49163a0c18524
--- /dev/null
+++ b/experiment/libero/libero/run_libero_eval.py
@@ -0,0 +1,300 @@
+"""
+run_libero_eval.py
+
+Runs a model in a LIBERO simulation environment.
+
+Usage:
+ # OpenVLA:
+ # IMPORTANT: Set `center_crop=True` if model is fine-tuned with augmentations
+ python Libero/robot/libero/run_libero_eval.py \
+ --model_family openvla \
+ --pretrained_checkpoint \
+ --task_suite_name [ libero_spatial | libero_object | libero_goal | libero_10 | libero_90 ] \
+ --center_crop [ True | False ] \
+ --run_id_note \
+ --use_wandb [ True | False ] \
+ --wandb_project \
+ --wandb_entity
+"""
+
+import tensorflow as tf
+import os, json, re, io, base64, threading
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
+for g in tf.config.list_physical_devices('GPU'):
+ tf.config.experimental.set_memory_growth(g, True)
+
+import os
+import sys
+parent_dir = os.path.dirname(os.getcwd())
+sys.path.insert(0, parent_dir)
+sys.path.insert(0, os.getcwd())
+
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Union
+import torch
+
+import draccus
+import numpy as np
+import tqdm
+from libero.libero import benchmark
+
+import wandb
+
+# Append current directory so that interpreter can find Libero.robot
+from experiment.libero.libero.libero_utils import (
+ get_libero_dummy_action,
+ get_libero_env,
+ get_libero_image,
+ get_libero_wrist_image,
+ quat2axisangle,
+ save_rollout_video,
+)
+
+from experiment.libero.robot_utils import (
+ DATE_TIME,
+ get_action,
+ get_image_resize_size,
+ get_model,
+ invert_gripper_action,
+ normalize_gripper_action,
+ set_seed_everywhere,
+)
+
+
+@dataclass
+class GenerateConfig:
+ # fmt: off
+
+ #################################################################################################################
+ # Model-specific parameters
+ #################################################################################################################
+ model_family: str = "instruct_vla" # Model family
+ pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path
+ unnorm_key: Optional[str] = None
+ # image_size: list[int] = [224, 224]
+ action_dim: int = 7
+ model_port: int = 8012
+
+ #################################################################################################################
+ # LIBERO environment-specific parameters
+ #################################################################################################################
+ task_suite_name: str = "libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
+ task_id: Optional[int] = None
+ num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
+ num_trials_per_task: int = 50 # Number of rollouts per task
+
+ #################################################################################################################
+ # Utils
+ #################################################################################################################
+ run_id_note: Optional[str] = None # Extra note to add in run ID for logging
+ local_log_dir: str = "./Libero/logs" # Local directory for eval logs
+
+ use_wandb: bool = False # Whether to also log results in Weights & Biases
+ wandb_project: str = "YOUR_WANDB_PROJECT" # Name of W&B project to log to (use default!)
+ wandb_entity: str = "YOUR_WANDB_ENTITY" # Name of entity to log under
+
+ seed: int = 42 # Random Seed (for reproducibility)
+ use_length: int = 8
+ # fmt: on
+
+
+@draccus.wrap()
+def eval_libero(cfg: GenerateConfig) -> None:
+
+ ckpt_index = cfg.pretrained_checkpoint.split('/checkpoints/')[0].split('/')[-1]
+ # Set random seed
+ set_seed_everywhere(cfg.seed)
+
+ # [OpenVLA] Check that the model contains the action un-normalization key
+ if cfg.model_family == "openvla":
+ # [OpenVLA] Set action un-normalization key
+ cfg.unnorm_key = cfg.task_suite_name
+ model, server = get_model(cfg)
+ server = None
+ # In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset
+ # with the suffix "_no_noops" in the dataset name)
+ if cfg.unnorm_key not in model.norm_stats and f"{cfg.unnorm_key}_no_noops" in model.norm_stats:
+ cfg.unnorm_key = f"{cfg.unnorm_key}_no_noops"
+ assert cfg.unnorm_key in model.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!"
+
+ elif cfg.model_family == "instruct_vla":
+ # [OpenVLA] Set action un-normalization key
+ cfg.unnorm_key = f"{cfg.task_suite_name}_no_noops"
+ model, server = get_model(cfg)
+
+ # Initialize local logging
+ run_id = f"EVAL-{cfg.task_suite_name}-task{cfg.task_id}-{cfg.model_family}-{DATE_TIME}-{ckpt_index}"
+ if cfg.run_id_note is not None:
+ run_id += f"--{cfg.run_id_note}"
+ cfg.local_log_dir = os.path.join(cfg.local_log_dir, ckpt_index)
+ os.makedirs(cfg.local_log_dir, exist_ok=True)
+ local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
+ log_file = open(local_log_filepath, "w")
+ print(f"Logging to local log file: {local_log_filepath}")
+
+ # Initialize Weights & Biases logging as well
+ if cfg.use_wandb:
+ wandb.init(
+ entity=cfg.wandb_entity,
+ project=cfg.wandb_project,
+ name=run_id,
+ )
+
+ # Initialize LIBERO task suite
+ benchmark_dict = benchmark.get_benchmark_dict()
+ task_suite = benchmark_dict[cfg.task_suite_name]()
+ num_tasks_in_suite = task_suite.n_tasks
+ print(f"Task suite: {cfg.task_suite_name}")
+ log_file.write(f"Task suite: {cfg.task_suite_name}\n")
+
+ # Get expected image dimensions
+ resize_size = get_image_resize_size(cfg)
+
+ # Start evaluation
+ total_episodes, total_successes = 0, 0
+ for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
+ # Get task
+ if cfg.task_id is not None:
+ if cfg.task_suite_name == 'libero_10':
+ if task_id != cfg.task_id:
+ continue
+ task = task_suite.get_task(task_id)
+
+ # Get default LIBERO initial states
+ initial_states = task_suite.get_task_init_states(task_id)
+
+ # Initialize LIBERO environment and task description
+ env, task_description = get_libero_env(task, cfg.model_family, resolution=256)
+
+ # Start episodes
+ task_episodes, task_successes = 0, 0
+ for episode_idx in tqdm.tqdm(range(cfg.num_trials_per_task)):
+ print(f"\nTask: {task_description}")
+ log_file.write(f"\nTask: {task_description}\n")
+
+ # Reset environment
+ env.reset()
+ server.reset(robo_name='libero')
+ # Set initial states
+ obs = env.set_init_state(initial_states[episode_idx])
+
+ # Setup
+ t = 0
+ replay_images = []
+ if cfg.task_suite_name == "libero_spatial":
+ max_steps = 220 # longest training demo has 193 steps
+ elif cfg.task_suite_name == "libero_object":
+ max_steps = 280 # longest training demo has 254 steps
+ elif cfg.task_suite_name == "libero_goal":
+ max_steps = 300 # longest training demo has 270 steps
+ elif cfg.task_suite_name == "libero_10":
+ max_steps = 520 # longest training demo has 505 steps
+ elif cfg.task_suite_name == "libero_90":
+ max_steps = 400 # longest training demo has 373 steps
+
+ print(f"Starting episode {task_episodes+1}...")
+ log_file.write(f"Starting episode {task_episodes+1}...\n")
+ while t < max_steps + cfg.num_steps_wait:
+ # try:
+ # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
+ # and we need to wait for them to fall
+ if t < cfg.num_steps_wait:
+ obs, reward, done, info = env.step(get_libero_dummy_action(cfg.model_family))
+ t += 1
+ continue
+
+ # Get preprocessed image
+ img = get_libero_image(obs, resize_size)
+ wrist_img = get_libero_wrist_image(obs, resize_size)
+
+ # Save preprocessed image for replay video
+ replay_images.append(img)
+
+ # Prepare observations dict
+ # Note: OpenVLA does not take proprio state as input
+
+ state = np.concatenate(
+ (obs["robot0_eef_pos"], quat2axisangle(obs["robot0_eef_quat"]), obs["robot0_gripper_qpos"]))
+
+ observation = {
+ "image": img,
+ "wrist_image": wrist_img,
+ "state": state,
+ "task": task_description,
+ }
+
+ # Query model to get action
+ action = get_action(
+ server, observation
+ ).copy()
+
+ # Normalize gripper action [0,1] -> [-1,+1] because the environment expects the latter
+ # action = normalize_gripper_action(action, binarize=True)
+ action[..., -1] = np.sign(action[..., -1]) # binarize
+
+ # [OpenVLA] The dataloader flips the sign of the gripper action to align with other datasets
+ # (0 = close, 1 = open), so flip it back (-1 = open, +1 = close) before executing the action
+ # action = invert_gripper_action(action) # skip since we use raw action
+
+ print('==>action is',action)
+ # Execute action in environment
+ obs, reward, done, info = env.step(action.tolist())
+ if done:
+ task_successes += 1
+ total_successes += 1
+ break
+ t += 1
+
+ # except Exception as e:
+ # print(f"Caught exception: {e}")
+ # log_file.write(f"Caught exception: {e}\n")
+ # break
+
+ task_episodes += 1
+ total_episodes += 1
+
+ # Save a replay video of the episode
+ save_rollout_video(
+ replay_images, total_episodes, success=done, task_description=task_description, log_file=log_file, ckpt_index=ckpt_index, task_suite_name=cfg.task_suite_name, task_id=task_id
+ )
+
+ # Log current results
+ print(f"Success: {done}")
+ print(f"# episodes completed so far: {total_episodes}")
+ print(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
+ log_file.write(f"Success: {done}\n")
+ log_file.write(f"# episodes completed so far: {total_episodes}\n")
+ log_file.write(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)\n")
+ log_file.flush()
+
+ # Log final results
+ print(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
+ print(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
+ log_file.write(f"Current task success rate: {float(task_successes) / float(task_episodes)}\n")
+ log_file.write(f"Current total success rate: {float(total_successes) / float(total_episodes)}\n")
+ log_file.flush()
+ if cfg.use_wandb:
+ wandb.log(
+ {
+ f"success_rate/{task_description}": float(task_successes) / float(task_episodes),
+ f"num_episodes/{task_description}": task_episodes,
+ }
+ )
+
+ # Save local log file
+ log_file.close()
+
+ # Push total metrics and local log file to wandb
+ if cfg.use_wandb:
+ wandb.log(
+ {
+ "success_rate/total": float(total_successes) / float(total_episodes),
+ "num_episodes/total": total_episodes,
+ }
+ )
+ wandb.save(local_log_filepath)
+
+
+if __name__ == "__main__":
+ eval_libero()
diff --git a/experiment/libero/robot_utils.py b/experiment/libero/robot_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd0a207afab077a99891101234940278dd3a131a
--- /dev/null
+++ b/experiment/libero/robot_utils.py
@@ -0,0 +1,84 @@
+"""Utils for evaluating robot policies in various environments."""
+
+import os
+import random
+import time
+
+import numpy as np
+import torch
+
+# Initialize important constants and pretty-printing mode in NumPy.
+ACTION_DIM = 7
+DATE = time.strftime("%Y_%m_%d")
+DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
+np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
+
+
+
+def set_seed_everywhere(seed: int):
+ """Sets the random seed for Python, NumPy, and PyTorch functions."""
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ os.environ["PYTHONHASHSEED"] = str(seed)
+
+
+def get_model(cfg, wrap_diffusion_policy_for_droid=False):
+ """Load model for evaluation."""
+ from deploy.websocket_client_policy import WebsocketClientPolicy
+ cronus_server = WebsocketClientPolicy(port=cfg.model_port)
+ return None, cronus_server
+
+
+def get_image_resize_size(cfg):
+ """
+ Gets image resize size for a model class.
+ If `resize_size` is an int, then the resized image will be a square.
+ Else, the image will be a rectangle.
+ """
+ if cfg.model_family == "openvla" or "instruct_vla" in cfg.model_family:
+ resize_size = 224
+ else:
+ raise ValueError("Unexpected `model_family` found in config.")
+ return resize_size
+
+
+def get_action(server, obs):
+ """Queries the model to get an action."""
+
+ action = server.infer(obs)['action']
+ return action
+
+
+def normalize_gripper_action(action, binarize=True):
+ """
+ Changes gripper action (last dimension of action vector) from [0,1] to [-1,+1].
+ Necessary for some environments (not Bridge) because the dataset wrapper standardizes gripper actions to [0,1].
+ Note that unlike the other action dimensions, the gripper action is not normalized to [-1,+1] by default by
+ the dataset wrapper.
+
+ Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1
+ """
+ # Just normalize the last action to [-1,+1].
+ orig_low, orig_high = 0.0, 1.0
+ action = np.array(action, copy=True)
+ action[..., -1] = 2 * (action[..., -1] - orig_low) / (orig_high - orig_low) - 1
+
+ if binarize:
+ # Binarize to -1 or +1.
+ action[..., -1] = np.sign(action[..., -1])
+
+ return action
+
+
+def invert_gripper_action(action):
+ """
+ Flips the sign of the gripper action (last dimension of action vector).
+ This is necessary for some environments where -1 = open, +1 = close, since
+ the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open.
+ """
+ action[..., -1] = action[..., -1] * -1.0
+ return action
diff --git a/experiment/robotwin/README.md b/experiment/robotwin/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..abf2e53511210942edea0c629f14e0113e6d4573
--- /dev/null
+++ b/experiment/robotwin/README.md
@@ -0,0 +1,85 @@
+# Generate Lerobot Dataset from RoboTwin Data
+
+This guide explains how to process raw data from **RoboTwin** and convert it into the **LerobotDataset** format following the official RoboTwin instructions.
+
+## 1. Clone the Official RoboTwin Repository
+```bash
+git clone git@github.com:RoboTwin-Platform/RoboTwin.git
+```
+
+## 2. Create Required Directories
+Navigate to the `policy/pi0` directory inside the cloned RoboTwin repository and create the folders:
+
+```bash
+cd ./policy/pi0
+mkdir processed_data training_data
+```
+
+## 3. Convert RoboTwin Raw Data to HDF5
+
+Use the provided script [process_data_pi0.sh](https://github.com/RoboTwin-Platform/RoboTwin/blob/main/policy/pi0/process_data_pi0.sh):
+
+```bash
+bash process_data_pi0.sh ${task_name} ${task_config} ${expert_data_num}
+```
+
+**Example (clean demo):**
+```bash
+bash process_data_pi0.sh beat_block_hammer demo_clean 50
+```
+
+**Example (randomized demo):**
+```bash
+bash process_data_pi0.sh beat_block_hammer demo_randomized 50
+```
+
+If successful, the output folder:
+```
+processed_data/${task_name}-${task_config}-${expert_data_num}/
+```
+
+## 4. Prepare Training Data
+
+Copy the required processed datasets into `training_data/${model_name}`:
+
+```bash
+cp -r processed_data/${task_name}-${task_config}-${expert_data_num} \
+ training_data/${model_name}/
+```
+
+## 5. Ensure Sufficient Disk Space
+
+The generated **LerobotDataset** will be stored under:
+
+```
+$XDG_CACHE_HOME/huggingface/lerobot/${repo_id}
+```
+
+By default, `XDG_CACHE_HOME` points to `~/.cache`, which must have sufficient free space.
+If space is low, change the cache location:
+
+```bash
+export XDG_CACHE_HOME=/path/to/your/cache
+```
+
+## 6. Generate LerobotDataset Format
+
+Run [process_data_pi0.sh](https://github.com/RoboTwin-Platform/RoboTwin/blob/main/policy/pi0/generate.sh) to convert the HDF5 datasets to Lerobot.
+
+Parameters:
+- **hdf5_path**: Path to the HDF5 training data (e.g., `./training_data/${model_name}/`)
+- **repo_id**: Name for the dataset (e.g., `my_repo`)
+
+```bash
+bash generate.sh ${hdf5_path} ${repo_id}
+```
+
+**Example:**
+```bash
+bash generate.sh ./training_data/demo_clean/ demo_clean_repo
+```
+
+Output:
+```
+${XDG_CACHE_HOME}/huggingface/lerobot/${repo_id}
+```
\ No newline at end of file
diff --git a/lingbotvla/__init__.py b/lingbotvla/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..baf703ef287bb65202b683ce6a43a1a1c18229b2
--- /dev/null
+++ b/lingbotvla/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+__version__ = "0.0.1"
diff --git a/lingbotvla/checkpoint/__init__.py b/lingbotvla/checkpoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..44a21fce377be74f33333c87d6976050b0ede257
--- /dev/null
+++ b/lingbotvla/checkpoint/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .checkpointer import build_checkpointer
+from .format_utils import bytecheckpoint_ckpt_to_state_dict, ckpt_to_state_dict, dcp_to_torch_state_dict
+
+
+__all__ = [
+ "ckpt_to_state_dict",
+ "dcp_to_torch_state_dict",
+ "bytecheckpoint_ckpt_to_state_dict",
+ "build_checkpointer",
+]
diff --git a/lingbotvla/checkpoint/checkpointer.py b/lingbotvla/checkpoint/checkpointer.py
new file mode 100644
index 0000000000000000000000000000000000000000..062c7e2fc904abf39d44335e3906c3ecc1a88ec3
--- /dev/null
+++ b/lingbotvla/checkpoint/checkpointer.py
@@ -0,0 +1,340 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from abc import ABC, abstractmethod
+from typing import Any, Dict
+
+import torch
+import torch.distributed as dist
+from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
+from ..utils.import_utils import is_torch_version_greater_than
+from ..utils.logging import get_logger
+from pathlib import Path
+
+if is_torch_version_greater_than("2.4"):
+ import torch.distributed.checkpoint as dcp
+ from torch.distributed.checkpoint import (
+ FileSystemReader,
+ FileSystemWriter,
+ )
+ from torch.distributed.checkpoint.state_dict import (
+ get_model_state_dict,
+ get_optimizer_state_dict,
+ set_model_state_dict,
+ set_optimizer_state_dict,
+ )
+ from torch.distributed.checkpoint.stateful import Stateful
+else:
+ Stateful = ABC
+
+logger = get_logger(__name__)
+
+_EXTRA_STATE_FORMAT = "extra_state_rank_{}.pt"
+_MODEL_DIR = "model"
+_EMA_DIR = "ema"
+_OPTIMIZER_DIR = "optimizer"
+_EXTRA_STATE_DIR = "extra_state"
+
+
+class ModelState(Stateful):
+ """
+ A wrapper around a model to make it stateful.
+ Args:
+ model (Model): model to wrap.
+ """
+
+ def __init__(self, model):
+ self.model = model
+
+ def state_dict(self):
+ model_state_dict = get_model_state_dict(model=self.model)
+ return {"model": model_state_dict}
+
+ def load_state_dict(self, state_dict):
+ set_model_state_dict(model=self.model, model_state_dict=state_dict["model"])
+
+
+class OptimizerState(Stateful):
+ """
+ A wrapper around an optimizer to make it stateful.
+
+ Args:
+ model (Model): model to wrap.
+ optimizer (Optimizer): optimizer to wrap.
+ """
+
+ def __init__(self, model, optimizer):
+ self.model = model
+ self.optimizer = optimizer
+
+ def state_dict(self):
+ optimizer_state_dict = get_optimizer_state_dict(model=self.model, optimizers=self.optimizer)
+ return {"optim": optimizer_state_dict}
+
+ def load_state_dict(self, state_dict):
+ set_optimizer_state_dict(model=self.model, optimizers=self.optimizer, optim_state_dict=state_dict["optim"])
+
+
+def build_checkpointer(
+ dist_backend: str = "fsdp1",
+ ckpt_manager: str = "bytecheckpoint",
+):
+ """
+ create a checkpointer manager with given mode.
+ Args:
+ dist_backend (str, optional): checkpoint mode. Defaults to "fsdp1".
+ fsdp1: FSDP1 checkpoint from bytecheckpoint
+ fsdp2-vescale: FSDP2 checkpoint from bytecheckpoint
+ fsdp2: FSDP2 checkpoint from bytecheckpoint
+ ddp: DDP checkpoint from bytecheckpoint
+ dcp: DCP checkpoint from torch.distributed.checkpoint
+ ckpt_manager (str, optional): checkpoint manager. Defaults to "bytecheckpoint".
+ bytecheckpoint: bytecheckpoint checkpoint manager
+ dcp: torch dcp checkpoint manager
+ Raises:
+ ValueError: if ckpt_manager is not supported
+
+ Returns:
+ Checkpointer: checkpointer with given mode.
+ """
+
+ if ckpt_manager == "bytecheckpoint":
+ if dist_backend == "ddp":
+ from bytecheckpoint import DDPCheckpointer as Checkpointer
+ elif dist_backend == "fsdp1":
+ from bytecheckpoint import FSDPCheckpointer as Checkpointer
+ elif dist_backend == "fsdp2-vescale":
+ from bytecheckpoint import VeScaleCheckpointer as Checkpointer
+ elif dist_backend == "fsdp2":
+ from bytecheckpoint import FSDP2Checkpointer as Checkpointer
+ elif ckpt_manager == "dcp":
+ if not is_torch_version_greater_than("2.4"):
+ raise ValueError("DCP checkpoint manager requires torch version >= 2.4")
+ if dist_backend not in ["ddp", "fsdp1", "fsdp2"]:
+ raise ValueError(
+ f"Unsupported distributed backend: {dist_backend} for DCP checkpoint manager, supported modes are: ddp, fsdp1, fsdp2"
+ )
+ Checkpointer = DistributedCheckpointer
+ else:
+ raise ValueError(
+ f"Unknown checkpoint manager: {ckpt_manager}, supported modes are: bytecheckpoint, dcp, native"
+ )
+
+ return Checkpointer
+
+
+class CheckpointerBase(ABC):
+ """Base class for checkpointer"""
+
+ @abstractmethod
+ def save(
+ cls,
+ path: str,
+ state: Dict[str, Any],
+ ):
+ return
+
+ @abstractmethod
+ def load(
+ cls,
+ path: str,
+ state: Dict[str, Any],
+ ):
+ return
+
+
+class DistributedCheckpointer(CheckpointerBase):
+ """
+ Distributed checkpointer for torch.distributed.checkpoint
+ """
+
+ @classmethod
+ def save(
+ cls,
+ path: str,
+ state: Dict[str, Any],
+ global_steps: int = None,
+ save_async=False,
+ ) -> None:
+ """
+ save training state to distributed checkpoint
+
+ args:
+ path: path to save checkpoint
+ state: state to save
+ global_steps: global steps
+ save_async: whether to save asynchronously
+ return:
+ None
+ """
+
+ checkpoint_dir = f"{path}/global_step_{global_steps}" if global_steps else path
+ os.makedirs(checkpoint_dir, exist_ok=True)
+
+ if "model" not in state:
+ raise ValueError("Model must be provided to save a distributed checkpoint.")
+
+ if save_async:
+ model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
+ dcp.async_save(
+ state_dict={"state": ModelState(state["model"])},
+ storage_writer=FileSystemWriter(
+ model_dir,
+ thread_count=16,
+ single_file_per_rank=True,
+ sync_files=False,
+ ),
+ )
+ if "ema" in state and state["ema"] is not None:
+ ema_dir = os.path.join(checkpoint_dir, _EMA_DIR)
+ dcp.async_save(
+ state_dict={"state": ModelState(state["ema"])},
+ storage_writer=FileSystemWriter(
+ ema_dir,
+ thread_count=16,
+ single_file_per_rank=True,
+ sync_files=False,
+ ),
+ )
+ if "optimizer" in state:
+ optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR)
+ dcp.async_save(
+ state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])},
+ storage_writer=FileSystemWriter(
+ optimizer_dir,
+ thread_count=16,
+ single_file_per_rank=True,
+ sync_files=False,
+ ),
+ )
+ else:
+ def safe_create_writer(output_dir):
+ tmp_path = Path(output_dir) / ".metadata.tmp"
+ if tmp_path.exists():
+ print(f"Warning: removing existing tmp file: {tmp_path}")
+ tmp_path.unlink() # remove .metadata.tmp
+ return FileSystemWriter(
+ output_dir,
+ thread_count=16,
+ single_file_per_rank=True,
+ sync_files=False,
+ )
+ model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
+ storage_writer = safe_create_writer(model_dir)
+ dcp.save(
+ state_dict={"state": ModelState(state["model"])},
+ storage_writer=storage_writer,
+ )
+ if "ema" in state and state["ema"] is not None:
+ ema_dir = os.path.join(checkpoint_dir, _EMA_DIR)
+ storage_writer = safe_create_writer(ema_dir)
+ dcp.save(
+ state_dict={"state": ModelState(state["ema"])},
+ storage_writer=storage_writer,
+ )
+ if "optimizer" in state:
+ optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR)
+ dcp.save(
+ state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])},
+ storage_writer=FileSystemWriter(
+ optimizer_dir,
+ thread_count=16,
+ single_file_per_rank=True,
+ sync_files=False,
+ ),
+ )
+ # dist.barrier()
+
+ if "extra_state" in state:
+ extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR)
+ os.makedirs(extra_state_dir, exist_ok=True)
+ extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank()))
+ torch.save(
+ state["extra_state"],
+ extra_state_path,
+ )
+
+ logger.info_rank0(f"Saved checkpoint to {checkpoint_dir}")
+
+ @classmethod
+ def load(
+ cls,
+ path: str,
+ state: Dict[str, Any],
+ process_group=None,
+ ) -> Dict[str, Any]:
+ """
+ load training state from distributed checkpoint
+ args:
+ path: path to load checkpoint
+ state: state to load, "model" are required, "optimizer" and "extra_state" are optional
+
+ return:
+ state: state loaded
+ """
+ checkpoint_dir = path
+
+ if state is None:
+ raise ValueError("State dict must be provided to load a distributed checkpoint.")
+
+ if "model" not in state:
+ raise ValueError("Model must be provided to load a distributed checkpoint.")
+
+ if "ema" in state and state["ema"] is not None:
+ ema_dir = os.path.join(checkpoint_dir, _EMA_DIR)
+ dcp.load(
+ state_dict={"state": ModelState(state["ema"])},
+ storage_reader=FileSystemReader(ema_dir),
+ process_group=process_group,
+ )
+
+ if "optimizer" in state:
+ model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
+ dcp.load(
+ state_dict={"state": ModelState(state["model"])},
+ storage_reader=FileSystemReader(model_dir),
+ process_group=process_group,
+ )
+
+ optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR)
+ try:
+ dcp.load(
+ state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])}, # 1043
+ storage_reader=FileSystemReader(optimizer_dir), # 1027
+ planner = DefaultLoadPlanner(allow_partial_load=True),
+ process_group=process_group,
+ )
+ except:
+ logger.info_rank0(f"Skip loading Optimizer from {checkpoint_dir}")
+ else:
+ model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
+ dcp.load(
+ state_dict={"state": ModelState(state["model"])},
+ storage_reader=FileSystemReader(model_dir),
+ process_group=process_group,
+ )
+
+ if "extra_state" in state:
+ extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR)
+ os.makedirs(extra_state_dir, exist_ok=True)
+ extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank()))
+ state["extra_state"] = torch.load(
+ extra_state_path,
+ )
+
+ logger.info_rank0(f"Loaded checkpoint from {checkpoint_dir}")
+
+ return state
diff --git a/lingbotvla/checkpoint/format_utils.py b/lingbotvla/checkpoint/format_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2642cfb5463084a851f18936e3d8786082a14b15
--- /dev/null
+++ b/lingbotvla/checkpoint/format_utils.py
@@ -0,0 +1,127 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from abc import ABC
+from typing import Any, Dict, Union
+
+import torch
+
+from ..utils.import_utils import is_torch_version_greater_than
+from ..utils.logging import get_logger
+
+
+if is_torch_version_greater_than("2.4"):
+ from torch.distributed.checkpoint import FileSystemReader
+ from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
+ from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
+ from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
+else:
+ STATE_DICT_TYPE = ABC
+
+logger = get_logger(__name__)
+
+_MODEL_DIR = "model"
+_EMA_DIR = "ema"
+
+
+def ckpt_to_state_dict(
+ save_checkpoint_path: Union[str, os.PathLike],
+ output_dir: Union[str, os.PathLike],
+ ckpt_manager: str = "bytecheckpoint",
+ ema: bool = False,
+) -> Dict[str, Any]:
+ """
+ Interface to convert a checkpoint to a state_dict.
+ Supported checkpoint managers:
+ - bytecheckpoint
+ - dcp
+ - native
+
+ Args:
+ save_checkpoint_path: Path to the checkpoint.
+ output_dir: Path to the output directory.
+ ckpt_manager: Checkpoint manager.
+ Returns:
+ state_dict: State dict.
+ """
+ if ckpt_manager == "bytecheckpoint":
+ state_dict = bytecheckpoint_ckpt_to_state_dict(save_checkpoint_path, output_dir)
+ elif ckpt_manager == "dcp":
+ state_dict = dcp_to_torch_state_dict(save_checkpoint_path, ema)
+ elif ckpt_manager == "native":
+ model_dir = os.path.join(save_checkpoint_path, _MODEL_DIR)
+ if os.path.exists(model_dir):
+ save_checkpoint_path = model_dir
+ state_dict = torch.load(save_checkpoint_path)
+ else:
+ raise ValueError(f"Unknown checkpoint manager: {ckpt_manager}")
+ return state_dict
+
+
+def bytecheckpoint_ckpt_to_state_dict(
+ save_checkpoint_path: Union[str, os.PathLike], output_dir: Union[str, os.PathLike]
+):
+ """
+ Given a directory containing an Bytecheckpoint checkpoint, this function will convert it into a
+ Torch state_dict.
+ Args:
+ save_checkpoint_path: Directory containing the Bytecheckpoint checkpoint.
+ output_dir: Directory to save the converted checkpoint.
+ """
+
+ from bytecheckpoint.utilities.ckpt_format.merge_tool import bytecheckpoint_ckpt_to_pytorch_ckpt
+
+ state_dict = bytecheckpoint_ckpt_to_pytorch_ckpt(
+ save_path=save_checkpoint_path,
+ output_path=output_dir,
+ framework="fsdp",
+ model_only=True,
+ return_dict=True,
+ )
+ return state_dict["model"]
+
+
+def dcp_to_torch_state_dict(save_checkpoint_path: Union[str, os.PathLike], ema: bool = False) -> STATE_DICT_TYPE:
+ """
+ Given a directory containing a DCP checkpoint, this function will convert it into a
+ Torch state_dict.
+
+ Args:
+ save_checkpoint_path: Directory containing the DCP checkpoint.
+
+ .. warning::
+ To avoid OOM, it's recommended to only run this function on a single rank.
+ """
+ if ema:
+ model_dir = os.path.join(save_checkpoint_path, _EMA_DIR)
+ else:
+ model_dir = os.path.join(save_checkpoint_path, _MODEL_DIR)
+ if os.path.exists(model_dir):
+ save_checkpoint_path = model_dir
+
+ # Load the state_dict from the DCP checkpoint
+ state_dict: STATE_DICT_TYPE = {}
+
+ _load_state_dict(
+ state_dict,
+ storage_reader=FileSystemReader(save_checkpoint_path),
+ planner=_EmptyStateDictLoadPlanner(),
+ no_dist=True,
+ )
+ if "state" in state_dict:
+ state_dict = state_dict["state"]
+
+ return state_dict["model"]
diff --git a/lingbotvla/data/__init__.py b/lingbotvla/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2db0766a135cfdc0b774a768eefad61542318ee8
--- /dev/null
+++ b/lingbotvla/data/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .chat_template import build_chat_template
+from .data_collator import (
+ CollatePipeline,
+ DataCollatorWithPacking,
+ DataCollatorWithPadding,
+ DataCollatorWithPositionIDs,
+ MakeMicroBatchCollator,
+ TextSequenceShardCollator,
+ UnpackDataCollator,
+)
+from .data_loader import build_dataloader
+from .dataset import build_iterative_dataset, build_mapping_dataset, liberoDataset, RobotwinDataset
+from .data_transform import (
+ VLADataCollatorWithPacking,
+)
diff --git a/lingbotvla/data/batching_strategy.py b/lingbotvla/data/batching_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..1820840d747e2bc5e4058c44207ae78fa2d1f996
--- /dev/null
+++ b/lingbotvla/data/batching_strategy.py
@@ -0,0 +1,223 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict
+
+
+class DynBszBuffer:
+ """
+ A buffer to store samples for dynamic batch size.
+ """
+
+ def __init__(self):
+ self._buffer = []
+ self._buffer_sample_lens = []
+ self.del_idxs = []
+ self.cur_idx = 0
+ self.all_token_cnt = 0
+
+ def append(self, item: Dict[str, Any]):
+ """
+ Append a sample to the buffer.
+ Args:
+ item: a sample to append to the buffer.
+ The sample should be a dict with the following keys:
+ - input_ids: torch.Tensor of shape (seq_len, )
+ - attention_mask: torch.Tensor of shape (seq_len, )
+ """
+ self._buffer.append(item)
+ if 'attention_mask' in item:
+ self._buffer_sample_lens.append(item["attention_mask"].sum())
+ self.all_token_cnt += self._buffer_sample_lens[-1]
+ elif 'lang_masks' in item:
+ self._buffer_sample_lens.append(item["lang_masks"].sum())
+ self.all_token_cnt += self._buffer_sample_lens[-1]
+
+ def get_samples(self, n_token_per_iter: int, force: bool = True):
+ """
+ get samples from the buffer.
+ Args:
+ n_token_per_iter: the number of tokens to get.
+ force: if True, the first sample will be returned even if it is not full.
+ Returns:
+ samples: a list of samples.
+ """
+ cum_seq_len = 0
+ samples = []
+ while self.cur_idx < len(self._buffer) and cum_seq_len < n_token_per_iter:
+ seq_len = self._buffer_sample_lens[self.cur_idx]
+ if self.cur_idx not in self.del_idxs and (
+ (force is True and cum_seq_len == 0) or (seq_len <= n_token_per_iter - cum_seq_len)
+ ):
+ cum_seq_len += seq_len
+ samples.append(self._buffer[self.cur_idx])
+ self.del_idxs.append(self.cur_idx)
+ self.cur_idx += 1
+ assert len(samples) > 0
+ return samples
+
+ def __len__(self):
+ return len(self._buffer)
+
+ def flush(self):
+ """ "
+ Flush the buffer.
+ """
+ self.cur_idx = 0
+ self.all_token_cnt -= sum([self._buffer_sample_lens[idx] for idx in self.del_idxs])
+ buffer_len = len(self._buffer)
+ self._buffer = [self._buffer[idx] for idx in range(buffer_len) if idx not in self.del_idxs]
+ self._buffer_sample_lens = [
+ self._buffer_sample_lens[idx] for idx in range(buffer_len) if idx not in self.del_idxs
+ ]
+ self.del_idxs = []
+
+ def merge(self, buffer_to_merge: "DynBszBuffer"):
+ """ "
+ Merge the buffer with another buffer.
+ Args:
+ buffer_to_merge: the buffer to merge.
+ """
+ self.flush()
+ buffer_to_merge.flush()
+ for item in buffer_to_merge._buffer:
+ self.append(item)
+
+
+class BaseBatchingStrategy:
+ """
+ Base class for batching strategy.s
+ """
+
+ def is_full_filled(self) -> bool:
+ raise NotImplementedError("should implement `is_full_filled`")
+
+ def put_item(self, item: Dict[str, Any]):
+ raise NotImplementedError("should implement `put_item`")
+
+ def get_micro_batch(self, step: int) -> Any:
+ raise NotImplementedError("should implement `get_micro_batch` ")
+
+ def empty(self) -> bool:
+ raise NotImplementedError("should implement `empty`")
+
+
+class IdentityPacker:
+ def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
+ self.token_micro_bsz = token_micro_bsz
+ self.bsz_warmup_steps = bsz_warmup_steps
+ self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
+
+ def __call__(self, samples):
+ return samples
+
+ def get_token_num_to_request(self, cur_step, warmup):
+ return (
+ (self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
+ + self.bsz_warmup_init_mbtoken
+ if warmup
+ else self.token_micro_bsz
+ )
+
+
+class TextBatchingStrategy(BaseBatchingStrategy):
+ """ "
+ Batching strategy for text data.
+ Args:
+ token_micro_bsz: the number of tokens to get for each request.
+ buffer_size: the size of the buffer.
+ bsz_warmup_steps: the number of steps to warm up the batch size.
+ bsz_warmup_init_mbtoken: the initial number of tokens to get for each request.
+ """
+
+ def __init__(
+ self,
+ token_micro_bsz,
+ buffer_size: int = 500,
+ bsz_warmup_steps: int = -1,
+ bsz_warmup_init_mbtoken: int = 200,
+ ) -> None:
+ super().__init__()
+ self._step = 0
+ self.token_micro_bsz = token_micro_bsz
+ self.bsz_warmup_steps = bsz_warmup_steps
+ self.buffer_size = buffer_size # minimum samples in buffer
+ self.buffer = DynBszBuffer()
+ self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
+ assert self.bsz_warmup_init_mbtoken >= 0
+
+ self.packer = IdentityPacker(
+ token_micro_bsz=token_micro_bsz,
+ bsz_warmup_steps=bsz_warmup_steps,
+ bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
+ )
+
+ def is_full_filled(self) -> bool:
+ return len(self.buffer) >= self.buffer_size and self.buffer.all_token_cnt >= self.token_micro_bsz
+
+ def put_item(self, item: Dict[str, Any]):
+ if "input_ids" in item:
+ if len(item["input_ids"]) == 1:
+ print("WARNING: EMPTY STRING.")
+ return
+ elif "lang_tokens" in item:
+ if all (item["lang_tokens"] == 0):
+ print("WARNING: EMPTY STRING.")
+ return
+ self.buffer.append(item)
+
+ def get_token_num_to_request(self):
+ if self.packer is not None:
+ warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
+ return self.packer.get_token_num_to_request(self._step, warmup=warmup)
+ else:
+ return self.get_cur_token_micro_bsz()
+
+ def get_cur_token_micro_bsz(self):
+ warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0
+ if warmup:
+ return (
+ self.token_micro_bsz - self.bsz_warmup_init_mbtoken
+ ) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken
+ else:
+ return self.token_micro_bsz
+
+ def get_micro_batch(self, step) -> Any:
+ """
+ Get a micro batch from the buffer according to the current step.
+ Args:
+ step: the current step.
+ Returns:
+ data: a list of samples.
+ """
+
+ self._step = step
+ n_token_per_iter = self.get_token_num_to_request()
+ cur_token_micro_bsz = self.get_cur_token_micro_bsz()
+ assert cur_token_micro_bsz % n_token_per_iter == 0, (
+ "The token num to get for each request should be divisible by token micro bsz."
+ )
+ n_iter = int(cur_token_micro_bsz // n_token_per_iter)
+ data = []
+ for i in range(n_iter):
+ samples = self.buffer.get_samples(n_token_per_iter)
+ if self.packer:
+ samples = self.packer(samples) # maybe packed into one sample, but wrapped in list.
+ data.extend(samples)
+ self.buffer.flush() # remove the selected samples.
+ return data
+
+ def empty(self) -> bool:
+ return len(self.buffer) == 0
diff --git a/lingbotvla/data/chat_template.py b/lingbotvla/data/chat_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b984752d537faf791f1a87a3be39109e9ef716c
--- /dev/null
+++ b/lingbotvla/data/chat_template.py
@@ -0,0 +1,262 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Dict, List, Sequence
+
+import torch
+
+from ..utils import logging
+from .constants import IGNORE_INDEX
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+ROLE_SUPPORTED = ["system", "user", "assistant", "tool"]
+
+
+class ChatTemplate(ABC):
+ """
+ Abstract class for chat template.
+ """
+
+ def __init__(self, tokenizer: "PreTrainedTokenizer") -> None:
+ self.tokenizer = tokenizer
+
+ def save_pretrained(self, output_dir: str) -> None:
+ self.tokenizer.chat_template = self.get_jinja_template()
+ try:
+ self.tokenizer.save_pretrained(output_dir)
+ except Exception:
+ logger.warning("Failed to save tokenizer.")
+
+ @abstractmethod
+ def encode_messages(self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192) -> Dict[str, List[int]]:
+ """
+ Encodes messages to a dictionary of input_ids, attention_mask, and labels.
+ """
+ ...
+
+ @abstractmethod
+ def get_jinja_template(self) -> str:
+ """
+ Gets the jinja template for the chat template.
+ """
+ ...
+
+
+class DefaultTemplate(ChatTemplate):
+ def encode_messages(self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192) -> Dict[str, List[int]]:
+ input_ids, attention_mask, labels = [], [], []
+ for message in messages:
+ content_str = message["role"].title() + ": " + message["content"].strip() + self.tokenizer.eos_token + "\n"
+ content_ids = self.tokenizer.encode(content_str, add_special_tokens=False)
+ input_ids += content_ids
+ attention_mask += [1] * len(content_ids)
+ if message["loss_mask"] == 1:
+ labels += content_ids
+ else:
+ labels += [IGNORE_INDEX] * len(content_ids)
+
+ model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
+ model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
+ return model_inputs
+
+ def get_jinja_template(self) -> str:
+ return (
+ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}"
+ "{% for message in messages %}"
+ "{{ message['role'].title() + ': ' + message['content'] | trim + eos_token + '\n' }}"
+ "{% endfor %}"
+ "{% if add_generation_prompt %}{{ 'Assistant: ' }}{% endif %}"
+ )
+
+
+class Llama2Template(ChatTemplate):
+ def encode_messages(self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192) -> Dict[str, List[int]]:
+ input_ids, attention_mask, labels = [], [], []
+ for message in messages:
+ if message["role"] == "system":
+ content_str = "<>\n" + message["content"].strip() + "\n<>\n\n"
+ elif message["role"] == "user":
+ content_str = self.tokenizer.bos_token + "[INST] " + message["content"].strip() + " [/INST]"
+ elif message["role"] == "assistant":
+ content_str = " " + message["content"].strip() + " " + self.tokenizer.eos_token
+ elif message["role"] == "tool":
+ content_str = self.tokenizer.bos_token + "[TOOL] " + message["content"].strip() + " [/TOOL]"
+ else:
+ raise ValueError(
+ f"Unknown role {message['role']}, should be one of {{system, user, assistant, tool}}."
+ )
+
+ content_ids = self.tokenizer.encode(content_str, add_special_tokens=False)
+ input_ids += content_ids
+ attention_mask += [1] * len(content_ids)
+ if message["loss_mask"] == 1:
+ labels += content_ids
+ else:
+ labels += [IGNORE_INDEX] * len(content_ids)
+
+ model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
+ model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
+ return model_inputs
+
+ def get_jinja_template(self) -> str:
+ return (
+ "{% if messages[0]['role'] == 'system' %}"
+ "{{ '<>\n' + messages[0]['content'] | trim + '\n<>\n\n' }}"
+ "{% set loop_messages = messages[1:] %}"
+ "{% else %}"
+ "{% set loop_messages = messages %}"
+ "{% endif %}"
+ "{% for message in loop_messages %}"
+ "{% set content = message['content'] %}"
+ "{% if message['role'] == 'user' %}"
+ "{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}"
+ "{% elif message['role'] == 'tool' %}"
+ "{{ bos_token + '[TOOL] ' + content | trim + ' [/TOOL]' }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ ' ' + content | trim + ' ' + eos_token }}"
+ "{% endif %}"
+ "{% endfor %}"
+ )
+
+
+class JanusTemplate(ChatTemplate):
+ def encode_messages(
+ self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192, task_type: str = ""
+ ) -> Dict[str, List[int]]:
+ input_ids, attention_mask, labels = [], [], []
+ images_seq_mask, images_emb_mask = [], []
+ seps = ["\n\n", "<|end▁of▁sentence|>"]
+ assitant_cnt = 0
+ for idx, message in enumerate(messages):
+ if message["content"] == "":
+ content_str = message["role"] + ":"
+ elif (
+ "assistant" in message["role"]
+ and "wikihow_generation" in task_type
+ or "assistant" in message["role"]
+ and "interleave_generation" in task_type
+ ):
+ prefix = "Assistant: " if assitant_cnt == 0 else ""
+ suffix = seps[1] if idx + 1 == len(messages) else seps[0]
+ content_str = prefix + message["content"].strip() + suffix
+ assitant_cnt += 1
+ elif "assistant" in message["role"]:
+ content_str = "Assistant" + ": " + message["content"].strip() + seps[1]
+ elif "user" in message["role"]:
+ content_str = "User" + ": " + message["content"].strip() + seps[0]
+ elif "system" in message["role"] and "wikihow_generation" in task_type:
+ content_str = (
+ message["content"].strip()
+ + seps[0]
+ + "Please generate a step-by-step tutorial with images for the following question."
+ + seps[0]
+ )
+ elif "system" in message["role"]:
+ content_str = message["content"].strip() + seps[0]
+ if "system" in message["role"]:
+ content_ids = self.tokenizer.encode(content_str)
+ else:
+ content_ids = self.tokenizer.encode(content_str, add_special_tokens=False)
+ input_ids += content_ids
+ attention_mask += [1] * len(content_ids)
+ image_token_id = self.tokenizer.vocab.get("")
+ content_ids_tensor = torch.tensor(content_ids)
+ images_seq_mask += (content_ids_tensor == image_token_id).tolist()
+ image_token_id = self.tokenizer.vocab.get("")
+ num_image_tokens = torch.sum(content_ids_tensor == image_token_id).item()
+ n_image = num_image_tokens // 576
+ if n_image > 0:
+ for j, n_image_tokens in enumerate([num_image_tokens]):
+ images_emb_mask.append([True] * n_image_tokens)
+
+ if message["loss_mask"] == 1:
+ if (
+ image_token_id in content_ids
+ and "wikihow_generation" not in task_type
+ and "interleave_generation" not in task_type
+ ):
+ labels += [image_token_id if x == image_token_id else IGNORE_INDEX for x in content_ids]
+ else:
+ labels += content_ids
+ else:
+ labels += [IGNORE_INDEX] * len(content_ids)
+
+ model_inputs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": labels,
+ "images_seq_mask": images_seq_mask,
+ "images_emb_mask": images_emb_mask,
+ }
+ model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
+ return model_inputs
+
+ def get_jinja_template(self) -> str:
+ return (
+ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}"
+ "{% for message in messages %}"
+ "{{ '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>\n' }}"
+ "{% endfor %}"
+ "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
+ )
+
+
+class ChatmlTemplate(ChatTemplate):
+ def encode_messages(self, messages: Sequence[Dict[str, str]], max_seq_len: int = 8192) -> Dict[str, List[int]]:
+ input_ids, attention_mask, labels = [], [], []
+ for message in messages:
+ content_str = "<|im_start|>" + message["role"] + "\n" + message["content"].strip() + "<|im_end|>\n"
+ content_ids = self.tokenizer.encode(content_str, add_special_tokens=False)
+ input_ids += content_ids
+ attention_mask += [1] * len(content_ids)
+ if message["loss_mask"] == 1:
+ labels += content_ids
+ else:
+ labels += [IGNORE_INDEX] * len(content_ids)
+
+ model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
+ model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
+ return model_inputs
+
+ def get_jinja_template(self) -> str:
+ return (
+ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}"
+ "{% for message in messages %}"
+ "{{ '<|im_start|>' + message['role'] + '\n' + message['content'] | trim + '<|im_end|>\n' }}"
+ "{% endfor %}"
+ "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
+ )
+
+
+TEMPLATES = {
+ "default": DefaultTemplate,
+ "llama2": Llama2Template,
+ "chatml": ChatmlTemplate,
+ "Janus": JanusTemplate,
+}
+
+
+def build_chat_template(template_name: str, tokenizer: "PreTrainedTokenizer") -> "ChatTemplate":
+ if template_name not in TEMPLATES:
+ raise ValueError(f"Unknown chat template: {template_name}")
+
+ return TEMPLATES[template_name](tokenizer)
diff --git a/lingbotvla/data/constants.py b/lingbotvla/data/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fb59ceec664606928391028944ca521bfa596ca
--- /dev/null
+++ b/lingbotvla/data/constants.py
@@ -0,0 +1,39 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+IGNORE_INDEX = -100
+
+# input index
+IMAGE_INPUT_INDEX = -200
+VIDEO_INPUT_INDEX = -300
+AUDIO_INPUT_INDEX = -400
+# output index
+IMAGE_OUTPUT_INDEX = -201
+VIDEO_OUTPUT_INDEX = -301
+AUDIO_OUTPUT_INDEX = -401
+
+
+TYPE2INDEX = {
+ "input": {
+ "image": IMAGE_INPUT_INDEX,
+ "video": VIDEO_INPUT_INDEX,
+ "audio": AUDIO_INPUT_INDEX,
+ },
+ "output": {
+ "image": IMAGE_OUTPUT_INDEX,
+ "video": VIDEO_OUTPUT_INDEX,
+ "audio": AUDIO_OUTPUT_INDEX,
+ },
+}
diff --git a/lingbotvla/data/data_collator.py b/lingbotvla/data/data_collator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2958775a8ea125ad2bba981e82e8e412dd077110
--- /dev/null
+++ b/lingbotvla/data/data_collator.py
@@ -0,0 +1,270 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data._utils.collate import default_collate
+
+from ..distributed.parallel_state import get_parallel_state
+from ..utils.seqlen_pos_transform_utils import len2culen, pos2culen
+from .constants import IGNORE_INDEX
+
+
+@dataclass
+class DataCollator(ABC):
+ """
+ Used in dataloader as a collate_fn.
+ """
+
+ @abstractmethod
+ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
+ """
+ Converts a list of features to batched tensor dict.
+ """
+ ...
+
+
+class CollatePipeline:
+ def __init__(self, data_collators: Optional[Union[Callable, List[Callable]]] = None):
+ """
+ Args:
+ data_collators: a list of data collators or a single data collator
+ """
+
+ if not isinstance(data_collators, (list, tuple)):
+ data_collators = [data_collators]
+ self.data_collators = data_collators
+
+ def __call__(self, batch: Sequence[Dict[str, Any]]):
+ """
+ process data batch through data collators.
+
+ Args:
+ batch: the original input data batch
+
+ Returns:
+ batch: the processed data batch
+
+ """
+ for data_collator in self.data_collators:
+ batch = data_collator(batch)
+ return batch
+
+
+@dataclass
+class DataCollatorWithPadding(DataCollator):
+ """
+ Data collator with padding.
+ """
+
+ pad_token_id: int = 0
+
+ def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]:
+ batch = defaultdict(list)
+
+ # batching features
+ for feature in features:
+ for key in feature.keys():
+ batch[key].append(feature[key])
+
+ for key in batch.keys():
+ # process padding features
+ if key in ["input_ids", "attention_mask", "position_ids", "images_seq_mask"]:
+ batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=0)
+ elif key in ["labels", "labels_image"]:
+ batch[key] = pad_sequence(batch[key], batch_first=True, padding_value=IGNORE_INDEX)
+ else:
+ batch[key] = default_collate(batch[key])
+
+ return batch
+
+
+@dataclass
+class DataCollatorWithPacking(DataCollator):
+ """
+ Data collator with packing.
+ """
+
+ def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]:
+ seqlens = torch.tensor([len(feature["input_ids"]) for feature in features], dtype=torch.long)
+ batch = {"cu_seqlens": len2culen(seqlens)}
+ for input_name in features[0].keys():
+ if input_name in ("input_ids", "attention_mask", "labels"):
+ batch[input_name] = torch.cat([feature[input_name] for feature in features])
+ else:
+ batch[input_name] = default_collate([feature[input_name] for feature in features])
+
+ return batch
+
+
+@dataclass
+class DataCollatorWithPositionIDs(DataCollator):
+ """
+ Data collator with packing by position ids.
+ """
+
+ def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]:
+ batch = {}
+ for input_name in features[0].keys():
+ if input_name in ("input_ids", "attention_mask", "labels", "position_ids"):
+ batch[input_name] = torch.cat([feature[input_name] for feature in features], dim=-1).unsqueeze(0)
+ else:
+ batch[input_name] = default_collate([feature[input_name] for feature in features])
+
+ if "position_ids" not in batch:
+ batch["position_ids"] = torch.cat(
+ [torch.arange(len(feature["input_ids"])) for feature in features]
+ ).unsqueeze(0)
+
+ if "labels" in batch:
+ cu_seqlens = pos2culen(batch["position_ids"])
+ batch["labels"][:, cu_seqlens[1:-1]] = IGNORE_INDEX
+
+ return batch
+
+
+@dataclass
+class NoopDataCollator(DataCollator):
+ """
+ Data collator with no operation, used in dynamic batch dataloader at main process.
+ """
+
+ def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> List[Dict[str, "torch.Tensor"]]:
+ return features
+
+
+@dataclass
+class UnpackDataCollator(DataCollator):
+ """
+ Data collator to unpack examples, used in dynamic batch dataloader at worker process.
+ """
+
+ def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]:
+ return features[0]
+
+
+@dataclass
+class MakeMicroBatchCollator(DataCollator):
+ """
+ Data collator to build micro batches, used in mapping dataloader.
+ """
+
+ num_micro_batch: int
+ internal_data_collator: "DataCollator"
+
+ def __call__(self, features: Sequence[Tuple[Dict[str, "torch.Tensor"]]]) -> List[Dict[str, "torch.Tensor"]]:
+ micro_batch_size = len(features) // self.num_micro_batch
+ if isinstance(features[0], list):
+ for i in range(len(features)):
+ features[i] = features[i][0] # 1-to-N inverse transform
+
+ micro_batches = []
+ for i in range(0, len(features), micro_batch_size):
+ micro_batches.append(self.internal_data_collator(features[i : i + micro_batch_size]))
+
+ return micro_batches
+
+
+@dataclass
+class TextSequenceShardCollator(DataCollator):
+ """
+ Data collator to chunk inputs according to sequence parallelism.
+ Args:
+ rmpad: whether the samples is packing or not.
+ rmpad_with_pos_ids: whether the samples is packing by position ids or not.
+ pad_token_id: the id of the padding token.
+ """
+
+ rmpad: bool
+ rmpad_with_pos_ids: bool
+ pad_token_id: int = 0
+
+ def __post_init__(self):
+ self.sp_size = get_parallel_state().sp_size
+ self.sp_rank = get_parallel_state().sp_rank
+
+ def sp_slice(self, tensor: "torch.Tensor", dim: int = -1) -> "torch.Tensor":
+ """
+ Slices a tensor along the specified dimension for sequence parallelism.
+ """
+ seq_length = tensor.size(dim)
+ sp_chunk_size = (seq_length + self.sp_size - 1) // self.sp_size
+ return tensor.narrow(dim, self.sp_rank * sp_chunk_size, sp_chunk_size)
+
+ def sp_padding(
+ self, tensor: "torch.Tensor", dim: int = -1, pad_value: int = 0, pad_length: int = 0
+ ) -> "torch.Tensor":
+ """
+ Pads a tensor with pad_length to aligns tensor with sp size.
+ """
+ if pad_length == 0:
+ return tensor
+
+ pad_shape = list(tensor.shape)
+ pad_shape[dim] = pad_length
+ pad = torch.full(pad_shape, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device)
+ return torch.cat((tensor, pad), dim=dim)
+
+ def __call__(self, batch: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]:
+ input_ids = batch.pop("input_ids")
+ labels = batch.pop("labels")[..., 1:].contiguous() # shift labels
+ labels = F.pad(labels, (0, 1), "constant", IGNORE_INDEX)
+
+ if self.rmpad_with_pos_ids: # mask the last token of each sequence
+ cu_seqlens = pos2culen(batch["position_ids"])
+ labels[:, cu_seqlens[1:-1] - 1] = IGNORE_INDEX
+ elif self.rmpad:
+ labels = labels.view(-1)
+ labels[batch["cu_seqlens"][1:-1] - 1] = IGNORE_INDEX
+ else:
+ if "position_ids" not in batch: # we should calculate the position ids before chunking
+ batch["position_ids"] = torch.arange(0, input_ids.size(-1)).unsqueeze(0)
+
+ # sp padding
+ seq_length = input_ids.size(-1)
+ sp_chunk_size = (seq_length + self.sp_size - 1) // self.sp_size
+ pad_length = sp_chunk_size * self.sp_size - seq_length
+
+ input_ids = self.sp_padding(input_ids, dim=-1, pad_value=self.pad_token_id, pad_length=pad_length)
+ labels = self.sp_padding(labels, dim=-1, pad_value=IGNORE_INDEX, pad_length=pad_length)
+
+ if self.rmpad_with_pos_ids:
+ batch["attention_mask"] = self.sp_padding(
+ batch["attention_mask"], dim=-1, pad_value=1, pad_length=pad_length
+ )
+ else:
+ batch["attention_mask"] = self.sp_padding(
+ batch["attention_mask"], dim=-1, pad_value=0, pad_length=pad_length
+ )
+
+ if self.rmpad:
+ if pad_length > 0:
+ batch["cu_seqlens"] = F.pad(
+ batch["cu_seqlens"], (0, 1), "constant", batch["cu_seqlens"][-1].item() + pad_length
+ )
+ else:
+ batch["position_ids"] = self.sp_padding(batch["position_ids"], dim=-1, pad_value=0, pad_length=pad_length)
+
+ # sp slice
+ batch["input_ids"] = self.sp_slice(input_ids, dim=-1)
+ batch["labels"] = self.sp_slice(labels, dim=-1)
+
+ return batch
diff --git a/lingbotvla/data/data_loader.py b/lingbotvla/data/data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef95577df72b6e33068c93ebb10e5f28bd125c75
--- /dev/null
+++ b/lingbotvla/data/data_loader.py
@@ -0,0 +1,149 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import TYPE_CHECKING, Callable, List, Optional, Union
+
+from torch.utils.data import IterableDataset
+from torchdata.stateful_dataloader import StatefulDataLoader
+from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
+
+from ..distributed.parallel_state import get_parallel_state
+from ..utils import logging
+from .batching_strategy import TextBatchingStrategy
+from .data_collator import (
+ CollatePipeline,
+ DataCollatorWithPacking,
+ DataCollatorWithPadding,
+ DataCollatorWithPositionIDs,
+ MakeMicroBatchCollator,
+ TextSequenceShardCollator,
+ UnpackDataCollator,
+)
+from .dynamic_batching import DynamicBatchSizeDataLoader
+
+
+if TYPE_CHECKING:
+ from torch.utils.data import Dataset
+
+
+logger = logging.get_logger(__name__)
+
+
+class DistributedDataloader(StatefulDataLoader):
+ dataset: "Dataset"
+ sampler: "StatefulDistributedSampler"
+
+ def set_epoch(self, epoch: int) -> None:
+ if self.sampler is not None and hasattr(self.sampler, "set_epoch"):
+ self.sampler.set_epoch(epoch)
+ elif hasattr(self.dataset, "set_epoch"):
+ self.dataset.set_epoch(epoch)
+
+
+def build_dataloader(
+ dataset: "Dataset",
+ micro_batch_size: int,
+ global_batch_size: int,
+ dataloader_batch_size: int,
+ max_seq_len: int,
+ train_steps: int,
+ rmpad: bool = True,
+ rmpad_with_pos_ids: bool = False,
+ bsz_warmup_ratio: float = 0.02,
+ bsz_warmup_init_mbtoken: int = 200,
+ dyn_bsz_buffer_size: int = 500,
+ dyn_bsz_margin: int = 0,
+ collate_fn: Optional[Union[Callable, List[Callable]]] = None,
+ num_workers: int = 8,
+ drop_last: bool = True,
+ pin_memory: bool = True,
+ prefetch_factor: Optional[int] = 2,
+ seed: int = 0,
+) -> "DistributedDataloader":
+ parallel_state = get_parallel_state()
+ token_micro_bsz = micro_batch_size * max_seq_len
+ num_micro_batch = global_batch_size // (
+ micro_batch_size * parallel_state.dp_size
+ ) # num_micro_batch = num accumulation steps
+ bsz_warmup_steps = int(train_steps * bsz_warmup_ratio)
+ use_rmpad = rmpad or rmpad_with_pos_ids
+ logger.info_rank0(
+ f"train_steps: {train_steps}, max_seq_len: {max_seq_len}, use_rmpad: {use_rmpad}, "
+ f"bsz_warmup_steps: {bsz_warmup_steps}, bsz_warmup_init_mbtoken: {bsz_warmup_init_mbtoken}, "
+ f"token_micro_bsz: {token_micro_bsz}, num_micro_batch: {num_micro_batch}, "
+ f"micro_batch_size: {micro_batch_size}, global_batch_size: {global_batch_size}, "
+ f"dp_size: {parallel_state.dp_size}, sp_size: {parallel_state.sp_size}."
+ )
+
+ if collate_fn is None:
+ collate_fn_list = []
+ if rmpad_with_pos_ids:
+ collate_fn_list.append(DataCollatorWithPositionIDs())
+ elif rmpad:
+ collate_fn_list.append(DataCollatorWithPacking())
+ else:
+ collate_fn_list.append(DataCollatorWithPadding())
+
+ if parallel_state.sp_enabled:
+ collate_fn_list.append(TextSequenceShardCollator(rmpad=rmpad, rmpad_with_pos_ids=rmpad_with_pos_ids))
+
+ collate_fn = CollatePipeline(collate_fn_list)
+
+ if isinstance(collate_fn, list):
+ collate_fn = CollatePipeline(collate_fn)
+
+ if use_rmpad:
+ batching_strategy = TextBatchingStrategy(
+ token_micro_bsz=token_micro_bsz - dyn_bsz_margin * max_seq_len,
+ buffer_size=dyn_bsz_buffer_size,
+ bsz_warmup_steps=bsz_warmup_steps if bsz_warmup_steps else -1,
+ bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken,
+ )
+ dyn_bsz_collate_fn = collate_fn
+ collate_fn = UnpackDataCollator()
+ else:
+ collate_fn = MakeMicroBatchCollator(num_micro_batch=num_micro_batch, internal_data_collator=collate_fn)
+
+ sampler = None
+ if not isinstance(dataset, IterableDataset):
+ sampler = StatefulDistributedSampler(
+ dataset,
+ num_replicas=parallel_state.dp_size,
+ rank=parallel_state.dp_rank,
+ shuffle=True,
+ seed=seed,
+ )
+
+ dataloader = DistributedDataloader(
+ dataset,
+ batch_size=dataloader_batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=collate_fn,
+ pin_memory=pin_memory,
+ drop_last=drop_last,
+ prefetch_factor=prefetch_factor,
+ )
+ if use_rmpad:
+ dataloader = DynamicBatchSizeDataLoader(
+ dataloader,
+ batching_strategy=batching_strategy,
+ collate_fn=dyn_bsz_collate_fn,
+ num_micro_batch=num_micro_batch,
+ length=train_steps,
+ drop_last=drop_last,
+ )
+
+ return dataloader
diff --git a/lingbotvla/data/data_transform.py b/lingbotvla/data/data_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..f023de6655aa86254493910aade6f8b3454b702e
--- /dev/null
+++ b/lingbotvla/data/data_transform.py
@@ -0,0 +1,136 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
+from dataclasses import dataclass, field
+from torch.utils.data._utils.collate import default_collate
+import torch
+from .data_collator import DataCollator
+
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer
+
+ from .chat_template import ChatTemplate
+
+
+def split_into_chunks(sequence: Sequence[int], chunk_size: int) -> List[List[int]]:
+ """
+ Splits a long sequence into chunks.
+ """
+ total_len = len(sequence)
+ chunks = []
+ for i in range(0, total_len, chunk_size):
+ chunks.append(sequence[i : i + chunk_size])
+
+ return chunks
+
+
+def process_pretrain_example(
+ example: Dict[str, Any],
+ tokenizer: "PreTrainedTokenizer",
+ max_seq_len: int,
+ text_keys: Union[str, List[str]] = "content_split",
+ source_name: Optional[str] = None,
+) -> List[Dict[str, "torch.Tensor"]]:
+ examples = []
+ if isinstance(text_keys, str):
+ text_example = example[text_keys]
+ elif isinstance(text_keys, list):
+ for key in text_keys:
+ if key in example:
+ text_example = example[key]
+ break
+ else:
+ raise ValueError(f"None of the keys {text_keys} are found in the example.")
+ else:
+ raise ValueError(f"text_keys must be a string or a list of strings, but got {type(text_keys)}")
+
+ tokens = tokenizer.encode(text_example, add_special_tokens=False) + [tokenizer.eos_token_id]
+ for input_ids in split_into_chunks(tokens, max_seq_len):
+ examples.append(
+ {
+ "input_ids": torch.tensor(input_ids),
+ "attention_mask": torch.tensor([1] * len(input_ids)),
+ "labels": torch.tensor(input_ids),
+ }
+ )
+
+ return examples
+
+
+def process_sft_example(
+ example: Dict[str, Any],
+ chat_template: "ChatTemplate",
+ max_seq_len: int,
+ text_keys: Union[str, List[str]] = "messages",
+) -> List[Dict[str, "torch.Tensor"]]:
+ if isinstance(text_keys, str):
+ text_example = example[text_keys]
+ elif isinstance(text_keys, list):
+ for key in text_keys:
+ if key in example:
+ text_example = example[key]
+ break
+ else:
+ raise ValueError(f"None of the keys {text_keys} are found in the example.")
+ else:
+ raise ValueError(f"text_keys must be a string or a list of strings, but got {type(text_keys)}")
+
+ tokenized_example = chat_template.encode_messages(text_example, max_seq_len=max_seq_len)
+ tokenized_example = {k: torch.tensor(v) for k, v in tokenized_example.items()}
+ return [tokenized_example]
+
+
+@dataclass
+class VLADataCollatorWithPacking(DataCollator):
+ """
+ Data collator to packing for omni dataset.
+ Args:
+ packing_features: features to packing in batch.
+ concat_features: features to concat in batch.
+ Example:
+ >>> from lingbotvla.data import OmniDataCollatorWithPacking
+ """
+ state_features: List = field(
+ default_factory=lambda: [
+ "state",
+ "images",
+ "img_masks",
+ "lang_tokens",
+ "lang_masks",
+ "action_is_pad",
+ "actions",
+ "joint_mask",
+ "label",
+ "fast_mask"
+ ],
+ metadata={"help": "state features with one chunk."},
+ )
+
+ def __call__(self, features: Sequence[Dict[str, "torch.Tensor"]]) -> Dict[str, "torch.Tensor"]:
+ batch = {}
+ keys = {key for feature in features for key in feature.keys()}
+ for input_name in keys:
+ if input_name in self.state_features:
+ batch[input_name] = torch.cat(
+ [feature[input_name].unsqueeze(0) for feature in features if input_name in feature], dim=0
+ )
+ else:
+ batch[input_name] = default_collate(
+ [feature[input_name] for feature in features if input_name in feature]
+ )
+
+ return batch
\ No newline at end of file
diff --git a/lingbotvla/data/dataset.py b/lingbotvla/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..15a97ab43f0d89386a0456f1521d4eb974007047
--- /dev/null
+++ b/lingbotvla/data/dataset.py
@@ -0,0 +1,196 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from typing import Callable, Dict, List, Literal, Optional
+import numpy as np
+import torch
+from datasets import load_dataset
+from datasets.distributed import split_dataset_by_node
+from torch.utils.data import Dataset, IterableDataset
+from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
+from torchvision.transforms.v2 import Resize
+from transformers import AutoTokenizer, AutoImageProcessor
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
+import json
+from ..distributed.parallel_state import get_parallel_state
+from ..utils import logging
+from ..utils.dist_utils import main_process_first
+from .vla_data import *
+from .vla_data.transform import Normalizer, prepare_action, prepare_images, prepare_language, prepare_state
+logger = logging.get_logger(__name__)
+
+try:
+ import datasets.features.features as features
+
+ _OLD_GENERATE_FROM_DICT = features.generate_from_dict
+
+ def _new_generate_from_dict(obj):
+ if isinstance(obj, dict) and obj.get("_type") == "List":
+ obj["_type"] = "Sequence"
+ return _OLD_GENERATE_FROM_DICT(obj)
+
+ features.generate_from_dict = _new_generate_from_dict
+except (ImportError, AttributeError):
+ # If datasets or the function doesn't exist, do nothing.
+ pass
+
+class DummyDataset(Dataset):
+ def __init__(self, size: int, seq_length: int):
+ """
+ Args:
+ size (int): Nums of datasets
+ seq_length (int, optional): seq_length
+ """
+ self.size = size
+ self.seq_length = seq_length
+ self.vocab_size = 32768
+
+ def __len__(self) -> int:
+ return self.size
+
+ def __getitem__(self, index: int) -> List[Dict[str, "torch.Tensor"]]:
+ input_ids = torch.randint(low=0, high=self.vocab_size, size=(self.seq_length,))
+ attention_mask = torch.ones((self.seq_length,), dtype=torch.long)
+ labels = input_ids.clone()
+ return [{"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}]
+
+class MappingDataset(Dataset):
+ """
+ Mapping dataset.
+ Args:
+ data (Dataset): Dataset
+ transform (Optional[Callable]): transform function
+ """
+
+ def __init__(self, data: "Dataset", transform: Optional[Callable] = None):
+ self._data = data
+ self._transform = transform
+
+ def __len__(self) -> int:
+ return len(self._data)
+
+ def __getitem__(self, index: int) -> List[Dict[str, "torch.Tensor"]]:
+ if self._transform is not None:
+ return self._transform(self._data[index])
+ else:
+ return self._data[index]
+
+
+class IterativeDataset(IterableDataset):
+ """
+ Iterative dataset.
+ Args:
+ data (Dataset): Dataset
+ transform (Optional[Callable]): transform function
+ """
+
+ def __init__(self, data: "Dataset", transform: Optional[Callable] = None):
+ self._data = data
+ self._transform = transform
+
+ def __iter__(self):
+ for sample in self._data:
+ if self._transform is not None:
+ yield self._transform(sample)
+ else:
+ yield sample
+
+ def load_state_dict(self, state_dict):
+ self._data.load_state_dict(state_dict["dataset"])
+
+ def state_dict(self):
+ return {"dataset": self._data.state_dict()}
+
+ def set_epoch(self, epoch: int):
+ self._data.set_epoch(epoch)
+
+
+def build_dummy_dataset(size: int, max_seq_len: int) -> "Dataset":
+ return DummyDataset(size=size, seq_length=max_seq_len)
+
+
+def build_mapping_dataset(
+ data_path: str,
+ transform: Optional[Callable] = None,
+ namespace: Literal["train", "test"] = "train",
+) -> "Dataset":
+ """
+ Build mapping dataset.
+ Args:
+ data_path (str): data path
+ transform (Optional[Callable]): transform function
+ namespace (Literal["train", "test"]): dataset namespace
+ Returns:
+ Dataset: mapping dataset
+ """
+ data_files = []
+ data_paths = data_path.split(",")
+ for data_path in data_paths:
+ if os.path.isdir(data_path):
+ data_files.extend([os.path.join(data_path, fn) for fn in os.listdir(data_path)])
+ elif os.path.isfile(data_path):
+ data_files.append(data_path)
+ else:
+ raise FileNotFoundError(f"Dataset {data_path} not exists.")
+ file_extenstion = os.path.splitext(data_files[0])[-1][1:]
+ if file_extenstion not in ["parquet", "jsonl", "json", "csv", "arrow"]:
+ raise ValueError(f"{file_extenstion} files are not supported.")
+
+ file_extenstion = "json" if file_extenstion == "jsonl" else file_extenstion
+ with main_process_first():
+ dataset = load_dataset(file_extenstion, data_files=data_files, split=namespace)
+
+ return MappingDataset(data=dataset, transform=transform)
+
+
+def build_iterative_dataset(
+ data_path: str,
+ transform: Optional[Callable] = None,
+ namespace: Literal["train", "test"] = "train",
+ seed: int = 42,
+) -> "IterableDataset":
+ """ "
+ Build iterative dataset.
+ Args:
+ data_path (str): data path
+ transform (Optional[Callable]): transform function
+ namespace (Literal["train", "test"]): dataset namespace
+ seed (int): random seed
+ Returns:
+ IterableDataset: iterative dataset
+ """
+
+ data_files = []
+ data_paths = data_path.split(",")
+ for data_path in data_paths:
+ if os.path.isdir(data_path):
+ data_files.extend([os.path.join(data_path, fn) for fn in os.listdir(data_path)])
+ elif os.path.isfile(data_path):
+ data_files.append(data_path)
+ else:
+ raise FileNotFoundError(f"Dataset {data_path} not exists.")
+
+ parallel_state = get_parallel_state()
+ file_extenstion = os.path.splitext(data_files[0])[-1][1:]
+ if file_extenstion not in ["parquet", "jsonl", "json", "csv", "arrow"]:
+ raise ValueError(f"{file_extenstion} files are not supported.")
+
+ file_extenstion = "json" if file_extenstion == "jsonl" else file_extenstion
+ dataset = load_dataset(file_extenstion, data_files=data_files, split=namespace, streaming=True)
+ dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
+ dataset = split_dataset_by_node(dataset, parallel_state.dp_rank, parallel_state.dp_size)
+
+ return IterativeDataset(dataset, transform)
\ No newline at end of file
diff --git a/lingbotvla/data/dynamic_batching.py b/lingbotvla/data/dynamic_batching.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc1a2c6a1f931c5370914924d52937200769b603
--- /dev/null
+++ b/lingbotvla/data/dynamic_batching.py
@@ -0,0 +1,186 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import copy
+import sys
+import traceback
+from collections import deque
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterator, Optional
+
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+if TYPE_CHECKING:
+ from .batching_strategy import BaseBatchingStrategy
+
+
+class DynamicBatchSizeDataLoader:
+ """Dynamic batch DataLoader.
+
+ Args:
+ dataloader: torch DataLoader
+ batching_strategy: dynamic batch strategy
+ collate_fn: DataLoader collate_fn, collate data after get data from batching_strategy
+ num_micro_batch: num_micro_batch, if num_micro_batch == 1, return micro_batch for gradient accumulation
+ length: length of dataloader, if length == -1, length = sys.maxsize, default len(dataloader)
+ drop_last: if True, drop last batch if batch size < num_micro_batch
+
+ """
+
+ def __init__(
+ self,
+ dataloader: Any,
+ batching_strategy: "BaseBatchingStrategy",
+ collate_fn: Optional[Callable] = None,
+ num_micro_batch: int = 1,
+ length: int = 0,
+ drop_last: bool = True,
+ ) -> None:
+ self.batching_strategy = batching_strategy
+ self.num_micro_batch = num_micro_batch
+ self.dataloader_item_buffer = deque()
+ self.item_buffer = deque()
+ self.step = 0
+ self._collate_fn = collate_fn
+ self._dataloader = dataloader
+ self._drop_last = drop_last
+ self._data_iter: Iterator
+ self._resume = False
+ self._batch_data_iter: Generator
+
+ if length > 0:
+ self._length = length
+ elif length == -1:
+ self._length = sys.maxsize
+ else:
+ self._length = len(self._dataloader)
+
+ def __len__(self):
+ if self._length:
+ return self._length
+ else:
+ raise RuntimeError("length must set at init. before call len()")
+
+ def __iter__(self) -> Iterator:
+ if not self._resume:
+ self.step = 0
+ self._data_iter = iter(self._dataloader)
+ self._batch_data_iter = self.batch_data_generator()
+ self._resume = False
+ return self
+
+ def __next__(self):
+ return next(self._batch_data_iter)
+
+ def batch_data_generator(self):
+ batch = []
+
+ while True:
+ if self._length and self.step >= self._length:
+ return
+
+ if self.batching_strategy.is_full_filled():
+ micro_batch = self.batching_strategy.get_micro_batch(self.step)
+ if self._collate_fn:
+ micro_batch = self._collate_fn(micro_batch)
+ batch.append(micro_batch)
+ if len(batch) == self.num_micro_batch:
+ yield batch
+ self.step += 1
+ batch = []
+
+ try:
+ processing_item = next(self._data_iter)
+ except Exception as e:
+ if isinstance(e, StopIteration):
+ if self.step < self._length:
+ # call iter until reach length
+ self._data_iter = iter(self._dataloader)
+ processing_item = next(self._data_iter)
+ elif not self._drop_last and not self.batching_strategy.empty():
+ while not self.batching_strategy.empty():
+ micro_batch = self.batching_strategy.get_micro_batch(self.step)
+ if self._collate_fn:
+ micro_batch = self._collate_fn(micro_batch)
+ batch.append(micro_batch)
+ if len(batch) == self.num_micro_batch:
+ yield batch
+ self.step += 1
+ batch = []
+
+ while len(batch) < self.num_micro_batch:
+ padding_batch = copy.deepcopy(micro_batch)
+ padding_batch["padding_flag"] = True
+ batch.append(padding_batch)
+ yield batch
+ self.step += 1
+ return
+ else:
+ return
+ else:
+ logger.error(f"DynamicBatchDataset iter data exception: {e} \n{traceback.format_exc()}")
+ raise
+
+ # put processing_item to buffer
+ if isinstance(processing_item, dict):
+ processing_item = [processing_item]
+
+ for item in processing_item:
+ self.batching_strategy.put_item(item)
+
+ def state_dict(self):
+ # save state
+ state = self.__dict__.copy()
+ # remove internal fields
+ for k in list(state.keys()):
+ if k.startswith("_"):
+ del state[k]
+
+ # save dataloader state
+ if hasattr(self._dataloader, "state_dict"):
+ state["dataloader_state"] = self._dataloader.state_dict()
+ elif hasattr(self._dataloader, "__getstate__"):
+ state["dataloader_state"] = self._dataloader.__getstate__()
+
+ if hasattr(self.batching_strategy, "state_dict"):
+ state["batching_strategy_state"] = self.batching_strategy.state_dict() # type: ignore
+ del state["batching_strategy"]
+
+ return copy.deepcopy(state)
+
+ def load_state_dict(self, state: Dict[str, Any]):
+ if state["num_micro_batch"] != self.num_micro_batch:
+ logger.warning(
+ f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer"
+ )
+ del state["num_micro_batch"]
+ self.__dict__.update(state)
+ self._resume = True
+
+ if hasattr(self._dataloader, "load_state_dict"):
+ self._dataloader.load_state_dict(state["dataloader_state"])
+ elif hasattr(self._dataloader, "__getstate__"):
+ self._dataloader.__setstate__(state["dataloader_state"])
+
+ if "batching_strategy_state" in state:
+ self.batching_strategy.load_state_dict( # type: ignore
+ state["batching_strategy_state"]
+ )
+ del state["batching_strategy_state"]
+
+ self._data_iter = iter(self._dataloader)
+ self._batch_data_iter = self.batch_data_generator()
diff --git a/lingbotvla/data/vla_data/README.md b/lingbotvla/data/vla_data/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9bded04b08885be42e437e451235c409ffa38013
--- /dev/null
+++ b/lingbotvla/data/vla_data/README.md
@@ -0,0 +1,58 @@
+
+---
+
+# Customized Downstream Task Dataset Construction and Deployment
+
+This guide explains how to construct a custom downstream task dataset for post-training and how to deploy it on corresponding downstream tasks. We use the 5 tasks from **RoboTwin 2.0** ("open_microwave", "click_bell", "stack_blocks_three", "place_shoe", "put_object_cabinet") as an example.
+
+## 1. Compute Normalization Statistics
+
+First, you need to calculate the normalization statistics for your custom dataset:
+
+```bash
+CUDA_VISIBLE_DEVICES=0 bash train.sh \
+ scripts/compute_norm_robotwin_5.py \
+ configs/norm/robotwin_5.yaml \
+ --model.model_path /path/to/LingBot-VLA \
+ --model.tokenizer_path /path/to/Qwen2.5-VL-3B-Instruct/ \
+ --data.train_path /path/to/mixed_robotwin_5tasks \
+ --data.norm_path assets/norm_stats/robotwin_5_customized.json
+```
+
+> **Note:**
+> In [`scripts/compute_norm_robotwin_5.py`](../../../scripts/compute_norm_robotwin_5.py) (lines 71–75), specify the original keys of **action** and **state** in the lerobot-formatted data.
+> For RoboTwin2.0, these correspond to:
+> - `action`
+> - `observation.state`
+
+---
+
+## 2. Construct Custom Dataset
+
+The `assets/norm_stats/robotwin_5_customized.json` generated in the previous step stores your normalization statistics. To use this file:
+
+1. **Specify the path** in your [Run Command](../../../README.md) via: `--data.norm_stats_file assets/norm_stats/robotwin_5_customized.json`.
+2. **Replace the Dataset Class:** In [tasks/vla/train_lingbotvla.py](../../../tasks/vla/train_lingbotvla.py), replace the default `RobotwinDataset` with the provided `CustomizedRobotwinDataset`.
+
+### Implementation Details:
+When constructing a custom dataset similar to `CustomizedRobotwinDataset`, ensure the following:
+* **Key Mapping:** When instantiating `self.normalizer` and obtaining `normalized_item`, you must modify the original keys for actions, states, and images from all views. For RoboTwin 2.0, the keys are:
+ * **Action:** `'action'`
+ * **State:** `'observation.state'`
+ * **Images:** `'observation.images.cam_high'`, `'observation.images.cam_left_wrist'`, `'observation.images.cam_right_wrist'`
+* **Data Type:** When instantiating `self.normalizer`, set the parameter **`data_type='customized'`**.
+
+---
+
+## 3. Deployment
+
+To ensure correct results, the data processing logic during the testing phase must **be identical to the training phase**.
+
+Taking [deploy/lingbot_robotwin_policy.py](../../../deploy/lingbot_robotwin_policy.py) as an example:
+You should use the same **action**, **state**, and **image keys** as in the training phase.
+Also, when constructing `policy.normalizer` in line 323, make sure the attribute `data_type` is consistent with your training setup.
+
+For example, if you used `CustomizedRobotwinDataset` during training,
+then line **323** in `deploy/lingbot_robotwin_policy.py` should be **`data_type='customized'`**.
+
+---
\ No newline at end of file
diff --git a/lingbotvla/data/vla_data/__init__.py b/lingbotvla/data/vla_data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..00311bdc56e32881b1dc690f059a64ef98571afc
--- /dev/null
+++ b/lingbotvla/data/vla_data/__init__.py
@@ -0,0 +1 @@
+from .base_dataset import liberoDataset, RobotwinDataset, CustomizedRobotwinDataset
\ No newline at end of file
diff --git a/lingbotvla/data/vla_data/base_dataset.py b/lingbotvla/data/vla_data/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6be25728e428cb21d96b54523e44b6d3949c80e
--- /dev/null
+++ b/lingbotvla/data/vla_data/base_dataset.py
@@ -0,0 +1,385 @@
+# Copyright 2026 Robbyant Team and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from typing import Callable, Dict, List, Literal, Optional
+import numpy as np
+import torch
+from datasets import load_dataset
+from datasets.distributed import split_dataset_by_node
+from torch.utils.data import Dataset, IterableDataset
+from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
+from torchvision.transforms.v2 import Resize
+from transformers import AutoTokenizer, AutoImageProcessor
+from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
+import json
+import yaml
+from PIL import Image
+from .transform import Normalizer, prepare_action, prepare_images, prepare_language, prepare_state
+
+from ...utils import logging
+
+class VlaDataset(Dataset):
+ def __init__(
+ self,
+ repo_id="path2dataset",
+ config=PI0Config,
+ tokenizer=AutoTokenizer,
+ data_config=None,
+ image_processor=None,
+ use_depth_align=False,
+ action_name="action",
+ ):
+ self.image_processor = image_processor
+ # [i / 30 for i in range(50)] represents action chunks in 50 steps at 30 FPS.
+ # The timestamps are set to 0 for the images and state, as we only use current obs.
+ self.config = config
+ self.tokenizer = tokenizer
+ self.dataset_meta = LeRobotDatasetMetadata(repo_id)
+ delta_timestamps = {
+ action_name: [t / self.dataset_meta.fps for t in range(50)],
+ }
+ self.dataset = LeRobotDataset(
+ repo_id=repo_id,
+ delta_timestamps=delta_timestamps,
+ )
+ self.action_name = action_name
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def getdata(self, idx):
+ item = self.dataset[idx]
+ task = self.dataset_meta.tasks[int(item['task_index'])]
+ assert task == item['task']
+ return item
+
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
+ if idx < 0 or idx >= len(self):
+ raise IndexError(f"Index {idx} out of bounds.")
+ max_retries = 200
+ attempts = 0
+ cur = idx
+ last_err = None
+ while attempts < max_retries:
+ try:
+ return self.getdata(cur)
+ except Exception as e:
+ last_err = e
+ attempts += 1
+ cur = np.random.randint(0, len(self))
+ if cur >= len(self):
+ cur = 0
+ continue
+
+ raise RuntimeError(
+ f"Failed to fetch a valid item starting from idx={idx} after {attempts} attempts. "
+ f"Last error: {repr(last_err)}"
+ )
+
+class liberoDataset(Dataset):
+ def __init__(
+ self,
+ repo_id="libero",
+ config=PI0Config,
+ tokenizer=AutoTokenizer,
+ data_config=None,
+ image_processor=None,
+ use_depth_align=False,
+ ):
+ image_transforms = Resize((data_config.img_size, data_config.img_size))
+ self.image_processor = image_processor
+ # [i / 30 for i in range(50)] represents action chunks in 50 steps at 30 FPS.
+ # The timestamps are set to 0 for the images and state, as we only use current obs.
+ self.config = config
+ self.tokenizer = tokenizer
+ self.norm_stats_file = data_config.norm_stats_file
+ self.dataset_meta = LeRobotDatasetMetadata(repo_id)
+ delta_timestamps = {
+ "actions": [t / self.dataset_meta.fps for t in range(50)],
+ }
+ self.dataset = LeRobotDataset(
+ repo_id=repo_id,
+ image_transforms=image_transforms,
+ delta_timestamps=delta_timestamps,
+ )
+ with open(self.norm_stats_file) as f:
+ self.norm_stats = json.load(f)
+ self.normalizer = Normalizer(
+ # norm_stats=self.dataset.meta.stats,
+ norm_stats=self.norm_stats['norm_stats'],
+ from_file=True,
+ data_type='libero',
+ norm_type={
+ "image": "identity",
+ "wrist_image": "identity",
+ "state": data_config.norm_type,
+ "actions": data_config.norm_type,
+ },
+ )
+ self.use_depth_align = use_depth_align
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ item = self.dataset[idx]
+ task = self.dataset_meta.tasks[int(item['task_index'])]
+ assert task == item['task']
+
+ normalized_item = self.normalizer.normalize(item)
+ base_image = (normalized_item["image"] * 255).to(torch.uint8)
+ wrist_image = (normalized_item["wrist_image"] * 255).to(
+ torch.uint8
+ )
+ batch_dict = {
+ "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": wrist_image},
+ "state": normalized_item["state"].to(torch.float32),
+ "action": normalized_item["actions"].to(torch.float32),
+ "action_is_pad": normalized_item["actions_is_pad"],
+ "prompt": [item["task"]],
+ }
+ state = prepare_state(self.config, batch_dict) # bs,8 -> bs,32
+ lang_tokens, lang_masks = prepare_language(self.config, self.tokenizer, batch_dict) # bs, seq_len
+ actions = prepare_action(self.config, batch_dict) # bs,50,7 -> bs,50,32 , 7
+ images, img_masks, pil_images = prepare_images(self.config, self.image_processor, batch_dict, use_depth_align=self.use_depth_align)
+
+ batch_dict = {
+ 'images': images,
+ 'img_masks': img_masks,
+ 'state': state,
+ 'lang_tokens': lang_tokens,
+ 'lang_masks': lang_masks,
+ 'actions': actions,
+ 'action_is_pad': batch_dict['action_is_pad'],
+ }
+
+ if self.use_depth_align: batch_dict['pil_images'] = pil_images
+
+ return batch_dict
+
+class RobotwinDataset(Dataset):
+ def __init__(
+ self,
+ repo_id="robotwin",
+ config=PI0Config,
+ tokenizer=AutoTokenizer,
+ data_config=None,
+ image_processor=None,
+ use_depth_align=False,
+ ):
+ image_transforms = Resize((data_config.img_size, data_config.img_size))
+ self.image_processor = image_processor
+ # [i / 30 for i in range(50)] represents action chunks in 50 steps at 30 FPS.
+ # The timestamps are set to 0 for the images and state, as we only use current obs.
+ self.config = config
+ self.tokenizer = tokenizer
+ self.norm_stats_file = data_config.norm_stats_file
+ self.dataset_meta = LeRobotDatasetMetadata(repo_id)
+ delta_timestamps = {
+ "action": [t / self.dataset_meta.fps for t in range(50)],
+ }
+ self.dataset = LeRobotDataset(
+ repo_id=repo_id,
+ image_transforms=image_transforms,
+ delta_timestamps=delta_timestamps,
+ )
+ with open(self.norm_stats_file) as f:
+ self.norm_stats = json.load(f)
+ self.normalizer = Normalizer(
+ # norm_stats=self.dataset.meta.stats,
+ norm_stats=self.norm_stats['norm_stats'],
+ from_file=True,
+ data_type='robotwin',
+ norm_type={
+ "observation.images.cam_high": "identity",
+ "observation.images.cam_left_wrist": "identity",
+ "observation.images.cam_right_wrist": "identity",
+ "observation.state": data_config.norm_type,
+ "action": data_config.norm_type,
+ },
+ )
+ self.use_depth_align = use_depth_align
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def getdata(self, idx):
+ item = self.dataset[idx]
+ task = self.dataset_meta.tasks[int(item['task_index'])]
+ assert task == item['task']
+
+ normalized_item = self.normalizer.normalize(item)
+ base_image = (normalized_item["observation.images.cam_high"] * 255).to(torch.uint8)
+ left_wrist_image = (normalized_item["observation.images.cam_left_wrist"] * 255).to(
+ torch.uint8
+ )
+ right_wrist_image = (normalized_item["observation.images.cam_right_wrist"] * 255).to(
+ torch.uint8
+ )
+ batch_dict = {
+ "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
+ "state": normalized_item["observation.state"].to(torch.float32),
+ "action": normalized_item["action"].to(torch.float32),
+ "action_is_pad": normalized_item["action_is_pad"],
+ "prompt": [item["task"]],
+ }
+ state = prepare_state(self.config, batch_dict) # bs,8 -> bs,32
+ lang_tokens, lang_masks = prepare_language(self.config, self.tokenizer, batch_dict) # bs, seq_len
+ actions = prepare_action(self.config, batch_dict) # bs,50,7 -> bs,50,32 , 7
+ images, img_masks, pil_images = prepare_images(self.config, self.image_processor, batch_dict, use_depth_align=self.use_depth_align)
+
+ batch_dict = {
+ 'images': images,
+ 'img_masks': img_masks,
+ 'state': state,
+ 'lang_tokens': lang_tokens,
+ 'lang_masks': lang_masks,
+ 'actions': actions,
+ 'action_is_pad': batch_dict['action_is_pad'],
+ }
+ if self.use_depth_align: batch_dict['pil_images'] = pil_images
+
+ return batch_dict
+
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
+ if idx < 0 or idx >= len(self):
+ raise IndexError(f"Index {idx} out of bounds.")
+ max_retries = 200
+ attempts = 0
+ cur = idx
+ last_err = None
+ while attempts < max_retries:
+ try:
+ return self.getdata(cur)
+ except Exception as e:
+ last_err = e
+ attempts += 1
+ cur = np.random.randint(0, len(self))
+ if cur >= len(self):
+ cur = 0
+ continue
+
+ raise RuntimeError(
+ f"Failed to fetch a valid item starting from idx={idx} after {attempts} attempts. "
+ f"Last error: {repr(last_err)}"
+ )
+
+class CustomizedRobotwinDataset(Dataset):
+ def __init__(
+ self,
+ repo_id="robotwin",
+ config=PI0Config,
+ tokenizer=AutoTokenizer,
+ data_config=None,
+ image_processor=None,
+ use_depth_align=False,
+ ):
+ image_transforms = Resize((data_config.img_size, data_config.img_size))
+ self.image_processor = image_processor
+ # [i / 30 for i in range(50)] represents action chunks in 50 steps at 30 FPS.
+ # The timestamps are set to 0 for the images and state, as we only use current obs.
+ self.config = config
+ self.tokenizer = tokenizer
+ self.norm_stats_file = data_config.norm_stats_file
+ self.dataset_meta = LeRobotDatasetMetadata(repo_id)
+ delta_timestamps = {
+ "action": [t / self.dataset_meta.fps for t in range(50)],
+ }
+ self.dataset = LeRobotDataset(
+ repo_id=repo_id,
+ image_transforms=image_transforms,
+ delta_timestamps=delta_timestamps,
+ )
+ with open(self.norm_stats_file) as f:
+ self.norm_stats = json.load(f)
+ self.normalizer = Normalizer(
+ # norm_stats=self.dataset.meta.stats,
+ norm_stats=self.norm_stats['norm_stats'],
+ from_file=True,
+ data_type='customized',
+ norm_type={
+ "observation.images.cam_high": "identity",
+ "observation.images.cam_left_wrist": "identity",
+ "observation.images.cam_right_wrist": "identity",
+ "observation.state": data_config.norm_type,
+ "action": data_config.norm_type,
+ },
+ )
+ self.use_depth_align = use_depth_align
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def getdata(self, idx):
+ item = self.dataset[idx]
+ task = self.dataset_meta.tasks[int(item['task_index'])]
+ assert task == item['task']
+
+ normalized_item = self.normalizer.normalize(item)
+ base_image = (normalized_item["observation.images.cam_high"] * 255).to(torch.uint8)
+ left_wrist_image = (normalized_item["observation.images.cam_left_wrist"] * 255).to(
+ torch.uint8
+ )
+ right_wrist_image = (normalized_item["observation.images.cam_right_wrist"] * 255).to(
+ torch.uint8
+ )
+ batch_dict = {
+ "image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
+ "state": normalized_item["observation.state"].to(torch.float32),
+ "action": normalized_item["action"].to(torch.float32),
+ "action_is_pad": normalized_item["action_is_pad"],
+ "prompt": [item["task"]],
+ }
+ state = prepare_state(self.config, batch_dict) # bs,8 -> bs,32
+ lang_tokens, lang_masks = prepare_language(self.config, self.tokenizer, batch_dict) # bs, seq_len
+ actions = prepare_action(self.config, batch_dict) # bs,50,7 -> bs,50,32 , 7
+ images, img_masks, pil_images = prepare_images(self.config, self.image_processor, batch_dict, use_depth_align=self.use_depth_align)
+
+ batch_dict = {
+ 'images': images,
+ 'img_masks': img_masks,
+ 'state': state,
+ 'lang_tokens': lang_tokens,
+ 'lang_masks': lang_masks,
+ 'actions': actions,
+ 'action_is_pad': batch_dict['action_is_pad'],
+ }
+ if self.use_depth_align: batch_dict['pil_images'] = pil_images
+
+ return batch_dict
+
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
+ if idx < 0 or idx >= len(self):
+ raise IndexError(f"Index {idx} out of bounds.")
+ max_retries = 200
+ attempts = 0
+ cur = idx
+ last_err = None
+ while attempts < max_retries:
+ try:
+ return self.getdata(cur)
+ except Exception as e:
+ last_err = e
+ attempts += 1
+ cur = np.random.randint(0, len(self))
+ if cur >= len(self):
+ cur = 0
+ continue
+
+ raise RuntimeError(
+ f"Failed to fetch a valid item starting from idx={idx} after {attempts} attempts. "
+ f"Last error: {repr(last_err)}"
+ )
\ No newline at end of file
diff --git a/lingbotvla/data/vla_data/transform.py b/lingbotvla/data/vla_data/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..80a7bc4d067ac56ced0dde9915a3d26c1a977882
--- /dev/null
+++ b/lingbotvla/data/vla_data/transform.py
@@ -0,0 +1,306 @@
+from typing import Dict
+
+import numpy as np
+import torch
+import math
+import einops
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+
+IMAGE_KEYS = (
+ "base_0_rgb",
+ "left_wrist_0_rgb",
+ "right_wrist_0_rgb",
+)
+
+
+def dict_apply(func, d):
+ """
+ Apply a function to all values in a dictionary recursively.
+ If the value is a dictionary, it will apply the function to its values.
+ """
+ for key, value in d.items():
+ if isinstance(value, dict):
+ dict_apply(func, value)
+ else:
+ d[key] = func(value)
+ return d
+
+class Normalizer:
+ def __init__(
+ self,
+ norm_stats: Dict[str, Dict[str, np.ndarray]],
+ from_file: bool=False,
+ data_type: str=None,
+ norm_type: Dict[str, str] | None = None,
+ ):
+ if from_file:
+ if data_type == 'libero':
+ norm_stats['state']['mean'] = np.array(norm_stats['state']['mean'][:8])
+ norm_stats['state']['std'] = np.array(norm_stats['state']['std'][:8])
+ norm_stats['actions']['mean'] = np.array(norm_stats['actions']['mean'][:7])
+ norm_stats['actions']['std'] = np.array(norm_stats['actions']['std'][:7])
+ elif data_type == 'robotwin':
+ norm_stats['observation.state'], norm_stats['action'] = {}, {}
+ norm_stats['observation.state']['q01'] = np.array(norm_stats['observation.state.arm.position']['q01'][:6] + norm_stats['observation.state.effector.position']['q01'][:1] + norm_stats['observation.state.arm.position']['q01'][6:] + norm_stats['observation.state.effector.position']['q01'][1:])
+ norm_stats['observation.state']['q99'] = np.array(norm_stats['observation.state.arm.position']['q99'][:6] + norm_stats['observation.state.effector.position']['q99'][:1] + norm_stats['observation.state.arm.position']['q99'][6:] + norm_stats['observation.state.effector.position']['q99'][1:])
+ norm_stats['action']['q01'] = np.array(norm_stats['action.arm.position']['q01'][:6] + norm_stats['action.effector.position']['q01'][:1] + norm_stats['action.arm.position']['q01'][6:] + norm_stats['action.effector.position']['q01'][1:])
+ norm_stats['action']['q99'] = np.array(norm_stats['action.arm.position']['q99'][:6] + norm_stats['action.effector.position']['q99'][:1] + norm_stats['action.arm.position']['q99'][6:] + norm_stats['action.effector.position']['q99'][1:])
+ elif data_type == 'robotwin_rep':
+ norm_stats['observation.state'], norm_stats['action'] = {}, {}
+ norm_stats['observation.state']['q01'] = np.array(norm_stats['observation.state.arm.position']['q01'] + norm_stats['observation.state.effector.position']['q01'])
+ norm_stats['observation.state']['q99'] = np.array(norm_stats['observation.state.arm.position']['q99'] + norm_stats['observation.state.effector.position']['q99'])
+ norm_stats['action']['q01'] = np.array(norm_stats['action.arm.position']['q01'][:6] + norm_stats['action.effector.position']['q01'][:1] + norm_stats['action.arm.position']['q01'][6:] + norm_stats['action.effector.position']['q01'][1:])
+ norm_stats['action']['q99'] = np.array(norm_stats['action.arm.position']['q99'][:6] + norm_stats['action.effector.position']['q99'][:1] + norm_stats['action.arm.position']['q99'][6:] + norm_stats['action.effector.position']['q99'][1:])
+ elif data_type == 'customized':
+ for key in norm_stats:
+ if isinstance(norm_stats[key], dict):
+ for sub_key in norm_stats[key]:
+ norm_stats[key][sub_key] = np.array(norm_stats[key][sub_key])
+ self.norm_stats = norm_stats
+ else:
+ self.norm_stats = dict_apply(lambda x: x.astype(np.float32), norm_stats)
+ self.norm_type = norm_type or {}
+ self.from_file = from_file
+
+ def normalize(self, data: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
+ normalized_data = {}
+ for key, value in data.items():
+ if key in self.norm_stats:
+ norm_type = self.norm_type.get(key, "identity")
+ if norm_type == "meanstd":
+ mean = self.norm_stats[key]["mean"]
+ std = self.norm_stats[key]["std"]
+ normalized_value = (value - mean) / (std + 1e-6)
+ elif norm_type == "bounds_99_woclip":
+ low = self.norm_stats[key]["q01"]
+ high = self.norm_stats[key]["q99"]
+ normalized_value = (value - low) / (high - low + 1e-6) * 2.0 - 1.0
+ elif norm_type == "std":
+ std = self.norm_stats[key]["std"]
+ normalized_value = value / (std + 1e-6)
+ elif norm_type == "minmax":
+ min_val = self.norm_stats[key]["min"]
+ max_val = self.norm_stats[key]["max"]
+ normalized_value = (value - min_val) / (
+ max_val - min_val + 1e-6
+ ) * 2 - 1
+ elif norm_type == "identity":
+ normalized_value = value
+ else:
+ raise ValueError(
+ f"Unknown normalization type: {norm_type}. Supported types are 'meanstd', 'minmax', and 'identity'."
+ )
+ normalized_data[key] = normalized_value
+ else:
+ # If the key is not in norm_stats, we assume no normalization is needed
+ normalized_data[key] = value
+ return normalized_data
+
+ def unnormalize(self, data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+ """
+ Unnormalize the given data using stored normalization statistics.
+
+ Args:
+ data (Dict[str, np.ndarray]): Dictionary of normalized arrays to unnormalize.
+
+ Returns:
+ Dict[str, np.ndarray]: Dictionary of unnormalized arrays.
+ """
+ unnormalized_data = {}
+ for key, value in data.items():
+ if key in self.norm_stats:
+ norm_type = self.norm_type.get(key, "identity")
+ stats = self.norm_stats[key]
+ if norm_type == "meanstd":
+ mean = stats["mean"]
+ std = stats["std"]
+ unnormalized_value = value * (std + 1e-6) + mean
+ elif norm_type == "bounds_98" or norm_type == 'bounds_98_woclip':
+ low = self.norm_stats[key]["q02"]
+ high = self.norm_stats[key]["q98"]
+ unnormalized_value = ((value + 1.0) / 2.0) * (high - low + 1e-6) + low
+ elif norm_type == "bounds_99" or norm_type == "bounds_99_woclip":
+ low = self.norm_stats[key]["q01"]
+ high = self.norm_stats[key]["q99"]
+ unnormalized_value = ((value + 1.0) / 2.0) * (high - low + 1e-6) + low
+ elif norm_type == "std":
+ std = stats["std"]
+ unnormalized_value = value * (std + 1e-6)
+ elif norm_type == "minmax":
+ min_val = stats["min"]
+ max_val = stats["max"]
+ # Reverse: (x + 1)/2 * (max-min+eps) + min
+ unnormalized_value = (value + 1) / 2.0 * (max_val - min_val + 1e-6) + min_val
+ elif norm_type == "identity":
+ unnormalized_value = value
+ else:
+ raise ValueError(
+ f"Unknown normalization type: {norm_type}. Supported types are 'meanstd', 'minmax', and 'identity'."
+ )
+ unnormalized_data[key] = unnormalized_value
+ else:
+ # If no normalization was applied, return as-is
+ unnormalized_data[key] = value
+ return unnormalized_data
+
+def resize_with_pad_item(img, width, height, pad_value=-1):
+ # assume no-op when width height fits already
+ if img.ndim != 3:
+ raise ValueError(f"(c,h,w) expected, but {img.shape}")
+
+ cur_height, cur_width = img.shape[1:]
+
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+ resized_img = F.interpolate(
+ img.unsqueeze(0), size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ ).squeeze(0)
+
+ pad_height = max(0, int(height - resized_height))
+ pad_width = max(0, int(width - resized_width))
+
+ # pad on left and top of image
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
+ return padded_img
+
+def prepare_images(config, image_processor, observation: dict[str, Tensor], use_depth_align=False):
+ """Normalize, resize, and pad images and stack them into a tensor.
+
+ Args:
+ observation (dict[str, Tensor])
+
+ Returns:
+ images (torch.Tensor): (*b, n, c, h, w) images in range [-1.0, 1.0]
+ img_masks (torch.Tensor): (*b, n) masks for images, True if image is present, False if missing
+ """
+ dtype = observation["state"].dtype
+ images, img_masks = [], []
+ if use_depth_align:
+ pil_images = []
+
+ for key in IMAGE_KEYS:
+ if key in observation["image"]:
+ # resize, pad, and normalize
+ img = observation["image"][key]
+ assert img.ndim == 3, f"Expected 3D image, got {img.shape}"
+ pil_img = img.cpu().numpy()
+ if image_processor is None:
+ img = img.to(dtype) / 127.5 - 1.0 # to [-1, 1]
+ img = resize_with_pad_item(
+ img, *config.resize_imgs_with_padding, pad_value=-1.0
+ )
+ else:
+ img = resize_with_pad_item(
+ img, *config.resize_imgs_with_padding, pad_value=0
+ )
+ img = image_processor(img)['pixel_values']
+ images.append(img)
+ img_masks.append(True)
+ if use_depth_align:
+ pil_images.append(pil_img)
+ else:
+ # zero padding
+ if image_processor is None:
+ img = torch.full_like(img, fill_value=-1.0)
+ if use_depth_align:
+ pil_img = torch.full_like(pil_img, fill_value=-1.0)
+ else:
+ img = np.zeros_like(img)
+ if use_depth_align:
+ pil_img = np.zeros_like(pil_img)
+ images.append(img)
+ if use_depth_align:
+ pil_images.append(pil_img)
+ img_masks.append(False)
+ if isinstance(images[0], torch.Tensor):
+ images = torch.stack(images, dim=0) # (n, c, h, w)
+ elif isinstance(images[0], np.ndarray):
+ images = torch.from_numpy(np.stack(images, axis=0)) # (n, c, h, w)
+ img_masks = torch.tensor(img_masks, dtype=torch.bool) # (*n)
+
+ if use_depth_align:
+ pil_images = torch.from_numpy(np.stack(pil_images, axis=0)) # (n, c, h, w)
+ else:
+ pil_images = []
+
+ return images, img_masks, pil_images
+
+def prepare_state(config, observation: dict[str, Tensor]):
+ """Pad the state to the maximum state dimension.
+
+ Args:
+ observation (dict[str, Tensor])
+
+ Returns:
+ state (torch.Tensor): (*b, max_state_dim) padded state tensor
+ """
+ state = observation["state"]
+ state = F.pad(state, (0, config.max_state_dim - state.shape[-1]))
+ return state
+
+def prepare_action(config, observation: dict[str, Tensor]):
+ """Pad the action to the maximum action dimension.
+
+ Args:
+ observation (dict[str, Tensor])
+
+ Returns:
+ action (torch.Tensor): (*b, n, max_action_dim) padded action tensor
+ action_dim (int): the actual dimension of the action before padding
+ """
+ # ipdb.set_trace()
+ action = observation["action"]
+ action = F.pad(action, (0, config.max_action_dim - action.shape[-1]))
+ return action
+
+def prepare_language(config, language_tokenizer, observation: dict[str, Tensor]):
+ """If `prompt` is provided, modify it to PaliGemma format and tokenize it.
+ If `lang_tokens` and `lang_masks` are provided, use them directly.
+
+ PaliGemma expects prefix prompts to be formatted as:
+ .... prompt , where uses `\\n`.
+ So here we format the prompt to start with `` and end with `\\n`.
+ Later, we will concatenate the images and language tokens into a single sequence.
+
+ Args:
+ observation (dict[str, Tensor])
+
+ Returns:
+ lang_tokens (torch.Tensor): (*b, l) language tokens
+ lang_masks (torch.Tensor): (*b, l) masks for language tokens, True if token is present, False if missing
+ """
+ lang_tokens = observation.get("lang_tokens", None)
+ lang_masks = observation.get("lang_masks", None)
+ prompt = observation.get("prompt", None)
+
+ # either provide `prompt` or (`lang_tokens`, `lang_masks`)
+ if prompt is None and (lang_tokens is None or lang_masks is None):
+ raise ValueError(
+ "Either 'prompt' or ('lang_tokens', 'lang_masks') must be provided in the observation."
+ )
+
+ device = observation["state"].device
+ if prompt is not None and (lang_tokens is None or lang_masks is None):
+ prompt = [p if p.startswith("") else f"{p}" for p in prompt]
+ prompt = [p if p.endswith("\n") else f"{p}\n" for p in prompt]
+ tokenized_prompt = language_tokenizer.__call__(
+ prompt,
+ padding="max_length",
+ padding_side="right",
+ max_length=config.tokenizer_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ lang_tokens = tokenized_prompt["input_ids"].to(device=device)
+ lang_masks = tokenized_prompt["attention_mask"].to(
+ device=device, dtype=torch.bool
+ )
+ else:
+ lang_tokens = observation["lang_tokens"].to(device=device)
+ lang_masks = observation["lang_masks"].to(device=device, dtype=torch.bool)
+
+ return lang_tokens.squeeze(0), lang_masks.squeeze(0)
\ No newline at end of file
diff --git a/lingbotvla/distributed/__init__.py b/lingbotvla/distributed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd1e8433dffa0b3ba420be3e346f4f5cd062014
--- /dev/null
+++ b/lingbotvla/distributed/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/lingbotvla/distributed/checkpoint.py b/lingbotvla/distributed/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..a87ff291bcd0a175537828530eb023a7304c0d97
--- /dev/null
+++ b/lingbotvla/distributed/checkpoint.py
@@ -0,0 +1,137 @@
+import contextlib
+
+import torch
+from torch.distributed.fsdp._common_utils import _get_module_fsdp_state_if_fully_sharded_module, _module_handle
+from torch.distributed.fsdp._runtime_utils import (
+ _post_backward_hook,
+ _pre_backward_hook,
+)
+from torch.utils.checkpoint import (
+ _get_autocast_kwargs,
+ _get_device_module,
+ _infer_device_type,
+ check_backward_validity,
+ detach_variable,
+ get_device_states,
+ set_device_states,
+)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, preserve_rng_state, *args):
+ check_backward_validity(args)
+ ctx.run_function = run_function
+ ctx.preserve_rng_state = preserve_rng_state
+ # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
+ ctx.device = _infer_device_type(*args)
+ ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(ctx.device)
+ if preserve_rng_state:
+ ctx.fwd_cpu_state = torch.get_rng_state()
+ # Don't eagerly initialize the cuda context by accident.
+ # (If the user intends that the context is initialized later, within their
+ # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
+ # we have no way to anticipate this will happen before we run the function.)
+ ctx.had_device_in_fwd = False
+ device_module = _get_device_module(ctx.device)
+ if getattr(device_module, "_initialized", False):
+ ctx.had_device_in_fwd = True
+ ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
+
+ # Save non-tensor inputs in ctx, keep a placeholder None for tensors
+ # to be filled out during the backward.
+ ctx.inputs = []
+ ctx.tensor_indices = []
+ tensor_inputs = []
+ for i, arg in enumerate(args):
+ if torch.is_tensor(arg):
+ tensor_inputs.append(arg)
+ ctx.tensor_indices.append(i)
+ ctx.inputs.append(None)
+ else:
+ ctx.inputs.append(arg)
+
+ ctx.save_for_backward(*tensor_inputs)
+
+ with torch.no_grad():
+ outputs = run_function(*args)
+
+ # patch code, remove the extra allgather with use_reentrant + ckpt
+ if not isinstance(ctx.run_function, torch.nn.Module):
+ ctx.patch_module = ctx.run_function.__self__
+ else:
+ ctx.patch_module = ctx.run_function
+ state = _get_module_fsdp_state_if_fully_sharded_module(ctx.patch_module)
+ if state:
+ handle = _module_handle(state, ctx.patch_module)
+ if handle:
+ handle._needs_pre_backward_unshard = True
+ return outputs
+
+ @staticmethod
+ def backward(ctx, *args):
+ if not torch.autograd._is_checkpoint_valid():
+ raise RuntimeError(
+ "When use_reentrant=True, torch.utils.checkpoint is incompatible"
+ " with .grad() or passing an `inputs` parameter to .backward()."
+ " To resolve this error, you can either set use_reentrant=False,"
+ " or call .backward() without passing the `inputs` argument."
+ )
+ # patch code, remove the extra allgather with use_reentrant + ckpt
+ handle = None
+ state = _get_module_fsdp_state_if_fully_sharded_module(ctx.patch_module)
+ if state:
+ handle = _module_handle(state, ctx.patch_module)
+ if handle:
+ _pre_backward_hook(state, ctx.patch_module, handle, None)
+
+ # Copy the list to avoid modifying original list.
+ inputs = list(ctx.inputs)
+ tensor_indices = ctx.tensor_indices
+ tensors = ctx.saved_tensors
+
+ # Fill in inputs with appropriate saved tensors.
+ for i, idx in enumerate(tensor_indices):
+ inputs[idx] = tensors[i]
+
+ # Stash the surrounding rng state, and mimic the state that was
+ # present at this time during forward. Restore the surrounding state
+ # when we're done.
+ rng_devices = []
+ if ctx.preserve_rng_state and ctx.had_device_in_fwd:
+ rng_devices = ctx.fwd_devices
+ with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device):
+ if ctx.preserve_rng_state:
+ torch.set_rng_state(ctx.fwd_cpu_state)
+ if ctx.had_device_in_fwd:
+ set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
+ detached_inputs = detach_variable(tuple(inputs))
+
+ device_autocast_ctx = (
+ torch.amp.autocast(device_type=ctx.device, **ctx.device_autocast_kwargs)
+ if torch.amp.is_autocast_available(ctx.device)
+ else contextlib.nullcontext()
+ )
+ with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
+ outputs = ctx.run_function(*detached_inputs)
+
+ if isinstance(outputs, torch.Tensor):
+ outputs = (outputs,)
+
+ # run backward() with only tensor that requires grad
+ outputs_with_grad = []
+ args_with_grad = []
+ for i in range(len(outputs)):
+ if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
+ outputs_with_grad.append(outputs[i])
+ args_with_grad.append(args[i])
+ if len(outputs_with_grad) == 0:
+ raise RuntimeError("none of output has requires_grad=True, this checkpoint() is not necessary")
+ torch.autograd.backward(outputs_with_grad, args_with_grad)
+ grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs)
+
+ # patch code, remove the extra allgather with use_reentrant + ckpt
+ if handle:
+ _post_backward_hook(state, handle, None)
+
+ return (None, None) + grads
diff --git a/lingbotvla/distributed/fsdp/__init__.py b/lingbotvla/distributed/fsdp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..369bb88e1de8b1df25aa74541240974a81000408
--- /dev/null
+++ b/lingbotvla/distributed/fsdp/__init__.py
@@ -0,0 +1,3 @@
+from .clip_grad_norm import clip_grad_norm_
+from .extension import register_checkpoint_extension
+from .initialize import init_fsdp_fn, parallel_init_fsdp_fn, parallel_load_safetensors
diff --git a/lingbotvla/distributed/fsdp/clip_grad_norm.py b/lingbotvla/distributed/fsdp/clip_grad_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ee64aaee099f197cfa2eef8f44048c37d5fd791
--- /dev/null
+++ b/lingbotvla/distributed/fsdp/clip_grad_norm.py
@@ -0,0 +1,137 @@
+import functools
+import math
+import warnings
+
+import torch
+import torch.distributed as dist
+from torch.distributed._tensor import Shard
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.fully_sharded_data_parallel import _get_grad_norm
+
+from ...utils.import_utils import is_torch_version_greater_than
+from ..parallel_plan import SpecInfo
+
+
+def clip_grad_norm_(fsdp_model: FSDP, max_norm, norm_type=2.0) -> torch.Tensor:
+ extension = fsdp_model._fsdp_extension
+ ep_mesh = extension.ep_mesh
+ ep_group = None if ep_mesh is None else ep_mesh.get_group()
+
+ if ep_group is None or dist.get_world_size(ep_group) in (1, dist.get_world_size()):
+ return FSDP.clip_grad_norm_(fsdp_model, max_norm, norm_type)
+
+ assert fsdp_model._is_root
+ # use dict as ordered set to make param order consistent among
+ # dp (hsdp) ranks to avoid gnorm difference due to reduction order
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+ fsdp_managed_params = set()
+ sharded_params_for_gnorm = {}
+ ep_fsdp_sharded_params_for_gnorm = {}
+ nonsharded_params_for_gnorm = {}
+ grads_for_clip = []
+ ep_fsdp_process_group = None
+
+ for handle in fsdp_model._all_handles:
+ assert handle.uses_sharded_strategy
+ assert handle._use_orig_params, "tensor parallelism can only work with FSDP using `use_orig_params=True`"
+ for param in handle.flat_param._params:
+ assert hasattr(param, "spec_info")
+ spec_info: SpecInfo = param.spec_info
+ fsdp_managed_params.add(param)
+ if param.grad is not None:
+ grads_for_clip.append(param.grad)
+ # ep param
+ if isinstance(spec_info.placement, Shard):
+ if ep_fsdp_process_group is None:
+ ep_fsdp_process_group = handle.process_group
+ ep_fsdp_sharded_params_for_gnorm.setdefault(param, None)
+ # fsdp param
+ else:
+ sharded_params_for_gnorm.setdefault(param, None)
+ for param in fsdp_model.parameters():
+ not_fsdp_managed = param not in fsdp_managed_params and param not in sharded_params_for_gnorm
+ if not_fsdp_managed:
+ assert hasattr(param, "_spec")
+ raise NotImplementedError(f"param {param._spec.fqn} is not managed by FSDP")
+
+ # Compute local norms (forced to be in FP32)
+ if is_torch_version_greater_than("2.5.0"):
+ grad_norm_kwargs = {
+ "norm_type": norm_type,
+ "zero": torch.tensor(0.0),
+ "device": fsdp_model.compute_device,
+ }
+ else:
+ grad_norm_kwargs = {
+ "norm_type": norm_type,
+ }
+
+ local_sharded_norm = _get_grad_norm(sharded_params_for_gnorm, **grad_norm_kwargs).to(fsdp_model.compute_device)
+ local_ep_fsdp_sharded_norm = (
+ _get_grad_norm(ep_fsdp_sharded_params_for_gnorm, **grad_norm_kwargs).to(fsdp_model.compute_device)
+ if ep_fsdp_sharded_params_for_gnorm
+ else None
+ )
+ local_nonsharded_norm = (
+ _get_grad_norm(nonsharded_params_for_gnorm, **grad_norm_kwargs).to(fsdp_model.compute_device)
+ if nonsharded_params_for_gnorm
+ else None
+ )
+
+ # Reconstruct the total gradient norm depending on the norm type
+ if norm_type == math.inf:
+ total_norm = (
+ torch.maximum(local_sharded_norm, local_nonsharded_norm)
+ if local_nonsharded_norm is not None
+ else local_sharded_norm
+ )
+ dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=fsdp_model.process_group)
+ # allreduce across tp group
+ dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=ep_group)
+ else:
+ total_norm = local_sharded_norm**norm_type
+ dist.all_reduce(total_norm, group=fsdp_model.process_group)
+ if local_ep_fsdp_sharded_norm is not None:
+ total_ep_fsdp_sharded_norm = local_ep_fsdp_sharded_norm**norm_type
+ dist.all_reduce(total_ep_fsdp_sharded_norm, group=ep_fsdp_process_group)
+ dist.all_reduce(total_ep_fsdp_sharded_norm, group=ep_group)
+ total_norm += total_ep_fsdp_sharded_norm
+
+ # All-reducing the local non-sharded norm would count it an extra
+ # world-size-many times
+ if local_nonsharded_norm is not None:
+ total_norm += local_nonsharded_norm**norm_type
+ total_norm = total_norm ** (1.0 / norm_type)
+ if fsdp_model.cpu_offload.offload_params:
+ total_norm = total_norm.cpu()
+
+ clip_coef = max_norm / (total_norm + 1e-6)
+ # Multiplying by the clamped coefficient is meaningless when it is
+ # equal to 1, but it avoids the host-device sync that would result from
+ # `if clip_coef < 1`
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
+ for grad in grads_for_clip:
+ grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype))
+ # Use the "largest" dtype by type promotion semantics to use the same
+ # dtype as if we did not force local norm computation to be in FP32
+ if len(grads_for_clip) == 0:
+ # If this rank has no gradients, then we must default to FP32
+ # unless we use additional communication, which we prefer to avoid
+ # since `clip_grad_norm_()` is called in the training loop
+ warnings.warn(
+ f"Called FSDP.clip_grad_norm_() on rank {fsdp_model.rank} with no "
+ "gradients -- returning the total norm in the default dtype "
+ f"{total_norm.dtype}"
+ ) # warn since this is generally unexpected
+ return total_norm
+ total_norm_dtype = functools.reduce(
+ torch.promote_types,
+ [grad.dtype for grad in grads_for_clip],
+ )
+ return total_norm.to(total_norm_dtype)
+
+
+def _is_first_ep_rank(ep_group: dist.ProcessGroup):
+ assert ep_group is not None
+ return dist.get_rank(ep_group) == 0
diff --git a/lingbotvla/distributed/fsdp/extension.py b/lingbotvla/distributed/fsdp/extension.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d324188d905e325532211d4e79e0acc911a947e
--- /dev/null
+++ b/lingbotvla/distributed/fsdp/extension.py
@@ -0,0 +1,450 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import copy
+from functools import partial
+from typing import Any, Dict, List, Tuple, Union
+
+import torch
+from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions
+from torch.distributed.fsdp._optim_utils import (
+ FSDPParamInfo,
+ _OptimStateKey,
+ _unflatten_optim_state,
+ sorted_items,
+)
+
+from ...utils import logging
+from ...utils.import_utils import is_torch_version_greater_than
+from ..parallel_plan import SpecInfo
+
+
+logger = logging.get_logger(__name__)
+
+OPTIM_STATE_NO_SHARD_KEY = ["step"]
+orig_optim_state_dict = FSDP.optim_state_dict
+orig_optim_state_dict_to_load = FSDP.optim_state_dict_to_load
+
+
+def _shard_tensor(orgin_tensor: torch.Tensor, device_mesh: DeviceMesh, shard: Shard = Shard(0)):
+ """
+ Shard Tensor to DTensor.
+
+ args:
+ orgin_tensor (torch.Tensor): The orgin tensor.
+ device_mesh (DeviceMesh): The ep device mesh.
+ shard (Shard): The shard info, default Shard(0).
+
+ """
+ assert device_mesh.ndim == 2, f"global_mesh.ndim must be 2, got {device_mesh.ndim}"
+ ep_mesh = device_mesh["ep"]
+
+ if orgin_tensor.__class__.__name__ == "DTensor":
+ dtensor = DTensor.from_local(orgin_tensor._local_tensor, device_mesh=device_mesh, placements=[shard, shard])
+ elif orgin_tensor.__class__.__name__ == "Tensor":
+ dtensor = DTensor.from_local(orgin_tensor, device_mesh=ep_mesh, placements=[shard])
+
+ return dtensor
+
+
+def _shard_dtensor(orgin_dtensor: DTensor, device_mesh: DeviceMesh, shard: Shard = Shard(0)):
+ """
+ Convert DTensor to local Tensor
+
+ args:
+ orgin_dtensor (torch.Tensor): The orgin tensor.
+ device_mesh (DeviceMesh): The ep device mesh.
+ shard (Shard): The shard info, default Shard(0).
+
+ """
+ assert isinstance(orgin_dtensor, DTensor), (
+ f"Only support DTensor, got {type(orgin_dtensor)}, for torch.Tensor, use _shard_dtensor instead."
+ )
+
+ local_tensor = orgin_dtensor.to_local()
+
+ return local_tensor
+
+
+def check_any_unflat_param_names_match(unflat_param_name: str, fqn2spec_info: Dict[str, SpecInfo], prefix: str = None):
+ assert isinstance(unflat_param_name, str), f"unflat_param_name must be a str, got {type(unflat_param_name)}"
+
+ if prefix:
+ assert unflat_param_name.startswith(prefix), (
+ f"unflat_param_name {unflat_param_name} must start with prefix {prefix}"
+ )
+ unflat_param_name = unflat_param_name[len(prefix) :].lstrip(".")
+
+ if unflat_param_name not in fqn2spec_info:
+ logger.warning_rank0(f"unflat_param_name {unflat_param_name} not in fqn2spec_info.")
+ return False
+
+ if isinstance(fqn2spec_info[unflat_param_name].placement, Shard):
+ return True
+
+ return False
+
+
+def check_all_unflat_param_names_match(unflat_param_names: Tuple[str], fqn2spec_info: Dict[str, SpecInfo]):
+ """
+ Check
+ """
+ assert isinstance(unflat_param_names, (list, tuple)), (
+ f"unflat_param_names must be a list or tuple, got {type(unflat_param_names)}"
+ )
+
+ unflat_len = len(unflat_param_names)
+ cnt = 0
+ for names in unflat_param_names:
+ assert names in fqn2spec_info, (
+ f"unflat_param_names {unflat_param_names} must be in fqn2spec_info {fqn2spec_info}"
+ )
+ if isinstance(fqn2spec_info[names].placement, Shard):
+ cnt += 1
+ assert cnt == 0 or cnt == unflat_len, f"unflat_param_names {unflat_param_names} must be all shard or all not shard"
+
+ return cnt == unflat_len
+
+
+class CheckpointExtensions(FSDPExtensions):
+ def __init__(
+ self,
+ ep_fsdp_device_mesh: DeviceMesh,
+ fqn2spec_info: Dict[str, SpecInfo],
+ ):
+ super().__init__()
+ self.ep_fsdp_device_mesh = ep_fsdp_device_mesh
+ self.ep_mesh = ep_fsdp_device_mesh["ep"] if ep_fsdp_device_mesh is not None else None
+ self.fqn2spec_info = fqn2spec_info
+
+ def chunk_dtensor(self, tensor: torch.Tensor, rank: int, device_mesh: DeviceMesh) -> torch.Tensor:
+ """Shards a tensor/DTensor to DTensor and returns the local DTensor."""
+ # We need to explicitly call .detach() to return a new tensor detached from the current graph.
+ tensor = tensor.clone().detach()
+ fsdp_size = device_mesh.size(-1)
+ dimlens = tuple(tensor.size())
+ # by default we use the max-len dimension for sharding
+ selected_dim = dimlens.index(max(dimlens))
+ for dim, dimlen in enumerate(dimlens):
+ if dimlen % fsdp_size == 0:
+ selected_dim = dim
+ break
+ # HSDP placements: [Replicate(), ..., Shard(selected_dim)]
+ replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
+ shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
+ shard_placements[-1] = Shard(selected_dim) # type: ignore[call-overload]
+ dtensor = DTensor.from_local(tensor, device_mesh, replicate_placements, run_check=False).redistribute(
+ placements=shard_placements,
+ )
+
+ return dtensor
+
+ def chunk_tensor(self, tensor, rank, world_size, num_devices_per_node, pg, device=None):
+ # use default
+ raise NotImplementedError("Please init FSDP with device mesh")
+ from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor
+
+ return _ext_chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg)
+
+ def pre_flatten_transform(self, tensor):
+ # use default
+ from torch.distributed.fsdp._fsdp_extensions import _ext_pre_flatten_transform
+
+ return _ext_pre_flatten_transform(tensor)
+
+ def pre_load_state_dict_transform(self, tensor):
+ # use default
+ from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform
+
+ return _ext_pre_load_state_dict_transform(tensor)
+
+ def post_unflatten_transform(self, tensor, param_extension):
+ # use default
+ from torch.distributed.fsdp._fsdp_extensions import _ext_post_unflatten_transform
+
+ return _ext_post_unflatten_transform(tensor, param_extension)
+
+ def all_gather_dtensor(self, tensor: DTensor, parent_mesh):
+ # this is required during loading checkpoint (model.load_state_dict)
+ # use default
+ from torch.distributed.fsdp._fsdp_extensions import _ext_all_gather_dtensor
+
+ if is_torch_version_greater_than("2.5.0"):
+ return _ext_all_gather_dtensor(tensor, tensor.device_mesh)
+ else:
+ return _ext_all_gather_dtensor(tensor, None)
+
+ @torch.no_grad()
+ def state_dict_post_hook(
+ self, module, state_dict, prefix, local_metadata, fqn2spec_info: Dict[str, SpecInfo] = None
+ ):
+ """
+ Post state dict when calling `model.state_dict()` for EP cases.
+
+ This will append EP placements to the FSDP DTensor state dicts
+ """
+ assert fqn2spec_info is not None, "if fqn2spec_info is None it should not be patch"
+
+ if self.ep_mesh is None:
+ return
+ # [pp, ep_dp, ep, tp]
+ global_device_mesh = self.ep_fsdp_device_mesh
+ assert global_device_mesh.ndim == 2
+
+ keys = list(state_dict.keys())
+ for name in sorted(keys):
+ if name in fqn2spec_info and isinstance(fqn2spec_info[name].placement, Shard):
+ cur_spec_info = fqn2spec_info[name]
+ tensor = state_dict[name]
+ tensor = _shard_tensor(tensor, cur_spec_info.ep_fsdp_mesh, cur_spec_info.placement)
+ state_dict[name] = tensor
+
+ @torch.no_grad()
+ def load_state_dict_pre_hook(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ fqn2spec_info: Dict[str, SpecInfo] = None,
+ ):
+ """
+ Pre load state dict when calling `model.load_state_dict()` for EP cases.
+
+ This will shard Dtensor from ckpt to tensor state dicts
+ """
+ assert fqn2spec_info is not None, "if fqn2spec_info is None it should not be patch"
+
+ if self.ep_mesh is None:
+ return
+ # [ep, fsdp-ep]
+ global_device_mesh = self.ep_fsdp_device_mesh
+ assert global_device_mesh.ndim == 2
+
+ if self.ep_mesh.size() != global_device_mesh.size():
+ return
+
+ keys = list(state_dict.keys())
+ for name in sorted(keys):
+ tensor = state_dict[name]
+ if check_any_unflat_param_names_match(name, fqn2spec_info, "_fsdp_wrapped_module"):
+ fqn = name.split("_fsdp_wrapped_module.")[-1]
+ cur_spec_info = fqn2spec_info[fqn]
+ tensor = _shard_dtensor(tensor, cur_spec_info.ep_fsdp_mesh, cur_spec_info.placement)
+ state_dict[name] = tensor
+
+ def patch_convert_state_with_flat_params(self):
+ """ """
+
+ # Modify from torch.distributed.fsdp._optim_utils._convert_state_with_flat_params
+ def _convert_state_with_flat_params_patch(
+ all_optim_state_keys: List[_OptimStateKey],
+ optim_state_key_to_param_key: Dict[_OptimStateKey, Union[int, str]],
+ fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
+ optim_state_dict: Dict[Union[str, int], Any],
+ to_save: bool,
+ shard_state: bool,
+ cpu_offload: bool = True,
+ fqn2spec_info: Dict[str, SpecInfo] = None,
+ ) -> Dict[str, Any]:
+ fsdp_osd_state: Dict[str, Any] = {}
+ # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
+ # across ranks
+ for optim_state_key in all_optim_state_keys:
+ param_key: Union[str, int, None] = optim_state_key_to_param_key.get(optim_state_key, None)
+
+ assert param_key is not None, (
+ "If use_orig_params is False, we must be able to find the "
+ f"corresponding param id. {optim_state_key} {param_key}"
+ )
+
+ if optim_state_key.is_fsdp_managed:
+ # If there are multiple unflat_param_names (not use_orig_params),
+ # they share the same FSDPParamInfo. So the first unflat_param_name
+ # is sufficient to fetch the FSDPParamInfo.
+ fqn = optim_state_key.unflat_param_names[0]
+ fsdp_param_info = fqn_to_fsdp_param_info[fqn]
+ if check_all_unflat_param_names_match(optim_state_key.unflat_param_names, fqn2spec_info):
+ unflat_state = _unflatten_optim_state(
+ fsdp_param_info,
+ optim_state_dict[param_key],
+ to_save,
+ False,
+ cpu_offload,
+ )
+ else:
+ unflat_state = _unflatten_optim_state(
+ fsdp_param_info,
+ optim_state_dict[param_key],
+ to_save,
+ shard_state,
+ cpu_offload,
+ )
+ if to_save:
+ assert len(unflat_state) == len(optim_state_key.unflat_param_names)
+ for unflat_param_name, unflat_param_state in zip(
+ optim_state_key.unflat_param_names,
+ unflat_state,
+ ):
+ fsdp_osd_state[unflat_param_name] = unflat_param_state
+ elif to_save:
+ assert len(optim_state_key.unflat_param_names) == 1
+ unflat_param_name = optim_state_key.unflat_param_names[0]
+ fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key])
+ if cpu_offload:
+ for state_name, value in sorted_items(fsdp_osd_state[unflat_param_name]):
+ if not torch.is_tensor(value):
+ continue
+ fsdp_osd_state[unflat_param_name][state_name] = value.cpu()
+
+ return fsdp_osd_state
+
+ # monkey patch
+ torch.distributed.fsdp._optim_utils._convert_state_with_flat_params = partial(
+ _convert_state_with_flat_params_patch, fqn2spec_info=self.fqn2spec_info
+ )
+
+ def patch_fsdp_optim_state_dict(self):
+ """ """
+
+ def fsdp_optim_state_post_patch_fn(
+ model, optim, optim_state_dict=None, fqn2spec_info: Dict[str, SpecInfo] = None
+ ):
+ assert fqn2spec_info is not None, "if fqn2spec_info is None it should not be patch"
+
+ fsdp_mesh = model._device_mesh
+ assert fsdp_mesh is not None, "Please init FSDP module with device_mesh"
+ # NOTE we don't support diverse process group for different FSDP sub-modules
+ fsdp_pg = model.process_group
+ optim_state = orig_optim_state_dict(model, optim, optim_state_dict, fsdp_pg)
+ if self.ep_mesh is None:
+ return optim_state
+
+ global_device_mesh = self.ep_fsdp_device_mesh
+ assert global_device_mesh.ndim == 2
+
+ # extend placements by adding EP placement
+ for fqn in sorted(optim_state["state"].keys()):
+ if fqn in fqn2spec_info and isinstance(fqn2spec_info[fqn].placement, Shard):
+ cur_spec_info = fqn2spec_info[fqn]
+ fqn_state = {}
+ for key, val in optim_state["state"][fqn].items():
+ # key in OPTIM_STATE_NO_SHARD_KEY in optim stat dict is scalar, like'step', should not be sharded
+ if key not in OPTIM_STATE_NO_SHARD_KEY:
+ val = _shard_tensor(val, cur_spec_info.ep_fsdp_mesh, cur_spec_info.placement)
+ fqn_state[key] = val
+ optim_state["state"][fqn] = fqn_state
+ return optim_state
+
+ # monkey patch
+ FSDP.optim_state_dict = staticmethod(partial(fsdp_optim_state_post_patch_fn, fqn2spec_info=self.fqn2spec_info))
+
+ def patch_fsdp_optim_state_dict_to_load(self):
+ """
+ post optimizer state dict hook when calling `FSDP.optim_state_dict(model, optimizer)`
+
+ This will extend the DTensors in optimizer state dict with EP placements
+
+ Args:
+ fsdp_no_shard_param_names: List[str], like
+ """
+
+ def optim_state_dict_to_load_pre_patch_fn(
+ model,
+ optim,
+ optim_state_dict,
+ is_named_optimizer=False,
+ load_directly=False,
+ group=None,
+ fqn2spec_info: Dict[str, SpecInfo] = None,
+ ):
+ """
+ At this point, the `optim_state_dict` is correctly resharded to the current device mesh by `dcp.load`
+ """
+ assert fqn2spec_info is not None, "if fqn2spec_info is None it should not be patch"
+
+ fsdp_mesh = model._device_mesh
+ assert fsdp_mesh is not None, "Please init FSDP module with device_mesh"
+
+ global_device_mesh = self.ep_fsdp_device_mesh
+ assert global_device_mesh.ndim == 2
+
+ # NOTE we don't support diverse process group for different FSDP sub-modules
+ if self.ep_mesh is not None and self.ep_mesh.size() == self.ep_fsdp_device_mesh.size():
+ for fqn in sorted(optim_state_dict["state"].keys()):
+ if check_any_unflat_param_names_match(fqn, fqn2spec_info):
+ fqn_state = {}
+ for key, val in optim_state_dict["state"][fqn].items():
+ # key in OPTIM_STATE_NO_SHARD_KEY in optim stat dict is scalar, like 'step', should not be sharded
+ if key not in OPTIM_STATE_NO_SHARD_KEY:
+ val = _shard_dtensor(val, self.ep_mesh)
+ fqn_state[key] = val
+ optim_state_dict["state"][fqn] = fqn_state
+
+ fsdp_pg = model.process_group
+ optim_state = orig_optim_state_dict_to_load(
+ model, optim, optim_state_dict, is_named_optimizer, load_directly, fsdp_pg
+ )
+ return optim_state
+
+ # monkey patch
+ FSDP.optim_state_dict_to_load = staticmethod(
+ partial(optim_state_dict_to_load_pre_patch_fn, fqn2spec_info=self.fqn2spec_info)
+ )
+
+
+def register_checkpoint_extension(
+ fsdp_model: FSDP,
+ save_hook_mesh: DeviceMesh = None,
+ fqn2spec_info: Dict[str, SpecInfo] = None,
+):
+ """
+ Register dtensor-based hooks for FSDP+EP
+
+ This will:
+
+ 1. Customize the FSDP extension for save / load hooks in EP scenarios.
+ """
+
+ extension = CheckpointExtensions(
+ ep_fsdp_device_mesh=save_hook_mesh,
+ fqn2spec_info=fqn2spec_info,
+ )
+ for fsdp_module in FSDP.fsdp_modules(fsdp_model):
+ fsdp_module._fsdp_extension = extension
+ fsdp_module._handle._fsdp_extension = extension
+ # make sure the root module is also registered
+ fsdp_model._fsdp_extension = extension
+ fsdp_model._handle._fsdp_extension = extension
+
+ # register load / save hook for ep
+ if fqn2spec_info is not None:
+ state_dict_post_hook_fn = partial(extension.state_dict_post_hook, fqn2spec_info=fqn2spec_info)
+ fsdp_model._register_state_dict_hook(state_dict_post_hook_fn)
+
+ load_state_dict_pre_hook_fn = partial(extension.load_state_dict_pre_hook, fqn2spec_info=fqn2spec_info)
+ fsdp_model._register_load_state_dict_pre_hook(load_state_dict_pre_hook_fn)
+
+ # patch load / save functino for ep
+ extension.patch_convert_state_with_flat_params()
+ extension.patch_fsdp_optim_state_dict()
+ extension.patch_fsdp_optim_state_dict_to_load()
+
+ return fsdp_model
diff --git a/lingbotvla/distributed/fsdp/initialize.py b/lingbotvla/distributed/fsdp/initialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5ef267f7b63577d49b7dce85b60a24bfb7c2056
--- /dev/null
+++ b/lingbotvla/distributed/fsdp/initialize.py
@@ -0,0 +1,288 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import json
+import math
+import os
+from collections import defaultdict
+from typing import Callable, Dict, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from safetensors.torch import load_file
+from torch.distributed._tensor import Replicate, Shard
+
+from ...utils import logging
+from ..parallel_plan import SpecInfo
+
+
+logger = logging.get_logger(__name__)
+
+
+def parallel_load_safetensors(
+ filepath: str, specific_param_name: list[str] = None, ignore_param_name: list[str] = None
+):
+ assert not (specific_param_name is not None and ignore_param_name is not None)
+
+ dist.barrier()
+
+ safetensors2param = {}
+ index_file = os.path.join(filepath, "model.safetensors.index.json")
+ if os.path.exists(index_file):
+ index = json.load(open(index_file, "rb"))
+ for param_name, filename in index["weight_map"].items():
+ if specific_param_name is not None:
+ if param_name not in specific_param_name:
+ continue
+ elif ignore_param_name is not None:
+ if param_name in ignore_param_name:
+ continue
+ safetensors2param.setdefault(filename, []).append(param_name)
+ else:
+ # in this case, the model is small and we can load it all at once
+ param_file = os.path.join(filepath, "model.safetensors")
+ assert os.path.exists(param_file), f"Cannot find {param_file}"
+ states = load_file(param_file)
+ for param_name in states:
+ safetensors2param.setdefault("model.safetensors", []).append(param_name)
+ del states
+
+ total_files = len(safetensors2param)
+ ckpt_chunks = sorted(safetensors2param.keys())
+ world_size = dist.get_world_size()
+ size = int(math.ceil(total_files / world_size))
+ ckpt_chunks = [ckpt_chunks[i * size : (i + 1) * size] for i in range(world_size)]
+
+ shard_states = {}
+ device = torch.cuda.current_device()
+ for rank, files in enumerate(ckpt_chunks):
+ if rank == dist.get_rank():
+ for file in files:
+ safetensors_file = os.path.join(filepath, file)
+ states = load_file(safetensors_file, device=device)
+ valid_states = {k: v for k, v in states.items() if k in safetensors2param[file]}
+ shard_states.update(valid_states)
+ del states
+ else:
+ for file in files:
+ for param_name in safetensors2param[file]:
+ shard_states[param_name] = rank
+ return shard_states
+
+
+def parallel_init_fsdp_fn(
+ module: torch.nn.Module,
+ shard_states: Dict[str, torch.nn.Parameter],
+ remove_standalone: bool = True,
+ specific_param_name: list[str] = None,
+ ignore_param_name: list[str] = None,
+):
+ """
+ Initialize a module with sharded states in a parallel fashion using Fully Sharded Data Parallel (FSDP).
+
+ Args:
+ module (torch.nn.Module): The module to be initialized.
+ shard_states (Dict[str, torch.nn.Parameter]): A dictionary containing sharded states.
+ remove_standalone (bool, optional): If True, only consider shared states. Defaults to True.
+ specific_param_name (list[str], optional): A list of specific parameter names to consider. Defaults to None.
+ ignore_param_name (list[str], optional): A list of parameter names to ignore. Defaults to None.
+
+ Returns:
+ Callable[[torch.nn.Module], torch.nn.Module]: A function that initializes sub-modules of the given module.
+ """
+ assert not (specific_param_name is not None and ignore_param_name is not None)
+ state2fqn = {}
+ for name, state in itertools.chain(
+ module.named_parameters(remove_duplicate=False), module.named_buffers(remove_duplicate=False)
+ ):
+ if specific_param_name is not None:
+ if name not in specific_param_name:
+ continue
+ elif ignore_param_name is not None:
+ if name in ignore_param_name:
+ continue
+ state2fqn.setdefault(state, []).append(name)
+
+ shared = {s for s, names in state2fqn.items() if len(names) > 1} if remove_standalone else set(state2fqn.keys())
+
+ materialized_states = {}
+
+ def make_full_tensor(param: torch.Tensor, spec_info: SpecInfo):
+ """
+ Create a full tensor from a sharded tensor based on the given specification.
+
+ Args:
+ param (torch.Tensor): The sharded tensor.
+ spec_info (SpecInfo): The specification information.
+
+ Returns:
+ torch.Tensor: The full tensor.
+ """
+ device = torch.cuda.current_device()
+ if isinstance(spec_info.placement, Replicate):
+ return torch.empty_like(param.data, device=device)
+ else:
+ assert isinstance(spec_info.placement, Shard)
+ size = list(param.shape)
+ size[spec_info.placement.dim] *= spec_info.ep_mesh.size()
+ return torch.empty(size, dtype=param.dtype, device=device)
+
+ def copy_to_local(param: torch.Tensor, full_data: torch.Tensor, spec_info: SpecInfo):
+ """
+ Copy data from a full tensor to a local sharded tensor based on the given specification.
+
+ Args:
+ param (torch.Tensor): The local sharded tensor.
+ full_data (torch.Tensor): The full tensor.
+ spec_info (SpecInfo): The specification information.
+ """
+ if isinstance(spec_info.placement, Replicate):
+ param.data.copy_(full_data)
+ else:
+ assert isinstance(spec_info.placement, Shard)
+ local_data = full_data.chunk(spec_info.ep_mesh.size(), dim=spec_info.placement.dim)[
+ spec_info.ep_mesh.get_local_rank()
+ ]
+ param.data.copy_(local_data.contiguous())
+ param.spec_info = spec_info
+
+ @torch.no_grad()
+ def create_and_sync_state(param_name, state, is_param):
+ """
+ Create and synchronize a state tensor across multiple devices.
+
+ Args:
+ param_name (str): The name of the parameter.
+ state (torch.Tensor): The state tensor.
+ is_param (bool): Whether the state is a parameter or a buffer.
+
+ Returns:
+ torch.Tensor: The synchronized state tensor.
+ """
+ device = torch.cuda.current_device()
+ if is_param:
+ param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad)
+ else: # buffer
+ param = torch.empty_like(state.data, device=device)
+ if param_name not in shard_states:
+ logger.warn(f"{param_name} not found in shard states, init it from random")
+ assert is_param
+ if dist.get_rank() == 0:
+ initializer_range = (2.5 * max(state.shape)) ** -0.5
+ size = list(state.size())
+ if hasattr(state, "spec_info"):
+ shard = state.spec_info.placement
+ if isinstance(shard, Shard):
+ size[shard.dim] *= state.spec_info.ep_mesh.size()
+ shard_states[param_name] = torch.nn.Parameter(
+ torch.randn(size, dtype=state.dtype, device=device, requires_grad=state.requires_grad)
+ * initializer_range
+ )
+ else:
+ shard_states[param_name] = 0
+ loaded = shard_states[param_name]
+ if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)):
+ loaded = loaded.to(dtype=param.dtype, device=device)
+ dist.broadcast(loaded.data.to(param.dtype), src=dist.get_rank())
+ if hasattr(state, "spec_info"):
+ copy_to_local(param, loaded.data, state.spec_info)
+ else:
+ param.data.copy_(loaded.data)
+ else:
+ assert isinstance(loaded, int) # the rank that holds the state
+ if hasattr(state, "spec_info"):
+ full_data = make_full_tensor(param, state.spec_info)
+ dist.broadcast(full_data, src=loaded)
+ copy_to_local(param, full_data, state.spec_info)
+ else:
+ dist.broadcast(param.data, src=loaded)
+ shard_states.pop(param_name)
+ del loaded
+ return param
+
+ def init_fn(sub_mod: torch.nn.Module):
+ """
+ Initialize a sub-module with sharded states.
+
+ Args:
+ sub_mod (torch.nn.Module): The sub-module to be initialized.
+
+ Returns:
+ torch.nn.Module: The initialized sub-module.
+ """
+ param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(
+ sub_mod.named_buffers(recurse=False)
+ )
+ for name, state in param_and_buffers:
+ if state not in state2fqn:
+ logger.warning_once(f"{name} in {sub_mod.__class__.__name__} not found in state2fqn, skip it")
+ continue
+ is_param = name in sub_mod._parameters
+ fqn = state2fqn[state].pop(0)
+ if (not is_param) and fqn not in shard_states:
+ if state.is_meta:
+ raise RuntimeError(
+ f"find a non-persistent buffer ({fqn}) initiated with device meta. "
+ "Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device."
+ )
+ continue
+ if state in shared:
+ if state not in materialized_states:
+ materialized_states[state] = create_and_sync_state(fqn, state, is_param)
+ else:
+ if fqn in shard_states:
+ shard_states.pop(fqn)
+ materialize_state = materialized_states[state]
+ else:
+ materialize_state = create_and_sync_state(fqn, state, is_param)
+ if is_param:
+ sub_mod._parameters[name] = materialize_state
+ else:
+ sub_mod._buffers[name] = materialize_state
+ return sub_mod
+
+ return init_fn
+
+
+def init_fsdp_fn(model: nn.Module, device: Union[str, "torch.device"]) -> Callable[[nn.Module], None]:
+ """
+ Gets tensor materialization function that supports shared parameters and buffers.
+ Args:
+ model (nn.Module): the top module that may include shared parameters / buffers.
+ device (Union[str, torch.device]): the device to initialize parameters on.
+
+ Returns:
+ Callable[[nn.Module], None]: initialization method to materialize meta tensors on device.
+ """
+ param_occurrence = defaultdict(int)
+ for _, param in model.named_parameters(remove_duplicate=False):
+ param_occurrence[param] += 1
+
+ duplicated_params = {param for param in param_occurrence.keys() if param_occurrence[param] > 1}
+ materialized_params = {}
+
+ def init_fn(module: "nn.Module"):
+ for name, param in module.named_parameters(recurse=False):
+ if param in duplicated_params:
+ module._parameters[name] = materialized_params.setdefault(
+ param, nn.Parameter(torch.empty_like(param.data, device=device), requires_grad=param.requires_grad)
+ )
+ else:
+ module._parameters[name] = nn.Parameter(
+ torch.empty_like(param.data, device=device), requires_grad=param.requires_grad
+ )
+
+ return init_fn
diff --git a/lingbotvla/distributed/moe/__init__.py b/lingbotvla/distributed/moe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ffef3b20b4187b32f9e2b7a6581a021b13ef501
--- /dev/null
+++ b/lingbotvla/distributed/moe/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .moe_layer import EPGroupGemm, preprocess, token_pre_all2all, tokens_post_all2all
+
+
+__all__ = [
+ "preprocess",
+ "token_pre_all2all",
+ "tokens_post_all2all",
+ "EPGroupGemm",
+ "fused_moe_forward",
+]
diff --git a/lingbotvla/distributed/moe/comm.py b/lingbotvla/distributed/moe/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..26b80fb293cbcccc5d31c9b54109dde207661725
--- /dev/null
+++ b/lingbotvla/distributed/moe/comm.py
@@ -0,0 +1,58 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.distributed as dist
+
+
+class _AllToAll(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, group, input, output_split_sizes, input_split_sizes):
+ ctx.group = group
+ ctx.output_split_sizes = output_split_sizes
+ ctx.input_split_sizes = input_split_sizes
+
+ world_size = dist.get_world_size(group=group)
+
+ if world_size == 1:
+ return input
+
+ input = input.contiguous()
+
+ if output_split_sizes is None:
+ output = torch.empty_like(input)
+ else:
+ output = torch.empty(size=(sum(output_split_sizes), input.size(1)), dtype=input.dtype, device=input.device)
+ dist.all_to_all_single(
+ output,
+ input,
+ output_split_sizes=output_split_sizes,
+ input_split_sizes=input_split_sizes,
+ group=group,
+ )
+ return output
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ return (
+ None,
+ _AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes),
+ None,
+ None,
+ )
+
+
+def all_to_all(group, input, output_split_size=None, input_split_size=None):
+ return _AllToAll.apply(group, input, output_split_size, input_split_size)
diff --git a/lingbotvla/distributed/moe/moe_layer.py b/lingbotvla/distributed/moe/moe_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..16b35719d2cfc5589ce4bc42124db1578d60ff73
--- /dev/null
+++ b/lingbotvla/distributed/moe/moe_layer.py
@@ -0,0 +1,300 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+from ...ops.group_gemm.kernel.group_gemm import group_gemm_same_mn, group_gemm_same_nk
+from .comm import all_to_all
+from .moe_utils import generate_weights_idx, permute, sort_chunks_by_idxs, unpermute
+
+
+def preprocess(
+ expert_mask: torch.Tensor,
+ num_experts: int,
+ ep_group: dist.ProcessGroup,
+) -> torch.Tensor:
+ ep_size = ep_group.size()
+ num_local_experts = num_experts // ep_size
+ rank = dist.get_rank(ep_group)
+ num_local_tokens_per_expert = expert_mask.sum(dim=(1, 2))
+
+ # [ep_size] represent the number of sum tokens in each rank
+ input_splits = num_local_tokens_per_expert.reshape(ep_size, num_local_experts).sum(dim=1).tolist()
+
+ # gather all the number of tokens per expert from all ep ranks
+ # [ep_size, num_experts]
+ num_global_tokens_per_expert = torch.zeros(
+ ep_size,
+ num_local_tokens_per_expert.size(0),
+ dtype=num_local_tokens_per_expert.dtype,
+ device=num_local_tokens_per_expert.device,
+ )
+ dist.all_gather_into_tensor(num_global_tokens_per_expert, num_local_tokens_per_expert, group=ep_group)
+
+ # [ep_size, num_local_experts]
+ start_idx, end_idx = rank * num_local_experts, (rank + 1) * num_local_experts
+ num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, start_idx:end_idx].contiguous()
+
+ # [ep_size]
+ output_splits = num_global_tokens_per_local_expert.sum(dim=1).tolist()
+
+ # [num_local_expert]
+ num_global_sum_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=0).to(
+ torch.device("cpu"), non_blocking=True
+ )
+
+ num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view(-1, num_local_experts).to(
+ torch.device("cpu"), non_blocking=True
+ )
+
+ return input_splits, output_splits, num_global_tokens_per_local_expert, num_global_sum_tokens_per_local_expert
+
+
+def token_pre_all2all(
+ hidden_states: torch.Tensor,
+ expert_mask: torch.Tensor,
+ num_experts: int,
+ input_splits: torch.Tensor,
+ output_splits: torch.Tensor,
+ num_global_tokens_per_local_expert: torch.Tensor,
+ ep_group: Optional[dist.ProcessGroup] = None,
+) -> torch.Tensor:
+ hidden_dim = hidden_states.size(-1)
+ hidden_states = hidden_states.reshape(-1, hidden_dim)
+ org_hidden_states_shape = hidden_states.shape
+ routing_map = expert_mask.sum(dim=1)
+
+ local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, routing_map)
+
+ global_permuted_hidden_states = all_to_all(ep_group, local_permuted_hidden_states, output_splits, input_splits)
+
+ # group tokens together by expert
+ num_local_experts = num_experts // ep_group.size()
+ permute_order = torch.arange(num_experts).reshape(-1, num_local_experts).T.ravel().tolist()
+ global_permuted_hidden_states = sort_chunks_by_idxs(
+ global_permuted_hidden_states,
+ num_global_tokens_per_local_expert.ravel(),
+ permute_order,
+ )
+
+ return global_permuted_hidden_states, routing_map, local_input_permutation_mapping, org_hidden_states_shape
+
+
+def tokens_post_all2all(
+ expert_outputs: torch.Tensor,
+ routing_weights: torch.Tensor,
+ selected_experts: int,
+ num_experts: int,
+ input_splits: torch.Tensor,
+ output_splits: torch.Tensor,
+ num_global_tokens_per_local_expert: torch.Tensor,
+ routing_map: torch.Tensor,
+ local_input_permutation_mapping: torch.Tensor,
+ org_hidden_states_shape: torch.Size,
+ ep_group: Optional[dist.ProcessGroup] = None,
+) -> torch.Tensor:
+ # group tokens together by expert
+ num_local_experts = num_experts // ep_group.size()
+ unpermute_order = torch.arange(num_experts).reshape(num_local_experts, -1).T.ravel().tolist()
+ expert_outputs = sort_chunks_by_idxs(
+ expert_outputs,
+ num_global_tokens_per_local_expert.T.ravel(),
+ unpermute_order,
+ )
+
+ unpermute_outputs = all_to_all(ep_group, expert_outputs, input_splits, output_splits)
+
+ # [tokens, experts]
+ weights_idx = generate_weights_idx(routing_weights, selected_experts, num_experts)
+
+ unpermute_outputs = unpermute(
+ unpermute_outputs,
+ weights_idx,
+ org_hidden_states_shape,
+ local_input_permutation_mapping,
+ routing_map,
+ )
+
+ return unpermute_outputs
+
+
+class EPGroupGemm(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ permute_tokens,
+ cumsum,
+ fc1_1_weight,
+ fc1_2_weight,
+ fc2_weight,
+ ):
+ # permute_tokens: [tokens, hidden_dim]
+ # cumsum: [local_experts]
+
+ # compute linear layer fc1-1
+ fc1_1_output = group_gemm_same_nk(
+ a=permute_tokens,
+ b=fc1_1_weight,
+ cumsum_M=cumsum,
+ max_M=permute_tokens.shape[0],
+ transpose_a=False,
+ transpose_b=True,
+ )
+
+ # compute linear layer fc1-2
+ fc1_2_output = group_gemm_same_nk(
+ a=permute_tokens,
+ b=fc1_2_weight,
+ cumsum_M=cumsum,
+ max_M=permute_tokens.shape[0],
+ transpose_a=False,
+ transpose_b=True,
+ )
+
+ # compute the actication of linear layer fc1-1
+ fc1_1_activation = torch.ops.aten.silu(fc1_1_output)
+
+ # compute final result of linear layer fc1
+ fc1_output = fc1_1_activation * fc1_2_output
+
+ # weighted projection is outside this function
+ # compute linear layer fc2
+ fc2_output = group_gemm_same_nk(
+ a=fc1_output,
+ b=fc2_weight,
+ cumsum_M=cumsum,
+ max_M=permute_tokens.shape[0],
+ transpose_a=False,
+ transpose_b=True,
+ )
+
+ ctx.save_for_backward(
+ permute_tokens,
+ cumsum,
+ fc1_1_weight,
+ fc1_2_weight,
+ fc2_weight,
+ fc1_1_output,
+ fc1_2_output,
+ )
+
+ return fc2_output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # grad_output: [tokens, hidden_dim]
+ (
+ permute_tokens,
+ cumsum,
+ fc1_1_weight,
+ fc1_2_weight,
+ fc2_weight,
+ fc1_1_output,
+ fc1_2_output,
+ ) = ctx.saved_tensors
+ # permute_tokens: [tokens, hidden_dim]
+ # cumsum: [local_experts]
+
+ # dgrad fc1
+ grad_fc1_output = group_gemm_same_nk(
+ a=grad_output,
+ b=fc2_weight,
+ cumsum_M=cumsum,
+ max_M=grad_output.shape[0],
+ transpose_b=False,
+ )
+
+ # recompute
+ fc1_1_activation = torch.ops.aten.silu(fc1_1_output)
+ fc1_output = fc1_1_activation * fc1_2_output
+
+ # wgrad fc2
+ grad_fc2_weight = None
+ if fc2_weight.requires_grad:
+ grad_fc2_weight = torch.empty_like(fc2_weight)
+ group_gemm_same_mn(
+ a=grad_output,
+ b=fc1_output,
+ c=grad_fc2_weight,
+ cumsum_K=cumsum,
+ max_K=grad_output.shape[0],
+ transpose_a=True,
+ transpose_b=False,
+ )
+
+ grad_fc1_2_output = fc1_1_activation * grad_fc1_output
+ grad_fc1_1_activation = grad_fc1_output * fc1_2_output
+
+ # dgrad output 2
+ grad_scatter_output_2 = group_gemm_same_nk(
+ a=grad_fc1_2_output,
+ b=fc1_2_weight,
+ cumsum_M=cumsum,
+ max_M=grad_output.shape[0],
+ transpose_b=False,
+ )
+
+ # wgrad fc1-2
+ grad_fc1_2_weight = None
+ if fc1_2_weight.requires_grad:
+ grad_fc1_2_weight = torch.empty_like(fc1_2_weight)
+ group_gemm_same_mn(
+ a=grad_fc1_2_output,
+ b=permute_tokens,
+ c=grad_fc1_2_weight,
+ cumsum_K=cumsum,
+ max_K=grad_output.shape[0],
+ transpose_a=True,
+ transpose_b=False,
+ )
+
+ grad_fc1_1_output = torch.ops.aten.silu_backward(grad_fc1_1_activation, fc1_1_output)
+
+ # dgrad output 1
+ grad_scatter_output_1 = group_gemm_same_nk(
+ a=grad_fc1_1_output,
+ b=fc1_1_weight,
+ cumsum_M=cumsum,
+ max_M=grad_output.shape[0],
+ transpose_b=False,
+ )
+
+ # wgrad fc1-1
+ grad_fc1_1_weight = None
+ if fc1_1_weight.requires_grad:
+ grad_fc1_1_weight = torch.empty_like(fc1_1_weight)
+ group_gemm_same_mn(
+ a=grad_fc1_1_output,
+ b=permute_tokens,
+ c=grad_fc1_1_weight,
+ cumsum_K=cumsum,
+ max_K=grad_output.shape[0],
+ transpose_a=True,
+ transpose_b=False,
+ )
+
+ # grad input
+ grad_permute_tokens = grad_scatter_output_1 + grad_scatter_output_2
+
+ return (
+ grad_permute_tokens, # permute_tokens
+ None, # cumsum
+ grad_fc1_1_weight, # fc1_1_weight
+ grad_fc1_2_weight, # fc1_2_weight
+ grad_fc2_weight, # fc2_weight
+ )
diff --git a/lingbotvla/distributed/moe/moe_utils.py b/lingbotvla/distributed/moe/moe_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b420d3c420dcb64dc1a89542f63e7968172eac9a
--- /dev/null
+++ b/lingbotvla/distributed/moe/moe_utils.py
@@ -0,0 +1,99 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+
+def permute(tokens: torch.Tensor, routing_map: torch.Tensor):
+ """
+ Permutes the tokens according to the routing map.
+
+ Args:
+ tokens (torch.Tensor): The input token tensor, [num_tokens, hidden_dim].
+ routing_map (torch.Tensor): The sparse token to expert mapping, [num_experts, tokens].
+
+ """
+ num_tokens, _ = tokens.shape
+ num_experts = routing_map.shape[0]
+
+ # mask [num_tokens, num_experts] -> [num_experts, num_tokens]
+ routing_map = routing_map.bool()
+
+ # Create a dense expert-to-token mapping from the sparse token-to-expert mapping
+ token_indices = torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
+ sorted_indices = token_indices.masked_select(routing_map)
+
+ # use the mapping to permute the tokens
+ permuted_input = tokens.index_select(0, sorted_indices)
+
+ return permuted_input, sorted_indices
+
+
+def unpermute(
+ tokens: torch.Tensor,
+ routing_weights: torch.Tensor,
+ hidden_states_shape: torch.Size,
+ permutation_mapping: torch.Tensor,
+ routing_map: torch.Tensor,
+):
+ """
+ Unpermutes the tokens and apply the weight.
+
+ Args:
+ tokens (torch.Tensor): The input token tensor, [num_tokens, hidden_dim].
+ routing_weights (torch.Tensor): The routing weights, [num_tokens, num_experts].
+ hidden_states_shape (torch.Size): The shape of the hidden states, [num_tokens, hidden_dim].
+ routing_map (torch.Tensor): The sparse token to expert mapping, [num_experts, tokens].
+
+ Returns:
+ torch.Tensor: The unpermuted token tensor, [num_tokens, hidden_dim].
+ """
+ tokens_weight = routing_weights.T.contiguous().masked_select(routing_map.bool())
+
+ tokens = tokens * tokens_weight.unsqueeze(-1)
+ hidden_dim = hidden_states_shape[-1]
+
+ unpermuted_tokens = torch.zeros(hidden_states_shape, device=tokens.device, dtype=tokens.dtype)
+
+ # Scatter add the permuted_input back to the original positions
+ unpermuted_tokens.scatter_add_(0, permutation_mapping.unsqueeze(1).expand(-1, hidden_dim), tokens)
+ return unpermuted_tokens
+
+
+def generate_weights_idx(routing_weights: torch.Tensor, selected_experts: torch.Tensor, num_experts) -> torch.Tensor:
+ """
+ Generate the weight index for the unpermute operation.
+
+ Args:
+ routing_weights (torch.Tensor): The routing weights. shape [num_tokens, topk].
+ selected_experts (torch.Tensor): The selected experts. shape [num_tokens, topk].
+ num_experts (int): The number of experts. shape [num_tokens, num_experts].
+
+ Returns:
+ torch.Tensor: The weight index.
+ """
+ num_tokens, topk = routing_weights.shape
+ weights_idx = torch.zeros((num_tokens, num_experts), dtype=routing_weights.dtype, device=routing_weights.device)
+
+ weights_idx.scatter_add_(1, selected_experts, routing_weights)
+
+ return weights_idx
+
+
+def sort_chunks_by_idxs(input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor):
+ """Split and sort the input tensor based on the split_sizes and sorted indices."""
+ input = torch.split(input, split_sizes.tolist(), dim=0)
+ output = torch.cat([input[i] for i in sorted_idxs], dim=0)
+ return output
diff --git a/lingbotvla/distributed/offloading.py b/lingbotvla/distributed/offloading.py
new file mode 100644
index 0000000000000000000000000000000000000000..d71b0b28541855ca7ea52b02f1c838676fdd07be
--- /dev/null
+++ b/lingbotvla/distributed/offloading.py
@@ -0,0 +1,87 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import enum
+from contextlib import nullcontext
+from typing import Tuple, Union
+
+import torch
+from torch.autograd.graph import saved_tensors_hooks
+
+
+class OffloadPolicy(enum.Enum):
+ OFFLOAD = 0
+ KEEP_ON_GPU = 1
+ IGNORE = 2
+
+
+class custom_save_on_cpu(saved_tensors_hooks):
+ def __init__(self, gpu_limit_in_gb: float = 0, pin_memory: bool = False, min_offload_size: int = 1024) -> None:
+ self.cur_gpu_ram_in_mb = 0.0
+
+ def pack_to_cpu(tensor: torch.Tensor) -> Tuple[OffloadPolicy, torch.device, torch.Tensor]:
+ tensor_num_bytes = tensor.element_size() * tensor.nelement()
+ # heuristic to skip nn.Linear.weight
+ if type(tensor.grad_fn).__name__ == "TBackward0" or tensor_num_bytes <= min_offload_size:
+ return (OffloadPolicy.IGNORE, tensor.device, tensor)
+
+ if self.cur_gpu_ram_in_mb < gpu_limit_in_gb * 1024:
+ self.cur_gpu_ram_in_mb += tensor_num_bytes / 1024 / 1024
+ return (OffloadPolicy.KEEP_ON_GPU, tensor.device, tensor)
+
+ if not pin_memory:
+ return (OffloadPolicy.OFFLOAD, tensor.device, tensor.cpu())
+
+ packed = torch.empty(
+ tensor.size(),
+ dtype=tensor.dtype,
+ layout=tensor.layout,
+ pin_memory=(not tensor.is_sparse),
+ )
+ packed.copy_(tensor)
+ return (OffloadPolicy.OFFLOAD, tensor.device, packed)
+
+ def unpack_from_cpu(packed: Tuple[OffloadPolicy, torch.device, torch.Tensor]) -> torch.Tensor:
+ offload_policy, device, tensor = packed
+
+ if offload_policy == OffloadPolicy.IGNORE:
+ return tensor
+ elif offload_policy == OffloadPolicy.KEEP_ON_GPU:
+ tensor_num_bytes = tensor.element_size() * tensor.nelement()
+ self.cur_gpu_ram_in_mb -= tensor_num_bytes / 1024 / 1024
+ return tensor
+ else:
+ return tensor.to(device, non_blocking=pin_memory)
+
+ super().__init__(pack_to_cpu, unpack_from_cpu)
+
+
+def build_activation_offloading_context(
+ enable_activation_offload: bool = False,
+ enable_gradient_checkpointing: bool = False,
+ activation_gpu_limit: float = 0.0,
+) -> Tuple[Union["saved_tensors_hooks", "nullcontext"], Union["saved_tensors_hooks", "nullcontext"]]:
+ model_fwd_context, model_bwd_context = nullcontext(), nullcontext()
+ if enable_activation_offload:
+ # pin_memory=False since CachingHostAllocator caches pinned memory aggressively.
+ # torch._C._host_emptyCache() can be used after version 2.5.
+ if enable_gradient_checkpointing:
+ # inter-layer activations are always offloaded when enabling gradient checkpointing to avoid potential thrashing
+ model_fwd_context = custom_save_on_cpu(gpu_limit_in_gb=0.0, pin_memory=False)
+ model_bwd_context = custom_save_on_cpu(gpu_limit_in_gb=activation_gpu_limit, pin_memory=False)
+ else:
+ model_fwd_context = custom_save_on_cpu(gpu_limit_in_gb=activation_gpu_limit, pin_memory=False)
+
+ return model_fwd_context, model_bwd_context
diff --git a/lingbotvla/distributed/parallel_plan.py b/lingbotvla/distributed/parallel_plan.py
new file mode 100644
index 0000000000000000000000000000000000000000..23b059a39ba6411e35fb0bbc886dc3a2b42a6666
--- /dev/null
+++ b/lingbotvla/distributed/parallel_plan.py
@@ -0,0 +1,101 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
+
+from ..utils import logging
+from .utils import check_fqn_match, get_module_from_path, set_module_from_path
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class SpecInfo:
+ ep_fsdp_mesh: DeviceMesh
+ placement: Union[Shard, Replicate]
+ fqn: str
+
+ @property
+ def ep_mesh(self):
+ if self.ep_fsdp_mesh is not None:
+ return self.ep_fsdp_mesh["ep"]
+ else:
+ return None
+
+
+class ParallelPlan:
+ def __init__(self, ep_plan: Dict[str, Shard]):
+ self.ep_plan = ep_plan
+ self.ep_param_suffix = {k.split(".")[-1] for k in ep_plan.keys()}
+ self.fsdp_no_shard_module = {".".join(list(ep_plan.keys())[0].split(".")[:-1])}
+
+ def apply(self, model: nn.Module, ep_fsdp_mesh: DeviceMesh):
+ """
+ ep_fsdp_mesh: [replicate, replicate, ... , shard]
+ """
+ ep_mesh = ep_fsdp_mesh["ep"]
+ # ep_plan
+ fqn2spec_info = {}
+ if self.ep_plan:
+ ep_size = ep_mesh.size(-1)
+ ep_replicate = [Replicate() for _ in range(ep_mesh.ndim)]
+ for fqn, param in model.named_parameters():
+ for fqn_pattern, shard in self.ep_plan.items():
+ if check_fqn_match(fqn_pattern, fqn):
+ assert param.size(shard.dim) % ep_size == 0
+ ep_placement = ep_replicate[:-1] + [shard]
+ dtensor = DTensor.from_local(
+ local_tensor=param.data, device_mesh=ep_mesh, placements=ep_replicate
+ )
+ dtensor = dtensor.redistribute(device_mesh=ep_mesh, placements=ep_placement)
+ local_chunk = torch.nn.Parameter(dtensor.to_local(), requires_grad=param.requires_grad)
+ local_chunk.spec_info = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=shard, fqn=fqn)
+ set_module_from_path(model, fqn, local_chunk)
+ fqn2spec_info[fqn] = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=shard, fqn=fqn)
+ break
+ if fqn not in fqn2spec_info: # not sharded
+ param.spec_info = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=Replicate(), fqn=fqn)
+ fqn2spec_info[fqn] = SpecInfo(ep_fsdp_mesh=ep_fsdp_mesh, placement=Replicate(), fqn=fqn)
+ for param in model.parameters():
+ assert hasattr(param, "spec_info"), f"Internal Error: {param} is omitted"
+
+ return fqn2spec_info
+
+ def get_fsdp_no_shard_info(self, model: nn.Module):
+ if self.fsdp_no_shard_module is None:
+ return None
+
+ fsdp_no_shard_states_fqn_to_module = {}
+ for fqn, param in model.named_modules():
+ for no_shard_pattern in self.fsdp_no_shard_module:
+ if check_fqn_match(no_shard_pattern, fqn):
+ fsdp_no_shard_states_fqn_to_module[fqn] = get_module_from_path(model, fqn)
+ assert len(fsdp_no_shard_states_fqn_to_module) > 0, "no module in model match `fsdp_no_shard_module`"
+
+ return fsdp_no_shard_states_fqn_to_module
+
+ def update_prefix(self, prefix: str):
+ """
+ Update ep_plan when model is wrappered.
+ """
+ self.ep_plan = {prefix + "." + k: v for k, v in self.ep_plan.items()}
+ self.ep_param_suffix = {k.split(".")[-1] for k in self.ep_plan.keys()}
+ self.fsdp_no_shard_module = {".".join(list(self.ep_plan.keys())[0].split(".")[:-1])}
diff --git a/lingbotvla/distributed/parallel_state.py b/lingbotvla/distributed/parallel_state.py
new file mode 100644
index 0000000000000000000000000000000000000000..e78a67b2bb5fb275578ebe9fddb59c4286f2e973
--- /dev/null
+++ b/lingbotvla/distributed/parallel_state.py
@@ -0,0 +1,559 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Adapted from https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
+
+import math
+import os
+from dataclasses import dataclass
+from functools import wraps
+from typing import TYPE_CHECKING, Callable, Literal, Optional
+
+import torch
+from torch import distributed as dist
+
+from ..utils import logging
+from ..utils.import_utils import is_torch_npu_available, is_torch_version_greater_than
+
+
+if is_torch_version_greater_than("2.4"):
+ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
+
+
+if TYPE_CHECKING:
+ from torch.distributed import ProcessGroup
+ from torch.distributed.device_mesh import DeviceMesh
+
+
+logger = logging.get_logger(__name__)
+
+_PARALLEL_STATE: "ParallelState" = None
+
+
+def requires_mesh(fn: Callable) -> Callable:
+ @wraps(fn)
+ def _inner(self: "ParallelState", *args, **kwargs):
+ if self.device_mesh is None:
+ raise ValueError("Device mesh is not initialized.")
+
+ return fn(self, *args, **kwargs)
+
+ return _inner
+
+
+def init_ep_mesh_matrix(ep_size: int, ep_fsdp_size: int, ep_outside: bool = False) -> "DeviceMesh":
+ """
+ Initialize the device mesh matrix for the EP.
+ Args:
+ ep_size (int): The size of the EP.
+ ep_fsdp_size (int): The size of the EP-FSDP.
+ ep_outside (bool): Whether the EP is outside in ep-fsdp group.
+ """
+ if ep_outside:
+ with torch.device("cpu"):
+ mesh = torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int).view(ep_size, ep_fsdp_size)
+ else:
+ with torch.device("cpu"):
+ mesh = (
+ torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int)
+ .view(ep_fsdp_size, ep_size)
+ .transpose(0, 1)
+ )
+ return mesh
+
+
+@dataclass(frozen=True)
+class ParallelState:
+ dp_size: int = 1
+ dp_replicate_size: int = 1
+ dp_shard_size: int = 1
+ tp_size: int = 1
+ ep_size: int = 1
+ pp_size: int = 1
+ cp_size: int = 1
+ ulysses_size: int = 1
+ dp_mode: Literal["ddp", "fsdp1", "fsdp2"] = "fsdp1"
+ device_type: str = "npu" if is_torch_npu_available() else "cuda"
+ include_sp_in_fsdp: bool = True
+ device_mesh: Optional["DeviceMesh"] = None
+ ep_fsdp_device_mesh: Optional["DeviceMesh"] = None
+
+ def __post_init__(self):
+ if not self.include_sp_in_fsdp:
+ raise NotImplementedError("Decoupled sequence parallel has not been implemented.")
+
+ if self.cp_size > 1:
+ raise NotImplementedError("Ring attention is not supported yet.")
+
+ if self.pp_size * self.dp_size * self.cp_size * self.ulysses_size * self.tp_size != self.world_size:
+ raise ValueError("The product of parallel sizes should be equal to the world size.")
+
+ if self.dp_replicate_size * self.dp_shard_size != self.dp_size:
+ raise ValueError(
+ f"The product of dp_replicate_size: {self.dp_replicate_size} and dp_shard_size: {self.dp_shard_size} should be equal to dp_size: {self.dp_size}."
+ )
+
+ if self.sp_enabled:
+ from ..distributed.sequence_parallel import (
+ init_sequence_parallel,
+ set_context_parallel_group,
+ set_data_parallel_group,
+ set_ulysses_sequence_parallel_group,
+ set_unified_sequence_parallel_group,
+ )
+
+ if self.device_mesh is not None:
+ set_data_parallel_group(self.device_mesh.get_group("dp"))
+ if self.ulysses_size > 1:
+ set_ulysses_sequence_parallel_group(self.device_mesh.get_group("ulysses"))
+ if self.cp_size > 1:
+ set_context_parallel_group(self.device_mesh.get_group("cp"))
+ # set unified sequence parallel group
+ set_unified_sequence_parallel_group(self.device_mesh.get_group("sp"))
+ else:
+ init_sequence_parallel(
+ ulysses_size=self.ulysses_size,
+ sep_dp=True,
+ ulysses_group_key="default",
+ cp_size=self.cp_size,
+ )
+
+ @property
+ def is_initialized(self) -> bool:
+ return dist.is_initialized()
+
+ @property
+ def local_rank(self) -> int:
+ return int(os.getenv("LOCAL_RANK", "-1"))
+
+ @property
+ def global_rank(self) -> int:
+ if self.is_initialized:
+ return dist.get_rank()
+ return -1
+
+ @property
+ def world_size(self) -> int:
+ if self.is_initialized:
+ return dist.get_world_size()
+ return 1
+
+ # ------------------------------ DP ------------------------------ #
+ @property
+ def dp_group(self) -> Optional["ProcessGroup"]:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_group("dp")
+
+ if self.sp_enabled:
+ from ..distributed.sequence_parallel import get_data_parallel_group
+
+ return get_data_parallel_group()
+
+ return self.fsdp_group
+
+ @property
+ def dp_rank(self) -> int:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_local_rank("dp")
+
+ if self.sp_enabled:
+ from ..distributed.sequence_parallel import get_data_parallel_rank
+
+ return get_data_parallel_rank()
+
+ return self.fsdp_rank
+
+ @property
+ @requires_mesh
+ def dp_mesh(self) -> "DeviceMesh":
+ if self.device_mesh is not None:
+ return self.device_mesh["dp"]
+
+ raise self.fsdp_mesh
+
+ @property
+ def dp_enabled(self) -> bool:
+ return self.dp_size > 1
+
+ # ------------------------------ DP replicate ------------------------------ #
+ @property
+ def dp_replicate_group(self) -> Optional["ProcessGroup"]:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_group("dp_replicate")
+
+ @property
+ def dp_replicate_rank(self) -> int:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_local_rank("dp_replicate")
+
+ @property
+ @requires_mesh
+ def dp_replicate_mesh(self) -> "DeviceMesh":
+ if self.device_mesh is not None:
+ return self.device_mesh["dp_replicate"]
+
+ @property
+ def dp_replicate_enabled(self) -> bool:
+ return self.dp_replicate_size > 1
+
+ # ------------------------------ DP shard ------------------------------ #
+ @property
+ def dp_shard_group(self) -> Optional["ProcessGroup"]:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_group("dp_shard")
+
+ @property
+ def dp_shard_rank(self) -> int:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_local_rank("dp_shard")
+
+ @property
+ @requires_mesh
+ def dp_shard_mesh(self) -> "DeviceMesh":
+ if self.device_mesh is not None:
+ return self.device_mesh["dp_shard"]
+
+ @property
+ def dp_shard_enabled(self) -> bool:
+ return self.dp_shard_size >= 1
+
+ # ----------------------------- FSDP ----------------------------- #
+ @property
+ def fsdp_group(self) -> Optional["ProcessGroup"]:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_group("dp_sp")
+
+ @property
+ def fsdp_rank(self) -> int:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_local_rank("dp_sp")
+
+ return self.global_rank
+
+ @property
+ def dp_shard_sp_enabled(self) -> bool:
+ return self.dp_shard_enabled and self.sp_enabled
+
+ @property
+ @requires_mesh
+ def fsdp_mesh(self) -> "DeviceMesh":
+ if self.dp_replicate_enabled:
+ # HSDP
+ if self.dp_shard_sp_enabled:
+ return self.device_mesh["dp_replicate", "dp_shard_sp"]
+ elif self.dp_shard_enabled:
+ return self.device_mesh["dp_replicate", "dp_shard"]
+ else:
+ # DDP
+ return self.device_mesh["dp_replicate"]
+ # FSDP
+ elif self.dp_shard_sp_enabled:
+ return self.device_mesh["dp_shard_sp"]
+ elif self.dp_shard_enabled:
+ return self.device_mesh["dp_shard"]
+ else:
+ return self.device_mesh["dp"]
+
+ @property
+ def fsdp_enabled(self) -> bool:
+ return self.fsdp_size > 1
+
+ @property
+ def fsdp_size(self) -> int:
+ return self.world_size // (self.pp_size * self.tp_size)
+
+ # ------------------------------ TP ------------------------------ #
+ @property
+ @requires_mesh
+ def tp_rank(self) -> int:
+ return self.device_mesh.get_local_rank("tp")
+
+ @property
+ @requires_mesh
+ def tp_mesh(self) -> "DeviceMesh":
+ return self.device_mesh["tp"]
+
+ @property
+ def tp_enabled(self) -> bool:
+ return self.tp_size > 1
+
+ # ------------------------------ PP ------------------------------ #
+ @property
+ @requires_mesh
+ def pp_rank(self) -> int:
+ return self.device_mesh.get_local_rank("pp")
+
+ @property
+ @requires_mesh
+ def pp_mesh(self) -> "DeviceMesh":
+ return self.device_mesh["pp"]
+
+ @property
+ def pp_enabled(self) -> bool:
+ return self.pp_size > 1
+
+ @property
+ @requires_mesh
+ def is_first_pp_stage(self) -> bool:
+ return self.pp_rank == 0
+
+ @property
+ @requires_mesh
+ def is_last_pp_stage(self) -> bool:
+ return self.pp_rank == (self.pp_size - 1)
+
+ # ------------------------------ EP ------------------------------ #
+ @property
+ @requires_mesh
+ def ep_mesh(self) -> "DeviceMesh":
+ return self.ep_fsdp_device_mesh["ep"]
+
+ @property
+ @requires_mesh
+ def ep_fsdp_mesh(self) -> "DeviceMesh":
+ return self.ep_fsdp_device_mesh["ep", "ep_fsdp"]
+
+ @property
+ @requires_mesh
+ def ep_group(self) -> "ProcessGroup":
+ return self.ep_mesh.get_group()
+
+ @property
+ def ep_enabled(self) -> bool:
+ return self.ep_size > 1
+
+ @property
+ def ep_rank(self) -> int:
+ return self.ep_fsdp_device_mesh.get_local_rank("ep")
+
+ # ------------------------------ SP ------------------------------ #
+ @property
+ def sp_group(self) -> Optional["ProcessGroup"]:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_group("sp")
+
+ if self.sp_enabled:
+ from .sequence_parallel import get_unified_sequence_parallel_group
+
+ return get_unified_sequence_parallel_group()
+
+ return None
+
+ @property
+ def sp_rank(self) -> int:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_local_rank("sp")
+
+ if self.sp_enabled:
+ from .sequence_parallel import get_unified_sequence_parallel_rank
+
+ return get_unified_sequence_parallel_rank()
+
+ return -1
+
+ @property
+ def sp_enabled(self) -> bool:
+ return self.cp_size > 1 or self.ulysses_size > 1
+
+ @property
+ def sp_size(self) -> int:
+ return self.ulysses_size * self.cp_size
+
+ @property
+ def ulysses_group(self) -> Optional["ProcessGroup"]:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_group("ulysses")
+
+ if self.sp_enabled:
+ from .sequence_parallel import get_ulysses_sequence_parallel_group
+
+ return get_ulysses_sequence_parallel_group()
+
+ return None
+
+ @property
+ def ulysses_rank(self) -> int:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_local_rank("ulysses")
+
+ if self.sp_enabled:
+ from .sequence_parallel import get_ulysses_sequence_parallel_rank
+
+ return get_ulysses_sequence_parallel_rank()
+
+ return -1
+
+ @property
+ def ulysses_enabled(self) -> bool:
+ return self.ulysses_size > 1
+
+ @property
+ def cp_group(self) -> Optional["ProcessGroup"]:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_group("cp")
+
+ if self.sp_enabled:
+ from .sequence_parallel import get_context_parallel_group
+
+ return get_context_parallel_group()
+
+ return None
+
+ @property
+ def cp_rank(self) -> int:
+ if self.device_mesh is not None:
+ return self.device_mesh.get_local_rank("cp")
+
+ if self.sp_enabled:
+ from .sequence_parallel import get_context_parallel_rank
+
+ return get_context_parallel_rank()
+
+ return -1
+
+ @property
+ def cp_enabled(self) -> bool:
+ return self.cp_size > 1
+
+
+def init_parallel_state(
+ dp_size: int = 1,
+ dp_replicate_size: int = 1,
+ dp_shard_size: int = 1,
+ tp_size: int = 1,
+ ep_size: int = 1,
+ pp_size: int = 1,
+ cp_size: int = 1,
+ ulysses_size: int = 1,
+ dp_mode: Literal["ddp", "fsdp1", "fsdp2"] = "fsdp1",
+ device_type: str = None,
+ include_sp_in_fsdp: bool = True,
+ ep_outside: bool = False,
+) -> None:
+ """
+ Initializes global parallel state.
+ """
+ global _PARALLEL_STATE
+ if _PARALLEL_STATE is not None:
+ logger.warning("Parallel state has already been initialized.")
+ return
+
+ if device_type is None:
+ device_type = "npu" if is_torch_npu_available() else "cuda"
+
+ # Set dp_shard_size to dp_size if dp_shard_size and dp_replicate_size are not set when dp enabled
+ if dp_size > 1 and dp_shard_size == 1 and dp_replicate_size == 1:
+ dp_shard_size = dp_size
+
+ logger.info_rank0(
+ f"Initializing parallel state... dp_size {dp_size}, dp_replicate_size {dp_replicate_size}, dp_shard_size {dp_shard_size},tp_size {tp_size}, pp_size {pp_size}, cp_size {cp_size}, ulysses_size {ulysses_size}"
+ )
+
+ device_mesh, ep_fsdp_device_mesh = None, None
+ if is_torch_version_greater_than("2.4"):
+ mesh_shape = []
+ mesh_dim_names = []
+ for d, name in zip(
+ [pp_size, dp_replicate_size, dp_shard_size, ulysses_size, cp_size, tp_size],
+ ["pp", "dp_replicate", "dp_shard", "ulysses", "cp", "tp"],
+ ):
+ if d > 1 or name in ["dp_shard"]:
+ mesh_shape.append(d)
+ mesh_dim_names.append(name)
+
+ device_mesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=tuple(mesh_shape),
+ mesh_dim_names=tuple(mesh_dim_names),
+ )
+
+ # Mesh for data loading (no communication on this mesh)
+ dp_mesh_dim_names = []
+ # Mesh for param sharding
+ dp_shard_sp_mesh_dim_names = []
+ # Mesh for loss all-reduce
+ dp_sp_mesh_dim_names = []
+ # Mesh for sequence parallel
+ sp_mesh_dim_names = []
+
+ if dp_replicate_size > 1:
+ dp_mesh_dim_names.append("dp_replicate")
+ dp_sp_mesh_dim_names.append("dp_replicate")
+ if dp_shard_size >= 1:
+ dp_mesh_dim_names.append("dp_shard")
+ dp_shard_sp_mesh_dim_names.append("dp_shard")
+ dp_sp_mesh_dim_names.append("dp_shard")
+ if ulysses_size > 1:
+ dp_shard_sp_mesh_dim_names.append("ulysses")
+ sp_mesh_dim_names.append("ulysses")
+ dp_sp_mesh_dim_names.append("ulysses")
+ if cp_size > 1:
+ dp_shard_sp_mesh_dim_names.append("cp")
+ sp_mesh_dim_names.append("cp")
+ dp_sp_mesh_dim_names.append("cp")
+
+ if dp_mesh_dim_names != []:
+ device_mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
+
+ if dp_shard_sp_mesh_dim_names != []:
+ device_mesh[tuple(dp_shard_sp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_sp")
+
+ if dp_sp_mesh_dim_names != []:
+ device_mesh[tuple(dp_sp_mesh_dim_names)]._flatten(mesh_dim_name="dp_sp")
+
+ if sp_mesh_dim_names != []:
+ device_mesh[tuple(sp_mesh_dim_names)]._flatten(mesh_dim_name="sp")
+
+ if ep_size > 1:
+ world_size = dist.get_world_size()
+ assert world_size % ep_size == 0, "ep_size must be a factor of world_size"
+ ep_fsdp_size = world_size // ep_size
+
+ mesh = init_ep_mesh_matrix(ep_size=ep_size, ep_fsdp_size=ep_fsdp_size, ep_outside=ep_outside)
+ ep_fsdp_device_mesh = DeviceMesh(
+ device_type=device_type,
+ mesh=mesh,
+ mesh_dim_names=("ep", "ep_fsdp"),
+ )
+
+ logger.info_rank0(f"Device mesh: {device_mesh}")
+ logger.info_rank0(f"EP FSDP device mesh: {ep_fsdp_device_mesh}")
+
+ _PARALLEL_STATE = ParallelState(
+ dp_size=dp_size,
+ dp_replicate_size=dp_replicate_size,
+ dp_shard_size=dp_shard_size,
+ tp_size=tp_size,
+ ep_size=ep_size,
+ pp_size=pp_size,
+ cp_size=cp_size,
+ ulysses_size=ulysses_size,
+ dp_mode=dp_mode,
+ device_type=device_type,
+ include_sp_in_fsdp=include_sp_in_fsdp,
+ device_mesh=device_mesh,
+ ep_fsdp_device_mesh=ep_fsdp_device_mesh,
+ )
+
+
+def get_parallel_state() -> "ParallelState":
+ """
+ Returns global parallel state.
+ """
+ if _PARALLEL_STATE is None:
+ logger.warning_once("Parallel state has not been initialized. returning default Single-process state.")
+ return ParallelState()
+
+ return _PARALLEL_STATE
diff --git a/lingbotvla/distributed/sequence_parallel/__init__.py b/lingbotvla/distributed/sequence_parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31af3072517e12880810755b8220d72ed1f3e866
--- /dev/null
+++ b/lingbotvla/distributed/sequence_parallel/__init__.py
@@ -0,0 +1,89 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .async_ulysses import (
+ async_ulysses_output_projection,
+ async_ulysses_qkv_projection,
+ divide_qkv_linear_bias,
+ divide_qkv_linear_weight,
+)
+from .comm import (
+ get_context_parallel_group,
+ get_context_parallel_rank,
+ get_context_parallel_world_size,
+ get_data_parallel_group,
+ get_data_parallel_rank,
+ get_ulysses_sequence_parallel_group,
+ get_ulysses_sequence_parallel_rank,
+ get_ulysses_sequence_parallel_world_size,
+ get_unified_sequence_parallel_group,
+ get_unified_sequence_parallel_rank,
+ get_unified_sequence_parallel_world_size,
+ init_sequence_parallel,
+ set_context_parallel_group,
+ set_data_parallel_group,
+ set_ulysses_sequence_parallel_group,
+ set_unified_sequence_parallel_group,
+)
+from .data import (
+ gather_outputs,
+ sequence_parallel_preprocess,
+ slice_input_tensor,
+ slice_input_tensor_scale_grad,
+ slice_position_embedding,
+)
+from .loss import reduce_sequence_parallel_loss
+from .ulysses import (
+ all_to_all_images,
+ gather_heads_scatter_seq,
+ gather_seq_scatter_heads,
+)
+from .utils import pad_tensor, unpad_tensor, vlm_images_a2a_meta
+
+
+__all__ = [
+ "init_sequence_parallel",
+ "set_data_parallel_group",
+ "get_data_parallel_group",
+ "get_data_parallel_rank",
+ "set_ulysses_sequence_parallel_group",
+ "get_ulysses_sequence_parallel_world_size",
+ "get_ulysses_sequence_parallel_rank",
+ "get_ulysses_sequence_parallel_group",
+ "set_context_parallel_group",
+ "get_context_parallel_group",
+ "get_context_parallel_rank",
+ "get_context_parallel_world_size",
+ "set_unified_sequence_parallel_group",
+ "get_unified_sequence_parallel_group",
+ "get_unified_sequence_parallel_rank",
+ "get_unified_sequence_parallel_world_size",
+ "slice_input_tensor",
+ "slice_input_tensor_scale_grad",
+ "slice_position_embedding",
+ "sequence_parallel_preprocess",
+ "gather_heads_scatter_seq",
+ "gather_seq_scatter_heads",
+ "all_to_all_images",
+ "gather_outputs",
+ "vlm_images_a2a_meta",
+ "pad_tensor",
+ "unpad_tensor",
+ "reduce_sequence_parallel_loss",
+ "async_ulysses_qkv_projection",
+ "async_ulysses_output_projection",
+ "divide_qkv_linear_weight",
+ "divide_qkv_linear_bias",
+]
diff --git a/lingbotvla/distributed/sequence_parallel/async_ulysses.py b/lingbotvla/distributed/sequence_parallel/async_ulysses.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a48806633df9ac4bf8ad082aa6b5b84bd9dac68
--- /dev/null
+++ b/lingbotvla/distributed/sequence_parallel/async_ulysses.py
@@ -0,0 +1,491 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import numbers
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch.distributed import ProcessGroup
+
+from .comm import get_ulysses_sequence_parallel_group
+from .ulysses import all_to_all_tensor
+from .utils import padding_tensor_for_seqeunce_parallel, unpadding_tensor_for_seqeunce_parallel
+
+
+fused_layer_norm_cuda = None
+
+
+def divide_qkv_linear_weight(weight: Tensor, dim: int):
+ return weight.chunk(3, dim=dim)
+
+
+def divide_qkv_linear_bias(bias: Tensor, dim: int):
+ if bias is not None:
+ return bias.chunk(3, dim=dim)
+ else:
+ return None, None, None
+
+
+class AsyncUlyssesQKVProjection(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ hidden_states: Tensor,
+ seq_dimension: int,
+ head_dimension: int,
+ q_weight: Tensor,
+ q_bias: Tensor,
+ k_weight: Tensor,
+ k_bias: Tensor,
+ v_weight: Tensor,
+ v_bias: Tensor,
+ norm_type: str,
+ norm_q_weight: Tensor,
+ norm_q_bias: Tensor,
+ norm_k_weight: Tensor,
+ norm_k_bias: Tensor,
+ normalized_shape: int,
+ eps: float,
+ unpadded_dim_size: int,
+ head_dim: int,
+ group: ProcessGroup,
+ ):
+ sp_group = get_ulysses_sequence_parallel_group() if group is None else group
+
+ # q projection
+ q = F.linear(hidden_states, q_weight, q_bias)
+
+ # q communication launch
+ q_res = all_to_all_tensor(
+ q, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
+ )
+
+ # k projection
+ k = F.linear(hidden_states, k_weight, k_bias)
+
+ # k communication launch
+ k_res = all_to_all_tensor(
+ k, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
+ )
+
+ # v projection
+ v = F.linear(hidden_states, v_weight, v_bias)
+
+ # v communication launch
+ v_res = all_to_all_tensor(
+ v, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
+ )
+
+ # q communication collect
+ q = q_res()
+ q = unpadding_tensor_for_seqeunce_parallel(q, seq_dimension, unpadded_dim_size)
+ q = q.reshape(list(q.shape[:-1]) + [-1, head_dim]).contiguous()
+
+ # k communication collect
+ k = k_res()
+ k = unpadding_tensor_for_seqeunce_parallel(k, seq_dimension, unpadded_dim_size)
+ k = k.reshape(list(k.shape[:-1]) + [-1, head_dim]).contiguous()
+
+ # qk normalization (if needed)
+ if norm_type is not None:
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ normalized_shape = torch.Size(normalized_shape)
+ global fused_layer_norm_cuda
+ if fused_layer_norm_cuda is None:
+ fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
+ norm_q_weight = norm_q_weight.contiguous()
+ norm_k_weight = norm_k_weight.contiguous()
+ output_q, mean_q, invvar_q = None, None, None
+ output_k, mean_k, invvar_k = None, None, None
+ if norm_type == "rmsnorm":
+ output_q, invvar_q = fused_layer_norm_cuda.rms_forward_affine(q, normalized_shape, norm_q_weight, eps)
+ output_k, invvar_k = fused_layer_norm_cuda.rms_forward_affine(k, normalized_shape, norm_k_weight, eps)
+ elif norm_type == "layernorm":
+ output_q, mean_q, invvar_q = fused_layer_norm_cuda.forward_affine(
+ q, normalized_shape, norm_q_weight, norm_q_bias, eps
+ )
+ output_k, mean_k, invvar_k = fused_layer_norm_cuda.forward_affine(
+ k, normalized_shape, norm_k_weight, norm_k_bias, eps
+ )
+ else:
+ raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
+ else:
+ output_q = q
+ output_k = k
+ mean_q = None
+ mean_k = None
+ invvar_q = None
+ invvar_k = None
+
+ # v communication collect
+ v = v_res()
+ v = unpadding_tensor_for_seqeunce_parallel(v, seq_dimension, unpadded_dim_size)
+ v = v.reshape(list(v.shape[:-1]) + [-1, head_dim]).contiguous()
+
+ # save ctx for backward
+ ctx.sp_group = sp_group
+ ctx.head_dimension = head_dimension
+ ctx.seq_dimension = seq_dimension
+ ctx.norm_type = norm_type
+ ctx.normalized_shape = normalized_shape
+ ctx.eps = eps
+ ctx.save_for_backward(
+ hidden_states,
+ q_weight,
+ q_bias,
+ k_weight,
+ k_bias,
+ v_weight,
+ v_bias,
+ q,
+ norm_q_weight,
+ norm_q_bias,
+ mean_q,
+ invvar_q,
+ k,
+ norm_k_weight,
+ norm_k_bias,
+ mean_k,
+ invvar_k,
+ )
+
+ return output_q, output_k, v
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor):
+ # get ctx for backward
+ sp_group = ctx.sp_group
+ seq_dimension = ctx.seq_dimension
+ head_dimension = ctx.head_dimension
+ norm_type = ctx.norm_type
+ normalized_shape = ctx.normalized_shape
+ eps = ctx.eps
+ (
+ hidden_states,
+ q_weight,
+ q_bias,
+ k_weight,
+ k_bias,
+ v_weight,
+ v_bias,
+ q,
+ norm_q_weight,
+ norm_q_bias,
+ mean_q,
+ invvar_q,
+ k,
+ norm_k_weight,
+ norm_k_bias,
+ mean_k,
+ invvar_k,
+ ) = ctx.saved_tensors
+
+ # initialize grads
+ grad_hidden_states = None
+ grad_q_weight = None
+ grad_q_bias = None
+ grad_k_weight = None
+ grad_k_bias = None
+ grad_v_weight = None
+ grad_v_bias = None
+ grad_norm_q_weight = None
+ grad_norm_q_bias = None
+ grad_norm_k_weight = None
+ grad_norm_k_bias = None
+
+ # v grad communication launch
+ grad_v = grad_output[2].contiguous()
+ grad_v = grad_v.reshape(list(grad_v.shape[:-2]) + [-1]).contiguous()
+ grad_v = padding_tensor_for_seqeunce_parallel(grad_v, dim=seq_dimension)
+ grad_v_res = all_to_all_tensor(
+ grad_v,
+ scatter_dim=seq_dimension,
+ gather_dim=head_dimension,
+ group=sp_group,
+ async_op=True,
+ )
+
+ # qk normalization backward (if needed)
+ if norm_type is not None:
+ if norm_type == "rmsnorm":
+ grad_k, grad_norm_k_weight = fused_layer_norm_cuda.rms_backward_affine(
+ grad_output[1].contiguous(),
+ invvar_k,
+ k,
+ normalized_shape,
+ norm_k_weight,
+ eps,
+ False,
+ )
+ grad_q, grad_norm_q_weight = fused_layer_norm_cuda.rms_backward_affine(
+ grad_output[0].contiguous(),
+ invvar_q,
+ q,
+ normalized_shape,
+ norm_q_weight,
+ eps,
+ False,
+ )
+ elif norm_type == "layernorm":
+ grad_k, grad_norm_k_weight, grad_norm_k_bias = fused_layer_norm_cuda.backward_affine(
+ grad_output[1].contiguous(),
+ mean_k,
+ invvar_k,
+ k,
+ normalized_shape,
+ norm_k_weight,
+ norm_k_bias,
+ eps,
+ False,
+ )
+ grad_q, grad_norm_q_weight, grad_norm_q_bias = fused_layer_norm_cuda.backward_affine(
+ grad_output[0].contiguous(),
+ mean_q,
+ invvar_q,
+ q,
+ normalized_shape,
+ norm_q_weight,
+ norm_q_bias,
+ eps,
+ False,
+ )
+ else:
+ raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
+ else:
+ grad_k = grad_output[1].contiguous()
+ grad_q = grad_output[0].contiguous()
+ grad_norm_k_weight = None
+ grad_norm_q_weight = None
+
+ # v grad communication collect
+ grad_v = grad_v_res()
+
+ # k grad communication launch
+ grad_k = grad_k.reshape(list(grad_k.shape[:-2]) + [-1]).contiguous()
+ grad_k = padding_tensor_for_seqeunce_parallel(grad_k, dim=seq_dimension)
+ grad_k_res = all_to_all_tensor(
+ grad_k,
+ scatter_dim=seq_dimension,
+ gather_dim=head_dimension,
+ group=sp_group,
+ async_op=True,
+ )
+
+ # v projection grad
+ grad_v_input = grad_v @ v_weight
+ grad_v_weight = grad_v.transpose(-1, -2) @ hidden_states
+ if v_bias is not None and ctx.needs_input_grad[7]:
+ grad_v_bias = grad_v.sum(0)
+
+ # k grad communication collect
+ grad_k = grad_k_res()
+
+ # q grad communication launch
+ grad_q = grad_q.reshape(list(grad_q.shape[:-2]) + [-1]).contiguous()
+ grad_q = padding_tensor_for_seqeunce_parallel(grad_q, dim=seq_dimension)
+ grad_q_res = all_to_all_tensor(
+ grad_q,
+ scatter_dim=seq_dimension,
+ gather_dim=head_dimension,
+ group=sp_group,
+ async_op=True,
+ )
+
+ # k projection grad
+ grad_k_input = grad_k @ k_weight
+ grad_k_weight = grad_k.transpose(-1, -2) @ hidden_states
+ if k_bias is not None and ctx.needs_input_grad[5]:
+ grad_k_bias = grad_k.sum(0)
+
+ # q grad communication collect
+ grad_q = grad_q_res()
+
+ # q projection grad
+ grad_q_input = grad_q @ q_weight
+ grad_q_weight = grad_q.transpose(-1, -2) @ hidden_states
+ if q_bias is not None and ctx.needs_input_grad[3]:
+ grad_q_bias = grad_q.sum(0)
+
+ # grad
+ grad_hidden_states = grad_q_input + grad_k_input + grad_v_input
+
+ return (
+ grad_hidden_states,
+ None,
+ None,
+ grad_q_weight,
+ grad_q_bias,
+ grad_k_weight,
+ grad_k_bias,
+ grad_v_weight,
+ grad_v_bias,
+ None,
+ grad_norm_q_weight,
+ grad_norm_q_bias,
+ grad_norm_k_weight,
+ grad_norm_k_bias,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class AsyncUlyssesOutputProjection(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ hidden_states: Tensor,
+ seq_dimension: int,
+ head_dimension: int,
+ proj_weight: Tensor,
+ proj_bias: Tensor,
+ unpadded_dim_size: int,
+ group: ProcessGroup,
+ ):
+ sp_group = get_ulysses_sequence_parallel_group() if group is None else group
+
+ # out projection
+ hidden_states = padding_tensor_for_seqeunce_parallel(hidden_states, seq_dimension)
+ hidden_states = all_to_all_tensor(
+ hidden_states, scatter_dim=seq_dimension, gather_dim=head_dimension, group=sp_group
+ )
+ o = F.linear(hidden_states, proj_weight, proj_bias)
+
+ # save ctx for backward
+ ctx.sp_group = sp_group
+ ctx.head_dimension = head_dimension
+ ctx.seq_dimension = seq_dimension
+ ctx.unpadded_dim_size = unpadded_dim_size
+
+ ctx.save_for_backward(
+ hidden_states,
+ proj_weight,
+ proj_bias,
+ )
+
+ return o
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor):
+ # get ctx for backward
+ sp_group = ctx.sp_group
+ head_dimension = ctx.head_dimension
+ seq_dimension = ctx.seq_dimension
+ unpadded_dim_size = ctx.unpadded_dim_size
+ (
+ hidden_states,
+ proj_weight,
+ proj_bias,
+ ) = ctx.saved_tensors
+
+ # initialize grads
+ grad_o = None
+ grad_proj_weight = None
+ grad_proj_bias = None
+
+ # output grad
+ grad_o = grad_output[0] @ (proj_weight)
+
+ # output grad communication launch
+ grad_out_res = all_to_all_tensor(
+ grad_o, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
+ )
+
+ grad_proj_weight = grad_output[0].transpose(-1, -2) @ (hidden_states)
+ if proj_bias is not None and ctx.needs_input_grad[3]:
+ grad_proj_bias = grad_output[0].sum(0)
+
+ # output grad communication collect
+ grad_o = grad_out_res()
+ grad_o = unpadding_tensor_for_seqeunce_parallel(grad_o, seq_dimension, unpadded_dim_size)
+
+ return (
+ grad_o,
+ None,
+ None,
+ grad_proj_weight,
+ grad_proj_bias,
+ None,
+ None,
+ )
+
+
+def async_ulysses_qkv_projection(
+ hidden_states: Tensor = None,
+ seq_dimension: int = None,
+ head_dimension: int = None,
+ q_weight: Tensor = None,
+ q_bias: Optional[Tensor] = None,
+ k_weight: Tensor = None,
+ k_bias: Optional[Tensor] = None,
+ v_weight: Tensor = None,
+ v_bias: Optional[Tensor] = None,
+ norm_type: str = None,
+ norm_q_weight: Optional[Tensor] = None,
+ norm_q_bias: Optional[Tensor] = None,
+ norm_k_weight: Optional[Tensor] = None,
+ norm_k_bias: Optional[Tensor] = None,
+ normalized_shape: Optional[int] = None,
+ eps: Optional[float] = None,
+ unpadded_dim_size: int = None,
+ head_dim: int = None,
+ group: Optional[ProcessGroup] = None,
+):
+ return AsyncUlyssesQKVProjection.apply(
+ hidden_states,
+ seq_dimension,
+ head_dimension,
+ q_weight,
+ q_bias,
+ k_weight,
+ k_bias,
+ v_weight,
+ v_bias,
+ norm_type,
+ norm_q_weight,
+ norm_q_bias,
+ norm_k_weight,
+ norm_k_bias,
+ normalized_shape,
+ eps,
+ unpadded_dim_size,
+ head_dim,
+ group,
+ )
+
+
+def async_ulysses_output_projection(
+ hidden_states: Optional[Tensor] = None,
+ seq_dimension: int = None,
+ head_dimension: int = None,
+ proj_weight: Optional[Tensor] = None,
+ proj_bias: Optional[Tensor] = None,
+ unpadded_dim_size: Optional[int] = None,
+ group: Optional[ProcessGroup] = None,
+):
+ return AsyncUlyssesOutputProjection.apply(
+ hidden_states,
+ seq_dimension,
+ head_dimension,
+ proj_weight,
+ proj_bias,
+ unpadded_dim_size,
+ group,
+ )
diff --git a/lingbotvla/distributed/sequence_parallel/async_ulysses_dit.py b/lingbotvla/distributed/sequence_parallel/async_ulysses_dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf0a6c77d1aab163fff878c8327984b579cf238c
--- /dev/null
+++ b/lingbotvla/distributed/sequence_parallel/async_ulysses_dit.py
@@ -0,0 +1,509 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import numbers
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+from torch.distributed import ProcessGroup
+
+from .comm import get_ulysses_sequence_parallel_group
+from .ulysses import all_to_all_tensor
+from .utils import padding_tensor_for_seqeunce_parallel, unpadding_tensor_for_seqeunce_parallel
+
+
+fused_layer_norm_cuda = None
+
+
+def divide_qkv_linear_weight(weight: Tensor, dim: int):
+ return weight.chunk(3, dim=dim)
+
+
+def divide_qkv_linear_bias(bias: Tensor, dim: int):
+ if bias is not None:
+ return bias.chunk(3, dim=dim)
+ else:
+ return None, None, None
+
+
+class AsyncUlyssesQKVProjection(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ hidden_states: Tensor,
+ seq_dimension: int,
+ head_dimension: int,
+ q_weight: Tensor,
+ q_bias: Tensor,
+ k_weight: Tensor,
+ k_bias: Tensor,
+ v_weight: Tensor,
+ v_bias: Tensor,
+ norm_type: str,
+ norm_q_weight: Tensor,
+ norm_q_bias: Tensor,
+ norm_k_weight: Tensor,
+ norm_k_bias: Tensor,
+ normalized_shape: int,
+ eps: float,
+ unpadded_dim_size: int,
+ head_dim: int,
+ group: ProcessGroup,
+ ):
+ sp_group = get_ulysses_sequence_parallel_group() if group is None else group
+
+ # q projection
+ q = F.linear(hidden_states, q_weight, q_bias)
+
+ # qk normalization (if needed)
+ if norm_type is not None:
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ normalized_shape = torch.Size(normalized_shape)
+ global fused_layer_norm_cuda
+ if fused_layer_norm_cuda is None:
+ fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
+ norm_q_weight = norm_q_weight.contiguous()
+ output_q, mean_q, invvar_q = None, None, None
+ output_k, mean_k, invvar_k = None, None, None
+ if norm_type == "rmsnorm":
+ output_q, invvar_q = fused_layer_norm_cuda.rms_forward_affine(q, normalized_shape, norm_q_weight, eps)
+ elif norm_type == "layernorm":
+ output_q, mean_q, invvar_q = fused_layer_norm_cuda.forward_affine(
+ q, normalized_shape, norm_q_weight, norm_q_bias, eps
+ )
+ else:
+ raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
+ else:
+ output_q = q
+ mean_q = None
+ invvar_q = None
+
+ # q communication launch
+ output_q_res = all_to_all_tensor(
+ output_q, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
+ )
+
+ # k projection
+ k = F.linear(hidden_states, k_weight, k_bias)
+
+ # qk normalization (if needed)
+ if norm_type is not None:
+ if isinstance(normalized_shape, numbers.Integral):
+ normalized_shape = (normalized_shape,)
+ normalized_shape = torch.Size(normalized_shape)
+ if fused_layer_norm_cuda is None:
+ fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
+ norm_k_weight = norm_k_weight.contiguous()
+ output_k, mean_k, invvar_k = None, None, None
+ if norm_type == "rmsnorm":
+ output_k, invvar_k = fused_layer_norm_cuda.rms_forward_affine(k, normalized_shape, norm_k_weight, eps)
+ elif norm_type == "layernorm":
+ output_k, mean_k, invvar_k = fused_layer_norm_cuda.forward_affine(
+ k, normalized_shape, norm_k_weight, norm_k_bias, eps
+ )
+ else:
+ raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
+ else:
+ output_k = k
+ mean_k = None
+ invvar_k = None
+
+ # k communication launch
+ output_k_res = all_to_all_tensor(
+ output_k, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
+ )
+
+ # v projection
+ v = F.linear(hidden_states, v_weight, v_bias)
+
+ # v communication launch
+ v_res = all_to_all_tensor(
+ v, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
+ )
+
+ # q communication collect
+ output_q = output_q_res()
+ output_q = unpadding_tensor_for_seqeunce_parallel(output_q, seq_dimension, unpadded_dim_size)
+
+ # k communication collect
+ output_k = output_k_res()
+ output_k = unpadding_tensor_for_seqeunce_parallel(output_k, seq_dimension, unpadded_dim_size)
+
+ # v communication collect
+ v = v_res()
+ v = unpadding_tensor_for_seqeunce_parallel(v, seq_dimension, unpadded_dim_size)
+
+ # save ctx for backward
+ ctx.sp_group = sp_group
+ ctx.head_dimension = head_dimension
+ ctx.seq_dimension = seq_dimension
+ ctx.norm_type = norm_type
+ ctx.normalized_shape = normalized_shape
+ ctx.eps = eps
+ ctx.save_for_backward(
+ hidden_states,
+ q_weight,
+ q_bias,
+ k_weight,
+ k_bias,
+ v_weight,
+ v_bias,
+ q,
+ norm_q_weight,
+ norm_q_bias,
+ mean_q,
+ invvar_q,
+ k,
+ norm_k_weight,
+ norm_k_bias,
+ mean_k,
+ invvar_k,
+ )
+
+ return output_q, output_k, v
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor):
+ # get ctx for backward
+ sp_group = ctx.sp_group
+ seq_dimension = ctx.seq_dimension
+ head_dimension = ctx.head_dimension
+ norm_type = ctx.norm_type
+ normalized_shape = ctx.normalized_shape
+ eps = ctx.eps
+ (
+ hidden_states,
+ q_weight,
+ q_bias,
+ k_weight,
+ k_bias,
+ v_weight,
+ v_bias,
+ q,
+ norm_q_weight,
+ norm_q_bias,
+ mean_q,
+ invvar_q,
+ k,
+ norm_k_weight,
+ norm_k_bias,
+ mean_k,
+ invvar_k,
+ ) = ctx.saved_tensors
+
+ # initialize grads
+ grad_hidden_states = None
+ grad_q_weight = None
+ grad_q_bias = None
+ grad_k_weight = None
+ grad_k_bias = None
+ grad_v_weight = None
+ grad_v_bias = None
+ grad_norm_q_weight = None
+ grad_norm_q_bias = None
+ grad_norm_k_weight = None
+ grad_norm_k_bias = None
+
+ # v grad communication launch
+ grad_v = grad_output[2].contiguous()
+
+ grad_v = padding_tensor_for_seqeunce_parallel(grad_v, dim=seq_dimension)
+ grad_v_res = all_to_all_tensor(
+ grad_v,
+ scatter_dim=seq_dimension,
+ gather_dim=head_dimension,
+ group=sp_group,
+ async_op=True,
+ )
+
+ # v grad communication collect
+ grad_v = grad_v_res()
+
+ grad_k = grad_output[1].contiguous()
+ # k grad communication launch
+ grad_k = padding_tensor_for_seqeunce_parallel(grad_k, dim=seq_dimension)
+ grad_k_res = all_to_all_tensor(
+ grad_k,
+ scatter_dim=seq_dimension,
+ gather_dim=head_dimension,
+ group=sp_group,
+ async_op=True,
+ )
+
+ # v projection grad
+ grad_v_input = grad_v @ v_weight
+ grad_v_weight = grad_v.transpose(-1, -2) @ hidden_states
+ if v_bias is not None and ctx.needs_input_grad[7]:
+ grad_v_bias = grad_v.sum(0)
+
+ # qk normalization backward (if needed)
+ if norm_type is not None:
+ if norm_type == "rmsnorm":
+ grad_k, grad_norm_k_weight = fused_layer_norm_cuda.rms_backward_affine(
+ grad_k,
+ invvar_k,
+ k,
+ normalized_shape,
+ norm_k_weight,
+ eps,
+ False,
+ )
+ elif norm_type == "layernorm":
+ grad_k, grad_norm_k_weight, grad_norm_k_bias = fused_layer_norm_cuda.backward_affine(
+ grad_k,
+ mean_k,
+ invvar_k,
+ k,
+ normalized_shape,
+ norm_k_weight,
+ norm_k_bias,
+ eps,
+ False,
+ )
+ else:
+ raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
+ else:
+ grad_norm_k_weight = None
+
+ # k grad communication collect
+ grad_k = grad_k_res()
+
+ grad_q = grad_output[0].contiguous()
+ # q grad communication launch
+
+ grad_q = padding_tensor_for_seqeunce_parallel(grad_q, dim=seq_dimension)
+ grad_q_res = all_to_all_tensor(
+ grad_q,
+ scatter_dim=seq_dimension,
+ gather_dim=head_dimension,
+ group=sp_group,
+ async_op=True,
+ )
+
+ # k projection grad
+ grad_k_input = grad_k @ k_weight
+ grad_k_weight = grad_k.transpose(-1, -2) @ hidden_states
+ if k_bias is not None and ctx.needs_input_grad[5]:
+ grad_k_bias = grad_k.sum(0)
+
+ # q grad communication collect
+ grad_q = grad_q_res()
+
+ # qk normalization backward (if needed)
+ if norm_type is not None:
+ if norm_type == "rmsnorm":
+ grad_q, grad_norm_q_weight = fused_layer_norm_cuda.rms_backward_affine(
+ grad_q,
+ invvar_q,
+ q,
+ normalized_shape,
+ norm_q_weight,
+ eps,
+ False,
+ )
+ elif norm_type == "layernorm":
+ grad_q, grad_norm_q_weight, grad_norm_q_bias = fused_layer_norm_cuda.backward_affine(
+ grad_q,
+ mean_q,
+ invvar_q,
+ q,
+ normalized_shape,
+ norm_q_weight,
+ norm_q_bias,
+ eps,
+ False,
+ )
+ else:
+ raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
+ else:
+ grad_norm_q_weight = None
+
+ # q projection grad
+ grad_q_input = grad_q @ q_weight
+ grad_q_weight = grad_q.transpose(-1, -2) @ hidden_states
+ if q_bias is not None and ctx.needs_input_grad[3]:
+ grad_q_bias = grad_q.sum(0)
+
+ # grad
+ grad_hidden_states = grad_q_input + grad_k_input + grad_v_input
+
+ return (
+ grad_hidden_states,
+ None,
+ None,
+ grad_q_weight,
+ grad_q_bias,
+ grad_k_weight,
+ grad_k_bias,
+ grad_v_weight,
+ grad_v_bias,
+ None,
+ grad_norm_q_weight,
+ grad_norm_q_bias,
+ grad_norm_k_weight,
+ grad_norm_k_bias,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class AsyncUlyssesOutputProjection(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ hidden_states: Tensor,
+ seq_dimension: int,
+ head_dimension: int,
+ proj_weight: Tensor,
+ proj_bias: Tensor,
+ unpadded_dim_size: int,
+ group: ProcessGroup,
+ ):
+ sp_group = get_ulysses_sequence_parallel_group() if group is None else group
+
+ # out projection
+ hidden_states = padding_tensor_for_seqeunce_parallel(hidden_states, seq_dimension)
+ hidden_states = all_to_all_tensor(
+ hidden_states, scatter_dim=seq_dimension, gather_dim=head_dimension, group=sp_group
+ )
+ o = F.linear(hidden_states, proj_weight, proj_bias)
+
+ # save ctx for backward
+ ctx.sp_group = sp_group
+ ctx.head_dimension = head_dimension
+ ctx.seq_dimension = seq_dimension
+ ctx.unpadded_dim_size = unpadded_dim_size
+
+ ctx.save_for_backward(
+ hidden_states,
+ proj_weight,
+ proj_bias,
+ )
+
+ return o
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor):
+ # get ctx for backward
+ sp_group = ctx.sp_group
+ head_dimension = ctx.head_dimension
+ seq_dimension = ctx.seq_dimension
+ unpadded_dim_size = ctx.unpadded_dim_size
+ (
+ hidden_states,
+ proj_weight,
+ proj_bias,
+ ) = ctx.saved_tensors
+
+ # initialize grads
+ grad_o = None
+ grad_proj_weight = None
+ grad_proj_bias = None
+
+ # output grad
+ grad_o = grad_output[0] @ (proj_weight)
+
+ # output grad communication launch
+ grad_out_res = all_to_all_tensor(
+ grad_o, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
+ )
+
+ grad_proj_weight = grad_output[0].transpose(-1, -2) @ (hidden_states)
+ if proj_bias is not None and ctx.needs_input_grad[3]:
+ grad_proj_bias = grad_output[0].sum(0)
+
+ # output grad communication collect
+ grad_o = grad_out_res()
+ grad_o = unpadding_tensor_for_seqeunce_parallel(grad_o, seq_dimension, unpadded_dim_size)
+
+ return (
+ grad_o,
+ None,
+ None,
+ grad_proj_weight,
+ grad_proj_bias,
+ None,
+ None,
+ )
+
+
+def async_ulysses_qkv_projection(
+ hidden_states: Tensor = None,
+ seq_dimension: int = None,
+ head_dimension: int = None,
+ q_weight: Tensor = None,
+ q_bias: Optional[Tensor] = None,
+ k_weight: Tensor = None,
+ k_bias: Optional[Tensor] = None,
+ v_weight: Tensor = None,
+ v_bias: Optional[Tensor] = None,
+ norm_type: str = None,
+ norm_q_weight: Optional[Tensor] = None,
+ norm_q_bias: Optional[Tensor] = None,
+ norm_k_weight: Optional[Tensor] = None,
+ norm_k_bias: Optional[Tensor] = None,
+ normalized_shape: Optional[int] = None,
+ eps: Optional[float] = None,
+ unpadded_dim_size: int = None,
+ head_dim: int = None,
+ group: Optional[ProcessGroup] = None,
+):
+ return AsyncUlyssesQKVProjection.apply(
+ hidden_states,
+ seq_dimension,
+ head_dimension,
+ q_weight,
+ q_bias,
+ k_weight,
+ k_bias,
+ v_weight,
+ v_bias,
+ norm_type,
+ norm_q_weight,
+ norm_q_bias,
+ norm_k_weight,
+ norm_k_bias,
+ normalized_shape,
+ eps,
+ unpadded_dim_size,
+ head_dim,
+ group,
+ )
+
+
+def async_ulysses_output_projection(
+ hidden_states: Optional[Tensor] = None,
+ seq_dimension: int = None,
+ head_dimension: int = None,
+ proj_weight: Optional[Tensor] = None,
+ proj_bias: Optional[Tensor] = None,
+ unpadded_dim_size: Optional[int] = None,
+ group: Optional[ProcessGroup] = None,
+):
+ return AsyncUlyssesOutputProjection.apply(
+ hidden_states,
+ seq_dimension,
+ head_dimension,
+ proj_weight,
+ proj_bias,
+ unpadded_dim_size,
+ group,
+ )
diff --git a/lingbotvla/distributed/sequence_parallel/comm.py b/lingbotvla/distributed/sequence_parallel/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d621b9bcb8a41fde09da91828291f1bccd8f5287
--- /dev/null
+++ b/lingbotvla/distributed/sequence_parallel/comm.py
@@ -0,0 +1,349 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from contextlib import nullcontext
+from typing import Any, Optional
+
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+
+_DATA_PARALLEL_GROUP = None
+
+_ULYSSES_SEQUENCE_PARALLEL_GROUP = {"default": None}
+_ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP = {"default": None}
+_ULYSSES_GROUP_KEY = "default"
+
+_CONTEXT_PARALLEL_GROUP = None
+
+_UNIFIED_SEQUENCE_PARALLEL_GROUP = None
+_UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP = None
+
+
+# ------------------------------ Data Parallel ------------------------------ #
+def set_data_parallel_group(group: dist.ProcessGroup):
+ """
+ Set data parallel process group.
+ """
+ global _DATA_PARALLEL_GROUP
+ _DATA_PARALLEL_GROUP = group
+
+
+def get_data_parallel_group() -> Optional[dist.ProcessGroup]:
+ """
+ Get data parallel process group.
+ """
+ global _DATA_PARALLEL_GROUP
+ return _DATA_PARALLEL_GROUP
+
+
+def get_data_parallel_rank() -> Optional[dist.ProcessGroup]:
+ """
+ Get data parallel rank.
+ """
+ group = get_data_parallel_group()
+ return dist.get_rank(group)
+
+
+def get_data_parallel_world_size() -> Optional[dist.ProcessGroup]:
+ """
+ Get data parallel world_size.
+ """
+ group = get_data_parallel_group()
+ return dist.get_world_size(group)
+
+
+# ----------------------------- Ulysses Parallel ---------------------------- #
+def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup, group_key: str = "default"):
+ """
+ Set ulysses sequence parallel process group.
+ """
+ global _ULYSSES_SEQUENCE_PARALLEL_GROUP
+ _ULYSSES_SEQUENCE_PARALLEL_GROUP[group_key] = group
+
+
+def set_ulysses_sequence_parallel_cpu_group(group: dist.ProcessGroup, group_key: str = "default"):
+ """
+ Set ulysses sequence parallel process group.
+ """
+ global _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP
+ _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP[group_key] = group
+
+
+def set_ulysses_sequence_parallel_group_key(group_key: str = "default"):
+ """
+ Set ulysses sequence parallel process group key.
+ """
+ global _ULYSSES_GROUP_KEY
+ _ULYSSES_GROUP_KEY = group_key
+
+
+def get_ulysses_sequence_parallel_group_key() -> str:
+ """
+ Get ulysses sequence parallel group key.
+ """
+ global _ULYSSES_GROUP_KEY
+ return _ULYSSES_GROUP_KEY
+
+
+def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
+ """
+ Get ulysses sequence parallel process group.
+ """
+ global _ULYSSES_SEQUENCE_PARALLEL_GROUP
+ group_key = get_ulysses_sequence_parallel_group_key()
+ if group_key not in _ULYSSES_SEQUENCE_PARALLEL_GROUP:
+ raise RuntimeError(f"Unknown key {group_key} in Ulysses sequence parallel group!")
+ return _ULYSSES_SEQUENCE_PARALLEL_GROUP[group_key]
+
+
+def get_ulysses_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
+ """
+ Get ulysses sequence parallel CPU process group.
+ """
+ global _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP
+ group_key = get_ulysses_sequence_parallel_group_key()
+ if group_key not in _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP:
+ raise RuntimeError(f"Unknown key {group_key} in Ulysses sequence parallel group!")
+ return _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP[group_key]
+
+
+def get_ulysses_sequence_parallel_group_by_key(group_key: str = "default") -> Optional[dist.ProcessGroup]:
+ """
+ Get ulysses sequence parallel process group.
+ """
+ global _ULYSSES_SEQUENCE_PARALLEL_GROUP
+ if group_key not in _ULYSSES_SEQUENCE_PARALLEL_GROUP:
+ raise RuntimeError(f"Unknown key {group_key} in Ulysses sequence parallel group!")
+ return _ULYSSES_SEQUENCE_PARALLEL_GROUP[group_key]
+
+
+def get_ulysses_sequence_parallel_cpu_group_by_key(group_key: str = "default") -> Optional[dist.ProcessGroup]:
+ """
+ Get ulysses sequence parallel CPU process group.
+ """
+ global _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP
+ if group_key not in _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP:
+ raise RuntimeError(f"Unknown key {group_key} in Ulysses sequence parallel group!")
+ return _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP[group_key]
+
+
+def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
+ """
+ Get ulysses sequence parallel rank.
+ """
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ return dist.get_rank(group) if group else 0
+
+
+def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
+ """
+ Get ulysses sequence parallel world size.
+ """
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ return dist.get_world_size(group) if group else 1
+
+
+# ----------------------------- Context Parallel ---------------------------- #
+
+
+def set_context_parallel_group(cp_group: dist.ProcessGroup):
+ """
+ Set context parallel process group.
+ """
+ global _CONTEXT_PARALLEL_GROUP
+ _CONTEXT_PARALLEL_GROUP = cp_group
+
+
+def get_context_parallel_group(check_initialized=True):
+ """Get the context parallel group the caller rank belongs to."""
+ global _CONTEXT_PARALLEL_GROUP
+ if check_initialized:
+ assert _CONTEXT_PARALLEL_GROUP is not None, "context parallel group is not initialized"
+ return _CONTEXT_PARALLEL_GROUP
+
+
+def get_context_parallel_rank():
+ """Return my rank for the context parallel group."""
+
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_rank(group=get_context_parallel_group())
+ else:
+ return 0
+
+
+def get_context_parallel_world_size():
+ """Return world size for the context parallel group."""
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_world_size(group=get_context_parallel_group())
+ else:
+ return 0
+
+
+# ----------------------------- Unified Parallel ---------------------------- #
+def set_unified_sequence_parallel_group(group: dist.ProcessGroup):
+ """
+ Set unified sequence parallel process group.
+ """
+ global _UNIFIED_SEQUENCE_PARALLEL_GROUP
+ _UNIFIED_SEQUENCE_PARALLEL_GROUP = group
+
+
+def set_unified_sequence_parallel_cpu_group(group: dist.ProcessGroup):
+ """
+ Set unified sequence parallel process group.
+ """
+ global _UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP
+ _UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP = group
+
+
+def get_unified_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
+ """
+ Get unified sequence parallel process group.
+ """
+ global _UNIFIED_SEQUENCE_PARALLEL_GROUP
+ return _UNIFIED_SEQUENCE_PARALLEL_GROUP
+
+
+def get_unified_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
+ """
+ Get unified sequence parallel CPU process group.
+ """
+ global _UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP
+ return _UNIFIED_SEQUENCE_PARALLEL_CPU_GROUP
+
+
+def get_unified_sequence_parallel_rank() -> int:
+ """
+ Get unified sequence parallel rank.
+ """
+ group = get_unified_sequence_parallel_group()
+ return dist.get_rank(group) if group else 0
+
+
+def get_unified_sequence_parallel_world_size() -> int:
+ """
+ Get unified sequence parallel world size.
+ """
+ group = get_unified_sequence_parallel_group()
+ return dist.get_world_size(group) if group else 1
+
+
+# ------------------------------- Initialize ------------------------------- #
+def init_sequence_parallel(
+ ulysses_size: int = 1, sep_dp: bool = False, ulysses_group_key: str = "default", cp_size: int = 1
+):
+ """
+ Initialize unified sequence parallel.
+ """
+ global _CONTEXT_PARALLEL_GROUP
+ global _ULYSSES_SEQUENCE_PARALLEL_GROUP
+ global _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP
+
+ set_ulysses_sequence_parallel_group(group=None, group_key="default")
+ set_ulysses_sequence_parallel_cpu_group(group=None, group_key="default")
+
+ if ulysses_size == 1 and cp_size == 1:
+ return
+
+ assert dist.is_initialized()
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ unified_sp_size = ulysses_size * cp_size
+ assert world_size % unified_sp_size == 0
+ data_parallel_size = world_size // unified_sp_size
+
+ if cp_size > 1:
+ assert _CONTEXT_PARALLEL_GROUP is None, "Context parallel group has already been initialized!"
+ if ulysses_size:
+ assert (ulysses_group_key == "default" and _ULYSSES_SEQUENCE_PARALLEL_GROUP[ulysses_group_key] is None) or (
+ ulysses_group_key != "default" and ulysses_group_key not in _ULYSSES_SEQUENCE_PARALLEL_GROUP
+ ), f"Ulysses sequence parallel group ({ulysses_group_key}) has already been initialized!"
+ assert (
+ ulysses_group_key == "default" and _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP[ulysses_group_key] is None
+ ) or (ulysses_group_key != "default" and ulysses_group_key not in _ULYSSES_SEQUENCE_PARALLEL_CPU_GROUP), (
+ f"Ulysses sequence parallel ({ulysses_group_key}) group has already been initialized!"
+ )
+
+ for i in range(data_parallel_size):
+ # build ulysses group
+ if ulysses_size > 1:
+ for j in range(cp_size):
+ start_rank = i * unified_sp_size + j * ulysses_size
+ end_rank = start_rank + ulysses_size
+ ulysses_ranks = range(start_rank, end_rank)
+ ulysses_group = dist.new_group(ulysses_ranks)
+ ulysses_cpu_group = dist.new_group(ulysses_ranks, backend="gloo")
+ if rank in ulysses_ranks:
+ set_ulysses_sequence_parallel_group(group=ulysses_group, group_key=ulysses_group_key)
+ set_ulysses_sequence_parallel_cpu_group(group=ulysses_cpu_group, group_key=ulysses_group_key)
+
+ # build cp group
+ if cp_size > 1:
+ for j in range(ulysses_size):
+ cp_global_ranks = range(i * unified_sp_size + j, (i + 1) * unified_sp_size, ulysses_size)
+ cp_group = dist.new_group(cp_global_ranks)
+ if rank in cp_global_ranks:
+ set_context_parallel_group(cp_group=cp_group)
+
+ # build unified sp group
+ unified_sp_ranks = range(i * unified_sp_size, (i + 1) * unified_sp_size)
+ sp_group = dist.new_group(unified_sp_ranks)
+ sp_cpu_group = dist.new_group(unified_sp_ranks, backend="gloo")
+ if rank in unified_sp_ranks:
+ set_unified_sequence_parallel_group(group=sp_group)
+ set_unified_sequence_parallel_cpu_group(group=sp_cpu_group)
+
+ if sep_dp:
+ for j in range(unified_sp_size):
+ dp_ranks = range(j, world_size, unified_sp_size)
+ dp_group = dist.new_group(dp_ranks)
+ if rank in dp_ranks:
+ set_data_parallel_group(dp_group)
+
+
+class UlyssesGroupKeyManager:
+ def __init__(self, group_key: str):
+ self.group_key = group_key
+
+ def __enter__(self):
+ set_ulysses_sequence_parallel_group_key(group_key=self.group_key)
+
+ def __exit__(self, *args: Any):
+ set_ulysses_sequence_parallel_group_key(group_key="default")
+
+
+def is_ulysses_sequence_parallel_initialized() -> bool:
+ """
+ Check if ulysses sequence parallel is initialized.
+ """
+ return get_ulysses_sequence_parallel_group() is not None
+
+
+def is_context_parallel_initialized() -> bool:
+ """
+ Check if ulysses sequence parallel is initialized.
+ """
+ return get_context_parallel_group() is not None
+
+
+def get_ulysses_group_key_context(group_key: str = "default"):
+ if not isinstance(group_key, str):
+ raise RuntimeError(f"A Ulysses group key must be specified, now get: {group_key}")
+
+ if group_key != "default":
+ return UlyssesGroupKeyManager(group_key)
+ else:
+ return nullcontext()
diff --git a/lingbotvla/distributed/sequence_parallel/data.py b/lingbotvla/distributed/sequence_parallel/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..1351eaa30a3e27b157a9a9eecfa5b9d858b1b9dd
--- /dev/null
+++ b/lingbotvla/distributed/sequence_parallel/data.py
@@ -0,0 +1,147 @@
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch import Tensor
+from torch.distributed import ProcessGroup
+
+from ...data.constants import IGNORE_INDEX
+from .comm import get_ulysses_sequence_parallel_group, get_unified_sequence_parallel_group
+from .ulysses import _Gather, _Slice
+from .utils import pad_tensor, unpadding_tensor_for_seqeunce_parallel
+
+
+def slice_input_tensor(
+ x: Tensor,
+ dim: int,
+ padding: bool = True,
+ padding_value: int = 0,
+ group: ProcessGroup = None,
+) -> Tensor:
+ """
+ A func to slice the input sequence in sequence parallel
+ """
+ group = get_unified_sequence_parallel_group() if group is None else group
+ if not group:
+ return x
+ sp_rank = dist.get_rank(group)
+ sp_world = dist.get_world_size(group)
+ dim_size = x.shape[dim]
+ unit = (dim_size + sp_world - 1) // sp_world
+ if padding and dim_size % sp_world:
+ padding_size = sp_world - (dim_size % sp_world)
+ x = pad_tensor(x, dim, padding_size, padding_value)
+ slc = [slice(None)] * len(x.shape)
+ slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1))
+ return x[slc].contiguous()
+
+
+def slice_input_tensor_scale_grad(
+ x: Tensor,
+ dim: int,
+ group: ProcessGroup = None,
+ scale_grad=True,
+):
+ """
+ A func to gather the outputs for the model result in sequence parallel
+ """
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ if not group:
+ return x
+ x = _Slice.apply(group, x, dim, scale_grad)
+ return x
+
+
+def gather_outputs(
+ x: Tensor,
+ gather_dim: int,
+ padding_dim: Optional[int] = None,
+ unpad_dim_size: Optional[int] = None,
+ scale_grad=True,
+ group: ProcessGroup = None,
+):
+ """
+ A func to gather the outputs for the model result in sequence parallel
+ """
+ group = get_unified_sequence_parallel_group() if group is None else group
+ if not group:
+ return x
+ x = _Gather.apply(group, x, gather_dim, scale_grad)
+ if padding_dim is not None:
+ x = unpadding_tensor_for_seqeunce_parallel(x, padding_dim, unpad_dim_size, group)
+ return x
+
+
+def slice_position_embedding(position_embeddings: tuple, dim: int = 1, sp_group: dist.ProcessGroup = None):
+ """
+ Forward hook for LlamaRotaryEmbedding to apply Ulysses tensor slicing.
+
+ Args:
+ position_embeddings: Input tensors to the forward method
+ dim: The dimension to slice
+ sp_group: The sequence parallel group
+ Returns:
+ Modified (cos, sin) tuple with slicing applied if ulysses is enabled
+ """
+ if sp_group is not None:
+ cos, sin = position_embeddings
+ cos = slice_input_tensor(cos, dim=dim, padding=False, group=sp_group)
+ sin = slice_input_tensor(sin, dim=dim, padding=False, group=sp_group)
+ return (cos, sin)
+ return position_embeddings
+
+
+def sequence_parallel_preprocess(
+ input_ids: torch.Tensor,
+ labels: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ sp_group: Optional[ProcessGroup] = None,
+):
+ """
+ Preprocess input_ids and labels for sequence parallel training.
+
+ Args:
+ input_ids: Input token ids
+ labels: Label token ids
+ position_ids: Position ids
+ attention_mask: Attention mask
+ cu_seqlens: Cumulative sequence lengths
+
+ Returns:
+ Preprocessed input_ids, labels, position_ids, attention_mask, cu_seqlens
+ """
+ if sp_group is not None:
+ sp_size = dist.get_world_size(sp_group)
+ padding_size = (sp_size - (input_ids.shape[-1] % sp_size)) % sp_size
+
+ # Slice input_ids among sequence parallel group
+ input_ids = slice_input_tensor(input_ids, dim=-1, padding=True, padding_value=0, group=sp_group)
+
+ # Slice labels among sequence parallel group
+ if labels is not None:
+ labels = labels[..., 1:].contiguous() # shift labels
+ labels = F.pad(labels, (0, 1), "constant", IGNORE_INDEX) # pad to the same length as input_ids
+ labels = slice_input_tensor(labels, dim=-1, padding=True, padding_value=IGNORE_INDEX, group=sp_group)
+
+ # Padding position_ids
+ if position_ids is not None:
+ position_ids = pad_tensor(position_ids, dim=-1, padding_size=padding_size, padding_value=0)
+
+ # Padding attention_mask
+ if attention_mask is not None:
+ attn_mask_padding_value = 1 if position_ids is not None else 0
+ attention_mask = pad_tensor(
+ attention_mask, dim=-1, padding_size=padding_size, padding_value=attn_mask_padding_value
+ )
+
+ # Padding cu_seqlens
+ if cu_seqlens is not None:
+ cu_seqlens_padding_value = cu_seqlens[-1].item() + padding_size
+ cu_seqlens = pad_tensor(
+ cu_seqlens, dim=-1, padding_size=padding_size, padding_value=cu_seqlens_padding_value
+ )
+
+ return input_ids, labels, position_ids, attention_mask, cu_seqlens
diff --git a/lingbotvla/distributed/sequence_parallel/loss.py b/lingbotvla/distributed/sequence_parallel/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..925b9a99854ce7370db8803ed73f0a736d1555e0
--- /dev/null
+++ b/lingbotvla/distributed/sequence_parallel/loss.py
@@ -0,0 +1,51 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Tuple
+
+import torch
+import torch.distributed as dist
+
+from .comm import (
+ get_unified_sequence_parallel_group,
+ get_unified_sequence_parallel_world_size,
+)
+
+
+class ReduceLoss(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx: torch.autograd.Function, loss: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
+ if num_valid_tokens == 0:
+ loss = torch.nan_to_num(loss)
+
+ local_num_tokens = num_valid_tokens.detach().clone()
+ loss *= num_valid_tokens
+ group = get_unified_sequence_parallel_group()
+ dist.all_reduce(loss, group=group)
+ dist.all_reduce(num_valid_tokens, group=group)
+ ctx.save_for_backward(local_num_tokens, num_valid_tokens)
+ return loss / num_valid_tokens
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.Function, grad_output: torch.Tensor
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ local_num_tokens, global_num_tokens = ctx.saved_tensors
+ grad_output = get_unified_sequence_parallel_world_size() * local_num_tokens * grad_output / global_num_tokens
+ return grad_output, None
+
+
+def reduce_sequence_parallel_loss(loss: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
+ return ReduceLoss.apply(loss, num_valid_tokens)
diff --git a/lingbotvla/distributed/sequence_parallel/ulysses.py b/lingbotvla/distributed/sequence_parallel/ulysses.py
new file mode 100644
index 0000000000000000000000000000000000000000..e87167c7d81b059f4aab1b19e72d1adaf0591d30
--- /dev/null
+++ b/lingbotvla/distributed/sequence_parallel/ulysses.py
@@ -0,0 +1,334 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+from torch.distributed import ProcessGroup
+
+from .comm import (
+ get_ulysses_sequence_parallel_group,
+ get_ulysses_sequence_parallel_world_size,
+)
+from .utils import (
+ pad_tensor,
+ unpad_tensor,
+)
+
+
+def _all_gather(
+ x: Tensor,
+ group: dist.ProcessGroup,
+):
+ device = x.device
+ dtype = x.dtype
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ sp_world_size = dist.get_world_size(group)
+ x_size = torch.tensor(x.size()).to(device)
+ size_list = [torch.zeros(x_size.size(), dtype=torch.int64, device=device) for i in range(sp_world_size)]
+ dist.all_gather(size_list, x_size, group=group)
+ tensor_list = [torch.zeros(torch.Size(size_list[i]), dtype=dtype, device=device) for i in range(sp_world_size)]
+ dist.all_gather(tensor_list, x, group=group)
+ return tensor_list, size_list
+
+
+def _all_gather_into_tensor(
+ x: Tensor,
+ group: dist.ProcessGroup,
+):
+ dim_size = list(x.size())
+
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ sp_world_size = dist.get_world_size(group)
+ dim_size[0] = dim_size[0] * sp_world_size
+ output = torch.empty(dim_size, dtype=x.dtype, device=torch.cuda.current_device())
+ dist.all_gather_into_tensor(output, x, group=group)
+ return output
+
+
+def _all_to_all(
+ local_input: Tensor,
+ scatter_dim: int,
+ gather_dim: int,
+ group: Optional[dist.ProcessGroup] = None,
+ async_op: bool = False,
+):
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ seq_world_size = dist.get_world_size(group)
+ input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
+ output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
+ comm = dist.all_to_all(output_list, input_list, group=group, async_op=async_op)
+ if async_op:
+
+ def wait():
+ comm.wait()
+ return torch.cat(output_list, dim=gather_dim).contiguous()
+
+ return wait
+ return torch.cat(output_list, dim=gather_dim).contiguous()
+
+
+def _all_to_all_single(
+ x: Tensor, scatter_dim: int, gather_dim: int, group: Optional[dist.ProcessGroup] = None, async_op: bool = False
+):
+ """
+ A function to do all-to-all on the first two dim
+ """
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ sp_world_size = dist.get_world_size(group)
+ assert scatter_dim <= 1, "scatter_dim must be 0 or 1 when using all_to_all_single!"
+ assert gather_dim <= 1, "gather_dim must be 0 or 1 when using all_to_all_single!"
+ if scatter_dim != 0:
+ gather_dim_bef = x.shape[gather_dim]
+ scatter_dim_bef = x.shape[scatter_dim]
+ x = (
+ x.reshape([gather_dim_bef, sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
+ .transpose(0, 1)
+ .reshape([gather_dim_bef * sp_world_size, scatter_dim_bef // sp_world_size] + list(x.shape[2:]))
+ .contiguous()
+ )
+
+ output = torch.empty_like(x)
+ comm = dist.all_to_all_single(output, x.contiguous(), group=group, async_op=async_op)
+
+ if async_op:
+
+ def wait():
+ comm.wait()
+ if scatter_dim == 0:
+ return torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
+ else:
+ return output
+
+ return wait
+
+ if scatter_dim == 0:
+ output = torch.cat(output.split(x.size(0) // sp_world_size), dim=gather_dim)
+ return output
+
+
+def all_to_all_tensor(
+ x: Tensor,
+ scatter_dim: int,
+ gather_dim: int,
+ group: dist.ProcessGroup,
+ async_op: bool = False,
+):
+ if scatter_dim <= 1 and gather_dim <= 1:
+ return _all_to_all_single(x, scatter_dim, gather_dim, group, async_op)
+ else:
+ return _all_to_all(x, scatter_dim, gather_dim, group, async_op)
+
+
+class _SeqAllToAll(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ group: dist.ProcessGroup,
+ local_input: Tensor,
+ scatter_dim: int,
+ gather_dim: int,
+ async_op: bool,
+ ) -> Tensor:
+ ctx.group = group
+ ctx.scatter_dim = scatter_dim
+ ctx.gather_dim = gather_dim
+ ctx.async_op = async_op
+ return all_to_all_tensor(local_input, scatter_dim, gather_dim, group, async_op)
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
+ if ctx.async_op:
+ input_t = torch.cat(grad_output[1:], dim=ctx.gather_dim).contiguous()
+ else:
+ input_t = grad_output[0]
+ return (
+ None,
+ all_to_all_tensor(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group, False),
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class _Slice(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int, scale_grad: bool) -> Tensor:
+ ctx.group = group
+ ctx.rank = dist.get_rank(group)
+ seq_world_size = dist.get_world_size(group)
+ ctx.seq_world_size = seq_world_size
+ ctx.dim = dim
+ ctx.scale_grad = scale_grad
+ dim_size = local_input.shape[dim]
+ return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous()
+
+ @staticmethod
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]:
+ dim_size = list(grad_output.size())
+ split_size = dim_size[0]
+ output = _all_gather_into_tensor(grad_output, group=ctx.group)
+ if ctx.scale_grad:
+ output = output / ctx.seq_world_size
+ return (None, torch.cat(output.split(split_size), dim=ctx.dim), None, None)
+
+
+class _Gather(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ group: dist.ProcessGroup,
+ local_input: Tensor,
+ dim: int,
+ grad_scale: Optional[bool] = False,
+ ) -> Tensor:
+ ctx.group = group
+ ctx.rank = dist.get_rank(group)
+ ctx.dim = dim
+ ctx.grad_scale = grad_scale
+ seq_world_size = dist.get_world_size(group)
+ ctx.seq_world_size = seq_world_size
+ output, size_list = _all_gather(local_input.contiguous(), group=ctx.group)
+ dim_size_list = [size_list[i][dim].item() for i in range(seq_world_size)]
+ ctx.dim_size_list = dim_size_list
+ return torch.cat(output, dim=dim)
+
+ @staticmethod
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
+ if ctx.grad_scale:
+ grad_output = grad_output * ctx.seq_world_size
+ return (
+ None,
+ grad_output.split(ctx.dim_size_list, dim=ctx.dim)[ctx.rank].contiguous(),
+ None,
+ None,
+ )
+
+
+def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int, group: ProcessGroup = None) -> Tensor:
+ """
+ A func to sync attention result with alltoall in sequence parallel
+ """
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ if not group:
+ return x
+ dim_size = x.size(seq_dim)
+ sp_world = get_ulysses_sequence_parallel_world_size(group)
+ if dim_size % sp_world != 0:
+ padding_size = sp_world - (dim_size % sp_world)
+ x = pad_tensor(x, seq_dim, padding_size)
+ return _SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
+
+
+def gather_seq_scatter_heads(
+ x: Tensor,
+ seq_dim: int,
+ head_dim: int,
+ unpadded_dim_size: int = 0,
+ async_op: bool = False,
+ group: ProcessGroup = None,
+) -> Tensor:
+ """
+ A func to sync embedding input with alltoall in sequence parallel
+ """
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ if not group:
+ return x
+ sp_world = get_ulysses_sequence_parallel_world_size(group)
+ if async_op:
+ return _SeqAllToAll.apply(group, x, head_dim, seq_dim, async_op)
+ else:
+ x = _SeqAllToAll.apply(group, x, head_dim, seq_dim, async_op)
+ if unpadded_dim_size and unpadded_dim_size % sp_world != 0:
+ padding_size = x.size(seq_dim) - unpadded_dim_size
+ x = unpad_tensor(x, seq_dim, padding_size)
+ return x
+
+
+def gather_seq_scatter_heads_qkv(
+ qkv_tensor: Tensor,
+ seq_dim: int,
+ unpadded_dim_size: Optional[int] = None,
+ restore_shape: bool = True,
+ async_op: bool = False,
+ group: ProcessGroup = None,
+) -> Tensor:
+ """
+ A func to sync splited qkv tensor
+ qkv_tensor: the tensor we want to do alltoall with. The last dim must
+ be the projection_idx, which we will split into 3 part. After
+ spliting, the gather idx will be projecttion_idx + 1
+ seq_dim: gather_dim for all2all comm
+ restore_shape: if True, output will has the same shape length as input
+ """
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ if not group:
+ return qkv_tensor
+ sp_world = get_ulysses_sequence_parallel_world_size(group)
+ orig_shape = qkv_tensor.shape
+ scatter_dim = qkv_tensor.dim()
+ bef_all2all_shape = list(orig_shape)
+ qkv_proj_dim = bef_all2all_shape[-1]
+ bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3]
+ qkv_tensor = qkv_tensor.view(bef_all2all_shape)
+ if async_op:
+ return _SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
+ else:
+ qkv_tensor = _SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, async_op)
+
+ if restore_shape:
+ out_shape = list(orig_shape)
+ out_shape[seq_dim] *= sp_world
+ out_shape[-1] = qkv_proj_dim // sp_world
+ qkv_tensor = qkv_tensor.view(out_shape)
+
+ # remove padding
+ if unpadded_dim_size and unpadded_dim_size % sp_world != 0:
+ padding_size = qkv_tensor.size(seq_dim) - unpadded_dim_size
+ qkv_tensor = unpad_tensor(qkv_tensor, seq_dim, padding_size)
+
+ return qkv_tensor
+
+
+class _AlltoAllRegion(torch.autograd.Function):
+ """balance the intermediate tensors in the sequence parallel region"""
+
+ @staticmethod
+ def forward(ctx, group, x, input_splits, output_splits):
+ ctx.group = group
+ ctx.input_splits = input_splits
+ ctx.output_splits = output_splits
+ input_tensor_list = list(x.split(input_splits, dim=0))
+ input_tensor_list = [t.contiguous() for t in input_tensor_list]
+ output_tensor_list = [torch.empty([o, *x.shape[1:]], dtype=x.dtype, device=x.device) for o in output_splits]
+ dist.all_to_all(output_tensor_list, input_tensor_list, group=group)
+ return torch.cat(output_tensor_list, dim=0)
+
+ def backward(ctx, dy):
+ dx_list = [torch.empty([i, *dy.shape[1:]], dtype=dy.dtype, device=dy.device) for i in ctx.input_splits]
+ dy_list = list(dy.split(ctx.output_splits, dim=0))
+ dist.all_to_all(dx_list, dy_list, group=ctx.group)
+ return None, torch.cat(dx_list, dim=0), None, None
+
+
+def all_to_all_images(image_embeds, in_splits, out_splits):
+ if not in_splits:
+ return image_embeds
+ image_embeds = image_embeds[: sum(in_splits)]
+ group = get_ulysses_sequence_parallel_group()
+ return _AlltoAllRegion.apply(group, image_embeds, in_splits, out_splits)
diff --git a/lingbotvla/distributed/sequence_parallel/utils.py b/lingbotvla/distributed/sequence_parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2fd31df545cf1d39dafe08238db7dffc3a842d5
--- /dev/null
+++ b/lingbotvla/distributed/sequence_parallel/utils.py
@@ -0,0 +1,145 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Tuple
+
+import torch
+from torch import Tensor
+from torch.distributed import ProcessGroup
+
+from .comm import (
+ get_ulysses_sequence_parallel_group,
+ get_ulysses_sequence_parallel_rank,
+ get_ulysses_sequence_parallel_world_size,
+)
+
+
+def unpadding_tensor_for_seqeunce_parallel(x: Tensor, dim: int, unpadded_dim_size: int, group: ProcessGroup = None):
+ """
+ A func to remove the padding part of the tensor based on its original shape
+ """
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ if not group:
+ return x
+ sp_world = get_ulysses_sequence_parallel_world_size(group)
+ if unpadded_dim_size % sp_world == 0:
+ return x
+ padding_size = sp_world - (unpadded_dim_size % sp_world)
+ assert (padding_size + unpadded_dim_size) % sp_world == 0
+ return unpad_tensor(x, dim=dim, padding_size=padding_size)
+
+
+def padding_tensor_for_seqeunce_parallel(x: Tensor, dim: int, group: ProcessGroup = None) -> Tensor:
+ """
+ A func to remove the padding part of the tensor based on its original shape
+ """
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ if not group:
+ return x
+ sp_world = get_ulysses_sequence_parallel_world_size(group)
+ dim_size = x.shape[dim]
+ if dim_size % sp_world:
+ padding_size = sp_world - (dim_size % sp_world)
+ x = pad_tensor(x, dim, padding_size)
+ return x
+
+
+def pad_tensor(x: Tensor, dim: int, padding_size: int, padding_value: int = 0) -> Tensor:
+ shape = list(x.shape)
+ shape[dim] = padding_size
+ pad = torch.full(shape, padding_value, dtype=x.dtype, device=x.device)
+ return torch.cat([x, pad], dim=dim)
+
+
+def unpad_tensor(x: Tensor, dim: int, padding_size: int) -> Tensor:
+ slc = [slice(None)] * len(x.shape)
+ slc[dim] = slice(0, -padding_size)
+ return x[slc]
+
+
+def remove_last_rank_padding(x: Tensor, dim: int, unpad_dim_size: int, group: ProcessGroup = None) -> Tensor:
+ group = get_ulysses_sequence_parallel_group() if group is None else group
+ if not group:
+ return x
+ sp_rank = get_ulysses_sequence_parallel_rank(group)
+ sp_world = get_ulysses_sequence_parallel_world_size(group)
+ if unpad_dim_size % sp_world == 0 and sp_rank + 1 != sp_world:
+ return x
+ pad = sp_world - (unpad_dim_size % sp_world)
+ assert (pad + x.shape[dim]) % sp_world == 0
+ slc = [slice(None)] * len(x.shape)
+ slc[dim] = slice(0, -pad)
+ return x[slc]
+
+
+def has_overlap(x1, x2, y1, y2) -> Tuple[bool, int]:
+ """
+ A func to judge if two intervals have overlaps, and return the length of overlaps
+ """
+ max_value = max(x1, y1)
+ min_value = min(x2, y2)
+ return max_value < min_value, min_value - max_value
+
+
+def all2all_splits(image_lens: List, image_lens_per_rank: List, sp_size: int, sp_rank: int) -> Tuple[List, List]:
+ """
+ A func to generate splits for all2all communication
+ """
+ assert sum(image_lens) == sum(image_lens_per_rank)
+ num_images = len(image_lens)
+ sp_step = (num_images + sp_size - 1) // sp_size
+ in_splits, out_splits = [0 for _ in range(sp_size)], [0 for _ in range(sp_size)]
+ cu_seqlens = [0] + [sum(image_lens_per_rank[: i + 1]) for i in range(sp_size)]
+ rank = 0
+ num_tokens = 0
+ for image_idx, image_lens in enumerate(image_lens):
+ src_rank = image_idx // sp_step
+ tokens_split = []
+ for rank in range(sp_size):
+ overlap, overlap_len = has_overlap(
+ num_tokens, num_tokens + image_lens, cu_seqlens[rank], cu_seqlens[rank + 1]
+ )
+ if overlap:
+ tokens_split.append(overlap_len)
+ if rank == sp_rank:
+ out_splits[src_rank] += overlap_len
+ if src_rank == sp_rank:
+ in_splits[rank] += overlap_len
+ assert sum(tokens_split) == image_lens
+
+ num_tokens += image_lens
+
+ return in_splits, out_splits
+
+
+def vlm_images_a2a_meta(
+ sp_rank: int, sp_size: int, image_lens: List, image_masks: torch.Tensor
+) -> Tuple[List, List, torch.Tensor]:
+ """
+ A func to generate metadata for all2all communication after we balance the computaion in vision encoder
+ Usually we will split the batches of images for vision encoder in sp group. However, before we feed images
+ tokens into language model, we need to use all2all communication to gather necessary tokens into the current rank.
+ """
+ assert sum(image_lens) == image_masks.sum().item(), (
+ f"The sum of image_lens must be equal to the number of tokens, {image_lens} vs {image_masks.sum().item()}"
+ )
+ seq_len = image_masks.shape[1]
+ step = (seq_len + sp_size - 1) // sp_size
+ sequence_per_rank = [min(step * (i + 1), seq_len) - min(step * i, seq_len) for i in range(sp_size)]
+ mask_per_rank = image_masks.split(sequence_per_rank, dim=1)
+ image_lens_per_rank = [mask_per_rank[i].sum().item() for i in range(sp_size)]
+ in_splits, out_splits = all2all_splits(image_lens, image_lens_per_rank, sp_size, sp_rank)
+ local_image_masks = mask_per_rank[sp_rank]
+ return in_splits, out_splits, local_image_masks
diff --git a/lingbotvla/distributed/torch_parallelize.py b/lingbotvla/distributed/torch_parallelize.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecbf9f2467665973e8126852c7c800b2f8128306
--- /dev/null
+++ b/lingbotvla/distributed/torch_parallelize.py
@@ -0,0 +1,378 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import types
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+import torch
+import torch.nn as nn
+from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel, MixedPrecision, ShardingStrategy
+from torch.distributed.fsdp._common_utils import _get_module_fsdp_state_if_fully_sharded_module
+from torch.distributed.fsdp._runtime_utils import _lazy_init
+from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.checkpoint import create_selective_checkpoint_contexts, noop_context_fn
+
+from ..models import load_model_weights
+from ..utils import logging
+from ..utils.import_utils import is_torch_version_greater_than
+from .checkpoint import CheckpointFunction
+from .fsdp import (
+ clip_grad_norm_,
+ init_fsdp_fn,
+ parallel_init_fsdp_fn,
+ parallel_load_safetensors,
+ register_checkpoint_extension,
+)
+from .parallel_state import get_parallel_state
+from .utils import get_module_from_path, set_module_from_path
+
+
+if is_torch_version_greater_than("2.4"):
+ from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
+ from torch.distributed.tensor.parallel import parallelize_module
+
+
+logger = logging.get_logger(__name__)
+
+
+def verbose_fsdp_grouping(model, prefix="", depth=0):
+ indent = " " * depth
+
+ for name, child in model.named_children():
+ if isinstance(child, FullyShardedDataParallel):
+ module_names = [m_name for m_name, _ in child.named_modules()][1:] # [1:] 排除自身
+ strategy = child.sharding_strategy
+ logger.debug_rank0(f"{indent}├── [FSDP Group] {prefix}{name}")
+ logger.debug_rank0(
+ f"{indent}│ ├── Sharding Strategy: {strategy}, Mixed Precision: {child.mixed_precision}"
+ )
+ logger.debug_rank0(f"{indent}│ └── Contains Modules: {module_names}")
+
+ verbose_fsdp_grouping(child, prefix=f"{prefix}{name}.", depth=depth + 1)
+ else:
+ verbose_fsdp_grouping(child, prefix=f"{prefix}{name}.", depth=depth)
+
+
+def build_parallelize_model(
+ model: "nn.Module",
+ weights_path: Optional[str] = None,
+ sharding_plan: Optional[Dict[str, Any]] = None,
+ enable_full_shard: bool = True,
+ enable_mixed_precision: bool = True,
+ enable_fp32: bool = False,
+ enable_gradient_checkpointing: bool = True,
+ basic_modules: Optional[List[str]] = None,
+ fsdp_llm_blocks: bool = True,
+ ignore_norm: bool = False,
+ use_depth_align: bool = False,
+ ignore_depth: bool = False,
+ **kwargs,
+) -> "nn.Module":
+ """
+ Applies parallel strategies to the model.
+ """
+ parallel_state = get_parallel_state()
+ fsdp_no_shard_states = None
+
+ if not parallel_state.fsdp_enabled:
+ if kwargs.get("init_device") != "cuda":
+ raise ValueError("Only FSDP training supports `init_device=cpu` or `init_device=meta`.")
+ if kwargs.pop("enable_fsdp_offload", False):
+ raise ValueError("Only FSDP training supports `enable_fsdp_offload`.")
+
+ if enable_mixed_precision: # upcast to float32 before feed it to optimizer
+ model = model.float()
+
+ if enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
+ logger.info_rank0("Enable gradient checkpointing.")
+ use_reentrant = kwargs.pop("enable_reentrant", False)
+ if use_reentrant:
+ torch.utils.checkpoint.CheckpointFunction = CheckpointFunction
+
+ ops_to_save = kwargs.pop("ops_to_save", None)
+ context_fn = (
+ partial(create_selective_checkpoint_contexts, ops_to_save) if ops_to_save is not None else noop_context_fn
+ )
+ model.gradient_checkpointing_enable(
+ gradient_checkpointing_kwargs={"use_reentrant": use_reentrant, "context_fn": context_fn}
+ )
+
+ if parallel_state.tp_enabled:
+ logger.info_rank0("Apply tensor parallel to the model.")
+ model = parallelize_module(
+ model,
+ device_mesh=parallel_state.tp_mesh,
+ )
+
+ if parallel_state.ep_enabled:
+ parallel_plan = model.get_parallel_plan()
+ ep_param_suffix = parallel_plan.ep_param_suffix
+
+ fqn2spec_info = parallel_plan.apply(model, parallel_state.ep_fsdp_device_mesh)
+ fsdp_no_shard_states_fqn_to_module = parallel_plan.get_fsdp_no_shard_info(model)
+
+ fsdp_no_shard_states = list(fsdp_no_shard_states_fqn_to_module.values())
+ fsdp_no_shard_states_fqn = list(fsdp_no_shard_states_fqn_to_module.keys())
+ logger.info_rank0(f"Apply expert parallel to the model successfully.\nEP modules: {fsdp_no_shard_states_fqn}.")
+ else:
+ fqn2spec_info = None
+ ep_param_suffix = None
+ fsdp_no_shard_states = None
+ fsdp_no_shard_states_fqn = None
+
+ if parallel_state.fsdp_enabled:
+ logger.info_rank0(f"Apply data parallel to the model: {parallel_state.dp_mode}.")
+ if parallel_state.dp_mode == "fsdp2":
+ fsdp_kwargs = {
+ "mesh": parallel_state.fsdp_mesh,
+ "reshard_after_forward": enable_full_shard,
+ **kwargs.pop("fsdp_kwargs", {}),
+ }
+ if enable_mixed_precision and not enable_fp32:
+ logger.info_rank0("Enable mixed precision training.")
+ mp_policy = MixedPrecisionPolicy(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ output_dtype=torch.bfloat16,
+ )
+ fsdp_kwargs["mp_policy"] = mp_policy
+ elif enable_fp32:
+ mp_policy = MixedPrecisionPolicy(
+ param_dtype=torch.float32,
+ reduce_dtype=torch.float32,
+ output_dtype=torch.float32,
+ )
+ fsdp_kwargs["mp_policy"] = mp_policy
+ if ignore_norm:
+ ignored_modules = set()
+ for layer in model.model.qwenvl_with_expert.qwenvl.language_model.model.layers:
+ ignored_modules.add(layer.input_layernorm.weight)
+ ignored_modules.add(layer.post_attention_layernorm.weight)
+ for expert_layers in model.model.qwenvl_with_expert.qwen_expert.model.layers:
+ ignored_modules.add(expert_layers.input_layernorm.weight)
+ ignored_modules.add(expert_layers.post_attention_layernorm.weight)
+ fsdp_kwargs["ignored_params"] = ignored_modules
+
+ mp_fsdp_kwargs = {
+ "mesh": parallel_state.fsdp_mesh,
+ "reshard_after_forward": enable_full_shard,
+ **kwargs.pop("fsdp_kwargs", {}),
+ }
+ if use_depth_align and ignore_depth:
+ model.model.dav2_backbone.to(torch.bfloat16)
+ model.model.dav2_head.to(torch.bfloat16)
+ model.model.dav2_backbone.eval()
+ model.model.dav2_head.eval()
+
+ ignored_modules = set()
+ for param in model.model.dav2_backbone.parameters():
+ param.requires_grad = False
+ ignored_modules.add(param)
+ for param in model.model.dav2_head.parameters():
+ param.requires_grad = False
+ ignored_modules.add(param)
+ mp_fsdp_kwargs["ignored_params"] = ignored_modules
+
+ mp_fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ output_dtype=torch.bfloat16,
+ )
+ ignore_modules_in_mixed_precision = tuple()
+ if hasattr(model, "get_ignore_modules_in_mixed_precision"):
+ ignore_modules_in_mixed_precision = model.get_ignore_modules_in_mixed_precision()
+
+ def apply_fsdp_to_decoder_blocks(module: "nn.Module") -> None:
+ if module.__class__.__name__ in basic_modules or module.__class__ in ignore_modules_in_mixed_precision:
+ logger.debug(f"Apply FSDP2 to {module.__class__.__name__}.")
+ if module.__class__ in ignore_modules_in_mixed_precision:
+ fully_shard(module, **{k: v for k, v in fsdp_kwargs.items() if k != "mp_policy"})
+ else:
+ fully_shard(module, **fsdp_kwargs)
+
+ if basic_modules:
+ model.apply(apply_fsdp_to_decoder_blocks)
+ elif fsdp_llm_blocks:
+ layers = model.model.qwenvl_with_expert.qwenvl.language_model.model.layers
+ expert_layers = model.model.qwenvl_with_expert.qwen_expert.model.layers
+ if not hasattr(layers, '__iter__') or not hasattr(expert_layers, '__iter__'):
+ raise TypeError("Expected 'layers' to be a module list or container.")
+ logger.info_rank0(f"Applying FSDP to {len(layers)} transformer layers in Paligemma and Gemma decoder.")
+ for i, layer in enumerate(layers):
+ logger.debug(f"Sharding layer {i} ({layer.__class__.__name__})")
+ fully_shard(layer, **fsdp_kwargs)
+ for i, layer in enumerate(expert_layers):
+ logger.debug(f"Sharding layer {i} ({layer.__class__.__name__})")
+ fully_shard(layer, **fsdp_kwargs)
+
+ fully_shard(model, **mp_fsdp_kwargs)
+
+ if kwargs.get("init_device") == "meta":
+ if weights_path is None:
+ # shard init empty model with fsdp2
+ model.to_empty(device="cuda")
+ model.init_weights()
+ else:
+ from torch.distributed.tensor import distribute_tensor
+
+ load_model_weights(model, weights_path, "cuda", dtensor_factory=distribute_tensor)
+
+ elif parallel_state.dp_mode == "fsdp1":
+ wrap_policy = partial(
+ lambda_auto_wrap_policy, lambda_fn=lambda module: module.__class__.__name__ in basic_modules
+ )
+
+ # set fsdp/hsdp sharding strategy
+ if parallel_state.fsdp_mesh.ndim > 1 and parallel_state.fsdp_mesh.size() > 1:
+ strategy = ShardingStrategy.HYBRID_SHARD
+ else:
+ strategy = ShardingStrategy.FULL_SHARD
+
+ fsdp_kwargs = {
+ "auto_wrap_policy": wrap_policy,
+ "ignored_states": fsdp_no_shard_states,
+ "device_id": torch.cuda.current_device(),
+ "sharding_strategy": strategy if enable_full_shard else ShardingStrategy.NO_SHARD,
+ "use_orig_params": True,
+ }
+
+ fsdp_kwargs["device_mesh"] = parallel_state.fsdp_mesh
+
+ fsdp_kwargs.update(kwargs.pop("fsdp_kwargs", {}))
+
+ if enable_mixed_precision:
+ logger.info_rank0("Enable mixed precision training.")
+ mixed_precision = MixedPrecision(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ buffer_dtype=torch.float32,
+ )
+ if hasattr(model, "get_ignore_modules_in_mixed_precision"):
+ mixed_precision._module_classes_to_ignore += model.get_ignore_modules_in_mixed_precision()
+
+ fsdp_kwargs["mixed_precision"] = mixed_precision
+
+ if kwargs.get("init_device") == "cpu":
+ logger.info_rank0("Enable rank0-only initialization.")
+ fsdp_kwargs["sync_module_states"] = True
+ if parallel_state.global_rank != 0:
+ fsdp_kwargs["param_init_fn"] = init_fsdp_fn(model, device="cuda")
+ elif kwargs.get("init_device") == "meta":
+ # assert weights_path is not None, "`weights_path` must be provided when `init_device=meta` for fsdp1."
+
+ logger.info_rank0("Enable meta initialization.")
+ if weights_path is None:
+ logger.info_rank0("weights_path is None during meta initialization.")
+
+ ignore_param_names = (
+ [".".join([fqn, k]) for fqn in fsdp_no_shard_states_fqn for k in ep_param_suffix]
+ if fsdp_no_shard_states_fqn is not None
+ else None
+ )
+ shard_states = (
+ parallel_load_safetensors(weights_path, ignore_param_name=ignore_param_names)
+ if weights_path
+ else kwargs.get("state_dict", {})
+ )
+ fsdp_kwargs["param_init_fn"] = parallel_init_fsdp_fn(
+ model, shard_states, ignore_param_name=ignore_param_names
+ )
+
+ if kwargs.pop("enable_fsdp_offload", False):
+ logger.info_rank0("Enable offloading for parameters & gradients & optimizer states.")
+ fsdp_kwargs["cpu_offload"] = CPUOffload(offload_params=True)
+
+ if kwargs.pop("enable_forward_prefetch", False):
+ fsdp_kwargs["forward_prefetch"] = True
+ else:
+ fsdp_kwargs["forward_prefetch"] = False
+ fsdp_kwargs["backward_prefetch"] = None
+
+ # FULLY_SHARD first
+ model = FullyShardedDataParallel(model, **fsdp_kwargs)
+
+ if fsdp_no_shard_states is not None:
+ # apply NO_SHARD the ignored_states, but wrap into DDP
+ if parallel_state.ep_fsdp_mesh["ep_fsdp"].size() == 1:
+ moe_sharding_strategy = ShardingStrategy.NO_SHARD
+ ep_fsdp_device_mesh = parallel_state.fsdp_mesh
+ else:
+ moe_sharding_strategy = ShardingStrategy.FULL_SHARD
+ ep_fsdp_device_mesh = parallel_state.ep_fsdp_mesh["ep_fsdp"]
+
+ logger.info_rank0(f"Apply {moe_sharding_strategy} states on '{fsdp_no_shard_states_fqn}'.")
+ fsdp_kwargs.pop("ignored_states", None)
+ fsdp_kwargs.pop("auto_wrap_policy", None)
+ fsdp_kwargs["sharding_strategy"] = moe_sharding_strategy
+ fsdp_kwargs["device_mesh"] = ep_fsdp_device_mesh
+ logger.info_rank0(f"{ep_fsdp_device_mesh=}")
+ for fqn in fsdp_no_shard_states_fqn:
+ no_shard_module = get_module_from_path(model, fqn)
+ if kwargs.get("init_device") == "meta":
+ specific_param_name = [".".join([fqn, k]) for k in ep_param_suffix]
+ shard_states = (
+ parallel_load_safetensors(weights_path, specific_param_name=specific_param_name)
+ if weights_path
+ else {}
+ )
+ if weights_path:
+ for suffix in ep_param_suffix:
+ shard_states[suffix] = shard_states.pop(".".join([fqn, suffix]))
+ fsdp_kwargs["param_init_fn"] = parallel_init_fsdp_fn(
+ no_shard_module, shard_states, specific_param_name=ep_param_suffix
+ )
+ fsdp_module = FullyShardedDataParallel(no_shard_module, **fsdp_kwargs)
+ fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(fsdp_module)
+ fsdp_state._gradient_postdivide_factor *= parallel_state.ep_size
+ set_module_from_path(model, fqn, fsdp_module)
+
+ _lazy_init(model, model)
+
+ # Apply fsdp extension to FSDP model
+ save_hook_mesh = parallel_state.ep_fsdp_device_mesh if parallel_state.ep_enabled else None
+ logger.info_rank0("Register Checkpoints Extension hook to the model")
+ register_checkpoint_extension(
+ fsdp_model=model,
+ save_hook_mesh=save_hook_mesh,
+ fqn2spec_info=fqn2spec_info,
+ )
+
+ if parallel_state.ep_enabled:
+ model.clip_grad_norm_ = types.MethodType(clip_grad_norm_, model)
+
+ verbose_fsdp_grouping(model)
+ else:
+ ddp_kwargs = {"device_ids": [parallel_state.local_rank]}
+ if enable_mixed_precision:
+ logger.info_rank0("Enable mixed precision training.")
+ if enable_fp32:
+ mixed_precision = MixedPrecision(
+ param_dtype=torch.float32,
+ reduce_dtype=torch.float32,
+ buffer_dtype=torch.float32,
+ )
+ else:
+ mixed_precision = MixedPrecision(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ buffer_dtype=torch.bfloat16,
+ )
+ ddp_kwargs["mixed_precision"] = mixed_precision
+
+ model = DDP(model, **ddp_kwargs)
+
+ return model
diff --git a/lingbotvla/distributed/utils.py b/lingbotvla/distributed/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..de87390c2502d222038c26f9dfe9a87b17f1148f
--- /dev/null
+++ b/lingbotvla/distributed/utils.py
@@ -0,0 +1,113 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import List
+
+import torch.nn as nn
+
+
+def set_module_from_path(model: nn.Module, path: str, value: any):
+ attrs = path.split(".")
+ if len(attrs) == 1:
+ setattr(model, attrs[0], value)
+ else:
+ next_obj = getattr(model, attrs[0])
+ set_module_from_path(next_obj, ".".join(attrs[1:]), value)
+
+
+def get_module_from_path(model: nn.Module, path: str):
+ attrs = path.split(".")
+ if len(attrs) == 1:
+ return getattr(model, attrs[0])
+ else:
+ next_obj = getattr(model, attrs[0])
+ return get_module_from_path(next_obj, ".".join(attrs[1:]))
+
+
+def check_all_fqn_match(path_patterns: List[str], path_keys: List[str]):
+ """
+ Check
+ """
+ assert isinstance(path_patterns, list), f"path_patterns must be a list, got {type(path_patterns)}"
+ assert isinstance(path_keys, (list, tuple)), f"path_keys must be a list or tuple, got {type(path_keys)}"
+
+ if len(path_patterns) != len(path_keys):
+ return False
+
+ regex_list = []
+ for pattern in path_patterns:
+ regex_str = re.escape(pattern).replace(r"\*", r"(\d+)")
+ regex_str = f"^{regex_str}$"
+ regex_list.append((pattern, re.compile(regex_str)))
+
+ used_patterns = set()
+ expected_num = None # the first matched number
+
+ for key in path_keys:
+ matched = False
+ for p, regex in regex_list:
+ if p in used_patterns:
+ continue
+ match = regex.match(key)
+ if match:
+ current_num = match.group(1)
+ if expected_num is None:
+ expected_num = current_num
+ elif current_num != expected_num:
+ return False
+ used_patterns.add(p)
+ matched = True
+ break
+ if not matched:
+ return False
+
+ return True
+
+
+def check_any_fqn_match(path_patterns: List[str], path_key: str, return_idx: bool = False, prefix: str = None):
+ assert isinstance(path_patterns, list), f"path_patterns must be a list, got {type(path_patterns)}"
+ assert isinstance(path_key, str), f"path_key must be a str, got {type(path_key)}"
+
+ if prefix:
+ path_patterns = [".".join([prefix, pattern]) for pattern in path_patterns]
+
+ regex_list = []
+ for pattern in path_patterns:
+ regex_str = re.escape(pattern).replace(r"\*", r"(\d+)")
+ regex_str = f"^{regex_str}$"
+ regex_list.append(re.compile(regex_str))
+
+ for idx, regex in enumerate(regex_list):
+ match = regex.match(path_key)
+ if match:
+ return idx if return_idx else True
+
+ return -1 if return_idx else False
+
+
+def check_fqn_match(fqn_pattern: str, fqn: str, prefix: str = None):
+ assert isinstance(fqn_pattern, str), f"fqn_pattern must be a str, got {type(fqn_pattern)}"
+ assert isinstance(fqn, str), f"fqn must be a str, got {type(fqn)}"
+
+ if prefix:
+ fqn_pattern = [".".join([prefix, pattern]) for pattern in fqn_pattern]
+
+ regex_str = re.escape(fqn_pattern).replace(r"\*", r"(\d+)")
+ regex_str = f"^{regex_str}$"
+ regex = re.compile(regex_str)
+
+ match = regex.match(fqn)
+
+ return match
diff --git a/lingbotvla/distributed/vescale_parallelize.py b/lingbotvla/distributed/vescale_parallelize.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc8599c7201fece56c66916913d58644334f57ce
--- /dev/null
+++ b/lingbotvla/distributed/vescale_parallelize.py
@@ -0,0 +1,139 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import gc
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+import torch
+
+from ..utils import logging
+from .parallel_state import get_parallel_state
+
+
+if TYPE_CHECKING:
+ from torch import nn
+ from vescale import DeviceMesh
+
+
+logger = logging.get_logger(__name__)
+
+
+def build_parallelize_model(
+ model: "nn.Module",
+ dp_mode: str,
+ hf_weight_path: Optional[str] = None,
+ enable_full_shard: bool = True,
+ enable_fsdp_offload: bool = False,
+ enable_mixed_precision: bool = True,
+ enable_gradient_checkpointing: bool = True,
+ basic_modules: Optional[List[str]] = None,
+ enable_reentrant: bool = True,
+ use_pin_mem_for_offload: bool = True,
+) -> Tuple["nn.Module", "DeviceMesh"]:
+ """
+ Build a parallelized model with Vescale.
+ """
+ logger.info_rank0("Apply vescale parallel to the model.")
+ parallel_state = get_parallel_state()
+
+ assert dp_mode in ["fsdp2", "fsdp2-vescale"]
+ params_stored_in_dtensor = dp_mode == "fsdp2"
+ mesh = parallel_state.fsdp_mesh
+
+ if enable_mixed_precision:
+ model.float()
+
+ module_init_fn = lambda sub_mod, *_: sub_mod # noqa: E731
+ if hf_weight_path is not None:
+ from vescale.initialize.hf_utils import parallel_init_module_fn, parallel_load_safetensors
+
+ shard_states = parallel_load_safetensors(hf_weight_path)
+ module_init_fn = parallel_init_module_fn(model, shard_states)
+
+ from vescale import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy, fully_shard
+
+ if enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
+ logger.info_rank0("Enable gradient checkpointing.")
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": enable_reentrant})
+
+ # mp policy
+ mp_policy = MixedPrecisionPolicy()
+ if enable_mixed_precision:
+ mp_policy = MixedPrecisionPolicy(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ output_dtype=torch.bfloat16,
+ )
+
+ # cpu off load policy
+ cpu_offload_policy = OffloadPolicy()
+ if enable_fsdp_offload:
+ cpu_offload_policy = CPUOffloadPolicy(pin_memory=use_pin_mem_for_offload)
+
+ last_fsdp_module = None
+ for module in model.modules():
+ sub_mod_cls_name = module.__class__.__name__
+ if (sub_mod_cls_name in basic_modules) or (sub_mod_cls_name in model._no_split_modules):
+ module_init_fn(module)
+ if enable_fsdp_offload:
+ module.cpu()
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ model.cuda()
+ fully_shard(
+ module,
+ mesh=mesh,
+ reshard_after_forward=enable_full_shard,
+ mp_policy=mp_policy,
+ params_stored_in_dtensor=params_stored_in_dtensor,
+ offload_policy=cpu_offload_policy,
+ )
+ # explicit prefetch
+ if last_fsdp_module is not None:
+ last_fsdp_module.set_modules_to_forward_prefetch([module])
+ module.set_modules_to_backward_prefetch([last_fsdp_module])
+ last_fsdp_module = module
+
+ module_init_fn(model)
+ model = fully_shard(
+ model,
+ mesh=mesh,
+ reshard_after_forward=enable_full_shard,
+ mp_policy=mp_policy,
+ params_stored_in_dtensor=params_stored_in_dtensor,
+ offload_policy=cpu_offload_policy,
+ )
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # NOTE: uncomment below for saving memory fragmentation
+ model._set_unshard_async_op(True)
+
+ # for root module, we don't need to reshard after backward since forward will imediately use it
+ # model.set_reshard_after_backward(False, recurse=False)
+ # NOTE: the above line is WRONG in torch-native fsdp2's senmantic, as resulting logic follows:
+ # -) after backward, it is gradient clip to normalize model.parameters()'s grad
+ # -) at this time, model.parameters is unsharded param, which has already moved .grad to shard_param.grad, so unshard param.grad is always None
+ # -) then None grad disable gradient clip, which is WRONG!
+ # -) Even if we have no clip gradient, the optimizer step gives updated weight, which is never used in the next forward; as optimizer step only updates sharded_param, not unshard param
+ # -) but next forward of root is already in unsharded state, so never allgather from updated sharded param, which is WRONG again!
+
+ if not hasattr(mesh, "ndevice"):
+ # bytecheckpoint vescale ckpt use vescale device mesh, but here we have torch-native devicemesh, which does not have ndevice attribute
+ ndevice_func = lambda self: torch.numel(self.mesh) # noqa: E731
+ mesh.__class__.ndevice = property(ndevice_func)
+
+ return model, mesh
diff --git a/lingbotvla/distributed/vescale_plan.py b/lingbotvla/distributed/vescale_plan.py
new file mode 100644
index 0000000000000000000000000000000000000000..01c1d4cde8f889872c5c11379b6e6cf787a67105
--- /dev/null
+++ b/lingbotvla/distributed/vescale_plan.py
@@ -0,0 +1,107 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from vescale.plan import ParallelType, VescalePlan
+
+
+# TODO: add more model type
+SET_TP_SHARD_PLAN_FUNC = {}
+
+
+def build_vescale_plan(
+ model_config: dict,
+ tp_size: int = 1,
+ pp_size: int = 1,
+ use_doptim: bool = False,
+ use_fsdp: bool = False,
+ use_manual_eager: bool = False,
+ use_mixed_precision: bool = True,
+ clip_grad: float = 0.0,
+):
+ """Build parallel plan for P6 model.
+
+ Args:
+ model_config (dict): model config dict
+ tp_size (int): size of tensor parallelism
+ pp_size (int): size of pipeline parallelism
+ use_doptim (bool): whether to use DistributedOptimizer (zero)
+ use_manual_eager (bool): whether to use manual eager for tensor parallelism
+ use_mixed_precision (bool): whether to enable mixed precision, where parameters will be saved
+ and updated in additional fp32 copy, and gradients will be accumulated with fp32.
+ clip_grad (float): gradient clipping threshould
+ """
+ model_type = model_config.model_type
+ if use_doptim and use_fsdp:
+ raise RuntimeError("Cannot simutaneously use FSDP and DistributedOptimizer.")
+
+ plan = VescalePlan()
+
+ # get device mesh
+ ngpus = torch.distributed.get_world_size()
+ if ngpus % (tp_size * pp_size) != 0:
+ raise ValueError("total gpu number must be divisible by tp_size * pp_size ")
+ if pp_size > 1:
+ raise NotImplementedError("pp size only support 1")
+ dp_size = ngpus // (tp_size * pp_size)
+ print(f"creating {tp_size} tp, {pp_size} pp, {dp_size} dp...")
+
+ mesh = {}
+ # setup dp mesh
+ dp_name = ParallelType.FSDP if use_fsdp else ParallelType.DP
+ mesh[dp_name] = dp_size
+ # setup tp mesh
+ if tp_size > 1:
+ mesh[ParallelType.TP] = tp_size
+ # setup pp mesh
+ if pp_size > 1:
+ mesh[ParallelType.PP] = pp_size
+
+ plan.set_global_mesh("cuda", mesh)
+
+ # tensor parallel
+ if tp_size > 1:
+ plan = SET_TP_SHARD_PLAN_FUNC[model_type](plan, tp_size, model_config, use_manual_eager)
+
+ # dist optimizer: this must go before setting up data parallel
+ # due to `use_distributed_optimizer field`
+ if use_doptim:
+ plan.dist_optimizer(grad_to_fp32=use_mixed_precision, overlap_param_gather=False, clip_grad=clip_grad)
+
+ # data parallel fsdp / ddp
+ if use_fsdp:
+ if use_doptim:
+ raise ValueError("fsdp and doptim can not be used together")
+ if tp_size > 1:
+ raise NotImplementedError("vescale FSDP cannot work with TP for now")
+
+ from vescale.fsdp.api import MixedPrecision, ShardingStrategy
+
+ mp = None
+ if use_mixed_precision:
+ mp = MixedPrecision(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.float32,
+ )
+ plan.dist_fsdp(
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
+ mixed_precision=mp,
+ )
+ else:
+ plan.dist_data_parallel(
+ grad_in_fp32=use_mixed_precision,
+ overlap_grad_reduce=False,
+ )
+ return plan
diff --git a/lingbotvla/models/__init__.py b/lingbotvla/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03f3b416da75a12fd277037afb32756aa7916e9e
--- /dev/null
+++ b/lingbotvla/models/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .auto import build_foundation_model, build_processor, build_tokenizer
+from .module_utils import init_empty_weights, load_model_weights, save_model_assets, save_model_weights
+
+
+__all__ = [
+ "build_foundation_model",
+ "build_processor",
+ "build_tokenizer",
+ "init_empty_weights",
+ "load_model_weights",
+ "save_model_assets",
+ "save_model_weights",
+]
diff --git a/lingbotvla/models/auto.py b/lingbotvla/models/auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2e2a67f49c3427523e277d0c26242444c7cff6d
--- /dev/null
+++ b/lingbotvla/models/auto.py
@@ -0,0 +1,152 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import TYPE_CHECKING, Any, Dict, Literal, Optional
+
+import torch
+from transformers import (
+ AutoConfig,
+ AutoProcessor,
+ AutoTokenizer,
+ PreTrainedModel,
+)
+from lerobot.configs.policies import PreTrainedConfig
+from ..distributed.parallel_state import get_parallel_state
+from ..utils import logging
+from .loader import BaseModelLoader, get_loader
+
+if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizer, ProcessorMixin
+
+logger = logging.get_logger(__name__)
+
+
+def build_tokenizer(tokenizer_path: str) -> "PreTrainedTokenizer":
+ """
+ Builds the tokenizer.
+ """
+ return AutoTokenizer.from_pretrained(tokenizer_path, padding_side="right", trust_remote_code=True)
+
+
+def build_processor(processor_path: str) -> "ProcessorMixin":
+ """
+ Builds the processor.
+ """
+ return AutoProcessor.from_pretrained(processor_path, padding_side="right", trust_remote_code=True)
+
+
+def build_foundation_model(
+ config_path: str,
+ weights_path: Optional[str] = None,
+ torch_dtype: Literal["float16", "bfloat16", "float32"] = "bfloat16",
+ attn_implementation: Optional[Literal["eager", "sdpa", "flash_attention_2", "flex"]] = "flash_attention_2",
+ moe_implementation: Optional[Literal["eager", "fused"]] = None,
+ init_device: Literal["cpu", "cuda", "meta"] = "cuda",
+ freeze_vision_encoder: Optional[bool] = False,
+ tokenizer_max_length: Optional[int] = 48,
+ vocab_size: Optional[int] = 0,
+ use_lm_head: Optional[bool] = False,
+ config_kwargs: Optional[Dict[str, Any]] = None,
+ force_use_huggingface: bool = False,
+) -> "PreTrainedModel":
+ """
+ Builds the foundation model.
+
+ If weights_path is provided, it loads the pre-trained weights, otherwise it initializes weights.
+ """
+ if config_kwargs is None:
+ config_kwargs = {}
+ vlm_repo_id = config_kwargs['vlm_repo_id'] if 'vlm_repo_id' in config_kwargs else None
+ expert_vision_path = config_kwargs['expert_vision_path'] if 'expert_vision_path' in config_kwargs else None
+ tokenizer_path = config_kwargs['tokenizer_path'] if 'tokenizer_path' in config_kwargs else None
+ post_training = config_kwargs['post_training']
+ adanorm_time = config_kwargs['adanorm_time']
+ assert not (config_kwargs['split_gate_liner'] and config_kwargs['nosplit_gate_liner']), 'split_gate_liner and nosplit_gate_liner can not be both True.'
+ enable_expert_vision = config_kwargs['enable_expert_vision']
+ incremental_training = config_kwargs['incremental_training']
+ depth_incremental_training = config_kwargs['depth_incremental_training']
+ norm_qkv = config_kwargs['norm_qkv']
+ loss_type = config_kwargs['loss_type']
+ config = PreTrainedConfig.from_pretrained(config_path)
+ config.train_state_proj = True
+ config.adanorm_time = adanorm_time
+ config.split_gate_liner = config_kwargs['split_gate_liner']
+ config.nosplit_gate_liner = config_kwargs['nosplit_gate_liner']
+ config.separate_time_proj = config_kwargs['separate_time_proj']
+ config.old_adanorm = config_kwargs['old_adanorm']
+ config.final_norm_adanorm = config_kwargs['final_norm_adanorm']
+ config.freeze_vision_encoder = freeze_vision_encoder
+ config.tokenizer_max_length = tokenizer_max_length
+ config.attention_implementation = 'flex' # TODO
+ config.enable_expert_vision = config_kwargs['enable_expert_vision']
+ config.expert_vision_type = config_kwargs['expert_vision_type']
+ config.action_dim = config_kwargs['action_dim']
+ config.max_action_dim = config_kwargs['max_action_dim']
+ config.max_state_dim = config_kwargs['max_state_dim']
+ config.n_action_steps = config_kwargs['chunk_size']
+ config.vlm_repo_id = vlm_repo_id
+ config.expert_vision_path = expert_vision_path
+ config.tokenizer_path = tokenizer_path
+ config.loss_type = loss_type
+ config.align_params = config_kwargs['align_params']
+ config.norm_qkv = config_kwargs['norm_qkv']
+ config.use_lm_head = use_lm_head
+ if vocab_size == 0:
+ if vlm_repo_id and 'paligemma' in vlm_repo_id:
+ config.vocab_size = 257216
+ # elif vlm_repo_id and 'qwen' in vlm_repo_id.lower() and 'fast' in vlm_repo_id.lower():
+ # config.vocab_size = 153715
+ elif vlm_repo_id and 'qwen' in vlm_repo_id.lower():
+ config.vocab_size = 151936
+ else:
+ config.vocab_size = 257152
+ else:
+ config.vocab_size = vocab_size
+
+ if moe_implementation is not None:
+ if moe_implementation not in ["eager", "fused"]:
+ raise ValueError(f"Invalid moe_implementation: {moe_implementation}")
+ config._moe_implementation = moe_implementation
+ logger.info_rank0(f"Moe implementation: {moe_implementation}")
+
+ loader: Optional[BaseModelLoader] = get_loader(config, force_use_huggingface)
+ init_kwargs = {
+ "config": config,
+ "torch_dtype": getattr(torch, torch_dtype),
+ "attn_implementation": attn_implementation,
+ "ckpt_path": weights_path,
+ "trust_remote_code": True,
+ }
+
+ if (init_device == "cpu" and get_parallel_state().global_rank != 0) or init_device == "meta":
+ empty_init = True
+ else:
+ empty_init = False
+ weights_path = vlm_repo_id if vlm_repo_id else weights_path
+ model = loader.load_model(
+ init_kwargs=init_kwargs,
+ weights_path=weights_path,
+ empty_init=empty_init,
+ init_device=init_device,
+ vlm_repo_id=vlm_repo_id,
+ expert_vision_path=expert_vision_path,
+ post_training=post_training,
+ adanorm_time=adanorm_time,
+ incremental_training=incremental_training,
+ depth_incremental_training=depth_incremental_training,
+ norm_qkv=norm_qkv,
+ enable_expert_vision=enable_expert_vision,
+ )
+ return model
diff --git a/lingbotvla/models/loader.py b/lingbotvla/models/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..9006172f37e3c9614f05edeeecdb812525a385f4
--- /dev/null
+++ b/lingbotvla/models/loader.py
@@ -0,0 +1,190 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_loader/loader.py
+
+from abc import ABC
+
+import torch
+from transformers import AutoModel, AutoModelForCausalLM, AutoModelForVision2Seq, PreTrainedModel
+from transformers.modeling_utils import no_init_weights
+from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
+from ..utils import logging
+from ..utils.import_utils import is_torch_npu_available, is_vescale_available
+from .module_utils import init_empty_weights, load_model_weights
+from .registry import get_registry
+
+
+logger = logging.get_logger(__name__)
+
+
+class BaseModelLoader(ABC):
+ def __init__(self):
+ pass
+
+ def load_model(self, model_config, **kwargs):
+ raise NotImplementedError
+
+
+class HuggingfaceLoader(BaseModelLoader):
+ def __init__(self):
+ super().__init__()
+
+ def load_model(self, init_kwargs: dict, **kwargs):
+ model_config = init_kwargs["config"]
+ architecture = _get_model_arch_from_config(model_config)
+
+ if type(model_config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
+ load_class = AutoModelForVision2Seq
+ elif "ForCausalLM" in architecture and type(model_config) in AutoModelForCausalLM._model_mapping.keys():
+ load_class = AutoModelForCausalLM
+ else:
+ load_class = AutoModel
+
+ init_device = kwargs.pop("init_device", "cuda")
+ weights_path = kwargs.pop("weights_path", None)
+ empty_init = kwargs.pop("empty_init", False)
+
+ logger.info_rank0(
+ f"Loading model from Huggingface modeling.\n"
+ f"init_device: {init_device}\n"
+ f"empty_init: {empty_init}\n"
+ f"weights_path: {weights_path}"
+ )
+
+ if weights_path is None: # init empty model from config
+ if is_torch_npu_available() and init_device == "cuda":
+ init_device = "npu"
+ if init_device == "meta":
+ with torch.device(init_device), no_init_weights():
+ logger.info_rank0("Init empty model on meta device from config without init_weights.")
+ model = load_class.from_config(**init_kwargs)
+ else:
+ with torch.device(init_device):
+ logger.info_rank0("Init empty model from config.")
+ model = load_class.from_config(**init_kwargs)
+ else:
+ if is_vescale_available() and init_device == "meta":
+ from vescale.initialize.meta_init import meta_device_init
+
+ with meta_device_init():
+ model = load_class.from_config(**init_kwargs)
+ else:
+ with init_empty_weights(), no_init_weights():
+ model = load_class.from_config(**init_kwargs)
+ if not empty_init:
+ load_model_weights(model, weights_path, init_device)
+
+ return model
+
+
+class CustomizedModelingLoader(BaseModelLoader):
+ def __init__(self, model_cls: PreTrainedModel):
+ super().__init__()
+ self.model_cls = model_cls # model class from code_path
+
+ def load_model(self, init_kwargs: dict, **kwargs):
+ init_kwargs.pop("trust_remote_code", True)
+
+ init_device = kwargs.pop("init_device", "cuda")
+ weights_path = kwargs.pop("weights_path", None)
+ empty_init = kwargs.pop("empty_init", False)
+ vlm_repo_id = kwargs.pop("vlm_repo_id", None)
+ enable_expert_vision = kwargs.pop("enable_expert_vision", False)
+ expert_vision_path = kwargs.pop("expert_vision_path", None)
+ post_training = kwargs.pop("post_training", False)
+ adanorm_time = kwargs.pop("adanorm_time", False)
+ incremental_training = kwargs.pop("incremental_training", False)
+ depth_incremental_training = kwargs.pop("depth_incremental_training", False)
+ norm_qkv = kwargs.pop("norm_qkv", False)
+
+ logger.info_rank0(
+ f"Loading model from customized modeling.\n"
+ f"init_device: {init_device}\n"
+ f"empty_init: {empty_init}\n"
+ f"weights_path: {weights_path}"
+ )
+
+ if weights_path is None: # init empty model from config
+ if is_torch_npu_available() and init_device == "cuda":
+ init_device = "npu"
+ if init_device == "meta":
+ with torch.device(init_device), no_init_weights():
+ logger.info_rank0("Init empty model on meta device from config without init_weights.")
+ model = self.model_cls._from_config(**init_kwargs)
+ else:
+ with torch.device(init_device):
+ logger.info_rank0("Init empty model from config.")
+ model = self.model_cls._from_config(**init_kwargs)
+ else:
+ load_vlm_only = False
+ if is_vescale_available() and init_device == "meta":
+ from vescale.initialize.meta_init import meta_device_init
+
+ with meta_device_init():
+ model = self.model_cls._from_config(**init_kwargs)
+ else:
+ with init_empty_weights(), no_init_weights():
+ if (self.model_cls.__name__ == "PI0Policy" and
+ self.model_cls.__module__ == "lingbotvla.models.vla.pi0.modeling_pi0"):
+ model = self.model_cls(config=init_kwargs['config'], tokenizer_path=init_kwargs['config'].tokenizer_path).to(init_kwargs['torch_dtype'])
+ if vlm_repo_id is not None:
+ load_vlm_only = True
+ elif (self.model_cls.__name__ == "LingbotVlaPolicy" and
+ self.model_cls.__module__ == "lingbotvla.models.vla.pi0.modeling_lingbot_vla"):
+ model = self.model_cls(config=init_kwargs['config'], tokenizer_path=init_kwargs['config'].tokenizer_path).to(init_kwargs['torch_dtype'])
+ if vlm_repo_id is not None and incremental_training:
+ load_vlm_only = True
+ else:
+ model = self.model_cls._from_config(**init_kwargs)
+
+ if not empty_init:
+ load_model_weights(model, weights_path, init_device, load_vlm_only=load_vlm_only, enable_expert_vision=enable_expert_vision, expert_vision_path=expert_vision_path, post_training=post_training, incremental_training=incremental_training, depth_incremental_training=depth_incremental_training, norm_qkv=norm_qkv, adanorm_time=adanorm_time)
+
+ # we should tie embeddings after loading weights because init_empty_weights() leads to untied weights,
+ if getattr(model.config, "tie_word_embeddings", True):
+ try:
+ input_embeddings = model.get_input_embeddings()
+ output_embeddings = model.get_output_embeddings()
+ output_embeddings._parameters["weight"] = input_embeddings._parameters["weight"]
+ except Exception as e:
+ logger.info_rank0(f"Failed to tie embeddings: {e}")
+
+ return model
+
+
+def _get_model_arch_from_config(model_config):
+ arch_name = model_config.architectures
+ if isinstance(arch_name, list):
+ arch_name = arch_name[0]
+ return arch_name
+
+
+def get_loader(model_config, force_use_huggingface):
+ if isinstance(model_config, PI0Config):
+ if 'qwen' not in model_config.tokenizer_path.lower():
+ model_arch = 'PI0Policy'
+ elif 'qwen2' in model_config.tokenizer_path.lower():
+ model_arch = 'LingbotVlaPolicy'
+ else:
+ model_arch = _get_model_arch_from_config(model_config) # Qwen2VLForConditionalGeneration
+ loader = HuggingfaceLoader()
+ if not force_use_huggingface:
+ model_registry = get_registry()
+ if model_arch in model_registry.supported_models:
+ model_cls = model_registry.get_model_cls_from_model_arch(model_arch)
+ loader = CustomizedModelingLoader(model_cls=model_cls)
+
+ return loader
diff --git a/lingbotvla/models/modeling_layers.py b/lingbotvla/models/modeling_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..57be2d8e0d7dc7af50f847f699bf80b707bbb98e
--- /dev/null
+++ b/lingbotvla/models/modeling_layers.py
@@ -0,0 +1,48 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+
+import torch.nn as nn
+
+
+class GradientCheckpointingLayer(nn.Module):
+ """Base class for layers with gradient checkpointing.
+
+ This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
+ (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
+ enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
+
+ Important:
+
+ When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
+ must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
+
+ Example:
+
+ ```python
+ >>> # Correct - hidden_states passed as positional arg
+ >>> out = self.layer(hidden_states, attention_mask=attention_mask)
+
+ >>> # Incorrect - hidden_states passed as keyword arg
+ >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
+ ```
+ """
+
+ gradient_checkpointing = False
+
+ def __call__(self, *args, **kwargs):
+ if self.gradient_checkpointing and self.training:
+ return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
+ return super().__call__(*args, **kwargs)
diff --git a/lingbotvla/models/module_utils.py b/lingbotvla/models/module_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..074d838b05ce61cd69646b41675f670ce00aea0d
--- /dev/null
+++ b/lingbotvla/models/module_utils.py
@@ -0,0 +1,494 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+import os
+from collections import OrderedDict
+from contextlib import contextmanager
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Optional, Sequence, Tuple, Union
+
+import torch
+from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME as DIFFUSERS_SAFE_WEIGHTS_INDEX_NAME
+from diffusers.utils import SAFETENSORS_WEIGHTS_NAME as DIFFUSERS_SAFETENSORS_WEIGHTS_NAME
+from torch import distributed as dist
+from torch import nn
+from tqdm import tqdm
+from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
+from transformers.utils.hub import cached_file, get_checkpoint_shard_files
+from transformers.utils.import_utils import is_safetensors_available
+
+from ..utils import logging
+from ..utils.helper import empty_cache, get_dtype_size
+
+
+if is_safetensors_available():
+ from safetensors import safe_open
+ from safetensors.torch import save_file
+
+
+if TYPE_CHECKING:
+ from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
+
+ ModelAssets = Union[GenerationConfig, PretrainedConfig, PreTrainedTokenizer, ProcessorMixin]
+
+
+logger = logging.get_logger(__name__)
+
+
+@contextmanager
+def init_empty_weights():
+ """
+ A context manager under which models are initialized with all parameters on the meta device.
+
+ Borrowed from: https://github.com/huggingface/accelerate/blob/v1.0.0rc1/src/accelerate/big_modeling.py#L57
+ """
+ old_register_parameter = nn.Module.register_parameter
+
+ def register_empty_parameter(module: "nn.Module", name: str, param: "nn.Parameter"):
+ old_register_parameter(module, name, param)
+ if param is not None:
+ param_cls = type(module._parameters[name])
+ kwargs = module._parameters[name].__dict__
+ kwargs["requires_grad"] = param.requires_grad
+ module._parameters[name] = param_cls(module._parameters[name].to("meta"), **kwargs)
+
+ try:
+ nn.Module.register_parameter = register_empty_parameter
+ yield
+ finally:
+ nn.Module.register_parameter = old_register_parameter
+
+
+@dataclass
+class StateDictIterator:
+ filepath: str
+ prefix: str = ''
+
+ def __iter__(self) -> Generator[Tuple[str, "torch.Tensor"], None, None]:
+ if self.filepath.endswith(".safetensors"):
+ with safe_open(self.filepath, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ yield key, f.get_tensor(key)
+
+ else:
+ state_dict = torch.load(self.filepath, map_location="cpu", weights_only=True, mmap=True)
+ for key in state_dict.keys():
+ yield key, state_dict[key]
+
+
+def _load_state_dict(weights_path: str, expert_vision_path: str | None = None, **kwargs) -> List["StateDictIterator"]:
+ """
+ Loads (sharded) state dict in transformers' format.
+ """
+ cache_kwargs = {"_raise_exceptions_for_missing_entries": False, **kwargs}
+ resolved_weight_file = cached_file(weights_path, SAFE_WEIGHTS_NAME, **cache_kwargs)
+ if resolved_weight_file:
+ return [StateDictIterator(resolved_weight_file)]
+
+ resolved_weight_file = cached_file(weights_path, SAFE_WEIGHTS_INDEX_NAME, **cache_kwargs)
+ if resolved_weight_file:
+ if expert_vision_path is not None:
+ shard_files, _ = get_checkpoint_shard_files(expert_vision_path, resolved_weight_file, **kwargs)
+ else:
+ shard_files, _ = get_checkpoint_shard_files(weights_path, resolved_weight_file, **kwargs)
+ return [StateDictIterator(shard_file) for shard_file in shard_files]
+
+ resolved_weight_file = cached_file(weights_path, DIFFUSERS_SAFETENSORS_WEIGHTS_NAME, **cache_kwargs)
+ if resolved_weight_file:
+ return [StateDictIterator(resolved_weight_file)]
+
+ resolved_weight_file = cached_file(weights_path, DIFFUSERS_SAFE_WEIGHTS_INDEX_NAME, **cache_kwargs)
+ if resolved_weight_file:
+ shard_files, _ = get_checkpoint_shard_files(weights_path, resolved_weight_file, **kwargs)
+ return [StateDictIterator(shard_file) for shard_file in shard_files]
+
+ resolved_weight_file = cached_file(weights_path, WEIGHTS_NAME, **cache_kwargs)
+ if resolved_weight_file:
+ return [StateDictIterator(resolved_weight_file)]
+
+ resolved_weight_file = cached_file(weights_path, WEIGHTS_INDEX_NAME, **cache_kwargs)
+ if resolved_weight_file:
+ shard_files, _ = get_checkpoint_shard_files(weights_path, resolved_weight_file, **kwargs)
+ return [StateDictIterator(shard_file) for shard_file in shard_files]
+
+ raise ValueError(f"Cannot find checkpoint files in {weights_path}.")
+
+
+def _find_submodule(module: "nn.Module", name: str) -> Tuple["nn.Module", str]:
+ """
+ Finds the leaf module according to the name.
+ """
+ pieces = name.split(".")
+ for piece in pieces[:-1]:
+ if not hasattr(module, piece):
+ raise ValueError(f"Cannot find {piece} in {module}.")
+
+ module = getattr(module, piece)
+
+ return module, pieces[-1]
+
+
+def _dispatch_parameter(
+ module: "nn.Module",
+ name: str,
+ tensor: "torch.Tensor",
+ dtensor_factory: Optional[Callable[["torch.Tensor", Any, Any], "torch.Tensor"]] = None,
+) -> None:
+ """
+ Assigns parameter to an empty model.
+
+ NOTE: FSDP module must use in-place operators.
+ """
+ module, name = _find_submodule(module, name)
+ orig_tensor = module._parameters[name].data
+ tensor = tensor.to(orig_tensor)
+ if hasattr(orig_tensor, "device_mesh"): # dtensor
+ if orig_tensor.device.type == "cpu":
+ raise ValueError("Cannot load dtensor on CPU.")
+
+ device_mesh = getattr(orig_tensor, "device_mesh")
+ placements = getattr(orig_tensor, "placements")
+ module._parameters[name].data.copy_(dtensor_factory(tensor, device_mesh, placements))
+ else: # not dtensor
+ module._parameters[name].data.copy_(tensor)
+
+
+def _dispatch_buffer(
+ module: "nn.Module",
+ name: str,
+ buffer: "torch.Tensor",
+) -> None:
+ """
+ Assigns buffer to an empty model.
+ """
+ module, name = _find_submodule(module, name)
+ orig_tensor = module._buffers[name].data
+ module._buffers[name] = buffer.to(orig_tensor)
+
+
+def _init_parameter(
+ module: "nn.Module",
+ name: str,
+) -> None:
+ """
+ Initializes parameter in model.
+ """
+ pieces = name.split(".")
+ init_func = None
+
+ for piece in pieces[:-1]:
+ if not hasattr(module, piece):
+ raise ValueError(f"Cannot find {piece} in {module}.")
+
+ if hasattr(module, "_init_weights"):
+ init_func = getattr(module, "_init_weights")
+
+ module = getattr(module, piece)
+
+ if init_func is None:
+ print(module)
+ raise ValueError(f"Cannot retrieve `_init_weights` function in the parents of {module}.")
+
+ module.apply(init_func)
+
+def get_model_prefix(parameter_names):
+ vlm_prefix = ''
+ for param_name in parameter_names:
+ parts = param_name.split('.')
+ if parts[1]=='qwenvl_with_expert' and 'expert' not in parts[2]:
+ vlm_prefix='.'.join(parts[:3])+'.'
+ break
+ return vlm_prefix
+
+@torch.no_grad()
+def load_model_weights(
+ model: Union["nn.Module", "PreTrainedModel"],
+ weights_path: str,
+ init_device: Literal["cpu", "cuda"] = "cuda",
+ dtensor_factory: Optional[Callable[["torch.Tensor", Any, Any], "torch.Tensor"]] = None,
+ load_vlm_only: bool = False,
+ enable_expert_vision: bool = False,
+ expert_vision_path: str | None = None,
+ post_training: bool = False,
+ incremental_training: bool = False,
+ depth_incremental_training: bool = False,
+ norm_qkv: bool = False,
+ adanorm_time: bool = False,
+) -> None:
+ """
+ Loads pre-trained model states in transformers' format.
+ """
+ buffer_dict = {name: buffer.clone() for name, buffer in model.named_buffers()}
+ parameter_names = {name for name, _ in model.named_parameters()}
+ vlm_parameter_names = {name for name, _ in model.model.qwenvl_with_expert.qwenvl.named_parameters()}
+ print(f'====vlm contains {len(vlm_parameter_names)} paras=====')
+ if expert_vision_path is not None or enable_expert_vision:
+ dino_parameter_names = {name for name, _ in model.model.qwenvl_with_expert.expert_visual.named_parameters()}
+ print(f'====dino contains {len(dino_parameter_names)} paras=====')
+ model.to_empty(device=init_device)
+ if post_training:
+ logger.info_rank0(f">>> Doing Post-Training now, no need to load LLM's embedding weight.")
+ elif incremental_training:
+ logger.info_rank0(f">>> Load pretrained weights for incremental training.")
+ elif load_vlm_only:
+ logger.info_rank0(f">>> Doing Pre-Training now.")
+ else:
+ logger.info_rank0(f">>> Fine-tuneing based on PI0 now.")
+ # TODO
+ state_dict_iterators = _load_state_dict(weights_path, expert_vision_path)
+ vlm_perfix = get_model_prefix(parameter_names) if load_vlm_only else ''
+ for state_dict_iterator in tqdm(
+ state_dict_iterators, desc="Loading checkpoint shards", disable=int(os.getenv("LOCAL_RANK", "-1")) > 0
+ ):
+ for name, tensor in state_dict_iterator:
+ if 'expert_visual.' in name and not post_training:
+ name = 'model.qwenvl_with_expert.'+name
+ else:
+ name = vlm_perfix+name
+ if name in buffer_dict.keys(): # persistent buffers
+ buffer_dict[name] = tensor.clone()
+ elif name in parameter_names:
+ if incremental_training:
+ try:
+ _dispatch_parameter(model, name, tensor, dtensor_factory)
+ parameter_names.remove(name)
+ except:
+ logger.info_rank0(f">>>The {name} weight need to be reinitialized.")
+ else:
+ parameter_names.remove(name)
+ _dispatch_parameter(model, name, tensor, dtensor_factory)
+ else:
+ if post_training:
+ error_msg = f"Unexpected key '{name}' found in state dict during Post-Training. This is not allowed!!!"
+ logger.info_rank0(error_msg)
+ raise KeyError(error_msg)
+ if expert_vision_path is not None or enable_expert_vision:
+ assert '.expert_visual.' not in name, "vision encoder need to be inited for action expert!"
+ logger.info_rank0(f"Unexpected key in state dict: {name}.")
+
+ del state_dict_iterator
+ empty_cache()
+
+ for name, buffer in buffer_dict.items():
+ _dispatch_buffer(model, name, buffer)
+ if post_training:
+ assert len(parameter_names) == 0, f"Missing {parameter_names} during Post-Training. This is not allowed!!!"
+ if len(parameter_names) > 0:
+ if load_vlm_only and (expert_vision_path is not None or enable_expert_vision) and not incremental_training:
+ num_missing_vlm_para, num_missing_dino_para = 0, 0
+ for name in parameter_names:
+ if '.paligemma.' in name or '.qwenvl.' in name:
+ num_missing_vlm_para += 1
+ elif '.expert_visual.' in name:
+ num_missing_dino_para += 1
+ print(f'====Missing {num_missing_vlm_para} paras in vlm====')
+ print(f'====Missing {num_missing_dino_para} paras in DINO====')
+ assert (all('.paligemma.' not in name for name in parameter_names) or all('.qwenvl.' not in name for name in parameter_names)) and all('.expert_visual.' not in name for name in parameter_names), "Parameters in VLM and Expert_Visual are not loaded when PreTraining!!!"
+ elif incremental_training and not depth_incremental_training:
+ if norm_qkv:
+ assert all('_proj.' in name or '_layernorm.' in name for name in parameter_names), "Only MLP weight can be reinitialized when IncrementalTraining!!!"
+ else:
+ assert all('_proj.' in name or 'gate' in name for name in parameter_names), "Only MLP weight can be reinitialized when IncrementalTraining!!!"
+ elif depth_incremental_training:
+ assert all('depth_align_head.' in name for name in parameter_names), "Only depth align head weight can be reinitialized when IncrementalTraining with Depth Model!!!"
+ elif load_vlm_only:
+ num_missing_vlm_para = 0
+ for name in parameter_names:
+ if '.paligemma.' in name or '.qwenvl.' in name:
+ num_missing_vlm_para += 1
+ print(f'====Missing {num_missing_vlm_para} paras in vlm====')
+ assert all('.paligemma.' not in name for name in parameter_names) or all('.qwenvl.' not in name for name in parameter_names), \
+ "Parameters in VLM are not loaded when PreTraining!!!"
+ logger.info_rank0(f"Find missing key(s) in state dict: {parameter_names}, initialize them.")
+ if adanorm_time:
+ logger.info_rank0(">>> Parameters in AdaNorm has been ZERO initialized.")
+ exclude_keywords = [
+ "input_layernorm.gamma_beta_gate",
+ "post_attention_layernorm.gamma_beta_gate",
+ "norm.gamma_beta_gate",
+ "input_layernorm.gamma",
+ "post_attention_layernorm.gamma",
+ "norm.gamma",
+ "input_layernorm.beta",
+ "post_attention_layernorm.beta",
+ "norm.beta",
+ "input_layernorm.gate",
+ "post_attention_layernorm.gate",
+ "norm.gate",
+ ]
+ for name in parameter_names:
+ if not adanorm_time:
+ _init_parameter(model, name)
+ else:
+ if not any(keyword in name for keyword in exclude_keywords):
+ _init_parameter(model, name)
+
+ # we should tie embeddings after loading weights because to_empty() leads to untied weights,
+ # except for fsdp1 (custom init) and fsdp2 (swap tensor) contexts.
+ if getattr(model.config, "tie_word_embeddings", True):
+ try:
+ input_embeddings = model.get_input_embeddings()
+ output_embeddings = model.get_output_embeddings()
+ output_embeddings._parameters["weight"] = input_embeddings._parameters["weight"]
+ except Exception as e:
+ logger.info_rank0(f"Failed to tie embeddings: {e}")
+
+
+def _get_shard_info(
+ state_dict: Dict[str, "torch.Tensor"],
+ save_dtype: Optional[Union[str, "torch.dtype"]],
+ shard_size: int,
+ safe_serialization: bool,
+) -> Tuple[bool, int, Dict[str, str]]:
+ """
+ Gets the shard information, should be executed at rank 0.
+ """
+ current_size, total_size = 0, 0
+ current_shard, shard_list = [], []
+ for name, tensor in state_dict.items():
+ if isinstance(save_dtype, str):
+ dtype = getattr(torch, save_dtype)
+ elif isinstance(save_dtype, torch.dtype):
+ dtype = save_dtype
+ else:
+ dtype = tensor.dtype
+ tensor_size = tensor.numel() * get_dtype_size(dtype) # dtensor's numel == tensor's numel
+ if current_size != 0 and current_size + tensor_size > shard_size:
+ total_size += current_size
+ shard_list.append(current_shard)
+ current_size = 0
+ current_shard = []
+
+ current_size += tensor_size
+ current_shard.append(name)
+
+ if current_size != 0:
+ total_size += current_size
+ shard_list.append(current_shard)
+
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+ num_shards = len(shard_list)
+ weight_map = OrderedDict()
+ is_sharded = None
+ if num_shards == 1:
+ is_sharded = False
+ for name in shard_list[0]:
+ weight_map[name] = weights_name
+ else:
+ is_sharded = True
+ for shard_idx, shard in enumerate(shard_list):
+ prefix, extension = weights_name.rsplit(".", maxsplit=1)
+ file_name = f"{prefix}-{shard_idx + 1:05d}-of-{num_shards:05d}.{extension}"
+ for name in shard:
+ weight_map[name] = file_name
+
+ return is_sharded, total_size, weight_map
+
+
+def _save_state_dict(
+ state_dict: Dict[str, "torch.Tensor"],
+ path_to_save: "os.PathLike",
+ safe_serialization: bool,
+) -> None:
+ """
+ Save function.
+ """
+ if os.path.exists(path_to_save):
+ os.remove(path_to_save)
+ if safe_serialization:
+ save_file(state_dict, path_to_save, metadata={"format": "pt"})
+ else:
+ torch.save(state_dict, path_to_save)
+
+
+@torch.no_grad()
+def save_model_weights(
+ output_dir: Union[str, "os.PathLike"],
+ state_dict: Dict[str, "torch.Tensor"],
+ global_rank: Optional[int] = None,
+ save_dtype: Optional[Union[str, "torch.dtype"]] = "bfloat16",
+ shard_size: int = 5_000_000_000,
+ safe_serialization: bool = True,
+ model_assets: Optional[Sequence["ModelAssets"]] = None,
+) -> None:
+ """
+ Saves full model weights. The model parameters should be either tensor or dtensor.
+
+ If global_rank is given, it will assume it is executed on all ranks.
+ """
+
+ os.makedirs(output_dir, exist_ok=True)
+ is_sharded, total_size, weight_map = _get_shard_info(state_dict, save_dtype, shard_size, safe_serialization)
+ full_state_dict = OrderedDict()
+ prev_file_name = None
+ for name, tensor in state_dict.items():
+ if hasattr(tensor.data, "full_tensor"): # dtensor
+ tensor = tensor.data.full_tensor()
+ else:
+ tensor = tensor.data
+
+ if save_dtype:
+ tensor = tensor.to(dtype=getattr(torch, save_dtype) if isinstance(save_dtype, str) else save_dtype)
+
+ if prev_file_name is not None and weight_map[name] != prev_file_name:
+ if global_rank is None or global_rank == 0:
+ _save_state_dict(full_state_dict, os.path.join(output_dir, prev_file_name), safe_serialization)
+ full_state_dict = OrderedDict()
+
+ empty_cache()
+ if global_rank is not None and dist.is_initialized(): # avoid process hanging
+ torch.cuda.synchronize()
+ dist.barrier()
+
+ if global_rank is None or global_rank == 0:
+ full_state_dict[name] = tensor.detach().cpu()
+
+ prev_file_name = weight_map[name]
+ del tensor
+
+ if global_rank is None or global_rank == 0:
+ if len(full_state_dict):
+ _save_state_dict(full_state_dict, os.path.join(output_dir, prev_file_name), safe_serialization)
+
+ if is_sharded:
+ index = {
+ "metadata": {"total_size": total_size},
+ "weight_map": weight_map,
+ }
+
+ index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
+ with open(os.path.join(output_dir, index_file), "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+
+ logger.info(f"Model weight splits saved in {output_dir}.")
+ else:
+ logger.info(f"Model weights saved at {os.path.join(output_dir, prev_file_name)}.")
+
+ if model_assets is not None:
+ for model_asset in model_assets:
+ if hasattr(model_asset, "save_pretrained"):
+ model_asset.save_pretrained(output_dir)
+ else:
+ logger.warning(f"Model asset {model_asset} should implement `save_pretrained`.")
+
+
+def save_model_assets(output_dir: Union[str, "os.PathLike"], model_assets: Sequence["ModelAssets"]):
+ for model_asset in model_assets:
+ if hasattr(model_asset, "save_pretrained"):
+ model_asset.save_pretrained(output_dir)
+ else:
+ logger.warning(f"Model asset {model_asset} should implement `save_pretrained`.")
diff --git a/lingbotvla/models/registry.py b/lingbotvla/models/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2e734274410bbef5b3301cb5d9f98d85c5ccb2
--- /dev/null
+++ b/lingbotvla/models/registry.py
@@ -0,0 +1,76 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/registry.py
+
+import importlib
+import pkgutil
+from dataclasses import dataclass, field
+from functools import lru_cache
+from typing import Dict, List, Type, Union
+
+import torch.nn as nn
+
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+MODELING_PATH = ["lingbotvla.models.vla"]
+
+@dataclass
+class _ModelRegistry:
+ # Keyed by model_arch
+ modeling_path: List[str] = field(default_factory=list)
+ model_arch_name_to_cls: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
+
+ def __post_init__(self):
+ for modeling_path in self.modeling_path:
+ self._mapping_model_arch_name_to_cls(modeling_path)
+
+ @property
+ def supported_models(self) -> Dict[str, Type[nn.Module]]:
+ return self.model_arch_name_to_cls.keys()
+
+ def get_model_cls_from_model_arch(self, model_arch: str) -> Type[nn.Module]:
+ return self.model_arch_name_to_cls[model_arch]
+
+ def _mapping_model_arch_name_to_cls(self, modeling_path: str):
+ package = importlib.import_module(modeling_path)
+ for _, name, ispkg in pkgutil.walk_packages(package.__path__, modeling_path + "."):
+ if not ispkg:
+ try:
+ module = importlib.import_module(name)
+ except Exception as e:
+ logger.warning(f"Ignore import error when loading {name}. {e}")
+ continue
+ if hasattr(module, "ModelClass"):
+ entry = module.ModelClass
+ if isinstance(entry, list):
+ for tmp in entry:
+ assert tmp.__name__ not in self.model_arch_name_to_cls, (
+ f"Duplicated model implementation for {tmp.__name__}"
+ )
+ self.model_arch_name_to_cls[tmp.__name__] = tmp
+ else:
+ assert entry.__name__ not in self.model_arch_name_to_cls, (
+ f"Duplicated model implementation for {entry.__name__}"
+ )
+ self.model_arch_name_to_cls[entry.__name__] = entry
+
+
+@lru_cache
+def get_registry():
+ return _ModelRegistry(modeling_path=MODELING_PATH)
diff --git a/lingbotvla/models/vla/__init__.py b/lingbotvla/models/vla/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..048e827f681d247a0e7fc9aa4d86e876c85e8235
--- /dev/null
+++ b/lingbotvla/models/vla/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2026 Robbyant Team and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import pi0
+
+
+__all__ = ["pi0"]
diff --git a/lingbotvla/models/vla/pi0/__init__.py b/lingbotvla/models/vla/pi0/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b070999eebec96f6e367cafe41229755d1dd5fc1
--- /dev/null
+++ b/lingbotvla/models/vla/pi0/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2026 Robbyant Team and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
\ No newline at end of file
diff --git a/lingbotvla/models/vla/pi0/flex_attention.py b/lingbotvla/models/vla/pi0/flex_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3e32b8cd10b688d8c2df3a685f2a647fc9d6e42
--- /dev/null
+++ b/lingbotvla/models/vla/pi0/flex_attention.py
@@ -0,0 +1,148 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn.functional as F # noqa: N812
+from packaging.version import Version
+import einops
+import ipdb
+
+if Version(torch.__version__) > Version("2.5.0"):
+ # Ffex attention is only available from torch 2.5 onwards
+ from torch.nn.attention.flex_attention import (
+ _mask_mod_signature,
+ _round_up_to_multiple,
+ create_block_mask,
+ create_mask,
+ flex_attention,
+ )
+
+# @torch.compile(dynamic=False)
+def flex_attention_forward(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ scaling=None,
+):
+ """
+ This is defined out of classes to make compile happy.
+ """
+ batch_size, seq_len, num_att_heads, head_dim = query_states.shape # head_dim=256
+ original_dtype = query_states.dtype
+ num_key_value_heads = key_states.shape[2] # 1
+ num_key_value_groups = num_att_heads // num_key_value_heads # 8 // 1
+ # key_states = key_states[:, :, :, None, :]
+ # key_states = key_states.expand(
+ # batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
+ # )
+ # key_states = key_states.reshape(
+ # batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
+ # )
+
+ # value_states = value_states[:, :, :, None, :]
+ # value_states = value_states.expand(
+ # batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
+ # )
+ # value_states = value_states.reshape(
+ # batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
+ # )
+
+ key_states = einops.repeat(
+ key_states, "b l h d -> b l (h g) d", g=num_key_value_groups
+ )
+ value_states = einops.repeat(
+ value_states, "b l h d -> b l (h g) d", g=num_key_value_groups
+ )
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ query_states = query_states.to(torch.float32)
+ key_states = key_states.to(torch.float32)
+ value_states = value_states.to(torch.float32)
+
+ causal_mask = attention_mask
+ if causal_mask is not None:
+ causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
+
+ if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
+ causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
+
+ def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
+ def mask_mod(b, h, q_idx, kv_idx):
+ # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
+ return precomputed_mask[b][h][q_idx][kv_idx]
+
+ return mask_mod
+
+ b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
+ # ipdb.set_trace()
+ block_size = 128
+ q_len_rounded = _round_up_to_multiple(q_len, block_size)
+ kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
+
+ # *CRITICAL* we do need to expand here, else we get a CUDA index error
+
+ pad_q = q_len_rounded - q_len
+ pad_k = kv_len_rounded - kv_len
+
+ if pad_q > 0:
+ query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) # [B, H, q_len_rounded, D]
+ if pad_k > 0:
+ key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0)
+ value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0)
+ padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
+ mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
+
+ mask_4d = create_mask(
+ mod_fn=mask_mod_fn_orig,
+ B=b_mask,
+ H=h_mask,
+ Q_LEN=q_len_rounded,
+ KV_LEN=kv_len_rounded,
+ device=causal_mask.device,
+ )
+
+ mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
+ block_mask = create_block_mask(
+ mask_mod=mask_mod_fn_padded,
+ B=b_mask,
+ H=h_mask,
+ Q_LEN=q_len_rounded,
+ KV_LEN=kv_len_rounded,
+ BLOCK_SIZE=block_size,
+ device=causal_mask.device,
+ _compile=False,
+ )
+
+ # mask is applied inside the kernel, ideally more efficiently than score_mod.
+ attn_output, attention_weights = flex_attention(
+ query_states,
+ key_states,
+ value_states,
+ block_mask=block_mask,
+ enable_gqa=True, # because we shaped query/key states for GQA
+ scale=head_dim**-0.5 if scaling is None else scaling,
+ return_lse=True,
+ )
+ attn_output = attn_output[:, :, :seq_len, :].to(dtype=original_dtype)
+ attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
+ attn_output = attn_output.reshape(
+ batch_size,
+ -1,
+ attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
+ )
+ return attn_output
diff --git a/lingbotvla/models/vla/pi0/modeling_lingbot_vla.py b/lingbotvla/models/vla/pi0/modeling_lingbot_vla.py
new file mode 100644
index 0000000000000000000000000000000000000000..46a6b232276174293a9c7605bbaa556471812f2a
--- /dev/null
+++ b/lingbotvla/models/vla/pi0/modeling_lingbot_vla.py
@@ -0,0 +1,2061 @@
+import einops
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
+from lerobot.common.policies.pretrained import PreTrainedPolicy
+from torch import Tensor, nn
+from typing import List, Optional, Tuple, Union, Callable, Dict, Any
+from functools import partial
+from transformers import (
+ AutoConfig,
+ PretrainedConfig,
+ PreTrainedModel,
+)
+from transformers.models.auto import CONFIG_MAPPING
+from transformers import AutoTokenizer
+from dataclasses import dataclass
+from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
+from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache, DynamicCache
+from transformers.generation import GenerationMixin
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel, ALL_ATTENTION_FUNCTIONS
+from transformers.utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+ LossKwargs,
+ can_return_tuple,
+ is_torch_flex_attn_available,
+)
+from transformers.utils.deprecation import deprecate_kwarg
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, is_flash_attn_available
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from transformers.processing_utils import Unpack
+import torch.distributed._tensor as dt
+from .qwenvl_in_vla import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel, Qwen2_5_VLPreTrainedModel
+from .vla_flash_attn_policy import use_flash_attention_2_for_vla
+
+try:
+ from dinov3.hub.backbones import (
+ dinov3_vits16,
+ dinov3_vits16plus,
+ dinov3_vitb16,
+ )
+except: pass
+from .utils import (
+ create_sinusoidal_pos_embedding,
+ make_att_2d_masks,
+ resize_with_pad,
+ sample_beta,
+)
+from .utils import apply_rope, our_eager_attention_forward
+from .flex_attention import flex_attention_forward
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
+_CONFIG_FOR_DOC = "Qwen2Config"
+
+
+from lingbotvla.models.vla.vision_models.align_heads.depth_head import DepthHead, TaskTokenDepthHead
+
+
+class Qwen2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Qwen2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Qwen2Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ sliding_window = None
+ if (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ sliding_window = self.config.sliding_window
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=sliding_window, # main diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+class Qwen2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Qwen2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+class FixQwen2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ FixQwen2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+class Qwen2DecoderLayer(nn.Module):
+ def __init__(self, config: Qwen2Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
+ self.mlp = Qwen2MLP(config)
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ if config.norm_qkv:
+ self.q_layernorm = Qwen2RMSNorm(self.self_attn.head_dim, eps=config.rms_norm_eps)
+ self.k_layernorm = Qwen2RMSNorm(self.self_attn.head_dim, eps=config.rms_norm_eps)
+
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ att_output: Optional[torch.Tensor] = None,
+ start: Optional[int] = 0,
+ end: Optional[int] = 0,
+ compute_kqv: bool = False,
+ norm_qkv: bool = False,
+ old_adanorm: bool = False,
+ output_atten: bool = False,
+ ada_cond: Optional[torch.Tensor] = None,
+ gate: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ if compute_kqv:
+ if ada_cond is not None:
+ if old_adanorm:
+ hidden_states = self.input_layernorm(hidden_states, ada_cond)
+ gate = None
+ else:
+ hidden_states, gate = self.input_layernorm(hidden_states, ada_cond)
+ else:
+ hidden_states = self.input_layernorm(hidden_states)
+ gate = None
+ hidden_shape = (*hidden_states.shape[:-1], -1, self.self_attn.head_dim)
+
+ query_state = self.self_attn.q_proj(hidden_states).view(hidden_shape)
+ key_state = self.self_attn.k_proj(hidden_states).view(hidden_shape)
+ value_state = self.self_attn.v_proj(hidden_states).view(hidden_shape)
+ if norm_qkv:
+ query_state = self.q_layernorm(query_state)
+ key_state = self.k_layernorm(key_state)
+
+ return query_state, key_state, value_state, gate
+
+ elif output_atten:
+ if att_output.dtype != self.self_attn.o_proj.weight.dtype:
+ att_output = att_output.to(self.self_attn.o_proj.weight.dtype)
+ out_emb = self.self_attn.o_proj(att_output[:, start:end])
+
+ # first residual
+ if gate is not None:
+ out_emb = out_emb * gate + hidden_states
+ else:
+ out_emb += hidden_states
+ after_first_residual = out_emb.clone()
+ if ada_cond is not None:
+ if old_adanorm:
+ out_emb = self.post_attention_layernorm(out_emb, ada_cond)
+ after_gate= None
+ else:
+ out_emb, after_gate = self.post_attention_layernorm(out_emb, ada_cond)
+ else:
+ out_emb = self.post_attention_layernorm(out_emb)
+ after_gate = None
+ out_emb = self.mlp(out_emb)
+
+ # second residual
+ if after_gate is not None:
+ out_emb = out_emb * after_gate + after_first_residual
+ else:
+ out_emb += after_first_residual
+
+ return out_emb
+
+ else:
+ raise ValueError(f"Invaild Operation compute_kqv={compute_kqv} and output_atten={output_atten} with Qwen2DecoderLayer in LingBot-VLA")
+
+class Qwen2RotaryEmbedding(nn.Module):
+ def __init__(self, config: Qwen2Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+QWEN2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Qwen2Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+@add_start_docstrings(
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
+ QWEN2_START_DOCSTRING,
+)
+class Qwen2PreTrainedModel(PreTrainedModel):
+ config_class = Qwen2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen2DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+QWEN2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+@add_start_docstrings(
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
+ QWEN2_START_DOCSTRING,
+)
+class Qwen2Model(Qwen2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
+
+ Args:
+ config: Qwen2Config
+ """
+
+ def __init__(self, config: Qwen2Config, eval=False):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = FixQwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ if eval:
+ self._init_weights = lambda module: None
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
+ if not isinstance(past_key_values, (type(None), Cache)):
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and past_key_values is not None:
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if (
+ self.config._attn_implementation == "sdpa"
+ and not (using_static_cache or using_sliding_window_cache)
+ and not output_attentions
+ ):
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ sliding_window=self.config.sliding_window,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ # SlidingWindowCache or StaticCache
+ if using_sliding_window_cache or using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ # DynamicCache or no cache
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ config=self.config,
+ past_key_values=past_key_values,
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ config: Qwen2Config,
+ past_key_values: Cache,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ config (`Qwen2Config`):
+ The model's configuration class
+ past_key_values (`Cache`):
+ The cache class that is being used currently to generate
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ if config.sliding_window is not None:
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
+ cache_position.reshape(-1, 1) - config.sliding_window
+ )
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
+ causal_mask *= diagonal_attend_mask
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ if attention_mask.shape[-1] > target_length:
+ attention_mask = attention_mask[:, :target_length]
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+ return causal_mask
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config, eval):
+ super().__init__(config)
+ self.model = Qwen2Model(config, eval)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
+
+ >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+class QwenvlWithExpertConfig(PretrainedConfig):
+ model_type = "QwenvlWithExpertModel"
+ sub_configs = {"qwenvl_config": AutoConfig, "qwen_expert_config": AutoConfig}
+
+ def __init__(
+ self,
+ qwenvl_config: dict | None = None,
+ qwen_expert_config: dict | None = None,
+ freeze_vision_encoder: bool = True,
+ train_expert_only: bool = True,
+ vocab_size: int = 257152,
+ use_lm_head: bool = False,
+ attention_implementation: str = "eager",
+ tokenizer_path: str | None = None,
+ enable_expert_vision: bool = False,
+ expert_vision_type: str | None = None,
+ **kwargs,
+ ):
+ self.freeze_vision_encoder = freeze_vision_encoder
+ self.train_expert_only = train_expert_only
+ self.attention_implementation = attention_implementation
+ self.tokenizer_path = tokenizer_path
+ self.enable_expert_vision = enable_expert_vision
+ self.expert_vision_type = expert_vision_type
+ self.vocab_size = vocab_size
+ self.use_lm_head = use_lm_head
+ if qwenvl_config is None:
+ self.qwenvl_config = CONFIG_MAPPING["qwen2_5_vl"](
+ attention_dropout=0.0,
+ bos_token_id=151643,
+ eos_token_id=151645,
+ vision_start_token_id=151652,
+ vision_end_token_id=151653,
+ vision_token_id=151654,
+ image_token_id=151655,
+ video_token_id=151656,
+ hidden_act="silu",
+ hidden_size=2048,
+ initializer_range=0.02,
+ intermediate_size=11008,
+ max_position_embeddings=128000,
+ max_window_layers=70,
+ model_type="qwen2_5_vl",
+ num_attention_heads=16,
+ num_hidden_layers=36,
+ num_key_value_heads=2,
+ rms_norm_eps=1e-06,
+ rope_theta=1000000.0,
+ sliding_window=32768,
+ tie_word_embeddings=True,
+ torch_dtype="bfloat16",
+ transformers_version="4.41.2",
+ use_cache=True,
+ use_sliding_window=False,
+ vision_config={
+ "depth": 32,
+ "hidden_act": "silu",
+ "hidden_size": 1280,
+ "intermediate_size": 3420,
+ "num_heads": 16,
+ "in_chans": 3,
+ "out_hidden_size": 2048,
+ "patch_size": 14,
+ "spatial_merge_size": 2,
+ "spatial_patch_size": 14,
+ "window_size": 112,
+ "fullatt_block_indexes": [
+ 7,
+ 15,
+ 23,
+ 31
+ ],
+ "tokens_per_second": 2,
+ "temporal_patch_size": 2
+ },
+ rope_scaling={
+ "type": "mrope",
+ "mrope_section": [
+ 16,
+ 24,
+ 24
+ ]
+ },
+ vocab_size=151936,
+ )
+ elif isinstance(self.qwenvl_config, dict):
+ if "model_type" not in qwen_expert_config:
+ qwenvl_config["model_type"] = "qwen2_5_vl"
+
+ cfg_cls = CONFIG_MAPPING[qwenvl_config["model_type"]]
+ self.qwenvl_config = cfg_cls(**qwenvl_config)
+
+ if qwen_expert_config is None:
+ self.qwen_expert_config = CONFIG_MAPPING["qwen2"](
+ attention_dropout=0.0,
+ bos_token_id=151643,
+ eos_token_id=151645,
+ hidden_act="silu",
+ hidden_size=768,
+ head_dim=128,
+ initializer_range=0.02,
+ intermediate_size=2752,
+ max_position_embeddings=32768,
+ max_window_layers=21,
+ model_type="qwen2",
+ num_attention_heads=16,
+ num_hidden_layers=36,
+ num_key_value_heads=2,
+ rms_norm_eps=1e-06,
+ rope_theta=1000000.0,
+ sliding_window=32768,
+ tie_word_embeddings=True,
+ torch_dtype="bfloat16",
+ transformers_version="4.43.1",
+ use_cache=True,
+ use_sliding_window=False,
+ vocab_size=151936,
+ )
+ elif isinstance(self.qwen_expert_config, dict):
+ if "model_type" not in qwen_expert_config:
+ qwen_expert_config["model_type"] = "qwen2"
+
+ cfg_cls = CONFIG_MAPPING[qwenvl_config["model_type"]]
+ self.qwen_expert_config = cfg_cls(**qwen_expert_config)
+
+ super().__init__(**kwargs)
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self.train_expert_only and not self.freeze_vision_encoder:
+ raise ValueError(
+ "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
+ )
+
+ if self.attention_implementation not in ["eager", "fa2", "flex"]:
+ raise ValueError(
+ f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
+ )
+
+class OldAdaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, cond_dim, eps=1e-6):
+ """
+ AdaRMSNorm: RMSNorm + FiLM
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+ self.gamma = nn.Linear(cond_dim, hidden_size)
+ self.beta = nn.Linear(cond_dim, hidden_size)
+
+ # DiT style init: gamma.weight=0, gamma.bias=1; beta.weight=0, beta.bias=0
+ nn.init.zeros_(self.gamma.weight)
+ nn.init.zeros_(self.gamma.bias)
+ nn.init.zeros_(self.beta.weight)
+ nn.init.zeros_(self.beta.bias)
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ nn.init.constant_(module.weight, 0.0)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0.0)
+
+ def forward(self, hidden_states, cond):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ hidden_states = self.weight * hidden_states
+ gamma = self.gamma(cond).unsqueeze(1) # [B, 1, H]
+ beta = self.beta(cond).unsqueeze(1) # [B, 1, H]
+ hidden_states = (1 + gamma.to(torch.float32)) * hidden_states + beta.to(torch.float32)
+ return hidden_states.to(input_dtype)
+
+ # def extra_repr(self):
+ # return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+class AdaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, cond_dim, split_gate_liner, no_split_gate_liner, eps=1e-6):
+ """
+ AdaRMSNorm: RMSNorm + FiLM
+ """
+ super().__init__()
+ if not (split_gate_liner or no_split_gate_liner):
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+ self.use_gate = split_gate_liner or no_split_gate_liner
+ if not no_split_gate_liner:
+ self.gamma = nn.Linear(cond_dim, hidden_size)
+ self.beta = nn.Linear(cond_dim, hidden_size)
+ if self.use_gate:
+ self.gate = nn.Linear(cond_dim, hidden_size)
+ nn.init.zeros_(self.gate.weight)
+ nn.init.zeros_(self.gate.bias)
+
+ # DiT style init: gamma.weight=0, gamma.bias=1; beta.weight=0, beta.bias=0
+ nn.init.zeros_(self.gamma.weight)
+ nn.init.zeros_(self.gamma.bias)
+ nn.init.zeros_(self.beta.weight)
+ nn.init.zeros_(self.beta.bias)
+ else:
+ self.gamma_beta_gate = nn.Linear(cond_dim, hidden_size * 3, bias=True)
+ nn.init.zeros_(self.gamma_beta_gate.weight)
+ nn.init.zeros_(self.gamma_beta_gate.bias)
+ self.no_split_gate_liner = no_split_gate_liner
+ self.split_gate_liner = split_gate_liner
+
+ def forward(self, hidden_states, cond):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ if not (self.split_gate_liner or self.no_split_gate_liner):
+ hidden_states = self.weight * hidden_states
+ if not self.no_split_gate_liner:
+ gamma = self.gamma(cond).unsqueeze(1) # [B, 1, H]
+ beta = self.beta(cond).unsqueeze(1) # [B, 1, H]
+ if self.use_gate:
+ gate = self.gate(cond).unsqueeze(1) # [B, 1, H]
+ else:
+ gate = None
+ hidden_states = (1 + gamma.to(torch.float32)) * hidden_states + beta.to(torch.float32)
+ else:
+ modulation = self.gamma_beta_gate(cond)
+ if len(hidden_states.shape) == 3: # [batch, seq, features]
+ modulation = modulation.unsqueeze(1)
+ gamma, beta, gate = torch.chunk(modulation, 3, dim=-1)
+ hidden_states = (1 + gamma.to(torch.float32)) * hidden_states + beta.to(torch.float32)
+ return hidden_states.to(input_dtype), gate
+
+ # def extra_repr(self):
+ # return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+class FixAdaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, cond_dim, split_gate_liner, no_split_gate_liner, eps=1e-6):
+ """
+ AdaRMSNorm: RMSNorm + FiLM
+ """
+ super().__init__()
+ if not (split_gate_liner or no_split_gate_liner):
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+ self.use_gate = split_gate_liner or no_split_gate_liner
+ if not no_split_gate_liner:
+ self.gamma = nn.Linear(cond_dim, hidden_size)
+ self.beta = nn.Linear(cond_dim, hidden_size)
+ if self.use_gate:
+ self.gate = nn.Linear(cond_dim, hidden_size)
+ nn.init.zeros_(self.gate.weight)
+ nn.init.zeros_(self.gate.bias)
+
+ # DiT style init: gamma.weight=0, gamma.bias=1; beta.weight=0, beta.bias=0
+ nn.init.zeros_(self.gamma.weight)
+ nn.init.zeros_(self.gamma.bias)
+ nn.init.zeros_(self.beta.weight)
+ nn.init.zeros_(self.beta.bias)
+ else:
+ self.gamma_beta_gate = nn.Linear(cond_dim, hidden_size * 3, bias=True)
+ nn.init.zeros_(self.gamma_beta_gate.weight)
+ nn.init.zeros_(self.gamma_beta_gate.bias)
+ self.no_split_gate_liner = no_split_gate_liner
+ self.split_gate_liner = split_gate_liner
+
+ def forward(self, hidden_states, cond):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ if not (self.split_gate_liner or self.no_split_gate_liner):
+ hidden_states = self.weight * hidden_states
+ if not self.no_split_gate_liner:
+ gamma = self.gamma(cond).unsqueeze(1) # [B, 1, H]
+ beta = self.beta(cond).unsqueeze(1) # [B, 1, H]
+ if self.use_gate:
+ gate = self.gate(cond).unsqueeze(1) # [B, 1, H]
+ else:
+ gate = None
+ hidden_states = (1 + gamma.to(torch.float32)) * hidden_states + beta.to(torch.float32)
+ else:
+ modulation = self.gamma_beta_gate(cond)
+ if len(hidden_states.shape) == 3: # [batch, seq, features]
+ modulation = modulation.unsqueeze(1)
+ gamma, beta, gate = torch.chunk(modulation, 3, dim=-1)
+ hidden_states = (1 + gamma.to(torch.float32)) * hidden_states + beta.to(torch.float32)
+ return hidden_states.to(input_dtype), gate
+
+# HACK: show directly use this norm during initialization
+def replace_lnorm_with_adanorm(module, hidden_size, cond_dim, split_gate_liner, no_split_gate_liner, final_norm_adanorm, old_adanorm):
+ if old_adanorm:
+ for name, child in module.named_children():
+ if isinstance(child, Qwen2RMSNorm):
+ if 'q_layernorm' not in name and 'k_layernorm' not in name:
+ setattr(module, name, OldAdaRMSNorm(hidden_size, cond_dim))
+ else:
+ replace_lnorm_with_adanorm(child, hidden_size, cond_dim, split_gate_liner, no_split_gate_liner, final_norm_adanorm, old_adanorm)
+ else:
+ for name, child in module.named_children():
+ if final_norm_adanorm:
+ if isinstance(child, Qwen2RMSNorm):
+ if 'q_layernorm' not in name and 'k_layernorm' not in name:
+ setattr(module, name, AdaRMSNorm(hidden_size, cond_dim, split_gate_liner, no_split_gate_liner))
+ elif isinstance(child, FixQwen2RMSNorm):
+ if 'q_layernorm' not in name and 'k_layernorm' not in name:
+ setattr(module, name, FixAdaRMSNorm(hidden_size, cond_dim, split_gate_liner, no_split_gate_liner))
+ else:
+ replace_lnorm_with_adanorm(child, hidden_size, cond_dim, split_gate_liner, no_split_gate_liner, final_norm_adanorm, old_adanorm)
+ else:
+ if isinstance(child, Qwen2RMSNorm):
+ if 'q_layernorm' not in name and 'k_layernorm' not in name:
+ setattr(module, name, AdaRMSNorm(hidden_size, cond_dim, split_gate_liner, no_split_gate_liner))
+ else:
+ replace_lnorm_with_adanorm(child, hidden_size, cond_dim, split_gate_liner, no_split_gate_liner, final_norm_adanorm, old_adanorm)
+
+class QwenvlWithExpertModel(PreTrainedModel):
+ config_class = QwenvlWithExpertConfig
+
+ def __init__(self, config: QwenvlWithExpertConfig, eval=False):
+ super().__init__(config=config)
+ self.config = config
+ vlm_config = AutoConfig.from_pretrained(self.config.tokenizer_path)
+ vlm_config.vision_config.initializer_range = 0.02
+ vlm_config.norm_qkv = self.config.norm_qkv
+ if self.config.vocab_size != 0 and self.config.vocab_size != 257152 and vlm_config.vocab_size != self.config.vocab_size:
+ vlm_config.vocab_size = self.config.vocab_size
+ _fa2 = use_flash_attention_2_for_vla()
+ self.qwenvl = Qwen2_5_VLForConditionalGeneration._from_config(vlm_config, use_flash_attention_2=_fa2)
+ if self.config.use_lm_head:
+ self.qwenvl.tie_weights()
+ self.config.qwen_expert_config.norm_qkv = self.config.norm_qkv
+ self.qwen_expert = Qwen2ForCausalLM._from_config(
+ self.config.qwen_expert_config, use_flash_attention_2=_fa2, eval=eval
+ )
+
+ if getattr(self.config, 'adanorm_time', False):
+ replace_lnorm_with_adanorm(self.qwen_expert, self.config.qwen_expert_config.hidden_size, self.config.qwen_expert_config.hidden_size, config.split_gate_liner, config.no_split_gate_liner, config.final_norm_adanorm, config.old_adanorm)
+ # Remove unused embed_tokens
+ del self.qwen_expert.model.embed_tokens
+ if self.config.enable_expert_vision:
+ if 'dinov3_vitb16' in self.config.expert_vision_type:
+ self.expert_visual = dinov3_vitb16(pretrained=False)
+ self.expert_visual_mlp = nn.Sequential(
+ nn.Linear(self.expert_visual.embed_dim, self.expert_visual.embed_dim * 2),
+ nn.GELU(),
+ nn.Linear(self.expert_visual.embed_dim * 2, self.config.qwen_expert_config.hidden_size),
+ )
+ self.attention_interface = self.get_attention_interface()
+
+ # self.to_bfloat16_like_physical_intelligence()
+ self.set_requires_grad()
+
+ def set_requires_grad(self):
+ """sets the requires_grad attribute of the model parameters based on the configuration.
+ If `freeze_vision_encoder` is True, the vision tower parameters are frozen.
+ If `train_expert_only` is True, the entire Qwenvl model is frozen.
+ """
+ if self.config.freeze_vision_encoder:
+ self.qwenvl.visual.eval()
+ for params in self.qwenvl.visual.parameters():
+ params.requires_grad = False
+
+ if self.config.train_expert_only:
+ self.qwenvl.eval()
+ for params in self.qwenvl.parameters():
+ params.requires_grad = False
+
+ def train(self, mode: bool = True):
+ super().train(mode)
+ if self.config.freeze_vision_encoder:
+ self.qwenvl.visual.eval()
+ if self.config.train_expert_only:
+ self.qwenvl.eval()
+
+ def to_bfloat16_like_physical_intelligence(self):
+ """casts the model to bfloat16.
+
+ Modules not casted to bfloat16:
+ - qwenvl.model.embed_tokens.weight
+ - qwenvl.model.norm.weight
+ - qwen_expert.model.norm.weight
+ - qwen_expert.lm_head.weight
+ """
+ self.qwenvl = self.qwenvl.to(dtype=torch.bfloat16)
+
+ params_to_change_dtype = [
+ "qwenvl.model.layers",
+ "qwen_expert.model.layers",
+ "visual",
+ "multi_modal",
+ ]
+ for name, param in self.named_parameters():
+ if any(selector in name for selector in params_to_change_dtype):
+ param.data = param.data.to(dtype=torch.bfloat16)
+
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
+ """
+ Encodes images into continuous embeddings that can be forwarded to the language model.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ """
+ image_embeds = self.qwenvl.visual(pixel_values, grid_thw=image_grid_thw)
+ split_sizes = (image_grid_thw.prod(-1) // self.qwenvl.visual.spatial_merge_size**2).tolist()
+ image_embeds = torch.split(image_embeds, split_sizes)
+ image_embeds = torch.stack(image_embeds, dim=0)
+ return image_embeds
+
+ def embed_image(self, image: torch.Tensor, patch_size=14, temporal_patch_size=2):
+ h = w = int(image.shape[1] ** 0.5)
+ image_grid_thw = torch.tensor([[1, h, w]]*image.shape[0], device=image.device)
+ image_embeds = self.get_image_features(image, image_grid_thw=image_grid_thw)
+ return image_embeds
+ # return torch.randn(72, 64, 2048).to(device=image.device, dtype=torch.bfloat16)
+
+ def embed_language_tokens(self, tokens: torch.Tensor):
+ return self.qwenvl.model.embed_tokens(tokens)
+
+ def handle_kv_cache(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
+ use_cache: Optional[bool] = None,
+ fill_kv_cache: Optional[bool] = None,
+ ):
+ if use_cache:
+ if past_key_values is None:
+ past_key_values = {}
+
+ if fill_kv_cache:
+ past_key_values[layer_idx] = {
+ "key_states": key_states,
+ "value_states": value_states,
+ }
+ else:
+ key_states = torch.cat(
+ [past_key_values[layer_idx]["key_states"], key_states], dim=1
+ )
+ value_states = torch.cat(
+ [past_key_values[layer_idx]["value_states"], value_states],
+ dim=1,
+ )
+ return key_states, value_states, past_key_values
+
+ def forward(
+ self,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ vlm_position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
+ inputs_embeds: List[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ fill_kv_cache: Optional[bool] = None,
+ ada_cond: List[torch.FloatTensor] = None,
+ use_ki: bool = False,
+ norm_qkv: bool = False,
+ ):
+ """
+ Args:
+ attention_mask (Optional[torch.Tensor], optional):
+ Attention mask with shape (b, seq_len, seq_len). Defaults to None.
+ position_ids (Optional[torch.LongTensor], optional):
+ Position indices for applying RoPE. Defaults to None.
+ past_key_values (Optional[Union[List[torch.FloatTensor], Cache]], optional):
+ Optional kv cache. Defaults to None.
+ inputs_embeds (List[torch.FloatTensor], optional):
+ Input embeddings. Defaults to None.
+ use_cache (Optional[bool], optional):
+ Whether to use kv cache. Defaults to None.
+ fill_kv_cache (Optional[bool], optional):
+ Whether to return kv tensors in this forward pass as cache. Defaults to None.
+
+ Returns:
+ outputs_embeds (torch.Tensor): Output embeddings.
+ past_key_values (Optional[Union[List[torch.FloatTensor], Cache]]):
+ Optional kv cache.
+ """
+ models = [self.qwenvl.model, self.qwen_expert.model]
+
+ # RMSNorm
+ num_layers = self.qwenvl.config.num_hidden_layers # 36
+ for layer_idx in range(num_layers):
+ query_states = []
+ key_states = []
+ value_states = []
+ gates = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ if hidden_states is None:
+ continue
+ if i == 1: # For action expert
+ query_state, key_state, value_state, gate = models[i].layers[layer_idx](hidden_states, compute_kqv=True, ada_cond = ada_cond, norm_qkv=norm_qkv, old_adanorm=self.config.old_adanorm)
+ else: # For VLM
+ query_state, key_state, value_state = models[i].layers[layer_idx](hidden_states, compute_kqv=True, norm_qkv=norm_qkv)
+ gate = None
+ if use_ki:
+ query_state, key_state, value_state = query_state.detach(), key_state.detach(), value_state.detach()
+
+ if query_state.dtype != torch.float32:
+ query_state, key_state, value_state = query_state.to(torch.float32), key_state.to(torch.float32), value_state.to(torch.float32)
+ query_states.append(query_state)
+ key_states.append(key_state)
+ value_states.append(value_state)
+ gates.append(gate)
+
+ # B,L,H,D with L sequence length (img, lang, state, action), H number of heads, D head dim
+ # concatenate on the number of embeddings/tokens
+ query_states = torch.cat(query_states, dim=1)
+ key_states = torch.cat(key_states, dim=1)
+ value_states = torch.cat(value_states, dim=1)
+
+ query_states = apply_rope(query_states, position_ids)
+ key_states = apply_rope(key_states, position_ids)
+
+ key_states, value_states, past_key_values = self.handle_kv_cache(
+ key_states,
+ value_states,
+ layer_idx,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ fill_kv_cache=fill_kv_cache,
+ )
+ # ipdb.set_trace()
+ att_output = self.attention_interface(query_states, key_states, value_states, attention_mask)
+
+ # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
+ outputs_embeds = []
+ start = 0
+ for i, hidden_states in enumerate(inputs_embeds):
+ if hidden_states is not None:
+ end = start + hidden_states.shape[1]
+ if i == 1:
+ out_emb = models[i].layers[layer_idx](hidden_states, att_output, start, end, output_atten=True, ada_cond = ada_cond, gate=(gates[0] if len(gates) == 1 else gates[i]), old_adanorm=self.config.old_adanorm)
+ else:
+ out_emb = models[i].layers[layer_idx](hidden_states, att_output, start, end, output_atten=True)
+ outputs_embeds.append(out_emb)
+ start = end
+ else:
+ outputs_embeds.append(None)
+
+ inputs_embeds = outputs_embeds
+
+ # final norm
+ outputs_embeds = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ if hidden_states is not None:
+ if self.config.final_norm_adanorm:
+ if i == 1:
+ out_emb, _ = models[i].norm(hidden_states, ada_cond)
+ else:
+ out_emb = models[i].norm(hidden_states)
+ else:
+ out_emb = models[i].norm(hidden_states)
+ outputs_embeds.append(out_emb)
+ else:
+ outputs_embeds.append(None)
+
+ return outputs_embeds, past_key_values
+
+ def get_attention_interface(self):
+ if self.config.attention_implementation == "fa2":
+ raise NotImplementedError("FA2 is not implemented (yet)")
+ elif self.config.attention_implementation == "flex":
+ attention_interface = flex_attention_forward
+ elif self.config.attention_implementation == "eager":
+ attention_interface = our_eager_attention_forward
+ elif self.config.attention_implementation == "xformer":
+ # attention_interface = xformer_attention_forward
+ raise NotImplementedError("Xformer attention is not implemented (yet)")
+ else:
+ raise ValueError(
+ f"Invalid attention implementation: {self.config.attention_implementation}. "
+ "Expected one of ['fa2', 'flex', 'eager', 'xformer']."
+ )
+ return attention_interface
+
+class QwenVLA_Config(PI0Config):
+ model_type = "torch_qwenvla"
+ architectures = ["LingbotVlaPolicy"]
+
+class LingbotVlaPolicy(PreTrainedPolicy):
+ config_class = QwenVLA_Config
+ name = "torch_lingbot_vla"
+ _no_split_modules = ["Qwen2DecoderLayer", "FixQwen2RMSNorm", "FixAdaRMSNorm"]
+ def __init__(
+ self,
+ config: PI0Config,
+ tokenizer_path: str,
+ eval: bool=False,
+ ):
+ """
+ Args:
+ config: Policy configuration class instance or None, in which case the default instantiation of
+ the configuration class is used.
+ """
+
+ super().__init__(config)
+ self.config = config
+ self.language_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ self.model = FlowMatching(config, eval)
+
+ if not getattr(self.config,"use_lm_head", False):
+ del self.model.qwenvl_with_expert.qwenvl.lm_head
+ del self.model.qwenvl_with_expert.qwen_expert.lm_head
+
+ self.reset()
+ torch.set_float32_matmul_precision("high")
+
+ def reset(self):
+ return None
+
+ def get_optim_params(self) -> dict:
+ return self.parameters()
+
+ @torch.no_grad
+ def select_action(
+ self, observation: dict[str, Tensor], noise: Tensor | None = None
+ ):
+ pass
+
+ def forward(
+ self, images, img_masks, state, lang_tokens, lang_masks, actions, joint_mask=None, action_is_pad=None, expert_imgs=None, label=None, noise=None, time=None, vlm_causal=False, use_ki=False, depth_targets=None, norm_qkv=False
+ ) -> tuple[Tensor, dict[str, Tensor]]:
+ loss_dict = {}
+ losses, loss_depth, depth_preds = self.model.forward(
+ images, img_masks, lang_tokens, lang_masks, state, actions, expert_imgs, label, noise, time, vlm_causal, self.config.loss_type, use_ki, depth_targets, norm_qkv
+ )
+ batch_mean_losses = losses.mean(dim=(1, 2))
+
+ loss_dict["batch_mean_losses"] = batch_mean_losses.clone()
+
+ losses = losses[:, :, :self.config.action_dim]
+ loss_dict["losses"] = losses.clone()
+
+ # For backward pass
+ loss_vla = losses.mean()
+ # For logging
+ loss_dict["l2_loss"] = loss_vla.item()
+ if not isinstance(loss_depth, int):
+ loss_dict["depth_loss"] = loss_depth.item()
+ total_loss = loss_vla + loss_depth
+
+ return total_loss, loss_vla, loss_depth, loss_dict, depth_preds
+
+class FlowMatching(nn.Module):
+ def __init__(self, config, eval):
+ super().__init__()
+ self.config = config
+
+ # qwenvl with action expert
+ qwenvl_with_export_config = QwenvlWithExpertConfig(
+ freeze_vision_encoder=self.config.freeze_vision_encoder,
+ train_expert_only=self.config.train_expert_only,
+ vocab_size=getattr(self.config,"vocab_size", 0),
+ use_lm_head=getattr(self.config,"use_lm_head", False),
+ attention_implementation=self.config.attention_implementation,
+ tokenizer_path=self.config.tokenizer_path,
+ enable_expert_vision=self.config.enable_expert_vision,
+ expert_vision_type=self.config.expert_vision_type,
+ )
+ qwenvl_with_export_config.adanorm_time = getattr(config, "adanorm_time", False)
+ qwenvl_with_export_config.split_gate_liner = getattr(config, "split_gate_liner", False)
+ qwenvl_with_export_config.no_split_gate_liner = getattr(config, "nosplit_gate_liner", False)
+ qwenvl_with_export_config.separate_time_proj = getattr(config, "separate_time_proj", False)
+ qwenvl_with_export_config.old_adanorm = getattr(config, "old_adanorm", False)
+ qwenvl_with_export_config.final_norm_adanorm = getattr(config, "final_norm_adanorm", False)
+ qwenvl_with_export_config.norm_qkv = getattr(config, "norm_qkv", False)
+ self.qwenvl_with_expert = QwenvlWithExpertModel(
+ qwenvl_with_export_config, eval
+ )
+ self.config.proj_width = qwenvl_with_export_config.qwen_expert_config.hidden_size
+ self.config.initializer_range = getattr(qwenvl_with_export_config.qwen_expert_config, "initializer_range", None)
+ # projection layers
+ self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
+ self.action_in_proj = nn.Linear(
+ self.config.max_action_dim, self.config.proj_width
+ )
+ self.action_out_proj = nn.Linear(
+ self.config.proj_width, self.config.max_action_dim
+ )
+ if getattr(config, "separate_time_proj", False):
+ self.time_mlp_in = nn.Linear(self.config.proj_width, self.config.proj_width)
+ self.time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
+ else:
+ self.action_time_mlp_in = nn.Linear(
+ self.config.proj_width * 2, self.config.proj_width
+ )
+ self.action_time_mlp_out = nn.Linear(
+ self.config.proj_width, self.config.proj_width
+ )
+ self.config.align_params = getattr(self.config, 'align_params', {})
+ if self.config.align_params != {}:
+ self.steps=0
+ self.use_depth_align = True
+ self.init_depth_heads(self.config.align_params)
+ else:
+ self.use_depth_align = False
+
+ self.set_requires_grad()
+
+ def init_depth_heads(self, config):
+ self.llm_image_token_size = config['llm']['image_token_size']
+ self.llm_image_input_size = config['llm']['image_input_size']
+ self.depth_token_size = config['depth']['token_size']
+ self.depth_input_size = config['depth']['input_size']
+ self.align_type = config.get('mode', None)
+ self.model_type = config['depth']['model_type']
+
+ if self.align_type == "direct":
+ self.depth_align_head = nn.Sequential(
+ nn.Linear(config['llm']['dim_out'], config['depth']['dim_out']*2),
+ nn.GELU(),
+ nn.Linear(config['depth']['dim_out']*2, config['depth']['dim_out'])
+ )
+ for p in self.depth_align_head.parameters():
+ p.requires_grad = True
+ elif self.align_type == "query":
+ self.num_task_tokens = config['num_task_tokens']
+ assert config['depth']['num_backbone_tokens'] % self.num_task_tokens == 0
+ self.depth_align_embs = nn.Parameter(
+ torch.randn(
+ config['depth']['num_backbone_tokens'], config['llm']['dim_out']
+ )
+ ).to(dtype=torch.bfloat16)
+ self.depth_align_embs.requires_grad_ = True
+
+ self.depth_align_head = TaskTokenDepthHead(config['depth'], llm_hidden_size=config['llm']['dim_out']).to(dtype=torch.bfloat16)
+
+ for p in self.depth_align_head.parameters():
+ p.requires_grad = True
+
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def set_requires_grad(self):
+ for params in self.state_proj.parameters():
+ params.requires_grad = self.config.train_state_proj
+
+ def sample_time(self, bsize, device):
+ time_beta = sample_beta(1.5, 1.0, bsize, device)
+ time = time_beta * 0.999 + 0.001
+ return time.to(dtype=torch.float32, device=device)
+
+ def embed_prefix(
+ self, images, img_masks, lang_tokens, lang_masks, vlm_causal, label=None
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ bsize = images.shape[0]
+ device = images.device
+ dtype = images.dtype
+
+ # embed image
+ if images.ndim == 5:
+ images = einops.rearrange(images, "b n c h w -> (b n) c h w")
+ elif images.ndim == 4:
+ images = einops.rearrange(images, "b n l d -> (b n) l d")
+ elif images.ndim == 3: # For inference bs=1
+ bsize = 1
+ # ipdb.set_trace()
+ img_emb = self.qwenvl_with_expert.embed_image(images)
+ num_patch = img_emb.shape[1]
+ img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize) # bsize = 24
+ num_img_embs = img_emb.shape[1]
+ if img_masks.ndim ==1: # For inference bs=1
+ img_masks = img_masks.unsqueeze(0)
+ if self.use_depth_align and self.align_type == "query":
+ align_masks = einops.repeat(img_masks, "b n -> b (n l)", l=self.num_task_tokens)
+ img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch)
+
+ # embed language
+ lang_emb = self.qwenvl_with_expert.embed_language_tokens(lang_tokens)
+ num_lang_embs = lang_emb.shape[1]
+
+ if self.use_depth_align and self.align_type == "query":
+ def _get_align_tokens(tokens):
+ tk_weights = tokens.view(self.num_task_tokens, tokens.shape[0] // self.num_task_tokens, tokens.shape[1])
+ tk_weights = tk_weights.mean(dim=1)
+ return tk_weights
+
+ align_embs = _get_align_tokens(self.depth_align_embs).repeat(img_emb.size(0), 1, 1).to(img_emb.device, img_emb.dtype)
+ # align_masks = einops.rearrange(img_masks, "b (n l) -> b n l", n=3)
+ # align_masks = align_masks[:, :, 0]
+ # align_masks = einops.repeat(align_masks, "b n -> b (n l)", l=self.num_task_tokens)
+ embs = torch.cat([img_emb, align_embs, align_embs, align_embs, lang_emb], dim=1)
+ pad_masks = torch.cat([img_masks, align_masks, lang_masks], dim=1)
+ else:
+ # assemble embeddings
+ embs = torch.cat([img_emb, lang_emb], dim=1)
+ pad_masks = torch.cat([img_masks, lang_masks], dim=1)
+
+ # (see `make_att_2d_masks` to understand why zeros means bidirection)
+ if not vlm_causal:
+ if self.use_depth_align and self.align_type == "query":
+ att_masks = torch.zeros(
+ (img_emb.size(0), num_img_embs + 3 * self.num_task_tokens + num_lang_embs), device=device, dtype=torch.bool
+ )
+ else:
+ att_masks = torch.zeros(
+ (img_emb.size(0), num_img_embs + num_lang_embs), device=device, dtype=torch.bool
+ )
+ else:
+ if self.use_depth_align and self.align_type == "query":
+ att_masks = torch.ones(
+ (img_emb.size(0), num_img_embs + 3 * self.num_task_tokens + num_lang_embs), device=device, dtype=torch.bool
+ )
+ else:
+ att_masks = torch.ones(
+ (img_emb.size(0), num_img_embs + num_lang_embs), device=device, dtype=torch.bool
+ )
+ return embs, pad_masks, att_masks
+
+ def embed_suffix(self, state, noisy_actions, timestep, expert_imgs=None):
+ bsize = state.shape[0] # state_bs = img_bs
+ device = state.device
+ dtype = state.dtype
+ # embed state
+ state_emb = self.state_proj(state) # torch.Size([state_bs, 1024])
+
+ # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
+ time_emb = create_sinusoidal_pos_embedding( # 1, 1024
+ timestep, # torch.Size([1]))
+ self.config.proj_width, # 1024
+ min_period=4e-3,
+ max_period=4.0,
+ device=device,
+ )
+ time_emb = time_emb.type(dtype=dtype)
+
+ time_emb_ori = time_emb
+
+ # Fuse timestep + action information using an MLP
+ action_emb = self.action_in_proj(noisy_actions) # torch.Size([1, state_bs*50, 1024])
+ if getattr(self.config, "separate_time_proj", False):
+ time_emb = self.time_mlp_in(time_emb)
+ time_emb = F.silu(time_emb)
+ time_emb_ori = F.silu(self.time_mlp_out(time_emb)) # [1, 1024]
+ action_time_emb = action_emb
+ else:
+ time_emb = einops.repeat(time_emb, "b d -> b n d", n=action_emb.shape[1]) # [1, 1024] -> [1, state_bs*50, 1024]
+ action_time_emb = torch.cat([action_emb, time_emb], dim=-1) # [1, state_bs*50, 2048]
+
+ action_time_emb = self.action_time_mlp_in(action_time_emb)
+ action_time_emb = F.silu(action_time_emb) # swish == silu
+ action_time_emb = self.action_time_mlp_out(action_time_emb) # [1, state_bs*50, 1024]
+ action_time_dim = action_time_emb.shape[1]
+
+ if expert_imgs is not None:
+ if expert_imgs.ndim == 5:
+ expert_imgs = einops.rearrange(expert_imgs, "b n c h w -> (b n) c h w")
+ elif expert_imgs.ndim == 4:
+ bsize=1
+ expert_img_emb = self.qwenvl_with_expert.expert_visual.forward_features(expert_imgs)["x_norm_clstoken"].unsqueeze(1)
+ expert_img_emb = self.qwenvl_with_expert.expert_visual_mlp(expert_img_emb)
+ expert_img_emb = einops.rearrange(expert_img_emb, "(b n) l d -> b (n l) d", b=bsize) # bsize = 24
+ embs = torch.cat([expert_img_emb, state_emb[:, None], action_time_emb], dim=1)
+ num_expert_img_emb = expert_img_emb.shape[1]
+ pad_masks = torch.ones(
+ (bsize, action_time_dim + 1 + num_expert_img_emb), device=device, dtype=torch.bool
+ )
+ att_masks = torch.zeros(
+ (bsize, action_time_dim + 1 + num_expert_img_emb), device=device, dtype=torch.bool
+ )
+ att_masks[:, [0, num_expert_img_emb, num_expert_img_emb + 1]] = True
+
+ else:
+ embs = torch.cat([state_emb[:, None], action_time_emb], dim=1)
+ pad_masks = torch.ones(
+ (bsize, action_time_dim + 1), device=device, dtype=torch.bool
+ )
+
+ # Set attention masks for suffix tokens so that prefix tokens cannot attend to suffix tokens.
+ # And state token cannot attend action tokens.
+ # Action tokens use a bidirectional attention.
+ att_masks = torch.zeros(
+ (bsize, action_time_dim + 1), device=device, dtype=torch.bool
+ )
+ att_masks[:, :2] = True
+
+ return time_emb_ori, embs, pad_masks, att_masks
+
+ def forward(
+ self,
+ images,
+ img_masks,
+ lang_tokens,
+ lang_masks,
+ state,
+ actions,
+ expert_imgs,
+ label=None,
+ noise=None,
+ time=None,
+ vlm_causal=False,
+ loss_type='fm',
+ use_ki=False,
+ depth_targets=None,
+ norm_qkv=False
+ ) -> Tensor:
+ dtype = state.dtype
+ device = state.device
+ if noise is None:
+ # actions_shape = (
+ # bsize,
+ # self.config.n_action_steps, # 50
+ # self.config.max_action_dim, # 32
+ # )
+ noise = torch.randn(actions.shape, device=device, dtype=dtype)
+
+ if time is None:
+ time = self.sample_time(actions.size(0), device).to(dtype)
+
+ time_expanded = time[:, None, None]
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
+ u_t = noise - actions
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
+ images, img_masks, lang_tokens, lang_masks, vlm_causal, label
+ ) # 1,bs_img*(768+48),2048 1,bs_img*(768+48) 1,bs_img*(768+48)
+ time_embs, suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
+ state, x_t, time, expert_imgs
+ ) # [1, state_bs*(50+1), 1024], [1, state_bs*(50+1)], [1, state_bs*(50+1)] state_bs=bs_img
+
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
+
+ # pad_masks = pad_masks.reshape(state.size(0), -1)
+ # att_masks = att_masks.reshape(state.size(0), -1)
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
+ if self.use_depth_align and self.align_type == "query":
+ att_2d_masks = self.make_att_2d_masks_with_query(att_2d_masks, prefix_pad_masks.shape[-1], img_masks)
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1
+ vlm_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
+
+ # prefix_embs = prefix_embs.reshape(state.size(0), -1, prefix_embs.size(-1))
+ # suffix_embs = suffix_embs.reshape(state.size(0), -1, suffix_embs.size(-1))
+ (outputs_embeds, suffix_out), _ = self.qwenvl_with_expert.forward(
+ attention_mask=att_2d_masks,
+ position_ids=position_ids,
+ vlm_position_ids=vlm_position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, suffix_embs],
+ use_cache=True,
+ fill_kv_cache=True,
+ ada_cond = time_embs if getattr(self.config, 'adanorm_time', False) else None,
+ use_ki=use_ki,
+ norm_qkv=norm_qkv
+ )
+ if self.config.align_params != {}:
+ loss_depth, depth_preds = self.depth_emb_forward(outputs_embeds, depth_targets, img_masks)
+ loss_depth = loss_depth * self.config.align_params['depth_loss_weight']
+ self.steps+=1
+ else:
+ loss_depth = 0
+ depth_preds = None
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
+ if suffix_out.dtype != self.action_out_proj.weight.dtype:
+ suffix_out = suffix_out.to(self.action_out_proj.weight.dtype)
+ v_t = self.action_out_proj(suffix_out)
+ # u_t = u_t.reshape(images.size(0), -1, u_t.size(-1))
+ if loss_type == 'fm':
+ losses = F.mse_loss(u_t, v_t, reduction="none")
+ # losses = torch.mean((v_t - u_t)**2, dim=-1)
+ elif loss_type == 'L1_fm':
+ losses = F.l1_loss(u_t, v_t, reduction="none")
+
+ return losses, loss_depth, depth_preds
+
+ def sample_actions(
+ self, images, img_masks, lang_tokens, lang_masks, state, expert_imgs=None, vlm_causal=False, noise=None
+ ) -> Tensor:
+ """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
+ bsize = state.shape[0]
+ device = state.device
+ dtype = state.dtype
+
+ if noise is None:
+ actions_shape = (
+ bsize,
+ self.config.n_action_steps,
+ self.config.max_action_dim,
+ )
+ noise = torch.randn(actions_shape, device=device, dtype=dtype)
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
+ images, img_masks, lang_tokens, lang_masks, vlm_causal
+ )
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) # bs, prefix_len, prefix_len
+ if self.use_depth_align and self.align_type == "query":
+ prefix_att_2d_masks = self.make_att_2d_masks_with_query(prefix_att_2d_masks, prefix_pad_masks.shape[-1], img_masks)
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
+
+ # Compute image and language key value cache
+ _, past_key_values = self.qwenvl_with_expert.forward(
+ attention_mask=prefix_att_2d_masks,
+ position_ids=prefix_position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, None],
+ use_cache=self.config.use_cache,
+ fill_kv_cache=True,
+ )
+
+ dt = torch.tensor(-1.0 / self.config.num_steps, dtype=dtype, device=device)
+ x_t = noise
+ time = torch.tensor(1.0, dtype=dtype, device=device)
+ count = 0
+ while time >= -dt / 2:
+ count += 1
+ expanded_time = time.expand(bsize)
+
+ v_t = self.predict_velocity(
+ state, prefix_pad_masks, past_key_values, x_t, expert_imgs, expanded_time
+ )
+
+ # Euler step
+ x_t += dt * v_t
+ time += dt
+ return x_t
+
+ def predict_velocity(self, state, prefix_pad_masks, past_key_values, x_t, expert_imgs, timestep):
+ """predict velocity at time t using the suffix model."""
+ time_embs, suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
+ state, x_t, timestep, expert_imgs
+ )
+
+ suffix_len = suffix_pad_masks.shape[1]
+ batch_size = prefix_pad_masks.shape[0]
+ prefix_len = prefix_pad_masks.shape[1]
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
+ batch_size, suffix_len, prefix_len
+ )
+
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
+
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) # bs, suffix_len, prefix_len+suffix_len
+
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
+
+ outputs_embeds, _ = self.qwenvl_with_expert.forward(
+ attention_mask=full_att_2d_masks,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=[None, suffix_embs],
+ use_cache=self.config.use_cache,
+ fill_kv_cache=False,
+ ada_cond = time_embs if getattr(self.config, 'adanorm_time', False) else None,
+ )
+ suffix_out = outputs_embeds[1]
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
+ v_t = self.action_out_proj(suffix_out)
+ return v_t
+
+ def make_att_2d_masks_with_query(self, att_2d_masks, prefix_len, img_masks):
+ if img_masks.ndim == 1:
+ img_masks = img_masks.unsqueeze(0)
+
+ num_image_tokens = self.llm_image_token_size * self.llm_image_token_size
+
+ att_2d_masks[:, num_image_tokens * 0 : num_image_tokens * 3, num_image_tokens * 3 + self.num_task_tokens * 0: num_image_tokens * 3 + self.num_task_tokens * 3] = False
+ att_2d_masks[:, num_image_tokens * 3 + self.num_task_tokens * 3: prefix_len, num_image_tokens * 3 + self.num_task_tokens * 0: num_image_tokens * 3 + self.num_task_tokens * 3] = False
+
+ att_2d_masks[:, num_image_tokens * 3 + self.num_task_tokens * 0: num_image_tokens * 3 + self.num_task_tokens * 3, :] = False
+
+ att_2d_masks[img_masks[:, 0], num_image_tokens * 3 + self.num_task_tokens * 0: num_image_tokens * 3 + self.num_task_tokens * 1, num_image_tokens * 0: num_image_tokens * 1] = True
+ att_2d_masks[img_masks[:, 1], num_image_tokens * 3 + self.num_task_tokens * 1: num_image_tokens * 3 + self.num_task_tokens * 2, num_image_tokens * 1: num_image_tokens * 2] = True
+ att_2d_masks[img_masks[:, 2], num_image_tokens * 3 + self.num_task_tokens * 2: num_image_tokens * 3 + self.num_task_tokens * 3, num_image_tokens * 2: num_image_tokens * 3] = True
+
+ att_2d_masks[img_masks[:, 0], num_image_tokens * 3 + self.num_task_tokens * 0: num_image_tokens * 3 + self.num_task_tokens * 1, num_image_tokens * 3 + self.num_task_tokens * 0: num_image_tokens * 3 + self.num_task_tokens * 1] = True
+ att_2d_masks[img_masks[:, 1], num_image_tokens * 3 + self.num_task_tokens * 1: num_image_tokens * 3 + self.num_task_tokens * 2, num_image_tokens * 3 + self.num_task_tokens * 1: num_image_tokens * 3 + self.num_task_tokens * 2] = True
+ att_2d_masks[img_masks[:, 2], num_image_tokens * 3 + self.num_task_tokens * 2: num_image_tokens * 3 + self.num_task_tokens * 3, num_image_tokens * 3 + self.num_task_tokens * 2: num_image_tokens * 3 + self.num_task_tokens * 3] = True
+
+ return att_2d_masks
+
+ def depth_emb_forward(self, hidden_states, depth_targets=None, img_masks=None):
+ chunk_size = self.llm_image_token_size * self.llm_image_token_size
+ img_masks = einops.rearrange(img_masks, 'b n -> (b n)')
+ if self.align_type == 'direct':
+
+ image_embs = hidden_states[:, chunk_size * 0 : chunk_size * 3, :]
+ image_embs = einops.rearrange(image_embs, 'b (n l) c -> (b n) l c', n=3)
+
+ depth_preds = self.depth_align_head(image_embs).contiguous().float()
+ elif self.align_type == 'query':
+ align_embs = hidden_states[:, chunk_size * 3 : chunk_size * 3 + self.num_task_tokens * 3, :]
+ align_embs = einops.rearrange(align_embs, 'b (n l) c -> (b n) l c', n=3)
+
+ image_embs = hidden_states[:, chunk_size * 0 : chunk_size * 3, :]
+ image_embs = einops.rearrange(image_embs, 'b (n l) c -> (b n) l c', n=3)
+
+ align_embs = torch.cat([image_embs, align_embs], dim=1)
+ depth_preds = self.depth_align_embs.repeat(align_embs.shape[0], 1, 1).to(dtype=align_embs.dtype, device=align_embs.device)
+
+ depth_preds = self.depth_align_head(align_embs, depth_preds).contiguous().float()
+
+ loss = self._emb_loss(depth_preds[img_masks], depth_targets[img_masks])
+
+ return loss, depth_preds
+
+ def _emb_loss(self, emb_preds, emb_targets):
+ if self.align_type == "direct":
+ S, L, D = emb_targets.shape
+ emb_preds = emb_preds.contiguous().view(S * self.llm_image_token_size * self.llm_image_token_size, emb_preds.shape[-1])
+
+ emb_targets = emb_targets.to(emb_preds.dtype).to(emb_preds.device)
+ emb_targets = emb_targets.view(S, self.depth_token_size, self.depth_token_size, D).permute(0, 3, 1, 2).contiguous()
+ emb_targets = F.adaptive_avg_pool2d(emb_targets, (self.llm_image_token_size, self.llm_image_token_size)).view(S, D, self.llm_image_token_size, self.llm_image_token_size)
+ emb_targets = emb_targets.view(S, D, self.llm_image_token_size*self.llm_image_token_size).permute(0, 2, 1).contiguous().view(S * self.llm_image_token_size * self.llm_image_token_size, D)
+
+ l1_loss = F.l1_loss(emb_preds.float(), emb_targets.float().detach(), reduction="none")
+
+ emb_preds_norm = F.normalize(emb_preds.float(), p=2, dim=-1, eps=1e-6)
+ emb_targets_norm = F.normalize(emb_targets.float(), p=2, dim=-1, eps=1e-6)
+ emb_preds_matrix = torch.matmul(emb_preds_norm, emb_preds_norm.transpose(0, 1))
+ emb_targets_matrix = torch.matmul(emb_targets_norm, emb_targets_norm.transpose(0, 1))
+
+ sim_loss = F.l1_loss(emb_preds_matrix.float(), emb_targets_matrix.float().detach(), reduction="none")
+
+ emb_loss = sim_loss.mean() + l1_loss.mean()
+ elif self.align_type == "query":
+ l1_loss = F.smooth_l1_loss(emb_preds.float(), emb_targets.float().detach(), reduction="none")
+ emb_loss = l1_loss.mean()
+ if self.model_type == 'MoRGBD':
+ emb_loss = emb_loss
+ return emb_loss
+
+ModelClass = LingbotVlaPolicy
+
+__all__ = ["LingbotVlaPolicy", "Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2ForCausalLM", "Qwen2_5_VLPreTrainedModel"]
\ No newline at end of file
diff --git a/lingbotvla/models/vla/pi0/modeling_pi0.py b/lingbotvla/models/vla/pi0/modeling_pi0.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c8c83fb6d8e493455b115dbdf2859b648e8c52f
--- /dev/null
+++ b/lingbotvla/models/vla/pi0/modeling_pi0.py
@@ -0,0 +1,2189 @@
+from logging import raiseExceptions
+import einops
+import numpy as np
+import torch
+from torch import nn
+import torch.nn.functional as F
+from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
+from lerobot.common.policies.pretrained import PreTrainedPolicy
+from torch import Tensor, nn
+from typing import List, Optional, Tuple, Union
+from transformers import (
+ AutoConfig,
+ PretrainedConfig,
+ PreTrainedModel,
+)
+from transformers.models.auto import CONFIG_MAPPING
+from transformers import AutoTokenizer
+from dataclasses import dataclass
+from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
+from transformers.models.gemma.configuration_gemma import GemmaConfig
+from transformers.cache_utils import Cache, HybridCache, StaticCache, DynamicCache
+from transformers.generation import GenerationMixin
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel, ALL_ATTENTION_FUNCTIONS
+from transformers.utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_torchdynamo_compiling,
+ logging,
+ replace_return_docstrings,
+ LossKwargs,
+ can_return_tuple,
+ is_torch_flex_attn_available,
+)
+from transformers.utils.deprecation import deprecate_kwarg
+from transformers.models.auto import AutoModel
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from transformers.processing_utils import Unpack
+import torch.distributed._tensor as dt
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
+from .utils import (
+ create_sinusoidal_pos_embedding,
+ make_att_2d_masks,
+ resize_with_pad,
+ sample_beta,
+)
+from .utils import apply_rope, our_eager_attention_forward
+from .flex_attention import flex_attention_forward
+import ipdb
+IMAGE_KEYS = (
+ "base_0_rgb",
+ "left_wrist_0_rgb",
+ "right_wrist_0_rgb",
+)
+
+logger = logging.get_logger(__name__)
+_CONFIG_FOR_DOC = "PaliGemmaConfig"
+
+@dataclass
+class PaliGemmaCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for PaliGemmacausal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+class PaliGemmaMultiModalProjector(nn.Module):
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__()
+ self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
+
+ def forward(self, image_features):
+ hidden_states = self.linear(image_features)
+
+ return hidden_states
+
+
+PALIGEMMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`PaliGemmaConfig`] or [`PaliGemmaVisionConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ PALIGEMMA_START_DOCSTRING,
+)
+class PaliGemmaPreTrainedModel(PreTrainedModel):
+ config_class = PaliGemmaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["PaliGemmaMultiModalProjector"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ # important: this ported version of PaliGemmaisn't meant for training from scratch - only
+ # inference and fine-tuning
+ std = (
+ self.config.initializer_range
+ if hasattr(self.config, "initializer_range")
+ else self.config.text_config.initializer_range
+ )
+
+ if hasattr(module, "class_embedding"):
+ module.class_embedding.data.normal_(mean=0.0, std=std)
+
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+PALIGEMMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses
+ [`SiglipImageProcessor`] for processing images).
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+@add_start_docstrings(
+ """The PALIGEMMA model which consists of a vision backbone and a language model.""",
+ PALIGEMMA_START_DOCSTRING,
+)
+class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
+ _no_split_modules = ["PaliGemmaMultiModalProjector", "GemmaDecoderLayer"]
+ def __init__(self, config: PaliGemmaConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
+ self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
+ self.vocab_size = config.text_config.vocab_size
+
+ language_model = GemmaForCausalLM(config=config.text_config, vlm=True)
+
+ if language_model._tied_weights_keys is not None:
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
+ self.language_model = language_model
+
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.post_init()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
+ def set_output_embeddings(self, new_embeddings):
+ self.language_model.set_output_embeddings(new_embeddings)
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
+ def set_decoder(self, decoder):
+ self.language_model.set_decoder(decoder)
+
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ def _update_causal_mask(
+ self,
+ attention_mask,
+ token_type_ids=None,
+ past_key_values=None,
+ cache_position=None,
+ input_tensor=None,
+ is_training: Optional[bool] = None,
+ ):
+ if self.config.text_config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+ is_training = is_training if is_training is not None else self.training
+ using_static_cache = isinstance(past_key_values, StaticCache)
+ min_dtype = torch.finfo(self.dtype).min
+ if input_tensor is None:
+ input_tensor = attention_mask
+
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ elif isinstance(past_key_values, HybridCache):
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else cache_position[0] + sequence_length + 1
+ )
+
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ return attention_mask
+
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
+ )
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
+ if sequence_length != 1:
+ if is_training:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ else:
+ causal_mask[:, :sequence_length] = 0.0
+
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+
+ # First unmask prefix tokens during training
+ if is_training:
+ if token_type_ids is None:
+ raise ValueError("Token type ids must be provided during training")
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
+ )
+
+ # Then apply padding mask (will mask pad tokens)
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ def get_image_features(self, pixel_values: torch.FloatTensor):
+ image_outputs = self.vision_tower(pixel_values)
+ selected_image_feature = image_outputs.last_hidden_state
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = image_features / (self.config.text_config.hidden_size**0.5)
+ return image_features
+
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
+ @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **lm_kwargs,
+ ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
+
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
+ >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
+
+ >>> prompt = "Where is the cat standing?"
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(**inputs,)
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Where is the cat standing?\nsnow"
+ ```"""
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ is_training = token_type_ids is not None and labels is not None
+
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size:
+ special_image_mask = input_ids == self.config.image_token_index
+ llm_input_ids = input_ids.clone()
+ llm_input_ids[special_image_mask] = 0
+ else:
+ llm_input_ids = input_ids
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
+
+ # Merge text and images
+ if pixel_values is not None:
+ image_features = self.get_image_features(pixel_values)
+
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
+ )
+ else:
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
+
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
+ raise ValueError(
+ f"Number of images does not match number of special image tokens in the input text. "
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
+ "tokens from image embeddings."
+ )
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ # mask out pad-token-ids in labels for BC
+ if labels is not None and self.pad_token_id in labels:
+ logger.warning_once(
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
+ )
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training
+ )
+ outputs: CausalLMOutputWithPast = self.language_model(
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **lm_kwargs,
+ )
+
+ logits = outputs[0]
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ shift_logits = logits[..., :-1, :]
+ shift_labels = labels[..., 1:]
+ if attention_mask is not None:
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
+ else:
+ shift_logits = shift_logits.contiguous()
+ shift_labels = shift_labels.contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
+ loss = loss_fct(flat_logits, flat_labels)
+
+ output = PaliGemmaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+ return output if return_dict else output.to_tuple()
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ pixel_values=None,
+ attention_mask=None,
+ token_type_ids=None,
+ use_cache=True,
+ logits_to_keep=None,
+ labels=None,
+ **kwargs,
+ ):
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
+ model_inputs = self.language_model.prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ use_cache=use_cache,
+ logits_to_keep=logits_to_keep,
+ token_type_ids=token_type_ids,
+ **kwargs,
+ )
+
+ # position_ids in Paligemma are 1-indexed
+ if model_inputs.get("position_ids") is not None:
+ model_inputs["position_ids"] += 1
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
+ if cache_position[0] == 0:
+ model_inputs["pixel_values"] = pixel_values
+ is_training = token_type_ids is not None and labels is not None
+ if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
+ causal_mask = self._update_causal_mask(
+ attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
+ )
+ model_inputs["attention_mask"] = causal_mask
+
+ return model_inputs
+
+_CHECKPOINT_FOR_DOC = "google/gemma-7b"
+_CONFIG_FOR_DOC = "GemmaConfig"
+
+class GemmaRMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.zeros(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float())
+ # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = output * (1.0 + self.weight.float())
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+class FixedGemmaRMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.zeros(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ output = self._norm(x.float())
+ # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ output = output * (1.0 + self.weight.float())
+ return output.type_as(x)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
+
+
+class GemmaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+class GemmaRotaryEmbedding(nn.Module):
+ def __init__(self, config: GemmaConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class GemmaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: GemmaConfig, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class GemmaDecoderLayer(nn.Module):
+ def __init__(self, config: GemmaConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = GemmaMLP(config)
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ att_output: Optional[torch.Tensor] = None,
+ start: Optional[int] = 0,
+ end: Optional[int] = 0,
+ compute_kqv: bool = False,
+ output_atten: bool = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ if compute_kqv:
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_shape = (*hidden_states.shape[:-1], -1, self.self_attn.head_dim)
+
+ query_state = self.self_attn.q_proj(hidden_states).view(hidden_shape)
+ key_state = self.self_attn.k_proj(hidden_states).view(hidden_shape)
+ value_state = self.self_attn.v_proj(hidden_states).view(hidden_shape)
+
+ return query_state, key_state, value_state
+
+ elif output_atten:
+ if att_output.dtype != self.self_attn.o_proj.weight.dtype:
+ att_output = att_output.to(self.self_attn.o_proj.weight.dtype)
+ out_emb = self.self_attn.o_proj(att_output[:, start:end])
+
+ # first residual
+ out_emb += hidden_states
+ after_first_residual = out_emb.clone()
+
+ out_emb = self.post_attention_layernorm(out_emb)
+ out_emb = self.mlp(out_emb)
+
+ # second residual
+ out_emb += after_first_residual
+
+ return out_emb
+
+ else:
+ raise ValueError(f"Invaild Operation compute_kqv={compute_kqv} and output_atten={output_atten} with GemmaDecoderLayer in PI0")
+
+class MpGemmaDecoderLayer(nn.Module):
+ def __init__(self, config: GemmaConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx)
+
+ self.mlp = GemmaMLP(config)
+ self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ att_output: Optional[torch.Tensor] = None,
+ start: Optional[int] = 0,
+ end: Optional[int] = 0,
+ compute_kqv: bool = False,
+ output_atten: bool = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ if compute_kqv:
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_shape = (*hidden_states.shape[:-1], -1, self.self_attn.head_dim)
+
+ query_state = self.self_attn.q_proj(hidden_states).view(hidden_shape)
+ key_state = self.self_attn.k_proj(hidden_states).view(hidden_shape)
+ value_state = self.self_attn.v_proj(hidden_states).view(hidden_shape)
+
+ return query_state, key_state, value_state
+
+ elif output_atten:
+ if att_output.dtype != self.self_attn.o_proj.weight.dtype:
+ att_output = att_output.to(self.self_attn.o_proj.weight.dtype)
+ out_emb = self.self_attn.o_proj(att_output[:, start:end])
+
+ # first residual
+ out_emb += hidden_states
+ after_first_residual = out_emb.clone()
+
+ out_emb = self.post_attention_layernorm(out_emb)
+ out_emb = self.mlp(out_emb)
+
+ # second residual
+ out_emb += after_first_residual
+
+ return out_emb
+
+ else:
+ raise ValueError(f"Invaild Operation compute_kqv={compute_kqv} and output_atten={output_atten} with GemmaDecoderLayer in PI0")
+
+GEMMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`GemmaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
+ GEMMA_START_DOCSTRING,
+)
+class GemmaPreTrainedModel(PreTrainedModel):
+ config_class = GemmaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["GemmaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+GEMMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
+ GEMMA_START_DOCSTRING,
+)
+class GemmaModel(GemmaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
+
+ Args:
+ config: GemmaConfig
+ """
+
+ def __init__(self, config: GemmaConfig, **kwargs):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ in_expert = kwargs.get("expert", False)
+ in_vlm = kwargs.get("vlm", False)
+ assert not (in_expert and in_vlm), "expert and vlm cannot be True at the same time"
+ assert (in_expert or in_vlm), "expert or vlm must be True"
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ if in_expert:
+ self.layers = nn.ModuleList(
+ [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ elif in_vlm:
+ self.layers = nn.ModuleList(
+ [MpGemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ if in_expert:
+ self.norm = FixedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ elif in_vlm:
+ self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = GemmaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs, # NOOP kwarg for now
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # normalized
+ # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
+ # See https://github.com/huggingface/transformers/pull/29402
+ normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
+ hidden_states = hidden_states * normalizer
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ if isinstance(attention_mask, BlockMask):
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config, **kwargs):
+ super().__init__(config)
+ self.model = GemmaModel(config, **kwargs)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
+ @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, GemmaForCausalLM
+
+ >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+class PaliGemmaWithExpertConfig(PretrainedConfig):
+ model_type = "PaliGemmaWithExpertModel"
+ sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
+
+ def __init__(
+ self,
+ paligemma_config: dict | None = None,
+ gemma_expert_config: dict | None = None,
+ freeze_vision_encoder: bool = True,
+ train_expert_only: bool = True,
+ vocab_size: int = 257152,
+ attention_implementation: str = "eager",
+ **kwargs,
+ ):
+ self.freeze_vision_encoder = freeze_vision_encoder
+ self.train_expert_only = train_expert_only
+ self.attention_implementation = attention_implementation
+
+ if paligemma_config is None:
+ # Default config from Pi0
+ self.paligemma_config = CONFIG_MAPPING["paligemma"](
+ transformers_version="4.48.1",
+ _vocab_size=257152,
+ bos_token_id=2,
+ eos_token_id=1,
+ hidden_size=2048,
+ image_token_index=257152,
+ model_type="paligemma",
+ pad_token_id=0,
+ projection_dim=2048,
+ text_config={
+ "hidden_activation": "gelu_pytorch_tanh",
+ "hidden_size": 2048,
+ "intermediate_size": 16384,
+ "model_type": "gemma",
+ "num_attention_heads": 8,
+ "num_hidden_layers": 18,
+ "num_image_tokens": 256,
+ "num_key_value_heads": 1,
+ "torch_dtype": "float32",
+ "vocab_size": vocab_size,
+ },
+ vision_config={
+ "hidden_size": 1152,
+ "intermediate_size": 4304,
+ "model_type": "siglip_vision_model",
+ "num_attention_heads": 16,
+ "num_hidden_layers": 27,
+ "num_image_tokens": 256,
+ "patch_size": 14,
+ "projection_dim": 2048,
+ "projector_hidden_act": "gelu_fast",
+ "torch_dtype": "float32",
+ "vision_use_head": False,
+ },
+ )
+ elif isinstance(self.paligemma_config, dict):
+ # Override Pi0 default config for PaliGemma
+ if "model_type" not in gemma_expert_config:
+ paligemma_config["model_type"] = "paligemma"
+
+ cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
+ self.paligemma_config = cfg_cls(**paligemma_config)
+
+ if gemma_expert_config is None:
+ # Default config from Pi0
+ self.gemma_expert_config = CONFIG_MAPPING["gemma"](
+ attention_bias=False,
+ attention_dropout=0.0,
+ bos_token_id=2,
+ eos_token_id=1,
+ head_dim=256,
+ hidden_act="gelu_pytorch_tanh",
+ hidden_activation="gelu_pytorch_tanh",
+ hidden_size=1024,
+ initializer_range=0.02,
+ intermediate_size=4096,
+ max_position_embeddings=8192,
+ model_type="gemma",
+ num_attention_heads=8,
+ num_hidden_layers=18,
+ num_key_value_heads=1,
+ pad_token_id=0,
+ rms_norm_eps=1e-06,
+ rope_theta=10000.0,
+ torch_dtype="float32",
+ transformers_version="4.48.1",
+ use_cache=True,
+ vocab_size=257152,
+ )
+ elif isinstance(self.gemma_expert_config, dict):
+ # Override Pi0 default config for Gemma Expert
+ if "model_type" not in gemma_expert_config:
+ gemma_expert_config["model_type"] = "gemma"
+
+ cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
+ self.gemma_expert_config = cfg_cls(**gemma_expert_config)
+
+ super().__init__(**kwargs)
+
+ def __post_init__(self):
+ super().__post_init__()
+ if self.train_expert_only and not self.freeze_vision_encoder:
+ raise ValueError(
+ "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
+ )
+
+ if self.attention_implementation not in ["eager", "fa2", "flex"]:
+ raise ValueError(
+ f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
+ )
+
+class PaliGemmaWithExpertModel(PreTrainedModel):
+ config_class = PaliGemmaWithExpertConfig
+
+ def __init__(self, config: PaliGemmaWithExpertConfig):
+ super().__init__(config=config)
+ self.config = config
+ self.paligemma = PaliGemmaForConditionalGeneration(
+ config=config.paligemma_config
+ )
+ self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config, expert=True)
+ # Remove unused embed_tokens
+ del self.gemma_expert.model.embed_tokens
+
+ self.attention_interface = self.get_attention_interface()
+
+ # self.to_bfloat16_like_physical_intelligence()
+ self.set_requires_grad()
+
+ def set_requires_grad(self):
+ """sets the requires_grad attribute of the model parameters based on the configuration.
+ If `freeze_vision_encoder` is True, the vision tower parameters are frozen.
+ If `train_expert_only` is True, the entire PaliGemma model is frozen.
+ """
+ if self.config.freeze_vision_encoder:
+ self.paligemma.vision_tower.eval()
+ for params in self.paligemma.vision_tower.parameters():
+ params.requires_grad = False
+
+ if self.config.train_expert_only:
+ self.paligemma.eval()
+ for params in self.paligemma.parameters():
+ params.requires_grad = False
+
+ def train(self, mode: bool = True):
+ super().train(mode)
+ if self.config.freeze_vision_encoder:
+ self.paligemma.vision_tower.eval()
+ if self.config.train_expert_only:
+ self.paligemma.eval()
+
+ def to_bfloat16_like_physical_intelligence(self):
+ """casts the model to bfloat16.
+
+ Modules not casted to bfloat16:
+ - paligemma.language_model.model.embed_tokens.weight
+ - paligemma.language_model.model.norm.weight
+ - gemma_expert.model.norm.weight
+ - gemma_expert.lm_head.weight
+ """
+ self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
+
+ params_to_change_dtype = [
+ "language_model.model.layers",
+ "gemma_expert.model.layers",
+ "vision_tower",
+ "multi_modal",
+ ]
+ for name, param in self.named_parameters():
+ if any(selector in name for selector in params_to_change_dtype):
+ param.data = param.data.to(dtype=torch.bfloat16)
+
+ def embed_image(self, image: torch.Tensor):
+ return self.paligemma.get_image_features(image)
+
+ def embed_language_tokens(self, tokens: torch.Tensor):
+ return self.paligemma.language_model.model.embed_tokens(tokens)
+
+ def handle_kv_cache(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
+ use_cache: Optional[bool] = None,
+ fill_kv_cache: Optional[bool] = None,
+ ):
+ if use_cache:
+ if past_key_values is None:
+ past_key_values = {}
+
+ if fill_kv_cache:
+ past_key_values[layer_idx] = {
+ "key_states": key_states,
+ "value_states": value_states,
+ }
+ else:
+ key_states = torch.cat(
+ [past_key_values[layer_idx]["key_states"], key_states], dim=1
+ )
+ value_states = torch.cat(
+ [past_key_values[layer_idx]["value_states"], value_states],
+ dim=1,
+ )
+ return key_states, value_states, past_key_values
+
+ def forward(
+ self,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
+ inputs_embeds: List[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ fill_kv_cache: Optional[bool] = None,
+ ):
+ """
+ Args:
+ attention_mask (Optional[torch.Tensor], optional):
+ Attention mask with shape (b, seq_len, seq_len). Defaults to None.
+ position_ids (Optional[torch.LongTensor], optional):
+ Position indices for applying RoPE. Defaults to None.
+ past_key_values (Optional[Union[List[torch.FloatTensor], Cache]], optional):
+ Optional kv cache. Defaults to None.
+ inputs_embeds (List[torch.FloatTensor], optional):
+ Input embeddings. Defaults to None.
+ use_cache (Optional[bool], optional):
+ Whether to use kv cache. Defaults to None.
+ fill_kv_cache (Optional[bool], optional):
+ Whether to return kv tensors in this forward pass as cache. Defaults to None.
+
+ Returns:
+ outputs_embeds (torch.Tensor): Output embeddings.
+ past_key_values (Optional[Union[List[torch.FloatTensor], Cache]]):
+ Optional kv cache.
+ """
+ models = [self.paligemma.language_model.model, self.gemma_expert.model]
+
+ # RMSNorm
+ num_layers = self.paligemma.config.text_config.num_hidden_layers
+ for layer_idx in range(num_layers):
+ query_states = []
+ key_states = []
+ value_states = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ if hidden_states is None:
+ continue
+ query_state, key_state, value_state = models[i].layers[layer_idx](hidden_states, compute_kqv=True)
+ if query_state.dtype != torch.float32:
+ query_state, key_state, value_state = query_state.to(torch.float32), key_state.to(torch.float32), value_state.to(torch.float32)
+ # layer = models[i].layers[layer_idx]
+ # hidden_states = layer.input_layernorm(hidden_states)
+ # hidden_shape = (*hidden_states.shape[:-1], -1, layer.self_attn.head_dim)
+
+ # query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
+ # key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
+ # value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
+
+ query_states.append(query_state)
+ key_states.append(key_state)
+ value_states.append(value_state)
+
+ # B,L,H,D with L sequence length, H number of heads, D head dim
+ # concatenate on the number of embeddings/tokens
+ query_states = torch.cat(query_states, dim=1)
+ key_states = torch.cat(key_states, dim=1)
+ value_states = torch.cat(value_states, dim=1)
+
+ query_states = apply_rope(query_states, position_ids)
+ key_states = apply_rope(key_states, position_ids)
+
+ key_states, value_states, past_key_values = self.handle_kv_cache(
+ key_states,
+ value_states,
+ layer_idx,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ fill_kv_cache=fill_kv_cache,
+ )
+ # ipdb.set_trace()
+ att_output = self.attention_interface(query_states, key_states, value_states, attention_mask)
+
+ # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
+ outputs_embeds = []
+ start = 0
+ for i, hidden_states in enumerate(inputs_embeds):
+ # layer = models[i].layers[layer_idx]
+
+ if hidden_states is not None:
+ end = start + hidden_states.shape[1]
+ out_emb = models[i].layers[layer_idx](hidden_states, att_output, start, end, output_atten=True)
+ # if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
+ # att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
+ # out_emb = layer.self_attn.o_proj(att_output[:, start:end])
+
+ # # first residual
+ # out_emb += hidden_states
+ # after_first_residual = out_emb.clone()
+ # out_emb = layer.post_attention_layernorm(out_emb)
+ # out_emb = layer.mlp(out_emb)
+ # # second residual
+ # out_emb += after_first_residual
+
+ outputs_embeds.append(out_emb)
+
+ start = end
+ else:
+ outputs_embeds.append(None)
+
+ inputs_embeds = outputs_embeds
+
+ # final norm
+ outputs_embeds = []
+ for i, hidden_states in enumerate(inputs_embeds):
+ if hidden_states is not None:
+ out_emb = models[i].norm(hidden_states)
+ outputs_embeds.append(out_emb)
+ else:
+ outputs_embeds.append(None)
+
+ return outputs_embeds, past_key_values
+
+ def get_attention_interface(self):
+ if self.config.attention_implementation == "fa2":
+ raise NotImplementedError("FA2 is not implemented (yet)")
+ elif self.config.attention_implementation == "flex":
+ print('=====Using Flex Attn=====')
+ attention_interface = flex_attention_forward
+ elif self.config.attention_implementation == "eager":
+ print('=====Using Eager Attn=====')
+ attention_interface = our_eager_attention_forward
+ elif self.config.attention_implementation == "xformer":
+ # attention_interface = xformer_attention_forward
+ raise NotImplementedError("Xformer attention is not implemented (yet)")
+ else:
+ raise ValueError(
+ f"Invalid attention implementation: {self.config.attention_implementation}. "
+ "Expected one of ['fa2', 'flex', 'eager', 'xformer']."
+ )
+ return attention_interface
+
+class PI0_Omni_Config(PI0Config):
+ model_type = "torch_pi0_omni"
+ architectures = ["PI0Policy"]
+
+class PI0Policy(PreTrainedPolicy):
+ config_class = PI0_Omni_Config
+ name = "torch_pi0"
+ _no_split_modules = ["GemmaDecoderLayer", "FixedGemmaRMSNorm"]
+ def __init__(
+ self,
+ config: PI0Config,
+ tokenizer_path: str,
+ ):
+ """
+ Args:
+ config: Policy configuration class instance or None, in which case the default instantiation of
+ the configuration class is used.
+ """
+
+ super().__init__(config)
+ self.config = config
+ self.language_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
+ self.model = PI0FlowMatching(config)
+
+ del self.model.paligemma_with_expert.paligemma.language_model.lm_head
+ del self.model.paligemma_with_expert.gemma_expert.lm_head
+
+ self.reset()
+ torch.set_float32_matmul_precision("high")
+
+ def reset(self):
+ return None
+
+ def get_optim_params(self) -> dict:
+ return self.parameters()
+
+ def forward(
+ self, images, img_masks, state, lang_tokens, lang_masks, actions, joint_mask=None, action_is_pad=None, noise=None, time=None
+ ) -> tuple[Tensor, dict[str, Tensor]]:
+ # ipdb.set_trace()
+ # noise = batch.get("noise", None)
+ # time = batch.get("time", None)
+ # lang_tokens, lang_masks, actions, action_is_pad = lang_tokens.reshape(images.size(0),-1), lang_masks.reshape(images.size(0),-1), actions.reshape(images.size(0),-1, actions.size(-1)), action_is_pad.reshape(images.size(0),-1)
+ loss_dict = {}
+ losses = self.model.forward(
+ images, img_masks, lang_tokens, lang_masks, state, actions, noise, time
+ )
+
+ # action_is_pad = action_is_pad.reshape(images.size(0), -1)
+ if action_is_pad is not None:
+ in_episode_bound = ~action_is_pad
+ losses = losses * in_episode_bound.unsqueeze(-1)
+ loss_dict["losses_after_in_ep_bound"] = losses.clone()
+
+ # Remove padding
+ losses = losses[:, :, :self.config.action_dim]
+ loss_dict["losses"] = losses.clone()
+
+ # For backward pass
+ loss = losses.mean()
+ # For logging
+ loss_dict["l2_loss"] = loss.item()
+
+ return loss, loss_dict
+
+class PI0FlowMatching(nn.Module):
+ """
+ Designed by Physical Intelligence. Ported from Jax by Hugging Face.
+ ┌──────────────────────────────┐
+ │ actions │
+ │ ▲ │
+ │ ┌┴─────┐ │
+ │ kv cache │Gemma │ │
+ │ ┌──────────►│Expert│ │
+ │ │ │ │ │
+ │ ┌┴────────┐ │x 10 │ │
+ │ │ │ └▲──▲──┘ │
+ │ │PaliGemma│ │ │ │
+ │ │ │ │ robot state │
+ │ │ │ noise │
+ │ └▲──▲─────┘ │
+ │ │ │ │
+ │ │ image(s) │
+ │ language tokens │
+ └──────────────────────────────┘
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ # paligemma with action expert
+ paligemma_with_export_config = PaliGemmaWithExpertConfig(
+ freeze_vision_encoder=self.config.freeze_vision_encoder,
+ train_expert_only=self.config.train_expert_only,
+ vocab_size=self.config.vocab_size,
+ attention_implementation=self.config.attention_implementation,
+ )
+ self.paligemma_with_expert = PaliGemmaWithExpertModel(
+ paligemma_with_export_config
+ )
+ self.config.initializer_range = getattr(paligemma_with_export_config.gemma_expert_config, "initializer_range", None)
+ # projection layers
+ self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
+ self.action_in_proj = nn.Linear(
+ self.config.max_action_dim, self.config.proj_width
+ )
+ self.action_out_proj = nn.Linear(
+ self.config.proj_width, self.config.max_action_dim
+ )
+
+ self.action_time_mlp_in = nn.Linear(
+ self.config.proj_width * 2, self.config.proj_width
+ )
+ self.action_time_mlp_out = nn.Linear(
+ self.config.proj_width, self.config.proj_width
+ )
+
+ self.set_requires_grad()
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def set_requires_grad(self):
+ for params in self.state_proj.parameters():
+ params.requires_grad = self.config.train_state_proj
+
+ def sample_time(self, bsize, device):
+ time_beta = sample_beta(1.5, 1.0, bsize, device)
+ time = time_beta * 0.999 + 0.001
+ return time.to(dtype=torch.float32, device=device)
+
+ def embed_prefix(
+ self, images, img_masks, lang_tokens, lang_masks
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Embed images with SigLIP and language tokens with embedding layer to prepare
+ for PaliGemma transformer processing.
+
+ Args:
+ images (torch.Tensor): float (*b, n, c, h, w) images in range [-1.0, 1.0]
+ img_masks (torch.Tensor): bool (*b, n) masks for images
+ lang_tokens (torch.Tensor): int (*b, l) language tokens
+ lang_masks (torch.Tensor): bool (*b, l) masks for language tokens
+ """
+ bsize = images.shape[0]
+ device = images.device
+ dtype = images.dtype
+
+ # embed image
+ images = einops.rearrange(images, "b n c h w -> (b n) c h w")
+ # ipdb.set_trace()
+ img_emb = self.paligemma_with_expert.embed_image(images) # torch.Size([72, 3, 224, 224]) -> torch.Size([72, 256, 2048])
+ num_patch = img_emb.shape[1]
+ img_emb = einops.rearrange(img_emb, "(b n) l d -> b (n l) d", b=bsize)
+ ######## multi v
+ img_emb = img_emb.to(dtype=dtype) * (img_emb.shape[-1] ** 0.5)
+ num_img_embs = img_emb.shape[1]
+ img_masks = einops.repeat(img_masks, "b n -> b (n l)", l=num_patch)
+
+ # embed language
+ lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
+ num_lang_embs = lang_emb.shape[1]
+ lang_emb = lang_emb.to(dtype=dtype) * np.sqrt(lang_emb.shape[-1])
+
+ # assemble embeddings
+ # img_emb = img_emb.reshape(1, -1, img_emb.size(-1))
+ # num_img_embs = img_emb.shape[1]
+ # img_masks = img_masks.reshape(1, -1)
+ embs = torch.cat([img_emb, lang_emb], dim=1)
+ pad_masks = torch.cat([img_masks, lang_masks], dim=1)
+
+ # PaliGemma uses bidirectional attention for prefix tokens,
+ # so we set 1D `att_masks` to zeros.
+ # (see `make_att_2d_masks` to understand why zeros means bidirection)
+ att_masks = torch.zeros(
+ (img_emb.size(0), num_img_embs + num_lang_embs), device=device, dtype=torch.bool
+ ) # 1, bs_img*(768+48)
+ return embs, pad_masks, att_masks
+
+ def embed_suffix(self, state, noisy_actions, timestep): # (torch.Size([state_bs, 32]), torch.Size([1, state_bs*50, 32]), torch.Size([1]))
+ """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.
+
+ Args:
+ state (torch.Tensor): float32 (*b, s) robot state
+ noisy_actions (torch.Tensor): float32 (*b, n, m) noisy actions
+ timestep (torch.Tensor): float32 (*b,) timestep in [0, 1] range
+ """
+ bsize = state.shape[0] # state_bs = img_bs
+ device = state.device
+ dtype = state.dtype
+ # embed state
+ state_emb = self.state_proj(state) # torch.Size([state_bs, 1024])
+
+ # embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
+ time_emb = create_sinusoidal_pos_embedding( # 1, 1024
+ timestep, # torch.Size([1]))
+ self.config.proj_width, # 1024
+ min_period=4e-3,
+ max_period=4.0,
+ device=device,
+ )
+ time_emb = time_emb.type(dtype=dtype)
+
+ # Fuse timestep + action information using an MLP
+ action_emb = self.action_in_proj(noisy_actions) # torch.Size([1, state_bs*50, 1024])
+ time_emb = einops.repeat(time_emb, "b d -> b n d", n=action_emb.shape[1]) # [1, 1024] -> [1, state_bs*50, 1024]
+ action_time_emb = torch.cat([action_emb, time_emb], dim=-1) # [1, state_bs*50, 2048]
+
+ action_time_emb = self.action_time_mlp_in(action_time_emb)
+ action_time_emb = F.silu(action_time_emb) # swish == silu
+ action_time_emb = self.action_time_mlp_out(action_time_emb) # [1, state_bs*50, 1024]
+ action_time_dim = action_time_emb.shape[1]
+ # action_chunk_num = int(action_time_dim / bsize) # 50
+
+ # # Add to input tokens
+ # embs = torch.cat([state_emb[None, :, :], action_time_emb], dim=1) # [1, state_bs*(50+1), 1024]
+ # pad_masks = torch.ones(
+ # (bsize, action_chunk_num + 1), device=device, dtype=torch.bool
+ # ) # state_bs, 51
+ # pad_masks = pad_masks.reshape(1, -1) # 1, state_bs*(50+1)
+
+ # # Set attention masks for suffix tokens so that prefix tokens cannot attend to suffix tokens.
+ # # And state token cannot attend action tokens.
+ # # Action tokens use a bidirectional attention.
+ # att_masks = torch.zeros(
+ # (bsize, action_chunk_num + 1), device=device, dtype=torch.bool
+ # ) # state_bs, 51
+ # att_masks[:, :2] = True
+ # att_masks = att_masks.reshape(1, -1) # 1, state_bs*(50+1)
+
+ # Add to input tokens
+ embs = torch.cat([state_emb[:, None], action_time_emb], dim=1)
+ pad_masks = torch.ones(
+ (bsize, action_time_dim + 1), device=device, dtype=torch.bool
+ )
+
+ # Set attention masks for suffix tokens so that prefix tokens cannot attend to suffix tokens.
+ # And state token cannot attend action tokens.
+ # Action tokens use a bidirectional attention.
+ att_masks = torch.zeros(
+ (bsize, action_time_dim + 1), device=device, dtype=torch.bool
+ )
+ att_masks[:, :2] = True
+
+ return embs, pad_masks, att_masks
+
+ def forward(
+ self,
+ images,
+ img_masks,
+ lang_tokens,
+ lang_masks,
+ state,
+ actions,
+ noise=None,
+ time=None,
+ ) -> Tensor:
+ dtype = state.dtype
+ device = state.device
+ # ipdb.set_trace()
+ """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
+ # def set_tensor_ones(x: torch.Tensor) -> torch.Tensor:
+ # return torch.ones_like(x)
+ # images, img_masks, lang_tokens, lang_masks, state, actions = set_tensor_ones(images),set_tensor_ones(img_masks),set_tensor_ones(lang_tokens),set_tensor_ones(lang_masks),set_tensor_ones(state),set_tensor_ones(actions)
+ if noise is None:
+ # actions_shape = (
+ # bsize,
+ # self.config.n_action_steps, # 50
+ # self.config.max_action_dim, # 32
+ # )
+ noise = torch.randn(actions.shape, device=device, dtype=dtype)
+
+ if time is None:
+ time = self.sample_time(actions.size(0), device).to(dtype)
+
+ time_expanded = time[:, None, None]
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
+ u_t = noise - actions
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
+ images, img_masks, lang_tokens, lang_masks
+ ) # 1,bs_img*(768+48),2048 1,bs_img*(768+48) 1,bs_img*(768+48)
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
+ state, x_t, time
+ ) # [1, state_bs*(50+1), 1024], [1, state_bs*(50+1)], [1, state_bs*(50+1)] state_bs=bs_img
+
+ pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) # 1,state_bs*(768+48+50+1)
+ att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)# 1,state_bs*(768+48+50+1)
+
+ # pad_masks = pad_masks.reshape(state.size(0), -1)
+ # att_masks = att_masks.reshape(state.size(0), -1)
+ att_2d_masks = make_att_2d_masks(pad_masks, att_masks) # torch.Size([state_bs, 768+48+50+1, 768+48+50+1])
+ position_ids = torch.cumsum(pad_masks, dim=1) - 1 # torch.Size([state_bs, 768+48+50+1])
+
+ # prefix_embs = prefix_embs.reshape(state.size(0), -1, prefix_embs.size(-1))
+ # suffix_embs = suffix_embs.reshape(state.size(0), -1, suffix_embs.size(-1))
+ (_, suffix_out), _ = self.paligemma_with_expert.forward(
+ attention_mask=att_2d_masks,
+ position_ids=position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, suffix_embs], # bs_img,(768+48),2048 [state_bs, (50+1), 1024]
+ use_cache=True,
+ fill_kv_cache=True,
+ )
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
+ if suffix_out.dtype != self.action_out_proj.weight.dtype:
+ suffix_out = suffix_out.to(self.action_out_proj.weight.dtype)
+ v_t = self.action_out_proj(suffix_out)
+ # u_t = u_t.reshape(images.size(0), -1, u_t.size(-1))
+ losses = F.mse_loss(u_t, v_t, reduction="none")
+ # losses = torch.mean((v_t - u_t)**2, dim=-1)
+ return losses
+
+ def sample_actions(
+ self, images, img_masks, lang_tokens, lang_masks, state, noise=None
+ ) -> Tensor:
+ """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
+ bsize = state.shape[0]
+ device = state.device
+ dtype = state.dtype
+
+ if noise is None:
+ actions_shape = (
+ bsize,
+ self.config.n_action_steps,
+ self.config.max_action_dim,
+ )
+ noise = torch.randn(actions_shape, device=device, dtype=dtype)
+
+ prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
+ images, img_masks, lang_tokens, lang_masks
+ )
+ prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
+ prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
+
+ # Compute image and language key value cache
+ _, past_key_values = self.paligemma_with_expert.forward(
+ attention_mask=prefix_att_2d_masks,
+ position_ids=prefix_position_ids,
+ past_key_values=None,
+ inputs_embeds=[prefix_embs, None],
+ use_cache=self.config.use_cache,
+ fill_kv_cache=True,
+ )
+
+ dt = torch.tensor(-1.0 / self.config.num_steps, dtype=dtype, device=device)
+ x_t = noise
+ time = torch.tensor(1.0, dtype=dtype, device=device)
+ while time >= -dt / 2:
+ expanded_time = time.expand(bsize)
+
+ v_t = self.predict_velocity(
+ state, prefix_pad_masks, past_key_values, x_t, expanded_time
+ )
+
+ # Euler step
+ x_t += dt * v_t
+ time += dt
+
+ return x_t
+
+ def predict_velocity(self, state, prefix_pad_masks, past_key_values, x_t, timestep):
+ """predict velocity at time t using the suffix model."""
+ suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(
+ state, x_t, timestep
+ )
+
+ suffix_len = suffix_pad_masks.shape[1]
+ batch_size = prefix_pad_masks.shape[0]
+ prefix_len = prefix_pad_masks.shape[1]
+ prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(
+ batch_size, suffix_len, prefix_len
+ )
+
+ suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
+
+ full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
+
+ prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
+ position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
+
+ outputs_embeds, _ = self.paligemma_with_expert.forward(
+ attention_mask=full_att_2d_masks,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=[None, suffix_embs],
+ use_cache=self.config.use_cache,
+ fill_kv_cache=False,
+ )
+ suffix_out = outputs_embeds[1]
+ suffix_out = suffix_out[:, -self.config.n_action_steps :]
+ v_t = self.action_out_proj(suffix_out)
+ return v_t
+
+ModelClass = PI0Policy
+
+__all__ = ["PI0Policy", "PaliGemmaForConditionalGeneration", "GemmaForCausalLM", "GemmaModel", "PaliGemmaPreTrainedModel", "GemmaPreTrainedModel"]
\ No newline at end of file
diff --git a/lingbotvla/models/vla/pi0/qwenvl_in_vla.py b/lingbotvla/models/vla/pi0/qwenvl_in_vla.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0583a0c9a77f622b9f1a3f681bd65601cf3d3f3
--- /dev/null
+++ b/lingbotvla/models/vla/pi0/qwenvl_in_vla.py
@@ -0,0 +1,1998 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+from torch import Tensor, nn
+from typing import List, Optional, Tuple, Union, Callable, Dict, Any
+import math
+from transformers import (
+ PreTrainedModel,
+)
+from dataclasses import dataclass
+from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig
+from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache, DynamicCache
+from transformers.generation import GenerationMixin
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs, flash_attn_supports_top_left_mask, is_flash_attn_available
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from transformers.processing_utils import Unpack
+import torch.distributed._tensor as dt
+
+if is_flash_attn_available():
+ from transformers.modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
+if is_flash_attn_available():
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
+
+
+from .vla_flash_attn_policy import use_flash_attention_2_for_vla
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "Qwen2_5_VLConfig"
+
+
+def rotate_half(x: torch.Tensor) -> torch.Tensor:
+ """Split last dim in half and swap with sign flip (RoPE). Used by eager/sdpa paths; flash path uses apply_rotary_emb."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+class Qwen2_5_VLMLP(nn.Module):
+ def __init__(self, config, bias: bool = False):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
+
+
+class Qwen2_5_VisionPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ embed_dim: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.view(
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
+ )
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
+ return hidden_states
+
+
+class Qwen2_5_VisionRotaryEmbedding(nn.Module):
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+class Qwen2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Qwen2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+class Qwen2_5_VLPatchMerger(nn.Module):
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+ self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
+ self.mlp = nn.Sequential(
+ nn.Linear(self.hidden_size, self.hidden_size),
+ nn.GELU(),
+ nn.Linear(self.hidden_size, dim),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
+ return x
+
+
+def apply_rotary_pos_emb_flashatt(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ cos = cos.chunk(2, dim=-1)[0].contiguous()
+ sin = sin.chunk(2, dim=-1)[0].contiguous()
+ q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
+ k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
+ return q_embed, k_embed
+
+
+class Qwen2_5_VLVisionAttention(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ else:
+ cos, sin = position_embeddings
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
+
+ attention_mask = torch.full(
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
+ )
+ for i in range(1, len(cu_seqlens)):
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
+
+ q = q.transpose(0, 1)
+ k = k.transpose(0, 1)
+ v = v.transpose(0, 1)
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
+ attn_weights = attn_weights + attention_mask
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
+ attn_output = torch.matmul(attn_weights, v)
+ attn_output = attn_output.transpose(0, 1)
+ attn_output = attn_output.reshape(seq_length, -1)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+class Qwen2_5_VLVisionSdpaAttention(Qwen2_5_VLVisionAttention):
+ """Packed vision tokens attention via :func:`torch.nn.functional.scaled_dot_product_attention`."""
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ else:
+ cos, sin = position_embeddings
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
+
+ attention_mask = torch.full(
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
+ )
+ for i in range(1, len(cu_seqlens)):
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
+
+ q = q.transpose(0, 1).unsqueeze(0).contiguous()
+ k = k.transpose(0, 1).unsqueeze(0).contiguous()
+ v = v.transpose(0, 1).unsqueeze(0).contiguous()
+ attn_mask = attention_mask.unsqueeze(1)
+
+ attn_output = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
+ )
+ attn_output = attn_output.squeeze(0).transpose(0, 1).reshape(seq_length, -1)
+ return self.proj(attn_output)
+
+
+class Qwen2_5_VLVisionFlashAttention2(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ else:
+ cos, sin = position_embeddings
+ q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
+ q = q.squeeze(0)
+ k = k.squeeze(0)
+
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ out_fp32_atten = False
+ if k.dtype == torch.float32:
+ out_fp32_atten = True
+ q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16)
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
+ seq_length, -1
+ )
+ if out_fp32_atten:
+ attn_output = attn_output.to(torch.float32)
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
+def apply_rotary_pos_emb_vision(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ orig_q_dtype = q.dtype
+ orig_k_dtype = k.dtype
+ q, k = q.float(), k.float()
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ q_embed = q_embed.to(orig_q_dtype)
+ k_embed = k_embed.to(orig_k_dtype)
+ return q_embed, k_embed
+
+
+QWEN2_5_VL_VISION_ATTENTION_CLASSES = {
+ "eager": Qwen2_5_VLVisionAttention,
+ "sdpa": Qwen2_5_VLVisionSdpaAttention,
+ "flash_attention_2": Qwen2_5_VLVisionFlashAttention2,
+}
+
+class Qwen2_5_VLVisionBlock(nn.Module):
+ def __init__(self, config, attn_implementation: str = "flash_attention_2") -> None:
+ super().__init__()
+ self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
+ self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
+ self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation](
+ config.hidden_size, num_heads=config.num_heads
+ )
+ self.mlp = Qwen2_5_VLMLP(config, bias=True)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: Optional[torch.Tensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ position_embeddings=position_embeddings,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+Qwen2_5_VL_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Qwen2_5_VLConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.",
+ Qwen2_5_VL_START_DOCSTRING,
+)
+class Qwen2_5_VLPreTrainedModel(PreTrainedModel):
+ config_class = Qwen2_5_VLConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
+
+ # def _init_weights(self, module):
+ # std = self.config.initializer_range
+ # if isinstance(module, (nn.Linear, nn.Conv3d)):
+ # module.weight.data.normal_(mean=0.0, std=std)
+ # if module.bias is not None:
+ # module.bias.data.zero_()
+ # elif isinstance(module, nn.Embedding):
+ # module.weight.data.normal_(mean=0.0, std=std)
+ # if module.padding_idx is not None:
+ # module.weight.data[module.padding_idx].zero_()
+
+
+class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
+ config_class = Qwen2_5_VLVisionConfig
+ _no_split_modules = ["Qwen2_5_VLVisionBlock"]
+
+ def __init__(self, config, *inputs, **kwargs) -> None:
+ super().__init__(config, *inputs, **kwargs)
+ self.spatial_merge_size = config.spatial_merge_size
+ self.patch_size = config.patch_size
+ self.fullatt_block_indexes = config.fullatt_block_indexes
+ self.window_size = config.window_size
+ self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
+
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
+ patch_size=config.patch_size,
+ temporal_patch_size=config.temporal_patch_size,
+ in_channels=config.in_channels,
+ embed_dim=config.hidden_size,
+ )
+
+ head_dim = config.hidden_size // config.num_heads
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList(
+ [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
+ )
+ self.merger = Qwen2_5_VLPatchMerger(
+ dim=config.out_hidden_size,
+ context_dim=config.hidden_size,
+ spatial_merge_size=config.spatial_merge_size,
+ )
+ self.gradient_checkpointing = False
+
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def get_window_index(self, grid_thw):
+ window_index: list = []
+ cu_window_seqlens: list = [0]
+ window_index_id = 0
+ vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
+
+ for grid_t, grid_h, grid_w in grid_thw:
+ llm_grid_h, llm_grid_w = (
+ grid_h // self.spatial_merge_size,
+ grid_w // self.spatial_merge_size,
+ )
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
+ index_padded = index_padded.reshape(
+ grid_t,
+ num_windows_h,
+ vit_merger_window_size,
+ num_windows_w,
+ vit_merger_window_size,
+ )
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
+ grid_t,
+ num_windows_h * num_windows_w,
+ vit_merger_window_size,
+ vit_merger_window_size,
+ )
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
+ index_padded = index_padded.reshape(-1)
+ index_new = index_padded[index_padded != -100]
+ window_index.append(index_new + window_index_id)
+ cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
+ window_index = torch.cat(window_index, dim=0)
+
+ return window_index, cu_window_seqlens
+
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
+ The final hidden states of the model.
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
+ The temporal, height and width of feature shape of each image in LLM.
+
+ Returns:
+ `torch.Tensor`: hidden_states.
+ """
+ hidden_states = self.patch_embed(hidden_states)
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ cu_window_seqlens = torch.tensor(
+ cu_window_seqlens,
+ device=hidden_states.device,
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
+
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ hidden_states = hidden_states[window_index, :, :]
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ for layer_num, blk in enumerate(self.blocks):
+ if layer_num in self.fullatt_block_indexes:
+ cu_seqlens_now = cu_seqlens
+ else:
+ cu_seqlens_now = cu_window_seqlens
+ if self.gradient_checkpointing and self.training:
+ hidden_states = self._gradient_checkpointing_func(
+ blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings
+ )
+ else:
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)
+
+ hidden_states = self.merger(hidden_states)
+ reverse_indices = torch.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices, :]
+
+ return hidden_states
+
+
+class Qwen2_5_VLRotaryEmbedding(nn.Module):
+ def __init__(self, config: Qwen2_5_VLConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ # In contrast to other models, Qwen2_5_VL has different position ids for the grids
+ # So we expand the inv_freq to shape (3, ...)
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
+
+ Explanation:
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
+ difference with modern LLMs.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ mrope_section(`List(int)`):
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ mrope_section = mrope_section * 2
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
+ unsqueeze_dim
+ )
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """Expand kv heads to match query heads (GQA). Same layout as HF Llama."""
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Qwen2_5_VLAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+ """
+
+ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+ self.rope_scaling = config.rope_scaling
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # Fix precision issues in Qwen2-VL float16 inference
+ # Replace inf values with zeros in attention weights to prevent NaN propagation
+ if query_states.dtype == torch.float16:
+ attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention):
+ """Decoder self-attention via :func:`torch.nn.functional.scaled_dot_product_attention`."""
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ return super().forward(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=True,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ is_causal = causal_mask is None and self.is_causal
+
+ attn_output = F.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=dropout_rate,
+ is_causal=is_causal,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, None, past_key_value
+
+
+class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention):
+ """
+ Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ cos, sin = position_embeddings
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ if (
+ self.config.use_sliding_window
+ and getattr(self.config, "sliding_window", None) is not None
+ and self.layer_idx >= self.config.max_window_layers
+ ):
+ sliding_window = self.config.sliding_window
+ else:
+ sliding_window = None
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ sliding_window=sliding_window,
+ is_causal=self.is_causal,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+class Qwen2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+QWEN2_5_VL_ATTENTION_CLASSES = {
+ "eager": Qwen2_5_VLAttention,
+ "sdpa": Qwen2_5_VLSdpaAttention,
+ "flash_attention_2": Qwen2_5_VLFlashAttention2,
+}
+
+class Qwen2_5_VLDecoderLayer(nn.Module):
+ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
+ logger.warning_once(
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
+ "unexpected results may be encountered."
+ )
+ self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
+
+ self.mlp = Qwen2MLP(config)
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ if config.norm_qkv:
+ self.q_layernorm = Qwen2RMSNorm(self.self_attn.head_dim, eps=config.rms_norm_eps)
+ self.k_layernorm = Qwen2RMSNorm(self.self_attn.head_dim, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ att_output: Optional[torch.Tensor] = None,
+ start: Optional[int] = 0,
+ end: Optional[int] = 0,
+ compute_kqv: bool = False,
+ norm_qkv: bool = False,
+ output_atten: bool = False,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+
+ if compute_kqv:
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_shape = (*hidden_states.shape[:-1], -1, self.self_attn.head_dim)
+
+ query_state = self.self_attn.q_proj(hidden_states).view(hidden_shape)
+ key_state = self.self_attn.k_proj(hidden_states).view(hidden_shape)
+ value_state = self.self_attn.v_proj(hidden_states).view(hidden_shape)
+
+ if norm_qkv:
+ query_state = self.q_layernorm(query_state)
+ key_state = self.k_layernorm(key_state)
+
+ return query_state, key_state, value_state
+
+ elif output_atten:
+ if att_output.dtype != self.self_attn.o_proj.weight.dtype:
+ att_output = att_output.to(self.self_attn.o_proj.weight.dtype)
+ out_emb = self.self_attn.o_proj(att_output[:, start:end])
+
+ # first residual
+ out_emb += hidden_states
+ after_first_residual = out_emb.clone()
+
+ out_emb = self.post_attention_layernorm(out_emb)
+ out_emb = self.mlp(out_emb)
+
+ # second residual
+ out_emb += after_first_residual
+
+ return out_emb
+
+ else:
+ raise ValueError(f"Invaild Operation compute_kqv={compute_kqv} and output_atten={output_atten} with Qwen2_5_VLDecoderLayer in LingBot-VLA")
+
+
+@add_start_docstrings(
+ "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.",
+ Qwen2_5_VL_START_DOCSTRING,
+)
+class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
+ def __init__(self, config: Qwen2_5_VLConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self._init_weights = lambda module: None
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # torch.jit.trace() doesn't support cache objects in the output
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
+ past_key_values = DynamicCache()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ # the hard coded `3` is for temporal, height and width.
+ if position_ids is None:
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
+ elif position_ids.dim() == 2:
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and past_key_values is not None:
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if (
+ self.config._attn_implementation == "sdpa"
+ and not (using_static_cache or using_sliding_window_cache)
+ and not output_attentions
+ ):
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ sliding_window=self.config.sliding_window,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ # SlidingWindowCache or StaticCache
+ if using_sliding_window_cache or using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ # DynamicCache or no cache
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ config=self.config,
+ past_key_values=past_key_values,
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ config: Qwen2_5_VLConfig,
+ past_key_values: Cache,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ config (`Qwen2_5_VLConfig`):
+ The model's configuration class
+ past_key_values (`Cache`):
+ The cache class that is being used currently to generate
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ if config.sliding_window is not None:
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
+ cache_position.reshape(-1, 1) - config.sliding_window
+ )
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
+ causal_mask *= diagonal_attend_mask
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ if attention_mask.shape[-1] > target_length:
+ attention_mask = attention_mask[:, :target_length]
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+ return causal_mask
+
+
+@dataclass
+class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
+ """
+ Base class for Qwen2_5_VL causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ rope_deltas: Optional[torch.LongTensor] = None
+
+
+QWEN2_5_VL_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
+ [`Qwen2_5_VLImageProcessor`] for processing images.
+ pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
+ The tensors corresponding to the input videos. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
+ [`Qwen2_5_VLImageProcessor`] for processing videos.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
+ The rope index difference between sequence length and multimodal rope.
+"""
+
+
+class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ config_class = Qwen2_5_VLConfig
+ _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(
+ config.vision_config, use_flash_attention_2=use_flash_attention_2_for_vla()
+ )
+ self.model = Qwen2_5_VLModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.rope_deltas = None # cache rope_deltas here
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ def get_rope_index(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
+
+ Explanation:
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
+
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
+ Examples:
+ input_ids: [T T T T T], here T is for text.
+ temporal position_ids: [0, 1, 2, 3, 4]
+ height position_ids: [0, 1, 2, 3, 4]
+ width position_ids: [0, 1, 2, 3, 4]
+
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
+ and 1D rotary position embedding for text part.
+ Examples:
+ Temporal (Time): 3 patches, representing different segments of the video in time.
+ Height: 2 patches, dividing each frame vertically.
+ Width: 2 patches, dividing each frame horizontally.
+ We also have some important parameters:
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
+ text temporal position_ids: [101, 102, 103, 104, 105]
+ text height position_ids: [101, 102, 103, 104, 105]
+ text width position_ids: [101, 102, 103, 104, 105]
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
+ The temporal, height and width of feature shape of each image in LLM.
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
+ The temporal, height and width of feature shape of each video in LLM.
+ second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
+ The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
+ """
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+ mrope_position_deltas = []
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
+ total_input_ids = input_ids
+ if attention_mask is None:
+ attention_mask = torch.ones_like(total_input_ids)
+ position_ids = torch.ones(
+ 3,
+ input_ids.shape[0],
+ input_ids.shape[1],
+ dtype=input_ids.dtype,
+ device=input_ids.device,
+ )
+ image_index, video_index = 0, 0
+ attention_mask = attention_mask.to(total_input_ids.device)
+ for i, input_ids in enumerate(total_input_ids):
+ input_ids = input_ids[attention_mask[i] == 1]
+ image_nums, video_nums = 0, 0
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
+ vision_tokens = input_ids[vision_start_indices + 1]
+ image_nums = (vision_tokens == image_token_id).sum()
+ video_nums = (vision_tokens == video_token_id).sum()
+ input_tokens = input_ids.tolist()
+ llm_pos_ids_list: list = []
+ st = 0
+ remain_images, remain_videos = image_nums, video_nums
+ for _ in range(image_nums + video_nums):
+ if image_token_id in input_tokens and remain_images > 0:
+ ed_image = input_tokens.index(image_token_id, st)
+ else:
+ ed_image = len(input_tokens) + 1
+ if video_token_id in input_tokens and remain_videos > 0:
+ ed_video = input_tokens.index(video_token_id, st)
+ else:
+ ed_video = len(input_tokens) + 1
+ if ed_image < ed_video:
+ t, h, w = (
+ image_grid_thw[image_index][0],
+ image_grid_thw[image_index][1],
+ image_grid_thw[image_index][2],
+ )
+ second_per_grid_t = 0
+ image_index += 1
+ remain_images -= 1
+ ed = ed_image
+
+ else:
+ t, h, w = (
+ video_grid_thw[video_index][0],
+ video_grid_thw[video_index][1],
+ video_grid_thw[video_index][2],
+ )
+ if second_per_grid_ts is not None:
+ second_per_grid_t = second_per_grid_ts[video_index]
+ else:
+ second_per_grid_t = 1.0
+ video_index += 1
+ remain_videos -= 1
+ ed = ed_video
+ llm_grid_t, llm_grid_h, llm_grid_w = (
+ t.item(),
+ h.item() // spatial_merge_size,
+ w.item() // spatial_merge_size,
+ )
+ text_len = ed - st
+
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
+ expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
+
+ time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
+
+ time_tensor_long = time_tensor.long()
+ t_index = time_tensor_long.flatten()
+
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
+
+ if st < len(input_tokens):
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
+ text_len = len(input_tokens) - st
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
+ return position_ids, mrope_position_deltas
+ else:
+ if attention_mask is not None:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
+ else:
+ position_ids = (
+ torch.arange(input_ids.shape[1], device=input_ids.device)
+ .view(1, 1, -1)
+ .expand(3, input_ids.shape[0], -1)
+ )
+ mrope_position_deltas = torch.zeros(
+ [input_ids.shape[0], 1],
+ device=input_ids.device,
+ dtype=input_ids.dtype,
+ )
+
+ return position_ids, mrope_position_deltas
+
+ @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
+ image_grid_thw: Optional[torch.LongTensor] = None,
+ video_grid_thw: Optional[torch.LongTensor] = None,
+ rope_deltas: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ second_per_grid_ts: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
+
+ >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+ >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
+
+ >>> messages = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+ ]
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if inputs_embeds is None:
+ inputs_embeds = self.model.embed_tokens(input_ids)
+ if pixel_values is not None:
+ pixel_values = pixel_values.type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
+ n_image_features = image_embeds.shape[0]
+ if n_image_tokens != n_image_features:
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+
+ mask = input_ids == self.config.image_token_id
+ mask_unsqueezed = mask.unsqueeze(-1)
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
+ image_mask = mask_expanded.to(inputs_embeds.device)
+
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if pixel_values_videos is not None:
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
+ n_video_features = video_embeds.shape[0]
+ if n_video_tokens != n_video_features:
+ raise ValueError(
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
+ )
+
+ mask = input_ids == self.config.video_token_id
+ mask_unsqueezed = mask.unsqueeze(-1)
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
+ video_mask = mask_expanded.to(inputs_embeds.device)
+
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(inputs_embeds.device)
+
+ # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
+ if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
+ # calculate RoPE index once per generation in the pre-fill stage only
+ if (
+ (cache_position is not None and cache_position[0] == 0)
+ or self.rope_deltas is None
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
+ ):
+ position_ids, rope_deltas = self.get_rope_index(
+ input_ids,
+ image_grid_thw,
+ video_grid_thw,
+ second_per_grid_ts,
+ attention_mask,
+ )
+ self.rope_deltas = rope_deltas
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
+ else:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ delta = (
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
+ if cache_position is not None
+ else 0
+ )
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
+ if cache_position is not None: # otherwise `deltas` is an int `0`
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
+ position_ids = position_ids.add(delta)
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
+
+ outputs = self.model(
+ input_ids=None,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return Qwen2_5_VLCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ rope_deltas=self.rope_deltas,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ second_per_grid_ts=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward
+ model_inputs["position_ids"] = None
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+ model_inputs["pixel_values_videos"] = None
+
+ return model_inputs
+
+ def _get_image_nums_and_video_nums(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Returns:
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
+ """
+ image_token_id = self.config.image_token_id
+ video_token_id = self.config.video_token_id
+ vision_start_token_id = self.config.vision_start_token_id
+
+ vision_start_mask = input_ids == vision_start_token_id
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
+ image_mask = input_ids == image_token_id
+ video_mask = input_ids == video_token_id
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
+
+ return image_nums, video_nums
+
+ def _expand_inputs_for_generation(
+ self,
+ expand_size: int = 1,
+ is_encoder_decoder: bool = False,
+ input_ids: Optional[torch.LongTensor] = None,
+ **model_kwargs,
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
+ # Overwritten -- Support for expanding tensors without a batch size dimension
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
+ # image_grid_thw.shape[0] is sum(num_images for samples)
+
+ if expand_size == 1:
+ return input_ids, model_kwargs
+
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
+
+ def _expand_dict_for_generation_visual(dict_to_expand):
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
+ image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
+
+ def _repeat_interleave_samples(x, lengths, repeat_times):
+ samples = torch.split(x, lengths)
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
+ return result
+
+ for key in dict_to_expand:
+ if key == "pixel_values":
+ # split images into samples
+ samples = torch.split(image_grid_thw, list(image_nums))
+ # compute the sequence length of images for each sample
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "image_grid_thw":
+ # get the num of images for each sample
+ lengths = list(image_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "pixel_values_videos":
+ samples = torch.split(video_grid_thw, list(video_nums))
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "video_grid_thw":
+ lengths = list(video_nums)
+ dict_to_expand[key] = _repeat_interleave_samples(
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
+ )
+ elif key == "second_per_grid_ts":
+ if not isinstance(dict_to_expand[key], list):
+ raise TypeError(
+ f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
+ )
+ tensor = torch.tensor(dict_to_expand[key])
+ lengths = list(video_nums)
+ tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
+ dict_to_expand[key] = tensor.tolist()
+ return dict_to_expand
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if (
+ key != "cache_position"
+ and dict_to_expand[key] is not None
+ and isinstance(dict_to_expand[key], torch.Tensor)
+ and key not in visual_keys
+ ):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ # input_ids is required for expanding visual inputs
+ # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
+ if input_ids is not None and input_ids.numel() != 0:
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
\ No newline at end of file
diff --git a/lingbotvla/models/vla/pi0/utils.py b/lingbotvla/models/vla/pi0/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..892ff8f5b9090edae948a738e9c51dd2648c2141
--- /dev/null
+++ b/lingbotvla/models/vla/pi0/utils.py
@@ -0,0 +1,207 @@
+import math
+
+import einops
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from packaging.version import Version
+import ipdb
+# from xformers.ops import memory_efficient_attention
+
+
+def find_next_divisible_by_8_numpy(n: np.ndarray) -> np.ndarray:
+ """
+ Finds the smallest integers greater than each element in a NumPy array 'n'
+ that are divisible by 8. Assumes non-negative integers.
+
+ Args:
+ n: A NumPy array of integers.
+
+ Returns:
+ A NumPy array containing the smallest integers greater than each input element
+ that are divisible by 8.
+ """
+ remainder = n % 8
+ # Calculate the amount to add: 0 if already divisible, otherwise 8 - remainder
+ # np.where is efficient for conditional operations on arrays
+ amount_to_add = np.where(remainder == 0, 8, 8 - remainder)
+ return n + amount_to_add
+
+
+def create_sinusoidal_pos_embedding(
+ time: torch.tensor,
+ dimension: int,
+ min_period: float,
+ max_period: float,
+ device="cpu",
+) -> Tensor:
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
+ if dimension % 2 != 0:
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
+
+ if time.ndim != 1:
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
+
+ fraction = torch.linspace(
+ 0.0, 1.0, dimension // 2, dtype=torch.float32, device=device
+ )
+ period = min_period * (max_period / min_period) ** fraction
+
+ # Compute the outer product
+ scaling_factor = 1.0 / period * 2 * math.pi
+ sin_input = scaling_factor[None, :] * time[:, None]
+ pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
+ return pos_emb
+
+
+def sample_beta(alpha, beta, bsize, device):
+ gamma1 = torch.rand((bsize,), device=device).pow(1 / alpha)
+ gamma2 = torch.rand((bsize,), device=device).pow(1 / beta)
+ return gamma1 / (gamma1 + gamma2)
+
+
+def make_att_2d_masks(pad_masks, att_masks):
+ """Copied from big_vision.
+
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
+ setup several types of attention, for example:
+
+ [[1 1 1 1 1 1]]: pure causal attention.
+
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
+ themselves and the last 3 tokens have a causal attention. The first
+ entry could also be a 1 without changing behaviour.
+
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
+ block can attend all previous blocks and all tokens on the same block.
+
+ Args:
+ input_mask: bool[B, N] true if its part of the input, false if padding.
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
+ it and 0 where it shares the same attention mask as the previous token.
+ """
+ if att_masks.ndim != 2:
+ raise ValueError(att_masks.ndim)
+ if pad_masks.ndim != 2:
+ raise ValueError(pad_masks.ndim)
+
+ cumsum = torch.cumsum(att_masks, dim=1)
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
+ att_2d_masks = att_2d_masks & pad_2d_masks
+ return att_2d_masks
+
+
+def resize_with_pad(img, width, height, pad_value=-1):
+ # assume no-op when width height fits already
+ if img.ndim != 4:
+ raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
+
+ cur_height, cur_width = img.shape[2:]
+
+ ratio = max(cur_width / width, cur_height / height)
+ resized_height = int(cur_height / ratio)
+ resized_width = int(cur_width / ratio)
+ resized_img = F.interpolate(
+ img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ )
+
+ pad_height = max(0, int(height - resized_height))
+ pad_width = max(0, int(width - resized_width))
+
+ # pad on left and top of image
+ padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
+ return padded_img
+
+
+def our_eager_attention_forward(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+):
+ """
+ Performs eager attention, optimized with torch.einsum.
+
+ Args:
+ query_states: Query tensor of shape [batch_size, seq_len, num_attention_heads, head_dim].
+ key_states: Key tensor of shape [batch_size, seq_len, num_key_value_heads, head_dim].
+ value_states: Value tensor of shape [batch_size, seq_len, num_key_value_heads, head_dim].
+ attention_mask: Attention mask tensor, typically [batch_size, 1, seq_len, seq_len] or [batch_size, seq_len, seq_len].
+
+ Returns:
+ Output tensor of shape [batch_size, seq_len, num_attention_heads * head_dim].
+ """
+ # ipdb.set_trace()
+ bsize, seq_len, num_att_heads, head_dim = query_states.shape
+ num_key_value_heads = key_states.shape[2]
+ num_key_value_groups = num_att_heads // num_key_value_heads
+
+ key_states = einops.repeat(
+ key_states, "b l h d -> b l (h g) d", g=num_key_value_groups
+ )
+ value_states = einops.repeat(
+ value_states, "b l h d -> b l (h g) d", g=num_key_value_groups
+ )
+
+ query_states_permuted = torch.einsum("blhd->bhld", query_states)
+ key_states_permuted = torch.einsum("blhd->bhld", key_states)
+
+ att_weights = torch.einsum(
+ "bhqd,bhkd->bhqk", query_states_permuted, key_states_permuted
+ )
+ att_weights *= head_dim**-0.5
+
+ big_neg = -2.3819763e38
+ masked_att_weights = torch.where(
+ attention_mask[:, None, :, :], att_weights, big_neg
+ )
+
+ probs = nn.functional.softmax(masked_att_weights, dim=-1)
+ probs = probs.to(dtype=value_states.dtype)
+
+ value_states_permuted = torch.einsum("blhd->bhld", value_states) # [B, H, L_v, D]
+ att_output = torch.einsum(
+ "bhqk,bhkv->bhqv", probs, value_states_permuted
+ ) # [B, H, L_q, D]
+ att_output = torch.einsum("bhld->blhd", att_output) # [B, L, H, D]
+ att_output = att_output.reshape(bsize, seq_len, num_att_heads * head_dim)
+
+ return att_output
+
+
+# @torch.jit.script
+def apply_rope(
+ x: torch.Tensor,
+ positions: torch.Tensor,
+ max_wavelength: float = 10_000.0,
+ dtype: torch.dtype = torch.float32,
+) -> torch.Tensor:
+ """Applies RoPE positions [B, L] to x [B, L, H, D]."""
+ # ipdb.set_trace()
+ original_dtype = x.dtype # bf16
+ d = x.shape[-1]
+ d_half = d // 2
+ device = x.device
+
+ # Cast input to compute_dtype for all internal operations
+ x_casted = x.to(dtype)
+ positions_casted = positions.to(dtype)
+
+ freq_exponents = (2.0 / d) * torch.arange(d_half, dtype=dtype, device=device)
+ timescale = max_wavelength**freq_exponents
+ radians = torch.einsum("bl,h->blh", positions_casted, 1.0 / timescale) # fp32 -> bf16
+
+ radians = radians[..., None, :] # [B, L, 1, D_half]
+
+ sin = torch.sin(radians) # bf16
+ cos = torch.cos(radians) # bf16
+
+ x1, x2 = x_casted.split(d_half, dim=-1) # fp32
+
+ res = torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) # fp32
+
+ return res.to(original_dtype) # bf16
\ No newline at end of file
diff --git a/lingbotvla/models/vla/pi0/vla_flash_attn_policy.py b/lingbotvla/models/vla/pi0/vla_flash_attn_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..79204eea510a0769c6d9f454700183529f782cbd
--- /dev/null
+++ b/lingbotvla/models/vla/pi0/vla_flash_attn_policy.py
@@ -0,0 +1,19 @@
+"""Whether to enable HF Flash Attention 2 when building Qwen2.5-VL / expert (by GPU type)."""
+
+from __future__ import annotations
+
+import torch
+
+__all__ = ["use_flash_attention_2_for_vla"]
+
+
+def use_flash_attention_2_for_vla() -> bool:
+ """
+ Return True to pass ``use_flash_attention_2=True`` into ``_from_config``, else False.
+
+ Currently: Tesla V100 (Volta) → False; any other detected CUDA device name → True.
+ No CUDA → False.
+ """
+ if "v100" in torch.cuda.get_device_name(0).lower():
+ return False
+ return True
diff --git a/lingbotvla/models/vla/vision_models/MoGe/.gitignore b/lingbotvla/models/vla/vision_models/MoGe/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..eee13ac82659361f18f3595742fd0355de307e39
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/.gitignore
@@ -0,0 +1,423 @@
+## Ignore Visual Studio temporary files, build results, and
+## files generated by popular Visual Studio add-ons.
+##
+## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
+
+# User-specific files
+*.rsuser
+*.suo
+*.user
+*.userosscache
+*.sln.docstates
+
+# User-specific files (MonoDevelop/Xamarin Studio)
+*.userprefs
+
+# Mono auto generated files
+mono_crash.*
+
+# Build results
+[Dd]ebug/
+[Dd]ebugPublic/
+[Rr]elease/
+[Rr]eleases/
+x64/
+x86/
+[Ww][Ii][Nn]32/
+[Aa][Rr][Mm]/
+[Aa][Rr][Mm]64/
+bld/
+[Bb]in/
+[Oo]bj/
+[Ll]og/
+[Ll]ogs/
+
+# Visual Studio 2015/2017 cache/options directory
+.vs/
+# Uncomment if you have tasks that create the project's static files in wwwroot
+#wwwroot/
+
+# Visual Studio 2017 auto generated files
+Generated\ Files/
+
+# MSTest test Results
+[Tt]est[Rr]esult*/
+[Bb]uild[Ll]og.*
+
+# NUnit
+*.VisualState.xml
+TestResult.xml
+nunit-*.xml
+
+# Build Results of an ATL Project
+[Dd]ebugPS/
+[Rr]eleasePS/
+dlldata.c
+
+# Benchmark Results
+BenchmarkDotNet.Artifacts/
+
+# .NET Core
+project.lock.json
+project.fragment.lock.json
+artifacts/
+
+# ASP.NET Scaffolding
+ScaffoldingReadMe.txt
+
+# StyleCop
+StyleCopReport.xml
+
+# Files built by Visual Studio
+*_i.c
+*_p.c
+*_h.h
+*.ilk
+*.meta
+*.obj
+*.iobj
+*.pch
+*.pdb
+*.ipdb
+*.pgc
+*.pgd
+*.rsp
+*.sbr
+*.tlb
+*.tli
+*.tlh
+*.tmp
+*.tmp_proj
+*_wpftmp.csproj
+*.log
+*.tlog
+*.vspscc
+*.vssscc
+.builds
+*.pidb
+*.svclog
+*.scc
+
+# Chutzpah Test files
+_Chutzpah*
+
+# Visual C++ cache files
+ipch/
+*.aps
+*.ncb
+*.opendb
+*.opensdf
+*.sdf
+*.cachefile
+*.VC.db
+*.VC.VC.opendb
+
+# Visual Studio profiler
+*.psess
+*.vsp
+*.vspx
+*.sap
+
+# Visual Studio Trace Files
+*.e2e
+
+# TFS 2012 Local Workspace
+$tf/
+
+# Guidance Automation Toolkit
+*.gpState
+
+# ReSharper is a .NET coding add-in
+_ReSharper*/
+*.[Rr]e[Ss]harper
+*.DotSettings.user
+
+# TeamCity is a build add-in
+_TeamCity*
+
+# DotCover is a Code Coverage Tool
+*.dotCover
+
+# AxoCover is a Code Coverage Tool
+.axoCover/*
+!.axoCover/settings.json
+
+# Coverlet is a free, cross platform Code Coverage Tool
+coverage*.json
+coverage*.xml
+coverage*.info
+
+# Visual Studio code coverage results
+*.coverage
+*.coveragexml
+
+# NCrunch
+_NCrunch_*
+.*crunch*.local.xml
+nCrunchTemp_*
+
+# MightyMoose
+*.mm.*
+AutoTest.Net/
+
+# Web workbench (sass)
+.sass-cache/
+
+# Installshield output folder
+[Ee]xpress/
+
+# DocProject is a documentation generator add-in
+DocProject/buildhelp/
+DocProject/Help/*.HxT
+DocProject/Help/*.HxC
+DocProject/Help/*.hhc
+DocProject/Help/*.hhk
+DocProject/Help/*.hhp
+DocProject/Help/Html2
+DocProject/Help/html
+
+# Click-Once directory
+publish/
+
+# Publish Web Output
+*.[Pp]ublish.xml
+*.azurePubxml
+# Note: Comment the next line if you want to checkin your web deploy settings,
+# but database connection strings (with potential passwords) will be unencrypted
+*.pubxml
+*.publishproj
+
+# Microsoft Azure Web App publish settings. Comment the next line if you want to
+# checkin your Azure Web App publish settings, but sensitive information contained
+# in these scripts will be unencrypted
+PublishScripts/
+
+# NuGet Packages
+*.nupkg
+# NuGet Symbol Packages
+*.snupkg
+# The packages folder can be ignored because of Package Restore
+**/[Pp]ackages/*
+# except build/, which is used as an MSBuild target.
+!**/[Pp]ackages/build/
+# Uncomment if necessary however generally it will be regenerated when needed
+#!**/[Pp]ackages/repositories.config
+# NuGet v3's project.json files produces more ignorable files
+*.nuget.props
+*.nuget.targets
+
+# Microsoft Azure Build Output
+csx/
+*.build.csdef
+
+# Microsoft Azure Emulator
+ecf/
+rcf/
+
+# Windows Store app package directories and files
+AppPackages/
+BundleArtifacts/
+Package.StoreAssociation.xml
+_pkginfo.txt
+*.appx
+*.appxbundle
+*.appxupload
+
+# Visual Studio cache files
+# files ending in .cache can be ignored
+*.[Cc]ache
+# but keep track of directories ending in .cache
+!?*.[Cc]ache/
+
+# Others
+ClientBin/
+~$*
+*~
+*.dbmdl
+*.dbproj.schemaview
+*.jfm
+*.pfx
+*.publishsettings
+orleans.codegen.cs
+
+# Including strong name files can present a security risk
+# (https://github.com/github/gitignore/pull/2483#issue-259490424)
+#*.snk
+
+# Since there are multiple workflows, uncomment next line to ignore bower_components
+# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
+#bower_components/
+
+# RIA/Silverlight projects
+Generated_Code/
+
+# Backup & report files from converting an old project file
+# to a newer Visual Studio version. Backup files are not needed,
+# because we have git ;-)
+_UpgradeReport_Files/
+Backup*/
+UpgradeLog*.XML
+UpgradeLog*.htm
+ServiceFabricBackup/
+*.rptproj.bak
+
+# SQL Server files
+*.mdf
+*.ldf
+*.ndf
+
+# Business Intelligence projects
+*.rdl.data
+*.bim.layout
+*.bim_*.settings
+*.rptproj.rsuser
+*- [Bb]ackup.rdl
+*- [Bb]ackup ([0-9]).rdl
+*- [Bb]ackup ([0-9][0-9]).rdl
+
+# Microsoft Fakes
+FakesAssemblies/
+
+# GhostDoc plugin setting file
+*.GhostDoc.xml
+
+# Node.js Tools for Visual Studio
+.ntvs_analysis.dat
+node_modules/
+
+# Visual Studio 6 build log
+*.plg
+
+# Visual Studio 6 workspace options file
+*.opt
+
+# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
+*.vbw
+
+# Visual Studio 6 auto-generated project file (contains which files were open etc.)
+*.vbp
+
+# Visual Studio 6 workspace and project file (working project files containing files to include in project)
+*.dsw
+*.dsp
+
+# Visual Studio 6 technical files
+*.ncb
+*.aps
+
+# Visual Studio LightSwitch build output
+**/*.HTMLClient/GeneratedArtifacts
+**/*.DesktopClient/GeneratedArtifacts
+**/*.DesktopClient/ModelManifest.xml
+**/*.Server/GeneratedArtifacts
+**/*.Server/ModelManifest.xml
+_Pvt_Extensions
+
+# Paket dependency manager
+.paket/paket.exe
+paket-files/
+
+# FAKE - F# Make
+.fake/
+
+# CodeRush personal settings
+.cr/personal
+
+# Python Tools for Visual Studio (PTVS)
+__pycache__/
+*.pyc
+
+# Cake - Uncomment if you are using it
+# tools/**
+# !tools/packages.config
+
+# Tabs Studio
+*.tss
+
+# Telerik's JustMock configuration file
+*.jmconfig
+
+# BizTalk build output
+*.btp.cs
+*.btm.cs
+*.odx.cs
+*.xsd.cs
+
+# OpenCover UI analysis results
+OpenCover/
+
+# Azure Stream Analytics local run output
+ASALocalRun/
+
+# MSBuild Binary and Structured Log
+*.binlog
+
+# NVidia Nsight GPU debugger configuration file
+*.nvuser
+
+# MFractors (Xamarin productivity tool) working folder
+.mfractor/
+
+# Local History for Visual Studio
+.localhistory/
+
+# Visual Studio History (VSHistory) files
+.vshistory/
+
+# BeatPulse healthcheck temp database
+healthchecksdb
+
+# Backup folder for Package Reference Convert tool in Visual Studio 2017
+MigrationBackup/
+
+# Ionide (cross platform F# VS Code tools) working folder
+.ionide/
+
+# Fody - auto-generated XML schema
+FodyWeavers.xsd
+
+# VS Code files for those working on multiple tools
+.vscode/*
+!.vscode/settings.json
+!.vscode/tasks.json
+!.vscode/launch.json
+!.vscode/extensions.json
+*.code-workspace
+
+# Local History for Visual Studio Code
+.history/
+
+# Windows Installer files from build outputs
+*.cab
+*.msi
+*.msix
+*.msm
+*.msp
+
+# JetBrains Rider
+*.sln.iml
+
+# Python
+*.egg-info/
+/build
+
+# MoGe
+/data*
+/download
+/extract
+/debug
+/workspace
+/mlruns
+/infer_output
+/video_output
+/eval_output
+/.blobcache
+/test_images
+/test_videos
+/vis
+/videos
+/blobmnt
+/eval_dump
+/pretrained
+/.gradio
+/tmp
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/CHANGELOG.md b/lingbotvla/models/vla/vision_models/MoGe/CHANGELOG.md
new file mode 100644
index 0000000000000000000000000000000000000000..1359169c38b5fd8235f84333cbcbcae178b1f1c8
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/CHANGELOG.md
@@ -0,0 +1,40 @@
+## 2024-11-28
+### Added
+- Supported user-provided camera FOV. See [scripts/infer.py](scripts/infer.py) --fov_x.
+ - Related issues: [#25](https://github.com/microsoft/MoGe/issues/25) and [#24](https://github.com/microsoft/MoGe/issues/24).
+- Added inference scripts for panorama images. See [scripts/infer_panorama.py](scripts/infer_panorama.py).
+ - Related issue: [#19](https://github.com/microsoft/MoGe/issues/19).
+
+### Fixed
+- Suppressed unnecessary numpy runtime warnings.
+- Specified recommended versions of requirements.
+ - Related issue: [#21](https://github.com/microsoft/MoGe/issues/21).
+
+### Changed
+- Moved `app.py` and `infer.py` to [scripts/](scripts/)
+- Improved edge removal.
+
+## 2025-03-18
+### Added
+- Training and evaluation code. See [docs/train.md](docs/train.md) and [docs/eval.md](docs/eval.md).
+- Supported installation via pip. Thanks to @fabiencastan and @jgoueslard
+ for commits in the [#47](https://github.com/microsoft/MoGe/pull/47)
+- Supported command-line usage when installed.
+
+### Changed
+- Moved `scripts/` into `moge/` for package installation and command-line usage.
+- Renamed `moge.model.moge_model` to `moge.model.v1` for version management.
+ Now you can import the model class through `from moge.model.v1 import MoGeModel` or `from moge.model import import_model_class_by_version; MoGeModel = import_model_class_by_version('v1')`.
+- Exposed `num_tokens` parameter in MoGe model.
+
+## 2025-06-10
+### Added
+- Released MoGe-2.
+
+## 2025-10-16
+### Added
+- Update training code for MoGe-2.
+
+### Changed
+- Refactored training dataloader code for better readability.
+- Removed Git LFS for convenience.
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/CODE_OF_CONDUCT.md b/lingbotvla/models/vla/vision_models/MoGe/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..f9ba8cf65f3e3104dd061c178066ec8247811f33
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/CODE_OF_CONDUCT.md
@@ -0,0 +1,9 @@
+# Microsoft Open Source Code of Conduct
+
+This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
+
+Resources:
+
+- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
+- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
+- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
diff --git a/lingbotvla/models/vla/vision_models/MoGe/LICENSE b/lingbotvla/models/vla/vision_models/MoGe/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..3458b5ccd398afed340e17a4d0615c9a8666bb5d
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/LICENSE
@@ -0,0 +1,224 @@
+ MIT License
+
+ Copyright (c) Microsoft Corporation.
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all
+ copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE
+
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/lingbotvla/models/vla/vision_models/MoGe/README.md b/lingbotvla/models/vla/vision_models/MoGe/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3e5996e8b97f9f198092c67d349830a35465df5b
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/README.md
@@ -0,0 +1,295 @@
+# MoGe: Accurate Monocular Geometry Estimation
+
+MoGe is a powerful model for recovering 3D geometry from monocular open-domain images, including metric point maps, metric depth maps, normal maps and camera FOV. ***Check our websites ([MoGe-1](https://wangrc.site/MoGePage), [MoGe-2](https://wangrc.site/MoGe2Page)) for videos and interactive results!***
+
+## 📖 Publications
+
+### MoGe-2: Accurate Monocular Geometry with Metric Scale and Sharp Details
+
+
+

+

+
-blue)
+
+https://github.com/user-attachments/assets/8f9ae680-659d-4f7f-82e2-b9ed9d6b988a
+
+
+
+### MoGe: Unlocking Accurate Monocular Geometry Estimation for Open-Domain Images with Optimal Training Supervision
+
+
+
+
+
+
+## 🌟 Features
+
+* **Accurate 3D geometry estimation**: Estimate point maps & depth maps & [normal maps](docs/normal.md) from open-domain single images with high precision -- all capabilities in one model, one forward pass.
+* **Optional ground-truth FOV input**: Enhance model accuracy further by providing the true field of view.
+* **Flexible resolution support**: Works seamlessly with various resolutions and aspect ratios, from 2:1 to 1:2.
+* **Optimized for speed**: Achieves 60ms latency per image (A100 or RTX3090, FP16, ViT-L). Adjustable inference resolution for even faster speed.
+
+## ✨ News
+
+***(2025-10-16)***
+* Updated training code for MoGe-2.
+
+***(2025-06-10)***
+
+* ❗**Released MoGe-2**, a state-of-the-art model for monocular geometry, with these new capabilities in one unified model:
+ * point map prediction in **metric scale**;
+ * comparable and even better performance over MoGe-1;
+ * significant improvement of **visual sharpness**;
+ * high-quality [**normal map** estimation](docs/normal.md);
+ * lower inference latency.
+
+## 📦 Installation
+
+### Install via pip
+
+```bash
+pip install git+https://github.com/microsoft/MoGe.git
+```
+
+### Or clone this repository
+
+```bash
+git clone https://github.com/microsoft/MoGe.git
+cd MoGe
+pip install -r requirements.txt # install the requirements
+```
+
+Note: MoGe should be compatible with most requirements versions. Please check the `requirements.txt` for more details if you encounter any dependency issues.
+
+## 🤗 Pretrained Models
+
+Our pretrained models are available on the huggingface hub:
+
+
+
+
+> NOTE: `moge-2-vitl-normal` has full capabilities, with almost the same level of performance as `moge-2-vitl` plus extra normal map estimation.
+
+You may import the `MoGeModel` class of the matched version, then load the pretrained weights via `MoGeModel.from_pretrained("HUGGING_FACE_MODEL_REPO_NAME")` with automatic downloading.
+If loading a local checkpoint, replace the model name with the local path.
+
+For ONNX support, please refer to [docs/onnx.md](docs/onnx.md).
+
+## 💡 Minimal Code Example
+
+Here is a minimal example for loading the model and inferring on a single image.
+
+```python
+import cv2
+import torch
+# from moge.model.v1 import MoGeModel
+from moge.model.v2 import MoGeModel # Let's try MoGe-2
+
+device = torch.device("cuda")
+
+# Load the model from huggingface hub (or load from local).
+model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)
+
+# Read the input image and convert to tensor (3, H, W) with RGB values normalized to [0, 1]
+input_image = cv2.cvtColor(cv2.imread("PATH_TO_IMAGE.jpg"), cv2.COLOR_BGR2RGB)
+input_image = torch.tensor(input_image / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
+
+# Infer
+output = model.infer(input_image)
+"""
+`output` has keys "points", "depth", "mask", "normal" (optional) and "intrinsics",
+The maps are in the same size as the input image.
+{
+ "points": (H, W, 3), # point map in OpenCV camera coordinate system (x right, y down, z forward). For MoGe-2, the point map is in metric scale.
+ "depth": (H, W), # depth map
+ "normal": (H, W, 3) # normal map in OpenCV camera coordinate system. (available for MoGe-2-normal)
+ "mask": (H, W), # a binary mask for valid pixels.
+ "intrinsics": (3, 3), # normalized camera intrinsics
+}
+"""
+```
+For more usage details, see the `MoGeModel.infer()` docstring.
+
+## 💡 Usage
+
+### Gradio demo | `moge app`
+
+> The demo for MoGe-1 is also available at our [Hugging Face Space](https://huggingface.co/spaces/Ruicheng/MoGe).
+
+```bash
+# Using the command line tool
+moge app # will run MoGe-2 demo by default.
+
+# In this repo
+python moge/scripts/app.py # --share for Gradio public sharing
+```
+
+See also [`moge/scripts/app.py`](moge/scripts/app.py)
+
+
+### Inference | `moge infer`
+
+Run the script `moge/scripts/infer.py` via the following command:
+
+```bash
+# Save the output [maps], [glb] and [ply] files
+moge infer -i IMAGES_FOLDER_OR_IMAGE_PATH --o OUTPUT_FOLDER --maps --glb --ply
+
+# Show the result in a window (requires pyglet < 2.0, e.g. pip install pyglet==1.5.29)
+moge infer -i IMAGES_FOLDER_OR_IMAGE_PATH --o OUTPUT_FOLDER --show
+```
+
+For detailed options, run `moge infer --help`:
+
+```
+Usage: moge infer [OPTIONS]
+
+ Inference script
+
+Options:
+ -i, --input PATH Input image or folder path. "jpg" and "png" are
+ supported.
+ --fov_x FLOAT If camera parameters are known, set the
+ horizontal field of view in degrees. Otherwise,
+ MoGe will estimate it.
+ -o, --output PATH Output folder path
+ --pretrained TEXT Pretrained model name or path. If not provided,
+ the corresponding default model will be chosen.
+ --version [v1|v2] Model version. Defaults to "v2"
+ --device TEXT Device name (e.g. "cuda", "cuda:0", "cpu").
+ Defaults to "cuda"
+ --fp16 Use fp16 precision for much faster inference.
+ --resize INTEGER Resize the image(s) & output maps to a specific
+ size. Defaults to None (no resizing).
+ --resolution_level INTEGER An integer [0-9] for the resolution level for
+ inference. Higher value means more tokens and
+ the finer details will be captured, but
+ inference can be slower. Defaults to 9. Note
+ that it is irrelevant to the output size, which
+ is always the same as the input size.
+ `resolution_level` actually controls
+ `num_tokens`. See `num_tokens` for more details.
+ --num_tokens INTEGER number of tokens used for inference. A integer
+ in the (suggested) range of `[1200, 2500]`.
+ `resolution_level` will be ignored if
+ `num_tokens` is provided. Default: None
+ --threshold FLOAT Threshold for removing edges. Defaults to 0.01.
+ Smaller value removes more edges. "inf" means no
+ thresholding.
+ --maps Whether to save the output maps (image, point
+ map, depth map, normal map, mask) and fov.
+ --glb Whether to save the output as a.glb file. The
+ color will be saved as a texture.
+ --ply Whether to save the output as a.ply file. The
+ color will be saved as vertex colors.
+ --show Whether show the output in a window. Note that
+ this requires pyglet<2 installed as required by
+ trimesh.
+ --help Show this message and exit.
+```
+
+See also [`moge/scripts/infer.py`](moge/scripts/infer.py)
+
+### 360° panorama images | `moge infer_panorama`
+
+> *NOTE: This is an experimental extension of MoGe.*
+
+The script will split the 360-degree panorama image into multiple perspective views and infer on each view separately.
+The output maps will be combined to produce a panorama depth map and point map.
+
+Note that the panorama image must have spherical parameterization (e.g., environment maps or equirectangular images). Other formats must be converted to spherical format before using this script. Run `moge infer_panorama --help` for detailed options.
+
+
+
+

+
+The photo is from [this URL](https://commons.wikimedia.org/wiki/Category:360%C2%B0_panoramas_with_equirectangular_projection#/media/File:Braunschweig_Sankt-%C3%84gidien_Panorama_02.jpg)
+
+
+See also [`moge/scripts/infer_panorama.py`](moge/scripts/infer_panorama.py)
+
+## 🏋️♂️ Training & Finetuning
+
+See [docs/train.md](docs/train.md)
+
+## 🧪 Evaluation
+
+See [docs/eval.md](docs/eval.md)
+
+## ⚖️ License
+
+MoGe code is released under the MIT license, except for DINOv2 code in `moge/model/dinov2` which is released by Meta AI under the Apache 2.0 license.
+See [LICENSE](LICENSE) for more details.
+
+
+## 📜 Citation
+
+If you find our work useful in your research, we gratefully request that you consider citing our paper:
+
+```
+@inproceedings{wang2025moge,
+ title={Moge: Unlocking accurate monocular geometry estimation for open-domain images with optimal training supervision},
+ author={Wang, Ruicheng and Xu, Sicheng and Dai, Cassie and Xiang, Jianfeng and Deng, Yu and Tong, Xin and Yang, Jiaolong},
+ booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
+ pages={5261--5271},
+ year={2025}
+}
+
+@misc{wang2025moge2,
+ title={MoGe-2: Accurate Monocular Geometry with Metric Scale and Sharp Details},
+ author={Ruicheng Wang and Sicheng Xu and Yue Dong and Yu Deng and Jianfeng Xiang and Zelong Lv and Guangzhong Sun and Xin Tong and Jiaolong Yang},
+ year={2025},
+ eprint={2507.02546},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2507.02546},
+}
+```
diff --git a/lingbotvla/models/vla/vision_models/MoGe/SECURITY.md b/lingbotvla/models/vla/vision_models/MoGe/SECURITY.md
new file mode 100644
index 0000000000000000000000000000000000000000..b3c89efc852e22f71eabf5dfbc6ac62493425eb6
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/SECURITY.md
@@ -0,0 +1,41 @@
+
+
+## Security
+
+Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
+
+If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
+
+## Reporting Security Issues
+
+**Please do not report security vulnerabilities through public GitHub issues.**
+
+Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
+
+If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
+
+You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
+
+Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
+
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
+ * Full paths of source file(s) related to the manifestation of the issue
+ * The location of the affected source code (tag/branch/commit or direct URL)
+ * Any special configuration required to reproduce the issue
+ * Step-by-step instructions to reproduce the issue
+ * Proof-of-concept or exploit code (if possible)
+ * Impact of the issue, including how an attacker might exploit the issue
+
+This information will help us triage your report more quickly.
+
+If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
+
+## Preferred Languages
+
+We prefer all communications to be in English.
+
+## Policy
+
+Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
+
+
diff --git a/lingbotvla/models/vla/vision_models/MoGe/SUPPORT.md b/lingbotvla/models/vla/vision_models/MoGe/SUPPORT.md
new file mode 100644
index 0000000000000000000000000000000000000000..291d4d43733f4c15a81ff598ec1c99fd6c18f64c
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/SUPPORT.md
@@ -0,0 +1,25 @@
+# TODO: The maintainer of this repo has not yet edited this file
+
+**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
+
+- **No CSS support:** Fill out this template with information about how to file issues and get help.
+- **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
+- **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
+
+*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
+
+# Support
+
+## How to file issues and get help
+
+This project uses GitHub Issues to track bugs and feature requests. Please search the existing
+issues before filing new issues to avoid duplicates. For new issues, file your bug or
+feature request as a new Issue.
+
+For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
+FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
+CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
+
+## Microsoft Support Policy
+
+Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
diff --git a/lingbotvla/models/vla/vision_models/MoGe/assets/normal_comaprison.jpg b/lingbotvla/models/vla/vision_models/MoGe/assets/normal_comaprison.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6f7653596dd0ac774c90fa8a80f62a28377b8a9d
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/assets/normal_comaprison.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c653f5e14d29ede3964aa038d9c9529c18de71cacc8aee34cfddf215589c5ebe
+size 2552556
diff --git a/lingbotvla/models/vla/vision_models/MoGe/assets/overview_simplified.png b/lingbotvla/models/vla/vision_models/MoGe/assets/overview_simplified.png
new file mode 100644
index 0000000000000000000000000000000000000000..60a958eb46578b30a14fec1cfaea1289df88391e
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/assets/overview_simplified.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7025a671e863bddbc22e79dc3e2eca8b7aeaf35fe93f6ef7f2b18f4fc9e093e6
+size 414314
diff --git a/lingbotvla/models/vla/vision_models/MoGe/assets/panorama_pipeline.png b/lingbotvla/models/vla/vision_models/MoGe/assets/panorama_pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..334354c8a68ed7a9865c424f9890a72468b0a198
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/assets/panorama_pipeline.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed28c5309162bddda016ca600307ecc73f7e6415f9eaaefb9f6fffadf6951aaa
+size 738233
diff --git a/lingbotvla/models/vla/vision_models/MoGe/baselines/da_v2.py b/lingbotvla/models/vla/vision_models/MoGe/baselines/da_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bca560a75514bdfa38c9a28c8d36ea0e006dab1e
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/baselines/da_v2.py
@@ -0,0 +1,88 @@
+# Reference: https://github.com/DepthAnything/Depth-Anything-V2
+import os
+import sys
+from typing import *
+from pathlib import Path
+
+import click
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+
+from moge.test.baseline import MGEBaselineInterface
+
+
+class Baseline(MGEBaselineInterface):
+ def __init__(self, repo_path: str, backbone: str, num_tokens: int, device: Union[torch.device, str]):
+ # Create from repo
+ repo_path = os.path.abspath(repo_path)
+ if repo_path not in sys.path:
+ sys.path.append(repo_path)
+ if not Path(repo_path).exists():
+ raise FileNotFoundError(f'Cannot find the Depth-Anything repository at {repo_path}. Please clone the repository and provide the path to it using the --repo option.')
+ from depth_anything_v2.dpt import DepthAnythingV2
+
+ device = torch.device(device)
+
+ # Instantiate model
+ model = DepthAnythingV2(encoder=backbone, features=256, out_channels=[256, 512, 1024, 1024])
+
+ # Load checkpoint
+ checkpoint_path = os.path.join(repo_path, f'checkpoints/depth_anything_v2_{backbone}.pth')
+ if not os.path.exists(checkpoint_path):
+ raise FileNotFoundError(f'Cannot find the checkpoint file at {checkpoint_path}. Please download the checkpoint file and place it in the checkpoints directory.')
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
+ model.load_state_dict(checkpoint)
+
+ model.to(device).eval()
+ self.model = model
+ self.num_tokens = num_tokens
+ self.device = device
+
+ @click.command()
+ @click.option('--repo', 'repo_path', type=click.Path(), default='../Depth-Anything-V2', help='Path to the Depth-Anything repository.')
+ @click.option('--backbone', type=click.Choice(['vits', 'vitb', 'vitl']), default='vitl', help='Encoder architecture.')
+ @click.option('--num_tokens', type=int, default=None, help='Number of tokens to use for the input image.')
+ @click.option('--device', type=str, default='cuda', help='Device to use for inference.')
+ @staticmethod
+ def load(repo_path: str, backbone, num_tokens: int, device: torch.device = 'cuda'):
+ return Baseline(repo_path, backbone, num_tokens, device)
+
+ @torch.inference_mode()
+ def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
+ original_height, original_width = image.shape[-2:]
+
+ assert intrinsics is None, "Depth-Anything-V2 does not support camera intrinsics input"
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+ omit_batch_dim = True
+ else:
+ omit_batch_dim = False
+
+ if self.num_tokens is None:
+ resize_factor = 518 / min(original_height, original_width)
+ expected_width = round(original_width * resize_factor / 14) * 14
+ expected_height = round(original_height * resize_factor / 14) * 14
+ else:
+ aspect_ratio = original_width / original_height
+ tokens_rows = round((self.num_tokens * aspect_ratio) ** 0.5)
+ tokens_cols = round((self.num_tokens / aspect_ratio) ** 0.5)
+ expected_width = tokens_cols * 14
+ expected_height = tokens_rows * 14
+ image = TF.resize(image, (expected_height, expected_width), interpolation=T.InterpolationMode.BICUBIC, antialias=True)
+
+ image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ disparity = self.model(image)
+
+ disparity = F.interpolate(disparity[:, None], size=(original_height, original_width), mode='bilinear', align_corners=False, antialias=False)[:, 0]
+
+ if omit_batch_dim:
+ disparity = disparity.squeeze(0)
+
+ return {
+ 'disparity_affine_invariant': disparity
+ }
+
diff --git a/lingbotvla/models/vla/vision_models/MoGe/baselines/da_v2_metric.py b/lingbotvla/models/vla/vision_models/MoGe/baselines/da_v2_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee4c70d8c6634babf165d2982a692230f5adeac6
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/baselines/da_v2_metric.py
@@ -0,0 +1,99 @@
+# Reference https://github.com/DepthAnything/Depth-Anything-V2/metric_depth
+import os
+import sys
+from typing import *
+from pathlib import Path
+
+import click
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+import cv2
+
+from moge.test.baseline import MGEBaselineInterface
+
+
+class Baseline(MGEBaselineInterface):
+
+ def __init__(self, repo_path: str, backbone: str, domain: str, num_tokens: int, device: str):
+ device = torch.device(device)
+ repo_path = os.path.abspath(repo_path)
+ if not Path(repo_path).exists():
+ raise FileNotFoundError(f'Cannot find the Depth-Anything repository at {repo_path}. Please clone the repository and provide the path to it using the --repo option.')
+ sys.path.append(os.path.join(repo_path, 'metric_depth'))
+ from depth_anything_v2.dpt import DepthAnythingV2
+
+ model_configs = {
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
+ }
+
+ if domain == 'indoor':
+ dataset = 'hypersim'
+ max_depth = 20
+ elif domain == 'outdoor':
+ dataset = 'vkitti'
+ max_depth = 80
+ else:
+ raise ValueError(f"Invalid domain: {domain}")
+
+ model = DepthAnythingV2(**model_configs[backbone], max_depth=max_depth)
+ checkpoint_path = os.path.join(repo_path, f'checkpoints/depth_anything_v2_metric_{dataset}_{backbone}.pth')
+ if not os.path.exists(checkpoint_path):
+ raise FileNotFoundError(f'Cannot find the checkpoint file at {checkpoint_path}. Please download the checkpoint file and place it in the checkpoints directory.')
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu', weights_only=True))
+ model.eval().to(device)
+
+ self.model = model
+ self.num_tokens = num_tokens
+ self.device = device
+
+ @click.command()
+ @click.option('--repo', 'repo_path', type=click.Path(), default='../Depth-Anything-V2', help='Path to the Depth-Anything repository.')
+ @click.option('--backbone', type=click.Choice(['vits', 'vitb', 'vitl']), default='vitl', help='Backbone architecture.')
+ @click.option('--domain', type=click.Choice(['indoor', 'outdoor']), help='Domain of the dataset.')
+ @click.option('--num_tokens', type=int, default=None, help='Number of tokens for the ViT model')
+ @click.option('--device', type=str, default='cuda', help='Device to use for inference.')
+ @staticmethod
+ def load(repo_path: str, backbone: str, domain: str, num_tokens: int, device: str):
+ return Baseline(repo_path, backbone, domain, num_tokens, device)
+
+ @torch.inference_mode()
+ def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
+ original_height, original_width = image.shape[-2:]
+
+ assert intrinsics is None, "Depth-Anything-V2 does not support camera intrinsics input"
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+ omit_batch_dim = True
+ else:
+ omit_batch_dim = False
+
+ if self.num_tokens is None:
+ resize_factor = 518 / min(original_height, original_width)
+ expected_width = round(original_width * resize_factor / 14) * 14
+ expected_height = round(original_height * resize_factor / 14) * 14
+ else:
+ aspect_ratio = original_width / original_height
+ tokens_rows = round((self.num_tokens * aspect_ratio) ** 0.5)
+ tokens_cols = round((self.num_tokens / aspect_ratio) ** 0.5)
+ expected_width = tokens_cols * 14
+ expected_height = tokens_rows * 14
+ image = TF.resize(image, (expected_height, expected_width), interpolation=T.InterpolationMode.BICUBIC, antialias=True)
+
+ image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ depth = self.model(image)
+
+ depth = F.interpolate(depth[:, None], size=(original_height, original_width), mode='bilinear', align_corners=False, antialias=False)[:, 0]
+
+ if omit_batch_dim:
+ depth = depth.squeeze(0)
+
+ return {
+ 'depth_metric': depth
+ }
+
diff --git a/lingbotvla/models/vla/vision_models/MoGe/baselines/metric3d_v2.py b/lingbotvla/models/vla/vision_models/MoGe/baselines/metric3d_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..661ed5ddfd3bf6f34f53dcb19bcc5a9889a8e377
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/baselines/metric3d_v2.py
@@ -0,0 +1,117 @@
+# Reference: https://github.com/YvanYin/Metric3D
+import os
+import sys
+from typing import *
+
+import click
+import torch
+import torch.nn.functional as F
+import cv2
+
+from moge.test.baseline import MGEBaselineInterface
+
+
+class Baseline(MGEBaselineInterface):
+ def __init__(self, backbone: Literal['vits', 'vitl', 'vitg'], device):
+ backbone_map = {
+ 'vits': 'metric3d_vit_small',
+ 'vitl': 'metric3d_vit_large',
+ 'vitg': 'metric3d_vit_giant2'
+ }
+
+ device = torch.device(device)
+ model = torch.hub.load('yvanyin/metric3d', backbone_map[backbone], pretrain=True)
+ model.to(device).eval()
+
+ self.model = model
+ self.device = device
+
+ @click.command()
+ @click.option('--backbone', type=click.Choice(['vits', 'vitl', 'vitg']), default='vitl', help='Encoder architecture.')
+ @click.option('--device', type=str, default='cuda', help='Device to use.')
+ @staticmethod
+ def load(backbone: str = 'vitl', device: torch.device = 'cuda'):
+ return Baseline(backbone, device)
+
+ @torch.inference_mode()
+ def inference_one_image(self, image: torch.Tensor, intrinsics: torch.Tensor = None):
+ # Reference: https://github.com/YvanYin/Metric3D/blob/main/mono/utils/do_test.py
+
+ # rgb_origin: RGB, 0-255, uint8
+ rgb_origin = image.cpu().numpy().transpose((1, 2, 0)) * 255
+
+ # keep ratio resize
+ input_size = (616, 1064) # for vit model
+ h, w = rgb_origin.shape[:2]
+ scale = min(input_size[0] / h, input_size[1] / w)
+ rgb = cv2.resize(rgb_origin, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR)
+ if intrinsics is not None:
+ focal = intrinsics[0, 0] * int(w * scale)
+
+ # padding to input_size
+ padding = [123.675, 116.28, 103.53]
+ h, w = rgb.shape[:2]
+ pad_h = input_size[0] - h
+ pad_w = input_size[1] - w
+ pad_h_half = pad_h // 2
+ pad_w_half = pad_w // 2
+ rgb = cv2.copyMakeBorder(rgb, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=padding)
+ pad_info = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
+
+ # normalize rgb
+ mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None]
+ std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None]
+ rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float()
+ rgb = torch.div((rgb - mean), std)
+ rgb = rgb[None, :, :, :].cuda()
+
+ # inference
+ pred_depth, confidence, output_dict = self.model.inference({'input': rgb})
+
+ # un pad
+ pred_depth = pred_depth.squeeze()
+ pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]]
+ pred_depth = pred_depth.clamp_min(0.5) # clamp to 0.5m, since metric3d could yield very small depth values, resulting in crashed the scale shift alignment.
+
+ # upsample to original size
+ pred_depth = F.interpolate(pred_depth[None, None, :, :], image.shape[-2:], mode='bilinear').squeeze()
+
+ if intrinsics is not None:
+ # de-canonical transform
+ canonical_to_real_scale = focal / 1000.0 # 1000.0 is the focal length of canonical camera
+ pred_depth = pred_depth * canonical_to_real_scale # now the depth is metric
+ pred_depth = torch.clamp(pred_depth, 0, 300)
+
+ pred_normal, normal_confidence = output_dict['prediction_normal'].split([3, 1], dim=1) # see https://arxiv.org/abs/2109.09881 for details
+
+ # un pad and resize to some size if needed
+ pred_normal = pred_normal.squeeze(0)
+ pred_normal = pred_normal[:, pad_info[0] : pred_normal.shape[1] - pad_info[1], pad_info[2] : pred_normal.shape[2] - pad_info[3]]
+
+ # you can now do anything with the normal
+ pred_normal = F.interpolate(pred_normal[None, :, :, :], image.shape[-2:], mode='bilinear').squeeze(0)
+ pred_normal = F.normalize(pred_normal, p=2, dim=0)
+
+ return pred_depth, pred_normal.permute(1, 2, 0)
+
+ @torch.inference_mode()
+ def infer(self, image: torch.Tensor, intrinsics: torch.Tensor = None):
+ # image: (B, H, W, 3) or (H, W, 3)
+ if image.ndim == 3:
+ pred_depth, pred_normal = self.inference_one_image(image, intrinsics)
+ else:
+ for i in range(image.shape[0]):
+ pred_depth_i, pred_normal_i = self.inference_one_image(image[i], intrinsics[i] if intrinsics is not None else None)
+ pred_depth.append(pred_depth_i)
+ pred_normal.append(pred_normal_i)
+ pred_depth = torch.stack(pred_depth, dim=0)
+ pred_normal = torch.stack(pred_normal, dim=0)
+
+ if intrinsics is not None:
+ return {
+ "depth_metric": pred_depth,
+ }
+ else:
+ return {
+ "depth_scale_invariant": pred_depth,
+ }
diff --git a/lingbotvla/models/vla/vision_models/MoGe/baselines/moge.py b/lingbotvla/models/vla/vision_models/MoGe/baselines/moge.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd66d69183db090cf783166a19fc62bd611f0e2b
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/baselines/moge.py
@@ -0,0 +1,83 @@
+import os
+import sys
+from typing import *
+import importlib
+
+import click
+import torch
+import utils3d
+
+from moge.test.baseline import MGEBaselineInterface
+
+
+class Baseline(MGEBaselineInterface):
+
+ def __init__(self, num_tokens: int, resolution_level: int, pretrained_model_name_or_path: str, use_fp16: bool, device: str = 'cuda:0', version: str = 'v1'):
+ super().__init__()
+ from moge.model import import_model_class_by_version
+ MoGeModel = import_model_class_by_version(version)
+ self.version = version
+
+ self.model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval()
+
+ self.device = torch.device(device)
+ self.num_tokens = num_tokens
+ self.resolution_level = resolution_level
+ self.use_fp16 = use_fp16
+
+ @click.command()
+ @click.option('--num_tokens', type=int, default=None)
+ @click.option('--resolution_level', type=int, default=9)
+ @click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl')
+ @click.option('--fp16', 'use_fp16', is_flag=True)
+ @click.option('--device', type=str, default='cuda:0')
+ @click.option('--version', type=str, default='v1')
+ @staticmethod
+ def load(num_tokens: int, resolution_level: int, pretrained_model_name_or_path: str, use_fp16: bool, device: str = 'cuda:0', version: str = 'v1'):
+ return Baseline(num_tokens, resolution_level, pretrained_model_name_or_path, use_fp16, device, version)
+
+ # Implementation for inference
+ @torch.inference_mode()
+ def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.FloatTensor] = None):
+ if intrinsics is not None:
+ fov_x, _ = utils3d.pt.intrinsics_to_fov(intrinsics)
+ fov_x = torch.rad2deg(fov_x)
+ else:
+ fov_x = None
+ output = self.model.infer(image, fov_x=fov_x, apply_mask=True, num_tokens=self.num_tokens)
+
+ if self.version == 'v1':
+ return {
+ 'points_scale_invariant': output['points'],
+ 'depth_scale_invariant': output['depth'],
+ 'intrinsics': output['intrinsics'],
+ }
+ else:
+ return {
+ 'points_metric': output['points'],
+ 'depth_metric': output['depth'],
+ 'intrinsics': output['intrinsics'],
+ }
+
+ @torch.inference_mode()
+ def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: torch.FloatTensor = None):
+ if intrinsics is not None:
+ fov_x, _ = utils3d.pt.intrinsics_to_fov(intrinsics)
+ fov_x = torch.rad2deg(fov_x)
+ else:
+ fov_x = None
+ output = self.model.infer(image, fov_x=fov_x, apply_mask=False, num_tokens=self.num_tokens, use_fp16=self.use_fp16)
+
+ if self.version == 'v1':
+ return {
+ 'points_scale_invariant': output['points'],
+ 'depth_scale_invariant': output['depth'],
+ 'intrinsics': output['intrinsics'],
+ }
+ else:
+ return {
+ 'points_metric': output['points'],
+ 'depth_metric': output['depth'],
+ 'intrinsics': output['intrinsics'],
+ }
+
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/all_benchmarks.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/all_benchmarks.json
new file mode 100644
index 0000000000000000000000000000000000000000..94c0fc4605f3a3472d7d39d4d8e40eb9e3d784b7
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/all_benchmarks.json
@@ -0,0 +1,78 @@
+{
+ "NYUv2": {
+ "path": "data/eval/NYUv2",
+ "width": 640,
+ "height": 480,
+ "split": ".index.txt",
+ "depth_unit": 1.0
+ },
+ "KITTI": {
+ "path": "data/eval/KITTI",
+ "width": 750,
+ "height": 375,
+ "split": ".index.txt",
+ "depth_unit": 1
+ },
+ "ETH3D": {
+ "path": "data/eval/ETH3D",
+ "width": 2048,
+ "height": 1365,
+ "split": ".index.txt",
+ "include_segmentation": true,
+ "depth_unit": 1
+ },
+ "iBims-1": {
+ "path": "data/eval/iBims-1",
+ "width": 640,
+ "height": 480,
+ "split": ".index.txt",
+ "has_sharp_boundary": true,
+ "include_segmentation": true,
+ "depth_unit": 1.0
+ },
+ "GSO": {
+ "path": "data/eval/GSO",
+ "width": 512,
+ "height": 512,
+ "split": ".index.txt"
+ },
+ "Sintel": {
+ "path": "data/eval/Sintel",
+ "width": 872,
+ "height": 436,
+ "split": ".index.txt",
+ "has_sharp_boundary": true,
+ "include_segmentation": true
+ },
+ "DDAD": {
+ "path": "data/eval/DDAD",
+ "width": 1400,
+ "height": 700,
+ "include_segmentation": true,
+ "split": ".index.txt",
+ "depth_unit": 1.0
+ },
+ "DIODE": {
+ "path": "data/eval/DIODE",
+ "width": 1024,
+ "height": 768,
+ "split": ".index.txt",
+ "include_segmentation": true,
+ "depth_unit": 1.0
+ },
+ "Spring": {
+ "path": "data/eval/Spring",
+ "width": 1920,
+ "height": 1080,
+ "split": ".index.txt",
+ "has_sharp_boundary": true
+ },
+ "HAMMER": {
+ "path": "data/eval/HAMMER",
+ "width": 1664,
+ "height": 832,
+ "split": ".index.txt",
+ "depth_unit": 1,
+ "has_sharp_boundary": true
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/ddad.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/ddad.json
new file mode 100644
index 0000000000000000000000000000000000000000..09dd4d74bbccbb46a4013afd9fee1e717d606a53
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/ddad.json
@@ -0,0 +1,9 @@
+{
+ "DDAD": {
+ "path": "data/eval/DDAD",
+ "width": 1400,
+ "height": 700,
+ "include_segmentation": true,
+ "split": ".index.txt"
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/diode.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/diode.json
new file mode 100644
index 0000000000000000000000000000000000000000..679ca6ee13ddf5e5bcab93f453b2f11279781a2f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/diode.json
@@ -0,0 +1,9 @@
+{
+ "DIODE": {
+ "path": "data/eval/DIODE",
+ "width": 1024,
+ "height": 768,
+ "split": ".index.txt",
+ "include_segmentation": true
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/eth3d.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/eth3d.json
new file mode 100644
index 0000000000000000000000000000000000000000..88a3a1b291dcde3f2959c0d36d7ebbc33213fc84
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/eth3d.json
@@ -0,0 +1,10 @@
+{
+ "ETH3D": {
+ "path": "data/eval/ETH3D",
+ "width": 2048,
+ "height": 1365,
+ "split": ".index.txt",
+ "include_segmentation": true,
+ "depth_unit": 1
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/gso.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/gso.json
new file mode 100644
index 0000000000000000000000000000000000000000..ee1aefff7ae3453b0cdddf7ab3369301d2e8d924
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/gso.json
@@ -0,0 +1,8 @@
+{
+ "GSO": {
+ "path": "data/eval/GSO",
+ "width": 512,
+ "height": 512,
+ "split": ".index.txt"
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/hammer.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/hammer.json
new file mode 100644
index 0000000000000000000000000000000000000000..41838db6bfcf2ea6f3ed230b6c7ee3315ec3fbfe
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/hammer.json
@@ -0,0 +1,10 @@
+{
+ "HAMMER": {
+ "path": "data/eval/HAMMER",
+ "width": 1664,
+ "height": 832,
+ "split": ".index.txt",
+ "depth_unit": 1,
+ "has_sharp_boundary": true
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/ibims-1.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/ibims-1.json
new file mode 100644
index 0000000000000000000000000000000000000000..a6f0a0387891deb09bcae61bc4e4098e04db7307
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/ibims-1.json
@@ -0,0 +1,10 @@
+{
+ "iBims-1": {
+ "path": "data/eval/iBims-1",
+ "width": 640,
+ "height": 480,
+ "split": ".index.txt",
+ "include_segmentation": true,
+ "has_sharp_boundary": true
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/kitti.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/kitti.json
new file mode 100644
index 0000000000000000000000000000000000000000..10ca7c3eb560649ce25edbf4ed5c835e90396cb8
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/kitti.json
@@ -0,0 +1,9 @@
+{
+ "KITTI": {
+ "path": "data/eval/KITTI",
+ "width": 750,
+ "height": 375,
+ "split": ".index.txt",
+ "depth_unit": 1
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/nyu.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/nyu.json
new file mode 100644
index 0000000000000000000000000000000000000000..62841335b17f508ca903634b51b70f3e8a576186
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/nyu.json
@@ -0,0 +1,8 @@
+{
+ "NYUv2": {
+ "path": "data/eval/NYUv2",
+ "width": 640,
+ "height": 480,
+ "split": ".test.txt"
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/sintel.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/sintel.json
new file mode 100644
index 0000000000000000000000000000000000000000..fde872e282e260f987208168bdcd166a104732d3
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/sintel.json
@@ -0,0 +1,10 @@
+{
+ "Sintel": {
+ "path": "data/eval/Sintel",
+ "width": 872,
+ "height": 436,
+ "split": ".index.txt",
+ "include_segmentation": true,
+ "has_sharp_boundary": true
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/spring.json b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/spring.json
new file mode 100644
index 0000000000000000000000000000000000000000..a18e51a969fe5b605c03ed1f0a4714dec9379539
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/eval/benchmarks/spring.json
@@ -0,0 +1,9 @@
+{
+ "Spring": {
+ "path": "data/eval/Spring",
+ "width": 1920,
+ "height": 1080,
+ "split": ".test.txt",
+ "has_sharp_boundary": true
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/train/v1.json b/lingbotvla/models/vla/vision_models/MoGe/configs/train/v1.json
new file mode 100644
index 0000000000000000000000000000000000000000..f87f38944129e60da58eed20395d36f4404d7164
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/train/v1.json
@@ -0,0 +1,77 @@
+{
+ "data": {
+ "aspect_ratio_range": [0.5, 2.0],
+ "area_range": [250000, 1000000],
+ "clamp_max_depth": 1000.0,
+ "center_augmentation": 0.5,
+ "fov_range_absolute": [1, 179],
+ "fov_range_relative": [0.01, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring"],
+ "datasets": [
+ {
+ "name": "TartanAir",
+ "path": "blobmnt/data_v3/TartanAir",
+ "label_type": "synthetic",
+ "index": ".index.txt",
+ "depth": "depth.png",
+ "weight": 4.8,
+ "center_augmentation": 0.25,
+ "fov_range_absolute": [30, 150],
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise"]
+ }
+ ]
+ },
+ "model_version": "v1",
+ "model": {
+ "encoder": "dinov2_vitl14",
+ "remap_output": "exp",
+ "intermediate_layers": 4,
+ "dim_upsample": [256, 128, 64],
+ "dim_times_res_block_hidden": 2,
+ "num_res_blocks": 2,
+ "num_tokens_range": [1200, 2500],
+ "last_conv_channels": 32,
+ "last_conv_size": 1
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": [
+ {"params": {"include": ["*"], "exclude": ["*backbone.*"]}, "lr": 1e-4},
+ {"params": {"include": ["*backbone.*"]}, "lr": 1e-5}
+ ]
+ },
+ "lr_scheduler": {
+ "type": "SequentialLR",
+ "params": {
+ "schedulers": [
+ {"type": "LambdaLR", "params": {"lr_lambda": ["1.0", "max(0.0, min(1.0, (epoch - 1000) / 1000))"]}},
+ {"type": "StepLR", "params": {"step_size": 25000, "gamma": 0.5}}
+ ],
+ "milestones": [2000]
+ }
+ },
+ "low_resolution_training_steps": 50000,
+ "loss": {
+ "invalid": {},
+ "synthetic": {
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}},
+ "patch_64": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 64, "align_resolution": 4, "num_patches": 4096}},
+ "normal": {"function": "normal_loss", "weight": 1.0},
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
+ },
+ "sfm": {
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}},
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
+ },
+ "lidar": {
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
+ }
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/configs/train/v2.json b/lingbotvla/models/vla/vision_models/MoGe/configs/train/v2.json
new file mode 100644
index 0000000000000000000000000000000000000000..b52db6895fc4d99cc8c463976bc98d0f0df37a28
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/configs/train/v2.json
@@ -0,0 +1,332 @@
+{
+ "data": {
+ "aspect_ratio_range": [0.5, 2.0],
+ "area_range": [250000, 1000000],
+ "clamp_max_depth": 1000.0,
+ "center_augmentation": 0.5,
+ "fov_range_absolute": [1, 179],
+ "fov_range_relative": [0.01, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring"],
+ "datasets": [
+ {
+ "name": "A2D2",
+ "path": "path/to//A2D2/",
+ "label_type": "C",
+ "weight": 0.8,
+ "depth_unit": 1,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "dof"],
+ "depth": "depth_completed.png"
+ },
+ {
+ "name": "ARKitScenes",
+ "path": "path/to//ARKitScenes/",
+ "label_type": "B",
+ "weight": 8.6,
+ "depth_unit": 0.001,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "dof"],
+ "depth": "depth_completed.png"
+ },
+ {
+ "name": "Argoverse2",
+ "path": "path/to//Argoverse2/",
+ "label_type": "C",
+ "weight": 7.4,
+ "depth_unit": 1,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "dof"],
+ "depth": "depth_completed.png"
+ },
+ {
+ "name": "MegaDepth",
+ "path": "path/to//MegaDepth_840/",
+ "label_type": "C",
+ "weight": 5.6,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "dof"],
+ "depth": "depth_completed.png"
+ },
+ {
+ "name": "Taskonomy",
+ "path": "path/to//Taskonomy/",
+ "label_type": "B",
+ "weight": 10.0,
+ "depth_unit": 1,
+ "fov_range_relative": [0.75, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "dof"],
+ "finite_depth_mask": "only_known",
+ "depth": "depth_completed.png"
+ },
+ {
+ "name": "Waymo",
+ "path": "path/to//Waymo/",
+ "label_type": "C",
+ "weight": 6.4,
+ "depth_unit": 1,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "dof"],
+ "depth": "depth_completed.png"
+ },
+ {
+ "name": "ScanNetpp",
+ "path": "path/to//ScanNetpp/",
+ "label_type": "B",
+ "weight": 4.8,
+ "depth_unit": 1,
+ "fov_range_relative": [0.33, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "dof"],
+ "depth": "depth_completed.png"
+ },
+ {
+ "name": "BlendedMVS",
+ "path": "path/to//BlendedMVS/",
+ "label_type": "B",
+ "weight": 12.0,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "dof"],
+ "depth": "depth_completed.png"
+ },
+ {
+ "name": "ObjaverseV1",
+ "path": "path/to//ObjaverseV1/",
+ "label_type": "A",
+ "weight": 4.8,
+ "center_augmentation": 0.25,
+ "fov_range_relative": [0.7, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise"]
+ },
+ {
+ "name": "GTA-SfM",
+ "path": "path/to//GTA-SfM/",
+ "label_type": "A",
+ "weight": 2.8,
+ "depth_unit": 1,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "Hypersim",
+ "path": "path/to//Hypersim/",
+ "label_type": "A",
+ "weight": 5.0,
+ "depth_unit": 1,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"],
+ "finite_depth_mask": "only_known"
+ },
+ {
+ "name": "IRS",
+ "path": "path/to//IRS/",
+ "label_type": "A",
+ "weight": 5.6,
+ "depth_unit": 1,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"],
+ "finite_depth_mask": "only_known"
+ },
+ {
+ "name": "KenBurns",
+ "path": "path/to//KenBurns/",
+ "label_type": "A",
+ "weight": 1.6,
+ "fov_range_relative": [0.75, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "MatrixCity",
+ "path": "path/to//MatrixCity/",
+ "label_type": "A",
+ "depth_unit": 1,
+ "weight": 1.3,
+ "fov_range_relative": [0.33, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "MidAir",
+ "path": "path/to//MidAir/",
+ "label_type": "A",
+ "depth_unit": 1,
+ "weight": 4.0,
+ "fov_range_relative": [0.33, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "MVS-Synth",
+ "path": "path/to//MVS-Synth/",
+ "label_type": "A",
+ "depth_unit": 0.1,
+ "weight": 1.2,
+ "fov_range_relative": [0.33, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "Structured3D",
+ "path": "path/to//Structured3D/",
+ "label_type": "A",
+ "weight": 4.8,
+ "depth_unit": 0.001,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"],
+ "finite_depth_mask": "only_known"
+ },
+ {
+ "name": "Synthia",
+ "path": "path/to//Synthia/",
+ "label_type": "A",
+ "depth_unit": 1,
+ "weight": 1.2,
+ "fov_range_relative": [0.75, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "TartanAir",
+ "path": "path/to//TartanAir/",
+ "label_type": "A",
+ "depth_unit": 1.0,
+ "weight": 5.0,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "UrbanSyn",
+ "path": "path/to//UrbanSyn/",
+ "label_type": "A",
+ "weight": 2.1,
+ "depth_unit": 1,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "ApolloSynthetic",
+ "path": "path/to//ApolloSynthetic/",
+ "label_type": "A",
+ "weight": 4.0,
+ "depth_unit": 1,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "Synscapes",
+ "path": "path/to//Synscapes/",
+ "label_type": "A",
+ "weight": 2.0,
+ "depth_unit": 1,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "UnrealStereo4K",
+ "path": "path/to//UnrealStereo4K/",
+ "label_type": "A",
+ "weight": 1.7,
+ "depth_unit": 1,
+ "fov_range_relative": [0.33, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ },
+ {
+ "name": "EDEN",
+ "path": "path/to//EDEN/",
+ "label_type": "A",
+ "weight": 1.2,
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise", "dof"]
+ }
+ ]
+ },
+ "model_version": "v2",
+ "model": {
+ "encoder": {
+ "backbone": "dinov2_vitl14",
+ "intermediate_layers": [5, 11, 17, 23],
+ "dim_out": 1024
+ },
+ "neck": {
+ "dim_in": [1026, 2, 2, 2, 2],
+ "dim_out": null,
+ "dim_res_blocks": [1024, 256, 128, 64, 32],
+ "num_res_blocks": [0, 2, 2, 2, 0],
+ "res_block_in_norm": "none",
+ "res_block_hidden_norm": "none",
+ "resamplers": ["conv_transpose", "conv_transpose", "conv_transpose", "bilinear"]
+ },
+ "points_head": {
+ "dim_in": [1024, 256, 128, 64, 32],
+ "dim_out": [null, null, null, null, 3],
+ "dim_res_blocks": [1024, 256, 128, 64, 32],
+ "num_res_blocks": [0, 1, 1, 1, 0],
+ "res_block_in_norm": "none",
+ "res_block_hidden_norm": "none",
+ "resamplers": ["conv_transpose", "conv_transpose", "conv_transpose", "bilinear"]
+ },
+ "normal_head": {
+ "dim_in": [1024, 256, 128, 64, 32],
+ "dim_out": [null, null, null, null, 3],
+ "dim_res_blocks": [1024, 256, 128, 64, 32],
+ "num_res_blocks": [0, 1, 1, 1, 0],
+ "res_block_in_norm": "none",
+ "res_block_hidden_norm": "none",
+ "resamplers": ["conv_transpose", "conv_transpose", "conv_transpose", "bilinear"]
+ },
+ "mask_head": {
+ "dim_in": [1024, 256, 128, 64, 32],
+ "dim_out": [null, null, null, null, 1],
+ "dim_res_blocks": [1024, 256, 128, 64, 32],
+ "num_res_blocks": [0, 1, 1, 1, 0],
+ "res_block_in_norm": "none",
+ "res_block_hidden_norm": "none",
+ "resamplers": ["conv_transpose", "conv_transpose", "conv_transpose", "bilinear"]
+ },
+ "scale_head": {
+ "dims": [1024, 1024, 1024, 1]
+ },
+ "remap_output": "exp",
+ "num_tokens_range": [1200, 3600]
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": [
+ {"params": {"include": ["*"], "exclude": ["*.backbone.*"]}, "lr": 1e-4},
+ {"params": {"include": ["*.backbone.*"]}, "lr": 1e-5}
+ ]
+ },
+ "lr_scheduler": {
+ "type": "SequentialLR",
+ "params": {
+ "schedulers": [
+ {"type": "LambdaLR", "params": {"lr_lambda": ["1.0", "max(0.0, min(1.0, (epoch - 1000) / 1000))"]}},
+ {"type": "StepLR", "params": {"step_size": 25000, "gamma": 0.5}}
+ ],
+ "milestones": [2000]
+ }
+ },
+ "low_resolution_training_steps": 50000,
+ "loss": {
+ "invalid": {},
+ "A": {
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 48}},
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 24, "num_patches": 16}},
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 12, "num_patches": 256}},
+ "patch_64": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 64, "align_resolution": 6, "num_patches": 4096}},
+ "normal": {"function": "edge_loss", "weight": 1.0},
+ "normal_map": {"function": "normal_map_loss", "weight": 0.1},
+ "metric_scale": {"function": "metric_scale_loss", "weight": 0.1},
+ "mask": {"function": "mask_bce_loss", "weight": 0.1}
+ },
+ "B": {
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 48}},
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 24, "num_patches": 16}},
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 12, "num_patches": 256}},
+ "metric_scale": {"function": "metric_scale_loss", "weight": 0.1},
+ "normal": {"function": "edge_loss", "weight": 1.0},
+ "normal_map": {"function": "normal_map_loss", "weight": 0.1},
+ "mask": {"function": "mask_bce_loss", "weight": 0.1}
+ },
+ "C": {
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 48}},
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 24, "num_patches": 16}},
+ "metric_scale": {"function": "metric_scale_loss", "weight": 0.1},
+ "mask": {"function": "mask_bce_loss", "weight": 0.1}
+ }
+ }
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/docs/eval.md b/lingbotvla/models/vla/vision_models/MoGe/docs/eval.md
new file mode 100644
index 0000000000000000000000000000000000000000..a9d93e4a540c6df1c06aaa5694c8377e67ba468f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/docs/eval.md
@@ -0,0 +1,77 @@
+# Evaluation
+
+We provide a unified evaluation script that runs baselines on multiple benchmarks. It takes a baseline model and evaluation configurations, evaluates on-the-fly, and reports results instantly in a JSON file.
+
+## Benchmarks
+
+Donwload the processed datasets from [Huggingface Datasets](https://huggingface.co/datasets/Ruicheng/monocular-geometry-evaluation) and put them in the `data/eval` directory, using `huggingface-cli`:
+
+```bash
+mkdir -p data/eval
+huggingface-cli download Ruicheng/monocular-geometry-evaluation --repo-type dataset --local-dir data/eval --local-dir-use-symlinks False
+```
+
+Then unzip the downloaded files:
+
+```bash
+cd data/eval
+unzip '*.zip'
+# rm *.zip # if you don't keep the zip files
+```
+
+## Configuration
+
+See [`configs/eval/all_benchmarks.json`](../configs/eval/all_benchmarks.json) for an example of evaluation configurations on all benchmarks. You can modify this file to evaluate on different benchmarks or different baselines.
+
+## Baseline
+
+Some examples of baselines are provided in [`baselines/`](../baselines/). Pass the path to the baseline model python code to the `--baseline` argument of the evaluation script.
+
+## Run Evaluation
+
+Run the script [`moge/scripts/eval_baseline.py`](../moge/scripts/eval_baseline.py).
+For example,
+
+```bash
+# Evaluate MoGe on the 10 benchmarks
+python moge/scripts/eval_baseline.py --baseline baselines/moge.py --config configs/eval/all_benchmarks.json --output eval_output/moge.json --pretrained Ruicheng/moge-vitl --resolution_level 9
+
+# Evaluate Depth Anything V2 on the 10 benchmarks. (NOTE: affine disparity)
+python moge/scripts/eval_baseline.py --baseline baselines/da_v2.py --config configs/eval/all_benchmarks.json --output eval_output/da_v2.json
+```
+
+The `--baselies` `--input` `--output` arguments are for the inference script. The rest arguments, e.g. `--pretrained` `--resolution_level`, are custormized for loading the baseline model.
+
+Details of the arguments:
+
+```
+Usage: eval_baseline.py [OPTIONS]
+
+ Evaluation script.
+
+Options:
+ --baseline PATH Path to the baseline model python code.
+ --config PATH Path to the evaluation configurations. Defaults to
+ "configs/eval/all_benchmarks.json".
+ --output PATH Path to the output json file.
+ --oracle Use oracle mode for evaluation, i.e., use the GT intrinsics
+ input.
+ --dump_pred Dump predition results.
+ --dump_gt Dump ground truth.
+ --help Show this message and exit.
+```
+
+
+
+## Wrap a Customized Baseline
+
+Wrap any baseline method with [`moge.test.baseline.MGEBaselineInterface`](../moge/test/baseline.py).
+See [`baselines/`](../baselines/) for more examples.
+
+It is a good idea to check the correctness of the baseline implementation by running inference on a small set of images via [`moge/scripts/infer_baselines.py`](../moge/scripts/infer_baselines.py):
+
+```base
+python moge/scripts/infer_baselines.py --baseline baselines/moge.py --input example_images/ --output infer_outupt/moge --pretrained Ruicheng/moge-vitl --maps --ply
+```
+
+
diff --git a/lingbotvla/models/vla/vision_models/MoGe/docs/normal.md b/lingbotvla/models/vla/vision_models/MoGe/docs/normal.md
new file mode 100644
index 0000000000000000000000000000000000000000..8d7bc4626375b87ac3677933fd2fa7c174801f7f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/docs/normal.md
@@ -0,0 +1,16 @@
+# MoGe-2 Normal Estimation
+
+
+
+
+> NOTE: Normal estimation was implemented after the submission of the MoGe-2 paper and is therefore not included in the original publication. This feature required minimal additional effort, and we do not claim any novel technical contribution.
+
+We added a lightweight convolutional head and trained the normal output using a squared angular loss:
+
+$$
+\mathcal L_{\rm normal} = {1\over |\mathcal M|}\sum_{i\in\mathcal M} \angle (\hat{\mathbf n}_i,\mathbf n_i)^2
+$$
+
+where $\hat{\mathbf{n}}_i$ is the predicted normal, $\mathbf{n}_i$ is the ground-truth normal, and $\mathcal{M}$ denotes the set of valid pixels. For convenience, we did not collect ground-truth normal maps for training. Instead, we derived surface normals from the depth map and camera intrinsics. The resulting estimates are visually and numerically satisfactory.
diff --git a/lingbotvla/models/vla/vision_models/MoGe/docs/onnx.md b/lingbotvla/models/vla/vision_models/MoGe/docs/onnx.md
new file mode 100644
index 0000000000000000000000000000000000000000..d6f7b49126a1fe9f0057d7e09086df2467c81fa2
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/docs/onnx.md
@@ -0,0 +1,89 @@
+# MoGe ONNX Support
+
+MoGe-2 is compatible with the ONNX format (opset version ≥ 14). We have exported several models for use in ONNXRuntime or deployment on other compatible inference engines.
+
+> **Important Note:** The `.infer()` method in our PyTorch code includes some post-processing logic (e.g., recovering focal and shift and reprojection) that cannot be exported to ONNX. The ONNX model only includes the raw forward() pass, which outputs intermediate predictions (affine point map, normal map, floating point mask, metric scale). You will need to implement any required post-processing steps separately if replicating the full inference pipeline.
+
+The exported models are in **FP32** precision, with **dynamic input resolution** and **variable-length** token support. You can further optimize these models based on your target deployment platform.
+
+
+
+## Customized Exportation
+
+### Dynamic Shape & Variable Number of Tokens
+```python
+import os
+os.environ['XFORMERS_DISABLED'] = '1' # Disable xformers
+import numpy as np
+import torch
+from moge.model.v2 import MoGeModel
+
+PRETRAINED_MODEL = 'Ruicheng/moge-2-vits-normal.pt'
+ONNX_FILE = 'moge-2-vits-normal.onnx'
+
+model = MoGeModel.from_pretrained(PRETRAINED_MODEL)
+model.onnx_compatible_mode = True # Enable ONNX compatible mode
+
+torch.onnx.export(
+ model,
+ (torch.rand(1, 3, 518, 518), torch.tensor(1800)),
+ ONNX_FILE,
+ input_names=['image', 'num_tokens'],
+ output_names=['points', 'normal', 'mask', 'metric_scale'],
+ dynamic_axes={
+ 'image': {0: 'batch_size', 2: 'height', 3: 'width'},
+ },
+ opset_version=14
+)
+```
+
+### Static Shape & Fixed Number of Tokens
+
+```python
+import os
+os.environ['XFORMERS_DISABLED'] = '1' # Disable xformers
+import numpy as np
+import torch
+from moge.model.v2 import MoGeModel
+
+class MoGeStatic(MoGeModel):
+ def forward(self, image: torch.Tensor):
+ return super().forward(image, NUM_TOKENS)
+
+NUM_TOKENS = 1800
+FIXED_IMAGE_INPUT = torch.rand(1, 3, 518, 518)
+PRETRAINED_MODEL = 'Ruicheng/moge-2-vits-normal.pt'
+ONNX_FILE = 'moge-2-vits-normal.onnx'
+
+model = MoGeStatic.from_pretrained(PRETRAINED_MODEL)
+model.onnx_compatible_mode = True # Enable ONNX compatible mode
+
+torch.onnx.export(
+ model,
+ (FIXED_IMAGE_INPUT,),
+ ONNX_FILE,
+ input_names=['image'],
+ output_names=['points', 'normal', 'mask', 'metric_scale'],
+ dynamic_axes=None,
+ opset_version=14
+)
+```
diff --git a/lingbotvla/models/vla/vision_models/MoGe/docs/train.md b/lingbotvla/models/vla/vision_models/MoGe/docs/train.md
new file mode 100644
index 0000000000000000000000000000000000000000..170abb80e08ac5eb25badedc2b05138c21bb33f2
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/docs/train.md
@@ -0,0 +1,181 @@
+
+# Training
+
+This document provides instructions for training and finetuning the MoGe model.
+
+## Additional Requirements
+
+The following packages other than those listed in [`pyproject.toml`](../pyproject.toml) are required for training and finetuning the MoGe model:
+
+```
+accelerate
+sympy
+mlflow
+```
+
+## Data preparation
+
+### Dataset format
+
+Each dataset should be organized as follows:
+
+```
+somedataset
+├── .index.txt # A list of instance paths
+├── folder1
+│ ├── instance1 # Each instance is in a folder
+│ │ ├── image.jpg # RGB image.
+│ │ ├── depth.png # 16-bit depth. See moge/utils/io.py for details
+│ │ ├── meta.json # Stores "intrinsics" as a 3x3 matrix
+│ │ └── ... # Other componests such as segmentation mask, normal map etc.
+...
+```
+
+* `.index.txt` is placed at top directory to store a list of instance paths in this dataset. The dataloader will look for instances in this list. You may also use a custom split, e.g. `.train.txt`, `.val.txt` and specify it in the configuration file.
+
+* For depth images, it is recommended to use `read_depth()` and `write_depth()` in [`moge/utils/io.py`](../moge/utils/io.py) to read and write depth images. The depth is stored in logarithmic scale in 16-bit PNG format, offering a balanced precision, dynamic range and compression ratio compared to 16-bit and 32-bit EXR and linear depth formats. It also encodes `NaN` and `Inf` values for invalid depth values.
+
+* The `meta.json` should be a dictionary containing the key `intrinsics`, which are **normalized** camera parameters. You may put more metadata.
+
+* We also support reading and storing segementation masks for evaluation data (see paper evaluation of local points), which are saved in PNG format with semantic labels stored in png metadata as JSON strings. See `read_segmentation()` and `write_segmentation()` in [`moge/utils/io.py`](../moge/utils/io.py) for details.
+
+
+### Visual inspection
+
+We provide a script to visualize the data and check the data quality. It will export the instance as a PLY file for visualization of point cloud.
+
+```bash
+python moge/scripts/vis_data.py PATH_TO_INSTANCE --ply [-o SOMEWHERE_ELSE_TO_SAVE_VIS]
+```
+
+### DataLoader
+
+Our training dataloaders is customized to handle loading data, performing perspective crop, and augmentation in a multithreading pipeline. Please refer to [`moge/train/dataloader.py`](../moge/train/dataloader.py) if you have any concern.
+
+
+## Configuration
+
+See [`configs/train/v1.json`](../configs/train/v1.json) for an example configuration file. The configuration file defines the hyperparameters for training the MoGe model.
+Here is a commented configuration for reference:
+
+```json
+{
+ "data": {
+ "aspect_ratio_range": [0.5, 2.0], # Range of aspect ratio of sampled images
+ "area_range": [250000, 1000000], # Range of sampled image area in pixels
+ "clamp_max_depth": 1000.0, # Maximum far/near
+ "center_augmentation": 0.5, # Ratio of center crop augmentation
+ "fov_range_absolute": [1, 179], # Absolute range of FOV in degrees
+ "fov_range_relative": [0.01, 1.0], # Relative range of FOV to the original FOV
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring"], # List of image augmentation techniques
+ "datasets": [
+ {
+ "name": "TartanAir", # Name of the dataset. Name it as you like.
+ "path": "data/TartanAir", # Path to the dataset
+ "label_type": "synthetic", # Label type for this dataset. Losses will be applied accordingly. see "loss" config
+ "weight": 4.8, # Probability of sampling this dataset
+ "index": ".index.txt", # File name of the index file. Defaults to .index.txt
+ "depth": "depth.png", # File name of depth images. Defaults to depth.png
+ "center_augmentation": 0.25, # Below are dataset-specific hyperparameters. Overriding the global ones above.
+ "fov_range_absolute": [30, 150],
+ "fov_range_relative": [0.5, 1.0],
+ "image_augmentation": ["jittering", "jpeg_loss", "blurring", "shot_noise"]
+ }
+ ]
+ },
+ "model_version": "v1", # Model version. If you have multiple model variants, you can use this to switch between them.
+ "model": { # Model hyperparameters. Will be passed to Model __init__() as kwargs.
+ "encoder": "dinov2_vitl14",
+ "remap_output": "exp",
+ "intermediate_layers": 4,
+ "dim_upsample": [256, 128, 64],
+ "dim_times_res_block_hidden": 2,
+ "num_res_blocks": 2,
+ "num_tokens_range": [1200, 2500],
+ "last_conv_channels": 32,
+ "last_conv_size": 1
+ },
+ "optimizer": { # Reflection-like optimizer configurations. See moge.train.utils.py build_optimizer() for details.
+ "type": "AdamW",
+ "params": [
+ {"params": {"include": ["*"], "exclude": ["*backbone.*"]}, "lr": 1e-4},
+ {"params": {"include": ["*backbone.*"]}, "lr": 1e-5}
+ ]
+ },
+ "lr_scheduler": { # Reflection-like lr_scheduler configurations. See moge.train.utils.py build_lr_scheduler() for details.
+ "type": "SequentialLR",
+ "params": {
+ "schedulers": [
+ {"type": "LambdaLR", "params": {"lr_lambda": ["1.0", "max(0.0, min(1.0, (epoch - 1000) / 1000))"]}},
+ {"type": "StepLR", "params": {"step_size": 25000, "gamma": 0.5}}
+ ],
+ "milestones": [2000]
+ }
+ },
+ "low_resolution_training_steps": 50000, # Total number of low-resolution training steps. It makes the early stage training faster. Later stage training on varying size images will be slower.
+ "loss": {
+ "invalid": {}, # invalid instance due to runtime error when loading data
+ "synthetic": { # Below are loss hyperparameters
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}},
+ "patch_64": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 64, "align_resolution": 4, "num_patches": 4096}},
+ "normal": {"function": "normal_loss", "weight": 1.0},
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
+ },
+ "sfm": {
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
+ "patch_16": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 16, "align_resolution": 8, "num_patches": 256}},
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
+ },
+ "lidar": {
+ "global": {"function": "affine_invariant_global_loss", "weight": 1.0, "params": {"align_resolution": 32}},
+ "patch_4": {"function": "affine_invariant_local_loss", "weight": 1.0, "params": {"level": 4, "align_resolution": 16, "num_patches": 16}},
+ "mask": {"function": "mask_l2_loss", "weight": 1.0}
+ }
+ }
+}
+```
+
+## Run Training
+
+Launch the training script [`moge/scripts/train.py`](../moge/scripts/train.py). Note that we use [`accelerate`](https://github.com/huggingface/accelerate) for distributed training.
+
+```bash
+accelerate launch \
+ --num_processes 8 \
+ moge/scripts/train.py \
+ --config configs/train/v1.json \
+ --workspace workspace/debug \
+ --gradient_accumulation_steps 2 \
+ --batch_size_forward 2 \
+ --checkpoint latest \
+ --enable_gradient_checkpointing True \
+ --vis_every 1000 \
+ --enable_mlflow True
+```
+
+
+## Finetuning
+
+To finetune the pre-trained MoGe model, download the model checkpoint and put it in a local directory, e.g. `pretrained/moge-vitl.pt`.
+
+> NOTE: when finetuning pretrained MoGe model, a much lower learning rate is required.
+The suggested learning rate for finetuning is not greater than 1e-5 for the head and 1e-6 for the backbone.
+And the batch size is recommended to be 32 at least.
+The settings in default configuration are not optimal for specific datasets and may require further tuning.
+
+```bash
+accelerate launch \
+ --num_processes 8 \
+ moge/scripts/train.py \
+ --config configs/train/v1.json \
+ --workspace workspace/debug \
+ --gradient_accumulation_steps 2 \
+ --batch_size_forward 2 \
+ --checkpoint pretrained/moge-vitl.pt \
+ --enable_gradient_checkpointing True \
+ --vis_every 1000 \
+ --enable_mlflow True
+```
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/01_HouseIndoor.jpg b/lingbotvla/models/vla/vision_models/MoGe/example_images/01_HouseIndoor.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..eee8b1f17491b5d5602a54b257e55fe3d09a3d20
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/01_HouseIndoor.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3eb519bc68d4262af0c68166ca69e786cac5f6656a1083f4c585c4a94005c859
+size 322353
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/02_Office.jpg b/lingbotvla/models/vla/vision_models/MoGe/example_images/02_Office.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3a21eec3de0c64ed4a8ce9cc612145673882d07d
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/02_Office.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28767640002f93b703b24a34a6d75ca24b1ef093a19f52ef0f9d3b074ef68c61
+size 197508
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/03_Traffic.jpg b/lingbotvla/models/vla/vision_models/MoGe/example_images/03_Traffic.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..457784f7e0371cdf2aa5b2d37dd959dbb3bc4c36
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/03_Traffic.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4fa8b46849dd3de5b3b0a141d6aafe98e190f578ccec0c9dacc440cd8434db11
+size 1125098
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/04_BunnyCake.jpg b/lingbotvla/models/vla/vision_models/MoGe/example_images/04_BunnyCake.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7939a1073e13cf4d600ce138919265e27e11828d
Binary files /dev/null and b/lingbotvla/models/vla/vision_models/MoGe/example_images/04_BunnyCake.jpg differ
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/05_Mountain.jpg b/lingbotvla/models/vla/vision_models/MoGe/example_images/05_Mountain.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..df9c2c8686c175cfce2273d8c0254485528399de
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/05_Mountain.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:670d322f6588713f7d9c7349091de0aacb2a5b0b37c7b7433995e110fb2bcfbc
+size 665958
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/06_MaitreyaBuddha.png b/lingbotvla/models/vla/vision_models/MoGe/example_images/06_MaitreyaBuddha.png
new file mode 100644
index 0000000000000000000000000000000000000000..72193f4b66cb3d2f5583a6128bdcb5f10037d486
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/06_MaitreyaBuddha.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:396c5fd722bf5a21b931cbb70b883d6b1d5f9bab439cc426ec2f606fc2b7872d
+size 1224680
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/07_Breads.jpg b/lingbotvla/models/vla/vision_models/MoGe/example_images/07_Breads.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0029d0c04179f2863f79a5429460122d41943560
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/07_Breads.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a95c2cab81412e252ee5a56a6100df31bb83de0f117607ca8476478f7f152a7b
+size 156435
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/08_CatGirl.png b/lingbotvla/models/vla/vision_models/MoGe/example_images/08_CatGirl.png
new file mode 100644
index 0000000000000000000000000000000000000000..664ef2a6bf02e1c1720f1ed19e00e57d2a839927
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/08_CatGirl.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:57fa6d587d598e7a428e8997b86d5c3a06e0e18529bfad8bab78ae03a1f5820f
+size 1689759
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/09_Restaurant.jpg b/lingbotvla/models/vla/vision_models/MoGe/example_images/09_Restaurant.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..87aa321a35339878b095c791e2f90aa49c0ba6be
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/09_Restaurant.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2bb7b5a1e91a174101109b0976b8ae2a4d6bb7d6eadad6569106ed102d0d5a6
+size 794391
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/10_MedievalVillage.jpg b/lingbotvla/models/vla/vision_models/MoGe/example_images/10_MedievalVillage.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9eb958edb1b7a632bc91a4087acee52c8b557005
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/10_MedievalVillage.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:718ed1aeb1e0010194c5cf0e95371e6a29d45b84e93efbed63ff4cc60e74508b
+size 465285
diff --git a/lingbotvla/models/vla/vision_models/MoGe/example_images/panorama/Braunschweig_Panoram.jpg b/lingbotvla/models/vla/vision_models/MoGe/example_images/panorama/Braunschweig_Panoram.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..847fe2715173a2569dda1203e3e68ec85150b607
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/example_images/panorama/Braunschweig_Panoram.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:abc31b78f03a0b5254f3735bc3201c28d21b6855708f971ce4b6a740dfbddcba
+size 562674
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c919e3be42c0005752e8c800129bd5f724b47ff9
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/__init__.py
@@ -0,0 +1,18 @@
+import importlib
+from typing import *
+
+if TYPE_CHECKING:
+ from .v1 import MoGeModel as MoGeModelV1
+ from .v2 import MoGeModel as MoGeModelV2
+
+
+def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]:
+ assert version in ['v1', 'v2'], f'Unsupported model version: {version}'
+
+ try:
+ module = importlib.import_module(f'.{version}', __package__)
+ except ModuleNotFoundError:
+ raise ValueError(f'Model version "{version}" not found.')
+
+ cls = getattr(module, 'MoGeModel')
+ return cls
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+__version__ = "0.0.1"
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/hub/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/hub/backbones.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/hub/backbones.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+
+ return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/hub/utils.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/hub/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+ output = F.pad(x, pads)
+ return output
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/attention.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9f79d471fc099b1dcaa512dfdbdec8a9fc5908f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/attention.py
@@ -0,0 +1,100 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+import torch.nn.functional as F
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Attention)")
+ else:
+ # warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ # # Deprecated implementation, extremely slow
+ # def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ # B, N, C = x.shape
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ # q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ # attn = q @ k.transpose(-2, -1)
+ # attn = attn.softmax(dim=-1)
+ # attn = self.attn_drop(attn)
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ # x = self.proj(x)
+ # x = self.proj_drop(x)
+ # return x
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
+
+ q, k, v = qkv.unbind(0) # (B, H, N, C // H)
+
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/block.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/block.py
@@ -0,0 +1,259 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Block)")
+ else:
+ # warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/dino_head.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/dino_head.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/drop_path.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/layer_scale.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/mlp.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/patch_embed.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/swiglu_ffn.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ # warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ # warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/models/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/models/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+from . import vision_transformer as vits
+
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/models/vision_transformer.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0bed9d0b7cdcff2b5e129121251c58e41c4c61d
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/models/vision_transformer.py
@@ -0,0 +1,407 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable, Optional, List
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ @property
+ def onnx_compatible_mode(self):
+ return getattr(self, "_onnx_compatible_mode", False)
+
+ @onnx_compatible_mode.setter
+ def onnx_compatible_mode(self, value: bool):
+ self._onnx_compatible_mode = value
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, h, w):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ batch_size = x.shape[0]
+ N = self.pos_embed.shape[1] - 1
+ if not self.onnx_compatible_mode and npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0, :]
+ patch_pos_embed = pos_embed[:, 1:, :]
+ dim = x.shape[-1]
+ h0, w0 = h // self.patch_size, w // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if not self.onnx_compatible_mode and self.interpolate_offset > 0:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sy, sx)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (h0, w0)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+
+ assert (h0, w0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
+ return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, h, w = x.shape
+ x = self.patch_embed(x)
+
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, h, w)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks, ar in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/cluster.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/cluster.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+
+class ClusterType(Enum):
+ AWS = "aws"
+ FAIR = "fair"
+ RSC = "rsc"
+
+
+def _guess_cluster_type() -> ClusterType:
+ uname = os.uname()
+ if uname.sysname == "Linux":
+ if uname.release.endswith("-aws"):
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
+ return ClusterType.AWS
+ elif uname.nodename.startswith("rsc"):
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
+ return ClusterType.RSC
+
+ return ClusterType.FAIR
+
+
+def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
+ if cluster_type is None:
+ return _guess_cluster_type()
+
+ return cluster_type
+
+
+def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ CHECKPOINT_DIRNAMES = {
+ ClusterType.AWS: "checkpoints",
+ ClusterType.FAIR: "checkpoint",
+ ClusterType.RSC: "checkpoint/dino",
+ }
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
+
+
+def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ checkpoint_path = get_checkpoint_path(cluster_type)
+ if checkpoint_path is None:
+ return None
+
+ username = os.environ.get("USER")
+ assert username is not None
+ return checkpoint_path / username
+
+
+def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ SLURM_PARTITIONS = {
+ ClusterType.AWS: "learnlab",
+ ClusterType.FAIR: "learnlab",
+ ClusterType.RSC: "learn",
+ }
+ return SLURM_PARTITIONS[cluster_type]
+
+
+def get_slurm_executor_parameters(
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
+) -> Dict[str, Any]:
+ # create default parameters
+ params = {
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
+ "gpus_per_node": num_gpus_per_node,
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
+ "cpus_per_task": 10,
+ "nodes": nodes,
+ "slurm_partition": get_slurm_partition(cluster_type),
+ }
+ # apply cluster-specific adjustments
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type == ClusterType.AWS:
+ params["cpus_per_task"] = 12
+ del params["mem_gb"]
+ elif cluster_type == ClusterType.RSC:
+ params["cpus_per_task"] = 12
+ # set additional parameters / apply overrides
+ params.update(kwargs)
+ return params
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/config.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/config.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+import logging
+import os
+
+from omegaconf import OmegaConf
+
+import dinov2.distributed as distributed
+from dinov2.logging import setup_logging
+from dinov2.utils import utils
+from dinov2.configs import dinov2_default_config
+
+
+logger = logging.getLogger("dinov2")
+
+
+def apply_scaling_rules_to_cfg(cfg): # to fix
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
+ base_lr = cfg.optim.base_lr
+ cfg.optim.lr = base_lr
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
+ else:
+ raise NotImplementedError
+ return cfg
+
+
+def write_config(cfg, output_dir, name="config.yaml"):
+ logger.info(OmegaConf.to_yaml(cfg))
+ saved_cfg_path = os.path.join(output_dir, name)
+ with open(saved_cfg_path, "w") as f:
+ OmegaConf.save(config=cfg, f=f)
+ return saved_cfg_path
+
+
+def get_cfg_from_args(args):
+ args.output_dir = os.path.abspath(args.output_dir)
+ args.opts += [f"train.output_dir={args.output_dir}"]
+ default_cfg = OmegaConf.create(dinov2_default_config)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
+ return cfg
+
+
+def default_setup(args):
+ distributed.enable(overwrite=True)
+ seed = getattr(args, "seed", 0)
+ rank = distributed.get_global_rank()
+
+ global logger
+ setup_logging(output=args.output_dir, level=logging.INFO)
+ logger = logging.getLogger("dinov2")
+
+ utils.fix_random_seeds(seed + rank)
+ logger.info("git:\n {}\n".format(utils.get_sha()))
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg_from_args(args)
+ os.makedirs(args.output_dir, exist_ok=True)
+ default_setup(args)
+ apply_scaling_rules_to_cfg(cfg)
+ write_config(cfg, args.output_dir)
+ return cfg
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/dtype.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/dtype.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+from typing import Dict, Union
+
+import numpy as np
+import torch
+
+
+TypeSpec = Union[str, np.dtype, torch.dtype]
+
+
+_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
+ np.dtype("bool"): torch.bool,
+ np.dtype("uint8"): torch.uint8,
+ np.dtype("int8"): torch.int8,
+ np.dtype("int16"): torch.int16,
+ np.dtype("int32"): torch.int32,
+ np.dtype("int64"): torch.int64,
+ np.dtype("float16"): torch.float16,
+ np.dtype("float32"): torch.float32,
+ np.dtype("float64"): torch.float64,
+ np.dtype("complex64"): torch.complex64,
+ np.dtype("complex128"): torch.complex128,
+}
+
+
+def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
+ if isinstance(dtype, torch.dtype):
+ return dtype
+ if isinstance(dtype, str):
+ dtype = np.dtype(dtype)
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/param_groups.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/param_groups.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/param_groups.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+import logging
+
+
+logger = logging.getLogger("dinov2")
+
+
+def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
+ """
+ Calculate lr decay rate for different ViT blocks.
+ Args:
+ name (string): parameter name.
+ lr_decay_rate (float): base lr decay rate.
+ num_layers (int): number of ViT blocks.
+ Returns:
+ lr decay rate for the given parameter.
+ """
+ layer_id = num_layers + 1
+ if name.startswith("backbone") or force_is_backbone:
+ if (
+ ".pos_embed" in name
+ or ".patch_embed" in name
+ or ".mask_token" in name
+ or ".cls_token" in name
+ or ".register_tokens" in name
+ ):
+ layer_id = 0
+ elif force_is_backbone and (
+ "pos_embed" in name
+ or "patch_embed" in name
+ or "mask_token" in name
+ or "cls_token" in name
+ or "register_tokens" in name
+ ):
+ layer_id = 0
+ elif ".blocks." in name and ".residual." not in name:
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
+ elif "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
+
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
+
+
+def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
+ chunked_blocks = False
+ if hasattr(model, "n_blocks"):
+ logger.info("chunked fsdp")
+ n_blocks = model.n_blocks
+ chunked_blocks = model.chunked_blocks
+ elif hasattr(model, "blocks"):
+ logger.info("first code branch")
+ n_blocks = len(model.blocks)
+ elif hasattr(model, "backbone"):
+ logger.info("second code branch")
+ n_blocks = len(model.backbone.blocks)
+ else:
+ logger.info("else code branch")
+ n_blocks = 0
+ all_param_groups = []
+
+ for name, param in model.named_parameters():
+ name = name.replace("_fsdp_wrapped_module.", "")
+ if not param.requires_grad:
+ continue
+ decay_rate = get_vit_lr_decay_rate(
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
+ )
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
+
+ if "last_layer" in name:
+ d.update({"is_last_layer": True})
+
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
+ d.update({"wd_multiplier": 0.0})
+
+ if "patch_embed" in name:
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
+
+ all_param_groups.append(d)
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
+
+ return all_param_groups
+
+
+def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
+ fused_params_groups = defaultdict(lambda: {"params": []})
+ for d in all_params_groups:
+ identifier = ""
+ for k in keys:
+ identifier += k + str(d[k]) + "_"
+
+ for k in keys:
+ fused_params_groups[identifier][k] = d[k]
+ fused_params_groups[identifier]["params"].append(d["params"])
+
+ return fused_params_groups.values()
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/utils.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/dinov2/utils/utils.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import random
+import subprocess
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
+ else:
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
+ if checkpoint_key is not None and checkpoint_key in state_dict:
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
+ state_dict = state_dict[checkpoint_key]
+ # remove `module.` prefix
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ # remove `backbone.` prefix induced by multicrop wrapper
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
+ msg = model.load_state_dict(state_dict, strict=False)
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
+
+
+def fix_random_seeds(seed=31):
+ """
+ Fix random seeds.
+ """
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommitted changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+class CosineScheduler(object):
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
+ super().__init__()
+ self.final_value = final_value
+ self.total_iters = total_iters
+
+ freeze_schedule = np.zeros((freeze_iters))
+
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
+
+ assert len(self.schedule) == self.total_iters
+
+ def __getitem__(self, it):
+ if it >= self.total_iters:
+ return self.final_value
+ else:
+ return self.schedule[it]
+
+
+def has_batchnorms(model):
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
+ for name, module in model.named_modules():
+ if isinstance(module, bn_types):
+ return True
+ return False
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/modules.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..b36ad48d40a8715da375eb15c74416f34f4f9c04
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/modules.py
@@ -0,0 +1,254 @@
+from typing import *
+from numbers import Number
+import importlib
+import itertools
+import functools
+import sys
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .dinov2.models.vision_transformer import DinoVisionTransformer
+from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
+from ..utils.geometry_torch import normalized_view_plane_uv
+
+
+class ResidualConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = None,
+ hidden_channels: int = None,
+ kernel_size: int = 3,
+ padding_mode: str = 'replicate',
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
+ in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
+ hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
+ ):
+ super(ResidualConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ if hidden_channels is None:
+ hidden_channels = in_channels
+
+ if activation =='relu':
+ activation_cls = nn.ReLU
+ elif activation == 'leaky_relu':
+ activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
+ elif activation =='silu':
+ activation_cls = nn.SiLU
+ elif activation == 'elu':
+ activation_cls = nn.ELU
+ else:
+ raise ValueError(f'Unsupported activation function: {activation}')
+
+ self.layers = nn.Sequential(
+ nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
+ nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
+ nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
+ nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
+ nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
+ nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
+ nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
+ nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
+ )
+
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x):
+ skip = self.skip_connection(x)
+ x = self.layers(x)
+ x = x + skip
+ return x
+
+
+class DINOv2Encoder(nn.Module):
+ "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]."
+ backbone: DinoVisionTransformer
+ image_mean: torch.Tensor
+ image_std: torch.Tensor
+ dim_features: int
+
+ def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, **deprecated_kwargs):
+ super(DINOv2Encoder, self).__init__()
+
+ self.intermediate_layers = intermediate_layers
+
+ # Load the backbone
+ self.hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), backbone)
+ self.backbone_name = backbone
+ self.backbone = self.hub_loader(pretrained=False)
+
+ self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
+ self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
+
+ self.output_projections = nn.ModuleList([
+ nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,)
+ for _ in range(self.num_features)
+ ])
+
+ self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ @property
+ def onnx_compatible_mode(self):
+ return getattr(self, "_onnx_compatible_mode", False)
+
+ @onnx_compatible_mode.setter
+ def onnx_compatible_mode(self, value: bool):
+ self._onnx_compatible_mode = value
+ self.backbone.onnx_compatible_mode = value
+
+ def init_weights(self):
+ pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
+ self.backbone.load_state_dict(pretrained_backbone_state_dict)
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.backbone.blocks)):
+ wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
+
+ def enable_pytorch_native_sdpa(self):
+ for i in range(len(self.backbone.blocks)):
+ wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
+
+ def forward(self, image: torch.Tensor, token_rows: Union[int, torch.LongTensor], token_cols: Union[int, torch.LongTensor], return_class_token: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
+ image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode)
+ image_14 = (image_14 - self.image_mean) / self.image_std
+
+ # Get intermediate layers from the backbone
+ features = self.backbone.get_intermediate_layers(image_14, n=self.intermediate_layers, return_class_token=True)
+
+ # Project features to the desired dimensionality
+ x = torch.stack([
+ proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
+ for proj, (feat, clstoken) in zip(self.output_projections, features)
+ ], dim=1).sum(dim=1)
+
+ if return_class_token:
+ return x, features[-1][1]
+ else:
+ return x
+
+
+class Resampler(nn.Sequential):
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
+ scale_factor: int = 2,
+ ):
+ if type_ == 'pixel_shuffle':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.PixelShuffle(scale_factor),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ for i in range(1, scale_factor ** 2):
+ self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
+ self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
+ elif type_ in ['nearest', 'bilinear']:
+ nn.Sequential.__init__(self,
+ nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ elif type_ == 'conv_transpose':
+ nn.Sequential.__init__(self,
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
+ elif type_ == 'pixel_unshuffle':
+ nn.Sequential.__init__(self,
+ nn.PixelUnshuffle(scale_factor),
+ nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ elif type_ == 'avg_pool':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ elif type_ == 'max_pool':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ else:
+ raise ValueError(f'Unsupported resampler type: {type_}')
+
+class MLP(nn.Sequential):
+ def __init__(self, dims: Sequence[int]):
+ nn.Sequential.__init__(self,
+ *itertools.chain(*[
+ (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
+ for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
+ ]),
+ nn.Linear(dims[-2], dims[-1]),
+ )
+
+
+class ConvStack(nn.Module):
+ def __init__(self,
+ dim_in: List[Optional[int]],
+ dim_res_blocks: List[int],
+ dim_out: List[Optional[int]],
+ resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
+ res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
+ ):
+ super().__init__()
+ self.input_blocks = nn.ModuleList([
+ nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
+ for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
+ ])
+ self.resamplers = nn.ModuleList([
+ Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
+ for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
+ dim_res_blocks[:-1],
+ dim_res_blocks[1:],
+ resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
+ ))
+ ])
+ self.res_blocks = nn.ModuleList([
+ nn.Sequential(
+ *(
+ ResidualConvBlock(
+ dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
+ activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
+ ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
+ )
+ ) for i, dim_res_block_ in enumerate(dim_res_blocks)
+ ])
+ self.output_blocks = nn.ModuleList([
+ nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
+ for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
+ ])
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.resamplers)):
+ self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
+ for i in range(len(self.res_blocks)):
+ for j in range(len(self.res_blocks[i])):
+ self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
+
+ def forward(self, in_features: List[torch.Tensor]):
+ out_features = []
+ for i in range(len(self.res_blocks)):
+ feature = self.input_blocks[i](in_features[i])
+ if i == 0:
+ x = feature
+ elif feature is not None:
+ x = x + feature
+ x = self.res_blocks[i](x)
+ out_features.append(self.output_blocks[i](x))
+ if i < len(self.res_blocks) - 1:
+ x = self.resamplers[i](x)
+ return out_features
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/utils.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c50761d8740d9d0a0284e129503b8931c6fe08c4
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/utils.py
@@ -0,0 +1,49 @@
+from typing import *
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def wrap_module_with_gradient_checkpointing(module: nn.Module):
+ from torch.utils.checkpoint import checkpoint
+ class _CheckpointingWrapper(module.__class__):
+ _restore_cls = module.__class__
+ def forward(self, *args, **kwargs):
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
+
+ module.__class__ = _CheckpointingWrapper
+ return module
+
+
+def unwrap_module_with_gradient_checkpointing(module: nn.Module):
+ module.__class__ = module.__class__._restore_cls
+
+
+def wrap_dinov2_attention_with_sdpa(module: nn.Module):
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
+ class _AttentionWrapper(module.__class__):
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
+
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
+
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+ module.__class__ = _AttentionWrapper
+ return module
+
+
+def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]:
+ group_to_use = torch.distributed.group.WORLD
+ world_size = group_to_use.size()
+ grad = bucket.buffer()
+ grad.div_(world_size)
+ torch.distributed.all_reduce(grad, group=group_to_use)
+ fut = torch.futures.Future()
+ fut.set_result(grad)
+ return fut
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/v1.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..2513b863252e62b124253100ccc7f54c534949f3
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/v1.py
@@ -0,0 +1,392 @@
+from typing import *
+from numbers import Number
+from functools import partial
+from pathlib import Path
+import importlib
+import warnings
+import json
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.version
+import utils3d
+from huggingface_hub import hf_hub_download
+
+
+from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d, dilate_with_mask
+from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
+from ..utils.tools import timeit
+
+
+class ResidualConvBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
+ super(ResidualConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ if hidden_channels is None:
+ hidden_channels = in_channels
+
+ if activation =='relu':
+ activation_cls = lambda: nn.ReLU(inplace=True)
+ elif activation == 'leaky_relu':
+ activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ elif activation =='silu':
+ activation_cls = lambda: nn.SiLU(inplace=True)
+ elif activation == 'elu':
+ activation_cls = lambda: nn.ELU(inplace=True)
+ else:
+ raise ValueError(f'Unsupported activation function: {activation}')
+
+ self.layers = nn.Sequential(
+ nn.GroupNorm(1, in_channels),
+ activation_cls(),
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
+ nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
+ activation_cls(),
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
+ )
+
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x):
+ skip = self.skip_connection(x)
+ x = self.layers(x)
+ x = x + skip
+ return x
+
+
+class Head(nn.Module):
+ def __init__(
+ self,
+ num_features: int,
+ dim_in: int,
+ dim_out: List[int],
+ dim_proj: int = 512,
+ dim_upsample: List[int] = [256, 128, 128],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
+ last_res_blocks: int = 0,
+ last_conv_channels: int = 32,
+ last_conv_size: int = 1
+ ):
+ super().__init__()
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
+ ])
+
+ self.upsample_blocks = nn.ModuleList([
+ nn.Sequential(
+ self._make_upsampler(in_ch + 2, out_ch),
+ *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
+ ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
+ ])
+
+ self.output_block = nn.ModuleList([
+ self._make_output_block(
+ dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,
+ ) for dim_out_ in dim_out
+ ])
+
+ def _make_upsampler(self, in_channels: int, out_channels: int):
+ upsampler = nn.Sequential(
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
+ return upsampler
+
+ def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
+ return nn.Sequential(
+ nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
+ )
+
+ def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
+ img_h, img_w = image.shape[-2:]
+ patch_h, patch_w = img_h // 14, img_w // 14
+
+ # Process the hidden states
+ x = torch.stack([
+ proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
+ for proj, (feat, clstoken) in zip(self.projects, hidden_states)
+ ], dim=1).sum(dim=1)
+
+ # Upsample stage
+ # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
+ for i, block in enumerate(self.upsample_blocks):
+ # UV coordinates is for awareness of image aspect ratio
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
+ x = torch.cat([x, uv], dim=1)
+ for layer in block:
+ x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
+
+ # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
+ x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
+ x = torch.cat([x, uv], dim=1)
+
+ if isinstance(self.output_block, nn.ModuleList):
+ output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block]
+ else:
+ output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False)
+
+ return output
+
+
+class MoGeModel(nn.Module):
+ image_mean: torch.Tensor
+ image_std: torch.Tensor
+
+ def __init__(self,
+ encoder: str = 'dinov2_vitb14',
+ intermediate_layers: Union[int, List[int]] = 4,
+ dim_proj: int = 512,
+ dim_upsample: List[int] = [256, 128, 128],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
+ num_tokens_range: Tuple[Number, Number] = [1200, 2500],
+ last_res_blocks: int = 0,
+ last_conv_channels: int = 32,
+ last_conv_size: int = 1,
+ mask_threshold: float = 0.5,
+ **deprecated_kwargs
+ ):
+ super(MoGeModel, self).__init__()
+
+ if deprecated_kwargs:
+ # Process legacy arguments
+ if 'trained_area_range' in deprecated_kwargs:
+ num_tokens_range = [deprecated_kwargs['trained_area_range'][0] // 14 ** 2, deprecated_kwargs['trained_area_range'][1] // 14 ** 2]
+ del deprecated_kwargs['trained_area_range']
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
+
+ self.encoder = encoder
+ self.remap_output = remap_output
+ self.intermediate_layers = intermediate_layers
+ self.num_tokens_range = num_tokens_range
+ self.mask_threshold = mask_threshold
+
+ # NOTE: We have copied the DINOv2 code in torchhub to this repository.
+ # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
+ hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
+ self.backbone = hub_loader(pretrained=False)
+ dim_feature = self.backbone.blocks[0].attn.qkv.in_features
+
+ self.head = Head(
+ num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
+ dim_in=dim_feature,
+ dim_out=[3, 1],
+ dim_proj=dim_proj,
+ dim_upsample=dim_upsample,
+ dim_times_res_block_hidden=dim_times_res_block_hidden,
+ num_res_blocks=num_res_blocks,
+ res_block_norm=res_block_norm,
+ last_res_blocks=last_res_blocks,
+ last_conv_channels=last_conv_channels,
+ last_conv_size=last_conv_size
+ )
+
+ image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
+ image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
+
+ self.register_buffer("image_mean", image_mean)
+ self.register_buffer("image_std", image_std)
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
+ """
+ Load a model from a checkpoint file.
+
+ ### Parameters:
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
+
+ ### Returns:
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
+ """
+ if Path(pretrained_model_name_or_path).exists():
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
+ else:
+ cached_checkpoint_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path,
+ repo_type="model",
+ filename="model.pt",
+ **hf_kwargs
+ )
+ checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
+ model_config = checkpoint['model_config']
+ if model_kwargs is not None:
+ model_config.update(model_kwargs)
+ model = cls(**model_config)
+ model.load_state_dict(checkpoint['model'])
+ return model
+
+ def init_weights(self):
+ "Load the backbone with pretrained dinov2 weights from torch hub"
+ state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
+ self.backbone.load_state_dict(state_dict)
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.backbone.blocks)):
+ self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
+
+ def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
+ if self.remap_output == 'linear':
+ pass
+ elif self.remap_output =='sinh':
+ points = torch.sinh(points)
+ elif self.remap_output == 'exp':
+ xy, z = points.split([2, 1], dim=-1)
+ z = torch.exp(z)
+ points = torch.cat([xy * z, z], dim=-1)
+ elif self.remap_output =='sinh_exp':
+ xy, z = points.split([2, 1], dim=-1)
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
+ else:
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
+ return points
+
+ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
+ original_height, original_width = image.shape[-2:]
+
+ # Resize to expected resolution defined by num_tokens
+ resize_factor = ((num_tokens * 14 ** 2) / (original_height * original_width)) ** 0.5
+ resized_width, resized_height = int(original_width * resize_factor), int(original_height * resize_factor)
+ image = F.interpolate(image, (resized_height, resized_width), mode="bicubic", align_corners=False, antialias=True)
+
+ # Apply image transformation for DINOv2
+ image = (image - self.image_mean) / self.image_std
+ image_14 = F.interpolate(image, (resized_height // 14 * 14, resized_width // 14 * 14), mode="bilinear", align_corners=False, antialias=True)
+
+ # Get intermediate layers from the backbone
+ features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
+
+ # Predict points (and mask)
+ output = self.head(features, image)
+ points, mask = output
+
+ # Make sure fp32 precision for output
+ with torch.autocast(device_type=image.device.type, dtype=torch.float32):
+ # Resize to original resolution
+ points = F.interpolate(points, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
+ mask = F.interpolate(mask, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
+
+ # Post-process points and mask
+ points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
+ points = self._remap_points(points) # slightly improves the performance in case of very large output values
+
+ return_dict = {'points': points, 'mask': mask}
+ return return_dict
+
+ @torch.inference_mode()
+ def infer(
+ self,
+ image: torch.Tensor,
+ fov_x: Union[Number, torch.Tensor] = None,
+ resolution_level: int = 9,
+ num_tokens: int = None,
+ apply_mask: bool = True,
+ force_projection: bool = True,
+ use_fp16: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ User-friendly inference function
+
+ ### Parameters
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
+ - `resolution_level`: An integer [0-9] for the resolution level for inference.
+ The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size.
+ `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.
+ - `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`.
+ `resolution_level` will be ignored if `num_tokens` is provided. Default: None
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
+ - `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True
+ - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
+
+ ### Returns
+
+ A dictionary containing the following keys:
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
+ """
+ if image.dim() == 3:
+ omit_batch_dim = True
+ image = image.unsqueeze(0)
+ else:
+ omit_batch_dim = False
+ image = image.to(dtype=self.dtype, device=self.device)
+
+ original_height, original_width = image.shape[-2:]
+ aspect_ratio = original_width / original_height
+
+ if num_tokens is None:
+ min_tokens, max_tokens = self.num_tokens_range
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
+
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
+ output = self.forward(image, num_tokens)
+ points, mask = output['points'], output['mask']
+
+ # Always process the output in fp32 precision
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x])
+
+ mask_binary = mask > self.mask_threshold
+
+ # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
+ if fov_x is None:
+ focal, shift = recover_focal_shift(points, mask_binary)
+ else:
+ focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
+ if focal.ndim == 0:
+ focal = focal[None].expand(points.shape[0])
+ _, shift = recover_focal_shift(points, mask_binary, focal=focal)
+ fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
+ fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
+ intrinsics = utils3d.pt.intrinsics_from_focal_center(fx, fy, torch.tensor(0.5, device=points.device, dtype=points.dtype), torch.tensor(0.5, device=points.device, dtype=points.dtype))
+ depth = points[..., 2] + shift[..., None, None]
+
+ # If projection constraint is forced, recompute the point map using the actual depth map
+ if force_projection:
+ points = utils3d.pt.depth_map_to_point_map(depth, intrinsics=intrinsics)
+ else:
+ points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]
+
+ # Apply mask if needed
+ if apply_mask:
+ points = torch.where(mask_binary[..., None], points, torch.inf)
+ depth = torch.where(mask_binary, depth, torch.inf)
+
+ return_dict = {
+ 'points': points,
+ 'intrinsics': intrinsics,
+ 'depth': depth,
+ 'mask': mask_binary,
+ }
+
+ if omit_batch_dim:
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
+
+ return return_dict
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/model/v2.py b/lingbotvla/models/vla/vision_models/MoGe/moge/model/v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf802805f04db87f91b37a87f91c31d09b37fec
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/model/v2.py
@@ -0,0 +1,303 @@
+from typing import *
+from numbers import Number
+from functools import partial
+from pathlib import Path
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.amp
+import torch.version
+import utils3d
+from huggingface_hub import hf_hub_download
+
+from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, angle_diff_vec3
+from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
+from .modules import DINOv2Encoder, MLP, ConvStack
+
+
+class MoGeModel(nn.Module):
+ encoder: DINOv2Encoder
+ neck: ConvStack
+ points_head: ConvStack
+ mask_head: ConvStack
+ scale_head: MLP
+ onnx_compatible_mode: bool
+
+ def __init__(self,
+ encoder: Dict[str, Any],
+ neck: Dict[str, Any],
+ points_head: Dict[str, Any] = None,
+ mask_head: Dict[str, Any] = None,
+ normal_head: Dict[str, Any] = None,
+ scale_head: Dict[str, Any] = None,
+ remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
+ num_tokens_range: List[int] = [1200, 3600],
+ **deprecated_kwargs
+ ):
+ super(MoGeModel, self).__init__()
+ if deprecated_kwargs:
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
+
+ self.remap_output = remap_output
+ self.num_tokens_range = num_tokens_range
+
+ self.encoder = DINOv2Encoder(**encoder)
+ self.neck = ConvStack(**neck)
+ if points_head is not None:
+ self.points_head = ConvStack(**points_head)
+ if mask_head is not None:
+ self.mask_head = ConvStack(**mask_head)
+ if normal_head is not None:
+ self.normal_head = ConvStack(**normal_head)
+ if scale_head is not None:
+ self.scale_head = MLP(**scale_head)
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ @property
+ def onnx_compatible_mode(self) -> bool:
+ return getattr(self, "_onnx_compatible_mode", False)
+
+ @onnx_compatible_mode.setter
+ def onnx_compatible_mode(self, value: bool):
+ self._onnx_compatible_mode = value
+ self.encoder.onnx_compatible_mode = value
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
+ """
+ Load a model from a checkpoint file.
+
+ ### Parameters:
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
+ - `compiled`
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
+
+ ### Returns:
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
+ """
+ if Path(pretrained_model_name_or_path).exists():
+ checkpoint_path = pretrained_model_name_or_path
+ else:
+ checkpoint_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path,
+ repo_type="model",
+ filename="model.pt",
+ **hf_kwargs
+ )
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
+
+ model_config = checkpoint['model_config']
+ if model_kwargs is not None:
+ model_config.update(model_kwargs)
+ model = cls(**model_config)
+ model.load_state_dict(checkpoint['model'], strict=False)
+
+ return model
+
+ def init_weights(self):
+ self.encoder.init_weights()
+
+ def enable_gradient_checkpointing(self):
+ self.encoder.enable_gradient_checkpointing()
+ self.neck.enable_gradient_checkpointing()
+ for head in ['points_head', 'normal_head', 'mask_head']:
+ if hasattr(self, head):
+ getattr(self, head).enable_gradient_checkpointing()
+
+ def enable_pytorch_native_sdpa(self):
+ self.encoder.enable_pytorch_native_sdpa()
+
+ def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
+ if self.remap_output == 'linear':
+ pass
+ elif self.remap_output =='sinh':
+ points = torch.sinh(points)
+ elif self.remap_output == 'exp':
+ xy, z = points.split([2, 1], dim=-1)
+ z = torch.exp(z)
+ points = torch.cat([xy * z, z], dim=-1)
+ elif self.remap_output =='sinh_exp':
+ xy, z = points.split([2, 1], dim=-1)
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
+ else:
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
+ return points
+
+ def forward(self, image: torch.Tensor, num_tokens: Union[int, torch.LongTensor]) -> Dict[str, torch.Tensor]:
+ batch_size, _, img_h, img_w = image.shape
+ device, dtype = image.device, image.dtype
+
+ aspect_ratio = img_w / img_h
+ base_h, base_w = (num_tokens / aspect_ratio) ** 0.5, (num_tokens * aspect_ratio) ** 0.5
+ if isinstance(base_h, torch.Tensor):
+ base_h, base_w = base_h.round().long(), base_w.round().long()
+ else:
+ base_h, base_w = round(base_h), round(base_w)
+
+ # Backbones encoding
+ features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
+ features = [features, None, None, None, None]
+
+ # Concat UVs for aspect ratio input
+ for level in range(5):
+ uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
+ if features[level] is None:
+ features[level] = uv
+ else:
+ features[level] = torch.concat([features[level], uv], dim=1)
+
+ # Shared neck
+ features = self.neck(features)
+
+ # Heads decoding
+ points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head'])
+ metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None
+
+ # Resize
+ points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask])
+
+ # Remap output
+ if points is not None:
+ points = points.permute(0, 2, 3, 1)
+ points = self._remap_points(points) # slightly improves the performance in case of very large output values
+ if normal is not None:
+ normal = normal.permute(0, 2, 3, 1)
+ normal = F.normalize(normal, dim=-1)
+ if mask is not None:
+ mask = mask.squeeze(1).sigmoid()
+ if metric_scale is not None:
+ metric_scale = metric_scale.squeeze(1).exp()
+
+ return_dict = {
+ 'points': points,
+ 'normal': normal,
+ 'mask': mask,
+ 'metric_scale': metric_scale
+ }
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
+
+ return return_dict
+
+ @torch.inference_mode()
+ def infer(
+ self,
+ image: torch.Tensor,
+ num_tokens: int = None,
+ resolution_level: int = 9,
+ force_projection: bool = True,
+ apply_mask: bool = True,
+ fov_x: Optional[Union[Number, torch.Tensor]] = None,
+ use_fp16: bool = True,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ User-friendly inference function
+
+ ### Parameters
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
+ - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500.
+ More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`.
+ - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
+ - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
+
+ ### Returns
+
+ A dictionary containing the following keys:
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
+ """
+ if image.dim() == 3:
+ omit_batch_dim = True
+ image = image.unsqueeze(0)
+ else:
+ omit_batch_dim = False
+ image = image.to(dtype=self.dtype, device=self.device)
+
+ original_height, original_width = image.shape[-2:]
+ area = original_height * original_width
+ aspect_ratio = original_width / original_height
+
+ # Determine the number of base tokens to use
+ if num_tokens is None:
+ min_tokens, max_tokens = self.num_tokens_range
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
+
+ # Forward pass
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
+ output = self.forward(image, num_tokens=num_tokens)
+ points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale'])
+
+ # Always process the output in fp32 precision
+ points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x])
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ if mask is not None:
+ mask_binary = mask > 0.5
+ else:
+ mask_binary = None
+
+ if points is not None:
+ # Convert affine point map to camera-space. Recover depth and intrinsics from point map.
+ # NOTE: Focal here is the focal length relative to half the image diagonal
+ if fov_x is None:
+ # Recover focal and shift from predicted point map
+ focal, shift = recover_focal_shift(points, mask_binary)
+ else:
+ # Focal is known, recover shift only
+ focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
+ if focal.ndim == 0:
+ focal = focal[None].expand(points.shape[0])
+ _, shift = recover_focal_shift(points, mask_binary, focal=focal)
+ fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
+ intrinsics = utils3d.pt.intrinsics_from_focal_center(fx, fy, torch.tensor(0.5, device=points.device, dtype=points.dtype), torch.tensor(0.5, device=points.device, dtype=points.dtype))
+ points[..., 2] += shift[..., None, None]
+ if mask_binary is not None:
+ mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice)
+ depth = points[..., 2].clone()
+ else:
+ depth, intrinsics = None, None
+
+ # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
+ if force_projection and depth is not None:
+ points = utils3d.pt.depth_map_to_point_map(depth, intrinsics=intrinsics)
+
+ # Apply metric scale
+ if metric_scale is not None:
+ if points is not None:
+ points *= metric_scale[:, None, None, None]
+ if depth is not None:
+ depth *= metric_scale[:, None, None]
+
+ # Apply mask
+ if apply_mask and mask_binary is not None:
+ points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
+ depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
+ normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None
+
+ return_dict = {
+ 'points': points,
+ 'intrinsics': intrinsics,
+ 'depth': depth,
+ 'mask': mask_binary,
+ 'normal': normal
+ }
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
+
+ if omit_batch_dim:
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
+
+ return return_dict
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/app.py b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a63e626b13f98ad1e790f6eb59da79d5a196c19
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/app.py
@@ -0,0 +1,301 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+import sys
+from pathlib import Path
+if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
+ sys.path.insert(0, _package_root)
+import time
+import uuid
+import tempfile
+import itertools
+from typing import *
+import atexit
+from concurrent.futures import ThreadPoolExecutor
+import shutil
+
+import click
+
+
+@click.command(help='Web demo')
+@click.option('--share', is_flag=True, help='Whether to run the app in shared mode.')
+@click.option('--pretrained', 'pretrained_model_name_or_path', default=None, help='The name or path of the pre-trained model.')
+@click.option('--version', 'model_version', default='v2', help='The version of the model.')
+@click.option('--fp16', 'use_fp16', is_flag=True, help='Whether to use fp16 inference.')
+def main(share: bool, pretrained_model_name_or_path: str, model_version: str, use_fp16: bool):
+ print("Import modules...")
+ # Lazy import
+ import cv2
+ import torch
+ import numpy as np
+ import trimesh
+ import trimesh.visual
+ from PIL import Image
+ import gradio as gr
+ try:
+ import spaces # This is for deployment at huggingface.co/spaces
+ HUGGINFACE_SPACES_INSTALLED = True
+ except ImportError:
+ HUGGINFACE_SPACES_INSTALLED = False
+
+ import utils3d
+ from moge.utils.io import write_normal
+ from moge.utils.vis import colorize_depth, colorize_normal
+ from moge.model import import_model_class_by_version
+ from moge.utils.geometry_numpy import depth_occlusion_edge_numpy
+ from moge.utils.tools import timeit
+
+ print("Load model...")
+ if pretrained_model_name_or_path is None:
+ DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = {
+ "v1": "Ruicheng/moge-vitl",
+ "v2": "Ruicheng/moge-2-vitl-normal",
+ }
+ pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version]
+ model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).cuda().eval()
+ if use_fp16:
+ model.half()
+ thread_pool_executor = ThreadPoolExecutor(max_workers=1)
+
+ def delete_later(path: Union[str, os.PathLike], delay: int = 300):
+ def _delete():
+ try:
+ os.remove(path)
+ except FileNotFoundError:
+ pass
+ def _wait_and_delete():
+ time.sleep(delay)
+ _delete(path)
+ thread_pool_executor.submit(_wait_and_delete)
+ atexit.register(_delete)
+
+ # Inference on GPU.
+ @(spaces.GPU if HUGGINFACE_SPACES_INSTALLED else lambda x: x)
+ def run_with_gpu(image: np.ndarray, resolution_level: int, apply_mask: bool) -> Dict[str, np.ndarray]:
+ image_tensor = torch.tensor(image, dtype=torch.float32 if not use_fp16 else torch.float16, device=torch.device('cuda')).permute(2, 0, 1) / 255
+ output = model.infer(image_tensor, apply_mask=apply_mask, resolution_level=resolution_level, use_fp16=use_fp16)
+ output = {k: v.cpu().numpy() for k, v in output.items()}
+ return output
+
+ # Full inference pipeline
+ def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', apply_mask: bool = True, remove_edge: bool = True, request: gr.Request = None):
+ larger_size = max(image.shape[:2])
+ if larger_size > max_size:
+ scale = max_size / larger_size
+ image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
+
+ height, width = image.shape[:2]
+
+ resolution_level_int = {'Low': 0, 'Medium': 5, 'High': 9, 'Ultra': 30}.get(resolution_level, 9)
+ output = run_with_gpu(image, resolution_level_int, apply_mask)
+
+ points, depth, mask, normal = output['points'], output['depth'], output['mask'], output.get('normal', None)
+
+ if remove_edge:
+ mask_cleaned = mask & ~utils3d.np.depth_map_edge(depth, rtol=0.04)
+ else:
+ mask_cleaned = mask
+
+ results = {
+ **output,
+ 'mask_cleaned': mask_cleaned,
+ 'image': image
+ }
+
+ # depth & normal visualization
+ depth_vis = colorize_depth(depth)
+ if normal is not None:
+ normal_vis = colorize_normal(normal)
+ else:
+ normal_vis = gr.update(label="Normal map (not avalable for this model)")
+
+ # mesh & pointcloud
+ if normal is None:
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.np.build_mesh_from_map(
+ points,
+ image.astype(np.float32) / 255,
+ utils3d.np.uv_map(height, width),
+ mask=mask_cleaned,
+ tri=True
+ )
+ vertex_normals = None
+ else:
+ faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.np.build_mesh_from_map(
+ points,
+ image.astype(np.float32) / 255,
+ utils3d.np.uv_map(height, width),
+ normal,
+ mask=mask_cleaned,
+ tri=True
+ )
+ vertices = vertices * np.array([1, -1, -1], dtype=np.float32)
+ vertex_uvs = vertex_uvs * np.array([1, -1], dtype=np.float32) + np.array([0, 1], dtype=np.float32)
+ if vertex_normals is not None:
+ vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32)
+
+ tempdir = Path(tempfile.gettempdir(), 'moge')
+ tempdir.mkdir(exist_ok=True)
+ output_path = Path(tempdir, request.session_hash)
+ shutil.rmtree(output_path, ignore_errors=True)
+ output_path.mkdir(exist_ok=True, parents=True)
+ trimesh.Trimesh(
+ vertices=vertices,
+ faces=faces,
+ visual = trimesh.visual.texture.TextureVisuals(
+ uv=vertex_uvs,
+ material=trimesh.visual.material.PBRMaterial(
+ baseColorTexture=Image.fromarray(image),
+ metallicFactor=0.5,
+ roughnessFactor=1.0
+ )
+ ),
+ vertex_normals=vertex_normals,
+ process=False
+ ).export(output_path / 'mesh.glb')
+ pointcloud = trimesh.PointCloud(
+ vertices=vertices,
+ colors=vertex_colors,
+ )
+ pointcloud.vertex_normals = vertex_normals
+ pointcloud.export(output_path / 'pointcloud.ply', vertex_normal=True)
+ trimesh.PointCloud(
+ vertices=vertices,
+ colors=vertex_colors,
+ ).export(output_path / 'pointcloud.glb', include_normals=True)
+ cv2.imwrite(str(output_path /'mask.png'), mask.astype(np.uint8) * 255)
+ cv2.imwrite(str(output_path / 'depth.exr'), depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+ cv2.imwrite(str(output_path / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+ if normal is not None:
+ cv2.imwrite(str(output_path / 'normal.exr'), cv2.cvtColor(normal.astype(np.float32) * np.array([1, -1, -1], dtype=np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
+
+ files = ['mesh.glb', 'pointcloud.ply', 'depth.exr', 'points.exr', 'mask.png']
+ if normal is not None:
+ files.append('normal.exr')
+
+ for f in files:
+ delete_later(output_path / f)
+
+ # FOV
+ intrinsics = results['intrinsics']
+ fov_x, fov_y = utils3d.np.intrinsics_to_fov(intrinsics)
+ fov_x, fov_y = np.rad2deg([fov_x, fov_y])
+
+ # messages
+ viewer_message = f'**Note:** Inference has been completed. It may take a few seconds to download the 3D model.'
+ if resolution_level != 'Ultra':
+ depth_message = f'**Note:** Want sharper depth map? Try increasing the `maximum image size` and setting the `inference resolution level` to `Ultra` in the settings.'
+ else:
+ depth_message = ""
+
+ return (
+ results,
+ depth_vis,
+ normal_vis,
+ output_path / 'pointcloud.glb',
+ [(output_path / f).as_posix() for f in files if (output_path / f).exists()],
+ f'- **Horizontal FOV: {fov_x:.1f}°**. \n - **Vertical FOV: {fov_y:.1f}°**',
+ viewer_message,
+ depth_message
+ )
+
+ def reset_measure(results: Dict[str, np.ndarray]):
+ return [results['image'], [], ""]
+
+
+ def measure(results: Dict[str, np.ndarray], measure_points: List[Tuple[int, int]], event: gr.SelectData):
+ point2d = event.index[0], event.index[1]
+ measure_points.append(point2d)
+
+ image = results['image'].copy()
+ for p in measure_points:
+ image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
+
+ depth_text = ""
+ for i, p in enumerate(measure_points):
+ d = results['depth'][p[1], p[0]]
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
+
+ if len(measure_points) == 2:
+ point1, point2 = measure_points
+ image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
+ distance = np.linalg.norm(results['points'][point1[1], point1[0]] - results['points'][point2[1], point2[0]])
+ measure_points = []
+
+ distance_text = f"- **Distance: {distance:.2f}m**"
+
+ text = depth_text + distance_text
+ return [image, measure_points, text]
+ else:
+ return [image, measure_points, depth_text]
+
+ print("Create Gradio app...")
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
+ gr.Markdown(
+f'''
+
+
Turn a 2D image into 3D with MoGe
+
+''')
+ results = gr.State(value=None)
+ measure_points = gr.State(value=[])
+
+ with gr.Row():
+ with gr.Column():
+ input_image = gr.Image(type="numpy", image_mode="RGB", label="Input Image")
+ with gr.Accordion(label="Settings", open=False):
+ max_size_input = gr.Number(value=800, label="Maximum Image Size", precision=0, minimum=256, maximum=2048)
+ resolution_level = gr.Dropdown(['Low', 'Medium', 'High', 'Ultra'], label="Inference Resolution Level", value='High')
+ apply_mask = gr.Checkbox(value=True, label="Apply mask")
+ remove_edges = gr.Checkbox(value=True, label="Remove edges")
+ submit_btn = gr.Button("Submit", variant='primary')
+
+ with gr.Column():
+ with gr.Tabs():
+ with gr.Tab("3D View"):
+ viewer_message = gr.Markdown("")
+ model_3d = gr.Model3D(display_mode="solid", label="3D Point Map", clear_color=[1.0, 1.0, 1.0, 1.0], height="60vh")
+ fov = gr.Markdown()
+ with gr.Tab("Depth"):
+ depth_message = gr.Markdown("")
+ depth_map = gr.Image(type="numpy", label="Colorized Depth Map", format='png', interactive=False)
+ with gr.Tab("Normal", interactive=hasattr(model, 'normal_head')):
+ normal_map = gr.Image(type="numpy", label="Normal Map", format='png', interactive=False)
+ with gr.Tab("Measure", interactive=hasattr(model, 'scale_head')):
+ gr.Markdown("### Click on the image to measure the distance between two points. \n"
+ "**Note:** Metric scale is most reliable for typical indoor or street scenes, and may degrade for contents unfamiliar to the model (e.g., stylized or close-up images).")
+ measure_image = gr.Image(type="numpy", show_label=False, format='webp', interactive=False, sources=[])
+ measure_text = gr.Markdown("")
+ with gr.Tab("Download"):
+ files = gr.File(type='filepath', label="Output Files")
+
+ if Path('example_images').exists():
+ example_image_paths = sorted(list(itertools.chain(*[Path('example_images').glob(f'*.{ext}') for ext in ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']])))
+ examples = gr.Examples(
+ examples = example_image_paths,
+ inputs=input_image,
+ label="Examples"
+ )
+
+ submit_btn.click(
+ fn=lambda: [None, None, None, None, None, "", "", ""],
+ outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message]
+ ).then(
+ fn=run,
+ inputs=[input_image, max_size_input, resolution_level, apply_mask, remove_edges],
+ outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message]
+ ).then(
+ fn=reset_measure,
+ inputs=[results],
+ outputs=[measure_image, measure_points, measure_text]
+ )
+
+ measure_image.select(
+ fn=measure,
+ inputs=[results, measure_points],
+ outputs=[measure_image, measure_points, measure_text]
+ )
+
+ demo.launch(share=share)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/cli.py b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c3b9006bf56306e403f8da5b6d5068215221ee
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/cli.py
@@ -0,0 +1,27 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+from pathlib import Path
+import sys
+if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
+ sys.path.insert(0, _package_root)
+
+import click
+
+
+@click.group(help='MoGe command line interface.')
+def cli():
+ pass
+
+def main():
+ from moge.scripts import app, infer, infer_baseline, infer_panorama, eval_baseline, vis_data
+ cli.add_command(app.main, name='app')
+ cli.add_command(infer.main, name='infer')
+ cli.add_command(infer_baseline.main, name='infer_baseline')
+ cli.add_command(infer_panorama.main, name='infer_panorama')
+ cli.add_command(eval_baseline.main, name='eval_baseline')
+ cli.add_command(vis_data.main, name='vis_data')
+ cli()
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/eval_baseline.py b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/eval_baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..8217d9e6500b1d72a00e1a0a225ba4c2134b892e
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/eval_baseline.py
@@ -0,0 +1,165 @@
+import os
+import sys
+from pathlib import Path
+if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
+ sys.path.insert(0, _package_root)
+import json
+from typing import *
+import importlib
+import importlib.util
+
+import click
+
+
+@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Evaluation script.')
+@click.option('--baseline', 'baseline_code_path', type=click.Path(), required=True, help='Path to the baseline model python code.')
+@click.option('--config', 'config_path', type=click.Path(), default='configs/eval/all_benchmarks.json', help='Path to the evaluation configurations. '
+ 'Defaults to "configs/eval/all_benchmarks.json".')
+@click.option('--output', '-o', 'output_path', type=click.Path(), required=True, help='Path to the output json file.')
+@click.option('--oracle', 'oracle_mode', is_flag=True, help='Use oracle mode for evaluation, i.e., use the GT intrinsics input.')
+@click.option('--dump_pred', is_flag=True, help='Dump predition results.')
+@click.option('--dump_gt', is_flag=True, help='Dump ground truth.')
+@click.pass_context
+def main(ctx: click.Context, baseline_code_path: str, config_path: str, oracle_mode: bool, output_path: Union[str, Path], dump_pred: bool, dump_gt: bool):
+ # Lazy import
+ import cv2
+ import numpy as np
+ from tqdm import tqdm
+ import torch
+ import torch.nn.functional as F
+ import utils3d
+
+ from moge.test.baseline import MGEBaselineInterface
+ from moge.test.dataloader import EvalDataLoaderPipeline
+ from moge.test.metrics import compute_metrics
+ from moge.utils.geometry_torch import intrinsics_to_fov
+ from moge.utils.vis import colorize_depth, colorize_normal
+ from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module
+
+ # Load the baseline model
+ module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem)
+ baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline')
+ baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False)
+
+ # Load the evaluation configurations
+ with open(config_path, 'r') as f:
+ config = json.load(f)
+
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+ all_metrics = {}
+ # Iterate over the dataset
+ for benchmark_name, benchmark_config in tqdm(list(config.items()), desc='Benchmarks'):
+ filenames, metrics_list = [], []
+ with (
+ EvalDataLoaderPipeline(**benchmark_config) as eval_data_pipe,
+ tqdm(total=len(eval_data_pipe), desc=benchmark_name, leave=False) as pbar
+ ):
+ # Iterate over the samples in the dataset
+ for i in range(len(eval_data_pipe)):
+ sample = eval_data_pipe.get()
+ sample = {k: v.to(baseline.device) if isinstance(v, torch.Tensor) else v for k, v in sample.items()}
+ image = sample['image']
+ gt_intrinsics = sample['intrinsics']
+
+ # Inference
+ torch.cuda.synchronize()
+ with torch.inference_mode(), timeit('_inference_timer', verbose=False) as timer:
+ if oracle_mode:
+ pred = baseline.infer_for_evaluation(image, gt_intrinsics)
+ else:
+ pred = baseline.infer_for_evaluation(image)
+ torch.cuda.synchronize()
+
+ # Compute metrics
+ metrics, misc = compute_metrics(pred, sample, vis=dump_pred or dump_gt)
+ metrics['inference_time'] = timer.time
+ metrics_list.append(metrics)
+
+ # Dump results
+ dump_path = Path(output_path.replace(".json", f"_dump"), f'{benchmark_name}', sample['filename'].replace('.zip', ''))
+ if dump_pred:
+ dump_path.joinpath('pred').mkdir(parents=True, exist_ok=True)
+ cv2.imwrite(str(dump_path / 'pred' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
+
+ with Path(dump_path, 'pred', 'metrics.json').open('w') as f:
+ json.dump(metrics, f, indent=4)
+
+ if 'pred_points' in misc:
+ points = misc['pred_points'].cpu().numpy()
+ cv2.imwrite(str(dump_path / 'pred' / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+
+ if 'pred_depth' in misc:
+ depth = misc['pred_depth'].cpu().numpy()
+ if 'mask' in pred:
+ mask = pred['mask'].cpu().numpy()
+ depth = np.where(mask, depth, np.inf)
+ cv2.imwrite(str(dump_path / 'pred' / 'depth.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR))
+
+ if 'mask' in pred:
+ mask = pred['mask'].cpu().numpy()
+ cv2.imwrite(str(dump_path / 'pred' / 'mask.png'), (mask * 255).astype(np.uint8))
+
+ if 'normal' in pred:
+ normal = pred['normal'].cpu().numpy()
+ cv2.imwrite(str(dump_path / 'pred' / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR))
+
+ if 'intrinsics' in pred:
+ intrinsics = pred['intrinsics']
+ fov_x, fov_y = intrinsics_to_fov(intrinsics)
+ with open(dump_path / 'pred' / 'fov.json', 'w') as f:
+ json.dump({
+ 'fov_x': np.rad2deg(fov_x.item()),
+ 'fov_y': np.rad2deg(fov_y.item()),
+ 'intrinsics': intrinsics.cpu().numpy().tolist(),
+ }, f)
+
+ if dump_gt:
+ dump_path.joinpath('gt').mkdir(parents=True, exist_ok=True)
+ cv2.imwrite(str(dump_path / 'gt' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
+
+ if 'points' in sample:
+ points = sample['points']
+ cv2.imwrite(str(dump_path / 'gt' / 'points.exr'), cv2.cvtColor(points.cpu().numpy().astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+
+ if 'depth' in sample:
+ depth = sample['depth']
+ mask = sample['depth_mask']
+ cv2.imwrite(str(dump_path / 'gt' / 'depth.png'), cv2.cvtColor(colorize_depth(depth.cpu().numpy(), mask=mask.cpu().numpy()), cv2.COLOR_RGB2BGR))
+
+ if 'normal' in sample:
+ normal = sample['normal']
+ cv2.imwrite(str(dump_path / 'gt' / 'normal.png'), cv2.cvtColor(colorize_normal(normal.cpu().numpy()), cv2.COLOR_RGB2BGR))
+
+ if 'depth_mask' in sample:
+ mask = sample['depth_mask']
+ cv2.imwrite(str(dump_path / 'gt' /'mask.png'), (mask.cpu().numpy() * 255).astype(np.uint8))
+
+ if 'intrinsics' in sample:
+ intrinsics = sample['intrinsics']
+ fov_x, fov_y = intrinsics_to_fov(intrinsics)
+ with open(dump_path / 'gt' / 'info.json', 'w') as f:
+ json.dump({
+ 'fov_x': np.rad2deg(fov_x.item()),
+ 'fov_y': np.rad2deg(fov_y.item()),
+ 'intrinsics': intrinsics.cpu().numpy().tolist(),
+ }, f)
+
+ # Save intermediate results
+ if i % 100 == 0 or i == len(eval_data_pipe) - 1:
+ Path(output_path).write_text(
+ json.dumps({
+ **all_metrics,
+ benchmark_name: key_average(metrics_list)
+ }, indent=4)
+ )
+ pbar.update(1)
+
+ all_metrics[benchmark_name] = key_average(metrics_list)
+
+ # Save final results
+ all_metrics['mean'] = key_average(list(all_metrics.values()))
+ Path(output_path).write_text(json.dumps(all_metrics, indent=4))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/infer.py b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..09990f34907cd77ddf5b57aeda4fcf27e07a5254
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/infer.py
@@ -0,0 +1,170 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+from pathlib import Path
+import sys
+if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
+ sys.path.insert(0, _package_root)
+from typing import *
+import itertools
+import json
+import warnings
+
+import click
+
+
+@click.command(help='Inference script')
+@click.option('--input', '-i', 'input_path', type=click.Path(exists=True), help='Input image or folder path. "jpg" and "png" are supported.')
+@click.option('--fov_x', 'fov_x_', type=float, default=None, help='If camera parameters are known, set the horizontal field of view in degrees. Otherwise, MoGe will estimate it.')
+@click.option('--output', '-o', 'output_path', default='./output', type=click.Path(), help='Output folder path')
+@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default=None, help='Pretrained model name or path. If not provided, the corresponding default model will be chosen.')
+@click.option('--version', 'model_version', type=click.Choice(['v1', 'v2']), default='v2', help='Model version. Defaults to "v2"')
+@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"')
+@click.option('--fp16', 'use_fp16', is_flag=True, help='Use fp16 precision for much faster inference.')
+@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).')
+@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level for inference. \
+Higher value means more tokens and the finer details will be captured, but inference can be slower. \
+Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. \
+`resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.')
+@click.option('--num_tokens', type=int, default=None, help='number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. \
+`resolution_level` will be ignored if `num_tokens` is provided. Default: None')
+@click.option('--threshold', type=float, default=0.04, help='Threshold for removing edges. Defaults to 0.01. Smaller value removes more edges. "inf" means no thresholding.')
+@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps (image, point map, depth map, normal map, mask) and fov.')
+@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.')
+@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.')
+@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.')
+def main(
+ input_path: str,
+ fov_x_: float,
+ output_path: str,
+ pretrained_model_name_or_path: str,
+ model_version: str,
+ device_name: str,
+ use_fp16: bool,
+ resize_to: int,
+ resolution_level: int,
+ num_tokens: int,
+ threshold: float,
+ save_maps_: bool,
+ save_glb_: bool,
+ save_ply_: bool,
+ show: bool,
+):
+ import cv2
+ import numpy as np
+ import torch
+ from PIL import Image
+ from tqdm import tqdm
+ import click
+
+ from moge.model import import_model_class_by_version
+ from moge.utils.io import save_glb, save_ply
+ from moge.utils.vis import colorize_depth, colorize_normal
+ from moge.utils.geometry_numpy import depth_occlusion_edge_numpy
+ import utils3d
+
+ device = torch.device(device_name)
+
+ include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
+ if Path(input_path).is_dir():
+ image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices)))
+ else:
+ image_paths = [Path(input_path)]
+
+ if len(image_paths) == 0:
+ raise FileNotFoundError(f'No image files found in {input_path}')
+
+ if pretrained_model_name_or_path is None:
+ DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = {
+ "v1": "Ruicheng/moge-vitl",
+ "v2": "Ruicheng/moge-2-vitl-normal",
+ }
+ pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version]
+ model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).to(device).eval()
+ if use_fp16:
+ model.half()
+
+ if not any([save_maps_, save_glb_, save_ply_]):
+ warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.')
+ save_maps_ = save_glb_ = save_ply_ = True
+
+ for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)):
+ if not image_path.exists():
+ raise FileNotFoundError(f'File {image_path} does not exist.')
+ image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
+ height, width = image.shape[:2]
+ if resize_to is not None:
+ height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height))
+ image = cv2.resize(image, (width, height), cv2.INTER_AREA)
+ image_tensor = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1)
+
+ # Inference
+ output = model.infer(image_tensor, fov_x=fov_x_, resolution_level=resolution_level, num_tokens=num_tokens, use_fp16=use_fp16)
+ points, depth, mask, intrinsics = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy(), output['intrinsics'].cpu().numpy()
+ normal = output['normal'].cpu().numpy() if 'normal' in output else None
+
+ save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem)
+ save_path.mkdir(exist_ok=True, parents=True)
+
+ # Save images / maps
+ if save_maps_:
+ cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
+ cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR))
+ cv2.imwrite(str(save_path / 'depth.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+ cv2.imwrite(str(save_path / 'mask.png'), (mask * 255).astype(np.uint8))
+ cv2.imwrite(str(save_path / 'points.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+ if normal is not None:
+ cv2.imwrite(str(save_path / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR))
+ fov_x, fov_y = utils3d.np.intrinsics_to_fov(intrinsics)
+ with open(save_path / 'fov.json', 'w') as f:
+ json.dump({
+ 'fov_x': round(float(np.rad2deg(fov_x)), 2),
+ 'fov_y': round(float(np.rad2deg(fov_y)), 2),
+ }, f)
+
+ # Export mesh & visulization
+ if save_glb_ or save_ply_ or show:
+ mask_cleaned = mask & ~utils3d.np.depth_map_edge(depth, rtol=threshold)
+ if normal is None:
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.np.build_mesh_from_map(
+ points,
+ image.astype(np.float32) / 255,
+ utils3d.np.uv_map(height, width),
+ mask=mask_cleaned,
+ tri=True
+ )
+ vertex_normals = None
+ else:
+ faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.np.build_mesh_from_map(
+ points,
+ image.astype(np.float32) / 255,
+ utils3d.np.uv_map(height, width),
+ normal,
+ mask=mask_cleaned,
+ tri=True
+ )
+ # When exporting the model, follow the OpenGL coordinate conventions:
+ # - world coordinate system: x right, y up, z backward.
+ # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top.
+ vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1]
+ if normal is not None:
+ vertex_normals = vertex_normals * [1, -1, -1]
+
+ if save_glb_:
+ save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image, vertex_normals)
+
+ if save_ply_:
+ save_ply(save_path / 'pointcloud.ply', vertices, np.zeros((0, 3), dtype=np.int32), vertex_colors, vertex_normals)
+
+ if show:
+ import trimesh
+ trimesh.Trimesh(
+ vertices=vertices,
+ vertex_colors=vertex_colors,
+ vertex_normals=vertex_normals,
+ faces=faces,
+ process=False
+ ).show()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/infer_baseline.py b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/infer_baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef81bc4792fe8e860b190bdc8265d73984f7911b
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/infer_baseline.py
@@ -0,0 +1,140 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+from pathlib import Path
+import sys
+if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
+ sys.path.insert(0, _package_root)
+import json
+from pathlib import Path
+from typing import *
+import itertools
+import warnings
+
+import click
+
+
+@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Inference script for wrapped baselines methods')
+@click.option('--baseline', 'baseline_code_path', required=True, type=click.Path(), help='Path to the baseline model python code.')
+@click.option('--input', '-i', 'input_path', type=str, required=True, help='Input image or folder')
+@click.option('--output', '-o', 'output_path', type=str, default='./output', help='Output folder')
+@click.option('--size', 'image_size', type=int, default=None, help='Resize input image')
+@click.option('--skip', is_flag=True, help='Skip existing output')
+@click.option('--maps', 'save_maps_', is_flag=True, help='Save output point / depth maps')
+@click.option('--ply', 'save_ply_', is_flag=True, help='Save mesh in PLY format')
+@click.option('--glb', 'save_glb_', is_flag=True, help='Save mesh in GLB format')
+@click.option('--threshold', type=float, default=0.03, help='Depth edge detection threshold for saving mesh')
+@click.pass_context
+def main(ctx: click.Context, baseline_code_path: str, input_path: str, output_path: str, image_size: int, skip: bool, save_maps_, save_ply_: bool, save_glb_: bool, threshold: float):
+ # Lazy import
+ import cv2
+ import numpy as np
+ from tqdm import tqdm
+ import torch
+ import utils3d
+
+ from moge.utils.io import save_ply, save_glb
+ from moge.utils.geometry_numpy import intrinsics_to_fov_numpy
+ from moge.utils.vis import colorize_depth, colorize_depth_affine, colorize_disparity
+ from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module
+ from moge.test.baseline import MGEBaselineInterface
+
+ # Load the baseline model
+ module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem)
+ baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline')
+ baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False)
+
+ # Input images list
+ include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
+ if Path(input_path).is_dir():
+ image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices)))
+ else:
+ image_paths = [Path(input_path)]
+
+ if not any([save_maps_, save_glb_, save_ply_]):
+ warnings.warn('No output format specified. Defaults to saving maps only. Please use "--maps", "--glb", or "--ply" to specify the output.')
+ save_maps_ = True
+
+ for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)):
+ # Load one image at a time
+ image_np = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
+ height, width = image_np.shape[:2]
+ if image_size is not None and max(image_np.shape[:2]) > image_size:
+ height, width = min(image_size, int(image_size * height / width)), min(image_size, int(image_size * width / height))
+ image_np = cv2.resize(image_np, (width, height), cv2.INTER_AREA)
+ image = torch.from_numpy(image_np.astype(np.float32) / 255.0).permute(2, 0, 1).to(baseline.device)
+
+ # Inference
+ torch.cuda.synchronize()
+ with torch.inference_mode(), (timer := timeit('Inference', verbose=False, average=True)):
+ output = baseline.infer(image)
+ torch.cuda.synchronize()
+
+ inference_time = timer.average_time
+ pbar.set_postfix({'average inference time': f'{inference_time:.3f}s'})
+
+ # Save the output
+ save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem)
+ if skip and save_path.exists():
+ continue
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ if save_maps_:
+ cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
+
+ if 'mask' in output:
+ mask = output['mask'].cpu().numpy()
+ cv2.imwrite(str(save_path /'mask.png'), (mask * 255).astype(np.uint8))
+
+ for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']:
+ if k in output:
+ points = output[k].cpu().numpy()
+ cv2.imwrite(str(save_path / f'{k}.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+
+ for k in ['depth_metric', 'depth_scale_invariant', 'depth_affine_invariant', 'disparity_affine_invariant']:
+ if k in output:
+ depth = output[k].cpu().numpy()
+ cv2.imwrite(str(save_path / f'{k}.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+ if k in ['depth_metric', 'depth_scale_invariant']:
+ depth_vis = colorize_depth(depth)
+ elif k == 'depth_affine_invariant':
+ depth_vis = colorize_depth_affine(depth)
+ elif k == 'disparity_affine_invariant':
+ depth_vis = colorize_disparity(depth)
+ cv2.imwrite(str(save_path / f'{k}_vis.png'), cv2.cvtColor(depth_vis, cv2.COLOR_RGB2BGR))
+
+ if 'intrinsics' in output:
+ intrinsics = output['intrinsics'].cpu().numpy()
+ fov_x, fov_y = intrinsics_to_fov_numpy(intrinsics)
+ with open(save_path / 'fov.json', 'w') as f:
+ json.dump({
+ 'fov_x': float(np.rad2deg(fov_x)),
+ 'fov_y': float(np.rad2deg(fov_y)),
+ 'intrinsics': intrinsics.tolist()
+ }, f, indent=4)
+
+ # Export mesh & visulization
+ if save_ply_ or save_glb_:
+ assert any(k in output for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']), 'No point map found in output'
+ points = next(output[k] for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant'] if k in output).cpu().numpy()
+ mask = output['mask'] if 'mask' in output else np.ones_like(points[..., 0], dtype=bool)
+ normals, normals_mask = utils3d.np.point_map_to_normal_map(points, mask=mask)
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.np.build_mesh_from_map(
+ points,
+ image_np.astype(np.float32) / 255,
+ utils3d.np.uv_map(height, width),
+ mask=mask & ~(utils3d.np.depth_map_edge(depth, rtol=threshold, mask=mask) & utils3d.np.normal_map_edge(normals, tol=5, mask=normals_mask)),
+ tri=True
+ )
+ # When exporting the model, follow the OpenGL coordinate conventions:
+ # - world coordinate system: x right, y up, z backward.
+ # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top.
+ vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1]
+
+ if save_glb_:
+ save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image_np)
+
+ if save_ply_:
+ save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors)
+
+if __name__ == '__main__':
+ main()
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/infer_panorama.py b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/infer_panorama.py
new file mode 100644
index 0000000000000000000000000000000000000000..525a8ad711417c0c4c7fd85e659105d89afbee12
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/infer_panorama.py
@@ -0,0 +1,162 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+from pathlib import Path
+import sys
+if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
+ sys.path.insert(0, _package_root)
+from typing import *
+import itertools
+import json
+import warnings
+
+import click
+
+
+@click.command(help='Inference script for panorama images')
+@click.option('--input', '-i', 'input_path', type=click.Path(exists=True), required=True, help='Input image or folder path. "jpg" and "png" are supported.')
+@click.option('--output', '-o', 'output_path', type=click.Path(), default='./output', help='Output folder path')
+@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Defaults to "Ruicheng/moge-vitl"')
+@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"')
+@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).')
+@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level of inference. The higher, the better but slower. Defaults to 9. Note that it is irrelevant to the output resolution.')
+@click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Defaults to 0.03. Smaller value removes more edges. "inf" means no thresholding.')
+@click.option('--batch_size', type=int, default=4, help='Batch size for inference. Defaults to 4.')
+@click.option('--splitted', 'save_splitted', is_flag=True, help='Whether to save the splitted images. Defaults to False.')
+@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).')
+@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.')
+@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.')
+@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.')
+def main(
+ input_path: str,
+ output_path: str,
+ pretrained_model_name_or_path: str,
+ device_name: str,
+ resize_to: int,
+ resolution_level: int,
+ threshold: float,
+ batch_size: int,
+ save_splitted: bool,
+ save_maps_: bool,
+ save_glb_: bool,
+ save_ply_: bool,
+ show: bool,
+):
+ # Lazy import
+ import cv2
+ import numpy as np
+ from numpy import ndarray
+ import torch
+ from PIL import Image
+ from tqdm import tqdm, trange
+ import trimesh
+ import trimesh.visual
+ from scipy.sparse import csr_array, hstack, vstack
+ from scipy.ndimage import convolve
+ from scipy.sparse.linalg import lsmr
+
+ import utils3d
+ from moge.model.v1 import MoGeModel
+ from moge.utils.io import save_glb, save_ply
+ from moge.utils.vis import colorize_depth
+ from moge.utils.panorama import spherical_uv_to_directions, get_panorama_cameras, split_panorama_image, merge_panorama_depth
+
+
+ device = torch.device(device_name)
+
+ include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']
+ if Path(input_path).is_dir():
+ image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices)))
+ else:
+ image_paths = [Path(input_path)]
+
+ if len(image_paths) == 0:
+ raise FileNotFoundError(f'No image files found in {input_path}')
+
+ # Write outputs
+ if not any([save_maps_, save_glb_, save_ply_]):
+ warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.')
+ save_maps_ = save_glb_ = save_ply_ = True
+
+ model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval()
+
+ for image_path in (pbar := tqdm(image_paths, desc='Total images', disable=len(image_paths) <= 1)):
+ image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB)
+ height, width = image.shape[:2]
+ if resize_to is not None:
+ height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height))
+ image = cv2.resize(image, (width, height), cv2.INTER_AREA)
+
+ splitted_extrinsics, splitted_intriniscs = get_panorama_cameras()
+ splitted_resolution = 512
+ splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution)
+
+ # Infer each view
+ print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring')
+
+ splitted_distance_maps, splitted_masks = [], []
+ for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False):
+ image_tensor = torch.tensor(np.stack(splitted_images[i:i + batch_size]) / 255, dtype=torch.float32, device=device).permute(0, 3, 1, 2)
+ fov_x, fov_y = np.rad2deg(utils3d.np.intrinsics_to_fov(np.array(splitted_intriniscs[i:i + batch_size])))
+ fov_x = torch.tensor(fov_x, dtype=torch.float32, device=device)
+ output = model.infer(image_tensor, fov_x=fov_x, apply_mask=False)
+ distance_map, mask = output['points'].norm(dim=-1).cpu().numpy(), output['mask'].cpu().numpy()
+ splitted_distance_maps.extend(list(distance_map))
+ splitted_masks.extend(list(mask))
+
+ # Save splitted
+ if save_splitted:
+ splitted_save_path = Path(output_path, image_path.stem, 'splitted')
+ splitted_save_path.mkdir(exist_ok=True, parents=True)
+ for i in range(len(splitted_images)):
+ cv2.imwrite(str(splitted_save_path / f'{i:02d}.jpg'), cv2.cvtColor(splitted_images[i], cv2.COLOR_RGB2BGR))
+ cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR))
+
+ # Merge
+ print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging')
+
+ merging_width, merging_height = min(1920, width), min(960, height)
+ panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs)
+ panorama_depth = panorama_depth.astype(np.float32)
+ panorama_depth = cv2.resize(panorama_depth, (width, height), cv2.INTER_LINEAR)
+ panorama_mask = cv2.resize(panorama_mask.astype(np.uint8), (width, height), cv2.INTER_NEAREST) > 0
+ points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.np.uv_map(height, width))
+
+ # Write outputs
+ print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring')
+ save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem)
+ save_path.mkdir(exist_ok=True, parents=True)
+ if save_maps_:
+ cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
+ cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(panorama_depth, mask=panorama_mask), cv2.COLOR_RGB2BGR))
+ cv2.imwrite(str(save_path / 'depth.exr'), panorama_depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+ cv2.imwrite(str(save_path / 'points.exr'), points, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+ cv2.imwrite(str(save_path /'mask.png'), (panorama_mask * 255).astype(np.uint8))
+
+ # Export mesh & visulization
+ if save_glb_ or save_ply_ or show:
+ normals, normals_mask = utils3d.np.point_map_to_normal_map(points, panorama_mask)
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.np.build_mesh_from_map(
+ points,
+ image.astype(np.float32) / 255,
+ utils3d.np.uv_map(height, width),
+ mask=panorama_mask & ~(utils3d.np.depth_map_edge(panorama_depth, rtol=threshold) & utils3d.np.normal_map_edge(normals, tol=5, mask=normals_mask)),
+ tri=True
+ )
+
+ if save_glb_:
+ save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image)
+
+ if save_ply_:
+ save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors)
+
+ if show:
+ trimesh.Trimesh(
+ vertices=vertices,
+ vertex_colors=vertex_colors,
+ faces=faces,
+ process=False
+ ).show()
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/train.py b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d810cda18c3503eb7f56379ce5b84177e8c0398
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/train.py
@@ -0,0 +1,461 @@
+import os
+from pathlib import Path
+import sys
+if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
+ sys.path.insert(0, _package_root)
+import json
+import time
+import random
+from typing import *
+import itertools
+from contextlib import nullcontext
+from concurrent.futures import ThreadPoolExecutor
+import io
+
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.version
+import accelerate
+from accelerate import Accelerator, DistributedDataParallelKwargs
+from accelerate.utils import set_seed
+import utils3d
+import click
+from tqdm import tqdm, trange
+import mlflow
+torch.backends.cudnn.benchmark = False # Varying input size, make sure cudnn benchmark is disabled
+
+from moge.train.dataloader import TrainDataLoaderPipeline
+from moge.train.losses import (
+ affine_invariant_global_loss,
+ affine_invariant_local_loss,
+ edge_loss,
+ normal_loss,
+ mask_l2_loss,
+ mask_bce_loss,
+ metric_scale_loss,
+ normal_map_loss,
+ monitoring,
+)
+from moge.train.utils import build_optimizer, build_lr_scheduler
+from moge.utils.geometry_torch import intrinsics_to_fov
+from moge.utils.vis import colorize_depth, colorize_normal
+from moge.utils.tools import key_average, recursive_replace, CallbackOnException, flatten_nested_dict
+from moge.test.metrics import compute_metrics
+
+
+@click.command()
+@click.option('--config', 'config_path', type=str, default='configs/debug.json')
+@click.option('--workspace', type=str, default='workspace/debug', help='Path to the workspace')
+@click.option('--checkpoint', 'checkpoint_path', type=str, default=None, help='Path to the checkpoint to load. "latest" to load latest checkpoint in workspace, integer to load by step number')
+@click.option('--batch_size_forward', type=int, default=8, help='Batch size for each forward pass on each device')
+@click.option('--gradient_accumulation_steps', type=int, default=1, help='Number of steps to accumulate gradients')
+@click.option('--enable_gradient_checkpointing', type=bool, default=True, help='Use gradient checkpointing in backbone')
+@click.option('--enable_mixed_precision', type=bool, default=False, help='Use mixed precision training. Backbone is converted to FP16')
+@click.option('--enable_ema', type=bool, default=True, help='Maintain an exponential moving average of the model weights')
+@click.option('--num_iterations', type=int, default=1000000, help='Number of iterations to train the model')
+@click.option('--save_every', type=int, default=10000, help='Save checkpoint every n iterations')
+@click.option('--log_every', type=int, default=1000, help='Log metrics every n iterations')
+@click.option('--vis_every', type=int, default=0, help='Visualize every n iterations')
+@click.option('--num_vis_images', type=int, default=32, help='Number of images to visualize, must be a multiple of divided batch size')
+@click.option('--enable_mlflow', type=bool, default=True, help='Log metrics to MLFlow')
+@click.option('--seed', type=int, default=0, help='Random seed')
+def main(
+ config_path: str,
+ workspace: str,
+ checkpoint_path: str,
+ batch_size_forward: int,
+ gradient_accumulation_steps: int,
+ enable_gradient_checkpointing: bool,
+ enable_mixed_precision: bool,
+ enable_ema: bool,
+ num_iterations: int,
+ save_every: int,
+ log_every: int,
+ vis_every: int,
+ num_vis_images: int,
+ enable_mlflow: bool,
+ seed: Optional[int],
+):
+ # Load config
+ with open(config_path, 'r') as f:
+ config = json.load(f)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ mixed_precision='fp16' if enable_mixed_precision else None,
+ kwargs_handlers=[
+ DistributedDataParallelKwargs(find_unused_parameters=True)
+ ]
+ )
+ device = accelerator.device
+ batch_size_total = batch_size_forward * gradient_accumulation_steps * accelerator.num_processes
+
+ # Log config
+ if accelerator.is_main_process:
+ if enable_mlflow:
+ try:
+ mlflow.log_params({
+ **click.get_current_context().params,
+ 'batch_size_total': batch_size_total,
+ })
+ except:
+ print('Failed to log config to MLFlow')
+ Path(workspace).mkdir(parents=True, exist_ok=True)
+ with Path(workspace).joinpath('config.json').open('w') as f:
+ json.dump(config, f, indent=4)
+
+ # Set seed
+ if seed is not None:
+ set_seed(seed, device_specific=True)
+
+ # Initialize model
+ print('Initialize model')
+ with accelerator.local_main_process_first():
+ from moge.model import import_model_class_by_version
+ MoGeModel = import_model_class_by_version(config['model_version'])
+ model = MoGeModel(**config['model'])
+ count_total_parameters = sum(p.numel() for p in model.parameters())
+ print(f'Total parameters: {count_total_parameters}')
+
+ # Set up EMA model
+ if enable_ema and accelerator.is_main_process:
+ ema_avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: 0.999 * averaged_model_parameter + 0.001 * model_parameter
+ ema_model = torch.optim.swa_utils.AveragedModel(model, device=accelerator.device, avg_fn=ema_avg_fn)
+
+ # Set gradient checkpointing
+ if enable_gradient_checkpointing:
+ model.enable_gradient_checkpointing()
+ import warnings
+ warnings.filterwarnings("ignore", category=FutureWarning, module="torch.utils.checkpoint")
+
+ # Initalize optimizer & lr scheduler
+ optimizer = build_optimizer(model, config['optimizer'])
+ lr_scheduler = build_lr_scheduler(optimizer, config['lr_scheduler'])
+
+ count_grouped_parameters = [sum(p.numel() for p in param_group['params'] if p.requires_grad) for param_group in optimizer.param_groups]
+ for i, count in enumerate(count_grouped_parameters):
+ print(f'- Group {i}: {count} parameters')
+
+ # Attempt to load checkpoint
+ checkpoint: Dict[str, Any]
+ with accelerator.local_main_process_first():
+ if checkpoint_path is None:
+ # - No checkpoint
+ checkpoint = None
+ elif checkpoint_path.endswith('.pt'):
+ # - Load specific checkpoint file
+ print(f'Load checkpoint: {checkpoint_path}')
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
+ elif checkpoint_path == "latest":
+ # - Load latest checkpoint
+ checkpoint_path = Path(workspace, 'checkpoint', 'latest.pt')
+ if checkpoint_path.exists():
+ print(f'Load checkpoint: {checkpoint_path}')
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
+ i_step = checkpoint['step']
+ if 'model' not in checkpoint and (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists():
+ print(f'Load model checkpoint: {checkpoint_model_path}')
+ checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model']
+ if 'optimizer' not in checkpoint and (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists():
+ print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}')
+ checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True))
+ if enable_ema and accelerator.is_main_process:
+ if 'ema_model' not in checkpoint and (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists():
+ print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}')
+ checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model']
+ else:
+ print(f'No latest checkpoint found. Start from scratch.')
+ checkpoint = None
+ else:
+ # - Load by step number
+ i_step = int(checkpoint_path)
+ checkpoint = {'step': i_step}
+ if (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists():
+ print(f'Load model checkpoint: {checkpoint_model_path}')
+ checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model']
+ if (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists():
+ print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}')
+ checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True))
+ if enable_ema and accelerator.is_main_process:
+ if (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists():
+ print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}')
+ checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model']
+
+ if checkpoint is None:
+ # Initialize model weights
+ print('Initialize model weights')
+ with accelerator.local_main_process_first():
+ model.init_weights()
+ initial_step = 0
+ else:
+ model.load_state_dict(checkpoint['model'], strict=False)
+ if 'step' in checkpoint:
+ initial_step = checkpoint['step'] + 1
+ else:
+ initial_step = 0
+ if 'optimizer' in checkpoint:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ if enable_ema and accelerator.is_main_process and 'ema_model' in checkpoint:
+ ema_model.module.load_state_dict(checkpoint['ema_model'], strict=False)
+ if 'lr_scheduler' in checkpoint:
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+
+ del checkpoint
+
+ model, optimizer = accelerator.prepare(model, optimizer)
+ if torch.version.hip and isinstance(model, torch.nn.parallel.DistributedDataParallel):
+ # Hacking potential gradient synchronization issue in ROCm backend
+ from moge.model.utils import sync_ddp_hook
+ model.register_comm_hook(None, sync_ddp_hook)
+
+ # Initialize training data pipeline
+ with accelerator.local_main_process_first():
+ train_data_pipe = TrainDataLoaderPipeline(config['data'], batch_size_forward)
+
+ def _write_bytes_retry_loop(save_path: Path, data: bytes):
+ while True:
+ try:
+ save_path.write_bytes(data)
+ break
+ except Exception as e:
+ print('Error while saving checkpoint, retrying in 1 minute: ', e)
+ time.sleep(60)
+
+ # Ready to train
+ records = []
+ model.train()
+ with (
+ train_data_pipe,
+ tqdm(initial=initial_step, total=num_iterations, desc='Training', disable=not accelerator.is_main_process) as pbar,
+ ThreadPoolExecutor(max_workers=1) as save_checkpoint_executor,
+ ):
+ # Get some batches for visualization
+ if accelerator.is_main_process:
+ batches_for_vis: List[Dict[str, torch.Tensor]] = []
+ num_vis_images = num_vis_images // batch_size_forward * batch_size_forward
+ for _ in range(num_vis_images // batch_size_forward):
+ batch = train_data_pipe.get()
+ batches_for_vis.append(batch)
+
+ # Visualize GT
+ if vis_every > 0 and accelerator.is_main_process and initial_step == 0:
+ save_dir = Path(workspace).joinpath('vis/gt')
+ for i_batch, batch in enumerate(tqdm(batches_for_vis, desc='Visualize GT', leave=False)):
+ image, gt_depth, gt_normal, gt_intrinsics, info = batch['image'], batch['depth'], batch['normal'], batch['intrinsics'], batch['info']
+ gt_points = utils3d.pt.depth_map_to_point_map(gt_depth, intrinsics=gt_intrinsics)
+ for i_instance in range(batch['image'].shape[0]):
+ idx = i_batch * batch_size_forward + i_instance
+ image_i = (image[i_instance].numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
+ gt_depth_i = gt_depth[i_instance].numpy()
+ gt_points_i = gt_points[i_instance].numpy()
+ gt_normal_i = gt_normal[i_instance].numpy()
+ save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True)
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR))
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(gt_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(gt_depth_i), cv2.COLOR_RGB2BGR))
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/normal.png')), cv2.cvtColor(colorize_normal(gt_normal_i), cv2.COLOR_RGB2BGR))
+ with save_dir.joinpath(f'{idx:04d}/info.json').open('w') as f:
+ json.dump(info[i_instance], f)
+
+ # Reset seed to avoid training on the same data when resuming training
+ if seed is not None:
+ set_seed(seed + initial_step, device_specific=True)
+
+ # Training loop
+ for i_step in range(initial_step, num_iterations):
+
+ i_accumulate, weight_accumulate = 0, 0
+ while i_accumulate < gradient_accumulation_steps:
+ # Load batch
+ batch = train_data_pipe.get()
+ image, gt_depth, gt_normal, gt_mask_fin, gt_mask_inf, gt_intrinsics, label_type, is_metric = batch['image'], batch['depth'], batch['normal'], batch['depth_mask_fin'], batch['depth_mask_inf'], batch['intrinsics'], batch['label_type'], batch['is_metric']
+ image, gt_depth, gt_normal, gt_mask_fin, gt_mask_inf, gt_intrinsics = image.to(device), gt_depth.to(device), gt_normal.to(device), gt_mask_fin.to(device), gt_mask_inf.to(device), gt_intrinsics.to(device)
+ current_batch_size = image.shape[0]
+ if all(label == 'invalid' for label in label_type):
+ continue # NOTE: Skip all-invalid batches to avoid messing up the optimizer.
+
+ gt_points = utils3d.pt.depth_map_to_point_map(gt_depth, intrinsics=gt_intrinsics)
+ gt_focal = 1 / (1 / gt_intrinsics[..., 0, 0] ** 2 + 1 / gt_intrinsics[..., 1, 1] ** 2) ** 0.5
+
+ with accelerator.accumulate(model):
+ # Forward
+ if i_step <= config.get('low_resolution_training_steps', 0):
+ num_tokens = config['model']['num_tokens_range'][0]
+ else:
+ num_tokens = accelerate.utils.broadcast_object_list([random.randint(*config['model']['num_tokens_range'])])[0]
+ with torch.autocast(device_type=accelerator.device.type, dtype=torch.float16, enabled=enable_mixed_precision):
+ output = model(image, num_tokens=num_tokens)
+ pred_points, pred_mask, pred_normal, pred_metric_scale = (output.get(k, None) for k in ['points', 'mask', 'normal', 'metric_scale'])
+
+ # Compute loss (per instance)
+ loss_list, weight_list = [], []
+ for i in range(current_batch_size):
+ gt_metric_scale = None
+ loss_dict, weight_dict, misc_dict = {}, {}, {}
+ misc_dict['monitoring'] = monitoring(pred_points[i])
+ for k, v in config['loss'][label_type[i]].items():
+ weight_dict[k] = v['weight']
+ if v['function'] == 'affine_invariant_global_loss':
+ loss_dict[k], misc_dict[k], gt_metric_scale = affine_invariant_global_loss(pred_points[i], gt_points[i], **v['params'])
+ elif v['function'] == 'affine_invariant_local_loss':
+ loss_dict[k], misc_dict[k] = affine_invariant_local_loss(pred_points[i], gt_points[i], gt_focal[i], gt_metric_scale, **v['params'])
+ elif v['function'] == 'normal_loss':
+ loss_dict[k], misc_dict[k] = normal_loss(pred_points[i], gt_points[i])
+ elif v['function'] == 'edge_loss':
+ loss_dict[k], misc_dict[k] = edge_loss(pred_points[i], gt_points[i])
+ elif v['function'] == 'normal_map_loss':
+ loss_dict[k], misc_dict[k] = normal_map_loss(pred_normal[i], gt_normal[i])
+ elif v['function'] == 'mask_bce_loss':
+ loss_dict[k], misc_dict[k] = mask_bce_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i])
+ elif v['function'] == 'mask_l2_loss':
+ loss_dict[k], misc_dict[k] = mask_l2_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i])
+ elif v['function'] == 'metric_scale_loss':
+ if is_metric[i] and pred_metric_scale is not None:
+ loss_dict[k], misc_dict[k] = metric_scale_loss(pred_metric_scale[i], gt_metric_scale)
+ else:
+ raise ValueError(f'Undefined loss function: {v["function"]}')
+ weight_dict = {'.'.join(k): v for k, v in flatten_nested_dict(weight_dict).items()}
+ loss_dict = {'.'.join(k): v for k, v in flatten_nested_dict(loss_dict).items()}
+ loss_ = sum([weight_dict[k] * loss_dict[k] for k in loss_dict], start=torch.tensor(0.0, device=device))
+ loss_list.append(loss_)
+
+ if torch.isnan(loss_).item():
+ pbar.write(f'NaN loss in process {accelerator.process_index}')
+ pbar.write(str(loss_dict))
+
+ misc_dict = {'.'.join(k): v for k, v in flatten_nested_dict(misc_dict).items()}
+ records.append({
+ **{k: v.item() for k, v in loss_dict.items()},
+ **misc_dict,
+ })
+
+ loss = sum(loss_list) / len(loss_list)
+
+ # Backward & update
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ if not enable_mixed_precision and any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None):
+ if accelerator.is_main_process:
+ pbar.write(f'NaN gradients, skip update')
+ optimizer.zero_grad()
+ continue
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ i_accumulate += 1
+
+ lr_scheduler.step()
+
+ # EMA update
+ if enable_ema and accelerator.is_main_process and accelerator.sync_gradients:
+ ema_model.update_parameters(model)
+
+ # Log metrics
+ if i_step == initial_step or i_step % log_every == 0:
+ records = [key_average(records)]
+ records = accelerator.gather_for_metrics(records, use_gather_object=True)
+ if accelerator.is_main_process:
+ records = key_average(records)
+ if enable_mlflow:
+ try:
+ mlflow.log_metrics(records, step=i_step)
+ except Exception as e:
+ print(f'Error while logging metrics to mlflow: {e}')
+ records = []
+
+ # Save model weight checkpoint
+ if accelerator.is_main_process and (i_step % save_every == 0):
+ # NOTE: Writing checkpoint is done in a separate thread to avoid blocking the main process
+ pbar.write(f'Save checkpoint: {i_step:08d}')
+ Path(workspace, 'checkpoint').mkdir(parents=True, exist_ok=True)
+
+ # Model checkpoint
+ with io.BytesIO() as f:
+ torch.save({
+ 'model_config': config['model'],
+ 'model': accelerator.unwrap_model(model).state_dict(),
+ }, f)
+ checkpoint_bytes = f.getvalue()
+ save_checkpoint_executor.submit(
+ _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}.pt'), checkpoint_bytes
+ )
+
+ # Optimizer checkpoint
+ with io.BytesIO() as f:
+ torch.save({
+ 'model_config': config['model'],
+ 'step': i_step,
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ }, f)
+ checkpoint_bytes = f.getvalue()
+ save_checkpoint_executor.submit(
+ _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt'), checkpoint_bytes
+ )
+
+ # EMA model checkpoint
+ if enable_ema:
+ with io.BytesIO() as f:
+ torch.save({
+ 'model_config': config['model'],
+ 'model': ema_model.module.state_dict(),
+ }, f)
+ checkpoint_bytes = f.getvalue()
+ save_checkpoint_executor.submit(
+ _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt'), checkpoint_bytes
+ )
+
+ # Latest checkpoint
+ with io.BytesIO() as f:
+ torch.save({
+ 'model_config': config['model'],
+ 'step': i_step,
+ }, f)
+ checkpoint_bytes = f.getvalue()
+ save_checkpoint_executor.submit(
+ _write_bytes_retry_loop, Path(workspace, 'checkpoint', 'latest.pt'), checkpoint_bytes
+ )
+
+ # Visualize
+ if vis_every > 0 and accelerator.is_main_process and (i_step == initial_step or i_step % vis_every == 0):
+ unwrapped_model = accelerator.unwrap_model(model)
+ save_dir = Path(workspace).joinpath(f'vis/step_{i_step:08d}')
+ save_dir.mkdir(parents=True, exist_ok=True)
+ with torch.inference_mode():
+ for i_batch, batch in enumerate(tqdm(batches_for_vis, desc=f'Visualize: {i_step:08d}', leave=False)):
+ image, gt_depth, gt_intrinsics = batch['image'], batch['depth'], batch['intrinsics']
+ image, gt_depth, gt_intrinsics = image.to(device), gt_depth.to(device), gt_intrinsics.to(device)
+
+ output = unwrapped_model.infer(image)
+ pred_points = output['points'].cpu().numpy() if 'points' in output else None
+ pred_depth = output['depth'].cpu().numpy() if 'depth' in output else None
+ pred_mask = output['mask'].cpu().numpy() if 'mask' in output else None
+ pred_normal = output['normal'].cpu().numpy() if 'normal' in output else None
+ pred_uncertainty = output['uncertainty'].cpu().numpy() if 'uncertainty' in output else None
+ image = (image.cpu().numpy().transpose(0, 2, 3, 1) * 255).astype(np.uint8)
+
+ for i_instance in range(image.shape[0]):
+ idx = i_batch * batch_size_forward + i_instance
+ save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True)
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image[i_instance], cv2.COLOR_RGB2BGR))
+ if pred_points is not None:
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(pred_points[i_instance], cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])
+ if pred_mask is not None:
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), pred_mask[i_instance] * 255)
+ if pred_depth is not None:
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(pred_depth[i_instance], pred_mask[i_instance] if pred_mask is not None else None), cv2.COLOR_RGB2BGR))
+ if pred_normal is not None:
+ cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/normal_vis.png')), cv2.cvtColor(colorize_normal(pred_normal[i_instance], pred_mask[i_instance] if pred_mask is not None else None), cv2.COLOR_RGB2BGR))
+
+ pbar.set_postfix({'loss': loss.item()}, refresh=False)
+ pbar.update(1)
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/vis_data.py b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/vis_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcca724b194cc534483304d5e9b09c53d21868e4
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/scripts/vis_data.py
@@ -0,0 +1,84 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+import sys
+from pathlib import Path
+if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path:
+ sys.path.insert(0, _package_root)
+
+import click
+
+
+@click.command()
+@click.argument('folder_or_path', type=click.Path(exists=True))
+@click.option('--output', '-o', 'output_folder', type=click.Path(), help='Path to output folder')
+@click.option('--max_depth', '-m', type=float, default=float('inf'), help='max depth')
+@click.option('--fov', type=float, default=None, help='field of view in degrees')
+@click.option('--show', 'show', is_flag=True, help='show point cloud')
+@click.option('--depth', 'depth_filename', type=str, default='depth.png', help='depth image file name')
+@click.option('--ply', 'save_ply', is_flag=True, help='save point cloud as PLY file')
+@click.option('--depth_vis', 'save_depth_vis', is_flag=True, help='save depth image')
+@click.option('--inf', 'inf_mask', is_flag=True, help='use infinity mask')
+@click.option('--version', 'version', type=str, default='v3', help='version of rgbd data')
+def main(
+ folder_or_path: str,
+ output_folder: str,
+ max_depth: float,
+ fov: float,
+ depth_filename: str,
+ show: bool,
+ save_ply: bool,
+ save_depth_vis: bool,
+ inf_mask: bool,
+ version: str
+):
+ # Lazy import
+ import cv2
+ import numpy as np
+ import utils3d
+ from tqdm import tqdm
+ import trimesh
+
+ from moge.utils.io import read_image, read_depth, read_json
+ from moge.utils.vis import colorize_depth, colorize_normal
+
+ filepaths = sorted(p.parent for p in Path(folder_or_path).rglob('meta.json'))
+
+ for filepath in tqdm(filepaths):
+ image = read_image(Path(filepath, 'image.jpg'))
+ depth = read_depth(Path(filepath, depth_filename))
+ meta = read_json(Path(filepath,'meta.json'))
+ depth_mask = np.isfinite(depth)
+ depth_mask_inf = (depth == np.inf)
+ intrinsics = np.array(meta['intrinsics'])
+
+ extrinsics = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=float) # OpenGL's identity camera
+ verts = utils3d.np.unproject_cv(utils3d.np.uv_map(image.shape[:2]), depth, extrinsics=extrinsics, intrinsics=intrinsics)
+
+ depth_mask_ply = depth_mask & (depth < depth[depth_mask].min() * max_depth)
+ point_cloud = trimesh.PointCloud(verts[depth_mask_ply], image[depth_mask_ply] / 255)
+
+ if show:
+ point_cloud.show()
+
+ if output_folder is None:
+ output_path = filepath
+ else:
+ output_path = Path(output_folder, filepath.name)
+ output_path.mkdir(exist_ok=True, parents=True)
+
+ if inf_mask:
+ depth = np.where(depth_mask_inf, np.inf, depth)
+ depth_mask = depth_mask | depth_mask_inf
+
+ if save_depth_vis:
+ p = output_path.joinpath('depth_vis.png')
+ cv2.imwrite(str(p), cv2.cvtColor(colorize_depth(depth, depth_mask), cv2.COLOR_RGB2BGR))
+ print(f"{p}")
+
+ if save_ply:
+ p = output_path.joinpath('pointcloud.ply')
+ point_cloud.export(p)
+ print(f"{p}")
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/test/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/test/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/test/baseline.py b/lingbotvla/models/vla/vision_models/MoGe/moge/test/baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..05980aaf96870304534fcec6532225e870351a66
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/test/baseline.py
@@ -0,0 +1,43 @@
+from typing import *
+
+import click
+import torch
+
+
+class MGEBaselineInterface:
+ """
+ Abstract class for model wrapper to uniformize the interface of loading and inference across different models.
+ """
+ device: torch.device
+
+ @click.command()
+ @staticmethod
+ def load(*args, **kwargs) -> "MGEBaselineInterface":
+ """
+ Customized static method to create an instance of the model wrapper from command line arguments. Decorated by `click.command()`
+ """
+ raise NotImplementedError(f"{type(self).__name__} has not implemented the load method.")
+
+ def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
+ """
+ ### Parameters
+ `image`: [B, 3, H, W] or [3, H, W], RGB values in range [0, 1]
+ `intrinsics`: [B, 3, 3] or [3, 3], camera intrinsics. Optional.
+
+ ### Returns
+ A dictionary containing:
+ - `points_*`. point map output in OpenCV identity camera space.
+ Supported suffixes: `metric`, `scale_invariant`, `affine_invariant`.
+ - `depth_*`. depth map output
+ Supported suffixes: `metric` (in meters), `scale_invariant`, `affine_invariant`.
+ - `disparity_affine_invariant`. affine disparity map output
+ """
+ raise NotImplementedError(f"{type(self).__name__} has not implemented the infer method.")
+
+ def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
+ """
+ If the model has a special evaluation mode, override this method to provide the evaluation mode inference.
+
+ By default, this method simply calls `infer()`.
+ """
+ return self.infer(image, intrinsics)
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/test/dataloader.py b/lingbotvla/models/vla/vision_models/MoGe/moge/test/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..97a9298d887b11f6977cefcfcf54d74e2c1bf34c
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/test/dataloader.py
@@ -0,0 +1,221 @@
+import os
+from typing import *
+from pathlib import Path
+import math
+
+import numpy as np
+import torch
+from PIL import Image
+import cv2
+import utils3d
+import pipeline
+
+from ..utils.geometry_numpy import focal_to_fov_numpy, norm3d
+from ..utils.io import *
+from ..utils.tools import timeit
+
+
+class EvalDataLoaderPipeline:
+
+ def __init__(
+ self,
+ path: str,
+ width: int,
+ height: int,
+ split: int = '.index.txt',
+ drop_max_depth: float = 1000.,
+ num_load_workers: int = 4,
+ num_process_workers: int = 8,
+ include_segmentation: bool = False,
+ include_normal: bool = False,
+ depth_to_normal: bool = False,
+ max_segments: int = 100,
+ min_seg_area: int = 1000,
+ depth_unit: str = None,
+ has_sharp_boundary = False,
+ subset: int = None,
+ ):
+ filenames = Path(path).joinpath(split).read_text(encoding='utf-8').splitlines()
+ filenames = filenames[::subset]
+ self.width = width
+ self.height = height
+ self.drop_max_depth = drop_max_depth
+ self.path = Path(path)
+ self.filenames = filenames
+ self.include_segmentation = include_segmentation
+ self.include_normal = include_normal
+ self.max_segments = max_segments
+ self.min_seg_area = min_seg_area
+ self.depth_to_normal = depth_to_normal
+ self.depth_unit = depth_unit
+ self.has_sharp_boundary = has_sharp_boundary
+
+ self.rng = np.random.default_rng(seed=0)
+
+ self.pipeline = pipeline.Sequential([
+ self._generator,
+ pipeline.Parallel([self._load_instance] * num_load_workers),
+ pipeline.Parallel([self._process_instance] * num_process_workers),
+ pipeline.Buffer(4)
+ ])
+
+ def __len__(self):
+ return math.ceil(len(self.filenames))
+
+ def _generator(self):
+ for idx in range(len(self)):
+ yield idx
+
+ def _load_instance(self, idx):
+ if idx >= len(self.filenames):
+ return None
+
+ path = self.path.joinpath(self.filenames[idx])
+
+ instance = {
+ 'filename': self.filenames[idx],
+ 'width': self.width,
+ 'height': self.height,
+ }
+ instance['image'] = read_image(Path(path, 'image.jpg'))
+
+ depth = read_depth(Path(path, 'depth.png')) # ignore depth unit from depth file, use config instead
+ instance.update({
+ 'depth': np.nan_to_num(depth, nan=1, posinf=1, neginf=1),
+ 'depth_mask': np.isfinite(depth),
+ 'depth_mask_inf': np.isinf(depth),
+ })
+
+ if self.include_segmentation:
+ segmentation_mask, segmentation_labels = read_segmentation(Path(path,'segmentation.png'))
+ instance.update({
+ 'segmentation_mask': segmentation_mask,
+ 'segmentation_labels': segmentation_labels,
+ })
+
+ meta = read_meta(Path(path, 'meta.json'))
+ instance['intrinsics'] = np.array(meta['intrinsics'], dtype=np.float32)
+
+ return instance
+
+ def _process_instance(self, instance: dict):
+ if instance is None:
+ return None
+
+ image, depth, depth_mask, intrinsics = instance['image'], instance['depth'], instance['depth_mask'], instance['intrinsics']
+ segmentation_mask, segmentation_labels = instance.get('segmentation_mask', None), instance.get('segmentation_labels', None)
+
+ raw_height, raw_width = image.shape[:2]
+ raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
+ raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height
+ tgt_width, tgt_height = instance['width'], instance['height']
+ tgt_aspect = tgt_width / tgt_height
+
+ # set expected target view field
+ tgt_horizontal = min(raw_horizontal, raw_vertical * tgt_aspect)
+ tgt_vertical = tgt_horizontal / tgt_aspect
+
+ # set target view direction
+ cu, cv = 0.5, 0.5
+ direction = utils3d.np.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0]
+ R = utils3d.np.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32))
+
+ # restrict target view field within the raw view
+ corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
+ corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane
+ corners = corners[:, :2] / corners[:, 2:3]
+
+ warp_horizontal, warp_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
+ for i in range(4):
+ intersection, _ = utils3d.np.ray_intersection(
+ np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]),
+ corners[i - 1], corners[i] - corners[i - 1],
+ )
+ warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min())
+ tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical)
+
+ # get target view intrinsics
+ fx, fy = 1.0 / tgt_horizontal, 1.0 / tgt_vertical
+ tgt_intrinsics = utils3d.np.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32)
+
+ # do homogeneous transformation with the rotation and intrinsics
+ # 4.1 The image and depth is resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling
+ tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes)
+ rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h)
+ image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS))
+
+ depth, depth_mask = utils3d.np.masked_nearest_resize(depth, mask=depth_mask, size=(rescaled_h, rescaled_w))
+ distance = norm3d(utils3d.np.depth_map_to_point_map(depth, intrinsics=intrinsics))
+ segmentation_mask = cv2.resize(segmentation_mask, (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) if segmentation_mask is not None else None
+
+ # 4.2 calculate homography warping
+ transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics)
+ uv_tgt = utils3d.np.uv_map(tgt_height, tgt_width)
+ pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T
+ uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12)
+ pixel_remap = utils3d.np.uv_to_pixel(uv_remap, (rescaled_h, rescaled_w)).astype(np.float32)
+
+ tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR)
+ tgt_distance = cv2.remap(distance, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST)
+ tgt_ray_length = utils3d.np.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics)
+ tgt_ray_length = (tgt_ray_length[:, :, 0] ** 2 + tgt_ray_length[:, :, 1] ** 2 + tgt_ray_length[:, :, 2] ** 2) ** 0.5
+ tgt_depth = tgt_distance / (tgt_ray_length + 1e-12)
+ tgt_depth_mask = cv2.remap(depth_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
+ tgt_segmentation_mask = cv2.remap(segmentation_mask, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) if segmentation_mask is not None else None
+
+ # drop depth greater than drop_max_depth
+ max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.drop_max_depth
+ tgt_depth_mask &= tgt_depth <= max_depth
+ tgt_depth = np.nan_to_num(tgt_depth, nan=0.0)
+
+ if self.depth_unit is not None:
+ tgt_depth *= self.depth_unit
+
+ if not np.any(tgt_depth_mask):
+ # always make sure that mask is not empty, otherwise the loss calculation will crash
+ tgt_depth_mask = np.ones_like(tgt_depth_mask)
+ tgt_depth = np.ones_like(tgt_depth)
+ instance['label_type'] = 'invalid'
+
+ tgt_pts = utils3d.np.unproject_cv(uv_tgt, tgt_depth, intrinsics=tgt_intrinsics)
+
+ # Process segmentation labels
+ if self.include_segmentation and segmentation_mask is not None:
+ for k in ['undefined', 'unannotated', 'background', 'sky']:
+ if k in segmentation_labels:
+ del segmentation_labels[k]
+ seg_id2count = dict(zip(*np.unique(tgt_segmentation_mask, return_counts=True)))
+ sorted_labels = sorted(segmentation_labels.keys(), key=lambda x: seg_id2count.get(segmentation_labels[x], 0), reverse=True)
+ segmentation_labels = {k: segmentation_labels[k] for k in sorted_labels[:self.max_segments] if seg_id2count.get(segmentation_labels[k], 0) >= self.min_seg_area}
+
+ instance.update({
+ 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1),
+ 'depth': torch.from_numpy(tgt_depth).float(),
+ 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(),
+ 'intrinsics': torch.from_numpy(tgt_intrinsics).float(),
+ 'points': torch.from_numpy(tgt_pts).float(),
+ 'segmentation_mask': torch.from_numpy(tgt_segmentation_mask).long() if tgt_segmentation_mask is not None else None,
+ 'segmentation_labels': segmentation_labels,
+ 'is_metric': self.depth_unit is not None,
+ 'has_sharp_boundary': self.has_sharp_boundary,
+ })
+
+ instance = {k: v for k, v in instance.items() if v is not None}
+
+ return instance
+
+ def start(self):
+ self.pipeline.start()
+
+ def stop(self):
+ self.pipeline.stop()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.stop()
+
+ def get(self):
+ return self.pipeline.get()
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/test/metrics.py b/lingbotvla/models/vla/vision_models/MoGe/moge/test/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c79c3378199602734e345a999228da2607fcd9b
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/test/metrics.py
@@ -0,0 +1,342 @@
+from typing import *
+from numbers import Number
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import utils3d
+
+from ..utils.geometry_torch import (
+ weighted_mean,
+ intrinsics_to_fov
+)
+from ..utils.alignment import (
+ align_points_scale_z_shift,
+ align_points_scale_xyz_shift,
+ align_points_xyz_shift,
+ align_affine_lstsq,
+ align_depth_scale,
+ align_depth_affine,
+ align_points_scale,
+)
+from ..utils.tools import key_average, timeit
+
+
+def rel_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
+ rel = (torch.abs(pred - gt) / (gt + eps)).mean()
+ return rel.item()
+
+
+def delta1_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
+ delta1 = (torch.maximum(gt / pred, pred / gt) < 1.25).float().mean()
+ return delta1.item()
+
+
+def rel_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
+ dist_gt = torch.norm(gt, dim=-1)
+ dist_err = torch.norm(pred - gt, dim=-1)
+ rel = (dist_err / (dist_gt + eps)).mean()
+ return rel.item()
+
+
+def delta1_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6):
+ dist_pred = torch.norm(pred, dim=-1)
+ dist_gt = torch.norm(gt, dim=-1)
+ dist_err = torch.norm(pred - gt, dim=-1)
+
+ delta1 = (dist_err < 0.25 * torch.minimum(dist_gt, dist_pred)).float().mean()
+ return delta1.item()
+
+
+def rel_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor):
+ dist_err = torch.norm(pred - gt, dim=-1)
+ rel = (dist_err / diameter).mean()
+ return rel.item()
+
+
+def delta1_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor):
+ dist_err = torch.norm(pred - gt, dim=-1)
+ delta1 = (dist_err < 0.25 * diameter).float().mean()
+ return delta1.item()
+
+
+def boundary_f1(pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, radius: int = 1):
+ neighbor_x, neight_y = torch.meshgrid(
+ torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device),
+ torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device),
+ indexing='xy'
+ )
+ neighbor_mask = (neighbor_x ** 2 + neight_y ** 2) <= radius ** 2 + 1e-5
+
+ pred_window = utils3d.pt.sliding_window_2d(pred, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
+ gt_window = utils3d.pt.sliding_window_2d(gt, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
+ mask_window = neighbor_mask & utils3d.pt.sliding_window_2d(mask, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1]
+
+ pred_rel = pred_window / pred[radius:-radius, radius:-radius, None, None]
+ gt_rel = gt_window / gt[radius:-radius, radius:-radius, None, None]
+ valid = mask[radius:-radius, radius:-radius, None, None] & mask_window
+
+ f1_list = []
+ w_list = t_list = torch.linspace(0.05, 0.25, 10).tolist()
+
+ for t in t_list:
+ pred_label = pred_rel > 1 + t
+ gt_label = gt_rel > 1 + t
+ TP = (pred_label & gt_label & valid).float().sum()
+ precision = TP / (gt_label & valid).float().sum().clamp_min(1e-12)
+ recall = TP / (pred_label & valid).float().sum().clamp_min(1e-12)
+ f1 = 2 * precision * recall / (precision + recall).clamp_min(1e-12)
+ f1_list.append(f1.item())
+
+ f1_avg = sum(w * f1 for w, f1 in zip(w_list, f1_list)) / sum(w_list)
+ return f1_avg
+
+
+def compute_metrics(
+ pred: Dict[str, torch.Tensor],
+ gt: Dict[str, torch.Tensor],
+ vis: bool = False
+) -> Tuple[Dict[str, Dict[str, Number]], Dict[str, torch.Tensor]]:
+ """
+ A unified function to compute metrics for different types of predictions and ground truths.
+
+ #### Supported keys in pred:
+ - `disparity_affine_invariant`: disparity map predicted by a depth estimator with scale and shift invariant.
+ - `depth_scale_invariant`: depth map predicted by a depth estimator with scale invariant.
+ - `depth_affine_invariant`: depth map predicted by a depth estimator with scale and shift invariant.
+ - `depth_metric`: depth map predicted by a depth estimator with no scale or shift.
+ - `points_scale_invariant`: point map predicted by a point estimator with scale invariant.
+ - `points_affine_invariant`: point map predicted by a point estimator with scale and xyz shift invariant.
+ - `points_metric`: point map predicted by a point estimator with no scale or shift.
+ - `intrinsics`: normalized camera intrinsics matrix.
+
+ #### Required keys in gt:
+ - `depth`: depth map ground truth (in metric units if `depth_metric` is used)
+ - `points`: point map ground truth in camera coordinates.
+ - `mask`: mask indicating valid pixels in the ground truth.
+ - `intrinsics`: normalized ground-truth camera intrinsics matrix.
+ - `is_metric`: whether the depth is in metric units.
+ """
+ metrics = {}
+ misc = {}
+
+ mask = gt['depth_mask']
+ gt_depth = gt['depth']
+ gt_points = gt['points']
+
+ height, width = mask.shape[-2:]
+ lr_mask, lr_index = utils3d.pt.masked_nearest_resize(mask=mask, size=(64, 64), return_index=True)
+
+ only_depth = not any('point' in k for k in pred)
+ pred_depth_aligned, pred_points_aligned = None, None
+
+ # Metric depth
+ if 'depth_metric' in pred and gt['is_metric']:
+ pred_depth, gt_depth = pred['depth_metric'], gt['depth']
+ metrics['depth_metric'] = {
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
+ }
+
+ if pred_depth_aligned is None:
+ pred_depth_aligned = pred_depth
+
+ # Scale-invariant depth
+ if 'depth_scale_invariant' in pred:
+ pred_depth_scale_invariant = pred['depth_scale_invariant']
+ elif 'depth_metric' in pred:
+ pred_depth_scale_invariant = pred['depth_metric']
+ else:
+ pred_depth_scale_invariant = None
+
+ if pred_depth_scale_invariant is not None:
+ pred_depth = pred_depth_scale_invariant
+
+ pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask]
+ scale = align_depth_scale(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked)
+ pred_depth = pred_depth * scale
+
+ metrics['depth_scale_invariant'] = {
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
+ }
+
+ if pred_depth_aligned is None:
+ pred_depth_aligned = pred_depth
+
+ # Affine-invariant depth
+ if 'depth_affine_invariant' in pred:
+ pred_depth_affine_invariant = pred['depth_affine_invariant']
+ elif 'depth_scale_invariant' in pred:
+ pred_depth_affine_invariant = pred['depth_scale_invariant']
+ elif 'depth_metric' in pred:
+ pred_depth_affine_invariant = pred['depth_metric']
+ else:
+ pred_depth_affine_invariant = None
+
+ if pred_depth_affine_invariant is not None:
+ pred_depth = pred_depth_affine_invariant
+
+ pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask]
+ scale, shift = align_depth_affine(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked)
+ pred_depth = pred_depth * scale + shift
+
+ metrics['depth_affine_invariant'] = {
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
+ }
+
+ if pred_depth_aligned is None:
+ pred_depth_aligned = pred_depth
+
+ # Affine-invariant disparity
+ if 'disparity_affine_invariant' in pred:
+ pred_disparity_affine_invariant = pred['disparity_affine_invariant']
+ elif 'depth_scale_invariant' in pred:
+ pred_disparity_affine_invariant = 1 / pred['depth_scale_invariant']
+ elif 'depth_metric' in pred:
+ pred_disparity_affine_invariant = 1 / pred['depth_metric']
+ else:
+ pred_disparity_affine_invariant = None
+
+ if pred_disparity_affine_invariant is not None:
+ pred_disp = pred_disparity_affine_invariant
+
+ scale, shift = align_affine_lstsq(pred_disp[mask], 1 / gt_depth[mask])
+ pred_disp = pred_disp * scale + shift
+
+ # NOTE: The alignment is done on the disparity map could introduce extreme outliers at disparities close to 0.
+ # Therefore we clamp the disparities by minimum ground truth disparity.
+ pred_depth = 1 / pred_disp.clamp_min(1 / gt_depth[mask].max().item())
+
+ metrics['disparity_affine_invariant'] = {
+ 'rel': rel_depth(pred_depth[mask], gt_depth[mask]),
+ 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask])
+ }
+
+ if pred_depth_aligned is None:
+ pred_depth_aligned = 1 / pred_disp.clamp_min(1e-6)
+
+ # Metric points
+ if 'points_metric' in pred and gt['is_metric']:
+ pred_points = pred['points_metric']
+
+ pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask]
+ shift = align_points_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
+ pred_points = pred_points + shift
+
+ metrics['points_metric'] = {
+ 'rel': rel_point(pred_points[mask], gt_points[mask]),
+ 'delta1': delta1_point(pred_points[mask], gt_points[mask])
+ }
+
+ if pred_points_aligned is None:
+ pred_points_aligned = pred['points_metric']
+
+ # Scale-invariant points (in camera space)
+ if 'points_scale_invariant' in pred:
+ pred_points_scale_invariant = pred['points_scale_invariant']
+ elif 'points_metric' in pred:
+ pred_points_scale_invariant = pred['points_metric']
+ else:
+ pred_points_scale_invariant = None
+
+ if pred_points_scale_invariant is not None:
+ pred_points = pred_points_scale_invariant
+
+ pred_points_lr_masked, gt_points_lr_masked = pred_points_scale_invariant[lr_index][lr_mask], gt_points[lr_index][lr_mask]
+ scale = align_points_scale(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
+ pred_points = pred_points * scale
+
+ metrics['points_scale_invariant'] = {
+ 'rel': rel_point(pred_points[mask], gt_points[mask]),
+ 'delta1': delta1_point(pred_points[mask], gt_points[mask])
+ }
+
+ if vis and pred_points_aligned is None:
+ pred_points_aligned = pred['points_scale_invariant'] * scale
+
+ # Affine-invariant points
+ if 'points_affine_invariant' in pred:
+ pred_points_affine_invariant = pred['points_affine_invariant']
+ elif 'points_scale_invariant' in pred:
+ pred_points_affine_invariant = pred['points_scale_invariant']
+ elif 'points_metric' in pred:
+ pred_points_affine_invariant = pred['points_metric']
+ else:
+ pred_points_affine_invariant = None
+
+ if pred_points_affine_invariant is not None:
+ pred_points = pred_points_affine_invariant
+
+ pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask]
+ scale, shift = align_points_scale_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1))
+ pred_points = pred_points * scale + shift
+
+ metrics['points_affine_invariant'] = {
+ 'rel': rel_point(pred_points[mask], gt_points[mask]),
+ 'delta1': delta1_point(pred_points[mask], gt_points[mask])
+ }
+
+ if vis and pred_points_aligned is None:
+ pred_points_aligned = pred['points_affine_invariant'] * scale + shift
+
+ # Local points
+ if 'segmentation_mask' in gt and 'points' in gt and any('points' in k for k in pred.keys()):
+ pred_points = next(pred[k] for k in pred.keys() if 'points' in k)
+ gt_points = gt['points']
+ segmentation_mask = gt['segmentation_mask']
+ segmentation_labels = gt['segmentation_labels']
+ segmentation_mask_lr = segmentation_mask[lr_index]
+ local_points_metrics = []
+ for _, seg_id in segmentation_labels.items():
+ valid_mask = (segmentation_mask == seg_id) & mask
+
+ pred_points_masked = pred_points[valid_mask]
+ gt_points_masked = gt_points[valid_mask]
+
+ valid_mask_lr = (segmentation_mask_lr == seg_id) & lr_mask
+ if valid_mask_lr.sum().item() < 10:
+ continue
+ pred_points_masked_lr = pred_points[lr_index][valid_mask_lr]
+ gt_points_masked_lr = gt_points[lr_index][valid_mask_lr]
+ diameter = (gt_points_masked.max(dim=0).values - gt_points_masked.min(dim=0).values).max()
+ scale, shift = align_points_scale_xyz_shift(pred_points_masked_lr, gt_points_masked_lr, 1 / diameter.expand(gt_points_masked_lr.shape[0]))
+ pred_points_masked = pred_points_masked * scale + shift
+
+ local_points_metrics.append({
+ 'rel': rel_point_local(pred_points_masked, gt_points_masked, diameter),
+ 'delta1': delta1_point_local(pred_points_masked, gt_points_masked, diameter),
+ })
+
+ metrics['local_points'] = key_average(local_points_metrics)
+
+ # FOV. NOTE: If there is no random augmentation applied to the input images, all GT FOV are generallly the same.
+ # Fair evaluation of FOV requires random augmentation.
+ if 'intrinsics' in pred and 'intrinsics' in gt:
+ pred_intrinsics = pred['intrinsics']
+ gt_intrinsics = gt['intrinsics']
+ pred_fov_x, pred_fov_y = intrinsics_to_fov(pred_intrinsics)
+ gt_fov_x, gt_fov_y = intrinsics_to_fov(gt_intrinsics)
+ metrics['fov_x'] = {
+ 'mae': torch.rad2deg(pred_fov_x - gt_fov_x).abs().mean().item(),
+ 'deviation': torch.rad2deg(pred_fov_x - gt_fov_x).item(),
+ }
+
+ # Boundary F1
+ if pred_depth_aligned is not None and gt['has_sharp_boundary']:
+ metrics['boundary'] = {
+ 'radius1_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=1),
+ 'radius2_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=2),
+ 'radius3_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=3),
+ }
+
+ if vis:
+ if pred_points_aligned is not None:
+ misc['pred_points'] = pred_points_aligned
+ if only_depth:
+ misc['pred_points'] = utils3d.pt.depth_map_to_point_map(pred_depth_aligned, intrinsics=gt['intrinsics'])
+ if pred_depth_aligned is not None:
+ misc['pred_depth'] = pred_depth_aligned
+
+ return metrics, misc
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/train/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/train/dataloader.py b/lingbotvla/models/vla/vision_models/MoGe/moge/train/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08846d0f8cc4b74208ca55edc4dbea942b1a4a4
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/train/dataloader.py
@@ -0,0 +1,258 @@
+import os
+from pathlib import Path
+import json
+import time
+import random
+from typing import *
+import traceback
+import itertools
+from numbers import Number
+import io
+
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+import torchvision.transforms.v2.functional as TF
+import utils3d
+import pipeline
+from tqdm import tqdm
+
+from ..utils.io import *
+from ..utils.geometry_numpy import harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy
+from ..utils.data_augmentation import sample_perspective, warp_perspective, image_color_augmentation
+
+
+class TrainDataLoaderPipeline:
+ def __init__(self, config: dict, batch_size: int, num_load_workers: int = 4, num_process_workers: int = 8, buffer_size: int = 8):
+ self.config = config
+
+ self.batch_size = batch_size
+ self.clamp_max_depth = config['clamp_max_depth']
+ self.fov_range_absolute = config.get('fov_range_absolute', 0.0)
+ self.fov_range_relative = config.get('fov_range_relative', 0.0)
+ self.center_augmentation = config.get('center_augmentation', 0.0)
+ self.image_augmentation = config.get('image_augmentation', [])
+ self.depth_interpolation = config.get('depth_interpolation', 'bilinear')
+
+ if 'image_sizes' in config:
+ self.image_size_strategy = 'fixed'
+ self.image_sizes = config['image_sizes']
+ elif 'aspect_ratio_range' in config and 'area_range' in config:
+ self.image_size_strategy = 'aspect_area'
+ self.aspect_ratio_range = config['aspect_ratio_range']
+ self.area_range = config['area_range']
+ else:
+ raise ValueError('Invalid image size configuration')
+
+ # Load datasets
+ self.datasets = {}
+ for dataset in tqdm(config['datasets'], desc='Loading datasets'):
+ name = dataset['name']
+ content = Path(dataset['path'], dataset.get('index', '.index.txt')).joinpath().read_text()
+ filenames = content.splitlines()
+ self.datasets[name] = {
+ **dataset,
+ 'path': dataset['path'],
+ 'filenames': filenames,
+ }
+ self.dataset_names = [dataset['name'] for dataset in config['datasets']]
+ self.dataset_weights = [dataset['weight'] for dataset in config['datasets']]
+
+ # Build pipeline
+ self.pipeline = pipeline.Sequential([
+ self._sample_batch,
+ pipeline.Unbatch(),
+ pipeline.Parallel([self._load_instance] * num_load_workers),
+ pipeline.Parallel([self._process_instance] * num_process_workers),
+ pipeline.Batch(self.batch_size),
+ self._collate_batch,
+ pipeline.Buffer(buffer_size),
+ ])
+
+ self.invalid_instance = {
+ 'intrinsics': np.array([[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]], dtype=np.float32),
+ 'image': np.zeros((256, 256, 3), dtype=np.uint8),
+ 'depth': np.ones((256, 256), dtype=np.float32),
+ 'depth_mask': np.ones((256, 256), dtype=bool),
+ 'depth_mask_inf': np.zeros((256, 256), dtype=bool),
+ 'label_type': 'invalid',
+ }
+
+ def _sample_batch(self):
+ batch_id = 0
+ last_area = None
+ while True:
+ # Depending on the sample strategy, choose a dataset and a filename
+ batch_id += 1
+ batch = []
+
+ # Sample instances
+ for _ in range(self.batch_size):
+ dataset_name = random.choices(self.dataset_names, weights=self.dataset_weights)[0]
+ filename = random.choice(self.datasets[dataset_name]['filenames'])
+
+ path = Path(self.datasets[dataset_name]['path'], filename)
+
+ instance = {
+ 'batch_id': batch_id,
+ 'seed': random.randint(0, 2 ** 32 - 1),
+ 'dataset': dataset_name,
+ 'filename': filename,
+ 'path': path,
+ 'label_type': self.datasets[dataset_name]['label_type'],
+ }
+ batch.append(instance)
+
+ # Decide the image size for this batch
+ if self.image_size_strategy == 'fixed':
+ width, height = random.choice(self.config['image_sizes'])
+ elif self.image_size_strategy == 'aspect_area':
+ area = random.uniform(*self.area_range)
+ aspect_ratio_ranges = [self.datasets[instance['dataset']].get('aspect_ratio_range', self.aspect_ratio_range) for instance in batch]
+ aspect_ratio_range = (min(r[0] for r in aspect_ratio_ranges), max(r[1] for r in aspect_ratio_ranges))
+ aspect_ratio = random.uniform(*aspect_ratio_range)
+ width, height = int((area * aspect_ratio) ** 0.5), int((area / aspect_ratio) ** 0.5)
+ else:
+ raise ValueError('Invalid image size strategy')
+
+ for instance in batch:
+ instance['width'], instance['height'] = width, height
+
+ yield batch
+
+ def _load_instance(self, instance: dict):
+ try:
+ image = read_image(Path(instance['path'], 'image.jpg'))
+ depth = read_depth(Path(instance['path'], self.datasets[instance['dataset']].get('depth', 'depth.png')))
+ meta = read_json(Path(instance['path'], 'meta.json'))
+ intrinsics = np.array(meta['intrinsics'], dtype=np.float32)
+ data = {
+ 'image': image,
+ 'depth': depth,
+ 'intrinsics': intrinsics
+ }
+ instance.update({
+ **data,
+ })
+ except Exception as e:
+ print(f"Failed to load instance {instance['dataset']}/{instance['filename']} because of exception:", e)
+ instance.update(self.invalid_instance)
+ return instance
+
+ def _process_instance(self, instance: Dict[str, Union[np.ndarray, str, float, bool]]):
+ raw_image, raw_depth, raw_intrinsics, label_type = instance['image'], instance['depth'], instance['intrinsics'], instance['label_type']
+ raw_normal, raw_normal_mask = utils3d.np.depth_map_to_normal_map(raw_depth, intrinsics=raw_intrinsics, mask=np.isfinite(raw_depth), edge_threshold=88)
+ raw_normal = np.where(raw_normal_mask[..., None], raw_normal, np.nan)
+ depth_unit = self.datasets[instance['dataset']].get('depth_unit', None)
+
+ raw_height, raw_width = raw_image.shape[:2]
+ raw_fov_x, raw_fov_y = utils3d.np.intrinsics_to_fov(raw_intrinsics)
+ tgt_width, tgt_height = instance['width'], instance['height']
+ tgt_aspect = tgt_width / tgt_height
+
+ rng = np.random.default_rng(instance['seed'])
+
+ # Sample perspective transformation
+ tgt_intrinsics, R = sample_perspective(
+ raw_intrinsics,
+ tgt_aspect=tgt_aspect,
+ center_augmentation=self.datasets[instance['dataset']].get('center_augmentation', self.center_augmentation),
+ fov_range_absolute=self.datasets[instance['dataset']].get('fov_range_absolute', self.fov_range_absolute),
+ fov_range_relative=self.datasets[instance['dataset']].get('fov_range_relative', self.fov_range_relative),
+ rng=rng
+ )
+
+ # Warp
+ transform = tgt_intrinsics @ R @ np.linalg.inv(raw_intrinsics)
+ # - Warp image
+ tgt_image = warp_perspective(raw_image, transform, tgt_size=(tgt_height, tgt_width), interpolation='lanczos')
+ # - Warp depth
+ depth_edge_mask = utils3d.np.depth_map_edge(raw_depth, mask=np.isfinite(raw_depth), kernel_size=5, ltol=0.01)
+ depth_bilinear_mask = np.isfinite(raw_depth) & ~depth_edge_mask
+ warped_depth_bilinear_mask = warp_perspective(depth_bilinear_mask.astype(np.float32), transform, (tgt_height, tgt_width), interpolation='bilinear')
+ warped_depth_nearest = warp_perspective(raw_depth, transform, (tgt_height, tgt_width), interpolation='nearest', sparse_mask=~np.isnan(raw_depth))
+ warped_depth_bilinear = 1 / warp_perspective(1 / raw_depth, transform, (tgt_height, tgt_width), interpolation='bilinear') # NOTE: Bilinear intepolation in disparity space maintains planar surfaces.
+ warped_depth = np.where(warped_depth_bilinear_mask == 1., warped_depth_bilinear, warped_depth_nearest)
+ tgt_uvhomo = np.concatenate([utils3d.np.uv_map((tgt_height, tgt_width)), np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1)
+ tgt_depth = warped_depth / np.dot(tgt_uvhomo, np.linalg.inv(transform)[2, :])
+ # - Warp normal
+ warped_normal = warp_perspective(raw_normal, transform, (tgt_height, tgt_width), interpolation='bilinear')
+ tgt_normal = warped_normal @ R.T
+
+ # always make sure that mask is not empty
+ if np.isfinite(tgt_depth).sum() / tgt_depth.size < 0.001:
+ tgt_depth = np.ones_like(tgt_depth)
+ instance['label_type'] = 'invalid'
+
+ # Flip augmentation
+ if rng.choice([True, False]):
+ tgt_image = np.flip(tgt_image, axis=1).copy()
+ tgt_depth = np.flip(tgt_depth, axis=1).copy()
+ tgt_normal = np.flip(tgt_normal, axis=1).copy() * [-1, 1, 1]
+ # NOTE: if cx != 0.5, flip intrinsics accordingly.
+
+ # Color augmentation
+ image_augmentation = self.datasets[instance['dataset']].get('image_augmentation', self.image_augmentation)
+ tgt_image = image_color_augmentation(
+ tgt_image,
+ augmentations=image_augmentation,
+ rng=rng,
+ depth=tgt_depth,
+ )
+
+ # Set metric flag if depth is in metric unit
+ if depth_unit is not None:
+ tgt_depth *= depth_unit
+ instance['is_metric'] = True
+ else:
+ instance['is_metric'] = False
+
+ # Clip maximum depth
+ max_depth = np.nanquantile(np.where(np.isfinite(tgt_depth), tgt_depth, np.nan), 0.01) * self.clamp_max_depth
+ tgt_depth = np.where(np.isfinite(tgt_depth), np.clip(tgt_depth, 0, max_depth), tgt_depth)
+
+ tgt_depth_mask_inf = np.isinf(tgt_depth)
+ if self.datasets[instance['dataset']].get('finite_depth_mask', None) == "only_known":
+ tgt_depth_mask_fin = np.isfinite(tgt_depth)
+ else:
+ tgt_depth_mask_fin = ~tgt_depth_mask_inf
+
+ instance.update({
+ 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1),
+ 'depth': torch.from_numpy(tgt_depth).float(),
+ 'depth_mask_fin': torch.from_numpy(tgt_depth_mask_fin).bool(),
+ 'depth_mask_inf': torch.from_numpy(tgt_depth_mask_inf).bool(),
+ "normal": torch.from_numpy(tgt_normal).float(),
+ 'intrinsics': torch.from_numpy(tgt_intrinsics).float(),
+ })
+ return instance
+
+ def _collate_batch(self, instances: List[Dict[str, Any]]):
+ batch = {k: torch.stack([instance[k] for instance in instances], dim=0) for k in ['image', 'depth', 'depth_mask_fin', 'depth_mask_inf', 'normal', 'intrinsics']}
+ batch = {
+ 'label_type': [instance['label_type'] for instance in instances],
+ 'is_metric': [instance['is_metric'] for instance in instances],
+ 'info': [{'dataset': instance['dataset'], 'filename': instance['filename']} for instance in instances],
+ **batch,
+ }
+ return batch
+
+ def get(self) -> Dict[str, Union[torch.Tensor, str]]:
+ return self.pipeline.get()
+
+ def start(self):
+ self.pipeline.start()
+
+ def stop(self):
+ self.pipeline.stop()
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.pipeline.stop()
+ return False
+
+
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/train/losses.py b/lingbotvla/models/vla/vision_models/MoGe/moge/train/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..a568adf72d9191d2309bc59cd5eb066266b1a958
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/train/losses.py
@@ -0,0 +1,293 @@
+from typing import *
+import math
+
+import torch
+import torch.nn.functional as F
+import utils3d
+
+from ..utils.geometry_torch import (
+ weighted_mean,
+ harmonic_mean,
+ geometric_mean,
+ normalized_view_plane_uv,
+ angle_diff_vec3
+)
+from ..utils.alignment import (
+ align_points_scale_z_shift,
+ align_points_scale,
+ align_points_scale_xyz_shift,
+ align_points_z_shift,
+)
+
+
+def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
+ if beta == 0:
+ return err
+ else:
+ return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
+
+
+def affine_invariant_global_loss(
+ pred_points: torch.Tensor,
+ gt_points: torch.Tensor,
+ align_resolution: int = 64,
+ beta: float = 0.0,
+ trunc: float = 1.0,
+ sparsity_aware: bool = False
+):
+ device = pred_points.device
+
+ mask = torch.isfinite(gt_points).all(dim=-1)
+ gt_points = torch.where(mask[..., None], gt_points, 1)
+
+ # Align
+ pred_points_lr, gt_points_lr, lr_mask = utils3d.pt.masked_nearest_resize(pred_points, gt_points, mask=mask, size=(align_resolution, align_resolution))
+ scale, shift = align_points_scale_z_shift(pred_points_lr.flatten(-3, -2), gt_points_lr.flatten(-3, -2), lr_mask.flatten(-2, -1) / gt_points_lr[..., 2].flatten(-2, -1).clamp_min(1e-2), trunc=trunc)
+ valid = scale > 0
+ scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0)
+
+ pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :]
+
+ # Compute loss
+ weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5)
+ weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values
+ loss = _smooth((pred_points - gt_points).abs() * weight[..., None], beta=beta).mean(dim=(-3, -2, -1))
+
+ if sparsity_aware:
+ # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
+ sparsity = mask.float().mean(dim=(-2, -1)) / lr_mask.float().mean(dim=(-2, -1))
+ loss = loss / (sparsity + 1e-7)
+
+ err = (pred_points.detach() - gt_points).norm(dim=-1) / gt_points[..., 2]
+
+ # Record any scalar metric
+ misc = {
+ 'truncated_error': weighted_mean(err.clamp_max(1.0), mask).item(),
+ 'delta': weighted_mean((err < 1).float(), mask).item()
+ }
+
+ return loss, misc, scale.detach()
+
+
+def monitoring(points: torch.Tensor):
+ return {
+ 'std': points.std().item(),
+ }
+
+
+def compute_anchor_sampling_weight(
+ points: torch.Tensor,
+ mask: torch.Tensor,
+ radius_2d: torch.Tensor,
+ radius_3d: torch.Tensor,
+ num_test: int = 64
+) -> torch.Tensor:
+ # Importance sampling to balance the sampled probability of fine strutures.
+ # NOTE: MoGe-1 uses uniform random sampling instead of importance sampling.
+ # This is an incremental trick introduced later than the publication of MoGe-1 paper.
+
+ height, width = points.shape[-3:-1]
+
+ pixel_i, pixel_j = torch.meshgrid(
+ torch.arange(height, device=points.device),
+ torch.arange(width, device=points.device),
+ indexing='ij'
+ )
+
+ test_delta_i = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test]
+ test_delta_j = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test]
+ test_i, test_j = pixel_i[..., None] + test_delta_i, pixel_j[..., None] + test_delta_j # [height, width, num_test]
+ test_mask = (test_i >= 0) & (test_i < height) & (test_j >= 0) & (test_j < width) # [height, width, num_test]
+ test_i, test_j = test_i.clamp(0, height - 1), test_j.clamp(0, width - 1) # [height, width, num_test]
+ test_mask = test_mask & mask[..., test_i, test_j] # [..., height, width, num_test]
+ test_points = points[..., test_i, test_j, :] # [..., height, width, num_test, 3]
+ test_dist = (test_points - points[..., None, :]).norm(dim=-1) # [..., height, width, num_test]
+
+ weight = 1 / ((test_dist <= radius_3d[..., None]) & test_mask).float().sum(dim=-1).clamp_min(1)
+ weight = torch.where(mask, weight, 0)
+ weight = weight / weight.sum(dim=(-2, -1), keepdim=True).add(1e-7) # [..., height, width]
+ return weight
+
+
+def affine_invariant_local_loss(
+ pred_points: torch.Tensor,
+ gt_points: torch.Tensor,
+ focal: torch.Tensor,
+ global_scale: torch.Tensor,
+ level: Literal[4, 16, 64],
+ align_resolution: int = 32,
+ num_patches: int = 16,
+ beta: float = 0.0,
+ trunc: float = 1.0,
+ sparsity_aware: bool = False
+):
+ device, dtype = pred_points.device, pred_points.dtype
+ *batch_shape, height, width, _ = pred_points.shape
+ batch_size = math.prod(batch_shape)
+
+ gt_mask = torch.isfinite(gt_points).all(dim=-1)
+ gt_points = torch.where(gt_mask[..., None], gt_points, 1)
+ pred_points, gt_points, gt_mask, focal, global_scale = pred_points.reshape(-1, height, width, 3), gt_points.reshape(-1, height, width, 3), gt_mask.reshape(-1, height, width), focal.reshape(-1), global_scale.reshape(-1) if global_scale is not None else None
+
+ # Sample patch anchor points indices [num_total_patches]
+ radius_2d = math.ceil(0.5 / level * (height ** 2 + width ** 2) ** 0.5)
+ radius_3d = 0.5 / level / focal * gt_points[..., 2]
+ anchor_sampling_weights = compute_anchor_sampling_weight(gt_points, gt_mask, radius_2d, radius_3d, num_test=64)
+ where_mask = torch.where(gt_mask)
+ random_selection = torch.multinomial(anchor_sampling_weights[where_mask], num_patches * batch_size, replacement=True)
+ patch_batch_idx, patch_anchor_i, patch_anchor_j = [indices[random_selection] for indices in where_mask] # [num_total_patches]
+
+ # Get patch indices [num_total_patches, patch_h, patch_w]
+ patch_i, patch_j = torch.meshgrid(
+ torch.arange(-radius_2d, radius_2d + 1, device=device),
+ torch.arange(-radius_2d, radius_2d + 1, device=device),
+ indexing='ij'
+ )
+ patch_i, patch_j = patch_i + patch_anchor_i[:, None, None], patch_j + patch_anchor_j[:, None, None]
+ patch_mask = (patch_i >= 0) & (patch_i < height) & (patch_j >= 0) & (patch_j < width)
+ patch_i, patch_j = patch_i.clamp(0, height - 1), patch_j.clamp(0, width - 1)
+
+ # Get patch mask and gt patch points
+ gt_patch_anchor_points = gt_points[patch_batch_idx, patch_anchor_i, patch_anchor_j]
+ gt_patch_radius_3d = 0.5 / level / focal[patch_batch_idx] * gt_patch_anchor_points[:, 2]
+ gt_patch_points = gt_points[patch_batch_idx[:, None, None], patch_i, patch_j]
+ gt_patch_dist = (gt_patch_points - gt_patch_anchor_points[:, None, None, :]).norm(dim=-1)
+ patch_mask &= gt_mask[patch_batch_idx[:, None, None], patch_i, patch_j]
+ patch_mask &= gt_patch_dist <= gt_patch_radius_3d[:, None, None]
+
+ # Pick only non-empty patches
+ MINIMUM_POINTS_PER_PATCH = 32
+ nonempty = torch.where(patch_mask.sum(dim=(-2, -1)) >= MINIMUM_POINTS_PER_PATCH)
+ num_nonempty_patches = nonempty[0].shape[0]
+ if num_nonempty_patches == 0:
+ return torch.tensor(0.0, dtype=dtype, device=device), {}
+
+ # Finalize all patch variables
+ patch_batch_idx, patch_i, patch_j = patch_batch_idx[nonempty], patch_i[nonempty], patch_j[nonempty]
+ patch_mask = patch_mask[nonempty] # [num_nonempty_patches, patch_h, patch_w]
+ gt_patch_points = gt_patch_points[nonempty] # [num_nonempty_patches, patch_h, patch_w, 3]
+ gt_patch_radius_3d = gt_patch_radius_3d[nonempty] # [num_nonempty_patches]
+ gt_patch_anchor_points = gt_patch_anchor_points[nonempty] # [num_nonempty_patches, 3]
+ pred_patch_points = pred_points[patch_batch_idx[:, None, None], patch_i, patch_j]
+
+ # Align patch points
+ pred_patch_points_lr, gt_patch_points_lr, patch_lr_mask = utils3d.pt.masked_nearest_resize(pred_patch_points, gt_patch_points, mask=patch_mask, size=(align_resolution, align_resolution))
+ local_scale, local_shift = align_points_scale_xyz_shift(pred_patch_points_lr.flatten(-3, -2), gt_patch_points_lr.flatten(-3, -2), patch_lr_mask.flatten(-2) / gt_patch_radius_3d[:, None].add(1e-7), trunc=trunc)
+ if global_scale is not None:
+ scale_differ = local_scale / global_scale[patch_batch_idx]
+ patch_valid = (scale_differ > 0.1) & (scale_differ < 10.0) & (global_scale > 0)
+ else:
+ patch_valid = local_scale > 0
+ local_scale, local_shift = torch.where(patch_valid, local_scale, 0), torch.where(patch_valid[:, None], local_shift, 0)
+ patch_mask &= patch_valid[:, None, None]
+
+ pred_patch_points = local_scale[:, None, None, None] * pred_patch_points + local_shift[:, None, None, :] # [num_patches_nonempty, patch_h, patch_w, 3]
+
+ # Compute loss
+ gt_mean = harmonic_mean(gt_points[..., 2], gt_mask, dim=(-2, -1))
+ patch_weight = patch_mask.float() / gt_patch_points[..., 2].clamp_min(0.1 * gt_mean[patch_batch_idx, None, None]) # [num_patches_nonempty, patch_h, patch_w]
+ loss = _smooth((pred_patch_points - gt_patch_points).abs() * patch_weight[..., None], beta=beta).mean(dim=(-3, -2, -1)) # [num_patches_nonempty]
+
+ if sparsity_aware:
+ # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1.
+ sparsity = patch_mask.float().mean(dim=(-2, -1)) / patch_lr_mask.float().mean(dim=(-2, -1))
+ loss = loss / (sparsity + 1e-7)
+ loss = torch.scatter_reduce(torch.zeros(batch_size, dtype=dtype, device=device), dim=0, index=patch_batch_idx, src=loss, reduce='sum') / num_patches
+ loss = loss.reshape(batch_shape)
+
+ err = (pred_patch_points.detach() - gt_patch_points).norm(dim=-1) / gt_patch_radius_3d[..., None, None]
+
+ # Record any scalar metric
+ misc = {
+ 'truncated_error': weighted_mean(err.clamp_max(1), patch_mask).item(),
+ 'delta': weighted_mean((err < 1).float(), patch_mask).item()
+ }
+
+ return loss, misc
+
+
+def normal_loss(points: torch.Tensor, gt_points: torch.Tensor) -> torch.Tensor:
+ device, dtype = points.device, points.dtype
+ height, width = points.shape[-3:-1]
+
+ mask = torch.isfinite(gt_points).all(dim=-1)
+ gt_points = torch.where(mask[..., None], gt_points, 1)
+
+ leftup, rightup, leftdown, rightdown = points[..., :-1, :-1, :], points[..., :-1, 1:, :], points[..., 1:, :-1, :], points[..., 1:, 1:, :]
+ upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1)
+ leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1)
+ downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1)
+ rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1)
+
+ gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = gt_points[..., :-1, :-1, :], gt_points[..., :-1, 1:, :], gt_points[..., 1:, :-1, :], gt_points[..., 1:, 1:, :]
+ gt_upxleft = torch.cross(gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1)
+ gt_leftxdown = torch.cross(gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1)
+ gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1)
+ gt_rightxup = torch.cross(gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1)
+
+ mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = mask[..., :-1, :-1], mask[..., :-1, 1:], mask[..., 1:, :-1], mask[..., 1:, 1:]
+ mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown
+ mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup
+ mask_downxright = mask_leftdown & mask_rightup & mask_leftup
+ mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown
+
+ MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3)
+
+ loss = mask_upxleft * _smooth(angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
+ + mask_leftxdown * _smooth(angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
+ + mask_downxright * _smooth(angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \
+ + mask_rightxup * _smooth(angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
+
+ loss = loss.mean() / (4 * max(points.shape[-3:-1]))
+
+ return loss, {}
+
+
+def edge_loss(points: torch.Tensor, gt_points: torch.Tensor) -> torch.Tensor:
+ device, dtype = points.device, points.dtype
+ height, width = points.shape[-3:-1]
+
+ mask = torch.isfinite(gt_points).all(dim=-1)
+ gt_points = torch.where(mask[..., None], gt_points, 1)
+
+ dx = points[..., :-1, :, :] - points[..., 1:, :, :]
+ dy = points[..., :, :-1, :] - points[..., :, 1:, :]
+
+ gt_dx = gt_points[..., :-1, :, :] - gt_points[..., 1:, :, :]
+ gt_dy = gt_points[..., :, :-1, :] - gt_points[..., :, 1:, :]
+
+ mask_dx = mask[..., :-1, :] & mask[..., 1:, :]
+ mask_dy = mask[..., :, :-1] & mask[..., :, 1:]
+
+ MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(0.1), math.radians(90), math.radians(3)
+
+ loss_dx = mask_dx * _smooth(angle_diff_vec3(dx, gt_dx).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
+ loss_dy = mask_dy * _smooth(angle_diff_vec3(dy, gt_dy).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD)
+ loss = (loss_dx.mean(dim=(-2, -1)) + loss_dy.mean(dim=(-2, -1))) / (2 * max(points.shape[-3:-1]))
+
+ return loss, {}
+
+
+def mask_l2_loss(pred_mask: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor:
+ loss = gt_mask_neg.float() * pred_mask.square() + gt_mask_pos.float() * (1 - pred_mask).square()
+ loss = loss.mean(dim=(-2, -1))
+ return loss, {}
+
+
+def mask_bce_loss(pred_mask_prob: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor:
+ loss = (gt_mask_pos | gt_mask_neg) * F.binary_cross_entropy(pred_mask_prob, gt_mask_pos.float(), reduction='none')
+ loss = loss.mean(dim=(-2, -1))
+ return loss, {}
+
+
+def metric_scale_loss(scale_pred: torch.Tensor, scale_gt: torch.Tensor):
+ valid = scale_gt > 0
+ return torch.where(valid, F.mse_loss(scale_pred.log(), torch.where(valid, scale_gt.log(), 0), reduction='none'), 0), {}
+
+
+def normal_map_loss(pred_normal: torch.Tensor, gt_normal: torch.Tensor) -> torch.Tensor:
+ mask = torch.isfinite(gt_normal).all(dim=-1)
+ gt_normal = torch.where(mask[..., None], gt_normal, 1)
+
+ loss = (mask * utils3d.pt.angle_between(pred_normal, gt_normal).square()).mean(dim=(-2, -1))
+ return loss, {}
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/train/utils.py b/lingbotvla/models/vla/vision_models/MoGe/moge/train/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f21e00876b927991381bf2f777a68b02c5b38cc
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/train/utils.py
@@ -0,0 +1,57 @@
+from typing import *
+import fnmatch
+
+import sympy
+import torch
+import torch.nn as nn
+
+
+def any_match(s: str, patterns: List[str]) -> bool:
+ return any(fnmatch.fnmatch(s, pat) for pat in patterns)
+
+
+def build_optimizer(model: nn.Module, optimizer_config: Dict[str, Any]) -> torch.optim.Optimizer:
+ named_param_groups = [
+ {
+ k: p for k, p in model.named_parameters() if any_match(k, param_group_config['params']['include']) and not any_match(k, param_group_config['params'].get('exclude', []))
+ } for param_group_config in optimizer_config['params']
+ ]
+ excluded_params = [k for k, p in model.named_parameters() if p.requires_grad and not any(k in named_params for named_params in named_param_groups)]
+ assert len(excluded_params) == 0, f'The following parameters require grad but are excluded from the optimizer: {excluded_params}'
+ optimizer_cls = getattr(torch.optim, optimizer_config['type'])
+ optimizer = optimizer_cls([
+ {
+ **param_group_config,
+ 'params': list(params.values()),
+ } for param_group_config, params in zip(optimizer_config['params'], named_param_groups)
+ ])
+ return optimizer
+
+
+def parse_lr_lambda(s: str) -> Callable[[int], float]:
+ epoch = sympy.symbols('epoch')
+ lr_lambda = sympy.sympify(s)
+ return sympy.lambdify(epoch, lr_lambda, 'math')
+
+
+def build_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]) -> torch.optim.lr_scheduler._LRScheduler:
+ if scheduler_config['type'] == "SequentialLR":
+ child_schedulers = [
+ build_lr_scheduler(optimizer, child_scheduler_config)
+ for child_scheduler_config in scheduler_config['params']['schedulers']
+ ]
+ return torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=child_schedulers, milestones=scheduler_config['params']['milestones'])
+ elif scheduler_config['type'] == "LambdaLR":
+ lr_lambda = scheduler_config['params']['lr_lambda']
+ if isinstance(lr_lambda, str):
+ lr_lambda = parse_lr_lambda(lr_lambda)
+ elif isinstance(lr_lambda, list):
+ lr_lambda = [parse_lr_lambda(l) for l in lr_lambda]
+ return torch.optim.lr_scheduler.LambdaLR(
+ optimizer,
+ lr_lambda=lr_lambda,
+ )
+ else:
+ scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_config['type'])
+ scheduler = scheduler_cls(optimizer, **scheduler_config.get('params', {}))
+ return scheduler
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/__init__.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/alignment.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d6bb78766ec1a43a89a4fc931b64f70c5201e2d
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/alignment.py
@@ -0,0 +1,416 @@
+from typing import *
+import math
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.types
+import utils3d
+
+
+def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min:
+ "Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`."
+ shape = src.shape[:dim] + (size,) + src.shape[dim + 1:]
+ minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False)
+ minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index))
+ indices = torch.full(shape, -1, dtype=torch.long, device=src.device)
+ indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim]
+ return torch.return_types.min((minimum, indices))
+
+
+def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
+ batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
+ n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
+ splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
+ splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
+ results = []
+ for i in range(n_chunks):
+ chunk_args = tuple(arg[i] for arg in splited_args)
+ chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
+ results.append(fn(*chunk_args, **chunk_kwargs))
+
+ if isinstance(results[0], tuple):
+ return tuple(torch.cat(r, dim=0) for r in zip(*results))
+ else:
+ return torch.cat(results, dim=0)
+
+
+def _pad_inf(x_: torch.Tensor):
+ return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1)
+
+
+def _pad_cumsum(cumsum: torch.Tensor):
+ return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1)
+
+
+def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float):
+ return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1)
+
+
+def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
+ """
+ If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`.
+
+ w_i must be >= 0.
+
+ ### Parameters:
+ - `x`: tensor of shape (..., n)
+ - `y`: tensor of shape (..., n)
+ - `w`: tensor of shape (..., n)
+ - `trunc`: optional, float or tensor of shape (..., n) or None
+
+ ### Returns:
+ - `a`: tensor of shape (...), differentiable
+ - `loss`: tensor of shape (...), value of loss function at `a`, detached
+ - `index`: tensor of shape (...), where a = y[idx] / x[idx]
+ """
+ if trunc is None:
+ x, y, w = torch.broadcast_tensors(x, y, w)
+ sign = torch.sign(x)
+ x, y = x * sign, y * sign
+ y_div_x = y / x.clamp_min(eps)
+ y_div_x, argsort = y_div_x.sort(dim=-1)
+
+ wx = torch.gather(x * w, dim=-1, index=argsort)
+ derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True)
+ search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1)
+
+ a = y_div_x.gather(dim=-1, index=search).squeeze(-1)
+ index = argsort.gather(dim=-1, index=search).squeeze(-1)
+ loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1)
+
+ else:
+ # Reshape to (batch_size, n) for simplicity
+ x, y, w = torch.broadcast_tensors(x, y, w)
+ batch_shape = x.shape[:-1]
+ batch_size = math.prod(batch_shape)
+ x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1])
+
+ sign = torch.sign(x)
+ x, y = x * sign, y * sign
+ wx, wy = w * x, w * y
+ xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering
+
+ y_div_x = A = y / x.clamp_min(eps)
+ B = (wy - trunc) / wx.clamp_min(eps)
+ C = (wy + trunc) / wx.clamp_min(eps)
+ with torch.no_grad():
+ # Caculate prefix sum by orders of A, B, C
+ A, A_argsort = A.sort(dim=-1)
+ Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1)
+ A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases.
+
+ B, B_argsort = B.sort(dim=-1)
+ Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1)
+ B, Q_B = _pad_inf(B), _pad_cumsum(Q_B)
+
+ C, C_argsort = C.sort(dim=-1)
+ Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1)
+ C, Q_C = _pad_inf(C), _pad_cumsum(Q_C)
+
+ # Caculate left and right derivative of A
+ j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1)
+ j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1)
+ j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1)
+ left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
+ j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1)
+ j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1)
+ j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1)
+ right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C)
+
+ # Find extrema
+ is_extrema = (left_derivative < 0) & (right_derivative >= 0)
+ is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema.
+ where_extrema_batch, where_extrema_index = torch.where(is_extrema)
+
+ # Calculate objective value at extrema
+ extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,)
+ MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G)
+ SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1]
+ extrema_value = torch.cat([
+ _compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc)
+ for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE))
+ ]) # (num_extrema,)
+
+ # Find minima among corresponding extrema
+ minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,)
+ index = where_extrema_index[indices]
+
+ a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps)
+ a = a.reshape(batch_shape)
+ loss = minima.reshape(batch_shape)
+ index = index.reshape(batch_shape)
+
+ return a, loss, index
+
+
+def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights.
+
+ ### Parameters:
+ - `depth_src: torch.Tensor` of shape (..., N)
+ - `depth_tgt: torch.Tensor` of shape (..., N)
+
+ """
+ scale, _, _ = align(depth_src, depth_tgt, weight, trunc)
+
+ return scale
+
+
+def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights.
+
+ ### Parameters:
+ - `depth_src: torch.Tensor` of shape (..., N)
+ - `depth_tgt: torch.Tensor` of shape (..., N)
+ - `weight: torch.Tensor` of shape (..., N)
+ - `trunc: float` or tensor of shape (..., N) or None
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (...).
+ """
+ dtype, device = depth_src.dtype, depth_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1]
+ batch_size = math.prod(batch_shape)
+ depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n)
+
+ # Here, we take anchors only for non-zero weights.
+ # Although the results will be still correct even anchor points have zero weight,
+ # it is wasting computation and may cause instability in some cases, e.g. too many extrema.
+ anchors_where_batch, anchors_where_n = torch.where(weight > 0)
+
+ # Stop gradient when solving optimal anchors
+ with torch.no_grad():
+ depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors)
+ depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors)
+
+ depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n)
+ depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n)
+ weight_anchored = weight[anchors_where_batch, :] # (anchors, n)
+
+ scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors)
+
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,)
+
+ # Reproduce by indexing for shorter compute graph
+ index_1 = anchors_where_n[index_anchor] # (batch_size,)
+ index_2 = index[index_anchor] # (batch_size,)
+
+ tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1)
+ tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7)
+ shift = tgt_1 - scale * src_1
+
+ scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape)
+
+ return scale, shift
+
+def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12):
+ """
+ Align `depth_src` to `depth_tgt` with given constant weights using IRLS.
+ """
+ dtype, device = depth_src.dtype, depth_src.device
+
+ w = weight
+ x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1)
+ y = depth_tgt
+
+ for i in range(max_iter):
+ beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1)
+ w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps)
+
+ return beta[..., 0], beta[..., 1]
+
+
+def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weight: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it.
+ - `b: torch.Tensor` of shape (...)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc)
+
+ return scale
+
+
+def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None):
+ """
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros.
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
+ batch_size = math.prod(batch_shape)
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
+
+ # Take anchors
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
+ with torch.no_grad():
+ zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype)
+ points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
+ points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3)
+
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
+
+ # Solve optimal scale and shift for each anchor
+ MAX_ELEMENTS = 2 ** 20
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
+
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
+
+ # Reproduce by indexing for shorter compute graph
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
+
+ zeros = torch.zeros((batch_size, n), device=device, dtype=dtype)
+ points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1)
+ tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
+ tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
+ shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
+
+ return scale, shift
+
+
+def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift.
+ It is similar to `align_affine` but scale and shift are applied to different dimensions.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ # Flatten batch dimensions for simplicity
+ batch_shape, n = points_src.shape[:-2], points_src.shape[-2]
+ batch_size = math.prod(batch_shape)
+ points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n)
+
+ # Take anchors
+ anchor_where_batch, anchor_where_n = torch.where(weight > 0)
+
+ with torch.no_grad():
+ points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3)
+ points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3)
+
+ points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3)
+ points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3)
+ weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3)
+
+ # Solve optimal scale and shift for each anchor
+ MAX_ELEMENTS = 2 ** 20
+ scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,)
+
+ # Get optimal scale and shift for each batch element
+ loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,)
+
+ index_2 = index[index_anchor] # (batch_size,) [0, 3n)
+ index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n)
+
+ src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1)
+ src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1)
+
+ scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0)
+ shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2)
+
+ scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3)
+
+ return scale, shift
+
+
+def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc)
+ shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)
+
+ return shift
+
+
+def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6):
+ """
+ Align `points_src` to `points_tgt` with respect to a Z-axis shift.
+
+ ### Parameters:
+ - `points_src: torch.Tensor` of shape (..., N, 3)
+ - `points_tgt: torch.Tensor` of shape (..., N, 3)
+ - `weights: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `scale: torch.Tensor` of shape (...).
+ - `shift: torch.Tensor` of shape (..., 3)
+ """
+ dtype, device = points_src.dtype, points_src.device
+
+ shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc)
+
+ return shift
+
+
+def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares.
+
+ ### Parameters:
+ - `x: torch.Tensor` of shape (..., N)
+ - `y: torch.Tensor` of shape (..., N)
+ - `w: torch.Tensor` of shape (..., N)
+
+ ### Returns:
+ - `a: torch.Tensor` of shape (...,)
+ - `b: torch.Tensor` of shape (...,)
+ """
+ w_sqrt = torch.ones_like(x) if w is None else w.sqrt()
+ A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1)
+ B = (w_sqrt * y)[..., None]
+ a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1)
+ return a, b
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/data_augmentation.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/data_augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fc4c9da8b64fa3e020436796ba63596711a275c
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/data_augmentation.py
@@ -0,0 +1,250 @@
+import os
+import json
+import time
+import random
+from typing import *
+import itertools
+from numbers import Number
+import io
+
+import numpy as np
+import cv2
+from PIL import Image
+import torch
+import torchvision.transforms.v2.functional as TF
+import utils3d
+from scipy.signal import fftconvolve
+
+from ..utils.geometry_numpy import harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy
+
+
+def sample_perspective(
+ src_intrinsics: np.ndarray,
+ tgt_aspect: float,
+ center_augmentation: float,
+ fov_range_absolute: Tuple[float, float],
+ fov_range_relative: Tuple[float, float],
+ rng: np.random.Generator = None
+) -> Tuple[np.ndarray, np.ndarray]:
+ raw_horizontal, raw_vertical = abs(1.0 / src_intrinsics[0, 0]), abs(1.0 / src_intrinsics[1, 1])
+ raw_fov_x, raw_fov_y = utils3d.np.intrinsics_to_fov(src_intrinsics)
+
+ # 1. set target fov
+ fov_range_absolute_min, fov_range_absolute_max = fov_range_absolute
+ fov_range_relative_min, fov_range_relative_max = fov_range_relative
+ tgt_fov_x_min = min(fov_range_relative_min * raw_fov_x, utils3d.focal_to_fov(utils3d.fov_to_focal(fov_range_relative_min * raw_fov_y) / tgt_aspect))
+ tgt_fov_x_max = min(fov_range_relative_max * raw_fov_x, utils3d.focal_to_fov(utils3d.fov_to_focal(fov_range_relative_max * raw_fov_y) / tgt_aspect))
+ tgt_fov_x_min, tgt_fov_max = max(np.deg2rad(fov_range_absolute_min), tgt_fov_x_min), min(np.deg2rad(fov_range_absolute_max), tgt_fov_x_max)
+ tgt_fov_x = rng.uniform(min(tgt_fov_x_min, tgt_fov_x_max), tgt_fov_x_max)
+ tgt_fov_y = utils3d.focal_to_fov(utils3d.np.fov_to_focal(tgt_fov_x) * tgt_aspect)
+
+ # 2. set target image center (principal point) and the corresponding z-direction in raw camera space
+ center_dtheta = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_x - tgt_fov_x)
+ center_dphi = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_y - tgt_fov_y)
+ cu, cv = 0.5 + 0.5 * np.tan(center_dtheta) / np.tan(raw_fov_x / 2), 0.5 + 0.5 * np.tan(center_dphi) / np.tan(raw_fov_y / 2)
+ direction = utils3d.np.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=src_intrinsics)[0]
+
+ # 3. obtain the rotation matrix for homography warping (new_ext = R * old_ext)
+ R = utils3d.np.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32))
+
+ # 4. shrink the target view to fit into the warped image
+ corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
+ corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(src_intrinsics).T @ R.T) # corners in viewport's camera plane
+ corners = corners[:, :2] / corners[:, 2:3]
+ tgt_horizontal, tgt_vertical = np.tan(tgt_fov_x / 2) * 2, np.tan(tgt_fov_y / 2) * 2
+ warp_horizontal, warp_vertical = float('inf'), float('inf')
+ for i in range(4):
+ intersection, _ = utils3d.np.ray_intersection(
+ np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]),
+ corners[i - 1], corners[i] - corners[i - 1],
+ )
+ warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min())
+ tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical)
+
+ # 5. obtain the target intrinsics
+ fx, fy = 1 / tgt_horizontal, 1 / tgt_vertical
+ tgt_intrinsics = utils3d.np.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32)
+
+ return tgt_intrinsics, R
+
+
+def warp_perspective(
+ src_map: np.ndarray = None,
+ transform: np.ndarray = None,
+ tgt_size: Tuple[int, int] = None,
+ interpolation: Literal['nearest', 'bilinear', 'lanczos'] = 'nearest',
+ sparse_mask: np.ndarray = None,
+):
+ """Perspective warping with careful resampling.
+ - For `lanczos`, use PIL to resize first to reduce aliasing.
+ - For `nearest` with sparse input, use mask-aware nearest resize to avoid losing points.
+ - For `bilinear` or `nearest` with dense input, directly use cv2.remap.
+
+ - `transform` is the matrix that transforms homogeneous pixel coordinates of source image to those of target image, i.e., `p_tgt = transform @ p_src`.
+ """
+
+ tgt_height, tgt_width = tgt_size
+ src_height, src_width = src_map.shape[:2]
+
+ # source to target transform
+ transform_pixel = np.array([[tgt_width, 0, -0.5], [0, tgt_height, -0.5], [0, 0, 1]], dtype=np.float32) @ transform @ np.array([[1 / src_width, 0, 0.5 / src_width], [0, 1 / src_height, 0.5 / src_height], [0, 0, 1]], dtype=np.float32)
+ # Get scale factor at the target center
+ w = np.dot(np.linalg.inv(transform_pixel)[2, :], np.array([tgt_width / 2, tgt_height / 2, 1], dtype=np.float32))
+ scale_x, scale_y = w * np.linalg.norm(transform_pixel[:2, :2], axis=0)
+
+ if interpolation == 'lanczos' and (scale_x < 0.8 or scale_y < 0.8):
+ # If lanczos & downsampling, use PIL to resize first to reduce aliasing
+ src_height, src_width = max(round(src_height * scale_y * 1.25), 16), max(round(src_width * scale_x * 1.25), 16)
+ src_map = np.array(Image.fromarray(src_map).resize((src_width, src_height), Image.Resampling.LANCZOS))
+ elif interpolation == 'nearest' and sparse_mask is not None and (scale_x < 1 or scale_y < 1):
+ # If nearest and sparse, use mask-aware nearest resize first to avoid losing points
+ src_height, src_width = max(round(src_height * scale_y), 16), max(round(src_width * scale_x), 16)
+ src_map, _ = utils3d.np.masked_nearest_resize(src_map, mask=sparse_mask, size=(src_height, src_width))
+
+ # Recompute the pixel-space transform after resizing
+ transform_pixel = np.array([[tgt_width, 0, -0.5], [0, tgt_height, -0.5], [0, 0, 1]], dtype=np.float32) @ transform @ np.array([[1 / src_width, 0, 0.5 / src_width], [0, 1 / src_height, 0.5 / src_height], [0, 0, 1]], dtype=np.float32)
+
+ # Remap
+ cv2_interpolation = {'nearest': cv2.INTER_NEAREST, 'bilinear': cv2.INTER_LINEAR, 'lanczos': cv2.INTER_LANCZOS4}[interpolation]
+ tgt_map = cv2.warpPerspective(src_map, transform_pixel, (tgt_width, tgt_height), flags=cv2_interpolation)
+
+ return tgt_map
+
+
+def image_color_augmentation(image: np.ndarray, augmentations: List[Dict[str, Any]], rng: np.random.Generator = None, depth: np.ndarray = None):
+ height, width = image.shape[:2]
+ if rng is None:
+ rng = np.random.default_rng()
+ if 'jittering' in augmentations:
+ image = torch.from_numpy(image).permute(2, 0, 1)
+ image = TF.adjust_brightness(image, rng.uniform(0.9, 1.1))
+ image = TF.adjust_contrast(image, rng.uniform(0.9, 1.1))
+ image = TF.adjust_saturation(image, rng.uniform(0.9, 1.1))
+ image = TF.adjust_hue(image, rng.uniform(-0.05, 0.05))
+ image = TF.adjust_gamma(image, rng.uniform(0.9, 1.1))
+ image = image.permute(1, 2, 0).numpy()
+ if 'dof' in augmentations:
+ assert depth is not None, 'Depth map is required for DOF augmentation'
+ if rng.uniform() < 0.5:
+ dof_strength = rng.integers(12)
+ disp = 1 / depth
+ finite_mask = np.isfinite(depth)
+ disp_min, disp_max = disp[finite_mask].min(), disp[finite_mask].max()
+ disp = cv2.inpaint(np.nan_to_num(disp, nan=1), np.isnan(disp).astype(np.uint8), 3, cv2.INPAINT_TELEA).clip(0, disp_max)
+ dof_focus = rng.uniform(disp_min, disp_max)
+ image = depth_of_field(image, disp, dof_focus, dof_strength)
+ if 'shot_noise' in augmentations:
+ if rng.uniform() < 0.5:
+ k = np.exp(rng.uniform(np.log(100), np.log(10000))) / 255
+ image = (rng.poisson(image * k) / k).clip(0, 255).astype(np.uint8)
+ if 'blurring' in augmentations:
+ if rng.uniform() < 0.5:
+ ratio = rng.uniform(0.25, 1)
+ image = cv2.resize(cv2.resize(image, (int(width * ratio), int(height * ratio)), interpolation=cv2.INTER_AREA), (width, height), interpolation=rng.choice([cv2.INTER_LINEAR_EXACT, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]))
+ if 'jpeg_loss' in augmentations:
+ if rng.uniform() < 0.5:
+ image = cv2.imdecode(cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, rng.integers(20, 100)])[1], cv2.IMREAD_COLOR)
+
+ return image
+
+
+
+def disk_kernel(radius: int) -> np.ndarray:
+ """
+ Generate disk kernel with given radius.
+
+ Args:
+ radius (int): Radius of the disk (in pixels).
+
+ Returns:
+ np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel.
+ """
+ # Create coordinate grid centered at (0,0)
+ L = np.arange(-radius, radius + 1)
+ X, Y = np.meshgrid(L, L)
+ # Generate disk: region inside circle with radius R is 1
+ kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32)
+ # Normalize the kernel
+ kernel /= np.sum(kernel)
+ return kernel
+
+
+def disk_blur(image: np.ndarray, radius: int) -> np.ndarray:
+ """
+ Apply disk blur to an image using FFT convolution.
+
+ Args:
+ image (np.ndarray): Input image, can be grayscale or color.
+ radius (int): Blur radius (in pixels).
+
+ Returns:
+ np.ndarray: Blurred image.
+ """
+ if radius == 0:
+ return image
+ kernel = disk_kernel(radius)
+ if image.ndim == 2:
+ blurred = fftconvolve(image, kernel, mode='same')
+ elif image.ndim == 3:
+ channels = []
+ for i in range(image.shape[2]):
+ blurred_channel = fftconvolve(image[..., i], kernel, mode='same')
+ channels.append(blurred_channel)
+ blurred = np.stack(channels, axis=-1)
+ else:
+ raise ValueError("Image must be 2D or 3D.")
+ return blurred
+
+
+def depth_of_field(
+ img: np.ndarray,
+ disp: np.ndarray,
+ focus_disp : float,
+ max_blur_radius : int = 10,
+) -> np.ndarray:
+ """
+ Apply depth of field effect to an image.
+
+ Args:
+ img (numpy.ndarray): (H, W, 3) input image.
+ depth (numpy.ndarray): (H, W) depth map of the scene.
+ focus_depth (float): Focus depth of the lens.
+ strength (float): Strength of the depth of field effect.
+ max_blur_radius (int): Maximum blur radius (in pixels).
+
+ Returns:
+ numpy.ndarray: (H, W, 3) output image with depth of field effect applied.
+ """
+ # Precalculate dialated depth map for each blur radius
+ max_disp = np.max(disp)
+ disp = disp / max_disp
+ focus_disp = focus_disp / max_disp
+ dilated_disp = []
+ for radius in range(max_blur_radius + 1):
+ dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * radius + 1, 2 * radius + 1)), iterations=1))
+
+ # Determine the blur radius for each pixel based on the depth map
+ blur_radii = np.clip(np.abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
+ for radius in range(max_blur_radius + 1):
+ dialted_blur_radii = np.clip(np.abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
+ mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp)
+ blur_radii[mask] = dialted_blur_radii[mask]
+ blur_radii = np.clip(blur_radii, 0, max_blur_radius)
+ blur_radii = cv2.blur(blur_radii, (5, 5))
+
+ # Precalculate the blured image for each blur radius
+ unique_radii = np.unique(blur_radii)
+ precomputed = {}
+ for radius in range(max_blur_radius + 1):
+ if radius not in unique_radii:
+ continue
+ precomputed[radius] = disk_blur(img, radius)
+
+ # Composit the blured image for each pixel
+ output = np.zeros_like(img)
+ for r in unique_radii:
+ mask = blur_radii == r
+ output[mask] = precomputed[r][mask]
+
+ return output
+
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/download.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..886edbccc81cc0c3daed4d858f641097bdfceee2
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/download.py
@@ -0,0 +1,55 @@
+from pathlib import Path
+from typing import *
+import requests
+
+from tqdm import tqdm
+
+
+__all__ = ["download_file", "download_bytes"]
+
+
+def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None:
+ # Ensure headers is a dict if not provided
+ headers = headers or {}
+
+ # Initialize local variables
+ file_path = Path(filepath)
+ downloaded_bytes = 0
+
+ # Check if we should resume the download
+ if resume and file_path.exists():
+ downloaded_bytes = file_path.stat().st_size
+ headers['Range'] = f"bytes={downloaded_bytes}-"
+
+ # Make a GET request to fetch the file
+ with requests.get(url, stream=True, headers=headers) as response:
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
+
+ # Calculate the total size to download
+ total_size = downloaded_bytes + int(response.headers.get('content-length', 0))
+
+ # Display a progress bar while downloading
+ with (
+ tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar,
+ open(file_path, 'ab') as file,
+ ):
+ # Set the initial position of the progress bar
+ pbar.update(downloaded_bytes)
+
+ # Write the content to the file in chunks
+ for chunk in response.iter_content(chunk_size=4096):
+ file.write(chunk)
+ pbar.update(len(chunk))
+
+
+def download_bytes(url: str, headers: dict = None) -> bytes:
+ # Ensure headers is a dict if not provided
+ headers = headers or {}
+
+ # Make a GET request to fetch the file
+ with requests.get(url, stream=True, headers=headers) as response:
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
+
+ # Read the content of the response
+ return response.content
+
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/geometry_numpy.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/geometry_numpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..99de45cd34c85f6ebc73637321f1b2be576b75bb
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/geometry_numpy.py
@@ -0,0 +1,261 @@
+from typing import *
+from functools import partial
+import math
+
+import cv2
+import numpy as np
+from scipy.signal import fftconvolve
+import numpy as np
+import utils3d
+
+from .tools import timeit
+
+
+def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
+ if w is None:
+ return np.mean(x, axis=axis)
+ else:
+ w = w.astype(x.dtype)
+ return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
+
+
+def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
+ if w is None:
+ return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
+ else:
+ w = w.astype(x.dtype)
+ return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
+
+
+def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
+ v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
+ u, v = np.meshgrid(u, v, indexing='xy')
+ uv = np.stack([u, v], axis=-1)
+ return uv
+
+
+def focal_to_fov_numpy(focal: np.ndarray):
+ return 2 * np.arctan(0.5 / focal)
+
+
+def fov_to_focal_numpy(fov: np.ndarray):
+ return 0.5 / np.tan(fov / 2)
+
+
+def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
+ fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
+ return fov_x, fov_y
+
+
+def point_map_to_depth_legacy_numpy(points: np.ndarray):
+ height, width = points.shape[-3:-1]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+ uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
+ _, uv = np.broadcast_arrays(points[..., :2], uv)
+
+ # Solve least squares problem
+ b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
+ A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
+
+ M = A.swapaxes(-2, -1) @ A
+ solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
+ focal, shift = solution
+
+ depth = points[..., 2] + shift[..., None, None]
+ fov_x = np.arctan(width / diagonal / focal) * 2
+ fov_y = np.arctan(height / diagonal / focal) * 2
+ return depth, fov_x, fov_y, shift
+
+
+def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[: , None]
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+ err = (f * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ xy_proj = xy / (z + optim_shift)[: , None]
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+
+ return optim_shift, optim_focal
+
+
+def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[: , None]
+ err = (focal * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ return optim_shift
+
+
+def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
+ import cv2
+ assert points.shape[-1] == 3, "Points should (H, W, 3)"
+
+ height, width = points.shape[-3], points.shape[-2]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+
+ uv = normalized_view_plane_uv_numpy(width=width, height=height)
+
+ if mask is None:
+ points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
+ uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
+ else:
+ points_lr, uv_lr, mask_lr = utils3d.np.masked_nearest_resize(points, uv, mask=mask, size=downsample_size)
+
+ if points_lr.size < 2:
+ return 1., 0.
+
+ if focal is None:
+ shift,focal = solve_optimal_focal_shift(uv_lr, points_lr)
+ else:
+ shift = solve_optimal_shift(uv_lr, points_lr, focal)
+
+ return focal, shift
+
+
+def norm3d(x: np.ndarray) -> np.ndarray:
+ "Faster `np.linalg.norm(x, axis=-1)` for 3D vectors"
+ return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2]))
+
+
+def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, thickness: int = 1, tol: float = 0.1):
+ disp = np.where(mask, 1 / depth, 0)
+ disp_pad = np.pad(disp, (thickness, thickness), constant_values=0)
+ mask_pad = np.pad(mask, (thickness, thickness), constant_values=False)
+ kernel_size = 2 * thickness + 1
+ disp_window = utils3d.np.sliding_window(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
+ mask_window = utils3d.np.sliding_window(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
+
+ disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1))
+ fg_edge_mask = mask & (disp > (1 + tol) * disp_mean)
+ bg_edge_mask = mask & (disp_mean > (1 + tol) * disp)
+
+ edge_mask = (cv2.dilate(fg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) \
+ & (cv2.dilate(bg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0)
+
+ return edge_mask
+
+
+def disk_kernel(radius: int) -> np.ndarray:
+ """
+ Generate disk kernel with given radius.
+
+ Args:
+ radius (int): Radius of the disk (in pixels).
+
+ Returns:
+ np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel.
+ """
+ # Create coordinate grid centered at (0,0)
+ L = np.arange(-radius, radius + 1)
+ X, Y = np.meshgrid(L, L)
+ # Generate disk: region inside circle with radius R is 1
+ kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32)
+ # Normalize the kernel
+ kernel /= np.sum(kernel)
+ return kernel
+
+
+def disk_blur(image: np.ndarray, radius: int) -> np.ndarray:
+ """
+ Apply disk blur to an image using FFT convolution.
+
+ Args:
+ image (np.ndarray): Input image, can be grayscale or color.
+ radius (int): Blur radius (in pixels).
+
+ Returns:
+ np.ndarray: Blurred image.
+ """
+ if radius == 0:
+ return image
+ kernel = disk_kernel(radius)
+ if image.ndim == 2:
+ blurred = fftconvolve(image, kernel, mode='same')
+ elif image.ndim == 3:
+ channels = []
+ for i in range(image.shape[2]):
+ blurred_channel = fftconvolve(image[..., i], kernel, mode='same')
+ channels.append(blurred_channel)
+ blurred = np.stack(channels, axis=-1)
+ else:
+ raise ValueError("Image must be 2D or 3D.")
+ return blurred
+
+
+def depth_of_field(
+ img: np.ndarray,
+ disp: np.ndarray,
+ focus_disp : float,
+ max_blur_radius : int = 10,
+) -> np.ndarray:
+ """
+ Apply depth of field effect to an image.
+
+ Args:
+ img (numpy.ndarray): (H, W, 3) input image.
+ depth (numpy.ndarray): (H, W) depth map of the scene.
+ focus_depth (float): Focus depth of the lens.
+ strength (float): Strength of the depth of field effect.
+ max_blur_radius (int): Maximum blur radius (in pixels).
+
+ Returns:
+ numpy.ndarray: (H, W, 3) output image with depth of field effect applied.
+ """
+ # Precalculate dialated depth map for each blur radius
+ max_disp = np.max(disp)
+ disp = disp / max_disp
+ focus_disp = focus_disp / max_disp
+ dilated_disp = []
+ for radius in range(max_blur_radius + 1):
+ dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1))
+
+ # Determine the blur radius for each pixel based on the depth map
+ blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
+ for radius in range(max_blur_radius + 1):
+ dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
+ mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp)
+ blur_radii[mask] = dialted_blur_radii[mask]
+ blur_radii = np.clip(blur_radii, 0, max_blur_radius)
+ blur_radii = cv2.blur(blur_radii, (5, 5))
+
+ # Precalculate the blured image for each blur radius
+ unique_radii = np.unique(blur_radii)
+ precomputed = {}
+ for radius in range(max_blur_radius + 1):
+ if radius not in unique_radii:
+ continue
+ precomputed[radius] = disk_blur(img, radius)
+
+ # Composit the blured image for each pixel
+ output = np.zeros_like(img)
+ for r in unique_radii:
+ mask = blur_radii == r
+ output[mask] = precomputed[r][mask]
+
+ return output
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/geometry_torch.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/geometry_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..20b5632aa8605a7409a0cec97244ebd89865e3a6
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/geometry_torch.py
@@ -0,0 +1,234 @@
+from typing import *
+import math
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.types
+import utils3d
+
+from .tools import timeit
+from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
+
+
+def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.mean(dim=dim, keepdim=keepdim)
+ else:
+ w = w.to(x.dtype)
+ return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
+
+
+def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
+ else:
+ w = w.to(x.dtype)
+ return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
+
+
+def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.add(eps).log().mean(dim=dim).exp()
+ else:
+ w = w.to(x.dtype)
+ return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
+
+
+def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
+ u, v = torch.meshgrid(u, v, indexing='xy')
+ uv = torch.stack([u, v], dim=-1)
+ return uv
+
+
+def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
+ kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
+ kernel = kernel / kernel.sum()
+ kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
+ input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
+ input = F.conv2d(input, kernel, groups=input.shape[1])
+ return input
+
+
+def focal_to_fov(focal: torch.Tensor):
+ return 2 * torch.atan(0.5 / focal)
+
+
+def fov_to_focal(fov: torch.Tensor):
+ return 0.5 / torch.tan(fov / 2)
+
+
+def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12):
+ return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1))
+
+def intrinsics_to_fov(intrinsics: torch.Tensor):
+ """
+ Returns field of view in radians from normalized intrinsics matrix.
+ ### Parameters:
+ - intrinsics: torch.Tensor of shape (..., 3, 3)
+
+ ### Returns:
+ - fov_x: torch.Tensor of shape (...)
+ - fov_y: torch.Tensor of shape (...)
+ """
+ focal_x = intrinsics[..., 0, 0]
+ focal_y = intrinsics[..., 1, 1]
+ return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
+
+
+def point_map_to_depth_legacy(points: torch.Tensor):
+ height, width = points.shape[-3:-1]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
+
+ # Solve least squares problem
+ b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
+ A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
+
+ M = A.transpose(-2, -1) @ A
+ solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
+ focal, shift = solution.unbind(-1)
+
+ depth = points[..., 2] + shift[..., None, None]
+ fov_x = torch.atan(width / diagonal / focal) * 2
+ fov_y = torch.atan(height / diagonal / focal) * 2
+ return depth, fov_x, fov_y, shift
+
+
+def view_plane_uv_to_focal(uv: torch.Tensor):
+ normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
+ focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
+ return focal
+
+
+def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
+ """
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
+
+ Note that it assumes:
+ - the optical center is at the center of the map
+ - the map is undistorted
+ - the map is isometric in the x and y directions
+
+ ### Parameters:
+ - `points: torch.Tensor` of shape (..., H, W, 3)
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
+
+ ### Returns:
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
+ """
+ shape = points.shape
+ height, width = points.shape[-3], points.shape[-2]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+
+ points = points.reshape(-1, *shape[-3:])
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
+ focal = focal.reshape(-1) if focal is not None else None
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
+
+ points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
+ uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
+ mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
+
+ uv_lr_np = uv_lr.cpu().numpy()
+ points_lr_np = points_lr.detach().cpu().numpy()
+ focal_np = focal.cpu().numpy() if focal is not None else None
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
+ optim_shift, optim_focal = [], []
+ for i in range(points.shape[0]):
+ points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
+ if uv_lr_i_np.shape[0] < 2:
+ optim_focal.append(1)
+ optim_shift.append(0)
+ continue
+ if focal is None:
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
+ optim_focal.append(float(optim_focal_i))
+ else:
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
+ optim_shift.append(float(optim_shift_i))
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+
+ if focal is None:
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+ else:
+ optim_focal = focal.reshape(shape[:-3])
+
+ return optim_focal, optim_shift
+
+
+def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3):
+ *batch_shape, height, width = depth.shape
+ depth = depth.reshape(-1, 1, height, width)
+ mask = mask.reshape(-1, 1, height, width)
+ if pooler =='max':
+ pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
+ output_mask = pooled_depth > depth * (1 + rtol)
+ elif pooler =='min':
+ pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
+ output_mask = pooled_depth < depth * (1 - rtol)
+ else:
+ raise ValueError(f'Unsupported pooler: {pooler}')
+ output_mask = output_mask.reshape(*batch_shape, height, width)
+ return output_mask
+
+
+def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor:
+ kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool)
+ for _ in range(iterations):
+ input_window = utils3d.pt.sliding_window(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1))
+ mask_window = kernel & utils3d.pt.sliding_window(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1))
+ if filter =='min':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values)
+ elif filter =='max':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values)
+ elif filter == 'mean':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1)))
+ elif filter =='median':
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values)
+ mask = mask_window.any(dim=(-2, -1))
+ return input, mask
+
+
+def refine_depth_with_normal(depth: torch.Tensor, normal: torch.Tensor, intrinsics: torch.Tensor, iterations: int = 10, damp: float = 1e-3, eps: float = 1e-12, kernel_size: int = 5) -> torch.Tensor:
+ device, dtype = depth.device, depth.dtype
+ height, width = depth.shape[-2:]
+ radius = kernel_size // 2
+
+ duv = torch.stack(torch.meshgrid(torch.linspace(-radius / width, radius / width, kernel_size, device=device, dtype=dtype), torch.linspace(-radius / height, radius / height, kernel_size, device=device, dtype=dtype), indexing='xy'), dim=-1).to(dtype=dtype, device=device)
+
+ log_depth = depth.clamp_min_(eps).log()
+ log_depth_diff = utils3d.pt.sliding_window(log_depth, window_size=kernel_size, stride=1, dim=(-2, -1)) - log_depth[..., radius:-radius, radius:-radius, None, None]
+
+ weight = torch.exp(-(log_depth_diff / duv.norm(dim=-1).clamp_min_(eps) / 10).square())
+ tot_weight = weight.sum(dim=(-2, -1)).clamp_min_(eps)
+
+ uv = utils3d.pt.uv_map((height, width), device=device, dtype=dtype)
+ K_inv = torch.inverse(intrinsics)
+
+ grad = -(normal[..., None, :2] @ K_inv[..., None, None, :2, :2]).squeeze(-2) \
+ / (normal[..., None, 2:] + normal[..., None, :2] @ (K_inv[..., None, None, :2, :2] @ uv[..., :, None] + K_inv[..., None, None, :2, 2:])).squeeze(-2)
+ laplacian = (weight * ((utils3d.pt.sliding_window(grad, window_size=kernel_size, stride=1, dim=(-3, -2)) + grad[..., radius:-radius, radius:-radius, :, None, None]) * (duv.permute(2, 0, 1) / 2)).sum(dim=-3)).sum(dim=(-2, -1))
+
+ laplacian = laplacian.clamp(-0.1, 0.1)
+ log_depth_refine = log_depth.clone()
+
+ for _ in range(iterations):
+ log_depth_refine[..., radius:-radius, radius:-radius] = 0.1 * log_depth_refine[..., radius:-radius, radius:-radius] + 0.9 * (damp * log_depth[..., radius:-radius, radius:-radius] - laplacian + (weight * utils3d.pt.sliding_window_2d(log_depth_refine, window_size=kernel_size, stride=1, dim=(-2, -1))).sum(dim=(-2, -1))) / (tot_weight + damp)
+
+ depth_refine = log_depth_refine.exp()
+
+ return depth_refine
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/io.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..47b16413e6635b6aa0949ba588dbf51c5d4be3e4
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/io.py
@@ -0,0 +1,271 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+from typing import IO
+import zipfile
+import json
+import io
+from typing import *
+from pathlib import Path
+import re
+from PIL import Image, PngImagePlugin
+
+import numpy as np
+import cv2
+
+from .tools import timeit
+
+
+def save_glb(
+ save_path: Union[str, os.PathLike],
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ vertex_uvs: np.ndarray,
+ texture: np.ndarray,
+ vertex_normals: Optional[np.ndarray] = None,
+):
+ import trimesh
+ import trimesh.visual
+ from PIL import Image
+
+ trimesh.Trimesh(
+ vertices=vertices,
+ vertex_normals=vertex_normals,
+ faces=faces,
+ visual = trimesh.visual.texture.TextureVisuals(
+ uv=vertex_uvs,
+ material=trimesh.visual.material.PBRMaterial(
+ baseColorTexture=Image.fromarray(texture),
+ metallicFactor=0.5,
+ roughnessFactor=1.0
+ )
+ ),
+ process=False
+ ).export(save_path)
+
+
+def save_ply(
+ save_path: Union[str, os.PathLike],
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ vertex_colors: np.ndarray,
+ vertex_normals: Optional[np.ndarray] = None,
+):
+ import trimesh
+ import trimesh.visual
+ from PIL import Image
+
+ trimesh.Trimesh(
+ vertices=vertices,
+ faces=faces,
+ vertex_colors=vertex_colors,
+ vertex_normals=vertex_normals,
+ process=False
+ ).export(save_path)
+
+
+def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray:
+ """
+ Read a image, return uint8 RGB array of shape (H, W, 3).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
+ return image
+
+
+def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95):
+ """
+ Write a image, input uint8 RGB array of shape (H, W, 3).
+ """
+ data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes()
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_bytes(data)
+ else:
+ path.write(data)
+
+
+def read_depth(path: Union[str, os.PathLike, IO]) -> np.ndarray:
+ """
+ Read a depth image, return float32 depth array of shape (H, W).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ pil_image = Image.open(io.BytesIO(data))
+ near = float(pil_image.info.get('near'))
+ far = float(pil_image.info.get('far'))
+ depth = np.array(pil_image)
+ mask_nan, mask_inf = depth == 0, depth == 65535
+ depth = (depth.astype(np.float32) - 1) / 65533
+ depth = near ** (1 - depth) * far ** depth
+ if 'unit' in pil_image.info: # Legacy support for depth units
+ unit = float(pil_image.info.get('unit'))
+ depth = depth * unit
+ depth[mask_nan] = np.nan
+ depth[mask_inf] = np.inf
+ return depth
+
+
+def write_depth(
+ path: Union[str, os.PathLike, IO],
+ depth: np.ndarray,
+ max_range: float = 1e5,
+ compression_level: int = 7,
+):
+ """
+ Encode and write a depth image as 16-bit PNG format.
+ ## Parameters:
+ - `path: Union[str, os.PathLike, IO]`
+ The file path or file object to write to.
+ - `depth: np.ndarray`
+ The depth array, float32 array of shape (H, W).
+ May contain `NaN` for invalid values and `Inf` for infinite values.
+
+ Depth values are encoded as follows:
+ - 0: unknown
+ - 1 ~ 65534: depth values in logarithmic
+ - 65535: infinity
+
+ metadata is stored in the PNG file as text fields:
+ - `near`: the minimum depth value
+ - `far`: the maximum depth value
+ """
+ mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth)
+
+ depth = depth.astype(np.float32)
+ mask_finite = depth
+ near = max(depth[mask_values].min(), 1e-5)
+ far = max(near * 1.1, min(depth[mask_values].max(), near * max_range))
+ depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534
+ depth[mask_nan] = 0
+ depth[mask_inf] = 65535
+
+ pil_image = Image.fromarray(depth)
+ pnginfo = PngImagePlugin.PngInfo()
+ pnginfo.add_text('near', str(near))
+ pnginfo.add_text('far', str(far))
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
+
+
+def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]:
+ """
+ Read a segmentation mask
+ ### Parameters:
+ - `path: Union[str, os.PathLike, IO]`
+ The file path or file object to read from.
+ ### Returns:
+ - `Tuple[np.ndarray, Dict[str, int]]`
+ A tuple containing:
+ - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W).
+ - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}.
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ pil_image = Image.open(io.BytesIO(data))
+ labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None
+ mask = np.array(pil_image)
+ return mask, labels
+
+
+def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7):
+ """
+ Write a segmentation mask and label mapping, as PNG format.
+ ### Parameters:
+ - `path: Union[str, os.PathLike, IO]`
+ The file path or file object to write to.
+ - `mask: np.ndarray`
+ The segmentation mask, uint8 or uint16 array of shape (H, W).
+ - `labels: Dict[str, int] = None`
+ The label mapping, a dictionary of {label_name: label_id}.
+ - `compression_level: int = 7`
+ The compression level for PNG compression.
+ """
+ assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}"
+ pil_image = Image.fromarray(mask)
+ pnginfo = PngImagePlugin.PngInfo()
+ if labels is not None:
+ labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':'))
+ pnginfo.add_text('labels', labels_json)
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
+
+
+
+def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray:
+ """
+ Read a normal image, return float32 normal array of shape (H, W, 3).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
+ mask_nan = np.all(normal == 0, axis=-1)
+ normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
+ normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12)
+ normal[mask_nan] = np.nan
+ return normal
+
+
+def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray:
+ """
+ Write a normal image, input float32 normal array of shape (H, W, 3).
+ """
+ mask_nan = np.isnan(normal).any(axis=-1)
+ normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
+ normal[mask_nan] = 0
+ data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_bytes(data)
+ else:
+ path.write(data)
+
+
+def read_mask(path: Union[str, os.PathLike, IO[bytes]]) -> np.ndarray:
+ """
+ Read a binary mask, return bool array of shape (H, W).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ mask = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED)
+ if len(mask.shape) == 3:
+ mask = mask[..., 0]
+ return mask > 0
+
+
+def write_mask(path: Union[str, os.PathLike, IO[bytes]], mask: np.ndarray, compression_level: int = 7):
+ """
+ Write a binary mask, input bool array of shape (H, W).
+ """
+ assert mask.dtype == bool, f"Mask must be bool array, got {mask.dtype}"
+ mask = (mask.astype(np.uint8) * 255).astype(np.uint8)
+ data = cv2.imencode('.png', mask, [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_bytes(data)
+ else:
+ path.write(data)
+
+
+JSON_TYPE = Union[str, int, float, bool, None, Dict[str, "JSON"], List["JSON"]]
+
+
+def read_json(path: Union[str, os.PathLike, IO[str]]) -> JSON_TYPE:
+ if isinstance(path, (str, os.PathLike)):
+ text = Path(path).read_text()
+ else:
+ text = path.read()
+ return json.loads(text)
+
+
+def write_json(path: Union[str, os.PathLike, IO[str]], content: JSON_TYPE):
+ text = json.dumps(content)
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_text(text)
+ else:
+ path.write(text)
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/panorama.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/panorama.py
new file mode 100644
index 0000000000000000000000000000000000000000..42d915ad324424bf4faf1a47b55d596753d9d3c6
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/panorama.py
@@ -0,0 +1,191 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+from pathlib import Path
+from typing import *
+import itertools
+import json
+import warnings
+
+import cv2
+import numpy as np
+from numpy import ndarray
+from tqdm import tqdm, trange
+from scipy.sparse import csr_array, hstack, vstack
+from scipy.ndimage import convolve
+from scipy.sparse.linalg import lsmr
+
+import utils3d
+
+
+def get_panorama_cameras():
+ vertices, _ = utils3d.np.create_icosahedron_mesh()
+ intrinsics = utils3d.np.intrinsics_from_fov(fov_x=np.deg2rad(90), fov_y=np.deg2rad(90))
+ extrinsics = utils3d.np.extrinsics_look_at([0, 0, 0], vertices, [0, 0, 1]).astype(np.float32)
+ return extrinsics, [intrinsics] * len(vertices)
+
+
+def spherical_uv_to_directions(uv: np.ndarray):
+ theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi
+ directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1)
+ return directions
+
+
+def directions_to_spherical_uv(directions: np.ndarray):
+ directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
+ u = 1 - np.arctan2(directions[..., 1], directions[..., 0]) / (2 * np.pi) % 1.0
+ v = np.arccos(directions[..., 2]) / np.pi
+ return np.stack([u, v], axis=-1)
+
+
+def split_panorama_image(image: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray, resolution: int):
+ height, width = image.shape[:2]
+ uv = utils3d.np.uv_map((resolution, resolution))
+ splitted_images = []
+ for i in range(len(extrinsics)):
+ spherical_uv = directions_to_spherical_uv(utils3d.np.unproject_cv(uv, np.ones_like(uv[..., 0]), extrinsics=extrinsics[i], intrinsics=intrinsics[i]))
+ pixels = utils3d.np.uv_to_pixel(spherical_uv, (height, width)).astype(np.float32)
+
+ splitted_image = cv2.remap(image, pixels[..., 0], pixels[..., 1], interpolation=cv2.INTER_LINEAR)
+ splitted_images.append(splitted_image)
+ return splitted_images
+
+
+def poisson_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, ndarray]:
+ grid_index = np.arange(height * width).reshape(height, width)
+ grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode='wrap' if wrap_x else 'edge')
+ grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode='wrap' if wrap_y else 'edge')
+
+ data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(height * width, axis=0).reshape(-1)
+ indices = np.stack([
+ grid_index[1:-1, 1:-1],
+ grid_index[:-2, 1:-1], # up
+ grid_index[2:, 1:-1], # down
+ grid_index[1:-1, :-2], # left
+ grid_index[1:-1, 2:] # right
+ ], axis=-1).reshape(-1)
+ indptr = np.arange(0, height * width * 5 + 1, 5)
+ A = csr_array((data, indices, indptr), shape=(height * width, height * width))
+
+ return A
+
+
+def grad_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, np.ndarray]:
+ grid_index = np.arange(width * height).reshape(height, width)
+ if wrap_x:
+ grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode='wrap')
+ if wrap_y:
+ grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode='wrap')
+
+ data = np.concatenate([
+ np.concatenate([
+ np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j]
+ -np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j-1]
+ ], axis=1).reshape(-1),
+ np.concatenate([
+ np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i,j]
+ -np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i-1,j]
+ ], axis=1).reshape(-1),
+ ])
+ indices = np.concatenate([
+ np.concatenate([
+ grid_index[:, :-1].reshape(-1, 1),
+ grid_index[:, 1:].reshape(-1, 1),
+ ], axis=1).reshape(-1),
+ np.concatenate([
+ grid_index[:-1, :].reshape(-1, 1),
+ grid_index[1:, :].reshape(-1, 1),
+ ], axis=1).reshape(-1),
+ ])
+ indptr = np.arange(0, grid_index.shape[0] * (grid_index.shape[1] - 1) * 2 + (grid_index.shape[0] - 1) * grid_index.shape[1] * 2 + 1, 2)
+ A = csr_array((data, indices, indptr), shape=(grid_index.shape[0] * (grid_index.shape[1] - 1) + (grid_index.shape[0] - 1) * grid_index.shape[1], height * width))
+
+ return A
+
+
+def merge_panorama_depth(width: int, height: int, distance_maps: List[np.ndarray], pred_masks: List[np.ndarray], extrinsics: List[np.ndarray], intrinsics: List[np.ndarray]):
+ if max(width, height) > 256:
+ panorama_depth_init, _ = merge_panorama_depth(width // 2, height // 2, distance_maps, pred_masks, extrinsics, intrinsics)
+ panorama_depth_init = cv2.resize(panorama_depth_init, (width, height), cv2.INTER_LINEAR)
+ else:
+ panorama_depth_init = None
+
+ uv = utils3d.np.uv_map(height, width)
+ spherical_directions = spherical_uv_to_directions(uv)
+
+ # Warp each view to the panorama
+ panorama_log_distance_grad_maps, panorama_grad_masks = [], []
+ panorama_log_distance_laplacian_maps, panorama_laplacian_masks = [], []
+ panorama_pred_masks = []
+ for i in range(len(distance_maps)):
+ projected_uv, projected_depth = utils3d.np.project_cv(spherical_directions, extrinsics=extrinsics[i], intrinsics=intrinsics[i])
+ projection_valid_mask = (projected_depth > 0) & (projected_uv > 0).all(axis=-1) & (projected_uv < 1).all(axis=-1)
+
+ projected_pixels = utils3d.np.uv_to_pixel(np.clip(projected_uv, 0, 1), distance_maps[i].shape).astype(np.float32)
+
+ log_splitted_distance = np.log(distance_maps[i])
+ panorama_log_distance_map = np.where(projection_valid_mask, cv2.remap(log_splitted_distance, projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE), 0)
+ panorama_pred_mask = projection_valid_mask & (cv2.remap(pred_masks[i].astype(np.uint8), projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE) > 0)
+
+ # calculate gradient map
+ padded = np.pad(panorama_log_distance_map, ((0, 0), (0, 1)), mode='wrap')
+ grad_x, grad_y = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :]
+
+ padded = np.pad(panorama_pred_mask, ((0, 0), (0, 1)), mode='wrap')
+ mask_x, mask_y = padded[:, :-1] & padded[:, 1:], padded[:-1, :] & padded[1:, :]
+
+ panorama_log_distance_grad_maps.append((grad_x, grad_y))
+ panorama_grad_masks.append((mask_x, mask_y))
+
+ # calculate laplacian map
+ padded = np.pad(panorama_log_distance_map, ((1, 1), (0, 0)), mode='edge')
+ padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap')
+ laplacian = convolve(padded, np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32))[1:-1, 1:-1]
+
+ padded = np.pad(panorama_pred_mask, ((1, 1), (0, 0)), mode='edge')
+ padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap')
+ mask = convolve(padded.astype(np.uint8), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8))[1:-1, 1:-1] == 5
+
+ panorama_log_distance_laplacian_maps.append(laplacian)
+ panorama_laplacian_masks.append(mask)
+
+ panorama_pred_masks.append(panorama_pred_mask)
+
+ panorama_log_distance_grad_x = np.stack([grad_map[0] for grad_map in panorama_log_distance_grad_maps], axis=0)
+ panorama_log_distance_grad_y = np.stack([grad_map[1] for grad_map in panorama_log_distance_grad_maps], axis=0)
+ panorama_grad_mask_x = np.stack([mask_map[0] for mask_map in panorama_grad_masks], axis=0)
+ panorama_grad_mask_y = np.stack([mask_map[1] for mask_map in panorama_grad_masks], axis=0)
+
+ panorama_log_distance_grad_x = np.sum(panorama_log_distance_grad_x * panorama_grad_mask_x, axis=0) / np.sum(panorama_grad_mask_x, axis=0).clip(1e-3)
+ panorama_log_distance_grad_y = np.sum(panorama_log_distance_grad_y * panorama_grad_mask_y, axis=0) / np.sum(panorama_grad_mask_y, axis=0).clip(1e-3)
+
+ panorama_laplacian_maps = np.stack(panorama_log_distance_laplacian_maps, axis=0)
+ panorama_laplacian_masks = np.stack(panorama_laplacian_masks, axis=0)
+ panorama_laplacian_map = np.sum(panorama_laplacian_maps * panorama_laplacian_masks, axis=0) / np.sum(panorama_laplacian_masks, axis=0).clip(1e-3)
+
+ grad_x_mask = np.any(panorama_grad_mask_x, axis=0).reshape(-1)
+ grad_y_mask = np.any(panorama_grad_mask_y, axis=0).reshape(-1)
+ grad_mask = np.concatenate([grad_x_mask, grad_y_mask])
+ laplacian_mask = np.any(panorama_laplacian_masks, axis=0).reshape(-1)
+
+ # Solve overdetermined system
+ A = vstack([
+ grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask],
+ poisson_equation(width, height, wrap_x=True, wrap_y=False)[laplacian_mask],
+ ])
+ b = np.concatenate([
+ panorama_log_distance_grad_x.reshape(-1)[grad_x_mask],
+ panorama_log_distance_grad_y.reshape(-1)[grad_y_mask],
+ panorama_laplacian_map.reshape(-1)[laplacian_mask]
+ ])
+ x, *_ = lsmr(
+ A, b,
+ atol=1e-5, btol=1e-5,
+ x0=np.log(panorama_depth_init).reshape(-1) if panorama_depth_init is not None else None,
+ show=False,
+ )
+
+ panorama_depth = np.exp(x).reshape(height, width).astype(np.float32)
+ panorama_mask = np.any(panorama_pred_masks, axis=0)
+
+ return panorama_depth, panorama_mask
+
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/tools.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..3687f6938fe34433d149a1a8405be7eed5f23c37
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/tools.py
@@ -0,0 +1,289 @@
+from typing import *
+import time
+from pathlib import Path
+from numbers import Number
+from functools import wraps
+import warnings
+import math
+import json
+import os
+import importlib
+import importlib.util
+
+
+def catch_exception(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ import traceback
+ print(f"Exception in {fn.__name__}", end='r')
+ # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})
+ traceback.print_exc(chain=False)
+ time.sleep(0.1)
+ return None
+ return wrapper
+
+
+class CallbackOnException:
+ def __init__(self, callback: Callable, exception: type):
+ self.exception = exception
+ self.callback = callback
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if isinstance(exc_val, self.exception):
+ self.callback()
+ return True
+ return False
+
+def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
+ for k, v in d.items():
+ if isinstance(v, dict):
+ for sub_key in traverse_nested_dict_keys(v):
+ yield (k, ) + sub_key
+ else:
+ yield (k, )
+
+
+def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
+ for k in keys:
+ d = d.get(k, default)
+ if d is None:
+ break
+ return d
+
+def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
+ for k in keys[:-1]:
+ d = d.setdefault(k, {})
+ d[keys[-1]] = value
+
+
+def key_average(list_of_dicts: list) -> Dict[str, Any]:
+ """
+ Returns a dictionary with the average value of each key in the input list of dictionaries.
+ """
+ _nested_dict_keys = set()
+ for d in list_of_dicts:
+ _nested_dict_keys.update(traverse_nested_dict_keys(d))
+ _nested_dict_keys = sorted(_nested_dict_keys)
+ result = {}
+ for k in _nested_dict_keys:
+ values = []
+ for d in list_of_dicts:
+ v = get_nested_dict(d, k)
+ if v is not None and not math.isnan(v):
+ values.append(v)
+ avg = sum(values) / len(values) if values else float('nan')
+ set_nested_dict(result, k, avg)
+ return result
+
+
+def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
+ """
+ Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
+ """
+ items = []
+ if parent_key is None:
+ parent_key = ()
+ for k, v in d.items():
+ new_key = parent_key + (k, )
+ if isinstance(v, MutableMapping):
+ items.extend(flatten_nested_dict(v, new_key).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
+ """
+ result = {}
+ for k, v in d.items():
+ sub_dict = result
+ for k_ in k[:-1]:
+ if k_ not in sub_dict:
+ sub_dict[k_] = {}
+ sub_dict = sub_dict[k_]
+ sub_dict[k[-1]] = v
+ return result
+
+
+def read_jsonl(file):
+ import json
+ with open(file, 'r') as f:
+ data = f.readlines()
+ return [json.loads(line) for line in data]
+
+
+def write_jsonl(data: List[dict], file):
+ import json
+ with open(file, 'w') as f:
+ for item in data:
+ f.write(json.dumps(item) + '\n')
+
+
+def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
+ import pandas as pd
+ data = [flatten_nested_dict(d) for d in data]
+ df = pd.DataFrame(data)
+ df = df.sort_index(axis=1)
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
+ return df
+
+
+def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
+ if isinstance(d, str):
+ for old, new in mapping.items():
+ d = d.replace(old, new)
+ elif isinstance(d, list):
+ for i, item in enumerate(d):
+ d[i] = recursive_replace(item, mapping)
+ elif isinstance(d, dict):
+ for k, v in d.items():
+ d[k] = recursive_replace(v, mapping)
+ return d
+
+
+class timeit:
+ _history: Dict[str, List['timeit']] = {}
+
+ def __init__(self, name: str = None, verbose: bool = True, average: bool = False):
+ self.name = name
+ self.verbose = verbose
+ self.start = None
+ self.end = None
+ self.average = average
+ if average and name not in timeit._history:
+ timeit._history[name] = []
+
+ def __call__(self, func: Callable):
+ import inspect
+ if inspect.iscoroutinefunction(func):
+ async def wrapper(*args, **kwargs):
+ with timeit(self.name or func.__qualname__):
+ ret = await func(*args, **kwargs)
+ return ret
+ return wrapper
+ else:
+ def wrapper(*args, **kwargs):
+ with timeit(self.name or func.__qualname__):
+ ret = func(*args, **kwargs)
+ return ret
+ return wrapper
+
+ def __enter__(self):
+ self.start = time.time()
+ return self
+
+ @property
+ def time(self) -> float:
+ assert self.start is not None, "Time not yet started."
+ assert self.end is not None, "Time not yet ended."
+ return self.end - self.start
+
+ @property
+ def average_time(self) -> float:
+ assert self.average, "Average time not available."
+ return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
+
+ @property
+ def history(self) -> List['timeit']:
+ return timeit._history.get(self.name, [])
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.end = time.time()
+ if self.average:
+ timeit._history[self.name].append(self)
+ if self.verbose:
+ if self.average:
+ avg = self.average_time
+ print(f"{self.name or 'It'} took {avg:.6f} seconds in average.")
+ else:
+ print(f"{self.name or 'It'} took {self.time:.6f} seconds.")
+
+
+def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
+ first = strings[0]
+
+ for start in range(len(first)):
+ if any(s[start] != strings[0][start] for s in strings):
+ break
+
+ for end in range(1, min(len(s) for s in strings)):
+ if any(s[-end] != first[-end] for s in strings):
+ break
+
+ return [s[start:len(s) - end + 1] for s in strings]
+
+
+def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
+ from concurrent.futures import ThreadPoolExecutor
+ from contextlib import nullcontext
+ from tqdm import tqdm
+
+ if pbar is not None:
+ pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
+ else:
+ pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
+
+ def decorator(fn: Callable):
+ with (
+ ThreadPoolExecutor(max_workers=num_workers) as executor,
+ pbar
+ ):
+ pbar.refresh()
+ @catch_exception
+ @suppress_traceback
+ def _fn(input):
+ ret = fn(input)
+ pbar.update()
+ return ret
+ executor.map(_fn, inputs)
+ executor.shutdown(wait=True)
+
+ return decorator
+
+
+def suppress_traceback(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
+ raise
+ return wrapper
+
+
+class no_warnings:
+ def __init__(self, action: str = 'ignore', **kwargs):
+ self.action = action
+ self.filter_kwargs = kwargs
+
+ def __call__(self, fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ with warnings.catch_warnings():
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+ return fn(*args, **kwargs)
+ return wrapper
+
+ def __enter__(self):
+ self.warnings_manager = warnings.catch_warnings()
+ self.warnings_manager.__enter__()
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
+
+
+def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/vis.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb9c2378b58ec26ac5067b7ffcbd749a8ad968ce
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/vis.py
@@ -0,0 +1,65 @@
+from typing import *
+
+import numpy as np
+import matplotlib
+
+
+def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
+ if mask is None:
+ depth = np.where(depth > 0, depth, np.nan)
+ else:
+ depth = np.where((depth > 0) & mask, depth, np.nan)
+ disp = 1 / depth
+ if normalize:
+ min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99)
+ disp = (disp - min_disp) / (max_disp - min_disp)
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray:
+ if mask is not None:
+ depth = np.where(mask, depth, np.nan)
+
+ min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999)
+ depth = (depth - min_depth) / (max_depth - min_depth)
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0)
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
+ if mask is not None:
+ disparity = np.where(mask, disparity, np.nan)
+
+ if normalize:
+ min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999)
+ disparity = (disparity - min_disp) / (max_disp - min_disp)
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0)
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray:
+ colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20)[..., :3]
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
+ if mask is not None:
+ normal = np.where(mask[..., None], normal, 0)
+ normal = normal * [0.5, -0.5, -0.5] + 0.5
+ normal = (normal.clip(0, 1) * 255).astype(np.uint8)
+ return normal
+
+
+def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None):
+ vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map))
+ cmap = matplotlib.colormaps[cmap]
+ colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3]
+ if mask is not None:
+ colorized_error_map = np.where(mask[..., None], colorized_error_map, 0)
+ colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8))
+ return colorized_error_map
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/webfile.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/webfile.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e98abf8413e1c9f408849b74f4d2025d25511b6
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/webfile.py
@@ -0,0 +1,73 @@
+import requests
+from typing import *
+
+__all__ = ["WebFile"]
+
+
+class WebFile:
+ def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None):
+ self.url = url
+ self.session = session or requests.Session()
+ self.session.headers.update(headers or {})
+ self._offset = 0
+ self.size = size if size is not None else self._fetch_size()
+
+ def _fetch_size(self):
+ with self.session.get(self.url, stream=True) as response:
+ response.raise_for_status()
+ content_length = response.headers.get("Content-Length")
+ if content_length is None:
+ raise ValueError("Missing Content-Length in header")
+ return int(content_length)
+
+ def _fetch_data(self, offset: int, n: int) -> bytes:
+ headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"}
+ response = self.session.get(self.url, headers=headers)
+ response.raise_for_status()
+ return response.content
+
+ def seekable(self) -> bool:
+ return True
+
+ def tell(self) -> int:
+ return self._offset
+
+ def available(self) -> int:
+ return self.size - self._offset
+
+ def seek(self, offset: int, whence: int = 0) -> None:
+ if whence == 0:
+ new_offset = offset
+ elif whence == 1:
+ new_offset = self._offset + offset
+ elif whence == 2:
+ new_offset = self.size + offset
+ else:
+ raise ValueError("Invalid value for whence")
+
+ self._offset = max(0, min(new_offset, self.size))
+
+ def read(self, n: Optional[int] = None) -> bytes:
+ if n is None or n < 0:
+ n = self.available()
+ else:
+ n = min(n, self.available())
+
+ if n == 0:
+ return b''
+
+ data = self._fetch_data(self._offset, n)
+ self._offset += len(data)
+
+ return data
+
+ def close(self) -> None:
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/moge/utils/webzipfile.py b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/webzipfile.py
new file mode 100644
index 0000000000000000000000000000000000000000..25ed1d3cd34720335eb001d77a278539ffef569b
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/moge/utils/webzipfile.py
@@ -0,0 +1,128 @@
+from typing import *
+import io
+import os
+from zipfile import (
+ ZipInfo, BadZipFile, ZipFile, ZipExtFile,
+ sizeFileHeader, structFileHeader, stringFileHeader,
+ _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS,
+ _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED
+)
+import struct
+from requests import Session
+
+from .webfile import WebFile
+
+
+class _SharedWebFile(WebFile):
+ def __init__(self, webfile: WebFile, pos: int):
+ super().__init__(webfile.url, webfile.session, size=webfile.size)
+ self.seek(pos)
+
+
+class WebZipFile(ZipFile):
+ "Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads."
+ def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None):
+ """Open the ZIP file with mode read 'r', write 'w', exclusive create 'x',
+ or append 'a'."""
+ webf = WebFile(url, session=session, headers=headers)
+ super().__init__(webf, mode='r')
+
+ def open(self, name, mode="r", pwd=None, *, force_zip64=False):
+ """Return file-like object for 'name'.
+
+ name is a string for the file name within the ZIP file, or a ZipInfo
+ object.
+
+ mode should be 'r' to read a file already in the ZIP file, or 'w' to
+ write to a file newly added to the archive.
+
+ pwd is the password to decrypt files (only used for reading).
+
+ When writing, if the file size is not known in advance but may exceed
+ 2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large
+ files. If the size is known in advance, it is best to pass a ZipInfo
+ instance for name, with zinfo.file_size set.
+ """
+ if mode not in {"r", "w"}:
+ raise ValueError('open() requires mode "r" or "w"')
+ if pwd and (mode == "w"):
+ raise ValueError("pwd is only supported for reading files")
+ if not self.fp:
+ raise ValueError(
+ "Attempt to use ZIP archive that was already closed")
+
+ assert mode == "r", "Only read mode is supported for now"
+
+ # Make sure we have an info object
+ if isinstance(name, ZipInfo):
+ # 'name' is already an info object
+ zinfo = name
+ elif mode == 'w':
+ zinfo = ZipInfo(name)
+ zinfo.compress_type = self.compression
+ zinfo._compresslevel = self.compresslevel
+ else:
+ # Get info object for name
+ zinfo = self.getinfo(name)
+
+ if mode == 'w':
+ return self._open_to_write(zinfo, force_zip64=force_zip64)
+
+ if self._writing:
+ raise ValueError("Can't read from the ZIP file while there "
+ "is an open writing handle on it. "
+ "Close the writing handle before trying to read.")
+
+ # Open for reading:
+ self._fileRefCnt += 1
+ zef_file = _SharedWebFile(self.fp, zinfo.header_offset)
+
+ try:
+ # Skip the file header:
+ fheader = zef_file.read(sizeFileHeader)
+ if len(fheader) != sizeFileHeader:
+ raise BadZipFile("Truncated file header")
+ fheader = struct.unpack(structFileHeader, fheader)
+ if fheader[_FH_SIGNATURE] != stringFileHeader:
+ raise BadZipFile("Bad magic number for file header")
+
+ fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
+ if fheader[_FH_EXTRA_FIELD_LENGTH]:
+ zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1)
+
+ if zinfo.flag_bits & _MASK_COMPRESSED_PATCH:
+ # Zip 2.7: compressed patched data
+ raise NotImplementedError("compressed patched data (flag bit 5)")
+
+ if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION:
+ # strong encryption
+ raise NotImplementedError("strong encryption (flag bit 6)")
+
+ if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME:
+ # UTF-8 filename
+ fname_str = fname.decode("utf-8")
+ else:
+ fname_str = fname.decode(self.metadata_encoding or "cp437")
+
+ if fname_str != zinfo.orig_filename:
+ raise BadZipFile(
+ 'File name in directory %r and header %r differ.'
+ % (zinfo.orig_filename, fname))
+
+ # check for encrypted flag & handle password
+ is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED
+ if is_encrypted:
+ if not pwd:
+ pwd = self.pwd
+ if pwd and not isinstance(pwd, bytes):
+ raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__)
+ if not pwd:
+ raise RuntimeError("File %r is encrypted, password "
+ "required for extraction" % name)
+ else:
+ pwd = None
+
+ return ZipExtFile(zef_file, mode, zinfo, pwd, True)
+ except:
+ zef_file.close()
+ raise
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/pyproject.toml b/lingbotvla/models/vla/vision_models/MoGe/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..27a761307a2da7cfc4a48e79f7a5929ff7d6cb6b
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/pyproject.toml
@@ -0,0 +1,36 @@
+[build-system]
+requires = ["setuptools>=61.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "moge"
+version = "2.0.0"
+description = "MoGe: Unlocking Accurate Monocular Geometry Estimation for Open-Domain Images with Optimal Training Supervision"
+readme = "README.md"
+license = {text = "MIT"}
+dependencies = [
+ "click",
+ "opencv-python",
+ "scipy",
+ "matplotlib",
+ "trimesh",
+ "pillow",
+ "huggingface_hub",
+ "numpy",
+ "torch>=2.0.0",
+ "torchvision",
+ "gradio",
+ "utils3d @ git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183",
+ "pipeline @ git+https://github.com/EasternJournalist/pipeline.git@866f059d2a05cde05e4a52211ec5051fd5f276d6"
+]
+requires-python = ">=3.9"
+
+[project.urls]
+Homepage = "https://github.com/microsoft/MoGe"
+
+[tool.setuptools.packages.find]
+where = ["."]
+include = ["moge*"]
+
+[project.scripts]
+moge = "moge.scripts.cli:main"
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/pyrightconfig.json b/lingbotvla/models/vla/vision_models/MoGe/pyrightconfig.json
new file mode 100644
index 0000000000000000000000000000000000000000..deb3aa62afbda00a7c7413b9eefa6f0ec18fb72b
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/pyrightconfig.json
@@ -0,0 +1,10 @@
+{
+ "include": [
+ "moge",
+ "scripts",
+ "baselines"
+ ],
+ "ignore": [
+ "**"
+ ]
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/MoGe/requirements.txt b/lingbotvla/models/vla/vision_models/MoGe/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4169725b1272e7c470bf3bc1c7b70b25723ac7d9
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/MoGe/requirements.txt
@@ -0,0 +1,14 @@
+# The versions are not specified since MoGe should be compatible with most versions of the packages.
+# If incompatibilities are found, consider upgrading to latest versions or installing the following recommended version of the package.
+torch # >= 2.0.0
+torchvision
+gradio # ==2.8.13
+click # ==8.1.7
+opencv-python # ==4.10.0.84
+scipy # ==1.14.1
+matplotlib # ==3.9.2
+trimesh # ==4.5.1
+pillow # ==10.4.0
+huggingface_hub # ==0.25.2
+git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183
+git+https://github.com/EasternJournalist/pipeline.git@866f059d2a05cde05e4a52211ec5051fd5f276d6
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/align_heads/__init_.py b/lingbotvla/models/vla/vision_models/align_heads/__init_.py
new file mode 100644
index 0000000000000000000000000000000000000000..072c22736ad24520a234079047edf05a018aa7ce
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/align_heads/__init_.py
@@ -0,0 +1 @@
+from .depth_head import DepthHead, TaskTokenDepthHead
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/align_heads/depth_head.py b/lingbotvla/models/vla/vision_models/align_heads/depth_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5aa6e0b7c01102d6618b94f996292fcb7ed259e
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/align_heads/depth_head.py
@@ -0,0 +1,66 @@
+
+import torch.nn as nn
+import torch.nn.functional as F
+from .resampler import Resampler, TaskTokenResampler
+
+def build_mlp(in_hidden_size, hidden_size):
+ modules = [nn.Linear(in_hidden_size, hidden_size)]
+ modules.append(nn.ReLU())
+ modules.append(nn.Linear(hidden_size, hidden_size))
+ return nn.Sequential(*modules)
+
+def build_expand_mlp(in_hidden_size, hidden_size, out_size):
+ modules = [nn.Linear(in_hidden_size, hidden_size)]
+ modules.append(nn.ReLU())
+ modules.append(nn.Linear(hidden_size, hidden_size))
+ modules.append(nn.ReLU())
+ modules.append(nn.Linear(hidden_size, out_size))
+ return nn.Sequential(*modules)
+
+class DepthHead(nn.Module):
+ def __init__(
+ self,
+ proj_config=None,
+ llm_hidden_size=4096,
+ use_intermediate_depth=False,
+ ):
+ super(DepthHead, self).__init__()
+
+ self.projector = Resampler(
+ dim_in=llm_hidden_size,
+ dim_mid=llm_hidden_size,
+ dim_head=proj_config["dim_head"],
+ dim_out=proj_config["dim_out"],
+ num_layers=proj_config["num_layers"],
+ num_heads=proj_config["num_heads"],
+ num_queries=proj_config["num_backbone_tokens"],
+ ff_mult=proj_config["ff_mult"],
+ )
+
+ def forward(self, llm_feats):
+ queries = self.projector(llm_feats)
+ return queries
+
+class TaskTokenDepthHead(nn.Module):
+ def __init__(
+ self,
+ proj_config=None,
+ llm_hidden_size=4096,
+ use_intermediate_depth=False,
+ ):
+ super(TaskTokenDepthHead, self).__init__()
+
+ self.projector = TaskTokenResampler(
+ dim_in=llm_hidden_size,
+ dim_mid=llm_hidden_size,
+ dim_head=proj_config["dim_head"],
+ dim_out=proj_config["dim_out"],
+ num_layers=proj_config["num_layers"],
+ num_heads=proj_config["num_heads"],
+ num_queries=proj_config["num_backbone_tokens"],
+ ff_mult=proj_config["ff_mult"],
+ )
+
+ def forward(self, llm_feats, queries):
+ queries = self.projector(llm_feats, queries)
+ return queries
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/align_heads/resampler.py b/lingbotvla/models/vla/vision_models/align_heads/resampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..05116c1c1e9a8310e62b3c2ca9d0e471955fe986
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/align_heads/resampler.py
@@ -0,0 +1,346 @@
+# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class AttentionPool2d(nn.Module):
+
+ def __init__(self, seq_len: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(seq_len + 1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x, return_all_tokens=False):
+ # x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = x.permute(1, 0, 2) # (N(HW)C) => (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(query=x,
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False)
+ if return_all_tokens:
+ return x
+ else:
+ return x[0]
+
+
+class Resampler(nn.Module):
+
+ def __init__(
+ self,
+ dim_in=768,
+ dim_mid=1024,
+ dim_head=64,
+ dim_out=1024,
+ num_layers=8,
+ num_queries=8,
+ num_heads=16,
+ ff_mult=4,
+ ):
+ super().__init__()
+
+ self.queries = nn.Parameter(torch.randn(1, num_queries, dim_in) / dim_mid ** 0.5)
+
+ self.proj_in = nn.Linear(dim_in, dim_mid)
+ self.proj_out = nn.Linear(dim_mid, dim_out)
+ self.norm_out = nn.LayerNorm(dim_out)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(num_layers):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim_mid, dim_head=dim_head, heads=num_heads),
+ FeedForward(dim=dim_mid, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x):
+ queries = self.queries.repeat(x.size(0), 1, 1)
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ queries = attn(x, queries) + queries
+ queries = ff(queries) + queries
+
+ queries = self.proj_out(queries)
+ queries = self.norm_out(queries)
+ return queries
+
+class TaskTokenResampler(nn.Module):
+
+ def __init__(
+ self,
+ dim_in=768,
+ dim_mid=1024,
+ dim_head=64,
+ dim_out=1024,
+ num_layers=8,
+ num_queries=8,
+ num_heads=16,
+ ff_mult=4,
+ ):
+ super().__init__()
+
+ self.num_queries = num_queries
+ self.proj_in1 = nn.Linear(dim_in, dim_mid)
+ self.proj_in2 = nn.Linear(dim_in, dim_mid)
+ self.proj_out = nn.Linear(dim_mid, dim_out)
+ self.norm_out = nn.LayerNorm(dim_out)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(num_layers):
+ self.layers.append(
+ nn.ModuleList([
+ PerceiverAttention(dim=dim_mid, dim_head=dim_head, heads=num_heads),
+ FeedForward(dim=dim_mid, mult=ff_mult),
+ ]))
+
+ def forward(self, x, queries):
+ queries = self.proj_in1(queries)
+ x = self.proj_in2(x)
+
+ for attn, ff in self.layers:
+ queries = attn(x, queries) + queries
+ queries = ff(queries) + queries
+
+ queries = self.proj_out(queries)
+ queries = self.norm_out(queries)
+ return queries
+
+
+class ResamplerXL(nn.Module):
+
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output1_dim=768,
+ output2_dim=1280,
+ ff_mult=4,
+ ):
+ super().__init__()
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ # self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(dim)
+
+ self.in_dim = dim
+ self.out_dim = output1_dim + output2_dim
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList([
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]))
+
+ self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim)
+ self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim)
+ self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim)
+
+ def forward(self, x):
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ hidden_embeds = self.norm_out(latents)
+
+ encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768]
+ encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280]
+ prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048]
+ pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280]
+
+ return prompt_embeds, pooled_prompt_embeds
+
+
+class ResamplerXLV2(nn.Module):
+
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output1_dim=768,
+ output2_dim=1280,
+ ff_mult=4,
+ normalize=True
+ ):
+ super().__init__()
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+
+ self.normalize = normalize
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ # self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(dim)
+
+ self.in_dim = dim
+ self.out_dim = output1_dim + output2_dim
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList([
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]))
+
+ self.unet_proj_1 = nn.Linear(self.in_dim, output1_dim)
+ self.unet_proj_2 = nn.Linear(self.in_dim, output2_dim)
+ self.unet_attnpool = AttentionPool2d(num_queries, self.in_dim, heads, output2_dim)
+
+ def forward(self, x,pooled_text_embeds=None):
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ if self.normalize:
+ x = F.normalize(x)
+
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ hidden_embeds = self.norm_out(latents)
+
+ encoder_hidden_1 = self.unet_proj_1(hidden_embeds) # [bs, 256, 768]
+ encoder_hidden_2 = self.unet_proj_2(hidden_embeds) # [bs, 256, 1280]
+ prompt_embeds = torch.cat([encoder_hidden_1, encoder_hidden_2], dim=-1) # [bs, 256, 2048]
+ pooled_prompt_embeds = self.unet_attnpool(hidden_embeds) # [bs, 1280]
+
+ return prompt_embeds, pooled_prompt_embeds
+
+class ResamplerXLIdentity(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, x, pooled_text_embeds=None):
+ return x, pooled_text_embeds
+
+
+if __name__ == '__main__':
+ image_proj_model = Resampler(dim=1024,
+ depth=4,
+ dim_head=64,
+ heads=12,
+ num_queries=1024,
+ embedding_dim=1024,
+ output_dim=1024,
+ ff_mult=4)
+ numel = 0
+ for name, param in image_proj_model.named_parameters():
+ numel += param.numel()
+
+ print(f'Total params: {numel}')
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/.gitignore b/lingbotvla/models/vla/vision_models/lingbot-depth/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2b1c8453b898042abc0d04fab680cee8deb53054
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/.gitignore
@@ -0,0 +1,9 @@
+*.pyc
+*.egg-info
+*.egg
+__pycache__
+.idea
+.vscode
+workspace*/
+result*/
+ckpt/
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/LEGAL.md b/lingbotvla/models/vla/vision_models/lingbot-depth/LEGAL.md
new file mode 100644
index 0000000000000000000000000000000000000000..f96892081dd58b22ee2199adffd7b188b79e7e7f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/LEGAL.md
@@ -0,0 +1,7 @@
+Legal Disclaimer
+
+Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
+
+法律免责声明
+
+关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/LICENSE b/lingbotvla/models/vla/vision_models/lingbot-depth/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..3d4076fa61eb68f5ad8a5c3e61075005b6d8d6d6
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2026 LingBot-Depth Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/README.md b/lingbotvla/models/vla/vision_models/lingbot-depth/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0ce9b3d0b26b6fefc7441f835623f51232c41734
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/README.md
@@ -0,0 +1,268 @@
+# LingBot-Depth: Masked Depth Modeling for Spatial Perception
+
+
+[](LICENSE)
+[](https://www.python.org/downloads/)
+[](https://pytorch.org/)
+
+📄 **[Technical Report](https://github.com/Robbyant/lingbot-depth/blob/main/tech-report.pdf)** |
+📄 **[arXiv](https://arxiv.org/abs/2601.17895)** |
+🌐 **[Project Page](https://technology.robbyant.com/lingbot-depth)** |
+💻 **[Code](https://github.com/robbyant/lingbot-depth)** |
+🤗 **[Hugging Face](https://huggingface.co/collections/robbyant/lingbot-depth)** |
+🤖 **[ModelScope](https://www.modelscope.cn/collections/Robbyant/LingBot-Depth)**
+
+
+**LingBot-Depth** transforms incomplete and noisy depth sensor data into high-quality, metric-accurate 3D measurements. By jointly aligning RGB appearance and depth geometry in a unified latent space, our model serves as a powerful spatial perception foundation for robot learning and 3D vision applications.
+
+
+
+
+
+Our approach refines raw sensor depth into clean, complete measurements, enabling:
+- **Depth Completion & Refinement**: Fills missing regions with metric accuracy and improved quality
+- **Scene Reconstruction**: High-fidelity indoor mapping with a strong depth prior
+- **4D Point Tracking**: Accurate dynamic tracking in metric space for robot learning
+- **Dexterous Manipulation**: Robust grasping with precise geometric understanding
+
+## Artifacts Release
+
+
+### Model Zoo
+
+We provide pretrained models for different scenarios:
+
+| Model | Hugging Face Model | ModelScope Model | Description |
+|-------|-----------|-----------|-------------|
+| LingBot-Depth | [robbyant/lingbot-depth-pretrain-vitl-14](https://huggingface.co/robbyant/lingbot-depth-pretrain-vitl-14/tree/main) | [robbyant/lingbot-depth-pretrain-vitl-14](https://www.modelscope.cn/models/Robbyant/lingbot-depth-pretrain-vitl-14)| General-purpose depth refinement |
+| LingBot-Depth-DC | [robbyant/lingbot-depth-postrain-dc-vitl14](https://huggingface.co/robbyant/lingbot-depth-postrain-dc-vitl14/tree/main) | [robbyant/lingbot-depth-postrain-dc-vitl14](https://www.modelscope.cn/models/Robbyant/lingbot-depth-postrain-dc-vitl14)| Optimized for sparse depth completion |
+
+### Data Release (Coming Soon)
+- The curated 3M RGB-D dataset will be released upon completion of the necessary licensing and approval procedures.
+- Expected release: **mid-March 2026**.
+
+## Installation
+
+### Requirements
+
+• Python ≥ 3.9 • PyTorch ≥ 2.0.0 • CUDA-capable GPU (recommended)
+
+### From source
+
+```bash
+git clone https://github.com/robbyant/lingbot-depth
+cd lingbot-depth
+pip install -e .
+```
+
+
+## Quick Start
+
+**Inference:**
+
+```python
+import torch
+import cv2
+import numpy as np
+from mdm.model.v2 import MDMModel
+
+# Load model
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+model = MDMModel.from_pretrained('robbyant/lingbot-depth-pretrain-vitl-14').to(device)
+
+# Load and prepare inputs
+image = cv2.cvtColor(cv2.imread('examples/0/rgb.png'), cv2.COLOR_BGR2RGB)
+h, w = image.shape[:2]
+image = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1)[None]
+
+depth = cv2.imread('examples/0/raw_depth.png', cv2.IMREAD_UNCHANGED).astype(np.float32) / 1000.0
+depth = torch.tensor(depth, dtype=torch.float32, device=device)[None]
+
+intrinsics = np.loadtxt('examples/0/intrinsics.txt')
+intrinsics[0] /= w # Normalize fx and cx by width
+intrinsics[1] /= h # Normalize fy and cy by height
+intrinsics = torch.tensor(intrinsics, dtype=torch.float32, device=device)[None]
+
+# Run inference
+output = model.infer(
+ image,
+ depth_in=depth,
+ intrinsics=intrinsics)
+
+depth_pred = output['depth'] # Refined depth map
+points = output['points'] # 3D point cloud
+```
+
+**Run example:**
+
+Download the model weight from [Hugging Face](https://huggingface.co/robbyant/lingbot-depth-pretrain-vitl-14/tree/main) and put it in the `ckpt` folder. Then run:
+
+```bash
+python example.py
+```
+
+This processes the example data from `examples/0/` and saves visualizations to `result/`.
+
+## Method
+
+We introduce a masked depth modeling approach that learns robust RGB-D representations through self-supervised learning. The model employs a Vision Transformer encoder with specialized depth-aware attention mechanisms to jointly process RGB and depth inputs.
+
+
+
+
+
+**Depth-aware attention visualization.** Visualizing attention from depth queries (Q1–Q3, marked with ⋆) to RGB tokens in two scenes: (a) aquarium and (b) indoor shelf. Each row shows masked input depth, attention weights on RGB, and refined output. Different queries attend to spatially corresponding regions, demonstrating cross-modal alignment.
+
+**Key Innovations:**
+- **Masked Depth Modeling**: Self-supervised pre-training via depth reconstruction
+- **Cross-Modal Attention**: Joint RGB-Depth alignment in unified latent space
+- **Metric-Scale Preservation**: Maintains real-world measurements for downstream tasks
+
+## Training Data
+
+Our model is trained on a large-scale diverse dataset combining real-world and simulated RGB-D captures:
+
+
+
+
+
+**Training dataset.** 2M real-world and 1M simulated samples spanning diverse indoor environments (top). Representative RGB-D inputs with ground truth depth (bottom).
+
+**Dataset Composition:**
+- **Real Captures**: 2M samples from residential, office, and commercial environments
+- **Simulated Data**: 1M photo-realistic renders with perfect ground truth
+- **Modalities**: RGB images, raw depth, refined ground truth depth
+- **Diversity**: Multiple sensor types, lighting conditions, and scene complexities
+
+## Applications
+
+### 4D Point Tracking
+
+LingBot-Depth provides metric-accurate 3D geometry essential for tracking dynamic targets:
+
+
+
+
+
+**4D point tracking.** Robust tracking in gym environments with dynamic human motion. Top: query point selection. Middle: 3D tracking on deforming geometry. Bottom: refined depth maps. Demonstrated on scooter, rowing machine, gym equipment, and pull-up bar.
+
+### Dexterous Manipulation
+
+High-quality geometric understanding enables reliable robotic grasping across diverse objects and materials:
+
+
+
+
+
+**Dexterous grasping.** Robust manipulation enabled by refined depth. Top: point cloud reconstruction. Bottom: successful grasps on steel cup, glass cup, storage box, and toy car.
+
+## Hardware Setup
+
+We developed a scalable RGB-D capture system for large-scale data collection:
+
+
+
+
+
+**RGB-D capture system.** Multi-sensor setup with Intel RealSense, Orbbec Gemini, and Azure Kinect for scalable real-world data collection.
+
+## Model Details
+
+### Architecture
+
+- **Encoder**: Vision Transformer (Large) with RGB-D fusion
+- **Decoder**: Multi-scale feature pyramid with specialized heads
+- **Heads**: Depth regression
+- **Training**: Masked depth modeling with reconstruction objective
+
+### Input Format
+
+**RGB Image:**
+- Shape: `[B, 3, H, W]` normalized to [0, 1]
+- Format: PyTorch tensor, float32
+
+**Depth Map:**
+- Shape: `[B, H, W]`
+- Unit: Meters (configurable via scale parameter)
+- Invalid regions: 0 or NaN
+
+**Camera Intrinsics:**
+- Shape: `[B, 3, 3]`
+- Normalized format: `fx'=fx/W, fy'=fy/H, cx'=cx/W, cy'=cy/H`
+- Example:
+ ```
+ [[fx/W, 0, cx/W],
+ [ 0, fy/H, cy/H],
+ [ 0, 0, 1 ]]
+ ```
+
+### Output Format
+
+The model returns a dictionary:
+
+```python
+{
+ 'depth': torch.Tensor, # Refined depth [B, H, W]
+ 'points': torch.Tensor, # Point cloud [B, H, W, 3] in camera space
+}
+```
+
+### Inference Parameters
+
+```python
+model.infer(
+ image, # RGB tensor [B, 3, H, W]
+ depth_in=None, # Input depth [B, H, W]
+ use_fp16=True, # Mixed precision inference
+ intrinsics=None, # Camera intrinsics [B, 3, 3]
+)
+```
+
+## Citation
+
+If you find this work useful for your research, please cite:
+
+```bibtex
+@article{lingbot-depth2026,
+ title={Masked Depth Modeling for Spatial Perception},
+ author={Tan, Bin and Sun, Changjiang and Qin, Xiage and Adai, Hanat and Fu, Zelin and Zhou, Tianxiang and Zhang, Han and Xu, Yinghao and Zhu, Xing and Shen, Yujun and Xue, Nan},
+ journal={arXiv preprint arXiv:[2601.17895]},
+ year={2026}
+}
+```
+
+Please also consider citing DINOv2, which serves as our backbone:
+
+```bibtex
+@article{oquab2023dinov2,
+ title={DINOv2: Learning Robust Visual Features without Supervision},
+ author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and others},
+ journal={Transactions on Machine Learning Research},
+ year={2024}
+}
+```
+
+## License
+
+This project is released under the Apache License 2.0. See [LICENSE](LICENSE) file for details.
+
+## Acknowledgments
+
+This work builds upon several excellent open-source projects:
+
+- [DINOv2](https://github.com/facebookresearch/dinov2) - Self-supervised vision transformer backbone
+- [Masked Autoencoders](https://github.com/facebookresearch/mae) - Self-supervised learning framework
+- The broader open-source computer vision and robotics communities
+
+## Contact
+
+For questions, discussions, or collaborations:
+
+- **Issues**: Open an [issue](https://github.com/robbyant/lingbot-depth/issues) on GitHub
+- **Email**: Contact Dr. [Bin Tan](https://https://icetttb.github.io/) (tanbin.tan@antgroup.com) or Dr. [Nan Xue](https://xuenan.net) (xuenan.xue@antgroup.com)
+
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/assets/attention/fig-attention-vis.png b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/attention/fig-attention-vis.png
new file mode 100644
index 0000000000000000000000000000000000000000..982ae361dfc568bf355ab616201c84f99e813e8d
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/attention/fig-attention-vis.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a745828bd4dca8c459e8c8f65208861e4836475ec3408cc9a4d26cdf65d64bca
+size 5304739
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/assets/dataset/diversity_figure.png b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/dataset/diversity_figure.png
new file mode 100644
index 0000000000000000000000000000000000000000..245ac8658f664c6f2e096b961813b7dad4029114
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/dataset/diversity_figure.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6a424d1a5b3c1579126879a539617947e5b8c3af403026284c49a4a77ee7aabf
+size 887222
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-divided.jpg b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-divided.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..444ed0bcd6b12d6ca22808af6b3fafb6eb94b2bf
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-divided.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ebb466c2c5beed4dcf0919dab34d85f837f9e11d58b864d99029cabb2af5c953
+size 164334
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-full.jpg b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-full.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4d9ea4ffa0bffc1b222e569b8c19d233ef26c9a0
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-full.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c812502a85f7089740c030c92624e944997bd648c78b84d75cd6ff4389679c00
+size 145609
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_grasp/fig-grasp-demo.png b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_grasp/fig-grasp-demo.png
new file mode 100644
index 0000000000000000000000000000000000000000..88749a6f364b08dd17d54108821822f93b0c333a
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_grasp/fig-grasp-demo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18bb0a341b082db86c24cb26909bb331d00eade606fa4c736de40ec6edf25f9c
+size 1566673
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-dynamic-tracking.png b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-dynamic-tracking.png
new file mode 100644
index 0000000000000000000000000000000000000000..11f6233e87461965a5295739de1bc780983ae583
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-dynamic-tracking.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:26b88afd813cb6f3d3442d1e94a40dd9dbe7b77ee308e8e76b3534acecdfae74
+size 8646746
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-scene-tracking-crop.png b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-scene-tracking-crop.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f3adaf71b715e53882199501b44ae54aa1aebfd
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-scene-tracking-crop.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffc9755a91376789c385c656b4518237d0279834d61e7e38e33f511ec13832ae
+size 1097675
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/assets/teaser/teaser-crop.png b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/teaser/teaser-crop.png
new file mode 100644
index 0000000000000000000000000000000000000000..4edbb182eaf879ff04d29fb2f61c7a3786f9fcc2
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/assets/teaser/teaser-crop.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b0dfa20815a4fdceb0b5255234693ecebf40127be8986bf6297af0e5ba490044
+size 2276197
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/example.py b/lingbotvla/models/vla/vision_models/lingbot-depth/example.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c5e94d1238f48010c55096abb919a9c92bfe77
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/example.py
@@ -0,0 +1,116 @@
+import cv2
+import torch
+import numpy as np
+import trimesh
+import os
+from pathlib import Path
+from mdm.model.v2 import MDMModel as v2
+
+def preprocess_input_image(image_path, device):
+ """
+ Preprocess input image
+
+ Args:
+ image_path (str): Image path
+ device (torch.device): Device
+
+ Returns:
+ tuple: (numpy_image, tensor_image) Image in numpy and tensor format
+ """
+ # Read image and convert to RGB format
+ image_np = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
+ # Convert to tensor and normalize to [0, 1] range
+ image_tensor = torch.tensor(image_np / 255, dtype=torch.float32, device=device).permute(2, 0, 1)[None]
+ return image_np, image_tensor
+
+def load_depth_map(depth_path, scale=1000.0):
+ """
+ Load depth map and convert to meters
+
+ Args:
+ depth_path (str): Depth map path
+
+ Returns:
+ np.ndarray: Depth map (in meters)
+ """
+ # Read depth map and convert to meters (original unit is millimeters)
+ depth_map = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / scale
+ depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=0.0, neginf=0.0)
+
+ return depth_map
+
+def depth_to_color_opencv(depth_map, vmin=None, vmax=None, colormap=cv2.COLORMAP_TURBO):
+ """
+ Convert depth map using OpenCV colormap (faster)
+
+ Args:
+ depth_map: (H, W) numpy array
+ colormap: cv2.COLORMAP_TURBO, cv2.COLORMAP_JET, cv2.COLORMAP_VIRIDIS, etc.
+
+ Returns:
+ (H, W, 3) numpy array, BGR, 0-255
+ """
+ # Handle invalid values
+ valid_mask = np.isfinite(depth_map)
+ depth_clean = depth_map.copy()
+ depth_clean[~valid_mask] = 0
+
+ if vmin is None:
+ vmin = depth_clean[valid_mask].min() if valid_mask.any() else 0
+ if vmax is None:
+ vmax = depth_clean[valid_mask].max() if valid_mask.any() else 1
+
+ # Normalize to [0, 255]
+ depth_normalized = np.clip((depth_clean - vmin) / (vmax - vmin + 1e-8) * 255, 0, 255).astype(np.uint8)
+
+ # Apply colormap
+ depth_colored = cv2.applyColorMap(depth_normalized, colormap)
+
+ # Handle invalid values
+ depth_colored[~valid_mask] = [0, 0, 0]
+
+ return depth_colored
+
+DEVICE = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
+
+ckpt_path = 'ckpt/model.pt'
+rgb_path = 'examples/0/rgb.png'
+depth_path = 'examples/0/raw_depth.png'
+intrinsics_path = 'examples/0/intrinsics.txt'
+
+
+intrinsics = np.loadtxt(intrinsics_path)
+intrinsics = torch.tensor(intrinsics, dtype=torch.float32, device=DEVICE)
+image_np, image_tensor = preprocess_input_image(rgb_path, DEVICE)
+depth_np = load_depth_map(depth_path)
+depth_tensor = torch.tensor(depth_np, dtype=torch.float32, device=DEVICE)
+
+h, w = image_np.shape[:2]
+intrinsics[0] /= w
+intrinsics[1] /= h
+
+model = v2.from_pretrained(ckpt_path).to(DEVICE)
+
+output = model.infer(
+ image_tensor,
+ depth_in=depth_tensor,
+ apply_mask=True,
+ intrinsics=intrinsics[None]
+ )
+
+depth_pred = output['depth'].squeeze().cpu().numpy()
+
+res_dir = Path('result')
+res_dir.mkdir(exist_ok=True)
+# save depth map
+depth_raw_color = depth_to_color_opencv(depth_np)
+depth_pred_color = depth_to_color_opencv(depth_pred)
+depth_concat = np.concatenate([depth_raw_color, depth_pred_color], axis=1)
+cv2.imwrite(res_dir/'res.png', depth_concat)
+
+# save pcd
+points_pred = output['points'].squeeze().cpu().numpy()
+verts = points_pred.reshape(-1, 3)[::2]
+verts_color = image_np.reshape(-1, 3)[::2]
+point_cloud = trimesh.PointCloud(verts, verts_color)
+point_cloud.export(res_dir/'pcd.ply')
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/intrinsics.txt b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/intrinsics.txt
new file mode 100644
index 0000000000000000000000000000000000000000..78cfb6cd2d9b6182735d1f2b6c8a5de8259a1cd2
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/intrinsics.txt
@@ -0,0 +1,3 @@
+460.139587 0.000000 319.656128
+0.000000 460.199005 237.396271
+0.000000 0.000000 1.000000
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/raw_depth.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/raw_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..7ffcfe971a40c4a5150843639672f8316f010c03
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/raw_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cdd19fa8ff03737f1a498a3596d0fa4f9337047b082c722e9bdc3e7ca1dfde33
+size 215840
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/raw_depth_color.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/raw_depth_color.png
new file mode 100644
index 0000000000000000000000000000000000000000..237b60bf78e202033ecc6a1797d0bf6ee54d87aa
Binary files /dev/null and b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/raw_depth_color.png differ
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/rgb.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/rgb.png
new file mode 100644
index 0000000000000000000000000000000000000000..9f87e564b53ac01fc3ef43ac88dc9701bd79347c
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/rgb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ca9885351c5ec4df2a3d8dd46fde15851bead31746d632d3d54334ed9fca10cd
+size 395219
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/intrinsics.txt b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/intrinsics.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fb1e47fe8a41b2f1dc7e4901fc519fe432fb55c4
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/intrinsics.txt
@@ -0,0 +1,3 @@
+461.2757263183594 0.000000 324.9846496582031
+0.000000 461.51947021484375 241.5243377685547
+0.000000 0.000000 1.000000
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/raw_depth.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/raw_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..8addf2734cd92750c2bb83c1ae8e8942fc6a26bf
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/raw_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1ff88d57f9b1c5a6b1e7b8b5b1909bb9d7782094bdf1e86b7d2b4120cfd8ca8f
+size 177294
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/rgb.jpg b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/rgb.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..968ec58d51c6882195a1997518984b3b08e5347f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/rgb.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:507c1818fca7f5a0270d146006e87744f084200447c9f5a7803b5352804733e8
+size 100145
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/intrinsics.txt b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/intrinsics.txt
new file mode 100644
index 0000000000000000000000000000000000000000..78cfb6cd2d9b6182735d1f2b6c8a5de8259a1cd2
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/intrinsics.txt
@@ -0,0 +1,3 @@
+460.139587 0.000000 319.656128
+0.000000 460.199005 237.396271
+0.000000 0.000000 1.000000
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/raw_depth.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/raw_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..16cbad1bae033d6b753ce1624a1512e0f7e034d0
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/raw_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e17f5db495253601f002a26703a28ecee01397fc1bd10a9b24f31e1fc4c402fd
+size 125956
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/rgb.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/rgb.png
new file mode 100644
index 0000000000000000000000000000000000000000..cef5ace3009c00860ce73c51d81e6bf7d06c6c38
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/rgb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:805119bb2fed22ee481b80c7315b37b65f9063374db9d5c4f897dd3151e5b501
+size 415738
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/intrinsics.txt b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/intrinsics.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fb1e47fe8a41b2f1dc7e4901fc519fe432fb55c4
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/intrinsics.txt
@@ -0,0 +1,3 @@
+461.2757263183594 0.000000 324.9846496582031
+0.000000 461.51947021484375 241.5243377685547
+0.000000 0.000000 1.000000
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/raw_depth.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/raw_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..3d00442f772d5ed65ff4f3d1e2741961c06f6390
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/raw_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1712c731126437aaf178618f8e6ca21ed68fc0c865bdb21c2c369670c742480d
+size 183587
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/rgb.jpg b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/rgb.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9e231543e377f9cd113b8481d77ac85fb8fb8a21
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/rgb.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9eb065ce6d55edde4a2e60612e0ba0ae02b68bab2d3537dc619b3f35e9ef7fd
+size 126665
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/intrinsics.txt b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/intrinsics.txt
new file mode 100644
index 0000000000000000000000000000000000000000..78cfb6cd2d9b6182735d1f2b6c8a5de8259a1cd2
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/intrinsics.txt
@@ -0,0 +1,3 @@
+460.139587 0.000000 319.656128
+0.000000 460.199005 237.396271
+0.000000 0.000000 1.000000
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/raw_depth.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/raw_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..ef13603550857f5472d8315f7abccc200e4680a4
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/raw_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e71b43da332fa7b907fcbc07df5f90a7160db939ad3117ad9cfe8a3316104ea8
+size 137558
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/rgb.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/rgb.png
new file mode 100644
index 0000000000000000000000000000000000000000..b2b1313137b8634ee340007d36c9c4fafbfc9166
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/rgb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed8d1289bc035e15a40a4967da849b16120e6656c999dbc9f721e150981625a3
+size 353301
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/intrinsics.txt b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/intrinsics.txt
new file mode 100644
index 0000000000000000000000000000000000000000..78cfb6cd2d9b6182735d1f2b6c8a5de8259a1cd2
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/intrinsics.txt
@@ -0,0 +1,3 @@
+460.139587 0.000000 319.656128
+0.000000 460.199005 237.396271
+0.000000 0.000000 1.000000
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/raw_depth.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/raw_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..0c0b97e0701cc39b869ead7a6696a2637d6db25a
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/raw_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:77d744908421a5219db44bc0f72e1cafbe1ed0979e5f7a50ebfce4f75abc8cd5
+size 184927
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/rgb.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/rgb.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc2eb7614dacecea5bf31d9ad1dd14e49fdb5ec0
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/rgb.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e29e72873a9176ccd3c51142c402c3b6fc87e1c44983e6e5b53e751f69ce5bc0
+size 434530
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/intrinsics.txt b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/intrinsics.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fb1e47fe8a41b2f1dc7e4901fc519fe432fb55c4
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/intrinsics.txt
@@ -0,0 +1,3 @@
+461.2757263183594 0.000000 324.9846496582031
+0.000000 461.51947021484375 241.5243377685547
+0.000000 0.000000 1.000000
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/raw_depth.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/raw_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..7fa7031ee11cbb1393e87885f9842908fb1ce792
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/raw_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a3871bfea86149baf60e658473a90889802195dcd88bbb0c048dc8e89a5f3c1
+size 156168
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/rgb.jpg b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/rgb.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..98adbf02717c0f6a33041e96082153d69524c242
Binary files /dev/null and b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/rgb.jpg differ
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/intrinsics.txt b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/intrinsics.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fb1e47fe8a41b2f1dc7e4901fc519fe432fb55c4
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/intrinsics.txt
@@ -0,0 +1,3 @@
+461.2757263183594 0.000000 324.9846496582031
+0.000000 461.51947021484375 241.5243377685547
+0.000000 0.000000 1.000000
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/raw_depth.png b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/raw_depth.png
new file mode 100644
index 0000000000000000000000000000000000000000..53ab050fb6cd7f3b642c89974749871b43606637
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/raw_depth.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:068a5aaf9334d7d5a1baa2680c6807165248276ed383d7773bee87f27d4ea607
+size 267301
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/rgb.jpg b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/rgb.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d943c8dd0e04d528edae116ca0a8b794c9889065
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/rgb.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8bc01589c575a1e5f5363743e4fed922f8093dcc7ca60858d0ac4e7fae1f17ff
+size 121283
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/__init__.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..99420f50e6c63ab275d94d6bb2d2c9fc39cdd3e6
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/__init__.py
@@ -0,0 +1,15 @@
+import importlib
+from typing import *
+
+if TYPE_CHECKING:
+ from .v2 import MDMModel as MDMModelV2
+
+def import_model_class_by_version(version: str) -> Type[Union['MDMModelV2']]:
+ assert version in ['v2'], f'Unsupported model version: {version}'
+
+ try:
+ module = importlib.import_module(f'.{version}', __package__)
+ except ModuleNotFoundError:
+ raise ValueError(f'Model version "{version}" not found.')
+ cls = getattr(module, 'MDMModel')
+ return cls
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/__init__.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+__version__ = "0.0.1"
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/hub/__init__.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/hub/backbones.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f81215aaab11548425fee4f1b199048e164cec8
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/hub/backbones.py
@@ -0,0 +1,162 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+
+ return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+def dinov2_vitl16(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ # kwargs.update({'img_size': 224, 'patch_size': 16, })
+ return _make_dinov2_model(arch_name="vit_large", pretrained=False, weights=weights, **kwargs)
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/hub/utils.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/hub/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+ output = F.pad(x, pads)
+ return output
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/__init__.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ca939cbc945791c12b8c4e4088e0c0ecb7c0fef
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
+from .patch_embed_mlp import PatchEmbed as PatchEmbedMLP
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/attention.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9f79d471fc099b1dcaa512dfdbdec8a9fc5908f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/attention.py
@@ -0,0 +1,100 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+import torch.nn.functional as F
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Attention)")
+ else:
+ # warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ # # Deprecated implementation, extremely slow
+ # def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ # B, N, C = x.shape
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ # q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ # attn = q @ k.transpose(-2, -1)
+ # attn = attn.softmax(dim=-1)
+ # attn = self.attn_drop(attn)
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ # x = self.proj(x)
+ # x = self.proj_drop(x)
+ # return x
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
+
+ q, k, v = qkv.unbind(0) # (B, H, N, C // H)
+
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/block.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..de6faacca49fe7cd263ce12f5c9fcf46fc7e3770
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/block.py
@@ -0,0 +1,259 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Block)")
+ else:
+ # warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/dino_head.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/dino_head.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/drop_path.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/layer_scale.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/mlp.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/patch_embed.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..26938ac088a04bd20ea4032f84dc7904efc202bc
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/patch_embed_mlp.py
@@ -0,0 +1,153 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+import torch.nn.functional as F
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+class PixelUnshuffle (nn.Module):
+ def __init__(self, downscale_factor):
+ super().__init__()
+ self.downscale_factor = downscale_factor
+
+ def forward(self, input):
+ if input.numel() == 0:
+ # this is not in the original torch implementation
+ C,H,W = input.shape[-3:]
+ assert H and W and H % self.downscale_factor == W%self.downscale_factor == 0
+ return input.view(*input.shape[:-3], C*self.downscale_factor**2, H//self.downscale_factor, W//self.downscale_factor)
+ else:
+ return F.pixel_unshuffle(input, self.downscale_factor)
+
+class Permute(nn.Module):
+ dims: tuple[int, ...]
+ def __init__(self, dims: tuple[int, ...]) -> None:
+ super().__init__()
+ self.dims = tuple(dims)
+
+ def __repr__(self):
+ return f"Permute{self.dims}"
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return input.permute(*self.dims)
+
+from itertools import repeat
+import collections.abc
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+to_2tuple = _ntuple(2)
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Sequential(
+ PixelUnshuffle(patch_size),
+ Permute((0,2,3,1)),
+ Mlp(in_chans * patch_size * patch_size, 4*embed_dim, embed_dim),
+ Permute((0,3,1,2)),
+ )
+
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ # warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ # warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/models/__init__.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c00e05581b29e31e936a55ee7791dbe2cf85f37
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/models/__init__.py
@@ -0,0 +1,55 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+'''
+Docstring for MDM.mdm.model.dinov2_rgbd.models_vlmae
+=======================================================
+This version is modified from the original DINOv2 to support the MIM(masked image modeling) of RGBD input.
+(The original DINOv2 is available at https://github.com/facebookresearch/dinov2.)
+
+Core Changes:
+1. We add the depth input into the original DINOv2 transformer encoder.
+
+2. We support the Variable Mask Ratio MAE for both RGB and Depth input.
+'''
+
+import logging
+
+from . import vision_transformer as vits
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
+
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/models/mask_utils.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/models/mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0491007de14d21da8e9e81e508d36717190de8bb
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/models/mask_utils.py
@@ -0,0 +1,137 @@
+import torch
+def depth_masking(
+ x,
+ patch_num_h,
+ patch_num_w,
+ depth_values,
+ depth_mask_threshold_ratio=None,
+ depth_mask_threshold_num=None,
+ valid_depth_range=(0.1, 10.0),
+):
+ """
+ Perform patch masking based on depth validity
+
+ Args:
+ x: [B, N, D] input features (after patch embedding)
+ patch_num_h: int, height of the patch grid
+ patch_num_w: int, width of the patch grid
+ depth_values: [B, 1, H_img, W_img], raw depth map
+ depth_mask_threshold_ratio: float or list, valid depth ratio threshold (0-1)
+ depth_mask_threshold_num: int or list, valid depth pixel count threshold
+ valid_depth_range: tuple, valid depth range (min, max)
+
+ Returns:
+ visible_list: list of [N_visible_i, D], visible patches for each sample
+ mask_info: dict, containing masking information
+ """
+ B, N, D = x.shape
+ device = x.device
+
+ assert N == patch_num_h * patch_num_w, \
+ f"N={N} must equal patch_num_h * patch_num_w = {patch_num_h * patch_num_w}"
+
+ # Compute depth invalid mask
+ depth_invalid_mask = _compute_depth_invalid_mask(
+ depth_values,
+ patch_num_h,
+ patch_num_w,
+ depth_mask_threshold_ratio,
+ depth_mask_threshold_num,
+ valid_depth_range
+ ) # [B, N], True indicates this patch is invalid
+
+ # Process each sample separately
+ visible_list = []
+ mask_info = {
+ 'visible_indices': [],
+ 'mask_indices': [],
+ 'num_visible': [],
+ }
+
+ for i in range(B):
+ # Get valid patch indices
+ valid_mask = ~depth_invalid_mask[i] # [N]
+ visible_indices = torch.where(valid_mask)[0]
+ masked_indices = torch.where(depth_invalid_mask[i])[0]
+
+ # Extract visible patches
+ visible = x[i, visible_indices] # [N_visible, D]
+ visible_list.append(visible)
+
+ # Record information
+ mask_info['visible_indices'].append(visible_indices)
+ mask_info['mask_indices'].append(masked_indices)
+ mask_info['num_visible'].append(len(visible_indices))
+
+ return visible_list, mask_info
+
+def _compute_depth_invalid_mask(
+ depth_values,
+ H_patch,
+ W_patch,
+ threshold_ratio,
+ threshold_num,
+ valid_range
+):
+ """
+ Compute depth validity for each patch
+
+ Args:
+ depth_values: [B, 1, H_img, W_img] raw depth map
+ H_patch, W_patch: patch grid dimensions
+ threshold_ratio: float or list, valid depth ratio threshold
+ threshold_num: int or list, valid depth pixel count threshold
+ valid_range: tuple, (min_depth, max_depth)
+
+ Returns:
+ invalid_mask: [B, N] bool tensor, True indicates this patch is invalid
+ """
+ B, _, H_img, W_img = depth_values.shape
+ N = H_patch * W_patch
+ device = depth_values.device
+
+ min_depth, max_depth = valid_range
+
+ # Calculate pixel size for each patch
+ patch_h = H_img // H_patch
+ patch_w = W_img // W_patch
+
+ assert H_img % H_patch == 0 and W_img % W_patch == 0, \
+ f"Image size ({H_img}, {W_img}) must be divisible by patch grid ({H_patch}, {W_patch})"
+
+ # Reshape depth map into patches: [B, 1, H_img, W_img] -> [B, H_patch, patch_h, W_patch, patch_w]
+ depth_reshaped = depth_values.view(B, 1, H_patch, patch_h, W_patch, patch_w)
+
+ # Transpose and flatten: [B, H_patch, W_patch, patch_h, patch_w] -> [B, N, patch_h*patch_w]
+ depth_reshaped = depth_reshaped.permute(0, 2, 4, 1, 3, 5).reshape(B, N, -1)
+
+ # Calculate valid depth
+ valid_depth = (depth_reshaped >= min_depth) & (depth_reshaped <= max_depth)
+ valid_depth_ratio = valid_depth.float().mean(dim=-1) # [B, N]
+ valid_depth_num = valid_depth.float().sum(dim=-1) # [B, N]
+
+ # Handle list-form thresholds (different thresholds for each sample in batch)
+ if isinstance(threshold_ratio, list) or isinstance(threshold_num, list):
+ invalid_mask = torch.zeros(B, N, dtype=torch.bool, device=device)
+
+ for i in range(B):
+ tr = threshold_ratio[i] if isinstance(threshold_ratio, list) else threshold_ratio
+ tn = threshold_num[i] if isinstance(threshold_num, list) else threshold_num
+
+ sample_mask = torch.zeros(N, dtype=torch.bool, device=device)
+ if tr is not None:
+ sample_mask |= (valid_depth_ratio[i] < tr)
+ if tn is not None:
+ sample_mask |= (valid_depth_num[i] < tn)
+
+ invalid_mask[i] = sample_mask
+ else:
+ # Uniform threshold
+ invalid_mask = torch.zeros(B, N, dtype=torch.bool, device=device)
+
+ if threshold_ratio is not None:
+ invalid_mask |= (valid_depth_ratio < threshold_ratio)
+ if threshold_num is not None:
+ invalid_mask |= (valid_depth_num < threshold_num)
+
+ return invalid_mask
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/models/vision_transformer.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..07900ce480a1407f08e0a212dfa264cf88c59f8a
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/models/vision_transformer.py
@@ -0,0 +1,479 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable, Optional, List
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+from ..layers import PatchEmbedMLP
+
+from .mask_utils import depth_masking
+
+logger = logging.getLogger("dinov2_rgbd")
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ img_depth_fuse_mode='',
+ depth_mask_ratio:Union[float, List[float]]=0.6,
+ img_mask_ratio:Union[float, List[float]]=0.0,
+ depth_mask_patch_grid_size: int=1,
+ img_mask_patch_grid_size: int=1,
+ depth_emb_mode='',
+ # depth_emb_mode='conv_1c'
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.depth_emb_mode = depth_emb_mode
+ if self.depth_emb_mode == 'conv_1c':
+ self.depth_patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=1, embed_dim=embed_dim)
+ else:
+ self.depth_patch_embed = None
+
+ self.img_depth_fuse_mode = img_depth_fuse_mode
+
+ self.depth_mask_patch_grid_size = depth_mask_patch_grid_size
+ self.img_mask_patch_grid_size = img_mask_patch_grid_size
+ assert self.depth_mask_patch_grid_size == 1, "depth_mask_patch_grid_size must be 1 in current version"
+ assert self.img_mask_patch_grid_size == 1, "img_mask_patch_grid_size must be 1 in current version"
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ @property
+ def onnx_compatible_mode(self):
+ return getattr(self, "_onnx_compatible_mode", False)
+
+ @onnx_compatible_mode.setter
+ def onnx_compatible_mode(self, value: bool):
+ self._onnx_compatible_mode = value
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, h, w):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ batch_size = x.shape[0]
+ N = self.pos_embed.shape[1] - 1
+ if not self.onnx_compatible_mode and npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0, :]
+ patch_pos_embed = pos_embed[:, 1:, :]
+ dim = x.shape[-1]
+ h0, w0 = h // self.patch_size, w // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if not self.onnx_compatible_mode and self.interpolate_offset > 0:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sy, sx)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (h0, w0)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+
+ assert (h0, w0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
+ return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def interpolate_pos_encoding_without_cls(self, x, h, w, input_pos_embed):
+ previous_dtype = x.dtype
+ npatch = x.shape[1]
+ batch_size = x.shape[0]
+ N = input_pos_embed.shape[1]
+ if not self.onnx_compatible_mode and npatch == N and w == h:
+ return input_pos_embed
+ patch_pos_embed = input_pos_embed.float()
+ dim = x.shape[-1]
+ h0, w0 = h // self.patch_size, w // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if not self.onnx_compatible_mode and self.interpolate_offset > 0:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sy, sx)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (h0, w0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (h0, w0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
+ return patch_pos_embed.to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, masks=None, **kwargs):
+ assert masks is None, "extra masks are not supported for this model."
+ B, nc, h_img, w_img = x_img.shape
+ _, _, h_depth, w_depth = x_depth.shape
+ x_depth_raw = x_depth.clone()
+ x_depth_raw[x_depth_raw == 0] = -10
+
+ depth_patch_num_h, depth_patch_num_w = h_depth // self.patch_size, w_depth // self.patch_size
+
+ # patchify, embed image tokens and depth tokens
+ x_img = self.patch_embed(x_img) # batch, length_img, dim
+ assert self.depth_patch_embed is not None
+ x_depth = self.depth_patch_embed(x_depth) # batch, length_depth, dim
+ assert depth_patch_num_h * depth_patch_num_w == x_depth.shape[1]
+
+ # get full pose enc of img and depth
+ # 1-> img data type enc
+ # 2-> depth data type enc
+ img_pose_enc = 1 + self.interpolate_pos_encoding_without_cls(x_img, h_img, w_img, self.pos_embed[:, 1:]).repeat(B, 1, 1)
+ depth_pose_enc = 2 + self.interpolate_pos_encoding_without_cls(x_depth, h_depth, w_depth, self.pos_embed[:, 1:]).repeat(B, 1, 1)
+
+ # add pose enc to img and depth
+ x_img = x_img + img_pose_enc
+ x_depth = x_depth + depth_pose_enc
+
+ ## mask depth tokens
+ if kwargs.get('enable_depth_mask', True):
+ x_depth_masked, depth_mask_info = depth_masking(
+ x_depth,
+ depth_patch_num_h,
+ depth_patch_num_w,
+ depth_values=x_depth_raw,
+ depth_mask_threshold_num=[1]*B,
+ valid_depth_range=(-9.5, 200.0)
+ )
+ else:
+ x_depth_masked = x_depth
+ depth_mask_info = None
+
+ ## mask image tokens
+ x_img_masked = x_img
+ img_mask_info = None
+
+ # get cls token
+ x_cls = self.cls_token.squeeze(0) + self.pos_embed.squeeze(0)[:1] # 1, dim
+
+ # cat cls, img and depth tokens
+ assert self.img_depth_fuse_mode == 'cat_token', "Only cat_token mode is supported for this model."
+ x_masked_list = []
+ for i in range(B):
+ if self.register_tokens is not None:
+ x_mased = torch.cat([x_cls, self.register_tokens.squeeze(0), x_img_masked[i], x_depth_masked[i]], dim=0) # 1 + num_register_tokens + length_img + length_depth, dim
+ else:
+ x_mased = torch.cat([x_cls, x_img_masked[i], x_depth_masked[i]], dim=0) # 1 + length_img + length_depth, dim
+ x_mased = x_mased.unsqueeze(0) # 1, 1 + num_register_tokens + length_img + length_depth, dim
+ x_masked_list.append(x_mased)
+
+ return x_masked_list
+
+ def _get_intermediate_layers_not_chunked(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, n=1, return_mae_aux=False, **kwargs):
+ x = self.prepare_tokens_with_masks(x_img, x_depth, x_img_mask, x_depth_mask, **kwargs)
+
+ if not kwargs.get('enable_depth_mask', True):
+ x = torch.cat(x, dim=0)
+
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+
+ if not kwargs.get('enable_depth_mask', True):
+ output = [list(torch.split(out, 1, dim=0)) for out in output]
+ return output
+
+ def _get_intermediate_layers_chunked(self, x_img, x_depth, x_img_mask=None, x_depth_mask=None, n=1, return_mae_aux=False, **kwargs):
+ x = self.prepare_tokens_with_masks(x_img, x_depth, x_img_mask, x_depth_mask, **kwargs)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+
+ return output
+
+ def extract_features(self, outputs, norm=True):
+ feat_outputs = []
+ class_tokens = []
+ feat_start_idx = 1 + self.num_register_tokens
+
+ def process_output(out):
+ normed = self.norm(out) if norm else out
+ return normed[:, feat_start_idx:], normed[:, 0]
+
+ for output in outputs:
+ if isinstance(output, list):
+ feats, tokens = zip(*[process_output(out) for out in output])
+ feat_outputs.append(list(feats))
+ class_tokens.append(list(tokens))
+ else:
+ feat, token = process_output(output)
+ feat_outputs.append(feat)
+ class_tokens.append(token)
+
+ return feat_outputs, class_tokens
+
+ def get_intermediate_layers_mae(
+ self,
+ x_img: torch.Tensor,
+ x_depth: torch.Tensor,
+ x_img_mask: torch.Tensor=None,
+ x_depth_mask: torch.Tensor=None,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ return_mae_aux=True,
+ **kwargs
+ ):
+ assert reshape is False, "reshape is not supported for now"
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x_img, x_depth, x_img_mask, x_depth_mask, n, return_mae_aux=return_mae_aux,**kwargs)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x_img, x_depth, x_img_mask, x_depth_mask, n, return_mae_aux=return_mae_aux,**kwargs)
+
+ feat_outputs, class_tokens = self.extract_features(outputs, norm)
+
+ if return_class_token:
+ return tuple(zip(feat_outputs, class_tokens))
+ return tuple(feat_outputs)
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/__init__.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/cluster.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/cluster.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+
+class ClusterType(Enum):
+ AWS = "aws"
+ FAIR = "fair"
+ RSC = "rsc"
+
+
+def _guess_cluster_type() -> ClusterType:
+ uname = os.uname()
+ if uname.sysname == "Linux":
+ if uname.release.endswith("-aws"):
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
+ return ClusterType.AWS
+ elif uname.nodename.startswith("rsc"):
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
+ return ClusterType.RSC
+
+ return ClusterType.FAIR
+
+
+def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
+ if cluster_type is None:
+ return _guess_cluster_type()
+
+ return cluster_type
+
+
+def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ CHECKPOINT_DIRNAMES = {
+ ClusterType.AWS: "checkpoints",
+ ClusterType.FAIR: "checkpoint",
+ ClusterType.RSC: "checkpoint/dino",
+ }
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
+
+
+def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ checkpoint_path = get_checkpoint_path(cluster_type)
+ if checkpoint_path is None:
+ return None
+
+ username = os.environ.get("USER")
+ assert username is not None
+ return checkpoint_path / username
+
+
+def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ SLURM_PARTITIONS = {
+ ClusterType.AWS: "learnlab",
+ ClusterType.FAIR: "learnlab",
+ ClusterType.RSC: "learn",
+ }
+ return SLURM_PARTITIONS[cluster_type]
+
+
+def get_slurm_executor_parameters(
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
+) -> Dict[str, Any]:
+ # create default parameters
+ params = {
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
+ "gpus_per_node": num_gpus_per_node,
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
+ "cpus_per_task": 10,
+ "nodes": nodes,
+ "slurm_partition": get_slurm_partition(cluster_type),
+ }
+ # apply cluster-specific adjustments
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type == ClusterType.AWS:
+ params["cpus_per_task"] = 12
+ del params["mem_gb"]
+ elif cluster_type == ClusterType.RSC:
+ params["cpus_per_task"] = 12
+ # set additional parameters / apply overrides
+ params.update(kwargs)
+ return params
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/config.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/config.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+import logging
+import os
+
+from omegaconf import OmegaConf
+
+import dinov2.distributed as distributed
+from dinov2.logging import setup_logging
+from dinov2.utils import utils
+from dinov2.configs import dinov2_default_config
+
+
+logger = logging.getLogger("dinov2")
+
+
+def apply_scaling_rules_to_cfg(cfg): # to fix
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
+ base_lr = cfg.optim.base_lr
+ cfg.optim.lr = base_lr
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
+ else:
+ raise NotImplementedError
+ return cfg
+
+
+def write_config(cfg, output_dir, name="config.yaml"):
+ logger.info(OmegaConf.to_yaml(cfg))
+ saved_cfg_path = os.path.join(output_dir, name)
+ with open(saved_cfg_path, "w") as f:
+ OmegaConf.save(config=cfg, f=f)
+ return saved_cfg_path
+
+
+def get_cfg_from_args(args):
+ args.output_dir = os.path.abspath(args.output_dir)
+ args.opts += [f"train.output_dir={args.output_dir}"]
+ default_cfg = OmegaConf.create(dinov2_default_config)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
+ return cfg
+
+
+def default_setup(args):
+ distributed.enable(overwrite=True)
+ seed = getattr(args, "seed", 0)
+ rank = distributed.get_global_rank()
+
+ global logger
+ setup_logging(output=args.output_dir, level=logging.INFO)
+ logger = logging.getLogger("dinov2")
+
+ utils.fix_random_seeds(seed + rank)
+ logger.info("git:\n {}\n".format(utils.get_sha()))
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg_from_args(args)
+ os.makedirs(args.output_dir, exist_ok=True)
+ default_setup(args)
+ apply_scaling_rules_to_cfg(cfg)
+ write_config(cfg, args.output_dir)
+ return cfg
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/dtype.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/dtype.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+from typing import Dict, Union
+
+import numpy as np
+import torch
+
+
+TypeSpec = Union[str, np.dtype, torch.dtype]
+
+
+_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
+ np.dtype("bool"): torch.bool,
+ np.dtype("uint8"): torch.uint8,
+ np.dtype("int8"): torch.int8,
+ np.dtype("int16"): torch.int16,
+ np.dtype("int32"): torch.int32,
+ np.dtype("int64"): torch.int64,
+ np.dtype("float16"): torch.float16,
+ np.dtype("float32"): torch.float32,
+ np.dtype("float64"): torch.float64,
+ np.dtype("complex64"): torch.complex64,
+ np.dtype("complex128"): torch.complex128,
+}
+
+
+def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
+ if isinstance(dtype, torch.dtype):
+ return dtype
+ if isinstance(dtype, str):
+ dtype = np.dtype(dtype)
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/param_groups.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/param_groups.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/param_groups.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+import logging
+
+
+logger = logging.getLogger("dinov2")
+
+
+def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
+ """
+ Calculate lr decay rate for different ViT blocks.
+ Args:
+ name (string): parameter name.
+ lr_decay_rate (float): base lr decay rate.
+ num_layers (int): number of ViT blocks.
+ Returns:
+ lr decay rate for the given parameter.
+ """
+ layer_id = num_layers + 1
+ if name.startswith("backbone") or force_is_backbone:
+ if (
+ ".pos_embed" in name
+ or ".patch_embed" in name
+ or ".mask_token" in name
+ or ".cls_token" in name
+ or ".register_tokens" in name
+ ):
+ layer_id = 0
+ elif force_is_backbone and (
+ "pos_embed" in name
+ or "patch_embed" in name
+ or "mask_token" in name
+ or "cls_token" in name
+ or "register_tokens" in name
+ ):
+ layer_id = 0
+ elif ".blocks." in name and ".residual." not in name:
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
+ elif "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
+
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
+
+
+def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
+ chunked_blocks = False
+ if hasattr(model, "n_blocks"):
+ logger.info("chunked fsdp")
+ n_blocks = model.n_blocks
+ chunked_blocks = model.chunked_blocks
+ elif hasattr(model, "blocks"):
+ logger.info("first code branch")
+ n_blocks = len(model.blocks)
+ elif hasattr(model, "backbone"):
+ logger.info("second code branch")
+ n_blocks = len(model.backbone.blocks)
+ else:
+ logger.info("else code branch")
+ n_blocks = 0
+ all_param_groups = []
+
+ for name, param in model.named_parameters():
+ name = name.replace("_fsdp_wrapped_module.", "")
+ if not param.requires_grad:
+ continue
+ decay_rate = get_vit_lr_decay_rate(
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
+ )
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
+
+ if "last_layer" in name:
+ d.update({"is_last_layer": True})
+
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
+ d.update({"wd_multiplier": 0.0})
+
+ if "patch_embed" in name:
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
+
+ all_param_groups.append(d)
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
+
+ return all_param_groups
+
+
+def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
+ fused_params_groups = defaultdict(lambda: {"params": []})
+ for d in all_params_groups:
+ identifier = ""
+ for k in keys:
+ identifier += k + str(d[k]) + "_"
+
+ for k in keys:
+ fused_params_groups[identifier][k] = d[k]
+ fused_params_groups[identifier]["params"].append(d["params"])
+
+ return fused_params_groups.values()
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/utils.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/dinov2_rgbd/utils/utils.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import random
+import subprocess
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
+ else:
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
+ if checkpoint_key is not None and checkpoint_key in state_dict:
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
+ state_dict = state_dict[checkpoint_key]
+ # remove `module.` prefix
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ # remove `backbone.` prefix induced by multicrop wrapper
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
+ msg = model.load_state_dict(state_dict, strict=False)
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
+
+
+def fix_random_seeds(seed=31):
+ """
+ Fix random seeds.
+ """
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommitted changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+class CosineScheduler(object):
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
+ super().__init__()
+ self.final_value = final_value
+ self.total_iters = total_iters
+
+ freeze_schedule = np.zeros((freeze_iters))
+
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
+
+ assert len(self.schedule) == self.total_iters
+
+ def __getitem__(self, it):
+ if it >= self.total_iters:
+ return self.final_value
+ else:
+ return self.schedule[it]
+
+
+def has_batchnorms(model):
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
+ for name, module in model.named_modules():
+ if isinstance(module, bn_types):
+ return True
+ return False
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/modules_decoder.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/modules_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dbbc46e9754f5d6946000e3a6926d31ea553570
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/modules_decoder.py
@@ -0,0 +1,185 @@
+from typing import *
+from numbers import Number
+import importlib
+import itertools
+import functools
+import sys
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .utils import wrap_module_with_gradient_checkpointing
+
+
+class ResidualConvBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int = None,
+ hidden_channels: int = None,
+ kernel_size: int = 3,
+ padding_mode: str = 'replicate',
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
+ in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
+ hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
+ ):
+ super(ResidualConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ if hidden_channels is None:
+ hidden_channels = in_channels
+
+ if activation =='relu':
+ activation_cls = nn.ReLU
+ elif activation == 'leaky_relu':
+ activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
+ elif activation =='silu':
+ activation_cls = nn.SiLU
+ elif activation == 'elu':
+ activation_cls = nn.ELU
+ else:
+ raise ValueError(f'Unsupported activation function: {activation}')
+
+ self.layers = nn.Sequential(
+ nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
+ nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
+ nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
+ nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
+ nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
+ nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
+ nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
+ nn.Identity(),
+ activation_cls(),
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
+ )
+
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x):
+ skip = self.skip_connection(x)
+ x = self.layers(x)
+ x = x + skip
+ return x
+
+
+class Resampler(nn.Sequential):
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
+ scale_factor: int = 2,
+ ):
+ if type_ == 'pixel_shuffle':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.PixelShuffle(scale_factor),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ for i in range(1, scale_factor ** 2):
+ self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
+ self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
+ elif type_ in ['nearest', 'bilinear']:
+ nn.Sequential.__init__(self,
+ nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ elif type_ == 'conv_transpose':
+ nn.Sequential.__init__(self,
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
+ elif type_ == 'pixel_unshuffle':
+ nn.Sequential.__init__(self,
+ nn.PixelUnshuffle(scale_factor),
+ nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ elif type_ == 'avg_pool':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ elif type_ == 'max_pool':
+ nn.Sequential.__init__(self,
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
+ )
+ else:
+ raise ValueError(f'Unsupported resampler type: {type_}')
+
+
+class MLP(nn.Sequential):
+ def __init__(self, dims: Sequence[int]):
+ nn.Sequential.__init__(self,
+ *itertools.chain(*[
+ (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
+ for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
+ ]),
+ nn.Linear(dims[-2], dims[-1]),
+ )
+
+
+class ConvStack(nn.Module):
+ def __init__(self,
+ dim_in: List[Optional[int]],
+ dim_res_blocks: List[int],
+ dim_out: List[Optional[int]],
+ resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
+ res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
+ ):
+ super().__init__()
+ self.input_blocks = nn.ModuleList([
+ nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
+ for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
+ ])
+ self.resamplers = nn.ModuleList([
+ Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
+ for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
+ dim_res_blocks[:-1],
+ dim_res_blocks[1:],
+ resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
+ ))
+ ])
+ self.res_blocks = nn.ModuleList([
+ nn.Sequential(
+ *(
+ ResidualConvBlock(
+ dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
+ activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
+ ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
+ )
+ ) for i, dim_res_block_ in enumerate(dim_res_blocks)
+ ])
+ self.output_blocks = nn.ModuleList([
+ nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
+ for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
+ ])
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.resamplers)):
+ self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
+ for i in range(len(self.res_blocks)):
+ for j in range(len(self.res_blocks[i])):
+ self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
+
+ def forward(self, in_features: List[torch.Tensor]):
+ out_features = []
+ for i in range(len(self.res_blocks)):
+ feature = self.input_blocks[i](in_features[i])
+ if i == 0:
+ x = feature
+ elif feature is not None:
+ x = x + feature
+ x = self.res_blocks[i](x)
+ out_features.append(self.output_blocks[i](x))
+ if i < len(self.res_blocks) - 1:
+ x = self.resamplers[i](x)
+ return out_features
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/modules_rgbd_encoder.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/modules_rgbd_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8712b3f6674c5ba5ff1e3a242ac65beeef4c53b8
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/modules_rgbd_encoder.py
@@ -0,0 +1,152 @@
+from typing import *
+from numbers import Number
+import importlib
+import itertools
+import functools
+import sys
+
+import torch
+from torch import Tensor
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .dinov2_rgbd.models.vision_transformer import DinoVisionTransformer
+from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing
+
+
+class DINOv2_RGBD_Encoder(nn.Module):
+ backbone: DinoVisionTransformer
+ image_mean: torch.Tensor
+ image_std: torch.Tensor
+ dim_features: int
+
+ def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, ignore_layers: Union[str, List[str]]=[], in_chans: int=3, strict: bool=True, img_depth_fuse_mode='', depth_emb_mode='', depth_mask_ratio=0.6, img_mask_ratio=0.0, **deprecated_kwargs):
+ super(DINOv2_RGBD_Encoder, self).__init__()
+
+ self.intermediate_layers = intermediate_layers
+ self.strict = strict
+ self.ignore_layers = ignore_layers
+ self.img_mask_ratio = img_mask_ratio
+ # Load the backbone
+ self.hub_loader = getattr(importlib.import_module(".dinov2_rgbd.hub.backbones", __package__), backbone)
+ self.backbone_name = backbone
+ self.backbone = self.hub_loader(pretrained=False,
+ in_chans=in_chans,
+ img_depth_fuse_mode=img_depth_fuse_mode,
+ depth_emb_mode=depth_emb_mode,
+ depth_mask_ratio=depth_mask_ratio,
+ img_mask_ratio=img_mask_ratio)
+
+ self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
+ self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
+
+ if img_mask_ratio > 0:
+ self.mask_token_mae = nn.Parameter(torch.zeros(1, 1, self.dim_features))
+ torch.nn.init.normal_(self.mask_token_mae, std=.02)
+
+ self.output_projections = nn.ModuleList([
+ nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,)
+ for _ in range(self.num_features)
+ ])
+
+ self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ @property
+ def onnx_compatible_mode(self):
+ return getattr(self, "_onnx_compatible_mode", False)
+
+ @onnx_compatible_mode.setter
+ def onnx_compatible_mode(self, value: bool):
+ self._onnx_compatible_mode = value
+ self.backbone.onnx_compatible_mode = value
+
+ def init_weights(self):
+ pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
+ ignore_layers = []
+ if isinstance(self.ignore_layers, str):
+ ignore_layers = [self.ignore_layers]
+ else:
+ ignore_layers = self.ignore_layers
+
+ if len(ignore_layers) == 0:
+ self.backbone.load_state_dict(pretrained_backbone_state_dict, strict=self.strict)
+ else:
+ state_dict = {}
+ for k, v in pretrained_backbone_state_dict.items():
+ is_ignore = False
+ for ig_k in ignore_layers:
+ if ig_k in k:
+ is_ignore = True
+ break
+ if not is_ignore:
+ state_dict[k] = v
+ self.backbone.load_state_dict(state_dict, strict=self.strict)
+
+ def enable_gradient_checkpointing(self):
+ for i in range(len(self.backbone.blocks)):
+ wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
+
+ def enable_pytorch_native_sdpa(self):
+ for i in range(len(self.backbone.blocks)):
+ wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
+
+ def forward(self,
+ image: torch.Tensor,
+ depth: torch.Tensor,
+ token_rows: Union[int, torch.LongTensor],
+ token_cols: Union[int, torch.LongTensor],
+ return_class_token: bool = False,
+ remap_depth_in: str='linear',
+ **kwargs):
+ image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode)
+ image_14 = (image_14 - self.image_mean) / self.image_std
+
+ depth_14 = F.interpolate(depth, (token_rows * 14, token_cols * 14), mode="nearest")
+
+ # set invalid depth value to zero
+ depth_14[torch.isinf(depth_14)] = 0.0
+ depth_14[torch.isnan(depth_14)] = 0.0
+ dmask_14 = (depth_14 > 0.01).detach()
+ depth_14 = depth_14 * dmask_14.float()
+
+ if remap_depth_in == 'linear':
+ pass # do nothing
+ elif remap_depth_in == 'log':
+ depth_14 = torch.log(depth_14)
+ depth_14[~dmask_14] = 0.0
+ depth_14 = torch.nan_to_num(depth_14, nan=0.0, posinf=0.0, neginf=0.0)
+ else:
+ raise NotImplementedError
+
+ # Get intermediate layers from the backbone
+ features = self.backbone.get_intermediate_layers_mae(
+ x_img=image_14,
+ x_depth=depth_14,
+ n=self.intermediate_layers,
+ return_class_token=True,
+ **kwargs)
+
+ assert self.img_mask_ratio == 0, "img_mask_ratio is not supported in this encoder"
+
+ if isinstance(features[0][0], list):
+ num_valid_tokens = token_rows * token_cols
+ features = tuple(
+ (
+ torch.cat([feat[:, :num_valid_tokens].contiguous() for feat in feats], dim=0),
+ torch.cat(cls_tokens, dim=0)
+ )
+ for feats, cls_tokens in features
+ )
+
+ # Project features to the desired dimensionality
+ x = torch.stack([
+ proj(feat.permute(0, 2, 1)[:, :, :token_rows*token_cols].unflatten(2, (token_rows, token_cols)).contiguous())
+ for proj, (feat, clstoken) in zip(self.output_projections, features)
+ ], dim=1).sum(dim=1)
+ cls_token = features[-1][1]
+
+ if return_class_token:
+ return x, cls_token, None, None
+ else:
+ return x, None, None
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/utils.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aca85509a7957a7e29bc7dffee76c7950cf8e79
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/utils.py
@@ -0,0 +1,127 @@
+from typing import *
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def wrap_module_with_gradient_checkpointing(module: nn.Module):
+ from torch.utils.checkpoint import checkpoint
+ class _CheckpointingWrapper(module.__class__):
+ _restore_cls = module.__class__
+ def forward(self, *args, **kwargs):
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
+
+ module.__class__ = _CheckpointingWrapper
+ return module
+
+
+def unwrap_module_with_gradient_checkpointing(module: nn.Module):
+ module.__class__ = module.__class__._restore_cls
+
+
+def wrap_dinov2_attention_with_sdpa(module: nn.Module):
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
+ class _AttentionWrapper(module.__class__):
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
+
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
+
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+ module.__class__ = _AttentionWrapper
+ return module
+
+def wrap_dinov3_attention_with_sdpa(module: nn.Module):
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
+ class _AttentionWrapper(module.__class__):
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
+
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
+
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+ module.__class__ = _AttentionWrapper
+ return module
+
+def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]:
+ group_to_use = torch.distributed.group.WORLD
+ world_size = group_to_use.size()
+ grad = bucket.buffer()
+ grad.div_(world_size)
+ torch.distributed.all_reduce(grad, group=group_to_use)
+ fut = torch.futures.Future()
+ fut.set_result(grad)
+ return fut
+
+def depth_to_pointcloud(depth, intrinsic_normalized, depth_scale=1.0):
+ """
+ Convert depth map to point cloud (pure Tensor version, no point filtering)
+
+ Args:
+ depth: torch.Tensor, shape (H, W) or (B, H, W), depth map
+ intrinsic_normalized: torch.Tensor, shape (3, 3) or (B, 3, 3), normalized intrinsic matrix
+ Normalized intrinsics: fx' = fx/W, fy' = fy/H, cx' = cx/W, cy' = cy/H
+ depth_scale: float, depth scale factor, default 1000.0
+
+ Returns:
+ points: torch.Tensor, shape (H, W, 3) or (B, H, W, 3), point cloud coordinates (x, y, z)
+ """
+ # Handle batch dimension
+ if depth.dim() == 2:
+ depth = depth.unsqueeze(0) # (1, H, W)
+ intrinsic_normalized = intrinsic_normalized.unsqueeze(0) # (1, 3, 3)
+ squeeze_output = True
+ else:
+ squeeze_output = False
+
+ B, H, W = depth.shape
+ device = depth.device
+
+ # Denormalize intrinsics
+ fx = intrinsic_normalized[:, 0, 0] * W # (B,)
+ fy = intrinsic_normalized[:, 1, 1] * H
+ cx = intrinsic_normalized[:, 0, 2] * W
+ cy = intrinsic_normalized[:, 1, 2] * H
+
+ # Create pixel coordinate grid (H, W)
+ v, u = torch.meshgrid(
+ torch.arange(H, device=device, dtype=torch.float32),
+ torch.arange(W, device=device, dtype=torch.float32),
+ indexing='ij'
+ )
+
+ # Expand to batch dimension (B, H, W)
+ u = u.unsqueeze(0).expand(B, -1, -1)
+ v = v.unsqueeze(0).expand(B, -1, -1)
+
+ # Backproject to 3D space
+ z = depth / depth_scale # (B, H, W)
+
+ # Expand intrinsic dimensions for broadcasting (B, 1, 1)
+ fx = fx.view(B, 1, 1)
+ fy = fy.view(B, 1, 1)
+ cx = cx.view(B, 1, 1)
+ cy = cy.view(B, 1, 1)
+
+ x = (u - cx) * z / fx # (B, H, W)
+ y = (v - cy) * z / fy # (B, H, W)
+
+ # Stack coordinates (B, H, W, 3)
+ points = torch.stack([x, y, z], dim=-1)
+
+ if squeeze_output:
+ points = points.squeeze(0) # (H, W, 3)
+
+ return points
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/v2.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7582b703b9496be711b4cb6cea6af4c6c2d4b71
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/model/v2.py
@@ -0,0 +1,297 @@
+from typing import *
+from numbers import Number
+from functools import partial
+from pathlib import Path
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.amp
+import torch.version
+from huggingface_hub import hf_hub_download
+
+from .modules_rgbd_encoder import DINOv2_RGBD_Encoder
+from .modules_decoder import MLP, ConvStack
+from ..utils.geo import depth_to_pointcloud, normalized_view_plane_uv
+
+
+class MDMModel(nn.Module):
+ encoder: Union[DINOv2_RGBD_Encoder]
+ neck: ConvStack
+ points_head: ConvStack
+ mask_head: ConvStack
+ scale_head: MLP
+ onnx_compatible_mode: bool
+
+ def __init__(self,
+ encoder: Dict[str, Any],
+ neck: Dict[str, Any],
+ depth_head: Dict[str, Any] = None,
+ mask_head: Dict[str, Any] = None,
+ normal_head: Dict[str, Any] = None,
+ scale_head: Dict[str, Any] = None,
+ remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
+ remap_depth_in: Literal['linear', 'log'] = 'log',
+ remap_depth_out: Literal['linear', 'exp'] = 'exp',
+ num_tokens_range: List[int] = [1200, 3600],
+ **deprecated_kwargs
+ ):
+ super(MDMModel, self).__init__()
+ if deprecated_kwargs:
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
+
+ self.remap_output = remap_output
+ self.num_tokens_range = num_tokens_range
+ self.remap_depth_in = remap_depth_in
+ self.remap_depth_out = remap_depth_out
+
+ self.encoder = DINOv2_RGBD_Encoder(**encoder)
+
+ self.neck = ConvStack(**neck)
+ if depth_head is not None:
+ self.depth_head = ConvStack(**depth_head)
+ if mask_head is not None:
+ self.mask_head = ConvStack(**mask_head)
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, Path, IO[bytes]],
+ model_kwargs: Optional[Dict[str, Any]] = None,
+ **hf_kwargs) -> 'MDMModel':
+ if Path(pretrained_model_name_or_path).exists():
+ checkpoint_path = pretrained_model_name_or_path
+ else:
+ checkpoint_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path,
+ repo_type="model",
+ filename="model.pt",
+ **hf_kwargs
+ )
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
+
+ model_config = checkpoint['model_config']
+ if model_kwargs is not None:
+ model_config.update(model_kwargs)
+ model = cls(**model_config)
+ model.load_state_dict(checkpoint['model'], strict=False)
+
+ return model
+
+ def init_weights(self):
+ self.encoder.init_weights()
+
+ def enable_pytorch_native_sdpa(self):
+ self.encoder.enable_pytorch_native_sdpa()
+
+ def forward(self,
+ image: torch.Tensor,
+ num_tokens: Union[int, torch.LongTensor],
+ depth: Union[None, torch.Tensor]=None,
+ **kwargs) -> Dict[str, torch.Tensor]:
+ batch_size, _, img_h, img_w = image.shape
+ device, dtype = image.device, image.dtype
+
+ assert depth is not None # in this version, depth is required
+ if depth.dim() == 3:
+ depth = depth.unsqueeze(1) # from (B, H, W) to (B, 1, H, W)
+
+ aspect_ratio = img_w / img_h
+ base_h, base_w = (num_tokens / aspect_ratio) ** 0.5, (num_tokens * aspect_ratio) ** 0.5
+ if isinstance(base_h, torch.Tensor):
+ base_h, base_w = base_h.round().long(), base_w.round().long()
+ else:
+ base_h, base_w = round(base_h), round(base_w)
+
+ # Backbones encoding
+ features, cls_token, _, _ = self.encoder(image, depth, base_h, base_w, return_class_token=True, remap_depth_in=self.remap_depth_in, **kwargs)
+
+ features = features + cls_token[..., None, None]
+ features = [features, None, None, None, None]
+
+ # Concat UVs for aspect ratio input
+ for level in range(5):
+ uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
+ if features[level] is None:
+ features[level] = uv
+ else:
+ features[level] = torch.concat([features[level], uv], dim=1)
+
+ # Shared neck
+ features = self.neck(features)
+
+ # Heads decoding
+ depth_reg, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['depth_head', 'normal_head', 'mask_head'])
+ metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None
+
+ # Resize
+ depth_reg, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [depth_reg, normal, mask])
+
+ # Remap output
+ if depth_reg is not None:
+ if self.remap_depth_out == 'exp':
+ depth_reg = depth_reg.exp().squeeze(1)
+ elif self.remap_depth_out == 'linear':
+ depth_reg = depth_reg.squeeze(1)
+ else:
+ raise ValueError(f"Invalid remap_depth_out: {self.remap_depth_out}")
+ if normal is not None:
+ normal = normal.permute(0, 2, 3, 1)
+ normal = F.normalize(normal, dim=-1)
+ if mask is not None:
+ mask_prob = mask.squeeze(1).sigmoid()
+ # mask_logits = mask.squeeze(1)
+ else:
+ mask_prob = None
+ if metric_scale is not None:
+ metric_scale = metric_scale.squeeze(1).exp()
+
+ return_dict = {
+ 'depth_reg': depth_reg,
+ 'normal': normal,
+ 'mask': mask_prob,
+ }
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
+
+ return return_dict
+
+ @torch.inference_mode()
+ def infer(
+ self,
+ image: torch.Tensor,
+ depth_in: torch.Tensor = None,
+ num_tokens: int = None,
+ resolution_level: int = 9,
+ apply_mask: bool = True,
+ use_fp16: bool = True,
+ intrinsics: Optional[torch.Tensor] = None,
+ **kwargs
+ ) -> Dict[str, torch.Tensor]:
+ if image.dim() == 3:
+ omit_batch_dim = True
+ image = image.unsqueeze(0)
+ else:
+ omit_batch_dim = False
+ image = image.to(dtype=self.dtype, device=self.device)
+
+ if (depth_in is not None) and (depth_in.dim() == 2):
+ depth_in = depth_in.unsqueeze(0).to(dtype=self.dtype, device=self.device)
+
+ original_height, original_width = image.shape[-2:]
+ area = original_height * original_width
+ aspect_ratio = original_width / original_height
+
+ # Determine the number of base tokens to use
+ if num_tokens is None:
+ min_tokens, max_tokens = self.num_tokens_range
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
+
+ # Forward pass
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=use_fp16 and self.dtype != torch.bfloat16):
+ output = self.forward(image, num_tokens=num_tokens, depth=depth_in, **kwargs)
+ depth_reg, mask = (output.get(k, None) for k in ['depth_reg', 'mask'])
+
+ # Always process the output in fp32 precision
+ depth_reg, mask = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [depth_reg, mask])
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
+ if mask is not None:
+ mask_binary = mask > 0.5
+ else:
+ mask_binary = None
+
+ depth = depth_reg
+ if intrinsics is not None:
+ points = depth_to_pointcloud(depth, intrinsics)
+ else:
+ points = None
+
+ # Apply mask
+ if apply_mask and mask_binary is not None:
+ points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
+ depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
+
+ return_dict = {
+ 'points': points,
+ 'depth': depth,
+ 'mask': mask_binary,
+ }
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
+
+ if omit_batch_dim:
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
+
+ return return_dict
+
+ def forward_feat(self,
+ image: torch.Tensor,
+ num_tokens: Union[int, torch.LongTensor],
+ depth: Union[None, torch.Tensor]=None,
+ **kwargs) -> Dict[str, torch.Tensor]:
+ batch_size, _, img_h, img_w = image.shape
+ device, dtype = image.device, image.dtype
+
+ assert depth is not None # in this version, depth is required
+ if depth.dim() == 3:
+ depth = depth.unsqueeze(1) # from (B, H, W) to (B, 1, H, W)
+
+ aspect_ratio = img_w / img_h
+ base_h, base_w = (num_tokens / aspect_ratio) ** 0.5, (num_tokens * aspect_ratio) ** 0.5
+ if isinstance(base_h, torch.Tensor):
+ base_h, base_w = base_h.round().long(), base_w.round().long()
+ else:
+ base_h, base_w = round(base_h), round(base_w)
+
+ # Backbones encoding
+ features, cls_token, _, _ = self.encoder(image, depth, base_h, base_w, return_class_token=True, remap_depth_in=self.remap_depth_in, **kwargs)
+
+ return features, cls_token
+
+
+ @torch.inference_mode()
+ def infer_feat(
+ self,
+ image: torch.Tensor,
+ depth_in: torch.Tensor = None,
+ num_tokens: int = None,
+ resolution_level: int = 9,
+ apply_mask: bool = True,
+ use_fp16: bool = True,
+ intrinsics: Optional[torch.Tensor] = None,
+ **kwargs
+ ):
+ if image.dim() == 3:
+ omit_batch_dim = True
+ image = image.unsqueeze(0)
+ else:
+ omit_batch_dim = False
+ image = image.to(dtype=self.dtype, device=self.device)
+
+ if (depth_in is not None) and (depth_in.dim() == 2):
+ depth_in = depth_in.unsqueeze(0).to(dtype=self.dtype, device=self.device)
+
+ original_height, original_width = image.shape[-2:]
+ area = original_height * original_width
+ aspect_ratio = original_width / original_height
+
+ # Determine the number of base tokens to use
+ if num_tokens is None:
+ min_tokens, max_tokens = self.num_tokens_range
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
+
+ # Forward pass
+ with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=use_fp16 and self.dtype != torch.bfloat16):
+ features, cls_token = self.forward_feat(image, num_tokens=num_tokens, depth=depth_in, **kwargs)
+
+ return features, cls_token
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/__init__.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/geo.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/geo.py
new file mode 100644
index 0000000000000000000000000000000000000000..13fc0bd86248dd521b3beb571ed2418356620875
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/geo.py
@@ -0,0 +1,105 @@
+import torch
+
+def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
+ u, v = torch.meshgrid(u, v, indexing='xy')
+ uv = torch.stack([u, v], dim=-1)
+ return uv
+
+def depth_to_pointcloud(depth, intrinsic_normalized, depth_scale=1.0):
+ """
+ Convert depth map to point cloud (pure Tensor version, no point filtering)
+
+ Args:
+ depth: torch.Tensor, shape (H, W) or (B, H, W), depth map
+ intrinsic_normalized: torch.Tensor, shape (3, 3) or (B, 3, 3), normalized intrinsic matrix
+ Normalized intrinsics: fx' = fx/W, fy' = fy/H, cx' = cx/W, cy' = cy/H
+ depth_scale: float, depth scale factor, default 1000.0
+
+ Returns:
+ points: torch.Tensor, shape (H, W, 3) or (B, H, W, 3), point cloud coordinates (x, y, z)
+ """
+ # Handle batch dimension
+ if depth.dim() == 2:
+ depth = depth.unsqueeze(0) # (1, H, W)
+ intrinsic_normalized = intrinsic_normalized.unsqueeze(0) # (1, 3, 3)
+ squeeze_output = True
+ else:
+ squeeze_output = False
+
+ B, H, W = depth.shape
+ device = depth.device
+
+ # Denormalize intrinsics
+ fx = intrinsic_normalized[:, 0, 0] * W # (B,)
+ fy = intrinsic_normalized[:, 1, 1] * H
+ cx = intrinsic_normalized[:, 0, 2] * W
+ cy = intrinsic_normalized[:, 1, 2] * H
+
+ # Create pixel coordinate grid (H, W)
+ v, u = torch.meshgrid(
+ torch.arange(H, device=device, dtype=torch.float32),
+ torch.arange(W, device=device, dtype=torch.float32),
+ indexing='ij'
+ )
+
+ # Expand to batch dimension (B, H, W)
+ u = u.unsqueeze(0).expand(B, -1, -1)
+ v = v.unsqueeze(0).expand(B, -1, -1)
+
+ # Backproject to 3D space
+ z = depth / depth_scale # (B, H, W)
+
+ # Expand intrinsic dimensions for broadcasting (B, 1, 1)
+ fx = fx.view(B, 1, 1)
+ fy = fy.view(B, 1, 1)
+ cx = cx.view(B, 1, 1)
+ cy = cy.view(B, 1, 1)
+
+ x = (u - cx) * z / fx # (B, H, W)
+ y = (v - cy) * z / fy # (B, H, W)
+
+ # Stack coordinates (B, H, W, 3)
+ points = torch.stack([x, y, z], dim=-1)
+
+ if squeeze_output:
+ points = points.squeeze(0) # (H, W, 3)
+
+ return points
+
+
+# Usage example
+if __name__ == "__main__":
+ # Single image
+ depth = torch.rand(480, 640) * 5000 # Depth values
+ intrinsic_norm = torch.tensor([
+ [525.0/640, 0, 319.5/640],
+ [0, 525.0/480, 239.5/480],
+ [0, 0, 1]
+ ])
+
+ points = depth_to_pointcloud(depth, intrinsic_norm)
+ print(f"Point cloud shape: {points.shape}") # (480, 640, 3)
+
+ # Batch processing
+ depth_batch = torch.rand(4, 480, 640) * 5000
+ intrinsic_batch = intrinsic_norm.unsqueeze(0).expand(4, -1, -1)
+
+ points_batch = depth_to_pointcloud(depth_batch, intrinsic_batch)
+ print(f"Batch point cloud shape: {points_batch.shape}") # (4, 480, 640, 3)
+
+ # Flatten to (N, 3) format if needed
+ points_flat = points.reshape(-1, 3)
+ print(f"Flattened shape: {points_flat.shape}") # (480*640, 3)
+
+ # Batch flatten to (B, N, 3)
+ points_batch_flat = points_batch.reshape(4, -1, 3)
+ print(f"Batch flattened shape: {points_batch_flat.shape}") # (4, 480*640, 3)
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/io.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf40b327a32ab36fa597d9b868b0a6aa6cfbad53
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/io.py
@@ -0,0 +1,270 @@
+import os
+os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
+from typing import IO
+import zipfile
+import json
+import io
+from typing import *
+from pathlib import Path
+import re
+from PIL import Image, PngImagePlugin
+
+import numpy as np
+import cv2
+
+from .tools import timeit
+
+
+def save_glb(
+ save_path: Union[str, os.PathLike],
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ vertex_uvs: np.ndarray,
+ texture: np.ndarray,
+ vertex_normals: Optional[np.ndarray] = None,
+):
+ import trimesh
+ import trimesh.visual
+ from PIL import Image
+
+ trimesh.Trimesh(
+ vertices=vertices,
+ vertex_normals=vertex_normals,
+ faces=faces,
+ visual = trimesh.visual.texture.TextureVisuals(
+ uv=vertex_uvs,
+ material=trimesh.visual.material.PBRMaterial(
+ baseColorTexture=Image.fromarray(texture),
+ metallicFactor=0.5,
+ roughnessFactor=1.0
+ )
+ ),
+ process=False
+ ).export(save_path)
+
+
+def save_ply(
+ save_path: Union[str, os.PathLike],
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ vertex_colors: np.ndarray,
+ vertex_normals: Optional[np.ndarray] = None,
+):
+ import trimesh
+ import trimesh.visual
+ from PIL import Image
+
+ trimesh.Trimesh(
+ vertices=vertices,
+ faces=faces,
+ vertex_colors=vertex_colors,
+ vertex_normals=vertex_normals,
+ process=False
+ ).export(save_path)
+
+
+def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray:
+ """
+ Read a image, return uint8 RGB array of shape (H, W, 3).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
+ return image
+
+
+def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95):
+ """
+ Write a image, input uint8 RGB array of shape (H, W, 3).
+ """
+ data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes()
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_bytes(data)
+ else:
+ path.write(data)
+
+
+def read_depth(path: Union[str, os.PathLike, IO]) -> np.ndarray:
+ """
+ Read a depth image, return float32 depth array of shape (H, W).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ pil_image = Image.open(io.BytesIO(data))
+ near = float(pil_image.info.get('near'))
+ far = float(pil_image.info.get('far'))
+ depth = np.array(pil_image)
+ mask_nan, mask_inf = depth == 0, depth == 65535
+ depth = (depth.astype(np.float32) - 1) / 65533
+ depth = near ** (1 - depth) * far ** depth
+ if 'unit' in pil_image.info: # Legacy support for depth units
+ unit = float(pil_image.info.get('unit'))
+ depth = depth * unit
+ depth[mask_nan] = np.nan
+ depth[mask_inf] = np.inf
+ return depth
+
+def write_depth(
+ path: Union[str, os.PathLike, IO],
+ depth: np.ndarray,
+ max_range: float = 1e5,
+ compression_level: int = 7,
+):
+ """
+ Encode and write a depth image as 16-bit PNG format.
+ ## Parameters:
+ - `path: Union[str, os.PathLike, IO]`
+ The file path or file object to write to.
+ - `depth: np.ndarray`
+ The depth array, float32 array of shape (H, W).
+ May contain `NaN` for invalid values and `Inf` for infinite values.
+
+ Depth values are encoded as follows:
+ - 0: unknown
+ - 1 ~ 65534: depth values in logarithmic
+ - 65535: infinity
+
+ metadata is stored in the PNG file as text fields:
+ - `near`: the minimum depth value
+ - `far`: the maximum depth value
+ """
+ mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth)
+
+ depth = depth.astype(np.float32)
+ mask_finite = depth
+ near = max(depth[mask_values].min(), 1e-5)
+ far = max(near * 1.1, min(depth[mask_values].max(), near * max_range))
+ depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534
+ depth[mask_nan] = 0
+ depth[mask_inf] = 65535
+
+ pil_image = Image.fromarray(depth)
+ pnginfo = PngImagePlugin.PngInfo()
+ pnginfo.add_text('near', str(near))
+ pnginfo.add_text('far', str(far))
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
+
+
+def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]:
+ """
+ Read a segmentation mask
+ ### Parameters:
+ - `path: Union[str, os.PathLike, IO]`
+ The file path or file object to read from.
+ ### Returns:
+ - `Tuple[np.ndarray, Dict[str, int]]`
+ A tuple containing:
+ - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W).
+ - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}.
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ pil_image = Image.open(io.BytesIO(data))
+ labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None
+ mask = np.array(pil_image)
+ return mask, labels
+
+
+def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7):
+ """
+ Write a segmentation mask and label mapping, as PNG format.
+ ### Parameters:
+ - `path: Union[str, os.PathLike, IO]`
+ The file path or file object to write to.
+ - `mask: np.ndarray`
+ The segmentation mask, uint8 or uint16 array of shape (H, W).
+ - `labels: Dict[str, int] = None`
+ The label mapping, a dictionary of {label_name: label_id}.
+ - `compression_level: int = 7`
+ The compression level for PNG compression.
+ """
+ assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}"
+ pil_image = Image.fromarray(mask)
+ pnginfo = PngImagePlugin.PngInfo()
+ if labels is not None:
+ labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':'))
+ pnginfo.add_text('labels', labels_json)
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
+
+
+
+def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray:
+ """
+ Read a normal image, return float32 normal array of shape (H, W, 3).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
+ mask_nan = np.all(normal == 0, axis=-1)
+ normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
+ normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12)
+ normal[mask_nan] = np.nan
+ return normal
+
+
+def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray:
+ """
+ Write a normal image, input float32 normal array of shape (H, W, 3).
+ """
+ mask_nan = np.isnan(normal).any(axis=-1)
+ normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
+ normal[mask_nan] = 0
+ data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_bytes(data)
+ else:
+ path.write(data)
+
+
+def read_mask(path: Union[str, os.PathLike, IO[bytes]]) -> np.ndarray:
+ """
+ Read a binary mask, return bool array of shape (H, W).
+ """
+ if isinstance(path, (str, os.PathLike)):
+ data = Path(path).read_bytes()
+ else:
+ data = path.read()
+ mask = cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED)
+ if len(mask.shape) == 3:
+ mask = mask[..., 0]
+ return mask > 0
+
+
+def write_mask(path: Union[str, os.PathLike, IO[bytes]], mask: np.ndarray, compression_level: int = 7):
+ """
+ Write a binary mask, input bool array of shape (H, W).
+ """
+ assert mask.dtype == bool, f"Mask must be bool array, got {mask.dtype}"
+ mask = (mask.astype(np.uint8) * 255).astype(np.uint8)
+ data = cv2.imencode('.png', mask, [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_bytes(data)
+ else:
+ path.write(data)
+
+
+JSON_TYPE = Union[str, int, float, bool, None, Dict[str, "JSON"], List["JSON"]]
+
+
+def read_json(path: Union[str, os.PathLike, IO[str]]) -> JSON_TYPE:
+ if isinstance(path, (str, os.PathLike)):
+ text = Path(path).read_text()
+ else:
+ text = path.read()
+ return json.loads(text)
+
+
+def write_json(path: Union[str, os.PathLike, IO[str]], content: JSON_TYPE):
+ text = json.dumps(content)
+ if isinstance(path, (str, os.PathLike)):
+ Path(path).write_text(text)
+ else:
+ path.write(text)
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/tools.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..3687f6938fe34433d149a1a8405be7eed5f23c37
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/tools.py
@@ -0,0 +1,289 @@
+from typing import *
+import time
+from pathlib import Path
+from numbers import Number
+from functools import wraps
+import warnings
+import math
+import json
+import os
+import importlib
+import importlib.util
+
+
+def catch_exception(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ import traceback
+ print(f"Exception in {fn.__name__}", end='r')
+ # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})
+ traceback.print_exc(chain=False)
+ time.sleep(0.1)
+ return None
+ return wrapper
+
+
+class CallbackOnException:
+ def __init__(self, callback: Callable, exception: type):
+ self.exception = exception
+ self.callback = callback
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if isinstance(exc_val, self.exception):
+ self.callback()
+ return True
+ return False
+
+def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
+ for k, v in d.items():
+ if isinstance(v, dict):
+ for sub_key in traverse_nested_dict_keys(v):
+ yield (k, ) + sub_key
+ else:
+ yield (k, )
+
+
+def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
+ for k in keys:
+ d = d.get(k, default)
+ if d is None:
+ break
+ return d
+
+def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
+ for k in keys[:-1]:
+ d = d.setdefault(k, {})
+ d[keys[-1]] = value
+
+
+def key_average(list_of_dicts: list) -> Dict[str, Any]:
+ """
+ Returns a dictionary with the average value of each key in the input list of dictionaries.
+ """
+ _nested_dict_keys = set()
+ for d in list_of_dicts:
+ _nested_dict_keys.update(traverse_nested_dict_keys(d))
+ _nested_dict_keys = sorted(_nested_dict_keys)
+ result = {}
+ for k in _nested_dict_keys:
+ values = []
+ for d in list_of_dicts:
+ v = get_nested_dict(d, k)
+ if v is not None and not math.isnan(v):
+ values.append(v)
+ avg = sum(values) / len(values) if values else float('nan')
+ set_nested_dict(result, k, avg)
+ return result
+
+
+def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
+ """
+ Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
+ """
+ items = []
+ if parent_key is None:
+ parent_key = ()
+ for k, v in d.items():
+ new_key = parent_key + (k, )
+ if isinstance(v, MutableMapping):
+ items.extend(flatten_nested_dict(v, new_key).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
+ """
+ result = {}
+ for k, v in d.items():
+ sub_dict = result
+ for k_ in k[:-1]:
+ if k_ not in sub_dict:
+ sub_dict[k_] = {}
+ sub_dict = sub_dict[k_]
+ sub_dict[k[-1]] = v
+ return result
+
+
+def read_jsonl(file):
+ import json
+ with open(file, 'r') as f:
+ data = f.readlines()
+ return [json.loads(line) for line in data]
+
+
+def write_jsonl(data: List[dict], file):
+ import json
+ with open(file, 'w') as f:
+ for item in data:
+ f.write(json.dumps(item) + '\n')
+
+
+def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
+ import pandas as pd
+ data = [flatten_nested_dict(d) for d in data]
+ df = pd.DataFrame(data)
+ df = df.sort_index(axis=1)
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
+ return df
+
+
+def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
+ if isinstance(d, str):
+ for old, new in mapping.items():
+ d = d.replace(old, new)
+ elif isinstance(d, list):
+ for i, item in enumerate(d):
+ d[i] = recursive_replace(item, mapping)
+ elif isinstance(d, dict):
+ for k, v in d.items():
+ d[k] = recursive_replace(v, mapping)
+ return d
+
+
+class timeit:
+ _history: Dict[str, List['timeit']] = {}
+
+ def __init__(self, name: str = None, verbose: bool = True, average: bool = False):
+ self.name = name
+ self.verbose = verbose
+ self.start = None
+ self.end = None
+ self.average = average
+ if average and name not in timeit._history:
+ timeit._history[name] = []
+
+ def __call__(self, func: Callable):
+ import inspect
+ if inspect.iscoroutinefunction(func):
+ async def wrapper(*args, **kwargs):
+ with timeit(self.name or func.__qualname__):
+ ret = await func(*args, **kwargs)
+ return ret
+ return wrapper
+ else:
+ def wrapper(*args, **kwargs):
+ with timeit(self.name or func.__qualname__):
+ ret = func(*args, **kwargs)
+ return ret
+ return wrapper
+
+ def __enter__(self):
+ self.start = time.time()
+ return self
+
+ @property
+ def time(self) -> float:
+ assert self.start is not None, "Time not yet started."
+ assert self.end is not None, "Time not yet ended."
+ return self.end - self.start
+
+ @property
+ def average_time(self) -> float:
+ assert self.average, "Average time not available."
+ return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
+
+ @property
+ def history(self) -> List['timeit']:
+ return timeit._history.get(self.name, [])
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.end = time.time()
+ if self.average:
+ timeit._history[self.name].append(self)
+ if self.verbose:
+ if self.average:
+ avg = self.average_time
+ print(f"{self.name or 'It'} took {avg:.6f} seconds in average.")
+ else:
+ print(f"{self.name or 'It'} took {self.time:.6f} seconds.")
+
+
+def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
+ first = strings[0]
+
+ for start in range(len(first)):
+ if any(s[start] != strings[0][start] for s in strings):
+ break
+
+ for end in range(1, min(len(s) for s in strings)):
+ if any(s[-end] != first[-end] for s in strings):
+ break
+
+ return [s[start:len(s) - end + 1] for s in strings]
+
+
+def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
+ from concurrent.futures import ThreadPoolExecutor
+ from contextlib import nullcontext
+ from tqdm import tqdm
+
+ if pbar is not None:
+ pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
+ else:
+ pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
+
+ def decorator(fn: Callable):
+ with (
+ ThreadPoolExecutor(max_workers=num_workers) as executor,
+ pbar
+ ):
+ pbar.refresh()
+ @catch_exception
+ @suppress_traceback
+ def _fn(input):
+ ret = fn(input)
+ pbar.update()
+ return ret
+ executor.map(_fn, inputs)
+ executor.shutdown(wait=True)
+
+ return decorator
+
+
+def suppress_traceback(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
+ raise
+ return wrapper
+
+
+class no_warnings:
+ def __init__(self, action: str = 'ignore', **kwargs):
+ self.action = action
+ self.filter_kwargs = kwargs
+
+ def __call__(self, fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ with warnings.catch_warnings():
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+ return fn(*args, **kwargs)
+ return wrapper
+
+ def __enter__(self):
+ self.warnings_manager = warnings.catch_warnings()
+ self.warnings_manager.__enter__()
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
+
+
+def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/vis.py b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..e17edfa9e576b9c2b182bf39bfa289a4480bc9e3
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/mdm/utils/vis.py
@@ -0,0 +1,65 @@
+from typing import *
+
+import numpy as np
+import matplotlib
+import trimesh
+import random
+import torch
+import torch.nn.functional as F
+import os
+
+def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
+ depth = depth.copy()
+ if mask is None:
+ depth = np.where(depth > 0, depth, np.nan)
+ else:
+ depth = np.where((depth > 0) & mask, depth, np.nan)
+ disp = 1 / depth
+ if normalize:
+ min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99)
+ disp = (disp - min_disp) / (max_disp - min_disp)
+
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray:
+ if mask is not None:
+ depth = np.where(mask, depth, np.nan)
+
+ min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999)
+ depth = (depth - min_depth) / (max_depth - min_depth)
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0)
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
+ if mask is not None:
+ disparity = np.where(mask, disparity, np.nan)
+
+ if normalize:
+ min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999)
+ disparity = (disparity - min_disp) / (max_disp - min_disp)
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0)
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
+ return colored
+
+
+def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
+ if mask is not None:
+ normal = np.where(mask[..., None], normal, 0)
+ normal = normal * [0.5, -0.5, -0.5] + 0.5
+ normal = (normal.clip(0, 1) * 255).astype(np.uint8)
+ return normal
+
+
+def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None):
+ vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map))
+ cmap = matplotlib.colormaps[cmap]
+ colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3]
+ if mask is not None:
+ colorized_error_map = np.where(mask[..., None], colorized_error_map, 0)
+ colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8))
+ return colorized_error_map
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/pyproject.toml b/lingbotvla/models/vla/vision_models/lingbot-depth/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..04142f03538b10692d2ceed3c5cb24f2baf62949
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/pyproject.toml
@@ -0,0 +1,26 @@
+[build-system]
+requires = ["setuptools>=61.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "mdm"
+version = "1.0.0"
+readme = "README.md"
+dependencies = [
+ "click",
+ "opencv-python",
+ "scipy",
+ "matplotlib",
+ "trimesh",
+ "pillow",
+ "huggingface_hub",
+ "numpy",
+ "torch==2.6.0",
+ "torchvision",
+ "xformers==v0.0.29.post2",
+]
+requires-python = ">=3.9"
+
+[tool.setuptools.packages.find]
+where = ["."]
+include = ["mdm*"]
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/pyrightconfig.json b/lingbotvla/models/vla/vision_models/lingbot-depth/pyrightconfig.json
new file mode 100644
index 0000000000000000000000000000000000000000..02a3179f072f632f31bc07589c5f0dae23fb8a13
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/pyrightconfig.json
@@ -0,0 +1,8 @@
+{
+ "include": [
+ "mdm"
+ ],
+ "ignore": [
+ "**"
+ ]
+}
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/requirements.txt b/lingbotvla/models/vla/vision_models/lingbot-depth/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2fdeda1c8f0001005a3881790c484d864dc709de
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/requirements.txt
@@ -0,0 +1,11 @@
+# The versions are not specified since MoGe should be compatible with most versions of the packages.
+# If incompatibilities are found, consider upgrading to latest versions or installing the following recommended version of the package.
+torch # >= 2.0.0
+torchvision
+click # ==8.1.7
+opencv-python # ==4.10.0.84
+scipy # ==1.14.1
+matplotlib # ==3.9.2
+trimesh # ==4.5.1
+pillow # ==10.4.0
+huggingface_hub # ==0.25.2
\ No newline at end of file
diff --git a/lingbotvla/models/vla/vision_models/lingbot-depth/tech-report.pdf b/lingbotvla/models/vla/vision_models/lingbot-depth/tech-report.pdf
new file mode 100644
index 0000000000000000000000000000000000000000..fb56edc62d09b1ddbaa79cc4756bea3aa1f94c25
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/lingbot-depth/tech-report.pdf
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db1a4bd0435608eafbaad610862e90e20062a642cdcaca4e2038f4ba64a5ea63
+size 9129324
diff --git a/lingbotvla/models/vla/vision_models/module_utils.py b/lingbotvla/models/vla/vision_models/module_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..308fd033a6c4e981114a7072dd05be931a9040d0
--- /dev/null
+++ b/lingbotvla/models/vla/vision_models/module_utils.py
@@ -0,0 +1,133 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import os
+import numpy as np
+import matplotlib
+import einops
+from PIL import Image, ImageDraw
+
+try:
+ from mdm.model.v2 import MDMModel as v2_morgbd
+
+ from moge.model.v2 import MoGeModel as v2
+ from moge.utils.vis import colorize_depth
+except:
+ print('Load MoGe Module Failed!!')
+
+def make_grid(images, pil_images):
+ # Assuming each image is the same size
+
+ new_images = []
+ new_captions = []
+ for image, pil_image in zip(images, pil_images):
+ new_images.append(image)
+ pil_image = pil_image.resize((image.size[0], image.size[1]))
+ new_images.append(pil_image)
+ new_captions.append("Predicted")
+ new_captions.append("GT")
+
+ images = new_images
+ captions = new_captions
+
+ width, height = images[0].size
+ font_size = 14
+ caption_height = font_size + 10
+
+ # Calculate the size of the final image
+ images_per_row = min(len(images), 16) # Round up for odd number of images
+ row_count = (len(images) + 1) // images_per_row
+ total_width = width * images_per_row
+ total_height = (height + caption_height) * row_count
+
+ # Create a new blank image
+ new_image = Image.new("RGB", (total_width, total_height), "white")
+
+ draw = ImageDraw.Draw(new_image)
+
+ for i, (image, caption) in enumerate(zip(images, captions)):
+ row = i // images_per_row
+ col = i % images_per_row
+ x_offset = col * width
+ y_offset = row * (height + caption_height)
+
+ new_image.paste(image, (x_offset, y_offset))
+ text_position = (x_offset + 10, y_offset + height)
+ draw.text(text_position, caption, fill="red", font_size=font_size)
+
+ return new_image
+
+def build_depth_model(config):
+
+ moge_model = v2.from_pretrained(config['depth']['moge_path'])
+ for p in moge_model.parameters():
+ p.requires_grad = False
+ moge_model.cuda()
+ moge_model.eval()
+
+ morgbd_model = v2_morgbd.from_pretrained(config['depth']['morgbd_path'])
+ for p in morgbd_model.parameters():
+ p.requires_grad = False
+ morgbd_model.cuda()
+ morgbd_model.eval()
+ return moge_model, morgbd_model
+
+def get_depth_target(model_type, depth_model, pil_images):
+ device = pil_images.device
+ B, _, C, H, W = pil_images.shape
+ images = einops.rearrange(pil_images, 'b n c h w -> (b n) c h w', n=3).contiguous().float()
+
+ input_images = images / 255.0
+ moge_model, morgbd_model = depth_model
+ output_moge = moge_model.infer(input_images, resolution_level=3, num_tokens=256, apply_mask=False)
+ depth_pred = output_moge['depth'].squeeze().detach().clone() # moge2
+ depth_pred = torch.nan_to_num(depth_pred, nan=0.0, posinf=0.0, neginf=0.0)
+ depth_pred *= 1
+ depth_down_scale = 1
+ depth_target, cls_token = morgbd_model.infer_feat(input_images, depth_pred,
+ depth_down_scale=depth_down_scale,
+ resolution_level=3,
+ num_tokens=256,
+ enable_depth_mask=False)
+ depth_target = depth_target.permute(0, 2, 3, 1)
+ depth_target = depth_target.view(depth_target.shape[0], -1, depth_target.shape[-1])
+
+ return depth_target.to(dtype=torch.bfloat16), cls_token
+
+def log_depth(vis_head, depth_pred_feats, depth_target_feats=None, steps=0, config=None, cls_token=None):
+ model_type = config['depth']['model_type']
+ llm_image_token_size = config['llm']['image_token_size']
+ depth_token_size = config['depth']['token_size']
+ visual_dir = config['visual_dir']
+
+ if config['mode'] == "direct":
+ depth_pred_feats = depth_pred_feats.view(depth_pred_feats.shape[0], llm_image_token_size, llm_image_token_size, depth_pred_feats.shape[-1])
+ depth_pred_feats = depth_pred_feats.permute(0, 3, 1, 2)
+ depth_pred_feats = F.interpolate(depth_pred_feats, size=(depth_token_size, depth_token_size), mode="bilinear", align_corners=False)
+ elif config['mode'] == "query":
+ depth_pred_feats = depth_pred_feats.view(depth_pred_feats.shape[0], depth_token_size, depth_token_size, depth_pred_feats.shape[-1])
+ depth_pred_feats = depth_pred_feats.permute(0, 3, 1, 2)
+
+ import cv2
+ morgbd_model = vis_head
+ depth_target_feats = depth_target_feats.view(depth_target_feats.shape[0], depth_token_size, depth_token_size, depth_target_feats.shape[-1])
+ depth_target_feats = depth_target_feats.permute(0, 3, 1, 2)
+
+ output_morgbd_preds = morgbd_model.dec_depth(depth_pred_feats, cls_token, num_tokens=256, resolution_level=3, img_h=224, img_w=224)
+ output_morgbd_targets = morgbd_model.dec_depth(depth_target_feats, cls_token, num_tokens=256, resolution_level=3, img_h=224, img_w=224)
+
+ output_morgbd_preds = output_morgbd_preds['depth_reg'].squeeze().cpu().numpy()
+ output_morgbd_targets = output_morgbd_targets['depth_reg'].squeeze().cpu().numpy()
+
+ for idx, (output_morgbd_target, output_morgbd_pred) in enumerate(zip(output_morgbd_targets, output_morgbd_preds)):
+
+ depth_list = [output_morgbd_target, output_morgbd_pred]
+ depth_color_list = [cv2.cvtColor(colorize_depth(depth_raw), cv2.COLOR_RGB2BGR) for depth_raw in depth_list]
+
+ depth_concat = np.concatenate(depth_color_list, axis=1)
+
+ dst_path = os.path.join(visual_dir, f"depth_morgbd_{steps}_{idx}.png")
+ cv2.imwrite(dst_path,depth_concat)
+
+
diff --git a/lingbotvla/ops/__init__.py b/lingbotvla/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f91eff59e2a263eb25b22fbd4243a7fffc94d0a0
--- /dev/null
+++ b/lingbotvla/ops/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .attention import flash_attention_forward
+from .fused_moe import fused_moe_forward
+from .loss import causallm_loss_function
+
+
+__all__ = [
+ "fused_moe_forward",
+ "causallm_loss_function",
+]
diff --git a/lingbotvla/ops/attention.py b/lingbotvla/ops/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2d497b14496efa831d53b2fe9db435b07eb2325
--- /dev/null
+++ b/lingbotvla/ops/attention.py
@@ -0,0 +1,277 @@
+from typing import Optional, Tuple, Literal
+
+import torch
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+import torch.nn.functional as F # noqa: N812
+from packaging.version import Version
+import einops
+from ..distributed.parallel_state import get_parallel_state
+from ..distributed.sequence_parallel import (
+ gather_heads_scatter_seq,
+ gather_seq_scatter_heads,
+)
+from ..utils import logging
+from ..utils.import_utils import is_seed_kernels_available
+
+if is_seed_kernels_available():
+ from seed_kernels.transformers.functional import seed_flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def flash_attention_forward(
+ module: torch.nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ dropout: float = 0.0,
+ scaling: Optional[float] = None,
+ sliding_window: Optional[int] = None,
+ softcap: Optional[float] = None,
+ implementation: Optional[Literal["fa2", "lego", "fa3"]] = None,
+ skip_ulysses: bool = False, # Skip ulysses for some ViT cases like internvl3.5
+ **kwargs,
+) -> Tuple[torch.Tensor, None]:
+ if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None:
+ logger.warning_once(
+ "`flash_attention_2` does not support `output_attentions=True` or `head_mask`."
+ " Please set your attention to `eager` if you want any of these features."
+ )
+
+ # FA2 uses non-transposed inputs
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
+ kwargs.pop("is_causal", None)
+
+ # This is for Qwen2VL's mrope
+ position_ids = kwargs.pop("position_ids", None)
+ if position_ids is not None and position_ids.dim() == 3:
+ position_ids = position_ids[0]
+
+ # Ulysses patch
+ ulysses_enabled = get_parallel_state().ulysses_enabled
+ if ulysses_enabled and not skip_ulysses:
+ ulysses_group = get_parallel_state().ulysses_group
+ # Sanity Check & Repeat Key & Value
+ ulysses_size = get_parallel_state().ulysses_size
+ q_head_num = query.shape[2]
+ kv_head_num = key.shape[2]
+ unpadded_seq_len = None
+
+ assert q_head_num % ulysses_size == 0, (
+ f"num_query_heads ({q_head_num}) must be divisible by ulysses_size ({ulysses_size})"
+ )
+ if ulysses_size > kv_head_num:
+ assert ulysses_size % kv_head_num == 0, (
+ f"ulysses_size ({ulysses_size}) must be divisible by num_key_value_heads ({kv_head_num})"
+ )
+ n_repeat = ulysses_size // kv_head_num
+ key = repeat_kv(key, n_repeat)
+ value = repeat_kv(value, n_repeat)
+
+ if query.ndim == 4 and query.size(0) == 1:
+ query, key, value = query.squeeze(0), key.squeeze(0), value.squeeze(0)
+ query = gather_seq_scatter_heads(
+ query, seq_dim=0, head_dim=1, group=ulysses_group, unpadded_dim_size=unpadded_seq_len
+ )
+ key = gather_seq_scatter_heads(
+ key, seq_dim=0, head_dim=1, group=ulysses_group, unpadded_dim_size=unpadded_seq_len
+ )
+ value = gather_seq_scatter_heads(
+ value, seq_dim=0, head_dim=1, group=ulysses_group, unpadded_dim_size=unpadded_seq_len
+ )
+ query, key, value = query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0)
+ else:
+ query = gather_seq_scatter_heads(
+ query, seq_dim=1, head_dim=2, group=ulysses_group, unpadded_dim_size=unpadded_seq_len
+ )
+ key = gather_seq_scatter_heads(
+ key, seq_dim=1, head_dim=2, group=ulysses_group, unpadded_dim_size=unpadded_seq_len
+ )
+ value = gather_seq_scatter_heads(
+ value, seq_dim=1, head_dim=2, group=ulysses_group, unpadded_dim_size=unpadded_seq_len
+ )
+
+ # Only after all_to_all we got the full seq_len
+ seq_len = query.shape[1]
+
+ if is_seed_kernels_available() and implementation is not None:
+ attn_output: torch.Tensor = seed_flash_attention_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_length=seq_len,
+ is_causal=module.is_causal,
+ dropout=dropout,
+ position_ids=position_ids,
+ softmax_scale=scaling,
+ sliding_window=sliding_window,
+ softcap=softcap,
+ use_top_left_mask=False,
+ implementation=implementation,
+ cu_seqlens=kwargs.get("cu_seq_lens_q", None),
+ max_seqlen=kwargs.get("max_length_q", None),
+ **kwargs,
+ )
+ else:
+ assert implementation is None, (
+ f"You set {implementation=} but seed_kernels is not installed. Check --model.attn_implementation."
+ )
+ attn_output: torch.Tensor = _flash_attention_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_length=seq_len,
+ is_causal=module.is_causal,
+ dropout=dropout,
+ position_ids=position_ids,
+ softmax_scale=scaling,
+ sliding_window=sliding_window,
+ softcap=softcap,
+ use_top_left_mask=False,
+ implementation="flash_attention_2",
+ **kwargs,
+ )
+
+ # Ulysses patch
+ if ulysses_enabled and not skip_ulysses:
+ ulysses_group = get_parallel_state().ulysses_group
+ if attn_output.ndim == 4 and attn_output.size(0) == 1:
+ attn_output = attn_output.squeeze(0)
+ attn_output = gather_heads_scatter_seq(attn_output, seq_dim=0, head_dim=1, group=ulysses_group)
+ attn_output = attn_output.unsqueeze(0)
+ else:
+ attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2, group=ulysses_group)
+
+ return attn_output, None
+
+if Version(torch.__version__) > Version("2.5.0"):
+ # Ffex attention is only available from torch 2.5 onwards
+ from torch.nn.attention.flex_attention import (
+ _mask_mod_signature,
+ _round_up_to_multiple,
+ create_block_mask,
+ create_mask,
+ flex_attention,
+ )
+
+# @torch.compile(dynamic=False)
+def flex_attention_forward(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ scaling=None,
+):
+ """
+ This is defined out of classes to make compile happy.
+ """
+ batch_size, seq_len, num_att_heads, head_dim = query_states.shape # head_dim=256
+ original_dtype = query_states.dtype
+ num_key_value_heads = key_states.shape[2] # 1
+ num_key_value_groups = num_att_heads // num_key_value_heads # 8 // 1
+
+ key_states = einops.repeat(
+ key_states, "b l h d -> b l (h g) d", g=num_key_value_groups
+ )
+ value_states = einops.repeat(
+ value_states, "b l h d -> b l (h g) d", g=num_key_value_groups
+ )
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ query_states = query_states.to(torch.float32)
+ key_states = key_states.to(torch.float32)
+ value_states = value_states.to(torch.float32)
+
+ causal_mask = attention_mask
+ if causal_mask is not None:
+ causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
+
+ if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
+ causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
+
+ def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
+ def mask_mod(b, h, q_idx, kv_idx):
+ # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
+ return precomputed_mask[b][h][q_idx][kv_idx]
+
+ return mask_mod
+
+ b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
+ # ipdb.set_trace()
+ block_size = 128
+ q_len_rounded = _round_up_to_multiple(q_len, block_size)
+ kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
+
+ # *CRITICAL* we do need to expand here, else we get a CUDA index error
+
+ pad_q = q_len_rounded - q_len
+ pad_k = kv_len_rounded - kv_len
+
+ if pad_q > 0:
+ query_states = F.pad(query_states, (0, 0, 0, pad_q), value=0.0) # [B, H, q_len_rounded, D]
+ if pad_k > 0:
+ key_states = F.pad(key_states, (0, 0, 0, pad_k), value=0.0)
+ value_states = F.pad(value_states, (0, 0, 0, pad_k), value=0.0)
+ padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
+ mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
+
+ mask_4d = create_mask(
+ mod_fn=mask_mod_fn_orig,
+ B=b_mask,
+ H=h_mask,
+ Q_LEN=q_len_rounded,
+ KV_LEN=kv_len_rounded,
+ device=causal_mask.device,
+ )
+
+ mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
+ block_mask = create_block_mask(
+ mask_mod=mask_mod_fn_padded,
+ B=b_mask,
+ H=h_mask,
+ Q_LEN=q_len_rounded,
+ KV_LEN=kv_len_rounded,
+ BLOCK_SIZE=block_size,
+ device=causal_mask.device,
+ _compile=False,
+ )
+
+ # mask is applied inside the kernel, ideally more efficiently than score_mod.
+ attn_output, attention_weights = flex_attention(
+ query_states,
+ key_states,
+ value_states,
+ block_mask=block_mask,
+ enable_gqa=True, # because we shaped query/key states for GQA
+ scale=head_dim**-0.5 if scaling is None else scaling,
+ return_lse=True,
+ )
+ attn_output = attn_output[:, :, :seq_len, :].to(dtype=original_dtype)
+ attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
+ attn_output = attn_output.reshape(
+ batch_size,
+ -1,
+ attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
+ )
+ return attn_output
\ No newline at end of file
diff --git a/lingbotvla/ops/dit/rope_wan/rotary.py b/lingbotvla/ops/dit/rope_wan/rotary.py
new file mode 100644
index 0000000000000000000000000000000000000000..0455e447519feafd80e42fce03d3e9027e9c7b29
--- /dev/null
+++ b/lingbotvla/ops/dit/rope_wan/rotary.py
@@ -0,0 +1,281 @@
+# Copyright (c) 2023, Tri Dao.
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py
+from typing import Optional, Union
+
+import torch
+import triton
+import triton.language as tl
+from einops import rearrange
+
+
+@triton.jit
+def rotary_interleaved_kernel(
+ OUT, # Pointers to matrices
+ X,
+ COS,
+ SIN,
+ CU_SEQLENS,
+ SEQLEN_OFFSETS, # this could be int or a pointer
+ # Matrix dimensions
+ seqlen,
+ rotary_dim,
+ seqlen_ro,
+ # strides
+ stride_out_batch,
+ stride_out_seqlen,
+ stride_out_nheads,
+ stride_out_headdim,
+ stride_x_batch,
+ stride_x_seqlen,
+ stride_x_nheads,
+ stride_x_headdim,
+ # Meta-parameters
+ BLOCK_K: tl.constexpr,
+ IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ CONJUGATE: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+):
+ pid_m = tl.program_id(axis=0)
+ pid_batch = tl.program_id(axis=1)
+ pid_head = tl.program_id(axis=2)
+ rotary_dim_half = rotary_dim // 2
+
+ if not IS_VARLEN:
+ X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads
+ OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads
+ else:
+ start_idx = tl.load(CU_SEQLENS + pid_batch)
+ seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
+ X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads
+ OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads
+
+ if pid_m * BLOCK_M >= seqlen:
+ return
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ if not IS_SEQLEN_OFFSETS_TENSOR:
+ rm_cs = rm + SEQLEN_OFFSETS
+ else:
+ rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)
+ rk = tl.arange(0, BLOCK_K)
+
+ # We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
+ # Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
+ # Loading x0 will be fast but x1 will be slow.
+ # Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
+ # Then we do the calculation and use tl.where to pick put the right outputs for the even
+ # and for the odd indices.
+ rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...
+ rk_repeat = tl.arange(0, BLOCK_K) // 2
+ X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)
+ X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)
+ COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
+ SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])
+ cos = tl.load(
+ COS,
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
+ other=1.0,
+ )
+ sin = tl.load(
+ SIN,
+ mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),
+ other=0.0,
+ )
+ x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float64)
+ x1 = tl.load(X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float64)
+ if CONJUGATE:
+ sin = -sin
+ x0_cos = x0 * cos
+ x1_sin = x1 * sin
+ out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)
+ OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)
+ out = out.to(tl.float32)
+ tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))
+
+
+def apply_rotary_interleaved(
+ x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ inplace=False,
+ conjugate=False,
+) -> torch.Tensor:
+ """
+ Arguments:
+ x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim).
+ cos: (seqlen_ro, rotary_dim / 2)
+ sin: (seqlen_ro, rotary_dim / 2)
+ seqlen_offsets: integer or integer tensor of size (batch,)
+ cu_seqlens: (batch + 1,) or None
+ max_seqlen: int
+ Returns:
+ y: (batch, seqlen, nheads, headdim)
+ """
+ is_varlen = cu_seqlens is not None
+ if not is_varlen:
+ batch, seqlen, nheads, headdim = x.shape
+ else:
+ assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed"
+ total_seqlen, nheads, headdim = x.shape
+ batch_p_1 = cu_seqlens.shape[0]
+ batch = batch_p_1 - 1
+ seqlen = max_seqlen
+ seqlen_ro, rotary_dim = cos.shape
+ assert sin.shape == cos.shape
+ rotary_dim *= 2
+ assert rotary_dim <= headdim, "rotary_dim must be <= headdim"
+ assert headdim <= 256, "Only support headdim <= 256"
+ assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen"
+
+ assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}"
+ # assert (
+ # x.dtype == cos.dtype
+ # ), f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}"
+
+ cos, sin = cos.contiguous(), sin.contiguous()
+ if isinstance(seqlen_offsets, torch.Tensor):
+ assert seqlen_offsets.shape == (batch,)
+ assert seqlen_offsets.dtype in [torch.int32, torch.int64]
+ seqlen_offsets = seqlen_offsets.contiguous()
+ else:
+ assert seqlen_offsets + seqlen <= seqlen_ro
+
+ output = torch.empty_like(x) if not inplace else x
+ if rotary_dim < headdim and not inplace:
+ output[..., rotary_dim:].copy_(x[..., rotary_dim:])
+
+ BLOCK_K = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
+ grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
+ BLOCK_M = 4
+
+ # Need this, otherwise Triton tries to launch from cuda:0 and we get
+ # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
+ with torch.cuda.device(x.device.index):
+ rotary_interleaved_kernel[grid](
+ output, # data ptrs
+ x,
+ cos,
+ sin,
+ cu_seqlens,
+ seqlen_offsets,
+ seqlen, # shapes
+ rotary_dim,
+ seqlen_ro,
+ output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
+ output.stride(-3), # seqlen_stride or total_seqlen_stride
+ output.stride(-2), # nheads_stride
+ output.stride(-1), # headdim_stride
+ x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
+ x.stride(-3), # seqlen stride or total_seqlen_stride
+ x.stride(-2), # nheads stride
+ x.stride(-1), # headdim stride
+ BLOCK_K,
+ isinstance(seqlen_offsets, torch.Tensor),
+ is_varlen,
+ conjugate,
+ BLOCK_M,
+ )
+ return output
+
+
+class ApplyRotaryEmb(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ x,
+ cos,
+ sin,
+ inplace=False,
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
+ cu_seqlens: Optional[torch.Tensor] = None,
+ max_seqlen: Optional[int] = None,
+ ):
+ out = apply_rotary_interleaved(
+ x,
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ inplace=inplace,
+ )
+ if isinstance(seqlen_offsets, int):
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
+ ctx.seqlen_offsets = seqlen_offsets
+ else:
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
+ ctx.seqlen_offsets = None
+ ctx.inplace = inplace
+ ctx.max_seqlen = max_seqlen
+ return out if not inplace else x
+
+ @staticmethod
+ def backward(ctx, do):
+ seqlen_offsets = ctx.seqlen_offsets
+ if seqlen_offsets is None:
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
+ else:
+ cos, sin, cu_seqlens = ctx.saved_tensors
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
+ dx = apply_rotary_interleaved(
+ do,
+ cos,
+ sin,
+ seqlen_offsets=seqlen_offsets,
+ cu_seqlens=cu_seqlens,
+ max_seqlen=ctx.max_seqlen,
+ inplace=ctx.inplace,
+ conjugate=True,
+ )
+ return dx, None, None, None, None, None, None, None
+
+
+def apply_rotary_emb(x, **kwargs):
+ """
+ Arguments:
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim)
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
+ of 1st half and 2nd half (GPT-NeoX style).
+ inplace: if True, apply rotary embedding in-place.
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
+ Most commonly used in inference when we have KV cache.
+ cu_seqlens: (batch + 1,) or None
+ max_seqlen: int
+ Return:
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
+ else (total_seqlen, nheads, headdim)
+ rotary_dim must be <= headdim
+ Apply rotary embedding to the first rotary_dim of x.
+ """
+ cos = kwargs.pop("cos")
+ sin = kwargs.pop("sin")
+ inplace = kwargs.pop("inplace", False)
+ seqlen_offsets = kwargs.pop("seqlen_offsets", 0)
+ cu_seqlens = kwargs.pop("cu_seqlens", None)
+ max_seqlen = kwargs.pop("max_seqlen", None)
+
+ head_dim = kwargs.pop("head_dim")
+ x = rearrange(x, "b s (n d) -> b s n d", d=head_dim)
+ return ApplyRotaryEmb.apply(x, cos, sin, inplace, seqlen_offsets, cu_seqlens, max_seqlen).flatten(2)
diff --git a/lingbotvla/ops/fused_moe.py b/lingbotvla/ops/fused_moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..35ff0abc15b457294470a44a4d293e34372a4c7e
--- /dev/null
+++ b/lingbotvla/ops/fused_moe.py
@@ -0,0 +1,348 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+from ..distributed.moe import EPGroupGemm, preprocess, token_pre_all2all, tokens_post_all2all
+from ..distributed.parallel_state import get_parallel_state
+from ..utils.import_utils import is_fused_moe_available
+
+
+if is_fused_moe_available():
+ from .group_gemm.kernel.group_gemm import group_gemm_same_mn, group_gemm_same_nk
+ from .group_gemm.kernel.moe import expert_histogram, moe_gather, moe_scatter
+
+
+class FusedMoeExpertFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ num_experts,
+ gate_weights,
+ expert_index,
+ hidden_states,
+ fc1_1_weight,
+ fc1_2_weight,
+ fc2_weight,
+ ):
+ # MOE Step 3: dispatch input tokens to the experts
+ # result shape is (batch_size * sequence_len * topk, hidden_size)
+ # MOE Step 3-1: compute the token num for each expert
+ # splits shape (num_experts)
+ splits = expert_histogram(expert_index, num_experts)
+
+ # MOE Step 3-2: compute the each token's index in result
+ # scatter_index shape (batch_size * sequence_len, topk)
+ # TODO(wenyawei): opt it
+ scatter_index = expert_index.flatten().argsort(stable=True).argsort().int().view(expert_index.shape)
+
+ # MOE Step 3-3: compute the result, select tokens by scatter_index, and put them together
+ # scatter_output shape (batch_size * sequence_len * topk, hidden_size)
+ scatter_output = moe_scatter(hidden_states, scatter_index)
+
+ # MOE Step 4: compute linear layer 1-1
+ # Not consistent.
+ cumsum_t = torch.cumsum(splits, dim=0)
+ fc1_1_output = group_gemm_same_nk(
+ a=scatter_output,
+ b=fc1_1_weight,
+ cumsum_M=cumsum_t,
+ max_M=scatter_output.shape[0],
+ transpose_a=False,
+ transpose_b=True,
+ )
+
+ # MOE Step 6: compute linear layer 1-2
+ # fc1_2_output shape is (batch_size * sequence_len * topk, ffn_dim)
+ fc1_2_output = group_gemm_same_nk(
+ a=scatter_output,
+ b=fc1_2_weight,
+ cumsum_M=cumsum_t,
+ max_M=scatter_output.shape[0],
+ transpose_a=False,
+ transpose_b=True,
+ )
+
+ # MOE Step 5: compute the actication of linear layer 1-1
+ # TODO(wenyawei): act function
+ # fc1_1_activation shape is (batch_size * sequence_len * topk, ffn_dim)
+ fc1_1_activation = torch.ops.aten.silu(fc1_1_output)
+
+ # MOE Step 7: compute final result of linear layer 1
+ fc1_activation = fc1_1_activation * fc1_2_output
+
+ # MOE Step 8: compute the the weighted linear layer 1 result
+ # MOE Step 8-1: compute scattered_gate_weight, shape is (batch_size * sequence_len * topk)
+ reshaped_gate_weight = gate_weights.reshape(-1, 1)
+ scattered_gate_weight = torch.empty_like(reshaped_gate_weight)
+ scattered_gate_weight[scatter_index.flatten()] = reshaped_gate_weight
+
+ # MOE Step 8-2: multiply activate with scattered_gate_weight
+ # fc1_weighted_output shape is (batch_size * sequence_len * topk, ffn_dim)
+ fc1_weighted_output = fc1_activation * scattered_gate_weight
+
+ # MOE Step 9: compute linear layer 2
+ # result shape is (batch_size * sequence_len * topk, hidden_size)
+ fc2_output = group_gemm_same_nk(
+ a=fc1_weighted_output,
+ b=fc2_weight,
+ cumsum_M=cumsum_t,
+ max_M=scatter_output.shape[0],
+ transpose_a=False,
+ transpose_b=True,
+ )
+
+ # MOE Step 10: gather the final token result by averate the the topk token results
+ expert_output = moe_gather(fc2_output, scatter_index)
+
+ # reshape the output with input shape
+ output = expert_output.reshape(hidden_states.shape)
+
+ ctx.num_experts = num_experts
+ ctx.save_for_backward(
+ gate_weights,
+ fc1_1_weight,
+ fc1_2_weight,
+ fc2_weight,
+ hidden_states,
+ scatter_index,
+ scatter_output,
+ cumsum_t,
+ fc1_1_output,
+ fc1_2_output,
+ fc1_activation,
+ scattered_gate_weight,
+ fc1_weighted_output,
+ )
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ (
+ gate_weights,
+ fc1_1_weight,
+ fc1_2_weight,
+ fc2_weight,
+ hidden_states,
+ scatter_index,
+ scatter_output,
+ cumsum_t,
+ fc1_1_output,
+ fc1_2_output,
+ fc1_activation,
+ scattered_gate_weight,
+ fc1_weighted_output,
+ ) = ctx.saved_tensors
+ hidden_dim = grad_output.shape[-1]
+ grad_output = grad_output.view(-1, hidden_dim)
+
+ # MOE Step 10
+ grad_fc2_output = moe_scatter(grad_output, scatter_index)
+
+ # MOE Step 9
+ # grad_fc1_weighted_output = torch.empty_like(fc1_weighted_output)
+
+ # dgrad
+ grad_fc1_weighted_output = group_gemm_same_nk(
+ a=grad_fc2_output,
+ b=fc2_weight,
+ cumsum_M=cumsum_t,
+ max_M=grad_output.shape[0],
+ transpose_b=False,
+ )
+
+ # wgrad
+ grad_fc2_weight = None
+ if fc2_weight.requires_grad:
+ grad_fc2_weight = torch.empty_like(fc2_weight)
+ group_gemm_same_mn(
+ a=grad_fc2_output,
+ b=fc1_weighted_output,
+ c=grad_fc2_weight,
+ cumsum_K=cumsum_t,
+ max_K=grad_output.shape[0],
+ transpose_a=True,
+ transpose_b=False,
+ )
+
+ # MOE Step 8
+ # MOE Step 8-2
+ grad_fc1_activation = grad_fc1_weighted_output * scattered_gate_weight
+
+ # MOE Step 8-1
+ grad_scattered_gate_weight = torch.sum(fc1_activation * grad_fc1_weighted_output, dim=-1)
+ grad_gate_weight = grad_scattered_gate_weight[scatter_index.flatten()]
+ grad_gate_weight = grad_gate_weight.reshape(gate_weights.shape)
+
+ # recompute during backward
+ fc1_1_activation = torch.ops.aten.silu(fc1_1_output)
+
+ # MOE Step 7
+ grad_fc1_1_activation = grad_fc1_activation * fc1_2_output
+ grad_fc1_2_output = fc1_1_activation * grad_fc1_activation
+
+ # MOE Step 6
+ # grad_scatter_output_2 = torch.empty_like(scatter_output)
+
+ # dgrad
+ grad_scatter_output_2 = group_gemm_same_nk(
+ a=grad_fc1_2_output,
+ b=fc1_2_weight,
+ cumsum_M=cumsum_t,
+ max_M=grad_output.shape[0],
+ transpose_b=False,
+ )
+
+ # wgrad
+ grad_fc1_2_weight = None
+ if fc1_2_weight.requires_grad:
+ grad_fc1_2_weight = torch.empty_like(fc1_2_weight)
+ group_gemm_same_mn(
+ a=grad_fc1_2_output,
+ b=scatter_output,
+ c=grad_fc1_2_weight,
+ cumsum_K=cumsum_t,
+ max_K=grad_output.shape[0],
+ transpose_a=True,
+ transpose_b=False,
+ )
+
+ # MOE Step 5
+ grad_fc1_1_output = torch.ops.aten.silu_backward(grad_fc1_1_activation, fc1_1_output)
+
+ # MOE Step 4
+ # grad_scatter_output_1 = torch.empty_like(scatter_output)
+
+ # dgrad
+ grad_scatter_output_1 = group_gemm_same_nk(
+ a=grad_fc1_1_output,
+ b=fc1_1_weight,
+ cumsum_M=cumsum_t,
+ max_M=grad_output.shape[0],
+ transpose_b=False,
+ )
+
+ # wgrad
+ grad_fc1_1_weight = None
+ if fc1_1_weight.requires_grad:
+ grad_fc1_1_weight = torch.empty_like(fc1_1_weight)
+ group_gemm_same_mn(
+ a=grad_fc1_1_output,
+ b=scatter_output,
+ c=grad_fc1_1_weight,
+ cumsum_K=cumsum_t,
+ max_K=grad_output.shape[0],
+ transpose_a=True,
+ transpose_b=False,
+ )
+
+ # MOE Step 3
+ # MOE Step 3-3
+ grad_scatter_output = grad_scatter_output_1 + grad_scatter_output_2
+ grad_hidden_states = moe_gather(grad_scatter_output, scatter_index)
+
+ # MOE Step 3-2: no grad
+ # MOE Step 3-1: no grad
+
+ # reshape the result with input shape
+ grad_hidden_states = grad_hidden_states.reshape(hidden_states.shape)
+
+ return (
+ None, # num_experts
+ grad_gate_weight, # gate_weights
+ None, # expert_index
+ grad_hidden_states, # hidden_states
+ grad_fc1_1_weight, # fc1_1_weight
+ grad_fc1_2_weight, # fc1_2_weight
+ grad_fc2_weight, # fc2_weight
+ )
+
+
+def fused_moe_forward(
+ module: torch.nn.Module,
+ num_experts: int,
+ routing_weights: torch.Tensor,
+ selected_experts: torch.Tensor,
+ hidden_states: torch.Tensor,
+ fc1_1_weight: torch.Tensor,
+ fc1_2_weight: torch.Tensor,
+ fc2_weight: torch.Tensor,
+):
+ if module.training and get_parallel_state().ep_enabled:
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0)
+ # preprocess, permute token for ep
+ input_splits, output_splits, num_global_tokens_per_local_expert, num_global_sum_tokens_per_local_expert = (
+ preprocess(
+ expert_mask=expert_mask,
+ num_experts=num_experts,
+ ep_group=get_parallel_state().ep_group,
+ )
+ )
+ permute_tokens, routing_map, local_input_permutation_mapping, org_hidden_states_shape = token_pre_all2all(
+ hidden_states=hidden_states,
+ expert_mask=expert_mask,
+ num_experts=num_experts,
+ input_splits=input_splits,
+ output_splits=output_splits,
+ num_global_tokens_per_local_expert=num_global_tokens_per_local_expert,
+ ep_group=get_parallel_state().ep_group,
+ )
+
+ final_permute_tokens = torch.zeros(
+ (permute_tokens.shape),
+ dtype=permute_tokens.dtype,
+ device=permute_tokens.device,
+ )
+
+ cumsum = torch.cumsum(num_global_sum_tokens_per_local_expert, dim=0).to(permute_tokens.device)
+
+ final_permute_tokens = EPGroupGemm.apply(
+ permute_tokens,
+ cumsum,
+ fc1_1_weight,
+ fc1_2_weight,
+ fc2_weight,
+ )
+
+ # unpermute with routing_weight
+ final_hidden_states = tokens_post_all2all(
+ expert_outputs=final_permute_tokens,
+ routing_weights=routing_weights,
+ selected_experts=selected_experts,
+ num_experts=num_experts,
+ input_splits=input_splits,
+ output_splits=output_splits,
+ num_global_tokens_per_local_expert=num_global_tokens_per_local_expert,
+ routing_map=routing_map,
+ local_input_permutation_mapping=local_input_permutation_mapping,
+ org_hidden_states_shape=org_hidden_states_shape,
+ ep_group=get_parallel_state().ep_group,
+ )
+
+ else:
+ routing_weights = routing_weights.bfloat16()
+ hidden_states = hidden_states.bfloat16()
+ final_hidden_states = FusedMoeExpertFunction.apply(
+ num_experts,
+ routing_weights,
+ selected_experts,
+ hidden_states,
+ fc1_1_weight,
+ fc1_2_weight,
+ fc2_weight,
+ )
+
+ return final_hidden_states
diff --git a/lingbotvla/ops/group_gemm/__init__.py b/lingbotvla/ops/group_gemm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd1e8433dffa0b3ba420be3e346f4f5cd062014
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/lingbotvla/ops/group_gemm/kernel/__init__.py b/lingbotvla/ops/group_gemm/kernel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd1e8433dffa0b3ba420be3e346f4f5cd062014
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/kernel/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/lingbotvla/ops/group_gemm/kernel/group_gemm.py b/lingbotvla/ops/group_gemm/kernel/group_gemm.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdb75c7d0cd5691b9457eadeb78ea7d1b5066856
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/kernel/group_gemm.py
@@ -0,0 +1,396 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import torch
+import triton
+import triton.language as tl
+
+from ..utils.pretuned import algo_key_scaled, pretuned
+from .triton_utils.activation import (
+ ActivationType,
+ activation_fwd,
+)
+from .triton_utils.memory import (
+ load_block_with_pred_2d,
+ load_with_pred_1d,
+ load_with_pred_2d,
+ store_block_with_pred_2d,
+ store_with_pred_2d,
+)
+from .triton_utils.utils import (
+ get_pid_mn,
+ make_blocked,
+)
+
+
+def _get_cuda_autotune_config():
+ return [
+ triton.Config(
+ {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP": 8},
+ num_stages=3,
+ num_warps=8,
+ ),
+ ]
+
+
+# @triton.autotune(
+# configs=_get_cuda_autotune_config(),
+# key=["total_M", "N", "K"],
+# )
+@pretuned(
+ algo_key=algo_key_scaled(["total_M", "N", "K"], [5000, 1, 1], ["TRANSPOSE_A", "TRANSPOSE_B"]),
+ fallback={"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8},
+)
+@triton.heuristics(
+ values={
+ "N_ALIGNED": lambda args: args["N"] % args["BLOCK_N"] == 0,
+ "K_ALIGNED": lambda args: args["K"] % args["BLOCK_K"] == 0,
+ "HAS_ACTIVATION": lambda args: args["ACTIVATION"] is not None,
+ }
+)
+@triton.jit
+def group_gemm_same_nk_kernel(
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ act_ptr,
+ cumsum_M,
+ max_M,
+ total_M, # Used for generating algo. key only.
+ G: tl.constexpr,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ TRANSPOSE_A: tl.constexpr,
+ TRANSPOSE_B: tl.constexpr,
+ # No need to support TRANPOSE_C, just ask user to calculate `c.t()` as `b.t() @ a.t()`.
+ ACCUMULATE_TO_C: tl.constexpr,
+ GROUP: tl.constexpr,
+ N_ALIGNED: tl.constexpr,
+ K_ALIGNED: tl.constexpr,
+ ACTIVATION: tl.constexpr,
+ HAS_ACTIVATION: tl.constexpr,
+ SAVE_ACTIVATION: tl.constexpr,
+):
+ m, n = get_pid_mn(tl.program_id(axis=0), max_M, N, BLOCK_M, BLOCK_N, GROUP)
+ gid = tl.program_id(1).to(tl.uint64)
+ gtid_start = tl.load(cumsum_M + gid - 1, mask=gid > 0, other=0)
+ gtid_end = tl.load(cumsum_M + gid)
+ m_size = (gtid_end - gtid_start).to(tl.uint64)
+
+ if m * BLOCK_M >= m_size:
+ return
+
+ a_ptr += gtid_start * K
+ b_ptr += gid * K * N
+ c_ptr += gtid_start * N
+
+ offs_m = m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = n * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ offs_am = offs_m % m_size.to(tl.int64)
+ offs_bn = offs_n % N
+
+ blk_k = tl.arange(0, BLOCK_K)
+
+ stride_am, stride_ak = (K, 1) if not TRANSPOSE_A else (1, m_size)
+ stride_bk, stride_bn = (N, 1) if not TRANSPOSE_B else (1, K)
+
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + blk_k[None, :] * stride_ak)
+ b_ptrs = b_ptr + (blk_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
+ c_ptrs = c_ptr + N * offs_m[:, None] + 1 * offs_n[None, :]
+
+ if ACCUMULATE_TO_C:
+ c = load_with_pred_2d(
+ c_ptrs,
+ False,
+ N_ALIGNED,
+ offs_m[:, None] < m_size,
+ offs_n[None, :] < N,
+ other=0,
+ )
+ else:
+ c = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ # Really loading a 2D block. Using `load_with_pred_1d` as we only have one predicate.
+ a = load_with_pred_1d(a_ptrs, K_ALIGNED, blk_k[None, :] < K - k * BLOCK_K, other=0)
+ b = load_with_pred_1d(b_ptrs, K_ALIGNED, blk_k[:, None] < K - k * BLOCK_K, other=0)
+
+ c = tl.dot(a, b, c)
+
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+
+ if HAS_ACTIVATION:
+ # Makes GELU_APPROX faster, not sure why..
+ c = make_blocked(c, c_ptr.dtype.element_ty)
+ if SAVE_ACTIVATION:
+ store_with_pred_2d(
+ act_ptr + gtid_start * N + N * offs_m[:, None] + offs_n[None, :],
+ c,
+ False,
+ N_ALIGNED,
+ offs_m[:, None] < m_size,
+ offs_n[None, :] < N,
+ )
+ c = activation_fwd(c, ACTIVATION)
+
+ store_with_pred_2d(c_ptrs, c, False, N_ALIGNED, offs_m[:, None] < m_size, offs_n[None, :] < N)
+
+
+def group_gemm_same_nk(
+ a: torch.Tensor,
+ b: torch.Tensor,
+ cumsum_M: torch.Tensor,
+ max_M: int,
+ transpose_a: bool = False,
+ transpose_b: bool = False,
+ activation: Optional[ActivationType] = None,
+ save_activation: bool = False,
+ c: Optional[torch.Tensor] = None,
+):
+ """Grouped gemm for same nk
+
+ Keyword arguments:
+ a -- lhs matrixs to be matrix multiplied
+ b -- rhs matrixs to be matrix multiplied
+ cumsum_M -- matrixs's size cumsum on M
+ max_M -- matrixs's max size on M
+ transpose_a -- transpose `a` or not
+ transpose_b -- transpose `b` or not
+ activation -- activation type if needed
+ save_activation -- return the activation's input or not
+ c -- which tensor accumulate to, c = c + ggemm(a, b)
+ """
+ if transpose_b:
+ G, N, K = b.shape
+ else:
+ G, K, N = b.shape
+
+ assert not transpose_a, "Transpose A not tested yet."
+
+ assert a.dtype in [torch.bfloat16, torch.float16], a.dtype
+ assert b.dtype in [torch.bfloat16, torch.float16], b.dtype
+
+ assert a.device == b.device, f"a.device = {a.device}, b.device = {b.device}"
+
+ assert len(cumsum_M) == b.shape[0]
+
+ assert activation is None or activation in list(ActivationType), f"Not implemented: activation is {activation}."
+ assert activation or not save_activation, "Can't save activation since activation type is None"
+
+ assert a.is_contiguous() and b.is_contiguous(), "Not implemented: Noncontiguous input."
+
+ c_is_none = c is None
+ if c_is_none:
+ c = torch.empty((a.shape[1] if transpose_a else a.shape[0], N), dtype=a.dtype, device=a.device)
+
+ if save_activation:
+ act = torch.empty_like(c)
+
+ with torch.cuda.device(a.device):
+ group_gemm_same_nk_kernel[
+ lambda x: (
+ triton.cdiv(max_M, x["BLOCK_M"]) * triton.cdiv(N, x["BLOCK_N"]),
+ x["G"],
+ )
+ ](
+ a_ptr=a,
+ b_ptr=b,
+ c_ptr=c,
+ act_ptr=act if save_activation else None,
+ cumsum_M=cumsum_M,
+ max_M=max_M,
+ total_M=a.shape[0],
+ G=G,
+ K=K,
+ N=N,
+ TRANSPOSE_A=transpose_a,
+ TRANSPOSE_B=transpose_b,
+ ACCUMULATE_TO_C=not c_is_none,
+ ACTIVATION=activation,
+ SAVE_ACTIVATION=save_activation,
+ )
+
+ if save_activation:
+ return c, act
+
+ return c
+
+
+# @triton.autotune(
+# configs=_get_cuda_autotune_config(),
+# key=["total_K", "M", "N"],
+# )
+@pretuned(
+ algo_key=algo_key_scaled(["M", "N", "total_K"], [1, 1, 5000], ["TRANSPOSE_A", "TRANSPOSE_B"]),
+ fallback={"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP": 8},
+)
+@triton.heuristics(
+ values={
+ "M_ALIGNED": lambda args: args["M"] % args["BLOCK_M"] == 0,
+ "N_ALIGNED": lambda args: args["N"] % args["BLOCK_N"] == 0,
+ }
+)
+@triton.jit
+def group_gemm_same_mn_kernel(
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ cumsum_K,
+ total_K, # Used for generating algo. key only.
+ G: tl.constexpr,
+ M: tl.constexpr,
+ N: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ TRANSPOSE_A: tl.constexpr,
+ TRANSPOSE_B: tl.constexpr,
+ ACCUMULATE_TO_C: tl.constexpr,
+ GROUP: tl.constexpr,
+ M_ALIGNED: tl.constexpr,
+ N_ALIGNED: tl.constexpr,
+):
+ m, n = get_pid_mn(tl.program_id(axis=0), M, N, BLOCK_M, BLOCK_N, GROUP)
+ gid = tl.program_id(1).to(tl.uint64)
+ gtid_start = tl.load(cumsum_K + gid - 1, mask=gid > 0, other=0)
+ gtid_end = tl.load(cumsum_K + gid)
+ k = (gtid_end - gtid_start).to(tl.uint64)
+
+ if TRANSPOSE_A:
+ a_block_ptr = tl.make_block_ptr(
+ base=a_ptr + gtid_start * M,
+ shape=(M, k),
+ strides=(1, M),
+ offsets=(m * BLOCK_M, 0),
+ block_shape=(BLOCK_M, BLOCK_K),
+ order=(0, 1),
+ )
+ else:
+ a_block_ptr = tl.make_block_ptr(
+ base=a_ptr + gtid_start * M,
+ shape=(M, k),
+ strides=(k, 1),
+ offsets=(m * BLOCK_M, 0),
+ block_shape=(BLOCK_M, BLOCK_K),
+ order=(1, 0),
+ )
+ if TRANSPOSE_B:
+ b_block_ptr = tl.make_block_ptr(
+ base=b_ptr + gtid_start * N,
+ shape=(k, N),
+ strides=(1, k),
+ offsets=(0, n * BLOCK_N),
+ block_shape=(BLOCK_K, BLOCK_N),
+ order=(0, 1),
+ )
+ else:
+ b_block_ptr = tl.make_block_ptr(
+ base=b_ptr + gtid_start * N,
+ shape=(k, N),
+ strides=(N, 1),
+ offsets=(0, n * BLOCK_N),
+ block_shape=(BLOCK_K, BLOCK_N),
+ order=(1, 0),
+ )
+ c_block_ptr = tl.make_block_ptr(
+ base=c_ptr + gid * M * N,
+ shape=(M, N),
+ strides=(N, 1),
+ offsets=(m * BLOCK_M, n * BLOCK_N),
+ block_shape=(BLOCK_M, BLOCK_N),
+ order=(1, 0),
+ )
+
+ # Special case: no GEMM needed.
+ if k == 0:
+ if not ACCUMULATE_TO_C:
+ # Zero out the corresponding output region.
+ store_block_with_pred_2d(
+ c_block_ptr,
+ # tl.zeros(..., dtype=c_block_ptr.dtype.element_ty) raises "not implemented".
+ tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32).to(c_block_ptr.dtype.element_ty),
+ M_ALIGNED,
+ N_ALIGNED,
+ )
+ else:
+ # Nothing to do then, just leave the kernel.
+ pass
+
+ return
+
+ if ACCUMULATE_TO_C:
+ out = tl.load(c_block_ptr).to(tl.float32)
+ else:
+ out = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+
+ # FIXME: Weird type conversion.
+ for _ in range(tl.cdiv(k.to(tl.int64), BLOCK_K)):
+ a = load_block_with_pred_2d(a_block_ptr, M_ALIGNED, False)
+ b = load_block_with_pred_2d(b_block_ptr, False, N_ALIGNED)
+
+ out += tl.dot(a, b)
+
+ a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))
+ b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))
+
+ store_block_with_pred_2d(c_block_ptr, out.to(c_block_ptr.dtype.element_ty), M_ALIGNED, N_ALIGNED)
+
+
+def group_gemm_same_mn(
+ a: torch.Tensor,
+ b: torch.Tensor,
+ c: torch.Tensor,
+ cumsum_K: torch.Tensor,
+ max_K: int,
+ transpose_a: bool = False,
+ transpose_b: bool = False,
+):
+ G, M, N = c.shape
+
+ assert a.dtype in [torch.bfloat16, torch.float16], a.dtype
+ assert b.dtype in [torch.bfloat16, torch.float16], b.dtype
+
+ assert a.device == b.device, f"a.device = {a.device}, b.device = {b.device}"
+ assert a.device == c.device, f"a.device = {a.device}, c.device = {c.device}"
+
+ # TODO(wenyawei):
+ assert c is not None, c
+ assert len(cumsum_K) == c.shape[0], f"{len(cumsum_K), c.shape}"
+ assert a.is_contiguous() and b.is_contiguous() and c.is_contiguous(), "Not implemented: Noncontiguous input."
+
+ with torch.cuda.device(a.device):
+ group_gemm_same_mn_kernel[
+ lambda x: (
+ triton.cdiv(M, x["BLOCK_M"]) * triton.cdiv(N, x["BLOCK_N"]),
+ x["G"],
+ )
+ ](
+ a_ptr=a,
+ b_ptr=b,
+ c_ptr=c,
+ cumsum_K=cumsum_K,
+ total_K=b.shape[0],
+ G=G,
+ M=M,
+ N=N,
+ TRANSPOSE_A=transpose_a,
+ TRANSPOSE_B=transpose_b,
+ ACCUMULATE_TO_C=False,
+ )
diff --git a/lingbotvla/ops/group_gemm/kernel/moe.py b/lingbotvla/ops/group_gemm/kernel/moe.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c3d00f7ced4556293ce0c96f027ca083a2a499
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/kernel/moe.py
@@ -0,0 +1,417 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import triton
+import triton.language as tl
+
+from .triton_utils.memory import (
+ load_with_pred_1d,
+ store_with_pred_1d,
+)
+
+
+@triton.heuristics(values={"BLOCK_ALIGNED": lambda args: args["num_elts"] % args["BLOCK_SIZE"] == 0})
+@triton.jit
+def _expert_histogram_kernel(
+ out_ptr,
+ x_ptr,
+ num_elts,
+ num_bins,
+ NUM_BINS_LAST_UNUSED: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ BLOCK_ALIGNED: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ in_off = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ data = load_with_pred_1d(x_ptr + in_off, BLOCK_ALIGNED, in_off < num_elts, NUM_BINS_LAST_UNUSED - 1).to(tl.int32)
+
+ tl.device_assert(
+ data < num_bins or data == NUM_BINS_LAST_UNUSED - 1,
+ "Out-of-bound element found.",
+ )
+ count = tl.histogram(data, NUM_BINS_LAST_UNUSED)
+
+ out_off = tl.arange(0, NUM_BINS_LAST_UNUSED)
+ tl.atomic_add(out_ptr + out_off, count, mask=out_off < num_bins, sem="relaxed")
+
+
+def expert_histogram(input: torch.Tensor, num_bins: int) -> torch.Tensor:
+ """Returns histogram of `input`, with bin width 1. Note that for each individual `num_bins`,
+ a separate Triton kernel is generated (mostly). So if `num_bins` varies between calls, you
+ probably should go for some other histogram method.
+ """
+
+ assert input.is_cuda
+ assert input.dtype == torch.int32 or input.dtype == torch.int64
+ assert input.numel() < (1 << 31) - 1, "Too many elements."
+ flattened = input.flatten().contiguous()
+
+ # An extra slot is needed, our kernel uses the extra slot to handle possible OoO reads.
+ # Wastes a lot of slots but hopefully the kernel can still saturate memory B/W.
+ NUM_BINS_LAST_UNUSED = triton.next_power_of_2(num_bins + 1)
+ out = torch.zeros([num_bins], dtype=torch.int32, device=input.device)
+
+ BLOCK_SIZE = 1024
+ num_elts = flattened.numel()
+ grid = (triton.cdiv(num_elts, BLOCK_SIZE),)
+ with torch.cuda.device(input.device):
+ _expert_histogram_kernel[grid](
+ out_ptr=out,
+ x_ptr=flattened,
+ num_elts=num_elts,
+ num_bins=num_bins,
+ NUM_BINS_LAST_UNUSED=NUM_BINS_LAST_UNUSED,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+
+ return out[:num_bins]
+
+
+@triton.heuristics(values={"N_ALIGNED": lambda args: args["N"] % args["BLOCK_N"] == 0})
+@triton.jit
+def _moe_gather_kernel(
+ X,
+ Y,
+ index,
+ num_elts_in,
+ num_elts_out,
+ N: tl.constexpr, # hidden size
+ TOPK: tl.constexpr,
+ STRIDE_XM: tl.constexpr,
+ STRIDE_XN: tl.constexpr,
+ STRIDE_OM: tl.constexpr,
+ STRIDE_ON: tl.constexpr,
+ STRIDE_IM: tl.constexpr,
+ STRIDE_IN: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ N_ALIGNED: tl.constexpr,
+):
+ r"""
+ X: m * topk x n
+ Y: m x n
+ index: m x topk
+ code:
+ repeated-X: m * topk x n -> reduce(sum_over_topk) -> m x n
+ Y: Y[arange(m)] = sum_over_topk(repeated-X[arange(m) * topk])
+ """
+ pid_m = tl.program_id(axis=0).to(tl.int64) # m
+ block_idx = tl.program_id(axis=1).to(tl.int64)
+ n = block_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+ y = tl.zeros([BLOCK_N], dtype=tl.float32)
+ for i in tl.static_range(TOPK):
+ x_index = tl.load(index + pid_m.to(tl.int64) * STRIDE_IM + i * STRIDE_IN)
+ tl.device_assert(x_index < num_elts_in, "Input OOB")
+ x = load_with_pred_1d(
+ X + x_index.to(tl.int64) * STRIDE_XM + n.to(tl.int64) * STRIDE_XN, N_ALIGNED, mask=n < N, other=0
+ )
+ y += x
+ # save one line
+ tl.device_assert(pid_m < num_elts_out, "Output OOB")
+ Y = Y + pid_m.to(tl.int64) * STRIDE_OM + n.to(tl.int64) * STRIDE_ON # noqa
+ store_with_pred_1d(Y, y, N_ALIGNED, mask=n < N)
+
+
+def moe_gather(x: torch.Tensor, index: torch.Tensor, out_dtype=None):
+ assert x.is_cuda and index.is_cuda
+ M, topk = index.shape
+ assert x.shape[0] == M * topk
+ N = x.shape[1]
+
+ assert x.device == index.device, f"x.device = {x.device}, index.device = {index.device}"
+
+ out_dtype = out_dtype or x.dtype
+ out = torch.empty(M, N, dtype=out_dtype, device=x.device)
+
+ grid = lambda meta: (M, triton.cdiv(N, meta["BLOCK_N"])) # noqa
+ with torch.cuda.device(x.device):
+ _moe_gather_kernel[grid](
+ x,
+ out,
+ index,
+ num_elts_in=M * topk,
+ num_elts_out=M,
+ N=N,
+ TOPK=topk,
+ STRIDE_XM=x.stride(0),
+ STRIDE_XN=x.stride(1),
+ STRIDE_OM=out.stride(0),
+ STRIDE_ON=out.stride(1),
+ STRIDE_IM=index.stride(0),
+ STRIDE_IN=index.stride(1),
+ BLOCK_N=1024,
+ )
+
+ return out
+
+
+@triton.heuristics(values={"N_ALIGNED": lambda args: args["N"] % args["BLOCK_N"] == 0})
+@triton.jit
+def _moe_add_gather_kernel(
+ X,
+ Y,
+ Z,
+ index,
+ num_elts_in,
+ num_elts_out,
+ N: tl.constexpr, # hidden size
+ TOPK: tl.constexpr,
+ STRIDE_XM: tl.constexpr,
+ STRIDE_XN: tl.constexpr,
+ STRIDE_YM: tl.constexpr,
+ STRIDE_YN: tl.constexpr,
+ STRIDE_OM: tl.constexpr,
+ STRIDE_ON: tl.constexpr,
+ STRIDE_IM: tl.constexpr,
+ STRIDE_IN: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ N_ALIGNED: tl.constexpr,
+):
+ r"""
+ X: m * topk x n
+ Y: m * topk x n
+ Z: m x n
+ index: m x topk
+
+ code:
+ repeated-(X + Y): m * topk x n -> reduce(sum_over_topk) -> m x n
+ Z: Z[arange(m)] = sum_over_topk(repeated-(X+Y)[arange(m) * topk])
+ """
+ pid_m = tl.program_id(axis=0) # m
+ block_idx = tl.program_id(axis=1)
+
+ n = block_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+ z = tl.zeros([BLOCK_N], dtype=tl.float32)
+ for i in tl.static_range(TOPK):
+ x_index = tl.load(index + pid_m * STRIDE_IM + i * STRIDE_IN)
+ tl.device_assert(x_index < num_elts_in, "Input OOB")
+ x = load_with_pred_1d(X + x_index * STRIDE_XM + n * STRIDE_XN, N_ALIGNED, mask=n < N, other=0)
+ y = load_with_pred_1d(Y + x_index * STRIDE_YM + n * STRIDE_YN, N_ALIGNED, mask=n < N, other=0)
+ z += x + y
+
+ # save one line
+ tl.device_assert(pid_m < num_elts_out, "Output OOB")
+ Z = Z + pid_m * STRIDE_OM + n * STRIDE_ON # noqa
+ store_with_pred_1d(Z, z, N_ALIGNED, mask=n < N)
+
+
+def moe_add_gather(x: torch.Tensor, y: torch.Tensor, index: torch.Tensor, out_dtype=None):
+ assert x.is_cuda and y.is_cuda and index.is_cuda
+ assert x.shape == y.shape
+ assert x.dtype == y.dtype
+ M, topk = index.shape
+ assert x.shape[0] == M * topk
+ N = x.shape[1]
+
+ assert x.device == y.device, f"x.device = {x.device}, y.device = {y.device}"
+ assert x.device == index.device, f"x.device = {x.device}, index.device = {index.device}"
+
+ out_dtype = out_dtype or x.dtype
+ out = torch.empty(M, N, dtype=out_dtype, device=x.device)
+
+ grid = lambda meta: (M, triton.cdiv(N, meta["BLOCK_N"])) # noqa
+ with torch.cuda.device(x.device):
+ _moe_add_gather_kernel[grid](
+ x,
+ y,
+ out,
+ index,
+ num_elts_in=M * topk,
+ num_elts_out=M,
+ N=N,
+ TOPK=topk,
+ STRIDE_XM=x.stride(0),
+ STRIDE_XN=x.stride(1),
+ STRIDE_YM=y.stride(0),
+ STRIDE_YN=y.stride(1),
+ STRIDE_OM=out.stride(0),
+ STRIDE_ON=out.stride(1),
+ STRIDE_IM=index.stride(0),
+ STRIDE_IN=index.stride(1),
+ BLOCK_N=1024,
+ )
+
+ return out
+
+
+@triton.heuristics(values={"N_ALIGNED": lambda args: args["N"] % args["BLOCK_N"] == 0})
+@triton.jit
+def _moe_scatter_kernel(
+ X,
+ O, # noqa
+ index,
+ num_elts_in,
+ num_elts_out,
+ N: tl.constexpr, # hidden size
+ TOPK: tl.constexpr,
+ STRIDE_XM: tl.constexpr,
+ STRIDE_XN: tl.constexpr,
+ STRIDE_OM: tl.constexpr,
+ STRIDE_ON: tl.constexpr,
+ STRIDE_IM: tl.constexpr,
+ STRIDE_IN: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ N_ALIGNED: tl.constexpr,
+):
+ r"""
+ X: m x n
+ O: m * topk x n
+ index: m x topk
+
+ code:
+ X: m x n -> repeat -> m x topk x n -> m * topk x n
+ X[arange(m) * topk] = X[arange(m)]
+
+ O[index] = X
+ O[index[arange(m) * topk]] = X[arange(m) * topk]
+ """
+
+ pid_m = tl.program_id(axis=0) # m
+ block_idx = tl.program_id(axis=1)
+ n = block_idx * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ tl.device_assert(pid_m < num_elts_in, "Input OOB.")
+ X = X + pid_m * STRIDE_XM + n * STRIDE_XN
+ x = load_with_pred_1d(X, N_ALIGNED, mask=n < N, other=0)
+
+ for i in tl.static_range(TOPK):
+ o_index = tl.load(index + pid_m * STRIDE_IM + i * STRIDE_IN)
+ tl.device_assert(o_index < num_elts_out, "Output OOB.")
+ tmp_index = o_index.to(tl.int64) * STRIDE_OM
+ # tl.device_print("tmp_index", tmp_index)
+ out = O + tmp_index + n * STRIDE_ON
+
+ # save one line
+ store_with_pred_1d(out, x, N_ALIGNED, mask=n < N)
+
+
+def moe_scatter(x: torch.Tensor, index: torch.Tensor, out_dtype=None):
+ assert x.is_cuda and index.is_cuda
+ assert x.shape[0] == index.shape[0]
+
+ assert x.device == index.device, f"x.device = {x.device}, index.device = {index.device}"
+
+ M, N = x.shape
+ topk = index.shape[1]
+ out_dtype = out_dtype or x.dtype
+ out = torch.empty(M * topk, N, dtype=out_dtype, device=x.device)
+ assert lambda: index.unique().numel() == M * topk, "Holes in output?"
+
+ grid = lambda meta: (M, triton.cdiv(N, meta["BLOCK_N"])) # noqa
+ with torch.cuda.device(x.device):
+ _moe_scatter_kernel[grid](
+ x,
+ out,
+ index,
+ num_elts_in=M,
+ num_elts_out=M * topk,
+ N=N,
+ TOPK=topk,
+ STRIDE_XM=x.stride(0),
+ STRIDE_XN=x.stride(1),
+ STRIDE_OM=out.stride(0),
+ STRIDE_ON=out.stride(1),
+ STRIDE_IM=index.stride(0),
+ STRIDE_IN=index.stride(1),
+ BLOCK_N=1024,
+ )
+
+ return out
+
+
+@triton.jit
+def _moe_index_compute_kernel(
+ indices_ptr,
+ experts_for_tokens_ptr,
+ temp_histogram_cumsum_ptr,
+ num_elts,
+ NUM_EXPERTS: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr, # Unlikely to be aligned, so we don't test for alignment.
+):
+ _OOB_EXPERT_ID: tl.constexpr = 1023
+ tl.static_assert(_OOB_EXPERT_ID > NUM_EXPERTS, "Too many experts for me.")
+
+ start_pos = tl.program_id(0)
+ processing_range = start_pos * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ expert_ids = tl.load(
+ experts_for_tokens_ptr + processing_range,
+ processing_range < num_elts,
+ _OOB_EXPERT_ID,
+ )
+ assert expert_ids < NUM_EXPERTS or expert_ids == _OOB_EXPERT_ID
+
+ indices = tl.zeros([BLOCK_SIZE], dtype=tl.int32)
+ for expert_id in tl.static_range(NUM_EXPERTS):
+ mask = expert_ids == expert_id
+ one_if_expert_id_matches = mask.to(tl.int32)
+
+ # Tokens allocated to this expert.
+ slots_to_reserve = tl.sum(one_if_expert_id_matches)
+ slot_ids = (
+ # Reserve last `slots_to_reserve` slots for us.
+ tl.atomic_add(temp_histogram_cumsum_ptr + expert_id, -slots_to_reserve, sem="relaxed")
+ # `atomic_add` returns old value, so we need to do substraction again.
+ - slots_to_reserve
+ # Local offset for each token in `expert_ids`.
+ + tl.cumsum(one_if_expert_id_matches)
+ # Result of `cumsum` is "1-based".
+ - 1
+ )
+ assigned_slot_or_zero = tl.where(mask, slot_ids, 0)
+ indices += assigned_slot_or_zero.to(tl.int32)
+
+ tl.store(indices_ptr + processing_range, indices, processing_range < num_elts)
+
+
+def moe_index_compute(experts_for_tokens: torch.Tensor, expert_histogram_cumsum: torch.Tensor) -> torch.Tensor:
+ """Calculate row number into activation passed to MoE fc1 for each token.
+
+ Arguments:
+
+ - experts_for_tokens: [n_tokens, expert_topk] experts assigned to each token.
+ - expert_histogram_cumsum: [n_experts]: cumsum of number of tokens allocated to each expert,
+ with last element being number of tokens. NOTE: This is usually calculated as part of gemm
+ grouped, so you can just reuse it.
+
+ Returns:
+
+ - [n_tokens, expert_topk] row number into activation passed to MoE fc1 for each token. Each
+ token should be duplicated `expert_topk` times.
+ """
+ # No noncontiguous input.
+ assert experts_for_tokens.is_contiguous()
+ assert experts_for_tokens.numel() < (1 << 31) - 1
+ assert expert_histogram_cumsum.is_contiguous()
+ assert experts_for_tokens.device == expert_histogram_cumsum.device, (
+ f"experts_for_tokens.device = {experts_for_tokens.device}, expert_histogram_cumsum.device = {expert_histogram_cumsum.device}"
+ )
+
+ BLOCK_SIZE = 128 # Faster than 1024, not sure why. May be better occupancy?
+
+ histogram_cumsum_copy = expert_histogram_cumsum.clone().detach() # Temporary workspace.
+ indices = torch.empty_like(experts_for_tokens, dtype=int)
+
+ with torch.cuda.device(experts_for_tokens.device):
+ _moe_index_compute_kernel[(triton.cdiv(experts_for_tokens.numel(), BLOCK_SIZE),)](
+ indices_ptr=indices,
+ experts_for_tokens_ptr=experts_for_tokens,
+ temp_histogram_cumsum_ptr=histogram_cumsum_copy,
+ num_elts=experts_for_tokens.numel(),
+ NUM_EXPERTS=histogram_cumsum_copy.numel(),
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+
+ return indices
diff --git a/lingbotvla/ops/group_gemm/kernel/triton_utils/__init__.py b/lingbotvla/ops/group_gemm/kernel/triton_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd1e8433dffa0b3ba420be3e346f4f5cd062014
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/kernel/triton_utils/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/lingbotvla/ops/group_gemm/kernel/triton_utils/activation.py b/lingbotvla/ops/group_gemm/kernel/triton_utils/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c74f59e38cb12c97b52b91d74f7a9adc2a930d
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/kernel/triton_utils/activation.py
@@ -0,0 +1,115 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from enum import Enum
+
+import triton
+import triton.language as tl
+
+
+class ActivationType(str, Enum):
+ GELU = "gelu"
+ GELU_NEW = "gelu_new" # gelu with tanh approximation
+ SILU = "silu"
+
+
+@triton.jit
+def activation_fwd(x: tl.tensor, ACTIVATION: tl.constexpr):
+ orig_dtype = x.dtype
+ x = x.to(tl.float32)
+ if ACTIVATION == "gelu":
+ y = gelu(x)
+ elif ACTIVATION == "gelu_new":
+ y = gelu_new(x)
+ elif ACTIVATION == "silu":
+ y = silu(x)
+ else:
+ tl.static_assert(False, f"Unsupported activation of {ACTIVATION}")
+ return y.to(orig_dtype)
+
+
+@triton.jit
+def activation_bwd(dy: tl.tensor, x: tl.tensor, ACTIVATION: tl.constexpr):
+ orig_dtype = dy.dtype
+ x = x.to(tl.float32)
+ dy = dy.to(tl.float32)
+ if ACTIVATION == "gelu":
+ dx = dy * gelu_grad(x)
+ elif ACTIVATION == "gelu_new":
+ dx = dy * gelu_new_grad(x)
+ elif ACTIVATION == "silu":
+ dx = dy * silu_grad(x)
+ else:
+ tl.static_assert(False, f"Unsupported activation of {ACTIVATION}")
+ return dx.to(orig_dtype)
+
+
+_sqrt2pi: triton.language.constexpr = math.sqrt(2.0 / math.pi)
+_sqrt1_2: triton.language.constexpr = math.sqrt(1.0 / 2)
+_gaussian_pdf_normalization: triton.language.constexpr = 1.0 / math.sqrt(2 * math.pi)
+
+
+@triton.jit
+def tanh(x):
+ # Tanh is just a scaled sigmoid
+ return 2 * tl.sigmoid(2 * x) - 1
+
+
+@triton.jit
+def gelu(x):
+ """Gaussian Error Linear Unit (GELU)"""
+ x = x.to(tl.float32)
+ return x * 0.5 * (1.0 + tl.erf(x * _sqrt1_2))
+
+
+@triton.jit
+def gelu_grad(x):
+ x = x.to(tl.float32)
+ cdf = 0.5 * (1.0 + tl.erf(x * _sqrt1_2))
+ pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
+ return cdf + x * pdf
+
+
+@triton.jit
+def gelu_new(x):
+ """
+ GeLU_ activation - Gaussian error linear unit, with tanh approximation
+
+ .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
+ """
+ return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))
+
+
+@triton.jit
+def gelu_new_grad(x):
+ # CREDITS: Fast implementation proposed in
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
+ x = x.to(tl.float32)
+ tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
+ return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
+
+
+@triton.jit
+def silu(x):
+ """https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html"""
+ x = x.to(tl.float32)
+ return x * tl.sigmoid(x)
+
+
+@triton.jit
+def silu_grad(x):
+ x = x.to(tl.float32)
+ f = tl.sigmoid(x)
+ return f + x * (f - f * f)
diff --git a/lingbotvla/ops/group_gemm/kernel/triton_utils/memory.py b/lingbotvla/ops/group_gemm/kernel/triton_utils/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..a219f18f9cdccf625273a0941675178d01d9b9e4
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/kernel/triton_utils/memory.py
@@ -0,0 +1,94 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def load_with_pred_1d(ptr, skip_boundary_check: tl.constexpr, mask: tl.tensor, other=0):
+ if not skip_boundary_check:
+ return tl.load(ptr, mask, other=other)
+ else:
+ return tl.load(ptr)
+
+
+@triton.jit
+def store_with_pred_1d(ptr, value, skip_boundary_check: tl.constexpr, mask: tl.tensor):
+ if not skip_boundary_check:
+ tl.store(ptr, value, mask)
+ else:
+ tl.store(ptr, value)
+
+
+@triton.jit
+def load_with_pred_2d(
+ ptr,
+ skip_boundary_check_0: tl.constexpr,
+ skip_boundary_check_1: tl.constexpr,
+ mask_0: tl.tensor,
+ mask_1: tl.tensor,
+ other=0,
+):
+ if not skip_boundary_check_0 and not skip_boundary_check_1:
+ return tl.load(ptr, mask_0 and mask_1, other=other)
+ elif not skip_boundary_check_0 and skip_boundary_check_1:
+ return tl.load(ptr, mask_0, other=other)
+ elif skip_boundary_check_0 and not skip_boundary_check_1:
+ return tl.load(ptr, mask_1, other=other)
+ else:
+ return tl.load(ptr)
+
+
+@triton.jit
+def store_with_pred_2d(
+ ptr,
+ value,
+ skip_boundary_check_0: tl.constexpr,
+ skip_boundary_check_1: tl.constexpr,
+ mask_0: tl.tensor,
+ mask_1: tl.tensor,
+):
+ if not skip_boundary_check_0 and not skip_boundary_check_1:
+ tl.store(ptr, value, mask_0 and mask_1)
+ elif not skip_boundary_check_0 and skip_boundary_check_1:
+ tl.store(ptr, value, mask_0)
+ elif skip_boundary_check_0 and not skip_boundary_check_1:
+ tl.store(ptr, value, mask_1)
+ else:
+ tl.store(ptr, value)
+
+
+@triton.jit
+def load_block_with_pred_2d(ptr, skip_boundary_check_0: tl.constexpr, skip_boundary_check_1: tl.constexpr):
+ if not skip_boundary_check_0 and not skip_boundary_check_1:
+ return tl.load(ptr, boundary_check=(0, 1))
+ elif not skip_boundary_check_0 and skip_boundary_check_1:
+ return tl.load(ptr, boundary_check=(0,))
+ elif skip_boundary_check_0 and not skip_boundary_check_1:
+ return tl.load(ptr, boundary_check=(1,))
+ else:
+ return tl.load(ptr)
+
+
+@triton.jit
+def store_block_with_pred_2d(ptr, value, skip_boundary_check_0: tl.constexpr, skip_boundary_check_1: tl.constexpr):
+ if not skip_boundary_check_0 and not skip_boundary_check_1:
+ tl.store(ptr, value, boundary_check=(0, 1))
+ elif not skip_boundary_check_0 and skip_boundary_check_1:
+ tl.store(ptr, value, boundary_check=(0,))
+ elif skip_boundary_check_0 and not skip_boundary_check_1:
+ tl.store(ptr, value, boundary_check=(1,))
+ else:
+ tl.store(ptr, value)
diff --git a/lingbotvla/ops/group_gemm/kernel/triton_utils/utils.py b/lingbotvla/ops/group_gemm/kernel/triton_utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..19fa35e07c67692fa3abe531f698246ef90301ca
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/kernel/triton_utils/utils.py
@@ -0,0 +1,53 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+import triton.language as tl
+
+
+# FIXME: Maybe we should allow different `GROUP_SIZE` along `M` and `N`. Needs more investigation
+# on PTX produced.
+@triton.jit
+def get_pid_mn(pid, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_SIZE: tl.constexpr):
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ num_pid_in_group = GROUP_SIZE * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+ return pid_m, pid_n
+
+
+@triton.jit
+def make_blocked(t: tl.tensor, intermediate_type: tl.dtype) -> tl.tensor:
+ """Forcibly convert tensor (from "mma" layout) into "blocked" layout.
+
+ `intermediate_type` affects performance. Usually `tl.bfloat16` or `tl.float16` should be used.
+ INTERNALLY `t` IS CONVERTED TO `intermediate_type` AND BACK SO THE PRECISION CAN DROP.
+
+ ATM Triton does such conversion prior to storing tensor into global memory. This usually doesn't
+ matter as we usually only store the accumulator once. However, if we'd like to perform some
+ element-wise operation on the accumulator and save both pre-op and post-op results, Triton will
+ do the conversion twice, and hence hurt performance.
+
+ In such cases, forcibly convert tensor eagerly can help performance. This is not guaranteed, so
+ be sure to benchmark before applying this "optimization".
+
+ NOTE: Once Triton can optimize away multiple layout conversions, this hack should be removed.
+ """
+ # This really relies on Triton's internal implementation.. See implementation of `expand_dims`
+ # op, it triggers emission of `triton_gpu.convert_layout`.
+ return t.to(intermediate_type).expand_dims(0).reshape(t.shape)
diff --git a/lingbotvla/ops/group_gemm/utils/__init__.py b/lingbotvla/ops/group_gemm/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd1e8433dffa0b3ba420be3e346f4f5cd062014
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/utils/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/lingbotvla/ops/group_gemm/utils/benchmark_utils.py b/lingbotvla/ops/group_gemm/utils/benchmark_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9f88820cc61b80cde910d107448a4ff308b010
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/utils/benchmark_utils.py
@@ -0,0 +1,167 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from typing import Callable, Optional
+
+import torch
+import torch.testing
+
+from . import envvars
+from . import logger as blog
+
+
+_BENCHMARK_RESULT_FILE = "benchmark_results.txt"
+
+
+def _benchmark_fn(f, repeats):
+ warmup_repeats = 100
+
+ if envvars.testing_is_ci_env():
+ repeats = min(10, repeats)
+
+ if envvars.benchmarking_minimal_run():
+ # Mostly used together w/ Nsight Compute. Nishgt compute itself will run kernel multiple
+ # times, so we don't bother repeat launching kernel here.
+ repeats = 1
+ warmup_repeats = 0
+
+ start_event = [torch.cuda.Event(enable_timing=True) for _ in range(repeats)]
+ end_event = [torch.cuda.Event(enable_timing=True) for _ in range(repeats)]
+ for _ in range(warmup_repeats):
+ f()
+
+ if not envvars.benchmarking_minimal_run():
+ # Tens of milliseconds, should be sufficient for CPU to catch up.
+ torch.cuda._sleep(50_000_000)
+
+ for i in range(repeats):
+ start_event[i].record()
+ f()
+ end_event[i].record()
+ torch.cuda.synchronize()
+
+ durations = sorted([start_event[i].elapsed_time(end_event[i]) for i in range(repeats)])
+ if repeats >= 10: # We only preserve 25% to 75% timings.
+ durations = durations[int(len(durations) * 0.25) : int(len(durations) * 0.75)]
+
+ elapsed = sum(durations) * 1e-3 # ms -> s
+ return elapsed, len(durations) / elapsed
+
+
+def _append_result_to_on_disk_file(result):
+ current = []
+
+ if os.path.exists(_BENCHMARK_RESULT_FILE):
+ with open(_BENCHMARK_RESULT_FILE) as f:
+ current = json.loads(f.read())
+
+ current.append(result)
+
+ with open(_BENCHMARK_RESULT_FILE, "w") as f:
+ f.write(json.dumps(current, indent=4))
+
+
+def _report_benchmark_result(
+ name,
+ iters_per_sec,
+ elapsed_secs,
+ measurement,
+ measurement_unit,
+ is_baseline,
+ key_metric,
+):
+ if is_baseline:
+ name = name + " [baseline]" # ...
+
+ msec_per_iter = 1000 / iters_per_sec
+ blog.logging.info(
+ f"{name}: used {elapsed_secs:.2f} seconds ({msec_per_iter:.2f} ms per iter), "
+ f"{measurement:.2f} {measurement_unit}/s"
+ )
+
+ if envvars.benchmarking_write_report():
+ _append_result_to_on_disk_file(
+ {
+ "name": name,
+ "elapsed_secs": elapsed_secs,
+ "measurement": measurement,
+ "measurement_unit": measurement_unit,
+ "msec_per_iter": msec_per_iter,
+ "is_baseline": is_baseline,
+ "key_metric": key_metric,
+ }
+ )
+
+
+def benchmark_tflops(name, flops, run_func=None, baseline=None, key_metric=False, repeats=1000):
+ assert run_func is not None or baseline is not None
+
+ if baseline is not None and not envvars.benchmarking_no_baseline():
+ elapsed, iters_per_sec = _benchmark_fn(baseline, repeats)
+ _report_benchmark_result(
+ name,
+ iters_per_sec,
+ elapsed,
+ flops * iters_per_sec / 1e12,
+ "TFlops",
+ True,
+ key_metric,
+ )
+ if run_func is not None:
+ elapsed, iters_per_sec = _benchmark_fn(run_func, repeats)
+ _report_benchmark_result(
+ name,
+ iters_per_sec,
+ elapsed,
+ flops * iters_per_sec / 1e12,
+ "TFlops",
+ False,
+ key_metric,
+ )
+
+
+def benchmark_gibps(
+ name: str,
+ bytes: int,
+ run_func: Optional[Callable] = None,
+ baseline: Optional[Callable] = None,
+ key_metric: bool = False,
+ repeats: int = 100,
+):
+ assert run_func is not None or baseline is not None
+
+ if baseline is not None and not envvars.benchmarking_no_baseline():
+ elapsed, iters_per_sec = _benchmark_fn(baseline, repeats)
+ _report_benchmark_result(
+ name,
+ iters_per_sec,
+ elapsed,
+ bytes * iters_per_sec / 2**30,
+ "GiB",
+ True,
+ key_metric,
+ )
+ if run_func is not None:
+ elapsed, iters_per_sec = _benchmark_fn(run_func, repeats)
+ _report_benchmark_result(
+ name,
+ iters_per_sec,
+ elapsed,
+ bytes * iters_per_sec / 2**30,
+ "GiB",
+ False,
+ key_metric,
+ )
diff --git a/lingbotvla/ops/group_gemm/utils/config.py b/lingbotvla/ops/group_gemm/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4d0f604371a076b051f9447e364e71304182810
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/utils/config.py
@@ -0,0 +1,71 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from typing import Any, Dict
+
+from .path import (
+ get_bpex_root,
+ get_config_dedicated_file_for,
+ get_config_path_prefix_for,
+)
+
+
+def load_all_configs(path_prefix: str) -> Dict:
+ """Load all configs in specified directory and merge them into a single dictionary."""
+ res = {}
+
+ try:
+ with open(f"{path_prefix}.bpex") as f:
+ res = json.loads(f.read())
+ except FileNotFoundError:
+ pass
+
+ try:
+ dir = path_prefix
+ algos = [
+ f[: -len(".bpex")] for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f)) and f.endswith(".bpex")
+ ]
+ for algo_key in algos:
+ with open(f"{dir}/{algo_key}.bpex") as f:
+ t = json.loads(f.read())
+ res.update({algo_key: t})
+ except FileNotFoundError:
+ pass
+
+ return res
+
+
+def load_all_configs_for(kernel: Any) -> Dict:
+ """Load configs for all pre-tuned algo-key for a given kernel and merge them into a single
+ dictionary. Device and Triton version is assumed the same as the calling environment.
+ """
+ path_prefix = get_config_path_prefix_for(kernel)
+ configs = load_all_configs(path_prefix)
+ return configs
+
+
+def write_config_into_dedicated_file_for(dir_prefix: str, kernel: Any, algo_key: str, configs: Dict):
+ """Write config for the given kernel and algo_key into `dir_prefix`. Internal directory
+ hierarchy used by bpex is preserved inside `dir_prefix`."""
+ rel = os.path.relpath(get_config_dedicated_file_for(kernel, algo_key), get_bpex_root())
+ path = f"{dir_prefix}/{rel}"
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ with open(path, "w+") as f:
+ f.write(format_config_to_str(configs))
+
+
+def format_config_to_str(configs: Dict):
+ return json.dumps(configs, indent=2, sort_keys=True) + "\n"
diff --git a/lingbotvla/ops/group_gemm/utils/device.py b/lingbotvla/ops/group_gemm/utils/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..978621e65442cc21896134b13a7953f03b936566
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/utils/device.py
@@ -0,0 +1,31 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import lru_cache
+
+
+@lru_cache
+def get_device_key() -> str:
+ import torch
+
+ if torch.cuda.get_device_capability() == (8, 0):
+ return "A100" # A30 is treated the same way as A100 for the moment.
+
+ if torch.cuda.get_device_capability() == (9, 0):
+ return "H100"
+
+ name = torch.cuda.get_device_name()
+ if name.startswith("NVIDIA "):
+ name = name[len("NVIDIA ") :]
+ return name
diff --git a/lingbotvla/ops/group_gemm/utils/envvars.py b/lingbotvla/ops/group_gemm/utils/envvars.py
new file mode 100644
index 0000000000000000000000000000000000000000..093a48e6e220e901d693edef1821cc8810acf42f
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/utils/envvars.py
@@ -0,0 +1,65 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from functools import lru_cache
+
+
+@lru_cache
+def is_env_option_enabled(opt: str) -> bool:
+ return int(os.getenv(opt, "0"))
+
+
+def is_assertion_enabled():
+ return is_env_option_enabled("BPEX_DEBUG")
+
+
+def is_untuned_warning_suppressed():
+ return is_env_option_enabled("BPEX_NO_WARN_ON_UNTUNED_CASE") or testing_is_ci_env()
+
+
+def debugging_fake_benchmark_result():
+ return is_env_option_enabled("BPEX_DEBUGGING_FAKE_BENCHMARK_RESULT")
+
+
+def debugging_is_verbose():
+ return is_env_option_enabled("BPEX_DEBUGGING_VERBOSE")
+
+
+def testing_is_ci_env():
+ return is_env_option_enabled("BPEX_TESTING_IS_CI_ENV")
+
+
+def testing_no_noncontiguous_tensors():
+ return is_env_option_enabled("BPEX_TESTING_NO_NONCONTIGUOUS_TENSORS")
+
+
+def benchmarking_minimal_run():
+ return is_env_option_enabled("BPEX_BENCHMARKING_MINIMAL_RUN") or benchmarking_using_ncu()
+
+
+def benchmarking_no_baseline():
+ return is_env_option_enabled("BPEX_BENCHMARKING_NO_BASELINE") or benchmarking_using_ncu()
+
+
+def benchmarking_using_ncu():
+ return is_env_option_enabled("BPEX_BENCHMARKING_USE_NCU")
+
+
+def benchmarking_write_report():
+ return is_env_option_enabled("BPEX_BENCHMARKING_WRITE_REPORT")
+
+
+def tuning_correctness_check_only():
+ return is_env_option_enabled("BPEX_TUNING_CORRECTNESS_CHECK_ONLY")
diff --git a/lingbotvla/ops/group_gemm/utils/kernel.py b/lingbotvla/ops/group_gemm/utils/kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10e714c1c192232e95f733690c77d062e1bf533
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/utils/kernel.py
@@ -0,0 +1,25 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+
+
+def innermost_fn(fn: triton.KernelInterface):
+ while hasattr(fn, "fn"):
+ fn = fn.fn
+ return fn
+
+
+def qualified_name(fn: triton.KernelInterface) -> str:
+ return innermost_fn(fn).__qualname__
diff --git a/lingbotvla/ops/group_gemm/utils/path.py b/lingbotvla/ops/group_gemm/utils/path.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7d2070cab74158ad84d9fa3d73550830f356c67
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/utils/path.py
@@ -0,0 +1,45 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import os
+
+import triton
+from packaging import version
+
+from .device import get_device_key
+from .kernel import qualified_name
+
+
+def _get_relative_dir_of_triton_kernel(kernel) -> str:
+ path = os.path.relpath(inspect.getfile(kernel), get_bpex_root())
+ return path
+
+
+def get_bpex_root() -> str:
+ path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+ return path
+
+
+def get_config_path_prefix_for(kernel) -> str:
+ v = version.parse(triton.__version__)
+ return (
+ f"{get_bpex_root()}/config/{v.major}.{v.minor}/{get_device_key()}/"
+ f"{_get_relative_dir_of_triton_kernel(kernel)}/{qualified_name(kernel)}"
+ )
+
+
+def get_config_dedicated_file_for(kernel, algo_key) -> str:
+ # The only reason the extension is used is to avoid JSON lint..
+ return f"{get_config_path_prefix_for(kernel)}/{algo_key}.bpex"
diff --git a/lingbotvla/ops/group_gemm/utils/pretuned.py b/lingbotvla/ops/group_gemm/utils/pretuned.py
new file mode 100644
index 0000000000000000000000000000000000000000..d21714a6ae62104e80bfd6a42d1ad241f7cd748e
--- /dev/null
+++ b/lingbotvla/ops/group_gemm/utils/pretuned.py
@@ -0,0 +1,117 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+
+from ....utils import logging
+from . import envvars
+from .config import load_all_configs_for
+from .kernel import innermost_fn, qualified_name
+
+
+logger = logging.get_logger(__name__)
+
+CATCH_ALL_ALGO_KEY = "__CATCH_ALL__"
+
+
+def algo_key_scaled(names, scales, rest_key=None):
+ def key_maker(**kwargs):
+ lower_names = [name.lower() for name in names]
+ temp = []
+ for i, name in enumerate(lower_names):
+ t = name + str(kwargs[names[i]] // scales[i])
+ if scales[i] != 1:
+ t += f"x{scales[i]}"
+ temp.append(t)
+ res = "_".join(temp)
+
+ if rest_key is not None:
+ for k in rest_key:
+ res += f"_{kwargs[k]}"
+ return res
+
+ return key_maker
+
+
+class Pretuned(triton.KernelInterface):
+ def __init__(self, fn, algo_key_maker, configs):
+ self.fn = fn # In case the outer decorator cares.
+ self.kernel_name = qualified_name(fn)
+ self.algo_key_maker = algo_key_maker
+ self.configs = configs
+
+ assert CATCH_ALL_ALGO_KEY in self.configs
+
+ def run(self, *args, **kwargs):
+ algo_key = self.algo_key_maker(**kwargs)
+ if algo_key not in self.configs:
+ if not envvars.is_untuned_warning_suppressed():
+ logger.debug(
+ f"Untuned case (using algo-key [{algo_key}]) is seen when invoking "
+ f"kernel [{qualified_name(self)}], performance may suffer."
+ )
+ extra_kwargs = self.configs[CATCH_ALL_ALGO_KEY]
+ else:
+ extra_kwargs = self.configs[algo_key]
+ return self.fn.run(*args, **kwargs, **extra_kwargs)
+
+
+# TODO: Support using `triton.autotune` as an fallback.
+def pretuned(*, algo_key=None, fallback=None):
+ """Decorator to annotate a Triton kernel as pre-tuned. Hyperparameters are loaded from `PRETUNED`
+ in the same folder as the kernel being defined.
+
+ By default we look up pre-tuned hyperparameters via `kernel_name, device_name`. However, users
+ are allowed to provide `algo_key` option by providing a lambda that converts arguments passed
+ to kernel to a string that's used as a third level key in looking up pre-tuned hyperparameters.
+
+ Note that ONLY named arguments (but not positional arguments) are passed to `algo_key` callback.
+ """
+
+ if algo_key is None:
+
+ def catch_all(**kwargs):
+ return CATCH_ALL_ALGO_KEY
+
+ algo_key = catch_all
+
+ def decorator(fn: triton.KernelInterface):
+ nonlocal algo_key
+ nonlocal fallback
+
+ name = qualified_name(fn)
+ configs = load_all_configs_for(innermost_fn(fn))
+
+ if CATCH_ALL_ALGO_KEY not in configs:
+ # We'd like to find a fallback hyperparameter for each `device`. This is not the same one
+ # as `fallback` provided to `pretuned`. The latter is used when we're running on an untuned
+ # device, while the former is just a catch-all for a specific device.
+ if not envvars.is_untuned_warning_suppressed():
+ import torch
+
+ logger.debug(
+ f"No pre-tuned hyperparameter for kernel [{name}], using fallback config, "
+ "performance may suffer. You may have triton version or device name mismatch. "
+ f"You have triton=={triton.__version__} and device name [{torch.cuda.get_device_name()}]",
+ )
+ configs.update({CATCH_ALL_ALGO_KEY: fallback})
+
+ assert configs[CATCH_ALL_ALGO_KEY] is not None, "No usable fallback hyperparameter for kernel {name}"
+ return Pretuned(
+ fn,
+ algo_key_maker=algo_key,
+ configs=configs,
+ )
+
+ return decorator
diff --git a/lingbotvla/ops/loss.py b/lingbotvla/ops/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c02fa9e646b03f6e6bd48014b806a83281c76cf
--- /dev/null
+++ b/lingbotvla/ops/loss.py
@@ -0,0 +1,85 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..data.constants import IGNORE_INDEX
+from ..distributed.parallel_state import get_parallel_state
+from ..distributed.sequence_parallel import reduce_sequence_parallel_loss
+from ..utils import logging
+from ..utils.import_utils import is_liger_kernel_available
+
+
+logger = logging.get_logger(__name__)
+
+
+def fixed_cross_entropy(
+ source: torch.Tensor,
+ target: torch.Tensor,
+ num_items_in_batch: Optional[torch.Tensor] = None,
+ ignore_index: int = -100,
+ **kwargs,
+) -> torch.Tensor:
+ reduction = "sum" if num_items_in_batch is not None else "mean"
+ loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
+ if reduction == "sum":
+ # just in case users pass an int for num_items_in_batch, which could be the case for custom trainer
+ if torch.is_tensor(num_items_in_batch):
+ num_items_in_batch = num_items_in_batch.to(loss.device)
+ loss = loss / num_items_in_batch
+ return loss
+
+
+fused_linear_cross_entropy = None
+
+if is_liger_kernel_available():
+ from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss # type: ignore
+
+ fused_linear_cross_entropy = LigerFusedLinearCrossEntropyLoss(reduction="mean")
+
+
+def causallm_loss_function(
+ hidden_states: torch.Tensor,
+ weight: torch.Tensor,
+ labels: torch.Tensor,
+ vocab_size: Optional[int] = None,
+ num_items_in_batch: Optional[int] = None,
+ ignore_index: int = -100,
+ shift_labels: Optional[torch.Tensor] = None,
+ **kwargs,
+) -> torch.Tensor:
+ # We don't use shift_labels in causallm
+ assert shift_labels is None
+
+ loss = None
+ logits = None
+
+ if labels is None:
+ logits = F.linear(hidden_states, weight)
+ return loss, logits
+
+ sp_enabled = get_parallel_state().sp_enabled
+
+ # Shift the labels and hidden_states so that tokens < n predict n
+ if not sp_enabled:
+ labels = labels[..., 1:].contiguous()
+ hidden_states = hidden_states[..., :-1, :].contiguous()
+
+ # Flatten the labels and hidden_states
+ labels = labels.view(-1)
+ hidden_states = hidden_states.view(-1, hidden_states.size(-1))
+
+ # Calculate loss
+ if fused_linear_cross_entropy is not None: # use liger kernels
+ loss = fused_linear_cross_entropy(weight, hidden_states, labels)
+ else:
+ logits = F.linear(hidden_states, weight).float()
+ loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
+
+ # Reduce loss when using sp
+ if sp_enabled:
+ num_valid_tokens = (labels != IGNORE_INDEX).sum()
+ loss = reduce_sequence_parallel_loss(loss, num_valid_tokens)
+
+ return loss, logits
diff --git a/lingbotvla/optim/__init__.py b/lingbotvla/optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e14ac09b27178bbdb5dfddd5261c4fdb76a82989
--- /dev/null
+++ b/lingbotvla/optim/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .lr_scheduler import build_lr_scheduler
+from .optimizer import build_optimizer
+
+
+__all__ = ["build_lr_scheduler", "build_optimizer"]
diff --git a/lingbotvla/optim/lr_scheduler.py b/lingbotvla/optim/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd7bc913edef3c2e7241de0c1d4590ccf29d9108
--- /dev/null
+++ b/lingbotvla/optim/lr_scheduler.py
@@ -0,0 +1,185 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import math
+from typing import TYPE_CHECKING, Literal
+
+from torch.optim.lr_scheduler import LambdaLR
+
+from ..utils import logging
+
+
+if TYPE_CHECKING:
+ from torch.optim import Optimizer
+
+
+logger = logging.get_logger(__name__)
+
+
+def build_lr_scheduler(
+ optimizer: "Optimizer",
+ train_steps: int,
+ lr: float = 1e-3,
+ lr_decay_style: Literal["constant", "linear", "cosine", "two_stage"] = "constant",
+ lr_decay_ratio: float = 1.0,
+ lr_warmup_ratio: float = 0.0,
+ lr_min: float = 1e-7,
+ lr_start: float = 0.0,
+):
+ lr_warmup_steps = int(train_steps * lr_warmup_ratio)
+ if lr_decay_style == "constant":
+ return get_constant_schedule_with_warmup(
+ optimizer=optimizer,
+ num_warmup_steps=lr_warmup_steps,
+ lr_start=lr_start,
+ init_lr=lr,
+ )
+
+ if lr_decay_style == "linear":
+ return get_linear_schedule_with_warmup(
+ optimizer=optimizer,
+ num_warmup_steps=lr_warmup_steps,
+ num_training_steps=train_steps,
+ init_lr=lr,
+ lr_start=lr_start,
+ )
+
+ if lr_decay_style == "cosine":
+ return get_cosine_schedule_with_warmup(
+ optimizer=optimizer,
+ num_warmup_steps=lr_warmup_steps,
+ num_training_steps=train_steps,
+ init_lr=lr,
+ lr_decay_ratio=lr_decay_ratio,
+ min_lr=lr_min,
+ lr_start=lr_start,
+ )
+
+ if lr_decay_style == "two_stage":
+ return get_two_stage_constant_schedule_with_warmup(
+ optimizer=optimizer,
+ init_lr=lr,
+ lr_start=lr_start,
+ )
+
+ raise ValueError(f"Unknown learning rate decay style: {lr_decay_style}.")
+
+
+def get_constant_schedule_with_warmup(
+ optimizer: "Optimizer",
+ num_warmup_steps: int,
+ init_lr: float,
+ last_epoch: int = -1,
+ lr_start: float = 0.0,
+):
+ """
+ Creates a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+ increases linearly between 0 and the initial lr set in the optimizer.
+ """
+
+ def _lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return (lr_start + (init_lr - lr_start) * current_step / max(1, num_warmup_steps)) / init_lr
+
+ return 1.0
+
+ return LambdaLR(optimizer, _lr_lambda, last_epoch=last_epoch)
+
+def get_two_stage_constant_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps: int = 100,
+ init_lr: float = 0.0,
+ decay_steps: int = 10_000,
+ decay_ratio: float = 0.2,
+ last_epoch: int = -1,
+ lr_start: float = 0.0,
+):
+ """
+ Two stages constant learning rate schedule with warm up follows OpenVLA-OFT.
+ Defaultly, the learning rate is 0.0 for the first 100 steps and then decay to 0.2 * init_lr after 10_000 steps.
+ """
+ def _lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ warmup_lr = lr_start + (init_lr - lr_start) * (current_step / max(1, num_warmup_steps))
+ return warmup_lr / init_lr
+
+ if current_step >= decay_steps:
+ return decay_ratio
+
+ return 1.0
+
+ return LambdaLR(optimizer, _lr_lambda, last_epoch=last_epoch)
+
+
+def get_linear_schedule_with_warmup(
+ optimizer: "Optimizer",
+ num_warmup_steps: int,
+ num_training_steps: int,
+ init_lr: float,
+ last_epoch: int = -1,
+ min_lr: float = 1e-7,
+ lr_start: float = 0.0,
+):
+ """
+ Creates a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0,
+ after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+ """
+
+ def _lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return (lr_start + (init_lr - lr_start) * current_step / max(1, num_warmup_steps)) / init_lr
+
+ min_lr_ratio = min_lr / init_lr
+ return max(
+ min_lr_ratio,
+ float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)),
+ )
+
+ return LambdaLR(optimizer, _lr_lambda, last_epoch)
+
+
+def get_cosine_schedule_with_warmup(
+ optimizer: "Optimizer",
+ num_warmup_steps: int,
+ num_training_steps: int,
+ init_lr: float,
+ num_cycles: float = 0.5,
+ last_epoch: int = -1,
+ lr_decay_ratio: float = 1.0,
+ min_lr: float = 1e-7,
+ lr_start: float = 0.0,
+):
+ """
+ Creates a schedule with a learning rate that decreases following the values of the cosine function between
+ the initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0
+ and the initial lr set in the optimizer.
+ """
+
+ def lr_lambda(current_step: int):
+ lr_decay_steps = int(num_training_steps * lr_decay_ratio)
+ if current_step < num_warmup_steps:
+ return (lr_start + (init_lr - lr_start) * current_step / max(1, num_warmup_steps)) / init_lr
+
+ min_lr_ratio = min_lr / init_lr
+ if current_step > lr_decay_steps:
+ return min_lr_ratio
+
+ progress = float(current_step - num_warmup_steps) / float(max(1, lr_decay_steps - num_warmup_steps))
+ assert 0 <= progress <= 1
+ factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
+ factor = factor * (1 - min_lr_ratio) + min_lr_ratio
+ return max(0, factor)
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
diff --git a/lingbotvla/optim/optimizer.py b/lingbotvla/optim/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f8ec479d8a3b49ed8817f727318dfc95763e0b7
--- /dev/null
+++ b/lingbotvla/optim/optimizer.py
@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, Optional, Sequence, Tuple
+
+import torch
+import torch.nn as nn
+from torch.optim import AdamW
+from torch.optim.optimizer import Optimizer
+
+from ..utils.import_utils import is_torch_npu_available
+
+
+# https://github.com/meta-llama/llama-recipes/blob/v0.0.4/src/llama_recipes/policies/anyprecision_optimizer.py
+class AnyPrecisionAdamW(Optimizer):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.95),
+ eps=1e-8,
+ weight_decay=0.0,
+ use_kahan_summation=True,
+ momentum_dtype=torch.bfloat16,
+ variance_dtype=torch.bfloat16,
+ compensation_buffer_dtype=torch.bfloat16,
+ ):
+ defaults = {
+ "lr": lr,
+ "betas": betas,
+ "eps": eps,
+ "weight_decay": weight_decay,
+ "use_kahan_summation": use_kahan_summation,
+ "momentum_dtype": momentum_dtype,
+ "variance_dtype": variance_dtype,
+ "compensation_buffer_dtype": compensation_buffer_dtype,
+ }
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """
+ Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
+ """
+
+ if closure is not None:
+ with torch.enable_grad():
+ closure()
+
+ for group in self.param_groups:
+ beta1, beta2 = group["betas"]
+ lr = group["lr"]
+ weight_decay = group["weight_decay"]
+ eps = group["eps"]
+ use_kahan_summation = group["use_kahan_summation"]
+
+ momentum_dtype = group["momentum_dtype"]
+ variance_dtype = group["variance_dtype"]
+ compensation_buffer_dtype = group["compensation_buffer_dtype"]
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+
+ if p.grad.is_sparse:
+ raise RuntimeError("AnyPrecisionAdamW does not support sparse gradients.")
+
+ state = self.state[p]
+ # State initialization
+ if len(state) == 0:
+ state["step"] = torch.tensor(0.0)
+
+ # momentum - EMA of gradient values
+ state["exp_avg"] = torch.zeros_like(p, dtype=momentum_dtype)
+
+ # variance uncentered - EMA of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(p, dtype=variance_dtype)
+
+ # optional Kahan summation - accumulated error tracker
+ if use_kahan_summation:
+ state["compensation"] = torch.zeros_like(p, dtype=compensation_buffer_dtype)
+
+ # Main processing
+ # update the steps for each param group update
+ state["step"] += 1
+ step = state["step"]
+
+ exp_avg = state["exp_avg"]
+ exp_avg_sq = state["exp_avg_sq"]
+ grad = p.grad
+
+ if weight_decay: # weight decay, AdamW style
+ p.data.mul_(1 - lr * weight_decay)
+
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # update momentum
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # update uncentered variance
+
+ bias_correction1 = 1 - beta1**step # adjust using bias1
+ step_size = lr / bias_correction1
+
+ denom_correction = (1 - beta2**step) ** 0.5 # adjust using bias2 and avoids math import
+ centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(eps, alpha=1)
+
+ if use_kahan_summation: # lr update to compensation
+ compensation = state["compensation"]
+ compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
+
+ # update weights with compensation (Kahan summation)
+ # save error back to compensation for next iteration
+ temp_buffer = p.detach().clone()
+ p.data.add_(compensation)
+ compensation.add_(temp_buffer.sub_(p.data))
+ else: # usual AdamW updates
+ p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)
+
+
+def build_optimizer(
+ model: "nn.Module",
+ lr: float = 1e-3,
+ betas: Tuple[float, float] = (0.9, 0.95),
+ eps: float = 1e-8,
+ weight_decay: float = 1e-2,
+ fused: bool = False,
+ optimizer_type: str = "adamw",
+ param_groups: Optional[Sequence[Dict[str, Any]]] = None,
+ post_training=False,
+) -> "torch.optim.Optimizer":
+ if param_groups is None:
+ align_parameters = [
+ name for name, _ in model.named_parameters() if "depth" in name
+ ]
+
+ if len(align_parameters) > 0:
+ lr_gain = 10.0 if not post_training else 1.0
+ param_groups = [
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if (p.requires_grad and n not in align_parameters)
+ ],
+ "lr": lr,
+ },
+ {
+ "params": [
+ p
+ for n, p in model.named_parameters()
+ if (p.requires_grad and n in align_parameters)
+ ],
+ "lr": lr * lr_gain,
+ }
+ ]
+ else:
+ param_groups = filter(lambda p: p.requires_grad, model.parameters())
+
+ if optimizer_type == "adamw":
+ foreach = False if is_torch_npu_available() else (not fused)
+ fused = False if is_torch_npu_available() else fused
+ optim = AdamW(param_groups, lr, betas, eps, weight_decay, fused=fused, foreach=foreach)
+ elif optimizer_type == "anyprecision_adamw":
+ optim = AnyPrecisionAdamW(param_groups, lr, betas, eps, weight_decay)
+ else:
+ raise ValueError("Only adamw and anyprecision_adamw are supported as optimizers.")
+
+ return optim
diff --git a/lingbotvla/schedulers/flow_match.py b/lingbotvla/schedulers/flow_match.py
new file mode 100644
index 0000000000000000000000000000000000000000..f116ab151896f3ea492a0ff11402b65d4c0ce1b3
--- /dev/null
+++ b/lingbotvla/schedulers/flow_match.py
@@ -0,0 +1,98 @@
+# Copyright 2023 Zhongjie Duan
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+
+class FlowMatchScheduler:
+ def __init__(
+ self,
+ num_inference_steps=100,
+ num_train_timesteps=1000,
+ shift=3.0,
+ sigma_max=1.0,
+ sigma_min=0.003 / 1.002,
+ inverse_timesteps=False,
+ extra_one_step=False,
+ reverse_sigmas=False,
+ ):
+ self.num_train_timesteps = num_train_timesteps
+ self.shift = shift
+ self.sigma_max = sigma_max
+ self.sigma_min = sigma_min
+ self.inverse_timesteps = inverse_timesteps
+ self.extra_one_step = extra_one_step
+ self.reverse_sigmas = reverse_sigmas
+ self.set_timesteps(num_inference_steps)
+
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None):
+ if shift is not None:
+ self.shift = shift
+ sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
+ if self.extra_one_step:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
+ else:
+ self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
+ if self.inverse_timesteps:
+ self.sigmas = torch.flip(self.sigmas, dims=[0])
+ self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
+ if self.reverse_sigmas:
+ self.sigmas = 1 - self.sigmas
+ self.timesteps = self.sigmas * self.num_train_timesteps
+ if training:
+ x = self.timesteps
+ y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
+ y_shifted = y - y.min()
+ bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
+ self.linear_timesteps_weights = bsmntw_weighing
+ self.training = True
+ else:
+ self.training = False
+
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ if to_final or timestep_id + 1 >= len(self.timesteps):
+ sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
+ else:
+ sigma_ = self.sigmas[timestep_id + 1]
+ prev_sample = sample + model_output * (sigma_ - sigma)
+ return prev_sample
+
+ def return_to_timestep(self, timestep, sample, sample_stablized):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ model_output = (sample - sample_stablized) / sigma
+ return model_output
+
+ def add_noise(self, original_samples, noise, timestep, micro_batch_size, enable_mixed_precision):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.cpu()
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
+ sigma = self.sigmas[timestep_id]
+ sample = (1 - sigma) * original_samples + sigma * noise
+ return sample
+
+ def training_target(self, sample, noise, timestep):
+ target = noise - sample
+ return target
+
+ def training_weight(self, timestep, micro_batch_size):
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
+ weights = self.linear_timesteps_weights[timestep_id]
+ return weights
diff --git a/lingbotvla/utils/__init__.py b/lingbotvla/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd1e8433dffa0b3ba420be3e346f4f5cd062014
--- /dev/null
+++ b/lingbotvla/utils/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/lingbotvla/utils/arguments.py b/lingbotvla/utils/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3bddb629f6af3596c58fd96b747f42dc3635e85
--- /dev/null
+++ b/lingbotvla/utils/arguments.py
@@ -0,0 +1,851 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Argument utils"""
+
+import argparse
+import json
+import math
+import os
+import sys
+import types
+from collections import defaultdict
+from dataclasses import MISSING, asdict, dataclass, field, fields
+from enum import Enum
+from inspect import isclass
+from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, get_type_hints
+
+import yaml
+
+from . import logging
+
+
+T = TypeVar("T")
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class ModelArguments:
+ config_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the model config. Defaults to `model_path`."},
+ )
+ model_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the pre-trained model. If unspecified, use random init."},
+ )
+ tokenizer_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the tokenizer. Defaults to `config_path`."},
+ )
+ vlm_repo_id: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the VLM. Defaults to None."},
+ )
+ post_training: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether to use post training."},
+ )
+ vocab_size: Optional[int] = field(
+ default=0,
+ metadata={"help": "Vocab size. 257152 is for paligemma in initial pi0."},
+ )
+ incremental_training: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether to apply incremental training."},
+ )
+ depth_incremental_training: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether to re-init depth_align_head."},
+ )
+ adanorm_time: Optional[bool] = field(
+ default=False,
+ metadata={"help": "Whether to apply extra time embed to ada_norm in expert."},
+ )
+ encoders: Dict[Literal["image"], Dict[str, str]] = field(
+ default_factory=dict,
+ metadata={"help": "Multimodal encoder config and weights."},
+ )
+ decoders: Dict[Literal["image"], Dict[str, str]] = field(
+ default_factory=dict,
+ metadata={"help": "Multimodal decoder config and weights."},
+ )
+ input_encoder: Literal["encoder", "decoder"] = field(
+ default="encoder",
+ metadata={"help": "Use encoder to encode input images or use decoder.encoder to encode input images."},
+ )
+ output_encoder: Literal["encoder", "decoder"] = field(
+ default="decoder",
+ metadata={"help": "Use encoder to encode output images or use decoder.encoder to encode output images."},
+ )
+ encode_target: bool = field(
+ default=False,
+ metadata={"help": "Whether to encode target with decoder. Only supports stable diffusion as decoder."},
+ )
+ attn_implementation: Optional[Literal["eager", "sdpa", "flash_attention_2", "flash_attention_3"]] = field(
+ default="flash_attention_2",
+ metadata={"help": "Attention implementation to use."},
+ )
+ moe_implementation: Optional[Literal[None, "eager", "fused"]] = field(
+ default=None,
+ metadata={"help": "MoE implementation to use."},
+ )
+ basic_modules: Optional[List[str]] = field(
+ default_factory=list,
+ metadata={"help": "Basic modules beyond model._no_split_modules to be sharded in FSDP."},
+ )
+ force_use_huggingface: bool = field(
+ default=False,
+ metadata={"help": "Force loading model from huggingface."},
+ )
+ use_lm_head: bool = field(
+ default=False,
+ metadata={"help": "Whether to use lm_head."},
+ )
+ split_gate_liner: bool = field(
+ default=False,
+ metadata={"help": "Whether to split gate liner in adanorm."},
+ )
+ nosplit_gate_liner: bool = field(
+ default=False,
+ metadata={"help": "Whether to nosplit gate liner in adanorm."},
+ )
+ separate_time_proj: bool = field(
+ default=False,
+ metadata={"help": "Whether to split time proj in embed_suffix."},
+ )
+ final_norm_adanorm: bool = field(
+ default=False,
+ metadata={"help": "Whether to use adanorm in final norm."},
+ )
+ old_adanorm: bool = field(
+ default=False,
+ metadata={"help": "Whether to use old adanorm."},
+ )
+ moge_path: str = field(
+ default=None,
+ metadata={"help": "path of MgGe."},
+ )
+ morgbd_path: str = field(
+ default=None,
+ metadata={"help": "path of LingBot-Depth."},
+ )
+
+ def __post_init__(self):
+ if self.config_path is None and self.model_path is None:
+ raise ValueError("`config_path` must be specified when `model_path` is None.")
+
+ if self.config_path is None:
+ self.config_path = self.model_path
+
+ if self.tokenizer_path is None:
+ self.tokenizer_path = self.config_path
+
+ for encoder_type, encoder_args in self.encoders.items():
+ if encoder_type not in ["image"]:
+ raise ValueError(f"Unsupported encoder type: {encoder_type}. Should be one of {{image}}.")
+
+ if encoder_args.get("config_path") is None and encoder_args.get("model_path") is None:
+ raise ValueError("`config_path` and `model_path` cannot be both empty.")
+
+ if encoder_args.get("config_path") is None:
+ encoder_args["config_path"] = encoder_args["model_path"]
+
+ for decoder_type, decoder_args in self.decoders.items():
+ if decoder_type not in ["image"]:
+ raise ValueError(f"Unsupported decoder type: {decoder_type}. Should be one of {{image}}.")
+
+ if decoder_args.get("config_path") is None and decoder_args.get("model_path") is None:
+ raise ValueError("`config_path` and `model_path` cannot be both empty.")
+
+ if decoder_args.get("config_path") is None:
+ decoder_args["config_path"] = decoder_args["model_path"]
+
+
+@dataclass
+class DataArguments:
+ train_path: str = field(
+ metadata={"help": "Path of the training data. Use comma to separate multiple datasets."},
+ )
+ train_size: int = field(
+ default=10_000_000,
+ metadata={"help": "Number of tokens for training to compute training steps for dynamic batch dataloader."},
+ )
+ data_type: Literal["plaintext", "conversation", "diffusion"] = field(
+ default="conversation",
+ metadata={"help": "Type of the training data."},
+ )
+ dataloader_type: Literal["native"] = field(
+ default="native",
+ metadata={"help": "Type of the dataloader."},
+ )
+ datasets_type: Literal["mapping", "iterable", "vla"] = field(
+ default="mapping",
+ metadata={"help": "Type of the datasets."},
+ )
+ data_name: str = field(
+ default=None,
+ metadata={"help": "Dataset name for multimodal training."},
+ )
+ data_root: str = field(
+ default=None,
+ metadata={"help": "Root path of datasets."},
+ )
+ data_tag: Literal["default", "mmtag"] = field(
+ default="default",
+ metadata={"help": "Dataset tag for multimodal training."},
+ )
+ text_keys: str = field(
+ default=None,
+ metadata={"help": "Key to get text from the training data."},
+ )
+ image_keys: str = field(
+ default="images",
+ metadata={"help": "Key to get images from the training data."},
+ )
+ chat_template: str = field(
+ default="default",
+ metadata={"help": "Chat template to use."},
+ )
+ max_seq_len: int = field(
+ default=2048,
+ metadata={"help": "Maximum sequence length in training."},
+ )
+ num_workers: int = field(
+ default=20,
+ metadata={"help": "Number of workers to load data."},
+ )
+ prefetch_factor: int = field(
+ default=4,
+ metadata={"help": "Number of batches loaded in advance by each worker."},
+ )
+ drop_last: bool = field(
+ default=True,
+ metadata={"help": "Whether to drop the last incomplete batch."},
+ )
+ pin_memory: bool = field(
+ default=True,
+ metadata={"help": "Whether to pin memory for dataloader."},
+ )
+
+ def __post_init__(self):
+ if self.text_keys is None:
+ if self.data_type == "plaintext":
+ self.text_keys = "content_split"
+ elif self.data_type == "conversation":
+ self.text_keys = "messages"
+ else:
+ raise ValueError(f"Unknown data type: {self.data_type}")
+
+
+@dataclass
+class TrainingArguments:
+ output_dir: str = field(
+ metadata={"help": "Path to save model checkpoints."},
+ )
+ lr: float = field(
+ default=5e-5,
+ metadata={"help": "Maximum learning rate or defult learning rate, or init learning rate for warmup."},
+ )
+ lr_min: float = field(
+ default=1e-7,
+ metadata={"help": "Minimum learning rate."},
+ )
+ lr_start: float = field(
+ default=0.0,
+ metadata={"help": "Learning rate for warmup start. Default to 0.0."},
+ )
+ weight_decay: float = field(
+ default=0,
+ metadata={"help": "L2 regularization strength."},
+ )
+ optimizer: Literal["adamw", "anyprecision_adamw"] = field(
+ default="adamw",
+ metadata={"help": "Optimizer. Default to adamw."},
+ )
+ max_grad_norm: float = field(
+ default=1.0,
+ metadata={"help": "Clip value for gradient norm."},
+ )
+ micro_batch_size: int = field(
+ default=1,
+ metadata={"help": "Micro batch size. The number of samples per iteration on each device."},
+ )
+ global_batch_size: Optional[int] = field(
+ default=None,
+ metadata={"help": "Global batch size. If None, use `micro_batch_size` * `data_parallel_size`."},
+ )
+ num_train_epochs: int = field(
+ default=1,
+ metadata={"help": "Epochs to train."},
+ )
+ rmpad: bool = field(
+ default=True,
+ metadata={"help": "Enable padding-free training by using the cu_seqlens."},
+ )
+ rmpad_with_pos_ids: bool = field(
+ default=False,
+ metadata={"help": "Enable padding-free training by using the position_ids."},
+ )
+ dyn_bsz: bool = field(
+ default=True,
+ metadata={"help": "Enable dynamic batch size for padding-free training."},
+ )
+ dyn_bsz_margin: int = field(
+ default=0,
+ metadata={"help": "Number of pad tokens in dynamic batch."},
+ )
+ dyn_bsz_buffer_size: int = field(
+ default=200,
+ metadata={"help": "Buffer size for dynamic batch size."},
+ )
+ bsz_warmup_ratio: float = field(
+ default=0,
+ metadata={"help": "Ratio of batch size warmup steps."},
+ )
+ bsz_warmup_init_mbtoken: int = field(
+ default=200,
+ metadata={"help": "Initial number of tokens in a batch in warmup phase."},
+ )
+ lr_warmup_ratio: float = field(
+ default=0,
+ metadata={"help": "Ratio of learning rate warmup steps."},
+ )
+ lr_decay_style: str = field(
+ default="constant",
+ metadata={"help": "Name of the learning rate scheduler."},
+ )
+ lr_decay_ratio: float = field(
+ default=1.0,
+ metadata={"help": "Ratio of learning rate decay steps."},
+ )
+ use_doptim: bool = field(
+ default=False,
+ metadata={"help": "Use veScale's ZeRO optimizer."},
+ )
+ enable_mixed_precision: bool = field(
+ default=True,
+ metadata={"help": "Enable mixed precision training."}, # false -> torch_dtype when loading model is bf16
+ )
+ enable_fp32: bool = field(
+ default=False,
+ metadata={"help": "Enable fp32 training."},
+ )
+ enable_resume: bool = field(
+ default=False,
+ metadata={"help": "Whether to automatically resume training from a checkpoint."},
+ )
+ enable_gradient_checkpointing: bool = field(
+ default=True,
+ metadata={"help": "Enable gradient checkpointing."},
+ )
+ enable_reentrant: bool = field(
+ default=False,
+ metadata={"help": "Use reentrant gradient checkpointing."},
+ )
+ enable_full_shard: bool = field(
+ default=True,
+ metadata={"help": "Enable fully shard for FSDP training (ZeRO-3)."},
+ )
+ enable_forward_prefetch: bool = field(
+ default=True,
+ metadata={"help": "Enable forward prefetch for FSDP1."},
+ )
+ enable_fsdp_offload: bool = field(
+ default=False,
+ metadata={"help": "Enable CPU offload for FSDP1."},
+ )
+ enable_activation_offload: bool = field(
+ default=False,
+ metadata={"help": "Enable activation offload to CPU."},
+ )
+ activation_gpu_limit: float = field(
+ default=0.0,
+ metadata={
+ "help": "When enabling activation offload, `activation_gpu_limit` GB activations are allowed to reserve on GPU."
+ },
+ )
+ enable_manual_eager: bool = field(
+ default=False,
+ metadata={"help": "Enable veScale's manual eager."},
+ )
+ init_device: Literal["cpu", "cuda", "meta"] = field(
+ default="cuda",
+ metadata={
+ "help": "Device to initialize model weights. 1. `cpu`: Init parameters on CPU in rank0 only. 2. `cuda`: Init parameters on GPU. 3. `meta`: Init parameters on meta."
+ },
+ )
+ enable_full_determinism: bool = field(
+ default=False,
+ metadata={"help": "Enable full determinism."},
+ )
+ empty_cache_steps: int = field(
+ default=500,
+ metadata={"help": "Number of steps between two empty cache operations."},
+ )
+ data_parallel_mode: Literal["ddp", "fsdp1", "fsdp2", "fsdp2-vescale"] = field(
+ default="ddp",
+ metadata={"help": "Data parallel mode."},
+ )
+ use_compile: bool = field(
+ default=False,
+ metadata={"help": "wether to enable torch.compile."},
+ )
+ module_fsdp_enable: bool = field(
+ default=True,
+ metadata={"help": "Enable FSDP for module."},)
+ data_parallel_replicate_size: int = field(
+ default=-1,
+ metadata={"help": "Data parallel replicate size."},
+ )
+ data_parallel_shard_size: int = field(
+ default=-1,
+ metadata={"help": "Data parallel shard degree."},
+ )
+ tensor_parallel_size: int = field(
+ default=1,
+ metadata={"help": "Tensor parallel size."},
+ )
+ expert_parallel_size: int = field(
+ default=1,
+ metadata={"help": "Expert parallel size."},
+ )
+ pipeline_parallel_size: int = field(
+ default=1,
+ metadata={"help": "Pipeline parallel size."},
+ )
+ ulysses_parallel_size: int = field(
+ default=1,
+ metadata={"help": "Ulysses sequence parallel size."},
+ )
+ context_parallel_size: int = field(
+ default=1,
+ metadata={"help": "Ring-attn context parallel size."},
+ )
+ ckpt_manager: Literal["bytecheckpoint", "dcp"] = field(
+ default="dcp",
+ metadata={"help": "Checkpoint manager."},
+ )
+ load_checkpoint_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to bytecheckpoint checkpoint to resume from."},
+ )
+ save_steps: int = field(
+ default=0,
+ metadata={"help": "Number of steps between two checkpoint saves."},
+ )
+ save_epochs: int = field(
+ default=1,
+ metadata={"help": "Number of epochs between two checkpoint saves."},
+ )
+ save_hf_weights: bool = field(
+ default=True,
+ metadata={"help": "Save the huggingface format weights to the last checkpoint dir."},
+ )
+ seed: int = field(
+ default=42,
+ metadata={"help": "Random seed."},
+ )
+ use_wandb: bool = field(
+ default=True,
+ metadata={"help": "Use wandb to log experiment."},
+ )
+ wandb_project: str = field(
+ default="LingBotVLA",
+ metadata={"help": "Wandb project name."},
+ )
+ wandb_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "Wandb experiment name."},
+ )
+ enable_profiling: bool = field(
+ default=False,
+ metadata={"help": "Enable profiling."},
+ )
+ profile_start_step: int = field(
+ default=1,
+ metadata={"help": "Start step for profiling."},
+ )
+ profile_end_step: int = field(
+ default=2,
+ metadata={"help": "End step for profiling."},
+ )
+ profile_trace_dir: str = field(
+ default="./trace",
+ metadata={"help": "Direction to export the profiling result."},
+ )
+ profile_record_shapes: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to record the shapes of the input tensors."},
+ )
+ profile_profile_memory: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to profile the memory usage."},
+ )
+ profile_with_stack: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to record the stack traces."},
+ )
+ max_steps: Optional[int] = field(
+ default=None,
+ metadata={"help": "Max training steps per epoch. (for debug)"},
+ )
+
+ def __post_init__(self):
+ self._train_steps = -1
+ self.local_rank = int(os.getenv("LOCAL_RANK"))
+ self.global_rank = int(os.getenv("RANK"))
+ self.world_size = int(os.getenv("WORLD_SIZE"))
+ if (
+ self.world_size
+ % (
+ self.pipeline_parallel_size
+ * self.ulysses_parallel_size
+ * self.context_parallel_size
+ * self.tensor_parallel_size
+ )
+ != 0
+ ):
+ raise ValueError(
+ f"World size should be a multiple of pipeline_parallel_size: {self.pipeline_parallel_size}, ulysses_parallel_size: {self.ulysses_parallel_size}, context_parallel_size: {self.context_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}."
+ )
+ assert self.tensor_parallel_size == 1, "Tensor parallel size not supported yet."
+ assert self.pipeline_parallel_size == 1, "Pipeline parallel size not supported yet."
+ self.data_parallel_size = self.world_size // (
+ self.pipeline_parallel_size
+ * self.ulysses_parallel_size
+ * self.context_parallel_size
+ * self.tensor_parallel_size
+ )
+ # configure data parallel size
+ if self.data_parallel_replicate_size > 0 and self.data_parallel_shard_size > 0:
+ assert self.data_parallel_size == self.data_parallel_replicate_size * self.data_parallel_shard_size, (
+ f"data_parallel_size should be equal to data_parallel_replicate_size: {self.data_parallel_replicate_size} * data_parallel_shard_size: {self.data_parallel_shard_size}."
+ )
+ elif self.data_parallel_replicate_size > 0:
+ if self.data_parallel_size % self.data_parallel_replicate_size != 0:
+ raise ValueError("data_parallel_size should be a multiple of data_parallel_replicate_size.")
+ self.data_parallel_shard_size = self.data_parallel_size // self.data_parallel_replicate_size
+ elif self.data_parallel_shard_size > 0:
+ if self.data_parallel_size % self.data_parallel_shard_size != 0:
+ raise ValueError("data_parallel_size should be a multiple of data_parallel_shard_size.")
+ self.data_parallel_replicate_size = self.data_parallel_size // self.data_parallel_shard_size
+ else:
+ self.data_parallel_replicate_size = 1
+ self.data_parallel_shard_size = self.data_parallel_size
+
+ if self.rmpad and self.rmpad_with_pos_ids:
+ raise ValueError("`rmpad` and `rmpad_with_pos_ids` cannot be both True.")
+
+ # init method check
+ assert self.expert_parallel_size == 1 or self.init_device != "cpu", (
+ "cpu init is not supported when enable ep. Please use `init_device = cuda` or `init_device = meta` instead."
+ )
+
+ # calculate gradient accumulation steps
+ if self.global_batch_size is None:
+ self.global_batch_size = self.micro_batch_size * self.data_parallel_size
+ self.gradient_accumulation_steps = 1
+ logger.info_rank0("`global_batch_size` is None, disable gradient accumulation.")
+ elif self.global_batch_size % (self.micro_batch_size * self.data_parallel_size) == 0:
+ self.gradient_accumulation_steps = self.global_batch_size // (
+ self.micro_batch_size * self.data_parallel_size
+ )
+ logger.info_rank0(f"Set gradient accumulation to {self.gradient_accumulation_steps}.")
+ else:
+ raise ValueError(
+ f"`global_batch_size` should be a multiple of {self.micro_batch_size * self.data_parallel_size}."
+ )
+
+ if self.gradient_accumulation_steps > 1 and self.enable_fsdp_offload:
+ raise ValueError("Gradient accumulation is not supported with FSDP offload.")
+
+ self.dataloader_batch_size = self.global_batch_size // self.data_parallel_size # = micro bsz * grad accu
+
+ # merlin save paths
+ self.save_checkpoint_path = os.path.join(self.output_dir, "checkpoints")
+ self.model_assets_dir = os.path.join(self.output_dir, "model_assets")
+
+ def compute_train_steps(
+ self, max_seq_len: Optional[int] = None, train_size: Optional[int] = None, dataset_length: Optional[int] = None
+ ) -> None:
+ """
+ Computes the training steps per epoch according to the data length.
+ """
+ if self.rmpad or self.rmpad_with_pos_ids:
+ assert max_seq_len is not None and train_size is not None, "max_seq_len and train_size are required."
+ token_micro_bsz = self.micro_batch_size * max_seq_len
+ train_size = int(train_size * (1 + self.bsz_warmup_ratio / 2))
+ eff_token_rate = (token_micro_bsz - self.dyn_bsz_margin) / token_micro_bsz
+ self._train_steps = math.ceil(train_size / (self.global_batch_size * max_seq_len * eff_token_rate))
+ elif dataset_length is not None:
+ self._train_steps = math.floor(dataset_length / (self.dataloader_batch_size * self.world_size)) # assuming drop_last is true
+ elif self.max_steps is not None:
+ self._train_steps = self.max_steps
+ else:
+ raise ValueError("Please provide `dataset_length` or `max_steps`!")
+
+ @property
+ def train_steps(self) -> int:
+ if self.max_steps is not None and self._train_steps >= self.max_steps:
+ logger.warning_once(f"Set train_steps to {self.max_steps}. It should be for debug purpose only.")
+ return self.max_steps
+
+ if self._train_steps == -1:
+ raise ValueError("Please run `compute_train_steps` first!")
+
+ return self._train_steps
+
+
+@dataclass
+class InferArguments:
+ model_path: str = field(
+ metadata={"help": "Path to the pre-trained model."},
+ )
+ tokenizer_path: Optional[str] = field(
+ default=None,
+ metadata={"help": "Path to the tokenizer. Defaults to `config_path`."},
+ )
+ seed: int = field(
+ default=42,
+ metadata={"help": "Random seed."},
+ )
+ do_sample: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to use sampling in decoding."},
+ )
+ temperature: float = field(
+ default=1.0,
+ metadata={"help": "The temperature value of decoding."},
+ )
+ top_p: float = field(
+ default=1.0,
+ metadata={"help": "The top_p value of decoding."},
+ )
+ max_tokens: int = field(
+ default=1024,
+ metadata={"help": "Max tokens to generate."},
+ )
+
+ def __post_init__(self):
+ if self.tokenizer_path is None:
+ self.tokenizer_path = self.model_path
+
+
+def _string_to_bool(value: Union[bool, str]) -> bool:
+ """
+ Converts a string input to bool value.
+
+ Taken from: https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
+ """
+ if isinstance(value, bool):
+ return value
+ if value.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ if value.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ raise argparse.ArgumentTypeError(
+ f"Truthy value expected: got {value} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
+ )
+
+
+def _convert_str_dict(input_dict: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Safely checks that a passed value is a dictionary and converts any string values to their appropriate types.
+
+ Taken from: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/training_args.py#L189
+ """
+ for key, value in input_dict.items():
+ if isinstance(value, dict):
+ input_dict[key] = _convert_str_dict(value)
+ elif isinstance(value, str):
+ if value.lower() in ("true", "false"): # check for bool
+ input_dict[key] = value.lower() == "true"
+ elif value.isdigit(): # check for digit
+ input_dict[key] = int(value)
+ elif value.replace(".", "", 1).isdigit():
+ input_dict[key] = float(value)
+
+ return input_dict
+
+
+def _make_choice_type_function(choices: List[Any]) -> Callable[[str], Any]:
+ """
+ Creates a mapping function from each choices string representation to the actual value. Used to support multiple
+ value types for a single argument.
+
+ Based on: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/hf_argparser.py#L48
+
+ Args:
+ choices (list): List of choices.
+
+ Returns:
+ Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
+ """
+ str_to_choice = {str(choice): choice for choice in choices}
+ return lambda arg: str_to_choice.get(arg, arg)
+
+
+def parse_args(rootclass: T) -> T:
+ """
+ Parses the root argument class using the CLI inputs or yaml inputs.
+
+ Based on: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/hf_argparser.py#L266
+ """
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ base_to_subclass = {}
+ dict_fields = set()
+ list_fields = set()
+ for subclass in fields(rootclass):
+ base = subclass.name
+ base_to_subclass[base] = subclass.default_factory
+ try:
+ type_hints: Dict[str, type] = get_type_hints(subclass.default_factory)
+ except Exception:
+ raise RuntimeError(f"Type resolution failed for {subclass.default_factory}.")
+
+ for attr in fields(subclass.default_factory):
+ if not attr.init:
+ continue
+
+ attr_type = type_hints[attr.name]
+ origin_type = getattr(attr_type, "__origin__", attr_type)
+ if isinstance(attr_type, str):
+ raise RuntimeError(f"Cannot resolve type {attr.type} of {attr.name}.")
+
+ if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)):
+ if len(attr_type.__args__) != 2 or type(None) not in attr_type.__args__: # only allows Optional[X]
+ raise RuntimeError(f"Cannot resolve type {attr.type} of {attr.name}.")
+
+ if bool not in attr_type.__args__: # except for `Union[bool, NoneType]`
+ attr_type = (
+ attr_type.__args__[0] if isinstance(None, attr_type.__args__[1]) else attr_type.__args__[1]
+ )
+ origin_type = getattr(attr_type, "__origin__", attr_type)
+
+ parser_kwargs = attr.metadata.copy()
+ if origin_type is Literal or (isinstance(attr_type, type) and issubclass(attr_type, Enum)):
+ if origin_type is Literal:
+ parser_kwargs["choices"] = attr_type.__args__
+ else:
+ parser_kwargs["choices"] = [x.value for x in attr_type]
+
+ parser_kwargs["type"] = _make_choice_type_function(parser_kwargs["choices"])
+
+ if attr.default is not MISSING:
+ parser_kwargs["default"] = attr.default
+ else:
+ parser_kwargs["required"] = True
+
+ elif attr_type is bool or attr_type == Optional[bool]:
+ parser_kwargs["type"] = _string_to_bool
+ if attr_type is bool or (attr.default is not None and attr.default is not MISSING):
+ parser_kwargs["default"] = False if attr.default is MISSING else attr.default
+ parser_kwargs["nargs"] = "?"
+ parser_kwargs["const"] = True
+
+ elif isclass(origin_type) and issubclass(origin_type, list):
+ parser_kwargs["type"] = attr_type.__args__[0]
+ parser_kwargs["nargs"] = "+"
+ list_fields.add(f"{base}.{attr.name}")
+ if attr.default_factory is not MISSING:
+ parser_kwargs["default"] = attr.default_factory()
+ elif attr.default is MISSING:
+ parser_kwargs["required"] = True
+
+ elif isclass(origin_type) and issubclass(origin_type, dict):
+ parser_kwargs["type"] = str # parse dict inputs with json string
+ dict_fields.add(f"{base}.{attr.name}")
+ if attr.default_factory is not MISSING:
+ parser_kwargs["default"] = str(attr.default_factory())
+ elif attr.default is MISSING:
+ parser_kwargs["required"] = True
+
+ else:
+ parser_kwargs["type"] = attr_type
+ if attr.default is not MISSING:
+ parser_kwargs["default"] = attr.default
+ elif attr.default_factory is not MISSING:
+ parser_kwargs["default"] = attr.default_factory()
+ else:
+ parser_kwargs["required"] = True
+
+ parser.add_argument(f"--{base}.{attr.name}", **parser_kwargs)
+
+ cmd_args = sys.argv[1:]
+ cmd_args_string = "=".join(cmd_args) # use `=` to mark the end of arg name
+ input_data = {}
+ if cmd_args[0].endswith(".yaml") or cmd_args[0].endswith(".yml"):
+ input_path = cmd_args.pop(0)
+ with open(os.path.abspath(input_path), encoding="utf-8") as f:
+ input_data: Dict[str, Dict[str, Any]] = yaml.safe_load(f)
+
+ elif cmd_args[0].endswith(".json"):
+ input_path = cmd_args.pop(0)
+ with open(os.path.abspath(input_path), encoding="utf-8") as f:
+ input_data: Dict[str, Dict[str, Any]] = json.load(f)
+
+ for base, arg_dict in input_data.items():
+ for arg_name, arg_value in arg_dict.items():
+ if f"--{base}.{arg_name}=" not in cmd_args_string: # lower priority
+ # Skip list fields with None values to use default
+ if f"{base}.{arg_name}" in list_fields and arg_value is None:
+ continue
+
+ cmd_args.append(f"--{base}.{arg_name}")
+ if f"{base}.{arg_name}" in list_fields and isinstance(arg_value, list):
+ # For list fields, extend the arguments with individual elements
+ cmd_args.extend([str(item) for item in arg_value])
+ else:
+ cmd_args.append(arg_value if isinstance(arg_value, str) else json.dumps(arg_value))
+
+ args, remaining_args = parser.parse_known_args(cmd_args)
+ if remaining_args:
+ raise ValueError(f"Some specified arguments are not used by the ArgumentParser: {remaining_args}")
+
+ parse_result = defaultdict(dict)
+ for key, value in vars(args).items():
+ if key in dict_fields:
+ if isinstance(value, str) and value.startswith("{"):
+ value = _convert_str_dict(json.loads(value))
+ else:
+ raise ValueError(f"Expect a json string for dict argument, but got {value}")
+
+ base, name = key.split(".", maxsplit=1)
+ parse_result[base][name] = value
+
+ data_classes = {}
+ for base, subclass_type in base_to_subclass.items():
+ data_classes[base] = subclass_type(**parse_result.get(base, {}))
+
+ return rootclass(**data_classes)
+
+
+def save_args(args: T, output_path: str) -> None:
+ """
+ Saves arguments to a json file.
+
+ Args:
+ args (dataclass): Arguments.
+ output_path (str): Output path.
+ """
+
+ local_dir = output_path
+
+ os.makedirs(local_dir, exist_ok=True)
+ local_path = os.path.join(local_dir, "lingbotvla_cli.yaml")
+ with open(local_path, "w") as f:
+ f.write(yaml.safe_dump(asdict(args), default_flow_style=False))
diff --git a/lingbotvla/utils/count_flops.py b/lingbotvla/utils/count_flops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b96a30c15e6abbf492bd172f740d1649401837d5
--- /dev/null
+++ b/lingbotvla/utils/count_flops.py
@@ -0,0 +1,502 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from transformers import PretrainedConfig
+
+from . import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_device_flops(unit="T"):
+ def unit_convert(number, level):
+ units = ["B", "K", "M", "G", "T", "P"]
+ if number <= 0:
+ return number
+ ptr = 0
+ while ptr < len(units) and units[ptr] != level:
+ number /= 1000
+ ptr += 1
+ return number
+
+ device_name = torch.cuda.get_device_name()
+ flops = float("inf") # INF flops for unkown gpu type
+ if "H100" in device_name or "H800" in device_name or "NVIDIA L20X" in device_name:
+ flops = 989e12
+ elif "A100" in device_name or "A800" in device_name:
+ flops = 312e12
+ elif "L40" in device_name:
+ flops = 181.05e12
+ elif "L20" in device_name:
+ flops = 119.5e12
+ elif "H20" in device_name:
+ flops = 148e12
+ elif "910B" in device_name:
+ flops = 354e12
+ flops_unit = unit_convert(flops, unit)
+ return flops_unit
+
+
+class LingBotFlopsCounter:
+ """
+ Used to count mfu during training loop
+
+ Example:
+ flops_counter = LingBotFlopsCounter(config)
+ flops_achieved, flops_promised = flops_counter.estimate_flops(batch_seqlens, delta_time)
+
+ """
+
+ def __init__(self, config: PretrainedConfig):
+ self.estimate_func = {
+ "qwen2_vl": self._estimate_qwen2_vl_flops,
+ "pi0": self._estimate_qwenpi0_flops, # TODO
+ "deepseek_v3": self._estimate_deepseek_v3_flops,
+ "qwen3_moe": self._estimate_qwen3_moe_flops,
+ "llama": self._estimate_llama_flops,
+ "qwen2": self._estimate_qwen2_flops,
+ }
+ self.config = config
+
+ def _estimate_unknown_flops(self, tokens_sum, batch_seqlens, delta_time, **kwargs):
+ return 0
+
+ def compute_llm_flops(self, hidden_size, vocab_size, num_hidden_layers, num_key_value_heads, num_attention_heads, intermediate_size):
+ head_dim = hidden_size // num_attention_heads
+ q_size = num_attention_heads * head_dim
+ k_size = num_key_value_heads * head_dim
+ v_size = num_key_value_heads * head_dim
+ # non-attn per layer parm
+ mlp_N = hidden_size * intermediate_size * 3
+ attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
+ emd_and_lm_head_N = vocab_size * hidden_size * 2
+ # non-attn all_layer parm
+ dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
+ # non-attn all_layer & all_token fwd & bwd flops
+ model_attn_flops = head_dim * num_attention_heads * num_hidden_layers
+ return dense_N, model_attn_flops
+
+ def _estimate_deepseek_v3_flops(self, tokens_sum, batch_seqlens, delta_time):
+ hidden_size = self.config.hidden_size
+ vocab_size = self.config.vocab_size
+ moe_intermediate_size = self.config.moe_intermediate_size
+ num_hidden_layers = self.config.num_hidden_layers
+ first_k_dense_replace = self.config.first_k_dense_replace
+ num_query_heads = self.config.num_attention_heads
+ moe_num_expert = self.config.n_routed_experts
+ moe_topk = self.config.num_experts_per_tok
+ share_expert_num = self.config.n_shared_experts
+ # non-attn per layer parm
+ moe_gata_N = hidden_size * moe_num_expert
+ # moe has fc1_1, fc1_2 and fc2 using SwiGLU in ExpertMlp layer & shared experts
+ moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3
+ # MLA attn
+ attn_linear_N = 0
+ q_head_dim = self.config.qk_nope_head_dim + self.config.qk_rope_head_dim
+ if self.config.q_lora_rank is None:
+ attn_linear_N += hidden_size * num_query_heads * q_head_dim
+ else:
+ attn_linear_N += hidden_size * self.config.q_lora_rank
+ attn_linear_N += num_query_heads * q_head_dim * self.config.q_lora_rank
+ attn_linear_N += hidden_size * (self.config.kv_lora_rank + self.config.qk_rope_head_dim)
+ attn_linear_N += (
+ num_query_heads
+ * (q_head_dim - self.config.qk_rope_head_dim + self.config.v_head_dim)
+ * self.config.kv_lora_rank
+ )
+ attn_linear_N += num_query_heads * self.config.v_head_dim * hidden_size
+ emd_and_lm_head_N = vocab_size * hidden_size * 2
+ # non-attn all_layer parm
+ moe_N = (
+ (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers - first_k_dense_replace)
+ + (hidden_size * self.config.intermediate_size * 3 + attn_linear_N) * first_k_dense_replace
+ + emd_and_lm_head_N
+ )
+ # non-attn all_layer & all_token fwd & bwd flops
+ dense_N_flops = 6 * moe_N * tokens_sum
+ # attn all_layer & all_token fwd & bwd flops
+ seqlen_square_sum = 0
+ for seqlen in batch_seqlens:
+ seqlen_square_sum += seqlen * seqlen * num_hidden_layers
+ attn_qkv_flops = 12 * seqlen_square_sum * q_head_dim * num_query_heads
+ # all_layer & all_token fwd & bwk flops
+ flops_all_token = dense_N_flops + attn_qkv_flops
+ flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
+ return flops_achieved
+
+ def _estimate_qwen3_moe_flops(self, tokens_sum, batch_seqlens, delta_time):
+ hidden_size = self.config.hidden_size
+ vocab_size = self.config.vocab_size
+ moe_intermediate_size = self.config.moe_intermediate_size
+ num_hidden_layers = self.config.num_hidden_layers
+ num_key_value_heads = self.config.num_key_value_heads
+ num_attention_heads = self.config.num_attention_heads
+ moe_intermediate_size = self.config.moe_intermediate_size
+ moe_num_expert = self.config.num_experts
+ moe_topk = self.config.num_experts_per_tok
+
+ head_dim = hidden_size // num_attention_heads
+ q_size = num_attention_heads * head_dim
+ k_size = num_key_value_heads * head_dim
+ v_size = num_key_value_heads * head_dim
+
+ # non-attn per layer parm
+ moe_gata_N = hidden_size * moe_num_expert
+ # moe has gate_proj, up_proj and down_proj using SwiGLU in ExpertMlp layer & shared experts
+ moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk) * 3
+ attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
+ emd_and_lm_head_N = vocab_size * hidden_size * 2
+ # non-attn all_layer parm
+ moe_N = (moe_gata_N + moe_expertmlp_N + attn_linear_N) * (num_hidden_layers) + emd_and_lm_head_N
+ # non-attn all_layer & all_token fwd & bwd flops
+ dense_N_flops = 6 * moe_N * tokens_sum
+
+ # attn all_layer & all_token fwd & bwd flops
+ seqlen_square_sum = 0
+ for seqlen in batch_seqlens:
+ seqlen_square_sum += seqlen * seqlen
+ attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
+
+ # all_layer & all_token fwd & bwk flops
+ flops_all_token = dense_N_flops + attn_qkv_flops
+ flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
+ return flops_achieved
+
+ def _estimate_qwen2_flops(self, tokens_sum, batch_seqlens, delta_time):
+ hidden_size = self.config.hidden_size
+ vocab_size = self.config.vocab_size
+ num_hidden_layers = self.config.num_hidden_layers
+ num_key_value_heads = self.config.num_key_value_heads
+ num_attention_heads = self.config.num_attention_heads
+ intermediate_size = self.config.intermediate_size
+
+ head_dim = hidden_size // num_attention_heads
+ q_size = num_attention_heads * head_dim
+ k_size = num_key_value_heads * head_dim
+ v_size = num_key_value_heads * head_dim
+
+ # non-attn per layer parm
+ # llama use SwiGelu, gate, having up and down linear layer in mlp
+ mlp_N = hidden_size * intermediate_size * 3
+ attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
+ emd_and_lm_head_N = vocab_size * hidden_size * 2
+ # non-attn all_layer parm
+ dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
+ # non-attn all_layer & all_token fwd & bwd flops
+ dense_N_flops = 6 * dense_N * tokens_sum
+
+ # attn all_layer & all_token fwd & bwd flops
+ seqlen_square_sum = 0
+ for seqlen in batch_seqlens:
+ seqlen_square_sum += seqlen * seqlen
+ attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
+
+ # all_layer & all_token fwd & bwd flops
+ flops_all_token = dense_N_flops + attn_qkv_flops
+ flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
+ return flops_achieved
+
+ def _estimate_llama_flops(self, tokens_sum, batch_seqlens, delta_time):
+ hidden_size = self.config.hidden_size
+ vocab_size = self.config.vocab_size
+ num_hidden_layers = self.config.num_hidden_layers
+ num_key_value_heads = self.config.num_key_value_heads
+ num_attention_heads = self.config.num_attention_heads
+ intermediate_size = self.config.intermediate_size
+
+ head_dim = hidden_size // num_attention_heads
+ q_size = num_attention_heads * head_dim
+ k_size = num_key_value_heads * head_dim
+ v_size = num_key_value_heads * head_dim
+
+ # non-attn per layer parm
+ # llama use SwiGelu, gate, having up and down linear layer in mlp
+ mlp_N = hidden_size * intermediate_size * 3
+ attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
+ emd_and_lm_head_N = vocab_size * hidden_size * 2
+ # non-attn all_layer parm
+ dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
+ # non-attn all_layer & all_token fwd & bwd flops
+ dense_N_flops = 6 * dense_N * tokens_sum
+
+ # attn all_layer & all_token fwd & bwd flops
+ seqlen_square_sum = 0
+ for seqlen in batch_seqlens:
+ seqlen_square_sum += seqlen * seqlen
+ attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
+
+ # all_layer & all_token fwd & bwd flops
+ flops_all_token = dense_N_flops + attn_qkv_flops
+ flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
+ return flops_achieved
+
+ def _estimate_pi0_flops(self, tokens_sum, batch_seqlens, delta_time, **kargs):
+ llm_dense_N, llm_model_attn_flops = self.compute_llm_flops(hidden_size = 2048,
+ vocab_size = 257152,
+ num_hidden_layers = 18,
+ num_key_value_heads = 1,
+ num_attention_heads = 8,
+ intermediate_size = 16384,)
+
+ expert_dense_N, expert_model_attn_flops = self.compute_llm_flops(hidden_size = 1024,
+ vocab_size = 0,
+ num_hidden_layers = 18,
+ num_key_value_heads = 1,
+ num_attention_heads = 8,
+ intermediate_size = 4096,)
+ dense_N_flops = 6 * (llm_dense_N + expert_dense_N) * tokens_sum
+ seqlen_square_sum = 0
+ for seqlen in batch_seqlens:
+ seqlen_square_sum += seqlen * seqlen
+ attn_qkv_flops = 12 * seqlen_square_sum * (llm_model_attn_flops + expert_model_attn_flops)
+ # vit flops
+ image_seqlens = kargs.get("image_seqlens", None)
+ if image_seqlens is not None:
+ vit_flops = self.estimate_pi0_vit_flop(image_seqlens)
+ else:
+ vit_flops = 0
+ state_action_seqlens = kargs.get("state_action_seqlens", None)
+ if state_action_seqlens is not None:
+ state_action_dense_N_flops = 6 * (llm_dense_N + expert_dense_N) * sum(state_action_seqlens)
+ state_action_seqlen_square_sum = 0
+ for seqlen in state_action_seqlens:
+ state_action_seqlen_square_sum += seqlen * seqlen
+ state_action_attn_qkv_flops = 12 * state_action_seqlen_square_sum * (llm_model_attn_flops + expert_model_attn_flops)
+ else:
+ state_action_dense_N_flops, state_action_attn_qkv_flops = 0, 0
+ # all_layer & all_token fwd & bwd flops
+ flops_all_token = dense_N_flops + attn_qkv_flops + vit_flops + state_action_dense_N_flops + state_action_attn_qkv_flops
+ flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
+ return flops_achieved
+
+ def _estimate_qwenpi0_flops(self, tokens_sum, batch_seqlens, delta_time, **kargs):
+ llm_dense_N, llm_model_attn_flops = self.compute_llm_flops(hidden_size = 2048,
+ vocab_size = 151936,
+ num_hidden_layers = 36,
+ num_key_value_heads = 2,
+ num_attention_heads = 16,
+ intermediate_size = 11008,)
+
+ expert_dense_N, expert_model_attn_flops = self.compute_llm_flops(hidden_size = 768,
+ vocab_size = 0,
+ num_hidden_layers = 36, # same
+ num_key_value_heads = 2, # same
+ num_attention_heads = 16, # same
+ intermediate_size = 2752,) # /4
+ dense_N_flops = 6 * (llm_dense_N + expert_dense_N) * tokens_sum
+ seqlen_square_sum = 0
+ for seqlen in batch_seqlens:
+ seqlen_square_sum += seqlen * seqlen
+ attn_qkv_flops = 12 * seqlen_square_sum * (llm_model_attn_flops + expert_model_attn_flops)
+ # vit flops
+ image_seqlens = kargs.get("image_seqlens", None)
+ if image_seqlens is not None:
+ vit_flops = self.estimate_qwen2_5vlvit_flop(image_seqlens)
+ else:
+ vit_flops = 0
+ state_action_seqlens = kargs.get("state_action_seqlens", None)
+ if state_action_seqlens is not None:
+ state_action_dense_N_flops = 6 * (llm_dense_N + expert_dense_N) * sum(state_action_seqlens)
+ state_action_seqlen_square_sum = 0
+ for seqlen in state_action_seqlens:
+ state_action_seqlen_square_sum += seqlen * seqlen
+ state_action_attn_qkv_flops = 12 * state_action_seqlen_square_sum * (llm_model_attn_flops + expert_model_attn_flops)
+ else:
+ state_action_dense_N_flops, state_action_attn_qkv_flops = 0, 0
+ # all_layer & all_token fwd & bwd flops
+ flops_all_token = dense_N_flops + attn_qkv_flops + vit_flops + state_action_dense_N_flops + state_action_attn_qkv_flops
+ flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
+ return flops_achieved
+
+ def _estimate_qwen2_vl_flops(self, tokens_sum, batch_seqlens, delta_time, **kargs):
+ hidden_size = self.config.hidden_size
+ vocab_size = self.config.vocab_size
+ num_hidden_layers = self.config.num_hidden_layers
+ num_key_value_heads = self.config.num_key_value_heads
+ num_attention_heads = self.config.num_attention_heads
+ intermediate_size = self.config.intermediate_size
+
+ head_dim = hidden_size // num_attention_heads
+ q_size = num_attention_heads * head_dim
+ k_size = num_key_value_heads * head_dim
+ v_size = num_key_value_heads * head_dim
+
+ # non-attn per layer parm
+ mlp_N = hidden_size * intermediate_size * 3
+ attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
+ emd_and_lm_head_N = vocab_size * hidden_size * 2
+ # non-attn all_layer parm
+ dense_N = (mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
+ # non-attn all_layer & all_token fwd & bwd flops
+ dense_N_flops = 6 * dense_N * tokens_sum
+
+ # attn all_layer & all_token fwd & bwd flops
+ seqlen_square_sum = 0
+ for seqlen in batch_seqlens:
+ seqlen_square_sum += seqlen * seqlen
+ attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
+
+ # vit flops
+ image_seqlens = kargs.get("image_seqlens", None)
+ if image_seqlens is not None:
+ vit_flops = self.estimate_vit_flop(image_seqlens, self.config.vision_config)
+ else:
+ vit_flops = 0
+
+ # all_layer & all_token fwd & bwd flops
+ flops_all_token = dense_N_flops + attn_qkv_flops + vit_flops
+ flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
+ return flops_achieved
+
+ def estimate_qwen2_5vlvit_flop(self, image_seqlens):
+ """
+ Estimate the FLOPS of the vision encoder for Qwen2 and Qwen2.5
+ """
+
+ tokens_sum = sum(image_seqlens)
+
+ num_heads = 16
+ depth = 32
+
+ # In Qwen2 VL and Qwen2.5VL, the parameters naming are different:
+ #
+ # Parameter | Qwen2 VL | Qwen2.5 VL
+ # --------------------------|------------------|------------------
+ # ViT hidden dimension | embed_dim | hidden_size
+ # ViT output dimension | hidden_size | out_hidden_size
+ # ViT MLP intermediate dim | embed_dim * mlp_ratio | intermediate_size
+ #
+ # See https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/config.json
+ # and https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/config.json for an example.
+ dim = 1280
+ mlp_hidden_dim = 3420
+ out_hidden_size = 2048
+
+ spatial_merge_size = 2
+ head_dim = dim // num_heads
+
+ # Qwen 2.5 VL uses SiLU, thus 3.
+ mlp_N = dim * mlp_hidden_dim * 3
+ attn_linear_N = dim * (4 * dim) # qkv and output proj
+ patch_embed_and_merger_N = (out_hidden_size + (dim * (spatial_merge_size**2))) * (
+ dim * (spatial_merge_size**2)
+ )
+
+ # non-attn all_layer parm
+ dense_N = (mlp_N + attn_linear_N) * depth + patch_embed_and_merger_N
+
+ # non-attn all_layer & all_token fwd & bwd flops
+ dense_N_flops = 6 * dense_N * tokens_sum
+
+ # In Qwen2.5 VL, windowed attention is used in some layers.
+ full_attn_layer_num = 4
+ window_attn_layer_num = 32 - full_attn_layer_num
+
+ # full attn layer & all_token fwd & bwd flops
+ seqlen_square_sum = 0
+ for seqlen in image_seqlens:
+ seqlen_square_sum += seqlen * seqlen
+ attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_heads * full_attn_layer_num
+
+ # If window attention is used, add the window attention flops
+ if window_attn_layer_num > 0:
+ window_attn_compute_flops = 12 * tokens_sum * (112**2) * head_dim * num_heads
+ attn_qkv_flops += window_attn_compute_flops * window_attn_layer_num
+
+ vit_flops = dense_N_flops + attn_qkv_flops
+
+ return vit_flops
+
+ def estimate_vit_flop(self, image_seqlens, config):
+ if config is None:
+ return 0
+ tokens_sum = sum(image_seqlens)
+
+ num_heads = config.num_heads
+ depth = config.depth
+ dim = config.embed_dim
+ hidden_size = config.hidden_size
+ spatial_merge_size = config.spatial_merge_size
+ head_dim = dim // num_heads
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
+
+ mlp_N = dim * mlp_hidden_dim * 2
+ attn_linear_N = dim * (4 * dim) # qkv and output proj
+ patch_embed_and_merger_N = (hidden_size + (dim * (spatial_merge_size**2))) * (dim * (spatial_merge_size**2))
+
+ # non-attn all_layer parm
+ dense_N = (mlp_N + attn_linear_N) * depth + patch_embed_and_merger_N
+
+ # non-attn all_layer & all_token fwd & bwd flops
+ dense_N_flops = 6 * dense_N * tokens_sum
+
+ # attn all_layer & all_token fwd & bwd flops
+ seqlen_square_sum = 0
+ for seqlen in image_seqlens:
+ seqlen_square_sum += seqlen * seqlen
+ attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_heads * depth
+
+ vit_flops = dense_N_flops + attn_qkv_flops
+
+ return vit_flops
+
+ def estimate_pi0_vit_flop(self, image_seqlens):
+ tokens_sum = sum(image_seqlens)
+
+ num_heads = 16
+ depth = 27
+ dim = 2048
+ head_dim = dim // num_heads
+ mlp_hidden_dim = 4304
+
+ mlp_N = dim * mlp_hidden_dim * 2
+ attn_linear_N = dim * (4 * dim) # qkv and output proj
+ patch_embed_and_merger_N = (dim + dim) * dim
+
+ # non-attn all_layer parm
+ dense_N = (mlp_N + attn_linear_N) * depth + patch_embed_and_merger_N
+
+ # non-attn all_layer & all_token fwd & bwd flops
+ dense_N_flops = 6 * dense_N * tokens_sum
+
+ # attn all_layer & all_token fwd & bwd flops
+ seqlen_square_sum = 0
+ for seqlen in image_seqlens:
+ seqlen_square_sum += seqlen * seqlen
+ attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_heads * depth
+
+ vit_flops = dense_N_flops + attn_qkv_flops
+
+ return vit_flops
+
+ def estimate_flops(self, batch_seqlens, delta_time, **kwargs):
+ """
+ Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
+
+ Args:
+ batch_seqlens (List[int]): A list where each element represents the number of valid tokens in the current batch.
+ delta_time (float): The time taken to process the batch, in seconds.
+
+ Returns:
+ estimated_flops (float): The estimated FLOPS based on the input tokens and time.
+ promised_flops (float): The expected FLOPS of the current device.
+ """
+ tokens_sum = sum(batch_seqlens)
+ func = self.estimate_func.get(self.config.model_type, self._estimate_unknown_flops)
+ estimated_flops = func(tokens_sum, batch_seqlens, delta_time, **kwargs)
+ promised_flops = get_device_flops()
+ return estimated_flops, promised_flops
diff --git a/lingbotvla/utils/dist_utils.py b/lingbotvla/utils/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcd5a15d7f8a20ea63a52d3479786642231b8fc6
--- /dev/null
+++ b/lingbotvla/utils/dist_utils.py
@@ -0,0 +1,94 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from contextlib import contextmanager
+from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Union
+
+import torch
+from torch import distributed as dist
+
+
+if TYPE_CHECKING:
+ from torch.distributed import ProcessGroup
+
+
+def all_gather(tensor: "torch.Tensor", world_size: int) -> "torch.Tensor":
+ """
+ Gathers the tensor from all ranks and concats them along the first dim.
+ """
+ output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device="cuda")
+ dist.all_gather_into_tensor(output_tensor, tensor)
+ return output_tensor.view(-1, *tensor.size()[1:])
+
+
+def all_reduce(
+ data: Union[int, float, List[Union[int, float]], "torch.Tensor"],
+ op: Literal["mean", "sum", "max"] = "mean",
+ group: Optional["ProcessGroup"] = None,
+) -> Union[int, float, List[Union[int, float]]]:
+ """
+ Performs all reduce in the given process group.
+ """
+ if not isinstance(data, torch.Tensor):
+ data = torch.tensor(data, dtype=torch.float, device="cuda")
+
+ reduce_ops = {"mean": dist.ReduceOp.SUM, "sum": dist.ReduceOp.SUM, "max": dist.ReduceOp.MAX}
+ dist.all_reduce(data, op=reduce_ops[op], group=group)
+ if op == "mean": # ReduceOp.AVG is not supported by the NPU backend
+ data /= dist.get_world_size(group=group)
+
+ if data.numel() == 1:
+ return data.item()
+ else:
+ return data.tolist()
+
+
+@contextmanager
+def main_process_first(local_only: bool = True) -> None:
+ """
+ A context manager for torch distributed environment to do something on the main process firstly.
+ """
+ if int(os.getenv("WORLD_SIZE", "1")) > 1:
+ is_main_process = int(os.getenv("LOCAL_RANK")) == 0 if local_only else int(os.getenv("RANK")) == 0
+ try:
+ if not is_main_process:
+ dist.barrier()
+ yield
+ finally:
+ if is_main_process:
+ dist.barrier()
+ else:
+ yield
+
+
+def execute_in_order(task: Callable, *, local_only: bool = True, **kwargs) -> Any:
+ """
+ Executes the task in the order of rank.
+ """
+ world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1") if local_only else os.getenv("WORLD_SIZE", "1"))
+ rank = int(os.getenv("LOCAL_RANK", "1") if local_only else os.getenv("RANK", "1"))
+ if world_size > 1:
+ dist.barrier()
+ for i in range(world_size):
+ if rank == i:
+ result = task(**kwargs)
+ dist.barrier()
+ else:
+ dist.barrier()
+
+ return result
+ else:
+ return task(**kwargs)
diff --git a/lingbotvla/utils/dit_utils.py b/lingbotvla/utils/dit_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5138e05961b4aee296025ba016a59f277aed8b3
--- /dev/null
+++ b/lingbotvla/utils/dit_utils.py
@@ -0,0 +1,238 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from collections import OrderedDict
+from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+
+from .helper import EnvironMeter as OriginalEnvironMeter
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig
+
+from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
+from torch import distributed as dist
+from transformers.utils.import_utils import is_safetensors_available
+
+from ..models.module_utils import _save_state_dict
+from . import logging
+from .helper import empty_cache, get_dtype_size
+
+
+if is_safetensors_available():
+ pass
+
+
+if TYPE_CHECKING:
+ from transformers import GenerationConfig, PretrainedConfig, PreTrainedTokenizer, ProcessorMixin
+
+ ModelAssets = Union[GenerationConfig, PretrainedConfig, PreTrainedTokenizer, ProcessorMixin]
+
+
+logger = logging.get_logger(__name__)
+
+
+def _compute_wan_seqlens(
+ micro_batch: Dict[str, "torch.Tensor"], rmpad: bool, rmpad_with_pos_ids: bool
+) -> Tuple[List[int], Optional[List[int]]]:
+ """
+ Computes the sequence lengths of the current batch.
+
+ Args:
+ micro_batch (Dict[str, Tensor]): The current batch.
+ rmpad (bool): Whether to remove the padding tokens.
+ rmpad_with_pos_ids (bool): Whether to remove the padding tokens using the position ids.
+ """
+ latent_shape = micro_batch["latents"].shape
+ if len(latent_shape) == 5:
+ B = latent_shape[0]
+ else:
+ B = 1
+ C, T, H, W = latent_shape[-4:]
+ T_out = int((T - 1) / 1 + 1)
+ H_out = int((H - 2) / 2 + 1)
+ W_out = int((W - 2) / 2 + 1)
+ seqlens = B * T_out * H_out * W_out
+ return [seqlens]
+
+
+def _compute_flux_seqlens(micro_batch: Dict[str, "torch.Tensor"]) -> Tuple[List[int], Optional[List[int]]]:
+ """
+ Computes the sequence lengths of the current batch.
+
+ Args:
+ micro_batch (Dict[str, Tensor]): The current batch.
+ """
+ B, C, H, W = micro_batch.shape
+ H_out = int((H - 2) / 2 + 1)
+ W_out = int((W - 2) / 2 + 1)
+ seqlens = B * H_out * W_out
+ return [seqlens]
+
+
+class EnvironMeter(OriginalEnvironMeter):
+ """
+ Computes the metrics about the training efficiency.
+
+ Args:
+ config (PretrainedConfig): The configuration of the model.
+ global_batch_size (int): The global batch size.
+ empty_cache_steps (int, optional): The number of steps to empty the cache. Defaults to 500.
+ """
+
+ def __init__(
+ self,
+ config: "PretrainedConfig",
+ global_batch_size: int,
+ empty_cache_steps: int = 500,
+ ) -> None:
+ super().__init__(config, global_batch_size, empty_cache_steps=empty_cache_steps)
+
+ def add(self, micro_batch: Dict[str, "torch.Tensor"], model_type: Optional[str] = None) -> None:
+ if model_type == "wan":
+ seqlens = _compute_wan_seqlens(micro_batch, self.rmpad, self.rmpad_with_pos_ids)
+ elif model_type == "flux":
+ seqlens = _compute_flux_seqlens(micro_batch)
+ else:
+ raise ValueError(f"model_type {model_type} not supported")
+
+ self.batch_seqlens.extend(seqlens)
+
+
+def _get_shard_info(
+ state_dict: Dict[str, "torch.Tensor"],
+ save_dtype: Optional[Union[str, "torch.dtype"]],
+ shard_size: int,
+ safe_serialization: bool,
+) -> Tuple[bool, int, Dict[str, str]]:
+ """
+ Gets the shard information, should be executed at rank 0.
+ """
+ current_size, total_size = 0, 0
+ current_shard, shard_list = [], []
+ for name, tensor in state_dict.items():
+ if isinstance(save_dtype, str):
+ dtype = getattr(torch, save_dtype)
+ elif isinstance(save_dtype, torch.dtype):
+ dtype = save_dtype
+ else:
+ dtype = tensor.dtype
+ tensor_size = tensor.numel() * get_dtype_size(dtype) # dtensor's numel == tensor's numel
+ if current_size != 0 and current_size + tensor_size > shard_size:
+ total_size += current_size
+ shard_list.append(current_shard)
+ current_size = 0
+ current_shard = []
+
+ current_size += tensor_size
+ current_shard.append(name)
+
+ if current_size != 0:
+ total_size += current_size
+ shard_list.append(current_shard)
+
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+
+ num_shards = len(shard_list)
+ weight_map = OrderedDict()
+ is_sharded = None
+ if num_shards == 1:
+ is_sharded = False
+ for name in shard_list[0]:
+ weight_map[name] = weights_name
+ else:
+ is_sharded = True
+ for shard_idx, shard in enumerate(shard_list):
+ prefix, extension = weights_name.rsplit(".", maxsplit=1)
+ file_name = f"{prefix}-{shard_idx + 1:05d}-of-{num_shards:05d}.{extension}"
+ for name in shard:
+ weight_map[name] = file_name
+
+ return is_sharded, total_size, weight_map
+
+
+@torch.no_grad()
+def save_model_weights(
+ output_dir: Union[str, "os.PathLike"],
+ state_dict: Dict[str, "torch.Tensor"],
+ global_rank: Optional[int] = None,
+ save_dtype: Optional[Union[str, "torch.dtype"]] = "bfloat16",
+ shard_size: int = 5_000_000_000,
+ safe_serialization: bool = True,
+ model_assets: Optional[Sequence["ModelAssets"]] = None,
+) -> None:
+ """
+ Saves full model weights. The model parameters should be either tensor or dtensor.
+
+ If global_rank is given, it will assume it is executed on all ranks.
+ """
+
+ os.makedirs(output_dir, exist_ok=True)
+ is_sharded, total_size, weight_map = _get_shard_info(state_dict, save_dtype, shard_size, safe_serialization)
+ full_state_dict = OrderedDict()
+ prev_file_name = None
+ for name, tensor in state_dict.items():
+ if hasattr(tensor.data, "full_tensor"): # dtensor
+ tensor = tensor.data.full_tensor()
+ else:
+ tensor = tensor.data
+
+ if save_dtype:
+ tensor = tensor.to(dtype=getattr(torch, save_dtype) if isinstance(save_dtype, str) else save_dtype)
+
+ if prev_file_name is not None and weight_map[name] != prev_file_name:
+ if global_rank is None or global_rank == 0:
+ _save_state_dict(full_state_dict, os.path.join(output_dir, prev_file_name), safe_serialization)
+ full_state_dict = OrderedDict()
+
+ empty_cache()
+ if global_rank is not None and dist.is_initialized(): # avoid process hanging
+ torch.cuda.synchronize()
+ dist.barrier()
+
+ if global_rank is None or global_rank == 0:
+ full_state_dict[name] = tensor.detach().cpu()
+
+ prev_file_name = weight_map[name]
+ del tensor
+
+ if global_rank is None or global_rank == 0:
+ if len(full_state_dict):
+ _save_state_dict(full_state_dict, os.path.join(output_dir, prev_file_name), safe_serialization)
+
+ if is_sharded:
+ index = {
+ "metadata": {"total_size": total_size},
+ "weight_map": weight_map,
+ }
+
+ index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
+ with open(os.path.join(output_dir, index_file), "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+
+ logger.info(f"Model weight splits saved in {output_dir}.")
+ else:
+ logger.info(f"Model weights saved at {os.path.join(output_dir, prev_file_name)}.")
+
+ if model_assets is not None:
+ for model_asset in model_assets:
+ if hasattr(model_asset, "save_pretrained"):
+ model_asset.save_pretrained(output_dir)
+ else:
+ logger.warning(f"Model asset {model_asset} should implement `save_pretrained`.")
diff --git a/lingbotvla/utils/ema.py b/lingbotvla/utils/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6018031bd4907628a45ef763fd8f2051870a05
--- /dev/null
+++ b/lingbotvla/utils/ema.py
@@ -0,0 +1,12 @@
+import torch
+import torch.nn as nn
+
+def ema_update(model_dest: nn.Module, model_src: nn.Module, rate):
+ param_dict_src = dict(model_src.named_parameters())
+ for p_name, p_dest in model_dest.named_parameters():
+ # p_src = param_dict_src[p_name].clone()
+ p_src = param_dict_src[p_name]
+ assert p_src is not p_dest
+ assert p_dest.data.dtype == torch.float32
+ p_dest.data.mul_(rate).add_((1 - rate) * p_src.data.float())
+ # p_dest.data.mul_(rate).add_((1 - rate) * p_src.data)
\ No newline at end of file
diff --git a/lingbotvla/utils/helper.py b/lingbotvla/utils/helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..feb3e062dca459b28b9d89a01130476b65ae1ab4
--- /dev/null
+++ b/lingbotvla/utils/helper.py
@@ -0,0 +1,390 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Helper utils"""
+
+import gc
+import logging as builtin_logging
+import os
+import sys
+from functools import lru_cache
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import psutil
+import torch
+import torch.distributed as dist
+import transformers
+from torch import nn
+from transformers import enable_full_determinism
+from transformers import set_seed as set_seed_func
+
+from ..distributed.parallel_state import get_parallel_state
+from . import logging
+from .count_flops import LingBotFlopsCounter
+from .dist_utils import all_reduce
+from .import_utils import is_torch_npu_available
+from .seqlen_pos_transform_utils import culen2len, pos2culen
+
+
+if is_torch_npu_available():
+ import torch_npu # noqa: F401 # type: ignore
+ from torch_npu.contrib import transfer_to_npu # noqa: F401 # type: ignore
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig
+
+
+logger = logging.get_logger(__name__)
+
+CACHE_DIR = os.path.expanduser(os.getenv("CACHE_DIR", os.path.join("~/.cache", "lingbotvla")))
+
+
+def _compute_seqlens(
+ micro_batch: Dict[str, "torch.Tensor"], rmpad: bool, rmpad_with_pos_ids: bool
+) -> Tuple[List[int], Optional[List[int]]]:
+ """
+ Computes the sequence lengths of the current batch.
+
+ Args:
+ micro_batch (Dict[str, Tensor]): The current batch.
+ rmpad (bool): Whether to remove the padding tokens.
+ rmpad_with_pos_ids (bool): Whether to remove the padding tokens using the position ids.
+ """
+ if 'attention_mask' in micro_batch:
+ attention_mask = micro_batch["attention_mask"]
+ if rmpad:
+ seqlens = culen2len(micro_batch["cu_seqlens"]).tolist()
+ seqlens = seqlens[:-1] if (attention_mask == 0).any().item() else seqlens
+ elif rmpad_with_pos_ids:
+ seqlens = culen2len(pos2culen(micro_batch["position_ids"])).tolist()
+ seqlens = seqlens[:-1] if (attention_mask == 0).any().item() else seqlens
+ else:
+ seqlens = attention_mask.sum(-1).tolist()
+ elif 'lang_masks' in micro_batch:
+ attention_mask = micro_batch["lang_masks"]
+ seqlens = attention_mask.sum(-1).tolist()
+
+ return seqlens
+
+
+class EnvironMeter:
+ """
+ Computes the metrics about the training efficiency.
+
+ Args:
+ config (PretrainedConfig): The configuration of the model.
+ global_batch_size (int): The global batch size.
+ rmpad (bool, optional): Whether to remove the padding tokens. Defaults to False.
+ rmpad_with_pos_ids (bool, optional): Whether to remove the padding tokens using the position ids. Defaults to False.
+ empty_cache_steps (int, optional): The number of steps to empty the cache. Defaults to 500.
+ """
+
+ def __init__(
+ self,
+ config: "PretrainedConfig",
+ global_batch_size: int,
+ rmpad: bool = False,
+ rmpad_with_pos_ids: bool = False,
+ empty_cache_steps: int = 500,
+ ) -> None:
+ self.config = config
+ self.global_batch_size = global_batch_size
+ self.rmpad = rmpad
+ self.rmpad_with_pos_ids = rmpad_with_pos_ids
+ self.empty_cache_steps = empty_cache_steps
+ self.world_size = dist.get_world_size()
+ self.consume_tokens = 0
+ self.batch_seqlens = []
+ self.image_seqlens = []
+
+ self.estimate_flops = LingBotFlopsCounter(config).estimate_flops
+
+ def state_dict(self) -> Dict[str, Any]:
+ state_dict = {"consume_tokens": self.consume_tokens}
+
+ return state_dict
+
+ def load_state_dict(self, state_dict: Dict[str, Any]):
+ self.consume_tokens = state_dict["consume_tokens"]
+
+ def add(self, micro_batch: Dict[str, "torch.Tensor"]) -> None:
+ seqlens = _compute_seqlens(micro_batch, self.rmpad, self.rmpad_with_pos_ids)
+ if "image_grid_thw" in micro_batch:
+ image_grid_thw = micro_batch["image_grid_thw"]
+ image_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0])
+ self.image_seqlens.extend(image_seqlens.tolist())
+ elif 'lang_masks' in micro_batch:
+ self.image_seqlens = [768] * micro_batch['images'].size(0)
+ self.state_action_seqlens = [51] * micro_batch['images'].size(0)
+ self.batch_seqlens.extend(seqlens)
+
+ def step(self, delta_time: float, global_step: int) -> Dict[str, Any]:
+ if len(self.image_seqlens) > 0:
+ flops_achieved, flops_promised = self.estimate_flops(
+ self.batch_seqlens, delta_time, image_seqlens=self.image_seqlens
+ )
+ else:
+ flops_achieved, flops_promised = self.estimate_flops(self.batch_seqlens, delta_time)
+
+ flops_achieved, batch_tokens, real_global_batch_size = all_reduce(
+ (flops_achieved, sum(self.batch_seqlens), len(self.batch_seqlens)),
+ op="sum",
+ group=get_parallel_state().dp_group,
+ )
+ flops_promised = flops_promised * self.world_size
+ mfu = flops_achieved / flops_promised
+
+ # calculate average effective len and tokens per second
+ avg_effective_len = batch_tokens / self.global_batch_size
+ avg_sample_seq_len = batch_tokens / real_global_batch_size
+ tokens_per_second = batch_tokens / delta_time
+ self.consume_tokens += batch_tokens
+
+ # cuda memory
+ allocated_memory = torch.cuda.max_memory_allocated()
+ reserved_memory = torch.cuda.max_memory_reserved()
+ num_alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"]
+ allocated_memory, reserved_memory, num_alloc_retries = all_reduce(
+ (allocated_memory, reserved_memory, num_alloc_retries), op="max"
+ )
+
+ # cpu memory
+ cpu_memory_info = psutil.virtual_memory()
+
+ metrics = {
+ "flops_achieved(T)": flops_achieved,
+ "flops_promised(T)": flops_promised,
+ "mfu": mfu,
+ "training/avg_effective_len": avg_effective_len,
+ "training/avg_sample_seq_len": avg_sample_seq_len,
+ "tokens_per_second(M)": tokens_per_second / 1e6,
+ "consume_tokens(M)": self.consume_tokens / 1e6,
+ "consume_tokens(B)": self.consume_tokens / 1e9,
+ "max_memory_allocated(GB)": allocated_memory / (1024**3),
+ "max_memory_reserved(GB)": reserved_memory / (1024**3),
+ "cpu_used_memory(GB)": cpu_memory_info.used / (1024**3),
+ "cpu_available_memory(GB)": cpu_memory_info.available / (1024**3),
+ "cpu_memory_usage(%)": cpu_memory_info.percent,
+ "num_alloc_retries": num_alloc_retries,
+ }
+
+ if self.empty_cache_steps > 0 and global_step % self.empty_cache_steps == 0:
+ empty_cache()
+
+ self.batch_seqlens = []
+ self.image_seqlens = []
+
+ return metrics
+
+
+def enable_high_precision_for_bf16():
+ """
+ Set high accumulation dtype for matmul and reduction.
+ """
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
+
+
+def set_seed(seed: int, full_determinism: bool = False) -> None:
+ """
+ Sets a manual seed on all devices.
+ """
+ if full_determinism:
+ enable_full_determinism(seed)
+ else:
+ set_seed_func(seed)
+
+
+def create_logger(name: Optional[str] = None) -> "logging._Logger":
+ """
+ Creates a pretty logger for the third-party program.
+ """
+ logger = builtin_logging.getLogger(name)
+ formatter = builtin_logging.Formatter(
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
+ )
+ handler = builtin_logging.StreamHandler(sys.stdout)
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+ logger.setLevel(builtin_logging.INFO)
+ logger.propagate = False
+ return logger
+
+
+def enable_third_party_logging() -> None:
+ """
+ Enables explicit logger of the third-party libraries.
+ """
+ transformers.logging.set_verbosity_info()
+ transformers.logging.enable_default_handler()
+ transformers.logging.enable_explicit_format()
+
+
+def print_device_mem_info(prompt: str = "VRAM usage") -> None:
+ """
+ Logs VRAM info.
+ """
+ memory_allocated = torch.cuda.memory_allocated() / (1024**3)
+ max_memory_allocated = torch.cuda.max_memory_allocated() / (1024**3)
+ logger.info_rank0(f"{prompt}: cur {memory_allocated:.2f}GB, max {max_memory_allocated:.2f}GB.")
+
+
+def print_cpu_memory_info():
+ cpu_usage = psutil.cpu_percent(interval=1) # 1 秒间隔
+ logger.info_rank0(f"CPU Usage: {cpu_usage}%")
+
+ memory_info = psutil.virtual_memory()
+ logger.info_rank0(f"Total Memory: {memory_info.total / (1024**3):.2f} GB")
+ logger.info_rank0(f"Available Memory: {memory_info.available / (1024**3):.2f} GB")
+ logger.info_rank0(f"Used Memory: {memory_info.used / (1024**3):.2f} GB")
+ logger.info_rank0(f"Memory Usage: {memory_info.percent}%")
+
+
+def empty_cache() -> None:
+ """
+ Collects system memory.
+ """
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+
+def get_cache_dir(path: Optional[str] = None) -> str:
+ """
+ Returns the cache directory for the given path.
+ """
+ if path is None:
+ return CACHE_DIR
+
+ path = os.path.normpath(path)
+ if not os.path.splitext(path)[-1]: # is a dir
+ path = os.path.join(path, "")
+
+ path = os.path.split(os.path.dirname(path))[-1]
+ return os.path.join(CACHE_DIR, path, "") # must endswith os.path.sep
+
+
+@lru_cache
+def get_dtype_size(dtype: "torch.dtype") -> int:
+ """
+ Taken from https://github.com/huggingface/safetensors/blob/v0.4.5/bindings/python/py_src/safetensors/torch.py#L350
+ """
+ _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
+ _float8_e5m2 = getattr(torch, "float8_e5m2", None)
+ _SIZE = {
+ torch.int64: 8,
+ torch.float32: 4,
+ torch.int32: 4,
+ torch.bfloat16: 2,
+ torch.float16: 2,
+ torch.int16: 2,
+ torch.uint8: 1,
+ torch.int8: 1,
+ torch.bool: 1,
+ torch.float64: 8,
+ _float8_e4m3fn: 1,
+ _float8_e5m2: 1,
+ }
+ return _SIZE[dtype]
+
+
+def unwrap_model(model: "nn.Module") -> "nn.Module":
+ """
+ Recursively unwraps a model from potential containers (as used in distributed training).
+
+ Taken from: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py#L4808
+ """
+ if hasattr(model, "module"):
+ return unwrap_model(getattr(model, "module"))
+ else:
+ return model
+
+
+def print_example(example: Dict[str, "torch.Tensor"], rank: int) -> None:
+ """
+ Logs a single example to screen.
+ """
+ for key, value in example.items():
+ if isinstance(value, torch.Tensor):
+ logger.info(f"[rank {rank}]: {key}'s shape: {value.shape}, device: {value.device}, {value}")
+ else:
+ logger.info(f"[rank {rank}]: {key}: {value}")
+ # logger.info(f"[rank {rank}]: {key}'s shape: {value.shape}, device: {value.device},")
+
+
+def dict2device(input_dict: dict):
+ """
+ Move a dict of Tensor to GPUs.
+ """
+ output_dict = {}
+ for k, v in input_dict.items():
+ if isinstance(v, torch.Tensor):
+ output_dict[k] = v.cuda()
+ elif isinstance(v, dict):
+ output_dict[k] = dict2device(v)
+ else:
+ output_dict[k] = v
+ return output_dict
+
+
+def make_list(item):
+ if isinstance(item, List) or isinstance(item, np.ndarray):
+ return item
+ return [item]
+
+
+def create_profiler(
+ start_step: int, end_step: int, trace_dir: str, record_shapes: bool, profile_memory: bool, with_stack: bool
+):
+ """
+ Creates a profiler to record the CPU and CUDA activities. Default export to trace.json.
+ Profile steps in [start_step, end_step).
+
+ Args:
+ start_step (int): The step to start recording.
+ end_step (int): The step to end recording.
+ trace_dir (str): The path to save the profiling result.
+ record_shapes (bool): Whether to record the shapes of the tensors.
+ profile_memory (bool): Whether to profile the memory usage.
+ with_stack (bool): Whether to include the stack trace.
+ """
+
+ def handler_fn(p):
+ torch.profiler.tensorboard_trace_handler(trace_dir)(p)
+ logger.info(f"Profiling result saved at {trace_dir}.")
+
+ activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
+
+ warmup = 0 if start_step == 1 else 1
+ wait = start_step - warmup - 1
+ active = end_step - start_step
+ logger.info(f"build profiler schedule - wait: {wait}, warmup: {warmup}, active: {active}.")
+ profiler = torch.profiler.profile(
+ activities=activities,
+ schedule=torch.profiler.schedule(
+ wait=wait,
+ warmup=warmup,
+ active=active,
+ repeat=1,
+ ),
+ on_trace_ready=handler_fn,
+ record_shapes=record_shapes,
+ profile_memory=profile_memory,
+ with_modules=True,
+ with_stack=with_stack,
+ )
+ return profiler
diff --git a/lingbotvla/utils/import_utils.py b/lingbotvla/utils/import_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e97cd91ce7179fcd8510c02a6e3ebc2ac54c153
--- /dev/null
+++ b/lingbotvla/utils/import_utils.py
@@ -0,0 +1,101 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Import utils"""
+
+import importlib.metadata
+import importlib.util
+from functools import lru_cache
+from typing import TYPE_CHECKING, Dict
+
+from packaging import version
+
+
+if TYPE_CHECKING:
+ from packaging.version import Version
+
+
+def _is_package_available(name: str) -> bool:
+ return importlib.util.find_spec(name) is not None
+
+
+def _get_package_version(name: str) -> "Version":
+ try:
+ return version.parse(importlib.metadata.version(name))
+ except Exception:
+ return version.parse("0.0.0")
+
+
+_PACKAGE_FLAGS: Dict[str, bool] = {
+ "flash_attn": _is_package_available("flash_attn"),
+ "liger_kernel": _is_package_available("liger_kernel"),
+ "torch_npu": _is_package_available("torch_npu"),
+ "vescale": _is_package_available("vescale"),
+ "seed_kernels": _is_package_available("seed_kernels"),
+ "bytecheckpoint": _is_package_available("bytecheckpoint"),
+ "diffusers": _is_package_available("diffusers"),
+ "av": _is_package_available("av"),
+ "librosa": _is_package_available("librosa"),
+ "soundfile": _is_package_available("soundfile"),
+ "triton": _is_package_available("triton"),
+}
+
+
+def is_flash_attn_2_available() -> bool:
+ return _PACKAGE_FLAGS["flash_attn"]
+
+
+def is_liger_kernel_available() -> bool:
+ return _PACKAGE_FLAGS["liger_kernel"]
+
+
+def is_torch_npu_available() -> bool:
+ return _PACKAGE_FLAGS["torch_npu"]
+
+
+def is_vescale_available() -> bool:
+ return _PACKAGE_FLAGS["vescale"]
+
+
+def is_seed_kernels_available() -> bool:
+ return _PACKAGE_FLAGS["seed_kernels"]
+
+
+def is_bytecheckpoint_available() -> bool:
+ return _PACKAGE_FLAGS["bytecheckpoint"]
+
+
+def is_diffusers_available() -> bool:
+ return _PACKAGE_FLAGS["diffusers"]
+
+
+def is_fused_moe_available() -> bool:
+ import torch
+
+ return torch.cuda.is_available() and _PACKAGE_FLAGS["triton"]
+
+
+def is_video_audio_available() -> bool:
+ return _PACKAGE_FLAGS["av"] and _PACKAGE_FLAGS["librosa"] and _PACKAGE_FLAGS["soundfile"]
+
+
+@lru_cache
+def is_torch_version_greater_than(value: str) -> bool:
+ return _get_package_version("torch") >= version.parse(value)
+
+
+@lru_cache
+def is_transformers_version_greater_or_equal_to(value: str) -> bool:
+ return _get_package_version("transformers") > version.parse(value)
diff --git a/lingbotvla/utils/logging.py b/lingbotvla/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..c972064b9213ac583e68a7ad1e4e26513084db74
--- /dev/null
+++ b/lingbotvla/utils/logging.py
@@ -0,0 +1,142 @@
+# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team. and Bytedance Ltd. and/or its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Logging utils"""
+# Based on: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
+
+import logging
+import os
+import sys
+import threading
+from functools import lru_cache
+from typing import Optional
+
+
+_thread_lock = threading.RLock()
+_default_handler: Optional["logging.Handler"] = None
+_default_log_level: "logging._Level" = logging.INFO
+
+
+class _Logger(logging.Logger):
+ """
+ A logger that supports info_rank0.
+ """
+
+ def info_rank0(self, msg: str) -> None:
+ self.info(msg)
+
+ def warning_rank0(self, msg: str) -> None:
+ self.warning(msg)
+
+ def warning_once(self, msg: str) -> None:
+ self.warning_once(msg)
+
+ def debug_rank0(self, msg: str) -> None:
+ self.debug(msg)
+
+
+def _get_default_logging_level() -> "logging._Level":
+ global _default_log_level
+
+ env_lever_str = os.getenv("VEOMNI_VERBOSITY", None)
+ if env_lever_str:
+ if env_lever_str.upper() in logging._nameToLevel:
+ return logging._nameToLevel[env_lever_str.upper()]
+ else:
+ raise ValueError(f"Unknown verbosity: {env_lever_str}")
+
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> "logging.Logger":
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+ """
+ Configures root logger using a stdout stream handler with an explicit format.
+ """
+ global _default_handler
+
+ with _thread_lock:
+ if _default_handler:
+ return
+
+ formatter = logging.Formatter(
+ fmt="[%(levelname)s][%(name)s:%(lineno)s] %(asctime)s >> %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ )
+ _default_handler = logging.StreamHandler(sys.stdout)
+ _default_handler.setFormatter(formatter)
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ library_root_logger.propagate = False
+
+
+def get_logger(name: Optional[str] = None) -> "_Logger":
+ """
+ Returns a logger with the specified name. It is not supposed to be accessed by external scripts.
+ """
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def set_verbosity_info() -> None:
+ """
+ Sets the verbosity to the `INFO` level.
+ """
+ _configure_library_root_logger()
+ _get_library_root_logger().setLevel(logging.INFO)
+
+
+def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
+ if int(os.getenv("LOCAL_RANK", "0")) == 0:
+ self.info(*args, **kwargs)
+
+
+logging.Logger.info_rank0 = info_rank0
+
+
+def debug_rank0(self: "logging.Logger", *args, **kwargs) -> None:
+ if int(os.getenv("LOCAL_RANK", "0")) == 0:
+ self.debug(*args, **kwargs)
+
+
+logging.Logger.debug_rank0 = debug_rank0
+
+
+def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
+ if int(os.getenv("LOCAL_RANK", "0")) == 0:
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_rank0 = warning_rank0
+
+
+@lru_cache(None)
+def warning_once(self, *args, **kwargs) -> None:
+ if int(os.getenv("LOCAL_RANK", "0")) == 0:
+ self.warning_rank0(*args, **kwargs)
+
+
+logging.Logger.warning_once = warning_once
diff --git a/lingbotvla/utils/lora_utils.py b/lingbotvla/utils/lora_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac16fdc2a2c55be3cc857b986d70df0cd4c3488b
--- /dev/null
+++ b/lingbotvla/utils/lora_utils.py
@@ -0,0 +1,99 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+from peft import LoraConfig, inject_adapter_in_model
+from safetensors import safe_open
+
+
+def freeze_parameters(model: nn.Module):
+ # Freeze parameters
+ model.requires_grad_(False)
+ model.eval()
+ model.train()
+
+
+def add_lora_to_model(
+ model: nn.Module,
+ lora_rank=4,
+ lora_alpha=4,
+ lora_target_modules="q,k,v,o,ffn.0,ffn.2",
+ init_lora_weights="kaiming",
+ pretrained_lora_path=None,
+ state_dict_converter=None,
+ lora_target_modules_support=None,
+):
+ model.lora_alpha = lora_alpha
+ if init_lora_weights == "kaiming":
+ init_lora_weights = True
+
+ lora_config = LoraConfig(
+ r=lora_rank,
+ lora_alpha=lora_alpha,
+ init_lora_weights=init_lora_weights,
+ target_modules=lora_target_modules.split(","),
+ )
+
+ for lora_target_module in lora_config.target_modules:
+ if lora_target_module not in lora_target_modules_support:
+ raise ValueError(f"lora_target_module {lora_target_module} not in lora_target_modules_support")
+
+ model = inject_adapter_in_model(lora_config, model)
+ for param in model.parameters():
+ if param.requires_grad:
+ param.data = param.to(torch.float32)
+
+ for name, param in model.named_parameters():
+ if "lora" in name:
+ param.data = param.data.to(dtype=torch.float32)
+
+ # Lora pretrained lora weights
+ if pretrained_lora_path is not None:
+ state_dict = load_state_dict(pretrained_lora_path)
+ if state_dict_converter is not None:
+ state_dict = state_dict_converter(state_dict)
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ all_keys = [i for i, _ in model.named_parameters()]
+ num_updated_keys = len(all_keys) - len(missing_keys)
+ num_unexpected_keys = len(unexpected_keys)
+ print(
+ f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected."
+ )
+
+
+def load_state_dict(file_path, torch_dtype=None):
+ if file_path.endswith(".safetensors"):
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
+ else:
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
+
+
+def load_state_dict_from_safetensors(file_path, torch_dtype=None):
+ state_dict = {}
+ with safe_open(file_path, framework="pt", device="cpu") as f:
+ for k in f.keys():
+ state_dict[k] = f.get_tensor(k)
+ if torch_dtype is not None:
+ state_dict[k] = state_dict[k].to(torch_dtype)
+ return state_dict
+
+
+def load_state_dict_from_bin(file_path, torch_dtype=None):
+ state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
+ if torch_dtype is not None:
+ for i in state_dict:
+ if isinstance(state_dict[i], torch.Tensor):
+ state_dict[i] = state_dict[i].to(torch_dtype)
+ return state_dict
diff --git a/lingbotvla/utils/model_utils.py b/lingbotvla/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2084ef3818cd5e18eec7d3c6a10bd16e29d35580
--- /dev/null
+++ b/lingbotvla/utils/model_utils.py
@@ -0,0 +1,66 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+import torch.nn as nn
+
+from . import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def pretty_print_trainable_parameters(model: nn.Module):
+ trainable_parameters = []
+ for n, p in model.named_parameters():
+ if p.requires_grad:
+ trainable_parameters.append(n)
+
+ printable_results = {}
+ for p in trainable_parameters:
+ param_split = p.split(".")
+ param_name = ""
+ digit_index = 0
+ layer_index_list = []
+ for split_item in param_split:
+ if split_item.isdigit():
+ param_name += f"<{digit_index}>."
+ layer_index_list.append(int(split_item))
+ digit_index += 1
+ else:
+ param_name += f"{split_item}."
+ param_name = param_name[:-1]
+
+ if param_name not in printable_results:
+ printable_results[param_name] = []
+ printable_results[param_name].append(layer_index_list)
+
+ train_param_info = "\n**** trainable parameters ****"
+ for param_key in printable_results.keys():
+ layer_idxs = np.array(printable_results[param_key])
+ if layer_idxs.shape[-1] == 0:
+ train_param_info += "\n" + param_key
+ continue
+ layer_min = layer_idxs.min(axis=0)
+ layer_max = layer_idxs.max(axis=0)
+ print_pattern = param_key
+ for index in range(len(layer_min)):
+ if layer_min[index] == layer_max[index]:
+ print_pattern = print_pattern.replace(f"<{index}>", f"[{layer_min[index]}]")
+ else:
+ print_pattern = print_pattern.replace(f"<{index}>", f"[{layer_min[index]}-{layer_max[index]}]")
+ train_param_info += "\n" + print_pattern
+ train_param_info += "\n**** trainable parameters ****"
+ logger.info_rank0(train_param_info)
diff --git a/lingbotvla/utils/normalize.py b/lingbotvla/utils/normalize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e998dac62e551d645ca7ca100694cb563854849a
--- /dev/null
+++ b/lingbotvla/utils/normalize.py
@@ -0,0 +1,163 @@
+import json
+import pathlib
+
+import numpy as np
+import numpydantic
+import pydantic
+
+
+@pydantic.dataclasses.dataclass
+class NormStats:
+ mean: numpydantic.NDArray
+ std: numpydantic.NDArray
+ q01: numpydantic.NDArray | None = None # 1st quantile
+ q99: numpydantic.NDArray | None = None # 99th quantile
+ q02: numpydantic.NDArray | None = None # 2nd quantile
+ q98: numpydantic.NDArray | None = None # 98th quantile
+
+
+class RunningStats:
+ """Compute running statistics of a batch of vectors."""
+
+ def __init__(self):
+ self._count = 0
+ self._mean = None
+ self._mean_of_squares = None
+ self._min = None
+ self._max = None
+ self._histograms = None
+ self._bin_edges = None
+ self._num_quantile_bins = 5000 # for computing quantiles on the fly
+
+ def update(self, batch: np.ndarray) -> None:
+ """
+ Update the running statistics with a batch of vectors.
+
+ Args:
+ vectors (np.ndarray): A 2D array where each row is a new vector.
+ """
+ if batch.ndim == 1:
+ batch = batch.reshape(-1, 1)
+
+ num_elements, vector_length = batch.shape
+
+ if self._count == 0:
+ self._mean = np.mean(batch, axis=0)
+ self._mean_of_squares = np.mean(batch**2, axis=0)
+ self._min = np.min(batch, axis=0)
+ self._max = np.max(batch, axis=0)
+ self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
+ self._bin_edges = [
+ np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
+ for i in range(vector_length)
+ ]
+ else:
+ if vector_length != self._mean.size:
+ raise ValueError("The length of new vectors does not match the initialized vector length.")
+ new_max = np.max(batch, axis=0)
+ new_min = np.min(batch, axis=0)
+ max_changed = np.any(new_max > self._max)
+ min_changed = np.any(new_min < self._min)
+ self._max = np.maximum(self._max, new_max)
+ self._min = np.minimum(self._min, new_min)
+
+ if max_changed or min_changed:
+ self._adjust_histograms()
+
+ self._count += num_elements
+
+ batch_mean = np.mean(batch, axis=0)
+ batch_mean_of_squares = np.mean(batch**2, axis=0)
+
+ # Update running mean and mean of squares.
+ self._mean += (batch_mean - self._mean) * (num_elements / self._count)
+ self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (num_elements / self._count)
+
+ self._update_histograms(batch)
+
+ def get_statistics(self, chunk_size=None) -> NormStats:
+ """
+ Compute and return the statistics of the vectors processed so far.
+
+ Returns:
+ dict: A dictionary containing the computed statistics.
+ """
+ if self._count < 2:
+ raise ValueError("Cannot compute statistics for less than 2 vectors.")
+
+ variance = self._mean_of_squares - self._mean**2
+ stddev = np.sqrt(np.maximum(0, variance))
+ q01, q99 = self._compute_quantiles([0.01, 0.99])
+ q02, q98 = self._compute_quantiles([0.02, 0.98])
+
+ if chunk_size is not None:
+ assert isinstance(chunk_size, int)
+ self._mean = self._mean.reshape(chunk_size, -1)
+ stddev = stddev.reshape(chunk_size, -1)
+ q01 = q01.reshape(chunk_size, -1)
+ q99 = q99.reshape(chunk_size, -1)
+ q02 = q02.reshape(chunk_size, -1)
+ q98 = q98.reshape(chunk_size, -1)
+
+ return NormStats(mean=self._mean, std=stddev, q01=q01, q99=q99, q02=q02, q98=q98)
+
+ def _adjust_histograms(self):
+ """Adjust histograms when min or max changes."""
+ for i in range(len(self._histograms)):
+ old_edges = self._bin_edges[i]
+ new_edges = np.linspace(self._min[i], self._max[i], self._num_quantile_bins + 1)
+
+ # Redistribute the existing histogram counts to the new bins
+ new_hist, _ = np.histogram(old_edges[:-1], bins=new_edges, weights=self._histograms[i])
+
+ self._histograms[i] = new_hist
+ self._bin_edges[i] = new_edges
+
+ def _update_histograms(self, batch: np.ndarray) -> None:
+ """Update histograms with new vectors."""
+ for i in range(batch.shape[1]):
+ hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
+ self._histograms[i] += hist
+
+ def _compute_quantiles(self, quantiles):
+ """Compute quantiles based on histograms."""
+ results = []
+ for q in quantiles:
+ target_count = q * self._count
+ q_values = []
+ for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
+ cumsum = np.cumsum(hist)
+ idx = np.searchsorted(cumsum, target_count)
+ q_values.append(edges[idx])
+ results.append(np.array(q_values))
+ return results
+
+
+class _NormStatsDict(pydantic.BaseModel):
+ norm_stats: dict[str, NormStats]
+ count: int
+
+
+def serialize_json(norm_stats: dict[str, NormStats], count: int) -> str:
+ """Serialize the running statistics to a JSON string."""
+ return _NormStatsDict(norm_stats=norm_stats, count=count).model_dump_json(indent=2)
+
+
+def deserialize_json(data: str) -> dict[str, NormStats]:
+ """Deserialize the running statistics from a JSON string."""
+ return _NormStatsDict(**json.loads(data)).norm_stats
+
+
+def save(directory: pathlib.Path | str, norm_stats: dict[str, NormStats], count: int) -> None:
+ """Save the normalization stats to a directory."""
+ path = pathlib.Path(directory)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ path.write_text(serialize_json(norm_stats, count))
+
+
+def load(directory: pathlib.Path | str) -> dict[str, NormStats]:
+ """Load the normalization stats from a directory."""
+ path = pathlib.Path(directory) / "norm_stats.json"
+ if not path.exists():
+ raise FileNotFoundError(f"Norm stats file not found at: {path}")
+ return deserialize_json(path.read_text())
diff --git a/lingbotvla/utils/recompute_utils.py b/lingbotvla/utils/recompute_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a0b77606900fb6baebd3735bfff256dc2eff08d
--- /dev/null
+++ b/lingbotvla/utils/recompute_utils.py
@@ -0,0 +1,136 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List
+
+import torch
+
+from lingbotvla.utils import helper
+
+
+logger = helper.create_logger(__name__)
+
+
+def string_to_op(op_string: str) -> Any:
+ """
+ Convert a single operation string to PyTorch operation object
+
+ Args:
+ op_string: e.g. "aten.addmm.default" or "torch.ops.flash_attn._flash_attn_forward.default"
+
+ Returns:
+ PyTorch operation object
+ """
+ global torch
+ # Clean the string
+ clean_string = op_string.strip()
+
+ # Remove torch.ops. prefix (if exists)
+ if clean_string.startswith("torch.ops."):
+ clean_string = clean_string[len("torch.ops.") :]
+
+ # Split path and access level by level
+ parts = clean_string.split(".")
+
+ # Check if torch.ops is available
+ if not hasattr(torch, "ops"):
+ raise AttributeError("torch.ops not available in this PyTorch version")
+
+ current = torch.ops
+
+ # Special handling: ensure accessing aten operations by first trying to trigger registration
+ if parts[0] == "aten":
+ try:
+ # Try to access a basic aten operation to trigger module loading
+ _ = torch.ops.aten.add
+ except AttributeError:
+ # If cannot access aten, may need to import related modules
+ try:
+ import torch._C._dispatch
+ except ImportError:
+ pass
+
+ for i, part in enumerate(parts):
+ if hasattr(current, part):
+ current = getattr(current, part)
+ else:
+ # More detailed error information, including current path
+ current_path = ".".join(parts[:i])
+ available_attrs = dir(current) if hasattr(current, "__dict__") else []
+ raise AttributeError(
+ f"Operation '{op_string}' not found. "
+ f"Missing attribute: '{part}' at path 'torch.ops.{current_path}'. "
+ f"Available attributes: {available_attrs[:10]}{'...' if len(available_attrs) > 10 else ''}"
+ )
+
+ return current
+
+
+def convert_ops_to_objects(ops_strings: List[str]) -> List[Any]:
+ """
+ Convert operation string list to operation object list
+ Args:
+ ops_strings: String list
+
+ Returns:
+ PyTorch operation object list
+ """
+ ops_objects = []
+ failed_ops = []
+
+ # First perform environment check
+ _check_torch_ops_availability()
+
+ for op_str in ops_strings:
+ try:
+ op_obj = string_to_op(op_str)
+ ops_objects.append(op_obj)
+ logger.info_rank0(f"✓ Conversion successful: {op_str}")
+ assert isinstance(op_obj, torch._ops.OpOverload), "Please check if the ops is end with .default"
+ except (AttributeError, TypeError) as e:
+ logger.info_rank0(f"✗ Conversion failed: {op_str} - {e}")
+ failed_ops.append(op_str)
+ except Exception as e:
+ logger.info_rank0(f"✗ Conversion failed: {op_str} - {e}")
+ raise e
+
+ if failed_ops:
+ logger.info_rank0(f"\nWarning: {len(failed_ops)} operations failed to convert")
+ logger.info_rank0("Possible reasons:")
+ logger.info_rank0("1. PyTorch version does not support certain operations")
+ logger.info_rank0("2. Missing related extension modules (e.g. flash_attn)")
+ logger.info_rank0("3. Operation name spelling error")
+
+ return ops_objects
+
+
+def _check_torch_ops_availability():
+ global torch
+ # Check if torch.ops is available
+ if not hasattr(torch, "ops"):
+ raise RuntimeError("torch.ops is not available in current PyTorch version")
+
+ # Check basic aten operations
+ try:
+ _ = torch.ops.aten.add
+ logger.info_rank0("✓ torch.ops.aten available")
+ except AttributeError as e:
+ logger.info_rank0(f"✗ torch.ops.aten not available: {e}")
+ logger.info_rank0("Trying to import necessary modules...")
+ try:
+ import torch._C._dispatch
+
+ logger.info_rank0("✓ Successfully imported torch._C._dispatch")
+ except ImportError as e:
+ logger.info_rank0(f"✗ Cannot import torch._C._dispatch: {e}")
diff --git a/lingbotvla/utils/seqlen_pos_transform_utils.py b/lingbotvla/utils/seqlen_pos_transform_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e43919cbcf3f63ed04de1a7248126109a890440
--- /dev/null
+++ b/lingbotvla/utils/seqlen_pos_transform_utils.py
@@ -0,0 +1,54 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+import torch.nn.functional as F
+
+
+def len2culen(seqlens: "torch.Tensor") -> "torch.Tensor":
+ """
+ Converts the sequence lengths to cumulative sequence lengths.
+
+ NOTE: flash attention only accepts int32 cu_seqlens.
+ """
+ return F.pad(torch.cumsum(seqlens, dim=0), (1, 0)).type(torch.int32)
+
+
+def culen2len(cu_seqlens: "torch.Tensor") -> "torch.Tensor":
+ """
+ Converts the cumulative sequence lengths to sequence lengths.
+ """
+ return cu_seqlens.diff()
+
+
+def pos2culen(position_ids: "torch.Tensor") -> "torch.Tensor":
+ """
+ Converts the position ids to cumulative sequence lengths.
+ """
+ if position_ids.dim() == 3: # (batch_size, dim, seq_length):
+ position_ids = position_ids[:, 0, :]
+
+ position_ids = position_ids.flatten()
+ indices_q = torch.arange(position_ids.size(0), dtype=torch.int32, device=position_ids.device)
+ return F.pad(indices_q[position_ids == 0], (0, 1), "constant", position_ids.size(0))
+
+
+def culen2pos(cu_seqlens: "torch.Tensor") -> "torch.Tensor":
+ """
+ Converts the cumulative sequence lengths to position ids.
+ """
+ seqlens = culen2len(cu_seqlens).cpu()
+ position_ids = torch.cat([torch.arange(length, dtype=torch.long, device=cu_seqlens.device) for length in seqlens])
+ return position_ids.unsqueeze(0)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..8e40ac010b358cac2859708c4e67e2bb5eaafb06
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,38 @@
+[build-system]
+requires = ["setuptools>=61.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "lingbotvla"
+dynamic = ["dependencies", "optional-dependencies"]
+version = "0.0.1"
+authors = [
+ { name="Robbyant Team", email="lf419501@antgroup.com" },
+]
+description = "LingBot-VLA: A Pragmatic VLA Foundation Model"
+requires-python = ">=3.8"
+license = "Apache-2.0"
+license-files = ["LICENSE"]
+
+[tool.ruff]
+target-version = "py38"
+line-length = 119
+indent-width = 4
+
+[tool.ruff.lint]
+ignore = ["C901", "E501", "E741", "W605", "C408"]
+select = ["C", "E", "F", "I", "W"]
+
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+known-first-party = ["lingbotvla"]
+known-third-party = ["torch", "transformers", "wandb"]
+
+[tool.ruff.format]
+quote-style = "double"
+indent-style = "space"
+skip-magic-trailing-comma = false
+line-ending = "auto"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1c257706081334e9edffdc299339edd05d6dfacb
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+ipdb
+torchcodec==0.6.0
+pytest
+datasets==3.6.0
+transformers==4.51.3
+numpy==1.26.4
+numpydantic
+tensorboard==2.16.2
+msgpack
+websockets
+matplotlib
diff --git a/scripts/compute_norm_robotwin_5.py b/scripts/compute_norm_robotwin_5.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9f1db13e093525a450d0880a1f838b8d7829fd5
--- /dev/null
+++ b/scripts/compute_norm_robotwin_5.py
@@ -0,0 +1,104 @@
+import json
+import numpy as np
+import os
+import re
+import time
+from pathlib import Path
+from dataclasses import asdict, dataclass, field
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
+
+import torch
+from tqdm import trange, tqdm
+from torch.utils.data import DataLoader
+from lingbotvla.models import build_processor
+from lingbotvla.utils import helper
+from lingbotvla.utils.arguments import DataArguments, ModelArguments, TrainingArguments, parse_args
+import lingbotvla.utils.normalize as normalize
+from lingbotvla.data.vla_data.base_dataset import VlaDataset
+
+
+if TYPE_CHECKING:
+ from transformers import ProcessorMixin
+
+ from lingbotvla.data.chat_template import ChatTemplate
+
+logger = helper.create_logger(__name__)
+
+
+@dataclass
+class MyDataArguments(DataArguments):
+ norm_path: str = field(
+ default=None,
+ metadata={"help": "Path to save norm stats."},
+ )
+ chunk_size: int = field(
+ default=50,
+ metadata={"help": "Chunk size of action."},
+ )
+
+
+@dataclass
+class Arguments:
+ model: "ModelArguments" = field(default_factory=ModelArguments)
+ data: "MyDataArguments" = field(default_factory=MyDataArguments)
+ train: "TrainingArguments" = field(default_factory=TrainingArguments)
+
+def compute_norm(dataset, task_id, batch_size, stats, state_norm_keys, acton_norm_keys, delta_norm):
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=False, drop_last=True)
+ success = True
+ try:
+ for batch in tqdm(data_loader, desc=f"Computing stats of {task_id}"):
+ for key in state_norm_keys:
+ values = np.asarray(batch[key])
+ # values = batch[key]
+ stats[key].update(values.reshape(-1, values.shape[-1]))
+ for key in acton_norm_keys:
+ values = np.asarray(batch[key][:,0]) if not delta_norm[key] else np.asarray(batch[key].reshape(batch[key].shape[0], -1))
+ stats[key].update(values.reshape(-1, values.shape[-1]))
+ except: success = False
+ return success
+
+
+def main():
+ args = parse_args(Arguments)
+ logger.info(f"Process rank: {args.train.global_rank}, world size: {args.train.world_size}")
+ logger.info_rank0(json.dumps(asdict(args), indent=2))
+
+ logger.info_rank0("Prepare data")
+ stats = None
+
+ assert args.data.datasets_type == 'vla'
+ dataset = VlaDataset(repo_id=args.data.train_path, action_name='action')
+
+ state_norm_keys = ['observation.state']
+ acton_norm_keys = ['action']
+ delta_norm = {'action': False} # all action dims do not need to minus state in Robotwin
+ stats = {key: normalize.RunningStats() for key in acton_norm_keys+state_norm_keys}
+
+ chunk_size = args.data.chunk_size
+
+ try:
+ success = compute_norm(dataset, args.data.train_path, args.train.global_batch_size, stats, state_norm_keys, acton_norm_keys, delta_norm)
+ except Exception as e:
+ fail_info = f"{args.data.train_path} {e}"
+ print(fail_info)
+
+
+
+ if success:
+ norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
+ norm_stats = {}
+ for key, stats in stats.items():
+ if key in delta_norm and delta_norm[key]==True:
+ norm_stats[key] = stats.get_statistics(chunk_size=chunk_size)
+ else:
+ norm_stats[key] = stats.get_statistics()
+
+ output_path = Path(args.data.norm_path)
+ print(f"Writing stats to: {output_path}")
+ normalize.save(output_path, norm_stats, stats._count)
+
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/scripts/download_hf_data.py b/scripts/download_hf_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..941a6513b063310034d8da13bb795e44472ef8d6
--- /dev/null
+++ b/scripts/download_hf_data.py
@@ -0,0 +1,27 @@
+import argparse
+
+from huggingface_hub import snapshot_download
+
+
+"""
+python3 scripts/download_hf_data.py --repo_id HuggingFaceFW/fineweb --local_dir ./fineweb/ --allow_patterns sample/10BT/*
+"""
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--repo_id", type=str, default="HuggingFaceFW/fineweb")
+ parser.add_argument("--local_dir", type=str, default="./fineweb/")
+ parser.add_argument("--allow_patterns", type=str, default=None)
+ args = parser.parse_args()
+
+ repo_id = args.repo_id
+ local_dir = args.local_dir
+ allow_patterns = args.allow_patterns
+
+ folder = snapshot_download(
+ repo_id,
+ repo_type="dataset",
+ local_dir=local_dir,
+ allow_patterns=allow_patterns,
+ )
diff --git a/scripts/download_hf_model.py b/scripts/download_hf_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a11d6c101688511ed5baddfb50b9ed204d5cb194
--- /dev/null
+++ b/scripts/download_hf_model.py
@@ -0,0 +1,26 @@
+import argparse
+import os
+
+from huggingface_hub import snapshot_download
+
+
+"""
+python3 scripts/download_hf_model.py --repo_id deepseek-ai/Janus-1.3B --local_dir Janus-1.3B
+"""
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--repo_id", type=str, default="deepseek-ai/Janus-1.3B")
+ parser.add_argument("--local_dir", type=str, default="./Janus-1.3B")
+ parser.add_argument("--local_dir_use_symlinks", type=bool, default=False)
+ args = parser.parse_args()
+
+ repo_id = args.repo_id
+ local_dir = args.local_dir
+ local_dir_use_symlinks = args.local_dir_use_symlinks
+
+ snapshot_download(
+ repo_id=repo_id,
+ local_dir=os.path.join(local_dir, repo_id.split("/")[1]),
+ local_dir_use_symlinks=local_dir_use_symlinks,
+ )
diff --git a/scripts/mereg_dcp_to_hf.py b/scripts/mereg_dcp_to_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..8882c1d2ee1056aadae852620493db21cd044fef
--- /dev/null
+++ b/scripts/mereg_dcp_to_hf.py
@@ -0,0 +1,40 @@
+import argparse
+import os
+
+from transformers import AutoConfig, AutoProcessor
+
+from lingbotvla.checkpoint import bytecheckpoint_ckpt_to_state_dict
+from lingbotvla.models import save_model_weights
+from lingbotvla.utils import helper
+
+
+logger = helper.create_logger(__name__)
+
+
+def merge_to_hf_pt(load_dir: str, save_path: str, model_assets_dir: str = None):
+ # save model in huggingface's format
+ state_dict = bytecheckpoint_ckpt_to_state_dict(
+ save_checkpoint_path=load_dir,
+ output_dir=save_path,
+ )
+ if model_assets_dir is not None:
+ config = AutoConfig.from_pretrained(model_assets_dir)
+ processor = AutoProcessor.from_pretrained(model_assets_dir, trust_remote_code=True)
+
+ save_model_weights(save_path, state_dict, model_assets=[config, processor])
+ else:
+ save_model_weights(save_path, state_dict)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--load-dir", type=str, required=True)
+ parser.add_argument("--save-dir", type=str, default=None)
+ parser.add_argument("--model_assets_dir", type=str, default=None)
+ args = parser.parse_args()
+ load_dir = args.load_dir
+ save_dir = os.path.join(load_dir, "hf_ckpt") if args.save_dir is None else args.save_dir
+ model_assets_dir = args.model_assets_dir
+ logger.info(f"Merge Args: {args}")
+ merge_to_hf_pt(load_dir, save_dir, model_assets_dir)
+ logger.info(f"Merge to hf pt success! Save to: {save_dir}")
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..667cf00375369513c31f9e6a040c2808540c1391
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,70 @@
+import importlib.metadata
+import importlib.util
+import os
+import re
+from typing import List
+
+from setuptools import find_packages, setup
+
+
+def _is_package_available(name: str) -> bool:
+ return importlib.util.find_spec(name) is not None
+
+
+def _is_torch_npu_available() -> bool:
+ return _is_package_available("torch_npu")
+
+
+def _is_torch_available() -> bool:
+ return _is_package_available("torch")
+
+
+def _is_torch_cuda_available() -> bool:
+ if _is_torch_available():
+ import torch
+
+ return torch.cuda.is_available()
+ else:
+ return False
+
+
+def get_version() -> str:
+ with open(os.path.join("lingbotvla", "__init__.py"), encoding="utf-8") as f:
+ file_content = f.read()
+ pattern = r"{}\W*=\W*\"([^\"]+)\"".format("__version__")
+ (version,) = re.findall(pattern, file_content)
+ return version
+
+
+def get_requires() -> List[str]:
+ with open("requirements.txt", encoding="utf-8") as f:
+ file_content = f.read()
+ lines = [line.strip() for line in file_content.strip().split("\n") if not line.startswith("#")]
+ return lines
+
+BASE_REQUIRE = [
+ "torchdata>=0.8.0,<1.0",
+ "blobfile>=3.0.0",
+]
+
+def main():
+ # Update install_requires and extras_require
+ install_requires = BASE_REQUIRE
+
+ setup(
+ name="lingbotvla",
+ version=get_version(),
+ python_requires=">=3.8.0",
+ packages=find_packages(exclude=["scripts", "tasks", "tests"]),
+ url="https://www.robbyant.com",
+ license="Apache 2.0",
+ author="Robbyant Team",
+ author_email="lf419501@antgroup.com",
+ description="LingBot-VLA: A Pragmatic VLA Foundation Model",
+ install_requires=install_requires,
+ include_package_data=False,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tasks/vla/train_lingbotvla.py b/tasks/vla/train_lingbotvla.py
new file mode 100644
index 0000000000000000000000000000000000000000..518e7fe5da58378439e1915b97f2b8c23dc67752
--- /dev/null
+++ b/tasks/vla/train_lingbotvla.py
@@ -0,0 +1,822 @@
+import json
+from copy import deepcopy
+import os
+import re
+import time
+from dataclasses import asdict, dataclass, field
+from functools import partial
+from io import BytesIO
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Literal
+from collections import defaultdict
+import numpy as np
+import torch
+import torch.distributed as dist
+import wandb
+from PIL import Image
+from tqdm import trange
+from torch.utils.tensorboard import SummaryWriter
+from lingbotvla.checkpoint import build_checkpointer, ckpt_to_state_dict
+from lingbotvla.data import (
+ VLADataCollatorWithPacking,
+ build_dataloader,
+)
+from lingbotvla.data.vla_data import liberoDataset, RobotwinDataset, CustomizedRobotwinDataset
+from lingbotvla.distributed.offloading import build_activation_offloading_context
+from lingbotvla.distributed.parallel_state import get_parallel_state, init_parallel_state
+from lingbotvla.distributed.torch_parallelize import build_parallelize_model
+from lingbotvla.models import build_foundation_model, build_processor, save_model_assets, save_model_weights, build_tokenizer
+from lingbotvla.optim import build_lr_scheduler, build_optimizer
+from lingbotvla.utils import helper
+from lingbotvla.utils.ema import ema_update
+from lingbotvla.utils.arguments import DataArguments, ModelArguments, TrainingArguments, parse_args, save_args
+from lingbotvla.utils.dist_utils import all_reduce
+
+from lingbotvla.models.vla.vision_models.module_utils import build_depth_model, get_depth_target, log_depth
+
+if TYPE_CHECKING:
+ from transformers import ProcessorMixin
+
+ from lingbotvla.data.chat_template import ChatTemplate
+
+
+logger = helper.create_logger(__name__)
+# try:
+# from aistudio_tracking import training_tracking as wandb
+# except Exception as e:
+# logger.info_rank0(f"Failed to import aistudio_tracking: {repr(e)}.")
+
+def get_param_groups(model: "torch.nn.Module", default_lr: float, vit_lr: float):
+ vit_params, other_params = [], []
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ if "visual" in name:
+ vit_params.append(param)
+ else:
+ other_params.append(param)
+
+ return [{"params": vit_params, "lr": vit_lr}, {"params": other_params, "lr": default_lr}]
+
+@dataclass
+class MyTrainingArguments(TrainingArguments):
+ freeze_vit: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to freeze the vit parameters."},
+ )
+ vit_lr: float = field(
+ default=1e-6,
+ metadata={"help": "Maximum learning rate for vit parameters."},
+ )
+ freeze_vision_encoder: bool = field(
+ default=False,
+ metadata={"help": "Whether or not to freeze the vision encoder in VLA model."},
+ )
+ tokenizer_max_length: int = field(
+ default=48,
+ metadata={"help": "Maximum length of the tokenizer."},
+ )
+ enable_expert_vision: bool = field(
+ default=False,
+ metadata={"help": "Whether to enable expert vision."},
+ )
+ expert_vision_type: str | None = field(
+ default=None,
+ metadata={"help": "Type of expert vision. Currently only support vit."},
+ )
+ expert_vision_path: str | None = field(
+ default=None,
+ metadata={"help": "Path to expert vision model."},
+ )
+ action_dim: int = field(
+ default=7,
+ metadata={"help": "Action dimension."},
+ )
+ max_action_dim: int = field(
+ default=32,
+ metadata={"help": "Action dimension after padding."},
+ )
+ max_state_dim: int = field(
+ default=32,
+ metadata={"help": "State dimension after padding."},
+ )
+ chunk_size: int = field(
+ default=50,
+ metadata={"help": "Chunk size of action."},
+ )
+ vlm_causal: bool = field(
+ default=False,
+ metadata={"help": "Whether to use causal atten for img anb lang tokens in vlm."},
+ )
+ use_ema: bool = field(
+ default=False,
+ metadata={"help": "Whether to use EMA."},
+ )
+ qwenvl_bos: bool = field(
+ default=False,
+ metadata={"help": "Whether to use qwenvl bos."},
+ )
+ ema_rate: float = field(
+ default=0.9999,
+ metadata={"help": "Rate of EMA."},
+ )
+ pre_train: bool = field(
+ default=False,
+ metadata={"help": "Whether to apply pretraining."},
+ )
+ loss_type: str = field(
+ default='fm',
+ metadata={"help": "Which loss to use."},
+ )
+ align_params: Optional[Dict[str, Any]] = field(
+ default_factory=dict,
+ metadata={"help": "The config of vaco"},
+ )
+ use_ki: bool = field(
+ default=False,
+ metadata={"help": "Whether to apply knowledge insulating."},
+ )
+ ignore_depth: bool = field(
+ default=False,
+ metadata={"help": "Whether to ignore depth model in FSDP2."},
+ )
+ my_tokenizer_max_length: int = field(
+ default=72,
+ metadata={"help": ""},
+ )
+ use_subtask: bool = field(
+ default=False,
+ metadata={"help": "Whether to predict subtask from vlm."},
+ )
+ use_state: bool = field(
+ default=False,
+ metadata={"help": "Whether to use stringfy state in prefix."},
+ )
+ use_fast_action: bool = field(
+ default=False,
+ metadata={"help": "Whether to use fast action prediction."},
+ )
+ skip_max_norm: bool = field(
+ default=False,
+ metadata={"help": "Whether to skip batch with too large grad norm."},
+ )
+ decayed_max_grad_norm: float = field(
+ default=1.0,
+ metadata={"help": "Maximum norm for the decayed gradients."},
+ )
+ stable_train_steps: int = field(
+ default=100000,
+ metadata={"help": "Training steps for stable training, after this step, the decayed_max_grad_norm will be applied."},
+ )
+ resume_dataloader_state: bool = field(
+ default=True,
+ metadata={"help": "Whether to resume dataloader."},
+ )
+ norm_qkv: bool = field(
+ default=False,
+ metadata={"help": "Whether to apply RMSNorm for qkv."},
+ )
+ use_prompt: bool = field(
+ default=False,
+ metadata={"help": "Whether to use prompt condition."},
+ )
+ embodiment_name: str = field(
+ default=None,
+ metadata={"help": "Name of the embodiment type."},
+ )
+
+@dataclass
+class MyDataArguments(DataArguments):
+ source_name: str = field(
+ default=None,
+ metadata={"help": "Source name of dataset."},
+ )
+ robot_config_root: str = field(
+ default=None,
+ metadata={"help": "Path to get all robot configs."},
+ )
+ joints: Optional[List[str]] = field(
+ default=None,
+ metadata={"help": "The order of joints and their dim"},
+ )
+ cameras:Optional[List[str]] = field(
+ default=None,
+ metadata={"help": "The order of used images"},
+ )
+ norm_type:Literal["meanstd", "bounds_99", "bounds_98", "bounds_98_woclip", "bounds_99_woclip"] = field(
+ default="bounds_99",
+ metadata={"help": "Type of the normalization."},
+ )
+ img_size: int = field(
+ default=224,
+ metadata={"help": "Size of the image."},
+ )
+ norm_stats_file: str = field(
+ default=None,
+ metadata={"help": "Path to the normalization stats file."},
+ )
+
+
+@dataclass
+class Arguments:
+
+ model: "ModelArguments" = field(default_factory=ModelArguments)
+ data: "MyDataArguments" = field(default_factory=MyDataArguments)
+ train: "MyTrainingArguments" = field(default_factory=MyTrainingArguments)
+
+
+def main():
+ args = parse_args(Arguments)
+ logger.info(f"Process rank: {args.train.global_rank}, world size: {args.train.world_size}")
+ logger.info_rank0(json.dumps(asdict(args), indent=2))
+ torch.cuda.set_device(f"cuda:{args.train.local_rank}")
+ dist.init_process_group(backend="nccl")
+ helper.set_seed(args.train.seed, args.train.enable_full_determinism)
+ if args.train.local_rank == 0:
+ helper.enable_third_party_logging()
+
+ if args.train.global_rank == 0:
+ save_args(args, args.train.output_dir)
+
+ Checkpointer = build_checkpointer(dist_backend=args.train.data_parallel_mode, ckpt_manager=args.train.ckpt_manager)
+
+ init_parallel_state(
+ dp_size=args.train.data_parallel_size,
+ dp_replicate_size=args.train.data_parallel_replicate_size,
+ dp_shard_size=args.train.data_parallel_shard_size,
+ tp_size=args.train.tensor_parallel_size,
+ ep_size=args.train.expert_parallel_size,
+ pp_size=args.train.pipeline_parallel_size,
+ cp_size=args.train.context_parallel_size,
+ ulysses_size=args.train.ulysses_parallel_size,
+ dp_mode=args.train.data_parallel_mode,
+ )
+
+ logger.info_rank0("Prepare model")
+ config_kwargs = {'vlm_repo_id': getattr(args.model, "vlm_repo_id", None)}
+ config_kwargs['action_dim'] = getattr(args.train, "action_dim", 7)
+ config_kwargs['max_action_dim'] = getattr(args.train, "max_action_dim", 32)
+ config_kwargs['max_state_dim'] = getattr(args.train, "max_state_dim", 32)
+ config_kwargs['chunk_size'] = getattr(args.train, "chunk_size", 50)
+ config_kwargs['tokenizer_path'] = getattr(args.model, "tokenizer_path", None)
+ config_kwargs['post_training'] = getattr(args.model, "post_training", False)
+ config_kwargs['incremental_training'] = getattr(args.model, "incremental_training", False)
+ config_kwargs['depth_incremental_training'] = getattr(args.model, "depth_incremental_training", False)
+ config_kwargs['norm_qkv'] = getattr(args.train, "norm_qkv", False)
+ config_kwargs['enable_expert_vision'] = args.train.enable_expert_vision
+ config_kwargs['expert_vision_type'] = getattr(args.train, "expert_vision_type", None)
+ config_kwargs['expert_vision_path'] = getattr(args.train, "expert_vision_path", None)
+ config_kwargs['adanorm_time'] = getattr(args.model, "adanorm_time", False)
+ if not getattr(args.model, "adanorm_time", False):
+ assert not getattr(args.model, "separate_time_proj", False), 'separate_time_proj should be dropped when we do not apply adanorm_time!!'
+ config_kwargs['split_gate_liner'] = getattr(args.model, "split_gate_liner", False)
+ config_kwargs['nosplit_gate_liner'] = getattr(args.model, "nosplit_gate_liner", False)
+ config_kwargs['separate_time_proj'] = getattr(args.model, "separate_time_proj", False)
+ config_kwargs['old_adanorm'] = getattr(args.model, "old_adanorm", False)
+ if getattr(args.model, "old_adanorm", False):
+ assert getattr(args.model, "adanorm_time", False), 'Apply old_adanorm should apply adanorm_time!!'
+ config_kwargs['final_norm_adanorm'] = getattr(args.model, "final_norm_adanorm", False)
+ config_kwargs['loss_type'] = getattr(args.train, "loss_type", 'fm')
+ config_kwargs['align_params'] = getattr(args.train, "align_params", None)
+ if args.train.enable_expert_vision and not args.model.post_training:
+ assert args.train.expert_vision_path is not None, "expert_vision_path is required when enable_expert_vision is True!!!"
+ model = build_foundation_model(
+ config_path=args.model.config_path,
+ weights_path=args.model.model_path,
+ torch_dtype="float32" if args.train.enable_mixed_precision else "bfloat16",
+ init_device=args.train.init_device,
+ freeze_vision_encoder=args.train.freeze_vision_encoder,
+ tokenizer_max_length=args.train.tokenizer_max_length,
+ vocab_size=args.model.vocab_size,
+ use_lm_head=args.model.use_lm_head,
+ force_use_huggingface=args.model.force_use_huggingface,
+ config_kwargs=config_kwargs,
+ )
+ use_depth_align = True if args.train.align_params != {} else False
+ depth_model_type = None
+ if use_depth_align:
+ assert args.model.moge_path is not None and args.model.morgbd_path is not None, 'Depth models need to be loaded when uing LingBot-VLA-Depth!!!'
+ args.train.align_params['visual_dir'] = os.path.join(args.train.output_dir, 'images')
+ args.train.align_params['depth']['moge_path'] = args.model.moge_path
+ args.train.align_params['depth']['morgbd_path'] = args.model.morgbd_path
+ depth_model_type = args.train.align_params['depth']['model_type']
+ moge_model, morgbd_model = build_depth_model(args.train.align_params)
+ if args.train.use_compile:
+ moge_model = torch.compile(moge_model)
+ morgbd_model = torch.compile(morgbd_model)
+ os.makedirs(args.train.align_params['visual_dir'], exist_ok=True)
+ model_config = model.config
+ helper.print_device_mem_info("VRAM usage after building model")
+
+ logger.info_rank0("Prepare data")
+ processor = build_processor(args.model.tokenizer_path) # if use build_processor, tokenizer is processor.tokenizer
+
+ if args.train.rmpad:
+ raise ValueError("Qwen2-VL does not support rmpad. Use `rmpad_with_pos_ids` instead.")
+
+ data_collate_fn = []
+ if args.data.datasets_type == 'vla':
+ data_collate_fn.append(VLADataCollatorWithPacking())
+ else:
+ if args.train.rmpad_with_pos_ids:
+ data_collate_fn.append(OmniDataCollatorWithPacking()) # TODO 8.21
+ else:
+ data_collate_fn.append(OmniDataCollatorWithPadding())
+
+ if args.data.dataloader_type == "native":
+ if args.data.datasets_type == 'vla':
+ logger.info_rank0("Start building VLA dataset")
+ args.data.chunk_size = args.train.chunk_size
+ if args.data.data_name == 'libero':
+ train_dataset = liberoDataset(repo_id=args.data.train_path, config=model.config, tokenizer=processor.tokenizer, data_config=args.data, image_processor=processor.image_processor if 'qwen' in args.model.tokenizer_path.lower() else None,use_depth_align=use_depth_align)
+ elif 'robotwin' in args.data.data_name.lower():
+ train_dataset = RobotwinDataset(repo_id=args.data.train_path, config=model.config, tokenizer=processor.tokenizer, data_config=args.data, image_processor=processor.image_processor if 'qwen' in args.model.tokenizer_path.lower() else None, use_depth_align=use_depth_align)
+ args.train.compute_train_steps(args.data.max_seq_len, args.data.train_size, len(train_dataset))
+
+ train_dataloader = build_dataloader(
+ dataset=train_dataset,
+ micro_batch_size=args.train.micro_batch_size,
+ global_batch_size=args.train.global_batch_size,
+ dataloader_batch_size=args.train.dataloader_batch_size,
+ seed=args.train.seed,
+ collate_fn=data_collate_fn,
+ max_seq_len=args.data.max_seq_len,
+ train_steps=args.train.train_steps,
+ rmpad=args.train.rmpad,
+ rmpad_with_pos_ids=args.train.rmpad_with_pos_ids,
+ bsz_warmup_ratio=args.train.bsz_warmup_ratio,
+ dyn_bsz_margin=args.train.dyn_bsz_margin,
+ dyn_bsz_buffer_size=args.train.dyn_bsz_buffer_size,
+ num_workers=args.data.num_workers,
+ drop_last=args.data.drop_last,
+ pin_memory=args.data.pin_memory,
+ prefetch_factor=args.data.prefetch_factor if args.data.num_workers > 0 else None,
+ )
+ else:
+ raise NotImplementedError(f"Unsupported dataloader type: {args.data.dataloader_type}.")
+
+ fsdp_kwargs = {}
+ if args.train.freeze_vit:
+ model.visual.requires_grad_(False)
+ if args.train.data_parallel_mode == "fsdp1":
+ fsdp_kwargs["use_orig_params"] = True
+
+ if args.train.use_ema:
+ model_ema = deepcopy(model).eval()
+ else:
+ model_ema = None
+
+ model = build_parallelize_model(
+ model,
+ enable_full_shard=args.train.enable_full_shard,
+ enable_mixed_precision=args.train.enable_mixed_precision,
+ enable_fp32=args.train.enable_fp32,
+ enable_gradient_checkpointing=args.train.enable_gradient_checkpointing,
+ init_device=args.train.init_device,
+ enable_fsdp_offload=args.train.enable_fsdp_offload,
+ fsdp_kwargs=fsdp_kwargs,
+ basic_modules=model._no_split_modules if args.train.module_fsdp_enable else None,
+ enable_reentrant=args.train.enable_reentrant,
+ enable_forward_prefetch=args.train.enable_forward_prefetch,
+ fsdp_llm_blocks=False,
+ ignore_norm=False,
+ use_depth_align=use_depth_align,
+ ignore_depth=args.train.ignore_depth,
+ )
+ if model_ema is not None:
+ model_ema = build_parallelize_model(
+ model_ema,
+ enable_full_shard=args.train.enable_full_shard,
+ enable_mixed_precision=args.train.enable_mixed_precision,
+ enable_fp32=args.train.enable_fp32,
+ enable_gradient_checkpointing=args.train.enable_gradient_checkpointing,
+ init_device=args.train.init_device,
+ enable_fsdp_offload=args.train.enable_fsdp_offload,
+ fsdp_kwargs=fsdp_kwargs,
+ basic_modules=model_ema._no_split_modules if args.train.module_fsdp_enable else None,
+ enable_reentrant=args.train.enable_reentrant,
+ enable_forward_prefetch=args.train.enable_forward_prefetch,
+ fsdp_llm_blocks=False,
+ ignore_norm=False,
+ use_depth_align=use_depth_align,
+ ignore_depth=args.train.ignore_depth,
+ )
+ if args.train.use_compile:
+ model = torch.compile(model)
+ if model_ema is not None: model_ema = torch.compile(model_ema)
+
+ if args.train.use_ema:
+ ema_update(model_ema, model, 0)
+
+ optimizer = build_optimizer(
+ model,
+ lr=args.train.lr,
+ weight_decay=args.train.weight_decay,
+ fused=False,
+ optimizer_type=args.train.optimizer,
+ post_training=args.model.post_training,
+ )
+ lr_scheduler = build_lr_scheduler(
+ optimizer,
+ train_steps=args.train.train_steps * args.train.num_train_epochs,
+ lr=args.train.lr,
+ lr_min=args.train.lr_min,
+ lr_decay_style=args.train.lr_decay_style,
+ lr_decay_ratio=args.train.lr_decay_ratio,
+ lr_warmup_ratio=args.train.lr_warmup_ratio,
+ lr_start=args.train.lr_start,
+ )
+
+ if args.train.global_rank == 0:
+ log_dir=f"{args.train.output_dir}/runs/"
+ writer = SummaryWriter(log_dir=log_dir)
+ if args.train.use_wandb:
+ wandb.init(
+ name=args.train.wandb_name,
+ config={**vars(args.model), **vars(args.data), **vars(args.train)}, # flatten dict
+ )
+
+ if args.train.enable_profiling:
+ profiler = helper.create_profiler(
+ start_step=args.train.profile_start_step,
+ end_step=args.train.profile_end_step,
+ trace_dir=args.train.profile_trace_dir,
+ record_shapes=args.train.profile_record_shapes,
+ profile_memory=args.train.profile_profile_memory,
+ with_stack=args.train.profile_with_stack,
+ )
+ profiler.start()
+
+ model_assets = [model_config, processor]
+ save_model_assets(args.train.model_assets_dir, model_assets)
+
+ start_epoch, start_step, global_step = 0, 0, 0
+ save_checkpoint_path = None
+ environ_meter = helper.EnvironMeter(
+ config=model_config,
+ global_batch_size=args.train.global_batch_size,
+ rmpad=args.train.rmpad,
+ rmpad_with_pos_ids=args.train.rmpad_with_pos_ids,
+ empty_cache_steps=args.train.empty_cache_steps,
+ )
+
+ load_checkpoint_path = None
+ candidates = []
+ if args.train.load_checkpoint_path or args.train.enable_resume:
+ if args.train.load_checkpoint_path:
+ load_checkpoint_path = args.train.load_checkpoint_path
+ candidates = [load_checkpoint_path]
+ elif args.train.enable_resume:
+ checkpoint_dir = f'{args.train.output_dir}/checkpoints'
+ if os.path.exists(checkpoint_dir):
+ pattern = re.compile(r"global_step_(\d+)")
+ tmp = []
+ for dirname in os.listdir(checkpoint_dir):
+ match = pattern.fullmatch(dirname)
+ if match:
+ step = int(match.group(1))
+ tmp.append((step, os.path.join(checkpoint_dir, dirname)))
+ tmp.sort(key=lambda x: x[0], reverse=True)
+ candidates = [p for _, p in tmp]
+ if candidates:
+ load_checkpoint_path = candidates[0]
+ else:
+ logger.info_rank0(f"No checkpoints in {args.train.output_dir} now!")
+ if candidates:
+ last_err = None
+ loaded = False
+ for cp in candidates:
+ state = {"model": model, "ema": model_ema, "optimizer": optimizer, "extra_state": {}} # cannot be None
+ try:
+ Checkpointer.load(cp, state)
+ global_step = state["extra_state"]["global_step"]
+ start_epoch = global_step // args.train.train_steps
+ start_step = global_step % args.train.train_steps
+ lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"])
+ if start_step > 0 and args.train.resume_dataloader_state:
+ train_dataloader.load_state_dict(state["extra_state"]["train_dataloader"])
+ environ_meter.load_state_dict(state["extra_state"]["environ_meter"])
+ torch.set_rng_state(state["extra_state"]["torch_rng_state"])
+ if start_step == 0: # resume at the end of epoch
+ iter(train_dataloader) # clear resume state and prefetch data
+ dist.barrier()
+ logger.info_rank0(f"Load distributed checkpoint from {cp} successfully!")
+ loaded = True
+ break
+ except Exception as e:
+ last_err = e
+ logger.info_rank0(f"Failed to load checkpoint {cp}: {repr(e)}. Trying older one...")
+ continue
+ if not loaded:
+ logger.info_rank0("Starting training from scratch. No valid checkpoint could be loaded.")
+ else:
+ logger.info_rank0("Starting training from scratch.")
+
+ helper.empty_cache()
+ model_fwd_context, model_bwd_context = build_activation_offloading_context(
+ args.train.enable_activation_offload, args.train.enable_gradient_checkpointing, args.train.activation_gpu_limit
+ )
+ model.train()
+ logger.info(
+ f"rank{args.train.local_rank} Start training, train_steps: {args.train.train_steps}, epochs: {args.train.num_train_epochs}"
+ )
+ if model_ema is not None:
+ model_ema.eval()
+ # create the path in advance to save loss log
+ if args.train.global_rank == 0:
+ os.makedirs(args.train.save_checkpoint_path, exist_ok=True)
+ for epoch in range(start_epoch, args.train.num_train_epochs):
+ if hasattr(train_dataloader, "set_epoch"):
+ train_dataloader.set_epoch(epoch)
+
+ data_loader_tqdm = trange(
+ args.train.train_steps,
+ desc=f"Epoch {epoch + 1}/{args.train.num_train_epochs}",
+ total=args.train.train_steps,
+ initial=start_step,
+ disable=args.train.local_rank != 0,
+ )
+ data_iterator = iter(train_dataloader)
+ for _ in range(start_step, args.train.train_steps):
+ global_step += 1
+ try:
+ micro_batches: List[Dict[str, Any]] = next(data_iterator)
+ except StopIteration:
+ logger.info(f"epoch:{epoch} Dataloader finished with drop_last {args.data.drop_last}")
+ break
+
+ if global_step == 1:
+ helper.print_example(example=micro_batches[0], rank=args.train.local_rank)
+
+ total_loss = 0
+ total_vla_loss = 0
+ total_depth_loss = 0
+ depth_targets = None
+ depth_preds = None
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for micro_batch in micro_batches:
+ dataset_names = micro_batch.pop('rep_id', None)
+ environ_meter.add(micro_batch)
+
+ micro_batch = {
+ k: v.cuda(non_blocking=True) if isinstance(v, torch.Tensor) else v for k, v in micro_batch.items()
+ }
+ depth_forward_time = 0
+ if use_depth_align:
+ with torch.no_grad():
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ pil_images = micro_batch.pop('pil_images', None)
+ depth_targets, cls_token = get_depth_target(depth_model_type, (moge_model, morgbd_model), pil_images)
+
+ with model_fwd_context:
+ # torch.cuda.synchronize()
+ loss, vla_loss, depth_loss, loss_log, depth_preds = model(**micro_batch, vlm_causal = args.train.vlm_causal, use_ki = args.train.use_ki, depth_targets=depth_targets)
+ # torch.cuda.synchronize()
+
+ loss = loss / len(micro_batches)
+ vla_loss = vla_loss / len(micro_batches)
+ depth_loss = depth_loss / len(micro_batches)
+
+ with model_bwd_context:
+ loss.backward()
+
+ total_loss += loss.item()
+ total_vla_loss += vla_loss.item()
+ if not (isinstance(depth_loss, int) or isinstance(depth_loss, float)):
+ total_depth_loss += depth_loss.item()
+ del micro_batch
+ if global_step > args.train.stable_train_steps:
+ max_grad_norm = args.train.decayed_max_grad_norm
+ else:
+ max_grad_norm = args.train.max_grad_norm
+ if args.train.data_parallel_mode == "fsdp1":
+ grad_norm = model.clip_grad_norm_(max_grad_norm).item()
+ else:
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm, foreach=True)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+ if hasattr(grad_norm, "full_tensor"):
+ grad_norm = grad_norm.full_tensor().item()
+
+ # collect mean loss across data parallel group
+ total_loss, total_vla_loss, total_depth_loss, grad_norm = all_reduce((total_loss, total_vla_loss, total_depth_loss, grad_norm), group=get_parallel_state().fsdp_group)
+ if model_ema is not None:
+ ema_update(model_ema, model, args.train.ema_rate)
+ torch.cuda.synchronize()
+ delta_time = time.time() - start_time
+ lr = max(lr_scheduler.get_last_lr())
+ data_loader_tqdm.update()
+ logger.info_rank0(
+ f"Step {global_step}/{args.train.train_steps}, "
+ f"Epoch {epoch+1}, "
+ f"Loss {total_loss:.4f}, "
+ f"VLA_Loss {total_vla_loss:.4f}, "
+ f"Depth_Loss {total_depth_loss:.4f}, "
+ f"GradNorm {grad_norm:.4f}, "
+ f"LR {lr:.2e}, "
+ f"StepTime {delta_time:.3f}s, "
+ )
+
+
+ if args.train.global_rank == 0:
+ writer.add_scalar("training/loss", total_loss, global_step)
+ writer.add_scalar("training/vla_loss", total_vla_loss, global_step)
+ writer.add_scalar("training/depth_loss", total_depth_loss, global_step)
+ writer.add_scalar("training/grad_norm", grad_norm, global_step)
+ writer.add_scalar("training/lr", lr, global_step)
+ writer.add_scalar("steptime", delta_time, global_step)
+
+ # Log to wandb
+ if args.train.use_wandb:
+ wandb.log({
+ "training/loss": total_loss,
+ "training/vla_loss": total_vla_loss,
+ "training/depth_loss": total_depth_loss,
+ "training/grad_norm": grad_norm,
+ "training/lr": lr,
+ "steptime": delta_time,
+ "epoch": epoch + 1,
+ }, step=global_step)
+
+ # we only log the last mini batch if grad acc is activated
+ if dataset_names is not None and 'batch_mean_losses' in loss_log:
+ batch_mean_losses = loss_log['batch_mean_losses'] # shape (B,)
+ if hasattr(batch_mean_losses, "detach"):
+ batch_mean_losses = batch_mean_losses.detach().cpu()
+
+ group_losses = defaultdict(list)
+ for name, loss_value in zip(dataset_names, batch_mean_losses):
+ group_losses[name].append(loss_value.item() if hasattr(loss_value, "item") else float(loss_value))
+
+ detailed_loss_dict = {}
+ for name, values in group_losses.items():
+ mean_loss = sum(values) / len(values)
+ writer.add_scalar(f"detailed_loss/{name}", mean_loss, global_step)
+ detailed_loss_dict[f"detailed_loss/{name}"] = mean_loss
+
+ # Log detailed losses to wandb
+ if args.train.use_wandb and detailed_loss_dict:
+ wandb.log(detailed_loss_dict, step=global_step)
+
+ if args.train.enable_profiling and global_step <= args.train.profile_end_step:
+ profiler.step()
+ if global_step == args.train.profile_end_step:
+ profiler.stop()
+ helper.upload_trace(
+ args.train.wandb_project, args.train.wandb_name, args.train.profile_trace_dir
+ )
+
+ loss_record = {
+ "step": global_step,
+ "epoch": epoch + 1,
+ "loss": total_loss,
+ "grad_norm": grad_norm,
+ "lr": lr,
+ "step_time": delta_time
+ }
+ loss_file_path = os.path.join(args.train.save_checkpoint_path, "loss.jsonl")
+ try:
+ with open(loss_file_path, "a", encoding="utf-8") as f:
+ f.write(json.dumps(loss_record, ensure_ascii=False) + "\n")
+ except Exception as e:
+ logger.info_rank0(f"⚠️ Failed to write loss.jsonl: {e}")
+
+ # if use_depth_align:
+ # if global_step % args.train.align_params['visual_steps'] == 0:
+ # with torch.no_grad():
+ # with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ # log_depth(morgbd_model, depth_preds, depth_targets, steps=global_step, config=args.train.align_params, cls_token=cls_token)
+
+ if args.train.save_steps and global_step % args.train.save_steps == 0:
+ helper.empty_cache()
+ save_checkpoint_path = os.path.join(args.train.save_checkpoint_path, f"global_step_{global_step}")
+
+ state = {
+ "model": model,
+ "ema": model_ema,
+ "optimizer": optimizer,
+ "extra_state": {
+ "global_step": global_step,
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "train_dataloader": train_dataloader.state_dict(),
+ "environ_meter": environ_meter.state_dict(),
+ "torch_rng_state": torch.get_rng_state(),
+ },
+ }
+ Checkpointer.save(args.train.save_checkpoint_path, state, global_steps=global_step)
+ dist.barrier()
+ logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!")
+ if args.train.global_rank == 0:
+ if args.train.save_hf_weights and save_checkpoint_path is not None:
+ hf_weights_path = os.path.join(save_checkpoint_path, "hf_ckpt")
+ model_state_dict = ckpt_to_state_dict(
+ save_checkpoint_path=save_checkpoint_path,
+ output_dir=args.train.output_dir,
+ ckpt_manager=args.train.ckpt_manager,
+ )
+ if args.train.enable_fp32:
+ save_model_weights(hf_weights_path, model_state_dict, model_assets=model_assets, save_dtype=torch.float32)
+ else:
+ save_model_weights(hf_weights_path, model_state_dict, model_assets=model_assets)
+ logger.info_rank0(f"Huggingface checkpoint saved at {hf_weights_path} successfully!")
+ if "ema" in state and state["ema"] is not None:
+ ema_hf_weights_path = os.path.join(save_checkpoint_path, "ema_hf_ckpt")
+ ema_model_state_dict = ckpt_to_state_dict(
+ save_checkpoint_path=save_checkpoint_path,
+ output_dir=args.train.output_dir,
+ ckpt_manager=args.train.ckpt_manager,
+ ema=True
+ )
+ if args.train.enable_fp32:
+ save_model_weights(ema_hf_weights_path, ema_model_state_dict, model_assets=model_assets, save_dtype=torch.float32)
+ else:
+ save_model_weights(ema_hf_weights_path, ema_model_state_dict, model_assets=model_assets)
+ logger.info_rank0(f"Huggingface EMA checkpoint saved at {ema_hf_weights_path} successfully!")
+
+ data_loader_tqdm.close()
+ start_step = 0
+ helper.print_device_mem_info(f"VRAM usage after epoch {epoch + 1}")
+ if args.train.save_epochs and (epoch + 1) % args.train.save_epochs == 0:
+ helper.empty_cache()
+ save_checkpoint_path = os.path.join(args.train.save_checkpoint_path, f"global_step_{global_step}")
+ state = {
+ "model": model,
+ "ema": model_ema,
+ "optimizer": optimizer,
+ "extra_state": {
+ "global_step": global_step,
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "train_dataloader": train_dataloader.state_dict(),
+ "environ_meter": environ_meter.state_dict(),
+ "torch_rng_state": torch.get_rng_state(),
+ },
+ }
+ Checkpointer.save(args.train.save_checkpoint_path, state, global_steps=global_step)
+ dist.barrier()
+ logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!")
+ if args.train.global_rank == 0:
+ if args.train.save_hf_weights and save_checkpoint_path is not None:
+ hf_weights_path = os.path.join(save_checkpoint_path, "hf_ckpt")
+ model_state_dict = ckpt_to_state_dict(
+ save_checkpoint_path=save_checkpoint_path,
+ output_dir=args.train.output_dir,
+ ckpt_manager=args.train.ckpt_manager,
+ )
+ if args.train.enable_fp32:
+ save_model_weights(hf_weights_path, model_state_dict, model_assets=model_assets, save_dtype=torch.float32)
+ else:
+ save_model_weights(hf_weights_path, model_state_dict, model_assets=model_assets)
+ logger.info_rank0(f"Huggingface checkpoint saved at {hf_weights_path} successfully!")
+ if "ema" in state and state["ema"] is not None:
+ ema_hf_weights_path = os.path.join(save_checkpoint_path, "ema_hf_ckpt")
+ ema_model_state_dict = ckpt_to_state_dict(
+ save_checkpoint_path=save_checkpoint_path,
+ output_dir=args.train.output_dir,
+ ckpt_manager=args.train.ckpt_manager,
+ ema=True
+ )
+ if args.train.enable_fp32:
+ save_model_weights(ema_hf_weights_path, ema_model_state_dict, model_assets=model_assets, save_dtype=torch.float32)
+ else:
+ save_model_weights(ema_hf_weights_path, ema_model_state_dict, model_assets=model_assets)
+ logger.info_rank0(f"Huggingface EMA checkpoint saved at {ema_hf_weights_path} successfully!")
+
+ torch.cuda.synchronize()
+ # release memory
+ del optimizer, lr_scheduler
+ helper.empty_cache()
+ # save model in huggingface's format
+ if args.train.global_rank == 0:
+ if args.train.save_hf_weights and save_checkpoint_path is not None:
+ hf_weights_path = os.path.join(save_checkpoint_path, "hf_ckpt")
+ model_state_dict = ckpt_to_state_dict(
+ save_checkpoint_path=save_checkpoint_path,
+ output_dir=args.train.output_dir,
+ ckpt_manager=args.train.ckpt_manager,
+ )
+ if args.train.enable_fp32:
+ save_model_weights(hf_weights_path, model_state_dict, model_assets=model_assets, save_dtype=torch.float32)
+ else:
+ save_model_weights(hf_weights_path, model_state_dict, model_assets=model_assets)
+ logger.info_rank0(f"Huggingface checkpoint saved at {hf_weights_path} successfully!")
+ if "ema" in state and state["ema"] is not None:
+ ema_hf_weights_path = os.path.join(save_checkpoint_path, "ema_hf_ckpt")
+ ema_model_state_dict = ckpt_to_state_dict(
+ save_checkpoint_path=save_checkpoint_path,
+ output_dir=args.train.output_dir,
+ ckpt_manager=args.train.ckpt_manager,
+ ema=True
+ )
+ if args.train.enable_fp32:
+ save_model_weights(ema_hf_weights_path, ema_model_state_dict, model_assets=model_assets, save_dtype=torch.float32)
+ else:
+ save_model_weights(ema_hf_weights_path, ema_model_state_dict, model_assets=model_assets)
+ logger.info_rank0(f"Huggingface EMA checkpoint saved at {ema_hf_weights_path} successfully!")
+
+ dist.barrier()
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/train.sh b/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2e2a433f3c42dbe77a2f97dd8ef1f7d3525d7ae3
--- /dev/null
+++ b/train.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+set -x
+
+export TOKENIZERS_PARALLELISM=false
+if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
+ NPROC_PER_NODE=$(nvidia-smi -L | wc -l)
+else
+ # 可见 GPU 数量
+ NPROC_PER_NODE=$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n' | wc -l)
+fi
+echo "Using NPROC_PER_NODE=$NPROC_PER_NODE GPUs"
+NNODES=${NNODES:=1}
+NPROC_PER_NODE=${NPROC_PER_NODE:=$NPROC_PER_NODE}
+NODE_RANK=${NODE_RANK:=0}
+MASTER_ADDR=${MASTER_ADDR:=0.0.0.0}
+MASTER_PORT=${MASTER_PORT:=62500}
+
+
+torchrun --nnodes=$NNODES --nproc-per-node $NPROC_PER_NODE --node-rank $NODE_RANK \
+ --master-addr=$MASTER_ADDR --master-port=$MASTER_PORT $@ 2>&1 | tee log.txt