Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +46 -0
- .gitignore +222 -0
- .gitmodules +6 -0
- .vscode/launch.json +88 -0
- LEGAL.md +7 -0
- LICENSE +202 -0
- Makefile +21 -0
- README.md +330 -0
- assets/LingBot-VLA.pdf +3 -0
- assets/PaliGemmaPI.png +3 -0
- assets/QwenPI.png +3 -0
- assets/QwenPI_PaliGemmaPI.png +3 -0
- assets/Teaser.png +3 -0
- assets/exp-gm-100.png +3 -0
- assets/exp-robotwin.png +3 -0
- assets/norm_stats/libero.json +280 -0
- assets/norm_stats/robotwin_50.json +229 -0
- assets/norm_stats/robotwin_5_customized.json +201 -0
- assets/norm_stats/robotwin_all_new.json +229 -0
- assets/scale_ps.png +3 -0
- assets/scale_sr.png +3 -0
- configs/norm/robotwin_5.yaml +12 -0
- configs/vla/robotwin_load20000h.yaml +42 -0
- configs/vla/robotwin_load20000h_depth.yaml +68 -0
- deploy/__init__.py +0 -0
- deploy/image_tools.py +58 -0
- deploy/lingbot_robotwin_policy.py +506 -0
- deploy/lingbot_robotwin_policy_rep.py +491 -0
- deploy/msgpack_numpy.py +57 -0
- deploy/websocket_client_policy.py +88 -0
- deploy/websocket_policy_server.py +89 -0
- docker/Dockerfile +34 -0
- docs/Makefile +20 -0
- docs/README.md +19 -0
- docs/conf.py +66 -0
- docs/config/config.md +96 -0
- docs/examples/qwen2vl.rst +2 -0
- docs/examples/qwen3_moe.md +125 -0
- docs/index.rst +2 -0
- docs/requirements-docs.txt +9 -0
- docs/start/start.rst +2 -0
- experiment/libero/README.md +18 -0
- experiment/libero/libero/libero_utils.py +112 -0
- experiment/libero/libero/req.txt +6 -0
- experiment/libero/libero/run_libero_eval.py +300 -0
- experiment/libero/robot_utils.py +84 -0
- experiment/robotwin/README.md +85 -0
- lingbotvla/__init__.py +16 -0
- lingbotvla/checkpoint/__init__.py +25 -0
- lingbotvla/checkpoint/checkpointer.py +340 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,49 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/LingBot-VLA.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/PaliGemmaPI.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/QwenPI.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/QwenPI_PaliGemmaPI.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/Teaser.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/exp-gm-100.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/exp-robotwin.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
assets/scale_ps.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
assets/scale_sr.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
lingbotvla/models/vla/vision_models/MoGe/assets/normal_comaprison.jpg filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
lingbotvla/models/vla/vision_models/MoGe/assets/overview_simplified.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
lingbotvla/models/vla/vision_models/MoGe/assets/panorama_pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/01_HouseIndoor.jpg filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/02_Office.jpg filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/03_Traffic.jpg filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/05_Mountain.jpg filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/06_MaitreyaBuddha.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/07_Breads.jpg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/08_CatGirl.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/09_Restaurant.jpg filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/10_MedievalVillage.jpg filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
lingbotvla/models/vla/vision_models/MoGe/example_images/panorama/Braunschweig_Panoram.jpg filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/assets/attention/fig-attention-vis.png filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/assets/dataset/diversity_figure.png filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-divided.jpg filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/assets/device/device-full.jpg filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_grasp/fig-grasp-demo.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-dynamic-tracking.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/assets/downstream_tracking/fig-scene-tracking-crop.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/assets/teaser/teaser-crop.png filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/raw_depth.png filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/0/rgb.png filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/raw_depth.png filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/1/rgb.jpg filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/raw_depth.png filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/2/rgb.png filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/raw_depth.png filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/3/rgb.jpg filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/raw_depth.png filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/4/rgb.png filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/raw_depth.png filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/5/rgb.png filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/6/raw_depth.png filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/raw_depth.png filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/examples/7/rgb.jpg filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
lingbotvla/models/vla/vision_models/lingbot-depth/tech-report.pdf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
|
| 209 |
+
# log
|
| 210 |
+
*log.txt
|
| 211 |
+
ossutil_output/
|
| 212 |
+
.sumi/
|
| 213 |
+
env.sh
|
| 214 |
+
pids_qwenpi.txt
|
| 215 |
+
run.sh
|
| 216 |
+
start_multi_eval.sh
|
| 217 |
+
trash/
|
| 218 |
+
eval.sh
|
| 219 |
+
|
| 220 |
+
# xwc
|
| 221 |
+
output/
|
| 222 |
+
wandb/
|
.gitmodules
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "lingbotvla/models/vla/vision_models/lingbot-depth"]
|
| 2 |
+
path = lingbotvla/models/vla/vision_models/lingbot-depth
|
| 3 |
+
url = https://github.com/Robbyant/lingbot-depth
|
| 4 |
+
[submodule "lingbotvla/models/vla/vision_models/MoGe"]
|
| 5 |
+
path = lingbotvla/models/vla/vision_models/MoGe
|
| 6 |
+
url = https://github.com/microsoft/MoGe.git
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
// Use IntelliSense to learn about possible attributes.
|
| 3 |
+
// Hover to view descriptions of existing attributes.
|
| 4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
| 5 |
+
"version": "0.2.0",
|
| 6 |
+
"configurations": [
|
| 7 |
+
{
|
| 8 |
+
"name": "deploy lingbotvla (模块方式)",
|
| 9 |
+
"type": "debugpy",
|
| 10 |
+
"request": "launch",
|
| 11 |
+
"module": "deploy.lingbot_robotwin_policy",
|
| 12 |
+
"console": "integratedTerminal",
|
| 13 |
+
"cwd": "${workspaceFolder}",
|
| 14 |
+
"justMyCode": false,
|
| 15 |
+
"args": [
|
| 16 |
+
"--model_path",
|
| 17 |
+
"output/ori_4/checkpoints/global_step_12850/hf_ckpt",
|
| 18 |
+
"--use_length",
|
| 19 |
+
"50",
|
| 20 |
+
"--chunk_ret",
|
| 21 |
+
"true",
|
| 22 |
+
"--debug_infer_once"
|
| 23 |
+
],
|
| 24 |
+
"env": {
|
| 25 |
+
"CUDA_VISIBLE_DEVICES": "0",
|
| 26 |
+
"QWEN25_PATH": "/group/ossdphi_algo_scratch_11/weicxu/huggingface_cache/hub/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/66285546d2b821cf421d4f5eb2576359d3770cd3"
|
| 27 |
+
}
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"name": "example_call_robotwin_server",
|
| 31 |
+
"type": "debugpy",
|
| 32 |
+
"request": "launch",
|
| 33 |
+
"module": "deploy.example_call_robotwin_server",
|
| 34 |
+
"console": "integratedTerminal",
|
| 35 |
+
"cwd": "${workspaceFolder}",
|
| 36 |
+
"justMyCode": false,
|
| 37 |
+
"args": [
|
| 38 |
+
"--host",
|
| 39 |
+
"127.0.0.1",
|
| 40 |
+
"--port",
|
| 41 |
+
"8006"
|
| 42 |
+
],
|
| 43 |
+
"env": {
|
| 44 |
+
"CUDA_VISIBLE_DEVICES": "0"
|
| 45 |
+
}
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"name": "train lingbotvla",
|
| 49 |
+
"type": "debugpy",
|
| 50 |
+
"request": "launch",
|
| 51 |
+
"program": "${file}",
|
| 52 |
+
"console": "integratedTerminal",
|
| 53 |
+
"justMyCode": false,
|
| 54 |
+
"args": [
|
| 55 |
+
"configs/vla/robotwin_load20000h.yaml",
|
| 56 |
+
"--model.model_path",
|
| 57 |
+
"robbyant/lingbot-vla-4b",
|
| 58 |
+
"--data.train_path",
|
| 59 |
+
"mixed_robotwin_5tasks_repo_0.1.0",
|
| 60 |
+
"--train.output_dir",
|
| 61 |
+
"output/",
|
| 62 |
+
"--model.tokenizer_path",
|
| 63 |
+
"Qwen/Qwen2.5-VL-3B-Instruct",
|
| 64 |
+
"--train.micro_batch_size",
|
| 65 |
+
"1",
|
| 66 |
+
"--train.global_batch_size",
|
| 67 |
+
"1",
|
| 68 |
+
"--train.enable_full_shard",
|
| 69 |
+
"true",
|
| 70 |
+
"--train.use_compile",
|
| 71 |
+
"false",
|
| 72 |
+
"--train.enable_fp32",
|
| 73 |
+
"false",
|
| 74 |
+
"--train.freeze_vision_encoder",
|
| 75 |
+
"true",
|
| 76 |
+
],
|
| 77 |
+
"env": {
|
| 78 |
+
"CUDA_VISIBLE_DEVICES": "2",
|
| 79 |
+
"LOCAL_RANK": "0",
|
| 80 |
+
"RANK": "0",
|
| 81 |
+
"WORLD_SIZE": "1",
|
| 82 |
+
"MASTER_ADDR": "localhost",
|
| 83 |
+
"MASTER_PORT": "29500",
|
| 84 |
+
"PYDEVD_USE_SYS_MONITORING": "0"
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
]
|
| 88 |
+
}
|
LEGAL.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Legal Disclaimer
|
| 2 |
+
|
| 3 |
+
Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail.
|
| 4 |
+
|
| 5 |
+
法律免责声明
|
| 6 |
+
|
| 7 |
+
关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。
|
LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright [2026] [Robbyant Team]
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
Makefile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: build commit quality style test
|
| 2 |
+
|
| 3 |
+
check_dirs := tasks tests lingbot docs setup.py
|
| 4 |
+
|
| 5 |
+
build:
|
| 6 |
+
python3 setup.py sdist bdist_wheel
|
| 7 |
+
|
| 8 |
+
commit:
|
| 9 |
+
pre-commit install
|
| 10 |
+
pre-commit run --all-files
|
| 11 |
+
|
| 12 |
+
quality:
|
| 13 |
+
ruff check $(check_dirs)
|
| 14 |
+
ruff format --check $(check_dirs)
|
| 15 |
+
|
| 16 |
+
style:
|
| 17 |
+
ruff check $(check_dirs) --fix
|
| 18 |
+
ruff format $(check_dirs)
|
| 19 |
+
|
| 20 |
+
test:
|
| 21 |
+
pytest tests/
|
README.md
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center">LingBot-VLA: A Pragmatic VLA Foundation Model</h1>
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<a href="assets/LingBot-VLA.pdf"><img src="https://img.shields.io/static/v1?label=Paper&message=PDF&color=red&logo=arxiv"></a>
|
| 5 |
+
<a href="https://technology.robbyant.com/lingbot-vla"><img src="https://img.shields.io/badge/Project-Website-blue"></a>
|
| 6 |
+
<a href="https://huggingface.co/collections/robbyant/lingbot-vla"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Model&message=HuggingFace&color=yellow"></a>
|
| 7 |
+
<a href="https://modelscope.cn/collections/Robbyant/LingBot-VLA"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%96%20Model&message=ModelScope&color=purple"></a>
|
| 8 |
+
<a href="https://huggingface.co/datasets/robbyant/gm100"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20GM-100&message=HuggingFace&color=yellow"></a>
|
| 9 |
+
<a href="LICENSE"><img src="https://img.shields.io/badge/License-Apache--2.0-green"></a>
|
| 10 |
+
</p>
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
<p align="center">
|
| 14 |
+
<img src="assets/Teaser.png" width="100%">
|
| 15 |
+
</p>
|
| 16 |
+
|
| 17 |
+
## 🥳 We are excited to introduce **LingBot-VLA**, a pragmatic Vision-Language-Action foundation model.
|
| 18 |
+
|
| 19 |
+
**LingBot-VLA** has focused on **Pragmatic**:
|
| 20 |
+
- **Large-scale Pre-training Data**: 20,000 hours of real-world
|
| 21 |
+
data from 9 popular dual-arm robot configurations.
|
| 22 |
+
<p align="center">
|
| 23 |
+
<img src="assets/scale_sr.png" width="45%" style="margin: 0 10px;">
|
| 24 |
+
<img src="assets/scale_ps.png" width="45%" style="margin: 0 10px;">
|
| 25 |
+
</p>
|
| 26 |
+
|
| 27 |
+
- **Strong Performance**: Achieve clear superiority over competitors on simulation and real-world benchmarks.
|
| 28 |
+
- **Training Efficiency**: Represent a 1.5 ∼ 2.8× (depending on the relied VLM base model) speedup over existing VLA-oriented codebases.
|
| 29 |
+
|
| 30 |
+
## 🚀 News
|
| 31 |
+
- **[2026-01-27]** LingBot-VLA Technical Report is available on Arxiv.
|
| 32 |
+
- **[2026-01-27]** Weights and code released!
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
## 🛠️ Installation
|
| 39 |
+
Requirements
|
| 40 |
+
- Python 3.12.3
|
| 41 |
+
- Pytorch 2.8.0
|
| 42 |
+
- CUDA 12.8
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
# Install Lerobot
|
| 46 |
+
pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu128
|
| 47 |
+
GIT_LFS_SKIP_SMUDGE=1 git clone https://github.com/huggingface/lerobot.git
|
| 48 |
+
cd lerobot
|
| 49 |
+
git checkout 0cf864870cf29f4738d3ade893e6fd13fbd7cdb5
|
| 50 |
+
pip install -e .
|
| 51 |
+
# Install flash attention
|
| 52 |
+
pip install /path/to/flash_attn-2.8.3+cu12torch2.8cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
|
| 53 |
+
|
| 54 |
+
# Clone the repository
|
| 55 |
+
git clone https://github.com/robbyant/lingbot-vla.git
|
| 56 |
+
cd lingbot-vla/
|
| 57 |
+
git submodule update --remote --recursive
|
| 58 |
+
pip install -e .
|
| 59 |
+
pip install -r requirements.txt
|
| 60 |
+
# Install LingBot-Depth dependency
|
| 61 |
+
cd ./lingbotvla/models/vla/vision_models/lingbot-depth/
|
| 62 |
+
pip install -e . --no-deps
|
| 63 |
+
cd ../MoGe
|
| 64 |
+
pip install -e .
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
---
|
| 68 |
+
|
| 69 |
+
## 📦 Model Download
|
| 70 |
+
We release LingBot-VLA pre-trained weights in two configurations: depth-free version and a depth-distillated version.
|
| 71 |
+
- **Pretrained Checkpoints for Post-Training with and without depth**
|
| 72 |
+
|
| 73 |
+
| Model Name | Huggingface | ModelScope | Description |
|
| 74 |
+
| :--- | :---: | :---: | :---: |
|
| 75 |
+
| LingBot-VLA-4B | [🤗 lingbot-vla-4b](https://huggingface.co/robbyant/lingbot-vla-4b) | [🤖 lingbot-vla-4b](https://modelscope.cn/models/Robbyant/lingbot-vla-4b) | LingBot-VLA *w/o* Depth|
|
| 76 |
+
| LingBot-VLA-4B-Depth | [🤗 lingbot-vla-4b-depth](https://huggingface.co/robbyant/lingbot-vla-4b-depth) | [🤖 lingbot-vla-4b-depth](https://modelscope.cn/models/Robbyant/lingbot-vla-4b-depth) | LingBot-VLA *w/* Depth |
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
To train LingBot with our codebase, weights from [Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct), [MoGe-2-vitb-normal](https://huggingface.co/Ruicheng/moge-2-vitb-normal), and [LingBot-Depth](https://huggingface.co/robbyant/lingbot-depth-pretrain-vitl-14) also need to be prepared.
|
| 82 |
+
- **Run Command**:
|
| 83 |
+
```bash
|
| 84 |
+
python3 scripts/download_hf_model.py --repo_id robbyant/lingbot-vla-4b --local_dir lingbot-vla-4b
|
| 85 |
+
```
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## 💻 Post-Training Example
|
| 89 |
+
|
| 90 |
+
- **Data Preparation**:
|
| 91 |
+
Please follow [RoboTwin2.0 Preparation](experiment/robotwin/README.md)
|
| 92 |
+
|
| 93 |
+
- **Training Configuration**:
|
| 94 |
+
We provide the mixed post-training configuration in five RoboTwin 2.0 tasks ("open_microwave" "click_bell" "stack_blocks_three" "place_shoe" "put_object_cabinet").
|
| 95 |
+
<details>
|
| 96 |
+
<summary><b>Click to expand full YAML configuration</b></summary>
|
| 97 |
+
|
| 98 |
+
```yaml
|
| 99 |
+
model:
|
| 100 |
+
model_path: "path/to/lingbot_vla_checkpoint" # Path to pre-trained VLA foundation model (w/o or w depth)
|
| 101 |
+
tokenizer_path: "path/to/Qwen2.5-VL-3B-Instruct"
|
| 102 |
+
post_training: true # Enable post-training/fine-tuning mode
|
| 103 |
+
adanorm_time: true
|
| 104 |
+
old_adanorm: true
|
| 105 |
+
|
| 106 |
+
data:
|
| 107 |
+
datasets_type: vla
|
| 108 |
+
data_name: robotwin_5_new
|
| 109 |
+
train_path: "path/to/lerobot_merged_data" # merged data from 5 robotwin2.0 tasks
|
| 110 |
+
num_workers: 8
|
| 111 |
+
norm_type: bounds_99_woclip
|
| 112 |
+
norm_stats_file: assets/norm_stats/robotwin_50.json # file of normalization statistics
|
| 113 |
+
|
| 114 |
+
train:
|
| 115 |
+
output_dir: "path/to/output"
|
| 116 |
+
loss_type: L1_fm # we apply L1 flow-matching loss in robotwin2.0 finetuning
|
| 117 |
+
data_parallel_mode: fsdp2 # Use Fully Sharded Data Parallel (PyTorch FSDP2)
|
| 118 |
+
enable_full_shard: false # Don't apply reshare after forward in FSDP2
|
| 119 |
+
module_fsdp_enable: true
|
| 120 |
+
use_compile: true # Acceleration via torch.compile
|
| 121 |
+
use_wandb: false
|
| 122 |
+
rmpad: false
|
| 123 |
+
rmpad_with_pos_ids: false
|
| 124 |
+
ulysses_parallel_size: 1
|
| 125 |
+
freeze_vision_encoder: false # ViT need to be optimized
|
| 126 |
+
tokenizer_max_length: 24 # token numbers of task prompt
|
| 127 |
+
action_dim: 14 # Target robot action space dimension
|
| 128 |
+
max_action_dim: 75 # action dim in LingBot-VLA
|
| 129 |
+
max_state_dim: 75 # state dim in LingBot-VLA
|
| 130 |
+
lr: 1.0e-4
|
| 131 |
+
lr_decay_style: constant
|
| 132 |
+
num_train_epochs: 69 # finetuning 20k step
|
| 133 |
+
micro_batch_size: 32
|
| 134 |
+
global_batch_size: 256
|
| 135 |
+
max_steps: 220000
|
| 136 |
+
ckpt_manager: dcp
|
| 137 |
+
save_steps: 220000
|
| 138 |
+
save_epochs: 69
|
| 139 |
+
enable_fp32: true
|
| 140 |
+
enable_resume: true # resume training automatically
|
| 141 |
+
# ===========================================================================
|
| 142 |
+
# Depth Injection Parameters
|
| 143 |
+
# (Required only for LingBot-VLA with Depth. Ignore if not using depth)
|
| 144 |
+
# ===========================================================================
|
| 145 |
+
align_params:
|
| 146 |
+
mode: 'query' # Query-based distillation
|
| 147 |
+
num_task_tokens: 8 # Number of learnable task-specific tokens
|
| 148 |
+
use_image_tokens: True
|
| 149 |
+
use_task_tokens: False
|
| 150 |
+
use_text_tokens: False
|
| 151 |
+
use_contrastive: True
|
| 152 |
+
contrastive_loss_weight: 0.3
|
| 153 |
+
depth_loss_weight: 0.002
|
| 154 |
+
llm: # VLM Projection Settings
|
| 155 |
+
dim_out: 2048
|
| 156 |
+
image_token_size: 8
|
| 157 |
+
image_input_size: 224
|
| 158 |
+
depth:
|
| 159 |
+
model_type: MoRGBD
|
| 160 |
+
moge_path: /"path/to/moGe-2-vitb-normal"
|
| 161 |
+
morgbd_path: "path/to/LingBot-Depth"
|
| 162 |
+
num_layers: 1
|
| 163 |
+
num_heads: 4
|
| 164 |
+
dim_head: 32
|
| 165 |
+
ff_mult: 1
|
| 166 |
+
num_backbone_tokens: 256
|
| 167 |
+
token_size: 16
|
| 168 |
+
dim_out: 1024
|
| 169 |
+
input_size: 224
|
| 170 |
+
visual_steps: 10000
|
| 171 |
+
visual_dir: "path/to/output/images" # visualization path of depth distillation
|
| 172 |
+
```
|
| 173 |
+
</details>
|
| 174 |
+
|
| 175 |
+
- **Run Command**:
|
| 176 |
+
```bash
|
| 177 |
+
# without detph
|
| 178 |
+
bash train.sh tasks/vla/train_lingbotvla.py ./configs/vla/robotwin_load20000h.yaml --model.model_path /path/to/LingBot-VLA --data.train_path path/to/mixed_robotwin_5tasks --train.output_dir /path/to/lingbot_robotwin5tasks/ --model.tokenizer_path /path/to/Qwen2.5-VL-3B-Instruct --train.micro_batch_size ${your_batch_size} --train.global_batch_size ${your_batch_size * your_gpu_num}
|
| 179 |
+
|
| 180 |
+
# with depth
|
| 181 |
+
bash train.sh tasks/vla/train_lingbotvla.py ./configs/vla/robotwin_load20000h_depth.yaml --model.model_path /path/to/LingBot-VLA-Depth --data.train_path /path/to/mixed_robotwin_5tasks --train.output_dir /path/to/lingbot_depth_robotwin5tasks --model.tokenizer_path /path/to/Qwen2.5-VL-3B-Instruct --model.moge_path /path/to/moge2-vitb-normal.pt --model.morgbd_path /path/to/LingBot-Depth-Pretrained --train.micro_batch_size ${your_batch_size} --train.global_batch_size ${your_batch_size * your_gpu_num}
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
- **Evaluation**
|
| 185 |
+
```bash
|
| 186 |
+
# robotwin2.0
|
| 187 |
+
export QWEN25_PATH=path_to_Qwen2.5-VL-3B-Instruct
|
| 188 |
+
python -m deploy.lingbot_robotwin_policy \
|
| 189 |
+
--model_path path_to_your_model \
|
| 190 |
+
--use_length 50 \
|
| 191 |
+
--port port
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
- **Customized Post-training**:
|
| 195 |
+
To construct post-training in specified downstream tasks, we have provided an example and please refer to [Custom](lingbotvla/data/vla_data/README.md) for details.
|
| 196 |
+
---
|
| 197 |
+
|
| 198 |
+
## 🏗️ Efficiency
|
| 199 |
+
<p align="center">
|
| 200 |
+
<img src="assets/QwenPI_PaliGemmaPI.png" width="85%">
|
| 201 |
+
</p>
|
| 202 |
+
We evaluate the training efficiency of our codebase against established baselines for both <b>Qwen2.5-VL-3B-π</b> and <b>PaliGemma-3B-pt-224-π</b> models. The results demonstrate that our codebase
|
| 203 |
+
achieved the fastest training speeds in both model settings. The above figures detail the training throughput across configurations of 8, 16, 32, 128, and 256 GPUs, alongside the theoretical linear scaling limit.
|
| 204 |
+
|
| 205 |
+
> **📢 Note on Throughput Metrics:**
|
| 206 |
+
> All throughput values (e.g., 261 samples/sec) represent the **total aggregate throughput across all GPUs**, not per-GPU performance.
|
| 207 |
+
> <br><sup>(Updated: Previously mislabeled as per-GPU in earlier versions. We apologize for the confusion.)</sup>
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## 📊 Performance
|
| 212 |
+
|
| 213 |
+
Our LingBot-VLA achieves state-of-the-art results on real-world and simulation benchmarks:
|
| 214 |
+
- **GM-100 across 3 robot platforms**
|
| 215 |
+
|
| 216 |
+
<table>
|
| 217 |
+
<thead>
|
| 218 |
+
<tr>
|
| 219 |
+
<th rowspan="2">Platform</th>
|
| 220 |
+
<th colspan="2">WALL-OSS</th>
|
| 221 |
+
<th colspan="2">GR00T N1.6</th>
|
| 222 |
+
<th colspan="2">π<sub>0.5</sub></th>
|
| 223 |
+
<th colspan="2">Ours w/o depth</th>
|
| 224 |
+
<th colspan="2">Ours w/ depth</th>
|
| 225 |
+
</tr>
|
| 226 |
+
<tr>
|
| 227 |
+
<th>SR</th><th>PS</th>
|
| 228 |
+
<th>SR</th><th>PS</th>
|
| 229 |
+
<th>SR</th><th>PS</th>
|
| 230 |
+
<th>SR</th><th>PS</th>
|
| 231 |
+
<th>SR</th><th>PS</th>
|
| 232 |
+
</tr>
|
| 233 |
+
</thead>
|
| 234 |
+
<tbody>
|
| 235 |
+
<tr>
|
| 236 |
+
<td>Agibot G1</td>
|
| 237 |
+
<td>2.99%</td><td>8.75%</td><td>5.23%</td><td>12.63%</td><td>7.77%</td><td>21.98%</td><td><b>12.82%</b></td><td>30.04%</td><td>11.98%</td><td><b>30.47%</b></td>
|
| 238 |
+
</tr>
|
| 239 |
+
<tr>
|
| 240 |
+
<td>AgileX</td>
|
| 241 |
+
<td>2.26%</td><td>8.16%</td><td>3.26%</td><td>10.52%</td><td>17.20%</td><td>34.82%</td><td>15.50%</td><td>36.31%</td><td><b>18.93%</b></td><td><b>40.36%</b></td>
|
| 242 |
+
</tr>
|
| 243 |
+
<tr>
|
| 244 |
+
<td>Galaxea R1Pro</td>
|
| 245 |
+
<td>6.89%</td><td>14.13%</td><td>14.29%</td><td>24.83%</td><td>14.10%</td><td>26.14%</td><td>18.89%</td><td>34.71%</td><td><b>20.98%</b></td><td><b>35.40%</b></td>
|
| 246 |
+
</tr>
|
| 247 |
+
<tr>
|
| 248 |
+
<td><b>Average</b></td>
|
| 249 |
+
<td>4.05%</td><td>10.35%</td><td>7.59%</td><td>15.99%</td><td>13.02%</td><td>27.65%</td><td>15.74%</td><td>33.69%</td><td><b>17.30%</b></td><td><b>35.41%</b></td>
|
| 250 |
+
</tr>
|
| 251 |
+
</tbody>
|
| 252 |
+
</table>
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
- **RoboTwin 2.0 (Clean and Randomized)**
|
| 256 |
+
|
| 257 |
+
<table>
|
| 258 |
+
<thead>
|
| 259 |
+
<tr>
|
| 260 |
+
<th rowspan="2" ><b>Simulation Tasks</b></th>
|
| 261 |
+
<th colspan="2"><b>π<sub>0.5</sub></b></th>
|
| 262 |
+
<th colspan="2"><b>Ours w/o depth</b></th>
|
| 263 |
+
<th colspan="2"><b>Ours w/ depth</b></th>
|
| 264 |
+
</tr>
|
| 265 |
+
<tr>
|
| 266 |
+
<th><b>Clean</b></th>
|
| 267 |
+
<th><b>Rand.</b></th>
|
| 268 |
+
<th><b>Clean</b></th>
|
| 269 |
+
<th><b>Rand.</b></th>
|
| 270 |
+
<th><b>Clean</b></th>
|
| 271 |
+
<th><b>Rand.</b></th>
|
| 272 |
+
</tr>
|
| 273 |
+
</thead>
|
| 274 |
+
<tbody>
|
| 275 |
+
<tr style="border-top: 1px solid #ccc;"> <!-- \midrule -->
|
| 276 |
+
<td><b>Average SR</b></td>
|
| 277 |
+
<td>82.74%</td>
|
| 278 |
+
<td>76.76%</td>
|
| 279 |
+
<td>86.50%</td>
|
| 280 |
+
<td>85.34%</td>
|
| 281 |
+
<td>88.56%</td>
|
| 282 |
+
<td>86.68%</td>
|
| 283 |
+
</tr>
|
| 284 |
+
<!-- 您可以在此处继续添加其他任务行 -->
|
| 285 |
+
</tbody>
|
| 286 |
+
</table>
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
📢 We have released our checkpoints of LingBot-VLA-Posttrain-Robotwin:
|
| 290 |
+
| Model Name | Huggingface | ModelScope | Description |
|
| 291 |
+
| :--- | :---: | :---: | :---: |
|
| 292 |
+
| LingBot-VLA-4B-Posttrain-Robotwin | [🤗 lingbot-vla-4b-posttrain-robotwin](https://huggingface.co/robbyant/lingbot-vla-4b-posttrain-robotwin) | [🤖 lingbot-vla-4b-posttrain-robotwin](https://modelscope.cn/models/Robbyant/lingbot-vla-4b-posttrain-robotwin) | LingBot-VLA-Posttrain-Robotwin *w/o* Depth|
|
| 293 |
+
| LingBot-VLA-4B-Depth-Posttrain-Robotwin | [🤗 lingbot-vla-4b-depth-posttrain-robotwin](https://huggingface.co/robbyant/lingbot-vla-4b-depth-posttrain-robotwin) | [🤖 lingbot-vla-4b-depth-posttrain-robotwin](https://modelscope.cn/models/Robbyant/lingbot-vla-4b-depth-posttrain-robotwin) | LingBot-VLA-Posttrain-Robotwin *w/* Depth |
|
| 294 |
+
|
| 295 |
+
We also provided [evaluation code](deploy/lingbot_robotwin_policy_rep.py) for the community to reproduce the performance of LingBot-VLA on Robotwin 2.0:
|
| 296 |
+
```bash
|
| 297 |
+
export QWEN25_PATH=path_to_Qwen2.5-VL-3B-Instruct
|
| 298 |
+
python -m deploy.lingbot_robotwin_policy_rep \
|
| 299 |
+
--model_path Path_to_LingBot-VLA-Posttrain-Robotwin \
|
| 300 |
+
--use_length 50 \
|
| 301 |
+
--port port
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
<p align="center">
|
| 305 |
+
<img src="assets/exp-gm-100.png" width="45%" style="margin: 0 10px;">
|
| 306 |
+
<img src="assets/exp-robotwin.png" width="45%" style="margin: 0 10px;">
|
| 307 |
+
</p>
|
| 308 |
+
|
| 309 |
+
---
|
| 310 |
+
|
| 311 |
+
## 📝 Citation
|
| 312 |
+
|
| 313 |
+
If you find our work useful in your research, feel free to give us a cite.
|
| 314 |
+
|
| 315 |
+
```bibtex
|
| 316 |
+
@article{wu2026pragmatic,
|
| 317 |
+
title={A Pragmatic VLA Foundation Model},
|
| 318 |
+
author={Wei Wu and Fan Lu and Yunnan Wang and Shuai Yang and Shi Liu and Fangjing Wang and Shuailei Ma and He Sun and Yong Wang and Zhenqi Qiu and Houlong Xiong and Ziyu Wang and Shuai Zhou and Yiyu Ren and Kejia Zhang and Hui Yu and Jingmei Zhao and Qian Zhu and Ran Cheng and Yong-Lu Li and Yongtao Huang and Xing Zhu and Yujun Shen and Kecheng Zheng},
|
| 319 |
+
journal={arXiv preprint arXiv:2601.18692v1},
|
| 320 |
+
year={2026}
|
| 321 |
+
}
|
| 322 |
+
```
|
| 323 |
+
|
| 324 |
+
---
|
| 325 |
+
|
| 326 |
+
## 📄 License Agreement
|
| 327 |
+
This project is licensed under the [Apache-2.0 License](LICENSE).
|
| 328 |
+
|
| 329 |
+
## 😊 Acknowledgement
|
| 330 |
+
We would like to express our sincere gratitude to the developers of [VeOmni](https://arxiv.org/abs/2508.02317) and [LeRobot](https://github.com/huggingface/lerobot#). This project benefits significantly from their outstanding work and contributions to the open-source community.
|
assets/LingBot-VLA.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1b0a361d6084d74afc0bc9fcdee5051375b701a8e41013460107a46902bd0426
|
| 3 |
+
size 10000817
|
assets/PaliGemmaPI.png
ADDED
|
Git LFS Details
|
assets/QwenPI.png
ADDED
|
Git LFS Details
|
assets/QwenPI_PaliGemmaPI.png
ADDED
|
Git LFS Details
|
assets/Teaser.png
ADDED
|
Git LFS Details
|
assets/exp-gm-100.png
ADDED
|
Git LFS Details
|
assets/exp-robotwin.png
ADDED
|
Git LFS Details
|
assets/norm_stats/libero.json
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"norm_stats": {
|
| 3 |
+
"state": {
|
| 4 |
+
"mean": [
|
| 5 |
+
-0.04617275670170784,
|
| 6 |
+
0.034034404903650284,
|
| 7 |
+
0.7647115588188171,
|
| 8 |
+
2.971421480178833,
|
| 9 |
+
-0.2198116034269333,
|
| 10 |
+
-0.1260652393102646,
|
| 11 |
+
0.02694438025355339,
|
| 12 |
+
-0.0272101741284132,
|
| 13 |
+
0.0,
|
| 14 |
+
0.0,
|
| 15 |
+
0.0,
|
| 16 |
+
0.0,
|
| 17 |
+
0.0,
|
| 18 |
+
0.0,
|
| 19 |
+
0.0,
|
| 20 |
+
0.0,
|
| 21 |
+
0.0,
|
| 22 |
+
0.0,
|
| 23 |
+
0.0,
|
| 24 |
+
0.0,
|
| 25 |
+
0.0,
|
| 26 |
+
0.0,
|
| 27 |
+
0.0,
|
| 28 |
+
0.0,
|
| 29 |
+
0.0,
|
| 30 |
+
0.0,
|
| 31 |
+
0.0,
|
| 32 |
+
0.0,
|
| 33 |
+
0.0,
|
| 34 |
+
0.0,
|
| 35 |
+
0.0,
|
| 36 |
+
0.0
|
| 37 |
+
],
|
| 38 |
+
"std": [
|
| 39 |
+
0.1049584373831749,
|
| 40 |
+
0.15187117457389832,
|
| 41 |
+
0.3785041272640228,
|
| 42 |
+
0.3451951742172241,
|
| 43 |
+
0.910057544708252,
|
| 44 |
+
0.3253032863140106,
|
| 45 |
+
0.014151589013636112,
|
| 46 |
+
0.014038060791790485,
|
| 47 |
+
0.0,
|
| 48 |
+
0.0,
|
| 49 |
+
0.0,
|
| 50 |
+
0.0,
|
| 51 |
+
0.0,
|
| 52 |
+
0.0,
|
| 53 |
+
0.0,
|
| 54 |
+
0.0,
|
| 55 |
+
0.0,
|
| 56 |
+
0.0,
|
| 57 |
+
0.0,
|
| 58 |
+
0.0,
|
| 59 |
+
0.0,
|
| 60 |
+
0.0,
|
| 61 |
+
0.0,
|
| 62 |
+
0.0,
|
| 63 |
+
0.0,
|
| 64 |
+
0.0,
|
| 65 |
+
0.0,
|
| 66 |
+
0.0,
|
| 67 |
+
0.0,
|
| 68 |
+
0.0,
|
| 69 |
+
0.0,
|
| 70 |
+
0.0
|
| 71 |
+
],
|
| 72 |
+
"q01": [
|
| 73 |
+
-0.4003246918797493,
|
| 74 |
+
-0.268838057410717,
|
| 75 |
+
0.03963126605004072,
|
| 76 |
+
1.5141939243793487,
|
| 77 |
+
-2.7199491125106814,
|
| 78 |
+
-1.0708919448852539,
|
| 79 |
+
0.0017206525699933989,
|
| 80 |
+
-0.04004273633235134,
|
| 81 |
+
0.0,
|
| 82 |
+
0.0,
|
| 83 |
+
0.0,
|
| 84 |
+
0.0,
|
| 85 |
+
0.0,
|
| 86 |
+
0.0,
|
| 87 |
+
0.0,
|
| 88 |
+
0.0,
|
| 89 |
+
0.0,
|
| 90 |
+
0.0,
|
| 91 |
+
0.0,
|
| 92 |
+
0.0,
|
| 93 |
+
0.0,
|
| 94 |
+
0.0,
|
| 95 |
+
0.0,
|
| 96 |
+
0.0,
|
| 97 |
+
0.0,
|
| 98 |
+
0.0,
|
| 99 |
+
0.0,
|
| 100 |
+
0.0,
|
| 101 |
+
0.0,
|
| 102 |
+
0.0,
|
| 103 |
+
0.0,
|
| 104 |
+
0.0
|
| 105 |
+
],
|
| 106 |
+
"q99": [
|
| 107 |
+
0.1335429027736188,
|
| 108 |
+
0.3378903574764729,
|
| 109 |
+
1.2657122139371932,
|
| 110 |
+
3.2784227243721484,
|
| 111 |
+
2.4147262454509733,
|
| 112 |
+
0.5962245464324951,
|
| 113 |
+
0.04029089962062426,
|
| 114 |
+
-0.001789628425752747,
|
| 115 |
+
0.0,
|
| 116 |
+
0.0,
|
| 117 |
+
0.0,
|
| 118 |
+
0.0,
|
| 119 |
+
0.0,
|
| 120 |
+
0.0,
|
| 121 |
+
0.0,
|
| 122 |
+
0.0,
|
| 123 |
+
0.0,
|
| 124 |
+
0.0,
|
| 125 |
+
0.0,
|
| 126 |
+
0.0,
|
| 127 |
+
0.0,
|
| 128 |
+
0.0,
|
| 129 |
+
0.0,
|
| 130 |
+
0.0,
|
| 131 |
+
0.0,
|
| 132 |
+
0.0,
|
| 133 |
+
0.0,
|
| 134 |
+
0.0,
|
| 135 |
+
0.0,
|
| 136 |
+
0.0,
|
| 137 |
+
0.0,
|
| 138 |
+
0.0
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
"actions": {
|
| 142 |
+
"mean": [
|
| 143 |
+
0.06667574495077133,
|
| 144 |
+
0.06483978033065796,
|
| 145 |
+
-0.80384361743927,
|
| 146 |
+
-2.970874071121216,
|
| 147 |
+
0.22662578523159027,
|
| 148 |
+
0.11959122866392136,
|
| 149 |
+
-0.036161474883556366,
|
| 150 |
+
0.0,
|
| 151 |
+
0.0,
|
| 152 |
+
0.0,
|
| 153 |
+
0.0,
|
| 154 |
+
0.0,
|
| 155 |
+
0.0,
|
| 156 |
+
0.0,
|
| 157 |
+
0.0,
|
| 158 |
+
0.0,
|
| 159 |
+
0.0,
|
| 160 |
+
0.0,
|
| 161 |
+
0.0,
|
| 162 |
+
0.0,
|
| 163 |
+
0.0,
|
| 164 |
+
0.0,
|
| 165 |
+
0.0,
|
| 166 |
+
0.0,
|
| 167 |
+
0.0,
|
| 168 |
+
0.0,
|
| 169 |
+
0.0,
|
| 170 |
+
0.0,
|
| 171 |
+
0.0,
|
| 172 |
+
0.0,
|
| 173 |
+
0.0,
|
| 174 |
+
0.0
|
| 175 |
+
],
|
| 176 |
+
"std": [
|
| 177 |
+
0.32812511920928955,
|
| 178 |
+
0.4197826683521271,
|
| 179 |
+
0.6153613924980164,
|
| 180 |
+
0.35168182849884033,
|
| 181 |
+
0.9132273197174072,
|
| 182 |
+
0.3432939946651459,
|
| 183 |
+
0.9993459582328796,
|
| 184 |
+
0.0,
|
| 185 |
+
0.0,
|
| 186 |
+
0.0,
|
| 187 |
+
0.0,
|
| 188 |
+
0.0,
|
| 189 |
+
0.0,
|
| 190 |
+
0.0,
|
| 191 |
+
0.0,
|
| 192 |
+
0.0,
|
| 193 |
+
0.0,
|
| 194 |
+
0.0,
|
| 195 |
+
0.0,
|
| 196 |
+
0.0,
|
| 197 |
+
0.0,
|
| 198 |
+
0.0,
|
| 199 |
+
0.0,
|
| 200 |
+
0.0,
|
| 201 |
+
0.0,
|
| 202 |
+
0.0,
|
| 203 |
+
0.0,
|
| 204 |
+
0.0,
|
| 205 |
+
0.0,
|
| 206 |
+
0.0,
|
| 207 |
+
0.0,
|
| 208 |
+
0.0
|
| 209 |
+
],
|
| 210 |
+
"q01": [
|
| 211 |
+
-0.7088336983919143,
|
| 212 |
+
-0.8786727856397629,
|
| 213 |
+
-2.097322083187103,
|
| 214 |
+
-3.3041505486488343,
|
| 215 |
+
-2.4138620029449465,
|
| 216 |
+
-0.6111064100980759,
|
| 217 |
+
-1.0,
|
| 218 |
+
0.0,
|
| 219 |
+
0.0,
|
| 220 |
+
0.0,
|
| 221 |
+
0.0,
|
| 222 |
+
0.0,
|
| 223 |
+
0.0,
|
| 224 |
+
0.0,
|
| 225 |
+
0.0,
|
| 226 |
+
0.0,
|
| 227 |
+
0.0,
|
| 228 |
+
0.0,
|
| 229 |
+
0.0,
|
| 230 |
+
0.0,
|
| 231 |
+
0.0,
|
| 232 |
+
0.0,
|
| 233 |
+
0.0,
|
| 234 |
+
0.0,
|
| 235 |
+
0.0,
|
| 236 |
+
0.0,
|
| 237 |
+
0.0,
|
| 238 |
+
0.0,
|
| 239 |
+
0.0,
|
| 240 |
+
0.0,
|
| 241 |
+
0.0,
|
| 242 |
+
0.0
|
| 243 |
+
],
|
| 244 |
+
"q99": [
|
| 245 |
+
1.0219826289415357,
|
| 246 |
+
1.0526966882944104,
|
| 247 |
+
0.7265835452556608,
|
| 248 |
+
-1.491220802116394,
|
| 249 |
+
2.7264903316497806,
|
| 250 |
+
1.1191907620668413,
|
| 251 |
+
0.9996,
|
| 252 |
+
0.0,
|
| 253 |
+
0.0,
|
| 254 |
+
0.0,
|
| 255 |
+
0.0,
|
| 256 |
+
0.0,
|
| 257 |
+
0.0,
|
| 258 |
+
0.0,
|
| 259 |
+
0.0,
|
| 260 |
+
0.0,
|
| 261 |
+
0.0,
|
| 262 |
+
0.0,
|
| 263 |
+
0.0,
|
| 264 |
+
0.0,
|
| 265 |
+
0.0,
|
| 266 |
+
0.0,
|
| 267 |
+
0.0,
|
| 268 |
+
0.0,
|
| 269 |
+
0.0,
|
| 270 |
+
0.0,
|
| 271 |
+
0.0,
|
| 272 |
+
0.0,
|
| 273 |
+
0.0,
|
| 274 |
+
0.0,
|
| 275 |
+
0.0,
|
| 276 |
+
0.0
|
| 277 |
+
]
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
}
|
assets/norm_stats/robotwin_50.json
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"norm_stats": {
|
| 3 |
+
"action.arm.position": {
|
| 4 |
+
"mean": [
|
| 5 |
+
-0.22649447619915009,
|
| 6 |
+
1.0910465717315674,
|
| 7 |
+
0.8046976923942566,
|
| 8 |
+
-0.3529793620109558,
|
| 9 |
+
0.056382808834314346,
|
| 10 |
+
-0.04518803581595421,
|
| 11 |
+
0.23444592952728271,
|
| 12 |
+
1.1117788553237915,
|
| 13 |
+
0.8302268385887146,
|
| 14 |
+
-0.3584558367729187,
|
| 15 |
+
-0.010058438405394554,
|
| 16 |
+
0.010835078544914722
|
| 17 |
+
],
|
| 18 |
+
"std": [
|
| 19 |
+
0.36951732635498047,
|
| 20 |
+
0.9946224689483643,
|
| 21 |
+
0.7907869219779968,
|
| 22 |
+
0.663685142993927,
|
| 23 |
+
0.24930860102176666,
|
| 24 |
+
0.5646992921829224,
|
| 25 |
+
0.32377511262893677,
|
| 26 |
+
1.0205038785934448,
|
| 27 |
+
0.8121177554130554,
|
| 28 |
+
0.7205489277839661,
|
| 29 |
+
0.25676125288009644,
|
| 30 |
+
0.6210611462593079
|
| 31 |
+
],
|
| 32 |
+
"q01": [
|
| 33 |
+
-0.9676963651657111,
|
| 34 |
+
-0.0003164021181873977,
|
| 35 |
+
-0.0008187678098678652,
|
| 36 |
+
-1.5952941972732544,
|
| 37 |
+
-0.4444093635320664,
|
| 38 |
+
-2.2108209049224854,
|
| 39 |
+
-0.13648582720756508,
|
| 40 |
+
-0.0025135905981064077,
|
| 41 |
+
-0.0016476722434163094,
|
| 42 |
+
-1.7023667912483216,
|
| 43 |
+
-1.0292453282356262,
|
| 44 |
+
-1.6702169750213622
|
| 45 |
+
],
|
| 46 |
+
"q99": [
|
| 47 |
+
0.17045696868896432,
|
| 48 |
+
2.5792064671580563,
|
| 49 |
+
2.4791862522006034,
|
| 50 |
+
1.263499072647095,
|
| 51 |
+
1.2283580561399456,
|
| 52 |
+
1.4622943069458012,
|
| 53 |
+
1.096450059175491,
|
| 54 |
+
2.605947977209091,
|
| 55 |
+
2.5039097490906714,
|
| 56 |
+
1.3104696589708325,
|
| 57 |
+
1.074876550579071,
|
| 58 |
+
2.104229341125489
|
| 59 |
+
],
|
| 60 |
+
"q02": [
|
| 61 |
+
-0.9234203773498537,
|
| 62 |
+
-0.0003164021181873977,
|
| 63 |
+
-0.0008187678098678652,
|
| 64 |
+
-1.509812859249115,
|
| 65 |
+
-0.32799621334075924,
|
| 66 |
+
-1.656348336791992,
|
| 67 |
+
-0.05942733430862468,
|
| 68 |
+
-0.0025135905981064077,
|
| 69 |
+
-0.0016476722434163094,
|
| 70 |
+
-1.6187864029407502,
|
| 71 |
+
-0.8712951603889465,
|
| 72 |
+
-1.5470734649658198
|
| 73 |
+
],
|
| 74 |
+
"q98": [
|
| 75 |
+
0.11836757125854458,
|
| 76 |
+
2.4944407171577216,
|
| 77 |
+
2.3239549394726753,
|
| 78 |
+
1.0776700769424439,
|
| 79 |
+
1.0128444806575776,
|
| 80 |
+
1.2158620544433596,
|
| 81 |
+
0.945415413093567,
|
| 82 |
+
2.5296102081775667,
|
| 83 |
+
2.3580759009346366,
|
| 84 |
+
1.2048114322423933,
|
| 85 |
+
0.6983346325874327,
|
| 86 |
+
1.7523907409667974
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
"action.effector.position": {
|
| 90 |
+
"mean": [
|
| 91 |
+
0.6722026467323303,
|
| 92 |
+
0.6737783551216125
|
| 93 |
+
],
|
| 94 |
+
"std": [
|
| 95 |
+
0.45274168252944946,
|
| 96 |
+
0.45141810178756714
|
| 97 |
+
],
|
| 98 |
+
"q01": [
|
| 99 |
+
-1e-10,
|
| 100 |
+
-1e-10
|
| 101 |
+
],
|
| 102 |
+
"q99": [
|
| 103 |
+
0.99980000009996,
|
| 104 |
+
0.99980000009996
|
| 105 |
+
],
|
| 106 |
+
"q02": [
|
| 107 |
+
-1e-10,
|
| 108 |
+
-1e-10
|
| 109 |
+
],
|
| 110 |
+
"q98": [
|
| 111 |
+
0.99980000009996,
|
| 112 |
+
0.99980000009996
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
"observation.state.arm.position": {
|
| 116 |
+
"mean": [
|
| 117 |
+
-0.22545991837978363,
|
| 118 |
+
1.0864390134811401,
|
| 119 |
+
0.8012449741363525,
|
| 120 |
+
-0.3515830338001251,
|
| 121 |
+
0.05604754388332367,
|
| 122 |
+
-0.0445503294467926,
|
| 123 |
+
0.23296862840652466,
|
| 124 |
+
1.1059207916259766,
|
| 125 |
+
0.8258985280990601,
|
| 126 |
+
-0.3568105697631836,
|
| 127 |
+
-0.00992637686431408,
|
| 128 |
+
0.010328034870326519
|
| 129 |
+
],
|
| 130 |
+
"std": [
|
| 131 |
+
0.3688313364982605,
|
| 132 |
+
0.9950565099716187,
|
| 133 |
+
0.7906551957130432,
|
| 134 |
+
0.6622100472450256,
|
| 135 |
+
0.24865445494651794,
|
| 136 |
+
0.5626452565193176,
|
| 137 |
+
0.32314980030059814,
|
| 138 |
+
1.0208053588867188,
|
| 139 |
+
0.8119285702705383,
|
| 140 |
+
0.718558132648468,
|
| 141 |
+
0.25572913885116577,
|
| 142 |
+
0.6181830763816833
|
| 143 |
+
],
|
| 144 |
+
"q01": [
|
| 145 |
+
-0.9676963651657111,
|
| 146 |
+
-0.0003164021181873977,
|
| 147 |
+
-0.0008187678098678652,
|
| 148 |
+
-1.5938075653076171,
|
| 149 |
+
-0.44261839199066166,
|
| 150 |
+
-2.198074409103393,
|
| 151 |
+
-0.13494465734958627,
|
| 152 |
+
-0.0025135905981064077,
|
| 153 |
+
-0.0016476722434163094,
|
| 154 |
+
-1.7015782970190048,
|
| 155 |
+
-1.0292453282356262,
|
| 156 |
+
-1.6682623161315915
|
| 157 |
+
],
|
| 158 |
+
"q99": [
|
| 159 |
+
0.17045696868896432,
|
| 160 |
+
2.5792064671580563,
|
| 161 |
+
2.4782622562915084,
|
| 162 |
+
1.2545792808532719,
|
| 163 |
+
1.2247761130571364,
|
| 164 |
+
1.458045475006104,
|
| 165 |
+
1.0856618701696394,
|
| 166 |
+
2.6036578441381453,
|
| 167 |
+
2.502444082275033,
|
| 168 |
+
1.3057386935949324,
|
| 169 |
+
1.0699406078338622,
|
| 170 |
+
2.0983653644561766
|
| 171 |
+
],
|
| 172 |
+
"q02": [
|
| 173 |
+
-0.9234203773498537,
|
| 174 |
+
-0.0003164021181873977,
|
| 175 |
+
-0.0008187678098678652,
|
| 176 |
+
-1.5083262272834776,
|
| 177 |
+
-0.32799621334075924,
|
| 178 |
+
-1.6499750888824458,
|
| 179 |
+
-0.05942733430862468,
|
| 180 |
+
-0.0025135905981064077,
|
| 181 |
+
-0.0016476722434163094,
|
| 182 |
+
-1.6172094144821167,
|
| 183 |
+
-0.8684746216773986,
|
| 184 |
+
-1.5470734649658198
|
| 185 |
+
],
|
| 186 |
+
"q98": [
|
| 187 |
+
0.11836757125854458,
|
| 188 |
+
2.4944407171577216,
|
| 189 |
+
2.320258955836296,
|
| 190 |
+
1.0754401289939883,
|
| 191 |
+
1.0116504996299742,
|
| 192 |
+
1.2137376384735115,
|
| 193 |
+
0.945415413093567,
|
| 194 |
+
2.528846830487251,
|
| 195 |
+
2.3551445673033595,
|
| 196 |
+
1.2016574553251265,
|
| 197 |
+
0.6969243632316591,
|
| 198 |
+
1.746526764297485
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
"observation.state.effector.position": {
|
| 202 |
+
"mean": [
|
| 203 |
+
0.6734354496002197,
|
| 204 |
+
0.6749846339225769
|
| 205 |
+
],
|
| 206 |
+
"std": [
|
| 207 |
+
0.4522727429866791,
|
| 208 |
+
0.45095184445381165
|
| 209 |
+
],
|
| 210 |
+
"q01": [
|
| 211 |
+
-1e-10,
|
| 212 |
+
-1e-10
|
| 213 |
+
],
|
| 214 |
+
"q99": [
|
| 215 |
+
0.99980000009996,
|
| 216 |
+
0.99980000009996
|
| 217 |
+
],
|
| 218 |
+
"q02": [
|
| 219 |
+
-1e-10,
|
| 220 |
+
-1e-10
|
| 221 |
+
],
|
| 222 |
+
"q98": [
|
| 223 |
+
0.99980000009996,
|
| 224 |
+
0.99980000009996
|
| 225 |
+
]
|
| 226 |
+
}
|
| 227 |
+
},
|
| 228 |
+
"count": 532992
|
| 229 |
+
}
|
assets/norm_stats/robotwin_5_customized.json
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"norm_stats": {
|
| 3 |
+
"action": {
|
| 4 |
+
"mean": [
|
| 5 |
+
-0.32207754254341125,
|
| 6 |
+
1.406205654144287,
|
| 7 |
+
1.1087545156478882,
|
| 8 |
+
-0.6245313882827759,
|
| 9 |
+
-0.027720848098397255,
|
| 10 |
+
-0.035565875470638275,
|
| 11 |
+
0.4717631936073303,
|
| 12 |
+
0.25276312232017517,
|
| 13 |
+
0.8104884624481201,
|
| 14 |
+
0.5522242188453674,
|
| 15 |
+
-0.1358797252178192,
|
| 16 |
+
0.13210205733776093,
|
| 17 |
+
-0.13196010887622833,
|
| 18 |
+
0.7805091738700867
|
| 19 |
+
],
|
| 20 |
+
"std": [
|
| 21 |
+
0.2855374813079834,
|
| 22 |
+
0.9229381084442139,
|
| 23 |
+
0.8118345737457275,
|
| 24 |
+
0.49564430117607117,
|
| 25 |
+
0.16244904696941376,
|
| 26 |
+
0.5517618656158447,
|
| 27 |
+
0.4883338212966919,
|
| 28 |
+
0.40702372789382935,
|
| 29 |
+
1.036325216293335,
|
| 30 |
+
0.7480976581573486,
|
| 31 |
+
0.7034134268760681,
|
| 32 |
+
0.3450477123260498,
|
| 33 |
+
0.7341580390930176,
|
| 34 |
+
0.4033139646053314
|
| 35 |
+
],
|
| 36 |
+
"q01": [
|
| 37 |
+
-0.8213654638230801,
|
| 38 |
+
-5.257390398583084e-7,
|
| 39 |
+
-0.00002296771708643064,
|
| 40 |
+
-1.6557389229632915,
|
| 41 |
+
-0.6564541918039322,
|
| 42 |
+
-1.1997157670021057,
|
| 43 |
+
0.0,
|
| 44 |
+
-0.0013322193384173175,
|
| 45 |
+
0.0,
|
| 46 |
+
-0.0000281171942333458,
|
| 47 |
+
-1.4858032744407654,
|
| 48 |
+
-0.013652276556193832,
|
| 49 |
+
-1.5582030366897581,
|
| 50 |
+
0.0
|
| 51 |
+
],
|
| 52 |
+
"q99": [
|
| 53 |
+
0.01988644998967637,
|
| 54 |
+
2.618066892673189,
|
| 55 |
+
2.8887816588023267,
|
| 56 |
+
-0.00009503023102874764,
|
| 57 |
+
0.39941834962368006,
|
| 58 |
+
1.3274614672660827,
|
| 59 |
+
0.9998,
|
| 60 |
+
1.2499000839233396,
|
| 61 |
+
2.403721238327026,
|
| 62 |
+
2.223998639903084,
|
| 63 |
+
1.3482957191944123,
|
| 64 |
+
1.2036741195514797,
|
| 65 |
+
2.3008846492767336,
|
| 66 |
+
0.9998
|
| 67 |
+
],
|
| 68 |
+
"q02": [
|
| 69 |
+
-0.8116190195694566,
|
| 70 |
+
-5.257390398583084e-7,
|
| 71 |
+
-0.00002296771708643064,
|
| 72 |
+
-1.5653808554142714,
|
| 73 |
+
-0.5909986785650253,
|
| 74 |
+
-0.9318809885978698,
|
| 75 |
+
0.0,
|
| 76 |
+
-0.0013322193384173175,
|
| 77 |
+
0.0,
|
| 78 |
+
-0.0000281171942333458,
|
| 79 |
+
-1.400590261220932,
|
| 80 |
+
-0.005905654035508634,
|
| 81 |
+
-1.5582030366897581,
|
| 82 |
+
0.0
|
| 83 |
+
],
|
| 84 |
+
"q98": [
|
| 85 |
+
0.01988644998967637,
|
| 86 |
+
2.509362170317786,
|
| 87 |
+
2.6153081541584893,
|
| 88 |
+
-0.00009503023102874764,
|
| 89 |
+
0.34549802929162987,
|
| 90 |
+
1.2313367155075077,
|
| 91 |
+
0.9998,
|
| 92 |
+
1.2416952819347378,
|
| 93 |
+
2.374588215923309,
|
| 94 |
+
2.1395174845976728,
|
| 95 |
+
1.328065291595459,
|
| 96 |
+
1.1956508319407702,
|
| 97 |
+
2.172924092388153,
|
| 98 |
+
0.9998
|
| 99 |
+
]
|
| 100 |
+
},
|
| 101 |
+
"observation.state": {
|
| 102 |
+
"mean": [
|
| 103 |
+
-0.320831835269928,
|
| 104 |
+
1.401549220085144,
|
| 105 |
+
1.1045918464660645,
|
| 106 |
+
-0.6217827796936035,
|
| 107 |
+
-0.0279570072889328,
|
| 108 |
+
-0.03499468415975571,
|
| 109 |
+
0.4726906716823578,
|
| 110 |
+
0.2512069344520569,
|
| 111 |
+
0.8065828680992126,
|
| 112 |
+
0.5495453476905823,
|
| 113 |
+
-0.13533149659633636,
|
| 114 |
+
0.13129419088363647,
|
| 115 |
+
-0.1315813809633255,
|
| 116 |
+
0.7816013693809509
|
| 117 |
+
],
|
| 118 |
+
"std": [
|
| 119 |
+
0.28554511070251465,
|
| 120 |
+
0.924691379070282,
|
| 121 |
+
0.8124904036521912,
|
| 122 |
+
0.49545007944107056,
|
| 123 |
+
0.16213101148605347,
|
| 124 |
+
0.5504377484321594,
|
| 125 |
+
0.4883865714073181,
|
| 126 |
+
0.40611740946769714,
|
| 127 |
+
1.035233497619629,
|
| 128 |
+
0.7470027208328247,
|
| 129 |
+
0.7013660073280334,
|
| 130 |
+
0.3439686894416809,
|
| 131 |
+
0.7313857674598694,
|
| 132 |
+
0.4025507867336273
|
| 133 |
+
],
|
| 134 |
+
"q01": [
|
| 135 |
+
-0.8213654638230801,
|
| 136 |
+
-5.257390398583084e-7,
|
| 137 |
+
-0.00002296771708643064,
|
| 138 |
+
-1.6557389229632915,
|
| 139 |
+
-0.6564541918039322,
|
| 140 |
+
-1.1997157670021057,
|
| 141 |
+
0.0,
|
| 142 |
+
-0.0013322193384173175,
|
| 143 |
+
0.0,
|
| 144 |
+
-0.0000281171942333458,
|
| 145 |
+
-1.483351101398468,
|
| 146 |
+
-0.013652276556193832,
|
| 147 |
+
-1.5582030366897581,
|
| 148 |
+
0.0
|
| 149 |
+
],
|
| 150 |
+
"q99": [
|
| 151 |
+
0.01988644998967637,
|
| 152 |
+
2.6186390227908487,
|
| 153 |
+
2.889423615385998,
|
| 154 |
+
-0.00009503023102874764,
|
| 155 |
+
0.39780878782272344,
|
| 156 |
+
1.3274614672660827,
|
| 157 |
+
0.9998,
|
| 158 |
+
1.2499000839233396,
|
| 159 |
+
2.404215018367767,
|
| 160 |
+
2.2201366442319794,
|
| 161 |
+
1.347682675933838,
|
| 162 |
+
1.2036741195514797,
|
| 163 |
+
2.3008846492767336,
|
| 164 |
+
0.9998
|
| 165 |
+
],
|
| 166 |
+
"q02": [
|
| 167 |
+
-0.8116190195694566,
|
| 168 |
+
-5.257390398583084e-7,
|
| 169 |
+
-0.00002296771708643064,
|
| 170 |
+
-1.5653808554142714,
|
| 171 |
+
-0.5909986785650253,
|
| 172 |
+
-0.9318809885978698,
|
| 173 |
+
0.0,
|
| 174 |
+
-0.0013322193384173175,
|
| 175 |
+
0.0,
|
| 176 |
+
-0.0000281171942333458,
|
| 177 |
+
-1.3981380881786347,
|
| 178 |
+
-0.005905654035508634,
|
| 179 |
+
-1.5582030366897581,
|
| 180 |
+
0.0
|
| 181 |
+
],
|
| 182 |
+
"q98": [
|
| 183 |
+
0.01988644998967637,
|
| 184 |
+
2.509362170317786,
|
| 185 |
+
2.61595011074216,
|
| 186 |
+
-0.00009503023102874764,
|
| 187 |
+
0.3452297689914703,
|
| 188 |
+
1.2313367155075077,
|
| 189 |
+
0.9998,
|
| 190 |
+
1.2416952819347378,
|
| 191 |
+
2.374588215923309,
|
| 192 |
+
2.1380692362210083,
|
| 193 |
+
1.328065291595459,
|
| 194 |
+
1.1956508319407702,
|
| 195 |
+
2.1450514958381657,
|
| 196 |
+
0.9998
|
| 197 |
+
]
|
| 198 |
+
}
|
| 199 |
+
},
|
| 200 |
+
"count": 74240
|
| 201 |
+
}
|
assets/norm_stats/robotwin_all_new.json
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"norm_stats": {
|
| 3 |
+
"action.arm.position": {
|
| 4 |
+
"mean": [
|
| 5 |
+
-0.2260681688785553,
|
| 6 |
+
1.090435266494751,
|
| 7 |
+
0.8042582273483276,
|
| 8 |
+
-0.3527189791202545,
|
| 9 |
+
0.056556474417448044,
|
| 10 |
+
-0.04530515521764755,
|
| 11 |
+
0.2346765249967575,
|
| 12 |
+
1.112542748451233,
|
| 13 |
+
0.8304542303085327,
|
| 14 |
+
-0.357768177986145,
|
| 15 |
+
-0.01014612801373005,
|
| 16 |
+
0.010991317220032215
|
| 17 |
+
],
|
| 18 |
+
"std": [
|
| 19 |
+
0.3691432774066925,
|
| 20 |
+
0.994762122631073,
|
| 21 |
+
0.7908730506896973,
|
| 22 |
+
0.6637247800827026,
|
| 23 |
+
0.24963052570819855,
|
| 24 |
+
0.5638052821159363,
|
| 25 |
+
0.32393988966941833,
|
| 26 |
+
1.0204970836639404,
|
| 27 |
+
0.8119731545448303,
|
| 28 |
+
0.7209287285804749,
|
| 29 |
+
0.25776439905166626,
|
| 30 |
+
0.6208906769752502
|
| 31 |
+
],
|
| 32 |
+
"q01": [
|
| 33 |
+
-0.9676963651657111,
|
| 34 |
+
-0.0003164021181873977,
|
| 35 |
+
-0.0026667596280574857,
|
| 36 |
+
-1.596037513256073,
|
| 37 |
+
-0.4467973255872727,
|
| 38 |
+
-2.20232324104309,
|
| 39 |
+
-0.13648582720756508,
|
| 40 |
+
-0.0017502129077910933,
|
| 41 |
+
-0.0023805056512355804,
|
| 42 |
+
-1.703943779706955,
|
| 43 |
+
-1.0264247895240783,
|
| 44 |
+
-1.6682623161315915
|
| 45 |
+
],
|
| 46 |
+
"q99": [
|
| 47 |
+
0.17045696868896432,
|
| 48 |
+
2.5760957974332737,
|
| 49 |
+
2.4727182808369395,
|
| 50 |
+
1.259782492733002,
|
| 51 |
+
1.2253731035709379,
|
| 52 |
+
1.4495478111267097,
|
| 53 |
+
1.0841207003116606,
|
| 54 |
+
2.6036578441381453,
|
| 55 |
+
2.4987799152359367,
|
| 56 |
+
1.3104696589708325,
|
| 57 |
+
1.0692354731559752,
|
| 58 |
+
2.104229341125489
|
| 59 |
+
],
|
| 60 |
+
"q02": [
|
| 61 |
+
-0.9260248472213748,
|
| 62 |
+
-0.0003164021181873977,
|
| 63 |
+
-0.0026667596280574857,
|
| 64 |
+
-1.5090695432662964,
|
| 65 |
+
-0.3291901943683624,
|
| 66 |
+
-1.6520995048522948,
|
| 67 |
+
-0.05942733430862468,
|
| 68 |
+
-0.0017502129077910933,
|
| 69 |
+
-0.0023805056512355804,
|
| 70 |
+
-1.6187864029407502,
|
| 71 |
+
-0.8741156991004944,
|
| 72 |
+
-1.5490281238555905
|
| 73 |
+
],
|
| 74 |
+
"q98": [
|
| 75 |
+
0.1157631013870235,
|
| 76 |
+
2.4936630497265257,
|
| 77 |
+
2.3193349599272013,
|
| 78 |
+
1.0769267609596254,
|
| 79 |
+
1.0140384616851805,
|
| 80 |
+
1.2073643905639653,
|
| 81 |
+
0.9469565829515458,
|
| 82 |
+
2.528083452796936,
|
| 83 |
+
2.3551445673033595,
|
| 84 |
+
1.2071769149303435,
|
| 85 |
+
0.6969243632316591,
|
| 86 |
+
1.7504360820770266
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
"action.effector.position": {
|
| 90 |
+
"mean": [
|
| 91 |
+
0.6723259687423706,
|
| 92 |
+
0.6735112071037292
|
| 93 |
+
],
|
| 94 |
+
"std": [
|
| 95 |
+
0.4526418447494507,
|
| 96 |
+
0.4514695405960083
|
| 97 |
+
],
|
| 98 |
+
"q01": [
|
| 99 |
+
0.0,
|
| 100 |
+
0.0
|
| 101 |
+
],
|
| 102 |
+
"q99": [
|
| 103 |
+
0.9998,
|
| 104 |
+
0.9998
|
| 105 |
+
],
|
| 106 |
+
"q02": [
|
| 107 |
+
0.0,
|
| 108 |
+
0.0
|
| 109 |
+
],
|
| 110 |
+
"q98": [
|
| 111 |
+
0.9998,
|
| 112 |
+
0.9998
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
"observation.state.arm.position": {
|
| 116 |
+
"mean": [
|
| 117 |
+
-0.22502799332141876,
|
| 118 |
+
1.0857956409454346,
|
| 119 |
+
0.8007810711860657,
|
| 120 |
+
-0.3513113558292389,
|
| 121 |
+
0.05622035637497902,
|
| 122 |
+
-0.044659487903118134,
|
| 123 |
+
0.23319771885871887,
|
| 124 |
+
1.106688141822815,
|
| 125 |
+
0.82613205909729,
|
| 126 |
+
-0.3561287522315979,
|
| 127 |
+
-0.010010534897446632,
|
| 128 |
+
0.010481182485818863
|
| 129 |
+
],
|
| 130 |
+
"std": [
|
| 131 |
+
0.3684558570384979,
|
| 132 |
+
0.9951919317245483,
|
| 133 |
+
0.7907320857048035,
|
| 134 |
+
0.6622379422187805,
|
| 135 |
+
0.24897389113903046,
|
| 136 |
+
0.5617504119873047,
|
| 137 |
+
0.32331398129463196,
|
| 138 |
+
1.0208075046539307,
|
| 139 |
+
0.8117841482162476,
|
| 140 |
+
0.718940019607544,
|
| 141 |
+
0.25672635436058044,
|
| 142 |
+
0.6180205345153809
|
| 143 |
+
],
|
| 144 |
+
"q01": [
|
| 145 |
+
-0.9676963651657111,
|
| 146 |
+
-0.0003164021181873977,
|
| 147 |
+
-0.0026667596280574857,
|
| 148 |
+
-1.5938075653076171,
|
| 149 |
+
-0.4462003350734711,
|
| 150 |
+
-2.195949993133545,
|
| 151 |
+
-0.13648582720756508,
|
| 152 |
+
-0.0017502129077910933,
|
| 153 |
+
-0.0023805056512355804,
|
| 154 |
+
-1.703943779706955,
|
| 155 |
+
-1.0257196548461915,
|
| 156 |
+
-1.6663076572418207
|
| 157 |
+
],
|
| 158 |
+
"q99": [
|
| 159 |
+
0.16785249881744324,
|
| 160 |
+
2.5760957974332737,
|
| 161 |
+
2.47087028901875,
|
| 162 |
+
1.2516060169219974,
|
| 163 |
+
1.22238815100193,
|
| 164 |
+
1.4495478111267097,
|
| 165 |
+
1.073332511305809,
|
| 166 |
+
2.602131088757515,
|
| 167 |
+
2.494382914789021,
|
| 168 |
+
1.3104696589708325,
|
| 169 |
+
1.0657097997665406,
|
| 170 |
+
2.102274682235718
|
| 171 |
+
],
|
| 172 |
+
"q02": [
|
| 173 |
+
-0.9234203773498537,
|
| 174 |
+
-0.0003164021181873977,
|
| 175 |
+
-0.0026667596280574857,
|
| 176 |
+
-1.5060962793350219,
|
| 177 |
+
-0.3291901943683624,
|
| 178 |
+
-1.6436018409728996,
|
| 179 |
+
-0.05788616445064587,
|
| 180 |
+
-0.0017502129077910933,
|
| 181 |
+
-0.0023805056512355804,
|
| 182 |
+
-1.6164209202528,
|
| 183 |
+
-0.8698848910331727,
|
| 184 |
+
-1.5490281238555905
|
| 185 |
+
],
|
| 186 |
+
"q98": [
|
| 187 |
+
0.1157631013870235,
|
| 188 |
+
2.4928853822953303,
|
| 189 |
+
2.3174869681090113,
|
| 190 |
+
1.0754401289939883,
|
| 191 |
+
1.0122474901437757,
|
| 192 |
+
1.2031155586242681,
|
| 193 |
+
0.945415413093567,
|
| 194 |
+
2.527320075106621,
|
| 195 |
+
2.3522132336720825,
|
| 196 |
+
1.202445949554443,
|
| 197 |
+
0.694808959197998,
|
| 198 |
+
1.7484814231872559
|
| 199 |
+
]
|
| 200 |
+
},
|
| 201 |
+
"observation.state.effector.position": {
|
| 202 |
+
"mean": [
|
| 203 |
+
0.6735715866088867,
|
| 204 |
+
0.6747165322303772
|
| 205 |
+
],
|
| 206 |
+
"std": [
|
| 207 |
+
0.4521658420562744,
|
| 208 |
+
0.4510030150413513
|
| 209 |
+
],
|
| 210 |
+
"q01": [
|
| 211 |
+
0.0,
|
| 212 |
+
0.0
|
| 213 |
+
],
|
| 214 |
+
"q99": [
|
| 215 |
+
0.9998,
|
| 216 |
+
0.9998
|
| 217 |
+
],
|
| 218 |
+
"q02": [
|
| 219 |
+
0.0,
|
| 220 |
+
0.0
|
| 221 |
+
],
|
| 222 |
+
"q98": [
|
| 223 |
+
0.9998,
|
| 224 |
+
0.9998
|
| 225 |
+
]
|
| 226 |
+
}
|
| 227 |
+
},
|
| 228 |
+
"count": 535680
|
| 229 |
+
}
|
assets/scale_ps.png
ADDED
|
Git LFS Details
|
assets/scale_sr.png
ADDED
|
Git LFS Details
|
configs/norm/robotwin_5.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
model_path: /path/to/LingBot-VLA-Depth
|
| 3 |
+
tokenizer_path: /path/to/Qwen2.5-VL-3B-Instruct/
|
| 4 |
+
|
| 5 |
+
data:
|
| 6 |
+
datasets_type: vla
|
| 7 |
+
train_path: /path/to/mixed_robotwin_5tasks
|
| 8 |
+
norm_path: assets/norm_stats/robotwin_5_custom.json
|
| 9 |
+
|
| 10 |
+
train:
|
| 11 |
+
global_batch_size: 512
|
| 12 |
+
output_dir: output/norm
|
configs/vla/robotwin_load20000h.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
model_path: /path/to/LingBot-VLA
|
| 3 |
+
tokenizer_path: /path/to/Qwen2.5-VL-3B-Instruct/
|
| 4 |
+
post_training: true
|
| 5 |
+
adanorm_time: true
|
| 6 |
+
old_adanorm: true
|
| 7 |
+
|
| 8 |
+
data:
|
| 9 |
+
datasets_type: vla
|
| 10 |
+
data_name: robotwin_5_new
|
| 11 |
+
train_path: /path/to/mixed_robotwin_5tasks
|
| 12 |
+
num_workers: 8
|
| 13 |
+
norm_type: bounds_99_woclip
|
| 14 |
+
norm_stats_file: assets/norm_stats/robotwin_50.json
|
| 15 |
+
|
| 16 |
+
train:
|
| 17 |
+
output_dir: /path/to/lingbot_robotwin5tasks/
|
| 18 |
+
loss_type: L1_fm
|
| 19 |
+
data_parallel_mode: fsdp2
|
| 20 |
+
enable_full_shard: false
|
| 21 |
+
module_fsdp_enable: true
|
| 22 |
+
use_compile: true
|
| 23 |
+
use_wandb: false
|
| 24 |
+
rmpad: false
|
| 25 |
+
rmpad_with_pos_ids: false
|
| 26 |
+
ulysses_parallel_size: 1
|
| 27 |
+
freeze_vision_encoder: false
|
| 28 |
+
tokenizer_max_length: 24
|
| 29 |
+
action_dim: 14
|
| 30 |
+
max_action_dim: 75
|
| 31 |
+
max_state_dim: 75
|
| 32 |
+
lr: 1.0e-4
|
| 33 |
+
lr_decay_style: constant
|
| 34 |
+
num_train_epochs: 69
|
| 35 |
+
micro_batch_size: 32
|
| 36 |
+
global_batch_size: 256
|
| 37 |
+
max_steps: 220000
|
| 38 |
+
ckpt_manager: dcp
|
| 39 |
+
save_steps: 220000
|
| 40 |
+
save_epochs: 69
|
| 41 |
+
enable_fp32: true
|
| 42 |
+
enable_resume: true
|
configs/vla/robotwin_load20000h_depth.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
model_path: /path/to/LingBot-VLA-Depth
|
| 3 |
+
tokenizer_path: /path/to/Qwen2.5-VL-3B-Instruct/
|
| 4 |
+
post_training: true
|
| 5 |
+
adanorm_time: true
|
| 6 |
+
old_adanorm: true
|
| 7 |
+
moge_path: /path/to/moge2-vitb-normal
|
| 8 |
+
morgbd_path: /path/to/LingBot-Depth-Pretrained
|
| 9 |
+
|
| 10 |
+
data:
|
| 11 |
+
datasets_type: vla
|
| 12 |
+
data_name: robotwin_5_new
|
| 13 |
+
train_path: /path/to/mixed_robotwin_5tasks
|
| 14 |
+
num_workers: 8
|
| 15 |
+
norm_type: bounds_99_woclip
|
| 16 |
+
norm_stats_file: assets/norm_stats/robotwin_50.json
|
| 17 |
+
|
| 18 |
+
train:
|
| 19 |
+
output_dir: /path/to/lingbot_depth_robotwin5tasks/
|
| 20 |
+
loss_type: L1_fm
|
| 21 |
+
data_parallel_mode: fsdp2
|
| 22 |
+
enable_full_shard: false
|
| 23 |
+
module_fsdp_enable: true
|
| 24 |
+
use_compile: true
|
| 25 |
+
use_wandb: false
|
| 26 |
+
rmpad: false
|
| 27 |
+
rmpad_with_pos_ids: false
|
| 28 |
+
ulysses_parallel_size: 1
|
| 29 |
+
freeze_vision_encoder: false
|
| 30 |
+
tokenizer_max_length: 24
|
| 31 |
+
action_dim: 14
|
| 32 |
+
max_action_dim: 75
|
| 33 |
+
max_state_dim: 75
|
| 34 |
+
lr: 1.0e-4
|
| 35 |
+
lr_decay_style: constant
|
| 36 |
+
num_train_epochs: 69
|
| 37 |
+
micro_batch_size: 32
|
| 38 |
+
global_batch_size: 256
|
| 39 |
+
max_steps: 220000
|
| 40 |
+
ckpt_manager: dcp
|
| 41 |
+
save_steps: 220000
|
| 42 |
+
save_epochs: 69
|
| 43 |
+
enable_fp32: true
|
| 44 |
+
enable_resume: true
|
| 45 |
+
align_params:
|
| 46 |
+
mode: 'query'
|
| 47 |
+
num_task_tokens: 8
|
| 48 |
+
use_image_tokens: True
|
| 49 |
+
use_task_tokens: False
|
| 50 |
+
use_text_tokens: False
|
| 51 |
+
use_contrastive: True
|
| 52 |
+
contrastive_loss_weight: 0.3
|
| 53 |
+
depth_loss_weight: 0.004
|
| 54 |
+
llm:
|
| 55 |
+
dim_out: 2048
|
| 56 |
+
image_token_size: 8
|
| 57 |
+
image_input_size: 224
|
| 58 |
+
depth:
|
| 59 |
+
model_type: MoRGBD
|
| 60 |
+
num_layers: 1
|
| 61 |
+
num_heads: 4
|
| 62 |
+
dim_head: 32
|
| 63 |
+
ff_mult: 1
|
| 64 |
+
num_backbone_tokens: 256
|
| 65 |
+
token_size: 16
|
| 66 |
+
dim_out: 1024
|
| 67 |
+
input_size: 224
|
| 68 |
+
visual_steps: 10000
|
deploy/__init__.py
ADDED
|
File without changes
|
deploy/image_tools.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def convert_to_uint8(img: np.ndarray) -> np.ndarray:
|
| 6 |
+
"""Converts an image to uint8 if it is a float image.
|
| 7 |
+
|
| 8 |
+
This is important for reducing the size of the image when sending it over the network.
|
| 9 |
+
"""
|
| 10 |
+
if np.issubdtype(img.dtype, np.floating):
|
| 11 |
+
img = (255 * img).astype(np.uint8)
|
| 12 |
+
return img
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
|
| 16 |
+
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
images: A batch of images in [..., height, width, channel] format.
|
| 20 |
+
height: The target height of the image.
|
| 21 |
+
width: The target width of the image.
|
| 22 |
+
method: The interpolation method to use. Default is bilinear.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
The resized images in [..., height, width, channel].
|
| 26 |
+
"""
|
| 27 |
+
# If the images are already the correct size, return them as is.
|
| 28 |
+
if images.shape[-3:-1] == (height, width):
|
| 29 |
+
return images
|
| 30 |
+
|
| 31 |
+
original_shape = images.shape
|
| 32 |
+
|
| 33 |
+
images = images.reshape(-1, *original_shape[-3:])
|
| 34 |
+
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
|
| 35 |
+
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
|
| 39 |
+
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
|
| 40 |
+
width without distortion by padding with zeros.
|
| 41 |
+
|
| 42 |
+
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
|
| 43 |
+
"""
|
| 44 |
+
cur_width, cur_height = image.size
|
| 45 |
+
if cur_width == width and cur_height == height:
|
| 46 |
+
return image # No need to resize if the image is already the correct size.
|
| 47 |
+
|
| 48 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 49 |
+
resized_height = int(cur_height / ratio)
|
| 50 |
+
resized_width = int(cur_width / ratio)
|
| 51 |
+
resized_image = image.resize((resized_width, resized_height), resample=method)
|
| 52 |
+
|
| 53 |
+
zero_image = Image.new(resized_image.mode, (width, height), 0)
|
| 54 |
+
pad_height = max(0, int((height - resized_height) / 2))
|
| 55 |
+
pad_width = max(0, int((width - resized_width) / 2))
|
| 56 |
+
zero_image.paste(resized_image, (pad_width, pad_height))
|
| 57 |
+
assert zero_image.size == (width, height)
|
| 58 |
+
return zero_image
|
deploy/lingbot_robotwin_policy.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from collections import deque
|
| 7 |
+
import torchvision
|
| 8 |
+
import yaml
|
| 9 |
+
from types import SimpleNamespace
|
| 10 |
+
from packaging.version import Version
|
| 11 |
+
from typing import Callable, Dict, List, Optional, Type, Union, Tuple, Any, Sequence
|
| 12 |
+
from glob import glob
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from safetensors import safe_open
|
| 15 |
+
from safetensors.torch import load_file
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch import Tensor, nn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import transformers
|
| 24 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 25 |
+
from transformers import (
|
| 26 |
+
AutoConfig,
|
| 27 |
+
PretrainedConfig,
|
| 28 |
+
PreTrainedModel,
|
| 29 |
+
AutoProcessor,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from lerobot.configs.policies import PreTrainedConfig
|
| 33 |
+
from lingbotvla.models.vla.pi0.modeling_pi0 import PI0Policy
|
| 34 |
+
from lingbotvla.models.vla.pi0.modeling_lingbot_vla import LingbotVlaPolicy
|
| 35 |
+
from lingbotvla.data.vla_data.transform import Normalizer, prepare_images, prepare_language, prepare_state
|
| 36 |
+
from lingbotvla.models import build_processor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def set_seed_everywhere(seed: int):
|
| 40 |
+
"""Sets the random seed for Python, NumPy, and PyTorch functions."""
|
| 41 |
+
torch.manual_seed(seed)
|
| 42 |
+
torch.cuda.manual_seed_all(seed)
|
| 43 |
+
np.random.seed(seed)
|
| 44 |
+
random.seed(seed)
|
| 45 |
+
torch.backends.cudnn.deterministic = True
|
| 46 |
+
torch.backends.cudnn.benchmark = False
|
| 47 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 48 |
+
|
| 49 |
+
set_seed_everywhere(42)
|
| 50 |
+
|
| 51 |
+
BASE_MODEL_PATH = {
|
| 52 |
+
'pi0': os.environ.get('PALIGEMMA_PATH', './paligemma-3b-pt-224/'),
|
| 53 |
+
'lingbotvla': os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/'),
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
def load_model_weights(policy, path_to_pi_model, strict=True):
|
| 57 |
+
all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
|
| 58 |
+
merged_weights = {}
|
| 59 |
+
|
| 60 |
+
for file_path in tqdm(all_safetensors):
|
| 61 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 62 |
+
for key in f.keys():
|
| 63 |
+
merged_weights[key] = f.get_tensor(key)
|
| 64 |
+
policy.load_state_dict(merged_weights, strict=strict)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
|
| 68 |
+
crop_scale = 0.9
|
| 69 |
+
side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, 1.0))) # side length scale
|
| 70 |
+
out_size = (224, 224)
|
| 71 |
+
|
| 72 |
+
# Convert input to PIL Image
|
| 73 |
+
if isinstance(image, np.ndarray):
|
| 74 |
+
arr = image
|
| 75 |
+
if arr.dtype.kind == "f":
|
| 76 |
+
# If floats likely in [0,1], map to [0,255]
|
| 77 |
+
if arr.max() <= 1.0 and arr.min() >= 0.0:
|
| 78 |
+
arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8)
|
| 79 |
+
else:
|
| 80 |
+
arr = np.clip(arr, 0.0, 255.0).astype(np.uint8)
|
| 81 |
+
elif arr.dtype == np.uint16:
|
| 82 |
+
# Map 16-bit to 8-bit
|
| 83 |
+
arr = (arr / 257).astype(np.uint8)
|
| 84 |
+
elif arr.dtype != np.uint8:
|
| 85 |
+
arr = arr.astype(np.uint8)
|
| 86 |
+
pil = Image.fromarray(arr)
|
| 87 |
+
elif isinstance(image, Image.Image):
|
| 88 |
+
pil = image
|
| 89 |
+
else:
|
| 90 |
+
raise TypeError("image must be a numpy array or PIL.Image.Image")
|
| 91 |
+
|
| 92 |
+
# Force RGB for consistent output
|
| 93 |
+
pil = pil.convert("RGB")
|
| 94 |
+
W, H = pil.size
|
| 95 |
+
|
| 96 |
+
# Compute centered crop box (integer pixels)
|
| 97 |
+
crop_w = max(1, int(round(W * side_scale)))
|
| 98 |
+
crop_h = max(1, int(round(H * side_scale)))
|
| 99 |
+
left = (W - crop_w) // 2
|
| 100 |
+
top = (H - crop_h) // 2
|
| 101 |
+
right = left + crop_w
|
| 102 |
+
bottom = top + crop_h
|
| 103 |
+
|
| 104 |
+
cropped = pil.crop((left, top, right, bottom))
|
| 105 |
+
resized = cropped.resize(out_size, resample=Image.BILINEAR)
|
| 106 |
+
return resized
|
| 107 |
+
|
| 108 |
+
def resize_with_pad(img, width, height, pad_value=-1):
|
| 109 |
+
# assume no-op when width height fits already
|
| 110 |
+
if img.ndim != 4:
|
| 111 |
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
| 112 |
+
|
| 113 |
+
# channel last to channel first if necessary
|
| 114 |
+
if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3):
|
| 115 |
+
img = img.permute(0, 3, 1, 2)
|
| 116 |
+
|
| 117 |
+
cur_height, cur_width = img.shape[2:]
|
| 118 |
+
|
| 119 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 120 |
+
resized_height = int(cur_height / ratio)
|
| 121 |
+
resized_width = int(cur_width / ratio)
|
| 122 |
+
resized_img = F.interpolate(
|
| 123 |
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
pad_height = max(0, int(height - resized_height))
|
| 127 |
+
pad_width = max(0, int(width - resized_width))
|
| 128 |
+
|
| 129 |
+
# pad on left and top of image
|
| 130 |
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
| 131 |
+
return padded_img
|
| 132 |
+
|
| 133 |
+
class PolicyPreprocessMixin:
|
| 134 |
+
|
| 135 |
+
@torch.no_grad
|
| 136 |
+
def select_action(
|
| 137 |
+
self, observation: dict[str, Tensor], use_bf16: bool = False, vlm_causal: bool = False, noise: Tensor | None = None
|
| 138 |
+
):
|
| 139 |
+
self.eval()
|
| 140 |
+
device = 'cuda'
|
| 141 |
+
if use_bf16:
|
| 142 |
+
dtype = torch.bfloat16
|
| 143 |
+
else:
|
| 144 |
+
dtype = torch.float32
|
| 145 |
+
s1 = time.time()
|
| 146 |
+
|
| 147 |
+
if len(observation['images'].shape) == 4:
|
| 148 |
+
observation['images'] = observation['images'].unsqueeze(0)
|
| 149 |
+
observation['img_masks'] = observation['img_masks'].unsqueeze(0)
|
| 150 |
+
|
| 151 |
+
if 'expert_imgs' in observation:
|
| 152 |
+
actions = self.model.sample_actions(
|
| 153 |
+
observation['images'].to(dtype=dtype, device=device),
|
| 154 |
+
observation['img_masks'].to(device=device),
|
| 155 |
+
observation['lang_tokens'].unsqueeze(0).to(device=device),
|
| 156 |
+
observation['lang_masks'].unsqueeze(0).to(device=device),
|
| 157 |
+
observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
|
| 158 |
+
observation['expert_imgs'].to(dtype=dtype, device=device),
|
| 159 |
+
vlm_causal = vlm_causal
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
actions = self.model.sample_actions(
|
| 163 |
+
observation['images'].to(dtype=dtype, device=device),
|
| 164 |
+
observation['img_masks'].to(device=device),
|
| 165 |
+
observation['lang_tokens'].unsqueeze(0).to(device=device),
|
| 166 |
+
observation['lang_masks'].unsqueeze(0).to(device=device),
|
| 167 |
+
observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
|
| 168 |
+
vlm_causal = vlm_causal
|
| 169 |
+
)
|
| 170 |
+
delta_time = time.time() - s1
|
| 171 |
+
print(f'sample_actions cost {delta_time} s')
|
| 172 |
+
observation['action'] = actions.squeeze(0)[:, :14].to(dtype=torch.float32, device='cpu')
|
| 173 |
+
if use_bf16:
|
| 174 |
+
observation['state'] = observation['state'].to(dtype=torch.float32)
|
| 175 |
+
data = self.normalizer.unnormalize(observation)
|
| 176 |
+
return data
|
| 177 |
+
|
| 178 |
+
class LingBotVlaInferencePolicy(PolicyPreprocessMixin, LingbotVlaPolicy):
|
| 179 |
+
pass # Only combine necessary functions
|
| 180 |
+
|
| 181 |
+
class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy):
|
| 182 |
+
pass # Only combine necessary functions
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def merge_qwen_config(policy_config, qwen_config):
|
| 186 |
+
if hasattr(qwen_config, 'to_dict'):
|
| 187 |
+
config_dict = qwen_config.to_dict()
|
| 188 |
+
else:
|
| 189 |
+
config_dict = qwen_config
|
| 190 |
+
|
| 191 |
+
text_keys = {
|
| 192 |
+
"hidden_size",
|
| 193 |
+
"intermediate_size",
|
| 194 |
+
"num_hidden_layers",
|
| 195 |
+
"num_attention_heads",
|
| 196 |
+
"num_key_value_heads",
|
| 197 |
+
"rms_norm_eps",
|
| 198 |
+
"rope_theta",
|
| 199 |
+
"vocab_size",
|
| 200 |
+
"max_position_embeddings",
|
| 201 |
+
"hidden_act",
|
| 202 |
+
"tie_word_embeddings",
|
| 203 |
+
"tokenizer_path",
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
for key in text_keys:
|
| 207 |
+
if key in config_dict:
|
| 208 |
+
setattr(policy_config, key, config_dict[key])
|
| 209 |
+
print(f"✅ Merged LLM: {key} = {config_dict[key]}")
|
| 210 |
+
|
| 211 |
+
if "vision_config" in config_dict:
|
| 212 |
+
policy_config.vision_config = qwen_config.vision_config
|
| 213 |
+
else:
|
| 214 |
+
print("⚠️ Warning: 'vision_config' not found in qwen_config!")
|
| 215 |
+
|
| 216 |
+
return policy_config
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class QwenPiServer:
|
| 220 |
+
'''
|
| 221 |
+
policy wrapper to support action ensemble or chunk execution
|
| 222 |
+
'''
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
path_to_pi_model="",
|
| 226 |
+
adaptive_ensemble_alpha=0.1,
|
| 227 |
+
action_ensemble_horizon=8,
|
| 228 |
+
use_length=1, # to control the execution length of the action chunk, -1 denotes using action ensemble
|
| 229 |
+
chunk_ret=False,
|
| 230 |
+
use_bf16=True,
|
| 231 |
+
use_fp32=False,
|
| 232 |
+
) -> None:
|
| 233 |
+
assert not (use_bf16 and use_fp32), 'Bfloat16 or Float32!!!'
|
| 234 |
+
self.adaptive_ensemble_alpha = adaptive_ensemble_alpha
|
| 235 |
+
self.use_length = use_length
|
| 236 |
+
self.chunk_ret = chunk_ret
|
| 237 |
+
|
| 238 |
+
self.task_description = None
|
| 239 |
+
|
| 240 |
+
self.vla = self.load_vla(path_to_pi_model)
|
| 241 |
+
self.vla = self.vla.cuda().eval()
|
| 242 |
+
if use_bf16:
|
| 243 |
+
self.vla = self.vla.to(torch.bfloat16)
|
| 244 |
+
elif use_fp32:
|
| 245 |
+
self.vla.model.float()
|
| 246 |
+
self.global_step = 0
|
| 247 |
+
self.last_action_chunk = None
|
| 248 |
+
self.use_bf16 = use_bf16
|
| 249 |
+
self.use_fp32 = use_fp32
|
| 250 |
+
|
| 251 |
+
def load_vla(self, path_to_pi_model) -> LingbotVlaPolicy:
|
| 252 |
+
# load model
|
| 253 |
+
|
| 254 |
+
print(f"loading model from: {path_to_pi_model}")
|
| 255 |
+
config = PreTrainedConfig.from_pretrained(path_to_pi_model)
|
| 256 |
+
|
| 257 |
+
# load training config
|
| 258 |
+
training_config_path = Path(path_to_pi_model).parent.parent.parent/'lingbotvla_cli.yaml'
|
| 259 |
+
with open(training_config_path, 'r') as f:
|
| 260 |
+
training_config = yaml.safe_load(f)
|
| 261 |
+
f.close()
|
| 262 |
+
|
| 263 |
+
# update model config according to training config
|
| 264 |
+
training_model_config = training_config['model']
|
| 265 |
+
training_model_config.update(training_config['train'])
|
| 266 |
+
for k, v in training_model_config.items():
|
| 267 |
+
v = getattr(config, k, training_model_config[k])
|
| 268 |
+
setattr(config, k, v)
|
| 269 |
+
|
| 270 |
+
# Set attention_implementation to 'eager' to speed up evaluation.
|
| 271 |
+
config.attention_implementation = 'eager'
|
| 272 |
+
|
| 273 |
+
# set base model according to training config
|
| 274 |
+
training_base_model = training_config['model']['tokenizer_path']
|
| 275 |
+
if 'paligemma' in training_base_model:
|
| 276 |
+
model_name = 'pi0'
|
| 277 |
+
config.vocab_size = 257152 # set vocab size for paligamma
|
| 278 |
+
elif 'qwen2' in training_base_model.lower():
|
| 279 |
+
model_name = 'lingbotvla'
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Unsupported base model of {path_to_pi_model}")
|
| 282 |
+
base_model_path = BASE_MODEL_PATH[model_name]
|
| 283 |
+
config.tokenizer_path = base_model_path
|
| 284 |
+
self.model_name = model_name
|
| 285 |
+
|
| 286 |
+
qwen_config = AutoConfig.from_pretrained(base_model_path)
|
| 287 |
+
config = merge_qwen_config(config, qwen_config)
|
| 288 |
+
|
| 289 |
+
if 'vocab_size' in training_config['model'] and training_config['model']['vocab_size'] != 0:
|
| 290 |
+
config.vocab_size = training_config['model']['vocab_size']
|
| 291 |
+
# load processors
|
| 292 |
+
self.processor = build_processor(base_model_path)
|
| 293 |
+
self.language_tokenizer = self.processor.tokenizer
|
| 294 |
+
self.image_processor = self.processor.image_processor
|
| 295 |
+
data_config = SimpleNamespace(**training_config['data'])
|
| 296 |
+
|
| 297 |
+
print('Initializing model ... ')
|
| 298 |
+
|
| 299 |
+
if 'paligemma' in training_base_model:
|
| 300 |
+
policy = PI0InfernecePolicy(config, tokenizer_path=base_model_path)
|
| 301 |
+
else:
|
| 302 |
+
policy = LingBotVlaInferencePolicy(config, tokenizer_path=base_model_path)
|
| 303 |
+
|
| 304 |
+
load_model_weights(policy, path_to_pi_model, strict=True)
|
| 305 |
+
|
| 306 |
+
policy.feature_transform = None
|
| 307 |
+
self.data_config = data_config
|
| 308 |
+
self.config = config
|
| 309 |
+
self.joint_max_dim = training_config['train']['max_action_dim']
|
| 310 |
+
self.action_dim = training_config['train']['action_dim']
|
| 311 |
+
self.chunk_size = training_config['train']['chunk_size']
|
| 312 |
+
policy.action_dim = self.action_dim
|
| 313 |
+
policy.chunk_size = self.chunk_size
|
| 314 |
+
self.norm_stats_file = data_config.norm_stats_file
|
| 315 |
+
if 'align_params' in training_config['train']:
|
| 316 |
+
self.use_depth_align = True
|
| 317 |
+
else: self.use_depth_align = False
|
| 318 |
+
with open(self.norm_stats_file) as f:
|
| 319 |
+
self.norm_stats = json.load(f)
|
| 320 |
+
policy.normalizer = Normalizer(
|
| 321 |
+
norm_stats=self.norm_stats['norm_stats'],
|
| 322 |
+
from_file=True,
|
| 323 |
+
data_type='robotwin',
|
| 324 |
+
norm_type={
|
| 325 |
+
"observation.images.cam_high": "identity",
|
| 326 |
+
"observation.images.cam_left_wrist": "identity",
|
| 327 |
+
"observation.images.cam_right_wrist": "identity",
|
| 328 |
+
"observation.state": self.data_config.norm_type,
|
| 329 |
+
"action": self.data_config.norm_type,
|
| 330 |
+
},
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
print('Model initialized ... ')
|
| 334 |
+
|
| 335 |
+
return policy
|
| 336 |
+
|
| 337 |
+
def reset(self, robo_name, path_to_pi_model = None) -> None:
|
| 338 |
+
|
| 339 |
+
if path_to_pi_model is not None:
|
| 340 |
+
self.vla = self.load_vla(path_to_pi_model)
|
| 341 |
+
self.vla = self.vla.cuda().eval()
|
| 342 |
+
if self.use_bf16:
|
| 343 |
+
self.vla = self.vla.to(torch.bfloat16)
|
| 344 |
+
elif self.use_fp32:
|
| 345 |
+
self.vla.model.float()
|
| 346 |
+
|
| 347 |
+
self.global_step = 0
|
| 348 |
+
self.last_action_chunk = None
|
| 349 |
+
|
| 350 |
+
if getattr(self.data_config, 'norm_type', None) is None:
|
| 351 |
+
self.data_config.norm_type = 'meanstd'
|
| 352 |
+
if getattr(self.config, 'vlm_causal', None) is None:
|
| 353 |
+
self.config.vlm_causal = False
|
| 354 |
+
if getattr(self.config, 'qwenvl_bos', None) is None:
|
| 355 |
+
self.config.qwenvl_bos = False
|
| 356 |
+
|
| 357 |
+
# if update ckpt path
|
| 358 |
+
if path_to_pi_model is not None:
|
| 359 |
+
all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
|
| 360 |
+
merged_weights = {}
|
| 361 |
+
|
| 362 |
+
for file_path in tqdm(all_safetensors):
|
| 363 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 364 |
+
for key in f.keys():
|
| 365 |
+
merged_weights[key] = f.get_tensor(key)
|
| 366 |
+
|
| 367 |
+
self.vla.load_state_dict(merged_weights, strict=True)
|
| 368 |
+
|
| 369 |
+
def resize_image(self, observation):
|
| 370 |
+
for image_feature in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
|
| 371 |
+
assert image_feature in observation
|
| 372 |
+
assert len(observation[image_feature].shape)==3 and observation[image_feature].shape[-1] == 3
|
| 373 |
+
image = observation[image_feature]
|
| 374 |
+
img_pil = Image.fromarray(image)
|
| 375 |
+
image_size = getattr(self.data_config, 'img_size', 224)
|
| 376 |
+
img_pil = img_pil.resize((image_size, image_size), Image.BILINEAR)
|
| 377 |
+
|
| 378 |
+
# img_resized shape: C*H*W
|
| 379 |
+
img_resized = np.transpose(np.array(img_pil), (2,0,1)) # (3,224,224)
|
| 380 |
+
observation[image_feature] = img_resized / 255.
|
| 381 |
+
|
| 382 |
+
def infer(self, observation, center_crop=True):
|
| 383 |
+
"""Generates an action with the VLA policy."""
|
| 384 |
+
|
| 385 |
+
# (If trained with image augmentations) Center crop image and then resize back up to original size.
|
| 386 |
+
# IMPORTANT: Let's say crop scale == 0.9. To get the new height and width (post-crop), multiply
|
| 387 |
+
# the original height and width by sqrt(0.9) -- not 0.9!
|
| 388 |
+
if 'reset' in observation and observation['reset']:
|
| 389 |
+
self.reset(robo_name=observation['robo_name'], path_to_pi_model=observation['path_to_pi_model'] if 'path_to_pi_model' in observation else None)
|
| 390 |
+
return dict(action = None)
|
| 391 |
+
|
| 392 |
+
self.resize_image(observation)
|
| 393 |
+
for k, v in observation.items():
|
| 394 |
+
if isinstance(v, np.ndarray):
|
| 395 |
+
observation[k] = torch.from_numpy(v)
|
| 396 |
+
|
| 397 |
+
if self.use_length == -1 or self.global_step % self.use_length == 0:
|
| 398 |
+
joint_max_dim = getattr(self, 'joint_max_dim')
|
| 399 |
+
action_dim = getattr(self, 'action_dim')
|
| 400 |
+
chunk_size = getattr(self, 'chunk_size')
|
| 401 |
+
normalized_observation = self.vla.normalizer.normalize(observation)
|
| 402 |
+
base_image = (normalized_observation["observation.images.cam_high"] * 255).to(torch.uint8)
|
| 403 |
+
left_wrist_image = (normalized_observation["observation.images.cam_left_wrist"] * 255).to(
|
| 404 |
+
torch.uint8
|
| 405 |
+
)
|
| 406 |
+
right_wrist_image = (normalized_observation["observation.images.cam_right_wrist"] * 255).to(
|
| 407 |
+
torch.uint8
|
| 408 |
+
)
|
| 409 |
+
obs_dict = {
|
| 410 |
+
"image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
|
| 411 |
+
"state": normalized_observation["observation.state"].to(torch.float32),
|
| 412 |
+
"prompt": [observation["task"]],
|
| 413 |
+
}
|
| 414 |
+
state = prepare_state(self.config, obs_dict)
|
| 415 |
+
lang_tokens, lang_masks = prepare_language(self.config, self.language_tokenizer, obs_dict)
|
| 416 |
+
images, img_masks, _ = prepare_images(self.config, self.image_processor, obs_dict)
|
| 417 |
+
observation = {
|
| 418 |
+
'images': images,
|
| 419 |
+
'img_masks': img_masks,
|
| 420 |
+
'state': state,
|
| 421 |
+
'lang_tokens': lang_tokens,
|
| 422 |
+
'lang_masks': lang_masks,
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
if self.use_bf16:
|
| 426 |
+
observation['state'] = observation['state'].to(torch.bfloat16)
|
| 427 |
+
|
| 428 |
+
org_actions = ['action']
|
| 429 |
+
assert len(org_actions)==1, "Only support single action feature"
|
| 430 |
+
if self.chunk_ret:
|
| 431 |
+
action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]].float().cpu().numpy()
|
| 432 |
+
action = action[:self.use_length, :self.action_dim]
|
| 433 |
+
else:
|
| 434 |
+
if self.use_length == -1 or self.global_step % self.use_length == 0:
|
| 435 |
+
action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]]
|
| 436 |
+
self.last_action_chunk = action.float().cpu().numpy()
|
| 437 |
+
|
| 438 |
+
if self.use_length > 0:
|
| 439 |
+
action = self.last_action_chunk[self.global_step % self.use_length]
|
| 440 |
+
action = action[:, :self.action_dim]
|
| 441 |
+
print(f"on server step: {self.global_step}")
|
| 442 |
+
self.global_step+=1
|
| 443 |
+
|
| 444 |
+
return dict(action = action)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
import argparse
|
| 448 |
+
from .websocket_policy_server import WebsocketPolicyServer
|
| 449 |
+
|
| 450 |
+
def main():
|
| 451 |
+
parser = argparse.ArgumentParser(description="启动 QwenPi WebSocket 策略服务器")
|
| 452 |
+
|
| 453 |
+
parser.add_argument(
|
| 454 |
+
"--model_path",
|
| 455 |
+
type=str,
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
parser.add_argument(
|
| 459 |
+
"--use_length",
|
| 460 |
+
type=int,
|
| 461 |
+
default=50,
|
| 462 |
+
help="used length of action chunk"
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
parser.add_argument(
|
| 466 |
+
"--chunk_ret",
|
| 467 |
+
type=bool,
|
| 468 |
+
default=True,
|
| 469 |
+
help=" True: The returned action tensor includes the horizon dimension. This allows the model to output a sequence of actions for each horizon step. False: The horizon dimension is omitted. The model selects and returns the next step autonomously based on its policy."
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
parser.add_argument(
|
| 473 |
+
"--port",
|
| 474 |
+
type=int,
|
| 475 |
+
default=8006,
|
| 476 |
+
help="port of WebSocket"
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
parser.add_argument(
|
| 480 |
+
"--debug_infer_once",
|
| 481 |
+
action="store_true",
|
| 482 |
+
help="Run one infer with dummy observation then exit (for debugging infer() without WebSocket client)",
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
args = parser.parse_args()
|
| 486 |
+
|
| 487 |
+
model = QwenPiServer(args.model_path, use_length=args.use_length, chunk_ret=args.chunk_ret)
|
| 488 |
+
if args.debug_infer_once:
|
| 489 |
+
# 调试用:不启动 WebSocket,只跑一次 infer,可在 infer / select_action 里下断点
|
| 490 |
+
dummy_obs = {
|
| 491 |
+
"observation.images.cam_high": np.zeros((224, 224, 3), dtype=np.uint8),
|
| 492 |
+
"observation.images.cam_left_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
|
| 493 |
+
"observation.images.cam_right_wrist": np.zeros((224, 224, 3), dtype=np.uint8),
|
| 494 |
+
"observation.state": np.zeros(model.action_dim, dtype=np.float32),
|
| 495 |
+
"task": "dummy task for debug",
|
| 496 |
+
"reset": False,
|
| 497 |
+
}
|
| 498 |
+
out = model.infer(dummy_obs)
|
| 499 |
+
print("debug_infer_once result keys:", out.keys())
|
| 500 |
+
return
|
| 501 |
+
model_server = WebsocketPolicyServer(model, port=args.port)
|
| 502 |
+
model_server.serve_forever()
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
main()
|
deploy/lingbot_robotwin_policy_rep.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from collections import deque
|
| 7 |
+
import torchvision
|
| 8 |
+
import yaml
|
| 9 |
+
from types import SimpleNamespace
|
| 10 |
+
from packaging.version import Version
|
| 11 |
+
from typing import Callable, Dict, List, Optional, Type, Union, Tuple, Any, Sequence
|
| 12 |
+
from glob import glob
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
from safetensors import safe_open
|
| 15 |
+
from safetensors.torch import load_file
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from PIL import Image
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch import Tensor, nn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
import transformers
|
| 24 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
| 25 |
+
from transformers import (
|
| 26 |
+
AutoConfig,
|
| 27 |
+
PretrainedConfig,
|
| 28 |
+
PreTrainedModel,
|
| 29 |
+
AutoProcessor,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from lerobot.configs.policies import PreTrainedConfig
|
| 33 |
+
from lingbotvla.models.vla.pi0.modeling_pi0 import PI0Policy
|
| 34 |
+
from lingbotvla.models.vla.pi0.modeling_lingbot_vla import LingbotVlaPolicy
|
| 35 |
+
from lingbotvla.data.vla_data.transform import Normalizer, prepare_images, prepare_language, prepare_state
|
| 36 |
+
from lingbotvla.models import build_processor
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def set_seed_everywhere(seed: int):
|
| 40 |
+
"""Sets the random seed for Python, NumPy, and PyTorch functions."""
|
| 41 |
+
torch.manual_seed(seed)
|
| 42 |
+
torch.cuda.manual_seed_all(seed)
|
| 43 |
+
np.random.seed(seed)
|
| 44 |
+
random.seed(seed)
|
| 45 |
+
torch.backends.cudnn.deterministic = True
|
| 46 |
+
torch.backends.cudnn.benchmark = False
|
| 47 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 48 |
+
|
| 49 |
+
set_seed_everywhere(42)
|
| 50 |
+
|
| 51 |
+
BASE_MODEL_PATH = {
|
| 52 |
+
'pi0': os.environ.get('PALIGEMMA_PATH', './paligemma-3b-pt-224/'),
|
| 53 |
+
'lingbotvla': os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/'),
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
def load_model_weights(policy, path_to_pi_model, strict=True):
|
| 57 |
+
all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
|
| 58 |
+
merged_weights = {}
|
| 59 |
+
|
| 60 |
+
for file_path in tqdm(all_safetensors):
|
| 61 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 62 |
+
for key in f.keys():
|
| 63 |
+
merged_weights[key] = f.get_tensor(key)
|
| 64 |
+
policy.load_state_dict(merged_weights, strict=strict)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def center_crop_image(image: Union[np.ndarray, Image.Image]) -> Image.Image:
|
| 68 |
+
crop_scale = 0.9
|
| 69 |
+
side_scale = float(np.sqrt(np.clip(crop_scale, 0.0, 1.0))) # side length scale
|
| 70 |
+
out_size = (224, 224)
|
| 71 |
+
|
| 72 |
+
# Convert input to PIL Image
|
| 73 |
+
if isinstance(image, np.ndarray):
|
| 74 |
+
arr = image
|
| 75 |
+
if arr.dtype.kind == "f":
|
| 76 |
+
# If floats likely in [0,1], map to [0,255]
|
| 77 |
+
if arr.max() <= 1.0 and arr.min() >= 0.0:
|
| 78 |
+
arr = (np.clip(arr, 0.0, 1.0) * 255.0).astype(np.uint8)
|
| 79 |
+
else:
|
| 80 |
+
arr = np.clip(arr, 0.0, 255.0).astype(np.uint8)
|
| 81 |
+
elif arr.dtype == np.uint16:
|
| 82 |
+
# Map 16-bit to 8-bit
|
| 83 |
+
arr = (arr / 257).astype(np.uint8)
|
| 84 |
+
elif arr.dtype != np.uint8:
|
| 85 |
+
arr = arr.astype(np.uint8)
|
| 86 |
+
pil = Image.fromarray(arr)
|
| 87 |
+
elif isinstance(image, Image.Image):
|
| 88 |
+
pil = image
|
| 89 |
+
else:
|
| 90 |
+
raise TypeError("image must be a numpy array or PIL.Image.Image")
|
| 91 |
+
|
| 92 |
+
# Force RGB for consistent output
|
| 93 |
+
pil = pil.convert("RGB")
|
| 94 |
+
W, H = pil.size
|
| 95 |
+
|
| 96 |
+
# Compute centered crop box (integer pixels)
|
| 97 |
+
crop_w = max(1, int(round(W * side_scale)))
|
| 98 |
+
crop_h = max(1, int(round(H * side_scale)))
|
| 99 |
+
left = (W - crop_w) // 2
|
| 100 |
+
top = (H - crop_h) // 2
|
| 101 |
+
right = left + crop_w
|
| 102 |
+
bottom = top + crop_h
|
| 103 |
+
|
| 104 |
+
cropped = pil.crop((left, top, right, bottom))
|
| 105 |
+
resized = cropped.resize(out_size, resample=Image.BILINEAR)
|
| 106 |
+
return resized
|
| 107 |
+
|
| 108 |
+
def resize_with_pad(img, width, height, pad_value=-1):
|
| 109 |
+
# assume no-op when width height fits already
|
| 110 |
+
if img.ndim != 4:
|
| 111 |
+
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
| 112 |
+
|
| 113 |
+
# channel last to channel first if necessary
|
| 114 |
+
if img.shape[1] not in (1, 3) and img.shape[-1] in (1, 3):
|
| 115 |
+
img = img.permute(0, 3, 1, 2)
|
| 116 |
+
|
| 117 |
+
cur_height, cur_width = img.shape[2:]
|
| 118 |
+
|
| 119 |
+
ratio = max(cur_width / width, cur_height / height)
|
| 120 |
+
resized_height = int(cur_height / ratio)
|
| 121 |
+
resized_width = int(cur_width / ratio)
|
| 122 |
+
resized_img = F.interpolate(
|
| 123 |
+
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
pad_height = max(0, int(height - resized_height))
|
| 127 |
+
pad_width = max(0, int(width - resized_width))
|
| 128 |
+
|
| 129 |
+
# pad on left and top of image
|
| 130 |
+
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
| 131 |
+
return padded_img
|
| 132 |
+
|
| 133 |
+
class PolicyPreprocessMixin:
|
| 134 |
+
|
| 135 |
+
@torch.no_grad
|
| 136 |
+
def select_action(
|
| 137 |
+
self, observation: dict[str, Tensor], use_bf16: bool = False, vlm_causal: bool = False, noise: Tensor | None = None
|
| 138 |
+
):
|
| 139 |
+
self.eval()
|
| 140 |
+
device = 'cuda'
|
| 141 |
+
if use_bf16:
|
| 142 |
+
dtype = torch.bfloat16
|
| 143 |
+
else:
|
| 144 |
+
dtype = torch.float32
|
| 145 |
+
s1 = time.time()
|
| 146 |
+
|
| 147 |
+
if len(observation['images'].shape) == 4:
|
| 148 |
+
observation['images'] = observation['images'].unsqueeze(0)
|
| 149 |
+
observation['img_masks'] = observation['img_masks'].unsqueeze(0)
|
| 150 |
+
state_indices = list(range(12)) + list(range(73, 75)) + list(range(12, 14)) + list(range(14, 73))
|
| 151 |
+
observation['state'] = observation['state'][state_indices]
|
| 152 |
+
if 'expert_imgs' in observation:
|
| 153 |
+
actions = self.model.sample_actions(
|
| 154 |
+
observation['images'].to(dtype=dtype, device=device),
|
| 155 |
+
observation['img_masks'].to(device=device),
|
| 156 |
+
observation['lang_tokens'].unsqueeze(0).to(device=device),
|
| 157 |
+
observation['lang_masks'].unsqueeze(0).to(device=device),
|
| 158 |
+
observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
|
| 159 |
+
observation['expert_imgs'].to(dtype=dtype, device=device),
|
| 160 |
+
vlm_causal = vlm_causal
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
actions = self.model.sample_actions(
|
| 164 |
+
observation['images'].to(dtype=dtype, device=device),
|
| 165 |
+
observation['img_masks'].to(device=device),
|
| 166 |
+
observation['lang_tokens'].unsqueeze(0).to(device=device),
|
| 167 |
+
observation['lang_masks'].unsqueeze(0).to(device=device),
|
| 168 |
+
observation['state'].unsqueeze(0).to(dtype=dtype, device=device),
|
| 169 |
+
vlm_causal = vlm_causal
|
| 170 |
+
)
|
| 171 |
+
action_indices = list(range(6)) + [14] + list(range(6, 12)) + [15]
|
| 172 |
+
actions = actions[:, :, action_indices]
|
| 173 |
+
delta_time = time.time() - s1
|
| 174 |
+
print(f'sample_actions cost {delta_time} s')
|
| 175 |
+
observation['action'] = actions.squeeze(0)[:, :14].to(dtype=torch.float32, device='cpu')
|
| 176 |
+
if use_bf16:
|
| 177 |
+
observation['state'] = observation['state'].to(dtype=torch.float32)
|
| 178 |
+
data = self.normalizer.unnormalize(observation)
|
| 179 |
+
return data
|
| 180 |
+
|
| 181 |
+
class LingBotVlaInferencePolicy(PolicyPreprocessMixin, LingbotVlaPolicy):
|
| 182 |
+
pass # Only combine necessary functions
|
| 183 |
+
|
| 184 |
+
class PI0InfernecePolicy(PolicyPreprocessMixin, PI0Policy):
|
| 185 |
+
pass # Only combine necessary functions
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def merge_qwen_config(policy_config, qwen_config):
|
| 189 |
+
if hasattr(qwen_config, 'to_dict'):
|
| 190 |
+
config_dict = qwen_config.to_dict()
|
| 191 |
+
else:
|
| 192 |
+
config_dict = qwen_config
|
| 193 |
+
|
| 194 |
+
text_keys = {
|
| 195 |
+
"hidden_size",
|
| 196 |
+
"intermediate_size",
|
| 197 |
+
"num_hidden_layers",
|
| 198 |
+
"num_attention_heads",
|
| 199 |
+
"num_key_value_heads",
|
| 200 |
+
"rms_norm_eps",
|
| 201 |
+
"rope_theta",
|
| 202 |
+
"vocab_size",
|
| 203 |
+
"max_position_embeddings",
|
| 204 |
+
"hidden_act",
|
| 205 |
+
"tie_word_embeddings",
|
| 206 |
+
"tokenizer_path",
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
for key in text_keys:
|
| 210 |
+
if key in config_dict:
|
| 211 |
+
setattr(policy_config, key, config_dict[key])
|
| 212 |
+
print(f"✅ Merged LLM: {key} = {config_dict[key]}")
|
| 213 |
+
|
| 214 |
+
if "vision_config" in config_dict:
|
| 215 |
+
policy_config.vision_config = qwen_config.vision_config
|
| 216 |
+
else:
|
| 217 |
+
print("⚠️ Warning: 'vision_config' not found in qwen_config!")
|
| 218 |
+
|
| 219 |
+
return policy_config
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class QwenPiServer:
|
| 223 |
+
'''
|
| 224 |
+
policy wrapper to support action ensemble or chunk execution
|
| 225 |
+
'''
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
path_to_pi_model="",
|
| 229 |
+
adaptive_ensemble_alpha=0.1,
|
| 230 |
+
action_ensemble_horizon=8,
|
| 231 |
+
use_length=1, # to control the execution length of the action chunk, -1 denotes using action ensemble
|
| 232 |
+
chunk_ret=False,
|
| 233 |
+
use_bf16=True,
|
| 234 |
+
use_fp32=False,
|
| 235 |
+
) -> None:
|
| 236 |
+
assert not (use_bf16 and use_fp32), 'Bfloat16 or Float32!!!'
|
| 237 |
+
self.adaptive_ensemble_alpha = adaptive_ensemble_alpha
|
| 238 |
+
self.use_length = use_length
|
| 239 |
+
self.chunk_ret = chunk_ret
|
| 240 |
+
|
| 241 |
+
self.task_description = None
|
| 242 |
+
|
| 243 |
+
self.vla = self.load_vla(path_to_pi_model)
|
| 244 |
+
self.vla = self.vla.cuda().eval()
|
| 245 |
+
if use_bf16:
|
| 246 |
+
self.vla = self.vla.to(torch.bfloat16)
|
| 247 |
+
elif use_fp32:
|
| 248 |
+
self.vla.model.float()
|
| 249 |
+
self.global_step = 0
|
| 250 |
+
self.last_action_chunk = None
|
| 251 |
+
self.use_bf16 = use_bf16
|
| 252 |
+
self.use_fp32 = use_fp32
|
| 253 |
+
|
| 254 |
+
def load_vla(self, path_to_pi_model) -> LingbotVlaPolicy:
|
| 255 |
+
# load model
|
| 256 |
+
print(f"loading model from: {path_to_pi_model}")
|
| 257 |
+
config = PreTrainedConfig.from_pretrained(path_to_pi_model)
|
| 258 |
+
|
| 259 |
+
# load training config
|
| 260 |
+
training_config_path = Path(path_to_pi_model)/'lingbotvla_cli.yaml'
|
| 261 |
+
with open(training_config_path, 'r') as f:
|
| 262 |
+
training_config = yaml.safe_load(f)
|
| 263 |
+
f.close()
|
| 264 |
+
|
| 265 |
+
# update model config according to training config
|
| 266 |
+
training_model_config = training_config['model']
|
| 267 |
+
training_model_config.update(training_config['train'])
|
| 268 |
+
for k, v in training_model_config.items():
|
| 269 |
+
v = getattr(config, k, training_model_config[k])
|
| 270 |
+
setattr(config, k, v)
|
| 271 |
+
|
| 272 |
+
# Set attention_implementation to 'eager' to speed up evaluation.
|
| 273 |
+
config.attention_implementation = 'eager'
|
| 274 |
+
|
| 275 |
+
# set base model according to training config
|
| 276 |
+
training_base_model = os.environ.get('QWEN25_PATH', './Qwen2.5-VL-3B-Instruct/')
|
| 277 |
+
if 'paligemma' in training_base_model:
|
| 278 |
+
model_name = 'pi0'
|
| 279 |
+
config.vocab_size = 257152 # set vocab size for paligamma
|
| 280 |
+
elif 'qwen2' in training_base_model.lower():
|
| 281 |
+
model_name = 'lingbotvla'
|
| 282 |
+
else:
|
| 283 |
+
raise ValueError(f"Unsupported base model of {path_to_pi_model}")
|
| 284 |
+
base_model_path = BASE_MODEL_PATH[model_name]
|
| 285 |
+
config.tokenizer_path = base_model_path
|
| 286 |
+
self.model_name = model_name
|
| 287 |
+
|
| 288 |
+
qwen_config = AutoConfig.from_pretrained(base_model_path)
|
| 289 |
+
config = merge_qwen_config(config, qwen_config)
|
| 290 |
+
|
| 291 |
+
if 'vocab_size' in training_config['model'] and training_config['model']['vocab_size'] != 0:
|
| 292 |
+
config.vocab_size = training_config['model']['vocab_size']
|
| 293 |
+
# load processors
|
| 294 |
+
self.processor = build_processor(base_model_path)
|
| 295 |
+
self.language_tokenizer = self.processor.tokenizer
|
| 296 |
+
self.image_processor = self.processor.image_processor
|
| 297 |
+
data_config = SimpleNamespace(**training_config['data'])
|
| 298 |
+
|
| 299 |
+
print('Initializing model ... ')
|
| 300 |
+
|
| 301 |
+
if 'paligemma' in training_base_model:
|
| 302 |
+
policy = PI0InfernecePolicy(config, tokenizer_path=base_model_path)
|
| 303 |
+
else:
|
| 304 |
+
policy = LingBotVlaInferencePolicy(config, tokenizer_path=base_model_path, eval=True)
|
| 305 |
+
|
| 306 |
+
load_model_weights(policy, path_to_pi_model, strict=True)
|
| 307 |
+
|
| 308 |
+
policy.feature_transform = None
|
| 309 |
+
self.data_config = data_config
|
| 310 |
+
self.config = config
|
| 311 |
+
self.joint_max_dim = training_config['train']['max_action_dim']
|
| 312 |
+
self.action_dim = training_config['train']['action_dim']
|
| 313 |
+
self.chunk_size = training_config['train']['chunk_size']
|
| 314 |
+
policy.action_dim = self.action_dim
|
| 315 |
+
policy.chunk_size = self.chunk_size
|
| 316 |
+
self.norm_stats_file = 'assets/norm_stats/robotwin_all_new.json'
|
| 317 |
+
if 'align_params' in training_config['train']:
|
| 318 |
+
self.use_depth_align = True
|
| 319 |
+
else: self.use_depth_align = False
|
| 320 |
+
with open(self.norm_stats_file) as f:
|
| 321 |
+
self.norm_stats = json.load(f)
|
| 322 |
+
policy.normalizer = Normalizer(
|
| 323 |
+
norm_stats=self.norm_stats['norm_stats'],
|
| 324 |
+
from_file=True,
|
| 325 |
+
data_type='robotwin_rep',
|
| 326 |
+
norm_type={
|
| 327 |
+
"observation.images.cam_high": "identity",
|
| 328 |
+
"observation.images.cam_left_wrist": "identity",
|
| 329 |
+
"observation.images.cam_right_wrist": "identity",
|
| 330 |
+
"observation.state": self.data_config.norm_type,
|
| 331 |
+
"action": self.data_config.norm_type,
|
| 332 |
+
},
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
print('Model initialized ... ')
|
| 336 |
+
|
| 337 |
+
return policy
|
| 338 |
+
|
| 339 |
+
def reset(self, robo_name, path_to_pi_model = None) -> None:
|
| 340 |
+
|
| 341 |
+
if path_to_pi_model is not None:
|
| 342 |
+
self.vla = self.load_vla(path_to_pi_model)
|
| 343 |
+
self.vla = self.vla.cuda().eval()
|
| 344 |
+
if self.use_bf16:
|
| 345 |
+
self.vla = self.vla.to(torch.bfloat16)
|
| 346 |
+
elif self.use_fp32:
|
| 347 |
+
self.vla.model.float()
|
| 348 |
+
|
| 349 |
+
self.global_step = 0
|
| 350 |
+
self.last_action_chunk = None
|
| 351 |
+
|
| 352 |
+
if getattr(self.data_config, 'norm_type', None) is None:
|
| 353 |
+
self.data_config.norm_type = 'meanstd'
|
| 354 |
+
if getattr(self.config, 'vlm_causal', None) is None:
|
| 355 |
+
self.config.vlm_causal = False
|
| 356 |
+
if getattr(self.config, 'qwenvl_bos', None) is None:
|
| 357 |
+
self.config.qwenvl_bos = False
|
| 358 |
+
|
| 359 |
+
# if update ckpt path
|
| 360 |
+
if path_to_pi_model is not None:
|
| 361 |
+
all_safetensors = glob(os.path.join(path_to_pi_model, "*.safetensors"))
|
| 362 |
+
merged_weights = {}
|
| 363 |
+
|
| 364 |
+
for file_path in tqdm(all_safetensors):
|
| 365 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 366 |
+
for key in f.keys():
|
| 367 |
+
merged_weights[key] = f.get_tensor(key)
|
| 368 |
+
|
| 369 |
+
self.vla.load_state_dict(merged_weights, strict=True)
|
| 370 |
+
|
| 371 |
+
def resize_image(self, observation):
|
| 372 |
+
for image_feature in ['observation.images.cam_high', 'observation.images.cam_left_wrist', 'observation.images.cam_right_wrist']:
|
| 373 |
+
assert image_feature in observation
|
| 374 |
+
assert len(observation[image_feature].shape)==3 and observation[image_feature].shape[-1] == 3
|
| 375 |
+
image = observation[image_feature]
|
| 376 |
+
img_pil = Image.fromarray(image)
|
| 377 |
+
image_size = getattr(self.data_config, 'img_size', 224)
|
| 378 |
+
img_pil = img_pil.resize((image_size, image_size), Image.BILINEAR)
|
| 379 |
+
|
| 380 |
+
# img_resized shape: C*H*W
|
| 381 |
+
img_resized = np.transpose(np.array(img_pil), (2,0,1)) # (3,224,224)
|
| 382 |
+
observation[image_feature] = img_resized / 255.
|
| 383 |
+
|
| 384 |
+
def infer(self, observation, center_crop=True):
|
| 385 |
+
"""Generates an action with the VLA policy."""
|
| 386 |
+
|
| 387 |
+
# (If trained with image augmentations) Center crop image and then resize back up to original size.
|
| 388 |
+
# IMPORTANT: Let's say crop scale == 0.9. To get the new height and width (post-crop), multiply
|
| 389 |
+
# the original height and width by sqrt(0.9) -- not 0.9!
|
| 390 |
+
if 'reset' in observation and observation['reset']:
|
| 391 |
+
self.reset(robo_name=observation['robo_name'], path_to_pi_model=observation['path_to_pi_model'] if 'path_to_pi_model' in observation else None)
|
| 392 |
+
return dict(action = None)
|
| 393 |
+
|
| 394 |
+
self.resize_image(observation)
|
| 395 |
+
for k, v in observation.items():
|
| 396 |
+
if isinstance(v, np.ndarray):
|
| 397 |
+
observation[k] = torch.from_numpy(v)
|
| 398 |
+
|
| 399 |
+
if self.use_length == -1 or self.global_step % self.use_length == 0:
|
| 400 |
+
joint_max_dim = getattr(self, 'joint_max_dim')
|
| 401 |
+
action_dim = getattr(self, 'action_dim')
|
| 402 |
+
chunk_size = getattr(self, 'chunk_size')
|
| 403 |
+
indices = list(range(6)) + list(range(7, 13)) + [6] + [13]
|
| 404 |
+
observation["observation.state"] = observation["observation.state"][indices]
|
| 405 |
+
normalized_observation = self.vla.normalizer.normalize(observation)
|
| 406 |
+
base_image = (normalized_observation["observation.images.cam_high"] * 255).to(torch.uint8)
|
| 407 |
+
left_wrist_image = (normalized_observation["observation.images.cam_left_wrist"] * 255).to(
|
| 408 |
+
torch.uint8
|
| 409 |
+
)
|
| 410 |
+
right_wrist_image = (normalized_observation["observation.images.cam_right_wrist"] * 255).to(
|
| 411 |
+
torch.uint8
|
| 412 |
+
)
|
| 413 |
+
obs_dict = {
|
| 414 |
+
"image": {"base_0_rgb": base_image, "left_wrist_0_rgb": left_wrist_image, "right_wrist_0_rgb": right_wrist_image},
|
| 415 |
+
"state": normalized_observation["observation.state"].to(torch.float32),
|
| 416 |
+
"prompt": [observation["task"]],
|
| 417 |
+
}
|
| 418 |
+
state = prepare_state(self.config, obs_dict)
|
| 419 |
+
lang_tokens, lang_masks = prepare_language(self.config, self.language_tokenizer, obs_dict)
|
| 420 |
+
images, img_masks, _ = prepare_images(self.config, self.image_processor, obs_dict)
|
| 421 |
+
observation = {
|
| 422 |
+
'images': images,
|
| 423 |
+
'img_masks': img_masks,
|
| 424 |
+
'state': state,
|
| 425 |
+
'lang_tokens': lang_tokens,
|
| 426 |
+
'lang_masks': lang_masks,
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
if self.use_bf16:
|
| 430 |
+
observation['state'] = observation['state'].to(torch.bfloat16)
|
| 431 |
+
|
| 432 |
+
org_actions = ['action']
|
| 433 |
+
assert len(org_actions)==1, "Only support single action feature"
|
| 434 |
+
if self.chunk_ret:
|
| 435 |
+
action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]].float().cpu().numpy()
|
| 436 |
+
action = action[:self.use_length, :self.action_dim]
|
| 437 |
+
else:
|
| 438 |
+
if self.use_length == -1 or self.global_step % self.use_length == 0:
|
| 439 |
+
action = self.vla.select_action(observation, self.use_bf16, self.config.vlm_causal)[org_actions[0]]
|
| 440 |
+
self.last_action_chunk = action.float().cpu().numpy()
|
| 441 |
+
|
| 442 |
+
if self.use_length > 0:
|
| 443 |
+
action = self.last_action_chunk[self.global_step % self.use_length]
|
| 444 |
+
action = action[:, :self.action_dim]
|
| 445 |
+
print(f"on server step: {self.global_step}")
|
| 446 |
+
self.global_step+=1
|
| 447 |
+
|
| 448 |
+
return dict(action = action)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
import argparse
|
| 452 |
+
from .websocket_policy_server import WebsocketPolicyServer
|
| 453 |
+
|
| 454 |
+
def main():
|
| 455 |
+
parser = argparse.ArgumentParser(description="启动 QwenPi WebSocket 策略服务器")
|
| 456 |
+
|
| 457 |
+
parser.add_argument(
|
| 458 |
+
"--model_path",
|
| 459 |
+
type=str,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
parser.add_argument(
|
| 463 |
+
"--use_length",
|
| 464 |
+
type=int,
|
| 465 |
+
default=50,
|
| 466 |
+
help="used length of action chunk"
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
parser.add_argument(
|
| 470 |
+
"--chunk_ret",
|
| 471 |
+
type=bool,
|
| 472 |
+
default=True,
|
| 473 |
+
help=" True: The returned action tensor includes the horizon dimension. This allows the model to output a sequence of actions for each horizon step. False: The horizon dimension is omitted. The model selects and returns the next step autonomously based on its policy."
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
parser.add_argument(
|
| 477 |
+
"--port",
|
| 478 |
+
type=int,
|
| 479 |
+
default=8006,
|
| 480 |
+
help="port of WebSocket"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
args = parser.parse_args()
|
| 484 |
+
|
| 485 |
+
model = QwenPiServer(args.model_path, use_length=args.use_length, chunk_ret = args.chunk_ret)
|
| 486 |
+
model_server = WebsocketPolicyServer(model, port=args.port)
|
| 487 |
+
model_server.serve_forever()
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
if __name__ == "__main__":
|
| 491 |
+
main()
|
deploy/msgpack_numpy.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adds NumPy array support to msgpack.
|
| 2 |
+
|
| 3 |
+
msgpack is good for (de)serializing data over a network for multiple reasons:
|
| 4 |
+
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
|
| 5 |
+
- msgpack is widely used and has good cross-language support
|
| 6 |
+
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
|
| 7 |
+
languages like Python and JavaScript
|
| 8 |
+
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
|
| 9 |
+
than pickle for serializing large arrays using the below strategy
|
| 10 |
+
|
| 11 |
+
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
|
| 12 |
+
that it falls back to pickle for object arrays.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import functools
|
| 16 |
+
|
| 17 |
+
import msgpack
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def pack_array(obj):
|
| 22 |
+
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
|
| 23 |
+
raise ValueError(f"Unsupported dtype: {obj.dtype}")
|
| 24 |
+
|
| 25 |
+
if isinstance(obj, np.ndarray):
|
| 26 |
+
return {
|
| 27 |
+
b"__ndarray__": True,
|
| 28 |
+
b"data": obj.tobytes(),
|
| 29 |
+
b"dtype": obj.dtype.str,
|
| 30 |
+
b"shape": obj.shape,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
if isinstance(obj, np.generic):
|
| 34 |
+
return {
|
| 35 |
+
b"__npgeneric__": True,
|
| 36 |
+
b"data": obj.item(),
|
| 37 |
+
b"dtype": obj.dtype.str,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
return obj
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def unpack_array(obj):
|
| 44 |
+
if b"__ndarray__" in obj:
|
| 45 |
+
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
|
| 46 |
+
|
| 47 |
+
if b"__npgeneric__" in obj:
|
| 48 |
+
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
|
| 49 |
+
|
| 50 |
+
return obj
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Packer = functools.partial(msgpack.Packer, default=pack_array)
|
| 54 |
+
packb = functools.partial(msgpack.packb, default=pack_array)
|
| 55 |
+
|
| 56 |
+
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
|
| 57 |
+
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
|
deploy/websocket_client_policy.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
from typing import Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from typing_extensions import override
|
| 6 |
+
import websockets.sync.client
|
| 7 |
+
from .msgpack_numpy import Packer, unpackb
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class WebsocketClientPolicy:
|
| 11 |
+
"""Implements the Policy interface by communicating with a server over websocket.
|
| 12 |
+
|
| 13 |
+
See WebsocketPolicyServer for a corresponding server implementation.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
|
| 17 |
+
self._uri = f"ws://{host}"
|
| 18 |
+
if port is not None:
|
| 19 |
+
self._uri += f":{port}"
|
| 20 |
+
self._packer = Packer()
|
| 21 |
+
self._api_key = api_key
|
| 22 |
+
self._ws, self._server_metadata = self._wait_for_server()
|
| 23 |
+
|
| 24 |
+
def get_server_metadata(self) -> Dict:
|
| 25 |
+
return self._server_metadata
|
| 26 |
+
|
| 27 |
+
def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
|
| 28 |
+
logging.info(f"Waiting for server at {self._uri}...")
|
| 29 |
+
while True:
|
| 30 |
+
try:
|
| 31 |
+
headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
|
| 32 |
+
conn = websockets.sync.client.connect(
|
| 33 |
+
self._uri, compression=None, max_size=None, additional_headers=headers
|
| 34 |
+
)
|
| 35 |
+
metadata = unpackb(conn.recv())
|
| 36 |
+
return conn, metadata
|
| 37 |
+
except ConnectionRefusedError:
|
| 38 |
+
logging.info("Still waiting for server...")
|
| 39 |
+
time.sleep(5)
|
| 40 |
+
|
| 41 |
+
@override
|
| 42 |
+
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
| 43 |
+
data = self._packer.pack(obs)
|
| 44 |
+
self._ws.send(data)
|
| 45 |
+
response = self._ws.recv()
|
| 46 |
+
if isinstance(response, str):
|
| 47 |
+
# we're expecting bytes; if the server sends a string, it's an error.
|
| 48 |
+
raise RuntimeError(f"Error in inference server:\n{response}")
|
| 49 |
+
return unpackb(response)
|
| 50 |
+
|
| 51 |
+
@override
|
| 52 |
+
def reset(self, robo_name: str) -> None:
|
| 53 |
+
self.infer(dict(reset=True, robo_name=robo_name))
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
policy_on_device = WebsocketClientPolicy(port=8000)
|
| 57 |
+
import torch
|
| 58 |
+
import numpy as np
|
| 59 |
+
from PIL import Image
|
| 60 |
+
from .image_tools import convert_to_uint8
|
| 61 |
+
device = torch.device("cuda")
|
| 62 |
+
|
| 63 |
+
base_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
|
| 64 |
+
left_wrist_0_rgb = np.random.randint(0, 256, size=(1, 3, 224, 224), dtype=np.uint8)
|
| 65 |
+
state = np.random.rand(1,8).astype(np.float32)
|
| 66 |
+
prompt = ["do something"]
|
| 67 |
+
|
| 68 |
+
# observation = {
|
| 69 |
+
# "image": {
|
| 70 |
+
# "base_0_rgb": torch.from_numpy(base_0_rgb).to(device)[None],
|
| 71 |
+
# "left_wrist_0_rgb": torch.from_numpy(left_wrist_0_rgb).to(device)[None],
|
| 72 |
+
# },
|
| 73 |
+
# "state": torch.from_numpy(state).to(device)[None],
|
| 74 |
+
# "prompt": prompt,
|
| 75 |
+
# }
|
| 76 |
+
|
| 77 |
+
observation = {
|
| 78 |
+
"image": {
|
| 79 |
+
"base_0_rgb": convert_to_uint8(base_0_rgb),
|
| 80 |
+
"left_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
|
| 81 |
+
"right_wrist_0_rgb": convert_to_uint8(left_wrist_0_rgb),
|
| 82 |
+
},
|
| 83 |
+
"state": state,
|
| 84 |
+
"prompt": prompt,
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
policy_on_device.infer(observation)
|
| 88 |
+
from IPython import embed;embed()
|
deploy/websocket_policy_server.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import http
|
| 3 |
+
import logging
|
| 4 |
+
import time
|
| 5 |
+
import traceback
|
| 6 |
+
|
| 7 |
+
from .msgpack_numpy import Packer, unpackb
|
| 8 |
+
import websockets.asyncio.server as _server
|
| 9 |
+
import websockets.frames
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class WebsocketPolicyServer:
|
| 15 |
+
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
|
| 16 |
+
|
| 17 |
+
Currently only implements the `load` and `infer` methods.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
policy,
|
| 23 |
+
host: str = "0.0.0.0",
|
| 24 |
+
port: int | None = None,
|
| 25 |
+
metadata: dict | None = None,
|
| 26 |
+
) -> None:
|
| 27 |
+
self._policy = policy
|
| 28 |
+
self._host = host
|
| 29 |
+
self._port = port
|
| 30 |
+
self._metadata = metadata or {}
|
| 31 |
+
logging.getLogger("websockets.server").setLevel(logging.INFO)
|
| 32 |
+
|
| 33 |
+
def serve_forever(self) -> None:
|
| 34 |
+
asyncio.run(self.run())
|
| 35 |
+
|
| 36 |
+
async def run(self):
|
| 37 |
+
async with _server.serve(
|
| 38 |
+
self._handler,
|
| 39 |
+
self._host,
|
| 40 |
+
self._port,
|
| 41 |
+
compression=None,
|
| 42 |
+
max_size=None,
|
| 43 |
+
process_request=_health_check,
|
| 44 |
+
) as server:
|
| 45 |
+
await server.serve_forever()
|
| 46 |
+
|
| 47 |
+
async def _handler(self, websocket: _server.ServerConnection):
|
| 48 |
+
logger.info(f"Connection from {websocket.remote_address} opened")
|
| 49 |
+
packer = Packer()
|
| 50 |
+
|
| 51 |
+
await websocket.send(packer.pack(self._metadata))
|
| 52 |
+
|
| 53 |
+
prev_total_time = None
|
| 54 |
+
while True:
|
| 55 |
+
try:
|
| 56 |
+
start_time = time.monotonic()
|
| 57 |
+
obs = unpackb(await websocket.recv())
|
| 58 |
+
|
| 59 |
+
infer_time = time.monotonic()
|
| 60 |
+
action = self._policy.infer(obs)
|
| 61 |
+
infer_time = time.monotonic() - infer_time
|
| 62 |
+
|
| 63 |
+
action["server_timing"] = {
|
| 64 |
+
"infer_ms": infer_time * 1000,
|
| 65 |
+
}
|
| 66 |
+
if prev_total_time is not None:
|
| 67 |
+
# We can only record the last total time since we also want to include the send time.
|
| 68 |
+
action["server_timing"]["prev_total_ms"] = prev_total_time * 1000
|
| 69 |
+
|
| 70 |
+
await websocket.send(packer.pack(action))
|
| 71 |
+
prev_total_time = time.monotonic() - start_time
|
| 72 |
+
|
| 73 |
+
except websockets.ConnectionClosed:
|
| 74 |
+
logger.info(f"Connection from {websocket.remote_address} closed")
|
| 75 |
+
break
|
| 76 |
+
except Exception:
|
| 77 |
+
await websocket.send(traceback.format_exc())
|
| 78 |
+
await websocket.close(
|
| 79 |
+
code=websockets.frames.CloseCode.INTERNAL_ERROR,
|
| 80 |
+
reason="Internal server error. Traceback included in previous frame.",
|
| 81 |
+
)
|
| 82 |
+
raise
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
|
| 86 |
+
if request.path == "/healthz":
|
| 87 |
+
return connection.respond(http.HTTPStatus.OK, "OK\n")
|
| 88 |
+
# Continue with the normal request handling.
|
| 89 |
+
return None
|
docker/Dockerfile
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Start from the NVIDIA official image (ubuntu-22.04 + python-3.10)
|
| 2 |
+
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
|
| 3 |
+
FROM nvcr.io/nvidia/pytorch:24.08-py3
|
| 4 |
+
|
| 5 |
+
# Define environments
|
| 6 |
+
ENV MAX_JOBS=32
|
| 7 |
+
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
| 8 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 9 |
+
ENV NODE_OPTIONS=""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Install systemctl and tini
|
| 13 |
+
RUN apt-get update && \
|
| 14 |
+
apt-get install -y -o Dpkg::Options::="--force-confdef" systemd tini && \
|
| 15 |
+
apt-get clean || { echo "Installation failed"; exit 1; }
|
| 16 |
+
|
| 17 |
+
RUN apt-get install -y tzdata \
|
| 18 |
+
&& ln -fs /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
|
| 19 |
+
&& dpkg-reconfigure -f noninteractive tzdata
|
| 20 |
+
|
| 21 |
+
# Change pip source
|
| 22 |
+
RUN python -m pip install --upgrade pip
|
| 23 |
+
|
| 24 |
+
# Install torch-2.5.1 + vllm-0.7.3
|
| 25 |
+
RUN pip install --no-cache-dir vllm==0.7.3 torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 tensordict torchdata \
|
| 26 |
+
transformers>=4.49.0 accelerate datasets peft hf-transfer diffusers \
|
| 27 |
+
codetiming hydra-core pandas pyarrow>=15.0.0 pylatexenc qwen-vl-utils wandb ninja liger-kernel \
|
| 28 |
+
pytest yapf py-spy pyext pre-commit ruff packaging
|
| 29 |
+
|
| 30 |
+
# Install flux
|
| 31 |
+
RUN pip install --no-cache-dir byte-flux
|
| 32 |
+
|
| 33 |
+
# Install flash-attn and triton
|
| 34 |
+
RUN pip install --no-cache-dir flash-attn triton>=3.1.0
|
docs/Makefile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal makefile for Sphinx documentation
|
| 2 |
+
#
|
| 3 |
+
|
| 4 |
+
# You can set these variables from the command line.
|
| 5 |
+
SPHINXOPTS =
|
| 6 |
+
SPHINXBUILD = sphinx-build
|
| 7 |
+
SPHINXPROJ = LingBotVLA
|
| 8 |
+
SOURCEDIR = .
|
| 9 |
+
BUILDDIR = _build
|
| 10 |
+
|
| 11 |
+
# Put it first so that "make" without argument is like "make help".
|
| 12 |
+
help:
|
| 13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
| 14 |
+
|
| 15 |
+
.PHONY: help Makefile
|
| 16 |
+
|
| 17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
| 18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
| 19 |
+
%: Makefile
|
| 20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
docs/README.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LingBotVLA documents
|
| 2 |
+
|
| 3 |
+
## Build the docs
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
# Install dependencies.
|
| 7 |
+
pip install -r requirements-docs.txt
|
| 8 |
+
|
| 9 |
+
# Build the docs.
|
| 10 |
+
make clean
|
| 11 |
+
make html
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## Open the docs with your browser
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
python -m http.server -d _build/html/
|
| 18 |
+
```
|
| 19 |
+
Launch your browser and open localhost:8000.
|
docs/conf.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration file for the Sphinx documentation builder.
|
| 2 |
+
#
|
| 3 |
+
# This file only contains a selection of the most common options. For a full
|
| 4 |
+
# list see the documentation:
|
| 5 |
+
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
| 6 |
+
|
| 7 |
+
# -- Path setup --------------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
# If extensions (or modules to document with autodoc) are in another directory,
|
| 10 |
+
# add these directories to sys.path here. If the directory is relative to the
|
| 11 |
+
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
| 12 |
+
#
|
| 13 |
+
# import os
|
| 14 |
+
# import sys
|
| 15 |
+
# sys.path.insert(0, os.path.abspath('.'))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# -- Project information -----------------------------------------------------
|
| 19 |
+
|
| 20 |
+
project = "LingBotVLA"
|
| 21 |
+
# pylint: disable=W0622
|
| 22 |
+
copyright = "2026 Robbyant Team, based on VeOmni by ByteDance Seed Foundation MLSys Team"
|
| 23 |
+
|
| 24 |
+
# -- General configuration ---------------------------------------------------
|
| 25 |
+
# The master toctree document.
|
| 26 |
+
master_doc = "index"
|
| 27 |
+
|
| 28 |
+
# Add any Sphinx extension module names here, as strings. They can be
|
| 29 |
+
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
| 30 |
+
# ones.
|
| 31 |
+
extensions = [
|
| 32 |
+
"recommonmark",
|
| 33 |
+
"sphinx.ext.autosectionlabel",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
# The suffix(es) of source filenames.
|
| 37 |
+
# You can specify multiple suffix as a list of string:
|
| 38 |
+
source_suffix = [".rst", "rest", ".md"]
|
| 39 |
+
|
| 40 |
+
# Add any paths that contain templates here, relative to this directory.
|
| 41 |
+
templates_path = ["_templates"]
|
| 42 |
+
|
| 43 |
+
# The language for content autogenerated by Sphinx. Refer to documentation
|
| 44 |
+
# for a list of supported languages.
|
| 45 |
+
#
|
| 46 |
+
# This is also used if you do content translation via gettext catalogs.
|
| 47 |
+
# Usually you set "language" from the command line for these cases.
|
| 48 |
+
language = "en"
|
| 49 |
+
|
| 50 |
+
# List of patterns, relative to source directory, that match files and
|
| 51 |
+
# directories to ignore when looking for source files.
|
| 52 |
+
# This pattern also affects html_static_path and html_extra_path.
|
| 53 |
+
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# -- Options for HTML output -------------------------------------------------
|
| 57 |
+
|
| 58 |
+
# The theme to use for HTML and HTML Help pages. See the documentation for
|
| 59 |
+
# a list of builtin themes.
|
| 60 |
+
#
|
| 61 |
+
html_theme = "sphinx_rtd_theme"
|
| 62 |
+
|
| 63 |
+
# Add any paths that contain custom static files (such as style sheets) here,
|
| 64 |
+
# relative to this directory. They are copied after the builtin static files,
|
| 65 |
+
# so a file named "default.css" will overwrite the builtin "default.css".
|
| 66 |
+
html_static_path = ["_static"]
|
docs/config/config.md
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Config arguments Explanation
|
| 2 |
+
### Model configuration arguments
|
| 3 |
+
| Name | Type | Description | Default Value |
|
| 4 |
+
| --- | --- | --- | --- |
|
| 5 |
+
| model.config_path | str | Path to the model huggingface configuration, like `config.json` | model.model_path |
|
| 6 |
+
| model.model_path | str | Path to the model parameter file. If empty, random initialization will be performed | None |
|
| 7 |
+
| model.tokenizer_path | str | Path to the tokenizer | model.model_path |
|
| 8 |
+
| model.encoders | dict | Configuration file for multi-modal encoders | {} |
|
| 9 |
+
| model.decoders | dict | Configuration file for multi-modal decoders | {} |
|
| 10 |
+
| model.input_encoder | str: {"encoder", "decoder"} | Use the encoder of the encoder or decoder to encode the input image | encoder |
|
| 11 |
+
| model.output_encoder | str: {"encoder", "decoder"} | Use the encoder of the encoder or decoder to encode the output image | decoder |
|
| 12 |
+
| model.encode_target | bool | Used to encode the training data for the diffusion model | False |
|
| 13 |
+
|
| 14 |
+
### Data configuration arguments
|
| 15 |
+
|
| 16 |
+
| Name | Type | Description | Default Value |
|
| 17 |
+
| --- | --- | --- | --- |
|
| 18 |
+
| data.train_path | str | Path of training dataset | Required |
|
| 19 |
+
| data.train_size | int | Total number of tokens in the training set | 10,000,000 |
|
| 20 |
+
| data.data_type | str: {"plaintext", "conversation"} | Dataset type. | conversation |
|
| 21 |
+
| data.dataloader_type | str: {"native"} | Use the pytorch dataloader or | native |
|
| 22 |
+
| data.datasets_type | str: {"mapping", "iterable"} | Dataset type. `IterativeDataset` or `MappingDataset`, or your custom datsets | mapping |
|
| 23 |
+
| data.text_keys | str: {"content_split", "messages"} | The key corresponding to the text samples in the data dictionary. Generally, it is "content_split" for pretraining and "messages" for SFT. | content_split |
|
| 24 |
+
| data.image_keys | str | The key corresponding to the image samples in the data dictionary. Generally, it is "images". | images |
|
| 25 |
+
| data.chat_template | str | Name of the chat template. | default |
|
| 26 |
+
| data.max_seq_len | int | Maximum training length. | 2048 |
|
| 27 |
+
| data.num_workers | int | Number of multi-process loaders for the dataloader. | 4 |
|
| 28 |
+
| data.drop_last | bool | Whether to discard the remaining data at the end. | True |
|
| 29 |
+
| data.pin_memory | bool | Whether to pin the data in the CPU memory. | True |
|
| 30 |
+
| data.prefetch_factor | int | Number of samples preprocessed by the dataloader. | 2 |
|
| 31 |
+
|
| 32 |
+
#### Training configuration arguments
|
| 33 |
+
| Name | Type | Description | Default Value |
|
| 34 |
+
| --- | --- | --- | --- |
|
| 35 |
+
| train.output_dir | str | Path to save the model. | Required |
|
| 36 |
+
| train.lr | float | Maximum learning rate. | 5e - 5 |
|
| 37 |
+
| train.lr_min | float | Minimum learning rate. | 1e - 7 |
|
| 38 |
+
| train.weight_decay | float | Weight decay coefficient. | 0 |
|
| 39 |
+
| train.optimizer | str: {"adamw", "anyprecision_adamw"} | Name of the optimizer. | adamw |
|
| 40 |
+
| train.max_grad_norm | float | Gradient clipping norm. | 1.0 |
|
| 41 |
+
| train.micro_batch_size | int | Number of samples processed simultaneously on each GPU. | 1 |
|
| 42 |
+
| train.global_batch_size | int | Global batch size, which must be a multiple of the number of GPUs. | train.micro_batch_size * n_gpus |
|
| 43 |
+
| train.num_train_epochs | int | Number of training epochs. | 1 |
|
| 44 |
+
| train.rmpad | bool | Whether to use rmpad training based on cu_seqlens. | False |
|
| 45 |
+
| train.rmpad_with_pos_ids | bool | Whether to use rmpad training based on position_ids. | False |
|
| 46 |
+
| train.dyn_bsz_margin | int | Number of pad tokens in the dynamic batch. | 0 |
|
| 47 |
+
| train.dyn_bsz_runtime | str: {"main", "worker"} | Running process of the dynamic batch. | worker |
|
| 48 |
+
| train.bsz_warmup_ratio | float | Proportion of batch size warmup in the total number of steps. | 0 |
|
| 49 |
+
| train.lr_warmup_ratio | float | Proportion of learning rate warmup in the total number of steps. | 0 |
|
| 50 |
+
| train.lr_decay_style | str: {"constant", "linear", "cosine"} | Name of the learning rate scheduler. | cosine |
|
| 51 |
+
| train.lr_decay_ratio | float | Proportion of learning rate decay in the total number of steps | 1.0 |
|
| 52 |
+
| train.use_doptim | bool | Whether to use the distributed optimizer during Vescale training(no use for torch fsdp) | False |
|
| 53 |
+
| train.enable_mixed_precision | bool | Whether to enable mixed precision training (higher memory usage but more stable) | True |
|
| 54 |
+
| train.enable_gradient_checkpointing | bool | Whether to enable gradient checkpointing to reduce memory usage. | True |
|
| 55 |
+
| train.enable_reentrant | bool | Whether to enable reentrant in gradient checkpointing. | True |
|
| 56 |
+
| train.enable_full_shard | bool | Whether to use full sharding FSDP (equivalent to ZeRO3). | True |
|
| 57 |
+
| train.enable_fsdp_offload | bool | Whether to enable FSDP CPU offloading (only supported for FSDP1). | False |
|
| 58 |
+
| train.enable_activation_offload | bool | Whether to enable activation value CPU offloading. | False |
|
| 59 |
+
| train.activation_gpu_limit | float | Size of the activation values retained on the GPU (in GB). | 0.0 |
|
| 60 |
+
| train.enable_manual_eager | bool | Whether to use manual eager during Vescale training. | False |
|
| 61 |
+
| train.init_device: meta | str | "cpu", "cuda", "meta", init device for model initialization. use "meta" or cpu for large model(>30B) | cuda |
|
| 62 |
+
| train.enable_full_determinism | bool | Whether to enable deterministic mode (for bitwise alignment). | False |
|
| 63 |
+
| train.empty_cache_steps | int | Number of steps between two cache clearings. -1 means not enabled. | 500 |
|
| 64 |
+
| train.data_parallel_mode | str: {"ddp", "fsdp1", "fsdp2"} | Data parallel algorithm. | ddp |
|
| 65 |
+
| train.tensor_parallel_size | int | Tensor parallel size (currently only supported for vescale training). | 1 |
|
| 66 |
+
| train.pipeline_parallel_size | int | Pipeline parallel size (currently not supported). | 1 |
|
| 67 |
+
| train.ulysses_parallel_size | int | Ulysses sequence parallel size (currently only supported for P6dense and Qwen2VL). | 1 |
|
| 68 |
+
| train.context_parallel_size | int | Ring sequence parallel size (currently not supported) | 1 |
|
| 69 |
+
| train.expert_parallel_size | int | Expert parallel size (currently only supported DeepseekMOE) | 1 |
|
| 70 |
+
| train.load_checkpoint_path | str | Path to the omnistore checkpoint for resuming training. | None |
|
| 71 |
+
| train.save_steps | int | Number of steps between two checkpoint saves. 0 means invalid. | 0 |
|
| 72 |
+
| train.save_epochs | int | Number of epochs between two checkpoint saves. 0 means invalid. | 1 |
|
| 73 |
+
| train.save_hf_weights | bool | Whether to save the model weights in the huggingface format. It is recommended to set it to False for models > 30B to prevent NCCL timeout. You can convert it after training. | True |
|
| 74 |
+
| train.seed | int | Random seed. | 42 |
|
| 75 |
+
| train.use_wandb | bool | Whether to enable byted wandb experiment logging. | True |
|
| 76 |
+
| train.wandb_project | str | Name of the wandb experiment project. | LingBotVLA |
|
| 77 |
+
| train.wandb_name | str | Name of the wandb experiment. | None |
|
| 78 |
+
| train.enable_profiling | bool | Whether to use torch profiling. | False |
|
| 79 |
+
| train.profile_start_step | int | Starting step of profiling. | 1 |
|
| 80 |
+
| train.profile_end_step | int | Ending step of profiling. | 2 |
|
| 81 |
+
| train.profile_trace_dir | str | Path to save the profiling results. | ./trace |
|
| 82 |
+
| train.profile_record_shapes | bool | Whether to record the shapes of the input tensors. | True |
|
| 83 |
+
| train.profile_profile_memory | bool | Whether to record the memory usage. | True |
|
| 84 |
+
| train.profile_with_stack | bool | Whether to record the stack information. | True |
|
| 85 |
+
| train.max_steps | int | Number of steps per training epoch (only used for debugging). | None |
|
| 86 |
+
|
| 87 |
+
### Inference configuration arguments
|
| 88 |
+
| Name | Type | Description | Default Value |
|
| 89 |
+
| --- | --- | --- | --- |
|
| 90 |
+
| infer.model_path | str | Path to the model parameter file. | Required |
|
| 91 |
+
| infer.tokenizer_path | str | Path to the tokenizer. | model.model_path |
|
| 92 |
+
| infer.seed | int | Random seed. | 42 |
|
| 93 |
+
| infer.do_sample | bool | Whether to enable sampling. | True |
|
| 94 |
+
| infer.temperature | float | Sampling temperature. | 1.0 |
|
| 95 |
+
| infer.top_p | float | Sampling Top P value. | 1.0 |
|
| 96 |
+
| infer.max_tokens | int | Maximum number of tokens generated each time. | 1024 |
|
docs/examples/qwen2vl.rst
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Qwen2VL example
|
| 2 |
+
=========================
|
docs/examples/qwen3_moe.md
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Qwen3 MoE training guide
|
| 2 |
+
|
| 3 |
+
1. Download qwen3 moe model
|
| 4 |
+
|
| 5 |
+
```shell
|
| 6 |
+
python3 scripts/download_hf_model.py \
|
| 7 |
+
--repo_id Qwen/Qwen3-30B-A3B \
|
| 8 |
+
--local_dir .
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
2. Merge qwen3 moe model experts to support GroupGemm optimize
|
| 12 |
+
``` shell
|
| 13 |
+
python3 scripts/moe_ckpt_merge/moe_merge.py --raw_hf_path Qwen3-30B-A3B --merge_hf_path Qwen3-30B-A3B-merge
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Most of the MoE models in Transformers referenced the open-source implementation of Mixtral MoE. In this implementation, MoE experts are divided into multiple blocks instead of being combined into a single `nn.Parameters`. Additionally, there are cpu-block operators like `torch.where()` and for loop, which are not very friendly for integrating MoE fusion operators.
|
| 17 |
+
|
| 18 |
+
Origin [Qwen3MoeMLP](https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L200C1-L213C25) code
|
| 19 |
+
```python
|
| 20 |
+
class Qwen3MoeMLP(nn.Module):
|
| 21 |
+
def __init__(self, config, intermediate_size=None):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.config = config
|
| 24 |
+
self.hidden_size = config.hidden_size
|
| 25 |
+
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
| 26 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 27 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 28 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 29 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 33 |
+
return down_proj
|
| 34 |
+
|
| 35 |
+
class Qwen3MoeSparseMoeBlock(nn.Module):
|
| 36 |
+
def __init__(self, config):
|
| 37 |
+
|
| 38 |
+
...
|
| 39 |
+
|
| 40 |
+
self.experts = nn.ModuleList(
|
| 41 |
+
[Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
final_hidden_states = torch.zeros(
|
| 49 |
+
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
for expert_idx in expert_hitted:
|
| 53 |
+
expert_layer = self.experts[expert_idx]
|
| 54 |
+
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
| 55 |
+
|
| 56 |
+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
| 57 |
+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
| 58 |
+
|
| 59 |
+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
| 60 |
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
| 61 |
+
return final_hidden_states, router_logits
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
- Combine Qwen3MoeMLP to Qwen3MoeExperts, then use fused moe operator
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
class Qwen3MoeExperts(nn.Module):
|
| 69 |
+
def __init__(self, config):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.num_experts = config.num_experts
|
| 72 |
+
self.hidden_dim = config.hidden_size
|
| 73 |
+
self.intermediate_size = config.moe_intermediate_size
|
| 74 |
+
self.gate_proj = torch.nn.Parameter(
|
| 75 |
+
torch.empty(self.num_experts, self.intermediate_size, self.hidden_dim),
|
| 76 |
+
requires_grad=True,
|
| 77 |
+
)
|
| 78 |
+
self.up_proj = torch.nn.Parameter(
|
| 79 |
+
torch.empty(self.num_experts, self.intermediate_size, self.hidden_dim),
|
| 80 |
+
requires_grad=True,
|
| 81 |
+
)
|
| 82 |
+
self.down_proj = torch.nn.Parameter(
|
| 83 |
+
torch.empty(self.num_experts, self.hidden_dim, self.intermediate_size),
|
| 84 |
+
requires_grad=True,
|
| 85 |
+
)
|
| 86 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 87 |
+
|
| 88 |
+
def forward(self, hidden_states, expert_idx=None, cumsum=None):
|
| 89 |
+
gate_proj_out = torch.matmul(hidden_states, self.gate_proj[expert_idx].transpose(0, 1))
|
| 90 |
+
up_proj_out = torch.matmul(hidden_states, self.up_proj[expert_idx].transpose(0, 1))
|
| 91 |
+
|
| 92 |
+
out = self.act_fn(gate_proj_out) * up_proj_out
|
| 93 |
+
out = torch.matmul(out, self.down_proj[expert_idx].transpose(0, 1))
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class Qwen3MoeSparseFusedMoeBlock(nn.Module):
|
| 98 |
+
def __init__(self, config):
|
| 99 |
+
|
| 100 |
+
...
|
| 101 |
+
|
| 102 |
+
self.experts = Qwen3MoeExperts(config)
|
| 103 |
+
|
| 104 |
+
def forward(self, hidden_states, expert_idx=None, routing_weights=None, selected_experts=None) -> torch.Tensor:
|
| 105 |
+
|
| 106 |
+
...
|
| 107 |
+
|
| 108 |
+
out = fused_moe_forward(
|
| 109 |
+
module=self,
|
| 110 |
+
num_experts=self.num_experts,
|
| 111 |
+
routing_weights=routing_weights,
|
| 112 |
+
selected_experts=selected_experts,
|
| 113 |
+
hidden_states=hidden_states,
|
| 114 |
+
fc1_1_weight=self.gate_proj,
|
| 115 |
+
fc1_2_weight=self.up_proj,
|
| 116 |
+
fc2_weight=self.down_proj,
|
| 117 |
+
)
|
| 118 |
+
return out
|
| 119 |
+
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
3. Train qwen3 moe model
|
| 123 |
+
```
|
| 124 |
+
bash train.sh tasks/train_torch.py configs/pretrain/qwen3-moe.yaml
|
| 125 |
+
```
|
docs/index.rst
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Welcome to LingBotVLA
|
| 2 |
+
=========================
|
docs/requirements-docs.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# markdown suport
|
| 2 |
+
recommonmark
|
| 3 |
+
# markdown table suport
|
| 4 |
+
sphinx-markdown-tables
|
| 5 |
+
|
| 6 |
+
# theme default rtd
|
| 7 |
+
|
| 8 |
+
# crate-docs-theme
|
| 9 |
+
sphinx-rtd-theme
|
docs/start/start.rst
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Getting Started
|
| 2 |
+
=========================
|
experiment/libero/README.md
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Install official LIBERO
|
| 2 |
+
|
| 3 |
+
```bash
|
| 4 |
+
git clone https://github.com/Lifelong-Robot-Learning/LIBERO.git libero # (here)
|
| 5 |
+
cd libero
|
| 6 |
+
pip install -e .
|
| 7 |
+
|
| 8 |
+
cd experiment/libero/libero
|
| 9 |
+
pip install -r req.txt
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
If can not import xxx from libero.libero please add the libero (here) path to the PYTHONPATH variable.
|
| 13 |
+
|
| 14 |
+
The results will be save to /project_root/Libero
|
| 15 |
+
|
| 16 |
+
- release_ensemble/ stores the log files (This directory can be changed by --local_log_dir variable)
|
| 17 |
+
- rollouts stores the videos
|
| 18 |
+
|
experiment/libero/libero/libero_utils.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utils for evaluating policies in LIBERO simulation environments."""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import imageio
|
| 7 |
+
import numpy as np
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
from libero.libero import get_libero_path
|
| 10 |
+
from libero.libero.envs import OffScreenRenderEnv
|
| 11 |
+
|
| 12 |
+
from experiment.libero.robot_utils import (
|
| 13 |
+
DATE,
|
| 14 |
+
DATE_TIME,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_libero_env(task, model_family, resolution=256):
|
| 19 |
+
"""Initializes and returns the LIBERO environment, along with the task description."""
|
| 20 |
+
task_description = task.language
|
| 21 |
+
task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
|
| 22 |
+
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
|
| 23 |
+
env = OffScreenRenderEnv(**env_args)
|
| 24 |
+
env.seed(0) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
|
| 25 |
+
return env, task_description
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_libero_dummy_action(model_family: str):
|
| 29 |
+
"""Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
|
| 30 |
+
return [0, 0, 0, 0, 0, 0, -1]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def resize_image(img, resize_size):
|
| 34 |
+
"""
|
| 35 |
+
Takes numpy array corresponding to a single image and returns resized image as numpy array.
|
| 36 |
+
|
| 37 |
+
NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow
|
| 38 |
+
the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training.
|
| 39 |
+
"""
|
| 40 |
+
assert isinstance(resize_size, tuple)
|
| 41 |
+
# Resize to image size expected by model
|
| 42 |
+
with tf.device('/CPU:0'):
|
| 43 |
+
img = tf.image.encode_jpeg(img) # Encode as JPEG, as done in RLDS dataset builder
|
| 44 |
+
img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8) # Immediately decode back
|
| 45 |
+
img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True)
|
| 46 |
+
img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)
|
| 47 |
+
img = img.numpy()
|
| 48 |
+
return img
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_libero_image(obs, resize_size):
|
| 52 |
+
"""Extracts image from observations and preprocesses it."""
|
| 53 |
+
assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
|
| 54 |
+
if isinstance(resize_size, int):
|
| 55 |
+
resize_size = (resize_size, resize_size)
|
| 56 |
+
img = obs["agentview_image"]
|
| 57 |
+
img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
|
| 58 |
+
img = resize_image(img, resize_size)
|
| 59 |
+
return img
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_libero_wrist_image(obs, resize_size):
|
| 63 |
+
"""Extracts wrist camera image from observations and preprocesses it."""
|
| 64 |
+
assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
|
| 65 |
+
if isinstance(resize_size, int):
|
| 66 |
+
resize_size = (resize_size, resize_size)
|
| 67 |
+
img = obs["robot0_eye_in_hand_image"]
|
| 68 |
+
img = img[::-1, ::-1] # IMPORTANT: rotate 180 degrees to match train preprocessing
|
| 69 |
+
img = resize_image(img, resize_size)
|
| 70 |
+
return img
|
| 71 |
+
|
| 72 |
+
def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, ckpt_index=None, task_suite_name=None, task_id=None):
|
| 73 |
+
"""Saves an MP4 replay of an episode."""
|
| 74 |
+
rollout_dir = f"./Libero/rollouts/{ckpt_index}/{task_suite_name}-task{task_id}-{DATE_TIME}-{ckpt_index}"
|
| 75 |
+
os.makedirs(rollout_dir, exist_ok=True)
|
| 76 |
+
processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
|
| 77 |
+
mp4_path = f"{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4"
|
| 78 |
+
video_writer = imageio.get_writer(mp4_path, fps=30)
|
| 79 |
+
for img in rollout_images:
|
| 80 |
+
video_writer.append_data(img)
|
| 81 |
+
video_writer.close()
|
| 82 |
+
print(f"Saved rollout MP4 at path {mp4_path}")
|
| 83 |
+
if log_file is not None:
|
| 84 |
+
log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
|
| 85 |
+
return mp4_path
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def quat2axisangle(quat):
|
| 89 |
+
"""
|
| 90 |
+
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
| 91 |
+
|
| 92 |
+
Converts quaternion to axis-angle format.
|
| 93 |
+
Returns a unit vector direction scaled by its angle in radians.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
quat (np.array): (x,y,z,w) vec4 float angles
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
np.array: (ax,ay,az) axis-angle exponential coordinates
|
| 100 |
+
"""
|
| 101 |
+
# clip quaternion
|
| 102 |
+
if quat[3] > 1.0:
|
| 103 |
+
quat[3] = 1.0
|
| 104 |
+
elif quat[3] < -1.0:
|
| 105 |
+
quat[3] = -1.0
|
| 106 |
+
|
| 107 |
+
den = np.sqrt(1.0 - quat[3] * quat[3])
|
| 108 |
+
if math.isclose(den, 0.0):
|
| 109 |
+
# This is (close to) a zero degree rotation, immediately return
|
| 110 |
+
return np.zeros(3)
|
| 111 |
+
|
| 112 |
+
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
experiment/libero/libero/req.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
imageio[ffmpeg]
|
| 2 |
+
robosuite==1.4.1
|
| 3 |
+
bddl
|
| 4 |
+
easydict
|
| 5 |
+
cloudpickle
|
| 6 |
+
gym
|
experiment/libero/libero/run_libero_eval.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
run_libero_eval.py
|
| 3 |
+
|
| 4 |
+
Runs a model in a LIBERO simulation environment.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
# OpenVLA:
|
| 8 |
+
# IMPORTANT: Set `center_crop=True` if model is fine-tuned with augmentations
|
| 9 |
+
python Libero/robot/libero/run_libero_eval.py \
|
| 10 |
+
--model_family openvla \
|
| 11 |
+
--pretrained_checkpoint <CHECKPOINT_PATH> \
|
| 12 |
+
--task_suite_name [ libero_spatial | libero_object | libero_goal | libero_10 | libero_90 ] \
|
| 13 |
+
--center_crop [ True | False ] \
|
| 14 |
+
--run_id_note <OPTIONAL TAG TO INSERT INTO RUN ID FOR LOGGING> \
|
| 15 |
+
--use_wandb [ True | False ] \
|
| 16 |
+
--wandb_project <PROJECT> \
|
| 17 |
+
--wandb_entity <ENTITY>
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import tensorflow as tf
|
| 21 |
+
import os, json, re, io, base64, threading
|
| 22 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
| 23 |
+
for g in tf.config.list_physical_devices('GPU'):
|
| 24 |
+
tf.config.experimental.set_memory_growth(g, True)
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
parent_dir = os.path.dirname(os.getcwd())
|
| 29 |
+
sys.path.insert(0, parent_dir)
|
| 30 |
+
sys.path.insert(0, os.getcwd())
|
| 31 |
+
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Optional, Union
|
| 35 |
+
import torch
|
| 36 |
+
|
| 37 |
+
import draccus
|
| 38 |
+
import numpy as np
|
| 39 |
+
import tqdm
|
| 40 |
+
from libero.libero import benchmark
|
| 41 |
+
|
| 42 |
+
import wandb
|
| 43 |
+
|
| 44 |
+
# Append current directory so that interpreter can find Libero.robot
|
| 45 |
+
from experiment.libero.libero.libero_utils import (
|
| 46 |
+
get_libero_dummy_action,
|
| 47 |
+
get_libero_env,
|
| 48 |
+
get_libero_image,
|
| 49 |
+
get_libero_wrist_image,
|
| 50 |
+
quat2axisangle,
|
| 51 |
+
save_rollout_video,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
from experiment.libero.robot_utils import (
|
| 55 |
+
DATE_TIME,
|
| 56 |
+
get_action,
|
| 57 |
+
get_image_resize_size,
|
| 58 |
+
get_model,
|
| 59 |
+
invert_gripper_action,
|
| 60 |
+
normalize_gripper_action,
|
| 61 |
+
set_seed_everywhere,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class GenerateConfig:
|
| 67 |
+
# fmt: off
|
| 68 |
+
|
| 69 |
+
#################################################################################################################
|
| 70 |
+
# Model-specific parameters
|
| 71 |
+
#################################################################################################################
|
| 72 |
+
model_family: str = "instruct_vla" # Model family
|
| 73 |
+
pretrained_checkpoint: Union[str, Path] = "" # Pretrained checkpoint path
|
| 74 |
+
unnorm_key: Optional[str] = None
|
| 75 |
+
# image_size: list[int] = [224, 224]
|
| 76 |
+
action_dim: int = 7
|
| 77 |
+
model_port: int = 8012
|
| 78 |
+
|
| 79 |
+
#################################################################################################################
|
| 80 |
+
# LIBERO environment-specific parameters
|
| 81 |
+
#################################################################################################################
|
| 82 |
+
task_suite_name: str = "libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
|
| 83 |
+
task_id: Optional[int] = None
|
| 84 |
+
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize in sim
|
| 85 |
+
num_trials_per_task: int = 50 # Number of rollouts per task
|
| 86 |
+
|
| 87 |
+
#################################################################################################################
|
| 88 |
+
# Utils
|
| 89 |
+
#################################################################################################################
|
| 90 |
+
run_id_note: Optional[str] = None # Extra note to add in run ID for logging
|
| 91 |
+
local_log_dir: str = "./Libero/logs" # Local directory for eval logs
|
| 92 |
+
|
| 93 |
+
use_wandb: bool = False # Whether to also log results in Weights & Biases
|
| 94 |
+
wandb_project: str = "YOUR_WANDB_PROJECT" # Name of W&B project to log to (use default!)
|
| 95 |
+
wandb_entity: str = "YOUR_WANDB_ENTITY" # Name of entity to log under
|
| 96 |
+
|
| 97 |
+
seed: int = 42 # Random Seed (for reproducibility)
|
| 98 |
+
use_length: int = 8
|
| 99 |
+
# fmt: on
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@draccus.wrap()
|
| 103 |
+
def eval_libero(cfg: GenerateConfig) -> None:
|
| 104 |
+
|
| 105 |
+
ckpt_index = cfg.pretrained_checkpoint.split('/checkpoints/')[0].split('/')[-1]
|
| 106 |
+
# Set random seed
|
| 107 |
+
set_seed_everywhere(cfg.seed)
|
| 108 |
+
|
| 109 |
+
# [OpenVLA] Check that the model contains the action un-normalization key
|
| 110 |
+
if cfg.model_family == "openvla":
|
| 111 |
+
# [OpenVLA] Set action un-normalization key
|
| 112 |
+
cfg.unnorm_key = cfg.task_suite_name
|
| 113 |
+
model, server = get_model(cfg)
|
| 114 |
+
server = None
|
| 115 |
+
# In some cases, the key must be manually modified (e.g. after training on a modified version of the dataset
|
| 116 |
+
# with the suffix "_no_noops" in the dataset name)
|
| 117 |
+
if cfg.unnorm_key not in model.norm_stats and f"{cfg.unnorm_key}_no_noops" in model.norm_stats:
|
| 118 |
+
cfg.unnorm_key = f"{cfg.unnorm_key}_no_noops"
|
| 119 |
+
assert cfg.unnorm_key in model.norm_stats, f"Action un-norm key {cfg.unnorm_key} not found in VLA `norm_stats`!"
|
| 120 |
+
|
| 121 |
+
elif cfg.model_family == "instruct_vla":
|
| 122 |
+
# [OpenVLA] Set action un-normalization key
|
| 123 |
+
cfg.unnorm_key = f"{cfg.task_suite_name}_no_noops"
|
| 124 |
+
model, server = get_model(cfg)
|
| 125 |
+
|
| 126 |
+
# Initialize local logging
|
| 127 |
+
run_id = f"EVAL-{cfg.task_suite_name}-task{cfg.task_id}-{cfg.model_family}-{DATE_TIME}-{ckpt_index}"
|
| 128 |
+
if cfg.run_id_note is not None:
|
| 129 |
+
run_id += f"--{cfg.run_id_note}"
|
| 130 |
+
cfg.local_log_dir = os.path.join(cfg.local_log_dir, ckpt_index)
|
| 131 |
+
os.makedirs(cfg.local_log_dir, exist_ok=True)
|
| 132 |
+
local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
|
| 133 |
+
log_file = open(local_log_filepath, "w")
|
| 134 |
+
print(f"Logging to local log file: {local_log_filepath}")
|
| 135 |
+
|
| 136 |
+
# Initialize Weights & Biases logging as well
|
| 137 |
+
if cfg.use_wandb:
|
| 138 |
+
wandb.init(
|
| 139 |
+
entity=cfg.wandb_entity,
|
| 140 |
+
project=cfg.wandb_project,
|
| 141 |
+
name=run_id,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Initialize LIBERO task suite
|
| 145 |
+
benchmark_dict = benchmark.get_benchmark_dict()
|
| 146 |
+
task_suite = benchmark_dict[cfg.task_suite_name]()
|
| 147 |
+
num_tasks_in_suite = task_suite.n_tasks
|
| 148 |
+
print(f"Task suite: {cfg.task_suite_name}")
|
| 149 |
+
log_file.write(f"Task suite: {cfg.task_suite_name}\n")
|
| 150 |
+
|
| 151 |
+
# Get expected image dimensions
|
| 152 |
+
resize_size = get_image_resize_size(cfg)
|
| 153 |
+
|
| 154 |
+
# Start evaluation
|
| 155 |
+
total_episodes, total_successes = 0, 0
|
| 156 |
+
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
|
| 157 |
+
# Get task
|
| 158 |
+
if cfg.task_id is not None:
|
| 159 |
+
if cfg.task_suite_name == 'libero_10':
|
| 160 |
+
if task_id != cfg.task_id:
|
| 161 |
+
continue
|
| 162 |
+
task = task_suite.get_task(task_id)
|
| 163 |
+
|
| 164 |
+
# Get default LIBERO initial states
|
| 165 |
+
initial_states = task_suite.get_task_init_states(task_id)
|
| 166 |
+
|
| 167 |
+
# Initialize LIBERO environment and task description
|
| 168 |
+
env, task_description = get_libero_env(task, cfg.model_family, resolution=256)
|
| 169 |
+
|
| 170 |
+
# Start episodes
|
| 171 |
+
task_episodes, task_successes = 0, 0
|
| 172 |
+
for episode_idx in tqdm.tqdm(range(cfg.num_trials_per_task)):
|
| 173 |
+
print(f"\nTask: {task_description}")
|
| 174 |
+
log_file.write(f"\nTask: {task_description}\n")
|
| 175 |
+
|
| 176 |
+
# Reset environment
|
| 177 |
+
env.reset()
|
| 178 |
+
server.reset(robo_name='libero')
|
| 179 |
+
# Set initial states
|
| 180 |
+
obs = env.set_init_state(initial_states[episode_idx])
|
| 181 |
+
|
| 182 |
+
# Setup
|
| 183 |
+
t = 0
|
| 184 |
+
replay_images = []
|
| 185 |
+
if cfg.task_suite_name == "libero_spatial":
|
| 186 |
+
max_steps = 220 # longest training demo has 193 steps
|
| 187 |
+
elif cfg.task_suite_name == "libero_object":
|
| 188 |
+
max_steps = 280 # longest training demo has 254 steps
|
| 189 |
+
elif cfg.task_suite_name == "libero_goal":
|
| 190 |
+
max_steps = 300 # longest training demo has 270 steps
|
| 191 |
+
elif cfg.task_suite_name == "libero_10":
|
| 192 |
+
max_steps = 520 # longest training demo has 505 steps
|
| 193 |
+
elif cfg.task_suite_name == "libero_90":
|
| 194 |
+
max_steps = 400 # longest training demo has 373 steps
|
| 195 |
+
|
| 196 |
+
print(f"Starting episode {task_episodes+1}...")
|
| 197 |
+
log_file.write(f"Starting episode {task_episodes+1}...\n")
|
| 198 |
+
while t < max_steps + cfg.num_steps_wait:
|
| 199 |
+
# try:
|
| 200 |
+
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
|
| 201 |
+
# and we need to wait for them to fall
|
| 202 |
+
if t < cfg.num_steps_wait:
|
| 203 |
+
obs, reward, done, info = env.step(get_libero_dummy_action(cfg.model_family))
|
| 204 |
+
t += 1
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
# Get preprocessed image
|
| 208 |
+
img = get_libero_image(obs, resize_size)
|
| 209 |
+
wrist_img = get_libero_wrist_image(obs, resize_size)
|
| 210 |
+
|
| 211 |
+
# Save preprocessed image for replay video
|
| 212 |
+
replay_images.append(img)
|
| 213 |
+
|
| 214 |
+
# Prepare observations dict
|
| 215 |
+
# Note: OpenVLA does not take proprio state as input
|
| 216 |
+
|
| 217 |
+
state = np.concatenate(
|
| 218 |
+
(obs["robot0_eef_pos"], quat2axisangle(obs["robot0_eef_quat"]), obs["robot0_gripper_qpos"]))
|
| 219 |
+
|
| 220 |
+
observation = {
|
| 221 |
+
"image": img,
|
| 222 |
+
"wrist_image": wrist_img,
|
| 223 |
+
"state": state,
|
| 224 |
+
"task": task_description,
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
# Query model to get action
|
| 228 |
+
action = get_action(
|
| 229 |
+
server, observation
|
| 230 |
+
).copy()
|
| 231 |
+
|
| 232 |
+
# Normalize gripper action [0,1] -> [-1,+1] because the environment expects the latter
|
| 233 |
+
# action = normalize_gripper_action(action, binarize=True)
|
| 234 |
+
action[..., -1] = np.sign(action[..., -1]) # binarize
|
| 235 |
+
|
| 236 |
+
# [OpenVLA] The dataloader flips the sign of the gripper action to align with other datasets
|
| 237 |
+
# (0 = close, 1 = open), so flip it back (-1 = open, +1 = close) before executing the action
|
| 238 |
+
# action = invert_gripper_action(action) # skip since we use raw action
|
| 239 |
+
|
| 240 |
+
print('==>action is',action)
|
| 241 |
+
# Execute action in environment
|
| 242 |
+
obs, reward, done, info = env.step(action.tolist())
|
| 243 |
+
if done:
|
| 244 |
+
task_successes += 1
|
| 245 |
+
total_successes += 1
|
| 246 |
+
break
|
| 247 |
+
t += 1
|
| 248 |
+
|
| 249 |
+
# except Exception as e:
|
| 250 |
+
# print(f"Caught exception: {e}")
|
| 251 |
+
# log_file.write(f"Caught exception: {e}\n")
|
| 252 |
+
# break
|
| 253 |
+
|
| 254 |
+
task_episodes += 1
|
| 255 |
+
total_episodes += 1
|
| 256 |
+
|
| 257 |
+
# Save a replay video of the episode
|
| 258 |
+
save_rollout_video(
|
| 259 |
+
replay_images, total_episodes, success=done, task_description=task_description, log_file=log_file, ckpt_index=ckpt_index, task_suite_name=cfg.task_suite_name, task_id=task_id
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Log current results
|
| 263 |
+
print(f"Success: {done}")
|
| 264 |
+
print(f"# episodes completed so far: {total_episodes}")
|
| 265 |
+
print(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
|
| 266 |
+
log_file.write(f"Success: {done}\n")
|
| 267 |
+
log_file.write(f"# episodes completed so far: {total_episodes}\n")
|
| 268 |
+
log_file.write(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)\n")
|
| 269 |
+
log_file.flush()
|
| 270 |
+
|
| 271 |
+
# Log final results
|
| 272 |
+
print(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
|
| 273 |
+
print(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
|
| 274 |
+
log_file.write(f"Current task success rate: {float(task_successes) / float(task_episodes)}\n")
|
| 275 |
+
log_file.write(f"Current total success rate: {float(total_successes) / float(total_episodes)}\n")
|
| 276 |
+
log_file.flush()
|
| 277 |
+
if cfg.use_wandb:
|
| 278 |
+
wandb.log(
|
| 279 |
+
{
|
| 280 |
+
f"success_rate/{task_description}": float(task_successes) / float(task_episodes),
|
| 281 |
+
f"num_episodes/{task_description}": task_episodes,
|
| 282 |
+
}
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Save local log file
|
| 286 |
+
log_file.close()
|
| 287 |
+
|
| 288 |
+
# Push total metrics and local log file to wandb
|
| 289 |
+
if cfg.use_wandb:
|
| 290 |
+
wandb.log(
|
| 291 |
+
{
|
| 292 |
+
"success_rate/total": float(total_successes) / float(total_episodes),
|
| 293 |
+
"num_episodes/total": total_episodes,
|
| 294 |
+
}
|
| 295 |
+
)
|
| 296 |
+
wandb.save(local_log_filepath)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
eval_libero()
|
experiment/libero/robot_utils.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utils for evaluating robot policies in various environments."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
# Initialize important constants and pretty-printing mode in NumPy.
|
| 11 |
+
ACTION_DIM = 7
|
| 12 |
+
DATE = time.strftime("%Y_%m_%d")
|
| 13 |
+
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
|
| 14 |
+
np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def set_seed_everywhere(seed: int):
|
| 19 |
+
"""Sets the random seed for Python, NumPy, and PyTorch functions."""
|
| 20 |
+
torch.manual_seed(seed)
|
| 21 |
+
torch.cuda.manual_seed_all(seed)
|
| 22 |
+
np.random.seed(seed)
|
| 23 |
+
random.seed(seed)
|
| 24 |
+
torch.backends.cudnn.deterministic = True
|
| 25 |
+
torch.backends.cudnn.benchmark = False
|
| 26 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_model(cfg, wrap_diffusion_policy_for_droid=False):
|
| 30 |
+
"""Load model for evaluation."""
|
| 31 |
+
from deploy.websocket_client_policy import WebsocketClientPolicy
|
| 32 |
+
cronus_server = WebsocketClientPolicy(port=cfg.model_port)
|
| 33 |
+
return None, cronus_server
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_image_resize_size(cfg):
|
| 37 |
+
"""
|
| 38 |
+
Gets image resize size for a model class.
|
| 39 |
+
If `resize_size` is an int, then the resized image will be a square.
|
| 40 |
+
Else, the image will be a rectangle.
|
| 41 |
+
"""
|
| 42 |
+
if cfg.model_family == "openvla" or "instruct_vla" in cfg.model_family:
|
| 43 |
+
resize_size = 224
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError("Unexpected `model_family` found in config.")
|
| 46 |
+
return resize_size
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_action(server, obs):
|
| 50 |
+
"""Queries the model to get an action."""
|
| 51 |
+
|
| 52 |
+
action = server.infer(obs)['action']
|
| 53 |
+
return action
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def normalize_gripper_action(action, binarize=True):
|
| 57 |
+
"""
|
| 58 |
+
Changes gripper action (last dimension of action vector) from [0,1] to [-1,+1].
|
| 59 |
+
Necessary for some environments (not Bridge) because the dataset wrapper standardizes gripper actions to [0,1].
|
| 60 |
+
Note that unlike the other action dimensions, the gripper action is not normalized to [-1,+1] by default by
|
| 61 |
+
the dataset wrapper.
|
| 62 |
+
|
| 63 |
+
Normalization formula: y = 2 * (x - orig_low) / (orig_high - orig_low) - 1
|
| 64 |
+
"""
|
| 65 |
+
# Just normalize the last action to [-1,+1].
|
| 66 |
+
orig_low, orig_high = 0.0, 1.0
|
| 67 |
+
action = np.array(action, copy=True)
|
| 68 |
+
action[..., -1] = 2 * (action[..., -1] - orig_low) / (orig_high - orig_low) - 1
|
| 69 |
+
|
| 70 |
+
if binarize:
|
| 71 |
+
# Binarize to -1 or +1.
|
| 72 |
+
action[..., -1] = np.sign(action[..., -1])
|
| 73 |
+
|
| 74 |
+
return action
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def invert_gripper_action(action):
|
| 78 |
+
"""
|
| 79 |
+
Flips the sign of the gripper action (last dimension of action vector).
|
| 80 |
+
This is necessary for some environments where -1 = open, +1 = close, since
|
| 81 |
+
the RLDS dataloader aligns gripper actions such that 0 = close, 1 = open.
|
| 82 |
+
"""
|
| 83 |
+
action[..., -1] = action[..., -1] * -1.0
|
| 84 |
+
return action
|
experiment/robotwin/README.md
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generate Lerobot Dataset from RoboTwin Data
|
| 2 |
+
|
| 3 |
+
This guide explains how to process raw data from **RoboTwin** and convert it into the **LerobotDataset** format following the official RoboTwin instructions.
|
| 4 |
+
|
| 5 |
+
## 1. Clone the Official RoboTwin Repository
|
| 6 |
+
```bash
|
| 7 |
+
git clone git@github.com:RoboTwin-Platform/RoboTwin.git
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
## 2. Create Required Directories
|
| 11 |
+
Navigate to the `policy/pi0` directory inside the cloned RoboTwin repository and create the folders:
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
cd ./policy/pi0
|
| 15 |
+
mkdir processed_data training_data
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
## 3. Convert RoboTwin Raw Data to HDF5
|
| 19 |
+
|
| 20 |
+
Use the provided script [process_data_pi0.sh](https://github.com/RoboTwin-Platform/RoboTwin/blob/main/policy/pi0/process_data_pi0.sh):
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
bash process_data_pi0.sh ${task_name} ${task_config} ${expert_data_num}
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
**Example (clean demo):**
|
| 27 |
+
```bash
|
| 28 |
+
bash process_data_pi0.sh beat_block_hammer demo_clean 50
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
**Example (randomized demo):**
|
| 32 |
+
```bash
|
| 33 |
+
bash process_data_pi0.sh beat_block_hammer demo_randomized 50
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
If successful, the output folder:
|
| 37 |
+
```
|
| 38 |
+
processed_data/${task_name}-${task_config}-${expert_data_num}/
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## 4. Prepare Training Data
|
| 42 |
+
|
| 43 |
+
Copy the required processed datasets into `training_data/${model_name}`:
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
cp -r processed_data/${task_name}-${task_config}-${expert_data_num} \
|
| 47 |
+
training_data/${model_name}/
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## 5. Ensure Sufficient Disk Space
|
| 51 |
+
|
| 52 |
+
The generated **LerobotDataset** will be stored under:
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
$XDG_CACHE_HOME/huggingface/lerobot/${repo_id}
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
By default, `XDG_CACHE_HOME` points to `~/.cache`, which must have sufficient free space.
|
| 59 |
+
If space is low, change the cache location:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
export XDG_CACHE_HOME=/path/to/your/cache
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## 6. Generate LerobotDataset Format
|
| 66 |
+
|
| 67 |
+
Run [process_data_pi0.sh](https://github.com/RoboTwin-Platform/RoboTwin/blob/main/policy/pi0/generate.sh) to convert the HDF5 datasets to Lerobot.
|
| 68 |
+
|
| 69 |
+
Parameters:
|
| 70 |
+
- **hdf5_path**: Path to the HDF5 training data (e.g., `./training_data/${model_name}/`)
|
| 71 |
+
- **repo_id**: Name for the dataset (e.g., `my_repo`)
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
bash generate.sh ${hdf5_path} ${repo_id}
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
**Example:**
|
| 78 |
+
```bash
|
| 79 |
+
bash generate.sh ./training_data/demo_clean/ demo_clean_repo
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
Output:
|
| 83 |
+
```
|
| 84 |
+
${XDG_CACHE_HOME}/huggingface/lerobot/${repo_id}
|
| 85 |
+
```
|
lingbotvla/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
__version__ = "0.0.1"
|
lingbotvla/checkpoint/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from .checkpointer import build_checkpointer
|
| 17 |
+
from .format_utils import bytecheckpoint_ckpt_to_state_dict, ckpt_to_state_dict, dcp_to_torch_state_dict
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"ckpt_to_state_dict",
|
| 22 |
+
"dcp_to_torch_state_dict",
|
| 23 |
+
"bytecheckpoint_ckpt_to_state_dict",
|
| 24 |
+
"build_checkpointer",
|
| 25 |
+
]
|
lingbotvla/checkpoint/checkpointer.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from abc import ABC, abstractmethod
|
| 18 |
+
from typing import Any, Dict
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.distributed as dist
|
| 22 |
+
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
| 23 |
+
from ..utils.import_utils import is_torch_version_greater_than
|
| 24 |
+
from ..utils.logging import get_logger
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
if is_torch_version_greater_than("2.4"):
|
| 28 |
+
import torch.distributed.checkpoint as dcp
|
| 29 |
+
from torch.distributed.checkpoint import (
|
| 30 |
+
FileSystemReader,
|
| 31 |
+
FileSystemWriter,
|
| 32 |
+
)
|
| 33 |
+
from torch.distributed.checkpoint.state_dict import (
|
| 34 |
+
get_model_state_dict,
|
| 35 |
+
get_optimizer_state_dict,
|
| 36 |
+
set_model_state_dict,
|
| 37 |
+
set_optimizer_state_dict,
|
| 38 |
+
)
|
| 39 |
+
from torch.distributed.checkpoint.stateful import Stateful
|
| 40 |
+
else:
|
| 41 |
+
Stateful = ABC
|
| 42 |
+
|
| 43 |
+
logger = get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
_EXTRA_STATE_FORMAT = "extra_state_rank_{}.pt"
|
| 46 |
+
_MODEL_DIR = "model"
|
| 47 |
+
_EMA_DIR = "ema"
|
| 48 |
+
_OPTIMIZER_DIR = "optimizer"
|
| 49 |
+
_EXTRA_STATE_DIR = "extra_state"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ModelState(Stateful):
|
| 53 |
+
"""
|
| 54 |
+
A wrapper around a model to make it stateful.
|
| 55 |
+
Args:
|
| 56 |
+
model (Model): model to wrap.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, model):
|
| 60 |
+
self.model = model
|
| 61 |
+
|
| 62 |
+
def state_dict(self):
|
| 63 |
+
model_state_dict = get_model_state_dict(model=self.model)
|
| 64 |
+
return {"model": model_state_dict}
|
| 65 |
+
|
| 66 |
+
def load_state_dict(self, state_dict):
|
| 67 |
+
set_model_state_dict(model=self.model, model_state_dict=state_dict["model"])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class OptimizerState(Stateful):
|
| 71 |
+
"""
|
| 72 |
+
A wrapper around an optimizer to make it stateful.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
model (Model): model to wrap.
|
| 76 |
+
optimizer (Optimizer): optimizer to wrap.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
def __init__(self, model, optimizer):
|
| 80 |
+
self.model = model
|
| 81 |
+
self.optimizer = optimizer
|
| 82 |
+
|
| 83 |
+
def state_dict(self):
|
| 84 |
+
optimizer_state_dict = get_optimizer_state_dict(model=self.model, optimizers=self.optimizer)
|
| 85 |
+
return {"optim": optimizer_state_dict}
|
| 86 |
+
|
| 87 |
+
def load_state_dict(self, state_dict):
|
| 88 |
+
set_optimizer_state_dict(model=self.model, optimizers=self.optimizer, optim_state_dict=state_dict["optim"])
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def build_checkpointer(
|
| 92 |
+
dist_backend: str = "fsdp1",
|
| 93 |
+
ckpt_manager: str = "bytecheckpoint",
|
| 94 |
+
):
|
| 95 |
+
"""
|
| 96 |
+
create a checkpointer manager with given mode.
|
| 97 |
+
Args:
|
| 98 |
+
dist_backend (str, optional): checkpoint mode. Defaults to "fsdp1".
|
| 99 |
+
fsdp1: FSDP1 checkpoint from bytecheckpoint
|
| 100 |
+
fsdp2-vescale: FSDP2 checkpoint from bytecheckpoint
|
| 101 |
+
fsdp2: FSDP2 checkpoint from bytecheckpoint
|
| 102 |
+
ddp: DDP checkpoint from bytecheckpoint
|
| 103 |
+
dcp: DCP checkpoint from torch.distributed.checkpoint
|
| 104 |
+
ckpt_manager (str, optional): checkpoint manager. Defaults to "bytecheckpoint".
|
| 105 |
+
bytecheckpoint: bytecheckpoint checkpoint manager
|
| 106 |
+
dcp: torch dcp checkpoint manager
|
| 107 |
+
Raises:
|
| 108 |
+
ValueError: if ckpt_manager is not supported
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Checkpointer: checkpointer with given mode.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
if ckpt_manager == "bytecheckpoint":
|
| 115 |
+
if dist_backend == "ddp":
|
| 116 |
+
from bytecheckpoint import DDPCheckpointer as Checkpointer
|
| 117 |
+
elif dist_backend == "fsdp1":
|
| 118 |
+
from bytecheckpoint import FSDPCheckpointer as Checkpointer
|
| 119 |
+
elif dist_backend == "fsdp2-vescale":
|
| 120 |
+
from bytecheckpoint import VeScaleCheckpointer as Checkpointer
|
| 121 |
+
elif dist_backend == "fsdp2":
|
| 122 |
+
from bytecheckpoint import FSDP2Checkpointer as Checkpointer
|
| 123 |
+
elif ckpt_manager == "dcp":
|
| 124 |
+
if not is_torch_version_greater_than("2.4"):
|
| 125 |
+
raise ValueError("DCP checkpoint manager requires torch version >= 2.4")
|
| 126 |
+
if dist_backend not in ["ddp", "fsdp1", "fsdp2"]:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Unsupported distributed backend: {dist_backend} for DCP checkpoint manager, supported modes are: ddp, fsdp1, fsdp2"
|
| 129 |
+
)
|
| 130 |
+
Checkpointer = DistributedCheckpointer
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
f"Unknown checkpoint manager: {ckpt_manager}, supported modes are: bytecheckpoint, dcp, native"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
return Checkpointer
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class CheckpointerBase(ABC):
|
| 140 |
+
"""Base class for checkpointer"""
|
| 141 |
+
|
| 142 |
+
@abstractmethod
|
| 143 |
+
def save(
|
| 144 |
+
cls,
|
| 145 |
+
path: str,
|
| 146 |
+
state: Dict[str, Any],
|
| 147 |
+
):
|
| 148 |
+
return
|
| 149 |
+
|
| 150 |
+
@abstractmethod
|
| 151 |
+
def load(
|
| 152 |
+
cls,
|
| 153 |
+
path: str,
|
| 154 |
+
state: Dict[str, Any],
|
| 155 |
+
):
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class DistributedCheckpointer(CheckpointerBase):
|
| 160 |
+
"""
|
| 161 |
+
Distributed checkpointer for torch.distributed.checkpoint
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
@classmethod
|
| 165 |
+
def save(
|
| 166 |
+
cls,
|
| 167 |
+
path: str,
|
| 168 |
+
state: Dict[str, Any],
|
| 169 |
+
global_steps: int = None,
|
| 170 |
+
save_async=False,
|
| 171 |
+
) -> None:
|
| 172 |
+
"""
|
| 173 |
+
save training state to distributed checkpoint
|
| 174 |
+
|
| 175 |
+
args:
|
| 176 |
+
path: path to save checkpoint
|
| 177 |
+
state: state to save
|
| 178 |
+
global_steps: global steps
|
| 179 |
+
save_async: whether to save asynchronously
|
| 180 |
+
return:
|
| 181 |
+
None
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
checkpoint_dir = f"{path}/global_step_{global_steps}" if global_steps else path
|
| 185 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 186 |
+
|
| 187 |
+
if "model" not in state:
|
| 188 |
+
raise ValueError("Model must be provided to save a distributed checkpoint.")
|
| 189 |
+
|
| 190 |
+
if save_async:
|
| 191 |
+
model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
|
| 192 |
+
dcp.async_save(
|
| 193 |
+
state_dict={"state": ModelState(state["model"])},
|
| 194 |
+
storage_writer=FileSystemWriter(
|
| 195 |
+
model_dir,
|
| 196 |
+
thread_count=16,
|
| 197 |
+
single_file_per_rank=True,
|
| 198 |
+
sync_files=False,
|
| 199 |
+
),
|
| 200 |
+
)
|
| 201 |
+
if "ema" in state and state["ema"] is not None:
|
| 202 |
+
ema_dir = os.path.join(checkpoint_dir, _EMA_DIR)
|
| 203 |
+
dcp.async_save(
|
| 204 |
+
state_dict={"state": ModelState(state["ema"])},
|
| 205 |
+
storage_writer=FileSystemWriter(
|
| 206 |
+
ema_dir,
|
| 207 |
+
thread_count=16,
|
| 208 |
+
single_file_per_rank=True,
|
| 209 |
+
sync_files=False,
|
| 210 |
+
),
|
| 211 |
+
)
|
| 212 |
+
if "optimizer" in state:
|
| 213 |
+
optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR)
|
| 214 |
+
dcp.async_save(
|
| 215 |
+
state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])},
|
| 216 |
+
storage_writer=FileSystemWriter(
|
| 217 |
+
optimizer_dir,
|
| 218 |
+
thread_count=16,
|
| 219 |
+
single_file_per_rank=True,
|
| 220 |
+
sync_files=False,
|
| 221 |
+
),
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
def safe_create_writer(output_dir):
|
| 225 |
+
tmp_path = Path(output_dir) / ".metadata.tmp"
|
| 226 |
+
if tmp_path.exists():
|
| 227 |
+
print(f"Warning: removing existing tmp file: {tmp_path}")
|
| 228 |
+
tmp_path.unlink() # remove .metadata.tmp
|
| 229 |
+
return FileSystemWriter(
|
| 230 |
+
output_dir,
|
| 231 |
+
thread_count=16,
|
| 232 |
+
single_file_per_rank=True,
|
| 233 |
+
sync_files=False,
|
| 234 |
+
)
|
| 235 |
+
model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
|
| 236 |
+
storage_writer = safe_create_writer(model_dir)
|
| 237 |
+
dcp.save(
|
| 238 |
+
state_dict={"state": ModelState(state["model"])},
|
| 239 |
+
storage_writer=storage_writer,
|
| 240 |
+
)
|
| 241 |
+
if "ema" in state and state["ema"] is not None:
|
| 242 |
+
ema_dir = os.path.join(checkpoint_dir, _EMA_DIR)
|
| 243 |
+
storage_writer = safe_create_writer(ema_dir)
|
| 244 |
+
dcp.save(
|
| 245 |
+
state_dict={"state": ModelState(state["ema"])},
|
| 246 |
+
storage_writer=storage_writer,
|
| 247 |
+
)
|
| 248 |
+
if "optimizer" in state:
|
| 249 |
+
optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR)
|
| 250 |
+
dcp.save(
|
| 251 |
+
state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])},
|
| 252 |
+
storage_writer=FileSystemWriter(
|
| 253 |
+
optimizer_dir,
|
| 254 |
+
thread_count=16,
|
| 255 |
+
single_file_per_rank=True,
|
| 256 |
+
sync_files=False,
|
| 257 |
+
),
|
| 258 |
+
)
|
| 259 |
+
# dist.barrier()
|
| 260 |
+
|
| 261 |
+
if "extra_state" in state:
|
| 262 |
+
extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR)
|
| 263 |
+
os.makedirs(extra_state_dir, exist_ok=True)
|
| 264 |
+
extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank()))
|
| 265 |
+
torch.save(
|
| 266 |
+
state["extra_state"],
|
| 267 |
+
extra_state_path,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
logger.info_rank0(f"Saved checkpoint to {checkpoint_dir}")
|
| 271 |
+
|
| 272 |
+
@classmethod
|
| 273 |
+
def load(
|
| 274 |
+
cls,
|
| 275 |
+
path: str,
|
| 276 |
+
state: Dict[str, Any],
|
| 277 |
+
process_group=None,
|
| 278 |
+
) -> Dict[str, Any]:
|
| 279 |
+
"""
|
| 280 |
+
load training state from distributed checkpoint
|
| 281 |
+
args:
|
| 282 |
+
path: path to load checkpoint
|
| 283 |
+
state: state to load, "model" are required, "optimizer" and "extra_state" are optional
|
| 284 |
+
|
| 285 |
+
return:
|
| 286 |
+
state: state loaded
|
| 287 |
+
"""
|
| 288 |
+
checkpoint_dir = path
|
| 289 |
+
|
| 290 |
+
if state is None:
|
| 291 |
+
raise ValueError("State dict must be provided to load a distributed checkpoint.")
|
| 292 |
+
|
| 293 |
+
if "model" not in state:
|
| 294 |
+
raise ValueError("Model must be provided to load a distributed checkpoint.")
|
| 295 |
+
|
| 296 |
+
if "ema" in state and state["ema"] is not None:
|
| 297 |
+
ema_dir = os.path.join(checkpoint_dir, _EMA_DIR)
|
| 298 |
+
dcp.load(
|
| 299 |
+
state_dict={"state": ModelState(state["ema"])},
|
| 300 |
+
storage_reader=FileSystemReader(ema_dir),
|
| 301 |
+
process_group=process_group,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if "optimizer" in state:
|
| 305 |
+
model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
|
| 306 |
+
dcp.load(
|
| 307 |
+
state_dict={"state": ModelState(state["model"])},
|
| 308 |
+
storage_reader=FileSystemReader(model_dir),
|
| 309 |
+
process_group=process_group,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
optimizer_dir = os.path.join(checkpoint_dir, _OPTIMIZER_DIR)
|
| 313 |
+
try:
|
| 314 |
+
dcp.load(
|
| 315 |
+
state_dict={"state": OptimizerState(model=state["model"], optimizer=state["optimizer"])}, # 1043
|
| 316 |
+
storage_reader=FileSystemReader(optimizer_dir), # 1027
|
| 317 |
+
planner = DefaultLoadPlanner(allow_partial_load=True),
|
| 318 |
+
process_group=process_group,
|
| 319 |
+
)
|
| 320 |
+
except:
|
| 321 |
+
logger.info_rank0(f"Skip loading Optimizer from {checkpoint_dir}")
|
| 322 |
+
else:
|
| 323 |
+
model_dir = os.path.join(checkpoint_dir, _MODEL_DIR)
|
| 324 |
+
dcp.load(
|
| 325 |
+
state_dict={"state": ModelState(state["model"])},
|
| 326 |
+
storage_reader=FileSystemReader(model_dir),
|
| 327 |
+
process_group=process_group,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if "extra_state" in state:
|
| 331 |
+
extra_state_dir = os.path.join(checkpoint_dir, _EXTRA_STATE_DIR)
|
| 332 |
+
os.makedirs(extra_state_dir, exist_ok=True)
|
| 333 |
+
extra_state_path = os.path.join(extra_state_dir, _EXTRA_STATE_FORMAT.format(dist.get_rank()))
|
| 334 |
+
state["extra_state"] = torch.load(
|
| 335 |
+
extra_state_path,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
logger.info_rank0(f"Loaded checkpoint from {checkpoint_dir}")
|
| 339 |
+
|
| 340 |
+
return state
|